mirror of
https://gitee.com/270580156/weiyu.git
synced 2026-05-19 13:48:10 +00:00
Sync from bytedesk-private: update
This commit is contained in:
0
modules/python/vendors/FunASR/funasr/utils/__init__.py
vendored
Normal file
0
modules/python/vendors/FunASR/funasr/utils/__init__.py
vendored
Normal file
286
modules/python/vendors/FunASR/funasr/utils/compute_det_ctc.py
vendored
Normal file
286
modules/python/vendors/FunASR/funasr/utils/compute_det_ctc.py
vendored
Normal file
@@ -0,0 +1,286 @@
|
||||
""" This implementation is adapted from https://github.com/wenet-e2e/wekws/blob/main/wekws/bin/compute_det.py."""
|
||||
|
||||
import os
|
||||
import json
|
||||
import logging
|
||||
import argparse
|
||||
import threading
|
||||
|
||||
import kaldiio
|
||||
import torch
|
||||
from funasr.utils.kws_utils import split_mixed_label
|
||||
|
||||
|
||||
class thread_wrapper(threading.Thread):
|
||||
def __init__(self, func, args=()):
|
||||
super(thread_wrapper, self).__init__()
|
||||
self.func = func
|
||||
self.args = args
|
||||
self.result = []
|
||||
|
||||
def run(self):
|
||||
self.result = self.func(*self.args)
|
||||
|
||||
def get_result(self):
|
||||
try:
|
||||
return self.result
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def space_mixed_label(input_str):
|
||||
splits = split_mixed_label(input_str)
|
||||
space_str = ''.join(f'{sub} ' for sub in splits)
|
||||
return space_str.strip()
|
||||
|
||||
|
||||
def read_lists(list_file):
|
||||
lists = []
|
||||
with open(list_file, 'r', encoding='utf8') as fin:
|
||||
for line in fin:
|
||||
if line.strip() != '':
|
||||
lists.append(line.strip())
|
||||
return lists
|
||||
|
||||
|
||||
def make_pair(wav_lists, trans_lists):
|
||||
logging.info('make pair for wav-trans list')
|
||||
|
||||
trans_table = {}
|
||||
for line in trans_lists:
|
||||
arr = line.strip().replace('\t', ' ').split()
|
||||
if len(arr) < 2:
|
||||
logging.debug('invalid line in trans file: {}'.format(
|
||||
line.strip()))
|
||||
continue
|
||||
|
||||
trans_table[arr[0]] = line.replace(arr[0],'').strip()
|
||||
|
||||
lists = []
|
||||
for line in wav_lists:
|
||||
arr = line.strip().replace('\t', ' ').split()
|
||||
if len(arr) == 2 and arr[0] in trans_table:
|
||||
lists.append(
|
||||
dict(key=arr[0],
|
||||
txt=trans_table[arr[0]],
|
||||
wav=arr[1],
|
||||
sample_rate=16000))
|
||||
else:
|
||||
logging.debug("can't find corresponding trans for key: {}".format(
|
||||
arr[0]))
|
||||
continue
|
||||
|
||||
return lists
|
||||
|
||||
|
||||
def count_duration(tid, data_lists):
|
||||
results = []
|
||||
|
||||
for obj in data_lists:
|
||||
assert 'key' in obj
|
||||
assert 'wav' in obj
|
||||
assert 'txt' in obj
|
||||
key = obj['key']
|
||||
wav_file = obj['wav']
|
||||
txt = obj['txt']
|
||||
|
||||
try:
|
||||
rate, waveform = kaldiio.load_mat(wav_file)
|
||||
waveform = torch.tensor(waveform, dtype=torch.float32)
|
||||
waveform = waveform.unsqueeze(0)
|
||||
frames = len(waveform[0])
|
||||
duration = frames / float(rate)
|
||||
except:
|
||||
logging.info(f'load file failed: {wav_file}')
|
||||
duration = 0.0
|
||||
|
||||
obj['duration'] = duration
|
||||
results.append(obj)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def load_data_and_score(keywords_list, data_file, trans_file, score_file):
|
||||
# score_table: {uttid: [keywordlist]}
|
||||
score_table = {}
|
||||
with open(score_file, 'r', encoding='utf8') as fin:
|
||||
# read score file and store in table
|
||||
for line in fin:
|
||||
arr = line.strip().split()
|
||||
key = arr[0]
|
||||
is_detected = arr[1]
|
||||
if is_detected == 'detected':
|
||||
if key not in score_table:
|
||||
score_table.update(
|
||||
{key: {
|
||||
'kw': space_mixed_label(arr[2]),
|
||||
'confi': float(arr[3])
|
||||
}})
|
||||
else:
|
||||
if key not in score_table:
|
||||
score_table.update({key: {'kw': 'unknown', 'confi': -1.0}})
|
||||
|
||||
wav_lists = read_lists(data_file)
|
||||
trans_lists = read_lists(trans_file)
|
||||
data_lists = make_pair(wav_lists, trans_lists)
|
||||
logging.info(f'origin list samples: {len(data_lists)}')
|
||||
|
||||
# count duration for each wave
|
||||
num_workers = 8
|
||||
start = 0
|
||||
step = int(len(data_lists) / num_workers)
|
||||
tasks = []
|
||||
for idx in range(num_workers):
|
||||
if idx != num_workers - 1:
|
||||
task = thread_wrapper(count_duration,
|
||||
(idx, data_lists[start:start + step]))
|
||||
else:
|
||||
task = thread_wrapper(count_duration, (idx, data_lists[start:]))
|
||||
task.start()
|
||||
tasks.append(task)
|
||||
start += step
|
||||
|
||||
duration_lists = []
|
||||
for task in tasks:
|
||||
task.join()
|
||||
duration_lists += task.get_result()
|
||||
logging.info(f'after list samples: {len(duration_lists)}')
|
||||
|
||||
# build empty structure for keyword-filler infos
|
||||
keyword_filler_table = {}
|
||||
for keyword in keywords_list:
|
||||
keyword = space_mixed_label(keyword)
|
||||
keyword_filler_table[keyword] = {}
|
||||
keyword_filler_table[keyword]['keyword_table'] = {}
|
||||
keyword_filler_table[keyword]['keyword_duration'] = 0.0
|
||||
keyword_filler_table[keyword]['filler_table'] = {}
|
||||
keyword_filler_table[keyword]['filler_duration'] = 0.0
|
||||
|
||||
for obj in duration_lists:
|
||||
assert 'key' in obj
|
||||
assert 'wav' in obj
|
||||
assert 'txt' in obj
|
||||
assert 'duration' in obj
|
||||
|
||||
key = obj['key']
|
||||
wav_file = obj['wav']
|
||||
txt = obj['txt']
|
||||
txt = space_mixed_label(txt)
|
||||
txt_regstr_lrblk = ' ' + txt + ' '
|
||||
duration = obj['duration']
|
||||
assert key in score_table
|
||||
|
||||
for keyword in keywords_list:
|
||||
keyword = space_mixed_label(keyword)
|
||||
keyword_regstr_lrblk = ' ' + keyword + ' '
|
||||
if txt_regstr_lrblk.find(keyword_regstr_lrblk) != -1:
|
||||
if keyword == score_table[key]['kw']:
|
||||
keyword_filler_table[keyword]['keyword_table'].update(
|
||||
{key: score_table[key]['confi']})
|
||||
else:
|
||||
# uttrance detected but not match this keyword
|
||||
keyword_filler_table[keyword]['keyword_table'].update(
|
||||
{key: -1.0})
|
||||
keyword_filler_table[keyword]['keyword_duration'] += duration
|
||||
else:
|
||||
if keyword == score_table[key]['kw']:
|
||||
keyword_filler_table[keyword]['filler_table'].update(
|
||||
{key: score_table[key]['confi']})
|
||||
else:
|
||||
# uttrance if detected, which is not FA for this keyword
|
||||
keyword_filler_table[keyword]['filler_table'].update(
|
||||
{key: -1.0})
|
||||
keyword_filler_table[keyword]['filler_duration'] += duration
|
||||
|
||||
return keyword_filler_table
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description='compute det curve')
|
||||
parser.add_argument('--keywords',
|
||||
type=str,
|
||||
required=True,
|
||||
help='preset keyword str, input all keywords')
|
||||
parser.add_argument('--test_data', required=True, help='test data file')
|
||||
parser.add_argument('--trans_data',
|
||||
required=True,
|
||||
default='',
|
||||
help='transcription of test data')
|
||||
parser.add_argument('--score_file', required=True, help='score file')
|
||||
parser.add_argument('--step',
|
||||
type=float,
|
||||
default=0.001,
|
||||
help='threshold step')
|
||||
parser.add_argument('--stats_dir',
|
||||
required=True,
|
||||
help='to save det stats files')
|
||||
args = parser.parse_args()
|
||||
|
||||
root_logger = logging.getLogger()
|
||||
handlers = root_logger.handlers[:]
|
||||
for handler in handlers:
|
||||
root_logger.removeHandler(handler)
|
||||
handler.close()
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG,
|
||||
format='%(asctime)s %(levelname)s %(message)s')
|
||||
|
||||
keywords_list = args.keywords.strip().split(',')
|
||||
keyword_filler_table = load_data_and_score(keywords_list, args.test_data,
|
||||
args.trans_data,
|
||||
args.score_file)
|
||||
|
||||
stats_files = {}
|
||||
for keyword in keywords_list:
|
||||
keyword = space_mixed_label(keyword)
|
||||
keyword_dur = keyword_filler_table[keyword]['keyword_duration']
|
||||
keyword_num = len(keyword_filler_table[keyword]['keyword_table'])
|
||||
filler_dur = keyword_filler_table[keyword]['filler_duration']
|
||||
filler_num = len(keyword_filler_table[keyword]['filler_table'])
|
||||
if keyword_num <= 0:
|
||||
print('Can\'t compute det for {} without positive sample'.format(keyword))
|
||||
continue
|
||||
if filler_num <= 0:
|
||||
print('Can\'t compute det for {} without negative sample'.format(keyword))
|
||||
continue
|
||||
|
||||
logging.info('Computing det for {}'.format(keyword))
|
||||
logging.info(' Keyword duration: {} Hours, wave number: {}'.format(
|
||||
keyword_dur / 3600.0, keyword_num))
|
||||
logging.info(' Filler duration: {} Hours'.format(filler_dur / 3600.0))
|
||||
|
||||
stats_file = os.path.join(args.stats_dir, 'stats.' + keyword.replace(' ', '_') + '.txt')
|
||||
with open(stats_file, 'w', encoding='utf8') as fout:
|
||||
threshold = 0.0
|
||||
while threshold <= 1.0:
|
||||
num_false_reject = 0
|
||||
num_true_detect = 0
|
||||
# transverse the all keyword_table
|
||||
for key, confi in keyword_filler_table[keyword][
|
||||
'keyword_table'].items():
|
||||
if confi < threshold:
|
||||
num_false_reject += 1
|
||||
else:
|
||||
num_true_detect += 1
|
||||
|
||||
num_false_alarm = 0
|
||||
# transverse the all filler_table
|
||||
for key, confi in keyword_filler_table[keyword][
|
||||
'filler_table'].items():
|
||||
if confi >= threshold:
|
||||
num_false_alarm += 1
|
||||
# print(f'false alarm: {keyword}, {key}, {confi}')
|
||||
|
||||
# false_reject_rate = num_false_reject / keyword_num
|
||||
true_detect_rate = num_true_detect / keyword_num
|
||||
|
||||
num_false_alarm = max(num_false_alarm, 1e-6)
|
||||
false_alarm_per_hour = num_false_alarm / (filler_dur / 3600.0)
|
||||
false_alarm_rate = num_false_alarm / filler_num
|
||||
|
||||
fout.write('{:.3f} {:.6f} {:.6f} {:.6f}\n'.format(
|
||||
threshold, true_detect_rate, false_alarm_rate,
|
||||
false_alarm_per_hour))
|
||||
threshold += args.step
|
||||
|
||||
stats_files[keyword] = stats_file
|
||||
70
modules/python/vendors/FunASR/funasr/utils/datadir_writer.py
vendored
Normal file
70
modules/python/vendors/FunASR/funasr/utils/datadir_writer.py
vendored
Normal file
@@ -0,0 +1,70 @@
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
import warnings
|
||||
|
||||
|
||||
class DatadirWriter:
|
||||
"""Writer class to create kaldi like data directory.
|
||||
|
||||
Examples:
|
||||
>>> with DatadirWriter("output") as writer:
|
||||
... # output/sub.txt is created here
|
||||
... subwriter = writer["sub.txt"]
|
||||
... # Write "uttidA some/where/a.wav"
|
||||
... subwriter["uttidA"] = "some/where/a.wav"
|
||||
... subwriter["uttidB"] = "some/where/b.wav"
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, p: Union[Path, str]):
|
||||
self.path = Path(p)
|
||||
self.chilidren = {}
|
||||
self.fd = None
|
||||
self.has_children = False
|
||||
self.keys = set()
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __getitem__(self, key: str) -> "DatadirWriter":
|
||||
if self.fd is not None:
|
||||
raise RuntimeError("This writer points out a file")
|
||||
|
||||
if key not in self.chilidren:
|
||||
w = DatadirWriter((self.path / key))
|
||||
self.chilidren[key] = w
|
||||
self.has_children = True
|
||||
|
||||
retval = self.chilidren[key]
|
||||
return retval
|
||||
|
||||
def __setitem__(self, key: str, value: str):
|
||||
if self.has_children:
|
||||
raise RuntimeError("This writer points out a directory")
|
||||
if key in self.keys:
|
||||
warnings.warn(f"Duplicated: {key}")
|
||||
|
||||
if self.fd is None:
|
||||
self.path.parent.mkdir(parents=True, exist_ok=True)
|
||||
self.fd = self.path.open("w", encoding="utf-8")
|
||||
|
||||
self.keys.add(key)
|
||||
self.fd.write(f"{key} {value}\n")
|
||||
self.fd.flush()
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self.close()
|
||||
|
||||
def close(self):
|
||||
if self.has_children:
|
||||
prev_child = None
|
||||
for child in self.chilidren.values():
|
||||
child.close()
|
||||
if prev_child is not None and prev_child.keys != child.keys:
|
||||
warnings.warn(
|
||||
f"Ids are mismatching between " f"{prev_child.path} and {child.path}"
|
||||
)
|
||||
prev_child = child
|
||||
|
||||
elif self.fd is not None:
|
||||
self.fd.close()
|
||||
61
modules/python/vendors/FunASR/funasr/utils/dynamic_import.py
vendored
Normal file
61
modules/python/vendors/FunASR/funasr/utils/dynamic_import.py
vendored
Normal file
@@ -0,0 +1,61 @@
|
||||
import importlib.util
|
||||
|
||||
import importlib.util
|
||||
import inspect
|
||||
import os.path
|
||||
import sys
|
||||
|
||||
|
||||
def load_module_from_path(file_path):
|
||||
"""
|
||||
从给定的文件路径动态加载模块。
|
||||
|
||||
:param file_path: 模块文件的绝对路径。
|
||||
:return: 加载的模块
|
||||
"""
|
||||
module_name = file_path.split("/")[-1].replace(".py", "")
|
||||
spec = importlib.util.spec_from_file_location(module_name, file_path)
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(module)
|
||||
return module
|
||||
|
||||
|
||||
def import_module_from_path(file_path: str):
|
||||
|
||||
if file_path.startswith("http"):
|
||||
from funasr.download.file import download_from_url
|
||||
|
||||
file_path = download_from_url(file_path)
|
||||
|
||||
file_dir = os.path.dirname(file_path)
|
||||
# file_name = os.path.basename(file_path)
|
||||
module_name = file_path.split("/")[-1].replace(".py", "")
|
||||
if len(file_dir) < 1:
|
||||
file_dir = "./"
|
||||
sys.path.append(file_dir)
|
||||
try:
|
||||
importlib.import_module(module_name)
|
||||
print(f"Loading remote code successfully: {file_path}")
|
||||
except Exception as e:
|
||||
print(f"Loading remote code failed: {file_path}, {e}")
|
||||
|
||||
|
||||
#
|
||||
# def load_module_from_path(module_name, file_path):
|
||||
# """
|
||||
# 从给定的文件路径动态加载模块。
|
||||
#
|
||||
# :param module_name: 动态加载的模块的名称。
|
||||
# :param file_path: 模块文件的绝对路径。
|
||||
# :return: 加载的模块
|
||||
# """
|
||||
# # 创建加载模块的spec(规格)
|
||||
# spec = importlib.util.spec_from_file_location(module_name, file_path)
|
||||
#
|
||||
# # 根据spec创建模块
|
||||
# module = importlib.util.module_from_spec(spec)
|
||||
#
|
||||
# # 执行模块的代码来实际加载它
|
||||
# spec.loader.exec_module(module)
|
||||
#
|
||||
# return module
|
||||
202
modules/python/vendors/FunASR/funasr/utils/export_utils.py
vendored
Normal file
202
modules/python/vendors/FunASR/funasr/utils/export_utils.py
vendored
Normal file
@@ -0,0 +1,202 @@
|
||||
import os
|
||||
import torch
|
||||
import functools
|
||||
|
||||
|
||||
def export(
|
||||
model, data_in=None, quantize: bool = False, opset_version: int = 14, type="onnx", **kwargs
|
||||
):
|
||||
model_scripts = model.export(**kwargs)
|
||||
export_dir = kwargs.get("output_dir", os.path.dirname(kwargs.get("init_param")))
|
||||
os.makedirs(export_dir, exist_ok=True)
|
||||
|
||||
if not isinstance(model_scripts, (list, tuple)):
|
||||
model_scripts = (model_scripts,)
|
||||
for m in model_scripts:
|
||||
m.eval()
|
||||
if type == "onnx":
|
||||
_onnx(
|
||||
m,
|
||||
data_in=data_in,
|
||||
quantize=quantize,
|
||||
opset_version=opset_version,
|
||||
export_dir=export_dir,
|
||||
**kwargs,
|
||||
)
|
||||
elif type == "torchscript":
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
print("Exporting torchscripts on device {}".format(device))
|
||||
_torchscripts(m, path=export_dir, device=device)
|
||||
elif type == "bladedisc":
|
||||
assert (
|
||||
torch.cuda.is_available()
|
||||
), "Currently bladedisc optimization for FunASR only supports GPU"
|
||||
# bladedisc only optimizes encoder/decoder modules
|
||||
if hasattr(m, "encoder") and hasattr(m, "decoder"):
|
||||
_bladedisc_opt_for_encdec(m, path=export_dir, enable_fp16=True)
|
||||
else:
|
||||
_torchscripts(m, path=export_dir, device="cuda")
|
||||
print("output dir: {}".format(export_dir))
|
||||
|
||||
return export_dir
|
||||
|
||||
|
||||
def _onnx(
|
||||
model,
|
||||
data_in=None,
|
||||
quantize: bool = False,
|
||||
opset_version: int = 14,
|
||||
export_dir: str = None,
|
||||
**kwargs,
|
||||
):
|
||||
|
||||
dummy_input = model.export_dummy_inputs()
|
||||
|
||||
verbose = kwargs.get("verbose", False)
|
||||
|
||||
if isinstance(model.export_name, str):
|
||||
export_name = model.export_name + ".onnx"
|
||||
else:
|
||||
export_name = model.export_name()
|
||||
model_path = os.path.join(export_dir, export_name)
|
||||
torch.onnx.export(
|
||||
model,
|
||||
dummy_input,
|
||||
model_path,
|
||||
verbose=verbose,
|
||||
opset_version=opset_version,
|
||||
input_names=model.export_input_names(),
|
||||
output_names=model.export_output_names(),
|
||||
dynamic_axes=model.export_dynamic_axes(),
|
||||
)
|
||||
|
||||
if quantize:
|
||||
from onnxruntime.quantization import QuantType, quantize_dynamic
|
||||
import onnx
|
||||
|
||||
quant_model_path = model_path.replace(".onnx", "_quant.onnx")
|
||||
onnx_model = onnx.load(model_path)
|
||||
nodes = [n.name for n in onnx_model.graph.node]
|
||||
nodes_to_exclude = [
|
||||
m for m in nodes if "output" in m or "bias_encoder" in m or "bias_decoder" in m
|
||||
]
|
||||
print("Quantizing model from {} to {}".format(model_path, quant_model_path))
|
||||
quantize_dynamic(
|
||||
model_input=model_path,
|
||||
model_output=quant_model_path,
|
||||
op_types_to_quantize=["MatMul"],
|
||||
per_channel=True,
|
||||
reduce_range=False,
|
||||
weight_type=QuantType.QUInt8,
|
||||
nodes_to_exclude=nodes_to_exclude,
|
||||
)
|
||||
|
||||
|
||||
def _torchscripts(model, path, device="cuda"):
|
||||
dummy_input = model.export_dummy_inputs()
|
||||
|
||||
if device == "cuda":
|
||||
model = model.cuda()
|
||||
if isinstance(dummy_input, torch.Tensor):
|
||||
dummy_input = dummy_input.cuda()
|
||||
else:
|
||||
dummy_input = tuple([i.cuda() for i in dummy_input])
|
||||
|
||||
model_script = torch.jit.trace(model, dummy_input)
|
||||
if isinstance(model.export_name, str):
|
||||
model_script.save(os.path.join(path, f"{model.export_name}".replace("onnx", "torchscript")))
|
||||
else:
|
||||
model_script.save(os.path.join(path, f"{model.export_name()}".replace("onnx", "torchscript")))
|
||||
|
||||
|
||||
def _bladedisc_opt(model, model_inputs, enable_fp16=True):
|
||||
model = model.eval()
|
||||
try:
|
||||
import torch_blade
|
||||
except Exception as e:
|
||||
print(
|
||||
f"Warning, if you are exporting bladedisc, please install it and try it again: pip install -U torch_blade\n"
|
||||
)
|
||||
torch_config = torch_blade.config.Config()
|
||||
torch_config.enable_fp16 = enable_fp16
|
||||
with torch.no_grad(), torch_config:
|
||||
opt_model = torch_blade.optimize(
|
||||
model,
|
||||
allow_tracing=True,
|
||||
model_inputs=model_inputs,
|
||||
)
|
||||
return opt_model
|
||||
|
||||
|
||||
def _rescale_input_hook(m, x, scale):
|
||||
if len(x) > 1:
|
||||
return (x[0] / scale, *x[1:])
|
||||
else:
|
||||
return (x[0] / scale,)
|
||||
|
||||
|
||||
def _rescale_output_hook(m, x, y, scale):
|
||||
if isinstance(y, tuple):
|
||||
return (y[0] / scale, *y[1:])
|
||||
else:
|
||||
return y / scale
|
||||
|
||||
|
||||
def _rescale_encoder_model(model, input_data):
|
||||
# Calculate absmax
|
||||
absmax = torch.tensor(0).cuda()
|
||||
|
||||
def stat_input_hook(m, x, y):
|
||||
val = x[0] if isinstance(x, tuple) else x
|
||||
absmax.copy_(torch.max(absmax, val.detach().abs().max()))
|
||||
|
||||
encoders = model.encoder.model.encoders
|
||||
hooks = [m.register_forward_hook(stat_input_hook) for m in encoders]
|
||||
model = model.cuda()
|
||||
model(*input_data)
|
||||
for h in hooks:
|
||||
h.remove()
|
||||
|
||||
# Rescale encoder modules
|
||||
fp16_scale = int(2 * absmax // 65536)
|
||||
print(f"rescale encoder modules with factor={fp16_scale}")
|
||||
model.encoder.model.encoders0.register_forward_pre_hook(
|
||||
functools.partial(_rescale_input_hook, scale=fp16_scale),
|
||||
)
|
||||
for name, m in model.encoder.model.named_modules():
|
||||
if name.endswith("self_attn"):
|
||||
m.register_forward_hook(functools.partial(_rescale_output_hook, scale=fp16_scale))
|
||||
if name.endswith("feed_forward.w_2"):
|
||||
state_dict = {k: v / fp16_scale for k, v in m.state_dict().items()}
|
||||
m.load_state_dict(state_dict)
|
||||
|
||||
|
||||
def _bladedisc_opt_for_encdec(model, path, enable_fp16):
|
||||
# Get input data
|
||||
# TODO: better to use real data
|
||||
input_data = model.export_dummy_inputs()
|
||||
if isinstance(input_data, torch.Tensor):
|
||||
input_data = input_data.cuda()
|
||||
else:
|
||||
input_data = tuple([i.cuda() for i in input_data])
|
||||
|
||||
# Get input data for decoder module
|
||||
decoder_inputs = list()
|
||||
|
||||
def get_input_hook(m, x):
|
||||
decoder_inputs.extend(list(x))
|
||||
|
||||
hook = model.decoder.register_forward_pre_hook(get_input_hook)
|
||||
model = model.cuda()
|
||||
model(*input_data)
|
||||
hook.remove()
|
||||
|
||||
# Prevent FP16 overflow
|
||||
if enable_fp16:
|
||||
_rescale_encoder_model(model, input_data)
|
||||
|
||||
# Export and optimize encoder/decoder modules
|
||||
model.encoder = _bladedisc_opt(model.encoder, input_data[:2])
|
||||
model.decoder = _bladedisc_opt(model.decoder, tuple(decoder_inputs))
|
||||
model_script = torch.jit.trace(model, input_data)
|
||||
model_script.save(os.path.join(path, f"{model.export_name}_blade.torchscript"))
|
||||
36
modules/python/vendors/FunASR/funasr/utils/install_model_requirements.py
vendored
Normal file
36
modules/python/vendors/FunASR/funasr/utils/install_model_requirements.py
vendored
Normal file
@@ -0,0 +1,36 @@
|
||||
import subprocess
|
||||
|
||||
|
||||
def install_requirements(requirements_path):
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["pip", "install", "-r", requirements_path],
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
text=True,
|
||||
)
|
||||
|
||||
# check status
|
||||
if result.returncode == 0:
|
||||
print("install model requirements successfully")
|
||||
return True
|
||||
else:
|
||||
print("fail to install model requirements! ")
|
||||
print("error", result.stderr)
|
||||
return False
|
||||
except Exception as e:
|
||||
result = subprocess.run(
|
||||
["pip", "install", "-r", requirements_path],
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
text=True,
|
||||
)
|
||||
|
||||
# check status
|
||||
if result.returncode == 0:
|
||||
print("install model requirements successfully")
|
||||
return True
|
||||
else:
|
||||
print("fail to install model requirements! ")
|
||||
print("error", result.stderr)
|
||||
return False
|
||||
284
modules/python/vendors/FunASR/funasr/utils/kws_utils.py
vendored
Normal file
284
modules/python/vendors/FunASR/funasr/utils/kws_utils.py
vendored
Normal file
@@ -0,0 +1,284 @@
|
||||
import re
|
||||
import logging
|
||||
|
||||
import torch
|
||||
import math
|
||||
from collections import defaultdict
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
|
||||
symbol_str = '[’!"#$%&\'()*+,-./:;<>=?@,。?★、…【】《》?“”‘’![\\]^_`{|}~\s]+'
|
||||
|
||||
|
||||
def split_mixed_label(input_str):
|
||||
tokens = []
|
||||
s = input_str.lower()
|
||||
while len(s) > 0:
|
||||
match = re.match(r'[A-Za-z!?,<>()\']+', s)
|
||||
if match is not None:
|
||||
word = match.group(0)
|
||||
else:
|
||||
word = s[0:1]
|
||||
tokens.append(word)
|
||||
s = s.replace(word, '', 1).strip(' ')
|
||||
return tokens
|
||||
|
||||
|
||||
def query_token_set(txt, symbol_table, lexicon_table):
|
||||
tokens_str = tuple()
|
||||
tokens_idx = tuple()
|
||||
|
||||
if txt in symbol_table:
|
||||
tokens_str = tokens_str + (txt, )
|
||||
tokens_idx = tokens_idx + (symbol_table[txt], )
|
||||
return tokens_str, tokens_idx
|
||||
|
||||
parts = split_mixed_label(txt)
|
||||
for part in parts:
|
||||
if part == '!sil' or part == '(sil)' or part == '<sil>':
|
||||
tokens_str = tokens_str + ('!sil', )
|
||||
elif part == '<blank>' or part == '<blank>':
|
||||
tokens_str = tokens_str + ('<blank>', )
|
||||
elif part == '(noise)' or part == 'noise)' or part == '(noise' or part == '<noise>':
|
||||
tokens_str = tokens_str + ('<unk>', )
|
||||
elif part in symbol_table:
|
||||
tokens_str = tokens_str + (part, )
|
||||
elif part in lexicon_table:
|
||||
for ch in lexicon_table[part]:
|
||||
tokens_str = tokens_str + (ch, )
|
||||
else:
|
||||
part = re.sub(symbol_str, '', part)
|
||||
for ch in part:
|
||||
tokens_str = tokens_str + (ch, )
|
||||
|
||||
for ch in tokens_str:
|
||||
if ch in symbol_table:
|
||||
tokens_idx = tokens_idx + (symbol_table[ch], )
|
||||
elif ch == '!sil':
|
||||
if 'sil' in symbol_table:
|
||||
tokens_idx = tokens_idx + (symbol_table['sil'], )
|
||||
else:
|
||||
tokens_idx = tokens_idx + (symbol_table['<blank>'], )
|
||||
elif ch == '<unk>':
|
||||
if '<unk>' in symbol_table:
|
||||
tokens_idx = tokens_idx + (symbol_table['<unk>'], )
|
||||
else:
|
||||
tokens_idx = tokens_idx + (symbol_table['<blank>'], )
|
||||
else:
|
||||
if '<unk>' in symbol_table:
|
||||
tokens_idx = tokens_idx + (symbol_table['<unk>'], )
|
||||
logging.info(f'\'{ch}\' is not in token set, replace with <unk>')
|
||||
else:
|
||||
tokens_idx = tokens_idx + (symbol_table['<blank>'], )
|
||||
logging.info(f'\'{ch}\' is not in token set, replace with <blank>')
|
||||
|
||||
return tokens_str, tokens_idx
|
||||
|
||||
|
||||
class KwsCtcPrefixDecoder():
|
||||
"""Decoder interface wrapper for CTCPrefixDecode."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
ctc: torch.nn.Module,
|
||||
keywords: str,
|
||||
token_list: list,
|
||||
seg_dict: dict,
|
||||
):
|
||||
"""Initialize class.
|
||||
|
||||
Args:
|
||||
ctc (torch.nn.Module): The CTC implementation.
|
||||
For example, :class:`espnet.nets.pytorch_backend.ctc.CTC`
|
||||
|
||||
"""
|
||||
self.ctc = ctc
|
||||
self.token_list = token_list
|
||||
|
||||
token_table = {}
|
||||
for token in token_list:
|
||||
token_table[token] = token_list.index(token)
|
||||
|
||||
self.keywords_idxset = {0}
|
||||
self.keywords_token = {}
|
||||
self.keywords_str = keywords
|
||||
keywords_list = self.keywords_str.strip().replace(' ', '').split(',')
|
||||
for keyword in keywords_list:
|
||||
strs, indexs = query_token_set(keyword, token_table, seg_dict)
|
||||
self.keywords_token[keyword] = {}
|
||||
self.keywords_token[keyword]['token_id'] = indexs
|
||||
self.keywords_token[keyword]['token_str'] = ''.join('%s ' % str(i) for i in indexs)
|
||||
[ self.keywords_idxset.add(i) for i in indexs ]
|
||||
|
||||
def beam_search(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
logits_lengths: torch.Tensor,
|
||||
keywords_tokenset: set = None,
|
||||
score_beam_size: int = 3,
|
||||
path_beam_size: int = 20,
|
||||
) -> Tuple[List[List[int]], torch.Tensor]:
|
||||
""" CTC prefix beam search inner implementation
|
||||
|
||||
Args:
|
||||
logits (torch.Tensor): (1, max_len, vocab_size)
|
||||
logits_lengths (torch.Tensor): (1, )
|
||||
keywords_tokenset (set): token set for filtering score
|
||||
score_beam_size (int): beam size for score
|
||||
path_beam_size (int): beam size for path
|
||||
|
||||
Returns:
|
||||
List[List[int]]: nbest results
|
||||
"""
|
||||
|
||||
maxlen = logits.size(0)
|
||||
ctc_probs = logits
|
||||
cur_hyps = [(tuple(), (1.0, 0.0, []))]
|
||||
|
||||
# CTC beam search step by step
|
||||
for t in range(0, maxlen):
|
||||
probs = ctc_probs[t] # (vocab_size,)
|
||||
# key: prefix, value (pb, pnb), default value(-inf, -inf)
|
||||
next_hyps = defaultdict(lambda: (0.0, 0.0, []))
|
||||
|
||||
# 2.1 First beam prune: select topk best
|
||||
top_k_probs, top_k_index = probs.topk(
|
||||
score_beam_size) # (score_beam_size,)
|
||||
|
||||
# filter prob score that is too small
|
||||
filter_probs = []
|
||||
filter_index = []
|
||||
for prob, idx in zip(top_k_probs.tolist(), top_k_index.tolist()):
|
||||
if keywords_tokenset is not None:
|
||||
if prob > 0.05 and idx in keywords_tokenset:
|
||||
filter_probs.append(prob)
|
||||
filter_index.append(idx)
|
||||
else:
|
||||
if prob > 0.05:
|
||||
filter_probs.append(prob)
|
||||
filter_index.append(idx)
|
||||
|
||||
if len(filter_index) == 0:
|
||||
continue
|
||||
|
||||
for s in filter_index:
|
||||
ps = probs[s].item()
|
||||
# print(f'frame:{t}, token:{s}, score:{ps}')
|
||||
|
||||
for prefix, (pb, pnb, cur_nodes) in cur_hyps:
|
||||
last = prefix[-1] if len(prefix) > 0 else None
|
||||
if s == 0: # blank
|
||||
n_pb, n_pnb, nodes = next_hyps[prefix]
|
||||
n_pb = n_pb + pb * ps + pnb * ps
|
||||
nodes = cur_nodes.copy()
|
||||
next_hyps[prefix] = (n_pb, n_pnb, nodes)
|
||||
elif s == last:
|
||||
if not math.isclose(pnb, 0.0, abs_tol=0.000001):
|
||||
# Update *ss -> *s;
|
||||
n_pb, n_pnb, nodes = next_hyps[prefix]
|
||||
n_pnb = n_pnb + pnb * ps
|
||||
nodes = cur_nodes.copy()
|
||||
if ps > nodes[-1]['prob']: # update frame and prob
|
||||
nodes[-1]['prob'] = ps
|
||||
nodes[-1]['frame'] = t
|
||||
next_hyps[prefix] = (n_pb, n_pnb, nodes)
|
||||
|
||||
if not math.isclose(pb, 0.0, abs_tol=0.000001):
|
||||
# Update *s-s -> *ss, - is for blank
|
||||
n_prefix = prefix + (s, )
|
||||
n_pb, n_pnb, nodes = next_hyps[n_prefix]
|
||||
n_pnb = n_pnb + pb * ps
|
||||
nodes = cur_nodes.copy()
|
||||
nodes.append(dict(token=s, frame=t,
|
||||
prob=ps)) # to record token prob
|
||||
next_hyps[n_prefix] = (n_pb, n_pnb, nodes)
|
||||
else:
|
||||
n_prefix = prefix + (s, )
|
||||
n_pb, n_pnb, nodes = next_hyps[n_prefix]
|
||||
if nodes:
|
||||
if ps > nodes[-1]['prob']: # update frame and prob
|
||||
nodes[-1]['prob'] = ps
|
||||
nodes[-1]['frame'] = t
|
||||
else:
|
||||
nodes = cur_nodes.copy()
|
||||
nodes.append(dict(token=s, frame=t,
|
||||
prob=ps)) # to record token prob
|
||||
n_pnb = n_pnb + pb * ps + pnb * ps
|
||||
next_hyps[n_prefix] = (n_pb, n_pnb, nodes)
|
||||
|
||||
# 2.2 Second beam prune
|
||||
next_hyps = sorted(next_hyps.items(),
|
||||
key=lambda x: (x[1][0] + x[1][1]),
|
||||
reverse=True)
|
||||
|
||||
cur_hyps = next_hyps[:path_beam_size]
|
||||
|
||||
hyps = [(y[0], y[1][0] + y[1][1], y[1][2]) for y in cur_hyps]
|
||||
return hyps
|
||||
|
||||
|
||||
def is_sublist(self, main_list, check_list):
|
||||
if len(main_list) < len(check_list):
|
||||
return -1
|
||||
|
||||
if len(main_list) == len(check_list):
|
||||
return 0 if main_list == check_list else -1
|
||||
|
||||
for i in range(len(main_list) - len(check_list)):
|
||||
if main_list[i] == check_list[0]:
|
||||
for j in range(len(check_list)):
|
||||
if main_list[i + j] != check_list[j]:
|
||||
break
|
||||
else:
|
||||
return i
|
||||
else:
|
||||
return -1
|
||||
|
||||
|
||||
def _decode_inside(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
logits_lengths: torch.Tensor,
|
||||
):
|
||||
hyps = self.beam_search(logits, logits_lengths, self.keywords_idxset)
|
||||
|
||||
hit_keyword = None
|
||||
hit_score = 1.0
|
||||
# start = 0; end = 0
|
||||
for one_hyp in hyps:
|
||||
prefix_ids = one_hyp[0]
|
||||
# path_score = one_hyp[1]
|
||||
prefix_nodes = one_hyp[2]
|
||||
assert len(prefix_ids) == len(prefix_nodes)
|
||||
for word in self.keywords_token.keys():
|
||||
lab = self.keywords_token[word]['token_id']
|
||||
offset = self.is_sublist(prefix_ids, lab)
|
||||
if offset != -1:
|
||||
hit_keyword = word
|
||||
for idx in range(offset, offset + len(lab)):
|
||||
hit_score *= prefix_nodes[idx]['prob']
|
||||
break
|
||||
if hit_keyword is not None:
|
||||
hit_score = math.sqrt(hit_score)
|
||||
break
|
||||
|
||||
if hit_keyword is not None:
|
||||
return True, hit_keyword, hit_score
|
||||
else:
|
||||
return False, None, None
|
||||
|
||||
|
||||
def decode(self, x: torch.Tensor):
|
||||
"""Get an initial state for decoding.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): The encoded feature tensor
|
||||
|
||||
Returns: decode result
|
||||
|
||||
"""
|
||||
|
||||
raw_logp = self.ctc.softmax(x.unsqueeze(0)).detach().squeeze(0).cpu()
|
||||
xlen = torch.tensor([raw_logp.size(1)])
|
||||
|
||||
return self._decode_inside(raw_logp, xlen)
|
||||
217
modules/python/vendors/FunASR/funasr/utils/load_utils.py
vendored
Normal file
217
modules/python/vendors/FunASR/funasr/utils/load_utils.py
vendored
Normal file
@@ -0,0 +1,217 @@
|
||||
import os
|
||||
import torch
|
||||
import json
|
||||
import torch.distributed as dist
|
||||
import numpy as np
|
||||
import kaldiio
|
||||
import librosa
|
||||
import torchaudio
|
||||
import time
|
||||
import logging
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
|
||||
try:
|
||||
from funasr.download.file import download_from_url
|
||||
except:
|
||||
print("urllib is not installed, if you infer from url, please install it first.")
|
||||
import pdb
|
||||
import subprocess
|
||||
from subprocess import CalledProcessError, run
|
||||
|
||||
|
||||
def is_ffmpeg_installed():
|
||||
try:
|
||||
output = subprocess.check_output(["ffmpeg", "-version"], stderr=subprocess.STDOUT)
|
||||
return "ffmpeg version" in output.decode("utf-8")
|
||||
except (subprocess.CalledProcessError, FileNotFoundError):
|
||||
return False
|
||||
|
||||
|
||||
use_ffmpeg = False
|
||||
if is_ffmpeg_installed():
|
||||
use_ffmpeg = True
|
||||
else:
|
||||
print(
|
||||
"Notice: ffmpeg is not installed. torchaudio is used to load audio\n"
|
||||
"If you want to use ffmpeg backend to load audio, please install it by:"
|
||||
"\n\tsudo apt install ffmpeg # ubuntu"
|
||||
"\n\t# brew install ffmpeg # mac"
|
||||
)
|
||||
|
||||
|
||||
def load_audio_text_image_video(
|
||||
data_or_path_or_list,
|
||||
fs: int = 16000,
|
||||
audio_fs: int = 16000,
|
||||
data_type="sound",
|
||||
tokenizer=None,
|
||||
**kwargs,
|
||||
):
|
||||
if isinstance(data_or_path_or_list, (list, tuple)):
|
||||
if data_type is not None and isinstance(data_type, (list, tuple)):
|
||||
data_types = [data_type] * len(data_or_path_or_list)
|
||||
data_or_path_or_list_ret = [[] for d in data_type]
|
||||
for i, (data_type_i, data_or_path_or_list_i) in enumerate(
|
||||
zip(data_types, data_or_path_or_list)
|
||||
):
|
||||
for j, (data_type_j, data_or_path_or_list_j) in enumerate(
|
||||
zip(data_type_i, data_or_path_or_list_i)
|
||||
):
|
||||
data_or_path_or_list_j = load_audio_text_image_video(
|
||||
data_or_path_or_list_j,
|
||||
fs=fs,
|
||||
audio_fs=audio_fs,
|
||||
data_type=data_type_j,
|
||||
tokenizer=tokenizer,
|
||||
**kwargs,
|
||||
)
|
||||
data_or_path_or_list_ret[j].append(data_or_path_or_list_j)
|
||||
|
||||
return data_or_path_or_list_ret
|
||||
else:
|
||||
return [
|
||||
load_audio_text_image_video(
|
||||
audio, fs=fs, audio_fs=audio_fs, data_type=data_type, **kwargs
|
||||
)
|
||||
for audio in data_or_path_or_list
|
||||
]
|
||||
if isinstance(data_or_path_or_list, str) and data_or_path_or_list.startswith(
|
||||
("http://", "https://")
|
||||
): # download url to local file
|
||||
data_or_path_or_list = download_from_url(data_or_path_or_list)
|
||||
|
||||
if isinstance(data_or_path_or_list, str) and os.path.exists(data_or_path_or_list): # local file
|
||||
if data_type is None or data_type == "sound":
|
||||
# if use_ffmpeg:
|
||||
# data_or_path_or_list = _load_audio_ffmpeg(data_or_path_or_list, sr=fs)
|
||||
# data_or_path_or_list = torch.from_numpy(data_or_path_or_list).squeeze() # [n_samples,]
|
||||
# else:
|
||||
# data_or_path_or_list, audio_fs = torchaudio.load(data_or_path_or_list)
|
||||
# if kwargs.get("reduce_channels", True):
|
||||
# data_or_path_or_list = data_or_path_or_list.mean(0)
|
||||
try:
|
||||
data_or_path_or_list, audio_fs = torchaudio.load(data_or_path_or_list)
|
||||
if kwargs.get("reduce_channels", True):
|
||||
data_or_path_or_list = data_or_path_or_list.mean(0)
|
||||
except:
|
||||
data_or_path_or_list = _load_audio_ffmpeg(data_or_path_or_list, sr=fs)
|
||||
data_or_path_or_list = torch.from_numpy(
|
||||
data_or_path_or_list
|
||||
).squeeze() # [n_samples,]
|
||||
elif data_type == "text" and tokenizer is not None:
|
||||
data_or_path_or_list = tokenizer.encode(data_or_path_or_list)
|
||||
elif data_type == "image": # undo
|
||||
pass
|
||||
elif data_type == "video": # undo
|
||||
pass
|
||||
|
||||
# if data_in is a file or url, set is_final=True
|
||||
if "cache" in kwargs:
|
||||
kwargs["cache"]["is_final"] = True
|
||||
kwargs["cache"]["is_streaming_input"] = False
|
||||
elif isinstance(data_or_path_or_list, str) and data_type == "text" and tokenizer is not None:
|
||||
data_or_path_or_list = tokenizer.encode(data_or_path_or_list)
|
||||
elif isinstance(data_or_path_or_list, np.ndarray): # audio sample point
|
||||
data_or_path_or_list = torch.from_numpy(data_or_path_or_list) # .squeeze() # [n_samples,]
|
||||
elif isinstance(data_or_path_or_list, str) and data_type == "kaldi_ark":
|
||||
data_mat = kaldiio.load_mat(data_or_path_or_list)
|
||||
if isinstance(data_mat, tuple):
|
||||
audio_fs, mat = data_mat
|
||||
else:
|
||||
mat = data_mat
|
||||
if mat.dtype == "int16" or mat.dtype == "int32":
|
||||
mat = mat.astype(np.float64)
|
||||
mat = mat / 32768
|
||||
if mat.ndim == 2:
|
||||
mat = mat[:, 0]
|
||||
data_or_path_or_list = mat
|
||||
else:
|
||||
pass
|
||||
# print(f"unsupport data type: {data_or_path_or_list}, return raw data")
|
||||
|
||||
if audio_fs != fs and data_type != "text":
|
||||
resampler = torchaudio.transforms.Resample(audio_fs, fs)
|
||||
data_or_path_or_list = resampler(data_or_path_or_list[None, :])[0, :]
|
||||
return data_or_path_or_list
|
||||
|
||||
|
||||
def load_bytes(input):
|
||||
middle_data = np.frombuffer(input, dtype=np.int16)
|
||||
middle_data = np.asarray(middle_data)
|
||||
if middle_data.dtype.kind not in "iu":
|
||||
raise TypeError("'middle_data' must be an array of integers")
|
||||
dtype = np.dtype("float32")
|
||||
if dtype.kind != "f":
|
||||
raise TypeError("'dtype' must be a floating point type")
|
||||
|
||||
i = np.iinfo(middle_data.dtype)
|
||||
abs_max = 2 ** (i.bits - 1)
|
||||
offset = i.min + abs_max
|
||||
array = np.frombuffer((middle_data.astype(dtype) - offset) / abs_max, dtype=np.float32)
|
||||
return array
|
||||
|
||||
|
||||
def extract_fbank(data, data_len=None, data_type: str = "sound", frontend=None, **kwargs):
|
||||
if isinstance(data, np.ndarray):
|
||||
data = torch.from_numpy(data)
|
||||
if len(data.shape) < 2:
|
||||
data = data[None, :] # data: [batch, N]
|
||||
data_len = [data.shape[1]] if data_len is None else data_len
|
||||
elif isinstance(data, torch.Tensor):
|
||||
if len(data.shape) < 2:
|
||||
data = data[None, :] # data: [batch, N]
|
||||
data_len = [data.shape[1]] if data_len is None else data_len
|
||||
elif isinstance(data, (list, tuple)):
|
||||
data_list, data_len = [], []
|
||||
for data_i in data:
|
||||
if isinstance(data_i, np.ndarray):
|
||||
data_i = torch.from_numpy(data_i)
|
||||
data_list.append(data_i)
|
||||
data_len.append(data_i.shape[0])
|
||||
data = pad_sequence(data_list, batch_first=True) # data: [batch, N]
|
||||
|
||||
data, data_len = frontend(data, data_len, **kwargs)
|
||||
|
||||
if isinstance(data_len, (list, tuple)):
|
||||
data_len = torch.tensor([data_len])
|
||||
return data.to(torch.float32), data_len.to(torch.int32)
|
||||
|
||||
|
||||
def _load_audio_ffmpeg(file: str, sr: int = 16000):
|
||||
"""
|
||||
Open an audio file and read as mono waveform, resampling as necessary
|
||||
|
||||
Parameters
|
||||
----------
|
||||
file: str
|
||||
The audio file to open
|
||||
|
||||
sr: int
|
||||
The sample rate to resample the audio if necessary
|
||||
|
||||
Returns
|
||||
-------
|
||||
A NumPy array containing the audio waveform, in float32 dtype.
|
||||
"""
|
||||
|
||||
# This launches a subprocess to decode audio while down-mixing
|
||||
# and resampling as necessary. Requires the ffmpeg CLI in PATH.
|
||||
# fmt: off
|
||||
cmd = [
|
||||
"ffmpeg",
|
||||
"-nostdin",
|
||||
"-threads", "0",
|
||||
"-i", file,
|
||||
"-f", "s16le",
|
||||
"-ac", "1",
|
||||
"-acodec", "pcm_s16le",
|
||||
"-ar", str(sr),
|
||||
"-"
|
||||
]
|
||||
# fmt: on
|
||||
try:
|
||||
out = run(cmd, capture_output=True, check=True).stdout
|
||||
except CalledProcessError as e:
|
||||
raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e
|
||||
|
||||
return np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0
|
||||
119
modules/python/vendors/FunASR/funasr/utils/misc.py
vendored
Normal file
119
modules/python/vendors/FunASR/funasr/utils/misc.py
vendored
Normal file
@@ -0,0 +1,119 @@
|
||||
import os
|
||||
import io
|
||||
import shutil
|
||||
import logging
|
||||
from collections import OrderedDict
|
||||
import numpy as np
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
|
||||
|
||||
def statistic_model_parameters(model, prefix=None):
|
||||
var_dict = model.state_dict()
|
||||
numel = 0
|
||||
for i, key in enumerate(
|
||||
sorted(list([x for x in var_dict.keys() if "num_batches_tracked" not in x]))
|
||||
):
|
||||
if prefix is None or key.startswith(prefix):
|
||||
numel += var_dict[key].numel()
|
||||
return numel
|
||||
|
||||
|
||||
def int2vec(x, vec_dim=8, dtype=np.int32):
|
||||
b = ("{:0" + str(vec_dim) + "b}").format(x)
|
||||
# little-endian order: lower bit first
|
||||
return (np.array(list(b)[::-1]) == "1").astype(dtype)
|
||||
|
||||
|
||||
def seq2arr(seq, vec_dim=8):
|
||||
return np.row_stack([int2vec(int(x), vec_dim) for x in seq])
|
||||
|
||||
|
||||
def load_scp_as_dict(scp_path, value_type="str", kv_sep=" "):
|
||||
with io.open(scp_path, "r", encoding="utf-8") as f:
|
||||
ret_dict = OrderedDict()
|
||||
for one_line in f.readlines():
|
||||
one_line = one_line.strip()
|
||||
pos = one_line.find(kv_sep)
|
||||
key, value = one_line[:pos], one_line[pos + 1 :]
|
||||
if value_type == "list":
|
||||
value = value.split(" ")
|
||||
ret_dict[key] = value
|
||||
return ret_dict
|
||||
|
||||
|
||||
def load_scp_as_list(scp_path, value_type="str", kv_sep=" "):
|
||||
with io.open(scp_path, "r", encoding="utf8") as f:
|
||||
ret_dict = []
|
||||
for one_line in f.readlines():
|
||||
one_line = one_line.strip()
|
||||
pos = one_line.find(kv_sep)
|
||||
key, value = one_line[:pos], one_line[pos + 1 :]
|
||||
if value_type == "list":
|
||||
value = value.split(" ")
|
||||
ret_dict.append((key, value))
|
||||
return ret_dict
|
||||
|
||||
|
||||
def deep_update(original, update):
|
||||
for key, value in update.items():
|
||||
if isinstance(value, dict) and key in original:
|
||||
if len(value) == 0:
|
||||
original[key] = value
|
||||
deep_update(original[key], value)
|
||||
else:
|
||||
original[key] = value
|
||||
|
||||
|
||||
def prepare_model_dir(**kwargs):
|
||||
|
||||
os.makedirs(kwargs.get("output_dir", "./"), exist_ok=True)
|
||||
|
||||
yaml_file = os.path.join(kwargs.get("output_dir", "./"), "config.yaml")
|
||||
OmegaConf.save(config=kwargs, f=yaml_file)
|
||||
logging.info(f"kwargs: {kwargs}")
|
||||
logging.info("config.yaml is saved to: %s", yaml_file)
|
||||
|
||||
model_path = kwargs.get("model_path", None)
|
||||
if model_path is not None:
|
||||
config_json = os.path.join(model_path, "configuration.json")
|
||||
if os.path.exists(config_json):
|
||||
shutil.copy(
|
||||
config_json, os.path.join(kwargs.get("output_dir", "./"), "configuration.json")
|
||||
)
|
||||
|
||||
|
||||
def extract_filename_without_extension(file_path):
|
||||
"""
|
||||
从给定的文件路径中提取文件名(不包含路径和扩展名)
|
||||
:param file_path: 完整的文件路径
|
||||
:return: 文件名(不含路径和扩展名)
|
||||
"""
|
||||
# 首先,使用os.path.basename获取路径中的文件名部分(含扩展名)
|
||||
filename_with_extension = os.path.basename(file_path)
|
||||
# 然后,使用os.path.splitext分离文件名和扩展名
|
||||
filename, extension = os.path.splitext(filename_with_extension)
|
||||
# 返回不包含扩展名的文件名
|
||||
return filename
|
||||
|
||||
|
||||
def smart_remove(path):
|
||||
"""Intelligently removes files, empty directories, and non-empty directories recursively."""
|
||||
# Check if the provided path exists
|
||||
if not os.path.exists(path):
|
||||
print(f"{path} does not exist.")
|
||||
return
|
||||
|
||||
# If the path is a file, delete it
|
||||
if os.path.isfile(path):
|
||||
os.remove(path)
|
||||
print(f"File {path} has been deleted.")
|
||||
# If the path is a directory
|
||||
elif os.path.isdir(path):
|
||||
try:
|
||||
# Attempt to remove an empty directory
|
||||
os.rmdir(path)
|
||||
print(f"Empty directory {path} has been deleted.")
|
||||
except OSError:
|
||||
# If the directory is not empty, remove it along with all its contents
|
||||
shutil.rmtree(path)
|
||||
print(f"Non-empty directory {path} has been recursively deleted.")
|
||||
423
modules/python/vendors/FunASR/funasr/utils/postprocess_utils.py
vendored
Normal file
423
modules/python/vendors/FunASR/funasr/utils/postprocess_utils.py
vendored
Normal file
@@ -0,0 +1,423 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import string
|
||||
import logging
|
||||
from typing import Any, List, Union
|
||||
|
||||
|
||||
def isChinese(ch: str):
|
||||
if "\u4e00" <= ch <= "\u9fff" or "\u0030" <= ch <= "\u0039" or ch == "@":
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def isAllChinese(word: Union[List[Any], str]):
|
||||
word_lists = []
|
||||
for i in word:
|
||||
cur = i.replace(" ", "")
|
||||
cur = cur.replace("</s>", "")
|
||||
cur = cur.replace("<s>", "")
|
||||
cur = cur.replace("<unk>", "")
|
||||
cur = cur.replace("<OOV>", "")
|
||||
word_lists.append(cur)
|
||||
|
||||
if len(word_lists) == 0:
|
||||
return False
|
||||
|
||||
for ch in word_lists:
|
||||
if isChinese(ch) is False:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def isAllAlpha(word: Union[List[Any], str]):
|
||||
word_lists = []
|
||||
for i in word:
|
||||
cur = i.replace(" ", "")
|
||||
cur = cur.replace("</s>", "")
|
||||
cur = cur.replace("<s>", "")
|
||||
cur = cur.replace("<unk>", "")
|
||||
cur = cur.replace("<OOV>", "")
|
||||
word_lists.append(cur)
|
||||
|
||||
if len(word_lists) == 0:
|
||||
return False
|
||||
|
||||
for ch in word_lists:
|
||||
if ch.isalpha() is False and ch != "'":
|
||||
return False
|
||||
elif ch.isalpha() is True and isChinese(ch) is True:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
# def abbr_dispose(words: List[Any]) -> List[Any]:
|
||||
def abbr_dispose(words: List[Any], time_stamp: List[List] = None) -> List[Any]:
|
||||
words_size = len(words)
|
||||
word_lists = []
|
||||
abbr_begin = []
|
||||
abbr_end = []
|
||||
last_num = -1
|
||||
ts_lists = []
|
||||
ts_nums = []
|
||||
ts_index = 0
|
||||
for num in range(words_size):
|
||||
if num <= last_num:
|
||||
continue
|
||||
|
||||
if len(words[num]) == 1 and words[num].encode("utf-8").isalpha():
|
||||
if (
|
||||
num + 1 < words_size
|
||||
and words[num + 1] == " "
|
||||
and num + 2 < words_size
|
||||
and len(words[num + 2]) == 1
|
||||
and words[num + 2].encode("utf-8").isalpha()
|
||||
):
|
||||
# found the begin of abbr
|
||||
abbr_begin.append(num)
|
||||
num += 2
|
||||
abbr_end.append(num)
|
||||
# to find the end of abbr
|
||||
while True:
|
||||
num += 1
|
||||
if num < words_size and words[num] == " ":
|
||||
num += 1
|
||||
if (
|
||||
num < words_size
|
||||
and len(words[num]) == 1
|
||||
and words[num].encode("utf-8").isalpha()
|
||||
):
|
||||
abbr_end.pop()
|
||||
abbr_end.append(num)
|
||||
last_num = num
|
||||
else:
|
||||
break
|
||||
else:
|
||||
break
|
||||
|
||||
for num in range(words_size):
|
||||
if words[num] == " ":
|
||||
ts_nums.append(ts_index)
|
||||
else:
|
||||
ts_nums.append(ts_index)
|
||||
ts_index += 1
|
||||
last_num = -1
|
||||
for num in range(words_size):
|
||||
if num <= last_num:
|
||||
continue
|
||||
|
||||
if num in abbr_begin:
|
||||
if time_stamp is not None:
|
||||
begin = time_stamp[ts_nums[num]][0]
|
||||
abbr_word = words[num].upper()
|
||||
num += 1
|
||||
while num < words_size:
|
||||
if num in abbr_end:
|
||||
abbr_word += words[num].upper()
|
||||
last_num = num
|
||||
break
|
||||
else:
|
||||
if words[num].encode("utf-8").isalpha():
|
||||
abbr_word += words[num].upper()
|
||||
num += 1
|
||||
word_lists.append(abbr_word)
|
||||
if time_stamp is not None:
|
||||
end = time_stamp[ts_nums[num]][1]
|
||||
ts_lists.append([begin, end])
|
||||
else:
|
||||
word_lists.append(words[num])
|
||||
# length of time_stamp may not equal to length of words because of the (somehow improper) threshold set in timestamp_tools.py line 46, e.g., length of time_stamp can be zero but length of words is not.
|
||||
# Moreover, move "word_lists.append(words[num])" into if clause, to keep length of word_lists and length of ts_lists equal.
|
||||
if time_stamp is not None and ts_nums[num] < len(time_stamp) and words[num] != " ":
|
||||
begin = time_stamp[ts_nums[num]][0]
|
||||
end = time_stamp[ts_nums[num]][1]
|
||||
ts_lists.append([begin, end])
|
||||
begin = end
|
||||
|
||||
if time_stamp is not None:
|
||||
return word_lists, ts_lists
|
||||
else:
|
||||
return word_lists
|
||||
|
||||
|
||||
def sentence_postprocess(words: List[Any], time_stamp: List[List] = None):
|
||||
middle_lists = []
|
||||
word_lists = []
|
||||
word_item = ""
|
||||
ts_lists = []
|
||||
|
||||
# wash words lists
|
||||
for i in words:
|
||||
word = ""
|
||||
if isinstance(i, str):
|
||||
word = i
|
||||
else:
|
||||
word = i.decode("utf-8")
|
||||
|
||||
if word in ["<s>", "</s>", "<unk>", "<OOV>"]:
|
||||
continue
|
||||
else:
|
||||
middle_lists.append(word)
|
||||
|
||||
# all chinese characters
|
||||
if isAllChinese(middle_lists):
|
||||
for i, ch in enumerate(middle_lists):
|
||||
word_lists.append(ch.replace(" ", ""))
|
||||
if time_stamp is not None:
|
||||
ts_lists = time_stamp
|
||||
|
||||
# all alpha characters
|
||||
elif isAllAlpha(middle_lists):
|
||||
ts_flag = True
|
||||
for i, ch in enumerate(middle_lists):
|
||||
if ts_flag and time_stamp is not None:
|
||||
begin = time_stamp[i][0]
|
||||
end = time_stamp[i][1]
|
||||
word = ""
|
||||
if "@@" in ch:
|
||||
word = ch.replace("@@", "")
|
||||
word_item += word
|
||||
if time_stamp is not None:
|
||||
ts_flag = False
|
||||
end = time_stamp[i][1]
|
||||
else:
|
||||
word_item += ch
|
||||
word_lists.append(word_item)
|
||||
word_lists.append(" ")
|
||||
word_item = ""
|
||||
if time_stamp is not None:
|
||||
ts_flag = True
|
||||
end = time_stamp[i][1]
|
||||
ts_lists.append([begin, end])
|
||||
begin = end
|
||||
|
||||
# mix characters
|
||||
else:
|
||||
alpha_blank = False
|
||||
ts_flag = True
|
||||
begin = -1
|
||||
end = -1
|
||||
for i, ch in enumerate(middle_lists):
|
||||
if ts_flag and time_stamp is not None:
|
||||
begin = time_stamp[i][0]
|
||||
end = time_stamp[i][1]
|
||||
word = ""
|
||||
if isAllChinese(ch):
|
||||
if alpha_blank is True:
|
||||
word_lists.pop()
|
||||
word_lists.append(ch)
|
||||
alpha_blank = False
|
||||
if time_stamp is not None:
|
||||
ts_flag = True
|
||||
ts_lists.append([begin, end])
|
||||
begin = end
|
||||
elif "@@" in ch:
|
||||
word = ch.replace("@@", "")
|
||||
word_item += word
|
||||
alpha_blank = False
|
||||
if time_stamp is not None:
|
||||
ts_flag = False
|
||||
end = time_stamp[i][1]
|
||||
elif isAllAlpha(ch):
|
||||
word_item += ch
|
||||
word_lists.append(word_item)
|
||||
word_lists.append(" ")
|
||||
word_item = ""
|
||||
alpha_blank = True
|
||||
if time_stamp is not None:
|
||||
ts_flag = True
|
||||
end = time_stamp[i][1]
|
||||
ts_lists.append([begin, end])
|
||||
begin = end
|
||||
else:
|
||||
word_lists.append(ch)
|
||||
|
||||
if time_stamp is not None:
|
||||
word_lists, ts_lists = abbr_dispose(word_lists, ts_lists)
|
||||
real_word_lists = []
|
||||
for ch in word_lists:
|
||||
if ch != " ":
|
||||
real_word_lists.append(ch)
|
||||
sentence = " ".join(real_word_lists).strip()
|
||||
return sentence, ts_lists, real_word_lists
|
||||
else:
|
||||
word_lists = abbr_dispose(word_lists)
|
||||
real_word_lists = []
|
||||
for ch in word_lists:
|
||||
if ch != " ":
|
||||
real_word_lists.append(ch)
|
||||
sentence = "".join(word_lists).strip()
|
||||
return sentence, real_word_lists
|
||||
|
||||
|
||||
def sentence_postprocess_sentencepiece(words):
|
||||
middle_lists = []
|
||||
word_lists = []
|
||||
word_item = ""
|
||||
|
||||
# wash words lists
|
||||
for i in words:
|
||||
word = ""
|
||||
if isinstance(i, str):
|
||||
word = i
|
||||
else:
|
||||
word = i.decode("utf-8")
|
||||
|
||||
if word in ["<s>", "</s>", "<unk>", "<OOV>"]:
|
||||
continue
|
||||
else:
|
||||
middle_lists.append(word)
|
||||
|
||||
# all alpha characters
|
||||
for i, ch in enumerate(middle_lists):
|
||||
word = ""
|
||||
if "\u2581" in ch and i == 0:
|
||||
word_item = ""
|
||||
word = ch.replace("\u2581", "")
|
||||
word_item += word
|
||||
elif "\u2581" in ch and i != 0:
|
||||
word_lists.append(word_item)
|
||||
word_lists.append(" ")
|
||||
word_item = ""
|
||||
word = ch.replace("\u2581", "")
|
||||
word_item += word
|
||||
else:
|
||||
word_item += ch
|
||||
if word_item is not None:
|
||||
word_lists.append(word_item)
|
||||
# word_lists = abbr_dispose(word_lists)
|
||||
real_word_lists = []
|
||||
for ch in word_lists:
|
||||
if ch != " ":
|
||||
if ch == "i":
|
||||
ch = ch.replace("i", "I")
|
||||
elif ch == "i'm":
|
||||
ch = ch.replace("i'm", "I'm")
|
||||
elif ch == "i've":
|
||||
ch = ch.replace("i've", "I've")
|
||||
elif ch == "i'll":
|
||||
ch = ch.replace("i'll", "I'll")
|
||||
real_word_lists.append(ch)
|
||||
sentence = "".join(word_lists)
|
||||
return sentence, real_word_lists
|
||||
|
||||
|
||||
emo_dict = {
|
||||
"<|HAPPY|>": "😊",
|
||||
"<|SAD|>": "😔",
|
||||
"<|ANGRY|>": "😡",
|
||||
"<|NEUTRAL|>": "",
|
||||
"<|FEARFUL|>": "😰",
|
||||
"<|DISGUSTED|>": "🤢",
|
||||
"<|SURPRISED|>": "😮",
|
||||
}
|
||||
|
||||
event_dict = {
|
||||
"<|BGM|>": "🎼",
|
||||
"<|Speech|>": "",
|
||||
"<|Applause|>": "👏",
|
||||
"<|Laughter|>": "😀",
|
||||
"<|Cry|>": "😭",
|
||||
"<|Sneeze|>": "🤧",
|
||||
"<|Breath|>": "",
|
||||
"<|Cough|>": "🤧",
|
||||
}
|
||||
|
||||
lang_dict = {
|
||||
"<|zh|>": "<|lang|>",
|
||||
"<|en|>": "<|lang|>",
|
||||
"<|yue|>": "<|lang|>",
|
||||
"<|ja|>": "<|lang|>",
|
||||
"<|ko|>": "<|lang|>",
|
||||
"<|nospeech|>": "<|lang|>",
|
||||
}
|
||||
|
||||
emoji_dict = {
|
||||
"<|nospeech|><|Event_UNK|>": "❓",
|
||||
"<|zh|>": "",
|
||||
"<|en|>": "",
|
||||
"<|yue|>": "",
|
||||
"<|ja|>": "",
|
||||
"<|ko|>": "",
|
||||
"<|nospeech|>": "",
|
||||
"<|HAPPY|>": "😊",
|
||||
"<|SAD|>": "😔",
|
||||
"<|ANGRY|>": "😡",
|
||||
"<|NEUTRAL|>": "",
|
||||
"<|BGM|>": "🎼",
|
||||
"<|Speech|>": "",
|
||||
"<|Applause|>": "👏",
|
||||
"<|Laughter|>": "😀",
|
||||
"<|FEARFUL|>": "😰",
|
||||
"<|DISGUSTED|>": "🤢",
|
||||
"<|SURPRISED|>": "😮",
|
||||
"<|Cry|>": "😭",
|
||||
"<|EMO_UNKNOWN|>": "",
|
||||
"<|Sneeze|>": "🤧",
|
||||
"<|Breath|>": "",
|
||||
"<|Cough|>": "😷",
|
||||
"<|Sing|>": "",
|
||||
"<|Speech_Noise|>": "",
|
||||
"<|withitn|>": "",
|
||||
"<|woitn|>": "",
|
||||
"<|GBG|>": "",
|
||||
"<|Event_UNK|>": "",
|
||||
}
|
||||
|
||||
emo_set = {"😊", "😔", "😡", "😰", "🤢", "😮"}
|
||||
event_set = {
|
||||
"🎼",
|
||||
"👏",
|
||||
"😀",
|
||||
"😭",
|
||||
"🤧",
|
||||
"😷",
|
||||
}
|
||||
|
||||
|
||||
def format_str_v2(s):
|
||||
sptk_dict = {}
|
||||
for sptk in emoji_dict:
|
||||
sptk_dict[sptk] = s.count(sptk)
|
||||
s = s.replace(sptk, "")
|
||||
emo = "<|NEUTRAL|>"
|
||||
for e in emo_dict:
|
||||
if sptk_dict[e] > sptk_dict[emo]:
|
||||
emo = e
|
||||
for e in event_dict:
|
||||
if sptk_dict[e] > 0:
|
||||
s = event_dict[e] + s
|
||||
s = s + emo_dict[emo]
|
||||
|
||||
for emoji in emo_set.union(event_set):
|
||||
s = s.replace(" " + emoji, emoji)
|
||||
s = s.replace(emoji + " ", emoji)
|
||||
return s.strip()
|
||||
|
||||
|
||||
def rich_transcription_postprocess(s):
|
||||
def get_emo(s):
|
||||
return s[-1] if s[-1] in emo_set else None
|
||||
|
||||
def get_event(s):
|
||||
return s[0] if s[0] in event_set else None
|
||||
|
||||
s = s.replace("<|nospeech|><|Event_UNK|>", "❓")
|
||||
for lang in lang_dict:
|
||||
s = s.replace(lang, "<|lang|>")
|
||||
s_list = [format_str_v2(s_i).strip(" ") for s_i in s.split("<|lang|>")]
|
||||
new_s = " " + s_list[0]
|
||||
cur_ent_event = get_event(new_s)
|
||||
for i in range(1, len(s_list)):
|
||||
if len(s_list[i]) == 0:
|
||||
continue
|
||||
if get_event(s_list[i]) == cur_ent_event and get_event(s_list[i]) != None:
|
||||
s_list[i] = s_list[i][1:]
|
||||
# else:
|
||||
cur_ent_event = get_event(s_list[i])
|
||||
if get_emo(s_list[i]) != None and get_emo(s_list[i]) == get_emo(new_s):
|
||||
new_s = new_s[:-1]
|
||||
new_s += s_list[i].strip().lstrip()
|
||||
new_s = new_s.replace("The.", " ")
|
||||
return new_s.strip()
|
||||
197
modules/python/vendors/FunASR/funasr/utils/speaker_utils.py
vendored
Normal file
197
modules/python/vendors/FunASR/funasr/utils/speaker_utils.py
vendored
Normal file
@@ -0,0 +1,197 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
""" Some implementations are adapted from https://github.com/yuyq96/D-TDNN
|
||||
"""
|
||||
|
||||
import io
|
||||
from typing import Union
|
||||
|
||||
import librosa as sf
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torchaudio.compliance.kaldi as Kaldi
|
||||
from torch import nn
|
||||
|
||||
from funasr.utils.modelscope_file import File
|
||||
|
||||
|
||||
def check_audio_list(audio: list):
|
||||
audio_dur = 0
|
||||
for i in range(len(audio)):
|
||||
seg = audio[i]
|
||||
assert seg[1] >= seg[0], "modelscope error: Wrong time stamps."
|
||||
assert isinstance(seg[2], np.ndarray), "modelscope error: Wrong data type."
|
||||
assert (
|
||||
int(seg[1] * 16000) - int(seg[0] * 16000) == seg[2].shape[0]
|
||||
), "modelscope error: audio data in list is inconsistent with time length."
|
||||
if i > 0:
|
||||
assert seg[0] >= audio[i - 1][1], "modelscope error: Wrong time stamps."
|
||||
audio_dur += seg[1] - seg[0]
|
||||
return audio_dur
|
||||
# assert audio_dur > 5, 'modelscope error: The effective audio duration is too short.'
|
||||
|
||||
|
||||
def sv_preprocess(inputs: Union[np.ndarray, list]):
|
||||
output = []
|
||||
for i in range(len(inputs)):
|
||||
if isinstance(inputs[i], str):
|
||||
file_bytes = File.read(inputs[i])
|
||||
data, fs = sf.load(io.BytesIO(file_bytes), dtype="float32")
|
||||
if len(data.shape) == 2:
|
||||
data = data[:, 0]
|
||||
data = torch.from_numpy(data).unsqueeze(0)
|
||||
data = data.squeeze(0)
|
||||
elif isinstance(inputs[i], np.ndarray):
|
||||
assert len(inputs[i].shape) == 1, "modelscope error: Input array should be [N, T]"
|
||||
data = inputs[i]
|
||||
if data.dtype in ["int16", "int32", "int64"]:
|
||||
data = (data / (1 << 15)).astype("float32")
|
||||
else:
|
||||
data = data.astype("float32")
|
||||
data = torch.from_numpy(data)
|
||||
else:
|
||||
raise ValueError(
|
||||
"modelscope error: The input type is restricted to audio address and nump array."
|
||||
)
|
||||
output.append(data)
|
||||
return output
|
||||
|
||||
|
||||
def sv_chunk(vad_segments: list, fs=16000) -> list:
|
||||
config = {
|
||||
"seg_dur": 1.5,
|
||||
"seg_shift": 0.75,
|
||||
}
|
||||
|
||||
def seg_chunk(seg_data):
|
||||
seg_st = seg_data[0]
|
||||
data = seg_data[2]
|
||||
chunk_len = int(config["seg_dur"] * fs)
|
||||
chunk_shift = int(config["seg_shift"] * fs)
|
||||
last_chunk_ed = 0
|
||||
seg_res = []
|
||||
for chunk_st in range(0, data.shape[0], chunk_shift):
|
||||
chunk_ed = min(chunk_st + chunk_len, data.shape[0])
|
||||
if chunk_ed <= last_chunk_ed:
|
||||
break
|
||||
last_chunk_ed = chunk_ed
|
||||
chunk_st = max(0, chunk_ed - chunk_len)
|
||||
chunk_data = data[chunk_st:chunk_ed]
|
||||
if chunk_data.shape[0] < chunk_len:
|
||||
chunk_data = np.pad(chunk_data, (0, chunk_len - chunk_data.shape[0]), "constant")
|
||||
seg_res.append([chunk_st / fs + seg_st, chunk_ed / fs + seg_st, chunk_data])
|
||||
return seg_res
|
||||
|
||||
segs = []
|
||||
for i, s in enumerate(vad_segments):
|
||||
segs.extend(seg_chunk(s))
|
||||
|
||||
return segs
|
||||
|
||||
|
||||
def extract_feature(audio):
|
||||
features = []
|
||||
for au in audio:
|
||||
feature = Kaldi.fbank(au.unsqueeze(0), num_mel_bins=80)
|
||||
feature = feature - feature.mean(dim=0, keepdim=True)
|
||||
features.append(feature.unsqueeze(0))
|
||||
features = torch.cat(features)
|
||||
return features
|
||||
|
||||
|
||||
def postprocess(
|
||||
segments: list, vad_segments: list, labels: np.ndarray, embeddings: np.ndarray
|
||||
) -> list:
|
||||
assert len(segments) == len(labels)
|
||||
labels = correct_labels(labels)
|
||||
distribute_res = []
|
||||
for i in range(len(segments)):
|
||||
distribute_res.append([segments[i][0], segments[i][1], labels[i]])
|
||||
# merge the same speakers chronologically
|
||||
distribute_res = merge_seque(distribute_res)
|
||||
|
||||
# accquire speaker center
|
||||
spk_embs = []
|
||||
for i in range(labels.max() + 1):
|
||||
spk_emb = embeddings[labels == i].mean(0)
|
||||
spk_embs.append(spk_emb)
|
||||
spk_embs = np.stack(spk_embs)
|
||||
|
||||
def is_overlapped(t1, t2):
|
||||
if t1 > t2 + 1e-4:
|
||||
return True
|
||||
return False
|
||||
|
||||
# distribute the overlap region
|
||||
for i in range(1, len(distribute_res)):
|
||||
if is_overlapped(distribute_res[i - 1][1], distribute_res[i][0]):
|
||||
p = (distribute_res[i][0] + distribute_res[i - 1][1]) / 2
|
||||
distribute_res[i][0] = p
|
||||
distribute_res[i - 1][1] = p
|
||||
|
||||
# smooth the result
|
||||
distribute_res = smooth(distribute_res)
|
||||
|
||||
return distribute_res
|
||||
|
||||
|
||||
def correct_labels(labels):
|
||||
labels_id = 0
|
||||
id2id = {}
|
||||
new_labels = []
|
||||
for i in labels:
|
||||
if i not in id2id:
|
||||
id2id[i] = labels_id
|
||||
labels_id += 1
|
||||
new_labels.append(id2id[i])
|
||||
return np.array(new_labels)
|
||||
|
||||
|
||||
def merge_seque(distribute_res):
|
||||
res = [distribute_res[0]]
|
||||
for i in range(1, len(distribute_res)):
|
||||
if distribute_res[i][2] != res[-1][2] or distribute_res[i][0] > res[-1][1]:
|
||||
res.append(distribute_res[i])
|
||||
else:
|
||||
res[-1][1] = distribute_res[i][1]
|
||||
return res
|
||||
|
||||
|
||||
def smooth(res, mindur=1):
|
||||
# short segments are assigned to nearest speakers.
|
||||
for i in range(len(res)):
|
||||
res[i][0] = round(res[i][0], 2)
|
||||
res[i][1] = round(res[i][1], 2)
|
||||
if res[i][1] - res[i][0] < mindur:
|
||||
if i == 0:
|
||||
res[i][2] = res[i + 1][2]
|
||||
elif i == len(res) - 1:
|
||||
res[i][2] = res[i - 1][2]
|
||||
elif res[i][0] - res[i - 1][1] <= res[i + 1][0] - res[i][1]:
|
||||
res[i][2] = res[i - 1][2]
|
||||
else:
|
||||
res[i][2] = res[i + 1][2]
|
||||
# merge the speakers
|
||||
res = merge_seque(res)
|
||||
|
||||
return res
|
||||
|
||||
|
||||
def distribute_spk(sentence_list, sd_time_list):
|
||||
sd_sentence_list = []
|
||||
for d in sentence_list:
|
||||
sentence_start = d["ts_list"][0][0]
|
||||
sentence_end = d["ts_list"][-1][1]
|
||||
sentence_spk = 0
|
||||
max_overlap = 0
|
||||
for sd_time in sd_time_list:
|
||||
spk_st, spk_ed, spk = sd_time
|
||||
spk_st = spk_st * 1000
|
||||
spk_ed = spk_ed * 1000
|
||||
overlap = max(min(sentence_end, spk_ed) - max(sentence_start, spk_st), 0)
|
||||
if overlap > max_overlap:
|
||||
max_overlap = overlap
|
||||
sentence_spk = spk
|
||||
d["spk"] = sentence_spk
|
||||
sd_sentence_list.append(d)
|
||||
return sd_sentence_list
|
||||
278
modules/python/vendors/FunASR/funasr/utils/timestamp_tools.py
vendored
Normal file
278
modules/python/vendors/FunASR/funasr/utils/timestamp_tools.py
vendored
Normal file
@@ -0,0 +1,278 @@
|
||||
import torch
|
||||
import codecs
|
||||
import logging
|
||||
import argparse
|
||||
import numpy as np
|
||||
|
||||
# import edit_distance
|
||||
from itertools import zip_longest
|
||||
|
||||
|
||||
def cif_wo_hidden(alphas, threshold):
|
||||
batch_size, len_time = alphas.size()
|
||||
# loop varss
|
||||
integrate = torch.zeros([batch_size], device=alphas.device)
|
||||
# intermediate vars along time
|
||||
list_fires = []
|
||||
for t in range(len_time):
|
||||
alpha = alphas[:, t]
|
||||
integrate += alpha
|
||||
list_fires.append(integrate)
|
||||
fire_place = integrate >= threshold
|
||||
integrate = torch.where(
|
||||
fire_place,
|
||||
integrate - torch.ones([batch_size], device=alphas.device) * threshold,
|
||||
integrate,
|
||||
)
|
||||
fires = torch.stack(list_fires, 1)
|
||||
return fires
|
||||
|
||||
|
||||
def ts_prediction_lfr6_standard(
|
||||
us_alphas, us_peaks, char_list, vad_offset=0.0, force_time_shift=-1.5, sil_in_str=True, upsample_rate=3,
|
||||
):
|
||||
if not len(char_list):
|
||||
return "", []
|
||||
START_END_THRESHOLD = 5
|
||||
MAX_TOKEN_DURATION = 12 # 3 times upsampled
|
||||
TIME_RATE=10.0 * 6 / 1000 / upsample_rate
|
||||
if len(us_alphas.shape) == 2:
|
||||
alphas, peaks = us_alphas[0], us_peaks[0] # support inference batch_size=1 only
|
||||
else:
|
||||
alphas, peaks = us_alphas, us_peaks
|
||||
if char_list[-1] == "</s>":
|
||||
char_list = char_list[:-1]
|
||||
fire_place = (
|
||||
torch.where(peaks >= 1.0 - 1e-4)[0].cpu().numpy() + force_time_shift
|
||||
) # total offset
|
||||
if len(fire_place) != len(char_list) + 1:
|
||||
alphas /= alphas.sum() / (len(char_list) + 1)
|
||||
alphas = alphas.unsqueeze(0)
|
||||
peaks = cif_wo_hidden(alphas, threshold=1.0 - 1e-4)[0]
|
||||
fire_place = (
|
||||
torch.where(peaks >= 1.0 - 1e-4)[0].cpu().numpy() + force_time_shift
|
||||
) # total offset
|
||||
num_frames = peaks.shape[0]
|
||||
timestamp_list = []
|
||||
new_char_list = []
|
||||
# for bicif model trained with large data, cif2 actually fires when a character starts
|
||||
# so treat the frames between two peaks as the duration of the former token
|
||||
# fire_place = torch.where(peaks>=1.0-1e-4)[0].cpu().numpy() + force_time_shift # total offset
|
||||
# assert num_peak == len(char_list) + 1 # number of peaks is supposed to be number of tokens + 1
|
||||
# begin silence
|
||||
if fire_place[0] > START_END_THRESHOLD:
|
||||
# char_list.insert(0, '<sil>')
|
||||
timestamp_list.append([0.0, fire_place[0] * TIME_RATE])
|
||||
new_char_list.append("<sil>")
|
||||
# tokens timestamp
|
||||
for i in range(len(fire_place) - 1):
|
||||
new_char_list.append(char_list[i])
|
||||
if MAX_TOKEN_DURATION < 0 or fire_place[i + 1] - fire_place[i] <= MAX_TOKEN_DURATION:
|
||||
timestamp_list.append([fire_place[i] * TIME_RATE, fire_place[i + 1] * TIME_RATE])
|
||||
else:
|
||||
# cut the duration to token and sil of the 0-weight frames last long
|
||||
_split = fire_place[i] + MAX_TOKEN_DURATION
|
||||
timestamp_list.append([fire_place[i] * TIME_RATE, _split * TIME_RATE])
|
||||
timestamp_list.append([_split * TIME_RATE, fire_place[i + 1] * TIME_RATE])
|
||||
new_char_list.append("<sil>")
|
||||
# tail token and end silence
|
||||
# new_char_list.append(char_list[-1])
|
||||
if num_frames - fire_place[-1] > START_END_THRESHOLD:
|
||||
_end = (num_frames + fire_place[-1]) * 0.5
|
||||
# _end = fire_place[-1]
|
||||
timestamp_list[-1][1] = _end * TIME_RATE
|
||||
timestamp_list.append([_end * TIME_RATE, num_frames * TIME_RATE])
|
||||
new_char_list.append("<sil>")
|
||||
else:
|
||||
if len(timestamp_list)>0:
|
||||
timestamp_list[-1][1] = num_frames * TIME_RATE
|
||||
if vad_offset: # add offset time in model with vad
|
||||
for i in range(len(timestamp_list)):
|
||||
timestamp_list[i][0] = timestamp_list[i][0] + vad_offset / 1000.0
|
||||
timestamp_list[i][1] = timestamp_list[i][1] + vad_offset / 1000.0
|
||||
res_txt = ""
|
||||
for char, timestamp in zip(new_char_list, timestamp_list):
|
||||
# if char != '<sil>':
|
||||
if not sil_in_str and char == "<sil>":
|
||||
continue
|
||||
res_txt += "{} {} {};".format(
|
||||
char, str(timestamp[0] + 0.0005)[:5], str(timestamp[1] + 0.0005)[:5]
|
||||
)
|
||||
res = []
|
||||
for char, timestamp in zip(new_char_list, timestamp_list):
|
||||
if char != "<sil>":
|
||||
res.append([int(timestamp[0] * 1000), int(timestamp[1] * 1000)])
|
||||
return res_txt, res
|
||||
|
||||
|
||||
def timestamp_sentence(
|
||||
punc_id_list, timestamp_postprocessed, text_postprocessed, return_raw_text=False
|
||||
):
|
||||
punc_list = [",", "。", "?", "、"]
|
||||
res = []
|
||||
if text_postprocessed is None:
|
||||
return res
|
||||
if timestamp_postprocessed is None:
|
||||
return res
|
||||
if len(timestamp_postprocessed) == 0:
|
||||
return res
|
||||
if len(text_postprocessed) == 0:
|
||||
return res
|
||||
|
||||
if punc_id_list is None or len(punc_id_list) == 0:
|
||||
res.append(
|
||||
{
|
||||
"text": text_postprocessed.split(),
|
||||
"start": timestamp_postprocessed[0][0],
|
||||
"end": timestamp_postprocessed[-1][1],
|
||||
"timestamp": timestamp_postprocessed,
|
||||
}
|
||||
)
|
||||
return res
|
||||
if len(punc_id_list) != len(timestamp_postprocessed):
|
||||
logging.warning("length mismatch between punc and timestamp")
|
||||
sentence_text = ""
|
||||
sentence_text_seg = ""
|
||||
ts_list = []
|
||||
sentence_start = timestamp_postprocessed[0][0]
|
||||
sentence_end = timestamp_postprocessed[0][1]
|
||||
texts = text_postprocessed.split()
|
||||
punc_stamp_text_list = list(
|
||||
zip_longest(punc_id_list, timestamp_postprocessed, texts, fillvalue=None)
|
||||
)
|
||||
for punc_stamp_text in punc_stamp_text_list:
|
||||
punc_id, timestamp, text = punc_stamp_text
|
||||
if sentence_start is None and timestamp is not None:
|
||||
sentence_start = timestamp[0]
|
||||
# sentence_text += text if text is not None else ''
|
||||
if text is not None:
|
||||
if "a" <= text[0] <= "z" or "A" <= text[0] <= "Z":
|
||||
sentence_text += " " + text
|
||||
elif len(sentence_text) and (
|
||||
"a" <= sentence_text[-1] <= "z" or "A" <= sentence_text[-1] <= "Z"
|
||||
):
|
||||
sentence_text += " " + text
|
||||
else:
|
||||
sentence_text += text
|
||||
sentence_text_seg += text + " "
|
||||
ts_list.append(timestamp)
|
||||
|
||||
punc_id = int(punc_id) if punc_id is not None else 1
|
||||
sentence_end = timestamp[1] if timestamp is not None else sentence_end
|
||||
sentence_text_seg = (
|
||||
sentence_text_seg[:-1] if sentence_text_seg[-1] == " " else sentence_text_seg
|
||||
)
|
||||
if punc_id > 1:
|
||||
sentence_text += punc_list[punc_id - 2]
|
||||
if return_raw_text:
|
||||
res.append(
|
||||
{
|
||||
"text": sentence_text,
|
||||
"start": sentence_start,
|
||||
"end": sentence_end,
|
||||
"timestamp": ts_list,
|
||||
"raw_text": sentence_text_seg,
|
||||
}
|
||||
)
|
||||
else:
|
||||
res.append(
|
||||
{
|
||||
"text": sentence_text,
|
||||
"start": sentence_start,
|
||||
"end": sentence_end,
|
||||
"timestamp": ts_list,
|
||||
}
|
||||
)
|
||||
sentence_text = ""
|
||||
sentence_text_seg = ""
|
||||
ts_list = []
|
||||
sentence_start = None
|
||||
return res
|
||||
|
||||
|
||||
def timestamp_sentence_en(
|
||||
punc_id_list, timestamp_postprocessed, text_postprocessed, return_raw_text=False
|
||||
):
|
||||
punc_list = [",", ".", "?", ","]
|
||||
res = []
|
||||
if text_postprocessed is None:
|
||||
return res
|
||||
if timestamp_postprocessed is None:
|
||||
return res
|
||||
if len(timestamp_postprocessed) == 0:
|
||||
return res
|
||||
if len(text_postprocessed) == 0:
|
||||
return res
|
||||
|
||||
if punc_id_list is None or len(punc_id_list) == 0:
|
||||
res.append(
|
||||
{
|
||||
"text": text_postprocessed.split(),
|
||||
"start": timestamp_postprocessed[0][0],
|
||||
"end": timestamp_postprocessed[-1][1],
|
||||
"timestamp": timestamp_postprocessed,
|
||||
}
|
||||
)
|
||||
return res
|
||||
if len(punc_id_list) != len(timestamp_postprocessed):
|
||||
logging.warning("length mismatch between punc and timestamp")
|
||||
sentence_text = ""
|
||||
sentence_text_seg = ""
|
||||
ts_list = []
|
||||
sentence_start = timestamp_postprocessed[0][0]
|
||||
sentence_end = timestamp_postprocessed[0][1]
|
||||
texts = text_postprocessed.split()
|
||||
punc_stamp_text_list = list(
|
||||
zip_longest(punc_id_list, timestamp_postprocessed, texts, fillvalue=None)
|
||||
)
|
||||
is_sentence_start = True
|
||||
for punc_stamp_text in punc_stamp_text_list:
|
||||
punc_id, timestamp, text = punc_stamp_text
|
||||
# sentence_text += text if text is not None else ''
|
||||
if text is not None:
|
||||
if "a" <= text[0] <= "z" or "A" <= text[0] <= "Z":
|
||||
sentence_text += " " + text
|
||||
elif len(sentence_text) and (
|
||||
"a" <= sentence_text[-1] <= "z" or "A" <= sentence_text[-1] <= "Z"
|
||||
):
|
||||
sentence_text += " " + text
|
||||
else:
|
||||
sentence_text += text
|
||||
sentence_text_seg += text + " "
|
||||
ts_list.append(timestamp)
|
||||
|
||||
punc_id = int(punc_id) if punc_id is not None else 1
|
||||
sentence_end = timestamp[1] if timestamp is not None else sentence_end
|
||||
sentence_text = sentence_text[1:] if sentence_text[0] == ' ' else sentence_text
|
||||
if is_sentence_start:
|
||||
sentence_start = timestamp[0] if timestamp is not None else sentence_start
|
||||
is_sentence_start = False
|
||||
if punc_id > 1:
|
||||
is_sentence_start = True
|
||||
sentence_text += punc_list[punc_id - 2]
|
||||
sentence_text_seg = (
|
||||
sentence_text_seg[:-1] if sentence_text_seg[-1] == " " else sentence_text_seg
|
||||
)
|
||||
if return_raw_text:
|
||||
res.append(
|
||||
{
|
||||
"text": sentence_text,
|
||||
"start": sentence_start,
|
||||
"end": sentence_end,
|
||||
"timestamp": ts_list,
|
||||
"raw_text": sentence_text_seg,
|
||||
}
|
||||
)
|
||||
else:
|
||||
res.append(
|
||||
{
|
||||
"text": sentence_text,
|
||||
"start": sentence_start,
|
||||
"end": sentence_end,
|
||||
"timestamp": ts_list,
|
||||
}
|
||||
)
|
||||
sentence_text = ""
|
||||
sentence_text_seg = ""
|
||||
ts_list = []
|
||||
return res
|
||||
84
modules/python/vendors/FunASR/funasr/utils/torch_function.py
vendored
Normal file
84
modules/python/vendors/FunASR/funasr/utils/torch_function.py
vendored
Normal file
@@ -0,0 +1,84 @@
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
class MakePadMask(nn.Module):
|
||||
def __init__(self, max_seq_len=512, flip=True):
|
||||
super().__init__()
|
||||
if flip:
|
||||
self.mask_pad = torch.Tensor(1 - np.tri(max_seq_len)).type(torch.bool)
|
||||
else:
|
||||
self.mask_pad = torch.Tensor(np.tri(max_seq_len)).type(torch.bool)
|
||||
|
||||
def forward(self, lengths, xs=None, length_dim=-1, maxlen=None):
|
||||
"""Make mask tensor containing indices of padded part.
|
||||
This implementation creates the same mask tensor with original make_pad_mask,
|
||||
which can be converted into onnx format.
|
||||
Dimension length of xs should be 2 or 3.
|
||||
"""
|
||||
if length_dim == 0:
|
||||
raise ValueError("length_dim cannot be 0: {}".format(length_dim))
|
||||
|
||||
if xs is not None and len(xs.shape) == 3:
|
||||
if length_dim == 1:
|
||||
lengths = lengths.unsqueeze(1).expand(*xs.transpose(1, 2).shape[:2])
|
||||
else:
|
||||
lengths = lengths.unsqueeze(1).expand(*xs.shape[:2])
|
||||
|
||||
if maxlen is not None:
|
||||
m = maxlen
|
||||
elif xs is not None:
|
||||
m = xs.shape[-1]
|
||||
else:
|
||||
m = torch.max(lengths)
|
||||
|
||||
mask = self.mask_pad[lengths - 1][..., :m].type(torch.float32)
|
||||
|
||||
if length_dim == 1:
|
||||
return mask.transpose(1, 2)
|
||||
else:
|
||||
return mask
|
||||
|
||||
|
||||
class sequence_mask(nn.Module):
|
||||
def __init__(self, max_seq_len=512, flip=True):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, lengths, max_seq_len=None, dtype=torch.float32, device=None):
|
||||
if max_seq_len is None:
|
||||
max_seq_len = lengths.max()
|
||||
row_vector = torch.arange(0, max_seq_len, 1).to(lengths.device)
|
||||
matrix = torch.unsqueeze(lengths, dim=-1)
|
||||
mask = row_vector < matrix
|
||||
|
||||
return mask.type(dtype).to(device) if device is not None else mask.type(dtype)
|
||||
|
||||
|
||||
def normalize(
|
||||
input: torch.Tensor, p: float = 2.0, dim: int = 1, out: Optional[torch.Tensor] = None
|
||||
) -> torch.Tensor:
|
||||
if out is None:
|
||||
denom = input.norm(p, dim, keepdim=True).expand_as(input)
|
||||
return input / denom
|
||||
else:
|
||||
denom = input.norm(p, dim, keepdim=True).expand_as(input)
|
||||
return torch.div(input, denom, out=out)
|
||||
|
||||
|
||||
def subsequent_mask(size: torch.Tensor):
|
||||
return torch.ones(size, size).tril()
|
||||
|
||||
|
||||
def MakePadMask_test():
|
||||
feats_length = torch.tensor([10]).type(torch.long)
|
||||
mask_fn = MakePadMask()
|
||||
mask = mask_fn(feats_length)
|
||||
print(mask)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
MakePadMask_test()
|
||||
149
modules/python/vendors/FunASR/funasr/utils/type_utils.py
vendored
Normal file
149
modules/python/vendors/FunASR/funasr/utils/type_utils.py
vendored
Normal file
@@ -0,0 +1,149 @@
|
||||
from distutils.util import strtobool
|
||||
from typing import Optional
|
||||
from typing import Tuple
|
||||
from typing import Union
|
||||
|
||||
import humanfriendly
|
||||
|
||||
|
||||
def str2bool(value: str) -> bool:
|
||||
return bool(strtobool(value))
|
||||
|
||||
|
||||
def remove_parenthesis(value: str):
|
||||
value = value.strip()
|
||||
if value.startswith("(") and value.endswith(")"):
|
||||
value = value[1:-1]
|
||||
elif value.startswith("[") and value.endswith("]"):
|
||||
value = value[1:-1]
|
||||
return value
|
||||
|
||||
|
||||
def remove_quotes(value: str):
|
||||
value = value.strip()
|
||||
if value.startswith('"') and value.endswith('"'):
|
||||
value = value[1:-1]
|
||||
elif value.startswith("'") and value.endswith("'"):
|
||||
value = value[1:-1]
|
||||
return value
|
||||
|
||||
|
||||
def int_or_none(value: str) -> Optional[int]:
|
||||
"""int_or_none.
|
||||
|
||||
Examples:
|
||||
>>> import argparse
|
||||
>>> parser = argparse.ArgumentParser()
|
||||
>>> _ = parser.add_argument('--foo', type=int_or_none)
|
||||
>>> parser.parse_args(['--foo', '456'])
|
||||
Namespace(foo=456)
|
||||
>>> parser.parse_args(['--foo', 'none'])
|
||||
Namespace(foo=None)
|
||||
>>> parser.parse_args(['--foo', 'null'])
|
||||
Namespace(foo=None)
|
||||
>>> parser.parse_args(['--foo', 'nil'])
|
||||
Namespace(foo=None)
|
||||
|
||||
"""
|
||||
if value.strip().lower() in ("none", "null", "nil"):
|
||||
return None
|
||||
return int(value)
|
||||
|
||||
|
||||
def float_or_none(value: str) -> Optional[float]:
|
||||
"""float_or_none.
|
||||
|
||||
Examples:
|
||||
>>> import argparse
|
||||
>>> parser = argparse.ArgumentParser()
|
||||
>>> _ = parser.add_argument('--foo', type=float_or_none)
|
||||
>>> parser.parse_args(['--foo', '4.5'])
|
||||
Namespace(foo=4.5)
|
||||
>>> parser.parse_args(['--foo', 'none'])
|
||||
Namespace(foo=None)
|
||||
>>> parser.parse_args(['--foo', 'null'])
|
||||
Namespace(foo=None)
|
||||
>>> parser.parse_args(['--foo', 'nil'])
|
||||
Namespace(foo=None)
|
||||
|
||||
"""
|
||||
if value.strip().lower() in ("none", "null", "nil"):
|
||||
return None
|
||||
return float(value)
|
||||
|
||||
|
||||
def humanfriendly_parse_size_or_none(value) -> Optional[float]:
|
||||
if value.strip().lower() in ("none", "null", "nil"):
|
||||
return None
|
||||
return humanfriendly.parse_size(value)
|
||||
|
||||
|
||||
def str_or_int(value: str) -> Union[str, int]:
|
||||
try:
|
||||
return int(value)
|
||||
except ValueError:
|
||||
return value
|
||||
|
||||
|
||||
def str_or_none(value: str) -> Optional[str]:
|
||||
"""str_or_none.
|
||||
|
||||
Examples:
|
||||
>>> import argparse
|
||||
>>> parser = argparse.ArgumentParser()
|
||||
>>> _ = parser.add_argument('--foo', type=str_or_none)
|
||||
>>> parser.parse_args(['--foo', 'aaa'])
|
||||
Namespace(foo='aaa')
|
||||
>>> parser.parse_args(['--foo', 'none'])
|
||||
Namespace(foo=None)
|
||||
>>> parser.parse_args(['--foo', 'null'])
|
||||
Namespace(foo=None)
|
||||
>>> parser.parse_args(['--foo', 'nil'])
|
||||
Namespace(foo=None)
|
||||
|
||||
"""
|
||||
if value.strip().lower() in ("none", "null", "nil"):
|
||||
return None
|
||||
return value
|
||||
|
||||
|
||||
def str2pair_str(value: str) -> Tuple[str, str]:
|
||||
"""str2pair_str.
|
||||
|
||||
Examples:
|
||||
>>> import argparse
|
||||
>>> str2pair_str('abc,def ')
|
||||
('abc', 'def')
|
||||
>>> parser = argparse.ArgumentParser()
|
||||
>>> _ = parser.add_argument('--foo', type=str2pair_str)
|
||||
>>> parser.parse_args(['--foo', 'abc,def'])
|
||||
Namespace(foo=('abc', 'def'))
|
||||
|
||||
"""
|
||||
value = remove_parenthesis(value)
|
||||
a, b = value.split(",")
|
||||
|
||||
# Workaround for configargparse issues:
|
||||
# If the list values are given from yaml file,
|
||||
# the value givent to type() is shaped as python-list,
|
||||
# e.g. ['a', 'b', 'c'],
|
||||
# so we need to remove double quotes from it.
|
||||
return remove_quotes(a), remove_quotes(b)
|
||||
|
||||
|
||||
def str2triple_str(value: str) -> Tuple[str, str, str]:
|
||||
"""str2triple_str.
|
||||
|
||||
Examples:
|
||||
>>> str2triple_str('abc,def ,ghi')
|
||||
('abc', 'def', 'ghi')
|
||||
"""
|
||||
value = remove_parenthesis(value)
|
||||
a, b, c = value.split(",")
|
||||
|
||||
# Workaround for configargparse issues:
|
||||
# If the list values are given from yaml file,
|
||||
# the value givent to type() is shaped as python-list,
|
||||
# e.g. ['a', 'b', 'c'],
|
||||
# so we need to remove quotes from it.
|
||||
return remove_quotes(a), remove_quotes(b), remove_quotes(c)
|
||||
59
modules/python/vendors/FunASR/funasr/utils/vad_utils.py
vendored
Normal file
59
modules/python/vendors/FunASR/funasr/utils/vad_utils.py
vendored
Normal file
@@ -0,0 +1,59 @@
|
||||
import torch
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
|
||||
|
||||
def slice_padding_fbank(speech, speech_lengths, vad_segments):
|
||||
speech_list = []
|
||||
speech_lengths_list = []
|
||||
for i, segment in enumerate(vad_segments):
|
||||
|
||||
bed_idx = int(segment[0][0] * 16)
|
||||
end_idx = min(int(segment[0][1] * 16), speech_lengths[0])
|
||||
speech_i = speech[0, bed_idx:end_idx]
|
||||
speech_lengths_i = end_idx - bed_idx
|
||||
speech_list.append(speech_i)
|
||||
speech_lengths_list.append(speech_lengths_i)
|
||||
feats_pad = pad_sequence(speech_list, batch_first=True, padding_value=0.0)
|
||||
speech_lengths_pad = torch.Tensor(speech_lengths_list).int()
|
||||
return feats_pad, speech_lengths_pad
|
||||
|
||||
|
||||
def slice_padding_audio_samples(speech, speech_lengths, vad_segments):
|
||||
speech_list = []
|
||||
speech_lengths_list = []
|
||||
for i, segment in enumerate(vad_segments):
|
||||
bed_idx = int(segment[0][0] * 16)
|
||||
end_idx = min(int(segment[0][1] * 16), speech_lengths)
|
||||
speech_i = speech[bed_idx:end_idx]
|
||||
speech_lengths_i = end_idx - bed_idx
|
||||
speech_list.append(speech_i)
|
||||
speech_lengths_list.append(speech_lengths_i)
|
||||
|
||||
return speech_list, speech_lengths_list
|
||||
|
||||
|
||||
def merge_vad(vad_result, max_length=15000, min_length=0):
|
||||
new_result = []
|
||||
if len(vad_result) <= 1:
|
||||
return vad_result
|
||||
time_step = [t[0] for t in vad_result] + [t[1] for t in vad_result]
|
||||
time_step = sorted(list(set(time_step)))
|
||||
if len(time_step) == 0:
|
||||
return []
|
||||
bg = 0
|
||||
for i in range(len(time_step) - 1):
|
||||
time = time_step[i]
|
||||
if time_step[i + 1] - bg < max_length:
|
||||
continue
|
||||
if time - bg > min_length:
|
||||
new_result.append([bg, time])
|
||||
# if time - bg < max_length * 1.5:
|
||||
# new_result.append([bg, time])
|
||||
# else:
|
||||
# split_num = int(time - bg) // max_length + 1
|
||||
# spl_l = int(time - bg) // split_num
|
||||
# for j in range(split_num):
|
||||
# new_result.append([bg + j * spl_l, bg + (j + 1) * spl_l])
|
||||
bg = time
|
||||
new_result.append([bg, time_step[-1]])
|
||||
return new_result
|
||||
32
modules/python/vendors/FunASR/funasr/utils/version_checker.py
vendored
Normal file
32
modules/python/vendors/FunASR/funasr/utils/version_checker.py
vendored
Normal file
@@ -0,0 +1,32 @@
|
||||
from packaging import version
|
||||
from funasr import __version__ # Ensure that __version__ is defined in your package's __init__.py
|
||||
|
||||
|
||||
def get_pypi_version(package_name):
|
||||
import requests
|
||||
|
||||
url = f"https://pypi.org/pypi/{package_name}/json"
|
||||
response = requests.get(url)
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
return version.parse(data["info"]["version"])
|
||||
else:
|
||||
raise Exception("Failed to retrieve version information from PyPI.")
|
||||
|
||||
|
||||
def check_for_update(disable=False):
|
||||
current_version = version.parse(__version__)
|
||||
print(f"funasr version: {current_version}.")
|
||||
if disable:
|
||||
return
|
||||
|
||||
print(
|
||||
"Check update of funasr, and it would cost few times. You may disable it by set `disable_update=True` in AutoModel"
|
||||
)
|
||||
pypi_version = get_pypi_version("funasr")
|
||||
|
||||
if current_version < pypi_version:
|
||||
print(f"New version is available: {pypi_version}.")
|
||||
print('Please use the command "pip install -U funasr" to upgrade.')
|
||||
else:
|
||||
print(f"You are using the latest version of funasr-{current_version}")
|
||||
Reference in New Issue
Block a user