mirror of
https://gitee.com/270580156/weiyu.git
synced 2026-05-14 19:27:53 +00:00
90 lines
3.0 KiB
Python
90 lines
3.0 KiB
Python
import json
|
|
import logging
|
|
import os
|
|
import zlib
|
|
import tempfile
|
|
import threading
|
|
|
|
from .ASRData import ASRDataSeg, ASRData
|
|
|
|
|
|
class BaseASR:
|
|
SUPPORTED_SOUND_FORMAT = ["flac", "m4a", "mp3", "wav"]
|
|
CACHE_FILE = os.path.join(tempfile.gettempdir(), "bk_asr", "asr_cache.json")
|
|
_lock = threading.Lock()
|
|
|
|
def __init__(self, audio_path: [str, bytes], use_cache: bool = False):
|
|
self.audio_path = audio_path
|
|
self.file_binary = None
|
|
|
|
self.crc32_hex = None
|
|
self.use_cache = use_cache
|
|
|
|
self._set_data()
|
|
|
|
self.cache = self._load_cache()
|
|
|
|
def _load_cache(self):
|
|
if not self.use_cache:
|
|
return {}
|
|
os.makedirs(os.path.dirname(self.CACHE_FILE), exist_ok=True)
|
|
with self._lock:
|
|
if os.path.exists(self.CACHE_FILE):
|
|
try:
|
|
with open(self.CACHE_FILE, 'r', encoding='utf-8') as f:
|
|
cache = json.load(f)
|
|
if isinstance(cache, dict):
|
|
return cache
|
|
except (json.JSONDecodeError, IOError):
|
|
return {}
|
|
return {}
|
|
|
|
def _save_cache(self):
|
|
if not self.use_cache:
|
|
return
|
|
with self._lock:
|
|
try:
|
|
with open(self.CACHE_FILE, 'w', encoding='utf-8') as f:
|
|
json.dump(self.cache, f, ensure_ascii=False, indent=2)
|
|
if os.path.exists(self.CACHE_FILE) and os.path.getsize(self.CACHE_FILE) > 10 * 1024 * 1024:
|
|
os.remove(self.CACHE_FILE)
|
|
except IOError as e:
|
|
logging.error(f"Failed to save cache: {e}")
|
|
|
|
def _set_data(self):
|
|
if isinstance(self.audio_path, bytes):
|
|
self.file_binary = self.audio_path
|
|
else:
|
|
ext = self.audio_path.split(".")[-1].lower()
|
|
assert ext in self.SUPPORTED_SOUND_FORMAT, f"Unsupported sound format: {ext}"
|
|
assert os.path.exists(self.audio_path), f"File not found: {self.audio_path}"
|
|
with open(self.audio_path, "rb") as f:
|
|
self.file_binary = f.read()
|
|
crc32_value = zlib.crc32(self.file_binary) & 0xFFFFFFFF
|
|
self.crc32_hex = format(crc32_value, '08x')
|
|
|
|
def _get_key(self):
|
|
return f"{self.__class__.__name__}-{self.crc32_hex}"
|
|
|
|
def run(self):
|
|
k = self._get_key()
|
|
if k in self.cache and self.use_cache:
|
|
resp_data = self.cache[k]
|
|
else:
|
|
resp_data = self._run()
|
|
# Cache the result
|
|
self.cache[k] = resp_data
|
|
self._save_cache()
|
|
segments = self._make_segments(resp_data)
|
|
return ASRData(segments)
|
|
|
|
def _make_segments(self, resp_data: dict) -> list[ASRDataSeg]:
|
|
raise NotImplementedError("_make_segments method must be implemented in subclass")
|
|
|
|
def _run(self) -> dict:
|
|
""" Run the ASR service and return the response data. """
|
|
raise NotImplementedError("_run method must be implemented in subclass")
|
|
|
|
|
|
|