Files
weiyu/modules/python/vendors/FunASR/runtime/onnxruntime/src/wfst-decoder.h
2024-12-14 10:43:31 +08:00

87 lines
2.2 KiB
C++

#ifndef WFST_DECODER_
#define WFST_DECODER_
#include "kaldi/decoder/lattice-faster-online-decoder.h"
#include "model.h"
#include "fst/fstlib.h"
#include "fst/symbol-table.h"
#include "bias-lm.h"
#include "phone-set.h"
#include "util.h"
#define MAX_SCORE 10.0f
namespace funasr {
class Decodable : public kaldi::DecodableInterface {
public:
Decodable(float scale = 1.0f) : scale_(scale) {
Reset();
}
void Reset() {
num_frames_ = 0;
finished_ = false;
logp_.clear();
}
int NumFramesReady() const { return num_frames_; }
bool IsLastFrame(int frame) const {
return finished_ && (frame == num_frames_ - 1);
}
float LogLikelihood(int frm, int id) {
CHECK_GT(id, 0);
CHECK_LT(frm, num_frames_);
return scale_ * logp_[id - 1];
}
void AcceptLoglikes(const std::vector<float>& logp) {
num_frames_++;
logp_ = logp;
}
int NumIndices() const { return 0; }
void SetFinished() { finished_ = true; }
private:
int num_frames_ = 0;
float scale_ = 1.0f;
bool finished_ = false;
std::vector<float> logp_;
};
struct DecodeOptions : public kaldi::LatticeFasterDecoderConfig {
DecodeOptions(float glob_beam = 3.0f, float lat_beam = 3.0f, float ac_sc = 10.0f) :
kaldi::LatticeFasterDecoderConfig(glob_beam, lat_beam), acoustic_scale(ac_sc) {
}
float acoustic_scale;
};
class WfstDecoder {
public:
WfstDecoder(fst::Fst<fst::StdArc>* lm,
PhoneSet* phone_set,
Vocab* vocab,
float glob_beam,
float lat_beam,
float am_scale);
~WfstDecoder();
void StartUtterance();
void EndUtterance();
string Search(float *in, int len, int64_t token_nums);
string FinalizeDecode(bool is_stamp=false, std::vector<float> us_alphas={0}, std::vector<float> us_cif_peak={0});
void LoadHwsRes(int inc_bias, unordered_map<string, int> &hws_map);
void UnloadHwsRes();
private:
Vocab* vocab_ = nullptr;
PhoneSet* phone_set_ = nullptr;
int cur_frame_ = 0;
int cur_token_ = 0;
DecodeOptions dec_opts_;
Decodable decodable_;
fst::Fst<fst::StdArc>* lm_ = nullptr;
std::shared_ptr<kaldi::LatticeFasterOnlineDecoder> decoder_ = nullptr;
std::shared_ptr<BiasLm> bias_lm_ = nullptr;
};
} // namespace funasr
#endif // WFST_DECODER_