74 lines
2.7 KiB
Python
74 lines
2.7 KiB
Python
import asyncio
|
|
import logging
|
|
|
|
import torch
|
|
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor
|
|
|
|
from llmux.backends.base import BaseBackend
|
|
from llmux.config import PhysicalModel
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class TransformersASRBackend(BaseBackend):
|
|
def __init__(self, models_dir: str = "/models"):
|
|
self._models_dir = models_dir
|
|
self._loaded: dict[str, dict] = {}
|
|
|
|
async def load(self, model_id: str, device: str = "cuda") -> None:
|
|
if model_id in self._loaded:
|
|
return
|
|
physical = _get_physical_config(model_id)
|
|
hf_id = physical.model_id
|
|
logger.info(f"Loading ASR model {hf_id} to {device}")
|
|
|
|
def _load():
|
|
processor = AutoProcessor.from_pretrained(hf_id, cache_dir=self._models_dir, trust_remote_code=True)
|
|
model = AutoModelForSpeechSeq2Seq.from_pretrained(hf_id, cache_dir=self._models_dir, torch_dtype="auto", device_map=device, trust_remote_code=True)
|
|
return model, processor
|
|
|
|
loop = asyncio.get_event_loop()
|
|
model, processor = await loop.run_in_executor(None, _load)
|
|
self._loaded[model_id] = {"model": model, "processor": processor, "device": device}
|
|
|
|
async def unload(self, model_id: str) -> None:
|
|
if model_id not in self._loaded:
|
|
return
|
|
entry = self._loaded.pop(model_id)
|
|
del entry["model"]
|
|
del entry["processor"]
|
|
torch.cuda.empty_cache()
|
|
|
|
async def generate(self, model_id, messages, params, stream=False, tools=None):
|
|
raise NotImplementedError("ASR backend does not support chat generation")
|
|
|
|
async def transcribe(self, model_id: str, audio_data: bytes, language: str = "en") -> dict:
|
|
import io
|
|
import soundfile as sf
|
|
|
|
entry = self._loaded[model_id]
|
|
model = entry["model"]
|
|
processor = entry["processor"]
|
|
|
|
def _transcribe():
|
|
audio_array, sample_rate = sf.read(io.BytesIO(audio_data))
|
|
inputs = processor(audio_array, sampling_rate=sample_rate, return_tensors="pt", language=language).to(model.device)
|
|
with torch.no_grad():
|
|
predicted_ids = model.generate(**inputs)
|
|
transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
|
|
return transcription
|
|
|
|
loop = asyncio.get_event_loop()
|
|
text = await loop.run_in_executor(None, _transcribe)
|
|
return {"text": text}
|
|
|
|
|
|
_physical_models: dict[str, PhysicalModel] = {}
|
|
|
|
def set_physical_models(models: dict[str, PhysicalModel]) -> None:
|
|
global _physical_models
|
|
_physical_models = models
|
|
|
|
def _get_physical_config(model_id: str) -> PhysicalModel:
|
|
return _physical_models[model_id]
|