Sync from bytedesk-private: update

This commit is contained in:
jack ning
2024-12-14 10:43:18 +08:00
parent 476eebb101
commit 5e082909e4
3421 changed files with 812709 additions and 0 deletions

View File

View File

@@ -0,0 +1,286 @@
""" This implementation is adapted from https://github.com/wenet-e2e/wekws/blob/main/wekws/bin/compute_det.py."""
import os
import json
import logging
import argparse
import threading
import kaldiio
import torch
from funasr.utils.kws_utils import split_mixed_label
class thread_wrapper(threading.Thread):
def __init__(self, func, args=()):
super(thread_wrapper, self).__init__()
self.func = func
self.args = args
self.result = []
def run(self):
self.result = self.func(*self.args)
def get_result(self):
try:
return self.result
except Exception:
return None
def space_mixed_label(input_str):
splits = split_mixed_label(input_str)
space_str = ''.join(f'{sub} ' for sub in splits)
return space_str.strip()
def read_lists(list_file):
lists = []
with open(list_file, 'r', encoding='utf8') as fin:
for line in fin:
if line.strip() != '':
lists.append(line.strip())
return lists
def make_pair(wav_lists, trans_lists):
logging.info('make pair for wav-trans list')
trans_table = {}
for line in trans_lists:
arr = line.strip().replace('\t', ' ').split()
if len(arr) < 2:
logging.debug('invalid line in trans file: {}'.format(
line.strip()))
continue
trans_table[arr[0]] = line.replace(arr[0],'').strip()
lists = []
for line in wav_lists:
arr = line.strip().replace('\t', ' ').split()
if len(arr) == 2 and arr[0] in trans_table:
lists.append(
dict(key=arr[0],
txt=trans_table[arr[0]],
wav=arr[1],
sample_rate=16000))
else:
logging.debug("can't find corresponding trans for key: {}".format(
arr[0]))
continue
return lists
def count_duration(tid, data_lists):
results = []
for obj in data_lists:
assert 'key' in obj
assert 'wav' in obj
assert 'txt' in obj
key = obj['key']
wav_file = obj['wav']
txt = obj['txt']
try:
rate, waveform = kaldiio.load_mat(wav_file)
waveform = torch.tensor(waveform, dtype=torch.float32)
waveform = waveform.unsqueeze(0)
frames = len(waveform[0])
duration = frames / float(rate)
except:
logging.info(f'load file failed: {wav_file}')
duration = 0.0
obj['duration'] = duration
results.append(obj)
return results
def load_data_and_score(keywords_list, data_file, trans_file, score_file):
# score_table: {uttid: [keywordlist]}
score_table = {}
with open(score_file, 'r', encoding='utf8') as fin:
# read score file and store in table
for line in fin:
arr = line.strip().split()
key = arr[0]
is_detected = arr[1]
if is_detected == 'detected':
if key not in score_table:
score_table.update(
{key: {
'kw': space_mixed_label(arr[2]),
'confi': float(arr[3])
}})
else:
if key not in score_table:
score_table.update({key: {'kw': 'unknown', 'confi': -1.0}})
wav_lists = read_lists(data_file)
trans_lists = read_lists(trans_file)
data_lists = make_pair(wav_lists, trans_lists)
logging.info(f'origin list samples: {len(data_lists)}')
# count duration for each wave
num_workers = 8
start = 0
step = int(len(data_lists) / num_workers)
tasks = []
for idx in range(num_workers):
if idx != num_workers - 1:
task = thread_wrapper(count_duration,
(idx, data_lists[start:start + step]))
else:
task = thread_wrapper(count_duration, (idx, data_lists[start:]))
task.start()
tasks.append(task)
start += step
duration_lists = []
for task in tasks:
task.join()
duration_lists += task.get_result()
logging.info(f'after list samples: {len(duration_lists)}')
# build empty structure for keyword-filler infos
keyword_filler_table = {}
for keyword in keywords_list:
keyword = space_mixed_label(keyword)
keyword_filler_table[keyword] = {}
keyword_filler_table[keyword]['keyword_table'] = {}
keyword_filler_table[keyword]['keyword_duration'] = 0.0
keyword_filler_table[keyword]['filler_table'] = {}
keyword_filler_table[keyword]['filler_duration'] = 0.0
for obj in duration_lists:
assert 'key' in obj
assert 'wav' in obj
assert 'txt' in obj
assert 'duration' in obj
key = obj['key']
wav_file = obj['wav']
txt = obj['txt']
txt = space_mixed_label(txt)
txt_regstr_lrblk = ' ' + txt + ' '
duration = obj['duration']
assert key in score_table
for keyword in keywords_list:
keyword = space_mixed_label(keyword)
keyword_regstr_lrblk = ' ' + keyword + ' '
if txt_regstr_lrblk.find(keyword_regstr_lrblk) != -1:
if keyword == score_table[key]['kw']:
keyword_filler_table[keyword]['keyword_table'].update(
{key: score_table[key]['confi']})
else:
# uttrance detected but not match this keyword
keyword_filler_table[keyword]['keyword_table'].update(
{key: -1.0})
keyword_filler_table[keyword]['keyword_duration'] += duration
else:
if keyword == score_table[key]['kw']:
keyword_filler_table[keyword]['filler_table'].update(
{key: score_table[key]['confi']})
else:
# uttrance if detected, which is not FA for this keyword
keyword_filler_table[keyword]['filler_table'].update(
{key: -1.0})
keyword_filler_table[keyword]['filler_duration'] += duration
return keyword_filler_table
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='compute det curve')
parser.add_argument('--keywords',
type=str,
required=True,
help='preset keyword str, input all keywords')
parser.add_argument('--test_data', required=True, help='test data file')
parser.add_argument('--trans_data',
required=True,
default='',
help='transcription of test data')
parser.add_argument('--score_file', required=True, help='score file')
parser.add_argument('--step',
type=float,
default=0.001,
help='threshold step')
parser.add_argument('--stats_dir',
required=True,
help='to save det stats files')
args = parser.parse_args()
root_logger = logging.getLogger()
handlers = root_logger.handlers[:]
for handler in handlers:
root_logger.removeHandler(handler)
handler.close()
logging.basicConfig(level=logging.DEBUG,
format='%(asctime)s %(levelname)s %(message)s')
keywords_list = args.keywords.strip().split(',')
keyword_filler_table = load_data_and_score(keywords_list, args.test_data,
args.trans_data,
args.score_file)
stats_files = {}
for keyword in keywords_list:
keyword = space_mixed_label(keyword)
keyword_dur = keyword_filler_table[keyword]['keyword_duration']
keyword_num = len(keyword_filler_table[keyword]['keyword_table'])
filler_dur = keyword_filler_table[keyword]['filler_duration']
filler_num = len(keyword_filler_table[keyword]['filler_table'])
if keyword_num <= 0:
print('Can\'t compute det for {} without positive sample'.format(keyword))
continue
if filler_num <= 0:
print('Can\'t compute det for {} without negative sample'.format(keyword))
continue
logging.info('Computing det for {}'.format(keyword))
logging.info(' Keyword duration: {} Hours, wave number: {}'.format(
keyword_dur / 3600.0, keyword_num))
logging.info(' Filler duration: {} Hours'.format(filler_dur / 3600.0))
stats_file = os.path.join(args.stats_dir, 'stats.' + keyword.replace(' ', '_') + '.txt')
with open(stats_file, 'w', encoding='utf8') as fout:
threshold = 0.0
while threshold <= 1.0:
num_false_reject = 0
num_true_detect = 0
# transverse the all keyword_table
for key, confi in keyword_filler_table[keyword][
'keyword_table'].items():
if confi < threshold:
num_false_reject += 1
else:
num_true_detect += 1
num_false_alarm = 0
# transverse the all filler_table
for key, confi in keyword_filler_table[keyword][
'filler_table'].items():
if confi >= threshold:
num_false_alarm += 1
# print(f'false alarm: {keyword}, {key}, {confi}')
# false_reject_rate = num_false_reject / keyword_num
true_detect_rate = num_true_detect / keyword_num
num_false_alarm = max(num_false_alarm, 1e-6)
false_alarm_per_hour = num_false_alarm / (filler_dur / 3600.0)
false_alarm_rate = num_false_alarm / filler_num
fout.write('{:.3f} {:.6f} {:.6f} {:.6f}\n'.format(
threshold, true_detect_rate, false_alarm_rate,
false_alarm_per_hour))
threshold += args.step
stats_files[keyword] = stats_file

View File

@@ -0,0 +1,70 @@
from pathlib import Path
from typing import Union
import warnings
class DatadirWriter:
"""Writer class to create kaldi like data directory.
Examples:
>>> with DatadirWriter("output") as writer:
... # output/sub.txt is created here
... subwriter = writer["sub.txt"]
... # Write "uttidA some/where/a.wav"
... subwriter["uttidA"] = "some/where/a.wav"
... subwriter["uttidB"] = "some/where/b.wav"
"""
def __init__(self, p: Union[Path, str]):
self.path = Path(p)
self.chilidren = {}
self.fd = None
self.has_children = False
self.keys = set()
def __enter__(self):
return self
def __getitem__(self, key: str) -> "DatadirWriter":
if self.fd is not None:
raise RuntimeError("This writer points out a file")
if key not in self.chilidren:
w = DatadirWriter((self.path / key))
self.chilidren[key] = w
self.has_children = True
retval = self.chilidren[key]
return retval
def __setitem__(self, key: str, value: str):
if self.has_children:
raise RuntimeError("This writer points out a directory")
if key in self.keys:
warnings.warn(f"Duplicated: {key}")
if self.fd is None:
self.path.parent.mkdir(parents=True, exist_ok=True)
self.fd = self.path.open("w", encoding="utf-8")
self.keys.add(key)
self.fd.write(f"{key} {value}\n")
self.fd.flush()
def __exit__(self, exc_type, exc_val, exc_tb):
self.close()
def close(self):
if self.has_children:
prev_child = None
for child in self.chilidren.values():
child.close()
if prev_child is not None and prev_child.keys != child.keys:
warnings.warn(
f"Ids are mismatching between " f"{prev_child.path} and {child.path}"
)
prev_child = child
elif self.fd is not None:
self.fd.close()

View File

@@ -0,0 +1,61 @@
import importlib.util
import importlib.util
import inspect
import os.path
import sys
def load_module_from_path(file_path):
"""
从给定的文件路径动态加载模块。
:param file_path: 模块文件的绝对路径。
:return: 加载的模块
"""
module_name = file_path.split("/")[-1].replace(".py", "")
spec = importlib.util.spec_from_file_location(module_name, file_path)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
return module
def import_module_from_path(file_path: str):
if file_path.startswith("http"):
from funasr.download.file import download_from_url
file_path = download_from_url(file_path)
file_dir = os.path.dirname(file_path)
# file_name = os.path.basename(file_path)
module_name = file_path.split("/")[-1].replace(".py", "")
if len(file_dir) < 1:
file_dir = "./"
sys.path.append(file_dir)
try:
importlib.import_module(module_name)
print(f"Loading remote code successfully: {file_path}")
except Exception as e:
print(f"Loading remote code failed: {file_path}, {e}")
#
# def load_module_from_path(module_name, file_path):
# """
# 从给定的文件路径动态加载模块。
#
# :param module_name: 动态加载的模块的名称。
# :param file_path: 模块文件的绝对路径。
# :return: 加载的模块
# """
# # 创建加载模块的spec规格
# spec = importlib.util.spec_from_file_location(module_name, file_path)
#
# # 根据spec创建模块
# module = importlib.util.module_from_spec(spec)
#
# # 执行模块的代码来实际加载它
# spec.loader.exec_module(module)
#
# return module

View File

@@ -0,0 +1,202 @@
import os
import torch
import functools
def export(
model, data_in=None, quantize: bool = False, opset_version: int = 14, type="onnx", **kwargs
):
model_scripts = model.export(**kwargs)
export_dir = kwargs.get("output_dir", os.path.dirname(kwargs.get("init_param")))
os.makedirs(export_dir, exist_ok=True)
if not isinstance(model_scripts, (list, tuple)):
model_scripts = (model_scripts,)
for m in model_scripts:
m.eval()
if type == "onnx":
_onnx(
m,
data_in=data_in,
quantize=quantize,
opset_version=opset_version,
export_dir=export_dir,
**kwargs,
)
elif type == "torchscript":
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Exporting torchscripts on device {}".format(device))
_torchscripts(m, path=export_dir, device=device)
elif type == "bladedisc":
assert (
torch.cuda.is_available()
), "Currently bladedisc optimization for FunASR only supports GPU"
# bladedisc only optimizes encoder/decoder modules
if hasattr(m, "encoder") and hasattr(m, "decoder"):
_bladedisc_opt_for_encdec(m, path=export_dir, enable_fp16=True)
else:
_torchscripts(m, path=export_dir, device="cuda")
print("output dir: {}".format(export_dir))
return export_dir
def _onnx(
model,
data_in=None,
quantize: bool = False,
opset_version: int = 14,
export_dir: str = None,
**kwargs,
):
dummy_input = model.export_dummy_inputs()
verbose = kwargs.get("verbose", False)
if isinstance(model.export_name, str):
export_name = model.export_name + ".onnx"
else:
export_name = model.export_name()
model_path = os.path.join(export_dir, export_name)
torch.onnx.export(
model,
dummy_input,
model_path,
verbose=verbose,
opset_version=opset_version,
input_names=model.export_input_names(),
output_names=model.export_output_names(),
dynamic_axes=model.export_dynamic_axes(),
)
if quantize:
from onnxruntime.quantization import QuantType, quantize_dynamic
import onnx
quant_model_path = model_path.replace(".onnx", "_quant.onnx")
onnx_model = onnx.load(model_path)
nodes = [n.name for n in onnx_model.graph.node]
nodes_to_exclude = [
m for m in nodes if "output" in m or "bias_encoder" in m or "bias_decoder" in m
]
print("Quantizing model from {} to {}".format(model_path, quant_model_path))
quantize_dynamic(
model_input=model_path,
model_output=quant_model_path,
op_types_to_quantize=["MatMul"],
per_channel=True,
reduce_range=False,
weight_type=QuantType.QUInt8,
nodes_to_exclude=nodes_to_exclude,
)
def _torchscripts(model, path, device="cuda"):
dummy_input = model.export_dummy_inputs()
if device == "cuda":
model = model.cuda()
if isinstance(dummy_input, torch.Tensor):
dummy_input = dummy_input.cuda()
else:
dummy_input = tuple([i.cuda() for i in dummy_input])
model_script = torch.jit.trace(model, dummy_input)
if isinstance(model.export_name, str):
model_script.save(os.path.join(path, f"{model.export_name}".replace("onnx", "torchscript")))
else:
model_script.save(os.path.join(path, f"{model.export_name()}".replace("onnx", "torchscript")))
def _bladedisc_opt(model, model_inputs, enable_fp16=True):
model = model.eval()
try:
import torch_blade
except Exception as e:
print(
f"Warning, if you are exporting bladedisc, please install it and try it again: pip install -U torch_blade\n"
)
torch_config = torch_blade.config.Config()
torch_config.enable_fp16 = enable_fp16
with torch.no_grad(), torch_config:
opt_model = torch_blade.optimize(
model,
allow_tracing=True,
model_inputs=model_inputs,
)
return opt_model
def _rescale_input_hook(m, x, scale):
if len(x) > 1:
return (x[0] / scale, *x[1:])
else:
return (x[0] / scale,)
def _rescale_output_hook(m, x, y, scale):
if isinstance(y, tuple):
return (y[0] / scale, *y[1:])
else:
return y / scale
def _rescale_encoder_model(model, input_data):
# Calculate absmax
absmax = torch.tensor(0).cuda()
def stat_input_hook(m, x, y):
val = x[0] if isinstance(x, tuple) else x
absmax.copy_(torch.max(absmax, val.detach().abs().max()))
encoders = model.encoder.model.encoders
hooks = [m.register_forward_hook(stat_input_hook) for m in encoders]
model = model.cuda()
model(*input_data)
for h in hooks:
h.remove()
# Rescale encoder modules
fp16_scale = int(2 * absmax // 65536)
print(f"rescale encoder modules with factor={fp16_scale}")
model.encoder.model.encoders0.register_forward_pre_hook(
functools.partial(_rescale_input_hook, scale=fp16_scale),
)
for name, m in model.encoder.model.named_modules():
if name.endswith("self_attn"):
m.register_forward_hook(functools.partial(_rescale_output_hook, scale=fp16_scale))
if name.endswith("feed_forward.w_2"):
state_dict = {k: v / fp16_scale for k, v in m.state_dict().items()}
m.load_state_dict(state_dict)
def _bladedisc_opt_for_encdec(model, path, enable_fp16):
# Get input data
# TODO: better to use real data
input_data = model.export_dummy_inputs()
if isinstance(input_data, torch.Tensor):
input_data = input_data.cuda()
else:
input_data = tuple([i.cuda() for i in input_data])
# Get input data for decoder module
decoder_inputs = list()
def get_input_hook(m, x):
decoder_inputs.extend(list(x))
hook = model.decoder.register_forward_pre_hook(get_input_hook)
model = model.cuda()
model(*input_data)
hook.remove()
# Prevent FP16 overflow
if enable_fp16:
_rescale_encoder_model(model, input_data)
# Export and optimize encoder/decoder modules
model.encoder = _bladedisc_opt(model.encoder, input_data[:2])
model.decoder = _bladedisc_opt(model.decoder, tuple(decoder_inputs))
model_script = torch.jit.trace(model, input_data)
model_script.save(os.path.join(path, f"{model.export_name}_blade.torchscript"))

View File

@@ -0,0 +1,36 @@
import subprocess
def install_requirements(requirements_path):
try:
result = subprocess.run(
["pip", "install", "-r", requirements_path],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
)
# check status
if result.returncode == 0:
print("install model requirements successfully")
return True
else:
print("fail to install model requirements! ")
print("error", result.stderr)
return False
except Exception as e:
result = subprocess.run(
["pip", "install", "-r", requirements_path],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
)
# check status
if result.returncode == 0:
print("install model requirements successfully")
return True
else:
print("fail to install model requirements! ")
print("error", result.stderr)
return False

View File

@@ -0,0 +1,284 @@
import re
import logging
import torch
import math
from collections import defaultdict
from typing import List, Optional, Tuple
symbol_str = '[!"#$%&\'()*+,-./:;<>=?@,。?★、…【】《》?“”‘’![\\]^_`{|}~\s]+'
def split_mixed_label(input_str):
tokens = []
s = input_str.lower()
while len(s) > 0:
match = re.match(r'[A-Za-z!?,<>()\']+', s)
if match is not None:
word = match.group(0)
else:
word = s[0:1]
tokens.append(word)
s = s.replace(word, '', 1).strip(' ')
return tokens
def query_token_set(txt, symbol_table, lexicon_table):
tokens_str = tuple()
tokens_idx = tuple()
if txt in symbol_table:
tokens_str = tokens_str + (txt, )
tokens_idx = tokens_idx + (symbol_table[txt], )
return tokens_str, tokens_idx
parts = split_mixed_label(txt)
for part in parts:
if part == '!sil' or part == '(sil)' or part == '<sil>':
tokens_str = tokens_str + ('!sil', )
elif part == '<blank>' or part == '<blank>':
tokens_str = tokens_str + ('<blank>', )
elif part == '(noise)' or part == 'noise)' or part == '(noise' or part == '<noise>':
tokens_str = tokens_str + ('<unk>', )
elif part in symbol_table:
tokens_str = tokens_str + (part, )
elif part in lexicon_table:
for ch in lexicon_table[part]:
tokens_str = tokens_str + (ch, )
else:
part = re.sub(symbol_str, '', part)
for ch in part:
tokens_str = tokens_str + (ch, )
for ch in tokens_str:
if ch in symbol_table:
tokens_idx = tokens_idx + (symbol_table[ch], )
elif ch == '!sil':
if 'sil' in symbol_table:
tokens_idx = tokens_idx + (symbol_table['sil'], )
else:
tokens_idx = tokens_idx + (symbol_table['<blank>'], )
elif ch == '<unk>':
if '<unk>' in symbol_table:
tokens_idx = tokens_idx + (symbol_table['<unk>'], )
else:
tokens_idx = tokens_idx + (symbol_table['<blank>'], )
else:
if '<unk>' in symbol_table:
tokens_idx = tokens_idx + (symbol_table['<unk>'], )
logging.info(f'\'{ch}\' is not in token set, replace with <unk>')
else:
tokens_idx = tokens_idx + (symbol_table['<blank>'], )
logging.info(f'\'{ch}\' is not in token set, replace with <blank>')
return tokens_str, tokens_idx
class KwsCtcPrefixDecoder():
"""Decoder interface wrapper for CTCPrefixDecode."""
def __init__(
self,
ctc: torch.nn.Module,
keywords: str,
token_list: list,
seg_dict: dict,
):
"""Initialize class.
Args:
ctc (torch.nn.Module): The CTC implementation.
For example, :class:`espnet.nets.pytorch_backend.ctc.CTC`
"""
self.ctc = ctc
self.token_list = token_list
token_table = {}
for token in token_list:
token_table[token] = token_list.index(token)
self.keywords_idxset = {0}
self.keywords_token = {}
self.keywords_str = keywords
keywords_list = self.keywords_str.strip().replace(' ', '').split(',')
for keyword in keywords_list:
strs, indexs = query_token_set(keyword, token_table, seg_dict)
self.keywords_token[keyword] = {}
self.keywords_token[keyword]['token_id'] = indexs
self.keywords_token[keyword]['token_str'] = ''.join('%s ' % str(i) for i in indexs)
[ self.keywords_idxset.add(i) for i in indexs ]
def beam_search(
self,
logits: torch.Tensor,
logits_lengths: torch.Tensor,
keywords_tokenset: set = None,
score_beam_size: int = 3,
path_beam_size: int = 20,
) -> Tuple[List[List[int]], torch.Tensor]:
""" CTC prefix beam search inner implementation
Args:
logits (torch.Tensor): (1, max_len, vocab_size)
logits_lengths (torch.Tensor): (1, )
keywords_tokenset (set): token set for filtering score
score_beam_size (int): beam size for score
path_beam_size (int): beam size for path
Returns:
List[List[int]]: nbest results
"""
maxlen = logits.size(0)
ctc_probs = logits
cur_hyps = [(tuple(), (1.0, 0.0, []))]
# CTC beam search step by step
for t in range(0, maxlen):
probs = ctc_probs[t] # (vocab_size,)
# key: prefix, value (pb, pnb), default value(-inf, -inf)
next_hyps = defaultdict(lambda: (0.0, 0.0, []))
# 2.1 First beam prune: select topk best
top_k_probs, top_k_index = probs.topk(
score_beam_size) # (score_beam_size,)
# filter prob score that is too small
filter_probs = []
filter_index = []
for prob, idx in zip(top_k_probs.tolist(), top_k_index.tolist()):
if keywords_tokenset is not None:
if prob > 0.05 and idx in keywords_tokenset:
filter_probs.append(prob)
filter_index.append(idx)
else:
if prob > 0.05:
filter_probs.append(prob)
filter_index.append(idx)
if len(filter_index) == 0:
continue
for s in filter_index:
ps = probs[s].item()
# print(f'frame:{t}, token:{s}, score:{ps}')
for prefix, (pb, pnb, cur_nodes) in cur_hyps:
last = prefix[-1] if len(prefix) > 0 else None
if s == 0: # blank
n_pb, n_pnb, nodes = next_hyps[prefix]
n_pb = n_pb + pb * ps + pnb * ps
nodes = cur_nodes.copy()
next_hyps[prefix] = (n_pb, n_pnb, nodes)
elif s == last:
if not math.isclose(pnb, 0.0, abs_tol=0.000001):
# Update *ss -> *s;
n_pb, n_pnb, nodes = next_hyps[prefix]
n_pnb = n_pnb + pnb * ps
nodes = cur_nodes.copy()
if ps > nodes[-1]['prob']: # update frame and prob
nodes[-1]['prob'] = ps
nodes[-1]['frame'] = t
next_hyps[prefix] = (n_pb, n_pnb, nodes)
if not math.isclose(pb, 0.0, abs_tol=0.000001):
# Update *s-s -> *ss, - is for blank
n_prefix = prefix + (s, )
n_pb, n_pnb, nodes = next_hyps[n_prefix]
n_pnb = n_pnb + pb * ps
nodes = cur_nodes.copy()
nodes.append(dict(token=s, frame=t,
prob=ps)) # to record token prob
next_hyps[n_prefix] = (n_pb, n_pnb, nodes)
else:
n_prefix = prefix + (s, )
n_pb, n_pnb, nodes = next_hyps[n_prefix]
if nodes:
if ps > nodes[-1]['prob']: # update frame and prob
nodes[-1]['prob'] = ps
nodes[-1]['frame'] = t
else:
nodes = cur_nodes.copy()
nodes.append(dict(token=s, frame=t,
prob=ps)) # to record token prob
n_pnb = n_pnb + pb * ps + pnb * ps
next_hyps[n_prefix] = (n_pb, n_pnb, nodes)
# 2.2 Second beam prune
next_hyps = sorted(next_hyps.items(),
key=lambda x: (x[1][0] + x[1][1]),
reverse=True)
cur_hyps = next_hyps[:path_beam_size]
hyps = [(y[0], y[1][0] + y[1][1], y[1][2]) for y in cur_hyps]
return hyps
def is_sublist(self, main_list, check_list):
if len(main_list) < len(check_list):
return -1
if len(main_list) == len(check_list):
return 0 if main_list == check_list else -1
for i in range(len(main_list) - len(check_list)):
if main_list[i] == check_list[0]:
for j in range(len(check_list)):
if main_list[i + j] != check_list[j]:
break
else:
return i
else:
return -1
def _decode_inside(
self,
logits: torch.Tensor,
logits_lengths: torch.Tensor,
):
hyps = self.beam_search(logits, logits_lengths, self.keywords_idxset)
hit_keyword = None
hit_score = 1.0
# start = 0; end = 0
for one_hyp in hyps:
prefix_ids = one_hyp[0]
# path_score = one_hyp[1]
prefix_nodes = one_hyp[2]
assert len(prefix_ids) == len(prefix_nodes)
for word in self.keywords_token.keys():
lab = self.keywords_token[word]['token_id']
offset = self.is_sublist(prefix_ids, lab)
if offset != -1:
hit_keyword = word
for idx in range(offset, offset + len(lab)):
hit_score *= prefix_nodes[idx]['prob']
break
if hit_keyword is not None:
hit_score = math.sqrt(hit_score)
break
if hit_keyword is not None:
return True, hit_keyword, hit_score
else:
return False, None, None
def decode(self, x: torch.Tensor):
"""Get an initial state for decoding.
Args:
x (torch.Tensor): The encoded feature tensor
Returns: decode result
"""
raw_logp = self.ctc.softmax(x.unsqueeze(0)).detach().squeeze(0).cpu()
xlen = torch.tensor([raw_logp.size(1)])
return self._decode_inside(raw_logp, xlen)

View File

@@ -0,0 +1,217 @@
import os
import torch
import json
import torch.distributed as dist
import numpy as np
import kaldiio
import librosa
import torchaudio
import time
import logging
from torch.nn.utils.rnn import pad_sequence
try:
from funasr.download.file import download_from_url
except:
print("urllib is not installed, if you infer from url, please install it first.")
import pdb
import subprocess
from subprocess import CalledProcessError, run
def is_ffmpeg_installed():
try:
output = subprocess.check_output(["ffmpeg", "-version"], stderr=subprocess.STDOUT)
return "ffmpeg version" in output.decode("utf-8")
except (subprocess.CalledProcessError, FileNotFoundError):
return False
use_ffmpeg = False
if is_ffmpeg_installed():
use_ffmpeg = True
else:
print(
"Notice: ffmpeg is not installed. torchaudio is used to load audio\n"
"If you want to use ffmpeg backend to load audio, please install it by:"
"\n\tsudo apt install ffmpeg # ubuntu"
"\n\t# brew install ffmpeg # mac"
)
def load_audio_text_image_video(
data_or_path_or_list,
fs: int = 16000,
audio_fs: int = 16000,
data_type="sound",
tokenizer=None,
**kwargs,
):
if isinstance(data_or_path_or_list, (list, tuple)):
if data_type is not None and isinstance(data_type, (list, tuple)):
data_types = [data_type] * len(data_or_path_or_list)
data_or_path_or_list_ret = [[] for d in data_type]
for i, (data_type_i, data_or_path_or_list_i) in enumerate(
zip(data_types, data_or_path_or_list)
):
for j, (data_type_j, data_or_path_or_list_j) in enumerate(
zip(data_type_i, data_or_path_or_list_i)
):
data_or_path_or_list_j = load_audio_text_image_video(
data_or_path_or_list_j,
fs=fs,
audio_fs=audio_fs,
data_type=data_type_j,
tokenizer=tokenizer,
**kwargs,
)
data_or_path_or_list_ret[j].append(data_or_path_or_list_j)
return data_or_path_or_list_ret
else:
return [
load_audio_text_image_video(
audio, fs=fs, audio_fs=audio_fs, data_type=data_type, **kwargs
)
for audio in data_or_path_or_list
]
if isinstance(data_or_path_or_list, str) and data_or_path_or_list.startswith(
("http://", "https://")
): # download url to local file
data_or_path_or_list = download_from_url(data_or_path_or_list)
if isinstance(data_or_path_or_list, str) and os.path.exists(data_or_path_or_list): # local file
if data_type is None or data_type == "sound":
# if use_ffmpeg:
# data_or_path_or_list = _load_audio_ffmpeg(data_or_path_or_list, sr=fs)
# data_or_path_or_list = torch.from_numpy(data_or_path_or_list).squeeze() # [n_samples,]
# else:
# data_or_path_or_list, audio_fs = torchaudio.load(data_or_path_or_list)
# if kwargs.get("reduce_channels", True):
# data_or_path_or_list = data_or_path_or_list.mean(0)
try:
data_or_path_or_list, audio_fs = torchaudio.load(data_or_path_or_list)
if kwargs.get("reduce_channels", True):
data_or_path_or_list = data_or_path_or_list.mean(0)
except:
data_or_path_or_list = _load_audio_ffmpeg(data_or_path_or_list, sr=fs)
data_or_path_or_list = torch.from_numpy(
data_or_path_or_list
).squeeze() # [n_samples,]
elif data_type == "text" and tokenizer is not None:
data_or_path_or_list = tokenizer.encode(data_or_path_or_list)
elif data_type == "image": # undo
pass
elif data_type == "video": # undo
pass
# if data_in is a file or url, set is_final=True
if "cache" in kwargs:
kwargs["cache"]["is_final"] = True
kwargs["cache"]["is_streaming_input"] = False
elif isinstance(data_or_path_or_list, str) and data_type == "text" and tokenizer is not None:
data_or_path_or_list = tokenizer.encode(data_or_path_or_list)
elif isinstance(data_or_path_or_list, np.ndarray): # audio sample point
data_or_path_or_list = torch.from_numpy(data_or_path_or_list) # .squeeze() # [n_samples,]
elif isinstance(data_or_path_or_list, str) and data_type == "kaldi_ark":
data_mat = kaldiio.load_mat(data_or_path_or_list)
if isinstance(data_mat, tuple):
audio_fs, mat = data_mat
else:
mat = data_mat
if mat.dtype == "int16" or mat.dtype == "int32":
mat = mat.astype(np.float64)
mat = mat / 32768
if mat.ndim == 2:
mat = mat[:, 0]
data_or_path_or_list = mat
else:
pass
# print(f"unsupport data type: {data_or_path_or_list}, return raw data")
if audio_fs != fs and data_type != "text":
resampler = torchaudio.transforms.Resample(audio_fs, fs)
data_or_path_or_list = resampler(data_or_path_or_list[None, :])[0, :]
return data_or_path_or_list
def load_bytes(input):
middle_data = np.frombuffer(input, dtype=np.int16)
middle_data = np.asarray(middle_data)
if middle_data.dtype.kind not in "iu":
raise TypeError("'middle_data' must be an array of integers")
dtype = np.dtype("float32")
if dtype.kind != "f":
raise TypeError("'dtype' must be a floating point type")
i = np.iinfo(middle_data.dtype)
abs_max = 2 ** (i.bits - 1)
offset = i.min + abs_max
array = np.frombuffer((middle_data.astype(dtype) - offset) / abs_max, dtype=np.float32)
return array
def extract_fbank(data, data_len=None, data_type: str = "sound", frontend=None, **kwargs):
if isinstance(data, np.ndarray):
data = torch.from_numpy(data)
if len(data.shape) < 2:
data = data[None, :] # data: [batch, N]
data_len = [data.shape[1]] if data_len is None else data_len
elif isinstance(data, torch.Tensor):
if len(data.shape) < 2:
data = data[None, :] # data: [batch, N]
data_len = [data.shape[1]] if data_len is None else data_len
elif isinstance(data, (list, tuple)):
data_list, data_len = [], []
for data_i in data:
if isinstance(data_i, np.ndarray):
data_i = torch.from_numpy(data_i)
data_list.append(data_i)
data_len.append(data_i.shape[0])
data = pad_sequence(data_list, batch_first=True) # data: [batch, N]
data, data_len = frontend(data, data_len, **kwargs)
if isinstance(data_len, (list, tuple)):
data_len = torch.tensor([data_len])
return data.to(torch.float32), data_len.to(torch.int32)
def _load_audio_ffmpeg(file: str, sr: int = 16000):
"""
Open an audio file and read as mono waveform, resampling as necessary
Parameters
----------
file: str
The audio file to open
sr: int
The sample rate to resample the audio if necessary
Returns
-------
A NumPy array containing the audio waveform, in float32 dtype.
"""
# This launches a subprocess to decode audio while down-mixing
# and resampling as necessary. Requires the ffmpeg CLI in PATH.
# fmt: off
cmd = [
"ffmpeg",
"-nostdin",
"-threads", "0",
"-i", file,
"-f", "s16le",
"-ac", "1",
"-acodec", "pcm_s16le",
"-ar", str(sr),
"-"
]
# fmt: on
try:
out = run(cmd, capture_output=True, check=True).stdout
except CalledProcessError as e:
raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e
return np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0

View File

@@ -0,0 +1,119 @@
import os
import io
import shutil
import logging
from collections import OrderedDict
import numpy as np
from omegaconf import DictConfig, OmegaConf
def statistic_model_parameters(model, prefix=None):
var_dict = model.state_dict()
numel = 0
for i, key in enumerate(
sorted(list([x for x in var_dict.keys() if "num_batches_tracked" not in x]))
):
if prefix is None or key.startswith(prefix):
numel += var_dict[key].numel()
return numel
def int2vec(x, vec_dim=8, dtype=np.int32):
b = ("{:0" + str(vec_dim) + "b}").format(x)
# little-endian order: lower bit first
return (np.array(list(b)[::-1]) == "1").astype(dtype)
def seq2arr(seq, vec_dim=8):
return np.row_stack([int2vec(int(x), vec_dim) for x in seq])
def load_scp_as_dict(scp_path, value_type="str", kv_sep=" "):
with io.open(scp_path, "r", encoding="utf-8") as f:
ret_dict = OrderedDict()
for one_line in f.readlines():
one_line = one_line.strip()
pos = one_line.find(kv_sep)
key, value = one_line[:pos], one_line[pos + 1 :]
if value_type == "list":
value = value.split(" ")
ret_dict[key] = value
return ret_dict
def load_scp_as_list(scp_path, value_type="str", kv_sep=" "):
with io.open(scp_path, "r", encoding="utf8") as f:
ret_dict = []
for one_line in f.readlines():
one_line = one_line.strip()
pos = one_line.find(kv_sep)
key, value = one_line[:pos], one_line[pos + 1 :]
if value_type == "list":
value = value.split(" ")
ret_dict.append((key, value))
return ret_dict
def deep_update(original, update):
for key, value in update.items():
if isinstance(value, dict) and key in original:
if len(value) == 0:
original[key] = value
deep_update(original[key], value)
else:
original[key] = value
def prepare_model_dir(**kwargs):
os.makedirs(kwargs.get("output_dir", "./"), exist_ok=True)
yaml_file = os.path.join(kwargs.get("output_dir", "./"), "config.yaml")
OmegaConf.save(config=kwargs, f=yaml_file)
logging.info(f"kwargs: {kwargs}")
logging.info("config.yaml is saved to: %s", yaml_file)
model_path = kwargs.get("model_path", None)
if model_path is not None:
config_json = os.path.join(model_path, "configuration.json")
if os.path.exists(config_json):
shutil.copy(
config_json, os.path.join(kwargs.get("output_dir", "./"), "configuration.json")
)
def extract_filename_without_extension(file_path):
"""
从给定的文件路径中提取文件名(不包含路径和扩展名)
:param file_path: 完整的文件路径
:return: 文件名(不含路径和扩展名)
"""
# 首先使用os.path.basename获取路径中的文件名部分含扩展名
filename_with_extension = os.path.basename(file_path)
# 然后使用os.path.splitext分离文件名和扩展名
filename, extension = os.path.splitext(filename_with_extension)
# 返回不包含扩展名的文件名
return filename
def smart_remove(path):
"""Intelligently removes files, empty directories, and non-empty directories recursively."""
# Check if the provided path exists
if not os.path.exists(path):
print(f"{path} does not exist.")
return
# If the path is a file, delete it
if os.path.isfile(path):
os.remove(path)
print(f"File {path} has been deleted.")
# If the path is a directory
elif os.path.isdir(path):
try:
# Attempt to remove an empty directory
os.rmdir(path)
print(f"Empty directory {path} has been deleted.")
except OSError:
# If the directory is not empty, remove it along with all its contents
shutil.rmtree(path)
print(f"Non-empty directory {path} has been recursively deleted.")

View File

@@ -0,0 +1,423 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import string
import logging
from typing import Any, List, Union
def isChinese(ch: str):
if "\u4e00" <= ch <= "\u9fff" or "\u0030" <= ch <= "\u0039" or ch == "@":
return True
return False
def isAllChinese(word: Union[List[Any], str]):
word_lists = []
for i in word:
cur = i.replace(" ", "")
cur = cur.replace("</s>", "")
cur = cur.replace("<s>", "")
cur = cur.replace("<unk>", "")
cur = cur.replace("<OOV>", "")
word_lists.append(cur)
if len(word_lists) == 0:
return False
for ch in word_lists:
if isChinese(ch) is False:
return False
return True
def isAllAlpha(word: Union[List[Any], str]):
word_lists = []
for i in word:
cur = i.replace(" ", "")
cur = cur.replace("</s>", "")
cur = cur.replace("<s>", "")
cur = cur.replace("<unk>", "")
cur = cur.replace("<OOV>", "")
word_lists.append(cur)
if len(word_lists) == 0:
return False
for ch in word_lists:
if ch.isalpha() is False and ch != "'":
return False
elif ch.isalpha() is True and isChinese(ch) is True:
return False
return True
# def abbr_dispose(words: List[Any]) -> List[Any]:
def abbr_dispose(words: List[Any], time_stamp: List[List] = None) -> List[Any]:
words_size = len(words)
word_lists = []
abbr_begin = []
abbr_end = []
last_num = -1
ts_lists = []
ts_nums = []
ts_index = 0
for num in range(words_size):
if num <= last_num:
continue
if len(words[num]) == 1 and words[num].encode("utf-8").isalpha():
if (
num + 1 < words_size
and words[num + 1] == " "
and num + 2 < words_size
and len(words[num + 2]) == 1
and words[num + 2].encode("utf-8").isalpha()
):
# found the begin of abbr
abbr_begin.append(num)
num += 2
abbr_end.append(num)
# to find the end of abbr
while True:
num += 1
if num < words_size and words[num] == " ":
num += 1
if (
num < words_size
and len(words[num]) == 1
and words[num].encode("utf-8").isalpha()
):
abbr_end.pop()
abbr_end.append(num)
last_num = num
else:
break
else:
break
for num in range(words_size):
if words[num] == " ":
ts_nums.append(ts_index)
else:
ts_nums.append(ts_index)
ts_index += 1
last_num = -1
for num in range(words_size):
if num <= last_num:
continue
if num in abbr_begin:
if time_stamp is not None:
begin = time_stamp[ts_nums[num]][0]
abbr_word = words[num].upper()
num += 1
while num < words_size:
if num in abbr_end:
abbr_word += words[num].upper()
last_num = num
break
else:
if words[num].encode("utf-8").isalpha():
abbr_word += words[num].upper()
num += 1
word_lists.append(abbr_word)
if time_stamp is not None:
end = time_stamp[ts_nums[num]][1]
ts_lists.append([begin, end])
else:
word_lists.append(words[num])
# length of time_stamp may not equal to length of words because of the (somehow improper) threshold set in timestamp_tools.py line 46, e.g., length of time_stamp can be zero but length of words is not.
# Moreover, move "word_lists.append(words[num])" into if clause, to keep length of word_lists and length of ts_lists equal.
if time_stamp is not None and ts_nums[num] < len(time_stamp) and words[num] != " ":
begin = time_stamp[ts_nums[num]][0]
end = time_stamp[ts_nums[num]][1]
ts_lists.append([begin, end])
begin = end
if time_stamp is not None:
return word_lists, ts_lists
else:
return word_lists
def sentence_postprocess(words: List[Any], time_stamp: List[List] = None):
middle_lists = []
word_lists = []
word_item = ""
ts_lists = []
# wash words lists
for i in words:
word = ""
if isinstance(i, str):
word = i
else:
word = i.decode("utf-8")
if word in ["<s>", "</s>", "<unk>", "<OOV>"]:
continue
else:
middle_lists.append(word)
# all chinese characters
if isAllChinese(middle_lists):
for i, ch in enumerate(middle_lists):
word_lists.append(ch.replace(" ", ""))
if time_stamp is not None:
ts_lists = time_stamp
# all alpha characters
elif isAllAlpha(middle_lists):
ts_flag = True
for i, ch in enumerate(middle_lists):
if ts_flag and time_stamp is not None:
begin = time_stamp[i][0]
end = time_stamp[i][1]
word = ""
if "@@" in ch:
word = ch.replace("@@", "")
word_item += word
if time_stamp is not None:
ts_flag = False
end = time_stamp[i][1]
else:
word_item += ch
word_lists.append(word_item)
word_lists.append(" ")
word_item = ""
if time_stamp is not None:
ts_flag = True
end = time_stamp[i][1]
ts_lists.append([begin, end])
begin = end
# mix characters
else:
alpha_blank = False
ts_flag = True
begin = -1
end = -1
for i, ch in enumerate(middle_lists):
if ts_flag and time_stamp is not None:
begin = time_stamp[i][0]
end = time_stamp[i][1]
word = ""
if isAllChinese(ch):
if alpha_blank is True:
word_lists.pop()
word_lists.append(ch)
alpha_blank = False
if time_stamp is not None:
ts_flag = True
ts_lists.append([begin, end])
begin = end
elif "@@" in ch:
word = ch.replace("@@", "")
word_item += word
alpha_blank = False
if time_stamp is not None:
ts_flag = False
end = time_stamp[i][1]
elif isAllAlpha(ch):
word_item += ch
word_lists.append(word_item)
word_lists.append(" ")
word_item = ""
alpha_blank = True
if time_stamp is not None:
ts_flag = True
end = time_stamp[i][1]
ts_lists.append([begin, end])
begin = end
else:
word_lists.append(ch)
if time_stamp is not None:
word_lists, ts_lists = abbr_dispose(word_lists, ts_lists)
real_word_lists = []
for ch in word_lists:
if ch != " ":
real_word_lists.append(ch)
sentence = " ".join(real_word_lists).strip()
return sentence, ts_lists, real_word_lists
else:
word_lists = abbr_dispose(word_lists)
real_word_lists = []
for ch in word_lists:
if ch != " ":
real_word_lists.append(ch)
sentence = "".join(word_lists).strip()
return sentence, real_word_lists
def sentence_postprocess_sentencepiece(words):
middle_lists = []
word_lists = []
word_item = ""
# wash words lists
for i in words:
word = ""
if isinstance(i, str):
word = i
else:
word = i.decode("utf-8")
if word in ["<s>", "</s>", "<unk>", "<OOV>"]:
continue
else:
middle_lists.append(word)
# all alpha characters
for i, ch in enumerate(middle_lists):
word = ""
if "\u2581" in ch and i == 0:
word_item = ""
word = ch.replace("\u2581", "")
word_item += word
elif "\u2581" in ch and i != 0:
word_lists.append(word_item)
word_lists.append(" ")
word_item = ""
word = ch.replace("\u2581", "")
word_item += word
else:
word_item += ch
if word_item is not None:
word_lists.append(word_item)
# word_lists = abbr_dispose(word_lists)
real_word_lists = []
for ch in word_lists:
if ch != " ":
if ch == "i":
ch = ch.replace("i", "I")
elif ch == "i'm":
ch = ch.replace("i'm", "I'm")
elif ch == "i've":
ch = ch.replace("i've", "I've")
elif ch == "i'll":
ch = ch.replace("i'll", "I'll")
real_word_lists.append(ch)
sentence = "".join(word_lists)
return sentence, real_word_lists
emo_dict = {
"<|HAPPY|>": "😊",
"<|SAD|>": "😔",
"<|ANGRY|>": "😡",
"<|NEUTRAL|>": "",
"<|FEARFUL|>": "😰",
"<|DISGUSTED|>": "🤢",
"<|SURPRISED|>": "😮",
}
event_dict = {
"<|BGM|>": "🎼",
"<|Speech|>": "",
"<|Applause|>": "👏",
"<|Laughter|>": "😀",
"<|Cry|>": "😭",
"<|Sneeze|>": "🤧",
"<|Breath|>": "",
"<|Cough|>": "🤧",
}
lang_dict = {
"<|zh|>": "<|lang|>",
"<|en|>": "<|lang|>",
"<|yue|>": "<|lang|>",
"<|ja|>": "<|lang|>",
"<|ko|>": "<|lang|>",
"<|nospeech|>": "<|lang|>",
}
emoji_dict = {
"<|nospeech|><|Event_UNK|>": "",
"<|zh|>": "",
"<|en|>": "",
"<|yue|>": "",
"<|ja|>": "",
"<|ko|>": "",
"<|nospeech|>": "",
"<|HAPPY|>": "😊",
"<|SAD|>": "😔",
"<|ANGRY|>": "😡",
"<|NEUTRAL|>": "",
"<|BGM|>": "🎼",
"<|Speech|>": "",
"<|Applause|>": "👏",
"<|Laughter|>": "😀",
"<|FEARFUL|>": "😰",
"<|DISGUSTED|>": "🤢",
"<|SURPRISED|>": "😮",
"<|Cry|>": "😭",
"<|EMO_UNKNOWN|>": "",
"<|Sneeze|>": "🤧",
"<|Breath|>": "",
"<|Cough|>": "😷",
"<|Sing|>": "",
"<|Speech_Noise|>": "",
"<|withitn|>": "",
"<|woitn|>": "",
"<|GBG|>": "",
"<|Event_UNK|>": "",
}
emo_set = {"😊", "😔", "😡", "😰", "🤢", "😮"}
event_set = {
"🎼",
"👏",
"😀",
"😭",
"🤧",
"😷",
}
def format_str_v2(s):
sptk_dict = {}
for sptk in emoji_dict:
sptk_dict[sptk] = s.count(sptk)
s = s.replace(sptk, "")
emo = "<|NEUTRAL|>"
for e in emo_dict:
if sptk_dict[e] > sptk_dict[emo]:
emo = e
for e in event_dict:
if sptk_dict[e] > 0:
s = event_dict[e] + s
s = s + emo_dict[emo]
for emoji in emo_set.union(event_set):
s = s.replace(" " + emoji, emoji)
s = s.replace(emoji + " ", emoji)
return s.strip()
def rich_transcription_postprocess(s):
def get_emo(s):
return s[-1] if s[-1] in emo_set else None
def get_event(s):
return s[0] if s[0] in event_set else None
s = s.replace("<|nospeech|><|Event_UNK|>", "")
for lang in lang_dict:
s = s.replace(lang, "<|lang|>")
s_list = [format_str_v2(s_i).strip(" ") for s_i in s.split("<|lang|>")]
new_s = " " + s_list[0]
cur_ent_event = get_event(new_s)
for i in range(1, len(s_list)):
if len(s_list[i]) == 0:
continue
if get_event(s_list[i]) == cur_ent_event and get_event(s_list[i]) != None:
s_list[i] = s_list[i][1:]
# else:
cur_ent_event = get_event(s_list[i])
if get_emo(s_list[i]) != None and get_emo(s_list[i]) == get_emo(new_s):
new_s = new_s[:-1]
new_s += s_list[i].strip().lstrip()
new_s = new_s.replace("The.", " ")
return new_s.strip()

View File

@@ -0,0 +1,197 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
""" Some implementations are adapted from https://github.com/yuyq96/D-TDNN
"""
import io
from typing import Union
import librosa as sf
import numpy as np
import torch
import torch.nn.functional as F
import torchaudio.compliance.kaldi as Kaldi
from torch import nn
from funasr.utils.modelscope_file import File
def check_audio_list(audio: list):
audio_dur = 0
for i in range(len(audio)):
seg = audio[i]
assert seg[1] >= seg[0], "modelscope error: Wrong time stamps."
assert isinstance(seg[2], np.ndarray), "modelscope error: Wrong data type."
assert (
int(seg[1] * 16000) - int(seg[0] * 16000) == seg[2].shape[0]
), "modelscope error: audio data in list is inconsistent with time length."
if i > 0:
assert seg[0] >= audio[i - 1][1], "modelscope error: Wrong time stamps."
audio_dur += seg[1] - seg[0]
return audio_dur
# assert audio_dur > 5, 'modelscope error: The effective audio duration is too short.'
def sv_preprocess(inputs: Union[np.ndarray, list]):
output = []
for i in range(len(inputs)):
if isinstance(inputs[i], str):
file_bytes = File.read(inputs[i])
data, fs = sf.load(io.BytesIO(file_bytes), dtype="float32")
if len(data.shape) == 2:
data = data[:, 0]
data = torch.from_numpy(data).unsqueeze(0)
data = data.squeeze(0)
elif isinstance(inputs[i], np.ndarray):
assert len(inputs[i].shape) == 1, "modelscope error: Input array should be [N, T]"
data = inputs[i]
if data.dtype in ["int16", "int32", "int64"]:
data = (data / (1 << 15)).astype("float32")
else:
data = data.astype("float32")
data = torch.from_numpy(data)
else:
raise ValueError(
"modelscope error: The input type is restricted to audio address and nump array."
)
output.append(data)
return output
def sv_chunk(vad_segments: list, fs=16000) -> list:
config = {
"seg_dur": 1.5,
"seg_shift": 0.75,
}
def seg_chunk(seg_data):
seg_st = seg_data[0]
data = seg_data[2]
chunk_len = int(config["seg_dur"] * fs)
chunk_shift = int(config["seg_shift"] * fs)
last_chunk_ed = 0
seg_res = []
for chunk_st in range(0, data.shape[0], chunk_shift):
chunk_ed = min(chunk_st + chunk_len, data.shape[0])
if chunk_ed <= last_chunk_ed:
break
last_chunk_ed = chunk_ed
chunk_st = max(0, chunk_ed - chunk_len)
chunk_data = data[chunk_st:chunk_ed]
if chunk_data.shape[0] < chunk_len:
chunk_data = np.pad(chunk_data, (0, chunk_len - chunk_data.shape[0]), "constant")
seg_res.append([chunk_st / fs + seg_st, chunk_ed / fs + seg_st, chunk_data])
return seg_res
segs = []
for i, s in enumerate(vad_segments):
segs.extend(seg_chunk(s))
return segs
def extract_feature(audio):
features = []
for au in audio:
feature = Kaldi.fbank(au.unsqueeze(0), num_mel_bins=80)
feature = feature - feature.mean(dim=0, keepdim=True)
features.append(feature.unsqueeze(0))
features = torch.cat(features)
return features
def postprocess(
segments: list, vad_segments: list, labels: np.ndarray, embeddings: np.ndarray
) -> list:
assert len(segments) == len(labels)
labels = correct_labels(labels)
distribute_res = []
for i in range(len(segments)):
distribute_res.append([segments[i][0], segments[i][1], labels[i]])
# merge the same speakers chronologically
distribute_res = merge_seque(distribute_res)
# accquire speaker center
spk_embs = []
for i in range(labels.max() + 1):
spk_emb = embeddings[labels == i].mean(0)
spk_embs.append(spk_emb)
spk_embs = np.stack(spk_embs)
def is_overlapped(t1, t2):
if t1 > t2 + 1e-4:
return True
return False
# distribute the overlap region
for i in range(1, len(distribute_res)):
if is_overlapped(distribute_res[i - 1][1], distribute_res[i][0]):
p = (distribute_res[i][0] + distribute_res[i - 1][1]) / 2
distribute_res[i][0] = p
distribute_res[i - 1][1] = p
# smooth the result
distribute_res = smooth(distribute_res)
return distribute_res
def correct_labels(labels):
labels_id = 0
id2id = {}
new_labels = []
for i in labels:
if i not in id2id:
id2id[i] = labels_id
labels_id += 1
new_labels.append(id2id[i])
return np.array(new_labels)
def merge_seque(distribute_res):
res = [distribute_res[0]]
for i in range(1, len(distribute_res)):
if distribute_res[i][2] != res[-1][2] or distribute_res[i][0] > res[-1][1]:
res.append(distribute_res[i])
else:
res[-1][1] = distribute_res[i][1]
return res
def smooth(res, mindur=1):
# short segments are assigned to nearest speakers.
for i in range(len(res)):
res[i][0] = round(res[i][0], 2)
res[i][1] = round(res[i][1], 2)
if res[i][1] - res[i][0] < mindur:
if i == 0:
res[i][2] = res[i + 1][2]
elif i == len(res) - 1:
res[i][2] = res[i - 1][2]
elif res[i][0] - res[i - 1][1] <= res[i + 1][0] - res[i][1]:
res[i][2] = res[i - 1][2]
else:
res[i][2] = res[i + 1][2]
# merge the speakers
res = merge_seque(res)
return res
def distribute_spk(sentence_list, sd_time_list):
sd_sentence_list = []
for d in sentence_list:
sentence_start = d["ts_list"][0][0]
sentence_end = d["ts_list"][-1][1]
sentence_spk = 0
max_overlap = 0
for sd_time in sd_time_list:
spk_st, spk_ed, spk = sd_time
spk_st = spk_st * 1000
spk_ed = spk_ed * 1000
overlap = max(min(sentence_end, spk_ed) - max(sentence_start, spk_st), 0)
if overlap > max_overlap:
max_overlap = overlap
sentence_spk = spk
d["spk"] = sentence_spk
sd_sentence_list.append(d)
return sd_sentence_list

View File

@@ -0,0 +1,278 @@
import torch
import codecs
import logging
import argparse
import numpy as np
# import edit_distance
from itertools import zip_longest
def cif_wo_hidden(alphas, threshold):
batch_size, len_time = alphas.size()
# loop varss
integrate = torch.zeros([batch_size], device=alphas.device)
# intermediate vars along time
list_fires = []
for t in range(len_time):
alpha = alphas[:, t]
integrate += alpha
list_fires.append(integrate)
fire_place = integrate >= threshold
integrate = torch.where(
fire_place,
integrate - torch.ones([batch_size], device=alphas.device) * threshold,
integrate,
)
fires = torch.stack(list_fires, 1)
return fires
def ts_prediction_lfr6_standard(
us_alphas, us_peaks, char_list, vad_offset=0.0, force_time_shift=-1.5, sil_in_str=True, upsample_rate=3,
):
if not len(char_list):
return "", []
START_END_THRESHOLD = 5
MAX_TOKEN_DURATION = 12 # 3 times upsampled
TIME_RATE=10.0 * 6 / 1000 / upsample_rate
if len(us_alphas.shape) == 2:
alphas, peaks = us_alphas[0], us_peaks[0] # support inference batch_size=1 only
else:
alphas, peaks = us_alphas, us_peaks
if char_list[-1] == "</s>":
char_list = char_list[:-1]
fire_place = (
torch.where(peaks >= 1.0 - 1e-4)[0].cpu().numpy() + force_time_shift
) # total offset
if len(fire_place) != len(char_list) + 1:
alphas /= alphas.sum() / (len(char_list) + 1)
alphas = alphas.unsqueeze(0)
peaks = cif_wo_hidden(alphas, threshold=1.0 - 1e-4)[0]
fire_place = (
torch.where(peaks >= 1.0 - 1e-4)[0].cpu().numpy() + force_time_shift
) # total offset
num_frames = peaks.shape[0]
timestamp_list = []
new_char_list = []
# for bicif model trained with large data, cif2 actually fires when a character starts
# so treat the frames between two peaks as the duration of the former token
# fire_place = torch.where(peaks>=1.0-1e-4)[0].cpu().numpy() + force_time_shift # total offset
# assert num_peak == len(char_list) + 1 # number of peaks is supposed to be number of tokens + 1
# begin silence
if fire_place[0] > START_END_THRESHOLD:
# char_list.insert(0, '<sil>')
timestamp_list.append([0.0, fire_place[0] * TIME_RATE])
new_char_list.append("<sil>")
# tokens timestamp
for i in range(len(fire_place) - 1):
new_char_list.append(char_list[i])
if MAX_TOKEN_DURATION < 0 or fire_place[i + 1] - fire_place[i] <= MAX_TOKEN_DURATION:
timestamp_list.append([fire_place[i] * TIME_RATE, fire_place[i + 1] * TIME_RATE])
else:
# cut the duration to token and sil of the 0-weight frames last long
_split = fire_place[i] + MAX_TOKEN_DURATION
timestamp_list.append([fire_place[i] * TIME_RATE, _split * TIME_RATE])
timestamp_list.append([_split * TIME_RATE, fire_place[i + 1] * TIME_RATE])
new_char_list.append("<sil>")
# tail token and end silence
# new_char_list.append(char_list[-1])
if num_frames - fire_place[-1] > START_END_THRESHOLD:
_end = (num_frames + fire_place[-1]) * 0.5
# _end = fire_place[-1]
timestamp_list[-1][1] = _end * TIME_RATE
timestamp_list.append([_end * TIME_RATE, num_frames * TIME_RATE])
new_char_list.append("<sil>")
else:
if len(timestamp_list)>0:
timestamp_list[-1][1] = num_frames * TIME_RATE
if vad_offset: # add offset time in model with vad
for i in range(len(timestamp_list)):
timestamp_list[i][0] = timestamp_list[i][0] + vad_offset / 1000.0
timestamp_list[i][1] = timestamp_list[i][1] + vad_offset / 1000.0
res_txt = ""
for char, timestamp in zip(new_char_list, timestamp_list):
# if char != '<sil>':
if not sil_in_str and char == "<sil>":
continue
res_txt += "{} {} {};".format(
char, str(timestamp[0] + 0.0005)[:5], str(timestamp[1] + 0.0005)[:5]
)
res = []
for char, timestamp in zip(new_char_list, timestamp_list):
if char != "<sil>":
res.append([int(timestamp[0] * 1000), int(timestamp[1] * 1000)])
return res_txt, res
def timestamp_sentence(
punc_id_list, timestamp_postprocessed, text_postprocessed, return_raw_text=False
):
punc_list = ["", "", "", ""]
res = []
if text_postprocessed is None:
return res
if timestamp_postprocessed is None:
return res
if len(timestamp_postprocessed) == 0:
return res
if len(text_postprocessed) == 0:
return res
if punc_id_list is None or len(punc_id_list) == 0:
res.append(
{
"text": text_postprocessed.split(),
"start": timestamp_postprocessed[0][0],
"end": timestamp_postprocessed[-1][1],
"timestamp": timestamp_postprocessed,
}
)
return res
if len(punc_id_list) != len(timestamp_postprocessed):
logging.warning("length mismatch between punc and timestamp")
sentence_text = ""
sentence_text_seg = ""
ts_list = []
sentence_start = timestamp_postprocessed[0][0]
sentence_end = timestamp_postprocessed[0][1]
texts = text_postprocessed.split()
punc_stamp_text_list = list(
zip_longest(punc_id_list, timestamp_postprocessed, texts, fillvalue=None)
)
for punc_stamp_text in punc_stamp_text_list:
punc_id, timestamp, text = punc_stamp_text
if sentence_start is None and timestamp is not None:
sentence_start = timestamp[0]
# sentence_text += text if text is not None else ''
if text is not None:
if "a" <= text[0] <= "z" or "A" <= text[0] <= "Z":
sentence_text += " " + text
elif len(sentence_text) and (
"a" <= sentence_text[-1] <= "z" or "A" <= sentence_text[-1] <= "Z"
):
sentence_text += " " + text
else:
sentence_text += text
sentence_text_seg += text + " "
ts_list.append(timestamp)
punc_id = int(punc_id) if punc_id is not None else 1
sentence_end = timestamp[1] if timestamp is not None else sentence_end
sentence_text_seg = (
sentence_text_seg[:-1] if sentence_text_seg[-1] == " " else sentence_text_seg
)
if punc_id > 1:
sentence_text += punc_list[punc_id - 2]
if return_raw_text:
res.append(
{
"text": sentence_text,
"start": sentence_start,
"end": sentence_end,
"timestamp": ts_list,
"raw_text": sentence_text_seg,
}
)
else:
res.append(
{
"text": sentence_text,
"start": sentence_start,
"end": sentence_end,
"timestamp": ts_list,
}
)
sentence_text = ""
sentence_text_seg = ""
ts_list = []
sentence_start = None
return res
def timestamp_sentence_en(
punc_id_list, timestamp_postprocessed, text_postprocessed, return_raw_text=False
):
punc_list = [",", ".", "?", ","]
res = []
if text_postprocessed is None:
return res
if timestamp_postprocessed is None:
return res
if len(timestamp_postprocessed) == 0:
return res
if len(text_postprocessed) == 0:
return res
if punc_id_list is None or len(punc_id_list) == 0:
res.append(
{
"text": text_postprocessed.split(),
"start": timestamp_postprocessed[0][0],
"end": timestamp_postprocessed[-1][1],
"timestamp": timestamp_postprocessed,
}
)
return res
if len(punc_id_list) != len(timestamp_postprocessed):
logging.warning("length mismatch between punc and timestamp")
sentence_text = ""
sentence_text_seg = ""
ts_list = []
sentence_start = timestamp_postprocessed[0][0]
sentence_end = timestamp_postprocessed[0][1]
texts = text_postprocessed.split()
punc_stamp_text_list = list(
zip_longest(punc_id_list, timestamp_postprocessed, texts, fillvalue=None)
)
is_sentence_start = True
for punc_stamp_text in punc_stamp_text_list:
punc_id, timestamp, text = punc_stamp_text
# sentence_text += text if text is not None else ''
if text is not None:
if "a" <= text[0] <= "z" or "A" <= text[0] <= "Z":
sentence_text += " " + text
elif len(sentence_text) and (
"a" <= sentence_text[-1] <= "z" or "A" <= sentence_text[-1] <= "Z"
):
sentence_text += " " + text
else:
sentence_text += text
sentence_text_seg += text + " "
ts_list.append(timestamp)
punc_id = int(punc_id) if punc_id is not None else 1
sentence_end = timestamp[1] if timestamp is not None else sentence_end
sentence_text = sentence_text[1:] if sentence_text[0] == ' ' else sentence_text
if is_sentence_start:
sentence_start = timestamp[0] if timestamp is not None else sentence_start
is_sentence_start = False
if punc_id > 1:
is_sentence_start = True
sentence_text += punc_list[punc_id - 2]
sentence_text_seg = (
sentence_text_seg[:-1] if sentence_text_seg[-1] == " " else sentence_text_seg
)
if return_raw_text:
res.append(
{
"text": sentence_text,
"start": sentence_start,
"end": sentence_end,
"timestamp": ts_list,
"raw_text": sentence_text_seg,
}
)
else:
res.append(
{
"text": sentence_text,
"start": sentence_start,
"end": sentence_end,
"timestamp": ts_list,
}
)
sentence_text = ""
sentence_text_seg = ""
ts_list = []
return res

View File

@@ -0,0 +1,84 @@
from typing import Optional
import torch
import torch.nn as nn
import numpy as np
class MakePadMask(nn.Module):
def __init__(self, max_seq_len=512, flip=True):
super().__init__()
if flip:
self.mask_pad = torch.Tensor(1 - np.tri(max_seq_len)).type(torch.bool)
else:
self.mask_pad = torch.Tensor(np.tri(max_seq_len)).type(torch.bool)
def forward(self, lengths, xs=None, length_dim=-1, maxlen=None):
"""Make mask tensor containing indices of padded part.
This implementation creates the same mask tensor with original make_pad_mask,
which can be converted into onnx format.
Dimension length of xs should be 2 or 3.
"""
if length_dim == 0:
raise ValueError("length_dim cannot be 0: {}".format(length_dim))
if xs is not None and len(xs.shape) == 3:
if length_dim == 1:
lengths = lengths.unsqueeze(1).expand(*xs.transpose(1, 2).shape[:2])
else:
lengths = lengths.unsqueeze(1).expand(*xs.shape[:2])
if maxlen is not None:
m = maxlen
elif xs is not None:
m = xs.shape[-1]
else:
m = torch.max(lengths)
mask = self.mask_pad[lengths - 1][..., :m].type(torch.float32)
if length_dim == 1:
return mask.transpose(1, 2)
else:
return mask
class sequence_mask(nn.Module):
def __init__(self, max_seq_len=512, flip=True):
super().__init__()
def forward(self, lengths, max_seq_len=None, dtype=torch.float32, device=None):
if max_seq_len is None:
max_seq_len = lengths.max()
row_vector = torch.arange(0, max_seq_len, 1).to(lengths.device)
matrix = torch.unsqueeze(lengths, dim=-1)
mask = row_vector < matrix
return mask.type(dtype).to(device) if device is not None else mask.type(dtype)
def normalize(
input: torch.Tensor, p: float = 2.0, dim: int = 1, out: Optional[torch.Tensor] = None
) -> torch.Tensor:
if out is None:
denom = input.norm(p, dim, keepdim=True).expand_as(input)
return input / denom
else:
denom = input.norm(p, dim, keepdim=True).expand_as(input)
return torch.div(input, denom, out=out)
def subsequent_mask(size: torch.Tensor):
return torch.ones(size, size).tril()
def MakePadMask_test():
feats_length = torch.tensor([10]).type(torch.long)
mask_fn = MakePadMask()
mask = mask_fn(feats_length)
print(mask)
if __name__ == "__main__":
MakePadMask_test()

View File

@@ -0,0 +1,149 @@
from distutils.util import strtobool
from typing import Optional
from typing import Tuple
from typing import Union
import humanfriendly
def str2bool(value: str) -> bool:
return bool(strtobool(value))
def remove_parenthesis(value: str):
value = value.strip()
if value.startswith("(") and value.endswith(")"):
value = value[1:-1]
elif value.startswith("[") and value.endswith("]"):
value = value[1:-1]
return value
def remove_quotes(value: str):
value = value.strip()
if value.startswith('"') and value.endswith('"'):
value = value[1:-1]
elif value.startswith("'") and value.endswith("'"):
value = value[1:-1]
return value
def int_or_none(value: str) -> Optional[int]:
"""int_or_none.
Examples:
>>> import argparse
>>> parser = argparse.ArgumentParser()
>>> _ = parser.add_argument('--foo', type=int_or_none)
>>> parser.parse_args(['--foo', '456'])
Namespace(foo=456)
>>> parser.parse_args(['--foo', 'none'])
Namespace(foo=None)
>>> parser.parse_args(['--foo', 'null'])
Namespace(foo=None)
>>> parser.parse_args(['--foo', 'nil'])
Namespace(foo=None)
"""
if value.strip().lower() in ("none", "null", "nil"):
return None
return int(value)
def float_or_none(value: str) -> Optional[float]:
"""float_or_none.
Examples:
>>> import argparse
>>> parser = argparse.ArgumentParser()
>>> _ = parser.add_argument('--foo', type=float_or_none)
>>> parser.parse_args(['--foo', '4.5'])
Namespace(foo=4.5)
>>> parser.parse_args(['--foo', 'none'])
Namespace(foo=None)
>>> parser.parse_args(['--foo', 'null'])
Namespace(foo=None)
>>> parser.parse_args(['--foo', 'nil'])
Namespace(foo=None)
"""
if value.strip().lower() in ("none", "null", "nil"):
return None
return float(value)
def humanfriendly_parse_size_or_none(value) -> Optional[float]:
if value.strip().lower() in ("none", "null", "nil"):
return None
return humanfriendly.parse_size(value)
def str_or_int(value: str) -> Union[str, int]:
try:
return int(value)
except ValueError:
return value
def str_or_none(value: str) -> Optional[str]:
"""str_or_none.
Examples:
>>> import argparse
>>> parser = argparse.ArgumentParser()
>>> _ = parser.add_argument('--foo', type=str_or_none)
>>> parser.parse_args(['--foo', 'aaa'])
Namespace(foo='aaa')
>>> parser.parse_args(['--foo', 'none'])
Namespace(foo=None)
>>> parser.parse_args(['--foo', 'null'])
Namespace(foo=None)
>>> parser.parse_args(['--foo', 'nil'])
Namespace(foo=None)
"""
if value.strip().lower() in ("none", "null", "nil"):
return None
return value
def str2pair_str(value: str) -> Tuple[str, str]:
"""str2pair_str.
Examples:
>>> import argparse
>>> str2pair_str('abc,def ')
('abc', 'def')
>>> parser = argparse.ArgumentParser()
>>> _ = parser.add_argument('--foo', type=str2pair_str)
>>> parser.parse_args(['--foo', 'abc,def'])
Namespace(foo=('abc', 'def'))
"""
value = remove_parenthesis(value)
a, b = value.split(",")
# Workaround for configargparse issues:
# If the list values are given from yaml file,
# the value givent to type() is shaped as python-list,
# e.g. ['a', 'b', 'c'],
# so we need to remove double quotes from it.
return remove_quotes(a), remove_quotes(b)
def str2triple_str(value: str) -> Tuple[str, str, str]:
"""str2triple_str.
Examples:
>>> str2triple_str('abc,def ,ghi')
('abc', 'def', 'ghi')
"""
value = remove_parenthesis(value)
a, b, c = value.split(",")
# Workaround for configargparse issues:
# If the list values are given from yaml file,
# the value givent to type() is shaped as python-list,
# e.g. ['a', 'b', 'c'],
# so we need to remove quotes from it.
return remove_quotes(a), remove_quotes(b), remove_quotes(c)

View File

@@ -0,0 +1,59 @@
import torch
from torch.nn.utils.rnn import pad_sequence
def slice_padding_fbank(speech, speech_lengths, vad_segments):
speech_list = []
speech_lengths_list = []
for i, segment in enumerate(vad_segments):
bed_idx = int(segment[0][0] * 16)
end_idx = min(int(segment[0][1] * 16), speech_lengths[0])
speech_i = speech[0, bed_idx:end_idx]
speech_lengths_i = end_idx - bed_idx
speech_list.append(speech_i)
speech_lengths_list.append(speech_lengths_i)
feats_pad = pad_sequence(speech_list, batch_first=True, padding_value=0.0)
speech_lengths_pad = torch.Tensor(speech_lengths_list).int()
return feats_pad, speech_lengths_pad
def slice_padding_audio_samples(speech, speech_lengths, vad_segments):
speech_list = []
speech_lengths_list = []
for i, segment in enumerate(vad_segments):
bed_idx = int(segment[0][0] * 16)
end_idx = min(int(segment[0][1] * 16), speech_lengths)
speech_i = speech[bed_idx:end_idx]
speech_lengths_i = end_idx - bed_idx
speech_list.append(speech_i)
speech_lengths_list.append(speech_lengths_i)
return speech_list, speech_lengths_list
def merge_vad(vad_result, max_length=15000, min_length=0):
new_result = []
if len(vad_result) <= 1:
return vad_result
time_step = [t[0] for t in vad_result] + [t[1] for t in vad_result]
time_step = sorted(list(set(time_step)))
if len(time_step) == 0:
return []
bg = 0
for i in range(len(time_step) - 1):
time = time_step[i]
if time_step[i + 1] - bg < max_length:
continue
if time - bg > min_length:
new_result.append([bg, time])
# if time - bg < max_length * 1.5:
# new_result.append([bg, time])
# else:
# split_num = int(time - bg) // max_length + 1
# spl_l = int(time - bg) // split_num
# for j in range(split_num):
# new_result.append([bg + j * spl_l, bg + (j + 1) * spl_l])
bg = time
new_result.append([bg, time_step[-1]])
return new_result

View File

@@ -0,0 +1,32 @@
from packaging import version
from funasr import __version__ # Ensure that __version__ is defined in your package's __init__.py
def get_pypi_version(package_name):
import requests
url = f"https://pypi.org/pypi/{package_name}/json"
response = requests.get(url)
if response.status_code == 200:
data = response.json()
return version.parse(data["info"]["version"])
else:
raise Exception("Failed to retrieve version information from PyPI.")
def check_for_update(disable=False):
current_version = version.parse(__version__)
print(f"funasr version: {current_version}.")
if disable:
return
print(
"Check update of funasr, and it would cost few times. You may disable it by set `disable_update=True` in AutoModel"
)
pypi_version = get_pypi_version("funasr")
if current_version < pypi_version:
print(f"New version is available: {pypi_version}.")
print('Please use the command "pip install -U funasr" to upgrade.')
else:
print(f"You are using the latest version of funasr-{current_version}")