Sync from bytedesk-private: update

This commit is contained in:
jack ning
2024-12-14 10:43:18 +08:00
parent 476eebb101
commit 5e082909e4
3421 changed files with 812709 additions and 0 deletions

View File

@@ -0,0 +1 @@
"""Initialize sub package."""

View File

@@ -0,0 +1,82 @@
import torch
from torch_complex import functional as FC
from torch_complex.tensor import ComplexTensor
def get_power_spectral_density_matrix(
xs: ComplexTensor, mask: torch.Tensor, normalization=True, eps: float = 1e-15
) -> ComplexTensor:
"""Return cross-channel power spectral density (PSD) matrix
Args:
xs (ComplexTensor): (..., F, C, T)
mask (torch.Tensor): (..., F, C, T)
normalization (bool):
eps (float):
Returns
psd (ComplexTensor): (..., F, C, C)
"""
# outer product: (..., C_1, T) x (..., C_2, T) -> (..., T, C, C_2)
psd_Y = FC.einsum("...ct,...et->...tce", [xs, xs.conj()])
# Averaging mask along C: (..., C, T) -> (..., T)
mask = mask.mean(dim=-2)
# Normalized mask along T: (..., T)
if normalization:
# If assuming the tensor is padded with zero, the summation along
# the time axis is same regardless of the padding length.
mask = mask / (mask.sum(dim=-1, keepdim=True) + eps)
# psd: (..., T, C, C)
psd = psd_Y * mask[..., None, None]
# (..., T, C, C) -> (..., C, C)
psd = psd.sum(dim=-3)
return psd
def get_mvdr_vector(
psd_s: ComplexTensor,
psd_n: ComplexTensor,
reference_vector: torch.Tensor,
eps: float = 1e-15,
) -> ComplexTensor:
"""Return the MVDR(Minimum Variance Distortionless Response) vector:
h = (Npsd^-1 @ Spsd) / (Tr(Npsd^-1 @ Spsd)) @ u
Reference:
On optimal frequency-domain multichannel linear filtering
for noise reduction; M. Souden et al., 2010;
https://ieeexplore.ieee.org/document/5089420
Args:
psd_s (ComplexTensor): (..., F, C, C)
psd_n (ComplexTensor): (..., F, C, C)
reference_vector (torch.Tensor): (..., C)
eps (float):
Returns:
beamform_vector (ComplexTensor)r: (..., F, C)
"""
# Add eps
C = psd_n.size(-1)
eye = torch.eye(C, dtype=psd_n.dtype, device=psd_n.device)
shape = [1 for _ in range(psd_n.dim() - 2)] + [C, C]
eye = eye.view(*shape)
psd_n += eps * eye
# numerator: (..., C_1, C_2) x (..., C_2, C_3) -> (..., C_1, C_3)
numerator = FC.einsum("...ec,...cd->...ed", [psd_n.inverse(), psd_s])
# ws: (..., C, C) / (...,) -> (..., C, C)
ws = numerator / (FC.trace(numerator)[..., None, None] + eps)
# h: (..., F, C_1, C_2) x (..., C_2) -> (..., F, C_1)
beamform_vector = FC.einsum("...fec,...c->...fe", [ws, reference_vector])
return beamform_vector
def apply_beamforming_vector(beamform_vector: ComplexTensor, mix: ComplexTensor) -> ComplexTensor:
# (..., C) x (..., C, T) -> (..., T)
es = FC.einsum("...c,...ct->...t", [beamform_vector.conj(), mix])
return es

View File

@@ -0,0 +1,184 @@
"""Beamformer module."""
from distutils.version import LooseVersion
from typing import Sequence
from typing import Tuple
from typing import Union
import torch
try:
from torch_complex import functional as FC
from torch_complex.tensor import ComplexTensor
except:
print("Please install torch_complex firstly")
EPS = torch.finfo(torch.double).eps
is_torch_1_8_plus = LooseVersion(torch.__version__) >= LooseVersion("1.8.0")
is_torch_1_9_plus = LooseVersion(torch.__version__) >= LooseVersion("1.9.0")
def new_complex_like(
ref: Union[torch.Tensor, ComplexTensor],
real_imag: Tuple[torch.Tensor, torch.Tensor],
):
if isinstance(ref, ComplexTensor):
return ComplexTensor(*real_imag)
elif is_torch_complex_tensor(ref):
return torch.complex(*real_imag)
else:
raise ValueError("Please update your PyTorch version to 1.9+ for complex support.")
def is_torch_complex_tensor(c):
return not isinstance(c, ComplexTensor) and is_torch_1_9_plus and torch.is_complex(c)
def is_complex(c):
return isinstance(c, ComplexTensor) or is_torch_complex_tensor(c)
def to_double(c):
if not isinstance(c, ComplexTensor) and is_torch_1_9_plus and torch.is_complex(c):
return c.to(dtype=torch.complex128)
else:
return c.double()
def to_float(c):
if not isinstance(c, ComplexTensor) and is_torch_1_9_plus and torch.is_complex(c):
return c.to(dtype=torch.complex64)
else:
return c.float()
def cat(seq: Sequence[Union[ComplexTensor, torch.Tensor]], *args, **kwargs):
if not isinstance(seq, (list, tuple)):
raise TypeError(
"cat(): argument 'tensors' (position 1) must be tuple of Tensors, " "not Tensor"
)
if isinstance(seq[0], ComplexTensor):
return FC.cat(seq, *args, **kwargs)
else:
return torch.cat(seq, *args, **kwargs)
def complex_norm(c: Union[torch.Tensor, ComplexTensor], dim=-1, keepdim=False) -> torch.Tensor:
if not is_complex(c):
raise TypeError("Input is not a complex tensor.")
if is_torch_complex_tensor(c):
return torch.norm(c, dim=dim, keepdim=keepdim)
else:
return torch.sqrt((c.real**2 + c.imag**2).sum(dim=dim, keepdim=keepdim) + EPS)
def einsum(equation, *operands):
# NOTE: Do not mix ComplexTensor and torch.complex in the input!
# NOTE (wangyou): Until PyTorch 1.9.0, torch.einsum does not support
# mixed input with complex and real tensors.
if len(operands) == 1:
if isinstance(operands[0], (tuple, list)):
operands = operands[0]
complex_module = FC if isinstance(operands[0], ComplexTensor) else torch
return complex_module.einsum(equation, *operands)
elif len(operands) != 2:
op0 = operands[0]
same_type = all(op.dtype == op0.dtype for op in operands[1:])
if same_type:
_einsum = FC.einsum if isinstance(op0, ComplexTensor) else torch.einsum
return _einsum(equation, *operands)
else:
raise ValueError("0 or More than 2 operands are not supported.")
a, b = operands
if isinstance(a, ComplexTensor) or isinstance(b, ComplexTensor):
return FC.einsum(equation, a, b)
elif is_torch_1_9_plus and (torch.is_complex(a) or torch.is_complex(b)):
if not torch.is_complex(a):
o_real = torch.einsum(equation, a, b.real)
o_imag = torch.einsum(equation, a, b.imag)
return torch.complex(o_real, o_imag)
elif not torch.is_complex(b):
o_real = torch.einsum(equation, a.real, b)
o_imag = torch.einsum(equation, a.imag, b)
return torch.complex(o_real, o_imag)
else:
return torch.einsum(equation, a, b)
else:
return torch.einsum(equation, a, b)
def inverse(c: Union[torch.Tensor, ComplexTensor]) -> Union[torch.Tensor, ComplexTensor]:
if isinstance(c, ComplexTensor):
return c.inverse2()
else:
return c.inverse()
def matmul(
a: Union[torch.Tensor, ComplexTensor], b: Union[torch.Tensor, ComplexTensor]
) -> Union[torch.Tensor, ComplexTensor]:
# NOTE: Do not mix ComplexTensor and torch.complex in the input!
# NOTE (wangyou): Until PyTorch 1.9.0, torch.matmul does not support
# multiplication between complex and real tensors.
if isinstance(a, ComplexTensor) or isinstance(b, ComplexTensor):
return FC.matmul(a, b)
elif is_torch_1_9_plus and (torch.is_complex(a) or torch.is_complex(b)):
if not torch.is_complex(a):
o_real = torch.matmul(a, b.real)
o_imag = torch.matmul(a, b.imag)
return torch.complex(o_real, o_imag)
elif not torch.is_complex(b):
o_real = torch.matmul(a.real, b)
o_imag = torch.matmul(a.imag, b)
return torch.complex(o_real, o_imag)
else:
return torch.matmul(a, b)
else:
return torch.matmul(a, b)
def trace(a: Union[torch.Tensor, ComplexTensor]):
# NOTE (wangyou): until PyTorch 1.9.0, torch.trace does not
# support bacth processing. Use FC.trace() as fallback.
return FC.trace(a)
def reverse(a: Union[torch.Tensor, ComplexTensor], dim=0):
if isinstance(a, ComplexTensor):
return FC.reverse(a, dim=dim)
else:
return torch.flip(a, dims=(dim,))
def solve(b: Union[torch.Tensor, ComplexTensor], a: Union[torch.Tensor, ComplexTensor]):
"""Solve the linear equation ax = b."""
# NOTE: Do not mix ComplexTensor and torch.complex in the input!
# NOTE (wangyou): Until PyTorch 1.9.0, torch.solve does not support
# mixed input with complex and real tensors.
if isinstance(a, ComplexTensor) or isinstance(b, ComplexTensor):
if isinstance(a, ComplexTensor) and isinstance(b, ComplexTensor):
return FC.solve(b, a, return_LU=False)
else:
return matmul(inverse(a), b)
elif is_torch_1_9_plus and (torch.is_complex(a) or torch.is_complex(b)):
if torch.is_complex(a) and torch.is_complex(b):
return torch.linalg.solve(a, b)
else:
return matmul(inverse(a), b)
else:
if is_torch_1_8_plus:
return torch.linalg.solve(a, b)
else:
return torch.solve(b, a)[0]
def stack(seq: Sequence[Union[ComplexTensor, torch.Tensor]], *args, **kwargs):
if not isinstance(seq, (list, tuple)):
raise TypeError(
"stack(): argument 'tensors' (position 1) must be tuple of Tensors, " "not Tensor"
)
if isinstance(seq[0], ComplexTensor):
return FC.stack(seq, *args, **kwargs)
else:
return torch.stack(seq, *args, **kwargs)

View File

@@ -0,0 +1,161 @@
"""DNN beamformer module."""
from typing import Tuple
import torch
from torch.nn import functional as F
from funasr.frontends.utils.beamformer import apply_beamforming_vector
from funasr.frontends.utils.beamformer import get_mvdr_vector
from funasr.frontends.utils.beamformer import (
get_power_spectral_density_matrix, # noqa: H301
)
from funasr.frontends.utils.mask_estimator import MaskEstimator
from torch_complex.tensor import ComplexTensor
class DNN_Beamformer(torch.nn.Module):
"""DNN mask based Beamformer
Citation:
Multichannel End-to-end Speech Recognition; T. Ochiai et al., 2017;
https://arxiv.org/abs/1703.04783
"""
def __init__(
self,
bidim,
btype="blstmp",
blayers=3,
bunits=300,
bprojs=320,
bnmask=2,
dropout_rate=0.0,
badim=320,
ref_channel: int = -1,
beamformer_type="mvdr",
):
super().__init__()
self.mask = MaskEstimator(btype, bidim, blayers, bunits, bprojs, dropout_rate, nmask=bnmask)
self.ref = AttentionReference(bidim, badim)
self.ref_channel = ref_channel
self.nmask = bnmask
if beamformer_type != "mvdr":
raise ValueError("Not supporting beamformer_type={}".format(beamformer_type))
self.beamformer_type = beamformer_type
def forward(
self, data: ComplexTensor, ilens: torch.LongTensor
) -> Tuple[ComplexTensor, torch.LongTensor, ComplexTensor]:
"""The forward function
Notation:
B: Batch
C: Channel
T: Time or Sequence length
F: Freq
Args:
data (ComplexTensor): (B, T, C, F)
ilens (torch.Tensor): (B,)
Returns:
enhanced (ComplexTensor): (B, T, F)
ilens (torch.Tensor): (B,)
"""
def apply_beamforming(data, ilens, psd_speech, psd_noise):
# u: (B, C)
if self.ref_channel < 0:
u, _ = self.ref(psd_speech, ilens)
else:
# (optional) Create onehot vector for fixed reference microphone
u = torch.zeros(*(data.size()[:-3] + (data.size(-2),)), device=data.device)
u[..., self.ref_channel].fill_(1)
ws = get_mvdr_vector(psd_speech, psd_noise, u)
enhanced = apply_beamforming_vector(ws, data)
return enhanced, ws
# data (B, T, C, F) -> (B, F, C, T)
data = data.permute(0, 3, 2, 1)
# mask: (B, F, C, T)
masks, _ = self.mask(data, ilens)
assert self.nmask == len(masks)
if self.nmask == 2: # (mask_speech, mask_noise)
mask_speech, mask_noise = masks
psd_speech = get_power_spectral_density_matrix(data, mask_speech)
psd_noise = get_power_spectral_density_matrix(data, mask_noise)
enhanced, ws = apply_beamforming(data, ilens, psd_speech, psd_noise)
# (..., F, T) -> (..., T, F)
enhanced = enhanced.transpose(-1, -2)
mask_speech = mask_speech.transpose(-1, -3)
else: # multi-speaker case: (mask_speech1, ..., mask_noise)
mask_speech = list(masks[:-1])
mask_noise = masks[-1]
psd_speeches = [get_power_spectral_density_matrix(data, mask) for mask in mask_speech]
psd_noise = get_power_spectral_density_matrix(data, mask_noise)
enhanced = []
ws = []
for i in range(self.nmask - 1):
psd_speech = psd_speeches.pop(i)
# treat all other speakers' psd_speech as noises
enh, w = apply_beamforming(data, ilens, psd_speech, sum(psd_speeches) + psd_noise)
psd_speeches.insert(i, psd_speech)
# (..., F, T) -> (..., T, F)
enh = enh.transpose(-1, -2)
mask_speech[i] = mask_speech[i].transpose(-1, -3)
enhanced.append(enh)
ws.append(w)
return enhanced, ilens, mask_speech
class AttentionReference(torch.nn.Module):
def __init__(self, bidim, att_dim):
super().__init__()
self.mlp_psd = torch.nn.Linear(bidim, att_dim)
self.gvec = torch.nn.Linear(att_dim, 1)
def forward(
self, psd_in: ComplexTensor, ilens: torch.LongTensor, scaling: float = 2.0
) -> Tuple[torch.Tensor, torch.LongTensor]:
"""The forward function
Args:
psd_in (ComplexTensor): (B, F, C, C)
ilens (torch.Tensor): (B,)
scaling (float):
Returns:
u (torch.Tensor): (B, C)
ilens (torch.Tensor): (B,)
"""
B, _, C = psd_in.size()[:3]
assert psd_in.size(2) == psd_in.size(3), psd_in.size()
# psd_in: (B, F, C, C)
psd = psd_in.masked_fill(torch.eye(C, dtype=torch.bool, device=psd_in.device), 0)
# psd: (B, F, C, C) -> (B, C, F)
psd = (psd.sum(dim=-1) / (C - 1)).transpose(-1, -2)
# Calculate amplitude
psd_feat = (psd.real**2 + psd.imag**2) ** 0.5
# (B, C, F) -> (B, C, F2)
mlp_psd = self.mlp_psd(psd_feat)
# (B, C, F2) -> (B, C, 1) -> (B, C)
e = self.gvec(torch.tanh(mlp_psd)).squeeze(-1)
u = F.softmax(scaling * e, dim=-1)
return u, ilens

View File

@@ -0,0 +1,93 @@
from typing import Tuple
from pytorch_wpe import wpe_one_iteration
import torch
from torch_complex.tensor import ComplexTensor
from funasr.frontends.utils.mask_estimator import MaskEstimator
from funasr.models.transformer.utils.nets_utils import make_pad_mask
class DNN_WPE(torch.nn.Module):
def __init__(
self,
wtype: str = "blstmp",
widim: int = 257,
wlayers: int = 3,
wunits: int = 300,
wprojs: int = 320,
dropout_rate: float = 0.0,
taps: int = 5,
delay: int = 3,
use_dnn_mask: bool = True,
iterations: int = 1,
normalization: bool = False,
):
super().__init__()
self.iterations = iterations
self.taps = taps
self.delay = delay
self.normalization = normalization
self.use_dnn_mask = use_dnn_mask
self.inverse_power = True
if self.use_dnn_mask:
self.mask_est = MaskEstimator(
wtype, widim, wlayers, wunits, wprojs, dropout_rate, nmask=1
)
def forward(
self, data: ComplexTensor, ilens: torch.LongTensor
) -> Tuple[ComplexTensor, torch.LongTensor, ComplexTensor]:
"""The forward function
Notation:
B: Batch
C: Channel
T: Time or Sequence length
F: Freq or Some dimension of the feature vector
Args:
data: (B, C, T, F)
ilens: (B,)
Returns:
data: (B, C, T, F)
ilens: (B,)
"""
# (B, T, C, F) -> (B, F, C, T)
enhanced = data = data.permute(0, 3, 2, 1)
mask = None
for i in range(self.iterations):
# Calculate power: (..., C, T)
power = enhanced.real**2 + enhanced.imag**2
if i == 0 and self.use_dnn_mask:
# mask: (B, F, C, T)
(mask,), _ = self.mask_est(enhanced, ilens)
if self.normalization:
# Normalize along T
mask = mask / mask.sum(dim=-1)[..., None]
# (..., C, T) * (..., C, T) -> (..., C, T)
power = power * mask
# Averaging along the channel axis: (..., C, T) -> (..., T)
power = power.mean(dim=-2)
# enhanced: (..., C, T) -> (..., C, T)
enhanced = wpe_one_iteration(
data.contiguous(),
power,
taps=self.taps,
delay=self.delay,
inverse_power=self.inverse_power,
)
enhanced.masked_fill_(make_pad_mask(ilens, enhanced.real), 0)
# (B, F, C, T) -> (B, T, C, F)
enhanced = enhanced.permute(0, 3, 2, 1)
if mask is not None:
mask = mask.transpose(-1, -3)
return enhanced, ilens, mask

View File

@@ -0,0 +1,259 @@
from typing import List
from typing import Tuple
from typing import Union
import librosa
import numpy as np
import torch
from torch_complex.tensor import ComplexTensor
from funasr.models.transformer.utils.nets_utils import make_pad_mask
class FeatureTransform(torch.nn.Module):
def __init__(
self,
# Mel options,
fs: int = 16000,
n_fft: int = 512,
n_mels: int = 80,
fmin: float = 0.0,
fmax: float = None,
# Normalization
stats_file: str = None,
apply_uttmvn: bool = True,
uttmvn_norm_means: bool = True,
uttmvn_norm_vars: bool = False,
):
super().__init__()
self.apply_uttmvn = apply_uttmvn
self.logmel = LogMel(fs=fs, n_fft=n_fft, n_mels=n_mels, fmin=fmin, fmax=fmax)
self.stats_file = stats_file
if stats_file is not None:
self.global_mvn = GlobalMVN(stats_file)
else:
self.global_mvn = None
if self.apply_uttmvn is not None:
self.uttmvn = UtteranceMVN(norm_means=uttmvn_norm_means, norm_vars=uttmvn_norm_vars)
else:
self.uttmvn = None
def forward(
self, x: ComplexTensor, ilens: Union[torch.LongTensor, np.ndarray, List[int]]
) -> Tuple[torch.Tensor, torch.LongTensor]:
# (B, T, F) or (B, T, C, F)
if x.dim() not in (3, 4):
raise ValueError(f"Input dim must be 3 or 4: {x.dim()}")
if not torch.is_tensor(ilens):
ilens = torch.from_numpy(np.asarray(ilens)).to(x.device)
if x.dim() == 4:
# h: (B, T, C, F) -> h: (B, T, F)
if self.training:
# Select 1ch randomly
ch = np.random.randint(x.size(2))
h = x[:, :, ch, :]
else:
# Use the first channel
h = x[:, :, 0, :]
else:
h = x
# h: ComplexTensor(B, T, F) -> torch.Tensor(B, T, F)
h = h.real**2 + h.imag**2
h, _ = self.logmel(h, ilens)
if self.stats_file is not None:
h, _ = self.global_mvn(h, ilens)
if self.apply_uttmvn:
h, _ = self.uttmvn(h, ilens)
return h, ilens
class LogMel(torch.nn.Module):
"""Convert STFT to fbank feats
The arguments is same as librosa.filters.mel
Args:
fs: number > 0 [scalar] sampling rate of the incoming signal
n_fft: int > 0 [scalar] number of FFT components
n_mels: int > 0 [scalar] number of Mel bands to generate
fmin: float >= 0 [scalar] lowest frequency (in Hz)
fmax: float >= 0 [scalar] highest frequency (in Hz).
If `None`, use `fmax = fs / 2.0`
htk: use HTK formula instead of Slaney
norm: {None, 1, np.inf} [scalar]
if 1, divide the triangular mel weights by the width of the mel band
(area normalization). Otherwise, leave all the triangles aiming for
a peak value of 1.0
"""
def __init__(
self,
fs: int = 16000,
n_fft: int = 512,
n_mels: int = 80,
fmin: float = 0.0,
fmax: float = None,
htk: bool = False,
norm=1,
):
super().__init__()
_mel_options = dict(
sr=fs, n_fft=n_fft, n_mels=n_mels, fmin=fmin, fmax=fmax, htk=htk, norm=norm
)
self.mel_options = _mel_options
# Note(kamo): The mel matrix of librosa is different from kaldi.
melmat = librosa.filters.mel(**_mel_options)
# melmat: (D2, D1) -> (D1, D2)
self.register_buffer("melmat", torch.from_numpy(melmat.T).float())
def extra_repr(self):
return ", ".join(f"{k}={v}" for k, v in self.mel_options.items())
def forward(
self, feat: torch.Tensor, ilens: torch.LongTensor
) -> Tuple[torch.Tensor, torch.LongTensor]:
# feat: (B, T, D1) x melmat: (D1, D2) -> mel_feat: (B, T, D2)
mel_feat = torch.matmul(feat, self.melmat)
logmel_feat = (mel_feat + 1e-20).log()
# Zero padding
logmel_feat = logmel_feat.masked_fill(make_pad_mask(ilens, logmel_feat, 1), 0.0)
return logmel_feat, ilens
class GlobalMVN(torch.nn.Module):
"""Apply global mean and variance normalization
Args:
stats_file(str): npy file of 1-dim array or text file.
From the _first element to
the {(len(array) - 1) / 2}th element are treated as
the sum of features,
and the rest excluding the last elements are
treated as the sum of the square value of features,
and the last elements eqauls to the number of samples.
std_floor(float):
"""
def __init__(
self,
stats_file: str,
norm_means: bool = True,
norm_vars: bool = True,
eps: float = 1.0e-20,
):
super().__init__()
self.norm_means = norm_means
self.norm_vars = norm_vars
self.stats_file = stats_file
stats = np.load(stats_file)
stats = stats.astype(float)
assert (len(stats) - 1) % 2 == 0, stats.shape
count = stats.flatten()[-1]
mean = stats[: (len(stats) - 1) // 2] / count
var = stats[(len(stats) - 1) // 2 : -1] / count - mean * mean
std = np.maximum(np.sqrt(var), eps)
self.register_buffer("bias", torch.from_numpy(-mean.astype(np.float32)))
self.register_buffer("scale", torch.from_numpy(1 / std.astype(np.float32)))
def extra_repr(self):
return (
f"stats_file={self.stats_file}, "
f"norm_means={self.norm_means}, norm_vars={self.norm_vars}"
)
def forward(
self, x: torch.Tensor, ilens: torch.LongTensor
) -> Tuple[torch.Tensor, torch.LongTensor]:
# feat: (B, T, D)
if self.norm_means:
x += self.bias.type_as(x)
x.masked_fill(make_pad_mask(ilens, x, 1), 0.0)
if self.norm_vars:
x *= self.scale.type_as(x)
return x, ilens
class UtteranceMVN(torch.nn.Module):
def __init__(self, norm_means: bool = True, norm_vars: bool = False, eps: float = 1.0e-20):
super().__init__()
self.norm_means = norm_means
self.norm_vars = norm_vars
self.eps = eps
def extra_repr(self):
return f"norm_means={self.norm_means}, norm_vars={self.norm_vars}"
def forward(
self, x: torch.Tensor, ilens: torch.LongTensor
) -> Tuple[torch.Tensor, torch.LongTensor]:
return utterance_mvn(
x, ilens, norm_means=self.norm_means, norm_vars=self.norm_vars, eps=self.eps
)
def utterance_mvn(
x: torch.Tensor,
ilens: torch.LongTensor,
norm_means: bool = True,
norm_vars: bool = False,
eps: float = 1.0e-20,
) -> Tuple[torch.Tensor, torch.LongTensor]:
"""Apply utterance mean and variance normalization
Args:
x: (B, T, D), assumed zero padded
ilens: (B, T, D)
norm_means:
norm_vars:
eps:
"""
ilens_ = ilens.type_as(x)
# mean: (B, D)
mean = x.sum(dim=1) / ilens_[:, None]
if norm_means:
x -= mean[:, None, :]
x_ = x
else:
x_ = x - mean[:, None, :]
# Zero padding
x_.masked_fill(make_pad_mask(ilens, x_, 1), 0.0)
if norm_vars:
var = x_.pow(2).sum(dim=1) / ilens_[:, None]
var = torch.clamp(var, min=eps)
x /= var.sqrt()[:, None, :]
x_ = x
return x_, ilens
def feature_transform_for(args, n_fft):
return FeatureTransform(
# Mel options,
fs=args.fbank_fs,
n_fft=n_fft,
n_mels=args.n_mels,
fmin=args.fbank_fmin,
fmax=args.fbank_fmax,
# Normalization
stats_file=args.stats_file,
apply_uttmvn=args.apply_uttmvn,
uttmvn_norm_means=args.uttmvn_norm_means,
uttmvn_norm_vars=args.uttmvn_norm_vars,
)

View File

@@ -0,0 +1,151 @@
from typing import List
from typing import Optional
from typing import Tuple
from typing import Union
import numpy
import torch
import torch.nn as nn
from torch_complex.tensor import ComplexTensor
from funasr.frontends.utils.dnn_beamformer import DNN_Beamformer
from funasr.frontends.utils.dnn_wpe import DNN_WPE
class Frontend(nn.Module):
def __init__(
self,
idim: int,
# WPE options
use_wpe: bool = False,
wtype: str = "blstmp",
wlayers: int = 3,
wunits: int = 300,
wprojs: int = 320,
wdropout_rate: float = 0.0,
taps: int = 5,
delay: int = 3,
use_dnn_mask_for_wpe: bool = True,
# Beamformer options
use_beamformer: bool = False,
btype: str = "blstmp",
blayers: int = 3,
bunits: int = 300,
bprojs: int = 320,
bnmask: int = 2,
badim: int = 320,
ref_channel: int = -1,
bdropout_rate=0.0,
):
super().__init__()
self.use_beamformer = use_beamformer
self.use_wpe = use_wpe
self.use_dnn_mask_for_wpe = use_dnn_mask_for_wpe
# use frontend for all the data,
# e.g. in the case of multi-speaker speech separation
self.use_frontend_for_all = bnmask > 2
if self.use_wpe:
if self.use_dnn_mask_for_wpe:
# Use DNN for power estimation
# (Not observed significant gains)
iterations = 1
else:
# Performing as conventional WPE, without DNN Estimator
iterations = 2
self.wpe = DNN_WPE(
wtype=wtype,
widim=idim,
wunits=wunits,
wprojs=wprojs,
wlayers=wlayers,
taps=taps,
delay=delay,
dropout_rate=wdropout_rate,
iterations=iterations,
use_dnn_mask=use_dnn_mask_for_wpe,
)
else:
self.wpe = None
if self.use_beamformer:
self.beamformer = DNN_Beamformer(
btype=btype,
bidim=idim,
bunits=bunits,
bprojs=bprojs,
blayers=blayers,
bnmask=bnmask,
dropout_rate=bdropout_rate,
badim=badim,
ref_channel=ref_channel,
)
else:
self.beamformer = None
def forward(
self, x: ComplexTensor, ilens: Union[torch.LongTensor, numpy.ndarray, List[int]]
) -> Tuple[ComplexTensor, torch.LongTensor, Optional[ComplexTensor]]:
assert len(x) == len(ilens), (len(x), len(ilens))
# (B, T, F) or (B, T, C, F)
if x.dim() not in (3, 4):
raise ValueError(f"Input dim must be 3 or 4: {x.dim()}")
if not torch.is_tensor(ilens):
ilens = torch.from_numpy(numpy.asarray(ilens)).to(x.device)
mask = None
h = x
if h.dim() == 4:
if self.training:
choices = [(False, False)] if not self.use_frontend_for_all else []
if self.use_wpe:
choices.append((True, False))
if self.use_beamformer:
choices.append((False, True))
use_wpe, use_beamformer = choices[numpy.random.randint(len(choices))]
else:
use_wpe = self.use_wpe
use_beamformer = self.use_beamformer
# 1. WPE
if use_wpe:
# h: (B, T, C, F) -> h: (B, T, C, F)
h, ilens, mask = self.wpe(h, ilens)
# 2. Beamformer
if use_beamformer:
# h: (B, T, C, F) -> h: (B, T, F)
h, ilens, mask = self.beamformer(h, ilens)
return h, ilens, mask
def frontend_for(args, idim):
return Frontend(
idim=idim,
# WPE options
use_wpe=args.use_wpe,
wtype=args.wtype,
wlayers=args.wlayers,
wunits=args.wunits,
wprojs=args.wprojs,
wdropout_rate=args.wdropout_rate,
taps=args.wpe_taps,
delay=args.wpe_delay,
use_dnn_mask_for_wpe=args.use_dnn_mask_for_wpe,
# Beamformer options
use_beamformer=args.use_beamformer,
btype=args.btype,
blayers=args.blayers,
bunits=args.bunits,
bprojs=args.bprojs,
bnmask=args.bnmask,
badim=args.badim,
ref_channel=args.ref_channel,
bdropout_rate=args.bdropout_rate,
)

View File

@@ -0,0 +1,79 @@
import librosa
import torch
from typing import Tuple
from funasr.models.transformer.utils.nets_utils import make_pad_mask
class LogMel(torch.nn.Module):
"""Convert STFT to fbank feats
The arguments is same as librosa.filters.mel
Args:
fs: number > 0 [scalar] sampling rate of the incoming signal
n_fft: int > 0 [scalar] number of FFT components
n_mels: int > 0 [scalar] number of Mel bands to generate
fmin: float >= 0 [scalar] lowest frequency (in Hz)
fmax: float >= 0 [scalar] highest frequency (in Hz).
If `None`, use `fmax = fs / 2.0`
htk: use HTK formula instead of Slaney
"""
def __init__(
self,
fs: int = 16000,
n_fft: int = 512,
n_mels: int = 80,
fmin: float = None,
fmax: float = None,
htk: bool = False,
log_base: float = None,
):
super().__init__()
fmin = 0 if fmin is None else fmin
fmax = fs / 2 if fmax is None else fmax
_mel_options = dict(
sr=fs,
n_fft=n_fft,
n_mels=n_mels,
fmin=fmin,
fmax=fmax,
htk=htk,
)
self.mel_options = _mel_options
self.log_base = log_base
# Note(kamo): The mel matrix of librosa is different from kaldi.
melmat = librosa.filters.mel(**_mel_options)
# melmat: (D2, D1) -> (D1, D2)
self.register_buffer("melmat", torch.from_numpy(melmat.T).float())
def extra_repr(self):
return ", ".join(f"{k}={v}" for k, v in self.mel_options.items())
def forward(
self,
feat: torch.Tensor,
ilens: torch.Tensor = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
# feat: (B, T, D1) x melmat: (D1, D2) -> mel_feat: (B, T, D2)
mel_feat = torch.matmul(feat, self.melmat)
mel_feat = torch.clamp(mel_feat, min=1e-10)
if self.log_base is None:
logmel_feat = mel_feat.log()
elif self.log_base == 2.0:
logmel_feat = mel_feat.log2()
elif self.log_base == 10.0:
logmel_feat = mel_feat.log10()
else:
logmel_feat = mel_feat.log() / torch.log(self.log_base)
# Zero padding
if ilens is not None:
logmel_feat = logmel_feat.masked_fill(make_pad_mask(ilens, logmel_feat, 1), 0.0)
else:
ilens = feat.new_full([feat.size(0)], fill_value=feat.size(1), dtype=torch.long)
return logmel_feat, ilens

View File

@@ -0,0 +1,75 @@
from typing import Tuple
import numpy as np
import torch
from torch.nn import functional as F
from torch_complex.tensor import ComplexTensor
from funasr.models.transformer.utils.nets_utils import make_pad_mask
from funasr.models.language_model.rnn.encoders import RNN
from funasr.models.language_model.rnn.encoders import RNNP
class MaskEstimator(torch.nn.Module):
def __init__(self, type, idim, layers, units, projs, dropout, nmask=1):
super().__init__()
subsample = np.ones(layers + 1, dtype=np.int32)
typ = type.lstrip("vgg").rstrip("p")
if type[-1] == "p":
self.brnn = RNNP(idim, layers, units, projs, subsample, dropout, typ=typ)
else:
self.brnn = RNN(idim, layers, units, projs, dropout, typ=typ)
self.type = type
self.nmask = nmask
self.linears = torch.nn.ModuleList([torch.nn.Linear(projs, idim) for _ in range(nmask)])
def forward(
self, xs: ComplexTensor, ilens: torch.LongTensor
) -> Tuple[Tuple[torch.Tensor, ...], torch.LongTensor]:
"""The forward function
Args:
xs: (B, F, C, T)
ilens: (B,)
Returns:
hs (torch.Tensor): The hidden vector (B, F, C, T)
masks: A tuple of the masks. (B, F, C, T)
ilens: (B,)
"""
assert xs.size(0) == ilens.size(0), (xs.size(0), ilens.size(0))
_, _, C, input_length = xs.size()
# (B, F, C, T) -> (B, C, T, F)
xs = xs.permute(0, 2, 3, 1)
# Calculate amplitude: (B, C, T, F) -> (B, C, T, F)
xs = (xs.real**2 + xs.imag**2) ** 0.5
# xs: (B, C, T, F) -> xs: (B * C, T, F)
xs = xs.contiguous().view(-1, xs.size(-2), xs.size(-1))
# ilens: (B,) -> ilens_: (B * C)
ilens_ = ilens[:, None].expand(-1, C).contiguous().view(-1)
# xs: (B * C, T, F) -> xs: (B * C, T, D)
xs, _, _ = self.brnn(xs, ilens_)
# xs: (B * C, T, D) -> xs: (B, C, T, D)
xs = xs.view(-1, C, xs.size(-2), xs.size(-1))
masks = []
for linear in self.linears:
# xs: (B, C, T, D) -> mask:(B, C, T, F)
mask = linear(xs)
mask = torch.sigmoid(mask)
# Zero padding
mask.masked_fill(make_pad_mask(ilens, mask, length_dim=2), 0)
# (B, C, T, F) -> (B, F, C, T)
mask = mask.permute(0, 3, 1, 2)
# Take cares of multi gpu cases: If input_length > max(ilens)
if mask.size(-1) < input_length:
mask = F.pad(mask, [0, input_length - mask.size(-1)], value=0)
masks.append(mask)
return tuple(masks), ilens

View File

@@ -0,0 +1,226 @@
from distutils.version import LooseVersion
from typing import Optional
from typing import Tuple
from typing import Union
import torch
try:
from torch_complex.tensor import ComplexTensor
except:
print("Please install torch_complex firstly")
from funasr.models.transformer.utils.nets_utils import make_pad_mask
from funasr.frontends.utils.complex_utils import is_complex
import librosa
import numpy as np
is_torch_1_9_plus = LooseVersion(torch.__version__) >= LooseVersion("1.9.0")
is_torch_1_7_plus = LooseVersion(torch.__version__) >= LooseVersion("1.7")
class Stft(torch.nn.Module):
def __init__(
self,
n_fft: int = 512,
win_length: int = None,
hop_length: int = 128,
window: Optional[str] = "hann",
center: bool = True,
normalized: bool = False,
onesided: bool = True,
):
super().__init__()
self.n_fft = n_fft
if win_length is None:
self.win_length = n_fft
else:
self.win_length = win_length
self.hop_length = hop_length
self.center = center
self.normalized = normalized
self.onesided = onesided
if window is not None and not hasattr(torch, f"{window}_window"):
if window.lower() != "povey":
raise ValueError(f"{window} window is not implemented")
self.window = window
def extra_repr(self):
return (
f"n_fft={self.n_fft}, "
f"win_length={self.win_length}, "
f"hop_length={self.hop_length}, "
f"center={self.center}, "
f"normalized={self.normalized}, "
f"onesided={self.onesided}"
)
def forward(
self, input: torch.Tensor, ilens: torch.Tensor = None
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""STFT forward function.
Args:
input: (Batch, Nsamples) or (Batch, Nsample, Channels)
ilens: (Batch)
Returns:
output: (Batch, Frames, Freq, 2) or (Batch, Frames, Channels, Freq, 2)
"""
bs = input.size(0)
if input.dim() == 3:
multi_channel = True
# input: (Batch, Nsample, Channels) -> (Batch * Channels, Nsample)
input = input.transpose(1, 2).reshape(-1, input.size(1))
else:
multi_channel = False
# NOTE(kamo):
# The default behaviour of torch.stft is compatible with librosa.stft
# about padding and scaling.
# Note that it's different from scipy.signal.stft
# output: (Batch, Freq, Frames, 2=real_imag)
# or (Batch, Channel, Freq, Frames, 2=real_imag)
if self.window is not None:
if self.window.lower() == "povey":
window = torch.hann_window(
self.win_length, periodic=False, device=input.device, dtype=input.dtype
).pow(0.85)
else:
window_func = getattr(torch, f"{self.window}_window")
window = window_func(self.win_length, dtype=input.dtype, device=input.device)
else:
window = None
# For the compatibility of ARM devices, which do not support
# torch.stft() due to the lake of MKL.
if input.is_cuda or torch.backends.mkl.is_available():
stft_kwargs = dict(
n_fft=self.n_fft,
win_length=self.win_length,
hop_length=self.hop_length,
center=self.center,
window=window,
normalized=self.normalized,
onesided=self.onesided,
)
if is_torch_1_7_plus:
stft_kwargs["return_complex"] = False
output = torch.stft(input, **stft_kwargs)
else:
if self.training:
raise NotImplementedError(
"stft is implemented with librosa on this device, which does not "
"support the training mode."
)
# use stft_kwargs to flexibly control different PyTorch versions' kwargs
stft_kwargs = dict(
n_fft=self.n_fft,
win_length=self.win_length,
hop_length=self.hop_length,
center=self.center,
window=window,
)
if window is not None:
# pad the given window to n_fft
n_pad_left = (self.n_fft - window.shape[0]) // 2
n_pad_right = self.n_fft - window.shape[0] - n_pad_left
stft_kwargs["window"] = torch.cat(
[torch.zeros(n_pad_left), window, torch.zeros(n_pad_right)], 0
).numpy()
else:
win_length = self.win_length if self.win_length is not None else self.n_fft
stft_kwargs["window"] = torch.ones(win_length)
output = []
# iterate over istances in a batch
for i, instance in enumerate(input):
stft = librosa.stft(input[i].numpy(), **stft_kwargs)
output.append(torch.tensor(np.stack([stft.real, stft.imag], -1)))
output = torch.stack(output, 0)
if not self.onesided:
len_conj = self.n_fft - output.shape[1]
conj = output[:, 1 : 1 + len_conj].flip(1)
conj[:, :, :, -1].data *= -1
output = torch.cat([output, conj], 1)
if self.normalized:
output = output * (stft_kwargs["window"].shape[0] ** (-0.5))
# output: (Batch, Freq, Frames, 2=real_imag)
# -> (Batch, Frames, Freq, 2=real_imag)
output = output.transpose(1, 2)
if multi_channel:
# output: (Batch * Channel, Frames, Freq, 2=real_imag)
# -> (Batch, Frame, Channel, Freq, 2=real_imag)
output = output.view(bs, -1, output.size(1), output.size(2), 2).transpose(1, 2)
if ilens is not None:
if self.center:
pad = self.n_fft // 2
ilens = ilens + 2 * pad
olens = (ilens - self.n_fft) // self.hop_length + 1
output.masked_fill_(make_pad_mask(olens, output, 1), 0.0)
else:
olens = None
return output, olens
def inverse(
self, input: Union[torch.Tensor, ComplexTensor], ilens: torch.Tensor = None
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""Inverse STFT.
Args:
input: Tensor(batch, T, F, 2) or ComplexTensor(batch, T, F)
ilens: (batch,)
Returns:
wavs: (batch, samples)
ilens: (batch,)
"""
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
istft = torch.functional.istft
else:
try:
import torchaudio
except ImportError:
raise ImportError("Please install torchaudio>=0.3.0 or use torch>=1.6.0")
if not hasattr(torchaudio.functional, "istft"):
raise ImportError("Please install torchaudio>=0.3.0 or use torch>=1.6.0")
istft = torchaudio.functional.istft
if self.window is not None:
window_func = getattr(torch, f"{self.window}_window")
if is_complex(input):
datatype = input.real.dtype
else:
datatype = input.dtype
window = window_func(self.win_length, dtype=datatype, device=input.device)
else:
window = None
if is_complex(input):
input = torch.stack([input.real, input.imag], dim=-1)
elif input.shape[-1] != 2:
raise TypeError("Invalid input type")
input = input.transpose(1, 2)
wavs = istft(
input,
n_fft=self.n_fft,
hop_length=self.hop_length,
win_length=self.win_length,
window=window,
center=self.center,
normalized=self.normalized,
onesided=self.onesided,
length=ilens.max() if ilens is not None else ilens,
)
return wavs, ilens