feat: Chatterbox TTS backend with turbo/multilingual/default variants
This commit is contained in:
73
kischdle/llmux/llmux/backends/chatterbox_tts.py
Normal file
73
kischdle/llmux/llmux/backends/chatterbox_tts.py
Normal file
@@ -0,0 +1,73 @@
|
|||||||
|
import asyncio
|
||||||
|
import io
|
||||||
|
import logging
|
||||||
|
|
||||||
|
import soundfile as sf
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from llmux.backends.base import BaseBackend
|
||||||
|
from llmux.config import PhysicalModel
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ChatterboxTTSBackend(BaseBackend):
|
||||||
|
def __init__(self, models_dir: str = "/models"):
|
||||||
|
self._models_dir = models_dir
|
||||||
|
self._loaded: dict[str, dict] = {}
|
||||||
|
|
||||||
|
async def load(self, model_id: str, device: str = "cuda") -> None:
|
||||||
|
if model_id in self._loaded:
|
||||||
|
return
|
||||||
|
physical = _get_physical_config(model_id)
|
||||||
|
variant = physical.variant
|
||||||
|
logger.info(f"Loading Chatterbox {variant} to {device}")
|
||||||
|
|
||||||
|
def _load():
|
||||||
|
from chatterbox.tts import ChatterboxTTS
|
||||||
|
if variant == "turbo":
|
||||||
|
model = ChatterboxTTS.from_pretrained(device=device, variant="turbo")
|
||||||
|
elif variant == "multilingual":
|
||||||
|
model = ChatterboxTTS.from_pretrained(device=device, variant="multilingual")
|
||||||
|
else:
|
||||||
|
model = ChatterboxTTS.from_pretrained(device=device)
|
||||||
|
return model
|
||||||
|
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
model = await loop.run_in_executor(None, _load)
|
||||||
|
self._loaded[model_id] = {"model": model, "device": device}
|
||||||
|
|
||||||
|
async def unload(self, model_id: str) -> None:
|
||||||
|
if model_id not in self._loaded:
|
||||||
|
return
|
||||||
|
entry = self._loaded.pop(model_id)
|
||||||
|
del entry["model"]
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
async def generate(self, model_id, messages, params, stream=False, tools=None):
|
||||||
|
raise NotImplementedError("TTS backend does not support chat generation")
|
||||||
|
|
||||||
|
async def synthesize(self, model_id: str, text: str, voice: str = "default") -> bytes:
|
||||||
|
entry = self._loaded[model_id]
|
||||||
|
model = entry["model"]
|
||||||
|
|
||||||
|
def _synthesize():
|
||||||
|
wav = model.generate(text)
|
||||||
|
buf = io.BytesIO()
|
||||||
|
sf.write(buf, wav.cpu().numpy().squeeze(), samplerate=24000, format="WAV")
|
||||||
|
buf.seek(0)
|
||||||
|
return buf.read()
|
||||||
|
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
audio_bytes = await loop.run_in_executor(None, _synthesize)
|
||||||
|
return audio_bytes
|
||||||
|
|
||||||
|
|
||||||
|
_physical_models: dict[str, PhysicalModel] = {}
|
||||||
|
|
||||||
|
def set_physical_models(models: dict[str, PhysicalModel]) -> None:
|
||||||
|
global _physical_models
|
||||||
|
_physical_models = models
|
||||||
|
|
||||||
|
def _get_physical_config(model_id: str) -> PhysicalModel:
|
||||||
|
return _physical_models[model_id]
|
||||||
Reference in New Issue
Block a user