mirror of
https://gitee.com/270580156/weiyu.git
synced 2026-05-16 20:27:50 +00:00
Sync from bytedesk-private: update
This commit is contained in:
89
modules/python/vendors/AsrTools/bk_asr/BaseASR.py
vendored
Normal file
89
modules/python/vendors/AsrTools/bk_asr/BaseASR.py
vendored
Normal file
@@ -0,0 +1,89 @@
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import zlib
|
||||
import tempfile
|
||||
import threading
|
||||
|
||||
from .ASRData import ASRDataSeg, ASRData
|
||||
|
||||
|
||||
class BaseASR:
|
||||
SUPPORTED_SOUND_FORMAT = ["flac", "m4a", "mp3", "wav"]
|
||||
CACHE_FILE = os.path.join(tempfile.gettempdir(), "bk_asr", "asr_cache.json")
|
||||
_lock = threading.Lock()
|
||||
|
||||
def __init__(self, audio_path: [str, bytes], use_cache: bool = False):
|
||||
self.audio_path = audio_path
|
||||
self.file_binary = None
|
||||
|
||||
self.crc32_hex = None
|
||||
self.use_cache = use_cache
|
||||
|
||||
self._set_data()
|
||||
|
||||
self.cache = self._load_cache()
|
||||
|
||||
def _load_cache(self):
|
||||
if not self.use_cache:
|
||||
return {}
|
||||
os.makedirs(os.path.dirname(self.CACHE_FILE), exist_ok=True)
|
||||
with self._lock:
|
||||
if os.path.exists(self.CACHE_FILE):
|
||||
try:
|
||||
with open(self.CACHE_FILE, 'r', encoding='utf-8') as f:
|
||||
cache = json.load(f)
|
||||
if isinstance(cache, dict):
|
||||
return cache
|
||||
except (json.JSONDecodeError, IOError):
|
||||
return {}
|
||||
return {}
|
||||
|
||||
def _save_cache(self):
|
||||
if not self.use_cache:
|
||||
return
|
||||
with self._lock:
|
||||
try:
|
||||
with open(self.CACHE_FILE, 'w', encoding='utf-8') as f:
|
||||
json.dump(self.cache, f, ensure_ascii=False, indent=2)
|
||||
if os.path.exists(self.CACHE_FILE) and os.path.getsize(self.CACHE_FILE) > 10 * 1024 * 1024:
|
||||
os.remove(self.CACHE_FILE)
|
||||
except IOError as e:
|
||||
logging.error(f"Failed to save cache: {e}")
|
||||
|
||||
def _set_data(self):
|
||||
if isinstance(self.audio_path, bytes):
|
||||
self.file_binary = self.audio_path
|
||||
else:
|
||||
ext = self.audio_path.split(".")[-1].lower()
|
||||
assert ext in self.SUPPORTED_SOUND_FORMAT, f"Unsupported sound format: {ext}"
|
||||
assert os.path.exists(self.audio_path), f"File not found: {self.audio_path}"
|
||||
with open(self.audio_path, "rb") as f:
|
||||
self.file_binary = f.read()
|
||||
crc32_value = zlib.crc32(self.file_binary) & 0xFFFFFFFF
|
||||
self.crc32_hex = format(crc32_value, '08x')
|
||||
|
||||
def _get_key(self):
|
||||
return f"{self.__class__.__name__}-{self.crc32_hex}"
|
||||
|
||||
def run(self):
|
||||
k = self._get_key()
|
||||
if k in self.cache and self.use_cache:
|
||||
resp_data = self.cache[k]
|
||||
else:
|
||||
resp_data = self._run()
|
||||
# Cache the result
|
||||
self.cache[k] = resp_data
|
||||
self._save_cache()
|
||||
segments = self._make_segments(resp_data)
|
||||
return ASRData(segments)
|
||||
|
||||
def _make_segments(self, resp_data: dict) -> list[ASRDataSeg]:
|
||||
raise NotImplementedError("_make_segments method must be implemented in subclass")
|
||||
|
||||
def _run(self) -> dict:
|
||||
""" Run the ASR service and return the response data. """
|
||||
raise NotImplementedError("_run method must be implemented in subclass")
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user