mirror of
https://gitee.com/270580156/weiyu.git
synced 2026-05-17 20:57:51 +00:00
Sync from bytedesk-private: update
This commit is contained in:
284
modules/python/vendors/FunASR/funasr/utils/kws_utils.py
vendored
Normal file
284
modules/python/vendors/FunASR/funasr/utils/kws_utils.py
vendored
Normal file
@@ -0,0 +1,284 @@
|
||||
import re
|
||||
import logging
|
||||
|
||||
import torch
|
||||
import math
|
||||
from collections import defaultdict
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
|
||||
symbol_str = '[’!"#$%&\'()*+,-./:;<>=?@,。?★、…【】《》?“”‘’![\\]^_`{|}~\s]+'
|
||||
|
||||
|
||||
def split_mixed_label(input_str):
|
||||
tokens = []
|
||||
s = input_str.lower()
|
||||
while len(s) > 0:
|
||||
match = re.match(r'[A-Za-z!?,<>()\']+', s)
|
||||
if match is not None:
|
||||
word = match.group(0)
|
||||
else:
|
||||
word = s[0:1]
|
||||
tokens.append(word)
|
||||
s = s.replace(word, '', 1).strip(' ')
|
||||
return tokens
|
||||
|
||||
|
||||
def query_token_set(txt, symbol_table, lexicon_table):
|
||||
tokens_str = tuple()
|
||||
tokens_idx = tuple()
|
||||
|
||||
if txt in symbol_table:
|
||||
tokens_str = tokens_str + (txt, )
|
||||
tokens_idx = tokens_idx + (symbol_table[txt], )
|
||||
return tokens_str, tokens_idx
|
||||
|
||||
parts = split_mixed_label(txt)
|
||||
for part in parts:
|
||||
if part == '!sil' or part == '(sil)' or part == '<sil>':
|
||||
tokens_str = tokens_str + ('!sil', )
|
||||
elif part == '<blank>' or part == '<blank>':
|
||||
tokens_str = tokens_str + ('<blank>', )
|
||||
elif part == '(noise)' or part == 'noise)' or part == '(noise' or part == '<noise>':
|
||||
tokens_str = tokens_str + ('<unk>', )
|
||||
elif part in symbol_table:
|
||||
tokens_str = tokens_str + (part, )
|
||||
elif part in lexicon_table:
|
||||
for ch in lexicon_table[part]:
|
||||
tokens_str = tokens_str + (ch, )
|
||||
else:
|
||||
part = re.sub(symbol_str, '', part)
|
||||
for ch in part:
|
||||
tokens_str = tokens_str + (ch, )
|
||||
|
||||
for ch in tokens_str:
|
||||
if ch in symbol_table:
|
||||
tokens_idx = tokens_idx + (symbol_table[ch], )
|
||||
elif ch == '!sil':
|
||||
if 'sil' in symbol_table:
|
||||
tokens_idx = tokens_idx + (symbol_table['sil'], )
|
||||
else:
|
||||
tokens_idx = tokens_idx + (symbol_table['<blank>'], )
|
||||
elif ch == '<unk>':
|
||||
if '<unk>' in symbol_table:
|
||||
tokens_idx = tokens_idx + (symbol_table['<unk>'], )
|
||||
else:
|
||||
tokens_idx = tokens_idx + (symbol_table['<blank>'], )
|
||||
else:
|
||||
if '<unk>' in symbol_table:
|
||||
tokens_idx = tokens_idx + (symbol_table['<unk>'], )
|
||||
logging.info(f'\'{ch}\' is not in token set, replace with <unk>')
|
||||
else:
|
||||
tokens_idx = tokens_idx + (symbol_table['<blank>'], )
|
||||
logging.info(f'\'{ch}\' is not in token set, replace with <blank>')
|
||||
|
||||
return tokens_str, tokens_idx
|
||||
|
||||
|
||||
class KwsCtcPrefixDecoder():
|
||||
"""Decoder interface wrapper for CTCPrefixDecode."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
ctc: torch.nn.Module,
|
||||
keywords: str,
|
||||
token_list: list,
|
||||
seg_dict: dict,
|
||||
):
|
||||
"""Initialize class.
|
||||
|
||||
Args:
|
||||
ctc (torch.nn.Module): The CTC implementation.
|
||||
For example, :class:`espnet.nets.pytorch_backend.ctc.CTC`
|
||||
|
||||
"""
|
||||
self.ctc = ctc
|
||||
self.token_list = token_list
|
||||
|
||||
token_table = {}
|
||||
for token in token_list:
|
||||
token_table[token] = token_list.index(token)
|
||||
|
||||
self.keywords_idxset = {0}
|
||||
self.keywords_token = {}
|
||||
self.keywords_str = keywords
|
||||
keywords_list = self.keywords_str.strip().replace(' ', '').split(',')
|
||||
for keyword in keywords_list:
|
||||
strs, indexs = query_token_set(keyword, token_table, seg_dict)
|
||||
self.keywords_token[keyword] = {}
|
||||
self.keywords_token[keyword]['token_id'] = indexs
|
||||
self.keywords_token[keyword]['token_str'] = ''.join('%s ' % str(i) for i in indexs)
|
||||
[ self.keywords_idxset.add(i) for i in indexs ]
|
||||
|
||||
def beam_search(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
logits_lengths: torch.Tensor,
|
||||
keywords_tokenset: set = None,
|
||||
score_beam_size: int = 3,
|
||||
path_beam_size: int = 20,
|
||||
) -> Tuple[List[List[int]], torch.Tensor]:
|
||||
""" CTC prefix beam search inner implementation
|
||||
|
||||
Args:
|
||||
logits (torch.Tensor): (1, max_len, vocab_size)
|
||||
logits_lengths (torch.Tensor): (1, )
|
||||
keywords_tokenset (set): token set for filtering score
|
||||
score_beam_size (int): beam size for score
|
||||
path_beam_size (int): beam size for path
|
||||
|
||||
Returns:
|
||||
List[List[int]]: nbest results
|
||||
"""
|
||||
|
||||
maxlen = logits.size(0)
|
||||
ctc_probs = logits
|
||||
cur_hyps = [(tuple(), (1.0, 0.0, []))]
|
||||
|
||||
# CTC beam search step by step
|
||||
for t in range(0, maxlen):
|
||||
probs = ctc_probs[t] # (vocab_size,)
|
||||
# key: prefix, value (pb, pnb), default value(-inf, -inf)
|
||||
next_hyps = defaultdict(lambda: (0.0, 0.0, []))
|
||||
|
||||
# 2.1 First beam prune: select topk best
|
||||
top_k_probs, top_k_index = probs.topk(
|
||||
score_beam_size) # (score_beam_size,)
|
||||
|
||||
# filter prob score that is too small
|
||||
filter_probs = []
|
||||
filter_index = []
|
||||
for prob, idx in zip(top_k_probs.tolist(), top_k_index.tolist()):
|
||||
if keywords_tokenset is not None:
|
||||
if prob > 0.05 and idx in keywords_tokenset:
|
||||
filter_probs.append(prob)
|
||||
filter_index.append(idx)
|
||||
else:
|
||||
if prob > 0.05:
|
||||
filter_probs.append(prob)
|
||||
filter_index.append(idx)
|
||||
|
||||
if len(filter_index) == 0:
|
||||
continue
|
||||
|
||||
for s in filter_index:
|
||||
ps = probs[s].item()
|
||||
# print(f'frame:{t}, token:{s}, score:{ps}')
|
||||
|
||||
for prefix, (pb, pnb, cur_nodes) in cur_hyps:
|
||||
last = prefix[-1] if len(prefix) > 0 else None
|
||||
if s == 0: # blank
|
||||
n_pb, n_pnb, nodes = next_hyps[prefix]
|
||||
n_pb = n_pb + pb * ps + pnb * ps
|
||||
nodes = cur_nodes.copy()
|
||||
next_hyps[prefix] = (n_pb, n_pnb, nodes)
|
||||
elif s == last:
|
||||
if not math.isclose(pnb, 0.0, abs_tol=0.000001):
|
||||
# Update *ss -> *s;
|
||||
n_pb, n_pnb, nodes = next_hyps[prefix]
|
||||
n_pnb = n_pnb + pnb * ps
|
||||
nodes = cur_nodes.copy()
|
||||
if ps > nodes[-1]['prob']: # update frame and prob
|
||||
nodes[-1]['prob'] = ps
|
||||
nodes[-1]['frame'] = t
|
||||
next_hyps[prefix] = (n_pb, n_pnb, nodes)
|
||||
|
||||
if not math.isclose(pb, 0.0, abs_tol=0.000001):
|
||||
# Update *s-s -> *ss, - is for blank
|
||||
n_prefix = prefix + (s, )
|
||||
n_pb, n_pnb, nodes = next_hyps[n_prefix]
|
||||
n_pnb = n_pnb + pb * ps
|
||||
nodes = cur_nodes.copy()
|
||||
nodes.append(dict(token=s, frame=t,
|
||||
prob=ps)) # to record token prob
|
||||
next_hyps[n_prefix] = (n_pb, n_pnb, nodes)
|
||||
else:
|
||||
n_prefix = prefix + (s, )
|
||||
n_pb, n_pnb, nodes = next_hyps[n_prefix]
|
||||
if nodes:
|
||||
if ps > nodes[-1]['prob']: # update frame and prob
|
||||
nodes[-1]['prob'] = ps
|
||||
nodes[-1]['frame'] = t
|
||||
else:
|
||||
nodes = cur_nodes.copy()
|
||||
nodes.append(dict(token=s, frame=t,
|
||||
prob=ps)) # to record token prob
|
||||
n_pnb = n_pnb + pb * ps + pnb * ps
|
||||
next_hyps[n_prefix] = (n_pb, n_pnb, nodes)
|
||||
|
||||
# 2.2 Second beam prune
|
||||
next_hyps = sorted(next_hyps.items(),
|
||||
key=lambda x: (x[1][0] + x[1][1]),
|
||||
reverse=True)
|
||||
|
||||
cur_hyps = next_hyps[:path_beam_size]
|
||||
|
||||
hyps = [(y[0], y[1][0] + y[1][1], y[1][2]) for y in cur_hyps]
|
||||
return hyps
|
||||
|
||||
|
||||
def is_sublist(self, main_list, check_list):
|
||||
if len(main_list) < len(check_list):
|
||||
return -1
|
||||
|
||||
if len(main_list) == len(check_list):
|
||||
return 0 if main_list == check_list else -1
|
||||
|
||||
for i in range(len(main_list) - len(check_list)):
|
||||
if main_list[i] == check_list[0]:
|
||||
for j in range(len(check_list)):
|
||||
if main_list[i + j] != check_list[j]:
|
||||
break
|
||||
else:
|
||||
return i
|
||||
else:
|
||||
return -1
|
||||
|
||||
|
||||
def _decode_inside(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
logits_lengths: torch.Tensor,
|
||||
):
|
||||
hyps = self.beam_search(logits, logits_lengths, self.keywords_idxset)
|
||||
|
||||
hit_keyword = None
|
||||
hit_score = 1.0
|
||||
# start = 0; end = 0
|
||||
for one_hyp in hyps:
|
||||
prefix_ids = one_hyp[0]
|
||||
# path_score = one_hyp[1]
|
||||
prefix_nodes = one_hyp[2]
|
||||
assert len(prefix_ids) == len(prefix_nodes)
|
||||
for word in self.keywords_token.keys():
|
||||
lab = self.keywords_token[word]['token_id']
|
||||
offset = self.is_sublist(prefix_ids, lab)
|
||||
if offset != -1:
|
||||
hit_keyword = word
|
||||
for idx in range(offset, offset + len(lab)):
|
||||
hit_score *= prefix_nodes[idx]['prob']
|
||||
break
|
||||
if hit_keyword is not None:
|
||||
hit_score = math.sqrt(hit_score)
|
||||
break
|
||||
|
||||
if hit_keyword is not None:
|
||||
return True, hit_keyword, hit_score
|
||||
else:
|
||||
return False, None, None
|
||||
|
||||
|
||||
def decode(self, x: torch.Tensor):
|
||||
"""Get an initial state for decoding.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): The encoded feature tensor
|
||||
|
||||
Returns: decode result
|
||||
|
||||
"""
|
||||
|
||||
raw_logp = self.ctc.softmax(x.unsqueeze(0)).detach().squeeze(0).cpu()
|
||||
xlen = torch.tensor([raw_logp.size(1)])
|
||||
|
||||
return self._decode_inside(raw_logp, xlen)
|
||||
Reference in New Issue
Block a user