mirror of
https://gitcode.com/gh_mirrors/eas/EasyFace.git
synced 2026-05-15 03:38:04 +00:00
194 lines
6.2 KiB
Python
194 lines
6.2 KiB
Python
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
|
|
import json
|
|
import logging
|
|
from collections import OrderedDict
|
|
|
|
import numpy as np
|
|
|
|
from modelscope.utils.logger import get_logger
|
|
|
|
from . import ontology
|
|
|
|
logger = get_logger()
|
|
|
|
|
|
def max_lens(X):
|
|
lens = [len(X)]
|
|
while isinstance(X[0], list):
|
|
lens.append(max(map(len, X)))
|
|
X = [x for xs in X for x in xs]
|
|
return lens
|
|
|
|
|
|
def list2np(X: object, padding: object = 0, dtype: object = 'int64') -> object:
|
|
shape = max_lens(X)
|
|
ret = np.full(shape, padding, dtype=np.int32)
|
|
|
|
if len(shape) == 1:
|
|
ret = np.array(X)
|
|
elif len(shape) == 2:
|
|
for i, x in enumerate(X):
|
|
ret[i, :len(x)] = np.array(x)
|
|
elif len(shape) == 3:
|
|
for i, xs in enumerate(X):
|
|
for j, x in enumerate(xs):
|
|
ret[i, j, :len(x)] = np.array(x)
|
|
return ret.astype(dtype)
|
|
|
|
|
|
def clean_replace(s, r, t, forward=True, backward=False):
|
|
def clean_replace_single(s, r, t, forward, backward, sidx=0):
|
|
# idx = s[sidx:].find(r)
|
|
idx = s.find(r)
|
|
if idx == -1:
|
|
return s, -1
|
|
idx_r = idx + len(r)
|
|
if backward:
|
|
while idx > 0 and s[idx - 1]:
|
|
idx -= 1
|
|
elif idx > 0 and s[idx - 1] != ' ':
|
|
return s, -1
|
|
|
|
if forward:
|
|
while \
|
|
idx_r < len(s) and (s[idx_r].isalpha() or s[idx_r].isdigit()):
|
|
idx_r += 1
|
|
elif idx_r != len(s) and (s[idx_r].isalpha() or s[idx_r].isdigit()):
|
|
return s, -1
|
|
return s[:idx] + t + s[idx_r:], idx_r
|
|
|
|
sidx = 0
|
|
while sidx != -1:
|
|
s, sidx = clean_replace_single(s, r, t, forward, backward, sidx)
|
|
return s
|
|
|
|
|
|
def py2np(list):
|
|
return np.array(list)
|
|
|
|
|
|
def write_dict(fn, dic):
|
|
with open(fn, 'w') as f:
|
|
json.dump(dic, f, indent=2)
|
|
|
|
|
|
def f1_score(label_list, pred_list):
|
|
tp = len([t for t in pred_list if t in label_list])
|
|
fp = max(0, len(pred_list) - tp)
|
|
fn = max(0, len(label_list) - tp)
|
|
precision = tp / (tp + fp + 1e-10)
|
|
recall = tp / (tp + fn + 1e-10)
|
|
f1 = 2 * precision * recall / (precision + recall + 1e-10)
|
|
return f1
|
|
|
|
|
|
class MultiWOZVocab(object):
|
|
def __init__(self, vocab_size=0):
|
|
"""
|
|
vocab for multiwoz dataset
|
|
"""
|
|
self.vocab_size = vocab_size
|
|
self.vocab_size_oov = 0 # get after construction
|
|
self._idx2word = {} # word + oov
|
|
self._word2idx = {} # word
|
|
self._freq_dict = {} # word + oov
|
|
for w in [
|
|
'[PAD]', '<go_r>', '[UNK]', '<go_b>', '<go_a>', '<eos_u>',
|
|
'<eos_r>', '<eos_b>', '<eos_a>', '<go_d>', '<eos_d>'
|
|
]:
|
|
self._absolute_add_word(w)
|
|
|
|
def _absolute_add_word(self, w):
|
|
idx = len(self._idx2word)
|
|
self._idx2word[idx] = w
|
|
self._word2idx[w] = idx
|
|
|
|
def add_word(self, word):
|
|
if word not in self._freq_dict:
|
|
self._freq_dict[word] = 0
|
|
self._freq_dict[word] += 1
|
|
|
|
def has_word(self, word):
|
|
return self._freq_dict.get(word)
|
|
|
|
def _add_to_vocab(self, word):
|
|
if word not in self._word2idx:
|
|
idx = len(self._idx2word)
|
|
self._idx2word[idx] = word
|
|
self._word2idx[word] = idx
|
|
|
|
def construct(self):
|
|
freq_dict_sorted = sorted(self._freq_dict.keys(),
|
|
key=lambda x: -self._freq_dict[x])
|
|
logger.info('Vocabulary size including oov: %d' %
|
|
(len(freq_dict_sorted) + len(self._idx2word)))
|
|
if len(freq_dict_sorted) + len(self._idx2word) < self.vocab_size:
|
|
logging.warning(
|
|
'actual label set smaller than that configured: {}/{}'.format(
|
|
len(freq_dict_sorted) + len(self._idx2word),
|
|
self.vocab_size))
|
|
for word in ontology.all_domains + ['general']:
|
|
word = '[' + word + ']'
|
|
self._add_to_vocab(word)
|
|
for word in ontology.all_acts:
|
|
word = '[' + word + ']'
|
|
self._add_to_vocab(word)
|
|
for word in ontology.all_slots:
|
|
self._add_to_vocab(word)
|
|
for word in freq_dict_sorted:
|
|
if word.startswith('[value_') and word.endswith(']'):
|
|
self._add_to_vocab(word)
|
|
for word in freq_dict_sorted:
|
|
self._add_to_vocab(word)
|
|
self.vocab_size_oov = len(self._idx2word)
|
|
|
|
def load_vocab(self, vocab_path):
|
|
self._freq_dict = json.loads(
|
|
open(vocab_path + '.freq.json', 'r', encoding='utf-8').read())
|
|
self._word2idx = json.loads(
|
|
open(vocab_path + '.word2idx.json', 'r', encoding='utf-8').read())
|
|
self._idx2word = {}
|
|
for w, idx in self._word2idx.items():
|
|
self._idx2word[idx] = w
|
|
self.vocab_size_oov = len(self._idx2word)
|
|
logger.info('vocab file loaded from "' + vocab_path + '"')
|
|
logger.info('Vocabulary size including oov: %d' %
|
|
(self.vocab_size_oov))
|
|
|
|
def save_vocab(self, vocab_path):
|
|
_freq_dict = OrderedDict(
|
|
sorted(self._freq_dict.items(), key=lambda kv: kv[1],
|
|
reverse=True))
|
|
write_dict(vocab_path + '.word2idx.json', self._word2idx)
|
|
write_dict(vocab_path + '.freq.json', _freq_dict)
|
|
|
|
def encode(self, word, include_oov=True):
|
|
if include_oov:
|
|
if self._word2idx.get(word, None) is None:
|
|
raise ValueError(
|
|
'Unknown word: %s. Vocabulary should include oovs here.' %
|
|
word)
|
|
return self._word2idx[word]
|
|
else:
|
|
word = '<unk>' if word not in self._word2idx else word
|
|
return self._word2idx[word]
|
|
|
|
def sentence_encode(self, word_list):
|
|
return [self.encode(_) for _ in word_list]
|
|
|
|
def oov_idx_map(self, idx):
|
|
return 2 if idx > self.vocab_size else idx
|
|
|
|
def sentence_oov_map(self, index_list):
|
|
return [self.oov_idx_map(_) for _ in index_list]
|
|
|
|
def decode(self, idx, indicate_oov=False):
|
|
if not self._idx2word.get(idx):
|
|
raise ValueError(
|
|
'Error idx: %d. Vocabulary should include oovs here.' % idx)
|
|
if not indicate_oov or idx < self.vocab_size:
|
|
return self._idx2word[idx]
|
|
else:
|
|
return self._idx2word[idx] + '(o)'
|