diff --git a/kischdle/llmux/config/models.yaml b/kischdle/llmux/config/models.yaml index 9f28812..e565961 100644 --- a/kischdle/llmux/config/models.yaml +++ b/kischdle/llmux/config/models.yaml @@ -48,12 +48,6 @@ physical_models: estimated_vram_gb: 4 default_language: "en" - chatterbox-turbo: - type: tts - backend: chatterbox - variant: "turbo" - estimated_vram_gb: 2 - chatterbox-multilingual: type: tts backend: chatterbox @@ -110,8 +104,6 @@ virtual_models: cohere-transcribe: physical: cohere-transcribe - Chatterbox-Turbo: - physical: chatterbox-turbo Chatterbox-Multilingual: physical: chatterbox-multilingual Chatterbox: diff --git a/kischdle/llmux/llmux/backends/chatterbox_tts.py b/kischdle/llmux/llmux/backends/chatterbox_tts.py index 3d8b09b..0857fbf 100644 --- a/kischdle/llmux/llmux/backends/chatterbox_tts.py +++ b/kischdle/llmux/llmux/backends/chatterbox_tts.py @@ -1,4 +1,5 @@ import asyncio +import gc import io import logging @@ -24,25 +25,26 @@ class ChatterboxTTSBackend(BaseBackend): logger.info(f"Loading Chatterbox {variant} to {device}") def _load(): - from chatterbox.tts import ChatterboxTTS - if variant == "turbo": - model = ChatterboxTTS.from_pretrained(device=device, variant="turbo") - elif variant == "multilingual": - model = ChatterboxTTS.from_pretrained(device=device, variant="multilingual") + if variant == "multilingual": + from chatterbox import ChatterboxMultilingualTTS + return ChatterboxMultilingualTTS.from_pretrained(device=device) else: - model = ChatterboxTTS.from_pretrained(device=device) - return model + from chatterbox.tts import ChatterboxTTS + return ChatterboxTTS.from_pretrained(device=device) loop = asyncio.get_event_loop() model = await loop.run_in_executor(None, _load) - self._loaded[model_id] = {"model": model, "device": device} + self._loaded[model_id] = {"model": model, "variant": variant, "device": device} async def unload(self, model_id: str) -> None: if model_id not in self._loaded: return entry = self._loaded.pop(model_id) del entry["model"] + del entry + gc.collect() torch.cuda.empty_cache() + logger.info(f"Unloaded Chatterbox {model_id}") async def generate(self, model_id, messages, params, stream=False, tools=None): raise NotImplementedError("TTS backend does not support chat generation") @@ -50,9 +52,15 @@ class ChatterboxTTSBackend(BaseBackend): async def synthesize(self, model_id: str, text: str, voice: str = "default") -> bytes: entry = self._loaded[model_id] model = entry["model"] + variant = entry["variant"] def _synthesize(): - wav = model.generate(text) + if variant == "multilingual": + # Default to English; voice param could encode language + lang = "en" if voice == "default" else voice + wav = model.generate(text, language_id=lang) + else: + wav = model.generate(text) buf = io.BytesIO() sf.write(buf, wav.cpu().numpy().squeeze(), samplerate=24000, format="WAV") buf.seek(0) diff --git a/kischdle/llmux/tests/test_config.py b/kischdle/llmux/tests/test_config.py index ab807e7..5f7ebd7 100644 --- a/kischdle/llmux/tests/test_config.py +++ b/kischdle/llmux/tests/test_config.py @@ -5,8 +5,8 @@ def test_load_models_config_returns_physical_and_virtual(): physical, virtual = load_models_config() assert isinstance(physical, dict) assert isinstance(virtual, dict) - assert len(physical) == 9 - assert len(virtual) == 16 + assert len(physical) == 8 + assert len(virtual) == 15 def test_physical_model_has_required_fields(): diff --git a/kischdle/llmux/tests/test_model_registry.py b/kischdle/llmux/tests/test_model_registry.py index f9279bd..6541ddc 100644 --- a/kischdle/llmux/tests/test_model_registry.py +++ b/kischdle/llmux/tests/test_model_registry.py @@ -10,7 +10,7 @@ def registry(): def test_list_virtual_models(registry): models = registry.list_virtual_models() - assert len(models) == 16 + assert len(models) == 15 names = [m["id"] for m in models] assert "Qwen3.5-9B-FP8-Thinking" in names assert "GPT-OSS-20B-High" in names diff --git a/kischdle/llmux/tests/test_routes.py b/kischdle/llmux/tests/test_routes.py index 01e7c9c..f465313 100644 --- a/kischdle/llmux/tests/test_routes.py +++ b/kischdle/llmux/tests/test_routes.py @@ -45,7 +45,7 @@ def test_list_models_returns_16(client, auth_headers): assert resp.status_code == 200 body = resp.json() assert body["object"] == "list" - assert len(body["data"]) == 16 + assert len(body["data"]) == 15 def test_list_models_contains_expected_names(client, auth_headers):