Sync from bytedesk-private: update

This commit is contained in:
jack ning
2024-12-14 10:43:18 +08:00
parent 476eebb101
commit 5e082909e4
3421 changed files with 812709 additions and 0 deletions

View File

@@ -0,0 +1,32 @@
import torch
from torch.nn import functional as F
class DotScorer(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(
self,
xs_pad: torch.Tensor,
spk_emb: torch.Tensor,
):
# xs_pad: B, T, D
# spk_emb: B, N, D
scores = torch.matmul(xs_pad, spk_emb.transpose(1, 2))
return scores
class CosScorer(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(
self,
xs_pad: torch.Tensor,
spk_emb: torch.Tensor,
):
# xs_pad: B, T, D
# spk_emb: B, N, D
scores = F.cosine_similarity(xs_pad.unsqueeze(2), spk_emb.unsqueeze(1), dim=-1)
return scores

View File

@@ -0,0 +1,174 @@
from typing import List
from typing import Optional
from typing import Sequence
from typing import Tuple
from typing import Union
import logging
import torch
import torch.nn as nn
from torch.nn import functional as F
import numpy as np
from funasr.models.transformer.utils.nets_utils import make_pad_mask
from funasr.models.transformer.layer_norm import LayerNorm
from funasr.models.encoder.abs_encoder import AbsEncoder
import math
from funasr.models.transformer.utils.repeat import repeat
class EncoderLayer(nn.Module):
def __init__(
self,
input_units,
num_units,
kernel_size=3,
activation="tanh",
stride=1,
include_batch_norm=False,
residual=False,
):
super().__init__()
left_padding = math.ceil((kernel_size - stride) / 2)
right_padding = kernel_size - stride - left_padding
self.conv_padding = nn.ConstantPad1d((left_padding, right_padding), 0.0)
self.conv1d = nn.Conv1d(
input_units,
num_units,
kernel_size,
stride,
)
self.activation = self.get_activation(activation)
if include_batch_norm:
self.bn = nn.BatchNorm1d(num_units, momentum=0.99, eps=1e-3)
self.residual = residual
self.include_batch_norm = include_batch_norm
self.input_units = input_units
self.num_units = num_units
self.stride = stride
@staticmethod
def get_activation(activation):
if activation == "tanh":
return nn.Tanh()
else:
return nn.ReLU()
def forward(self, xs_pad, ilens=None):
outputs = self.conv1d(self.conv_padding(xs_pad))
if self.residual and self.stride == 1 and self.input_units == self.num_units:
outputs = outputs + xs_pad
if self.include_batch_norm:
outputs = self.bn(outputs)
# add parenthesis for repeat module
return self.activation(outputs), ilens
class ConvEncoder(AbsEncoder):
"""
Author: Speech Lab of DAMO Academy, Alibaba Group
Convolution encoder in OpenNMT framework
"""
def __init__(
self,
num_layers,
input_units,
num_units,
kernel_size=3,
dropout_rate=0.3,
position_encoder=None,
activation="tanh",
auxiliary_states=True,
out_units=None,
out_norm=False,
out_residual=False,
include_batchnorm=False,
regularization_weight=0.0,
stride=1,
tf2torch_tensor_name_prefix_torch: str = "speaker_encoder",
tf2torch_tensor_name_prefix_tf: str = "EAND/speaker_encoder",
):
super().__init__()
self._output_size = num_units
self.num_layers = num_layers
self.input_units = input_units
self.num_units = num_units
self.kernel_size = kernel_size
self.dropout_rate = dropout_rate
self.position_encoder = position_encoder
self.out_units = out_units
self.auxiliary_states = auxiliary_states
self.out_norm = out_norm
self.activation = activation
self.out_residual = out_residual
self.include_batch_norm = include_batchnorm
self.regularization_weight = regularization_weight
self.tf2torch_tensor_name_prefix_torch = tf2torch_tensor_name_prefix_torch
self.tf2torch_tensor_name_prefix_tf = tf2torch_tensor_name_prefix_tf
if isinstance(stride, int):
self.stride = [stride] * self.num_layers
else:
self.stride = stride
self.downsample_rate = 1
for s in self.stride:
self.downsample_rate *= s
self.dropout = nn.Dropout(dropout_rate)
self.cnn_a = repeat(
self.num_layers,
lambda lnum: EncoderLayer(
input_units if lnum == 0 else num_units,
num_units,
kernel_size,
activation,
self.stride[lnum],
include_batchnorm,
residual=True if lnum > 0 else False,
),
)
if self.out_units is not None:
left_padding = math.ceil((kernel_size - stride) / 2)
right_padding = kernel_size - stride - left_padding
self.out_padding = nn.ConstantPad1d((left_padding, right_padding), 0.0)
self.conv_out = nn.Conv1d(
num_units,
out_units,
kernel_size,
)
if self.out_norm:
self.after_norm = LayerNorm(out_units)
def output_size(self) -> int:
return self.num_units
def forward(
self,
xs_pad: torch.Tensor,
ilens: torch.Tensor,
prev_states: torch.Tensor = None,
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
inputs = xs_pad
if self.position_encoder is not None:
inputs = self.position_encoder(inputs)
if self.dropout_rate > 0:
inputs = self.dropout(inputs)
outputs, _ = self.cnn_a(inputs.transpose(1, 2), ilens)
if self.out_units is not None:
outputs = self.conv_out(self.out_padding(outputs))
outputs = outputs.transpose(1, 2)
if self.out_norm:
outputs = self.after_norm(outputs)
if self.out_residual:
outputs = outputs + inputs
return outputs, ilens, None

View File

@@ -0,0 +1,672 @@
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
class _BatchNorm1d(nn.Module):
def __init__(
self,
input_shape=None,
input_size=None,
eps=1e-05,
momentum=0.1,
affine=True,
track_running_stats=True,
combine_batch_time=False,
skip_transpose=False,
):
super().__init__()
self.combine_batch_time = combine_batch_time
self.skip_transpose = skip_transpose
if input_size is None and skip_transpose:
input_size = input_shape[1]
elif input_size is None:
input_size = input_shape[-1]
self.norm = nn.BatchNorm1d(
input_size,
eps=eps,
momentum=momentum,
affine=affine,
track_running_stats=track_running_stats,
)
def forward(self, x):
shape_or = x.shape
if self.combine_batch_time:
if x.ndim == 3:
x = x.reshape(shape_or[0] * shape_or[1], shape_or[2])
else:
x = x.reshape(shape_or[0] * shape_or[1], shape_or[3], shape_or[2])
elif not self.skip_transpose:
x = x.transpose(-1, 1)
x_n = self.norm(x)
if self.combine_batch_time:
x_n = x_n.reshape(shape_or)
elif not self.skip_transpose:
x_n = x_n.transpose(1, -1)
return x_n
class _Conv1d(nn.Module):
def __init__(
self,
out_channels,
kernel_size,
input_shape=None,
in_channels=None,
stride=1,
dilation=1,
padding="same",
groups=1,
bias=True,
padding_mode="reflect",
skip_transpose=False,
):
super().__init__()
self.kernel_size = kernel_size
self.stride = stride
self.dilation = dilation
self.padding = padding
self.padding_mode = padding_mode
self.unsqueeze = False
self.skip_transpose = skip_transpose
if input_shape is None and in_channels is None:
raise ValueError("Must provide one of input_shape or in_channels")
if in_channels is None:
in_channels = self._check_input_shape(input_shape)
self.conv = nn.Conv1d(
in_channels,
out_channels,
self.kernel_size,
stride=self.stride,
dilation=self.dilation,
padding=0,
groups=groups,
bias=bias,
)
def forward(self, x):
if not self.skip_transpose:
x = x.transpose(1, -1)
if self.unsqueeze:
x = x.unsqueeze(1)
if self.padding == "same":
x = self._manage_padding(x, self.kernel_size, self.dilation, self.stride)
elif self.padding == "causal":
num_pad = (self.kernel_size - 1) * self.dilation
x = F.pad(x, (num_pad, 0))
elif self.padding == "valid":
pass
else:
raise ValueError("Padding must be 'same', 'valid' or 'causal'. Got " + self.padding)
wx = self.conv(x)
if self.unsqueeze:
wx = wx.squeeze(1)
if not self.skip_transpose:
wx = wx.transpose(1, -1)
return wx
def _manage_padding(
self,
x,
kernel_size: int,
dilation: int,
stride: int,
):
# Detecting input shape
L_in = x.shape[-1]
# Time padding
padding = get_padding_elem(L_in, stride, kernel_size, dilation)
# Applying padding
x = F.pad(x, padding, mode=self.padding_mode)
return x
def _check_input_shape(self, shape):
"""Checks the input shape and returns the number of input channels."""
if len(shape) == 2:
self.unsqueeze = True
in_channels = 1
elif self.skip_transpose:
in_channels = shape[1]
elif len(shape) == 3:
in_channels = shape[2]
else:
raise ValueError("conv1d expects 2d, 3d inputs. Got " + str(len(shape)))
# Kernel size must be odd
if self.kernel_size % 2 == 0:
raise ValueError(
"The field kernel size must be an odd number. Got %s." % (self.kernel_size)
)
return in_channels
def get_padding_elem(L_in: int, stride: int, kernel_size: int, dilation: int):
if stride > 1:
n_steps = math.ceil(((L_in - kernel_size * dilation) / stride) + 1)
L_out = stride * (n_steps - 1) + kernel_size * dilation
padding = [kernel_size // 2, kernel_size // 2]
else:
L_out = (L_in - dilation * (kernel_size - 1) - 1) // stride + 1
padding = [(L_in - L_out) // 2, (L_in - L_out) // 2]
return padding
# Skip transpose as much as possible for efficiency
class Conv1d(_Conv1d):
def __init__(self, *args, **kwargs):
super().__init__(skip_transpose=True, *args, **kwargs)
class BatchNorm1d(_BatchNorm1d):
def __init__(self, *args, **kwargs):
super().__init__(skip_transpose=True, *args, **kwargs)
def length_to_mask(length, max_len=None, dtype=None, device=None):
assert len(length.shape) == 1
if max_len is None:
max_len = length.max().long().item() # using arange to generate mask
mask = torch.arange(max_len, device=length.device, dtype=length.dtype).expand(
len(length), max_len
) < length.unsqueeze(1)
if dtype is None:
dtype = length.dtype
if device is None:
device = length.device
mask = torch.as_tensor(mask, dtype=dtype, device=device)
return mask
class TDNNBlock(nn.Module):
def __init__(
self,
in_channels,
out_channels,
kernel_size,
dilation,
activation=nn.ReLU,
groups=1,
):
super(TDNNBlock, self).__init__()
self.conv = Conv1d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
dilation=dilation,
groups=groups,
)
self.activation = activation()
self.norm = BatchNorm1d(input_size=out_channels)
def forward(self, x):
return self.norm(self.activation(self.conv(x)))
class Res2NetBlock(torch.nn.Module):
"""An implementation of Res2NetBlock w/ dilation.
Arguments
---------
in_channels : int
The number of channels expected in the input.
out_channels : int
The number of output channels.
scale : int
The scale of the Res2Net block.
kernel_size: int
The kernel size of the Res2Net block.
dilation : int
The dilation of the Res2Net block.
Example
-------
>>> inp_tensor = torch.rand([8, 120, 64]).transpose(1, 2)
>>> layer = Res2NetBlock(64, 64, scale=4, dilation=3)
>>> out_tensor = layer(inp_tensor).transpose(1, 2)
>>> out_tensor.shape
torch.Size([8, 120, 64])
"""
def __init__(self, in_channels, out_channels, scale=8, kernel_size=3, dilation=1):
super(Res2NetBlock, self).__init__()
assert in_channels % scale == 0
assert out_channels % scale == 0
in_channel = in_channels // scale
hidden_channel = out_channels // scale
self.blocks = nn.ModuleList(
[
TDNNBlock(
in_channel,
hidden_channel,
kernel_size=kernel_size,
dilation=dilation,
)
for i in range(scale - 1)
]
)
self.scale = scale
def forward(self, x):
y = []
for i, x_i in enumerate(torch.chunk(x, self.scale, dim=1)):
if i == 0:
y_i = x_i
elif i == 1:
y_i = self.blocks[i - 1](x_i)
else:
y_i = self.blocks[i - 1](x_i + y_i)
y.append(y_i)
y = torch.cat(y, dim=1)
return y
class SEBlock(nn.Module):
"""An implementation of squeeze-and-excitation block.
Arguments
---------
in_channels : int
The number of input channels.
se_channels : int
The number of output channels after squeeze.
out_channels : int
The number of output channels.
Example
-------
>>> inp_tensor = torch.rand([8, 120, 64]).transpose(1, 2)
>>> se_layer = SEBlock(64, 16, 64)
>>> lengths = torch.rand((8,))
>>> out_tensor = se_layer(inp_tensor, lengths).transpose(1, 2)
>>> out_tensor.shape
torch.Size([8, 120, 64])
"""
def __init__(self, in_channels, se_channels, out_channels):
super(SEBlock, self).__init__()
self.conv1 = Conv1d(in_channels=in_channels, out_channels=se_channels, kernel_size=1)
self.relu = torch.nn.ReLU(inplace=True)
self.conv2 = Conv1d(in_channels=se_channels, out_channels=out_channels, kernel_size=1)
self.sigmoid = torch.nn.Sigmoid()
def forward(self, x, lengths=None):
L = x.shape[-1]
if lengths is not None:
mask = length_to_mask(lengths * L, max_len=L, device=x.device)
mask = mask.unsqueeze(1)
total = mask.sum(dim=2, keepdim=True)
s = (x * mask).sum(dim=2, keepdim=True) / total
else:
s = x.mean(dim=2, keepdim=True)
s = self.relu(self.conv1(s))
s = self.sigmoid(self.conv2(s))
return s * x
class AttentiveStatisticsPooling(nn.Module):
"""This class implements an attentive statistic pooling layer for each channel.
It returns the concatenated mean and std of the input tensor.
Arguments
---------
channels: int
The number of input channels.
attention_channels: int
The number of attention channels.
Example
-------
>>> inp_tensor = torch.rand([8, 120, 64]).transpose(1, 2)
>>> asp_layer = AttentiveStatisticsPooling(64)
>>> lengths = torch.rand((8,))
>>> out_tensor = asp_layer(inp_tensor, lengths).transpose(1, 2)
>>> out_tensor.shape
torch.Size([8, 1, 128])
"""
def __init__(self, channels, attention_channels=128, global_context=True):
super().__init__()
self.eps = 1e-12
self.global_context = global_context
if global_context:
self.tdnn = TDNNBlock(channels * 3, attention_channels, 1, 1)
else:
self.tdnn = TDNNBlock(channels, attention_channels, 1, 1)
self.tanh = nn.Tanh()
self.conv = Conv1d(in_channels=attention_channels, out_channels=channels, kernel_size=1)
def forward(self, x, lengths=None):
"""Calculates mean and std for a batch (input tensor).
Arguments
---------
x : torch.Tensor
Tensor of shape [N, C, L].
"""
L = x.shape[-1]
def _compute_statistics(x, m, dim=2, eps=self.eps):
mean = (m * x).sum(dim)
std = torch.sqrt((m * (x - mean.unsqueeze(dim)).pow(2)).sum(dim).clamp(eps))
return mean, std
if lengths is None:
lengths = torch.ones(x.shape[0], device=x.device)
# Make binary mask of shape [N, 1, L]
mask = length_to_mask(lengths * L, max_len=L, device=x.device)
mask = mask.unsqueeze(1)
# Expand the temporal context of the pooling layer by allowing the
# self-attention to look at global properties of the utterance.
if self.global_context:
# torch.std is unstable for backward computation
# https://github.com/pytorch/pytorch/issues/4320
total = mask.sum(dim=2, keepdim=True).float()
mean, std = _compute_statistics(x, mask / total)
mean = mean.unsqueeze(2).repeat(1, 1, L)
std = std.unsqueeze(2).repeat(1, 1, L)
attn = torch.cat([x, mean, std], dim=1)
else:
attn = x
# Apply layers
attn = self.conv(self.tanh(self.tdnn(attn)))
# Filter out zero-paddings
attn = attn.masked_fill(mask == 0, float("-inf"))
attn = F.softmax(attn, dim=2)
mean, std = _compute_statistics(x, attn)
# Append mean and std of the batch
pooled_stats = torch.cat((mean, std), dim=1)
pooled_stats = pooled_stats.unsqueeze(2)
return pooled_stats
class SERes2NetBlock(nn.Module):
"""An implementation of building block in ECAPA-TDNN, i.e.,
TDNN-Res2Net-TDNN-SEBlock.
Arguments
----------
out_channels: int
The number of output channels.
res2net_scale: int
The scale of the Res2Net block.
kernel_size: int
The kernel size of the TDNN blocks.
dilation: int
The dilation of the Res2Net block.
activation : torch class
A class for constructing the activation layers.
groups: int
Number of blocked connections from input channels to output channels.
Example
-------
>>> x = torch.rand(8, 120, 64).transpose(1, 2)
>>> conv = SERes2NetBlock(64, 64, res2net_scale=4)
>>> out = conv(x).transpose(1, 2)
>>> out.shape
torch.Size([8, 120, 64])
"""
def __init__(
self,
in_channels,
out_channels,
res2net_scale=8,
se_channels=128,
kernel_size=1,
dilation=1,
activation=torch.nn.ReLU,
groups=1,
):
super().__init__()
self.out_channels = out_channels
self.tdnn1 = TDNNBlock(
in_channels,
out_channels,
kernel_size=1,
dilation=1,
activation=activation,
groups=groups,
)
self.res2net_block = Res2NetBlock(
out_channels, out_channels, res2net_scale, kernel_size, dilation
)
self.tdnn2 = TDNNBlock(
out_channels,
out_channels,
kernel_size=1,
dilation=1,
activation=activation,
groups=groups,
)
self.se_block = SEBlock(out_channels, se_channels, out_channels)
self.shortcut = None
if in_channels != out_channels:
self.shortcut = Conv1d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=1,
)
def forward(self, x, lengths=None):
residual = x
if self.shortcut:
residual = self.shortcut(x)
x = self.tdnn1(x)
x = self.res2net_block(x)
x = self.tdnn2(x)
x = self.se_block(x, lengths)
return x + residual
class ECAPA_TDNN(torch.nn.Module):
"""An implementation of the speaker embedding model in a paper.
"ECAPA-TDNN: Emphasized Channel Attention, Propagation and Aggregation in
TDNN Based Speaker Verification" (https://arxiv.org/abs/2005.07143).
Arguments
---------
activation : torch class
A class for constructing the activation layers.
channels : list of ints
Output channels for TDNN/SERes2Net layer.
kernel_sizes : list of ints
List of kernel sizes for each layer.
dilations : list of ints
List of dilations for kernels in each layer.
lin_neurons : int
Number of neurons in linear layers.
groups : list of ints
List of groups for kernels in each layer.
Example
-------
>>> input_feats = torch.rand([5, 120, 80])
>>> compute_embedding = ECAPA_TDNN(80, lin_neurons=192)
>>> outputs = compute_embedding(input_feats)
>>> outputs.shape
torch.Size([5, 1, 192])
"""
def __init__(
self,
input_size,
lin_neurons=192,
activation=torch.nn.ReLU,
channels=[512, 512, 512, 512, 1536],
kernel_sizes=[5, 3, 3, 3, 1],
dilations=[1, 2, 3, 4, 1],
attention_channels=128,
res2net_scale=8,
se_channels=128,
global_context=True,
groups=[1, 1, 1, 1, 1],
window_size=20,
window_shift=1,
):
super().__init__()
assert len(channels) == len(kernel_sizes)
assert len(channels) == len(dilations)
self.channels = channels
self.blocks = nn.ModuleList()
self.window_size = window_size
self.window_shift = window_shift
# The initial TDNN layer
self.blocks.append(
TDNNBlock(
input_size,
channels[0],
kernel_sizes[0],
dilations[0],
activation,
groups[0],
)
)
# SE-Res2Net layers
for i in range(1, len(channels) - 1):
self.blocks.append(
SERes2NetBlock(
channels[i - 1],
channels[i],
res2net_scale=res2net_scale,
se_channels=se_channels,
kernel_size=kernel_sizes[i],
dilation=dilations[i],
activation=activation,
groups=groups[i],
)
)
# Multi-layer feature aggregation
self.mfa = TDNNBlock(
channels[-1],
channels[-1],
kernel_sizes[-1],
dilations[-1],
activation,
groups=groups[-1],
)
# Attentive Statistical Pooling
self.asp = AttentiveStatisticsPooling(
channels[-1],
attention_channels=attention_channels,
global_context=global_context,
)
self.asp_bn = BatchNorm1d(input_size=channels[-1] * 2)
# Final linear transformation
self.fc = Conv1d(
in_channels=channels[-1] * 2,
out_channels=lin_neurons,
kernel_size=1,
)
def windowed_pooling(self, x, lengths=None):
# x: Batch, Channel, Time
tt = x.shape[2]
num_chunk = int(math.ceil(tt / self.window_shift))
pad = self.window_size // 2
x = F.pad(x, (pad, pad, 0, 0), "reflect")
stat_list = []
for i in range(num_chunk):
# B x C
st, ed = i * self.window_shift, i * self.window_shift + self.window_size
x = self.asp(
x[:, :, st:ed],
lengths=(
torch.clamp(lengths - i, 0, self.window_size) if lengths is not None else None
),
)
x = self.asp_bn(x)
x = self.fc(x)
stat_list.append(x)
return torch.cat(stat_list, dim=2)
def forward(self, x, lengths=None):
"""Returns the embedding vector.
Arguments
---------
x : torch.Tensor
Tensor of shape (batch, time, channel).
lengths: torch.Tensor
Tensor of shape (batch, )
"""
# Minimize transpose for efficiency
x = x.transpose(1, 2)
xl = []
for layer in self.blocks:
try:
x = layer(x, lengths=lengths)
except TypeError:
x = layer(x)
xl.append(x)
# Multi-layer feature aggregation
x = torch.cat(xl[1:], dim=1)
x = self.mfa(x)
if self.window_size is None:
# Attentive Statistical Pooling
x = self.asp(x, lengths=lengths)
x = self.asp_bn(x)
# Final linear transformation
x = self.fc(x)
# x = x.transpose(1, 2)
x = x.squeeze(2) # -> B, C
else:
x = self.windowed_pooling(x, lengths)
x = x.transpose(1, 2) # -> B, T, C
return x

View File

@@ -0,0 +1,180 @@
from typing import List
from typing import Optional
from typing import Sequence
from typing import Tuple
from typing import Union
import logging
import torch
import torch.nn as nn
from torch.nn import functional as F
import numpy as np
from funasr.models.transformer.utils.nets_utils import make_pad_mask
from funasr.models.transformer.layer_norm import LayerNorm
from funasr.models.encoder.abs_encoder import AbsEncoder
import math
from funasr.models.transformer.utils.repeat import repeat
from funasr.models.transformer.utils.multi_layer_conv import FsmnFeedForward
class FsmnBlock(torch.nn.Module):
def __init__(
self,
n_feat,
dropout_rate,
kernel_size,
fsmn_shift=0,
):
super().__init__()
self.dropout = nn.Dropout(p=dropout_rate)
self.fsmn_block = nn.Conv1d(
n_feat, n_feat, kernel_size, stride=1, padding=0, groups=n_feat, bias=False
)
# padding
left_padding = (kernel_size - 1) // 2
if fsmn_shift > 0:
left_padding = left_padding + fsmn_shift
right_padding = kernel_size - 1 - left_padding
self.pad_fn = nn.ConstantPad1d((left_padding, right_padding), 0.0)
def forward(self, inputs, mask, mask_shfit_chunk=None):
b, t, d = inputs.size()
if mask is not None:
mask = torch.reshape(mask, (b, -1, 1))
if mask_shfit_chunk is not None:
mask = mask * mask_shfit_chunk
inputs = inputs * mask
x = inputs.transpose(1, 2)
x = self.pad_fn(x)
x = self.fsmn_block(x)
x = x.transpose(1, 2)
x = x + inputs
x = self.dropout(x)
return x * mask
class EncoderLayer(torch.nn.Module):
def __init__(self, in_size, size, feed_forward, fsmn_block, dropout_rate=0.0):
super().__init__()
self.in_size = in_size
self.size = size
self.ffn = feed_forward
self.memory = fsmn_block
self.dropout = nn.Dropout(dropout_rate)
def forward(
self, xs_pad: torch.Tensor, mask: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
# xs_pad in Batch, Time, Dim
context = self.ffn(xs_pad)[0]
memory = self.memory(context, mask)
memory = self.dropout(memory)
if self.in_size == self.size:
return memory + xs_pad, mask
return memory, mask
class FsmnEncoder(AbsEncoder):
"""Encoder using Fsmn"""
def __init__(
self,
in_units,
filter_size,
fsmn_num_layers,
dnn_num_layers,
num_memory_units=512,
ffn_inner_dim=2048,
dropout_rate=0.0,
shift=0,
position_encoder=None,
sample_rate=1,
out_units=None,
tf2torch_tensor_name_prefix_torch="post_net",
tf2torch_tensor_name_prefix_tf="EAND/post_net",
):
"""Initializes the parameters of the encoder.
Args:
filter_size: the total order of memory block
fsmn_num_layers: The number of fsmn layers.
dnn_num_layers: The number of dnn layers
num_units: The number of memory units.
ffn_inner_dim: The number of units of the inner linear transformation
in the feed forward layer.
dropout_rate: The probability to drop units from the outputs.
shift: left padding, to control delay
position_encoder: The :class:`opennmt.layers.position.PositionEncoder` to
apply on inputs or ``None``.
"""
super(FsmnEncoder, self).__init__()
self.in_units = in_units
self.filter_size = filter_size
self.fsmn_num_layers = fsmn_num_layers
self.dnn_num_layers = dnn_num_layers
self.num_memory_units = num_memory_units
self.ffn_inner_dim = ffn_inner_dim
self.dropout_rate = dropout_rate
self.shift = shift
if not isinstance(shift, list):
self.shift = [shift for _ in range(self.fsmn_num_layers)]
self.sample_rate = sample_rate
if not isinstance(sample_rate, list):
self.sample_rate = [sample_rate for _ in range(self.fsmn_num_layers)]
self.position_encoder = position_encoder
self.dropout = nn.Dropout(dropout_rate)
self.out_units = out_units
self.tf2torch_tensor_name_prefix_torch = tf2torch_tensor_name_prefix_torch
self.tf2torch_tensor_name_prefix_tf = tf2torch_tensor_name_prefix_tf
self.fsmn_layers = repeat(
self.fsmn_num_layers,
lambda lnum: EncoderLayer(
in_units if lnum == 0 else num_memory_units,
num_memory_units,
FsmnFeedForward(
in_units if lnum == 0 else num_memory_units,
ffn_inner_dim,
num_memory_units,
1,
dropout_rate,
),
FsmnBlock(num_memory_units, dropout_rate, filter_size, self.shift[lnum]),
),
)
self.dnn_layers = repeat(
dnn_num_layers,
lambda lnum: FsmnFeedForward(
num_memory_units,
ffn_inner_dim,
num_memory_units,
1,
dropout_rate,
),
)
if out_units is not None:
self.conv1d = nn.Conv1d(num_memory_units, out_units, 1, 1)
def output_size(self) -> int:
return self.num_memory_units
def forward(
self, xs_pad: torch.Tensor, ilens: torch.Tensor, prev_states: torch.Tensor = None
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
inputs = xs_pad
if self.position_encoder is not None:
inputs = self.position_encoder(inputs)
inputs = self.dropout(inputs)
masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
inputs = self.fsmn_layers(inputs, masks)[0]
inputs = self.dnn_layers(inputs)[0]
if self.out_units is not None:
inputs = self.conv1d(inputs.transpose(1, 2)).transpose(1, 2)
return inputs, ilens, None

View File

@@ -0,0 +1,462 @@
import torch
from torch.nn import functional as F
from funasr.models.encoder.abs_encoder import AbsEncoder
from typing import Tuple, Optional
from funasr.models.pooling.statistic_pooling import statistic_pooling, windowed_statistic_pooling
from collections import OrderedDict
import logging
import numpy as np
class BasicLayer(torch.nn.Module):
def __init__(self, in_filters: int, filters: int, stride: int, bn_momentum: float = 0.5):
super().__init__()
self.stride = stride
self.in_filters = in_filters
self.filters = filters
self.bn1 = torch.nn.BatchNorm2d(in_filters, eps=1e-3, momentum=bn_momentum, affine=True)
self.relu1 = torch.nn.ReLU()
self.conv1 = torch.nn.Conv2d(in_filters, filters, 3, stride, bias=False)
self.bn2 = torch.nn.BatchNorm2d(filters, eps=1e-3, momentum=bn_momentum, affine=True)
self.relu2 = torch.nn.ReLU()
self.conv2 = torch.nn.Conv2d(filters, filters, 3, 1, bias=False)
if in_filters != filters or stride > 1:
self.conv_sc = torch.nn.Conv2d(in_filters, filters, 1, stride, bias=False)
self.bn_sc = torch.nn.BatchNorm2d(filters, eps=1e-3, momentum=bn_momentum, affine=True)
def proper_padding(self, x, stride):
# align padding mode to tf.layers.conv2d with padding_mod="same"
if stride == 1:
return F.pad(x, (1, 1, 1, 1), "constant", 0)
elif stride == 2:
h, w = x.size(2), x.size(3)
# (left, right, top, bottom)
return F.pad(x, (w % 2, 1, h % 2, 1), "constant", 0)
def forward(self, xs_pad, ilens):
identity = xs_pad
if self.in_filters != self.filters or self.stride > 1:
identity = self.conv_sc(identity)
identity = self.bn_sc(identity)
xs_pad = self.relu1(self.bn1(xs_pad))
xs_pad = self.proper_padding(xs_pad, self.stride)
xs_pad = self.conv1(xs_pad)
xs_pad = self.relu2(self.bn2(xs_pad))
xs_pad = self.proper_padding(xs_pad, 1)
xs_pad = self.conv2(xs_pad)
if self.stride == 2:
ilens = (ilens + 1) // self.stride
return xs_pad + identity, ilens
class BasicBlock(torch.nn.Module):
def __init__(self, in_filters, filters, num_layer, stride, bn_momentum=0.5):
super().__init__()
self.num_layer = num_layer
for i in range(num_layer):
layer = BasicLayer(
in_filters if i == 0 else filters, filters, stride if i == 0 else 1, bn_momentum
)
self.add_module("layer_{}".format(i), layer)
def forward(self, xs_pad, ilens):
for i in range(self.num_layer):
xs_pad, ilens = self._modules["layer_{}".format(i)](xs_pad, ilens)
return xs_pad, ilens
class ResNet34(AbsEncoder):
def __init__(
self,
input_size,
use_head_conv=True,
batchnorm_momentum=0.5,
use_head_maxpool=False,
num_nodes_pooling_layer=256,
layers_in_block=(3, 4, 6, 3),
filters_in_block=(32, 64, 128, 256),
):
super(ResNet34, self).__init__()
self.use_head_conv = use_head_conv
self.use_head_maxpool = use_head_maxpool
self.num_nodes_pooling_layer = num_nodes_pooling_layer
self.layers_in_block = layers_in_block
self.filters_in_block = filters_in_block
self.input_size = input_size
pre_filters = filters_in_block[0]
if use_head_conv:
self.pre_conv = torch.nn.Conv2d(
1, pre_filters, 3, 1, 1, bias=False, padding_mode="zeros"
)
self.pre_conv_bn = torch.nn.BatchNorm2d(
pre_filters, eps=1e-3, momentum=batchnorm_momentum
)
if use_head_maxpool:
self.head_maxpool = torch.nn.MaxPool2d(3, 1, padding=1)
for i in range(len(layers_in_block)):
if i == 0:
in_filters = pre_filters if self.use_head_conv else 1
else:
in_filters = filters_in_block[i - 1]
block = BasicBlock(
in_filters,
filters=filters_in_block[i],
num_layer=layers_in_block[i],
stride=1 if i == 0 else 2,
bn_momentum=batchnorm_momentum,
)
self.add_module("block_{}".format(i), block)
self.resnet0_dense = torch.nn.Conv2d(filters_in_block[-1], num_nodes_pooling_layer, 1)
self.resnet0_bn = torch.nn.BatchNorm2d(
num_nodes_pooling_layer, eps=1e-3, momentum=batchnorm_momentum
)
self.time_ds_ratio = 8
def output_size(self) -> int:
return self.num_nodes_pooling_layer
def forward(
self, xs_pad: torch.Tensor, ilens: torch.Tensor, prev_states: torch.Tensor = None
) -> Tuple[torch.Tensor, torch.Tensor]:
features = xs_pad
assert (
features.size(-1) == self.input_size
), "Dimension of features {} doesn't match the input_size {}.".format(
features.size(-1), self.input_size
)
features = torch.unsqueeze(features, dim=1)
if self.use_head_conv:
features = self.pre_conv(features)
features = self.pre_conv_bn(features)
features = F.relu(features)
if self.use_head_maxpool:
features = self.head_maxpool(features)
resnet_outs, resnet_out_lens = features, ilens
for i in range(len(self.layers_in_block)):
block = self._modules["block_{}".format(i)]
resnet_outs, resnet_out_lens = block(resnet_outs, resnet_out_lens)
features = self.resnet0_dense(resnet_outs)
features = F.relu(features)
features = self.resnet0_bn(features)
return features, resnet_out_lens
# Note: For training, this implement is not equivalent to tf because of the kernel_regularizer in tf.layers.
# TODO: implement kernel_regularizer in torch with munal loss addition or weigth_decay in the optimizer
class ResNet34_SP_L2Reg(AbsEncoder):
def __init__(
self,
input_size,
use_head_conv=True,
batchnorm_momentum=0.5,
use_head_maxpool=False,
num_nodes_pooling_layer=256,
layers_in_block=(3, 4, 6, 3),
filters_in_block=(32, 64, 128, 256),
tf2torch_tensor_name_prefix_torch="encoder",
tf2torch_tensor_name_prefix_tf="EAND/speech_encoder",
tf_train_steps=720000,
):
super(ResNet34_SP_L2Reg, self).__init__()
self.use_head_conv = use_head_conv
self.use_head_maxpool = use_head_maxpool
self.num_nodes_pooling_layer = num_nodes_pooling_layer
self.layers_in_block = layers_in_block
self.filters_in_block = filters_in_block
self.input_size = input_size
self.tf2torch_tensor_name_prefix_torch = tf2torch_tensor_name_prefix_torch
self.tf2torch_tensor_name_prefix_tf = tf2torch_tensor_name_prefix_tf
self.tf_train_steps = tf_train_steps
pre_filters = filters_in_block[0]
if use_head_conv:
self.pre_conv = torch.nn.Conv2d(
1, pre_filters, 3, 1, 1, bias=False, padding_mode="zeros"
)
self.pre_conv_bn = torch.nn.BatchNorm2d(
pre_filters, eps=1e-3, momentum=batchnorm_momentum
)
if use_head_maxpool:
self.head_maxpool = torch.nn.MaxPool2d(3, 1, padding=1)
for i in range(len(layers_in_block)):
if i == 0:
in_filters = pre_filters if self.use_head_conv else 1
else:
in_filters = filters_in_block[i - 1]
block = BasicBlock(
in_filters,
filters=filters_in_block[i],
num_layer=layers_in_block[i],
stride=1 if i == 0 else 2,
bn_momentum=batchnorm_momentum,
)
self.add_module("block_{}".format(i), block)
self.resnet0_dense = torch.nn.Conv1d(
filters_in_block[-1] * input_size // 8, num_nodes_pooling_layer, 1
)
self.resnet0_bn = torch.nn.BatchNorm1d(
num_nodes_pooling_layer, eps=1e-3, momentum=batchnorm_momentum
)
self.time_ds_ratio = 8
def output_size(self) -> int:
return self.num_nodes_pooling_layer
def forward(
self, xs_pad: torch.Tensor, ilens: torch.Tensor, prev_states: torch.Tensor = None
) -> Tuple[torch.Tensor, torch.Tensor]:
features = xs_pad
assert (
features.size(-1) == self.input_size
), "Dimension of features {} doesn't match the input_size {}.".format(
features.size(-1), self.input_size
)
features = torch.unsqueeze(features, dim=1)
if self.use_head_conv:
features = self.pre_conv(features)
features = self.pre_conv_bn(features)
features = F.relu(features)
if self.use_head_maxpool:
features = self.head_maxpool(features)
resnet_outs, resnet_out_lens = features, ilens
for i in range(len(self.layers_in_block)):
block = self._modules["block_{}".format(i)]
resnet_outs, resnet_out_lens = block(resnet_outs, resnet_out_lens)
# B, C, T, F
bb, cc, tt, ff = resnet_outs.shape
resnet_outs = torch.reshape(resnet_outs.permute(0, 3, 1, 2), [bb, ff * cc, tt])
features = self.resnet0_dense(resnet_outs)
features = F.relu(features)
features = self.resnet0_bn(features)
return features, resnet_out_lens
class ResNet34Diar(ResNet34):
def __init__(
self,
input_size,
embedding_node="resnet1_dense",
use_head_conv=True,
batchnorm_momentum=0.5,
use_head_maxpool=False,
num_nodes_pooling_layer=256,
layers_in_block=(3, 4, 6, 3),
filters_in_block=(32, 64, 128, 256),
num_nodes_resnet1=256,
num_nodes_last_layer=256,
pooling_type="window_shift",
pool_size=20,
stride=1,
tf2torch_tensor_name_prefix_torch="encoder",
tf2torch_tensor_name_prefix_tf="seq2seq/speech_encoder",
):
"""
Author: Speech Lab, Alibaba Group, China
SOND: Speaker Overlap-aware Neural Diarization for Multi-party Meeting Analysis
https://arxiv.org/abs/2211.10243
"""
super(ResNet34Diar, self).__init__(
input_size,
use_head_conv=use_head_conv,
batchnorm_momentum=batchnorm_momentum,
use_head_maxpool=use_head_maxpool,
num_nodes_pooling_layer=num_nodes_pooling_layer,
layers_in_block=layers_in_block,
filters_in_block=filters_in_block,
)
self.embedding_node = embedding_node
self.num_nodes_resnet1 = num_nodes_resnet1
self.num_nodes_last_layer = num_nodes_last_layer
self.pooling_type = pooling_type
self.pool_size = pool_size
self.stride = stride
self.tf2torch_tensor_name_prefix_torch = tf2torch_tensor_name_prefix_torch
self.tf2torch_tensor_name_prefix_tf = tf2torch_tensor_name_prefix_tf
self.resnet1_dense = torch.nn.Linear(num_nodes_pooling_layer * 2, num_nodes_resnet1)
self.resnet1_bn = torch.nn.BatchNorm1d(
num_nodes_resnet1, eps=1e-3, momentum=batchnorm_momentum
)
self.resnet2_dense = torch.nn.Linear(num_nodes_resnet1, num_nodes_last_layer)
self.resnet2_bn = torch.nn.BatchNorm1d(
num_nodes_last_layer, eps=1e-3, momentum=batchnorm_momentum
)
def output_size(self) -> int:
if self.embedding_node.startswith("resnet1"):
return self.num_nodes_resnet1
elif self.embedding_node.startswith("resnet2"):
return self.num_nodes_last_layer
return self.num_nodes_pooling_layer
def forward(
self,
xs_pad: torch.Tensor,
ilens: torch.Tensor,
prev_states: torch.Tensor = None,
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
endpoints = OrderedDict()
res_out, ilens = super().forward(xs_pad, ilens)
endpoints["resnet0_bn"] = res_out
if self.pooling_type == "frame_gsp":
features = statistic_pooling(res_out, ilens, (3,))
else:
features, ilens = windowed_statistic_pooling(
res_out, ilens, (2, 3), self.pool_size, self.stride
)
features = features.transpose(1, 2)
endpoints["pooling"] = features
features = self.resnet1_dense(features)
endpoints["resnet1_dense"] = features
features = F.relu(features)
endpoints["resnet1_relu"] = features
features = self.resnet1_bn(features.transpose(1, 2)).transpose(1, 2)
endpoints["resnet1_bn"] = features
features = self.resnet2_dense(features)
endpoints["resnet2_dense"] = features
features = F.relu(features)
endpoints["resnet2_relu"] = features
features = self.resnet2_bn(features.transpose(1, 2)).transpose(1, 2)
endpoints["resnet2_bn"] = features
return endpoints[self.embedding_node], ilens, None
class ResNet34SpL2RegDiar(ResNet34_SP_L2Reg):
def __init__(
self,
input_size,
embedding_node="resnet1_dense",
use_head_conv=True,
batchnorm_momentum=0.5,
use_head_maxpool=False,
num_nodes_pooling_layer=256,
layers_in_block=(3, 4, 6, 3),
filters_in_block=(32, 64, 128, 256),
num_nodes_resnet1=256,
num_nodes_last_layer=256,
pooling_type="window_shift",
pool_size=20,
stride=1,
tf2torch_tensor_name_prefix_torch="encoder",
tf2torch_tensor_name_prefix_tf="seq2seq/speech_encoder",
):
"""
Author: Speech Lab, Alibaba Group, China
TOLD: A Novel Two-Stage Overlap-Aware Framework for Speaker Diarization
https://arxiv.org/abs/2303.05397
"""
super(ResNet34SpL2RegDiar, self).__init__(
input_size,
use_head_conv=use_head_conv,
batchnorm_momentum=batchnorm_momentum,
use_head_maxpool=use_head_maxpool,
num_nodes_pooling_layer=num_nodes_pooling_layer,
layers_in_block=layers_in_block,
filters_in_block=filters_in_block,
)
self.embedding_node = embedding_node
self.num_nodes_resnet1 = num_nodes_resnet1
self.num_nodes_last_layer = num_nodes_last_layer
self.pooling_type = pooling_type
self.pool_size = pool_size
self.stride = stride
self.tf2torch_tensor_name_prefix_torch = tf2torch_tensor_name_prefix_torch
self.tf2torch_tensor_name_prefix_tf = tf2torch_tensor_name_prefix_tf
self.resnet1_dense = torch.nn.Linear(num_nodes_pooling_layer * 2, num_nodes_resnet1)
self.resnet1_bn = torch.nn.BatchNorm1d(
num_nodes_resnet1, eps=1e-3, momentum=batchnorm_momentum
)
self.resnet2_dense = torch.nn.Linear(num_nodes_resnet1, num_nodes_last_layer)
self.resnet2_bn = torch.nn.BatchNorm1d(
num_nodes_last_layer, eps=1e-3, momentum=batchnorm_momentum
)
def output_size(self) -> int:
if self.embedding_node.startswith("resnet1"):
return self.num_nodes_resnet1
elif self.embedding_node.startswith("resnet2"):
return self.num_nodes_last_layer
return self.num_nodes_pooling_layer
def forward(
self,
xs_pad: torch.Tensor,
ilens: torch.Tensor,
prev_states: torch.Tensor = None,
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
endpoints = OrderedDict()
res_out, ilens = super().forward(xs_pad, ilens)
endpoints["resnet0_bn"] = res_out
if self.pooling_type == "frame_gsp":
features = statistic_pooling(res_out, ilens, (2,))
else:
features, ilens = windowed_statistic_pooling(
res_out, ilens, (2,), self.pool_size, self.stride
)
features = features.transpose(1, 2)
endpoints["pooling"] = features
features = self.resnet1_dense(features)
endpoints["resnet1_dense"] = features
features = F.relu(features)
endpoints["resnet1_relu"] = features
features = self.resnet1_bn(features.transpose(1, 2)).transpose(1, 2)
endpoints["resnet1_bn"] = features
features = self.resnet2_dense(features)
endpoints["resnet2_dense"] = features
features = F.relu(features)
endpoints["resnet2_relu"] = features
features = self.resnet2_bn(features.transpose(1, 2)).transpose(1, 2)
endpoints["resnet2_bn"] = features
return endpoints[self.embedding_node], ilens, None

View File

@@ -0,0 +1,333 @@
from typing import List
from typing import Optional
from typing import Sequence
from typing import Tuple
from typing import Union
import logging
import torch
import torch.nn as nn
from funasr.models.scama.chunk_utilis import overlap_chunk
import numpy as np
from funasr.models.transformer.utils.nets_utils import make_pad_mask
from funasr.models.sond.attention import MultiHeadSelfAttention
from funasr.models.transformer.embedding import SinusoidalPositionEncoder
from funasr.models.transformer.layer_norm import LayerNorm
from funasr.models.transformer.utils.multi_layer_conv import Conv1dLinear
from funasr.models.transformer.utils.multi_layer_conv import MultiLayeredConv1d
from funasr.models.transformer.positionwise_feed_forward import (
PositionwiseFeedForward, # noqa: H301
)
from funasr.models.transformer.utils.repeat import repeat
from funasr.models.transformer.utils.subsampling import Conv2dSubsampling
from funasr.models.transformer.utils.subsampling import Conv2dSubsampling2
from funasr.models.transformer.utils.subsampling import Conv2dSubsampling6
from funasr.models.transformer.utils.subsampling import Conv2dSubsampling8
from funasr.models.transformer.utils.subsampling import TooShortUttError
from funasr.models.transformer.utils.subsampling import check_short_utt
from funasr.models.ctc import CTC
from funasr.models.encoder.abs_encoder import AbsEncoder
class EncoderLayer(nn.Module):
def __init__(
self,
in_size,
size,
self_attn,
feed_forward,
dropout_rate,
normalize_before=True,
concat_after=False,
stochastic_depth_rate=0.0,
):
"""Construct an EncoderLayer object."""
super(EncoderLayer, self).__init__()
self.self_attn = self_attn
self.feed_forward = feed_forward
self.norm1 = LayerNorm(in_size)
self.norm2 = LayerNorm(size)
self.dropout = nn.Dropout(dropout_rate)
self.in_size = in_size
self.size = size
self.normalize_before = normalize_before
self.concat_after = concat_after
if self.concat_after:
self.concat_linear = nn.Linear(size + size, size)
self.stochastic_depth_rate = stochastic_depth_rate
self.dropout_rate = dropout_rate
def forward(self, x, mask, cache=None, mask_att_chunk_encoder=None):
"""Compute encoded features.
Args:
x_input (torch.Tensor): Input tensor (#batch, time, size).
mask (torch.Tensor): Mask tensor for the input (#batch, time).
cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size).
Returns:
torch.Tensor: Output tensor (#batch, time, size).
torch.Tensor: Mask tensor (#batch, time).
"""
skip_layer = False
# with stochastic depth, residual connection `x + f(x)` becomes
# `x <- x + 1 / (1 - p) * f(x)` at training time.
stoch_layer_coeff = 1.0
if self.training and self.stochastic_depth_rate > 0:
skip_layer = torch.rand(1).item() < self.stochastic_depth_rate
stoch_layer_coeff = 1.0 / (1 - self.stochastic_depth_rate)
if skip_layer:
if cache is not None:
x = torch.cat([cache, x], dim=1)
return x, mask
residual = x
if self.normalize_before:
x = self.norm1(x)
if self.concat_after:
x_concat = torch.cat(
(x, self.self_attn(x, mask, mask_att_chunk_encoder=mask_att_chunk_encoder)), dim=-1
)
if self.in_size == self.size:
x = residual + stoch_layer_coeff * self.concat_linear(x_concat)
else:
x = stoch_layer_coeff * self.concat_linear(x_concat)
else:
if self.in_size == self.size:
x = residual + stoch_layer_coeff * self.dropout(
self.self_attn(x, mask, mask_att_chunk_encoder=mask_att_chunk_encoder)
)
else:
x = stoch_layer_coeff * self.dropout(
self.self_attn(x, mask, mask_att_chunk_encoder=mask_att_chunk_encoder)
)
if not self.normalize_before:
x = self.norm1(x)
residual = x
if self.normalize_before:
x = self.norm2(x)
x = residual + stoch_layer_coeff * self.dropout(self.feed_forward(x))
if not self.normalize_before:
x = self.norm2(x)
return x, mask, cache, mask_att_chunk_encoder
class SelfAttentionEncoder(AbsEncoder):
"""
Author: Speech Lab of DAMO Academy, Alibaba Group
Self attention encoder in OpenNMT framework
"""
def __init__(
self,
input_size: int,
output_size: int = 256,
attention_heads: int = 4,
linear_units: int = 2048,
num_blocks: int = 6,
dropout_rate: float = 0.1,
positional_dropout_rate: float = 0.1,
attention_dropout_rate: float = 0.0,
input_layer: Optional[str] = "conv2d",
pos_enc_class=SinusoidalPositionEncoder,
normalize_before: bool = True,
concat_after: bool = False,
positionwise_layer_type: str = "linear",
positionwise_conv_kernel_size: int = 1,
padding_idx: int = -1,
interctc_layer_idx: List[int] = [],
interctc_use_conditioning: bool = False,
tf2torch_tensor_name_prefix_torch: str = "encoder",
tf2torch_tensor_name_prefix_tf: str = "seq2seq/encoder",
out_units=None,
):
super().__init__()
self._output_size = output_size
if input_layer == "linear":
self.embed = torch.nn.Sequential(
torch.nn.Linear(input_size, output_size),
torch.nn.LayerNorm(output_size),
torch.nn.Dropout(dropout_rate),
torch.nn.ReLU(),
pos_enc_class(output_size, positional_dropout_rate),
)
elif input_layer == "conv2d":
self.embed = Conv2dSubsampling(input_size, output_size, dropout_rate)
elif input_layer == "conv2d2":
self.embed = Conv2dSubsampling2(input_size, output_size, dropout_rate)
elif input_layer == "conv2d6":
self.embed = Conv2dSubsampling6(input_size, output_size, dropout_rate)
elif input_layer == "conv2d8":
self.embed = Conv2dSubsampling8(input_size, output_size, dropout_rate)
elif input_layer == "embed":
self.embed = torch.nn.Sequential(
torch.nn.Embedding(input_size, output_size, padding_idx=padding_idx),
SinusoidalPositionEncoder(),
)
elif input_layer is None:
if input_size == output_size:
self.embed = None
else:
self.embed = torch.nn.Linear(input_size, output_size)
elif input_layer == "pe":
self.embed = SinusoidalPositionEncoder()
elif input_layer == "null":
self.embed = None
else:
raise ValueError("unknown input_layer: " + input_layer)
self.normalize_before = normalize_before
if positionwise_layer_type == "linear":
positionwise_layer = PositionwiseFeedForward
positionwise_layer_args = (
output_size,
linear_units,
dropout_rate,
)
elif positionwise_layer_type == "conv1d":
positionwise_layer = MultiLayeredConv1d
positionwise_layer_args = (
output_size,
linear_units,
positionwise_conv_kernel_size,
dropout_rate,
)
elif positionwise_layer_type == "conv1d-linear":
positionwise_layer = Conv1dLinear
positionwise_layer_args = (
output_size,
linear_units,
positionwise_conv_kernel_size,
dropout_rate,
)
else:
raise NotImplementedError("Support only linear or conv1d.")
self.encoders = repeat(
num_blocks,
lambda lnum: (
EncoderLayer(
output_size,
output_size,
MultiHeadSelfAttention(
attention_heads,
output_size,
output_size,
attention_dropout_rate,
),
positionwise_layer(*positionwise_layer_args),
dropout_rate,
normalize_before,
concat_after,
)
if lnum > 0
else EncoderLayer(
input_size,
output_size,
MultiHeadSelfAttention(
attention_heads,
input_size if input_layer == "pe" or input_layer == "null" else output_size,
output_size,
attention_dropout_rate,
),
positionwise_layer(*positionwise_layer_args),
dropout_rate,
normalize_before,
concat_after,
)
),
)
if self.normalize_before:
self.after_norm = LayerNorm(output_size)
self.interctc_layer_idx = interctc_layer_idx
if len(interctc_layer_idx) > 0:
assert 0 < min(interctc_layer_idx) and max(interctc_layer_idx) < num_blocks
self.interctc_use_conditioning = interctc_use_conditioning
self.conditioning_layer = None
self.dropout = nn.Dropout(dropout_rate)
self.tf2torch_tensor_name_prefix_torch = tf2torch_tensor_name_prefix_torch
self.tf2torch_tensor_name_prefix_tf = tf2torch_tensor_name_prefix_tf
self.out_units = out_units
if out_units is not None:
self.output_linear = nn.Linear(output_size, out_units)
def output_size(self) -> int:
return self._output_size
def forward(
self,
xs_pad: torch.Tensor,
ilens: torch.Tensor,
prev_states: torch.Tensor = None,
ctc: CTC = None,
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
"""Embed positions in tensor.
Args:
xs_pad: input tensor (B, L, D)
ilens: input length (B)
prev_states: Not to be used now.
Returns:
position embedded tensor and mask
"""
masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
xs_pad = xs_pad * self.output_size() ** 0.5
if self.embed is None:
xs_pad = xs_pad
elif (
isinstance(self.embed, Conv2dSubsampling)
or isinstance(self.embed, Conv2dSubsampling2)
or isinstance(self.embed, Conv2dSubsampling6)
or isinstance(self.embed, Conv2dSubsampling8)
):
short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1))
if short_status:
raise TooShortUttError(
f"has {xs_pad.size(1)} frames and is too short for subsampling "
+ f"(it needs more than {limit_size} frames), return empty results",
xs_pad.size(1),
limit_size,
)
xs_pad, masks = self.embed(xs_pad, masks)
else:
xs_pad = self.embed(xs_pad)
xs_pad = self.dropout(xs_pad)
# encoder_outs = self.encoders0(xs_pad, masks)
# xs_pad, masks = encoder_outs[0], encoder_outs[1]
intermediate_outs = []
if len(self.interctc_layer_idx) == 0:
encoder_outs = self.encoders(xs_pad, masks)
xs_pad, masks = encoder_outs[0], encoder_outs[1]
else:
for layer_idx, encoder_layer in enumerate(self.encoders):
encoder_outs = encoder_layer(xs_pad, masks)
xs_pad, masks = encoder_outs[0], encoder_outs[1]
if layer_idx + 1 in self.interctc_layer_idx:
encoder_out = xs_pad
# intermediate outputs are also normalized
if self.normalize_before:
encoder_out = self.after_norm(encoder_out)
intermediate_outs.append((layer_idx + 1, encoder_out))
if self.interctc_use_conditioning:
ctc_out = ctc.softmax(encoder_out)
xs_pad = xs_pad + self.conditioning_layer(ctc_out)
if self.normalize_before:
xs_pad = self.after_norm(xs_pad)
if self.out_units is not None:
xs_pad = self.output_linear(xs_pad)
olens = masks.squeeze(1).sum(1)
if len(intermediate_outs) > 0:
return (xs_pad, intermediate_outs), olens, None
return xs_pad, olens, None