mirror of
https://gitee.com/270580156/weiyu.git
synced 2026-05-16 12:18:10 +00:00
Sync from bytedesk-private: update
This commit is contained in:
0
modules/python/vendors/FunASR/funasr/models/sond/encoder/__init__.py
vendored
Normal file
0
modules/python/vendors/FunASR/funasr/models/sond/encoder/__init__.py
vendored
Normal file
32
modules/python/vendors/FunASR/funasr/models/sond/encoder/ci_scorers.py
vendored
Normal file
32
modules/python/vendors/FunASR/funasr/models/sond/encoder/ci_scorers.py
vendored
Normal file
@@ -0,0 +1,32 @@
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
|
||||
|
||||
class DotScorer(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
xs_pad: torch.Tensor,
|
||||
spk_emb: torch.Tensor,
|
||||
):
|
||||
# xs_pad: B, T, D
|
||||
# spk_emb: B, N, D
|
||||
scores = torch.matmul(xs_pad, spk_emb.transpose(1, 2))
|
||||
return scores
|
||||
|
||||
|
||||
class CosScorer(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
xs_pad: torch.Tensor,
|
||||
spk_emb: torch.Tensor,
|
||||
):
|
||||
# xs_pad: B, T, D
|
||||
# spk_emb: B, N, D
|
||||
scores = F.cosine_similarity(xs_pad.unsqueeze(2), spk_emb.unsqueeze(1), dim=-1)
|
||||
return scores
|
||||
174
modules/python/vendors/FunASR/funasr/models/sond/encoder/conv_encoder.py
vendored
Normal file
174
modules/python/vendors/FunASR/funasr/models/sond/encoder/conv_encoder.py
vendored
Normal file
@@ -0,0 +1,174 @@
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
from typing import Sequence
|
||||
from typing import Tuple
|
||||
from typing import Union
|
||||
import logging
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import functional as F
|
||||
import numpy as np
|
||||
from funasr.models.transformer.utils.nets_utils import make_pad_mask
|
||||
from funasr.models.transformer.layer_norm import LayerNorm
|
||||
from funasr.models.encoder.abs_encoder import AbsEncoder
|
||||
import math
|
||||
from funasr.models.transformer.utils.repeat import repeat
|
||||
|
||||
|
||||
class EncoderLayer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
input_units,
|
||||
num_units,
|
||||
kernel_size=3,
|
||||
activation="tanh",
|
||||
stride=1,
|
||||
include_batch_norm=False,
|
||||
residual=False,
|
||||
):
|
||||
super().__init__()
|
||||
left_padding = math.ceil((kernel_size - stride) / 2)
|
||||
right_padding = kernel_size - stride - left_padding
|
||||
self.conv_padding = nn.ConstantPad1d((left_padding, right_padding), 0.0)
|
||||
self.conv1d = nn.Conv1d(
|
||||
input_units,
|
||||
num_units,
|
||||
kernel_size,
|
||||
stride,
|
||||
)
|
||||
self.activation = self.get_activation(activation)
|
||||
if include_batch_norm:
|
||||
self.bn = nn.BatchNorm1d(num_units, momentum=0.99, eps=1e-3)
|
||||
self.residual = residual
|
||||
self.include_batch_norm = include_batch_norm
|
||||
self.input_units = input_units
|
||||
self.num_units = num_units
|
||||
self.stride = stride
|
||||
|
||||
@staticmethod
|
||||
def get_activation(activation):
|
||||
if activation == "tanh":
|
||||
return nn.Tanh()
|
||||
else:
|
||||
return nn.ReLU()
|
||||
|
||||
def forward(self, xs_pad, ilens=None):
|
||||
outputs = self.conv1d(self.conv_padding(xs_pad))
|
||||
if self.residual and self.stride == 1 and self.input_units == self.num_units:
|
||||
outputs = outputs + xs_pad
|
||||
|
||||
if self.include_batch_norm:
|
||||
outputs = self.bn(outputs)
|
||||
|
||||
# add parenthesis for repeat module
|
||||
return self.activation(outputs), ilens
|
||||
|
||||
|
||||
class ConvEncoder(AbsEncoder):
|
||||
"""
|
||||
Author: Speech Lab of DAMO Academy, Alibaba Group
|
||||
Convolution encoder in OpenNMT framework
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_layers,
|
||||
input_units,
|
||||
num_units,
|
||||
kernel_size=3,
|
||||
dropout_rate=0.3,
|
||||
position_encoder=None,
|
||||
activation="tanh",
|
||||
auxiliary_states=True,
|
||||
out_units=None,
|
||||
out_norm=False,
|
||||
out_residual=False,
|
||||
include_batchnorm=False,
|
||||
regularization_weight=0.0,
|
||||
stride=1,
|
||||
tf2torch_tensor_name_prefix_torch: str = "speaker_encoder",
|
||||
tf2torch_tensor_name_prefix_tf: str = "EAND/speaker_encoder",
|
||||
):
|
||||
super().__init__()
|
||||
self._output_size = num_units
|
||||
|
||||
self.num_layers = num_layers
|
||||
self.input_units = input_units
|
||||
self.num_units = num_units
|
||||
self.kernel_size = kernel_size
|
||||
self.dropout_rate = dropout_rate
|
||||
self.position_encoder = position_encoder
|
||||
self.out_units = out_units
|
||||
self.auxiliary_states = auxiliary_states
|
||||
self.out_norm = out_norm
|
||||
self.activation = activation
|
||||
self.out_residual = out_residual
|
||||
self.include_batch_norm = include_batchnorm
|
||||
self.regularization_weight = regularization_weight
|
||||
self.tf2torch_tensor_name_prefix_torch = tf2torch_tensor_name_prefix_torch
|
||||
self.tf2torch_tensor_name_prefix_tf = tf2torch_tensor_name_prefix_tf
|
||||
if isinstance(stride, int):
|
||||
self.stride = [stride] * self.num_layers
|
||||
else:
|
||||
self.stride = stride
|
||||
self.downsample_rate = 1
|
||||
for s in self.stride:
|
||||
self.downsample_rate *= s
|
||||
|
||||
self.dropout = nn.Dropout(dropout_rate)
|
||||
self.cnn_a = repeat(
|
||||
self.num_layers,
|
||||
lambda lnum: EncoderLayer(
|
||||
input_units if lnum == 0 else num_units,
|
||||
num_units,
|
||||
kernel_size,
|
||||
activation,
|
||||
self.stride[lnum],
|
||||
include_batchnorm,
|
||||
residual=True if lnum > 0 else False,
|
||||
),
|
||||
)
|
||||
|
||||
if self.out_units is not None:
|
||||
left_padding = math.ceil((kernel_size - stride) / 2)
|
||||
right_padding = kernel_size - stride - left_padding
|
||||
self.out_padding = nn.ConstantPad1d((left_padding, right_padding), 0.0)
|
||||
self.conv_out = nn.Conv1d(
|
||||
num_units,
|
||||
out_units,
|
||||
kernel_size,
|
||||
)
|
||||
|
||||
if self.out_norm:
|
||||
self.after_norm = LayerNorm(out_units)
|
||||
|
||||
def output_size(self) -> int:
|
||||
return self.num_units
|
||||
|
||||
def forward(
|
||||
self,
|
||||
xs_pad: torch.Tensor,
|
||||
ilens: torch.Tensor,
|
||||
prev_states: torch.Tensor = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
|
||||
|
||||
inputs = xs_pad
|
||||
if self.position_encoder is not None:
|
||||
inputs = self.position_encoder(inputs)
|
||||
|
||||
if self.dropout_rate > 0:
|
||||
inputs = self.dropout(inputs)
|
||||
|
||||
outputs, _ = self.cnn_a(inputs.transpose(1, 2), ilens)
|
||||
|
||||
if self.out_units is not None:
|
||||
outputs = self.conv_out(self.out_padding(outputs))
|
||||
|
||||
outputs = outputs.transpose(1, 2)
|
||||
if self.out_norm:
|
||||
outputs = self.after_norm(outputs)
|
||||
|
||||
if self.out_residual:
|
||||
outputs = outputs + inputs
|
||||
|
||||
return outputs, ilens, None
|
||||
672
modules/python/vendors/FunASR/funasr/models/sond/encoder/ecapa_tdnn_encoder.py
vendored
Normal file
672
modules/python/vendors/FunASR/funasr/models/sond/encoder/ecapa_tdnn_encoder.py
vendored
Normal file
@@ -0,0 +1,672 @@
|
||||
import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class _BatchNorm1d(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
input_shape=None,
|
||||
input_size=None,
|
||||
eps=1e-05,
|
||||
momentum=0.1,
|
||||
affine=True,
|
||||
track_running_stats=True,
|
||||
combine_batch_time=False,
|
||||
skip_transpose=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.combine_batch_time = combine_batch_time
|
||||
self.skip_transpose = skip_transpose
|
||||
|
||||
if input_size is None and skip_transpose:
|
||||
input_size = input_shape[1]
|
||||
elif input_size is None:
|
||||
input_size = input_shape[-1]
|
||||
|
||||
self.norm = nn.BatchNorm1d(
|
||||
input_size,
|
||||
eps=eps,
|
||||
momentum=momentum,
|
||||
affine=affine,
|
||||
track_running_stats=track_running_stats,
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
shape_or = x.shape
|
||||
if self.combine_batch_time:
|
||||
if x.ndim == 3:
|
||||
x = x.reshape(shape_or[0] * shape_or[1], shape_or[2])
|
||||
else:
|
||||
x = x.reshape(shape_or[0] * shape_or[1], shape_or[3], shape_or[2])
|
||||
|
||||
elif not self.skip_transpose:
|
||||
x = x.transpose(-1, 1)
|
||||
|
||||
x_n = self.norm(x)
|
||||
|
||||
if self.combine_batch_time:
|
||||
x_n = x_n.reshape(shape_or)
|
||||
elif not self.skip_transpose:
|
||||
x_n = x_n.transpose(1, -1)
|
||||
|
||||
return x_n
|
||||
|
||||
|
||||
class _Conv1d(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
input_shape=None,
|
||||
in_channels=None,
|
||||
stride=1,
|
||||
dilation=1,
|
||||
padding="same",
|
||||
groups=1,
|
||||
bias=True,
|
||||
padding_mode="reflect",
|
||||
skip_transpose=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.kernel_size = kernel_size
|
||||
self.stride = stride
|
||||
self.dilation = dilation
|
||||
self.padding = padding
|
||||
self.padding_mode = padding_mode
|
||||
self.unsqueeze = False
|
||||
self.skip_transpose = skip_transpose
|
||||
|
||||
if input_shape is None and in_channels is None:
|
||||
raise ValueError("Must provide one of input_shape or in_channels")
|
||||
|
||||
if in_channels is None:
|
||||
in_channels = self._check_input_shape(input_shape)
|
||||
|
||||
self.conv = nn.Conv1d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
self.kernel_size,
|
||||
stride=self.stride,
|
||||
dilation=self.dilation,
|
||||
padding=0,
|
||||
groups=groups,
|
||||
bias=bias,
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
if not self.skip_transpose:
|
||||
x = x.transpose(1, -1)
|
||||
|
||||
if self.unsqueeze:
|
||||
x = x.unsqueeze(1)
|
||||
|
||||
if self.padding == "same":
|
||||
x = self._manage_padding(x, self.kernel_size, self.dilation, self.stride)
|
||||
|
||||
elif self.padding == "causal":
|
||||
num_pad = (self.kernel_size - 1) * self.dilation
|
||||
x = F.pad(x, (num_pad, 0))
|
||||
|
||||
elif self.padding == "valid":
|
||||
pass
|
||||
|
||||
else:
|
||||
raise ValueError("Padding must be 'same', 'valid' or 'causal'. Got " + self.padding)
|
||||
|
||||
wx = self.conv(x)
|
||||
|
||||
if self.unsqueeze:
|
||||
wx = wx.squeeze(1)
|
||||
|
||||
if not self.skip_transpose:
|
||||
wx = wx.transpose(1, -1)
|
||||
|
||||
return wx
|
||||
|
||||
def _manage_padding(
|
||||
self,
|
||||
x,
|
||||
kernel_size: int,
|
||||
dilation: int,
|
||||
stride: int,
|
||||
):
|
||||
# Detecting input shape
|
||||
L_in = x.shape[-1]
|
||||
|
||||
# Time padding
|
||||
padding = get_padding_elem(L_in, stride, kernel_size, dilation)
|
||||
|
||||
# Applying padding
|
||||
x = F.pad(x, padding, mode=self.padding_mode)
|
||||
|
||||
return x
|
||||
|
||||
def _check_input_shape(self, shape):
|
||||
"""Checks the input shape and returns the number of input channels."""
|
||||
|
||||
if len(shape) == 2:
|
||||
self.unsqueeze = True
|
||||
in_channels = 1
|
||||
elif self.skip_transpose:
|
||||
in_channels = shape[1]
|
||||
elif len(shape) == 3:
|
||||
in_channels = shape[2]
|
||||
else:
|
||||
raise ValueError("conv1d expects 2d, 3d inputs. Got " + str(len(shape)))
|
||||
|
||||
# Kernel size must be odd
|
||||
if self.kernel_size % 2 == 0:
|
||||
raise ValueError(
|
||||
"The field kernel size must be an odd number. Got %s." % (self.kernel_size)
|
||||
)
|
||||
return in_channels
|
||||
|
||||
|
||||
def get_padding_elem(L_in: int, stride: int, kernel_size: int, dilation: int):
|
||||
if stride > 1:
|
||||
n_steps = math.ceil(((L_in - kernel_size * dilation) / stride) + 1)
|
||||
L_out = stride * (n_steps - 1) + kernel_size * dilation
|
||||
padding = [kernel_size // 2, kernel_size // 2]
|
||||
|
||||
else:
|
||||
L_out = (L_in - dilation * (kernel_size - 1) - 1) // stride + 1
|
||||
|
||||
padding = [(L_in - L_out) // 2, (L_in - L_out) // 2]
|
||||
return padding
|
||||
|
||||
|
||||
# Skip transpose as much as possible for efficiency
|
||||
class Conv1d(_Conv1d):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(skip_transpose=True, *args, **kwargs)
|
||||
|
||||
|
||||
class BatchNorm1d(_BatchNorm1d):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(skip_transpose=True, *args, **kwargs)
|
||||
|
||||
|
||||
def length_to_mask(length, max_len=None, dtype=None, device=None):
|
||||
assert len(length.shape) == 1
|
||||
|
||||
if max_len is None:
|
||||
max_len = length.max().long().item() # using arange to generate mask
|
||||
mask = torch.arange(max_len, device=length.device, dtype=length.dtype).expand(
|
||||
len(length), max_len
|
||||
) < length.unsqueeze(1)
|
||||
|
||||
if dtype is None:
|
||||
dtype = length.dtype
|
||||
|
||||
if device is None:
|
||||
device = length.device
|
||||
|
||||
mask = torch.as_tensor(mask, dtype=dtype, device=device)
|
||||
return mask
|
||||
|
||||
|
||||
class TDNNBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
dilation,
|
||||
activation=nn.ReLU,
|
||||
groups=1,
|
||||
):
|
||||
super(TDNNBlock, self).__init__()
|
||||
self.conv = Conv1d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=kernel_size,
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
)
|
||||
self.activation = activation()
|
||||
self.norm = BatchNorm1d(input_size=out_channels)
|
||||
|
||||
def forward(self, x):
|
||||
return self.norm(self.activation(self.conv(x)))
|
||||
|
||||
|
||||
class Res2NetBlock(torch.nn.Module):
|
||||
"""An implementation of Res2NetBlock w/ dilation.
|
||||
|
||||
Arguments
|
||||
---------
|
||||
in_channels : int
|
||||
The number of channels expected in the input.
|
||||
out_channels : int
|
||||
The number of output channels.
|
||||
scale : int
|
||||
The scale of the Res2Net block.
|
||||
kernel_size: int
|
||||
The kernel size of the Res2Net block.
|
||||
dilation : int
|
||||
The dilation of the Res2Net block.
|
||||
|
||||
Example
|
||||
-------
|
||||
>>> inp_tensor = torch.rand([8, 120, 64]).transpose(1, 2)
|
||||
>>> layer = Res2NetBlock(64, 64, scale=4, dilation=3)
|
||||
>>> out_tensor = layer(inp_tensor).transpose(1, 2)
|
||||
>>> out_tensor.shape
|
||||
torch.Size([8, 120, 64])
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels, out_channels, scale=8, kernel_size=3, dilation=1):
|
||||
super(Res2NetBlock, self).__init__()
|
||||
assert in_channels % scale == 0
|
||||
assert out_channels % scale == 0
|
||||
|
||||
in_channel = in_channels // scale
|
||||
hidden_channel = out_channels // scale
|
||||
|
||||
self.blocks = nn.ModuleList(
|
||||
[
|
||||
TDNNBlock(
|
||||
in_channel,
|
||||
hidden_channel,
|
||||
kernel_size=kernel_size,
|
||||
dilation=dilation,
|
||||
)
|
||||
for i in range(scale - 1)
|
||||
]
|
||||
)
|
||||
self.scale = scale
|
||||
|
||||
def forward(self, x):
|
||||
y = []
|
||||
for i, x_i in enumerate(torch.chunk(x, self.scale, dim=1)):
|
||||
if i == 0:
|
||||
y_i = x_i
|
||||
elif i == 1:
|
||||
y_i = self.blocks[i - 1](x_i)
|
||||
else:
|
||||
y_i = self.blocks[i - 1](x_i + y_i)
|
||||
y.append(y_i)
|
||||
y = torch.cat(y, dim=1)
|
||||
return y
|
||||
|
||||
|
||||
class SEBlock(nn.Module):
|
||||
"""An implementation of squeeze-and-excitation block.
|
||||
|
||||
Arguments
|
||||
---------
|
||||
in_channels : int
|
||||
The number of input channels.
|
||||
se_channels : int
|
||||
The number of output channels after squeeze.
|
||||
out_channels : int
|
||||
The number of output channels.
|
||||
|
||||
Example
|
||||
-------
|
||||
>>> inp_tensor = torch.rand([8, 120, 64]).transpose(1, 2)
|
||||
>>> se_layer = SEBlock(64, 16, 64)
|
||||
>>> lengths = torch.rand((8,))
|
||||
>>> out_tensor = se_layer(inp_tensor, lengths).transpose(1, 2)
|
||||
>>> out_tensor.shape
|
||||
torch.Size([8, 120, 64])
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels, se_channels, out_channels):
|
||||
super(SEBlock, self).__init__()
|
||||
|
||||
self.conv1 = Conv1d(in_channels=in_channels, out_channels=se_channels, kernel_size=1)
|
||||
self.relu = torch.nn.ReLU(inplace=True)
|
||||
self.conv2 = Conv1d(in_channels=se_channels, out_channels=out_channels, kernel_size=1)
|
||||
self.sigmoid = torch.nn.Sigmoid()
|
||||
|
||||
def forward(self, x, lengths=None):
|
||||
L = x.shape[-1]
|
||||
if lengths is not None:
|
||||
mask = length_to_mask(lengths * L, max_len=L, device=x.device)
|
||||
mask = mask.unsqueeze(1)
|
||||
total = mask.sum(dim=2, keepdim=True)
|
||||
s = (x * mask).sum(dim=2, keepdim=True) / total
|
||||
else:
|
||||
s = x.mean(dim=2, keepdim=True)
|
||||
|
||||
s = self.relu(self.conv1(s))
|
||||
s = self.sigmoid(self.conv2(s))
|
||||
|
||||
return s * x
|
||||
|
||||
|
||||
class AttentiveStatisticsPooling(nn.Module):
|
||||
"""This class implements an attentive statistic pooling layer for each channel.
|
||||
It returns the concatenated mean and std of the input tensor.
|
||||
|
||||
Arguments
|
||||
---------
|
||||
channels: int
|
||||
The number of input channels.
|
||||
attention_channels: int
|
||||
The number of attention channels.
|
||||
|
||||
Example
|
||||
-------
|
||||
>>> inp_tensor = torch.rand([8, 120, 64]).transpose(1, 2)
|
||||
>>> asp_layer = AttentiveStatisticsPooling(64)
|
||||
>>> lengths = torch.rand((8,))
|
||||
>>> out_tensor = asp_layer(inp_tensor, lengths).transpose(1, 2)
|
||||
>>> out_tensor.shape
|
||||
torch.Size([8, 1, 128])
|
||||
"""
|
||||
|
||||
def __init__(self, channels, attention_channels=128, global_context=True):
|
||||
super().__init__()
|
||||
|
||||
self.eps = 1e-12
|
||||
self.global_context = global_context
|
||||
if global_context:
|
||||
self.tdnn = TDNNBlock(channels * 3, attention_channels, 1, 1)
|
||||
else:
|
||||
self.tdnn = TDNNBlock(channels, attention_channels, 1, 1)
|
||||
self.tanh = nn.Tanh()
|
||||
self.conv = Conv1d(in_channels=attention_channels, out_channels=channels, kernel_size=1)
|
||||
|
||||
def forward(self, x, lengths=None):
|
||||
"""Calculates mean and std for a batch (input tensor).
|
||||
|
||||
Arguments
|
||||
---------
|
||||
x : torch.Tensor
|
||||
Tensor of shape [N, C, L].
|
||||
"""
|
||||
L = x.shape[-1]
|
||||
|
||||
def _compute_statistics(x, m, dim=2, eps=self.eps):
|
||||
mean = (m * x).sum(dim)
|
||||
std = torch.sqrt((m * (x - mean.unsqueeze(dim)).pow(2)).sum(dim).clamp(eps))
|
||||
return mean, std
|
||||
|
||||
if lengths is None:
|
||||
lengths = torch.ones(x.shape[0], device=x.device)
|
||||
|
||||
# Make binary mask of shape [N, 1, L]
|
||||
mask = length_to_mask(lengths * L, max_len=L, device=x.device)
|
||||
mask = mask.unsqueeze(1)
|
||||
|
||||
# Expand the temporal context of the pooling layer by allowing the
|
||||
# self-attention to look at global properties of the utterance.
|
||||
if self.global_context:
|
||||
# torch.std is unstable for backward computation
|
||||
# https://github.com/pytorch/pytorch/issues/4320
|
||||
total = mask.sum(dim=2, keepdim=True).float()
|
||||
mean, std = _compute_statistics(x, mask / total)
|
||||
mean = mean.unsqueeze(2).repeat(1, 1, L)
|
||||
std = std.unsqueeze(2).repeat(1, 1, L)
|
||||
attn = torch.cat([x, mean, std], dim=1)
|
||||
else:
|
||||
attn = x
|
||||
|
||||
# Apply layers
|
||||
attn = self.conv(self.tanh(self.tdnn(attn)))
|
||||
|
||||
# Filter out zero-paddings
|
||||
attn = attn.masked_fill(mask == 0, float("-inf"))
|
||||
|
||||
attn = F.softmax(attn, dim=2)
|
||||
mean, std = _compute_statistics(x, attn)
|
||||
# Append mean and std of the batch
|
||||
pooled_stats = torch.cat((mean, std), dim=1)
|
||||
pooled_stats = pooled_stats.unsqueeze(2)
|
||||
|
||||
return pooled_stats
|
||||
|
||||
|
||||
class SERes2NetBlock(nn.Module):
|
||||
"""An implementation of building block in ECAPA-TDNN, i.e.,
|
||||
TDNN-Res2Net-TDNN-SEBlock.
|
||||
|
||||
Arguments
|
||||
----------
|
||||
out_channels: int
|
||||
The number of output channels.
|
||||
res2net_scale: int
|
||||
The scale of the Res2Net block.
|
||||
kernel_size: int
|
||||
The kernel size of the TDNN blocks.
|
||||
dilation: int
|
||||
The dilation of the Res2Net block.
|
||||
activation : torch class
|
||||
A class for constructing the activation layers.
|
||||
groups: int
|
||||
Number of blocked connections from input channels to output channels.
|
||||
|
||||
Example
|
||||
-------
|
||||
>>> x = torch.rand(8, 120, 64).transpose(1, 2)
|
||||
>>> conv = SERes2NetBlock(64, 64, res2net_scale=4)
|
||||
>>> out = conv(x).transpose(1, 2)
|
||||
>>> out.shape
|
||||
torch.Size([8, 120, 64])
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
res2net_scale=8,
|
||||
se_channels=128,
|
||||
kernel_size=1,
|
||||
dilation=1,
|
||||
activation=torch.nn.ReLU,
|
||||
groups=1,
|
||||
):
|
||||
super().__init__()
|
||||
self.out_channels = out_channels
|
||||
self.tdnn1 = TDNNBlock(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=1,
|
||||
dilation=1,
|
||||
activation=activation,
|
||||
groups=groups,
|
||||
)
|
||||
self.res2net_block = Res2NetBlock(
|
||||
out_channels, out_channels, res2net_scale, kernel_size, dilation
|
||||
)
|
||||
self.tdnn2 = TDNNBlock(
|
||||
out_channels,
|
||||
out_channels,
|
||||
kernel_size=1,
|
||||
dilation=1,
|
||||
activation=activation,
|
||||
groups=groups,
|
||||
)
|
||||
self.se_block = SEBlock(out_channels, se_channels, out_channels)
|
||||
|
||||
self.shortcut = None
|
||||
if in_channels != out_channels:
|
||||
self.shortcut = Conv1d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=1,
|
||||
)
|
||||
|
||||
def forward(self, x, lengths=None):
|
||||
residual = x
|
||||
if self.shortcut:
|
||||
residual = self.shortcut(x)
|
||||
|
||||
x = self.tdnn1(x)
|
||||
x = self.res2net_block(x)
|
||||
x = self.tdnn2(x)
|
||||
x = self.se_block(x, lengths)
|
||||
|
||||
return x + residual
|
||||
|
||||
|
||||
class ECAPA_TDNN(torch.nn.Module):
|
||||
"""An implementation of the speaker embedding model in a paper.
|
||||
"ECAPA-TDNN: Emphasized Channel Attention, Propagation and Aggregation in
|
||||
TDNN Based Speaker Verification" (https://arxiv.org/abs/2005.07143).
|
||||
|
||||
Arguments
|
||||
---------
|
||||
activation : torch class
|
||||
A class for constructing the activation layers.
|
||||
channels : list of ints
|
||||
Output channels for TDNN/SERes2Net layer.
|
||||
kernel_sizes : list of ints
|
||||
List of kernel sizes for each layer.
|
||||
dilations : list of ints
|
||||
List of dilations for kernels in each layer.
|
||||
lin_neurons : int
|
||||
Number of neurons in linear layers.
|
||||
groups : list of ints
|
||||
List of groups for kernels in each layer.
|
||||
|
||||
Example
|
||||
-------
|
||||
>>> input_feats = torch.rand([5, 120, 80])
|
||||
>>> compute_embedding = ECAPA_TDNN(80, lin_neurons=192)
|
||||
>>> outputs = compute_embedding(input_feats)
|
||||
>>> outputs.shape
|
||||
torch.Size([5, 1, 192])
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_size,
|
||||
lin_neurons=192,
|
||||
activation=torch.nn.ReLU,
|
||||
channels=[512, 512, 512, 512, 1536],
|
||||
kernel_sizes=[5, 3, 3, 3, 1],
|
||||
dilations=[1, 2, 3, 4, 1],
|
||||
attention_channels=128,
|
||||
res2net_scale=8,
|
||||
se_channels=128,
|
||||
global_context=True,
|
||||
groups=[1, 1, 1, 1, 1],
|
||||
window_size=20,
|
||||
window_shift=1,
|
||||
):
|
||||
|
||||
super().__init__()
|
||||
assert len(channels) == len(kernel_sizes)
|
||||
assert len(channels) == len(dilations)
|
||||
self.channels = channels
|
||||
self.blocks = nn.ModuleList()
|
||||
self.window_size = window_size
|
||||
self.window_shift = window_shift
|
||||
|
||||
# The initial TDNN layer
|
||||
self.blocks.append(
|
||||
TDNNBlock(
|
||||
input_size,
|
||||
channels[0],
|
||||
kernel_sizes[0],
|
||||
dilations[0],
|
||||
activation,
|
||||
groups[0],
|
||||
)
|
||||
)
|
||||
|
||||
# SE-Res2Net layers
|
||||
for i in range(1, len(channels) - 1):
|
||||
self.blocks.append(
|
||||
SERes2NetBlock(
|
||||
channels[i - 1],
|
||||
channels[i],
|
||||
res2net_scale=res2net_scale,
|
||||
se_channels=se_channels,
|
||||
kernel_size=kernel_sizes[i],
|
||||
dilation=dilations[i],
|
||||
activation=activation,
|
||||
groups=groups[i],
|
||||
)
|
||||
)
|
||||
|
||||
# Multi-layer feature aggregation
|
||||
self.mfa = TDNNBlock(
|
||||
channels[-1],
|
||||
channels[-1],
|
||||
kernel_sizes[-1],
|
||||
dilations[-1],
|
||||
activation,
|
||||
groups=groups[-1],
|
||||
)
|
||||
|
||||
# Attentive Statistical Pooling
|
||||
self.asp = AttentiveStatisticsPooling(
|
||||
channels[-1],
|
||||
attention_channels=attention_channels,
|
||||
global_context=global_context,
|
||||
)
|
||||
self.asp_bn = BatchNorm1d(input_size=channels[-1] * 2)
|
||||
|
||||
# Final linear transformation
|
||||
self.fc = Conv1d(
|
||||
in_channels=channels[-1] * 2,
|
||||
out_channels=lin_neurons,
|
||||
kernel_size=1,
|
||||
)
|
||||
|
||||
def windowed_pooling(self, x, lengths=None):
|
||||
# x: Batch, Channel, Time
|
||||
tt = x.shape[2]
|
||||
num_chunk = int(math.ceil(tt / self.window_shift))
|
||||
pad = self.window_size // 2
|
||||
x = F.pad(x, (pad, pad, 0, 0), "reflect")
|
||||
stat_list = []
|
||||
|
||||
for i in range(num_chunk):
|
||||
# B x C
|
||||
st, ed = i * self.window_shift, i * self.window_shift + self.window_size
|
||||
x = self.asp(
|
||||
x[:, :, st:ed],
|
||||
lengths=(
|
||||
torch.clamp(lengths - i, 0, self.window_size) if lengths is not None else None
|
||||
),
|
||||
)
|
||||
x = self.asp_bn(x)
|
||||
x = self.fc(x)
|
||||
stat_list.append(x)
|
||||
|
||||
return torch.cat(stat_list, dim=2)
|
||||
|
||||
def forward(self, x, lengths=None):
|
||||
"""Returns the embedding vector.
|
||||
|
||||
Arguments
|
||||
---------
|
||||
x : torch.Tensor
|
||||
Tensor of shape (batch, time, channel).
|
||||
lengths: torch.Tensor
|
||||
Tensor of shape (batch, )
|
||||
"""
|
||||
# Minimize transpose for efficiency
|
||||
x = x.transpose(1, 2)
|
||||
|
||||
xl = []
|
||||
for layer in self.blocks:
|
||||
try:
|
||||
x = layer(x, lengths=lengths)
|
||||
except TypeError:
|
||||
x = layer(x)
|
||||
xl.append(x)
|
||||
|
||||
# Multi-layer feature aggregation
|
||||
x = torch.cat(xl[1:], dim=1)
|
||||
x = self.mfa(x)
|
||||
|
||||
if self.window_size is None:
|
||||
# Attentive Statistical Pooling
|
||||
x = self.asp(x, lengths=lengths)
|
||||
x = self.asp_bn(x)
|
||||
# Final linear transformation
|
||||
x = self.fc(x)
|
||||
# x = x.transpose(1, 2)
|
||||
x = x.squeeze(2) # -> B, C
|
||||
else:
|
||||
x = self.windowed_pooling(x, lengths)
|
||||
x = x.transpose(1, 2) # -> B, T, C
|
||||
return x
|
||||
180
modules/python/vendors/FunASR/funasr/models/sond/encoder/fsmn_encoder.py
vendored
Normal file
180
modules/python/vendors/FunASR/funasr/models/sond/encoder/fsmn_encoder.py
vendored
Normal file
@@ -0,0 +1,180 @@
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
from typing import Sequence
|
||||
from typing import Tuple
|
||||
from typing import Union
|
||||
import logging
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import functional as F
|
||||
import numpy as np
|
||||
from funasr.models.transformer.utils.nets_utils import make_pad_mask
|
||||
from funasr.models.transformer.layer_norm import LayerNorm
|
||||
from funasr.models.encoder.abs_encoder import AbsEncoder
|
||||
import math
|
||||
from funasr.models.transformer.utils.repeat import repeat
|
||||
from funasr.models.transformer.utils.multi_layer_conv import FsmnFeedForward
|
||||
|
||||
|
||||
class FsmnBlock(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
n_feat,
|
||||
dropout_rate,
|
||||
kernel_size,
|
||||
fsmn_shift=0,
|
||||
):
|
||||
super().__init__()
|
||||
self.dropout = nn.Dropout(p=dropout_rate)
|
||||
self.fsmn_block = nn.Conv1d(
|
||||
n_feat, n_feat, kernel_size, stride=1, padding=0, groups=n_feat, bias=False
|
||||
)
|
||||
# padding
|
||||
left_padding = (kernel_size - 1) // 2
|
||||
if fsmn_shift > 0:
|
||||
left_padding = left_padding + fsmn_shift
|
||||
right_padding = kernel_size - 1 - left_padding
|
||||
self.pad_fn = nn.ConstantPad1d((left_padding, right_padding), 0.0)
|
||||
|
||||
def forward(self, inputs, mask, mask_shfit_chunk=None):
|
||||
b, t, d = inputs.size()
|
||||
if mask is not None:
|
||||
mask = torch.reshape(mask, (b, -1, 1))
|
||||
if mask_shfit_chunk is not None:
|
||||
mask = mask * mask_shfit_chunk
|
||||
|
||||
inputs = inputs * mask
|
||||
x = inputs.transpose(1, 2)
|
||||
x = self.pad_fn(x)
|
||||
x = self.fsmn_block(x)
|
||||
x = x.transpose(1, 2)
|
||||
x = x + inputs
|
||||
x = self.dropout(x)
|
||||
return x * mask
|
||||
|
||||
|
||||
class EncoderLayer(torch.nn.Module):
|
||||
def __init__(self, in_size, size, feed_forward, fsmn_block, dropout_rate=0.0):
|
||||
super().__init__()
|
||||
self.in_size = in_size
|
||||
self.size = size
|
||||
self.ffn = feed_forward
|
||||
self.memory = fsmn_block
|
||||
self.dropout = nn.Dropout(dropout_rate)
|
||||
|
||||
def forward(
|
||||
self, xs_pad: torch.Tensor, mask: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# xs_pad in Batch, Time, Dim
|
||||
|
||||
context = self.ffn(xs_pad)[0]
|
||||
memory = self.memory(context, mask)
|
||||
|
||||
memory = self.dropout(memory)
|
||||
if self.in_size == self.size:
|
||||
return memory + xs_pad, mask
|
||||
|
||||
return memory, mask
|
||||
|
||||
|
||||
class FsmnEncoder(AbsEncoder):
|
||||
"""Encoder using Fsmn"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_units,
|
||||
filter_size,
|
||||
fsmn_num_layers,
|
||||
dnn_num_layers,
|
||||
num_memory_units=512,
|
||||
ffn_inner_dim=2048,
|
||||
dropout_rate=0.0,
|
||||
shift=0,
|
||||
position_encoder=None,
|
||||
sample_rate=1,
|
||||
out_units=None,
|
||||
tf2torch_tensor_name_prefix_torch="post_net",
|
||||
tf2torch_tensor_name_prefix_tf="EAND/post_net",
|
||||
):
|
||||
"""Initializes the parameters of the encoder.
|
||||
|
||||
Args:
|
||||
filter_size: the total order of memory block
|
||||
fsmn_num_layers: The number of fsmn layers.
|
||||
dnn_num_layers: The number of dnn layers
|
||||
num_units: The number of memory units.
|
||||
ffn_inner_dim: The number of units of the inner linear transformation
|
||||
in the feed forward layer.
|
||||
dropout_rate: The probability to drop units from the outputs.
|
||||
shift: left padding, to control delay
|
||||
position_encoder: The :class:`opennmt.layers.position.PositionEncoder` to
|
||||
apply on inputs or ``None``.
|
||||
"""
|
||||
super(FsmnEncoder, self).__init__()
|
||||
self.in_units = in_units
|
||||
self.filter_size = filter_size
|
||||
self.fsmn_num_layers = fsmn_num_layers
|
||||
self.dnn_num_layers = dnn_num_layers
|
||||
self.num_memory_units = num_memory_units
|
||||
self.ffn_inner_dim = ffn_inner_dim
|
||||
self.dropout_rate = dropout_rate
|
||||
self.shift = shift
|
||||
if not isinstance(shift, list):
|
||||
self.shift = [shift for _ in range(self.fsmn_num_layers)]
|
||||
self.sample_rate = sample_rate
|
||||
if not isinstance(sample_rate, list):
|
||||
self.sample_rate = [sample_rate for _ in range(self.fsmn_num_layers)]
|
||||
self.position_encoder = position_encoder
|
||||
self.dropout = nn.Dropout(dropout_rate)
|
||||
self.out_units = out_units
|
||||
self.tf2torch_tensor_name_prefix_torch = tf2torch_tensor_name_prefix_torch
|
||||
self.tf2torch_tensor_name_prefix_tf = tf2torch_tensor_name_prefix_tf
|
||||
|
||||
self.fsmn_layers = repeat(
|
||||
self.fsmn_num_layers,
|
||||
lambda lnum: EncoderLayer(
|
||||
in_units if lnum == 0 else num_memory_units,
|
||||
num_memory_units,
|
||||
FsmnFeedForward(
|
||||
in_units if lnum == 0 else num_memory_units,
|
||||
ffn_inner_dim,
|
||||
num_memory_units,
|
||||
1,
|
||||
dropout_rate,
|
||||
),
|
||||
FsmnBlock(num_memory_units, dropout_rate, filter_size, self.shift[lnum]),
|
||||
),
|
||||
)
|
||||
|
||||
self.dnn_layers = repeat(
|
||||
dnn_num_layers,
|
||||
lambda lnum: FsmnFeedForward(
|
||||
num_memory_units,
|
||||
ffn_inner_dim,
|
||||
num_memory_units,
|
||||
1,
|
||||
dropout_rate,
|
||||
),
|
||||
)
|
||||
if out_units is not None:
|
||||
self.conv1d = nn.Conv1d(num_memory_units, out_units, 1, 1)
|
||||
|
||||
def output_size(self) -> int:
|
||||
return self.num_memory_units
|
||||
|
||||
def forward(
|
||||
self, xs_pad: torch.Tensor, ilens: torch.Tensor, prev_states: torch.Tensor = None
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
|
||||
inputs = xs_pad
|
||||
if self.position_encoder is not None:
|
||||
inputs = self.position_encoder(inputs)
|
||||
|
||||
inputs = self.dropout(inputs)
|
||||
masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
|
||||
inputs = self.fsmn_layers(inputs, masks)[0]
|
||||
inputs = self.dnn_layers(inputs)[0]
|
||||
|
||||
if self.out_units is not None:
|
||||
inputs = self.conv1d(inputs.transpose(1, 2)).transpose(1, 2)
|
||||
|
||||
return inputs, ilens, None
|
||||
462
modules/python/vendors/FunASR/funasr/models/sond/encoder/resnet34_encoder.py
vendored
Normal file
462
modules/python/vendors/FunASR/funasr/models/sond/encoder/resnet34_encoder.py
vendored
Normal file
@@ -0,0 +1,462 @@
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
from funasr.models.encoder.abs_encoder import AbsEncoder
|
||||
from typing import Tuple, Optional
|
||||
from funasr.models.pooling.statistic_pooling import statistic_pooling, windowed_statistic_pooling
|
||||
from collections import OrderedDict
|
||||
import logging
|
||||
import numpy as np
|
||||
|
||||
|
||||
class BasicLayer(torch.nn.Module):
|
||||
|
||||
def __init__(self, in_filters: int, filters: int, stride: int, bn_momentum: float = 0.5):
|
||||
|
||||
super().__init__()
|
||||
self.stride = stride
|
||||
self.in_filters = in_filters
|
||||
self.filters = filters
|
||||
|
||||
self.bn1 = torch.nn.BatchNorm2d(in_filters, eps=1e-3, momentum=bn_momentum, affine=True)
|
||||
self.relu1 = torch.nn.ReLU()
|
||||
self.conv1 = torch.nn.Conv2d(in_filters, filters, 3, stride, bias=False)
|
||||
|
||||
self.bn2 = torch.nn.BatchNorm2d(filters, eps=1e-3, momentum=bn_momentum, affine=True)
|
||||
self.relu2 = torch.nn.ReLU()
|
||||
self.conv2 = torch.nn.Conv2d(filters, filters, 3, 1, bias=False)
|
||||
|
||||
if in_filters != filters or stride > 1:
|
||||
self.conv_sc = torch.nn.Conv2d(in_filters, filters, 1, stride, bias=False)
|
||||
self.bn_sc = torch.nn.BatchNorm2d(filters, eps=1e-3, momentum=bn_momentum, affine=True)
|
||||
|
||||
def proper_padding(self, x, stride):
|
||||
# align padding mode to tf.layers.conv2d with padding_mod="same"
|
||||
if stride == 1:
|
||||
return F.pad(x, (1, 1, 1, 1), "constant", 0)
|
||||
elif stride == 2:
|
||||
h, w = x.size(2), x.size(3)
|
||||
# (left, right, top, bottom)
|
||||
return F.pad(x, (w % 2, 1, h % 2, 1), "constant", 0)
|
||||
|
||||
def forward(self, xs_pad, ilens):
|
||||
identity = xs_pad
|
||||
if self.in_filters != self.filters or self.stride > 1:
|
||||
identity = self.conv_sc(identity)
|
||||
identity = self.bn_sc(identity)
|
||||
|
||||
xs_pad = self.relu1(self.bn1(xs_pad))
|
||||
xs_pad = self.proper_padding(xs_pad, self.stride)
|
||||
xs_pad = self.conv1(xs_pad)
|
||||
|
||||
xs_pad = self.relu2(self.bn2(xs_pad))
|
||||
xs_pad = self.proper_padding(xs_pad, 1)
|
||||
xs_pad = self.conv2(xs_pad)
|
||||
|
||||
if self.stride == 2:
|
||||
ilens = (ilens + 1) // self.stride
|
||||
|
||||
return xs_pad + identity, ilens
|
||||
|
||||
|
||||
class BasicBlock(torch.nn.Module):
|
||||
def __init__(self, in_filters, filters, num_layer, stride, bn_momentum=0.5):
|
||||
super().__init__()
|
||||
self.num_layer = num_layer
|
||||
|
||||
for i in range(num_layer):
|
||||
layer = BasicLayer(
|
||||
in_filters if i == 0 else filters, filters, stride if i == 0 else 1, bn_momentum
|
||||
)
|
||||
self.add_module("layer_{}".format(i), layer)
|
||||
|
||||
def forward(self, xs_pad, ilens):
|
||||
|
||||
for i in range(self.num_layer):
|
||||
xs_pad, ilens = self._modules["layer_{}".format(i)](xs_pad, ilens)
|
||||
|
||||
return xs_pad, ilens
|
||||
|
||||
|
||||
class ResNet34(AbsEncoder):
|
||||
def __init__(
|
||||
self,
|
||||
input_size,
|
||||
use_head_conv=True,
|
||||
batchnorm_momentum=0.5,
|
||||
use_head_maxpool=False,
|
||||
num_nodes_pooling_layer=256,
|
||||
layers_in_block=(3, 4, 6, 3),
|
||||
filters_in_block=(32, 64, 128, 256),
|
||||
):
|
||||
super(ResNet34, self).__init__()
|
||||
|
||||
self.use_head_conv = use_head_conv
|
||||
self.use_head_maxpool = use_head_maxpool
|
||||
self.num_nodes_pooling_layer = num_nodes_pooling_layer
|
||||
self.layers_in_block = layers_in_block
|
||||
self.filters_in_block = filters_in_block
|
||||
self.input_size = input_size
|
||||
|
||||
pre_filters = filters_in_block[0]
|
||||
if use_head_conv:
|
||||
self.pre_conv = torch.nn.Conv2d(
|
||||
1, pre_filters, 3, 1, 1, bias=False, padding_mode="zeros"
|
||||
)
|
||||
self.pre_conv_bn = torch.nn.BatchNorm2d(
|
||||
pre_filters, eps=1e-3, momentum=batchnorm_momentum
|
||||
)
|
||||
|
||||
if use_head_maxpool:
|
||||
self.head_maxpool = torch.nn.MaxPool2d(3, 1, padding=1)
|
||||
|
||||
for i in range(len(layers_in_block)):
|
||||
if i == 0:
|
||||
in_filters = pre_filters if self.use_head_conv else 1
|
||||
else:
|
||||
in_filters = filters_in_block[i - 1]
|
||||
|
||||
block = BasicBlock(
|
||||
in_filters,
|
||||
filters=filters_in_block[i],
|
||||
num_layer=layers_in_block[i],
|
||||
stride=1 if i == 0 else 2,
|
||||
bn_momentum=batchnorm_momentum,
|
||||
)
|
||||
self.add_module("block_{}".format(i), block)
|
||||
|
||||
self.resnet0_dense = torch.nn.Conv2d(filters_in_block[-1], num_nodes_pooling_layer, 1)
|
||||
self.resnet0_bn = torch.nn.BatchNorm2d(
|
||||
num_nodes_pooling_layer, eps=1e-3, momentum=batchnorm_momentum
|
||||
)
|
||||
|
||||
self.time_ds_ratio = 8
|
||||
|
||||
def output_size(self) -> int:
|
||||
return self.num_nodes_pooling_layer
|
||||
|
||||
def forward(
|
||||
self, xs_pad: torch.Tensor, ilens: torch.Tensor, prev_states: torch.Tensor = None
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
|
||||
features = xs_pad
|
||||
assert (
|
||||
features.size(-1) == self.input_size
|
||||
), "Dimension of features {} doesn't match the input_size {}.".format(
|
||||
features.size(-1), self.input_size
|
||||
)
|
||||
features = torch.unsqueeze(features, dim=1)
|
||||
if self.use_head_conv:
|
||||
features = self.pre_conv(features)
|
||||
features = self.pre_conv_bn(features)
|
||||
features = F.relu(features)
|
||||
|
||||
if self.use_head_maxpool:
|
||||
features = self.head_maxpool(features)
|
||||
|
||||
resnet_outs, resnet_out_lens = features, ilens
|
||||
for i in range(len(self.layers_in_block)):
|
||||
block = self._modules["block_{}".format(i)]
|
||||
resnet_outs, resnet_out_lens = block(resnet_outs, resnet_out_lens)
|
||||
|
||||
features = self.resnet0_dense(resnet_outs)
|
||||
features = F.relu(features)
|
||||
features = self.resnet0_bn(features)
|
||||
|
||||
return features, resnet_out_lens
|
||||
|
||||
|
||||
# Note: For training, this implement is not equivalent to tf because of the kernel_regularizer in tf.layers.
|
||||
# TODO: implement kernel_regularizer in torch with munal loss addition or weigth_decay in the optimizer
|
||||
class ResNet34_SP_L2Reg(AbsEncoder):
|
||||
def __init__(
|
||||
self,
|
||||
input_size,
|
||||
use_head_conv=True,
|
||||
batchnorm_momentum=0.5,
|
||||
use_head_maxpool=False,
|
||||
num_nodes_pooling_layer=256,
|
||||
layers_in_block=(3, 4, 6, 3),
|
||||
filters_in_block=(32, 64, 128, 256),
|
||||
tf2torch_tensor_name_prefix_torch="encoder",
|
||||
tf2torch_tensor_name_prefix_tf="EAND/speech_encoder",
|
||||
tf_train_steps=720000,
|
||||
):
|
||||
super(ResNet34_SP_L2Reg, self).__init__()
|
||||
|
||||
self.use_head_conv = use_head_conv
|
||||
self.use_head_maxpool = use_head_maxpool
|
||||
self.num_nodes_pooling_layer = num_nodes_pooling_layer
|
||||
self.layers_in_block = layers_in_block
|
||||
self.filters_in_block = filters_in_block
|
||||
self.input_size = input_size
|
||||
self.tf2torch_tensor_name_prefix_torch = tf2torch_tensor_name_prefix_torch
|
||||
self.tf2torch_tensor_name_prefix_tf = tf2torch_tensor_name_prefix_tf
|
||||
self.tf_train_steps = tf_train_steps
|
||||
|
||||
pre_filters = filters_in_block[0]
|
||||
if use_head_conv:
|
||||
self.pre_conv = torch.nn.Conv2d(
|
||||
1, pre_filters, 3, 1, 1, bias=False, padding_mode="zeros"
|
||||
)
|
||||
self.pre_conv_bn = torch.nn.BatchNorm2d(
|
||||
pre_filters, eps=1e-3, momentum=batchnorm_momentum
|
||||
)
|
||||
|
||||
if use_head_maxpool:
|
||||
self.head_maxpool = torch.nn.MaxPool2d(3, 1, padding=1)
|
||||
|
||||
for i in range(len(layers_in_block)):
|
||||
if i == 0:
|
||||
in_filters = pre_filters if self.use_head_conv else 1
|
||||
else:
|
||||
in_filters = filters_in_block[i - 1]
|
||||
|
||||
block = BasicBlock(
|
||||
in_filters,
|
||||
filters=filters_in_block[i],
|
||||
num_layer=layers_in_block[i],
|
||||
stride=1 if i == 0 else 2,
|
||||
bn_momentum=batchnorm_momentum,
|
||||
)
|
||||
self.add_module("block_{}".format(i), block)
|
||||
|
||||
self.resnet0_dense = torch.nn.Conv1d(
|
||||
filters_in_block[-1] * input_size // 8, num_nodes_pooling_layer, 1
|
||||
)
|
||||
self.resnet0_bn = torch.nn.BatchNorm1d(
|
||||
num_nodes_pooling_layer, eps=1e-3, momentum=batchnorm_momentum
|
||||
)
|
||||
|
||||
self.time_ds_ratio = 8
|
||||
|
||||
def output_size(self) -> int:
|
||||
return self.num_nodes_pooling_layer
|
||||
|
||||
def forward(
|
||||
self, xs_pad: torch.Tensor, ilens: torch.Tensor, prev_states: torch.Tensor = None
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
|
||||
features = xs_pad
|
||||
assert (
|
||||
features.size(-1) == self.input_size
|
||||
), "Dimension of features {} doesn't match the input_size {}.".format(
|
||||
features.size(-1), self.input_size
|
||||
)
|
||||
features = torch.unsqueeze(features, dim=1)
|
||||
if self.use_head_conv:
|
||||
features = self.pre_conv(features)
|
||||
features = self.pre_conv_bn(features)
|
||||
features = F.relu(features)
|
||||
|
||||
if self.use_head_maxpool:
|
||||
features = self.head_maxpool(features)
|
||||
|
||||
resnet_outs, resnet_out_lens = features, ilens
|
||||
for i in range(len(self.layers_in_block)):
|
||||
block = self._modules["block_{}".format(i)]
|
||||
resnet_outs, resnet_out_lens = block(resnet_outs, resnet_out_lens)
|
||||
|
||||
# B, C, T, F
|
||||
bb, cc, tt, ff = resnet_outs.shape
|
||||
resnet_outs = torch.reshape(resnet_outs.permute(0, 3, 1, 2), [bb, ff * cc, tt])
|
||||
features = self.resnet0_dense(resnet_outs)
|
||||
features = F.relu(features)
|
||||
features = self.resnet0_bn(features)
|
||||
|
||||
return features, resnet_out_lens
|
||||
|
||||
|
||||
class ResNet34Diar(ResNet34):
|
||||
def __init__(
|
||||
self,
|
||||
input_size,
|
||||
embedding_node="resnet1_dense",
|
||||
use_head_conv=True,
|
||||
batchnorm_momentum=0.5,
|
||||
use_head_maxpool=False,
|
||||
num_nodes_pooling_layer=256,
|
||||
layers_in_block=(3, 4, 6, 3),
|
||||
filters_in_block=(32, 64, 128, 256),
|
||||
num_nodes_resnet1=256,
|
||||
num_nodes_last_layer=256,
|
||||
pooling_type="window_shift",
|
||||
pool_size=20,
|
||||
stride=1,
|
||||
tf2torch_tensor_name_prefix_torch="encoder",
|
||||
tf2torch_tensor_name_prefix_tf="seq2seq/speech_encoder",
|
||||
):
|
||||
"""
|
||||
Author: Speech Lab, Alibaba Group, China
|
||||
SOND: Speaker Overlap-aware Neural Diarization for Multi-party Meeting Analysis
|
||||
https://arxiv.org/abs/2211.10243
|
||||
"""
|
||||
|
||||
super(ResNet34Diar, self).__init__(
|
||||
input_size,
|
||||
use_head_conv=use_head_conv,
|
||||
batchnorm_momentum=batchnorm_momentum,
|
||||
use_head_maxpool=use_head_maxpool,
|
||||
num_nodes_pooling_layer=num_nodes_pooling_layer,
|
||||
layers_in_block=layers_in_block,
|
||||
filters_in_block=filters_in_block,
|
||||
)
|
||||
|
||||
self.embedding_node = embedding_node
|
||||
self.num_nodes_resnet1 = num_nodes_resnet1
|
||||
self.num_nodes_last_layer = num_nodes_last_layer
|
||||
self.pooling_type = pooling_type
|
||||
self.pool_size = pool_size
|
||||
self.stride = stride
|
||||
self.tf2torch_tensor_name_prefix_torch = tf2torch_tensor_name_prefix_torch
|
||||
self.tf2torch_tensor_name_prefix_tf = tf2torch_tensor_name_prefix_tf
|
||||
|
||||
self.resnet1_dense = torch.nn.Linear(num_nodes_pooling_layer * 2, num_nodes_resnet1)
|
||||
self.resnet1_bn = torch.nn.BatchNorm1d(
|
||||
num_nodes_resnet1, eps=1e-3, momentum=batchnorm_momentum
|
||||
)
|
||||
|
||||
self.resnet2_dense = torch.nn.Linear(num_nodes_resnet1, num_nodes_last_layer)
|
||||
self.resnet2_bn = torch.nn.BatchNorm1d(
|
||||
num_nodes_last_layer, eps=1e-3, momentum=batchnorm_momentum
|
||||
)
|
||||
|
||||
def output_size(self) -> int:
|
||||
if self.embedding_node.startswith("resnet1"):
|
||||
return self.num_nodes_resnet1
|
||||
elif self.embedding_node.startswith("resnet2"):
|
||||
return self.num_nodes_last_layer
|
||||
|
||||
return self.num_nodes_pooling_layer
|
||||
|
||||
def forward(
|
||||
self,
|
||||
xs_pad: torch.Tensor,
|
||||
ilens: torch.Tensor,
|
||||
prev_states: torch.Tensor = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
|
||||
|
||||
endpoints = OrderedDict()
|
||||
res_out, ilens = super().forward(xs_pad, ilens)
|
||||
endpoints["resnet0_bn"] = res_out
|
||||
if self.pooling_type == "frame_gsp":
|
||||
features = statistic_pooling(res_out, ilens, (3,))
|
||||
else:
|
||||
features, ilens = windowed_statistic_pooling(
|
||||
res_out, ilens, (2, 3), self.pool_size, self.stride
|
||||
)
|
||||
features = features.transpose(1, 2)
|
||||
endpoints["pooling"] = features
|
||||
|
||||
features = self.resnet1_dense(features)
|
||||
endpoints["resnet1_dense"] = features
|
||||
features = F.relu(features)
|
||||
endpoints["resnet1_relu"] = features
|
||||
features = self.resnet1_bn(features.transpose(1, 2)).transpose(1, 2)
|
||||
endpoints["resnet1_bn"] = features
|
||||
|
||||
features = self.resnet2_dense(features)
|
||||
endpoints["resnet2_dense"] = features
|
||||
features = F.relu(features)
|
||||
endpoints["resnet2_relu"] = features
|
||||
features = self.resnet2_bn(features.transpose(1, 2)).transpose(1, 2)
|
||||
endpoints["resnet2_bn"] = features
|
||||
|
||||
return endpoints[self.embedding_node], ilens, None
|
||||
|
||||
|
||||
class ResNet34SpL2RegDiar(ResNet34_SP_L2Reg):
|
||||
def __init__(
|
||||
self,
|
||||
input_size,
|
||||
embedding_node="resnet1_dense",
|
||||
use_head_conv=True,
|
||||
batchnorm_momentum=0.5,
|
||||
use_head_maxpool=False,
|
||||
num_nodes_pooling_layer=256,
|
||||
layers_in_block=(3, 4, 6, 3),
|
||||
filters_in_block=(32, 64, 128, 256),
|
||||
num_nodes_resnet1=256,
|
||||
num_nodes_last_layer=256,
|
||||
pooling_type="window_shift",
|
||||
pool_size=20,
|
||||
stride=1,
|
||||
tf2torch_tensor_name_prefix_torch="encoder",
|
||||
tf2torch_tensor_name_prefix_tf="seq2seq/speech_encoder",
|
||||
):
|
||||
"""
|
||||
Author: Speech Lab, Alibaba Group, China
|
||||
TOLD: A Novel Two-Stage Overlap-Aware Framework for Speaker Diarization
|
||||
https://arxiv.org/abs/2303.05397
|
||||
"""
|
||||
|
||||
super(ResNet34SpL2RegDiar, self).__init__(
|
||||
input_size,
|
||||
use_head_conv=use_head_conv,
|
||||
batchnorm_momentum=batchnorm_momentum,
|
||||
use_head_maxpool=use_head_maxpool,
|
||||
num_nodes_pooling_layer=num_nodes_pooling_layer,
|
||||
layers_in_block=layers_in_block,
|
||||
filters_in_block=filters_in_block,
|
||||
)
|
||||
|
||||
self.embedding_node = embedding_node
|
||||
self.num_nodes_resnet1 = num_nodes_resnet1
|
||||
self.num_nodes_last_layer = num_nodes_last_layer
|
||||
self.pooling_type = pooling_type
|
||||
self.pool_size = pool_size
|
||||
self.stride = stride
|
||||
self.tf2torch_tensor_name_prefix_torch = tf2torch_tensor_name_prefix_torch
|
||||
self.tf2torch_tensor_name_prefix_tf = tf2torch_tensor_name_prefix_tf
|
||||
|
||||
self.resnet1_dense = torch.nn.Linear(num_nodes_pooling_layer * 2, num_nodes_resnet1)
|
||||
self.resnet1_bn = torch.nn.BatchNorm1d(
|
||||
num_nodes_resnet1, eps=1e-3, momentum=batchnorm_momentum
|
||||
)
|
||||
|
||||
self.resnet2_dense = torch.nn.Linear(num_nodes_resnet1, num_nodes_last_layer)
|
||||
self.resnet2_bn = torch.nn.BatchNorm1d(
|
||||
num_nodes_last_layer, eps=1e-3, momentum=batchnorm_momentum
|
||||
)
|
||||
|
||||
def output_size(self) -> int:
|
||||
if self.embedding_node.startswith("resnet1"):
|
||||
return self.num_nodes_resnet1
|
||||
elif self.embedding_node.startswith("resnet2"):
|
||||
return self.num_nodes_last_layer
|
||||
|
||||
return self.num_nodes_pooling_layer
|
||||
|
||||
def forward(
|
||||
self,
|
||||
xs_pad: torch.Tensor,
|
||||
ilens: torch.Tensor,
|
||||
prev_states: torch.Tensor = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
|
||||
|
||||
endpoints = OrderedDict()
|
||||
res_out, ilens = super().forward(xs_pad, ilens)
|
||||
endpoints["resnet0_bn"] = res_out
|
||||
if self.pooling_type == "frame_gsp":
|
||||
features = statistic_pooling(res_out, ilens, (2,))
|
||||
else:
|
||||
features, ilens = windowed_statistic_pooling(
|
||||
res_out, ilens, (2,), self.pool_size, self.stride
|
||||
)
|
||||
features = features.transpose(1, 2)
|
||||
endpoints["pooling"] = features
|
||||
|
||||
features = self.resnet1_dense(features)
|
||||
endpoints["resnet1_dense"] = features
|
||||
features = F.relu(features)
|
||||
endpoints["resnet1_relu"] = features
|
||||
features = self.resnet1_bn(features.transpose(1, 2)).transpose(1, 2)
|
||||
endpoints["resnet1_bn"] = features
|
||||
|
||||
features = self.resnet2_dense(features)
|
||||
endpoints["resnet2_dense"] = features
|
||||
features = F.relu(features)
|
||||
endpoints["resnet2_relu"] = features
|
||||
features = self.resnet2_bn(features.transpose(1, 2)).transpose(1, 2)
|
||||
endpoints["resnet2_bn"] = features
|
||||
|
||||
return endpoints[self.embedding_node], ilens, None
|
||||
333
modules/python/vendors/FunASR/funasr/models/sond/encoder/self_attention_encoder.py
vendored
Normal file
333
modules/python/vendors/FunASR/funasr/models/sond/encoder/self_attention_encoder.py
vendored
Normal file
@@ -0,0 +1,333 @@
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
from typing import Sequence
|
||||
from typing import Tuple
|
||||
from typing import Union
|
||||
import logging
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from funasr.models.scama.chunk_utilis import overlap_chunk
|
||||
import numpy as np
|
||||
from funasr.models.transformer.utils.nets_utils import make_pad_mask
|
||||
from funasr.models.sond.attention import MultiHeadSelfAttention
|
||||
from funasr.models.transformer.embedding import SinusoidalPositionEncoder
|
||||
from funasr.models.transformer.layer_norm import LayerNorm
|
||||
from funasr.models.transformer.utils.multi_layer_conv import Conv1dLinear
|
||||
from funasr.models.transformer.utils.multi_layer_conv import MultiLayeredConv1d
|
||||
from funasr.models.transformer.positionwise_feed_forward import (
|
||||
PositionwiseFeedForward, # noqa: H301
|
||||
)
|
||||
from funasr.models.transformer.utils.repeat import repeat
|
||||
from funasr.models.transformer.utils.subsampling import Conv2dSubsampling
|
||||
from funasr.models.transformer.utils.subsampling import Conv2dSubsampling2
|
||||
from funasr.models.transformer.utils.subsampling import Conv2dSubsampling6
|
||||
from funasr.models.transformer.utils.subsampling import Conv2dSubsampling8
|
||||
from funasr.models.transformer.utils.subsampling import TooShortUttError
|
||||
from funasr.models.transformer.utils.subsampling import check_short_utt
|
||||
from funasr.models.ctc import CTC
|
||||
from funasr.models.encoder.abs_encoder import AbsEncoder
|
||||
|
||||
|
||||
class EncoderLayer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_size,
|
||||
size,
|
||||
self_attn,
|
||||
feed_forward,
|
||||
dropout_rate,
|
||||
normalize_before=True,
|
||||
concat_after=False,
|
||||
stochastic_depth_rate=0.0,
|
||||
):
|
||||
"""Construct an EncoderLayer object."""
|
||||
super(EncoderLayer, self).__init__()
|
||||
self.self_attn = self_attn
|
||||
self.feed_forward = feed_forward
|
||||
self.norm1 = LayerNorm(in_size)
|
||||
self.norm2 = LayerNorm(size)
|
||||
self.dropout = nn.Dropout(dropout_rate)
|
||||
self.in_size = in_size
|
||||
self.size = size
|
||||
self.normalize_before = normalize_before
|
||||
self.concat_after = concat_after
|
||||
if self.concat_after:
|
||||
self.concat_linear = nn.Linear(size + size, size)
|
||||
self.stochastic_depth_rate = stochastic_depth_rate
|
||||
self.dropout_rate = dropout_rate
|
||||
|
||||
def forward(self, x, mask, cache=None, mask_att_chunk_encoder=None):
|
||||
"""Compute encoded features.
|
||||
|
||||
Args:
|
||||
x_input (torch.Tensor): Input tensor (#batch, time, size).
|
||||
mask (torch.Tensor): Mask tensor for the input (#batch, time).
|
||||
cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Output tensor (#batch, time, size).
|
||||
torch.Tensor: Mask tensor (#batch, time).
|
||||
|
||||
"""
|
||||
skip_layer = False
|
||||
# with stochastic depth, residual connection `x + f(x)` becomes
|
||||
# `x <- x + 1 / (1 - p) * f(x)` at training time.
|
||||
stoch_layer_coeff = 1.0
|
||||
if self.training and self.stochastic_depth_rate > 0:
|
||||
skip_layer = torch.rand(1).item() < self.stochastic_depth_rate
|
||||
stoch_layer_coeff = 1.0 / (1 - self.stochastic_depth_rate)
|
||||
|
||||
if skip_layer:
|
||||
if cache is not None:
|
||||
x = torch.cat([cache, x], dim=1)
|
||||
return x, mask
|
||||
|
||||
residual = x
|
||||
if self.normalize_before:
|
||||
x = self.norm1(x)
|
||||
|
||||
if self.concat_after:
|
||||
x_concat = torch.cat(
|
||||
(x, self.self_attn(x, mask, mask_att_chunk_encoder=mask_att_chunk_encoder)), dim=-1
|
||||
)
|
||||
if self.in_size == self.size:
|
||||
x = residual + stoch_layer_coeff * self.concat_linear(x_concat)
|
||||
else:
|
||||
x = stoch_layer_coeff * self.concat_linear(x_concat)
|
||||
else:
|
||||
if self.in_size == self.size:
|
||||
x = residual + stoch_layer_coeff * self.dropout(
|
||||
self.self_attn(x, mask, mask_att_chunk_encoder=mask_att_chunk_encoder)
|
||||
)
|
||||
else:
|
||||
x = stoch_layer_coeff * self.dropout(
|
||||
self.self_attn(x, mask, mask_att_chunk_encoder=mask_att_chunk_encoder)
|
||||
)
|
||||
if not self.normalize_before:
|
||||
x = self.norm1(x)
|
||||
|
||||
residual = x
|
||||
if self.normalize_before:
|
||||
x = self.norm2(x)
|
||||
x = residual + stoch_layer_coeff * self.dropout(self.feed_forward(x))
|
||||
if not self.normalize_before:
|
||||
x = self.norm2(x)
|
||||
|
||||
return x, mask, cache, mask_att_chunk_encoder
|
||||
|
||||
|
||||
class SelfAttentionEncoder(AbsEncoder):
|
||||
"""
|
||||
Author: Speech Lab of DAMO Academy, Alibaba Group
|
||||
Self attention encoder in OpenNMT framework
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_size: int,
|
||||
output_size: int = 256,
|
||||
attention_heads: int = 4,
|
||||
linear_units: int = 2048,
|
||||
num_blocks: int = 6,
|
||||
dropout_rate: float = 0.1,
|
||||
positional_dropout_rate: float = 0.1,
|
||||
attention_dropout_rate: float = 0.0,
|
||||
input_layer: Optional[str] = "conv2d",
|
||||
pos_enc_class=SinusoidalPositionEncoder,
|
||||
normalize_before: bool = True,
|
||||
concat_after: bool = False,
|
||||
positionwise_layer_type: str = "linear",
|
||||
positionwise_conv_kernel_size: int = 1,
|
||||
padding_idx: int = -1,
|
||||
interctc_layer_idx: List[int] = [],
|
||||
interctc_use_conditioning: bool = False,
|
||||
tf2torch_tensor_name_prefix_torch: str = "encoder",
|
||||
tf2torch_tensor_name_prefix_tf: str = "seq2seq/encoder",
|
||||
out_units=None,
|
||||
):
|
||||
super().__init__()
|
||||
self._output_size = output_size
|
||||
|
||||
if input_layer == "linear":
|
||||
self.embed = torch.nn.Sequential(
|
||||
torch.nn.Linear(input_size, output_size),
|
||||
torch.nn.LayerNorm(output_size),
|
||||
torch.nn.Dropout(dropout_rate),
|
||||
torch.nn.ReLU(),
|
||||
pos_enc_class(output_size, positional_dropout_rate),
|
||||
)
|
||||
elif input_layer == "conv2d":
|
||||
self.embed = Conv2dSubsampling(input_size, output_size, dropout_rate)
|
||||
elif input_layer == "conv2d2":
|
||||
self.embed = Conv2dSubsampling2(input_size, output_size, dropout_rate)
|
||||
elif input_layer == "conv2d6":
|
||||
self.embed = Conv2dSubsampling6(input_size, output_size, dropout_rate)
|
||||
elif input_layer == "conv2d8":
|
||||
self.embed = Conv2dSubsampling8(input_size, output_size, dropout_rate)
|
||||
elif input_layer == "embed":
|
||||
self.embed = torch.nn.Sequential(
|
||||
torch.nn.Embedding(input_size, output_size, padding_idx=padding_idx),
|
||||
SinusoidalPositionEncoder(),
|
||||
)
|
||||
elif input_layer is None:
|
||||
if input_size == output_size:
|
||||
self.embed = None
|
||||
else:
|
||||
self.embed = torch.nn.Linear(input_size, output_size)
|
||||
elif input_layer == "pe":
|
||||
self.embed = SinusoidalPositionEncoder()
|
||||
elif input_layer == "null":
|
||||
self.embed = None
|
||||
else:
|
||||
raise ValueError("unknown input_layer: " + input_layer)
|
||||
self.normalize_before = normalize_before
|
||||
if positionwise_layer_type == "linear":
|
||||
positionwise_layer = PositionwiseFeedForward
|
||||
positionwise_layer_args = (
|
||||
output_size,
|
||||
linear_units,
|
||||
dropout_rate,
|
||||
)
|
||||
elif positionwise_layer_type == "conv1d":
|
||||
positionwise_layer = MultiLayeredConv1d
|
||||
positionwise_layer_args = (
|
||||
output_size,
|
||||
linear_units,
|
||||
positionwise_conv_kernel_size,
|
||||
dropout_rate,
|
||||
)
|
||||
elif positionwise_layer_type == "conv1d-linear":
|
||||
positionwise_layer = Conv1dLinear
|
||||
positionwise_layer_args = (
|
||||
output_size,
|
||||
linear_units,
|
||||
positionwise_conv_kernel_size,
|
||||
dropout_rate,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError("Support only linear or conv1d.")
|
||||
|
||||
self.encoders = repeat(
|
||||
num_blocks,
|
||||
lambda lnum: (
|
||||
EncoderLayer(
|
||||
output_size,
|
||||
output_size,
|
||||
MultiHeadSelfAttention(
|
||||
attention_heads,
|
||||
output_size,
|
||||
output_size,
|
||||
attention_dropout_rate,
|
||||
),
|
||||
positionwise_layer(*positionwise_layer_args),
|
||||
dropout_rate,
|
||||
normalize_before,
|
||||
concat_after,
|
||||
)
|
||||
if lnum > 0
|
||||
else EncoderLayer(
|
||||
input_size,
|
||||
output_size,
|
||||
MultiHeadSelfAttention(
|
||||
attention_heads,
|
||||
input_size if input_layer == "pe" or input_layer == "null" else output_size,
|
||||
output_size,
|
||||
attention_dropout_rate,
|
||||
),
|
||||
positionwise_layer(*positionwise_layer_args),
|
||||
dropout_rate,
|
||||
normalize_before,
|
||||
concat_after,
|
||||
)
|
||||
),
|
||||
)
|
||||
if self.normalize_before:
|
||||
self.after_norm = LayerNorm(output_size)
|
||||
|
||||
self.interctc_layer_idx = interctc_layer_idx
|
||||
if len(interctc_layer_idx) > 0:
|
||||
assert 0 < min(interctc_layer_idx) and max(interctc_layer_idx) < num_blocks
|
||||
self.interctc_use_conditioning = interctc_use_conditioning
|
||||
self.conditioning_layer = None
|
||||
self.dropout = nn.Dropout(dropout_rate)
|
||||
self.tf2torch_tensor_name_prefix_torch = tf2torch_tensor_name_prefix_torch
|
||||
self.tf2torch_tensor_name_prefix_tf = tf2torch_tensor_name_prefix_tf
|
||||
self.out_units = out_units
|
||||
if out_units is not None:
|
||||
self.output_linear = nn.Linear(output_size, out_units)
|
||||
|
||||
def output_size(self) -> int:
|
||||
return self._output_size
|
||||
|
||||
def forward(
|
||||
self,
|
||||
xs_pad: torch.Tensor,
|
||||
ilens: torch.Tensor,
|
||||
prev_states: torch.Tensor = None,
|
||||
ctc: CTC = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
|
||||
"""Embed positions in tensor.
|
||||
|
||||
Args:
|
||||
xs_pad: input tensor (B, L, D)
|
||||
ilens: input length (B)
|
||||
prev_states: Not to be used now.
|
||||
Returns:
|
||||
position embedded tensor and mask
|
||||
"""
|
||||
masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
|
||||
xs_pad = xs_pad * self.output_size() ** 0.5
|
||||
if self.embed is None:
|
||||
xs_pad = xs_pad
|
||||
elif (
|
||||
isinstance(self.embed, Conv2dSubsampling)
|
||||
or isinstance(self.embed, Conv2dSubsampling2)
|
||||
or isinstance(self.embed, Conv2dSubsampling6)
|
||||
or isinstance(self.embed, Conv2dSubsampling8)
|
||||
):
|
||||
short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1))
|
||||
if short_status:
|
||||
raise TooShortUttError(
|
||||
f"has {xs_pad.size(1)} frames and is too short for subsampling "
|
||||
+ f"(it needs more than {limit_size} frames), return empty results",
|
||||
xs_pad.size(1),
|
||||
limit_size,
|
||||
)
|
||||
xs_pad, masks = self.embed(xs_pad, masks)
|
||||
else:
|
||||
xs_pad = self.embed(xs_pad)
|
||||
|
||||
xs_pad = self.dropout(xs_pad)
|
||||
# encoder_outs = self.encoders0(xs_pad, masks)
|
||||
# xs_pad, masks = encoder_outs[0], encoder_outs[1]
|
||||
intermediate_outs = []
|
||||
if len(self.interctc_layer_idx) == 0:
|
||||
encoder_outs = self.encoders(xs_pad, masks)
|
||||
xs_pad, masks = encoder_outs[0], encoder_outs[1]
|
||||
else:
|
||||
for layer_idx, encoder_layer in enumerate(self.encoders):
|
||||
encoder_outs = encoder_layer(xs_pad, masks)
|
||||
xs_pad, masks = encoder_outs[0], encoder_outs[1]
|
||||
|
||||
if layer_idx + 1 in self.interctc_layer_idx:
|
||||
encoder_out = xs_pad
|
||||
|
||||
# intermediate outputs are also normalized
|
||||
if self.normalize_before:
|
||||
encoder_out = self.after_norm(encoder_out)
|
||||
|
||||
intermediate_outs.append((layer_idx + 1, encoder_out))
|
||||
|
||||
if self.interctc_use_conditioning:
|
||||
ctc_out = ctc.softmax(encoder_out)
|
||||
xs_pad = xs_pad + self.conditioning_layer(ctc_out)
|
||||
|
||||
if self.normalize_before:
|
||||
xs_pad = self.after_norm(xs_pad)
|
||||
|
||||
if self.out_units is not None:
|
||||
xs_pad = self.output_linear(xs_pad)
|
||||
olens = masks.squeeze(1).sum(1)
|
||||
if len(intermediate_outs) > 0:
|
||||
return (xs_pad, intermediate_outs), olens, None
|
||||
return xs_pad, olens, None
|
||||
Reference in New Issue
Block a user