mirror of
https://gitee.com/270580156/weiyu.git
synced 2026-05-19 05:37:53 +00:00
Sync from bytedesk-private: update
This commit is contained in:
1
modules/python/vendors/FunASR/funasr/frontends/utils/__init__.py
vendored
Normal file
1
modules/python/vendors/FunASR/funasr/frontends/utils/__init__.py
vendored
Normal file
@@ -0,0 +1 @@
|
||||
"""Initialize sub package."""
|
||||
82
modules/python/vendors/FunASR/funasr/frontends/utils/beamformer.py
vendored
Normal file
82
modules/python/vendors/FunASR/funasr/frontends/utils/beamformer.py
vendored
Normal file
@@ -0,0 +1,82 @@
|
||||
import torch
|
||||
from torch_complex import functional as FC
|
||||
from torch_complex.tensor import ComplexTensor
|
||||
|
||||
|
||||
def get_power_spectral_density_matrix(
|
||||
xs: ComplexTensor, mask: torch.Tensor, normalization=True, eps: float = 1e-15
|
||||
) -> ComplexTensor:
|
||||
"""Return cross-channel power spectral density (PSD) matrix
|
||||
|
||||
Args:
|
||||
xs (ComplexTensor): (..., F, C, T)
|
||||
mask (torch.Tensor): (..., F, C, T)
|
||||
normalization (bool):
|
||||
eps (float):
|
||||
Returns
|
||||
psd (ComplexTensor): (..., F, C, C)
|
||||
|
||||
"""
|
||||
# outer product: (..., C_1, T) x (..., C_2, T) -> (..., T, C, C_2)
|
||||
psd_Y = FC.einsum("...ct,...et->...tce", [xs, xs.conj()])
|
||||
|
||||
# Averaging mask along C: (..., C, T) -> (..., T)
|
||||
mask = mask.mean(dim=-2)
|
||||
|
||||
# Normalized mask along T: (..., T)
|
||||
if normalization:
|
||||
# If assuming the tensor is padded with zero, the summation along
|
||||
# the time axis is same regardless of the padding length.
|
||||
mask = mask / (mask.sum(dim=-1, keepdim=True) + eps)
|
||||
|
||||
# psd: (..., T, C, C)
|
||||
psd = psd_Y * mask[..., None, None]
|
||||
# (..., T, C, C) -> (..., C, C)
|
||||
psd = psd.sum(dim=-3)
|
||||
|
||||
return psd
|
||||
|
||||
|
||||
def get_mvdr_vector(
|
||||
psd_s: ComplexTensor,
|
||||
psd_n: ComplexTensor,
|
||||
reference_vector: torch.Tensor,
|
||||
eps: float = 1e-15,
|
||||
) -> ComplexTensor:
|
||||
"""Return the MVDR(Minimum Variance Distortionless Response) vector:
|
||||
|
||||
h = (Npsd^-1 @ Spsd) / (Tr(Npsd^-1 @ Spsd)) @ u
|
||||
|
||||
Reference:
|
||||
On optimal frequency-domain multichannel linear filtering
|
||||
for noise reduction; M. Souden et al., 2010;
|
||||
https://ieeexplore.ieee.org/document/5089420
|
||||
|
||||
Args:
|
||||
psd_s (ComplexTensor): (..., F, C, C)
|
||||
psd_n (ComplexTensor): (..., F, C, C)
|
||||
reference_vector (torch.Tensor): (..., C)
|
||||
eps (float):
|
||||
Returns:
|
||||
beamform_vector (ComplexTensor)r: (..., F, C)
|
||||
"""
|
||||
# Add eps
|
||||
C = psd_n.size(-1)
|
||||
eye = torch.eye(C, dtype=psd_n.dtype, device=psd_n.device)
|
||||
shape = [1 for _ in range(psd_n.dim() - 2)] + [C, C]
|
||||
eye = eye.view(*shape)
|
||||
psd_n += eps * eye
|
||||
|
||||
# numerator: (..., C_1, C_2) x (..., C_2, C_3) -> (..., C_1, C_3)
|
||||
numerator = FC.einsum("...ec,...cd->...ed", [psd_n.inverse(), psd_s])
|
||||
# ws: (..., C, C) / (...,) -> (..., C, C)
|
||||
ws = numerator / (FC.trace(numerator)[..., None, None] + eps)
|
||||
# h: (..., F, C_1, C_2) x (..., C_2) -> (..., F, C_1)
|
||||
beamform_vector = FC.einsum("...fec,...c->...fe", [ws, reference_vector])
|
||||
return beamform_vector
|
||||
|
||||
|
||||
def apply_beamforming_vector(beamform_vector: ComplexTensor, mix: ComplexTensor) -> ComplexTensor:
|
||||
# (..., C) x (..., C, T) -> (..., T)
|
||||
es = FC.einsum("...c,...ct->...t", [beamform_vector.conj(), mix])
|
||||
return es
|
||||
184
modules/python/vendors/FunASR/funasr/frontends/utils/complex_utils.py
vendored
Normal file
184
modules/python/vendors/FunASR/funasr/frontends/utils/complex_utils.py
vendored
Normal file
@@ -0,0 +1,184 @@
|
||||
"""Beamformer module."""
|
||||
|
||||
from distutils.version import LooseVersion
|
||||
from typing import Sequence
|
||||
from typing import Tuple
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
|
||||
try:
|
||||
from torch_complex import functional as FC
|
||||
from torch_complex.tensor import ComplexTensor
|
||||
except:
|
||||
print("Please install torch_complex firstly")
|
||||
|
||||
|
||||
EPS = torch.finfo(torch.double).eps
|
||||
is_torch_1_8_plus = LooseVersion(torch.__version__) >= LooseVersion("1.8.0")
|
||||
is_torch_1_9_plus = LooseVersion(torch.__version__) >= LooseVersion("1.9.0")
|
||||
|
||||
|
||||
def new_complex_like(
|
||||
ref: Union[torch.Tensor, ComplexTensor],
|
||||
real_imag: Tuple[torch.Tensor, torch.Tensor],
|
||||
):
|
||||
if isinstance(ref, ComplexTensor):
|
||||
return ComplexTensor(*real_imag)
|
||||
elif is_torch_complex_tensor(ref):
|
||||
return torch.complex(*real_imag)
|
||||
else:
|
||||
raise ValueError("Please update your PyTorch version to 1.9+ for complex support.")
|
||||
|
||||
|
||||
def is_torch_complex_tensor(c):
|
||||
return not isinstance(c, ComplexTensor) and is_torch_1_9_plus and torch.is_complex(c)
|
||||
|
||||
|
||||
def is_complex(c):
|
||||
return isinstance(c, ComplexTensor) or is_torch_complex_tensor(c)
|
||||
|
||||
|
||||
def to_double(c):
|
||||
if not isinstance(c, ComplexTensor) and is_torch_1_9_plus and torch.is_complex(c):
|
||||
return c.to(dtype=torch.complex128)
|
||||
else:
|
||||
return c.double()
|
||||
|
||||
|
||||
def to_float(c):
|
||||
if not isinstance(c, ComplexTensor) and is_torch_1_9_plus and torch.is_complex(c):
|
||||
return c.to(dtype=torch.complex64)
|
||||
else:
|
||||
return c.float()
|
||||
|
||||
|
||||
def cat(seq: Sequence[Union[ComplexTensor, torch.Tensor]], *args, **kwargs):
|
||||
if not isinstance(seq, (list, tuple)):
|
||||
raise TypeError(
|
||||
"cat(): argument 'tensors' (position 1) must be tuple of Tensors, " "not Tensor"
|
||||
)
|
||||
if isinstance(seq[0], ComplexTensor):
|
||||
return FC.cat(seq, *args, **kwargs)
|
||||
else:
|
||||
return torch.cat(seq, *args, **kwargs)
|
||||
|
||||
|
||||
def complex_norm(c: Union[torch.Tensor, ComplexTensor], dim=-1, keepdim=False) -> torch.Tensor:
|
||||
if not is_complex(c):
|
||||
raise TypeError("Input is not a complex tensor.")
|
||||
if is_torch_complex_tensor(c):
|
||||
return torch.norm(c, dim=dim, keepdim=keepdim)
|
||||
else:
|
||||
return torch.sqrt((c.real**2 + c.imag**2).sum(dim=dim, keepdim=keepdim) + EPS)
|
||||
|
||||
|
||||
def einsum(equation, *operands):
|
||||
# NOTE: Do not mix ComplexTensor and torch.complex in the input!
|
||||
# NOTE (wangyou): Until PyTorch 1.9.0, torch.einsum does not support
|
||||
# mixed input with complex and real tensors.
|
||||
if len(operands) == 1:
|
||||
if isinstance(operands[0], (tuple, list)):
|
||||
operands = operands[0]
|
||||
complex_module = FC if isinstance(operands[0], ComplexTensor) else torch
|
||||
return complex_module.einsum(equation, *operands)
|
||||
elif len(operands) != 2:
|
||||
op0 = operands[0]
|
||||
same_type = all(op.dtype == op0.dtype for op in operands[1:])
|
||||
if same_type:
|
||||
_einsum = FC.einsum if isinstance(op0, ComplexTensor) else torch.einsum
|
||||
return _einsum(equation, *operands)
|
||||
else:
|
||||
raise ValueError("0 or More than 2 operands are not supported.")
|
||||
a, b = operands
|
||||
if isinstance(a, ComplexTensor) or isinstance(b, ComplexTensor):
|
||||
return FC.einsum(equation, a, b)
|
||||
elif is_torch_1_9_plus and (torch.is_complex(a) or torch.is_complex(b)):
|
||||
if not torch.is_complex(a):
|
||||
o_real = torch.einsum(equation, a, b.real)
|
||||
o_imag = torch.einsum(equation, a, b.imag)
|
||||
return torch.complex(o_real, o_imag)
|
||||
elif not torch.is_complex(b):
|
||||
o_real = torch.einsum(equation, a.real, b)
|
||||
o_imag = torch.einsum(equation, a.imag, b)
|
||||
return torch.complex(o_real, o_imag)
|
||||
else:
|
||||
return torch.einsum(equation, a, b)
|
||||
else:
|
||||
return torch.einsum(equation, a, b)
|
||||
|
||||
|
||||
def inverse(c: Union[torch.Tensor, ComplexTensor]) -> Union[torch.Tensor, ComplexTensor]:
|
||||
if isinstance(c, ComplexTensor):
|
||||
return c.inverse2()
|
||||
else:
|
||||
return c.inverse()
|
||||
|
||||
|
||||
def matmul(
|
||||
a: Union[torch.Tensor, ComplexTensor], b: Union[torch.Tensor, ComplexTensor]
|
||||
) -> Union[torch.Tensor, ComplexTensor]:
|
||||
# NOTE: Do not mix ComplexTensor and torch.complex in the input!
|
||||
# NOTE (wangyou): Until PyTorch 1.9.0, torch.matmul does not support
|
||||
# multiplication between complex and real tensors.
|
||||
if isinstance(a, ComplexTensor) or isinstance(b, ComplexTensor):
|
||||
return FC.matmul(a, b)
|
||||
elif is_torch_1_9_plus and (torch.is_complex(a) or torch.is_complex(b)):
|
||||
if not torch.is_complex(a):
|
||||
o_real = torch.matmul(a, b.real)
|
||||
o_imag = torch.matmul(a, b.imag)
|
||||
return torch.complex(o_real, o_imag)
|
||||
elif not torch.is_complex(b):
|
||||
o_real = torch.matmul(a.real, b)
|
||||
o_imag = torch.matmul(a.imag, b)
|
||||
return torch.complex(o_real, o_imag)
|
||||
else:
|
||||
return torch.matmul(a, b)
|
||||
else:
|
||||
return torch.matmul(a, b)
|
||||
|
||||
|
||||
def trace(a: Union[torch.Tensor, ComplexTensor]):
|
||||
# NOTE (wangyou): until PyTorch 1.9.0, torch.trace does not
|
||||
# support bacth processing. Use FC.trace() as fallback.
|
||||
return FC.trace(a)
|
||||
|
||||
|
||||
def reverse(a: Union[torch.Tensor, ComplexTensor], dim=0):
|
||||
if isinstance(a, ComplexTensor):
|
||||
return FC.reverse(a, dim=dim)
|
||||
else:
|
||||
return torch.flip(a, dims=(dim,))
|
||||
|
||||
|
||||
def solve(b: Union[torch.Tensor, ComplexTensor], a: Union[torch.Tensor, ComplexTensor]):
|
||||
"""Solve the linear equation ax = b."""
|
||||
# NOTE: Do not mix ComplexTensor and torch.complex in the input!
|
||||
# NOTE (wangyou): Until PyTorch 1.9.0, torch.solve does not support
|
||||
# mixed input with complex and real tensors.
|
||||
if isinstance(a, ComplexTensor) or isinstance(b, ComplexTensor):
|
||||
if isinstance(a, ComplexTensor) and isinstance(b, ComplexTensor):
|
||||
return FC.solve(b, a, return_LU=False)
|
||||
else:
|
||||
return matmul(inverse(a), b)
|
||||
elif is_torch_1_9_plus and (torch.is_complex(a) or torch.is_complex(b)):
|
||||
if torch.is_complex(a) and torch.is_complex(b):
|
||||
return torch.linalg.solve(a, b)
|
||||
else:
|
||||
return matmul(inverse(a), b)
|
||||
else:
|
||||
if is_torch_1_8_plus:
|
||||
return torch.linalg.solve(a, b)
|
||||
else:
|
||||
return torch.solve(b, a)[0]
|
||||
|
||||
|
||||
def stack(seq: Sequence[Union[ComplexTensor, torch.Tensor]], *args, **kwargs):
|
||||
if not isinstance(seq, (list, tuple)):
|
||||
raise TypeError(
|
||||
"stack(): argument 'tensors' (position 1) must be tuple of Tensors, " "not Tensor"
|
||||
)
|
||||
if isinstance(seq[0], ComplexTensor):
|
||||
return FC.stack(seq, *args, **kwargs)
|
||||
else:
|
||||
return torch.stack(seq, *args, **kwargs)
|
||||
161
modules/python/vendors/FunASR/funasr/frontends/utils/dnn_beamformer.py
vendored
Normal file
161
modules/python/vendors/FunASR/funasr/frontends/utils/dnn_beamformer.py
vendored
Normal file
@@ -0,0 +1,161 @@
|
||||
"""DNN beamformer module."""
|
||||
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
|
||||
from funasr.frontends.utils.beamformer import apply_beamforming_vector
|
||||
from funasr.frontends.utils.beamformer import get_mvdr_vector
|
||||
from funasr.frontends.utils.beamformer import (
|
||||
get_power_spectral_density_matrix, # noqa: H301
|
||||
)
|
||||
from funasr.frontends.utils.mask_estimator import MaskEstimator
|
||||
from torch_complex.tensor import ComplexTensor
|
||||
|
||||
|
||||
class DNN_Beamformer(torch.nn.Module):
|
||||
"""DNN mask based Beamformer
|
||||
|
||||
Citation:
|
||||
Multichannel End-to-end Speech Recognition; T. Ochiai et al., 2017;
|
||||
https://arxiv.org/abs/1703.04783
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
bidim,
|
||||
btype="blstmp",
|
||||
blayers=3,
|
||||
bunits=300,
|
||||
bprojs=320,
|
||||
bnmask=2,
|
||||
dropout_rate=0.0,
|
||||
badim=320,
|
||||
ref_channel: int = -1,
|
||||
beamformer_type="mvdr",
|
||||
):
|
||||
super().__init__()
|
||||
self.mask = MaskEstimator(btype, bidim, blayers, bunits, bprojs, dropout_rate, nmask=bnmask)
|
||||
self.ref = AttentionReference(bidim, badim)
|
||||
self.ref_channel = ref_channel
|
||||
|
||||
self.nmask = bnmask
|
||||
|
||||
if beamformer_type != "mvdr":
|
||||
raise ValueError("Not supporting beamformer_type={}".format(beamformer_type))
|
||||
self.beamformer_type = beamformer_type
|
||||
|
||||
def forward(
|
||||
self, data: ComplexTensor, ilens: torch.LongTensor
|
||||
) -> Tuple[ComplexTensor, torch.LongTensor, ComplexTensor]:
|
||||
"""The forward function
|
||||
|
||||
Notation:
|
||||
B: Batch
|
||||
C: Channel
|
||||
T: Time or Sequence length
|
||||
F: Freq
|
||||
|
||||
Args:
|
||||
data (ComplexTensor): (B, T, C, F)
|
||||
ilens (torch.Tensor): (B,)
|
||||
Returns:
|
||||
enhanced (ComplexTensor): (B, T, F)
|
||||
ilens (torch.Tensor): (B,)
|
||||
|
||||
"""
|
||||
|
||||
def apply_beamforming(data, ilens, psd_speech, psd_noise):
|
||||
# u: (B, C)
|
||||
if self.ref_channel < 0:
|
||||
u, _ = self.ref(psd_speech, ilens)
|
||||
else:
|
||||
# (optional) Create onehot vector for fixed reference microphone
|
||||
u = torch.zeros(*(data.size()[:-3] + (data.size(-2),)), device=data.device)
|
||||
u[..., self.ref_channel].fill_(1)
|
||||
|
||||
ws = get_mvdr_vector(psd_speech, psd_noise, u)
|
||||
enhanced = apply_beamforming_vector(ws, data)
|
||||
|
||||
return enhanced, ws
|
||||
|
||||
# data (B, T, C, F) -> (B, F, C, T)
|
||||
data = data.permute(0, 3, 2, 1)
|
||||
|
||||
# mask: (B, F, C, T)
|
||||
masks, _ = self.mask(data, ilens)
|
||||
assert self.nmask == len(masks)
|
||||
|
||||
if self.nmask == 2: # (mask_speech, mask_noise)
|
||||
mask_speech, mask_noise = masks
|
||||
|
||||
psd_speech = get_power_spectral_density_matrix(data, mask_speech)
|
||||
psd_noise = get_power_spectral_density_matrix(data, mask_noise)
|
||||
|
||||
enhanced, ws = apply_beamforming(data, ilens, psd_speech, psd_noise)
|
||||
|
||||
# (..., F, T) -> (..., T, F)
|
||||
enhanced = enhanced.transpose(-1, -2)
|
||||
mask_speech = mask_speech.transpose(-1, -3)
|
||||
else: # multi-speaker case: (mask_speech1, ..., mask_noise)
|
||||
mask_speech = list(masks[:-1])
|
||||
mask_noise = masks[-1]
|
||||
|
||||
psd_speeches = [get_power_spectral_density_matrix(data, mask) for mask in mask_speech]
|
||||
psd_noise = get_power_spectral_density_matrix(data, mask_noise)
|
||||
|
||||
enhanced = []
|
||||
ws = []
|
||||
for i in range(self.nmask - 1):
|
||||
psd_speech = psd_speeches.pop(i)
|
||||
# treat all other speakers' psd_speech as noises
|
||||
enh, w = apply_beamforming(data, ilens, psd_speech, sum(psd_speeches) + psd_noise)
|
||||
psd_speeches.insert(i, psd_speech)
|
||||
|
||||
# (..., F, T) -> (..., T, F)
|
||||
enh = enh.transpose(-1, -2)
|
||||
mask_speech[i] = mask_speech[i].transpose(-1, -3)
|
||||
|
||||
enhanced.append(enh)
|
||||
ws.append(w)
|
||||
|
||||
return enhanced, ilens, mask_speech
|
||||
|
||||
|
||||
class AttentionReference(torch.nn.Module):
|
||||
def __init__(self, bidim, att_dim):
|
||||
super().__init__()
|
||||
self.mlp_psd = torch.nn.Linear(bidim, att_dim)
|
||||
self.gvec = torch.nn.Linear(att_dim, 1)
|
||||
|
||||
def forward(
|
||||
self, psd_in: ComplexTensor, ilens: torch.LongTensor, scaling: float = 2.0
|
||||
) -> Tuple[torch.Tensor, torch.LongTensor]:
|
||||
"""The forward function
|
||||
|
||||
Args:
|
||||
psd_in (ComplexTensor): (B, F, C, C)
|
||||
ilens (torch.Tensor): (B,)
|
||||
scaling (float):
|
||||
Returns:
|
||||
u (torch.Tensor): (B, C)
|
||||
ilens (torch.Tensor): (B,)
|
||||
"""
|
||||
B, _, C = psd_in.size()[:3]
|
||||
assert psd_in.size(2) == psd_in.size(3), psd_in.size()
|
||||
# psd_in: (B, F, C, C)
|
||||
psd = psd_in.masked_fill(torch.eye(C, dtype=torch.bool, device=psd_in.device), 0)
|
||||
# psd: (B, F, C, C) -> (B, C, F)
|
||||
psd = (psd.sum(dim=-1) / (C - 1)).transpose(-1, -2)
|
||||
|
||||
# Calculate amplitude
|
||||
psd_feat = (psd.real**2 + psd.imag**2) ** 0.5
|
||||
|
||||
# (B, C, F) -> (B, C, F2)
|
||||
mlp_psd = self.mlp_psd(psd_feat)
|
||||
# (B, C, F2) -> (B, C, 1) -> (B, C)
|
||||
e = self.gvec(torch.tanh(mlp_psd)).squeeze(-1)
|
||||
u = F.softmax(scaling * e, dim=-1)
|
||||
return u, ilens
|
||||
93
modules/python/vendors/FunASR/funasr/frontends/utils/dnn_wpe.py
vendored
Normal file
93
modules/python/vendors/FunASR/funasr/frontends/utils/dnn_wpe.py
vendored
Normal file
@@ -0,0 +1,93 @@
|
||||
from typing import Tuple
|
||||
|
||||
from pytorch_wpe import wpe_one_iteration
|
||||
import torch
|
||||
from torch_complex.tensor import ComplexTensor
|
||||
|
||||
from funasr.frontends.utils.mask_estimator import MaskEstimator
|
||||
from funasr.models.transformer.utils.nets_utils import make_pad_mask
|
||||
|
||||
|
||||
class DNN_WPE(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
wtype: str = "blstmp",
|
||||
widim: int = 257,
|
||||
wlayers: int = 3,
|
||||
wunits: int = 300,
|
||||
wprojs: int = 320,
|
||||
dropout_rate: float = 0.0,
|
||||
taps: int = 5,
|
||||
delay: int = 3,
|
||||
use_dnn_mask: bool = True,
|
||||
iterations: int = 1,
|
||||
normalization: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.iterations = iterations
|
||||
self.taps = taps
|
||||
self.delay = delay
|
||||
|
||||
self.normalization = normalization
|
||||
self.use_dnn_mask = use_dnn_mask
|
||||
|
||||
self.inverse_power = True
|
||||
|
||||
if self.use_dnn_mask:
|
||||
self.mask_est = MaskEstimator(
|
||||
wtype, widim, wlayers, wunits, wprojs, dropout_rate, nmask=1
|
||||
)
|
||||
|
||||
def forward(
|
||||
self, data: ComplexTensor, ilens: torch.LongTensor
|
||||
) -> Tuple[ComplexTensor, torch.LongTensor, ComplexTensor]:
|
||||
"""The forward function
|
||||
|
||||
Notation:
|
||||
B: Batch
|
||||
C: Channel
|
||||
T: Time or Sequence length
|
||||
F: Freq or Some dimension of the feature vector
|
||||
|
||||
Args:
|
||||
data: (B, C, T, F)
|
||||
ilens: (B,)
|
||||
Returns:
|
||||
data: (B, C, T, F)
|
||||
ilens: (B,)
|
||||
"""
|
||||
# (B, T, C, F) -> (B, F, C, T)
|
||||
enhanced = data = data.permute(0, 3, 2, 1)
|
||||
mask = None
|
||||
|
||||
for i in range(self.iterations):
|
||||
# Calculate power: (..., C, T)
|
||||
power = enhanced.real**2 + enhanced.imag**2
|
||||
if i == 0 and self.use_dnn_mask:
|
||||
# mask: (B, F, C, T)
|
||||
(mask,), _ = self.mask_est(enhanced, ilens)
|
||||
if self.normalization:
|
||||
# Normalize along T
|
||||
mask = mask / mask.sum(dim=-1)[..., None]
|
||||
# (..., C, T) * (..., C, T) -> (..., C, T)
|
||||
power = power * mask
|
||||
|
||||
# Averaging along the channel axis: (..., C, T) -> (..., T)
|
||||
power = power.mean(dim=-2)
|
||||
|
||||
# enhanced: (..., C, T) -> (..., C, T)
|
||||
enhanced = wpe_one_iteration(
|
||||
data.contiguous(),
|
||||
power,
|
||||
taps=self.taps,
|
||||
delay=self.delay,
|
||||
inverse_power=self.inverse_power,
|
||||
)
|
||||
|
||||
enhanced.masked_fill_(make_pad_mask(ilens, enhanced.real), 0)
|
||||
|
||||
# (B, F, C, T) -> (B, T, C, F)
|
||||
enhanced = enhanced.permute(0, 3, 2, 1)
|
||||
if mask is not None:
|
||||
mask = mask.transpose(-1, -3)
|
||||
return enhanced, ilens, mask
|
||||
259
modules/python/vendors/FunASR/funasr/frontends/utils/feature_transform.py
vendored
Normal file
259
modules/python/vendors/FunASR/funasr/frontends/utils/feature_transform.py
vendored
Normal file
@@ -0,0 +1,259 @@
|
||||
from typing import List
|
||||
from typing import Tuple
|
||||
from typing import Union
|
||||
|
||||
import librosa
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch_complex.tensor import ComplexTensor
|
||||
|
||||
from funasr.models.transformer.utils.nets_utils import make_pad_mask
|
||||
|
||||
|
||||
class FeatureTransform(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
# Mel options,
|
||||
fs: int = 16000,
|
||||
n_fft: int = 512,
|
||||
n_mels: int = 80,
|
||||
fmin: float = 0.0,
|
||||
fmax: float = None,
|
||||
# Normalization
|
||||
stats_file: str = None,
|
||||
apply_uttmvn: bool = True,
|
||||
uttmvn_norm_means: bool = True,
|
||||
uttmvn_norm_vars: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.apply_uttmvn = apply_uttmvn
|
||||
|
||||
self.logmel = LogMel(fs=fs, n_fft=n_fft, n_mels=n_mels, fmin=fmin, fmax=fmax)
|
||||
self.stats_file = stats_file
|
||||
if stats_file is not None:
|
||||
self.global_mvn = GlobalMVN(stats_file)
|
||||
else:
|
||||
self.global_mvn = None
|
||||
|
||||
if self.apply_uttmvn is not None:
|
||||
self.uttmvn = UtteranceMVN(norm_means=uttmvn_norm_means, norm_vars=uttmvn_norm_vars)
|
||||
else:
|
||||
self.uttmvn = None
|
||||
|
||||
def forward(
|
||||
self, x: ComplexTensor, ilens: Union[torch.LongTensor, np.ndarray, List[int]]
|
||||
) -> Tuple[torch.Tensor, torch.LongTensor]:
|
||||
# (B, T, F) or (B, T, C, F)
|
||||
if x.dim() not in (3, 4):
|
||||
raise ValueError(f"Input dim must be 3 or 4: {x.dim()}")
|
||||
if not torch.is_tensor(ilens):
|
||||
ilens = torch.from_numpy(np.asarray(ilens)).to(x.device)
|
||||
|
||||
if x.dim() == 4:
|
||||
# h: (B, T, C, F) -> h: (B, T, F)
|
||||
if self.training:
|
||||
# Select 1ch randomly
|
||||
ch = np.random.randint(x.size(2))
|
||||
h = x[:, :, ch, :]
|
||||
else:
|
||||
# Use the first channel
|
||||
h = x[:, :, 0, :]
|
||||
else:
|
||||
h = x
|
||||
|
||||
# h: ComplexTensor(B, T, F) -> torch.Tensor(B, T, F)
|
||||
h = h.real**2 + h.imag**2
|
||||
|
||||
h, _ = self.logmel(h, ilens)
|
||||
if self.stats_file is not None:
|
||||
h, _ = self.global_mvn(h, ilens)
|
||||
if self.apply_uttmvn:
|
||||
h, _ = self.uttmvn(h, ilens)
|
||||
|
||||
return h, ilens
|
||||
|
||||
|
||||
class LogMel(torch.nn.Module):
|
||||
"""Convert STFT to fbank feats
|
||||
|
||||
The arguments is same as librosa.filters.mel
|
||||
|
||||
Args:
|
||||
fs: number > 0 [scalar] sampling rate of the incoming signal
|
||||
n_fft: int > 0 [scalar] number of FFT components
|
||||
n_mels: int > 0 [scalar] number of Mel bands to generate
|
||||
fmin: float >= 0 [scalar] lowest frequency (in Hz)
|
||||
fmax: float >= 0 [scalar] highest frequency (in Hz).
|
||||
If `None`, use `fmax = fs / 2.0`
|
||||
htk: use HTK formula instead of Slaney
|
||||
norm: {None, 1, np.inf} [scalar]
|
||||
if 1, divide the triangular mel weights by the width of the mel band
|
||||
(area normalization). Otherwise, leave all the triangles aiming for
|
||||
a peak value of 1.0
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
fs: int = 16000,
|
||||
n_fft: int = 512,
|
||||
n_mels: int = 80,
|
||||
fmin: float = 0.0,
|
||||
fmax: float = None,
|
||||
htk: bool = False,
|
||||
norm=1,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
_mel_options = dict(
|
||||
sr=fs, n_fft=n_fft, n_mels=n_mels, fmin=fmin, fmax=fmax, htk=htk, norm=norm
|
||||
)
|
||||
self.mel_options = _mel_options
|
||||
|
||||
# Note(kamo): The mel matrix of librosa is different from kaldi.
|
||||
melmat = librosa.filters.mel(**_mel_options)
|
||||
# melmat: (D2, D1) -> (D1, D2)
|
||||
self.register_buffer("melmat", torch.from_numpy(melmat.T).float())
|
||||
|
||||
def extra_repr(self):
|
||||
return ", ".join(f"{k}={v}" for k, v in self.mel_options.items())
|
||||
|
||||
def forward(
|
||||
self, feat: torch.Tensor, ilens: torch.LongTensor
|
||||
) -> Tuple[torch.Tensor, torch.LongTensor]:
|
||||
# feat: (B, T, D1) x melmat: (D1, D2) -> mel_feat: (B, T, D2)
|
||||
mel_feat = torch.matmul(feat, self.melmat)
|
||||
|
||||
logmel_feat = (mel_feat + 1e-20).log()
|
||||
# Zero padding
|
||||
logmel_feat = logmel_feat.masked_fill(make_pad_mask(ilens, logmel_feat, 1), 0.0)
|
||||
return logmel_feat, ilens
|
||||
|
||||
|
||||
class GlobalMVN(torch.nn.Module):
|
||||
"""Apply global mean and variance normalization
|
||||
|
||||
Args:
|
||||
stats_file(str): npy file of 1-dim array or text file.
|
||||
From the _first element to
|
||||
the {(len(array) - 1) / 2}th element are treated as
|
||||
the sum of features,
|
||||
and the rest excluding the last elements are
|
||||
treated as the sum of the square value of features,
|
||||
and the last elements eqauls to the number of samples.
|
||||
std_floor(float):
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
stats_file: str,
|
||||
norm_means: bool = True,
|
||||
norm_vars: bool = True,
|
||||
eps: float = 1.0e-20,
|
||||
):
|
||||
super().__init__()
|
||||
self.norm_means = norm_means
|
||||
self.norm_vars = norm_vars
|
||||
|
||||
self.stats_file = stats_file
|
||||
stats = np.load(stats_file)
|
||||
|
||||
stats = stats.astype(float)
|
||||
assert (len(stats) - 1) % 2 == 0, stats.shape
|
||||
|
||||
count = stats.flatten()[-1]
|
||||
mean = stats[: (len(stats) - 1) // 2] / count
|
||||
var = stats[(len(stats) - 1) // 2 : -1] / count - mean * mean
|
||||
std = np.maximum(np.sqrt(var), eps)
|
||||
|
||||
self.register_buffer("bias", torch.from_numpy(-mean.astype(np.float32)))
|
||||
self.register_buffer("scale", torch.from_numpy(1 / std.astype(np.float32)))
|
||||
|
||||
def extra_repr(self):
|
||||
return (
|
||||
f"stats_file={self.stats_file}, "
|
||||
f"norm_means={self.norm_means}, norm_vars={self.norm_vars}"
|
||||
)
|
||||
|
||||
def forward(
|
||||
self, x: torch.Tensor, ilens: torch.LongTensor
|
||||
) -> Tuple[torch.Tensor, torch.LongTensor]:
|
||||
# feat: (B, T, D)
|
||||
if self.norm_means:
|
||||
x += self.bias.type_as(x)
|
||||
x.masked_fill(make_pad_mask(ilens, x, 1), 0.0)
|
||||
|
||||
if self.norm_vars:
|
||||
x *= self.scale.type_as(x)
|
||||
return x, ilens
|
||||
|
||||
|
||||
class UtteranceMVN(torch.nn.Module):
|
||||
def __init__(self, norm_means: bool = True, norm_vars: bool = False, eps: float = 1.0e-20):
|
||||
super().__init__()
|
||||
self.norm_means = norm_means
|
||||
self.norm_vars = norm_vars
|
||||
self.eps = eps
|
||||
|
||||
def extra_repr(self):
|
||||
return f"norm_means={self.norm_means}, norm_vars={self.norm_vars}"
|
||||
|
||||
def forward(
|
||||
self, x: torch.Tensor, ilens: torch.LongTensor
|
||||
) -> Tuple[torch.Tensor, torch.LongTensor]:
|
||||
return utterance_mvn(
|
||||
x, ilens, norm_means=self.norm_means, norm_vars=self.norm_vars, eps=self.eps
|
||||
)
|
||||
|
||||
|
||||
def utterance_mvn(
|
||||
x: torch.Tensor,
|
||||
ilens: torch.LongTensor,
|
||||
norm_means: bool = True,
|
||||
norm_vars: bool = False,
|
||||
eps: float = 1.0e-20,
|
||||
) -> Tuple[torch.Tensor, torch.LongTensor]:
|
||||
"""Apply utterance mean and variance normalization
|
||||
|
||||
Args:
|
||||
x: (B, T, D), assumed zero padded
|
||||
ilens: (B, T, D)
|
||||
norm_means:
|
||||
norm_vars:
|
||||
eps:
|
||||
|
||||
"""
|
||||
ilens_ = ilens.type_as(x)
|
||||
# mean: (B, D)
|
||||
mean = x.sum(dim=1) / ilens_[:, None]
|
||||
|
||||
if norm_means:
|
||||
x -= mean[:, None, :]
|
||||
x_ = x
|
||||
else:
|
||||
x_ = x - mean[:, None, :]
|
||||
|
||||
# Zero padding
|
||||
x_.masked_fill(make_pad_mask(ilens, x_, 1), 0.0)
|
||||
if norm_vars:
|
||||
var = x_.pow(2).sum(dim=1) / ilens_[:, None]
|
||||
var = torch.clamp(var, min=eps)
|
||||
x /= var.sqrt()[:, None, :]
|
||||
x_ = x
|
||||
return x_, ilens
|
||||
|
||||
|
||||
def feature_transform_for(args, n_fft):
|
||||
return FeatureTransform(
|
||||
# Mel options,
|
||||
fs=args.fbank_fs,
|
||||
n_fft=n_fft,
|
||||
n_mels=args.n_mels,
|
||||
fmin=args.fbank_fmin,
|
||||
fmax=args.fbank_fmax,
|
||||
# Normalization
|
||||
stats_file=args.stats_file,
|
||||
apply_uttmvn=args.apply_uttmvn,
|
||||
uttmvn_norm_means=args.uttmvn_norm_means,
|
||||
uttmvn_norm_vars=args.uttmvn_norm_vars,
|
||||
)
|
||||
151
modules/python/vendors/FunASR/funasr/frontends/utils/frontend.py
vendored
Normal file
151
modules/python/vendors/FunASR/funasr/frontends/utils/frontend.py
vendored
Normal file
@@ -0,0 +1,151 @@
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
from typing import Tuple
|
||||
from typing import Union
|
||||
|
||||
import numpy
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch_complex.tensor import ComplexTensor
|
||||
|
||||
from funasr.frontends.utils.dnn_beamformer import DNN_Beamformer
|
||||
from funasr.frontends.utils.dnn_wpe import DNN_WPE
|
||||
|
||||
|
||||
class Frontend(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
idim: int,
|
||||
# WPE options
|
||||
use_wpe: bool = False,
|
||||
wtype: str = "blstmp",
|
||||
wlayers: int = 3,
|
||||
wunits: int = 300,
|
||||
wprojs: int = 320,
|
||||
wdropout_rate: float = 0.0,
|
||||
taps: int = 5,
|
||||
delay: int = 3,
|
||||
use_dnn_mask_for_wpe: bool = True,
|
||||
# Beamformer options
|
||||
use_beamformer: bool = False,
|
||||
btype: str = "blstmp",
|
||||
blayers: int = 3,
|
||||
bunits: int = 300,
|
||||
bprojs: int = 320,
|
||||
bnmask: int = 2,
|
||||
badim: int = 320,
|
||||
ref_channel: int = -1,
|
||||
bdropout_rate=0.0,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.use_beamformer = use_beamformer
|
||||
self.use_wpe = use_wpe
|
||||
self.use_dnn_mask_for_wpe = use_dnn_mask_for_wpe
|
||||
# use frontend for all the data,
|
||||
# e.g. in the case of multi-speaker speech separation
|
||||
self.use_frontend_for_all = bnmask > 2
|
||||
|
||||
if self.use_wpe:
|
||||
if self.use_dnn_mask_for_wpe:
|
||||
# Use DNN for power estimation
|
||||
# (Not observed significant gains)
|
||||
iterations = 1
|
||||
else:
|
||||
# Performing as conventional WPE, without DNN Estimator
|
||||
iterations = 2
|
||||
|
||||
self.wpe = DNN_WPE(
|
||||
wtype=wtype,
|
||||
widim=idim,
|
||||
wunits=wunits,
|
||||
wprojs=wprojs,
|
||||
wlayers=wlayers,
|
||||
taps=taps,
|
||||
delay=delay,
|
||||
dropout_rate=wdropout_rate,
|
||||
iterations=iterations,
|
||||
use_dnn_mask=use_dnn_mask_for_wpe,
|
||||
)
|
||||
else:
|
||||
self.wpe = None
|
||||
|
||||
if self.use_beamformer:
|
||||
self.beamformer = DNN_Beamformer(
|
||||
btype=btype,
|
||||
bidim=idim,
|
||||
bunits=bunits,
|
||||
bprojs=bprojs,
|
||||
blayers=blayers,
|
||||
bnmask=bnmask,
|
||||
dropout_rate=bdropout_rate,
|
||||
badim=badim,
|
||||
ref_channel=ref_channel,
|
||||
)
|
||||
else:
|
||||
self.beamformer = None
|
||||
|
||||
def forward(
|
||||
self, x: ComplexTensor, ilens: Union[torch.LongTensor, numpy.ndarray, List[int]]
|
||||
) -> Tuple[ComplexTensor, torch.LongTensor, Optional[ComplexTensor]]:
|
||||
assert len(x) == len(ilens), (len(x), len(ilens))
|
||||
# (B, T, F) or (B, T, C, F)
|
||||
if x.dim() not in (3, 4):
|
||||
raise ValueError(f"Input dim must be 3 or 4: {x.dim()}")
|
||||
if not torch.is_tensor(ilens):
|
||||
ilens = torch.from_numpy(numpy.asarray(ilens)).to(x.device)
|
||||
|
||||
mask = None
|
||||
h = x
|
||||
if h.dim() == 4:
|
||||
if self.training:
|
||||
choices = [(False, False)] if not self.use_frontend_for_all else []
|
||||
if self.use_wpe:
|
||||
choices.append((True, False))
|
||||
|
||||
if self.use_beamformer:
|
||||
choices.append((False, True))
|
||||
|
||||
use_wpe, use_beamformer = choices[numpy.random.randint(len(choices))]
|
||||
|
||||
else:
|
||||
use_wpe = self.use_wpe
|
||||
use_beamformer = self.use_beamformer
|
||||
|
||||
# 1. WPE
|
||||
if use_wpe:
|
||||
# h: (B, T, C, F) -> h: (B, T, C, F)
|
||||
h, ilens, mask = self.wpe(h, ilens)
|
||||
|
||||
# 2. Beamformer
|
||||
if use_beamformer:
|
||||
# h: (B, T, C, F) -> h: (B, T, F)
|
||||
h, ilens, mask = self.beamformer(h, ilens)
|
||||
|
||||
return h, ilens, mask
|
||||
|
||||
|
||||
def frontend_for(args, idim):
|
||||
return Frontend(
|
||||
idim=idim,
|
||||
# WPE options
|
||||
use_wpe=args.use_wpe,
|
||||
wtype=args.wtype,
|
||||
wlayers=args.wlayers,
|
||||
wunits=args.wunits,
|
||||
wprojs=args.wprojs,
|
||||
wdropout_rate=args.wdropout_rate,
|
||||
taps=args.wpe_taps,
|
||||
delay=args.wpe_delay,
|
||||
use_dnn_mask_for_wpe=args.use_dnn_mask_for_wpe,
|
||||
# Beamformer options
|
||||
use_beamformer=args.use_beamformer,
|
||||
btype=args.btype,
|
||||
blayers=args.blayers,
|
||||
bunits=args.bunits,
|
||||
bprojs=args.bprojs,
|
||||
bnmask=args.bnmask,
|
||||
badim=args.badim,
|
||||
ref_channel=args.ref_channel,
|
||||
bdropout_rate=args.bdropout_rate,
|
||||
)
|
||||
79
modules/python/vendors/FunASR/funasr/frontends/utils/log_mel.py
vendored
Normal file
79
modules/python/vendors/FunASR/funasr/frontends/utils/log_mel.py
vendored
Normal file
@@ -0,0 +1,79 @@
|
||||
import librosa
|
||||
import torch
|
||||
from typing import Tuple
|
||||
|
||||
from funasr.models.transformer.utils.nets_utils import make_pad_mask
|
||||
|
||||
|
||||
class LogMel(torch.nn.Module):
|
||||
"""Convert STFT to fbank feats
|
||||
|
||||
The arguments is same as librosa.filters.mel
|
||||
|
||||
Args:
|
||||
fs: number > 0 [scalar] sampling rate of the incoming signal
|
||||
n_fft: int > 0 [scalar] number of FFT components
|
||||
n_mels: int > 0 [scalar] number of Mel bands to generate
|
||||
fmin: float >= 0 [scalar] lowest frequency (in Hz)
|
||||
fmax: float >= 0 [scalar] highest frequency (in Hz).
|
||||
If `None`, use `fmax = fs / 2.0`
|
||||
htk: use HTK formula instead of Slaney
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
fs: int = 16000,
|
||||
n_fft: int = 512,
|
||||
n_mels: int = 80,
|
||||
fmin: float = None,
|
||||
fmax: float = None,
|
||||
htk: bool = False,
|
||||
log_base: float = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
fmin = 0 if fmin is None else fmin
|
||||
fmax = fs / 2 if fmax is None else fmax
|
||||
_mel_options = dict(
|
||||
sr=fs,
|
||||
n_fft=n_fft,
|
||||
n_mels=n_mels,
|
||||
fmin=fmin,
|
||||
fmax=fmax,
|
||||
htk=htk,
|
||||
)
|
||||
self.mel_options = _mel_options
|
||||
self.log_base = log_base
|
||||
|
||||
# Note(kamo): The mel matrix of librosa is different from kaldi.
|
||||
melmat = librosa.filters.mel(**_mel_options)
|
||||
# melmat: (D2, D1) -> (D1, D2)
|
||||
self.register_buffer("melmat", torch.from_numpy(melmat.T).float())
|
||||
|
||||
def extra_repr(self):
|
||||
return ", ".join(f"{k}={v}" for k, v in self.mel_options.items())
|
||||
|
||||
def forward(
|
||||
self,
|
||||
feat: torch.Tensor,
|
||||
ilens: torch.Tensor = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# feat: (B, T, D1) x melmat: (D1, D2) -> mel_feat: (B, T, D2)
|
||||
mel_feat = torch.matmul(feat, self.melmat)
|
||||
mel_feat = torch.clamp(mel_feat, min=1e-10)
|
||||
|
||||
if self.log_base is None:
|
||||
logmel_feat = mel_feat.log()
|
||||
elif self.log_base == 2.0:
|
||||
logmel_feat = mel_feat.log2()
|
||||
elif self.log_base == 10.0:
|
||||
logmel_feat = mel_feat.log10()
|
||||
else:
|
||||
logmel_feat = mel_feat.log() / torch.log(self.log_base)
|
||||
|
||||
# Zero padding
|
||||
if ilens is not None:
|
||||
logmel_feat = logmel_feat.masked_fill(make_pad_mask(ilens, logmel_feat, 1), 0.0)
|
||||
else:
|
||||
ilens = feat.new_full([feat.size(0)], fill_value=feat.size(1), dtype=torch.long)
|
||||
return logmel_feat, ilens
|
||||
75
modules/python/vendors/FunASR/funasr/frontends/utils/mask_estimator.py
vendored
Normal file
75
modules/python/vendors/FunASR/funasr/frontends/utils/mask_estimator.py
vendored
Normal file
@@ -0,0 +1,75 @@
|
||||
from typing import Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
from torch_complex.tensor import ComplexTensor
|
||||
|
||||
from funasr.models.transformer.utils.nets_utils import make_pad_mask
|
||||
from funasr.models.language_model.rnn.encoders import RNN
|
||||
from funasr.models.language_model.rnn.encoders import RNNP
|
||||
|
||||
|
||||
class MaskEstimator(torch.nn.Module):
|
||||
def __init__(self, type, idim, layers, units, projs, dropout, nmask=1):
|
||||
super().__init__()
|
||||
subsample = np.ones(layers + 1, dtype=np.int32)
|
||||
|
||||
typ = type.lstrip("vgg").rstrip("p")
|
||||
if type[-1] == "p":
|
||||
self.brnn = RNNP(idim, layers, units, projs, subsample, dropout, typ=typ)
|
||||
else:
|
||||
self.brnn = RNN(idim, layers, units, projs, dropout, typ=typ)
|
||||
|
||||
self.type = type
|
||||
self.nmask = nmask
|
||||
self.linears = torch.nn.ModuleList([torch.nn.Linear(projs, idim) for _ in range(nmask)])
|
||||
|
||||
def forward(
|
||||
self, xs: ComplexTensor, ilens: torch.LongTensor
|
||||
) -> Tuple[Tuple[torch.Tensor, ...], torch.LongTensor]:
|
||||
"""The forward function
|
||||
|
||||
Args:
|
||||
xs: (B, F, C, T)
|
||||
ilens: (B,)
|
||||
Returns:
|
||||
hs (torch.Tensor): The hidden vector (B, F, C, T)
|
||||
masks: A tuple of the masks. (B, F, C, T)
|
||||
ilens: (B,)
|
||||
"""
|
||||
assert xs.size(0) == ilens.size(0), (xs.size(0), ilens.size(0))
|
||||
_, _, C, input_length = xs.size()
|
||||
# (B, F, C, T) -> (B, C, T, F)
|
||||
xs = xs.permute(0, 2, 3, 1)
|
||||
|
||||
# Calculate amplitude: (B, C, T, F) -> (B, C, T, F)
|
||||
xs = (xs.real**2 + xs.imag**2) ** 0.5
|
||||
# xs: (B, C, T, F) -> xs: (B * C, T, F)
|
||||
xs = xs.contiguous().view(-1, xs.size(-2), xs.size(-1))
|
||||
# ilens: (B,) -> ilens_: (B * C)
|
||||
ilens_ = ilens[:, None].expand(-1, C).contiguous().view(-1)
|
||||
|
||||
# xs: (B * C, T, F) -> xs: (B * C, T, D)
|
||||
xs, _, _ = self.brnn(xs, ilens_)
|
||||
# xs: (B * C, T, D) -> xs: (B, C, T, D)
|
||||
xs = xs.view(-1, C, xs.size(-2), xs.size(-1))
|
||||
|
||||
masks = []
|
||||
for linear in self.linears:
|
||||
# xs: (B, C, T, D) -> mask:(B, C, T, F)
|
||||
mask = linear(xs)
|
||||
|
||||
mask = torch.sigmoid(mask)
|
||||
# Zero padding
|
||||
mask.masked_fill(make_pad_mask(ilens, mask, length_dim=2), 0)
|
||||
|
||||
# (B, C, T, F) -> (B, F, C, T)
|
||||
mask = mask.permute(0, 3, 1, 2)
|
||||
|
||||
# Take cares of multi gpu cases: If input_length > max(ilens)
|
||||
if mask.size(-1) < input_length:
|
||||
mask = F.pad(mask, [0, input_length - mask.size(-1)], value=0)
|
||||
masks.append(mask)
|
||||
|
||||
return tuple(masks), ilens
|
||||
226
modules/python/vendors/FunASR/funasr/frontends/utils/stft.py
vendored
Normal file
226
modules/python/vendors/FunASR/funasr/frontends/utils/stft.py
vendored
Normal file
@@ -0,0 +1,226 @@
|
||||
from distutils.version import LooseVersion
|
||||
from typing import Optional
|
||||
from typing import Tuple
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
|
||||
try:
|
||||
from torch_complex.tensor import ComplexTensor
|
||||
except:
|
||||
print("Please install torch_complex firstly")
|
||||
from funasr.models.transformer.utils.nets_utils import make_pad_mask
|
||||
from funasr.frontends.utils.complex_utils import is_complex
|
||||
|
||||
import librosa
|
||||
import numpy as np
|
||||
|
||||
is_torch_1_9_plus = LooseVersion(torch.__version__) >= LooseVersion("1.9.0")
|
||||
|
||||
|
||||
is_torch_1_7_plus = LooseVersion(torch.__version__) >= LooseVersion("1.7")
|
||||
|
||||
|
||||
class Stft(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
n_fft: int = 512,
|
||||
win_length: int = None,
|
||||
hop_length: int = 128,
|
||||
window: Optional[str] = "hann",
|
||||
center: bool = True,
|
||||
normalized: bool = False,
|
||||
onesided: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
self.n_fft = n_fft
|
||||
if win_length is None:
|
||||
self.win_length = n_fft
|
||||
else:
|
||||
self.win_length = win_length
|
||||
self.hop_length = hop_length
|
||||
self.center = center
|
||||
self.normalized = normalized
|
||||
self.onesided = onesided
|
||||
if window is not None and not hasattr(torch, f"{window}_window"):
|
||||
if window.lower() != "povey":
|
||||
raise ValueError(f"{window} window is not implemented")
|
||||
self.window = window
|
||||
|
||||
def extra_repr(self):
|
||||
return (
|
||||
f"n_fft={self.n_fft}, "
|
||||
f"win_length={self.win_length}, "
|
||||
f"hop_length={self.hop_length}, "
|
||||
f"center={self.center}, "
|
||||
f"normalized={self.normalized}, "
|
||||
f"onesided={self.onesided}"
|
||||
)
|
||||
|
||||
def forward(
|
||||
self, input: torch.Tensor, ilens: torch.Tensor = None
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
"""STFT forward function.
|
||||
|
||||
Args:
|
||||
input: (Batch, Nsamples) or (Batch, Nsample, Channels)
|
||||
ilens: (Batch)
|
||||
Returns:
|
||||
output: (Batch, Frames, Freq, 2) or (Batch, Frames, Channels, Freq, 2)
|
||||
|
||||
"""
|
||||
bs = input.size(0)
|
||||
if input.dim() == 3:
|
||||
multi_channel = True
|
||||
# input: (Batch, Nsample, Channels) -> (Batch * Channels, Nsample)
|
||||
input = input.transpose(1, 2).reshape(-1, input.size(1))
|
||||
else:
|
||||
multi_channel = False
|
||||
|
||||
# NOTE(kamo):
|
||||
# The default behaviour of torch.stft is compatible with librosa.stft
|
||||
# about padding and scaling.
|
||||
# Note that it's different from scipy.signal.stft
|
||||
|
||||
# output: (Batch, Freq, Frames, 2=real_imag)
|
||||
# or (Batch, Channel, Freq, Frames, 2=real_imag)
|
||||
if self.window is not None:
|
||||
if self.window.lower() == "povey":
|
||||
window = torch.hann_window(
|
||||
self.win_length, periodic=False, device=input.device, dtype=input.dtype
|
||||
).pow(0.85)
|
||||
else:
|
||||
window_func = getattr(torch, f"{self.window}_window")
|
||||
window = window_func(self.win_length, dtype=input.dtype, device=input.device)
|
||||
else:
|
||||
window = None
|
||||
|
||||
# For the compatibility of ARM devices, which do not support
|
||||
# torch.stft() due to the lake of MKL.
|
||||
if input.is_cuda or torch.backends.mkl.is_available():
|
||||
stft_kwargs = dict(
|
||||
n_fft=self.n_fft,
|
||||
win_length=self.win_length,
|
||||
hop_length=self.hop_length,
|
||||
center=self.center,
|
||||
window=window,
|
||||
normalized=self.normalized,
|
||||
onesided=self.onesided,
|
||||
)
|
||||
if is_torch_1_7_plus:
|
||||
stft_kwargs["return_complex"] = False
|
||||
output = torch.stft(input, **stft_kwargs)
|
||||
else:
|
||||
if self.training:
|
||||
raise NotImplementedError(
|
||||
"stft is implemented with librosa on this device, which does not "
|
||||
"support the training mode."
|
||||
)
|
||||
|
||||
# use stft_kwargs to flexibly control different PyTorch versions' kwargs
|
||||
stft_kwargs = dict(
|
||||
n_fft=self.n_fft,
|
||||
win_length=self.win_length,
|
||||
hop_length=self.hop_length,
|
||||
center=self.center,
|
||||
window=window,
|
||||
)
|
||||
|
||||
if window is not None:
|
||||
# pad the given window to n_fft
|
||||
n_pad_left = (self.n_fft - window.shape[0]) // 2
|
||||
n_pad_right = self.n_fft - window.shape[0] - n_pad_left
|
||||
stft_kwargs["window"] = torch.cat(
|
||||
[torch.zeros(n_pad_left), window, torch.zeros(n_pad_right)], 0
|
||||
).numpy()
|
||||
else:
|
||||
win_length = self.win_length if self.win_length is not None else self.n_fft
|
||||
stft_kwargs["window"] = torch.ones(win_length)
|
||||
|
||||
output = []
|
||||
# iterate over istances in a batch
|
||||
for i, instance in enumerate(input):
|
||||
stft = librosa.stft(input[i].numpy(), **stft_kwargs)
|
||||
output.append(torch.tensor(np.stack([stft.real, stft.imag], -1)))
|
||||
output = torch.stack(output, 0)
|
||||
if not self.onesided:
|
||||
len_conj = self.n_fft - output.shape[1]
|
||||
conj = output[:, 1 : 1 + len_conj].flip(1)
|
||||
conj[:, :, :, -1].data *= -1
|
||||
output = torch.cat([output, conj], 1)
|
||||
if self.normalized:
|
||||
output = output * (stft_kwargs["window"].shape[0] ** (-0.5))
|
||||
|
||||
# output: (Batch, Freq, Frames, 2=real_imag)
|
||||
# -> (Batch, Frames, Freq, 2=real_imag)
|
||||
output = output.transpose(1, 2)
|
||||
if multi_channel:
|
||||
# output: (Batch * Channel, Frames, Freq, 2=real_imag)
|
||||
# -> (Batch, Frame, Channel, Freq, 2=real_imag)
|
||||
output = output.view(bs, -1, output.size(1), output.size(2), 2).transpose(1, 2)
|
||||
|
||||
if ilens is not None:
|
||||
if self.center:
|
||||
pad = self.n_fft // 2
|
||||
ilens = ilens + 2 * pad
|
||||
|
||||
olens = (ilens - self.n_fft) // self.hop_length + 1
|
||||
output.masked_fill_(make_pad_mask(olens, output, 1), 0.0)
|
||||
else:
|
||||
olens = None
|
||||
|
||||
return output, olens
|
||||
|
||||
def inverse(
|
||||
self, input: Union[torch.Tensor, ComplexTensor], ilens: torch.Tensor = None
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
"""Inverse STFT.
|
||||
|
||||
Args:
|
||||
input: Tensor(batch, T, F, 2) or ComplexTensor(batch, T, F)
|
||||
ilens: (batch,)
|
||||
Returns:
|
||||
wavs: (batch, samples)
|
||||
ilens: (batch,)
|
||||
"""
|
||||
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
|
||||
istft = torch.functional.istft
|
||||
else:
|
||||
try:
|
||||
import torchaudio
|
||||
except ImportError:
|
||||
raise ImportError("Please install torchaudio>=0.3.0 or use torch>=1.6.0")
|
||||
|
||||
if not hasattr(torchaudio.functional, "istft"):
|
||||
raise ImportError("Please install torchaudio>=0.3.0 or use torch>=1.6.0")
|
||||
istft = torchaudio.functional.istft
|
||||
|
||||
if self.window is not None:
|
||||
window_func = getattr(torch, f"{self.window}_window")
|
||||
if is_complex(input):
|
||||
datatype = input.real.dtype
|
||||
else:
|
||||
datatype = input.dtype
|
||||
window = window_func(self.win_length, dtype=datatype, device=input.device)
|
||||
else:
|
||||
window = None
|
||||
|
||||
if is_complex(input):
|
||||
input = torch.stack([input.real, input.imag], dim=-1)
|
||||
elif input.shape[-1] != 2:
|
||||
raise TypeError("Invalid input type")
|
||||
input = input.transpose(1, 2)
|
||||
|
||||
wavs = istft(
|
||||
input,
|
||||
n_fft=self.n_fft,
|
||||
hop_length=self.hop_length,
|
||||
win_length=self.win_length,
|
||||
window=window,
|
||||
center=self.center,
|
||||
normalized=self.normalized,
|
||||
onesided=self.onesided,
|
||||
length=ilens.max() if ilens is not None else ilens,
|
||||
)
|
||||
|
||||
return wavs, ilens
|
||||
Reference in New Issue
Block a user