Sync from bytedesk-private: update

This commit is contained in:
jack ning
2024-12-14 10:43:18 +08:00
parent 476eebb101
commit 5e082909e4
3421 changed files with 812709 additions and 0 deletions

View File

@@ -0,0 +1,27 @@
# GRPC python Client for 2pass decoding
The client can send streaming or full audio data to server as you wish, and get transcribed text once the server respond (depends on mode)
In the demo client, audio_chunk_duration is set to 1000ms, and send_interval is set to 100ms
### 1. Install the requirements
```shell
git clone https://github.com/alibaba/FunASR.git && cd FunASR/funasr/runtime/python/grpc
pip install -r requirements.txt
```
### 2. Generate protobuf file
```shell
# paraformer_pb2.py and paraformer_pb2_grpc.py are already generated,
# regenerate it only when you make changes to ./proto/paraformer.proto file.
python -m grpc_tools.protoc --proto_path=./proto -I ./proto --python_out=. --grpc_python_out=./ ./proto/paraformer.proto
```
### 3. Start grpc client
```
# Start client.
python grpc_main_client.py --host 127.0.0.1 --port 10100 --wav_path /path/to/your_test_wav.wav
```
## Acknowledge
1. This project is maintained by [FunASR community](https://github.com/alibaba-damo-academy/FunASR).
2. We acknowledge burkliu (刘柏基, liubaiji@xverse.cn) for contributing the grpc service.

View File

@@ -0,0 +1,83 @@
"""
Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights
Reserved. MIT License (https://opensource.org/licenses/MIT)
2023 by burkliu(刘柏基) liubaiji@xverse.cn
"""
import logging
import argparse
import soundfile as sf
import time
import grpc
import paraformer_pb2_grpc
from paraformer_pb2 import Request, WavFormat, DecodeMode
class GrpcClient:
def __init__(self, wav_path, uri, mode):
self.wav, self.sampling_rate = sf.read(wav_path, dtype="int16")
self.wav_format = WavFormat.pcm
self.audio_chunk_duration = 1000 # ms
self.audio_chunk_size = int(self.sampling_rate * self.audio_chunk_duration / 1000)
self.send_interval = 100 # ms
self.mode = mode
# connect to grpc server
channel = grpc.insecure_channel(uri)
self.stub = paraformer_pb2_grpc.ASRStub(channel)
# start request
for respond in self.stub.Recognize(self.request_iterator()):
logging.info(
"[receive] mode {}, text {}, is final {}".format(
DecodeMode.Name(respond.mode), respond.text, respond.is_final
)
)
def request_iterator(self, mode=DecodeMode.two_pass):
is_first_pack = True
is_final = False
for start in range(0, len(self.wav), self.audio_chunk_size):
request = Request()
audio_chunk = self.wav[start : start + self.audio_chunk_size]
if is_first_pack:
is_first_pack = False
request.sampling_rate = self.sampling_rate
request.mode = self.mode
request.wav_format = self.wav_format
if request.mode == DecodeMode.two_pass or request.mode == DecodeMode.online:
request.chunk_size.extend([5, 10, 5])
if start + self.audio_chunk_size >= len(self.wav):
is_final = True
request.is_final = is_final
request.audio_data = audio_chunk.tobytes()
logging.info(
"[request] audio_data len {}, is final {}".format(
len(request.audio_data), request.is_final
)
) # int16 = 2bytes
time.sleep(self.send_interval / 1000)
yield request
if __name__ == "__main__":
logging.basicConfig(filename="", format="%(asctime)s %(message)s", level=logging.INFO)
parser = argparse.ArgumentParser()
parser.add_argument(
"--host", type=str, default="127.0.0.1", required=False, help="grpc server host ip"
)
parser.add_argument("--port", type=int, default=10100, required=False, help="grpc server port")
parser.add_argument("--wav_path", type=str, required=True, help="audio wav path")
args = parser.parse_args()
for mode in [DecodeMode.offline, DecodeMode.online, DecodeMode.two_pass]:
mode_name = DecodeMode.Name(mode)
logging.info("[request] start requesting with mode {}".format(mode_name))
st = time.time()
uri = "{}:{}".format(args.host, args.port)
client = GrpcClient(args.wav_path, uri, mode)
logging.info("mode {}, time pass: {}".format(mode_name, time.time() - st))

View File

@@ -0,0 +1,25 @@
```
service ASR { //grpc service
rpc Recognize (stream Request) returns (stream Response) {} //Stub
}
message Request { //request data
bytes audio_data = 1; //audio data in bytes.
string user = 2; //user allowed.
string language = 3; //language, zh-CN for now.
bool speaking = 4; //flag for speaking.
bool isEnd = 5; //flag for end. set isEnd to true when you stop asr:
//vad:is_speech then speaking=True & isEnd = False, audio data will be appended for the specfied user.
//vad:silence then speaking=False & isEnd = False, clear audio buffer and do asr inference.
}
message Response { //response data.
string sentence = 1; //json, includes flag for success and asr text .
string user = 2; //same to request user.
string language = 3; //same to request language.
string action = 4; //server status:
//terminateasr stopped;
//speakinguser is speaking, audio data is appended;
//decoding: server is decoding;
//finish: get asr text, most used.
}

View File

@@ -0,0 +1,39 @@
// Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights
// Reserved. MIT License (https://opensource.org/licenses/MIT)
//
// 2023 by burkliu(刘柏基) liubaiji@xverse.cn
syntax = "proto3";
option objc_class_prefix = "paraformer";
package paraformer;
service ASR {
rpc Recognize (stream Request) returns (stream Response) {}
}
enum WavFormat {
pcm = 0;
}
enum DecodeMode {
offline = 0;
online = 1;
two_pass = 2;
}
message Request {
DecodeMode mode = 1;
WavFormat wav_format = 2;
int32 sampling_rate = 3;
repeated int32 chunk_size = 4;
bool is_final = 5;
bytes audio_data = 6;
}
message Response {
DecodeMode mode = 1;
string text = 2;
bool is_final = 3;
}

Binary file not shown.

After

Width:  |  Height:  |  Size: 64 KiB

View File

@@ -0,0 +1,2 @@
grpcio
grpcio-tools

View File

@@ -0,0 +1,68 @@
# Service with http-python
## Server
1. Install requirements
```shell
cd funasr/runtime/python/http
pip install -r requirements.txt
```
2. Start server
```shell
python server.py --port 8000
```
More parameters:
```shell
python server.py \
--host [host ip] \
--port [server port] \
--asr_model [asr model_name] \
--vad_model [vad model_name] \
--punc_model [punc model_name] \
--device [cuda or cpu] \
--ngpu [0 or 1] \
--ncpu [1 or 4] \
--hotword_path [path of hot word txt] \
--certfile [path of certfile for ssl] \
--keyfile [path of keyfile for ssl] \
--temp_dir [upload file temp dir]
```
## Client
```shell
# get test audio file
wget https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav
python client.py --host=127.0.0.1 --port=8000 --audio_path=asr_example_zh.wav
```
More parameters:
```shell
python server.py \
--host [sever ip] \
--port [sever port] \
--audio_path [use audio path]
```
## 支持多进程
方法是启动多个`server.py`然后通过Nginx的负载均衡分发请求达到支持多用户同时连效果处理方式如下默认您已经安装了Nginx没安装的请参考[官方安装教程](https://nginx.org/en/linux_packages.html#Ubuntu)。
配置Nginx。
```shell
sudo cp -f asr_nginx.conf /etc/nginx/nginx.conf
sudo service nginx reload
```
然后使用脚本启动多个服务,每个服务的端口号不一样。
```shell
sudo chmod +x start_server.sh
./start_server.sh
```
**说明:** 默认是3个进程如果需要修改首先修改`start_server.sh`的最后那部分,可以添加启动数量。然后修改`asr_nginx.conf`配置文件的`upstream backend`部分,增加新启动的服务,可以使其他服务器的服务。

View File

@@ -0,0 +1,44 @@
user nginx;
worker_processes auto;
error_log /var/log/nginx/error.log notice;
pid /var/run/nginx.pid;
events {
worker_connections 1024;
}
http {
include /etc/nginx/mime.types;
default_type application/octet-stream;
log_format main '$remote_addr - $remote_user [$time_local] "$request" '
'$status $body_bytes_sent "$http_referer" '
'"$http_user_agent" "$http_x_forwarded_for"';
access_log /var/log/nginx/access.log main;
sendfile on;
keepalive_timeout 65;
upstream backend {
# 最少连接算法
least_conn;
# 启动的服务地址
server localhost:8001;
server localhost:8002;
server localhost:8003;
}
server {
# 实际访问的端口
listen 8000;
location / {
proxy_pass http://backend;
}
}
include /etc/nginx/conf.d/*.conf;
}

View File

@@ -0,0 +1,33 @@
import os
import requests
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="127.0.0.1", required=False, help="sever ip")
parser.add_argument("--port", type=int, default=8000, required=False, help="server port")
parser.add_argument(
"--audio_path", type=str, default="asr_example_zh.wav", required=False, help="use audio path"
)
args = parser.parse_args()
print("----------- Configuration Arguments -----------")
for arg, value in vars(args).items():
print("%s: %s" % (arg, value))
print("------------------------------------------------")
url = f"http://{args.host}:{args.port}/recognition"
headers = {}
files = [
(
"audio",
(
os.path.basename(args.audio_path),
open(args.audio_path, "rb"),
"application/octet-stream",
),
)
]
response = requests.post(url, headers=headers, files=files)
print(response.text)

View File

@@ -0,0 +1,2 @@
阿里巴巴
通义实验室

View File

@@ -0,0 +1,6 @@
modelscope>=1.11.1
funasr>=1.0.5
fastapi>=0.95.1
aiofiles
uvicorn
requests

View File

@@ -0,0 +1,133 @@
import argparse
import logging
import os
import uuid
import aiofiles
import ffmpeg
import uvicorn
from fastapi import FastAPI, File, UploadFile
from modelscope.utils.logger import get_logger
from funasr import AutoModel
logger = get_logger(log_level=logging.INFO)
logger.setLevel(logging.INFO)
parser = argparse.ArgumentParser()
parser.add_argument(
"--host", type=str, default="0.0.0.0", required=False, help="host ip, localhost, 0.0.0.0"
)
parser.add_argument("--port", type=int, default=8000, required=False, help="server port")
parser.add_argument(
"--asr_model",
type=str,
default="paraformer-zh",
help="asr model from https://github.com/alibaba-damo-academy/FunASR?tab=readme-ov-file#model-zoo",
)
parser.add_argument("--asr_model_revision", type=str, default="v2.0.4", help="")
parser.add_argument(
"--vad_model",
type=str,
default="fsmn-vad",
help="vad model from https://github.com/alibaba-damo-academy/FunASR?tab=readme-ov-file#model-zoo",
)
parser.add_argument("--vad_model_revision", type=str, default="v2.0.4", help="")
parser.add_argument(
"--punc_model",
type=str,
default="ct-punc-c",
help="model from https://github.com/alibaba-damo-academy/FunASR?tab=readme-ov-file#model-zoo",
)
parser.add_argument("--punc_model_revision", type=str, default="v2.0.4", help="")
parser.add_argument("--ngpu", type=int, default=1, help="0 for cpu, 1 for gpu")
parser.add_argument("--device", type=str, default="cuda", help="cuda, cpu")
parser.add_argument("--ncpu", type=int, default=4, help="cpu cores")
parser.add_argument(
"--hotword_path",
type=str,
default="hotwords.txt",
help="hot word txt path, only the hot word model works",
)
parser.add_argument("--certfile", type=str, default=None, required=False, help="certfile for ssl")
parser.add_argument("--keyfile", type=str, default=None, required=False, help="keyfile for ssl")
parser.add_argument("--temp_dir", type=str, default="temp_dir/", required=False, help="temp dir")
args = parser.parse_args()
logger.info("----------- Configuration Arguments -----------")
for arg, value in vars(args).items():
logger.info("%s: %s" % (arg, value))
logger.info("------------------------------------------------")
os.makedirs(args.temp_dir, exist_ok=True)
logger.info("model loading")
# load funasr model
model = AutoModel(
model=args.asr_model,
model_revision=args.asr_model_revision,
vad_model=args.vad_model,
vad_model_revision=args.vad_model_revision,
punc_model=args.punc_model,
punc_model_revision=args.punc_model_revision,
ngpu=args.ngpu,
ncpu=args.ncpu,
device=args.device,
disable_pbar=True,
disable_log=True,
)
logger.info("loaded models!")
app = FastAPI(title="FunASR")
param_dict = {"sentence_timestamp": True, "batch_size_s": 300}
if args.hotword_path is not None and os.path.exists(args.hotword_path):
with open(args.hotword_path, "r", encoding="utf-8") as f:
lines = f.readlines()
lines = [line.strip() for line in lines]
hotword = " ".join(lines)
logger.info(f"热词:{hotword}")
param_dict["hotword"] = hotword
@app.post("/recognition")
async def api_recognition(audio: UploadFile = File(..., description="audio file")):
suffix = audio.filename.split(".")[-1]
audio_path = f"{args.temp_dir}/{str(uuid.uuid1())}.{suffix}"
async with aiofiles.open(audio_path, "wb") as out_file:
content = await audio.read()
await out_file.write(content)
try:
audio_bytes, _ = (
ffmpeg.input(audio_path, threads=0)
.output("-", format="s16le", acodec="pcm_s16le", ac=1, ar=16000)
.run(cmd=["ffmpeg", "-nostdin"], capture_stdout=True, capture_stderr=True)
)
except Exception as e:
logger.error(f"读取音频文件发生错误,错误信息:{e}")
return {"msg": "读取音频文件发生错误", "code": 1}
rec_results = model.generate(input=audio_bytes, is_final=True, **param_dict)
# 结果为空
if len(rec_results) == 0:
return {"text": "", "sentences": [], "code": 0}
elif len(rec_results) == 1:
# 解析识别结果
rec_result = rec_results[0]
text = rec_result["text"]
sentences = []
for sentence in rec_result["sentence_info"]:
# 每句话的时间戳
sentences.append(
{"text": sentence["text"], "start": sentence["start"], "end": sentence["end"]}
)
ret = {"text": text, "sentences": sentences, "code": 0}
logger.info(f"识别结果:{ret}")
return ret
else:
logger.info(f"识别结果:{rec_results}")
return {"msg": "未知错误", "code": -1}
if __name__ == "__main__":
uvicorn.run(
app, host=args.host, port=args.port, ssl_keyfile=args.keyfile, ssl_certfile=args.certfile
)

View File

@@ -0,0 +1,21 @@
#!/bin/bash
# 创建日志文件夹
if [ ! -d "log/" ];then
mkdir log
fi
# kill掉之前的进程
server_id=`ps -ef | grep server.py | grep -v "grep" | awk '{print $2}'`
echo $server_id
for id in $server_id
do
kill -9 $id
echo "killed $id"
done
# 启动多个服务,可以设置使用不同的显卡
CUDA_VISIBLE_DEVICES=0 nohup python -u server.py --host=localhost --port=8001 >> log/output1.log 2>&1 &
CUDA_VISIBLE_DEVICES=0 nohup python -u server.py --host=localhost --port=8002 >> log/output2.log 2>&1 &
CUDA_VISIBLE_DEVICES=0 nohup python -u server.py --host=localhost --port=8003 >> log/output3.log 2>&1 &

View File

@@ -0,0 +1,79 @@
# Libtorch-python
## Export the model
### Install [modelscope and funasr](https://github.com/alibaba-damo-academy/FunASR#installation)
```shell
# pip3 install torch torchaudio
pip install -U modelscope funasr
# For the users in China, you could install with the command:
# pip install -U modelscope funasr -i https://mirror.sjtu.edu.cn/pypi/web/simple
pip install torch-quant # Optional, for torchscript quantization
pip install onnx onnxruntime # Optional, for onnx quantization
```
### Export [onnx model](https://github.com/alibaba-damo-academy/FunASR/tree/main/funasr/export)
```shell
python -m funasr.export.export_model --model-name damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch --export-dir ./export --type torch --quantize True
```
## Install the `funasr_torch`
install from pip
```shell
pip install -U funasr_torch
# For the users in China, you could install with the command:
# pip install -U funasr_torch -i https://mirror.sjtu.edu.cn/pypi/web/simple
```
or install from source code
```shell
git clone https://github.com/alibaba/FunASR.git && cd FunASR
cd funasr/runtime/python/libtorch
pip install -e ./
# For the users in China, you could install with the command:
# pip install -e ./ -i https://mirror.sjtu.edu.cn/pypi/web/simple
```
## Run the demo
- Model_dir: the model path, which contains `model.torchscript`, `config.yaml`, `am.mvn`.
- Input: wav formt file, support formats: `str, np.ndarray, List[str]`
- Output: `List[str]`: recognition result.
- Example:
```python
from funasr_torch import Paraformer
model_dir = "/nfs/zhifu.gzf/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
model = Paraformer(model_dir, batch_size=1)
wav_path = ['/nfs/zhifu.gzf/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/example/asr_example.wav']
result = model(wav_path)
print(result)
```
## Performance benchmark
Please ref to [benchmark](https://github.com/alibaba-damo-academy/FunASR/blob/main/runtime/docs/benchmark_libtorch.md)
## Speed
EnvironmentIntel(R) Xeon(R) Platinum 8163 CPU @ 2.50GHz
Test [wav, 5.53s, 100 times avg.](https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav)
| Backend | RTF (FP32) |
|:--------:|:----------:|
| Pytorch | 0.110 |
| Libtorch | 0.048 |
| Onnx | 0.038 |
## Acknowledge
This project is maintained by [FunASR community](https://github.com/alibaba-damo-academy/FunASR).

View File

@@ -0,0 +1,13 @@
import torch
from pathlib import Path
from funasr_torch.paraformer_bin import ContextualParaformer
model_dir = "iic/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404"
device_id = 0 if torch.cuda.is_available() else -1
model = ContextualParaformer(model_dir, batch_size=1, device_id=device_id) # gpu
wav_path = "{}/.cache/modelscope/hub/{}/example/asr_example.wav".format(Path.home(), model_dir)
hotwords = "你的热词 魔搭"
result = model(wav_path, hotwords)
print(result)

View File

@@ -0,0 +1,11 @@
from pathlib import Path
from funasr_torch.paraformer_bin import Paraformer
model_dir = "iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
model = Paraformer(model_dir, batch_size=1) # cpu
# model = Paraformer(model_dir, batch_size=1, device_id=0) # gpu
wav_path = "{}/.cache/modelscope/hub/{}/example/asr_example.wav".format(Path.home(), model_dir)
result = model(wav_path)
print(result)

View File

@@ -0,0 +1,13 @@
import torch
from pathlib import Path
from funasr_torch.paraformer_bin import SeacoParaformer
model_dir = "iic/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
device_id = 0 if torch.cuda.is_available() else -1
model = SeacoParaformer(model_dir, batch_size=1, device_id=device_id) # gpu
wav_path = "{}/.cache/modelscope/hub/{}/example/asr_example.wav".format(Path.home(), model_dir)
hotwords = "你的热词 魔搭"
result = model(wav_path, hotwords)
print(result)

View File

@@ -0,0 +1,18 @@
#!/usr/bin/env python3
# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/FunAudioLLM/SenseVoice). All Rights Reserved.
# MIT License (https://opensource.org/licenses/MIT)
from pathlib import Path
from funasr_torch import SenseVoiceSmall
from funasr_torch.utils.postprocess_utils import rich_transcription_postprocess
model_dir = "iic/SenseVoiceSmall"
model = SenseVoiceSmall(model_dir, device="cuda:0")
wav_or_scp = ["{}/.cache/modelscope/hub/{}/example/en.mp3".format(Path.home(), model_dir)]
res = model(wav_or_scp, language="auto", use_itn=True)
print([rich_transcription_postprocess(i) for i in res])

View File

@@ -0,0 +1,3 @@
# -*- encoding: utf-8 -*-
from .paraformer_bin import Paraformer
from .sensevoice_bin import SenseVoiceSmall

View File

@@ -0,0 +1,420 @@
# -*- encoding: utf-8 -*-
import json
import copy
import torch
import os.path
import librosa
import numpy as np
from pathlib import Path
from typing import List, Union, Tuple
from .utils.utils import pad_list
from .utils.frontend import WavFrontend
from .utils.timestamp_utils import time_stamp_lfr6_onnx
from .utils.postprocess_utils import sentence_postprocess
from .utils.utils import CharTokenizer, Hypothesis, TokenIDConverter, get_logger, read_yaml
logging = get_logger()
class Paraformer:
"""
Author: Speech Lab of DAMO Academy, Alibaba Group
Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
https://arxiv.org/abs/2206.08317
"""
def __init__(
self,
model_dir: Union[str, Path] = None,
batch_size: int = 1,
device_id: Union[str, int] = "-1",
plot_timestamp_to: str = "",
quantize: bool = False,
cache_dir: str = None,
**kwargs,
):
if not Path(model_dir).exists():
try:
from modelscope.hub.snapshot_download import snapshot_download
except:
raise "You are exporting model from modelscope, please install modelscope and try it again. To install modelscope, you could:\n" "\npip3 install -U modelscope\n" "For the users in China, you could install with the command:\n" "\npip3 install -U modelscope -i https://mirror.sjtu.edu.cn/pypi/web/simple"
try:
model_dir = snapshot_download(model_dir, cache_dir=cache_dir)
except:
raise "model_dir must be model_name in modelscope or local path downloaded from modelscope, but is {}".format(
model_dir
)
model_file = os.path.join(model_dir, "model.torchscript")
if quantize:
model_file = os.path.join(model_dir, "model_quant.torchscript")
if not os.path.exists(model_file):
print(".torchscripts does not exist, begin to export torchscript")
try:
from funasr import AutoModel
except:
raise "You are exporting onnx, please install funasr and try it again. To install funasr, you could:\n" "\npip3 install -U funasr\n" "For the users in China, you could install with the command:\n" "\npip3 install -U funasr -i https://mirror.sjtu.edu.cn/pypi/web/simple"
model = AutoModel(model=model_dir)
model_dir = model.export(type="torchscript", quantize=quantize, **kwargs)
config_file = os.path.join(model_dir, "config.yaml")
cmvn_file = os.path.join(model_dir, "am.mvn")
config = read_yaml(config_file)
token_list = os.path.join(model_dir, "tokens.json")
with open(token_list, "r", encoding="utf-8") as f:
token_list = json.load(f)
self.converter = TokenIDConverter(token_list)
self.tokenizer = CharTokenizer()
self.frontend = WavFrontend(cmvn_file=cmvn_file, **config["frontend_conf"])
self.ort_infer = torch.jit.load(model_file)
self.batch_size = batch_size
self.device_id = device_id
self.plot_timestamp_to = plot_timestamp_to
if "predictor_bias" in config["model_conf"].keys():
self.pred_bias = config["model_conf"]["predictor_bias"]
else:
self.pred_bias = 0
if "lang" in config:
self.language = config["lang"]
else:
self.language = None
def __call__(self, wav_content: Union[str, np.ndarray, List[str]], **kwargs) -> List:
waveform_list = self.load_data(wav_content, self.frontend.opts.frame_opts.samp_freq)
waveform_nums = len(waveform_list)
asr_res = []
for beg_idx in range(0, waveform_nums, self.batch_size):
end_idx = min(waveform_nums, beg_idx + self.batch_size)
feats, feats_len = self.extract_feat(waveform_list[beg_idx:end_idx])
try:
with torch.no_grad():
if int(self.device_id) == -1:
outputs = self.ort_infer(feats, feats_len)
am_scores, valid_token_lens = outputs[0], outputs[1]
else:
outputs = self.ort_infer(feats.cuda(), feats_len.cuda())
am_scores, valid_token_lens = outputs[0].cpu(), outputs[1].cpu()
if len(outputs) == 4:
# for BiCifParaformer Inference
us_alphas, us_peaks = outputs[2], outputs[3]
else:
us_alphas, us_peaks = None, None
except:
# logging.warning(traceback.format_exc())
logging.warning("input wav is silence or noise")
preds = [""]
else:
preds = self.decode(am_scores, valid_token_lens)
if us_peaks is None:
for pred in preds:
pred = sentence_postprocess(pred)
asr_res.append({"preds": pred})
else:
for pred, us_peaks_ in zip(preds, us_peaks):
raw_tokens = pred
timestamp, timestamp_raw = time_stamp_lfr6_onnx(
us_peaks_, copy.copy(raw_tokens)
)
text_proc, timestamp_proc, _ = sentence_postprocess(
raw_tokens, timestamp_raw
)
# logging.warning(timestamp)
if len(self.plot_timestamp_to):
self.plot_wave_timestamp(
waveform_list[0], timestamp, self.plot_timestamp_to
)
asr_res.append(
{
"preds": text_proc,
"timestamp": timestamp_proc,
"raw_tokens": raw_tokens,
}
)
return asr_res
def plot_wave_timestamp(self, wav, text_timestamp, dest):
# TODO: Plot the wav and timestamp results with matplotlib
import matplotlib
matplotlib.use("Agg")
matplotlib.rc(
"font", family="Alibaba PuHuiTi"
) # set it to a font that your system supports
import matplotlib.pyplot as plt
fig, ax1 = plt.subplots(figsize=(11, 3.5), dpi=320)
ax2 = ax1.twinx()
ax2.set_ylim([0, 2.0])
# plot waveform
ax1.set_ylim([-0.3, 0.3])
time = np.arange(wav.shape[0]) / 16000
ax1.plot(time, wav / wav.max() * 0.3, color="gray", alpha=0.4)
# plot lines and text
for char, start, end in text_timestamp:
ax1.vlines(start, -0.3, 0.3, ls="--")
ax1.vlines(end, -0.3, 0.3, ls="--")
x_adj = 0.045 if char != "<sil>" else 0.12
ax1.text((start + end) * 0.5 - x_adj, 0, char)
# plt.legend()
plotname = "{}/timestamp.png".format(dest)
plt.savefig(plotname, bbox_inches="tight")
def load_data(self, wav_content: Union[str, np.ndarray, List[str]], fs: int = None) -> List:
def load_wav(path: str) -> np.ndarray:
waveform, _ = librosa.load(path, sr=fs)
return waveform
if isinstance(wav_content, np.ndarray):
return [wav_content]
if isinstance(wav_content, str):
return [load_wav(wav_content)]
if isinstance(wav_content, list):
return [load_wav(path) for path in wav_content]
raise TypeError(f"The type of {wav_content} is not in [str, np.ndarray, list]")
def extract_feat(self, waveform_list: List[np.ndarray]) -> Tuple[np.ndarray, np.ndarray]:
feats, feats_len = [], []
for waveform in waveform_list:
speech, _ = self.frontend.fbank(waveform)
feat, feat_len = self.frontend.lfr_cmvn(speech)
feats.append(feat)
feats_len.append(feat_len)
feats = self.pad_feats(feats, np.max(feats_len))
feats_len = np.array(feats_len).astype(np.int32)
feats = torch.from_numpy(feats).type(torch.float32)
feats_len = torch.from_numpy(feats_len).type(torch.int32)
return feats, feats_len
@staticmethod
def pad_feats(feats: List[np.ndarray], max_feat_len: int) -> np.ndarray:
def pad_feat(feat: np.ndarray, cur_len: int) -> np.ndarray:
pad_width = ((0, max_feat_len - cur_len), (0, 0))
return np.pad(feat, pad_width, "constant", constant_values=0)
feat_res = [pad_feat(feat, feat.shape[0]) for feat in feats]
feats = np.array(feat_res).astype(np.float32)
return feats
def infer(self, feats: np.ndarray, feats_len: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
outputs = self.ort_infer([feats, feats_len])
return outputs
def decode(self, am_scores: np.ndarray, token_nums: int) -> List[str]:
return [
self.decode_one(am_score, token_num)
for am_score, token_num in zip(am_scores, token_nums)
]
def decode_one(self, am_score: np.ndarray, valid_token_num: int) -> List[str]:
yseq = am_score.argmax(axis=-1)
score = am_score.max(axis=-1)
score = np.sum(score, axis=-1)
# pad with mask tokens to ensure compatibility with sos/eos tokens
# asr_model.sos:1 asr_model.eos:2
yseq = np.array([1] + yseq.tolist() + [2])
hyp = Hypothesis(yseq=yseq, score=score)
# remove sos/eos and get results
last_pos = -1
token_int = hyp.yseq[1:last_pos].tolist()
# remove blank symbol id, which is assumed to be 0
token_int = list(filter(lambda x: x not in (0, 2), token_int))
# Change integer-ids to tokens
token = self.converter.ids2tokens(token_int)
token = token[: valid_token_num - self.pred_bias]
# texts = sentence_postprocess(token)
return token
class ContextualParaformer(Paraformer):
"""
Author: Speech Lab of DAMO Academy, Alibaba Group
Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
https://arxiv.org/abs/2206.08317
"""
def __init__(
self,
model_dir: Union[str, Path] = None,
batch_size: int = 1,
device_id: Union[str, int] = "-1",
plot_timestamp_to: str = "",
quantize: bool = False,
cache_dir: str = None,
**kwargs,
):
if not Path(model_dir).exists():
try:
from modelscope.hub.snapshot_download import snapshot_download
except:
raise "You are exporting model from modelscope, please install modelscope and try it again. To install modelscope, you could:\n" "\npip3 install -U modelscope\n" "For the users in China, you could install with the command:\n" "\npip3 install -U modelscope -i https://mirror.sjtu.edu.cn/pypi/web/simple"
try:
model_dir = snapshot_download(model_dir, cache_dir=cache_dir)
except:
raise "model_dir must be model_name in modelscope or local path downloaded from modelscope, but is {}".format(
model_dir
)
if quantize:
model_bb_file = os.path.join(model_dir, "model_bb_quant.torchscript")
model_eb_file = os.path.join(model_dir, "model_eb_quant.torchscript")
else:
model_bb_file = os.path.join(model_dir, "model_bb.torchscript")
model_eb_file = os.path.join(model_dir, "model_eb.torchscript")
if not (os.path.exists(model_eb_file) and os.path.exists(model_bb_file)):
print(".onnx does not exist, begin to export onnx")
try:
from funasr import AutoModel
except:
raise "You are exporting onnx, please install funasr and try it again. To install funasr, you could:\n" "\npip3 install -U funasr\n" "For the users in China, you could install with the command:\n" "\npip3 install -U funasr -i https://mirror.sjtu.edu.cn/pypi/web/simple"
model = AutoModel(model=model_dir)
model_dir = model.export(type="torchscript", quantize=quantize, **kwargs)
config_file = os.path.join(model_dir, "config.yaml")
cmvn_file = os.path.join(model_dir, "am.mvn")
config = read_yaml(config_file)
token_list = os.path.join(model_dir, "tokens.json")
with open(token_list, "r", encoding="utf-8") as f:
token_list = json.load(f)
# revert token_list into vocab dict
self.vocab = {}
for i, token in enumerate(token_list):
self.vocab[token] = i
self.converter = TokenIDConverter(token_list)
self.tokenizer = CharTokenizer()
self.frontend = WavFrontend(cmvn_file=cmvn_file, **config["frontend_conf"])
self.ort_infer_bb = torch.jit.load(model_bb_file)
self.ort_infer_eb = torch.jit.load(model_eb_file)
self.device_id = device_id
self.batch_size = batch_size
self.plot_timestamp_to = plot_timestamp_to
if "predictor_bias" in config["model_conf"].keys():
self.pred_bias = config["model_conf"]["predictor_bias"]
else:
self.pred_bias = 0
def __call__(
self, wav_content: Union[str, np.ndarray, List[str]], hotwords: str, **kwargs
) -> List:
# make hotword list
hotwords, hotwords_length = self.proc_hotword(hotwords)
if int(self.device_id) != -1:
bias_embed = self.eb_infer(hotwords.cuda())
else:
bias_embed = self.eb_infer(hotwords)
# index from bias_embed
bias_embed = torch.transpose(bias_embed, 0, 1)
_ind = np.arange(0, len(hotwords)).tolist()
bias_embed = bias_embed[_ind, hotwords_length.tolist()]
waveform_list = self.load_data(wav_content, self.frontend.opts.frame_opts.samp_freq)
waveform_nums = len(waveform_list)
asr_res = []
for beg_idx in range(0, waveform_nums, self.batch_size):
end_idx = min(waveform_nums, beg_idx + self.batch_size)
feats, feats_len = self.extract_feat(waveform_list[beg_idx:end_idx])
bias_embed = torch.unsqueeze(bias_embed, 0).repeat(feats.shape[0], 1, 1)
try:
with torch.no_grad():
if int(self.device_id) == -1:
outputs = self.bb_infer(feats, feats_len, bias_embed)
am_scores, valid_token_lens = outputs[0], outputs[1]
else:
outputs = self.bb_infer(feats.cuda(), feats_len.cuda(), bias_embed.cuda())
am_scores, valid_token_lens = outputs[0].cpu(), outputs[1].cpu()
except:
# logging.warning(traceback.format_exc())
logging.warning("input wav is silence or noise")
preds = [""]
else:
preds = self.decode(am_scores, valid_token_lens)
for pred in preds:
pred = sentence_postprocess(pred)
asr_res.append({"preds": pred})
return asr_res
def proc_hotword(self, hotwords):
hotwords = hotwords.split(" ")
hotwords_length = [len(i) - 1 for i in hotwords]
hotwords_length.append(0)
hotwords_length = np.array(hotwords_length)
# hotwords.append('<s>')
def word_map(word):
hotwords = []
for c in word:
if c not in self.vocab.keys():
hotwords.append(8403)
logging.warning(
"oov character {} found in hotword {}, replaced by <unk>".format(c, word)
)
else:
hotwords.append(self.vocab[c])
return np.array(hotwords)
hotword_int = [word_map(i) for i in hotwords]
hotword_int.append(np.array([1]))
hotwords = pad_list(hotword_int, pad_value=0, max_len=10)
return torch.tensor(hotwords), hotwords_length
def bb_infer(
self, feats, feats_len, bias_embed
):
outputs = self.ort_infer_bb(feats, feats_len, bias_embed)
return outputs
def eb_infer(self, hotwords):
outputs = self.ort_infer_eb(hotwords.long())
return outputs
def decode(self, am_scores: np.ndarray, token_nums: int) -> List[str]:
return [
self.decode_one(am_score, token_num)
for am_score, token_num in zip(am_scores, token_nums)
]
def decode_one(self, am_score: np.ndarray, valid_token_num: int) -> List[str]:
yseq = am_score.argmax(axis=-1)
score = am_score.max(axis=-1)
score = np.sum(score, axis=-1)
# pad with mask tokens to ensure compatibility with sos/eos tokens
# asr_model.sos:1 asr_model.eos:2
yseq = np.array([1] + yseq.tolist() + [2])
hyp = Hypothesis(yseq=yseq, score=score)
# remove sos/eos and get results
last_pos = -1
token_int = hyp.yseq[1:last_pos].tolist()
# remove blank symbol id, which is assumed to be 0
token_int = list(filter(lambda x: x not in (0, 2), token_int))
# Change integer-ids to tokens
token = self.converter.ids2tokens(token_int)
token = token[: valid_token_num - self.pred_bias]
# texts = sentence_postprocess(token)
return token
class SeacoParaformer(ContextualParaformer):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# no difference with contextual_paraformer in method of calling onnx models

View File

@@ -0,0 +1,228 @@
#!/usr/bin/env python3
# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/FunAudioLLM/SenseVoice). All Rights Reserved.
# MIT License (https://opensource.org/licenses/MIT)
import torch
import os.path
import librosa
import numpy as np
from pathlib import Path
from typing import List, Union, Tuple
from .utils.utils import (
CharTokenizer,
get_logger,
read_yaml,
)
from .utils.frontend import WavFrontend
from .utils.sentencepiece_tokenizer import SentencepiecesTokenizer
logging = get_logger()
class SenseVoiceSmall:
"""
Author: Speech Lab of DAMO Academy, Alibaba Group
Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
https://arxiv.org/abs/2206.08317
"""
def __init__(
self,
model_dir: Union[str, Path] = None,
batch_size: int = 1,
plot_timestamp_to: str = "",
quantize: bool = False,
intra_op_num_threads: int = 4,
cache_dir: str = None,
**kwargs,
):
self.device = kwargs.get("device", "cpu")
if not Path(model_dir).exists():
try:
from modelscope.hub.snapshot_download import snapshot_download
except:
raise "You are exporting model from modelscope, please install modelscope and try it again. To install modelscope, you could:\n" "\npip3 install -U modelscope\n" "For the users in China, you could install with the command:\n" "\npip3 install -U modelscope -i https://mirror.sjtu.edu.cn/pypi/web/simple"
try:
model_dir = snapshot_download(model_dir, cache_dir=cache_dir)
except:
raise "model_dir must be model_name in modelscope or local path downloaded from modelscope, but is {}".format(
model_dir
)
model_file = os.path.join(model_dir, "model.torchscript")
if quantize:
model_file = os.path.join(model_dir, "model_quant.torchscript")
if not os.path.exists(model_file):
print(".torchscripts does not exist, begin to export torchscript")
try:
from funasr import AutoModel
except:
raise "You are exporting onnx, please install funasr and try it again. To install funasr, you could:\n" "\npip3 install -U funasr\n" "For the users in China, you could install with the command:\n" "\npip3 install -U funasr -i https://mirror.sjtu.edu.cn/pypi/web/simple"
model = AutoModel(model=model_dir)
model_dir = model.export(type="torchscript", quantize=quantize, **kwargs)
config_file = os.path.join(model_dir, "config.yaml")
cmvn_file = os.path.join(model_dir, "am.mvn")
config = read_yaml(config_file)
self.tokenizer = SentencepiecesTokenizer(
bpemodel=os.path.join(model_dir, "chn_jpn_yue_eng_ko_spectok.bpe.model")
)
config["frontend_conf"]["cmvn_file"] = cmvn_file
self.frontend = WavFrontend(**config["frontend_conf"])
self.ort_infer = torch.jit.load(model_file)
self.batch_size = batch_size
self.blank_id = 0
self.lid_dict = {"auto": 0, "zh": 3, "en": 4, "yue": 7, "ja": 11, "ko": 12, "nospeech": 13}
self.lid_int_dict = {24884: 3, 24885: 4, 24888: 7, 24892: 11, 24896: 12, 24992: 13}
self.textnorm_dict = {"withitn": 14, "woitn": 15}
self.textnorm_int_dict = {25016: 14, 25017: 15}
def _get_lid(self, lid):
if lid in list(self.lid_dict.keys()):
return self.lid_dict[lid]
else:
raise ValueError(
f"The language {l} is not in {list(self.lid_dict.keys())}"
)
def _get_tnid(self, tnid):
if tnid in list(self.textnorm_dict.keys()):
return self.textnorm_dict[tnid]
else:
raise ValueError(
f"The textnorm {tnid} is not in {list(self.textnorm_dict.keys())}"
)
def read_tags(self, language_input, textnorm_input):
# handle language
if isinstance(language_input, list):
language_list = []
for l in language_input:
language_list.append(self._get_lid(l))
elif isinstance(language_input, str):
# if is existing file
if os.path.exists(language_input):
language_file = open(language_input, "r").readlines()
language_list = [
self._get_lid(l.strip())
for l in language_file
]
else:
language_list = [self._get_lid(language_input)]
else:
raise ValueError(
f"Unsupported type {type(language_input)} for language_input"
)
# handle textnorm
if isinstance(textnorm_input, list):
textnorm_list = []
for tn in textnorm_input:
textnorm_list.append(self._get_tnid(tn))
elif isinstance(textnorm_input, str):
# if is existing file
if os.path.exists(textnorm_input):
textnorm_file = open(textnorm_input, "r").readlines()
textnorm_list = [
self._get_tnid(tn.strip())
for tn in textnorm_file
]
else:
textnorm_list = [self._get_tnid(textnorm_input)]
else:
raise ValueError(
f"Unsupported type {type(textnorm_input)} for textnorm_input"
)
return language_list, textnorm_list
def __call__(self, wav_content: Union[str, np.ndarray, List[str]], **kwargs):
language_input = kwargs.get("language", "auto")
textnorm_input = kwargs.get("textnorm", "woitn")
language_list, textnorm_list = self.read_tags(language_input, textnorm_input)
waveform_list = self.load_data(wav_content, self.frontend.opts.frame_opts.samp_freq)
waveform_nums = len(waveform_list)
assert len(language_list) == 1 or len(language_list) == waveform_nums, \
"length of parsed language list should be 1 or equal to the number of waveforms"
assert len(textnorm_list) == 1 or len(textnorm_list) == waveform_nums, \
"length of parsed textnorm list should be 1 or equal to the number of waveforms"
asr_res = []
for beg_idx in range(0, waveform_nums, self.batch_size):
end_idx = min(waveform_nums, beg_idx + self.batch_size)
feats, feats_len = self.extract_feat(waveform_list[beg_idx:end_idx])
_language_list = language_list[beg_idx:end_idx]
_textnorm_list = textnorm_list[beg_idx:end_idx]
if not len(_language_list):
_language_list = [language_list[0]]
_textnorm_list = [textnorm_list[0]]
B = feats.shape[0]
if len(_language_list) == 1 and B != 1:
_language_list = _language_list * B
if len(_textnorm_list) == 1 and B != 1:
_textnorm_list = _textnorm_list * B
ctc_logits, encoder_out_lens = self.ort_infer(
torch.Tensor(feats).to(self.device),
torch.Tensor(feats_len).to(self.device),
torch.tensor(_language_list).to(self.device),
torch.tensor(_textnorm_list).to(self.device),
)
for b in range(feats.shape[0]):
# back to torch.Tensor
if isinstance(ctc_logits, np.ndarray):
ctc_logits = torch.from_numpy(ctc_logits).float()
# support batch_size=1 only currently
x = ctc_logits[b, : encoder_out_lens[b].item(), :]
yseq = x.argmax(dim=-1)
yseq = torch.unique_consecutive(yseq, dim=-1)
mask = yseq != self.blank_id
token_int = yseq[mask].tolist()
asr_res.append(self.tokenizer.decode(token_int))
return asr_res
def load_data(self, wav_content: Union[str, np.ndarray, List[str]], fs: int = None) -> List:
def load_wav(path: str) -> np.ndarray:
waveform, _ = librosa.load(path, sr=fs)
return waveform
if isinstance(wav_content, np.ndarray):
return [wav_content]
if isinstance(wav_content, str):
return [load_wav(wav_content)]
if isinstance(wav_content, list):
return [load_wav(path) for path in wav_content]
raise TypeError(f"The type of {wav_content} is not in [str, np.ndarray, list]")
def extract_feat(self, waveform_list: List[np.ndarray]) -> Tuple[np.ndarray, np.ndarray]:
feats, feats_len = [], []
for waveform in waveform_list:
speech, _ = self.frontend.fbank(waveform)
feat, feat_len = self.frontend.lfr_cmvn(speech)
feats.append(feat)
feats_len.append(feat_len)
feats = self.pad_feats(feats, np.max(feats_len))
feats_len = np.array(feats_len).astype(np.int32)
return feats, feats_len
@staticmethod
def pad_feats(feats: List[np.ndarray], max_feat_len: int) -> np.ndarray:
def pad_feat(feat: np.ndarray, cur_len: int) -> np.ndarray:
pad_width = ((0, max_feat_len - cur_len), (0, 0))
return np.pad(feat, pad_width, "constant", constant_values=0)
feat_res = [pad_feat(feat, feat.shape[0]) for feat in feats]
feats = np.array(feat_res).astype(np.float32)
return feats

View File

@@ -0,0 +1,193 @@
import os
import numpy as np
import sys
def compute_wer(ref_file, hyp_file, cer_detail_file):
rst = {
"Wrd": 0,
"Corr": 0,
"Ins": 0,
"Del": 0,
"Sub": 0,
"Snt": 0,
"Err": 0.0,
"S.Err": 0.0,
"wrong_words": 0,
"wrong_sentences": 0,
}
hyp_dict = {}
ref_dict = {}
with open(hyp_file, "r") as hyp_reader:
for line in hyp_reader:
key = line.strip().split()[0]
value = line.strip().split()[1:]
hyp_dict[key] = value
with open(ref_file, "r") as ref_reader:
for line in ref_reader:
key = line.strip().split()[0]
value = line.strip().split()[1:]
ref_dict[key] = value
cer_detail_writer = open(cer_detail_file, "w")
for hyp_key in hyp_dict:
if hyp_key in ref_dict:
out_item = compute_wer_by_line(hyp_dict[hyp_key], ref_dict[hyp_key])
rst["Wrd"] += out_item["nwords"]
rst["Corr"] += out_item["cor"]
rst["wrong_words"] += out_item["wrong"]
rst["Ins"] += out_item["ins"]
rst["Del"] += out_item["del"]
rst["Sub"] += out_item["sub"]
rst["Snt"] += 1
if out_item["wrong"] > 0:
rst["wrong_sentences"] += 1
cer_detail_writer.write(hyp_key + print_cer_detail(out_item) + "\n")
cer_detail_writer.write("ref:" + "\t" + "".join(ref_dict[hyp_key]) + "\n")
cer_detail_writer.write("hyp:" + "\t" + "".join(hyp_dict[hyp_key]) + "\n")
if rst["Wrd"] > 0:
rst["Err"] = round(rst["wrong_words"] * 100 / rst["Wrd"], 2)
if rst["Snt"] > 0:
rst["S.Err"] = round(rst["wrong_sentences"] * 100 / rst["Snt"], 2)
cer_detail_writer.write("\n")
cer_detail_writer.write(
"%WER "
+ str(rst["Err"])
+ " [ "
+ str(rst["wrong_words"])
+ " / "
+ str(rst["Wrd"])
+ ", "
+ str(rst["Ins"])
+ " ins, "
+ str(rst["Del"])
+ " del, "
+ str(rst["Sub"])
+ " sub ]"
+ "\n"
)
cer_detail_writer.write(
"%SER "
+ str(rst["S.Err"])
+ " [ "
+ str(rst["wrong_sentences"])
+ " / "
+ str(rst["Snt"])
+ " ]"
+ "\n"
)
cer_detail_writer.write(
"Scored "
+ str(len(hyp_dict))
+ " sentences, "
+ str(len(hyp_dict) - rst["Snt"])
+ " not present in hyp."
+ "\n"
)
def compute_wer_by_line(hyp, ref):
hyp = list(map(lambda x: x.lower(), hyp))
ref = list(map(lambda x: x.lower(), ref))
len_hyp = len(hyp)
len_ref = len(ref)
cost_matrix = np.zeros((len_hyp + 1, len_ref + 1), dtype=np.int16)
ops_matrix = np.zeros((len_hyp + 1, len_ref + 1), dtype=np.int8)
for i in range(len_hyp + 1):
cost_matrix[i][0] = i
for j in range(len_ref + 1):
cost_matrix[0][j] = j
for i in range(1, len_hyp + 1):
for j in range(1, len_ref + 1):
if hyp[i - 1] == ref[j - 1]:
cost_matrix[i][j] = cost_matrix[i - 1][j - 1]
else:
substitution = cost_matrix[i - 1][j - 1] + 1
insertion = cost_matrix[i - 1][j] + 1
deletion = cost_matrix[i][j - 1] + 1
compare_val = [substitution, insertion, deletion]
min_val = min(compare_val)
operation_idx = compare_val.index(min_val) + 1
cost_matrix[i][j] = min_val
ops_matrix[i][j] = operation_idx
match_idx = []
i = len_hyp
j = len_ref
rst = {"nwords": len_ref, "cor": 0, "wrong": 0, "ins": 0, "del": 0, "sub": 0}
while i >= 0 or j >= 0:
i_idx = max(0, i)
j_idx = max(0, j)
if ops_matrix[i_idx][j_idx] == 0: # correct
if i - 1 >= 0 and j - 1 >= 0:
match_idx.append((j - 1, i - 1))
rst["cor"] += 1
i -= 1
j -= 1
elif ops_matrix[i_idx][j_idx] == 2: # insert
i -= 1
rst["ins"] += 1
elif ops_matrix[i_idx][j_idx] == 3: # delete
j -= 1
rst["del"] += 1
elif ops_matrix[i_idx][j_idx] == 1: # substitute
i -= 1
j -= 1
rst["sub"] += 1
if i < 0 and j >= 0:
rst["del"] += 1
elif j < 0 and i >= 0:
rst["ins"] += 1
match_idx.reverse()
wrong_cnt = cost_matrix[len_hyp][len_ref]
rst["wrong"] = wrong_cnt
return rst
def print_cer_detail(rst):
return (
"("
+ "nwords="
+ str(rst["nwords"])
+ ",cor="
+ str(rst["cor"])
+ ",ins="
+ str(rst["ins"])
+ ",del="
+ str(rst["del"])
+ ",sub="
+ str(rst["sub"])
+ ") corr:"
+ "{:.2%}".format(rst["cor"] / rst["nwords"])
+ ",cer:"
+ "{:.2%}".format(rst["wrong"] / rst["nwords"])
)
if __name__ == "__main__":
if len(sys.argv) != 4:
print("usage : python compute-wer.py test.ref test.hyp test.wer")
sys.exit(0)
ref_file = sys.argv[1]
hyp_file = sys.argv[2]
cer_detail_file = sys.argv[3]
compute_wer(ref_file, hyp_file, cer_detail_file)

View File

@@ -0,0 +1,193 @@
# -*- encoding: utf-8 -*-
from pathlib import Path
from typing import Any, Dict, Iterable, List, NamedTuple, Set, Tuple, Union
import numpy as np
import kaldi_native_fbank as knf
root_dir = Path(__file__).resolve().parent
logger_initialized = {}
class WavFrontend:
"""Conventional frontend structure for ASR."""
def __init__(
self,
cmvn_file: str = None,
fs: int = 16000,
window: str = "hamming",
n_mels: int = 80,
frame_length: int = 25,
frame_shift: int = 10,
lfr_m: int = 1,
lfr_n: int = 1,
dither: float = 1.0,
**kwargs,
) -> None:
opts = knf.FbankOptions()
opts.frame_opts.samp_freq = fs
opts.frame_opts.dither = dither
opts.frame_opts.window_type = window
opts.frame_opts.frame_shift_ms = float(frame_shift)
opts.frame_opts.frame_length_ms = float(frame_length)
opts.mel_opts.num_bins = n_mels
opts.energy_floor = 0
opts.frame_opts.snip_edges = True
opts.mel_opts.debug_mel = False
self.opts = opts
self.lfr_m = lfr_m
self.lfr_n = lfr_n
self.cmvn_file = cmvn_file
if self.cmvn_file:
self.cmvn = self.load_cmvn()
self.fbank_fn = None
self.fbank_beg_idx = 0
self.reset_status()
def fbank(self, waveform: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
waveform = waveform * (1 << 15)
self.fbank_fn = knf.OnlineFbank(self.opts)
self.fbank_fn.accept_waveform(self.opts.frame_opts.samp_freq, waveform.tolist())
frames = self.fbank_fn.num_frames_ready
mat = np.empty([frames, self.opts.mel_opts.num_bins])
for i in range(frames):
mat[i, :] = self.fbank_fn.get_frame(i)
feat = mat.astype(np.float32)
feat_len = np.array(mat.shape[0]).astype(np.int32)
return feat, feat_len
def fbank_online(self, waveform: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
waveform = waveform * (1 << 15)
# self.fbank_fn = knf.OnlineFbank(self.opts)
self.fbank_fn.accept_waveform(self.opts.frame_opts.samp_freq, waveform.tolist())
frames = self.fbank_fn.num_frames_ready
mat = np.empty([frames, self.opts.mel_opts.num_bins])
for i in range(self.fbank_beg_idx, frames):
mat[i, :] = self.fbank_fn.get_frame(i)
# self.fbank_beg_idx += (frames-self.fbank_beg_idx)
feat = mat.astype(np.float32)
feat_len = np.array(mat.shape[0]).astype(np.int32)
return feat, feat_len
def reset_status(self):
self.fbank_fn = knf.OnlineFbank(self.opts)
self.fbank_beg_idx = 0
def lfr_cmvn(self, feat: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
if self.lfr_m != 1 or self.lfr_n != 1:
feat = self.apply_lfr(feat, self.lfr_m, self.lfr_n)
if self.cmvn_file:
feat = self.apply_cmvn(feat)
feat_len = np.array(feat.shape[0]).astype(np.int32)
return feat, feat_len
@staticmethod
def apply_lfr(inputs: np.ndarray, lfr_m: int, lfr_n: int) -> np.ndarray:
LFR_inputs = []
T = inputs.shape[0]
T_lfr = int(np.ceil(T / lfr_n))
left_padding = np.tile(inputs[0], ((lfr_m - 1) // 2, 1))
inputs = np.vstack((left_padding, inputs))
T = T + (lfr_m - 1) // 2
for i in range(T_lfr):
if lfr_m <= T - i * lfr_n:
LFR_inputs.append((inputs[i * lfr_n : i * lfr_n + lfr_m]).reshape(1, -1))
else:
# process last LFR frame
num_padding = lfr_m - (T - i * lfr_n)
frame = inputs[i * lfr_n :].reshape(-1)
for _ in range(num_padding):
frame = np.hstack((frame, inputs[-1]))
LFR_inputs.append(frame)
LFR_outputs = np.vstack(LFR_inputs).astype(np.float32)
return LFR_outputs
def apply_cmvn(self, inputs: np.ndarray) -> np.ndarray:
"""
Apply CMVN with mvn data
"""
frame, dim = inputs.shape
means = np.tile(self.cmvn[0:1, :dim], (frame, 1))
vars = np.tile(self.cmvn[1:2, :dim], (frame, 1))
inputs = (inputs + means) * vars
return inputs
def load_cmvn(
self,
) -> np.ndarray:
with open(self.cmvn_file, "r", encoding="utf-8") as f:
lines = f.readlines()
means_list = []
vars_list = []
for i in range(len(lines)):
line_item = lines[i].split()
if line_item[0] == "<AddShift>":
line_item = lines[i + 1].split()
if line_item[0] == "<LearnRateCoef>":
add_shift_line = line_item[3 : (len(line_item) - 1)]
means_list = list(add_shift_line)
continue
elif line_item[0] == "<Rescale>":
line_item = lines[i + 1].split()
if line_item[0] == "<LearnRateCoef>":
rescale_line = line_item[3 : (len(line_item) - 1)]
vars_list = list(rescale_line)
continue
means = np.array(means_list).astype(np.float64)
vars = np.array(vars_list).astype(np.float64)
cmvn = np.array([means, vars])
return cmvn
def load_bytes(input):
middle_data = np.frombuffer(input, dtype=np.int16)
middle_data = np.asarray(middle_data)
if middle_data.dtype.kind not in "iu":
raise TypeError("'middle_data' must be an array of integers")
dtype = np.dtype("float32")
if dtype.kind != "f":
raise TypeError("'dtype' must be a floating point type")
i = np.iinfo(middle_data.dtype)
abs_max = 2 ** (i.bits - 1)
offset = i.min + abs_max
array = np.frombuffer((middle_data.astype(dtype) - offset) / abs_max, dtype=np.float32)
return array
def test():
path = "/nfs/zhifu.gzf/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/example/asr_example.wav"
import librosa
cmvn_file = "/nfs/zhifu.gzf/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/am.mvn"
config_file = "/nfs/zhifu.gzf/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/config.yaml"
from funasr.runtime.python.onnxruntime.rapid_paraformer.utils.utils import read_yaml
config = read_yaml(config_file)
waveform, _ = librosa.load(path, sr=None)
frontend = WavFrontend(
cmvn_file=cmvn_file,
**config["frontend_conf"],
)
speech, _ = frontend.fbank_online(waveform) # 1d, (sample,), numpy
feat, feat_len = frontend.lfr_cmvn(
speech
) # 2d, (frame, 450), np.float32 -> torch, torch.from_numpy(), dtype, (1, frame, 450)
frontend.reset_status() # clear cache
return feat, feat_len
if __name__ == "__main__":
test()

View File

@@ -0,0 +1,364 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import string
import logging
from typing import Any, List, Union
def isChinese(ch: str):
if "\u4e00" <= ch <= "\u9fff" or "\u0030" <= ch <= "\u0039":
return True
return False
def isAllChinese(word: Union[List[Any], str]):
word_lists = []
for i in word:
cur = i.replace(" ", "")
cur = cur.replace("</s>", "")
cur = cur.replace("<s>", "")
word_lists.append(cur)
if len(word_lists) == 0:
return False
for ch in word_lists:
if isChinese(ch) is False:
return False
return True
def isAllAlpha(word: Union[List[Any], str]):
word_lists = []
for i in word:
cur = i.replace(" ", "")
cur = cur.replace("</s>", "")
cur = cur.replace("<s>", "")
word_lists.append(cur)
if len(word_lists) == 0:
return False
for ch in word_lists:
if ch.isalpha() is False and ch != "'":
return False
elif ch.isalpha() is True and isChinese(ch) is True:
return False
return True
# def abbr_dispose(words: List[Any]) -> List[Any]:
def abbr_dispose(words: List[Any], time_stamp: List[List] = None) -> List[Any]:
words_size = len(words)
word_lists = []
abbr_begin = []
abbr_end = []
last_num = -1
ts_lists = []
ts_nums = []
ts_index = 0
for num in range(words_size):
if num <= last_num:
continue
if len(words[num]) == 1 and words[num].encode("utf-8").isalpha():
if (
num + 1 < words_size
and words[num + 1] == " "
and num + 2 < words_size
and len(words[num + 2]) == 1
and words[num + 2].encode("utf-8").isalpha()
):
# found the begin of abbr
abbr_begin.append(num)
num += 2
abbr_end.append(num)
# to find the end of abbr
while True:
num += 1
if num < words_size and words[num] == " ":
num += 1
if (
num < words_size
and len(words[num]) == 1
and words[num].encode("utf-8").isalpha()
):
abbr_end.pop()
abbr_end.append(num)
last_num = num
else:
break
else:
break
for num in range(words_size):
if words[num] == " ":
ts_nums.append(ts_index)
else:
ts_nums.append(ts_index)
ts_index += 1
last_num = -1
for num in range(words_size):
if num <= last_num:
continue
if num in abbr_begin:
if time_stamp is not None:
begin = time_stamp[ts_nums[num]][0]
word_lists.append(words[num].upper())
num += 1
while num < words_size:
if num in abbr_end:
word_lists.append(words[num].upper())
last_num = num
break
else:
if words[num].encode("utf-8").isalpha():
word_lists.append(words[num].upper())
num += 1
if time_stamp is not None:
end = time_stamp[ts_nums[num]][1]
ts_lists.append([begin, end])
else:
word_lists.append(words[num])
if time_stamp is not None and words[num] != " ":
begin = time_stamp[ts_nums[num]][0]
end = time_stamp[ts_nums[num]][1]
ts_lists.append([begin, end])
begin = end
if time_stamp is not None:
return word_lists, ts_lists
else:
return word_lists
def sentence_postprocess(words: List[Any], time_stamp: List[List] = None):
middle_lists = []
word_lists = []
word_item = ""
ts_lists = []
# wash words lists
for i in words:
word = ""
if isinstance(i, str):
word = i
else:
word = i.decode("utf-8")
if word in ["<s>", "</s>", "<unk>"]:
continue
else:
middle_lists.append(word)
# all chinese characters
if isAllChinese(middle_lists):
for i, ch in enumerate(middle_lists):
word_lists.append(ch.replace(" ", ""))
if time_stamp is not None:
ts_lists = time_stamp
# all alpha characters
elif isAllAlpha(middle_lists):
ts_flag = True
for i, ch in enumerate(middle_lists):
if ts_flag and time_stamp is not None:
begin = time_stamp[i][0]
end = time_stamp[i][1]
word = ""
if "@@" in ch:
word = ch.replace("@@", "")
word_item += word
if time_stamp is not None:
ts_flag = False
end = time_stamp[i][1]
else:
word_item += ch
word_lists.append(word_item)
word_lists.append(" ")
word_item = ""
if time_stamp is not None:
ts_flag = True
end = time_stamp[i][1]
ts_lists.append([begin, end])
begin = end
# mix characters
else:
alpha_blank = False
ts_flag = True
begin = -1
end = -1
for i, ch in enumerate(middle_lists):
if ts_flag and time_stamp is not None:
begin = time_stamp[i][0]
end = time_stamp[i][1]
word = ""
if isAllChinese(ch):
if alpha_blank is True:
word_lists.pop()
word_lists.append(ch)
alpha_blank = False
if time_stamp is not None:
ts_flag = True
ts_lists.append([begin, end])
begin = end
elif "@@" in ch:
word = ch.replace("@@", "")
word_item += word
alpha_blank = False
if time_stamp is not None:
ts_flag = False
end = time_stamp[i][1]
elif isAllAlpha(ch):
word_item += ch
word_lists.append(word_item)
word_lists.append(" ")
word_item = ""
alpha_blank = True
if time_stamp is not None:
ts_flag = True
end = time_stamp[i][1]
ts_lists.append([begin, end])
begin = end
else:
raise ValueError("invalid character: {}".format(ch))
if time_stamp is not None:
word_lists, ts_lists = abbr_dispose(word_lists, ts_lists)
real_word_lists = []
for ch in word_lists:
if ch != " ":
real_word_lists.append(ch)
sentence = " ".join(real_word_lists).strip()
return sentence, ts_lists, real_word_lists
else:
word_lists = abbr_dispose(word_lists)
real_word_lists = []
for ch in word_lists:
if ch != " ":
real_word_lists.append(ch)
sentence = "".join(word_lists).strip()
return sentence, real_word_lists
emo_dict = {
"<|HAPPY|>": "😊",
"<|SAD|>": "😔",
"<|ANGRY|>": "😡",
"<|NEUTRAL|>": "",
"<|FEARFUL|>": "😰",
"<|DISGUSTED|>": "🤢",
"<|SURPRISED|>": "😮",
}
event_dict = {
"<|BGM|>": "🎼",
"<|Speech|>": "",
"<|Applause|>": "👏",
"<|Laughter|>": "😀",
"<|Cry|>": "😭",
"<|Sneeze|>": "🤧",
"<|Breath|>": "",
"<|Cough|>": "🤧",
}
lang_dict = {
"<|zh|>": "<|lang|>",
"<|en|>": "<|lang|>",
"<|yue|>": "<|lang|>",
"<|ja|>": "<|lang|>",
"<|ko|>": "<|lang|>",
"<|nospeech|>": "<|lang|>",
}
emoji_dict = {
"<|nospeech|><|Event_UNK|>": "",
"<|zh|>": "",
"<|en|>": "",
"<|yue|>": "",
"<|ja|>": "",
"<|ko|>": "",
"<|nospeech|>": "",
"<|HAPPY|>": "😊",
"<|SAD|>": "😔",
"<|ANGRY|>": "😡",
"<|NEUTRAL|>": "",
"<|BGM|>": "🎼",
"<|Speech|>": "",
"<|Applause|>": "👏",
"<|Laughter|>": "😀",
"<|FEARFUL|>": "😰",
"<|DISGUSTED|>": "🤢",
"<|SURPRISED|>": "😮",
"<|Cry|>": "😭",
"<|EMO_UNKNOWN|>": "",
"<|Sneeze|>": "🤧",
"<|Breath|>": "",
"<|Cough|>": "😷",
"<|Sing|>": "",
"<|Speech_Noise|>": "",
"<|withitn|>": "",
"<|woitn|>": "",
"<|GBG|>": "",
"<|Event_UNK|>": "",
}
emo_set = {"😊", "😔", "😡", "😰", "🤢", "😮"}
event_set = {
"🎼",
"👏",
"😀",
"😭",
"🤧",
"😷",
}
def format_str_v2(s):
sptk_dict = {}
for sptk in emoji_dict:
sptk_dict[sptk] = s.count(sptk)
s = s.replace(sptk, "")
emo = "<|NEUTRAL|>"
for e in emo_dict:
if sptk_dict[e] > sptk_dict[emo]:
emo = e
for e in event_dict:
if sptk_dict[e] > 0:
s = event_dict[e] + s
s = s + emo_dict[emo]
for emoji in emo_set.union(event_set):
s = s.replace(" " + emoji, emoji)
s = s.replace(emoji + " ", emoji)
return s.strip()
def rich_transcription_postprocess(s):
def get_emo(s):
return s[-1] if s[-1] in emo_set else None
def get_event(s):
return s[0] if s[0] in event_set else None
s = s.replace("<|nospeech|><|Event_UNK|>", "")
for lang in lang_dict:
s = s.replace(lang, "<|lang|>")
s_list = [format_str_v2(s_i).strip(" ") for s_i in s.split("<|lang|>")]
new_s = " " + s_list[0]
cur_ent_event = get_event(new_s)
for i in range(1, len(s_list)):
if len(s_list[i]) == 0:
continue
if get_event(s_list[i]) == cur_ent_event and get_event(s_list[i]) != None:
s_list[i] = s_list[i][1:]
# else:
cur_ent_event = get_event(s_list[i])
if get_emo(s_list[i]) != None and get_emo(s_list[i]) == get_emo(new_s):
new_s = new_s[:-1]
new_s += s_list[i].strip().lstrip()
new_s = new_s.replace("The.", " ")
return new_s.strip()

View File

@@ -0,0 +1,53 @@
from pathlib import Path
from typing import Iterable
from typing import List
from typing import Union
import sentencepiece as spm
class SentencepiecesTokenizer:
def __init__(self, bpemodel: Union[Path, str], **kwargs):
super().__init__(**kwargs)
self.bpemodel = str(bpemodel)
# NOTE(kamo):
# Don't build SentencePieceProcessor in __init__()
# because it's not picklable and it may cause following error,
# "TypeError: can't pickle SwigPyObject objects",
# when giving it as argument of "multiprocessing.Process()".
self.sp = None
self._build_sentence_piece_processor()
def __repr__(self):
return f'{self.__class__.__name__}(model="{self.bpemodel}")'
def _build_sentence_piece_processor(self):
# Build SentencePieceProcessor lazily.
if self.sp is None:
self.sp = spm.SentencePieceProcessor()
self.sp.load(self.bpemodel)
def text2tokens(self, line: str) -> List[str]:
self._build_sentence_piece_processor()
return self.sp.EncodeAsPieces(line)
def tokens2text(self, tokens: Iterable[str]) -> str:
self._build_sentence_piece_processor()
return self.sp.DecodePieces(list(tokens))
def encode(self, line: str, **kwargs) -> List[int]:
self._build_sentence_piece_processor()
return self.sp.EncodeAsIds(line)
def decode(self, line: List[int], **kwargs):
self._build_sentence_piece_processor()
return self.sp.DecodeIds(line)
def get_vocab_size(self):
return self.sp.GetPieceSize()
def ids2tokens(self, *args, **kwargs):
return self.decode(*args, **kwargs)
def tokens2ids(self, *args, **kwargs):
return self.encode(*args, **kwargs)

View File

@@ -0,0 +1,62 @@
import numpy as np
def time_stamp_lfr6_onnx(us_cif_peak, char_list, begin_time=0.0, total_offset=-1.5):
if not len(char_list):
return "", []
START_END_THRESHOLD = 5
MAX_TOKEN_DURATION = 30
TIME_RATE = 10.0 * 6 / 1000 / 3 # 3 times upsampled
cif_peak = us_cif_peak.reshape(-1).cpu()
num_frames = cif_peak.shape[-1]
if char_list[-1] == "</s>":
char_list = char_list[:-1]
# char_list = [i for i in text]
timestamp_list = []
new_char_list = []
# for bicif model trained with large data, cif2 actually fires when a character starts
# so treat the frames between two peaks as the duration of the former token
fire_place = np.where(cif_peak > 1.0 - 1e-4)[0] + total_offset # np format
num_peak = len(fire_place)
assert num_peak == len(char_list) + 1 # number of peaks is supposed to be number of tokens + 1
# begin silence
if fire_place[0] > START_END_THRESHOLD:
# char_list.insert(0, '<sil>')
timestamp_list.append([0.0, fire_place[0] * TIME_RATE])
new_char_list.append("<sil>")
# tokens timestamp
for i in range(len(fire_place) - 1):
new_char_list.append(char_list[i])
if (
i == len(fire_place) - 2
or MAX_TOKEN_DURATION < 0
or fire_place[i + 1] - fire_place[i] < MAX_TOKEN_DURATION
):
timestamp_list.append([fire_place[i] * TIME_RATE, fire_place[i + 1] * TIME_RATE])
else:
# cut the duration to token and sil of the 0-weight frames last long
_split = fire_place[i] + MAX_TOKEN_DURATION
timestamp_list.append([fire_place[i] * TIME_RATE, _split * TIME_RATE])
timestamp_list.append([_split * TIME_RATE, fire_place[i + 1] * TIME_RATE])
new_char_list.append("<sil>")
# tail token and end silence
if num_frames - fire_place[-1] > START_END_THRESHOLD:
_end = (num_frames + fire_place[-1]) / 2
timestamp_list[-1][1] = _end * TIME_RATE
timestamp_list.append([_end * TIME_RATE, num_frames * TIME_RATE])
new_char_list.append("<sil>")
else:
timestamp_list[-1][1] = num_frames * TIME_RATE
if begin_time: # add offset time in model with vad
for i in range(len(timestamp_list)):
timestamp_list[i][0] = timestamp_list[i][0] + begin_time / 1000.0
timestamp_list[i][1] = timestamp_list[i][1] + begin_time / 1000.0
assert len(new_char_list) == len(timestamp_list)
res_str = ""
for char, timestamp in zip(new_char_list, timestamp_list):
res_str += "{} {} {};".format(char, timestamp[0], timestamp[1])
res = []
for char, timestamp in zip(new_char_list, timestamp_list):
if char != "<sil>":
res.append([int(timestamp[0] * 1000), int(timestamp[1] * 1000)])
return res_str, res

View File

@@ -0,0 +1,161 @@
# -*- encoding: utf-8 -*-
import yaml
import logging
import functools
import numpy as np
from pathlib import Path
from typing import Any, Dict, Iterable, List, NamedTuple, Set, Tuple, Union
root_dir = Path(__file__).resolve().parent
logger_initialized = {}
def pad_list(xs, pad_value, max_len=None):
n_batch = len(xs)
if max_len is None:
max_len = max(x.size(0) for x in xs)
# pad = xs[0].new(n_batch, max_len, *xs[0].size()[1:]).fill_(pad_value)
# numpy format
pad = (np.zeros((n_batch, max_len)) + pad_value).astype(np.int32)
for i in range(n_batch):
pad[i, : xs[i].shape[0]] = xs[i]
return pad
class TokenIDConverter:
def __init__(
self,
token_list: Union[List, str],
):
self.token_list = token_list
self.unk_symbol = token_list[-1]
self.token2id = {v: i for i, v in enumerate(self.token_list)}
self.unk_id = self.token2id[self.unk_symbol]
def get_num_vocabulary_size(self) -> int:
return len(self.token_list)
def ids2tokens(self, integers: Union[np.ndarray, Iterable[int]]) -> List[str]:
if isinstance(integers, np.ndarray) and integers.ndim != 1:
raise TokenIDConverterError(f"Must be 1 dim ndarray, but got {integers.ndim}")
return [self.token_list[i] for i in integers]
def tokens2ids(self, tokens: Iterable[str]) -> List[int]:
return [self.token2id.get(i, self.unk_id) for i in tokens]
class CharTokenizer:
def __init__(
self,
symbol_value: Union[Path, str, Iterable[str]] = None,
space_symbol: str = "<space>",
remove_non_linguistic_symbols: bool = False,
):
self.space_symbol = space_symbol
self.non_linguistic_symbols = self.load_symbols(symbol_value)
self.remove_non_linguistic_symbols = remove_non_linguistic_symbols
@staticmethod
def load_symbols(value: Union[Path, str, Iterable[str]] = None) -> Set:
if value is None:
return set()
if isinstance(value, Iterable[str]):
return set(value)
file_path = Path(value)
if not file_path.exists():
logging.warning("%s doesn't exist.", file_path)
return set()
with file_path.open("r", encoding="utf-8") as f:
return set(line.rstrip() for line in f)
def text2tokens(self, line: Union[str, list]) -> List[str]:
tokens = []
while len(line) != 0:
for w in self.non_linguistic_symbols:
if line.startswith(w):
if not self.remove_non_linguistic_symbols:
tokens.append(line[: len(w)])
line = line[len(w) :]
break
else:
t = line[0]
if t == " ":
t = "<space>"
tokens.append(t)
line = line[1:]
return tokens
def tokens2text(self, tokens: Iterable[str]) -> str:
tokens = [t if t != self.space_symbol else " " for t in tokens]
return "".join(tokens)
def __repr__(self):
return (
f"{self.__class__.__name__}("
f'space_symbol="{self.space_symbol}"'
f'non_linguistic_symbols="{self.non_linguistic_symbols}"'
f")"
)
class Hypothesis(NamedTuple):
"""Hypothesis data type."""
yseq: np.ndarray
score: Union[float, np.ndarray] = 0
scores: Dict[str, Union[float, np.ndarray]] = dict()
states: Dict[str, Any] = dict()
def asdict(self) -> dict:
"""Convert data to JSON-friendly dict."""
return self._replace(
yseq=self.yseq.tolist(),
score=float(self.score),
scores={k: float(v) for k, v in self.scores.items()},
)._asdict()
def read_yaml(yaml_path: Union[str, Path]) -> Dict:
if not Path(yaml_path).exists():
raise FileExistsError(f"The {yaml_path} does not exist.")
with open(str(yaml_path), "rb") as f:
data = yaml.load(f, Loader=yaml.Loader)
return data
@functools.lru_cache()
def get_logger(name="funasr_torch"):
"""Initialize and get a logger by name.
If the logger has not been initialized, this method will initialize the
logger by adding one or two handlers, otherwise the initialized logger will
be directly returned. During initialization, a StreamHandler will always be
added.
Args:
name (str): Logger name.
Returns:
logging.Logger: The expected logger.
"""
logger = logging.getLogger(name)
if name in logger_initialized:
return logger
for logger_name in logger_initialized:
if name.startswith(logger_name):
return logger
formatter = logging.Formatter(
"[%(asctime)s] %(name)s %(levelname)s: %(message)s", datefmt="%Y/%m/%d %H:%M:%S"
)
sh = logging.StreamHandler()
sh.setFormatter(formatter)
logger.addHandler(sh)
logger_initialized[name] = True
logger.propagate = False
return logger

View File

@@ -0,0 +1,46 @@
# -*- encoding: utf-8 -*-
from pathlib import Path
import setuptools
from setuptools import find_packages
def get_readme():
root_dir = Path(__file__).resolve().parent
readme_path = str(root_dir / "README.md")
print(readme_path)
with open(readme_path, "r", encoding="utf-8") as f:
readme = f.read()
return readme
setuptools.setup(
name="funasr_torch",
version="0.1.3",
platforms="Any",
url="https://github.com/alibaba-damo-academy/FunASR.git",
author="Speech Lab of DAMO Academy, Alibaba Group",
author_email="funasr@list.alibaba-inc.com",
description="FunASR: A Fundamental End-to-End Speech Recognition Toolkit",
license="The MIT License",
long_description=get_readme(),
long_description_content_type="text/markdown",
include_package_data=True,
install_requires=[
"librosa",
"onnxruntime>=1.7.0",
"scipy",
"numpy>=1.19.3",
"kaldi-native-fbank",
"PyYAML>=5.1.2",
"torch-quant >= 0.4.0",
],
packages=find_packages(include=["torch_paraformer*"]),
keywords=["funasr, paraformer, funasr_torch"],
classifiers=[
"Programming Language :: Python :: 3.6",
"Programming Language :: Python :: 3.7",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
],
)

View File

@@ -0,0 +1,188 @@
# ONNXRuntime-python
## Install `funasr-onnx`
install from pip
```shell
pip install -U funasr-onnx
# For the users in China, you could install with the command:
# pip install -U funasr-onnx -i https://mirror.sjtu.edu.cn/pypi/web/simple
# If you want to export .onnx file, you should install modelscope and funasr
pip install -U modelscope funasr
# For the users in China, you could install with the command:
# pip install -U modelscope funasr -i https://mirror.sjtu.edu.cn/pypi/web/simple
```
or install from source code
```shell
git clone https://github.com/alibaba/FunASR.git && cd FunASR
cd funasr/runtime/python/onnxruntime
pip install -e ./
# For the users in China, you could install with the command:
# pip install -e ./ -i https://mirror.sjtu.edu.cn/pypi/web/simple
```
## Inference with runtime
### Speech Recognition
#### Paraformer
```python
from funasr_onnx import Paraformer
from pathlib import Path
model_dir = "damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
model = Paraformer(model_dir, batch_size=1, quantize=True)
wav_path = ['{}/.cache/modelscope/hub/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/example/asr_example.wav'.format(Path.home())]
result = model(wav_path)
print(result)
```
- `model_dir`: model_name in modelscope or local path downloaded from modelscope. If the local path is set, it should contain `model.onnx`, `config.yaml`, `am.mvn`
- `batch_size`: `1` (Default), the batch size duration inference
- `device_id`: `-1` (Default), infer on CPU. If you want to infer with GPU, set it to gpu_id (Please make sure that you have install the onnxruntime-gpu)
- `quantize`: `False` (Default), load the model of `model.onnx` in `model_dir`. If set `True`, load the model of `model_quant.onnx` in `model_dir`
- `intra_op_num_threads`: `4` (Default), sets the number of threads used for intraop parallelism on CPU
Input: wav formt file, support formats: `str, np.ndarray, List[str]`
Output: `List[str]`: recognition result
#### Paraformer-online
### Voice Activity Detection
#### FSMN-VAD
```python
from funasr_onnx import Fsmn_vad
from pathlib import Path
model_dir = "damo/speech_fsmn_vad_zh-cn-16k-common-pytorch"
wav_path = '{}/.cache/modelscope/hub/damo/speech_fsmn_vad_zh-cn-16k-common-pytorch/example/vad_example.wav'.format(Path.home())
model = Fsmn_vad(model_dir)
result = model(wav_path)
print(result)
```
- `model_dir`: model_name in modelscope or local path downloaded from modelscope. If the local path is set, it should contain `model.onnx`, `config.yaml`, `am.mvn`
- `batch_size`: `1` (Default), the batch size duration inference
- `device_id`: `-1` (Default), infer on CPU. If you want to infer with GPU, set it to gpu_id (Please make sure that you have install the onnxruntime-gpu)
- `quantize`: `False` (Default), load the model of `model.onnx` in `model_dir`. If set `True`, load the model of `model_quant.onnx` in `model_dir`
- `intra_op_num_threads`: `4` (Default), sets the number of threads used for intraop parallelism on CPU
Input: wav formt file, support formats: `str, np.ndarray, List[str]`
Output: `List[str]`: recognition result
#### FSMN-VAD-online
```python
from funasr_onnx import Fsmn_vad_online
import soundfile
from pathlib import Path
model_dir = "damo/speech_fsmn_vad_zh-cn-16k-common-pytorch"
wav_path = '{}/.cache/modelscope/hub/damo/speech_fsmn_vad_zh-cn-16k-common-pytorch/example/vad_example.wav'.format(Path.home())
model = Fsmn_vad_online(model_dir)
##online vad
speech, sample_rate = soundfile.read(wav_path)
speech_length = speech.shape[0]
#
sample_offset = 0
step = 1600
param_dict = {'in_cache': []}
for sample_offset in range(0, speech_length, min(step, speech_length - sample_offset)):
if sample_offset + step >= speech_length - 1:
step = speech_length - sample_offset
is_final = True
else:
is_final = False
param_dict['is_final'] = is_final
segments_result = model(audio_in=speech[sample_offset: sample_offset + step],
param_dict=param_dict)
if segments_result:
print(segments_result)
```
- `model_dir`: model_name in modelscope or local path downloaded from modelscope. If the local path is set, it should contain `model.onnx`, `config.yaml`, `am.mvn`
- `batch_size`: `1` (Default), the batch size duration inference
- `device_id`: `-1` (Default), infer on CPU. If you want to infer with GPU, set it to gpu_id (Please make sure that you have install the onnxruntime-gpu)
- `quantize`: `False` (Default), load the model of `model.onnx` in `model_dir`. If set `True`, load the model of `model_quant.onnx` in `model_dir`
- `intra_op_num_threads`: `4` (Default), sets the number of threads used for intraop parallelism on CPU
Input: wav formt file, support formats: `str, np.ndarray, List[str]`
Output: `List[str]`: recognition result
### Punctuation Restoration
#### CT-Transformer
```python
from funasr_onnx import CT_Transformer
model_dir = "damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch"
model = CT_Transformer(model_dir)
text_in="跨境河流是养育沿岸人民的生命之源长期以来为帮助下游地区防灾减灾中方技术人员在上游地区极为恶劣的自然条件下克服巨大困难甚至冒着生命危险向印方提供汛期水文资料处理紧急事件中方重视印方在跨境河流问题上的关切愿意进一步完善双方联合工作机制凡是中方能做的我们都会去做而且会做得更好我请印度朋友们放心中国在上游的任何开发利用都会经过科学规划和论证兼顾上下游的利益"
result = model(text_in)
print(result[0])
```
- `model_dir`: model_name in modelscope or local path downloaded from modelscope. If the local path is set, it should contain `model.onnx`, `config.yaml`, `am.mvn`
- `device_id`: `-1` (Default), infer on CPU. If you want to infer with GPU, set it to gpu_id (Please make sure that you have install the onnxruntime-gpu)
- `quantize`: `False` (Default), load the model of `model.onnx` in `model_dir`. If set `True`, load the model of `model_quant.onnx` in `model_dir`
- `intra_op_num_threads`: `4` (Default), sets the number of threads used for intraop parallelism on CPU
Input: `str`, raw text of asr result
Output: `List[str]`: recognition result
#### CT-Transformer-online
```python
from funasr_onnx import CT_Transformer_VadRealtime
model_dir = "damo/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727"
model = CT_Transformer_VadRealtime(model_dir)
text_in = "跨境河流是养育沿岸|人民的生命之源长期以来为帮助下游地区防灾减灾中方技术人员|在上游地区极为恶劣的自然条件下克服巨大困难甚至冒着生命危险|向印方提供汛期水文资料处理紧急事件中方重视印方在跨境河流>问题上的关切|愿意进一步完善双方联合工作机制|凡是|中方能做的我们|都会去做而且会做得更好我请印度朋友们放心中国在上游的|任何开发利用都会经过科学|规划和论证兼顾上下游的利益"
vads = text_in.split("|")
rec_result_all=""
param_dict = {"cache": []}
for vad in vads:
result = model(vad, param_dict=param_dict)
rec_result_all += result[0]
print(rec_result_all)
```
- `model_dir`: model_name in modelscope or local path downloaded from modelscope. If the local path is set, it should contain `model.onnx`, `config.yaml`, `am.mvn`
- `device_id`: `-1` (Default), infer on CPU. If you want to infer with GPU, set it to gpu_id (Please make sure that you have install the onnxruntime-gpu)
- `quantize`: `False` (Default), load the model of `model.onnx` in `model_dir`. If set `True`, load the model of `model_quant.onnx` in `model_dir`
- `intra_op_num_threads`: `4` (Default), sets the number of threads used for intraop parallelism on CPU
Input: `str`, raw text of asr result
Output: `List[str]`: recognition result
## Performance benchmark
Please ref to [benchmark](https://github.com/alibaba-damo-academy/FunASR/blob/main/runtime/docs/benchmark_onnx.md)
## Acknowledge
1. This project is maintained by [FunASR community](https://github.com/alibaba-damo-academy/FunASR).
2. We partially refer [SWHL](https://github.com/RapidAI/RapidASR) for onnxruntime (only for paraformer model).

View File

@@ -0,0 +1,12 @@
from funasr_onnx import ContextualParaformer
from pathlib import Path
model_dir = "damo/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404"
model = ContextualParaformer(model_dir, batch_size=1)
wav_path = ["{}/.cache/modelscope/hub/{}/example/asr_example.wav".format(Path.home(), model_dir)]
hotwords = "你的热词 魔搭"
result = model(wav_path, hotwords)
print(result)

View File

@@ -0,0 +1,15 @@
from funasr_onnx import Paraformer
from pathlib import Path
model_dir = "damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
# model_dir = "damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
model = Paraformer(model_dir, batch_size=1, quantize=False)
# model = Paraformer(model_dir, batch_size=1, device_id=0) # gpu
# when using paraformer-large-vad-punc model, you can set plot_timestamp_to="./xx.png" to get figure of alignment besides timestamps
# model = Paraformer(model_dir, batch_size=1, plot_timestamp_to="test.png")
wav_path = ["{}/.cache/modelscope/hub/{}/example/asr_example.wav".format(Path.home(), model_dir)]
result = model(wav_path)
print(result)

View File

@@ -0,0 +1,31 @@
import soundfile
from funasr_onnx.paraformer_online_bin import Paraformer
from pathlib import Path
model_dir = "damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online"
wav_path = ["{}/.cache/modelscope/hub/{}/example/asr_example.wav".format(Path.home(), model_dir)]
chunk_size = [5, 10, 5]
model = Paraformer(
model_dir, batch_size=1, quantize=True, chunk_size=chunk_size, intra_op_num_threads=4
) # only support batch_size = 1
##online asr
speech, sample_rate = soundfile.read(wav_path)
speech_length = speech.shape[0]
sample_offset = 0
step = chunk_size[1] * 960
param_dict = {"cache": dict()}
final_result = ""
for sample_offset in range(0, speech_length, min(step, speech_length - sample_offset)):
if sample_offset + step >= speech_length - 1:
step = speech_length - sample_offset
is_final = True
else:
is_final = False
param_dict["is_final"] = is_final
rec_result = model(audio_in=speech[sample_offset : sample_offset + step], param_dict=param_dict)
if len(rec_result) > 0:
final_result += rec_result[0]["preds"][0]
print(rec_result)
print(final_result)

View File

@@ -0,0 +1,9 @@
from funasr_onnx import CT_Transformer
# model_dir = "damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch"
model_dir = "damo/punc_ct-transformer_cn-en-common-vocab471067-large"
model = CT_Transformer(model_dir)
text_in = "跨境河流是养育沿岸人民的生命之源长期以来为帮助下游地区防灾减灾中方技术人员在上游地区极为恶劣的自然条件下克服巨大困难甚至冒着生命危险向印方提供汛期水文资料处理紧急事件中方重视印方在跨境河流问题上的关切愿意进一步完善双方联合工作机制凡是中方能做的我们都会去做而且会做得更好我请印度朋友们放心中国在上游的任何开发利用都会经过科学规划和论证兼顾上下游的利益"
result = model(text_in)
print(result[0])

View File

@@ -0,0 +1,15 @@
from funasr_onnx import CT_Transformer_VadRealtime
model_dir = "damo/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727"
model = CT_Transformer_VadRealtime(model_dir)
text_in = "跨境河流是养育沿岸|人民的生命之源长期以来为帮助下游地区防灾减灾中方技术人员|在上游地区极为恶劣的自然条件下克服巨大困难甚至冒着生命危险|向印方提供汛期水文资料处理紧急事件中方重视印方在跨境河流>问题上的关切|愿意进一步完善双方联合工作机制|凡是|中方能做的我们|都会去做而且会做得更好我请印度朋友们放心中国在上游的|任何开发利用都会经过科学|规划和论证兼顾上下游的利益"
vads = text_in.split("|")
rec_result_all = ""
param_dict = {"cache": []}
for vad in vads:
result = model(vad, param_dict=param_dict)
rec_result_all += result[0]
print(rec_result_all)

View File

@@ -0,0 +1,12 @@
from funasr_onnx import SeacoParaformer
from pathlib import Path
model_dir = "iic/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
model = SeacoParaformer(model_dir, batch_size=1)
wav_path = ["{}/.cache/modelscope/hub/{}/example/asr_example.wav".format(Path.home(), model_dir)]
hotwords = "你的热词 魔搭"
result = model(wav_path, hotwords)
print(result)

View File

@@ -0,0 +1,19 @@
#!/usr/bin/env python3
# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/FunAudioLLM/SenseVoice). All Rights Reserved.
# MIT License (https://opensource.org/licenses/MIT)
from pathlib import Path
from funasr_onnx import SenseVoiceSmall
from funasr_onnx.utils.postprocess_utils import rich_transcription_postprocess
model_dir = "iic/SenseVoiceSmall"
model = SenseVoiceSmall(model_dir, batch_size=10, quantize=False)
# inference
wav_or_scp = ["{}/.cache/modelscope/hub/{}/example/en.mp3".format(Path.home(), model_dir)]
res = model(wav_or_scp, language="auto", use_itn=True)
print([rich_transcription_postprocess(i) for i in res])

View File

@@ -0,0 +1,12 @@
from funasr_onnx import Fsmn_vad
from pathlib import Path
model_dir = "damo/speech_fsmn_vad_zh-cn-16k-common-pytorch"
wav_path = "{}/.cache/modelscope/hub/damo/speech_fsmn_vad_zh-cn-16k-common-pytorch/example/vad_example.wav".format(
Path.home()
)
model = Fsmn_vad(model_dir)
result = model(wav_path)
print(result)

View File

@@ -0,0 +1,31 @@
from funasr_onnx import Fsmn_vad_online
import soundfile
from pathlib import Path
model_dir = "damo/speech_fsmn_vad_zh-cn-16k-common-pytorch"
wav_path = "{}/.cache/modelscope/hub/damo/speech_fsmn_vad_zh-cn-16k-common-pytorch/example/vad_example.wav".format(
Path.home()
)
model = Fsmn_vad_online(model_dir)
##online vad
speech, sample_rate = soundfile.read(wav_path)
speech_length = speech.shape[0]
#
sample_offset = 0
step = 1600
param_dict = {"in_cache": []}
for sample_offset in range(0, speech_length, min(step, speech_length - sample_offset)):
if sample_offset + step >= speech_length - 1:
step = speech_length - sample_offset
is_final = True
else:
is_final = False
param_dict["is_final"] = is_final
segments_result = model(
audio_in=speech[sample_offset : sample_offset + step], param_dict=param_dict
)
if segments_result:
print(segments_result)

View File

@@ -0,0 +1,27 @@
import base64
import requests
import threading
with open("A2_0.wav", "rb") as f:
test_wav_bytes = f.read()
url = "http://127.0.0.1:8888/api/asr"
def send_post(i, url, wav_bytes):
r1 = requests.post(url, json={"wav_base64": str(base64.b64encode(wav_bytes), "utf-8")})
print("线程:", i, r1.json())
for i in range(100):
t = threading.Thread(
target=send_post,
args=(
i,
url,
test_wav_bytes,
),
)
t.start()
# t.join()
print("完成测试")

View File

@@ -0,0 +1,7 @@
# -*- encoding: utf-8 -*-
from .paraformer_bin import Paraformer, ContextualParaformer, SeacoParaformer
from .vad_bin import Fsmn_vad
from .vad_bin import Fsmn_vad_online
from .punc_bin import CT_Transformer
from .punc_bin import CT_Transformer_VadRealtime
from .sensevoice_bin import SenseVoiceSmall

View File

@@ -0,0 +1,466 @@
# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
# MIT License (https://opensource.org/licenses/MIT)
import os.path
from pathlib import Path
from typing import List, Union, Tuple
import json
import copy
import librosa
import numpy as np
from .utils.utils import (
CharTokenizer,
Hypothesis,
ONNXRuntimeError,
OrtInferSession,
TokenIDConverter,
get_logger,
read_yaml,
)
from .utils.postprocess_utils import sentence_postprocess, sentence_postprocess_sentencepiece
from .utils.frontend import WavFrontend
from .utils.timestamp_utils import time_stamp_lfr6_onnx
from .utils.utils import pad_list
logging = get_logger()
class Paraformer:
"""
Author: Speech Lab of DAMO Academy, Alibaba Group
Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
https://arxiv.org/abs/2206.08317
"""
def __init__(
self,
model_dir: Union[str, Path] = None,
batch_size: int = 1,
device_id: Union[str, int] = "-1",
plot_timestamp_to: str = "",
quantize: bool = False,
intra_op_num_threads: int = 4,
cache_dir: str = None,
**kwargs,
):
if not Path(model_dir).exists():
try:
from modelscope.hub.snapshot_download import snapshot_download
except:
raise "You are exporting model from modelscope, please install modelscope and try it again. To install modelscope, you could:\n" "\npip3 install -U modelscope\n" "For the users in China, you could install with the command:\n" "\npip3 install -U modelscope -i https://mirror.sjtu.edu.cn/pypi/web/simple"
try:
model_dir = snapshot_download(model_dir, cache_dir=cache_dir)
except:
raise "model_dir must be model_name in modelscope or local path downloaded from modelscope, but is {}".format(
model_dir
)
model_file = os.path.join(model_dir, "model.onnx")
if quantize:
model_file = os.path.join(model_dir, "model_quant.onnx")
if not os.path.exists(model_file):
print(".onnx does not exist, begin to export onnx")
try:
from funasr import AutoModel
except:
raise "You are exporting onnx, please install funasr and try it again. To install funasr, you could:\n" "\npip3 install -U funasr\n" "For the users in China, you could install with the command:\n" "\npip3 install -U funasr -i https://mirror.sjtu.edu.cn/pypi/web/simple"
model = AutoModel(model=model_dir)
model_dir = model.export(type="onnx", quantize=quantize, **kwargs)
config_file = os.path.join(model_dir, "config.yaml")
cmvn_file = os.path.join(model_dir, "am.mvn")
config = read_yaml(config_file)
token_list = os.path.join(model_dir, "tokens.json")
with open(token_list, "r", encoding="utf-8") as f:
token_list = json.load(f)
self.converter = TokenIDConverter(token_list)
self.tokenizer = CharTokenizer()
self.frontend = WavFrontend(cmvn_file=cmvn_file, **config["frontend_conf"])
self.ort_infer = OrtInferSession(
model_file, device_id, intra_op_num_threads=intra_op_num_threads
)
self.batch_size = batch_size
self.plot_timestamp_to = plot_timestamp_to
if "predictor_bias" in config["model_conf"].keys():
self.pred_bias = config["model_conf"]["predictor_bias"]
else:
self.pred_bias = 0
if "lang" in config:
self.language = config["lang"]
else:
self.language = None
def __call__(self, wav_content: Union[str, np.ndarray, List[str]], **kwargs) -> List:
waveform_list = self.load_data(wav_content, self.frontend.opts.frame_opts.samp_freq)
waveform_nums = len(waveform_list)
asr_res = []
for beg_idx in range(0, waveform_nums, self.batch_size):
end_idx = min(waveform_nums, beg_idx + self.batch_size)
feats, feats_len = self.extract_feat(waveform_list[beg_idx:end_idx])
try:
outputs = self.infer(feats, feats_len)
am_scores, valid_token_lens = outputs[0], outputs[1]
if len(outputs) == 4:
# for BiCifParaformer Inference
us_alphas, us_peaks = outputs[2], outputs[3]
else:
us_alphas, us_peaks = None, None
except ONNXRuntimeError:
# logging.warning(traceback.format_exc())
logging.warning("input wav is silence or noise")
preds = [""]
else:
preds = self.decode(am_scores, valid_token_lens)
if us_peaks is None:
for pred in preds:
if self.language == "en-bpe":
pred = sentence_postprocess_sentencepiece(pred)
else:
pred = sentence_postprocess(pred)
asr_res.append({"preds": pred})
else:
for pred, us_peaks_ in zip(preds, us_peaks):
raw_tokens = pred
timestamp, timestamp_raw = time_stamp_lfr6_onnx(
us_peaks_, copy.copy(raw_tokens)
)
text_proc, timestamp_proc, _ = sentence_postprocess(
raw_tokens, timestamp_raw
)
# logging.warning(timestamp)
if len(self.plot_timestamp_to):
self.plot_wave_timestamp(
waveform_list[0], timestamp, self.plot_timestamp_to
)
asr_res.append(
{
"preds": text_proc,
"timestamp": timestamp_proc,
"raw_tokens": raw_tokens,
}
)
return asr_res
def plot_wave_timestamp(self, wav, text_timestamp, dest):
# TODO: Plot the wav and timestamp results with matplotlib
import matplotlib
matplotlib.use("Agg")
matplotlib.rc(
"font", family="Alibaba PuHuiTi"
) # set it to a font that your system supports
import matplotlib.pyplot as plt
fig, ax1 = plt.subplots(figsize=(11, 3.5), dpi=320)
ax2 = ax1.twinx()
ax2.set_ylim([0, 2.0])
# plot waveform
ax1.set_ylim([-0.3, 0.3])
time = np.arange(wav.shape[0]) / 16000
ax1.plot(time, wav / wav.max() * 0.3, color="gray", alpha=0.4)
# plot lines and text
for char, start, end in text_timestamp:
ax1.vlines(start, -0.3, 0.3, ls="--")
ax1.vlines(end, -0.3, 0.3, ls="--")
x_adj = 0.045 if char != "<sil>" else 0.12
ax1.text((start + end) * 0.5 - x_adj, 0, char)
# plt.legend()
plotname = "{}/timestamp.png".format(dest)
plt.savefig(plotname, bbox_inches="tight")
def load_data(self, wav_content: Union[str, np.ndarray, List[str]], fs: int = None) -> List:
def load_wav(path: str) -> np.ndarray:
waveform, _ = librosa.load(path, sr=fs)
return waveform
if isinstance(wav_content, np.ndarray):
return [wav_content]
if isinstance(wav_content, str):
return [load_wav(wav_content)]
if isinstance(wav_content, list):
return [load_wav(path) for path in wav_content]
raise TypeError(f"The type of {wav_content} is not in [str, np.ndarray, list]")
def extract_feat(self, waveform_list: List[np.ndarray]) -> Tuple[np.ndarray, np.ndarray]:
feats, feats_len = [], []
for waveform in waveform_list:
speech, _ = self.frontend.fbank(waveform)
feat, feat_len = self.frontend.lfr_cmvn(speech)
feats.append(feat)
feats_len.append(feat_len)
feats = self.pad_feats(feats, np.max(feats_len))
feats_len = np.array(feats_len).astype(np.int32)
return feats, feats_len
@staticmethod
def pad_feats(feats: List[np.ndarray], max_feat_len: int) -> np.ndarray:
def pad_feat(feat: np.ndarray, cur_len: int) -> np.ndarray:
pad_width = ((0, max_feat_len - cur_len), (0, 0))
return np.pad(feat, pad_width, "constant", constant_values=0)
feat_res = [pad_feat(feat, feat.shape[0]) for feat in feats]
feats = np.array(feat_res).astype(np.float32)
return feats
def infer(self, feats: np.ndarray, feats_len: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
outputs = self.ort_infer([feats, feats_len])
return outputs
def decode(self, am_scores: np.ndarray, token_nums: int) -> List[str]:
return [
self.decode_one(am_score, token_num)
for am_score, token_num in zip(am_scores, token_nums)
]
def decode_one(self, am_score: np.ndarray, valid_token_num: int) -> List[str]:
yseq = am_score.argmax(axis=-1)
score = am_score.max(axis=-1)
score = np.sum(score, axis=-1)
# pad with mask tokens to ensure compatibility with sos/eos tokens
# asr_model.sos:1 asr_model.eos:2
yseq = np.array([1] + yseq.tolist() + [2])
hyp = Hypothesis(yseq=yseq, score=score)
# remove sos/eos and get results
last_pos = -1
token_int = hyp.yseq[1:last_pos].tolist()
# remove blank symbol id, which is assumed to be 0
token_int = list(filter(lambda x: x not in (0, 2), token_int))
# Change integer-ids to tokens
token = self.converter.ids2tokens(token_int)
token = token[: valid_token_num - self.pred_bias]
# texts = sentence_postprocess(token)
return token
class ContextualParaformer(Paraformer):
"""
Author: Speech Lab of DAMO Academy, Alibaba Group
Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
https://arxiv.org/abs/2206.08317
"""
def __init__(
self,
model_dir: Union[str, Path] = None,
batch_size: int = 1,
device_id: Union[str, int] = "-1",
plot_timestamp_to: str = "",
quantize: bool = False,
intra_op_num_threads: int = 4,
cache_dir: str = None,
**kwargs,
):
if not Path(model_dir).exists():
try:
from modelscope.hub.snapshot_download import snapshot_download
except:
raise "You are exporting model from modelscope, please install modelscope and try it again. To install modelscope, you could:\n" "\npip3 install -U modelscope\n" "For the users in China, you could install with the command:\n" "\npip3 install -U modelscope -i https://mirror.sjtu.edu.cn/pypi/web/simple"
try:
model_dir = snapshot_download(model_dir, cache_dir=cache_dir)
except:
raise "model_dir must be model_name in modelscope or local path downloaded from modelscope, but is {}".format(
model_dir
)
if quantize:
model_bb_file = os.path.join(model_dir, "model_quant.onnx")
model_eb_file = os.path.join(model_dir, "model_eb_quant.onnx")
else:
model_bb_file = os.path.join(model_dir, "model.onnx")
model_eb_file = os.path.join(model_dir, "model_eb.onnx")
if not (os.path.exists(model_eb_file) and os.path.exists(model_bb_file)):
print(".onnx does not exist, begin to export onnx")
try:
from funasr import AutoModel
except:
raise "You are exporting onnx, please install funasr and try it again. To install funasr, you could:\n" "\npip3 install -U funasr\n" "For the users in China, you could install with the command:\n" "\npip3 install -U funasr -i https://mirror.sjtu.edu.cn/pypi/web/simple"
model = AutoModel(model=model_dir)
model_dir = model.export(type="onnx", quantize=quantize, **kwargs)
config_file = os.path.join(model_dir, "config.yaml")
cmvn_file = os.path.join(model_dir, "am.mvn")
config = read_yaml(config_file)
token_list = os.path.join(model_dir, "tokens.json")
with open(token_list, "r", encoding="utf-8") as f:
token_list = json.load(f)
# revert token_list into vocab dict
self.vocab = {}
for i, token in enumerate(token_list):
self.vocab[token] = i
self.converter = TokenIDConverter(token_list)
self.tokenizer = CharTokenizer()
self.frontend = WavFrontend(cmvn_file=cmvn_file, **config["frontend_conf"])
self.ort_infer_bb = OrtInferSession(
model_bb_file, device_id, intra_op_num_threads=intra_op_num_threads
)
self.ort_infer_eb = OrtInferSession(
model_eb_file, device_id, intra_op_num_threads=intra_op_num_threads
)
self.batch_size = batch_size
self.plot_timestamp_to = plot_timestamp_to
if "predictor_bias" in config["model_conf"].keys():
self.pred_bias = config["model_conf"]["predictor_bias"]
else:
self.pred_bias = 0
if "lang" in config:
self.language = config["lang"]
else:
self.language = None
def __call__(
self, wav_content: Union[str, np.ndarray, List[str]], hotwords: str, **kwargs
) -> List:
# def __call__(
# self, waveform_list:list, hotwords: str, **kwargs
# ) -> List:
# make hotword list
hotwords, hotwords_length = self.proc_hotword(hotwords)
[bias_embed] = self.eb_infer(hotwords, hotwords_length)
# index from bias_embed
bias_embed = bias_embed.transpose(1, 0, 2)
_ind = np.arange(0, len(hotwords)).tolist()
bias_embed = bias_embed[_ind, hotwords_length.tolist()]
waveform_list = self.load_data(wav_content, self.frontend.opts.frame_opts.samp_freq)
waveform_nums = len(waveform_list)
asr_res = []
for beg_idx in range(0, waveform_nums, self.batch_size):
end_idx = min(waveform_nums, beg_idx + self.batch_size)
feats, feats_len = self.extract_feat(waveform_list[beg_idx:end_idx])
bias_embed = np.expand_dims(bias_embed, axis=0)
bias_embed = np.repeat(bias_embed, feats.shape[0], axis=0)
try:
outputs = self.bb_infer(feats, feats_len, bias_embed)
am_scores, valid_token_lens = outputs[0], outputs[1]
if len(outputs) == 4:
# for BiCifParaformer Inference
us_alphas, us_peaks = outputs[2], outputs[3]
else:
us_alphas, us_peaks = None, None
except ONNXRuntimeError:
# logging.warning(traceback.format_exc())
logging.warning("input wav is silence or noise")
preds = [""]
else:
preds = self.decode(am_scores, valid_token_lens)
if us_peaks is None:
for pred in preds:
if self.language == "en-bpe":
pred = sentence_postprocess_sentencepiece(pred)
else:
pred = sentence_postprocess(pred)
asr_res.append({"preds": pred})
else:
for pred, us_peaks_ in zip(preds, us_peaks):
raw_tokens = pred
timestamp, timestamp_raw = time_stamp_lfr6_onnx(
us_peaks_, copy.copy(raw_tokens)
)
text_proc, timestamp_proc, _ = sentence_postprocess(
raw_tokens, timestamp_raw
)
# logging.warning(timestamp)
if len(self.plot_timestamp_to):
self.plot_wave_timestamp(
waveform_list[0], timestamp, self.plot_timestamp_to
)
asr_res.append(
{
"preds": text_proc,
"timestamp": timestamp_proc,
"raw_tokens": raw_tokens,
}
)
return asr_res
def proc_hotword(self, hotwords):
hotwords = hotwords.split(" ")
hotwords_length = [len(i) - 1 for i in hotwords]
hotwords_length.append(0)
hotwords_length = np.array(hotwords_length)
# hotwords.append('<s>')
def word_map(word):
hotwords = []
for c in word:
if c not in self.vocab.keys():
hotwords.append(8403)
logging.warning(
"oov character {} found in hotword {}, replaced by <unk>".format(c, word)
)
else:
hotwords.append(self.vocab[c])
return np.array(hotwords)
hotword_int = [word_map(i) for i in hotwords]
hotword_int.append(np.array([1]))
hotwords = pad_list(hotword_int, pad_value=0, max_len=10)
return hotwords, hotwords_length
def bb_infer(
self, feats: np.ndarray, feats_len: np.ndarray, bias_embed
) -> Tuple[np.ndarray, np.ndarray]:
outputs = self.ort_infer_bb([feats, feats_len, bias_embed])
return outputs
def eb_infer(self, hotwords, hotwords_length):
outputs = self.ort_infer_eb([hotwords.astype(np.int32), hotwords_length.astype(np.int32)])
return outputs
def decode(self, am_scores: np.ndarray, token_nums: int) -> List[str]:
return [
self.decode_one(am_score, token_num)
for am_score, token_num in zip(am_scores, token_nums)
]
def decode_one(self, am_score: np.ndarray, valid_token_num: int) -> List[str]:
yseq = am_score.argmax(axis=-1)
score = am_score.max(axis=-1)
score = np.sum(score, axis=-1)
# pad with mask tokens to ensure compatibility with sos/eos tokens
# asr_model.sos:1 asr_model.eos:2
yseq = np.array([1] + yseq.tolist() + [2])
hyp = Hypothesis(yseq=yseq, score=score)
# remove sos/eos and get results
last_pos = -1
token_int = hyp.yseq[1:last_pos].tolist()
# remove blank symbol id, which is assumed to be 0
token_int = list(filter(lambda x: x not in (0, 2), token_int))
# Change integer-ids to tokens
token = self.converter.ids2tokens(token_int)
token = token[: valid_token_num - self.pred_bias]
# texts = sentence_postprocess(token)
return token
class SeacoParaformer(ContextualParaformer):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# no difference with contextual_paraformer in method of calling onnx models

View File

@@ -0,0 +1,330 @@
# -*- encoding: utf-8 -*-
import os.path
from pathlib import Path
from typing import List, Union, Tuple
import json
import copy
import librosa
import numpy as np
from .utils.utils import (
CharTokenizer,
Hypothesis,
ONNXRuntimeError,
OrtInferSession,
TokenIDConverter,
get_logger,
read_yaml,
)
from .utils.postprocess_utils import sentence_postprocess
from .utils.frontend import WavFrontendOnline, SinusoidalPositionEncoderOnline
logging = get_logger()
class Paraformer:
def __init__(
self,
model_dir: Union[str, Path] = None,
batch_size: int = 1,
chunk_size: List = [5, 10, 5],
device_id: Union[str, int] = "-1",
quantize: bool = False,
intra_op_num_threads: int = 4,
cache_dir: str = None,
**kwargs,
):
if not Path(model_dir).exists():
try:
from modelscope.hub.snapshot_download import snapshot_download
except:
raise "You are exporting model from modelscope, please install modelscope and try it again. To install modelscope, you could:\n" "\npip3 install -U modelscope\n" "For the users in China, you could install with the command:\n" "\npip3 install -U modelscope -i https://mirror.sjtu.edu.cn/pypi/web/simple"
try:
model_dir = snapshot_download(model_dir, cache_dir=cache_dir)
except:
raise "model_dir must be model_name in modelscope or local path downloaded from modelscope, but is {}".format(
model_dir
)
encoder_model_file = os.path.join(model_dir, "model.onnx")
decoder_model_file = os.path.join(model_dir, "decoder.onnx")
if quantize:
encoder_model_file = os.path.join(model_dir, "model_quant.onnx")
decoder_model_file = os.path.join(model_dir, "decoder_quant.onnx")
if not os.path.exists(encoder_model_file) or not os.path.exists(decoder_model_file):
print(".onnx does not exist, begin to export onnx")
try:
from funasr import AutoModel
except:
raise "You are exporting onnx, please install funasr and try it again. To install funasr, you could:\n" "\npip3 install -U funasr\n" "For the users in China, you could install with the command:\n" "\npip3 install -U funasr -i https://mirror.sjtu.edu.cn/pypi/web/simple"
model = AutoModel(model=model_dir)
model_dir = model.export(type="onnx", quantize=quantize, **kwargs)
config_file = os.path.join(model_dir, "config.yaml")
cmvn_file = os.path.join(model_dir, "am.mvn")
config = read_yaml(config_file)
token_list = os.path.join(model_dir, "tokens.json")
with open(token_list, "r", encoding="utf-8") as f:
token_list = json.load(f)
self.converter = TokenIDConverter(token_list)
self.tokenizer = CharTokenizer()
self.frontend = WavFrontendOnline(cmvn_file=cmvn_file, **config["frontend_conf"])
self.pe = SinusoidalPositionEncoderOnline()
self.ort_encoder_infer = OrtInferSession(
encoder_model_file, device_id, intra_op_num_threads=intra_op_num_threads
)
self.ort_decoder_infer = OrtInferSession(
decoder_model_file, device_id, intra_op_num_threads=intra_op_num_threads
)
self.batch_size = batch_size
self.chunk_size = chunk_size
self.encoder_output_size = config["encoder_conf"]["output_size"]
self.fsmn_layer = config["decoder_conf"]["num_blocks"]
self.fsmn_lorder = config["decoder_conf"]["kernel_size"] - 1
self.fsmn_dims = config["encoder_conf"]["output_size"]
self.feats_dims = config["frontend_conf"]["n_mels"] * config["frontend_conf"]["lfr_m"]
self.cif_threshold = config["predictor_conf"]["threshold"]
self.tail_threshold = config["predictor_conf"]["tail_threshold"]
def prepare_cache(self, cache: dict = {}, batch_size=1):
if len(cache) > 0:
return cache
cache["start_idx"] = 0
cache["cif_hidden"] = np.zeros((batch_size, 1, self.encoder_output_size)).astype(np.float32)
cache["cif_alphas"] = np.zeros((batch_size, 1)).astype(np.float32)
cache["chunk_size"] = self.chunk_size
cache["last_chunk"] = False
cache["feats"] = np.zeros(
(batch_size, self.chunk_size[0] + self.chunk_size[2], self.feats_dims)
).astype(np.float32)
cache["decoder_fsmn"] = []
for i in range(self.fsmn_layer):
fsmn_cache = np.zeros((batch_size, self.fsmn_dims, self.fsmn_lorder)).astype(np.float32)
cache["decoder_fsmn"].append(fsmn_cache)
return cache
def add_overlap_chunk(self, feats: np.ndarray, cache: dict = {}):
if len(cache) == 0:
return feats
# process last chunk
overlap_feats = np.concatenate((cache["feats"], feats), axis=1)
if cache["is_final"]:
cache["feats"] = overlap_feats[:, -self.chunk_size[0] :, :]
if not cache["last_chunk"]:
padding_length = sum(self.chunk_size) - overlap_feats.shape[1]
overlap_feats = np.pad(overlap_feats, ((0, 0), (0, padding_length), (0, 0)))
else:
cache["feats"] = overlap_feats[:, -(self.chunk_size[0] + self.chunk_size[2]) :, :]
return overlap_feats
def __call__(self, audio_in: np.ndarray, **kwargs):
waveforms = np.expand_dims(audio_in, axis=0)
param_dict = kwargs.get("param_dict", dict())
is_final = param_dict.get("is_final", False)
cache = param_dict.get("cache", dict())
asr_res = []
if waveforms.shape[1] < 16 * 60 and is_final and len(cache) > 0:
cache["last_chunk"] = True
feats = cache["feats"]
feats_len = np.array([feats.shape[1]]).astype(np.int32)
asr_res = self.infer(feats, feats_len, cache)
return asr_res
feats, feats_len = self.extract_feat(waveforms, is_final)
if feats.shape[1] != 0:
feats *= self.encoder_output_size**0.5
cache = self.prepare_cache(cache)
cache["is_final"] = is_final
# fbank -> position encoding -> overlap chunk
feats = self.pe.forward(feats, cache["start_idx"])
cache["start_idx"] += feats.shape[1]
if is_final:
if feats.shape[1] + self.chunk_size[2] <= self.chunk_size[1]:
cache["last_chunk"] = True
feats = self.add_overlap_chunk(feats, cache)
else:
# first chunk
feats_chunk1 = self.add_overlap_chunk(feats[:, : self.chunk_size[1], :], cache)
feats_len = np.array([feats_chunk1.shape[1]]).astype(np.int32)
asr_res_chunk1 = self.infer(feats_chunk1, feats_len, cache)
# last chunk
cache["last_chunk"] = True
feats_chunk2 = self.add_overlap_chunk(
feats[:, -(feats.shape[1] + self.chunk_size[2] - self.chunk_size[1]) :, :],
cache,
)
feats_len = np.array([feats_chunk2.shape[1]]).astype(np.int32)
asr_res_chunk2 = self.infer(feats_chunk2, feats_len, cache)
asr_res_chunk = asr_res_chunk1 + asr_res_chunk2
res = {}
for pred in asr_res_chunk:
for key, value in pred.items():
if key in res:
res[key][0] += value[0]
res[key][1].extend(value[1])
else:
res[key] = [value[0], value[1]]
return [res]
else:
feats = self.add_overlap_chunk(feats, cache)
feats_len = np.array([feats.shape[1]]).astype(np.int32)
asr_res = self.infer(feats, feats_len, cache)
return asr_res
def infer(self, feats: np.ndarray, feats_len: np.ndarray, cache):
# encoder forward
enc_input = [feats, feats_len]
enc, enc_lens, cif_alphas = self.ort_encoder_infer(enc_input)
# predictor forward
acoustic_embeds, acoustic_embeds_len = self.cif_search(enc, cif_alphas, cache)
# decoder forward
asr_res = []
if acoustic_embeds.shape[1] > 0:
dec_input = [enc, enc_lens, acoustic_embeds, acoustic_embeds_len]
dec_input.extend(cache["decoder_fsmn"])
dec_output = self.ort_decoder_infer(dec_input)
logits, sample_ids, cache["decoder_fsmn"] = dec_output[0], dec_output[1], dec_output[2:]
cache["decoder_fsmn"] = [
item[:, :, -self.fsmn_lorder :] for item in cache["decoder_fsmn"]
]
preds = self.decode(logits, acoustic_embeds_len)
for pred in preds:
pred = sentence_postprocess(pred)
asr_res.append({"preds": pred})
return asr_res
def load_data(self, wav_content: Union[str, np.ndarray, List[str]], fs: int = None) -> List:
def load_wav(path: str) -> np.ndarray:
waveform, _ = librosa.load(path, sr=fs)
return waveform
if isinstance(wav_content, np.ndarray):
return [wav_content]
if isinstance(wav_content, str):
return [load_wav(wav_content)]
if isinstance(wav_content, list):
return [load_wav(path) for path in wav_content]
raise TypeError(f"The type of {wav_content} is not in [str, np.ndarray, list]")
def extract_feat(
self, waveforms: np.ndarray, is_final: bool = False
) -> Tuple[np.ndarray, np.ndarray]:
waveforms_lens = np.zeros(waveforms.shape[0]).astype(np.int32)
for idx, waveform in enumerate(waveforms):
waveforms_lens[idx] = waveform.shape[-1]
feats, feats_len = self.frontend.extract_fbank(waveforms, waveforms_lens, is_final)
return feats.astype(np.float32), feats_len.astype(np.int32)
def decode(self, am_scores: np.ndarray, token_nums: int) -> List[str]:
return [
self.decode_one(am_score, token_num)
for am_score, token_num in zip(am_scores, token_nums)
]
def decode_one(self, am_score: np.ndarray, valid_token_num: int) -> List[str]:
yseq = am_score.argmax(axis=-1)
score = am_score.max(axis=-1)
score = np.sum(score, axis=-1)
# pad with mask tokens to ensure compatibility with sos/eos tokens
# asr_model.sos:1 asr_model.eos:2
yseq = np.array([1] + yseq.tolist() + [2])
hyp = Hypothesis(yseq=yseq, score=score)
# remove sos/eos and get results
last_pos = -1
token_int = hyp.yseq[1:last_pos].tolist()
# remove blank symbol id, which is assumed to be 0
token_int = list(filter(lambda x: x not in (0, 2), token_int))
# Change integer-ids to tokens
token = self.converter.ids2tokens(token_int)
token = token[:valid_token_num]
# texts = sentence_postprocess(token)
return token
def cif_search(self, hidden, alphas, cache=None):
batch_size, len_time, hidden_size = hidden.shape
token_length = []
list_fires = []
list_frames = []
cache_alphas = []
cache_hiddens = []
alphas[:, : self.chunk_size[0]] = 0.0
alphas[:, sum(self.chunk_size[:2]) :] = 0.0
if cache is not None and "cif_alphas" in cache and "cif_hidden" in cache:
hidden = np.concatenate((cache["cif_hidden"], hidden), axis=1)
alphas = np.concatenate((cache["cif_alphas"], alphas), axis=1)
if cache is not None and "last_chunk" in cache and cache["last_chunk"]:
tail_hidden = np.zeros((batch_size, 1, hidden_size)).astype(np.float32)
tail_alphas = np.array([[self.tail_threshold]]).astype(np.float32)
tail_alphas = np.tile(tail_alphas, (batch_size, 1))
hidden = np.concatenate((hidden, tail_hidden), axis=1)
alphas = np.concatenate((alphas, tail_alphas), axis=1)
len_time = alphas.shape[1]
for b in range(batch_size):
integrate = 0.0
frames = np.zeros(hidden_size).astype(np.float32)
list_frame = []
list_fire = []
for t in range(len_time):
alpha = alphas[b][t]
if alpha + integrate < self.cif_threshold:
integrate += alpha
list_fire.append(integrate)
frames += alpha * hidden[b][t]
else:
frames += (self.cif_threshold - integrate) * hidden[b][t]
list_frame.append(frames)
integrate += alpha
list_fire.append(integrate)
integrate -= self.cif_threshold
frames = integrate * hidden[b][t]
cache_alphas.append(integrate)
if integrate > 0.0:
cache_hiddens.append(frames / integrate)
else:
cache_hiddens.append(frames)
token_length.append(len(list_frame))
list_fires.append(list_fire)
list_frames.append(list_frame)
max_token_len = max(token_length)
list_ls = []
for b in range(batch_size):
pad_frames = np.zeros((max_token_len - token_length[b], hidden_size)).astype(np.float32)
if token_length[b] == 0:
list_ls.append(pad_frames)
else:
list_ls.append(np.concatenate((list_frames[b], pad_frames), axis=0))
cache["cif_alphas"] = np.stack(cache_alphas, axis=0)
cache["cif_alphas"] = np.expand_dims(cache["cif_alphas"], axis=0)
cache["cif_hidden"] = np.stack(cache_hiddens, axis=0)
cache["cif_hidden"] = np.expand_dims(cache["cif_hidden"], axis=0)
return np.stack(list_ls, axis=0).astype(np.float32), np.stack(token_length, axis=0).astype(
np.int32
)

View File

@@ -0,0 +1,320 @@
# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
# MIT License (https://opensource.org/licenses/MIT)
import os.path
from pathlib import Path
from typing import List, Union, Tuple
import numpy as np
import json
from .utils.utils import ONNXRuntimeError, OrtInferSession, get_logger, read_yaml
from .utils.utils import (
TokenIDConverter,
split_to_mini_sentence,
code_mix_split_words,
code_mix_split_words_jieba,
)
logging = get_logger()
class CT_Transformer:
"""
Author: Speech Lab of DAMO Academy, Alibaba Group
CT-Transformer: Controllable time-delay transformer for real-time punctuation prediction and disfluency detection
https://arxiv.org/pdf/2003.01309.pdf
"""
def __init__(
self,
model_dir: Union[str, Path] = None,
batch_size: int = 1,
device_id: Union[str, int] = "-1",
quantize: bool = False,
intra_op_num_threads: int = 4,
cache_dir: str = None,
**kwargs
):
if not Path(model_dir).exists():
try:
from modelscope.hub.snapshot_download import snapshot_download
except:
raise "You are exporting model from modelscope, please install modelscope and try it again. To install modelscope, you could:\n" "\npip3 install -U modelscope\n" "For the users in China, you could install with the command:\n" "\npip3 install -U modelscope -i https://mirror.sjtu.edu.cn/pypi/web/simple"
try:
model_dir = snapshot_download(model_dir, cache_dir=cache_dir)
except:
raise "model_dir must be model_name in modelscope or local path downloaded from modelscope, but is {}".format(
model_dir
)
model_file = os.path.join(model_dir, "model.onnx")
if quantize:
model_file = os.path.join(model_dir, "model_quant.onnx")
if not os.path.exists(model_file):
print(".onnx does not exist, begin to export onnx")
try:
from funasr import AutoModel
except:
raise "You are exporting onnx, please install funasr and try it again. To install funasr, you could:\n" "\npip3 install -U funasr\n" "For the users in China, you could install with the command:\n" "\npip3 install -U funasr -i https://mirror.sjtu.edu.cn/pypi/web/simple"
model = AutoModel(model=model_dir)
model_dir = model.export(type="onnx", quantize=quantize, **kwargs)
config_file = os.path.join(model_dir, "config.yaml")
config = read_yaml(config_file)
token_list = os.path.join(model_dir, "tokens.json")
with open(token_list, "r", encoding="utf-8") as f:
token_list = json.load(f)
self.converter = TokenIDConverter(token_list)
self.ort_infer = OrtInferSession(
model_file, device_id, intra_op_num_threads=intra_op_num_threads
)
self.batch_size = 1
self.punc_list = config["model_conf"]["punc_list"]
self.period = 0
for i in range(len(self.punc_list)):
if self.punc_list[i] == ",":
self.punc_list[i] = ""
elif self.punc_list[i] == "?":
self.punc_list[i] = ""
elif self.punc_list[i] == "":
self.period = i
self.jieba_usr_dict_path = os.path.join(model_dir, "jieba_usr_dict")
if os.path.exists(self.jieba_usr_dict_path):
self.seg_jieba = True
self.code_mix_split_words_jieba = code_mix_split_words_jieba(self.jieba_usr_dict_path)
else:
self.seg_jieba = False
def __call__(self, text: Union[list, str], split_size=20):
if self.seg_jieba:
split_text = self.code_mix_split_words_jieba(text)
else:
split_text = code_mix_split_words(text)
split_text_id = self.converter.tokens2ids(split_text)
mini_sentences = split_to_mini_sentence(split_text, split_size)
mini_sentences_id = split_to_mini_sentence(split_text_id, split_size)
assert len(mini_sentences) == len(mini_sentences_id)
cache_sent = []
cache_sent_id = []
new_mini_sentence = ""
new_mini_sentence_punc = []
cache_pop_trigger_limit = 200
for mini_sentence_i in range(len(mini_sentences)):
mini_sentence = mini_sentences[mini_sentence_i]
mini_sentence_id = mini_sentences_id[mini_sentence_i]
mini_sentence = cache_sent + mini_sentence
mini_sentence_id = np.array(cache_sent_id + mini_sentence_id, dtype="int32")
data = {
"text": mini_sentence_id[None, :],
"text_lengths": np.array([len(mini_sentence_id)], dtype="int32"),
}
try:
outputs = self.infer(data["text"], data["text_lengths"])
y = outputs[0]
punctuations = np.argmax(y, axis=-1)[0]
assert punctuations.size == len(mini_sentence)
except ONNXRuntimeError:
logging.warning("error")
# Search for the last Period/QuestionMark as cache
if mini_sentence_i < len(mini_sentences) - 1:
sentenceEnd = -1
last_comma_index = -1
for i in range(len(punctuations) - 2, 1, -1):
if (
self.punc_list[punctuations[i]] == ""
or self.punc_list[punctuations[i]] == ""
):
sentenceEnd = i
break
if last_comma_index < 0 and self.punc_list[punctuations[i]] == "":
last_comma_index = i
if (
sentenceEnd < 0
and len(mini_sentence) > cache_pop_trigger_limit
and last_comma_index >= 0
):
# The sentence it too long, cut off at a comma.
sentenceEnd = last_comma_index
punctuations[sentenceEnd] = self.period
cache_sent = mini_sentence[sentenceEnd + 1 :]
cache_sent_id = mini_sentence_id[sentenceEnd + 1 :].tolist()
mini_sentence = mini_sentence[0 : sentenceEnd + 1]
punctuations = punctuations[0 : sentenceEnd + 1]
new_mini_sentence_punc += [int(x) for x in punctuations]
words_with_punc = []
for i in range(len(mini_sentence)):
if i > 0:
if (
len(mini_sentence[i][0].encode()) == 1
and len(mini_sentence[i - 1][0].encode()) == 1
):
mini_sentence[i] = " " + mini_sentence[i]
words_with_punc.append(mini_sentence[i])
if self.punc_list[punctuations[i]] != "_":
words_with_punc.append(self.punc_list[punctuations[i]])
new_mini_sentence += "".join(words_with_punc)
# Add Period for the end of the sentence
new_mini_sentence_out = new_mini_sentence
new_mini_sentence_punc_out = new_mini_sentence_punc
if mini_sentence_i == len(mini_sentences) - 1:
if new_mini_sentence[-1] == "" or new_mini_sentence[-1] == "":
new_mini_sentence_out = new_mini_sentence[:-1] + ""
new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [self.period]
elif new_mini_sentence[-1] != "" and new_mini_sentence[-1] != "":
new_mini_sentence_out = new_mini_sentence + ""
new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [self.period]
return new_mini_sentence_out, new_mini_sentence_punc_out
def infer(self, feats: np.ndarray, feats_len: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
outputs = self.ort_infer([feats, feats_len])
return outputs
class CT_Transformer_VadRealtime(CT_Transformer):
"""
Author: Speech Lab of DAMO Academy, Alibaba Group
CT-Transformer: Controllable time-delay transformer for real-time punctuation prediction and disfluency detection
https://arxiv.org/pdf/2003.01309.pdf
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def __call__(self, text: str, param_dict: map, split_size=20):
cache_key = "cache"
assert cache_key in param_dict
cache = param_dict[cache_key]
if cache is not None and len(cache) > 0:
precache = "".join(cache)
else:
precache = ""
cache = []
full_text = precache + " " + text
split_text = code_mix_split_words(full_text)
split_text_id = self.converter.tokens2ids(split_text)
mini_sentences = split_to_mini_sentence(split_text, split_size)
mini_sentences_id = split_to_mini_sentence(split_text_id, split_size)
new_mini_sentence_punc = []
assert len(mini_sentences) == len(mini_sentences_id)
cache_sent = []
cache_sent_id = np.array([], dtype="int32")
sentence_punc_list = []
sentence_words_list = []
cache_pop_trigger_limit = 200
skip_num = 0
for mini_sentence_i in range(len(mini_sentences)):
mini_sentence = mini_sentences[mini_sentence_i]
mini_sentence_id = mini_sentences_id[mini_sentence_i]
mini_sentence = cache_sent + mini_sentence
mini_sentence_id = np.concatenate(
(cache_sent_id, mini_sentence_id), axis=0, dtype="int32"
)
text_length = len(mini_sentence_id)
vad_mask = self.vad_mask(text_length, len(cache))[None, None, :, :].astype(np.float32)
data = {
"input": mini_sentence_id[None, :],
"text_lengths": np.array([text_length], dtype="int32"),
"vad_mask": vad_mask,
"sub_masks": vad_mask,
}
try:
outputs = self.infer(
data["input"], data["text_lengths"], data["vad_mask"], data["sub_masks"]
)
y = outputs[0]
punctuations = np.argmax(y, axis=-1)[0]
assert punctuations.size == len(mini_sentence)
except ONNXRuntimeError:
logging.warning("error")
# Search for the last Period/QuestionMark as cache
if mini_sentence_i < len(mini_sentences) - 1:
sentenceEnd = -1
last_comma_index = -1
for i in range(len(punctuations) - 2, 1, -1):
if (
self.punc_list[punctuations[i]] == ""
or self.punc_list[punctuations[i]] == ""
):
sentenceEnd = i
break
if last_comma_index < 0 and self.punc_list[punctuations[i]] == "":
last_comma_index = i
if (
sentenceEnd < 0
and len(mini_sentence) > cache_pop_trigger_limit
and last_comma_index >= 0
):
# The sentence it too long, cut off at a comma.
sentenceEnd = last_comma_index
punctuations[sentenceEnd] = self.period
cache_sent = mini_sentence[sentenceEnd + 1 :]
cache_sent_id = mini_sentence_id[sentenceEnd + 1 :]
mini_sentence = mini_sentence[0 : sentenceEnd + 1]
punctuations = punctuations[0 : sentenceEnd + 1]
punctuations_np = [int(x) for x in punctuations]
new_mini_sentence_punc += punctuations_np
sentence_punc_list += [self.punc_list[int(x)] for x in punctuations_np]
sentence_words_list += mini_sentence
assert len(sentence_punc_list) == len(sentence_words_list)
words_with_punc = []
sentence_punc_list_out = []
for i in range(0, len(sentence_words_list)):
if i > 0:
if (
len(sentence_words_list[i][0].encode()) == 1
and len(sentence_words_list[i - 1][-1].encode()) == 1
):
sentence_words_list[i] = " " + sentence_words_list[i]
if skip_num < len(cache):
skip_num += 1
else:
words_with_punc.append(sentence_words_list[i])
if skip_num >= len(cache):
sentence_punc_list_out.append(sentence_punc_list[i])
if sentence_punc_list[i] != "_":
words_with_punc.append(sentence_punc_list[i])
sentence_out = "".join(words_with_punc)
sentenceEnd = -1
for i in range(len(sentence_punc_list) - 2, 1, -1):
if sentence_punc_list[i] == "" or sentence_punc_list[i] == "":
sentenceEnd = i
break
cache_out = sentence_words_list[sentenceEnd + 1 :]
if sentence_out[-1] in self.punc_list:
sentence_out = sentence_out[:-1]
sentence_punc_list_out[-1] = "_"
param_dict[cache_key] = cache_out
return sentence_out, sentence_punc_list_out, cache_out
def vad_mask(self, size, vad_pos, dtype=bool):
"""Create mask for decoder self-attention.
:param int size: size of mask
:param int vad_pos: index of vad index
:param torch.dtype dtype: result dtype
:rtype: torch.Tensor (B, Lmax, Lmax)
"""
ret = np.ones((size, size), dtype=dtype)
if vad_pos <= 0 or vad_pos >= size:
return ret
sub_corner = np.zeros((vad_pos - 1, size - vad_pos), dtype=dtype)
ret[0 : vad_pos - 1, vad_pos:] = sub_corner
return ret
def infer(
self, feats: np.ndarray, feats_len: np.ndarray, vad_mask: np.ndarray, sub_masks: np.ndarray
) -> Tuple[np.ndarray, np.ndarray]:
outputs = self.ort_infer([feats, feats_len, vad_mask, sub_masks])
return outputs

View File

@@ -0,0 +1,244 @@
#!/usr/bin/env python3
# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/FunAudioLLM/SenseVoice). All Rights Reserved.
# MIT License (https://opensource.org/licenses/MIT)
import torch
import os.path
import librosa
import numpy as np
from pathlib import Path
from typing import List, Union, Tuple
from .utils.utils import (
CharTokenizer,
Hypothesis,
ONNXRuntimeError,
OrtInferSession,
TokenIDConverter,
get_logger,
read_yaml,
)
from .utils.sentencepiece_tokenizer import SentencepiecesTokenizer
from .utils.frontend import WavFrontend
logging = get_logger()
class SenseVoiceSmall:
"""
Author: Speech Lab of DAMO Academy, Alibaba Group
Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
https://arxiv.org/abs/2206.08317
"""
def __init__(
self,
model_dir: Union[str, Path] = None,
batch_size: int = 1,
device_id: Union[str, int] = "-1",
plot_timestamp_to: str = "",
quantize: bool = False,
intra_op_num_threads: int = 4,
cache_dir: str = None,
**kwargs,
):
if not Path(model_dir).exists():
try:
from modelscope.hub.snapshot_download import snapshot_download
except:
raise "You are exporting model from modelscope, please install modelscope and try it again. To install modelscope, you could:\n" "\npip3 install -U modelscope\n" "For the users in China, you could install with the command:\n" "\npip3 install -U modelscope -i https://mirror.sjtu.edu.cn/pypi/web/simple"
try:
model_dir = snapshot_download(model_dir, cache_dir=cache_dir)
except:
raise "model_dir must be model_name in modelscope or local path downloaded from modelscope, but is {}".format(
model_dir
)
model_file = os.path.join(model_dir, "model.onnx")
if quantize:
model_file = os.path.join(model_dir, "model_quant.onnx")
if not os.path.exists(model_file):
print(".onnx does not exist, begin to export onnx")
try:
from funasr import AutoModel
except:
raise "You are exporting onnx, please install funasr and try it again. To install funasr, you could:\n" "\npip3 install -U funasr\n" "For the users in China, you could install with the command:\n" "\npip3 install -U funasr -i https://mirror.sjtu.edu.cn/pypi/web/simple"
model = AutoModel(model=model_dir)
model_dir = model.export(type="onnx", quantize=quantize, **kwargs)
config_file = os.path.join(model_dir, "config.yaml")
cmvn_file = os.path.join(model_dir, "am.mvn")
config = read_yaml(config_file)
self.tokenizer = SentencepiecesTokenizer(
bpemodel=os.path.join(model_dir, "chn_jpn_yue_eng_ko_spectok.bpe.model")
)
config["frontend_conf"]["cmvn_file"] = cmvn_file
self.frontend = WavFrontend(**config["frontend_conf"])
self.ort_infer = OrtInferSession(
model_file, device_id, intra_op_num_threads=intra_op_num_threads
)
self.batch_size = batch_size
self.blank_id = 0
self.lid_dict = {"auto": 0, "zh": 3, "en": 4, "yue": 7, "ja": 11, "ko": 12, "nospeech": 13}
self.lid_int_dict = {24884: 3, 24885: 4, 24888: 7, 24892: 11, 24896: 12, 24992: 13}
self.textnorm_dict = {"withitn": 14, "woitn": 15}
self.textnorm_int_dict = {25016: 14, 25017: 15}
def _get_lid(self, lid):
if lid in list(self.lid_dict.keys()):
return self.lid_dict[lid]
else:
raise ValueError(
f"The language {l} is not in {list(self.lid_dict.keys())}"
)
def _get_tnid(self, tnid):
if tnid in list(self.textnorm_dict.keys()):
return self.textnorm_dict[tnid]
else:
raise ValueError(
f"The textnorm {tnid} is not in {list(self.textnorm_dict.keys())}"
)
def read_tags(self, language_input, textnorm_input):
# handle language
if isinstance(language_input, list):
language_list = []
for l in language_input:
language_list.append(self._get_lid(l))
elif isinstance(language_input, str):
# if is existing file
if os.path.exists(language_input):
language_file = open(language_input, "r").readlines()
language_list = [
self._get_lid(l.strip())
for l in language_file
]
else:
language_list = [self._get_lid(language_input)]
else:
raise ValueError(
f"Unsupported type {type(language_input)} for language_input"
)
# handle textnorm
if isinstance(textnorm_input, list):
textnorm_list = []
for tn in textnorm_input:
textnorm_list.append(self._get_tnid(tn))
elif isinstance(textnorm_input, str):
# if is existing file
if os.path.exists(textnorm_input):
textnorm_file = open(textnorm_input, "r").readlines()
textnorm_list = [
self._get_tnid(tn.strip())
for tn in textnorm_file
]
else:
textnorm_list = [self._get_tnid(textnorm_input)]
else:
raise ValueError(
f"Unsupported type {type(textnorm_input)} for textnorm_input"
)
return language_list, textnorm_list
def __call__(self, wav_content: Union[str, np.ndarray, List[str]], **kwargs):
language_input = kwargs.get("language", "auto")
textnorm_input = kwargs.get("textnorm", "woitn")
language_list, textnorm_list = self.read_tags(language_input, textnorm_input)
waveform_list = self.load_data(wav_content, self.frontend.opts.frame_opts.samp_freq)
waveform_nums = len(waveform_list)
assert len(language_list) == 1 or len(language_list) == waveform_nums, \
"length of parsed language list should be 1 or equal to the number of waveforms"
assert len(textnorm_list) == 1 or len(textnorm_list) == waveform_nums, \
"length of parsed textnorm list should be 1 or equal to the number of waveforms"
asr_res = []
for beg_idx in range(0, waveform_nums, self.batch_size):
end_idx = min(waveform_nums, beg_idx + self.batch_size)
feats, feats_len = self.extract_feat(waveform_list[beg_idx:end_idx])
_language_list = language_list[beg_idx:end_idx]
_textnorm_list = textnorm_list[beg_idx:end_idx]
if not len(_language_list):
_language_list = [language_list[0]]
_textnorm_list = [textnorm_list[0]]
B = feats.shape[0]
if len(_language_list) == 1 and B != 1:
_language_list = _language_list * B
if len(_textnorm_list) == 1 and B != 1:
_textnorm_list = _textnorm_list * B
ctc_logits, encoder_out_lens = self.infer(
feats,
feats_len,
np.array(_language_list, dtype=np.int32),
np.array(_textnorm_list, dtype=np.int32),
)
for b in range(feats.shape[0]):
# back to torch.Tensor
if isinstance(ctc_logits, np.ndarray):
ctc_logits = torch.from_numpy(ctc_logits).float()
# support batch_size=1 only currently
x = ctc_logits[b, : encoder_out_lens[b].item(), :]
yseq = x.argmax(dim=-1)
yseq = torch.unique_consecutive(yseq, dim=-1)
mask = yseq != self.blank_id
token_int = yseq[mask].tolist()
asr_res.append(self.tokenizer.decode(token_int))
return asr_res
def load_data(self, wav_content: Union[str, np.ndarray, List[str]], fs: int = None) -> List:
def load_wav(path: str) -> np.ndarray:
waveform, _ = librosa.load(path, sr=fs)
return waveform
if isinstance(wav_content, np.ndarray):
return [wav_content]
if isinstance(wav_content, str):
return [load_wav(wav_content)]
if isinstance(wav_content, list):
return [load_wav(path) for path in wav_content]
raise TypeError(f"The type of {wav_content} is not in [str, np.ndarray, list]")
def extract_feat(self, waveform_list: List[np.ndarray]) -> Tuple[np.ndarray, np.ndarray]:
feats, feats_len = [], []
for waveform in waveform_list:
speech, _ = self.frontend.fbank(waveform)
feat, feat_len = self.frontend.lfr_cmvn(speech)
feats.append(feat)
feats_len.append(feat_len)
feats = self.pad_feats(feats, np.max(feats_len))
feats_len = np.array(feats_len).astype(np.int32)
return feats, feats_len
@staticmethod
def pad_feats(feats: List[np.ndarray], max_feat_len: int) -> np.ndarray:
def pad_feat(feat: np.ndarray, cur_len: int) -> np.ndarray:
pad_width = ((0, max_feat_len - cur_len), (0, 0))
return np.pad(feat, pad_width, "constant", constant_values=0)
feat_res = [pad_feat(feat, feat.shape[0]) for feat in feats]
feats = np.array(feat_res).astype(np.float32)
return feats
def infer(
self,
feats: np.ndarray,
feats_len: np.ndarray,
language: np.ndarray,
textnorm: np.ndarray,
) -> Tuple[np.ndarray, np.ndarray]:
outputs = self.ort_infer([feats, feats_len, language, textnorm])
return outputs

View File

@@ -0,0 +1,704 @@
# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
# MIT License (https://opensource.org/licenses/MIT)
from enum import Enum
from typing import List, Tuple, Dict, Any
import math
import numpy as np
class VadStateMachine(Enum):
kVadInStateStartPointNotDetected = 1
kVadInStateInSpeechSegment = 2
kVadInStateEndPointDetected = 3
class FrameState(Enum):
kFrameStateInvalid = -1
kFrameStateSpeech = 1
kFrameStateSil = 0
# final voice/unvoice state per frame
class AudioChangeState(Enum):
kChangeStateSpeech2Speech = 0
kChangeStateSpeech2Sil = 1
kChangeStateSil2Sil = 2
kChangeStateSil2Speech = 3
kChangeStateNoBegin = 4
kChangeStateInvalid = 5
class VadDetectMode(Enum):
kVadSingleUtteranceDetectMode = 0
kVadMutipleUtteranceDetectMode = 1
class VADXOptions:
def __init__(
self,
sample_rate: int = 16000,
detect_mode: int = VadDetectMode.kVadMutipleUtteranceDetectMode.value,
snr_mode: int = 0,
max_end_silence_time: int = 800,
max_start_silence_time: int = 3000,
do_start_point_detection: bool = True,
do_end_point_detection: bool = True,
window_size_ms: int = 200,
sil_to_speech_time_thres: int = 150,
speech_to_sil_time_thres: int = 150,
speech_2_noise_ratio: float = 1.0,
do_extend: int = 1,
lookback_time_start_point: int = 200,
lookahead_time_end_point: int = 100,
max_single_segment_time: int = 60000,
nn_eval_block_size: int = 8,
dcd_block_size: int = 4,
snr_thres: int = -100.0,
noise_frame_num_used_for_snr: int = 100,
decibel_thres: int = -100.0,
speech_noise_thres: float = 0.6,
fe_prior_thres: float = 1e-4,
silence_pdf_num: int = 1,
sil_pdf_ids: List[int] = [0],
speech_noise_thresh_low: float = -0.1,
speech_noise_thresh_high: float = 0.3,
output_frame_probs: bool = False,
frame_in_ms: int = 10,
frame_length_ms: int = 25,
):
self.sample_rate = sample_rate
self.detect_mode = detect_mode
self.snr_mode = snr_mode
self.max_end_silence_time = max_end_silence_time
self.max_start_silence_time = max_start_silence_time
self.do_start_point_detection = do_start_point_detection
self.do_end_point_detection = do_end_point_detection
self.window_size_ms = window_size_ms
self.sil_to_speech_time_thres = sil_to_speech_time_thres
self.speech_to_sil_time_thres = speech_to_sil_time_thres
self.speech_2_noise_ratio = speech_2_noise_ratio
self.do_extend = do_extend
self.lookback_time_start_point = lookback_time_start_point
self.lookahead_time_end_point = lookahead_time_end_point
self.max_single_segment_time = max_single_segment_time
self.nn_eval_block_size = nn_eval_block_size
self.dcd_block_size = dcd_block_size
self.snr_thres = snr_thres
self.noise_frame_num_used_for_snr = noise_frame_num_used_for_snr
self.decibel_thres = decibel_thres
self.speech_noise_thres = speech_noise_thres
self.fe_prior_thres = fe_prior_thres
self.silence_pdf_num = silence_pdf_num
self.sil_pdf_ids = sil_pdf_ids
self.speech_noise_thresh_low = speech_noise_thresh_low
self.speech_noise_thresh_high = speech_noise_thresh_high
self.output_frame_probs = output_frame_probs
self.frame_in_ms = frame_in_ms
self.frame_length_ms = frame_length_ms
class E2EVadSpeechBufWithDoa(object):
def __init__(self):
self.start_ms = 0
self.end_ms = 0
self.buffer = []
self.contain_seg_start_point = False
self.contain_seg_end_point = False
self.doa = 0
def Reset(self):
self.start_ms = 0
self.end_ms = 0
self.buffer = []
self.contain_seg_start_point = False
self.contain_seg_end_point = False
self.doa = 0
class E2EVadFrameProb(object):
def __init__(self):
self.noise_prob = 0.0
self.speech_prob = 0.0
self.score = 0.0
self.frame_id = 0
self.frm_state = 0
class WindowDetector(object):
def __init__(
self,
window_size_ms: int,
sil_to_speech_time: int,
speech_to_sil_time: int,
frame_size_ms: int,
):
self.window_size_ms = window_size_ms
self.sil_to_speech_time = sil_to_speech_time
self.speech_to_sil_time = speech_to_sil_time
self.frame_size_ms = frame_size_ms
self.win_size_frame = int(window_size_ms / frame_size_ms)
self.win_sum = 0
self.win_state = [0] * self.win_size_frame # 初始化窗
self.cur_win_pos = 0
self.pre_frame_state = FrameState.kFrameStateSil
self.cur_frame_state = FrameState.kFrameStateSil
self.sil_to_speech_frmcnt_thres = int(sil_to_speech_time / frame_size_ms)
self.speech_to_sil_frmcnt_thres = int(speech_to_sil_time / frame_size_ms)
self.voice_last_frame_count = 0
self.noise_last_frame_count = 0
self.hydre_frame_count = 0
def Reset(self) -> None:
self.cur_win_pos = 0
self.win_sum = 0
self.win_state = [0] * self.win_size_frame
self.pre_frame_state = FrameState.kFrameStateSil
self.cur_frame_state = FrameState.kFrameStateSil
self.voice_last_frame_count = 0
self.noise_last_frame_count = 0
self.hydre_frame_count = 0
def GetWinSize(self) -> int:
return int(self.win_size_frame)
def DetectOneFrame(self, frameState: FrameState, frame_count: int) -> AudioChangeState:
cur_frame_state = FrameState.kFrameStateSil
if frameState == FrameState.kFrameStateSpeech:
cur_frame_state = 1
elif frameState == FrameState.kFrameStateSil:
cur_frame_state = 0
else:
return AudioChangeState.kChangeStateInvalid
self.win_sum -= self.win_state[self.cur_win_pos]
self.win_sum += cur_frame_state
self.win_state[self.cur_win_pos] = cur_frame_state
self.cur_win_pos = (self.cur_win_pos + 1) % self.win_size_frame
if (
self.pre_frame_state == FrameState.kFrameStateSil
and self.win_sum >= self.sil_to_speech_frmcnt_thres
):
self.pre_frame_state = FrameState.kFrameStateSpeech
return AudioChangeState.kChangeStateSil2Speech
if (
self.pre_frame_state == FrameState.kFrameStateSpeech
and self.win_sum <= self.speech_to_sil_frmcnt_thres
):
self.pre_frame_state = FrameState.kFrameStateSil
return AudioChangeState.kChangeStateSpeech2Sil
if self.pre_frame_state == FrameState.kFrameStateSil:
return AudioChangeState.kChangeStateSil2Sil
if self.pre_frame_state == FrameState.kFrameStateSpeech:
return AudioChangeState.kChangeStateSpeech2Speech
return AudioChangeState.kChangeStateInvalid
def FrameSizeMs(self) -> int:
return int(self.frame_size_ms)
class E2EVadModel:
"""
Author: Speech Lab of DAMO Academy, Alibaba Group
Deep-FSMN for Large Vocabulary Continuous Speech Recognition
https://arxiv.org/abs/1803.05030
"""
def __init__(self, vad_post_args: Dict[str, Any]):
super(E2EVadModel, self).__init__()
self.vad_opts = VADXOptions(**vad_post_args)
self.windows_detector = WindowDetector(
self.vad_opts.window_size_ms,
self.vad_opts.sil_to_speech_time_thres,
self.vad_opts.speech_to_sil_time_thres,
self.vad_opts.frame_in_ms,
)
# self.encoder = encoder
# init variables
self.is_final = False
self.data_buf_start_frame = 0
self.frm_cnt = 0
self.latest_confirmed_speech_frame = 0
self.lastest_confirmed_silence_frame = -1
self.continous_silence_frame_count = 0
self.vad_state_machine = VadStateMachine.kVadInStateStartPointNotDetected
self.confirmed_start_frame = -1
self.confirmed_end_frame = -1
self.number_end_time_detected = 0
self.sil_frame = 0
self.sil_pdf_ids = self.vad_opts.sil_pdf_ids
self.noise_average_decibel = -100.0
self.pre_end_silence_detected = False
self.next_seg = True
self.output_data_buf = []
self.output_data_buf_offset = 0
self.frame_probs = []
self.max_end_sil_frame_cnt_thresh = (
self.vad_opts.max_end_silence_time - self.vad_opts.speech_to_sil_time_thres
)
self.speech_noise_thres = self.vad_opts.speech_noise_thres
self.scores = None
self.idx_pre_chunk = 0
self.max_time_out = False
self.decibel = []
self.data_buf_size = 0
self.data_buf_all_size = 0
self.waveform = None
self.ResetDetection()
def AllResetDetection(self):
self.is_final = False
self.data_buf_start_frame = 0
self.frm_cnt = 0
self.latest_confirmed_speech_frame = 0
self.lastest_confirmed_silence_frame = -1
self.continous_silence_frame_count = 0
self.vad_state_machine = VadStateMachine.kVadInStateStartPointNotDetected
self.confirmed_start_frame = -1
self.confirmed_end_frame = -1
self.number_end_time_detected = 0
self.sil_frame = 0
self.sil_pdf_ids = self.vad_opts.sil_pdf_ids
self.noise_average_decibel = -100.0
self.pre_end_silence_detected = False
self.next_seg = True
self.output_data_buf = []
self.output_data_buf_offset = 0
self.frame_probs = []
self.max_end_sil_frame_cnt_thresh = (
self.vad_opts.max_end_silence_time - self.vad_opts.speech_to_sil_time_thres
)
self.speech_noise_thres = self.vad_opts.speech_noise_thres
self.scores = None
self.idx_pre_chunk = 0
self.max_time_out = False
self.decibel = []
self.data_buf_size = 0
self.data_buf_all_size = 0
self.waveform = None
self.ResetDetection()
def ResetDetection(self):
self.continous_silence_frame_count = 0
self.latest_confirmed_speech_frame = 0
self.lastest_confirmed_silence_frame = -1
self.confirmed_start_frame = -1
self.confirmed_end_frame = -1
self.vad_state_machine = VadStateMachine.kVadInStateStartPointNotDetected
self.windows_detector.Reset()
self.sil_frame = 0
self.frame_probs = []
def ComputeDecibel(self) -> None:
frame_sample_length = int(self.vad_opts.frame_length_ms * self.vad_opts.sample_rate / 1000)
frame_shift_length = int(self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000)
if self.data_buf_all_size == 0:
self.data_buf_all_size = len(self.waveform[0])
self.data_buf_size = self.data_buf_all_size
else:
self.data_buf_all_size += len(self.waveform[0])
for offset in range(
0, self.waveform.shape[1] - frame_sample_length + 1, frame_shift_length
):
self.decibel.append(
10
* math.log10(
np.square((self.waveform[0][offset : offset + frame_sample_length])).sum()
+ 0.000001
)
)
def ComputeScores(self, scores: np.ndarray) -> None:
# scores = self.encoder(feats, in_cache) # return B * T * D
self.vad_opts.nn_eval_block_size = scores.shape[1]
self.frm_cnt += scores.shape[1] # count total frames
self.scores = scores
def PopDataBufTillFrame(self, frame_idx: int) -> None: # need check again
while self.data_buf_start_frame < frame_idx:
if self.data_buf_size >= int(
self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000
):
self.data_buf_start_frame += 1
self.data_buf_size = self.data_buf_all_size - self.data_buf_start_frame * int(
self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000
)
def PopDataToOutputBuf(
self,
start_frm: int,
frm_cnt: int,
first_frm_is_start_point: bool,
last_frm_is_end_point: bool,
end_point_is_sent_end: bool,
) -> None:
self.PopDataBufTillFrame(start_frm)
expected_sample_number = int(
frm_cnt * self.vad_opts.sample_rate * self.vad_opts.frame_in_ms / 1000
)
if last_frm_is_end_point:
extra_sample = max(
0,
int(
self.vad_opts.frame_length_ms * self.vad_opts.sample_rate / 1000
- self.vad_opts.sample_rate * self.vad_opts.frame_in_ms / 1000
),
)
expected_sample_number += int(extra_sample)
if end_point_is_sent_end:
expected_sample_number = max(expected_sample_number, self.data_buf_size)
if self.data_buf_size < expected_sample_number:
print("error in calling pop data_buf\n")
if len(self.output_data_buf) == 0 or first_frm_is_start_point:
self.output_data_buf.append(E2EVadSpeechBufWithDoa())
self.output_data_buf[-1].Reset()
self.output_data_buf[-1].start_ms = start_frm * self.vad_opts.frame_in_ms
self.output_data_buf[-1].end_ms = self.output_data_buf[-1].start_ms
self.output_data_buf[-1].doa = 0
cur_seg = self.output_data_buf[-1]
if cur_seg.end_ms != start_frm * self.vad_opts.frame_in_ms:
print("warning\n")
out_pos = len(cur_seg.buffer) # cur_seg.buff现在没做任何操作
data_to_pop = 0
if end_point_is_sent_end:
data_to_pop = expected_sample_number
else:
data_to_pop = int(
frm_cnt * self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000
)
if data_to_pop > self.data_buf_size:
print("VAD data_to_pop is bigger than self.data_buf_size!!!\n")
data_to_pop = self.data_buf_size
expected_sample_number = self.data_buf_size
cur_seg.doa = 0
for sample_cpy_out in range(0, data_to_pop):
# cur_seg.buffer[out_pos ++] = data_buf_.back();
out_pos += 1
for sample_cpy_out in range(data_to_pop, expected_sample_number):
# cur_seg.buffer[out_pos++] = data_buf_.back()
out_pos += 1
if cur_seg.end_ms != start_frm * self.vad_opts.frame_in_ms:
print("Something wrong with the VAD algorithm\n")
self.data_buf_start_frame += frm_cnt
cur_seg.end_ms = (start_frm + frm_cnt) * self.vad_opts.frame_in_ms
if first_frm_is_start_point:
cur_seg.contain_seg_start_point = True
if last_frm_is_end_point:
cur_seg.contain_seg_end_point = True
def OnSilenceDetected(self, valid_frame: int):
self.lastest_confirmed_silence_frame = valid_frame
if self.vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected:
self.PopDataBufTillFrame(valid_frame)
# silence_detected_callback_
# pass
def OnVoiceDetected(self, valid_frame: int) -> None:
self.latest_confirmed_speech_frame = valid_frame
self.PopDataToOutputBuf(valid_frame, 1, False, False, False)
def OnVoiceStart(self, start_frame: int, fake_result: bool = False) -> None:
if self.vad_opts.do_start_point_detection:
pass
if self.confirmed_start_frame != -1:
print("not reset vad properly\n")
else:
self.confirmed_start_frame = start_frame
if (
not fake_result
and self.vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected
):
self.PopDataToOutputBuf(self.confirmed_start_frame, 1, True, False, False)
def OnVoiceEnd(self, end_frame: int, fake_result: bool, is_last_frame: bool) -> None:
for t in range(self.latest_confirmed_speech_frame + 1, end_frame):
self.OnVoiceDetected(t)
if self.vad_opts.do_end_point_detection:
pass
if self.confirmed_end_frame != -1:
print("not reset vad properly\n")
else:
self.confirmed_end_frame = end_frame
if not fake_result:
self.sil_frame = 0
self.PopDataToOutputBuf(self.confirmed_end_frame, 1, False, True, is_last_frame)
self.number_end_time_detected += 1
def MaybeOnVoiceEndIfLastFrame(self, is_final_frame: bool, cur_frm_idx: int) -> None:
if is_final_frame:
self.OnVoiceEnd(cur_frm_idx, False, True)
self.vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
def GetLatency(self) -> int:
return int(self.LatencyFrmNumAtStartPoint() * self.vad_opts.frame_in_ms)
def LatencyFrmNumAtStartPoint(self) -> int:
vad_latency = self.windows_detector.GetWinSize()
if self.vad_opts.do_extend:
vad_latency += int(self.vad_opts.lookback_time_start_point / self.vad_opts.frame_in_ms)
return vad_latency
def GetFrameState(self, t: int) -> FrameState:
frame_state = FrameState.kFrameStateInvalid
cur_decibel = self.decibel[t]
cur_snr = cur_decibel - self.noise_average_decibel
# for each frame, calc log posterior probability of each state
if cur_decibel < self.vad_opts.decibel_thres:
frame_state = FrameState.kFrameStateSil
self.DetectOneFrame(frame_state, t, False)
return frame_state
sum_score = 0.0
noise_prob = 0.0
assert len(self.sil_pdf_ids) == self.vad_opts.silence_pdf_num
if len(self.sil_pdf_ids) > 0:
assert len(self.scores) == 1 # 只支持batch_size = 1的测试
sil_pdf_scores = [
self.scores[0][t - self.idx_pre_chunk][sil_pdf_id]
for sil_pdf_id in self.sil_pdf_ids
]
sum_score = sum(sil_pdf_scores)
noise_prob = math.log(sum_score) * self.vad_opts.speech_2_noise_ratio
total_score = 1.0
sum_score = total_score - sum_score
speech_prob = math.log(sum_score)
if self.vad_opts.output_frame_probs:
frame_prob = E2EVadFrameProb()
frame_prob.noise_prob = noise_prob
frame_prob.speech_prob = speech_prob
frame_prob.score = sum_score
frame_prob.frame_id = t
self.frame_probs.append(frame_prob)
if math.exp(speech_prob) >= math.exp(noise_prob) + self.speech_noise_thres:
if cur_snr >= self.vad_opts.snr_thres and cur_decibel >= self.vad_opts.decibel_thres:
frame_state = FrameState.kFrameStateSpeech
else:
frame_state = FrameState.kFrameStateSil
else:
frame_state = FrameState.kFrameStateSil
if self.noise_average_decibel < -99.9:
self.noise_average_decibel = cur_decibel
else:
self.noise_average_decibel = (
cur_decibel
+ self.noise_average_decibel * (self.vad_opts.noise_frame_num_used_for_snr - 1)
) / self.vad_opts.noise_frame_num_used_for_snr
return frame_state
def __call__(
self,
score: np.ndarray,
waveform: np.ndarray,
is_final: bool = False,
max_end_sil: int = 800,
online: bool = False,
):
self.max_end_sil_frame_cnt_thresh = max_end_sil - self.vad_opts.speech_to_sil_time_thres
self.waveform = waveform # compute decibel for each frame
self.ComputeDecibel()
self.ComputeScores(score)
if not is_final:
self.DetectCommonFrames()
else:
self.DetectLastFrames()
segments = []
for batch_num in range(0, score.shape[0]): # only support batch_size = 1 now
segment_batch = []
if len(self.output_data_buf) > 0:
for i in range(self.output_data_buf_offset, len(self.output_data_buf)):
if online:
if not self.output_data_buf[i].contain_seg_start_point:
continue
if not self.next_seg and not self.output_data_buf[i].contain_seg_end_point:
continue
start_ms = self.output_data_buf[i].start_ms if self.next_seg else -1
if self.output_data_buf[i].contain_seg_end_point:
end_ms = self.output_data_buf[i].end_ms
self.next_seg = True
self.output_data_buf_offset += 1
else:
end_ms = -1
self.next_seg = False
else:
if not is_final and (
not self.output_data_buf[i].contain_seg_start_point
or not self.output_data_buf[i].contain_seg_end_point
):
continue
start_ms = self.output_data_buf[i].start_ms
end_ms = self.output_data_buf[i].end_ms
self.output_data_buf_offset += 1
segment = [start_ms, end_ms]
segment_batch.append(segment)
if segment_batch:
segments.append(segment_batch)
if is_final:
# reset class variables and clear the dict for the next query
self.AllResetDetection()
return segments
def DetectCommonFrames(self) -> int:
if self.vad_state_machine == VadStateMachine.kVadInStateEndPointDetected:
return 0
for i in range(self.vad_opts.nn_eval_block_size - 1, -1, -1):
frame_state = FrameState.kFrameStateInvalid
frame_state = self.GetFrameState(self.frm_cnt - 1 - i)
self.DetectOneFrame(frame_state, self.frm_cnt - 1 - i, False)
self.idx_pre_chunk += self.scores.shape[1]
return 0
def DetectLastFrames(self) -> int:
if self.vad_state_machine == VadStateMachine.kVadInStateEndPointDetected:
return 0
for i in range(self.vad_opts.nn_eval_block_size - 1, -1, -1):
frame_state = FrameState.kFrameStateInvalid
frame_state = self.GetFrameState(self.frm_cnt - 1 - i)
if i != 0:
self.DetectOneFrame(frame_state, self.frm_cnt - 1 - i, False)
else:
self.DetectOneFrame(frame_state, self.frm_cnt - 1, True)
return 0
def DetectOneFrame(
self, cur_frm_state: FrameState, cur_frm_idx: int, is_final_frame: bool
) -> None:
tmp_cur_frm_state = FrameState.kFrameStateInvalid
if cur_frm_state == FrameState.kFrameStateSpeech:
if math.fabs(1.0) > self.vad_opts.fe_prior_thres:
tmp_cur_frm_state = FrameState.kFrameStateSpeech
else:
tmp_cur_frm_state = FrameState.kFrameStateSil
elif cur_frm_state == FrameState.kFrameStateSil:
tmp_cur_frm_state = FrameState.kFrameStateSil
state_change = self.windows_detector.DetectOneFrame(tmp_cur_frm_state, cur_frm_idx)
frm_shift_in_ms = self.vad_opts.frame_in_ms
if AudioChangeState.kChangeStateSil2Speech == state_change:
silence_frame_count = self.continous_silence_frame_count
self.continous_silence_frame_count = 0
self.pre_end_silence_detected = False
start_frame = 0
if self.vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected:
start_frame = max(
self.data_buf_start_frame, cur_frm_idx - self.LatencyFrmNumAtStartPoint()
)
self.OnVoiceStart(start_frame)
self.vad_state_machine = VadStateMachine.kVadInStateInSpeechSegment
for t in range(start_frame + 1, cur_frm_idx + 1):
self.OnVoiceDetected(t)
elif self.vad_state_machine == VadStateMachine.kVadInStateInSpeechSegment:
for t in range(self.latest_confirmed_speech_frame + 1, cur_frm_idx):
self.OnVoiceDetected(t)
if (
cur_frm_idx - self.confirmed_start_frame + 1
> self.vad_opts.max_single_segment_time / frm_shift_in_ms
):
self.OnVoiceEnd(cur_frm_idx, False, False)
self.vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
elif not is_final_frame:
self.OnVoiceDetected(cur_frm_idx)
else:
self.MaybeOnVoiceEndIfLastFrame(is_final_frame, cur_frm_idx)
else:
pass
elif AudioChangeState.kChangeStateSpeech2Sil == state_change:
self.continous_silence_frame_count = 0
if self.vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected:
pass
elif self.vad_state_machine == VadStateMachine.kVadInStateInSpeechSegment:
if (
cur_frm_idx - self.confirmed_start_frame + 1
> self.vad_opts.max_single_segment_time / frm_shift_in_ms
):
self.OnVoiceEnd(cur_frm_idx, False, False)
self.vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
elif not is_final_frame:
self.OnVoiceDetected(cur_frm_idx)
else:
self.MaybeOnVoiceEndIfLastFrame(is_final_frame, cur_frm_idx)
else:
pass
elif AudioChangeState.kChangeStateSpeech2Speech == state_change:
self.continous_silence_frame_count = 0
if self.vad_state_machine == VadStateMachine.kVadInStateInSpeechSegment:
if (
cur_frm_idx - self.confirmed_start_frame + 1
> self.vad_opts.max_single_segment_time / frm_shift_in_ms
):
self.max_time_out = True
self.OnVoiceEnd(cur_frm_idx, False, False)
self.vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
elif not is_final_frame:
self.OnVoiceDetected(cur_frm_idx)
else:
self.MaybeOnVoiceEndIfLastFrame(is_final_frame, cur_frm_idx)
else:
pass
elif AudioChangeState.kChangeStateSil2Sil == state_change:
self.continous_silence_frame_count += 1
if self.vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected:
# silence timeout, return zero length decision
if (
(self.vad_opts.detect_mode == VadDetectMode.kVadSingleUtteranceDetectMode.value)
and (
self.continous_silence_frame_count * frm_shift_in_ms
> self.vad_opts.max_start_silence_time
)
) or (is_final_frame and self.number_end_time_detected == 0):
for t in range(self.lastest_confirmed_silence_frame + 1, cur_frm_idx):
self.OnSilenceDetected(t)
self.OnVoiceStart(0, True)
self.OnVoiceEnd(0, True, False)
self.vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
else:
if cur_frm_idx >= self.LatencyFrmNumAtStartPoint():
self.OnSilenceDetected(cur_frm_idx - self.LatencyFrmNumAtStartPoint())
elif self.vad_state_machine == VadStateMachine.kVadInStateInSpeechSegment:
if (
self.continous_silence_frame_count * frm_shift_in_ms
>= self.max_end_sil_frame_cnt_thresh
):
lookback_frame = int(self.max_end_sil_frame_cnt_thresh / frm_shift_in_ms)
if self.vad_opts.do_extend:
lookback_frame -= int(
self.vad_opts.lookahead_time_end_point / frm_shift_in_ms
)
lookback_frame -= 1
lookback_frame = max(0, lookback_frame)
self.OnVoiceEnd(cur_frm_idx - lookback_frame, False, False)
self.vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
elif (
cur_frm_idx - self.confirmed_start_frame + 1
> self.vad_opts.max_single_segment_time / frm_shift_in_ms
):
self.OnVoiceEnd(cur_frm_idx, False, False)
self.vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
elif self.vad_opts.do_extend and not is_final_frame:
if self.continous_silence_frame_count <= int(
self.vad_opts.lookahead_time_end_point / frm_shift_in_ms
):
self.OnVoiceDetected(cur_frm_idx)
else:
self.MaybeOnVoiceEndIfLastFrame(is_final_frame, cur_frm_idx)
else:
pass
if (
self.vad_state_machine == VadStateMachine.kVadInStateEndPointDetected
and self.vad_opts.detect_mode == VadDetectMode.kVadMutipleUtteranceDetectMode.value
):
self.ResetDetection()

View File

@@ -0,0 +1,433 @@
# -*- encoding: utf-8 -*-
from pathlib import Path
from typing import Any, Dict, Iterable, List, NamedTuple, Set, Tuple, Union
import copy
import numpy as np
import kaldi_native_fbank as knf
root_dir = Path(__file__).resolve().parent
logger_initialized = {}
class WavFrontend:
"""Conventional frontend structure for ASR."""
def __init__(
self,
cmvn_file: str = None,
fs: int = 16000,
window: str = "hamming",
n_mels: int = 80,
frame_length: int = 25,
frame_shift: int = 10,
lfr_m: int = 1,
lfr_n: int = 1,
dither: float = 1.0,
**kwargs,
) -> None:
opts = knf.FbankOptions()
opts.frame_opts.samp_freq = fs
opts.frame_opts.dither = dither
opts.frame_opts.window_type = window
opts.frame_opts.frame_shift_ms = float(frame_shift)
opts.frame_opts.frame_length_ms = float(frame_length)
opts.mel_opts.num_bins = n_mels
opts.energy_floor = 0
opts.frame_opts.snip_edges = True
opts.mel_opts.debug_mel = False
self.opts = opts
self.lfr_m = lfr_m
self.lfr_n = lfr_n
self.cmvn_file = cmvn_file
if self.cmvn_file:
self.cmvn = self.load_cmvn()
self.fbank_fn = None
self.fbank_beg_idx = 0
self.reset_status()
def fbank(self, waveform: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
waveform = waveform * (1 << 15)
self.fbank_fn = knf.OnlineFbank(self.opts)
self.fbank_fn.accept_waveform(self.opts.frame_opts.samp_freq, waveform.tolist())
frames = self.fbank_fn.num_frames_ready
mat = np.empty([frames, self.opts.mel_opts.num_bins])
for i in range(frames):
mat[i, :] = self.fbank_fn.get_frame(i)
feat = mat.astype(np.float32)
feat_len = np.array(mat.shape[0]).astype(np.int32)
return feat, feat_len
def fbank_online(self, waveform: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
waveform = waveform * (1 << 15)
# self.fbank_fn = knf.OnlineFbank(self.opts)
self.fbank_fn.accept_waveform(self.opts.frame_opts.samp_freq, waveform.tolist())
frames = self.fbank_fn.num_frames_ready
mat = np.empty([frames, self.opts.mel_opts.num_bins])
for i in range(self.fbank_beg_idx, frames):
mat[i, :] = self.fbank_fn.get_frame(i)
# self.fbank_beg_idx += (frames-self.fbank_beg_idx)
feat = mat.astype(np.float32)
feat_len = np.array(mat.shape[0]).astype(np.int32)
return feat, feat_len
def reset_status(self):
self.fbank_fn = knf.OnlineFbank(self.opts)
self.fbank_beg_idx = 0
def lfr_cmvn(self, feat: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
if self.lfr_m != 1 or self.lfr_n != 1:
feat = self.apply_lfr(feat, self.lfr_m, self.lfr_n)
if self.cmvn_file:
feat = self.apply_cmvn(feat)
feat_len = np.array(feat.shape[0]).astype(np.int32)
return feat, feat_len
@staticmethod
def apply_lfr(inputs: np.ndarray, lfr_m: int, lfr_n: int) -> np.ndarray:
LFR_inputs = []
T = inputs.shape[0]
T_lfr = int(np.ceil(T / lfr_n))
left_padding = np.tile(inputs[0], ((lfr_m - 1) // 2, 1))
inputs = np.vstack((left_padding, inputs))
T = T + (lfr_m - 1) // 2
for i in range(T_lfr):
if lfr_m <= T - i * lfr_n:
LFR_inputs.append((inputs[i * lfr_n : i * lfr_n + lfr_m]).reshape(1, -1))
else:
# process last LFR frame
num_padding = lfr_m - (T - i * lfr_n)
frame = inputs[i * lfr_n :].reshape(-1)
for _ in range(num_padding):
frame = np.hstack((frame, inputs[-1]))
LFR_inputs.append(frame)
LFR_outputs = np.vstack(LFR_inputs).astype(np.float32)
return LFR_outputs
def apply_cmvn(self, inputs: np.ndarray) -> np.ndarray:
"""
Apply CMVN with mvn data
"""
frame, dim = inputs.shape
means = np.tile(self.cmvn[0:1, :dim], (frame, 1))
vars = np.tile(self.cmvn[1:2, :dim], (frame, 1))
inputs = (inputs + means) * vars
return inputs
def load_cmvn(
self,
) -> np.ndarray:
with open(self.cmvn_file, "r", encoding="utf-8") as f:
lines = f.readlines()
means_list = []
vars_list = []
for i in range(len(lines)):
line_item = lines[i].split()
if line_item[0] == "<AddShift>":
line_item = lines[i + 1].split()
if line_item[0] == "<LearnRateCoef>":
add_shift_line = line_item[3 : (len(line_item) - 1)]
means_list = list(add_shift_line)
continue
elif line_item[0] == "<Rescale>":
line_item = lines[i + 1].split()
if line_item[0] == "<LearnRateCoef>":
rescale_line = line_item[3 : (len(line_item) - 1)]
vars_list = list(rescale_line)
continue
means = np.array(means_list).astype(np.float64)
vars = np.array(vars_list).astype(np.float64)
cmvn = np.array([means, vars])
return cmvn
class WavFrontendOnline(WavFrontend):
def __init__(self, **kwargs):
super().__init__(**kwargs)
# self.fbank_fn = knf.OnlineFbank(self.opts)
# add variables
self.frame_sample_length = int(
self.opts.frame_opts.frame_length_ms * self.opts.frame_opts.samp_freq / 1000
)
self.frame_shift_sample_length = int(
self.opts.frame_opts.frame_shift_ms * self.opts.frame_opts.samp_freq / 1000
)
self.waveform = None
self.reserve_waveforms = None
self.input_cache = None
self.lfr_splice_cache = []
@staticmethod
# inputs has catted the cache
def apply_lfr(
inputs: np.ndarray, lfr_m: int, lfr_n: int, is_final: bool = False
) -> Tuple[np.ndarray, np.ndarray, int]:
"""
Apply lfr with data
"""
LFR_inputs = []
T = inputs.shape[0] # include the right context
T_lfr = int(
np.ceil((T - (lfr_m - 1) // 2) / lfr_n)
) # minus the right context: (lfr_m - 1) // 2
splice_idx = T_lfr
for i in range(T_lfr):
if lfr_m <= T - i * lfr_n:
LFR_inputs.append((inputs[i * lfr_n : i * lfr_n + lfr_m]).reshape(1, -1))
else: # process last LFR frame
if is_final:
num_padding = lfr_m - (T - i * lfr_n)
frame = (inputs[i * lfr_n :]).reshape(-1)
for _ in range(num_padding):
frame = np.hstack((frame, inputs[-1]))
LFR_inputs.append(frame)
else:
# update splice_idx and break the circle
splice_idx = i
break
splice_idx = min(T - 1, splice_idx * lfr_n)
lfr_splice_cache = inputs[splice_idx:, :]
LFR_outputs = np.vstack(LFR_inputs)
return LFR_outputs.astype(np.float32), lfr_splice_cache, splice_idx
@staticmethod
def compute_frame_num(
sample_length: int, frame_sample_length: int, frame_shift_sample_length: int
) -> int:
frame_num = int((sample_length - frame_sample_length) / frame_shift_sample_length + 1)
return frame_num if frame_num >= 1 and sample_length >= frame_sample_length else 0
def fbank(
self, input: np.ndarray, input_lengths: np.ndarray
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
self.fbank_fn = knf.OnlineFbank(self.opts)
batch_size = input.shape[0]
if self.input_cache is None:
self.input_cache = np.empty((batch_size, 0), dtype=np.float32)
input = np.concatenate((self.input_cache, input), axis=1)
frame_num = self.compute_frame_num(
input.shape[-1], self.frame_sample_length, self.frame_shift_sample_length
)
# update self.in_cache
self.input_cache = input[
:, -(input.shape[-1] - frame_num * self.frame_shift_sample_length) :
]
waveforms = np.empty(0, dtype=np.float32)
feats_pad = np.empty(0, dtype=np.float32)
feats_lens = np.empty(0, dtype=np.int32)
if frame_num:
waveforms = []
feats = []
feats_lens = []
for i in range(batch_size):
waveform = input[i]
waveforms.append(
waveform[
: (
(frame_num - 1) * self.frame_shift_sample_length
+ self.frame_sample_length
)
]
)
waveform = waveform * (1 << 15)
self.fbank_fn.accept_waveform(self.opts.frame_opts.samp_freq, waveform.tolist())
frames = self.fbank_fn.num_frames_ready
mat = np.empty([frames, self.opts.mel_opts.num_bins])
for i in range(frames):
mat[i, :] = self.fbank_fn.get_frame(i)
feat = mat.astype(np.float32)
feat_len = np.array(mat.shape[0]).astype(np.int32)
feats.append(feat)
feats_lens.append(feat_len)
waveforms = np.stack(waveforms)
feats_lens = np.array(feats_lens)
feats_pad = np.array(feats)
self.fbanks = feats_pad
self.fbanks_lens = copy.deepcopy(feats_lens)
return waveforms, feats_pad, feats_lens
def get_fbank(self) -> Tuple[np.ndarray, np.ndarray]:
return self.fbanks, self.fbanks_lens
def lfr_cmvn(
self, input: np.ndarray, input_lengths: np.ndarray, is_final: bool = False
) -> Tuple[np.ndarray, np.ndarray, List[int]]:
batch_size = input.shape[0]
feats = []
feats_lens = []
lfr_splice_frame_idxs = []
for i in range(batch_size):
mat = input[i, : input_lengths[i], :]
lfr_splice_frame_idx = -1
if self.lfr_m != 1 or self.lfr_n != 1:
# update self.lfr_splice_cache in self.apply_lfr
mat, self.lfr_splice_cache[i], lfr_splice_frame_idx = self.apply_lfr(
mat, self.lfr_m, self.lfr_n, is_final
)
if self.cmvn_file is not None:
mat = self.apply_cmvn(mat)
feat_length = mat.shape[0]
feats.append(mat)
feats_lens.append(feat_length)
lfr_splice_frame_idxs.append(lfr_splice_frame_idx)
feats_lens = np.array(feats_lens)
feats_pad = np.array(feats)
return feats_pad, feats_lens, lfr_splice_frame_idxs
def extract_fbank(
self, input: np.ndarray, input_lengths: np.ndarray, is_final: bool = False
) -> Tuple[np.ndarray, np.ndarray]:
batch_size = input.shape[0]
assert (
batch_size == 1
), "we support to extract feature online only when the batch size is equal to 1 now"
waveforms, feats, feats_lengths = self.fbank(input, input_lengths) # input shape: B T D
if feats.shape[0]:
self.waveforms = (
waveforms
if self.reserve_waveforms is None
else np.concatenate((self.reserve_waveforms, waveforms), axis=1)
)
if not self.lfr_splice_cache:
for i in range(batch_size):
self.lfr_splice_cache.append(
np.expand_dims(feats[i][0, :], axis=0).repeat((self.lfr_m - 1) // 2, axis=0)
)
if feats_lengths[0] + self.lfr_splice_cache[0].shape[0] >= self.lfr_m:
lfr_splice_cache_np = np.stack(self.lfr_splice_cache) # B T D
feats = np.concatenate((lfr_splice_cache_np, feats), axis=1)
feats_lengths += lfr_splice_cache_np[0].shape[0]
frame_from_waveforms = int(
(self.waveforms.shape[1] - self.frame_sample_length)
/ self.frame_shift_sample_length
+ 1
)
minus_frame = (self.lfr_m - 1) // 2 if self.reserve_waveforms is None else 0
feats, feats_lengths, lfr_splice_frame_idxs = self.lfr_cmvn(
feats, feats_lengths, is_final
)
if self.lfr_m == 1:
self.reserve_waveforms = None
else:
reserve_frame_idx = lfr_splice_frame_idxs[0] - minus_frame
# print('reserve_frame_idx: ' + str(reserve_frame_idx))
# print('frame_frame: ' + str(frame_from_waveforms))
self.reserve_waveforms = self.waveforms[
:,
reserve_frame_idx
* self.frame_shift_sample_length : frame_from_waveforms
* self.frame_shift_sample_length,
]
sample_length = (
frame_from_waveforms - 1
) * self.frame_shift_sample_length + self.frame_sample_length
self.waveforms = self.waveforms[:, :sample_length]
else:
# update self.reserve_waveforms and self.lfr_splice_cache
self.reserve_waveforms = self.waveforms[
:, : -(self.frame_sample_length - self.frame_shift_sample_length)
]
for i in range(batch_size):
self.lfr_splice_cache[i] = np.concatenate(
(self.lfr_splice_cache[i], feats[i]), axis=0
)
return np.empty(0, dtype=np.float32), feats_lengths
else:
if is_final:
self.waveforms = (
waveforms if self.reserve_waveforms is None else self.reserve_waveforms
)
feats = np.stack(self.lfr_splice_cache)
feats_lengths = np.zeros(batch_size, dtype=np.int32) + feats.shape[1]
feats, feats_lengths, _ = self.lfr_cmvn(feats, feats_lengths, is_final)
if is_final:
self.cache_reset()
return feats, feats_lengths
def get_waveforms(self):
return self.waveforms
def cache_reset(self):
self.fbank_fn = knf.OnlineFbank(self.opts)
self.reserve_waveforms = None
self.input_cache = None
self.lfr_splice_cache = []
def load_bytes(input):
middle_data = np.frombuffer(input, dtype=np.int16)
middle_data = np.asarray(middle_data)
if middle_data.dtype.kind not in "iu":
raise TypeError("'middle_data' must be an array of integers")
dtype = np.dtype("float32")
if dtype.kind != "f":
raise TypeError("'dtype' must be a floating point type")
i = np.iinfo(middle_data.dtype)
abs_max = 2 ** (i.bits - 1)
offset = i.min + abs_max
array = np.frombuffer((middle_data.astype(dtype) - offset) / abs_max, dtype=np.float32)
return array
class SinusoidalPositionEncoderOnline:
"""Streaming Positional encoding."""
def encode(self, positions: np.ndarray = None, depth: int = None, dtype: np.dtype = np.float32):
batch_size = positions.shape[0]
positions = positions.astype(dtype)
log_timescale_increment = np.log(np.array([10000], dtype=dtype)) / (depth / 2 - 1)
inv_timescales = np.exp(np.arange(depth / 2).astype(dtype) * (-log_timescale_increment))
inv_timescales = np.reshape(inv_timescales, [batch_size, -1])
scaled_time = np.reshape(positions, [1, -1, 1]) * np.reshape(inv_timescales, [1, 1, -1])
encoding = np.concatenate((np.sin(scaled_time), np.cos(scaled_time)), axis=2)
return encoding.astype(dtype)
def forward(self, x, start_idx=0):
batch_size, timesteps, input_dim = x.shape
positions = np.arange(1, timesteps + 1 + start_idx)[None, :]
position_encoding = self.encode(positions, input_dim, x.dtype)
return x + position_encoding[:, start_idx : start_idx + timesteps]
def test():
path = "/nfs/zhifu.gzf/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/example/asr_example.wav"
import librosa
cmvn_file = "/nfs/zhifu.gzf/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/am.mvn"
config_file = "/nfs/zhifu.gzf/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/config.yaml"
from funasr.runtime.python.onnxruntime.rapid_paraformer.utils.utils import read_yaml
config = read_yaml(config_file)
waveform, _ = librosa.load(path, sr=None)
frontend = WavFrontend(
cmvn_file=cmvn_file,
**config["frontend_conf"],
)
speech, _ = frontend.fbank_online(waveform) # 1d, (sample,), numpy
feat, feat_len = frontend.lfr_cmvn(
speech
) # 2d, (frame, 450), np.float32 -> torch, torch.from_numpy(), dtype, (1, frame, 450)
frontend.reset_status() # clear cache
return feat, feat_len
if __name__ == "__main__":
test()

View File

@@ -0,0 +1,418 @@
# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
# MIT License (https://opensource.org/licenses/MIT)
import string
import logging
from typing import Any, List, Union
def isChinese(ch: str):
if "\u4e00" <= ch <= "\u9fff" or "\u0030" <= ch <= "\u0039":
return True
return False
def isAllChinese(word: Union[List[Any], str]):
word_lists = []
for i in word:
cur = i.replace(" ", "")
cur = cur.replace("</s>", "")
cur = cur.replace("<s>", "")
word_lists.append(cur)
if len(word_lists) == 0:
return False
for ch in word_lists:
if isChinese(ch) is False:
return False
return True
def isAllAlpha(word: Union[List[Any], str]):
word_lists = []
for i in word:
cur = i.replace(" ", "")
cur = cur.replace("</s>", "")
cur = cur.replace("<s>", "")
word_lists.append(cur)
if len(word_lists) == 0:
return False
for ch in word_lists:
if ch.isalpha() is False and ch != "'":
return False
elif ch.isalpha() is True and isChinese(ch) is True:
return False
return True
# def abbr_dispose(words: List[Any]) -> List[Any]:
def abbr_dispose(words: List[Any], time_stamp: List[List] = None) -> List[Any]:
words_size = len(words)
word_lists = []
abbr_begin = []
abbr_end = []
last_num = -1
ts_lists = []
ts_nums = []
ts_index = 0
for num in range(words_size):
if num <= last_num:
continue
if len(words[num]) == 1 and words[num].encode("utf-8").isalpha():
if (
num + 1 < words_size
and words[num + 1] == " "
and num + 2 < words_size
and len(words[num + 2]) == 1
and words[num + 2].encode("utf-8").isalpha()
):
# found the begin of abbr
abbr_begin.append(num)
num += 2
abbr_end.append(num)
# to find the end of abbr
while True:
num += 1
if num < words_size and words[num] == " ":
num += 1
if (
num < words_size
and len(words[num]) == 1
and words[num].encode("utf-8").isalpha()
):
abbr_end.pop()
abbr_end.append(num)
last_num = num
else:
break
else:
break
for num in range(words_size):
if words[num] == " ":
ts_nums.append(ts_index)
else:
ts_nums.append(ts_index)
ts_index += 1
last_num = -1
for num in range(words_size):
if num <= last_num:
continue
if num in abbr_begin:
if time_stamp is not None:
begin = time_stamp[ts_nums[num]][0]
word_lists.append(words[num].upper())
num += 1
while num < words_size:
if num in abbr_end:
word_lists.append(words[num].upper())
last_num = num
break
else:
if words[num].encode("utf-8").isalpha():
word_lists.append(words[num].upper())
num += 1
if time_stamp is not None:
end = time_stamp[ts_nums[num]][1]
ts_lists.append([begin, end])
else:
word_lists.append(words[num])
if time_stamp is not None and words[num] != " ":
begin = time_stamp[ts_nums[num]][0]
end = time_stamp[ts_nums[num]][1]
ts_lists.append([begin, end])
begin = end
if time_stamp is not None:
return word_lists, ts_lists
else:
return word_lists
def sentence_postprocess(words: List[Any], time_stamp: List[List] = None):
middle_lists = []
word_lists = []
word_item = ""
ts_lists = []
# wash words lists
for i in words:
word = ""
if isinstance(i, str):
word = i
else:
word = i.decode("utf-8")
if word in ["<s>", "</s>", "<unk>"]:
continue
else:
middle_lists.append(word)
# all chinese characters
if isAllChinese(middle_lists):
for i, ch in enumerate(middle_lists):
word_lists.append(ch.replace(" ", ""))
if time_stamp is not None:
ts_lists = time_stamp
# all alpha characters
elif isAllAlpha(middle_lists):
ts_flag = True
for i, ch in enumerate(middle_lists):
if ts_flag and time_stamp is not None:
begin = time_stamp[i][0]
end = time_stamp[i][1]
word = ""
if "@@" in ch:
word = ch.replace("@@", "")
word_item += word
if time_stamp is not None:
ts_flag = False
end = time_stamp[i][1]
else:
word_item += ch
word_lists.append(word_item)
word_lists.append(" ")
word_item = ""
if time_stamp is not None:
ts_flag = True
end = time_stamp[i][1]
ts_lists.append([begin, end])
begin = end
# mix characters
else:
alpha_blank = False
ts_flag = True
begin = -1
end = -1
for i, ch in enumerate(middle_lists):
if ts_flag and time_stamp is not None:
begin = time_stamp[i][0]
end = time_stamp[i][1]
word = ""
if isAllChinese(ch):
if alpha_blank is True:
word_lists.pop()
word_lists.append(ch)
alpha_blank = False
if time_stamp is not None:
ts_flag = True
ts_lists.append([begin, end])
begin = end
elif "@@" in ch:
word = ch.replace("@@", "")
word_item += word
alpha_blank = False
if time_stamp is not None:
ts_flag = False
end = time_stamp[i][1]
elif isAllAlpha(ch):
word_item += ch
word_lists.append(word_item)
word_lists.append(" ")
word_item = ""
alpha_blank = True
if time_stamp is not None:
ts_flag = True
end = time_stamp[i][1]
ts_lists.append([begin, end])
begin = end
else:
raise ValueError("invalid character: {}".format(ch))
if time_stamp is not None:
word_lists, ts_lists = abbr_dispose(word_lists, ts_lists)
real_word_lists = []
for ch in word_lists:
if ch != " ":
real_word_lists.append(ch)
sentence = " ".join(real_word_lists).strip()
return sentence, ts_lists, real_word_lists
else:
word_lists = abbr_dispose(word_lists)
real_word_lists = []
for ch in word_lists:
if ch != " ":
real_word_lists.append(ch)
sentence = "".join(word_lists).strip()
return sentence, real_word_lists
def sentence_postprocess_sentencepiece(words):
middle_lists = []
word_lists = []
word_item = ""
# wash words lists
for i in words:
word = ""
if isinstance(i, str):
word = i
else:
word = i.decode("utf-8")
if word in ["<s>", "</s>", "<unk>", "<OOV>"]:
continue
else:
middle_lists.append(word)
# all alpha characters
for i, ch in enumerate(middle_lists):
word = ""
if "\u2581" in ch and i == 0:
word_item = ""
word = ch.replace("\u2581", "")
word_item += word
elif "\u2581" in ch and i != 0:
word_lists.append(word_item)
word_lists.append(" ")
word_item = ""
word = ch.replace("\u2581", "")
word_item += word
else:
word_item += ch
if word_item is not None:
word_lists.append(word_item)
# word_lists = abbr_dispose(word_lists)
real_word_lists = []
for ch in word_lists:
if ch != " ":
if ch == "i":
ch = ch.replace("i", "I")
elif ch == "i'm":
ch = ch.replace("i'm", "I'm")
elif ch == "i've":
ch = ch.replace("i've", "I've")
elif ch == "i'll":
ch = ch.replace("i'll", "I'll")
real_word_lists.append(ch)
sentence = "".join(word_lists)
return sentence, real_word_lists
emo_dict = {
"<|HAPPY|>": "😊",
"<|SAD|>": "😔",
"<|ANGRY|>": "😡",
"<|NEUTRAL|>": "",
"<|FEARFUL|>": "😰",
"<|DISGUSTED|>": "🤢",
"<|SURPRISED|>": "😮",
}
event_dict = {
"<|BGM|>": "🎼",
"<|Speech|>": "",
"<|Applause|>": "👏",
"<|Laughter|>": "😀",
"<|Cry|>": "😭",
"<|Sneeze|>": "🤧",
"<|Breath|>": "",
"<|Cough|>": "🤧",
}
lang_dict = {
"<|zh|>": "<|lang|>",
"<|en|>": "<|lang|>",
"<|yue|>": "<|lang|>",
"<|ja|>": "<|lang|>",
"<|ko|>": "<|lang|>",
"<|nospeech|>": "<|lang|>",
}
emoji_dict = {
"<|nospeech|><|Event_UNK|>": "",
"<|zh|>": "",
"<|en|>": "",
"<|yue|>": "",
"<|ja|>": "",
"<|ko|>": "",
"<|nospeech|>": "",
"<|HAPPY|>": "😊",
"<|SAD|>": "😔",
"<|ANGRY|>": "😡",
"<|NEUTRAL|>": "",
"<|BGM|>": "🎼",
"<|Speech|>": "",
"<|Applause|>": "👏",
"<|Laughter|>": "😀",
"<|FEARFUL|>": "😰",
"<|DISGUSTED|>": "🤢",
"<|SURPRISED|>": "😮",
"<|Cry|>": "😭",
"<|EMO_UNKNOWN|>": "",
"<|Sneeze|>": "🤧",
"<|Breath|>": "",
"<|Cough|>": "😷",
"<|Sing|>": "",
"<|Speech_Noise|>": "",
"<|withitn|>": "",
"<|woitn|>": "",
"<|GBG|>": "",
"<|Event_UNK|>": "",
}
emo_set = {"😊", "😔", "😡", "😰", "🤢", "😮"}
event_set = {
"🎼",
"👏",
"😀",
"😭",
"🤧",
"😷",
}
def format_str_v2(s):
sptk_dict = {}
for sptk in emoji_dict:
sptk_dict[sptk] = s.count(sptk)
s = s.replace(sptk, "")
emo = "<|NEUTRAL|>"
for e in emo_dict:
if sptk_dict[e] > sptk_dict[emo]:
emo = e
for e in event_dict:
if sptk_dict[e] > 0:
s = event_dict[e] + s
s = s + emo_dict[emo]
for emoji in emo_set.union(event_set):
s = s.replace(" " + emoji, emoji)
s = s.replace(emoji + " ", emoji)
return s.strip()
def rich_transcription_postprocess(s):
def get_emo(s):
return s[-1] if s[-1] in emo_set else None
def get_event(s):
return s[0] if s[0] in event_set else None
s = s.replace("<|nospeech|><|Event_UNK|>", "")
for lang in lang_dict:
s = s.replace(lang, "<|lang|>")
s_list = [format_str_v2(s_i).strip(" ") for s_i in s.split("<|lang|>")]
new_s = " " + s_list[0]
cur_ent_event = get_event(new_s)
for i in range(1, len(s_list)):
if len(s_list[i]) == 0:
continue
if get_event(s_list[i]) == cur_ent_event and get_event(s_list[i]) != None:
s_list[i] = s_list[i][1:]
# else:
cur_ent_event = get_event(s_list[i])
if get_emo(s_list[i]) != None and get_emo(s_list[i]) == get_emo(new_s):
new_s = new_s[:-1]
new_s += s_list[i].strip().lstrip()
new_s = new_s.replace("The.", " ")
return new_s.strip()

View File

@@ -0,0 +1,53 @@
from pathlib import Path
from typing import Iterable
from typing import List
from typing import Union
import sentencepiece as spm
class SentencepiecesTokenizer:
def __init__(self, bpemodel: Union[Path, str], **kwargs):
super().__init__(**kwargs)
self.bpemodel = str(bpemodel)
# NOTE(kamo):
# Don't build SentencePieceProcessor in __init__()
# because it's not picklable and it may cause following error,
# "TypeError: can't pickle SwigPyObject objects",
# when giving it as argument of "multiprocessing.Process()".
self.sp = None
self._build_sentence_piece_processor()
def __repr__(self):
return f'{self.__class__.__name__}(model="{self.bpemodel}")'
def _build_sentence_piece_processor(self):
# Build SentencePieceProcessor lazily.
if self.sp is None:
self.sp = spm.SentencePieceProcessor()
self.sp.load(self.bpemodel)
def text2tokens(self, line: str) -> List[str]:
self._build_sentence_piece_processor()
return self.sp.EncodeAsPieces(line)
def tokens2text(self, tokens: Iterable[str]) -> str:
self._build_sentence_piece_processor()
return self.sp.DecodePieces(list(tokens))
def encode(self, line: str, **kwargs) -> List[int]:
self._build_sentence_piece_processor()
return self.sp.EncodeAsIds(line)
def decode(self, line: List[int], **kwargs):
self._build_sentence_piece_processor()
return self.sp.DecodeIds(line)
def get_vocab_size(self):
return self.sp.GetPieceSize()
def ids2tokens(self, *args, **kwargs):
return self.decode(*args, **kwargs)
def tokens2ids(self, *args, **kwargs):
return self.encode(*args, **kwargs)

View File

@@ -0,0 +1,66 @@
# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
# MIT License (https://opensource.org/licenses/MIT)
import numpy as np
def time_stamp_lfr6_onnx(us_cif_peak, char_list, begin_time=0.0, total_offset=-1.5):
if not len(char_list):
return "", []
START_END_THRESHOLD = 5
MAX_TOKEN_DURATION = 30
TIME_RATE = 10.0 * 6 / 1000 / 3 # 3 times upsampled
cif_peak = us_cif_peak.reshape(-1)
num_frames = cif_peak.shape[-1]
if char_list[-1] == "</s>":
char_list = char_list[:-1]
# char_list = [i for i in text]
timestamp_list = []
new_char_list = []
# for bicif model trained with large data, cif2 actually fires when a character starts
# so treat the frames between two peaks as the duration of the former token
fire_place = np.where(cif_peak > 1.0 - 1e-4)[0] + total_offset # np format
num_peak = len(fire_place)
assert num_peak == len(char_list) + 1 # number of peaks is supposed to be number of tokens + 1
# begin silence
if fire_place[0] > START_END_THRESHOLD:
# char_list.insert(0, '<sil>')
timestamp_list.append([0.0, fire_place[0] * TIME_RATE])
new_char_list.append("<sil>")
# tokens timestamp
for i in range(len(fire_place) - 1):
new_char_list.append(char_list[i])
if (
i == len(fire_place) - 2
or MAX_TOKEN_DURATION < 0
or fire_place[i + 1] - fire_place[i] < MAX_TOKEN_DURATION
):
timestamp_list.append([fire_place[i] * TIME_RATE, fire_place[i + 1] * TIME_RATE])
else:
# cut the duration to token and sil of the 0-weight frames last long
_split = fire_place[i] + MAX_TOKEN_DURATION
timestamp_list.append([fire_place[i] * TIME_RATE, _split * TIME_RATE])
timestamp_list.append([_split * TIME_RATE, fire_place[i + 1] * TIME_RATE])
new_char_list.append("<sil>")
# tail token and end silence
if num_frames - fire_place[-1] > START_END_THRESHOLD:
_end = (num_frames + fire_place[-1]) / 2
timestamp_list[-1][1] = _end * TIME_RATE
timestamp_list.append([_end * TIME_RATE, num_frames * TIME_RATE])
new_char_list.append("<sil>")
else:
timestamp_list[-1][1] = num_frames * TIME_RATE
if begin_time: # add offset time in model with vad
for i in range(len(timestamp_list)):
timestamp_list[i][0] = timestamp_list[i][0] + begin_time / 1000.0
timestamp_list[i][1] = timestamp_list[i][1] + begin_time / 1000.0
assert len(new_char_list) == len(timestamp_list)
res_str = ""
for char, timestamp in zip(new_char_list, timestamp_list):
res_str += "{} {} {};".format(char, timestamp[0], timestamp[1])
res = []
for char, timestamp in zip(new_char_list, timestamp_list):
if char != "<sil>":
res.append([int(timestamp[0] * 1000), int(timestamp[1] * 1000)])
return res_str, res

View File

@@ -0,0 +1,395 @@
# -*- encoding: utf-8 -*-
import functools
import logging
from pathlib import Path
from typing import Any, Dict, Iterable, List, NamedTuple, Set, Tuple, Union
import re
import numpy as np
import yaml
try:
from onnxruntime import (
GraphOptimizationLevel,
InferenceSession,
SessionOptions,
get_available_providers,
get_device,
)
except:
print("please pip3 install onnxruntime")
import jieba
import warnings
root_dir = Path(__file__).resolve().parent
logger_initialized = {}
def pad_list(xs, pad_value, max_len=None):
n_batch = len(xs)
if max_len is None:
max_len = max(x.size(0) for x in xs)
# pad = xs[0].new(n_batch, max_len, *xs[0].size()[1:]).fill_(pad_value)
# numpy format
pad = (np.zeros((n_batch, max_len)) + pad_value).astype(np.int32)
for i in range(n_batch):
pad[i, : xs[i].shape[0]] = xs[i]
return pad
"""
def make_pad_mask(lengths, xs=None, length_dim=-1, maxlen=None):
if length_dim == 0:
raise ValueError("length_dim cannot be 0: {}".format(length_dim))
if not isinstance(lengths, list):
lengths = lengths.tolist()
bs = int(len(lengths))
if maxlen is None:
if xs is None:
maxlen = int(max(lengths))
else:
maxlen = xs.size(length_dim)
else:
assert xs is None
assert maxlen >= int(max(lengths))
seq_range = torch.arange(0, maxlen, dtype=torch.int64)
seq_range_expand = seq_range.unsqueeze(0).expand(bs, maxlen)
seq_length_expand = seq_range_expand.new(lengths).unsqueeze(-1)
mask = seq_range_expand >= seq_length_expand
if xs is not None:
assert xs.size(0) == bs, (xs.size(0), bs)
if length_dim < 0:
length_dim = xs.dim() + length_dim
# ind = (:, None, ..., None, :, , None, ..., None)
ind = tuple(
slice(None) if i in (0, length_dim) else None for i in range(xs.dim())
)
mask = mask[ind].expand_as(xs).to(xs.device)
return mask
"""
class TokenIDConverter:
def __init__(
self,
token_list: Union[List, str],
):
self.token_list = token_list
self.unk_symbol = token_list[-1]
self.token2id = {v: i for i, v in enumerate(self.token_list)}
self.unk_id = self.token2id[self.unk_symbol]
def get_num_vocabulary_size(self) -> int:
return len(self.token_list)
def ids2tokens(self, integers: Union[np.ndarray, Iterable[int]]) -> List[str]:
if isinstance(integers, np.ndarray) and integers.ndim != 1:
raise TokenIDConverterError(f"Must be 1 dim ndarray, but got {integers.ndim}")
return [self.token_list[i] for i in integers]
def tokens2ids(self, tokens: Iterable[str]) -> List[int]:
return [self.token2id.get(i, self.unk_id) for i in tokens]
class CharTokenizer:
def __init__(
self,
symbol_value: Union[Path, str, Iterable[str]] = None,
space_symbol: str = "<space>",
remove_non_linguistic_symbols: bool = False,
):
self.space_symbol = space_symbol
self.non_linguistic_symbols = self.load_symbols(symbol_value)
self.remove_non_linguistic_symbols = remove_non_linguistic_symbols
@staticmethod
def load_symbols(value: Union[Path, str, Iterable[str]] = None) -> Set:
if value is None:
return set()
if isinstance(value, Iterable[str]):
return set(value)
file_path = Path(value)
if not file_path.exists():
logging.warning("%s doesn't exist.", file_path)
return set()
with file_path.open("r", encoding="utf-8") as f:
return set(line.rstrip() for line in f)
def text2tokens(self, line: Union[str, list]) -> List[str]:
tokens = []
while len(line) != 0:
for w in self.non_linguistic_symbols:
if line.startswith(w):
if not self.remove_non_linguistic_symbols:
tokens.append(line[: len(w)])
line = line[len(w) :]
break
else:
t = line[0]
if t == " ":
t = "<space>"
tokens.append(t)
line = line[1:]
return tokens
def tokens2text(self, tokens: Iterable[str]) -> str:
tokens = [t if t != self.space_symbol else " " for t in tokens]
return "".join(tokens)
def __repr__(self):
return (
f"{self.__class__.__name__}("
f'space_symbol="{self.space_symbol}"'
f'non_linguistic_symbols="{self.non_linguistic_symbols}"'
f")"
)
class Hypothesis(NamedTuple):
"""Hypothesis data type."""
yseq: np.ndarray
score: Union[float, np.ndarray] = 0
scores: Dict[str, Union[float, np.ndarray]] = dict()
states: Dict[str, Any] = dict()
def asdict(self) -> dict:
"""Convert data to JSON-friendly dict."""
return self._replace(
yseq=self.yseq.tolist(),
score=float(self.score),
scores={k: float(v) for k, v in self.scores.items()},
)._asdict()
class TokenIDConverterError(Exception):
pass
class ONNXRuntimeError(Exception):
pass
class OrtInferSession:
def __init__(self, model_file, device_id=-1, intra_op_num_threads=4):
device_id = str(device_id)
sess_opt = SessionOptions()
sess_opt.intra_op_num_threads = intra_op_num_threads
sess_opt.log_severity_level = 4
sess_opt.enable_cpu_mem_arena = False
sess_opt.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL
cuda_ep = "CUDAExecutionProvider"
cuda_provider_options = {
"device_id": device_id,
"arena_extend_strategy": "kNextPowerOfTwo",
"cudnn_conv_algo_search": "EXHAUSTIVE",
"do_copy_in_default_stream": "true",
}
cpu_ep = "CPUExecutionProvider"
cpu_provider_options = {
"arena_extend_strategy": "kSameAsRequested",
}
EP_list = []
if device_id != "-1" and get_device() == "GPU" and cuda_ep in get_available_providers():
EP_list = [(cuda_ep, cuda_provider_options)]
EP_list.append((cpu_ep, cpu_provider_options))
self._verify_model(model_file)
self.session = InferenceSession(model_file, sess_options=sess_opt, providers=EP_list)
if device_id != "-1" and cuda_ep not in self.session.get_providers():
warnings.warn(
f"{cuda_ep} is not avaiable for current env, the inference part is automatically shifted to be executed under {cpu_ep}.\n"
"Please ensure the installed onnxruntime-gpu version matches your cuda and cudnn version, "
"you can check their relations from the offical web site: "
"https://onnxruntime.ai/docs/execution-providers/CUDA-ExecutionProvider.html",
RuntimeWarning,
)
def __call__(self, input_content: List[Union[np.ndarray, np.ndarray]]) -> np.ndarray:
input_dict = dict(zip(self.get_input_names(), input_content))
try:
return self.session.run(self.get_output_names(), input_dict)
except Exception as e:
raise ONNXRuntimeError("ONNXRuntime inferece failed.") from e
def get_input_names(
self,
):
return [v.name for v in self.session.get_inputs()]
def get_output_names(
self,
):
return [v.name for v in self.session.get_outputs()]
def get_character_list(self, key: str = "character"):
return self.meta_dict[key].splitlines()
def have_key(self, key: str = "character") -> bool:
self.meta_dict = self.session.get_modelmeta().custom_metadata_map
if key in self.meta_dict.keys():
return True
return False
@staticmethod
def _verify_model(model_path):
model_path = Path(model_path)
if not model_path.exists():
raise FileNotFoundError(f"{model_path} does not exists.")
if not model_path.is_file():
raise FileExistsError(f"{model_path} is not a file.")
def split_to_mini_sentence(words: list, word_limit: int = 20):
assert word_limit > 1
if len(words) <= word_limit:
return [words]
sentences = []
length = len(words)
sentence_len = length // word_limit
for i in range(sentence_len):
sentences.append(words[i * word_limit : (i + 1) * word_limit])
if length % word_limit > 0:
sentences.append(words[sentence_len * word_limit :])
return sentences
def code_mix_split_words(text: str):
words = []
segs = text.split()
for seg in segs:
# There is no space in seg.
current_word = ""
for c in seg:
if len(c.encode()) == 1:
# This is an ASCII char.
current_word += c
else:
# This is a Chinese char.
if len(current_word) > 0:
words.append(current_word)
current_word = ""
words.append(c)
if len(current_word) > 0:
words.append(current_word)
return words
def isEnglish(text: str):
if re.search("^[a-zA-Z']+$", text):
return True
else:
return False
def join_chinese_and_english(input_list):
line = ""
for token in input_list:
if isEnglish(token):
line = line + " " + token
else:
line = line + token
line = line.strip()
return line
def code_mix_split_words_jieba(seg_dict_file: str):
jieba.load_userdict(seg_dict_file)
def _fn(text: str):
input_list = text.split()
token_list_all = []
langauge_list = []
token_list_tmp = []
language_flag = None
for token in input_list:
if isEnglish(token) and language_flag == "Chinese":
token_list_all.append(token_list_tmp)
langauge_list.append("Chinese")
token_list_tmp = []
elif not isEnglish(token) and language_flag == "English":
token_list_all.append(token_list_tmp)
langauge_list.append("English")
token_list_tmp = []
token_list_tmp.append(token)
if isEnglish(token):
language_flag = "English"
else:
language_flag = "Chinese"
if token_list_tmp:
token_list_all.append(token_list_tmp)
langauge_list.append(language_flag)
result_list = []
for token_list_tmp, language_flag in zip(token_list_all, langauge_list):
if language_flag == "English":
result_list.extend(token_list_tmp)
else:
seg_list = jieba.cut(join_chinese_and_english(token_list_tmp), HMM=False)
result_list.extend(seg_list)
return result_list
return _fn
def read_yaml(yaml_path: Union[str, Path]) -> Dict:
if not Path(yaml_path).exists():
raise FileExistsError(f"The {yaml_path} does not exist.")
with open(str(yaml_path), "rb") as f:
data = yaml.load(f, Loader=yaml.Loader)
return data
@functools.lru_cache()
def get_logger(name="funasr_onnx"):
"""Initialize and get a logger by name.
If the logger has not been initialized, this method will initialize the
logger by adding one or two handlers, otherwise the initialized logger will
be directly returned. During initialization, a StreamHandler will always be
added.
Args:
name (str): Logger name.
Returns:
logging.Logger: The expected logger.
"""
logger = logging.getLogger(name)
if name in logger_initialized:
return logger
for logger_name in logger_initialized:
if name.startswith(logger_name):
return logger
formatter = logging.Formatter(
"[%(asctime)s] %(name)s %(levelname)s: %(message)s", datefmt="%Y/%m/%d %H:%M:%S"
)
sh = logging.StreamHandler()
sh.setFormatter(formatter)
logger.addHandler(sh)
logger_initialized[name] = True
logger.propagate = False
logging.basicConfig(level=logging.ERROR)
return logger

View File

@@ -0,0 +1,330 @@
# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
# MIT License (https://opensource.org/licenses/MIT)
import os.path
from pathlib import Path
from typing import List, Union, Tuple
import copy
import librosa
import numpy as np
from .utils.utils import ONNXRuntimeError, OrtInferSession, get_logger, read_yaml
from .utils.frontend import WavFrontend, WavFrontendOnline
from .utils.e2e_vad import E2EVadModel
logging = get_logger()
class Fsmn_vad:
"""
Author: Speech Lab of DAMO Academy, Alibaba Group
Deep-FSMN for Large Vocabulary Continuous Speech Recognition
https://arxiv.org/abs/1803.05030
"""
def __init__(
self,
model_dir: Union[str, Path] = None,
batch_size: int = 1,
device_id: Union[str, int] = "-1",
quantize: bool = False,
intra_op_num_threads: int = 4,
max_end_sil: int = None,
cache_dir: str = None,
**kwargs,
):
if not Path(model_dir).exists():
try:
from modelscope.hub.snapshot_download import snapshot_download
except:
raise "You are exporting model from modelscope, please install modelscope and try it again. To install modelscope, you could:\n" "\npip3 install -U modelscope\n" "For the users in China, you could install with the command:\n" "\npip3 install -U modelscope -i https://mirror.sjtu.edu.cn/pypi/web/simple"
try:
model_dir = snapshot_download(model_dir, cache_dir=cache_dir)
except:
raise "model_dir must be model_name in modelscope or local path downloaded from modelscope, but is {}".format(
model_dir
)
model_file = os.path.join(model_dir, "model.onnx")
if quantize:
model_file = os.path.join(model_dir, "model_quant.onnx")
if not os.path.exists(model_file):
print(".onnx does not exist, begin to export onnx")
try:
from funasr import AutoModel
except:
raise "You are exporting onnx, please install funasr and try it again. To install funasr, you could:\n" "\npip3 install -U funasr\n" "For the users in China, you could install with the command:\n" "\npip3 install -U funasr -i https://mirror.sjtu.edu.cn/pypi/web/simple"
model = AutoModel(model=model_dir)
model_dir = model.export(type="onnx", quantize=quantize, **kwargs)
config_file = os.path.join(model_dir, "config.yaml")
cmvn_file = os.path.join(model_dir, "am.mvn")
config = read_yaml(config_file)
self.frontend = WavFrontend(cmvn_file=cmvn_file, **config["frontend_conf"])
self.ort_infer = OrtInferSession(
model_file, device_id, intra_op_num_threads=intra_op_num_threads
)
self.batch_size = batch_size
self.vad_scorer = E2EVadModel(config["model_conf"])
self.max_end_sil = (
max_end_sil if max_end_sil is not None else config["model_conf"]["max_end_silence_time"]
)
self.encoder_conf = config["encoder_conf"]
def prepare_cache(self, in_cache: list = []):
if len(in_cache) > 0:
return in_cache
fsmn_layers = self.encoder_conf["fsmn_layers"]
proj_dim = self.encoder_conf["proj_dim"]
lorder = self.encoder_conf["lorder"]
for i in range(fsmn_layers):
cache = np.zeros((1, proj_dim, lorder - 1, 1)).astype(np.float32)
in_cache.append(cache)
return in_cache
def __call__(self, audio_in: Union[str, np.ndarray, List[str]], **kwargs) -> List:
waveform_list = self.load_data(audio_in, self.frontend.opts.frame_opts.samp_freq)
waveform_nums = len(waveform_list)
is_final = kwargs.get("kwargs", False)
segments = [[]] * self.batch_size
for beg_idx in range(0, waveform_nums, self.batch_size):
end_idx = min(waveform_nums, beg_idx + self.batch_size)
waveform = waveform_list[beg_idx:end_idx]
feats, feats_len = self.extract_feat(waveform)
waveform = np.array(waveform)
param_dict = kwargs.get("param_dict", dict())
in_cache = param_dict.get("in_cache", list())
in_cache = self.prepare_cache(in_cache)
try:
t_offset = 0
step = int(min(feats_len.max(), 6000))
for t_offset in range(0, int(feats_len), min(step, feats_len - t_offset)):
if t_offset + step >= feats_len - 1:
step = feats_len - t_offset
is_final = True
else:
is_final = False
feats_package = feats[:, t_offset : int(t_offset + step), :]
waveform_package = waveform[
:,
t_offset
* 160 : min(waveform.shape[-1], (int(t_offset + step) - 1) * 160 + 400),
]
inputs = [feats_package]
# inputs = [feats]
inputs.extend(in_cache)
scores, out_caches = self.infer(inputs)
in_cache = out_caches
segments_part = self.vad_scorer(
scores,
waveform_package,
is_final=is_final,
max_end_sil=self.max_end_sil,
online=False,
)
# segments = self.vad_scorer(scores, waveform[0][None, :], is_final=is_final, max_end_sil=self.max_end_sil)
if segments_part:
for batch_num in range(0, self.batch_size):
segments[batch_num] += segments_part[batch_num]
except ONNXRuntimeError:
# logging.warning(traceback.format_exc())
logging.warning("input wav is silence or noise")
segments = ""
return segments
def load_data(self, wav_content: Union[str, np.ndarray, List[str]], fs: int = None) -> List:
def load_wav(path: str) -> np.ndarray:
waveform, _ = librosa.load(path, sr=fs)
return waveform
if isinstance(wav_content, np.ndarray):
return [wav_content]
if isinstance(wav_content, str):
return [load_wav(wav_content)]
if isinstance(wav_content, list):
return [load_wav(path) for path in wav_content]
raise TypeError(f"The type of {wav_content} is not in [str, np.ndarray, list]")
def extract_feat(self, waveform_list: List[np.ndarray]) -> Tuple[np.ndarray, np.ndarray]:
feats, feats_len = [], []
for waveform in waveform_list:
speech, _ = self.frontend.fbank(waveform)
feat, feat_len = self.frontend.lfr_cmvn(speech)
feats.append(feat)
feats_len.append(feat_len)
feats = self.pad_feats(feats, np.max(feats_len))
feats_len = np.array(feats_len).astype(np.int32)
return feats, feats_len
@staticmethod
def pad_feats(feats: List[np.ndarray], max_feat_len: int) -> np.ndarray:
def pad_feat(feat: np.ndarray, cur_len: int) -> np.ndarray:
pad_width = ((0, max_feat_len - cur_len), (0, 0))
return np.pad(feat, pad_width, "constant", constant_values=0)
feat_res = [pad_feat(feat, feat.shape[0]) for feat in feats]
feats = np.array(feat_res).astype(np.float32)
return feats
def infer(self, feats: List) -> Tuple[np.ndarray, np.ndarray]:
outputs = self.ort_infer(feats)
scores, out_caches = outputs[0], outputs[1:]
return scores, out_caches
class Fsmn_vad_online:
"""
Author: Speech Lab of DAMO Academy, Alibaba Group
Deep-FSMN for Large Vocabulary Continuous Speech Recognition
https://arxiv.org/abs/1803.05030
"""
def __init__(
self,
model_dir: Union[str, Path] = None,
batch_size: int = 1,
device_id: Union[str, int] = "-1",
quantize: bool = False,
intra_op_num_threads: int = 4,
max_end_sil: int = None,
cache_dir: str = None,
**kwargs,
):
if not Path(model_dir).exists():
try:
from modelscope.hub.snapshot_download import snapshot_download
except:
raise "You are exporting model from modelscope, please install modelscope and try it again. To install modelscope, you could:\n" "\npip3 install -U modelscope\n" "For the users in China, you could install with the command:\n" "\npip3 install -U modelscope -i https://mirror.sjtu.edu.cn/pypi/web/simple"
try:
model_dir = snapshot_download(model_dir, cache_dir=cache_dir)
except:
raise "model_dir must be model_name in modelscope or local path downloaded from modelscope, but is {}".format(
model_dir
)
model_file = os.path.join(model_dir, "model.onnx")
if quantize:
model_file = os.path.join(model_dir, "model_quant.onnx")
if not os.path.exists(model_file):
print(".onnx does not exist, begin to export onnx")
try:
from funasr import AutoModel
except:
raise "You are exporting onnx, please install funasr and try it again. To install funasr, you could:\n" "\npip3 install -U funasr\n" "For the users in China, you could install with the command:\n" "\npip3 install -U funasr -i https://mirror.sjtu.edu.cn/pypi/web/simple"
model = AutoModel(model=model_dir)
model_dir = model.export(type="onnx", quantize=quantize, **kwargs)
config_file = os.path.join(model_dir, "config.yaml")
cmvn_file = os.path.join(model_dir, "am.mvn")
config = read_yaml(config_file)
self.frontend = WavFrontendOnline(cmvn_file=cmvn_file, **config["frontend_conf"])
self.ort_infer = OrtInferSession(
model_file, device_id, intra_op_num_threads=intra_op_num_threads
)
self.batch_size = batch_size
self.vad_scorer = E2EVadModel(config["model_conf"])
self.max_end_sil = (
max_end_sil if max_end_sil is not None else config["model_conf"]["max_end_silence_time"]
)
self.encoder_conf = config["encoder_conf"]
def prepare_cache(self, in_cache: list = []):
if len(in_cache) > 0:
return in_cache
fsmn_layers = self.encoder_conf["fsmn_layers"]
proj_dim = self.encoder_conf["proj_dim"]
lorder = self.encoder_conf["lorder"]
for i in range(fsmn_layers):
cache = np.zeros((1, proj_dim, lorder - 1, 1)).astype(np.float32)
in_cache.append(cache)
return in_cache
def __call__(self, audio_in: np.ndarray, **kwargs) -> List:
waveforms = np.expand_dims(audio_in, axis=0)
param_dict = kwargs.get("param_dict", dict())
is_final = param_dict.get("is_final", False)
feats, feats_len = self.extract_feat(waveforms, is_final)
segments = []
if feats.size != 0:
in_cache = param_dict.get("in_cache", list())
in_cache = self.prepare_cache(in_cache)
try:
inputs = [feats]
inputs.extend(in_cache)
scores, out_caches = self.infer(inputs)
param_dict["in_cache"] = out_caches
waveforms = self.frontend.get_waveforms()
segments = self.vad_scorer(
scores, waveforms, is_final=is_final, max_end_sil=self.max_end_sil, online=True
)
except ONNXRuntimeError:
# logging.warning(traceback.format_exc())
logging.warning("input wav is silence or noise")
segments = []
return segments
def load_data(self, wav_content: Union[str, np.ndarray, List[str]], fs: int = None) -> List:
def load_wav(path: str) -> np.ndarray:
waveform, _ = librosa.load(path, sr=fs)
return waveform
if isinstance(wav_content, np.ndarray):
return [wav_content]
if isinstance(wav_content, str):
return [load_wav(wav_content)]
if isinstance(wav_content, list):
return [load_wav(path) for path in wav_content]
raise TypeError(f"The type of {wav_content} is not in [str, np.ndarray, list]")
def extract_feat(
self, waveforms: np.ndarray, is_final: bool = False
) -> Tuple[np.ndarray, np.ndarray]:
waveforms_lens = np.zeros(waveforms.shape[0]).astype(np.int32)
for idx, waveform in enumerate(waveforms):
waveforms_lens[idx] = waveform.shape[-1]
feats, feats_len = self.frontend.extract_fbank(waveforms, waveforms_lens, is_final)
# feats.append(feat)
# feats_len.append(feat_len)
# feats = self.pad_feats(feats, np.max(feats_len))
# feats_len = np.array(feats_len).astype(np.int32)
return feats.astype(np.float32), feats_len.astype(np.int32)
@staticmethod
def pad_feats(feats: List[np.ndarray], max_feat_len: int) -> np.ndarray:
def pad_feat(feat: np.ndarray, cur_len: int) -> np.ndarray:
pad_width = ((0, max_feat_len - cur_len), (0, 0))
return np.pad(feat, pad_width, "constant", constant_values=0)
feat_res = [pad_feat(feat, feat.shape[0]) for feat in feats]
feats = np.array(feat_res).astype(np.float32)
return feats
def infer(self, feats: List) -> Tuple[np.ndarray, np.ndarray]:
outputs = self.ort_infer(feats)
scores, out_caches = outputs[0], outputs[1:]
return scores, out_caches

View File

@@ -0,0 +1,43 @@
import argparse
import base64
import io
import soundfile as sf
import uvicorn
from fastapi import FastAPI, Body
app = FastAPI()
from funasr_onnx import Paraformer
model_dir = "damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-onnx"
model = Paraformer(model_dir, batch_size=1, quantize=True)
async def recognition_onnx(waveform):
result = model(waveform)[0]["preds"][0]
return result
@app.post("/api/asr")
async def asr(item: dict = Body(...)):
try:
audio_bytes = base64.b64decode(bytes(item["wav_base64"], "utf-8"))
waveform, _ = sf.read(io.BytesIO(audio_bytes))
result = await recognition_onnx(waveform)
ret = {"results": result, "code": 0}
except:
print("请求出错,这里是处理出错的")
ret = {"results": "", "code": 1}
return ret
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="API Service")
parser.add_argument("--listen", default="0.0.0.0", type=str, help="the network to listen")
parser.add_argument("--port", default=8888, type=int, help="the port to listen")
args = parser.parse_args()
print("start...")
print("server on:", args)
uvicorn.run(app, host=args.listen, port=args.port)

View File

@@ -0,0 +1,49 @@
# -*- encoding: utf-8 -*-
from pathlib import Path
import setuptools
def get_readme():
root_dir = Path(__file__).resolve().parent
readme_path = str(root_dir / "README.md")
print(readme_path)
with open(readme_path, "r", encoding="utf-8") as f:
readme = f.read()
return readme
MODULE_NAME = "funasr_onnx"
VERSION_NUM = "0.4.1"
setuptools.setup(
name=MODULE_NAME,
version=VERSION_NUM,
platforms="Any",
url="https://github.com/alibaba-damo-academy/FunASR.git",
author="Speech Lab of DAMO Academy, Alibaba Group",
author_email="funasr@list.alibaba-inc.com",
description="FunASR: A Fundamental End-to-End Speech Recognition Toolkit",
license="MIT",
long_description=get_readme(),
long_description_content_type="text/markdown",
include_package_data=True,
install_requires=[
"librosa",
"onnxruntime>=1.7.0",
"scipy",
"numpy<=1.26.4",
"kaldi-native-fbank",
"PyYAML>=5.1.2",
"onnx",
"sentencepiece",
],
packages=[MODULE_NAME, f"{MODULE_NAME}.utils"],
keywords=["funasr,asr"],
classifiers=[
"Programming Language :: Python :: 3.6",
"Programming Language :: Python :: 3.7",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
],
)

View File

@@ -0,0 +1,193 @@
import os
import numpy as np
import sys
def compute_wer(ref_file, hyp_file, cer_detail_file):
rst = {
"Wrd": 0,
"Corr": 0,
"Ins": 0,
"Del": 0,
"Sub": 0,
"Snt": 0,
"Err": 0.0,
"S.Err": 0.0,
"wrong_words": 0,
"wrong_sentences": 0,
}
hyp_dict = {}
ref_dict = {}
with open(hyp_file, "r") as hyp_reader:
for line in hyp_reader:
key = line.strip().split()[0]
value = line.strip().split()[1:]
hyp_dict[key] = value
with open(ref_file, "r") as ref_reader:
for line in ref_reader:
key = line.strip().split()[0]
value = line.strip().split()[1:]
ref_dict[key] = value
cer_detail_writer = open(cer_detail_file, "w")
for hyp_key in hyp_dict:
if hyp_key in ref_dict:
out_item = compute_wer_by_line(hyp_dict[hyp_key], ref_dict[hyp_key])
rst["Wrd"] += out_item["nwords"]
rst["Corr"] += out_item["cor"]
rst["wrong_words"] += out_item["wrong"]
rst["Ins"] += out_item["ins"]
rst["Del"] += out_item["del"]
rst["Sub"] += out_item["sub"]
rst["Snt"] += 1
if out_item["wrong"] > 0:
rst["wrong_sentences"] += 1
cer_detail_writer.write(hyp_key + print_cer_detail(out_item) + "\n")
cer_detail_writer.write("ref:" + "\t" + "".join(ref_dict[hyp_key]) + "\n")
cer_detail_writer.write("hyp:" + "\t" + "".join(hyp_dict[hyp_key]) + "\n")
if rst["Wrd"] > 0:
rst["Err"] = round(rst["wrong_words"] * 100 / rst["Wrd"], 2)
if rst["Snt"] > 0:
rst["S.Err"] = round(rst["wrong_sentences"] * 100 / rst["Snt"], 2)
cer_detail_writer.write("\n")
cer_detail_writer.write(
"%WER "
+ str(rst["Err"])
+ " [ "
+ str(rst["wrong_words"])
+ " / "
+ str(rst["Wrd"])
+ ", "
+ str(rst["Ins"])
+ " ins, "
+ str(rst["Del"])
+ " del, "
+ str(rst["Sub"])
+ " sub ]"
+ "\n"
)
cer_detail_writer.write(
"%SER "
+ str(rst["S.Err"])
+ " [ "
+ str(rst["wrong_sentences"])
+ " / "
+ str(rst["Snt"])
+ " ]"
+ "\n"
)
cer_detail_writer.write(
"Scored "
+ str(len(hyp_dict))
+ " sentences, "
+ str(len(hyp_dict) - rst["Snt"])
+ " not present in hyp."
+ "\n"
)
def compute_wer_by_line(hyp, ref):
hyp = list(map(lambda x: x.lower(), hyp))
ref = list(map(lambda x: x.lower(), ref))
len_hyp = len(hyp)
len_ref = len(ref)
cost_matrix = np.zeros((len_hyp + 1, len_ref + 1), dtype=np.int16)
ops_matrix = np.zeros((len_hyp + 1, len_ref + 1), dtype=np.int8)
for i in range(len_hyp + 1):
cost_matrix[i][0] = i
for j in range(len_ref + 1):
cost_matrix[0][j] = j
for i in range(1, len_hyp + 1):
for j in range(1, len_ref + 1):
if hyp[i - 1] == ref[j - 1]:
cost_matrix[i][j] = cost_matrix[i - 1][j - 1]
else:
substitution = cost_matrix[i - 1][j - 1] + 1
insertion = cost_matrix[i - 1][j] + 1
deletion = cost_matrix[i][j - 1] + 1
compare_val = [substitution, insertion, deletion]
min_val = min(compare_val)
operation_idx = compare_val.index(min_val) + 1
cost_matrix[i][j] = min_val
ops_matrix[i][j] = operation_idx
match_idx = []
i = len_hyp
j = len_ref
rst = {"nwords": len_ref, "cor": 0, "wrong": 0, "ins": 0, "del": 0, "sub": 0}
while i >= 0 or j >= 0:
i_idx = max(0, i)
j_idx = max(0, j)
if ops_matrix[i_idx][j_idx] == 0: # correct
if i - 1 >= 0 and j - 1 >= 0:
match_idx.append((j - 1, i - 1))
rst["cor"] += 1
i -= 1
j -= 1
elif ops_matrix[i_idx][j_idx] == 2: # insert
i -= 1
rst["ins"] += 1
elif ops_matrix[i_idx][j_idx] == 3: # delete
j -= 1
rst["del"] += 1
elif ops_matrix[i_idx][j_idx] == 1: # substitute
i -= 1
j -= 1
rst["sub"] += 1
if i < 0 and j >= 0:
rst["del"] += 1
elif j < 0 and i >= 0:
rst["ins"] += 1
match_idx.reverse()
wrong_cnt = cost_matrix[len_hyp][len_ref]
rst["wrong"] = wrong_cnt
return rst
def print_cer_detail(rst):
return (
"("
+ "nwords="
+ str(rst["nwords"])
+ ",cor="
+ str(rst["cor"])
+ ",ins="
+ str(rst["ins"])
+ ",del="
+ str(rst["del"])
+ ",sub="
+ str(rst["sub"])
+ ") corr:"
+ "{:.2%}".format(rst["cor"] / rst["nwords"])
+ ",cer:"
+ "{:.2%}".format(rst["wrong"] / rst["nwords"])
)
if __name__ == "__main__":
if len(sys.argv) != 4:
print("usage : python compute-wer.py test.ref test.hyp test.wer")
sys.exit(0)
ref_file = sys.argv[1]
hyp_file = sys.argv[2]
cer_detail_file = sys.argv[3]
compute_wer(ref_file, hyp_file, cer_detail_file)

View File

@@ -0,0 +1,30 @@
import sys
import re
in_f = sys.argv[1]
out_f = sys.argv[2]
with open(in_f, "r", encoding="utf-8") as f:
lines = f.readlines()
with open(out_f, "w", encoding="utf-8") as f:
for line in lines:
outs = line.strip().split(" ", 1)
if len(outs) == 2:
idx, text = outs
text = re.sub("</s>", "", text)
text = re.sub("<s>", "", text)
text = re.sub("@@", "", text)
text = re.sub("@", "", text)
text = re.sub("<unk>", "", text)
text = re.sub(" ", "", text)
text = text.lower()
else:
idx = outs[0]
text = " "
text = [x for x in text]
text = " ".join(text)
out = "{} {}\n".format(idx, text)
f.write(out)

View File

@@ -0,0 +1,5 @@
onnx
onnxruntime
torch-quant >= 0.4.0
funasr_torch
funasr_onnx

View File

@@ -0,0 +1,246 @@
#!/usr/bin/env perl
# Copyright 2010-2011 Microsoft Corporation
# See ../../COPYING for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
# MERCHANTABLITY OR NON-INFRINGEMENT.
# See the Apache 2 License for the specific language governing permissions and
# limitations under the License.
# This program splits up any kind of .scp or archive-type file.
# If there is no utt2spk option it will work on any text file and
# will split it up with an approximately equal number of lines in
# each but.
# With the --utt2spk option it will work on anything that has the
# utterance-id as the first entry on each line; the utt2spk file is
# of the form "utterance speaker" (on each line).
# It splits it into equal size chunks as far as it can. If you use the utt2spk
# option it will make sure these chunks coincide with speaker boundaries. In
# this case, if there are more chunks than speakers (and in some other
# circumstances), some of the resulting chunks will be empty and it will print
# an error message and exit with nonzero status.
# You will normally call this like:
# split_scp.pl scp scp.1 scp.2 scp.3 ...
# or
# split_scp.pl --utt2spk=utt2spk scp scp.1 scp.2 scp.3 ...
# Note that you can use this script to split the utt2spk file itself,
# e.g. split_scp.pl --utt2spk=utt2spk utt2spk utt2spk.1 utt2spk.2 ...
# You can also call the scripts like:
# split_scp.pl -j 3 0 scp scp.0
# [note: with this option, it assumes zero-based indexing of the split parts,
# i.e. the second number must be 0 <= n < num-jobs.]
use warnings;
$num_jobs = 0;
$job_id = 0;
$utt2spk_file = "";
$one_based = 0;
for ($x = 1; $x <= 3 && @ARGV > 0; $x++) {
if ($ARGV[0] eq "-j") {
shift @ARGV;
$num_jobs = shift @ARGV;
$job_id = shift @ARGV;
}
if ($ARGV[0] =~ /--utt2spk=(.+)/) {
$utt2spk_file=$1;
shift;
}
if ($ARGV[0] eq '--one-based') {
$one_based = 1;
shift @ARGV;
}
}
if ($num_jobs != 0 && ($num_jobs < 0 || $job_id - $one_based < 0 ||
$job_id - $one_based >= $num_jobs)) {
die "$0: Invalid job number/index values for '-j $num_jobs $job_id" .
($one_based ? " --one-based" : "") . "'\n"
}
$one_based
and $job_id--;
if(($num_jobs == 0 && @ARGV < 2) || ($num_jobs > 0 && (@ARGV < 1 || @ARGV > 2))) {
die
"Usage: split_scp.pl [--utt2spk=<utt2spk_file>] in.scp out1.scp out2.scp ...
or: split_scp.pl -j num-jobs job-id [--one-based] [--utt2spk=<utt2spk_file>] in.scp [out.scp]
... where 0 <= job-id < num-jobs, or 1 <= job-id <- num-jobs if --one-based.\n";
}
$error = 0;
$inscp = shift @ARGV;
if ($num_jobs == 0) { # without -j option
@OUTPUTS = @ARGV;
} else {
for ($j = 0; $j < $num_jobs; $j++) {
if ($j == $job_id) {
if (@ARGV > 0) { push @OUTPUTS, $ARGV[0]; }
else { push @OUTPUTS, "-"; }
} else {
push @OUTPUTS, "/dev/null";
}
}
}
if ($utt2spk_file ne "") { # We have the --utt2spk option...
open($u_fh, '<', $utt2spk_file) || die "$0: Error opening utt2spk file $utt2spk_file: $!\n";
while(<$u_fh>) {
@A = split;
@A == 2 || die "$0: Bad line $_ in utt2spk file $utt2spk_file\n";
($u,$s) = @A;
$utt2spk{$u} = $s;
}
close $u_fh;
open($i_fh, '<', $inscp) || die "$0: Error opening input scp file $inscp: $!\n";
@spkrs = ();
while(<$i_fh>) {
@A = split;
if(@A == 0) { die "$0: Empty or space-only line in scp file $inscp\n"; }
$u = $A[0];
$s = $utt2spk{$u};
defined $s || die "$0: No utterance $u in utt2spk file $utt2spk_file\n";
if(!defined $spk_count{$s}) {
push @spkrs, $s;
$spk_count{$s} = 0;
$spk_data{$s} = []; # ref to new empty array.
}
$spk_count{$s}++;
push @{$spk_data{$s}}, $_;
}
# Now split as equally as possible ..
# First allocate spks to files by allocating an approximately
# equal number of speakers.
$numspks = @spkrs; # number of speakers.
$numscps = @OUTPUTS; # number of output files.
if ($numspks < $numscps) {
die "$0: Refusing to split data because number of speakers $numspks " .
"is less than the number of output .scp files $numscps\n";
}
for($scpidx = 0; $scpidx < $numscps; $scpidx++) {
$scparray[$scpidx] = []; # [] is array reference.
}
for ($spkidx = 0; $spkidx < $numspks; $spkidx++) {
$scpidx = int(($spkidx*$numscps) / $numspks);
$spk = $spkrs[$spkidx];
push @{$scparray[$scpidx]}, $spk;
$scpcount[$scpidx] += $spk_count{$spk};
}
# Now will try to reassign beginning + ending speakers
# to different scp's and see if it gets more balanced.
# Suppose objf we're minimizing is sum_i (num utts in scp[i] - average)^2.
# We can show that if considering changing just 2 scp's, we minimize
# this by minimizing the squared difference in sizes. This is
# equivalent to minimizing the absolute difference in sizes. This
# shows this method is bound to converge.
$changed = 1;
while($changed) {
$changed = 0;
for($scpidx = 0; $scpidx < $numscps; $scpidx++) {
# First try to reassign ending spk of this scp.
if($scpidx < $numscps-1) {
$sz = @{$scparray[$scpidx]};
if($sz > 0) {
$spk = $scparray[$scpidx]->[$sz-1];
$count = $spk_count{$spk};
$nutt1 = $scpcount[$scpidx];
$nutt2 = $scpcount[$scpidx+1];
if( abs( ($nutt2+$count) - ($nutt1-$count))
< abs($nutt2 - $nutt1)) { # Would decrease
# size-diff by reassigning spk...
$scpcount[$scpidx+1] += $count;
$scpcount[$scpidx] -= $count;
pop @{$scparray[$scpidx]};
unshift @{$scparray[$scpidx+1]}, $spk;
$changed = 1;
}
}
}
if($scpidx > 0 && @{$scparray[$scpidx]} > 0) {
$spk = $scparray[$scpidx]->[0];
$count = $spk_count{$spk};
$nutt1 = $scpcount[$scpidx-1];
$nutt2 = $scpcount[$scpidx];
if( abs( ($nutt2-$count) - ($nutt1+$count))
< abs($nutt2 - $nutt1)) { # Would decrease
# size-diff by reassigning spk...
$scpcount[$scpidx-1] += $count;
$scpcount[$scpidx] -= $count;
shift @{$scparray[$scpidx]};
push @{$scparray[$scpidx-1]}, $spk;
$changed = 1;
}
}
}
}
# Now print out the files...
for($scpidx = 0; $scpidx < $numscps; $scpidx++) {
$scpfile = $OUTPUTS[$scpidx];
($scpfile ne '-' ? open($f_fh, '>', $scpfile)
: open($f_fh, '>&', \*STDOUT)) ||
die "$0: Could not open scp file $scpfile for writing: $!\n";
$count = 0;
if(@{$scparray[$scpidx]} == 0) {
print STDERR "$0: eError: split_scp.pl producing empty .scp file " .
"$scpfile (too many splits and too few speakers?)\n";
$error = 1;
} else {
foreach $spk ( @{$scparray[$scpidx]} ) {
print $f_fh @{$spk_data{$spk}};
$count += $spk_count{$spk};
}
$count == $scpcount[$scpidx] || die "Count mismatch [code error]";
}
close($f_fh);
}
} else {
# This block is the "normal" case where there is no --utt2spk
# option and we just break into equal size chunks.
open($i_fh, '<', $inscp) || die "$0: Error opening input scp file $inscp: $!\n";
$numscps = @OUTPUTS; # size of array.
@F = ();
while(<$i_fh>) {
push @F, $_;
}
$numlines = @F;
if($numlines == 0) {
print STDERR "$0: error: empty input scp file $inscp\n";
$error = 1;
}
$linesperscp = int( $numlines / $numscps); # the "whole part"..
$linesperscp >= 1 || die "$0: You are splitting into too many pieces! [reduce \$nj ($numscps) to be smaller than the number of lines ($numlines) in $inscp]\n";
$remainder = $numlines - ($linesperscp * $numscps);
($remainder >= 0 && $remainder < $numlines) || die "bad remainder $remainder";
# [just doing int() rounds down].
$n = 0;
for($scpidx = 0; $scpidx < @OUTPUTS; $scpidx++) {
$scpfile = $OUTPUTS[$scpidx];
($scpfile ne '-' ? open($o_fh, '>', $scpfile)
: open($o_fh, '>&', \*STDOUT)) ||
die "$0: Could not open scp file $scpfile for writing: $!\n";
for($k = 0; $k < $linesperscp + ($scpidx < $remainder ? 1 : 0); $k++) {
print $o_fh $F[$n++];
}
close($o_fh) || die "$0: Eror closing scp file $scpfile: $!\n";
}
$n == $numlines || die "$n != $numlines [code error]";
}
exit ($error);

View File

@@ -0,0 +1,56 @@
import os
import time
import sys
import librosa
from funasr.utils.types import str2bool
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--model_dir", type=str, required=True)
parser.add_argument("--backend", type=str, default="onnx", help='["onnx", "torch"]')
parser.add_argument("--wav_file", type=str, default=None, help="amp fallback number")
parser.add_argument("--quantize", type=str2bool, default=False, help="quantized model")
parser.add_argument(
"--intra_op_num_threads", type=int, default=1, help="intra_op_num_threads for onnx"
)
parser.add_argument("--output_dir", type=str, default=None, help="amp fallback number")
args = parser.parse_args()
from funasr.runtime.python.libtorch.funasr_torch import Paraformer
if args.backend == "onnx":
from funasr.runtime.python.onnxruntime.funasr_onnx import Paraformer
model = Paraformer(
args.model_dir,
batch_size=1,
quantize=args.quantize,
intra_op_num_threads=args.intra_op_num_threads,
)
wav_file_f = open(args.wav_file, "r")
wav_files = wav_file_f.readlines()
output_dir = args.output_dir
if not os.path.exists(output_dir):
os.makedirs(output_dir)
if os.name == "nt": # Windows
newline = "\r\n"
else: # Linux Mac
newline = "\n"
text_f = open(os.path.join(output_dir, "text"), "w", newline=newline)
token_f = open(os.path.join(output_dir, "token"), "w", newline=newline)
for i, wav_path_i in enumerate(wav_files):
wav_name, wav_path = wav_path_i.strip().split()
result = model(wav_path)
text_i = "{} {}\n".format(wav_name, result[0]["preds"][0])
token_i = "{} {}\n".format(wav_name, result[0]["preds"][1])
text_f.write(text_i)
text_f.flush()
token_f.write(token_i)
token_f.flush()
text_f.close()
token_f.close()

View File

@@ -0,0 +1,74 @@
split_scps_tool=split_scp.pl
inference_tool=test_cer.py
proce_text_tool=proce_text.py
compute_wer_tool=compute_wer.py
nj=32
stage=0
stop_stage=2
scp="/nfs/haoneng.lhn/funasr_data/aishell-1/data/test/wav.scp"
label_text="/nfs/haoneng.lhn/funasr_data/aishell-1/data/test/text"
export_root="/nfs/zhifu.gzf/export"
#:<<!
model_name="damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
backend="onnx" # "torch"
quantize='true' # 'False'
fallback_op_num_torch=20
tag=${model_name}/${backend}_quantize_${quantize}_${fallback_op_num_torch}
!
output_dir=${export_root}/logs/${tag}/split$nj
mkdir -p ${output_dir}
echo ${output_dir}
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ];then
python -m funasr.export.export_model --model-name ${model_name} --export-dir ${export_root} --type ${backend} --quantize ${quantize} --audio_in ${scp} --fallback-num ${fallback_op_num_torch}
fi
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ];then
model_dir=${export_root}/${model_name}
split_scps=""
for JOB in $(seq ${nj}); do
split_scps="$split_scps $output_dir/wav.$JOB.scp"
done
perl ${split_scps_tool} $scp ${split_scps}
for JOB in $(seq ${nj}); do
{
core_id=`expr $JOB - 1`
taskset -c ${core_id} python ${inference_tool} --backend ${backend} --model_dir ${model_dir} --wav_file ${output_dir}/wav.$JOB.scp --quantize ${quantize} --output_dir ${output_dir}/${JOB} &> ${output_dir}/log.$JOB.txt
}&
done
wait
mkdir -p ${output_dir}/1best_recog
for f in token text; do
if [ -f "${output_dir}/1/${f}" ]; then
for JOB in $(seq "${nj}"); do
cat "${output_dir}/${JOB}/${f}"
done | sort -k1 >"${output_dir}/1best_recog/${f}"
fi
done
fi
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ];then
echo "Computing WER ..."
python ${proce_text_tool} ${output_dir}/1best_recog/text ${output_dir}/1best_recog/text.proc
python ${proce_text_tool} ${label_text} ${output_dir}/1best_recog/text.ref
python ${compute_wer_tool} ${output_dir}/1best_recog/text.ref ${output_dir}/1best_recog/text.proc ${output_dir}/1best_recog/text.cer
tail -n 3 ${output_dir}/1best_recog/text.cer
fi

View File

@@ -0,0 +1,79 @@
import time
import sys
import librosa
from funasr.utils.types import str2bool
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--model_dir", type=str, required=True)
parser.add_argument("--backend", type=str, default="onnx", help='["onnx", "torch"]')
parser.add_argument("--wav_file", type=str, default=None, help="amp fallback number")
parser.add_argument("--quantize", type=str2bool, default=False, help="quantized model")
parser.add_argument(
"--intra_op_num_threads", type=int, default=1, help="intra_op_num_threads for onnx"
)
args = parser.parse_args()
from funasr.runtime.python.libtorch.funasr_torch import Paraformer
if args.backend == "onnx":
from funasr.runtime.python.onnxruntime.funasr_onnx import Paraformer
model = Paraformer(
args.model_dir,
batch_size=1,
quantize=args.quantize,
intra_op_num_threads=args.intra_op_num_threads,
)
wav_file_f = open(args.wav_file, "r")
wav_files = wav_file_f.readlines()
# warm-up
total = 0.0
num = 30
wav_path = (
wav_files[0].split("\t")[1].strip()
if "\t" in wav_files[0]
else wav_files[0].split(" ")[1].strip()
)
for i in range(num):
beg_time = time.time()
result = model(wav_path)
end_time = time.time()
duration = end_time - beg_time
total += duration
print(result)
print(
"num: {}, time, {}, avg: {}, rtf: {}".format(
len(wav_path), duration, total / (i + 1), (total / (i + 1)) / 5.53
)
)
# infer time
beg_time = time.time()
for i, wav_path_i in enumerate(wav_files):
wav_path = (
wav_path_i.split("\t")[1].strip()
if "\t" in wav_path_i
else wav_path_i.split(" ")[1].strip()
)
result = model(wav_path)
end_time = time.time()
duration = (end_time - beg_time) * 1000
print("total_time_comput_ms: {}".format(int(duration)))
duration_time = 0.0
for i, wav_path_i in enumerate(wav_files):
wav_path = (
wav_path_i.split("\t")[1].strip()
if "\t" in wav_path_i
else wav_path_i.split(" ")[1].strip()
)
waveform, _ = librosa.load(wav_path, sr=16000)
duration_time += len(waveform) / 16.0
print("total_time_wav_ms: {}".format(int(duration_time)))
print("total_rtf: {:.5}".format(duration / duration_time))

View File

@@ -0,0 +1,71 @@
nj=32
stage=0
scp="/nfs/haoneng.lhn/funasr_data/aishell-1/data/test/wav.scp"
export_root="/nfs/zhifu.gzf/export"
split_scps_tool=split_scp.pl
rtf_tool=test_rtf.py
#:<<!
model_name="damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
backend="onnx" # "torch"
quantize='true' # 'False'
tag=${model_name}/${backend}_quantize_${quantize}
!
logs_outputs_dir=${export_root}/logs/${tag}/split$nj
mkdir -p ${logs_outputs_dir}
echo ${logs_outputs_dir}
if [ ${stage} -le 0 ];then
python -m funasr.export.export_model --model-name ${model_name} --export-dir ${export_root} --type ${backend} --quantize ${quantize} --audio_in ${scp}
fi
if [ ${stage} -le 1 ];then
model_dir=${export_root}/${model_name}
split_scps=""
for JOB in $(seq ${nj}); do
split_scps="$split_scps $logs_outputs_dir/wav.$JOB.scp"
done
perl ${split_scps_tool} $scp ${split_scps}
for JOB in $(seq ${nj}); do
{
core_id=`expr $JOB - 1`
taskset -c ${core_id} python ${rtf_tool} --backend ${backend} --model_dir ${model_dir} --wav_file ${logs_outputs_dir}/wav.$JOB.scp --quantize ${quantize} &> ${logs_outputs_dir}/log.$JOB.txt
}&
done
wait
rm -rf ${logs_outputs_dir}/total_time_comput.txt
rm -rf ${logs_outputs_dir}/total_time_wav.txt
rm -rf ${logs_outputs_dir}/total_rtf.txt
for JOB in $(seq ${nj}); do
{
cat ${logs_outputs_dir}/log.$JOB.txt | grep "total_time_comput" | awk -F ' ' '{print $2}' >> ${logs_outputs_dir}/total_time_comput.txt
cat ${logs_outputs_dir}/log.$JOB.txt | grep "total_time_wav" | awk -F ' ' '{print $2}' >> ${logs_outputs_dir}/total_time_wav.txt
cat ${logs_outputs_dir}/log.$JOB.txt | grep "total_rtf" | awk -F ' ' '{print $2}' >> ${logs_outputs_dir}/total_rtf.txt
}
done
total_time_comput=`cat ${logs_outputs_dir}/total_time_comput.txt | awk 'BEGIN {max = 0} {if ($1+0>max+0) max=$1 fi} END {print max}'`
total_time_wav=`cat ${logs_outputs_dir}/total_time_wav.txt | awk '{sum +=$1};END {print sum}'`
rtf=`awk 'BEGIN{printf "%.5f\n",'$total_time_comput'/'$total_time_wav'}'`
speed=`awk 'BEGIN{printf "%.2f\n",1/'$rtf'}'`
echo "total_time_comput_ms: $total_time_comput"
echo "total_time_wav: $total_time_wav"
echo "total_rtf: $rtf, speech: $speed"
fi

View File

@@ -0,0 +1,82 @@
import time
import sys
import librosa
from funasr.utils.types import str2bool
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--model_dir", type=str, required=True)
parser.add_argument("--backend", type=str, default="onnx", help='["onnx", "torch"]')
parser.add_argument("--wav_file", type=str, default=None, help="amp fallback number")
parser.add_argument("--quantize", type=str2bool, default=False, help="quantized model")
parser.add_argument(
"--intra_op_num_threads", type=int, default=1, help="intra_op_num_threads for onnx"
)
parser.add_argument("--batch_size", type=int, default=1, help="batch_size for onnx")
args = parser.parse_args()
from funasr.runtime.python.libtorch.funasr_torch import Paraformer
if args.backend == "onnx":
from funasr.runtime.python.onnxruntime.funasr_onnx import Paraformer
model = Paraformer(
args.model_dir,
batch_size=args.batch_size,
quantize=args.quantize,
intra_op_num_threads=args.intra_op_num_threads,
)
wav_file_f = open(args.wav_file, "r")
wav_files = wav_file_f.readlines()
# warm-up
total = 0.0
num = 30
wav_path = (
wav_files[0].split("\t")[1].strip()
if "\t" in wav_files[0]
else wav_files[0].split(" ")[1].strip()
)
for i in range(num):
beg_time = time.time()
result = model(wav_path)
end_time = time.time()
duration = end_time - beg_time
total += duration
print(result)
print(
"num: {}, time, {}, avg: {}, rtf: {}".format(
len(wav_path), duration, total / (i + 1), (total / (i + 1)) / 5.53
)
)
# infer time
wav_path = []
beg_time = time.time()
for i, wav_path_i in enumerate(wav_files):
wav_path_i = (
wav_path_i.split("\t")[1].strip()
if "\t" in wav_path_i
else wav_path_i.split(" ")[1].strip()
)
wav_path += [wav_path_i]
result = model(wav_path)
end_time = time.time()
duration = (end_time - beg_time) * 1000
print("total_time_comput_ms: {}".format(int(duration)))
duration_time = 0.0
for i, wav_path_i in enumerate(wav_files):
wav_path = (
wav_path_i.split("\t")[1].strip()
if "\t" in wav_path_i
else wav_path_i.split(" ")[1].strip()
)
waveform, _ = librosa.load(wav_path, sr=16000)
duration_time += len(waveform) / 16.0
print("total_time_wav_ms: {}".format(int(duration_time)))
print("total_rtf: {:.5}".format(duration / duration_time))

View File

@@ -0,0 +1,127 @@
# Service with websocket-python
This is a demo using funasr pipeline with websocket python-api. It supports the offline, online, offline/online-2pass unifying speech recognition.
## For the Server
### Install the modelscope and funasr
```shell
pip install -U modelscope funasr
# For the users in China, you could install with the command:
# pip install -U modelscope funasr -i https://mirror.sjtu.edu.cn/pypi/web/simple
git clone https://github.com/alibaba/FunASR.git && cd FunASR
```
### Install the requirements for server
```shell
cd runtime/python/websocket
pip install -r requirements_server.txt
```
### Start server
##### API-reference
```shell
python funasr_wss_server.py \
--port [port id] \
--asr_model [asr model_name] \
--asr_model_online [asr model_name] \
--punc_model [punc model_name] \
--ngpu [0 or 1] \
--ncpu [1 or 4] \
--certfile [path of certfile for ssl] \
--keyfile [path of keyfile for ssl]
```
##### Usage examples
```shell
python funasr_wss_server.py --port 10095
```
## For the client
Install the requirements for client
```shell
git clone https://github.com/alibaba/FunASR.git && cd FunASR
cd funasr/runtime/python/websocket
pip install -r requirements_client.txt
```
If you want infer from videos, you should install `ffmpeg`
```shell
apt-get install -y ffmpeg #ubuntu
# yum install -y ffmpeg # centos
# brew install ffmpeg # mac
# winget install ffmpeg # wins
pip3 install websockets ffmpeg-python
```
### Start client
#### API-reference
```shell
python funasr_wss_client.py \
--host [ip_address] \
--port [port id] \
--chunk_size ["5,10,5"=600ms, "8,8,4"=480ms] \
--chunk_interval [duration of send chunk_size/chunk_interval] \
--words_max_print [max number of words to print] \
--audio_in [if set, loadding from wav.scp, else recording from mircrophone] \
--output_dir [if set, write the results to output_dir] \
--mode [`online` for streaming asr, `offline` for non-streaming, `2pass` for unifying streaming and non-streaming asr] \
--thread_num [thread_num for send data]
```
#### Usage examples
##### ASR offline client
Recording from mircrophone
```shell
# --chunk_interval, "10": 600/10=60ms, "5"=600/5=120ms, "20": 600/12=30ms
python funasr_wss_client.py --host "0.0.0.0" --port 10095 --mode offline
```
Loadding from wav.scp(kaldi style)
```shell
# --chunk_interval, "10": 600/10=60ms, "5"=600/5=120ms, "20": 600/12=30ms
python funasr_wss_client.py --host "0.0.0.0" --port 10095 --mode offline --audio_in "./data/wav.scp" --output_dir "./results"
```
##### ASR streaming client
Recording from mircrophone
```shell
# --chunk_size, "5,10,5"=600ms, "8,8,4"=480ms
python funasr_wss_client.py --host "0.0.0.0" --port 10095 --mode online --chunk_size "5,10,5"
```
Loadding from wav.scp(kaldi style)
```shell
# --chunk_size, "5,10,5"=600ms, "8,8,4"=480ms
python funasr_wss_client.py --host "0.0.0.0" --port 10095 --mode online --chunk_size "5,10,5" --audio_in "./data/wav.scp" --output_dir "./results"
```
##### ASR offline/online 2pass client
Recording from mircrophone
```shell
# --chunk_size, "5,10,5"=600ms, "8,8,4"=480ms
python funasr_wss_client.py --host "0.0.0.0" --port 10095 --mode 2pass --chunk_size "8,8,4"
```
Loadding from wav.scp(kaldi style)
```shell
# --chunk_size, "5,10,5"=600ms, "8,8,4"=480ms
python funasr_wss_client.py --host "0.0.0.0" --port 10095 --mode 2pass --chunk_size "8,8,4" --audio_in "./data/wav.scp" --output_dir "./results"
```
#### Websocket api
```shell
# class Funasr_websocket_recognizer example with 3 step
# 1.create an recognizer
rcg=Funasr_websocket_recognizer(host="127.0.0.1",port="30035",is_ssl=True,mode="2pass")
# 2.send pcm data to asr engine and get asr result
text=rcg.feed_chunk(data)
print("text",text)
# 3.get last result, set timeout=3
text=rcg.close(timeout=3)
print("text",text)
```
## Acknowledge
1. This project is maintained by [FunASR community](https://github.com/alibaba-damo-academy/FunASR).
2. We acknowledge [zhaoming](https://github.com/zhaomingwork/FunASR/tree/fix_bug_for_python_websocket) for contributing the websocket service.
3. We acknowledge [cgisky1980](https://github.com/cgisky1980/FunASR) for contributing the websocket service of offline model.

View File

@@ -0,0 +1,160 @@
"""
Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights
Reserved. MIT License (https://opensource.org/licenses/MIT)
2022-2023 by zhaomingwork@qq.com
"""
# pip install websocket-client
import ssl
from websocket import ABNF
from websocket import create_connection
from queue import Queue
import threading
import traceback
import json
import time
import numpy as np
# class for recognizer in websocket
class Funasr_websocket_recognizer:
"""
python asr recognizer lib
"""
def __init__(
self,
host="127.0.0.1",
port="30035",
is_ssl=True,
chunk_size="0, 10, 5",
chunk_interval=10,
mode="offline",
wav_name="default",
):
"""
host: server host ip
port: server port
is_ssl: True for wss protocal, False for ws
"""
try:
if is_ssl == True:
ssl_context = ssl.SSLContext()
ssl_context.check_hostname = False
ssl_context.verify_mode = ssl.CERT_NONE
uri = "wss://{}:{}".format(host, port)
ssl_opt = {"cert_reqs": ssl.CERT_NONE}
else:
uri = "ws://{}:{}".format(host, port)
ssl_context = None
ssl_opt = None
self.host = host
self.port = port
self.msg_queue = Queue() # used for recognized result text
print("connect to url", uri)
self.websocket = create_connection(uri, ssl=ssl_context, sslopt=ssl_opt)
self.thread_msg = threading.Thread(
target=Funasr_websocket_recognizer.thread_rec_msg, args=(self,)
)
self.thread_msg.start()
chunk_size = [int(x) for x in chunk_size.split(",")]
stride = int(60 * chunk_size[1] / chunk_interval / 1000 * 16000 * 2)
chunk_num = (len(audio_bytes) - 1) // stride + 1
message = json.dumps(
{
"mode": mode,
"chunk_size": chunk_size,
"encoder_chunk_look_back": 4,
"decoder_chunk_look_back": 1,
"chunk_interval": chunk_interval,
"wav_name": wav_name,
"is_speaking": True,
}
)
self.websocket.send(message)
print("send json", message)
except Exception as e:
print("Exception:", e)
traceback.print_exc()
# threads for rev msg
def thread_rec_msg(self):
try:
while True:
msg = self.websocket.recv()
if msg is None or len(msg) == 0:
continue
msg = json.loads(msg)
self.msg_queue.put(msg)
except Exception as e:
print("client closed")
# feed data to asr engine, wait_time means waiting for result until time out
def feed_chunk(self, chunk, wait_time=0.01):
try:
self.websocket.send(chunk, ABNF.OPCODE_BINARY)
# loop to check if there is a message, timeout in 0.01s
while True:
msg = self.msg_queue.get(timeout=wait_time)
if self.msg_queue.empty():
break
return msg
except:
return ""
def close(self, timeout=1):
message = json.dumps({"is_speaking": False})
self.websocket.send(message)
# sleep for timeout seconds to wait for result
time.sleep(timeout)
msg = ""
while not self.msg_queue.empty():
msg = self.msg_queue.get()
self.websocket.close()
# only resturn the last msg
return msg
if __name__ == "__main__":
print("example for Funasr_websocket_recognizer")
import wave
wav_path = "/Users/zhifu/Downloads/modelscope_models/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/example/asr_example.wav"
with wave.open(wav_path, "rb") as wav_file:
params = wav_file.getparams()
frames = wav_file.readframes(wav_file.getnframes())
audio_bytes = bytes(frames)
stride = int(60 * 10 / 10 / 1000 * 16000 * 2)
chunk_num = (len(audio_bytes) - 1) // stride + 1
# create an recognizer
rcg = Funasr_websocket_recognizer(
host="127.0.0.1", port="10095", is_ssl=True, mode="2pass", chunk_size="0,10,5"
)
# loop to send chunk
for i in range(chunk_num):
beg = i * stride
data = audio_bytes[beg : beg + stride]
text = rcg.feed_chunk(data, wait_time=0.02)
if len(text) > 0:
print("text", text)
time.sleep(0.05)
# get last message
text = rcg.close(timeout=3)
print("text", text)

View File

@@ -0,0 +1,393 @@
# -*- encoding: utf-8 -*-
import os
import time
import websockets, ssl
import asyncio
# import threading
import argparse
import json
import traceback
from multiprocessing import Process
# from funasr.fileio.datadir_writer import DatadirWriter
import logging
logging.basicConfig(level=logging.ERROR)
parser = argparse.ArgumentParser()
parser.add_argument(
"--host", type=str, default="localhost", required=False, help="host ip, localhost, 0.0.0.0"
)
parser.add_argument("--port", type=int, default=10095, required=False, help="grpc server port")
parser.add_argument("--chunk_size", type=str, default="5, 10, 5", help="chunk")
parser.add_argument("--encoder_chunk_look_back", type=int, default=4, help="chunk")
parser.add_argument("--decoder_chunk_look_back", type=int, default=0, help="chunk")
parser.add_argument("--chunk_interval", type=int, default=10, help="chunk")
parser.add_argument(
"--hotword",
type=str,
default="",
help="hotword file path, one hotword perline (e.g.:阿里巴巴 20)",
)
parser.add_argument("--audio_in", type=str, default=None, help="audio_in")
parser.add_argument("--audio_fs", type=int, default=16000, help="audio_fs")
parser.add_argument(
"--send_without_sleep",
action="store_true",
default=True,
help="if audio_in is set, send_without_sleep",
)
parser.add_argument("--thread_num", type=int, default=1, help="thread_num")
parser.add_argument("--words_max_print", type=int, default=10000, help="chunk")
parser.add_argument("--output_dir", type=str, default=None, help="output_dir")
parser.add_argument("--ssl", type=int, default=1, help="1 for ssl connect, 0 for no ssl")
parser.add_argument("--use_itn", type=int, default=1, help="1 for using itn, 0 for not itn")
parser.add_argument("--mode", type=str, default="2pass", help="offline, online, 2pass")
args = parser.parse_args()
args.chunk_size = [int(x) for x in args.chunk_size.split(",")]
print(args)
# voices = asyncio.Queue()
from queue import Queue
voices = Queue()
offline_msg_done = False
if args.output_dir is not None:
# if os.path.exists(args.output_dir):
# os.remove(args.output_dir)
if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir)
async def record_microphone():
is_finished = False
import pyaudio
# print("2")
global voices
FORMAT = pyaudio.paInt16
CHANNELS = 1
RATE = 16000
chunk_size = 60 * args.chunk_size[1] / args.chunk_interval
CHUNK = int(RATE / 1000 * chunk_size)
p = pyaudio.PyAudio()
stream = p.open(
format=FORMAT, channels=CHANNELS, rate=RATE, input=True, frames_per_buffer=CHUNK
)
# hotwords
fst_dict = {}
hotword_msg = ""
if args.hotword.strip() != "":
if os.path.exists(args.hotword):
f_scp = open(args.hotword)
hot_lines = f_scp.readlines()
for line in hot_lines:
words = line.strip().split(" ")
if len(words) < 2:
print("Please checkout format of hotwords")
continue
try:
fst_dict[" ".join(words[:-1])] = int(words[-1])
except ValueError:
print("Please checkout format of hotwords")
hotword_msg = json.dumps(fst_dict)
else:
hotword_msg = args.hotword
use_itn = True
if args.use_itn == 0:
use_itn = False
message = json.dumps(
{
"mode": args.mode,
"chunk_size": args.chunk_size,
"chunk_interval": args.chunk_interval,
"encoder_chunk_look_back": args.encoder_chunk_look_back,
"decoder_chunk_look_back": args.decoder_chunk_look_back,
"wav_name": "microphone",
"is_speaking": True,
"hotwords": hotword_msg,
"itn": use_itn,
}
)
# voices.put(message)
await websocket.send(message)
while True:
data = stream.read(CHUNK)
message = data
# voices.put(message)
await websocket.send(message)
await asyncio.sleep(0.005)
async def record_from_scp(chunk_begin, chunk_size):
global voices
is_finished = False
if args.audio_in.endswith(".scp"):
f_scp = open(args.audio_in)
wavs = f_scp.readlines()
else:
wavs = [args.audio_in]
# hotwords
fst_dict = {}
hotword_msg = ""
if args.hotword.strip() != "":
if os.path.exists(args.hotword):
f_scp = open(args.hotword)
hot_lines = f_scp.readlines()
for line in hot_lines:
words = line.strip().split(" ")
if len(words) < 2:
print("Please checkout format of hotwords")
continue
try:
fst_dict[" ".join(words[:-1])] = int(words[-1])
except ValueError:
print("Please checkout format of hotwords")
hotword_msg = json.dumps(fst_dict)
else:
hotword_msg = args.hotword
print(hotword_msg)
sample_rate = args.audio_fs
wav_format = "pcm"
use_itn = True
if args.use_itn == 0:
use_itn = False
if chunk_size > 0:
wavs = wavs[chunk_begin : chunk_begin + chunk_size]
for wav in wavs:
wav_splits = wav.strip().split()
wav_name = wav_splits[0] if len(wav_splits) > 1 else "demo"
wav_path = wav_splits[1] if len(wav_splits) > 1 else wav_splits[0]
if not len(wav_path.strip()) > 0:
continue
if wav_path.endswith(".pcm"):
with open(wav_path, "rb") as f:
audio_bytes = f.read()
elif wav_path.endswith(".wav"):
import wave
with wave.open(wav_path, "rb") as wav_file:
params = wav_file.getparams()
sample_rate = wav_file.getframerate()
frames = wav_file.readframes(wav_file.getnframes())
audio_bytes = bytes(frames)
else:
wav_format = "others"
with open(wav_path, "rb") as f:
audio_bytes = f.read()
stride = int(60 * args.chunk_size[1] / args.chunk_interval / 1000 * sample_rate * 2)
chunk_num = (len(audio_bytes) - 1) // stride + 1
# print(stride)
# send first time
message = json.dumps(
{
"mode": args.mode,
"chunk_size": args.chunk_size,
"chunk_interval": args.chunk_interval,
"encoder_chunk_look_back": args.encoder_chunk_look_back,
"decoder_chunk_look_back": args.decoder_chunk_look_back,
"audio_fs": sample_rate,
"wav_name": wav_name,
"wav_format": wav_format,
"is_speaking": True,
"hotwords": hotword_msg,
"itn": use_itn,
}
)
# voices.put(message)
await websocket.send(message)
is_speaking = True
for i in range(chunk_num):
beg = i * stride
data = audio_bytes[beg : beg + stride]
message = data
# voices.put(message)
await websocket.send(message)
if i == chunk_num - 1:
is_speaking = False
message = json.dumps({"is_speaking": is_speaking})
# voices.put(message)
await websocket.send(message)
sleep_duration = (
0.001
if args.mode == "offline"
else 60 * args.chunk_size[1] / args.chunk_interval / 1000
)
await asyncio.sleep(sleep_duration)
if not args.mode == "offline":
await asyncio.sleep(2)
# offline model need to wait for message recved
if args.mode == "offline":
global offline_msg_done
while not offline_msg_done:
await asyncio.sleep(1)
await websocket.close()
async def message(id):
global websocket, voices, offline_msg_done
text_print = ""
text_print_2pass_online = ""
text_print_2pass_offline = ""
if args.output_dir is not None:
ibest_writer = open(
os.path.join(args.output_dir, "text.{}".format(id)), "a", encoding="utf-8"
)
else:
ibest_writer = None
try:
while True:
meg = await websocket.recv()
meg = json.loads(meg)
wav_name = meg.get("wav_name", "demo")
text = meg["text"]
timestamp = ""
offline_msg_done = meg.get("is_final", False)
if "timestamp" in meg:
timestamp = meg["timestamp"]
if ibest_writer is not None:
if timestamp != "":
text_write_line = "{}\t{}\t{}\n".format(wav_name, text, timestamp)
else:
text_write_line = "{}\t{}\n".format(wav_name, text)
ibest_writer.write(text_write_line)
if "mode" not in meg:
continue
if meg["mode"] == "online":
text_print += "{}".format(text)
text_print = text_print[-args.words_max_print :]
os.system("clear")
print("\rpid" + str(id) + ": " + text_print)
elif meg["mode"] == "offline":
if timestamp != "":
text_print += "{} timestamp: {}".format(text, timestamp)
else:
text_print += "{}".format(text)
# text_print = text_print[-args.words_max_print:]
# os.system('clear')
print("\rpid" + str(id) + ": " + wav_name + ": " + text_print)
offline_msg_done = True
else:
if meg["mode"] == "2pass-online":
text_print_2pass_online += "{}".format(text)
text_print = text_print_2pass_offline + text_print_2pass_online
else:
text_print_2pass_online = ""
text_print = text_print_2pass_offline + "{}".format(text)
text_print_2pass_offline += "{}".format(text)
text_print = text_print[-args.words_max_print :]
os.system("clear")
print("\rpid" + str(id) + ": " + text_print)
# offline_msg_done=True
except Exception as e:
print("Exception:", e)
# traceback.print_exc()
# await websocket.close()
async def ws_client(id, chunk_begin, chunk_size):
if args.audio_in is None:
chunk_begin = 0
chunk_size = 1
global websocket, voices, offline_msg_done
for i in range(chunk_begin, chunk_begin + chunk_size):
offline_msg_done = False
voices = Queue()
if args.ssl == 1:
ssl_context = ssl.SSLContext()
ssl_context.check_hostname = False
ssl_context.verify_mode = ssl.CERT_NONE
uri = "wss://{}:{}".format(args.host, args.port)
else:
uri = "ws://{}:{}".format(args.host, args.port)
ssl_context = None
print("connect to", uri)
async with websockets.connect(
uri, subprotocols=["binary"], ping_interval=None, ssl=ssl_context
) as websocket:
if args.audio_in is not None:
task = asyncio.create_task(record_from_scp(i, 1))
else:
task = asyncio.create_task(record_microphone())
task3 = asyncio.create_task(message(str(id) + "_" + str(i))) # processid+fileid
await asyncio.gather(task, task3)
exit(0)
def one_thread(id, chunk_begin, chunk_size):
asyncio.get_event_loop().run_until_complete(ws_client(id, chunk_begin, chunk_size))
asyncio.get_event_loop().run_forever()
if __name__ == "__main__":
# for microphone
if args.audio_in is None:
p = Process(target=one_thread, args=(0, 0, 0))
p.start()
p.join()
print("end")
else:
# calculate the number of wavs for each preocess
if args.audio_in.endswith(".scp"):
f_scp = open(args.audio_in)
wavs = f_scp.readlines()
else:
wavs = [args.audio_in]
for wav in wavs:
wav_splits = wav.strip().split()
wav_name = wav_splits[0] if len(wav_splits) > 1 else "demo"
wav_path = wav_splits[1] if len(wav_splits) > 1 else wav_splits[0]
audio_type = os.path.splitext(wav_path)[-1].lower()
total_len = len(wavs)
if total_len >= args.thread_num:
chunk_size = int(total_len / args.thread_num)
remain_wavs = total_len - chunk_size * args.thread_num
else:
chunk_size = 1
remain_wavs = 0
process_list = []
chunk_begin = 0
for i in range(args.thread_num):
now_chunk_size = chunk_size
if remain_wavs > 0:
now_chunk_size = chunk_size + 1
remain_wavs = remain_wavs - 1
# process i handle wavs at chunk_begin and size of now_chunk_size
p = Process(target=one_thread, args=(i, chunk_begin, now_chunk_size))
chunk_begin = chunk_begin + now_chunk_size
p.start()
process_list.append(p)
for i in process_list:
p.join()
print("end")

View File

@@ -0,0 +1,345 @@
import asyncio
import json
import websockets
import time
import logging
import tracemalloc
import numpy as np
import argparse
import ssl
parser = argparse.ArgumentParser()
parser.add_argument(
"--host", type=str, default="0.0.0.0", required=False, help="host ip, localhost, 0.0.0.0"
)
parser.add_argument("--port", type=int, default=10095, required=False, help="grpc server port")
parser.add_argument(
"--asr_model",
type=str,
default="iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
help="model from modelscope",
)
parser.add_argument("--asr_model_revision", type=str, default="v2.0.4", help="")
parser.add_argument(
"--asr_model_online",
type=str,
default="iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online",
help="model from modelscope",
)
parser.add_argument("--asr_model_online_revision", type=str, default="v2.0.4", help="")
parser.add_argument(
"--vad_model",
type=str,
default="iic/speech_fsmn_vad_zh-cn-16k-common-pytorch",
help="model from modelscope",
)
parser.add_argument("--vad_model_revision", type=str, default="v2.0.4", help="")
parser.add_argument(
"--punc_model",
type=str,
default="iic/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727",
help="model from modelscope",
)
parser.add_argument("--punc_model_revision", type=str, default="v2.0.4", help="")
parser.add_argument("--ngpu", type=int, default=1, help="0 for cpu, 1 for gpu")
parser.add_argument("--device", type=str, default="cuda", help="cuda, cpu")
parser.add_argument("--ncpu", type=int, default=4, help="cpu cores")
parser.add_argument(
"--certfile",
type=str,
default="../../ssl_key/server.crt",
required=False,
help="certfile for ssl",
)
parser.add_argument(
"--keyfile",
type=str,
default="../../ssl_key/server.key",
required=False,
help="keyfile for ssl",
)
args = parser.parse_args()
websocket_users = set()
print("model loading")
from funasr import AutoModel
# asr
model_asr = AutoModel(
model=args.asr_model,
model_revision=args.asr_model_revision,
ngpu=args.ngpu,
ncpu=args.ncpu,
device=args.device,
disable_pbar=True,
disable_log=True,
)
# asr
model_asr_streaming = AutoModel(
model=args.asr_model_online,
model_revision=args.asr_model_online_revision,
ngpu=args.ngpu,
ncpu=args.ncpu,
device=args.device,
disable_pbar=True,
disable_log=True,
)
# vad
model_vad = AutoModel(
model=args.vad_model,
model_revision=args.vad_model_revision,
ngpu=args.ngpu,
ncpu=args.ncpu,
device=args.device,
disable_pbar=True,
disable_log=True,
# chunk_size=60,
)
if args.punc_model != "":
model_punc = AutoModel(
model=args.punc_model,
model_revision=args.punc_model_revision,
ngpu=args.ngpu,
ncpu=args.ncpu,
device=args.device,
disable_pbar=True,
disable_log=True,
)
else:
model_punc = None
print("model loaded! only support one client at the same time now!!!!")
async def ws_reset(websocket):
print("ws reset now, total num is ", len(websocket_users))
websocket.status_dict_asr_online["cache"] = {}
websocket.status_dict_asr_online["is_final"] = True
websocket.status_dict_vad["cache"] = {}
websocket.status_dict_vad["is_final"] = True
websocket.status_dict_punc["cache"] = {}
await websocket.close()
async def clear_websocket():
for websocket in websocket_users:
await ws_reset(websocket)
websocket_users.clear()
async def ws_serve(websocket, path):
frames = []
frames_asr = []
frames_asr_online = []
global websocket_users
# await clear_websocket()
websocket_users.add(websocket)
websocket.status_dict_asr = {}
websocket.status_dict_asr_online = {"cache": {}, "is_final": False}
websocket.status_dict_vad = {"cache": {}, "is_final": False}
websocket.status_dict_punc = {"cache": {}}
websocket.chunk_interval = 10
websocket.vad_pre_idx = 0
speech_start = False
speech_end_i = -1
websocket.wav_name = "microphone"
websocket.mode = "2pass"
print("new user connected", flush=True)
try:
async for message in websocket:
if isinstance(message, str):
messagejson = json.loads(message)
if "is_speaking" in messagejson:
websocket.is_speaking = messagejson["is_speaking"]
websocket.status_dict_asr_online["is_final"] = not websocket.is_speaking
if "chunk_interval" in messagejson:
websocket.chunk_interval = messagejson["chunk_interval"]
if "wav_name" in messagejson:
websocket.wav_name = messagejson.get("wav_name")
if "chunk_size" in messagejson:
chunk_size = messagejson["chunk_size"]
if isinstance(chunk_size, str):
chunk_size = chunk_size.split(",")
websocket.status_dict_asr_online["chunk_size"] = [int(x) for x in chunk_size]
if "encoder_chunk_look_back" in messagejson:
websocket.status_dict_asr_online["encoder_chunk_look_back"] = messagejson[
"encoder_chunk_look_back"
]
if "decoder_chunk_look_back" in messagejson:
websocket.status_dict_asr_online["decoder_chunk_look_back"] = messagejson[
"decoder_chunk_look_back"
]
if "hotword" in messagejson:
websocket.status_dict_asr["hotword"] = messagejson["hotwords"]
if "mode" in messagejson:
websocket.mode = messagejson["mode"]
websocket.status_dict_vad["chunk_size"] = int(
websocket.status_dict_asr_online["chunk_size"][1] * 60 / websocket.chunk_interval
)
if len(frames_asr_online) > 0 or len(frames_asr) >= 0 or not isinstance(message, str):
if not isinstance(message, str):
frames.append(message)
duration_ms = len(message) // 32
websocket.vad_pre_idx += duration_ms
# asr online
frames_asr_online.append(message)
websocket.status_dict_asr_online["is_final"] = speech_end_i != -1
if (
len(frames_asr_online) % websocket.chunk_interval == 0
or websocket.status_dict_asr_online["is_final"]
):
if websocket.mode == "2pass" or websocket.mode == "online":
audio_in = b"".join(frames_asr_online)
try:
await async_asr_online(websocket, audio_in)
except:
print(f"error in asr streaming, {websocket.status_dict_asr_online}")
frames_asr_online = []
if speech_start:
frames_asr.append(message)
# vad online
try:
speech_start_i, speech_end_i = await async_vad(websocket, message)
except:
print("error in vad")
if speech_start_i != -1:
speech_start = True
beg_bias = (websocket.vad_pre_idx - speech_start_i) // duration_ms
frames_pre = frames[-beg_bias:]
frames_asr = []
frames_asr.extend(frames_pre)
# asr punc offline
if speech_end_i != -1 or not websocket.is_speaking:
# print("vad end point")
if websocket.mode == "2pass" or websocket.mode == "offline":
audio_in = b"".join(frames_asr)
try:
await async_asr(websocket, audio_in)
except:
print("error in asr offline")
frames_asr = []
speech_start = False
frames_asr_online = []
websocket.status_dict_asr_online["cache"] = {}
if not websocket.is_speaking:
websocket.vad_pre_idx = 0
frames = []
websocket.status_dict_vad["cache"] = {}
else:
frames = frames[-20:]
except websockets.ConnectionClosed:
print("ConnectionClosed...", websocket_users, flush=True)
await ws_reset(websocket)
websocket_users.remove(websocket)
except websockets.InvalidState:
print("InvalidState...")
except Exception as e:
print("Exception:", e)
async def async_vad(websocket, audio_in):
segments_result = model_vad.generate(input=audio_in, **websocket.status_dict_vad)[0]["value"]
# print(segments_result)
speech_start = -1
speech_end = -1
if len(segments_result) == 0 or len(segments_result) > 1:
return speech_start, speech_end
if segments_result[0][0] != -1:
speech_start = segments_result[0][0]
if segments_result[0][1] != -1:
speech_end = segments_result[0][1]
return speech_start, speech_end
async def async_asr(websocket, audio_in):
if len(audio_in) > 0:
# print(len(audio_in))
rec_result = model_asr.generate(input=audio_in, **websocket.status_dict_asr)[0]
# print("offline_asr, ", rec_result)
if model_punc is not None and len(rec_result["text"]) > 0:
# print("offline, before punc", rec_result, "cache", websocket.status_dict_punc)
rec_result = model_punc.generate(
input=rec_result["text"], **websocket.status_dict_punc
)[0]
# print("offline, after punc", rec_result)
if len(rec_result["text"]) > 0:
# print("offline", rec_result)
mode = "2pass-offline" if "2pass" in websocket.mode else websocket.mode
message = json.dumps(
{
"mode": mode,
"text": rec_result["text"],
"wav_name": websocket.wav_name,
"is_final": websocket.is_speaking,
}
)
await websocket.send(message)
else:
mode = "2pass-offline" if "2pass" in websocket.mode else websocket.mode
message = json.dumps(
{
"mode": mode,
"text": "",
"wav_name": websocket.wav_name,
"is_final": websocket.is_speaking,
}
)
await websocket.send(message)
async def async_asr_online(websocket, audio_in):
if len(audio_in) > 0:
# print(websocket.status_dict_asr_online.get("is_final", False))
rec_result = model_asr_streaming.generate(
input=audio_in, **websocket.status_dict_asr_online
)[0]
# print("online, ", rec_result)
if websocket.mode == "2pass" and websocket.status_dict_asr_online.get("is_final", False):
return
# websocket.status_dict_asr_online["cache"] = dict()
if len(rec_result["text"]):
mode = "2pass-online" if "2pass" in websocket.mode else websocket.mode
message = json.dumps(
{
"mode": mode,
"text": rec_result["text"],
"wav_name": websocket.wav_name,
"is_final": websocket.is_speaking,
}
)
await websocket.send(message)
if len(args.certfile) > 0:
ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
# Generate with Lets Encrypt, copied to this location, chown to current user and 400 permissions
ssl_cert = args.certfile
ssl_key = args.keyfile
ssl_context.load_cert_chain(ssl_cert, keyfile=ssl_key)
start_server = websockets.serve(
ws_serve, args.host, args.port, subprotocols=["binary"], ping_interval=None, ssl=ssl_context
)
else:
start_server = websockets.serve(
ws_serve, args.host, args.port, subprotocols=["binary"], ping_interval=None
)
asyncio.get_event_loop().run_until_complete(start_server)
asyncio.get_event_loop().run_forever()

View File

@@ -0,0 +1,2 @@
websockets
pyaudio

View File

@@ -0,0 +1 @@
websockets