mirror of
https://gitee.com/270580156/weiyu.git
synced 2026-05-19 05:37:53 +00:00
Sync from bytedesk-private: update
This commit is contained in:
86
modules/python/vendors/FunASR/runtime/onnxruntime/src/wfst-decoder.h
vendored
Normal file
86
modules/python/vendors/FunASR/runtime/onnxruntime/src/wfst-decoder.h
vendored
Normal file
@@ -0,0 +1,86 @@
|
||||
#ifndef WFST_DECODER_
|
||||
#define WFST_DECODER_
|
||||
#include "kaldi/decoder/lattice-faster-online-decoder.h"
|
||||
#include "model.h"
|
||||
#include "fst/fstlib.h"
|
||||
#include "fst/symbol-table.h"
|
||||
#include "bias-lm.h"
|
||||
#include "phone-set.h"
|
||||
#include "util.h"
|
||||
|
||||
#define MAX_SCORE 10.0f
|
||||
namespace funasr {
|
||||
class Decodable : public kaldi::DecodableInterface {
|
||||
public:
|
||||
Decodable(float scale = 1.0f) : scale_(scale) {
|
||||
Reset();
|
||||
}
|
||||
void Reset() {
|
||||
num_frames_ = 0;
|
||||
finished_ = false;
|
||||
logp_.clear();
|
||||
}
|
||||
|
||||
int NumFramesReady() const { return num_frames_; }
|
||||
|
||||
bool IsLastFrame(int frame) const {
|
||||
return finished_ && (frame == num_frames_ - 1);
|
||||
}
|
||||
|
||||
float LogLikelihood(int frm, int id) {
|
||||
CHECK_GT(id, 0);
|
||||
CHECK_LT(frm, num_frames_);
|
||||
return scale_ * logp_[id - 1];
|
||||
}
|
||||
|
||||
void AcceptLoglikes(const std::vector<float>& logp) {
|
||||
num_frames_++;
|
||||
logp_ = logp;
|
||||
}
|
||||
|
||||
int NumIndices() const { return 0; }
|
||||
void SetFinished() { finished_ = true; }
|
||||
|
||||
private:
|
||||
int num_frames_ = 0;
|
||||
float scale_ = 1.0f;
|
||||
bool finished_ = false;
|
||||
std::vector<float> logp_;
|
||||
};
|
||||
|
||||
struct DecodeOptions : public kaldi::LatticeFasterDecoderConfig {
|
||||
DecodeOptions(float glob_beam = 3.0f, float lat_beam = 3.0f, float ac_sc = 10.0f) :
|
||||
kaldi::LatticeFasterDecoderConfig(glob_beam, lat_beam), acoustic_scale(ac_sc) {
|
||||
}
|
||||
float acoustic_scale;
|
||||
};
|
||||
|
||||
class WfstDecoder {
|
||||
public:
|
||||
WfstDecoder(fst::Fst<fst::StdArc>* lm,
|
||||
PhoneSet* phone_set,
|
||||
Vocab* vocab,
|
||||
float glob_beam,
|
||||
float lat_beam,
|
||||
float am_scale);
|
||||
~WfstDecoder();
|
||||
void StartUtterance();
|
||||
void EndUtterance();
|
||||
string Search(float *in, int len, int64_t token_nums);
|
||||
string FinalizeDecode(bool is_stamp=false, std::vector<float> us_alphas={0}, std::vector<float> us_cif_peak={0});
|
||||
void LoadHwsRes(int inc_bias, unordered_map<string, int> &hws_map);
|
||||
void UnloadHwsRes();
|
||||
|
||||
private:
|
||||
Vocab* vocab_ = nullptr;
|
||||
PhoneSet* phone_set_ = nullptr;
|
||||
int cur_frame_ = 0;
|
||||
int cur_token_ = 0;
|
||||
DecodeOptions dec_opts_;
|
||||
Decodable decodable_;
|
||||
fst::Fst<fst::StdArc>* lm_ = nullptr;
|
||||
std::shared_ptr<kaldi::LatticeFasterOnlineDecoder> decoder_ = nullptr;
|
||||
std::shared_ptr<BiasLm> bias_lm_ = nullptr;
|
||||
};
|
||||
} // namespace funasr
|
||||
#endif // WFST_DECODER_
|
||||
Reference in New Issue
Block a user