Sync from bytedesk-private: update

This commit is contained in:
jack ning
2024-12-14 10:43:18 +08:00
parent 476eebb101
commit 5e082909e4
3421 changed files with 812709 additions and 0 deletions

View File

@@ -0,0 +1,116 @@
from pathlib import Path
from typing import Tuple
from typing import Union
import numpy as np
import torch
from funasr.models.transformer.utils.nets_utils import make_pad_mask
from funasr.register import tables
@tables.register("normalize_classes", "GlobalMVN")
class GlobalMVN(torch.nn.Module):
"""Apply global mean and variance normalization
TODO(kamo): Make this class portable somehow
Args:
stats_file: npy file
norm_means: Apply mean normalization
norm_vars: Apply var normalization
eps:
"""
def __init__(
self,
stats_file: Union[Path, str],
norm_means: bool = True,
norm_vars: bool = True,
eps: float = 1.0e-20,
):
super().__init__()
self.norm_means = norm_means
self.norm_vars = norm_vars
self.eps = eps
stats_file = Path(stats_file)
self.stats_file = stats_file
stats = np.load(stats_file)
if isinstance(stats, np.ndarray):
# Kaldi like stats
count = stats[0].flatten()[-1]
mean = stats[0, :-1] / count
var = stats[1, :-1] / count - mean * mean
else:
# New style: Npz file
count = stats["count"]
sum_v = stats["sum"]
sum_square_v = stats["sum_square"]
mean = sum_v / count
var = sum_square_v / count - mean * mean
std = np.sqrt(np.maximum(var, eps))
self.register_buffer("mean", torch.from_numpy(mean))
self.register_buffer("std", torch.from_numpy(std))
def extra_repr(self):
return (
f"stats_file={self.stats_file}, "
f"norm_means={self.norm_means}, norm_vars={self.norm_vars}"
)
def forward(
self, x: torch.Tensor, ilens: torch.Tensor = None
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Forward function
Args:
x: (B, L, ...)
ilens: (B,)
"""
if ilens is None:
ilens = x.new_full([x.size(0)], x.size(1))
norm_means = self.norm_means
norm_vars = self.norm_vars
self.mean = self.mean.to(x.device, x.dtype)
self.std = self.std.to(x.device, x.dtype)
mask = make_pad_mask(ilens, x, 1)
# feat: (B, T, D)
if norm_means:
if x.requires_grad:
x = x - self.mean
else:
x -= self.mean
if x.requires_grad:
x = x.masked_fill(mask, 0.0)
else:
x.masked_fill_(mask, 0.0)
if norm_vars:
x /= self.std
return x, ilens
def inverse(
self, x: torch.Tensor, ilens: torch.Tensor = None
) -> Tuple[torch.Tensor, torch.Tensor]:
if ilens is None:
ilens = x.new_full([x.size(0)], x.size(1))
norm_means = self.norm_means
norm_vars = self.norm_vars
self.mean = self.mean.to(x.device, x.dtype)
self.std = self.std.to(x.device, x.dtype)
mask = make_pad_mask(ilens, x, 1)
if x.requires_grad:
x = x.masked_fill(mask, 0.0)
else:
x.masked_fill_(mask, 0.0)
if norm_vars:
x *= self.std
# feat: (B, T, D)
if norm_means:
x += self.mean
x.masked_fill_(make_pad_mask(ilens, x, 1), 0.0)
return x, ilens

View File

@@ -0,0 +1,87 @@
from typing import Tuple
import torch
from funasr.models.transformer.utils.nets_utils import make_pad_mask
from funasr.register import tables
@tables.register("normalize_classes", "UtteranceMVN")
class UtteranceMVN(torch.nn.Module):
def __init__(
self,
norm_means: bool = True,
norm_vars: bool = False,
eps: float = 1.0e-20,
):
super().__init__()
self.norm_means = norm_means
self.norm_vars = norm_vars
self.eps = eps
def extra_repr(self):
return f"norm_means={self.norm_means}, norm_vars={self.norm_vars}"
def forward(
self, x: torch.Tensor, ilens: torch.Tensor = None
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Forward function
Args:
x: (B, L, ...)
ilens: (B,)
"""
return utterance_mvn(
x,
ilens,
norm_means=self.norm_means,
norm_vars=self.norm_vars,
eps=self.eps,
)
def utterance_mvn(
x: torch.Tensor,
ilens: torch.Tensor = None,
norm_means: bool = True,
norm_vars: bool = False,
eps: float = 1.0e-20,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Apply utterance mean and variance normalization
Args:
x: (B, T, D), assumed zero padded
ilens: (B,)
norm_means:
norm_vars:
eps:
"""
if ilens is None:
ilens = x.new_full([x.size(0)], x.size(1))
ilens_ = ilens.to(x.device, x.dtype).view(-1, *[1 for _ in range(x.dim() - 1)])
# Zero padding
if x.requires_grad:
x = x.masked_fill(make_pad_mask(ilens, x, 1), 0.0)
else:
x.masked_fill_(make_pad_mask(ilens, x, 1), 0.0)
# mean: (B, 1, D)
mean = x.sum(dim=1, keepdim=True) / ilens_
if norm_means:
x -= mean
if norm_vars:
var = x.pow(2).sum(dim=1, keepdim=True) / ilens_
std = torch.clamp(var.sqrt(), min=eps)
x = x / std.sqrt()
return x, ilens
else:
if norm_vars:
y = x - mean
y.masked_fill_(make_pad_mask(ilens, y, 1), 0.0)
var = y.pow(2).sum(dim=1, keepdim=True) / ilens_
std = torch.clamp(var.sqrt(), min=eps)
x /= std
return x, ilens