fix: Chatterbox uses separate classes per variant, remove turbo

ChatterboxTTS and ChatterboxMultilingualTTS are separate classes.
Turbo variant doesn't exist in chatterbox-tts 0.1.7.
Multilingual generate() requires language_id parameter.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
tlg
2026-04-05 21:43:40 +02:00
parent f24a225baf
commit d615bb4553
5 changed files with 21 additions and 21 deletions

View File

@@ -48,12 +48,6 @@ physical_models:
estimated_vram_gb: 4
default_language: "en"
chatterbox-turbo:
type: tts
backend: chatterbox
variant: "turbo"
estimated_vram_gb: 2
chatterbox-multilingual:
type: tts
backend: chatterbox
@@ -110,8 +104,6 @@ virtual_models:
cohere-transcribe:
physical: cohere-transcribe
Chatterbox-Turbo:
physical: chatterbox-turbo
Chatterbox-Multilingual:
physical: chatterbox-multilingual
Chatterbox:

View File

@@ -1,4 +1,5 @@
import asyncio
import gc
import io
import logging
@@ -24,25 +25,26 @@ class ChatterboxTTSBackend(BaseBackend):
logger.info(f"Loading Chatterbox {variant} to {device}")
def _load():
from chatterbox.tts import ChatterboxTTS
if variant == "turbo":
model = ChatterboxTTS.from_pretrained(device=device, variant="turbo")
elif variant == "multilingual":
model = ChatterboxTTS.from_pretrained(device=device, variant="multilingual")
if variant == "multilingual":
from chatterbox import ChatterboxMultilingualTTS
return ChatterboxMultilingualTTS.from_pretrained(device=device)
else:
model = ChatterboxTTS.from_pretrained(device=device)
return model
from chatterbox.tts import ChatterboxTTS
return ChatterboxTTS.from_pretrained(device=device)
loop = asyncio.get_event_loop()
model = await loop.run_in_executor(None, _load)
self._loaded[model_id] = {"model": model, "device": device}
self._loaded[model_id] = {"model": model, "variant": variant, "device": device}
async def unload(self, model_id: str) -> None:
if model_id not in self._loaded:
return
entry = self._loaded.pop(model_id)
del entry["model"]
del entry
gc.collect()
torch.cuda.empty_cache()
logger.info(f"Unloaded Chatterbox {model_id}")
async def generate(self, model_id, messages, params, stream=False, tools=None):
raise NotImplementedError("TTS backend does not support chat generation")
@@ -50,8 +52,14 @@ class ChatterboxTTSBackend(BaseBackend):
async def synthesize(self, model_id: str, text: str, voice: str = "default") -> bytes:
entry = self._loaded[model_id]
model = entry["model"]
variant = entry["variant"]
def _synthesize():
if variant == "multilingual":
# Default to English; voice param could encode language
lang = "en" if voice == "default" else voice
wav = model.generate(text, language_id=lang)
else:
wav = model.generate(text)
buf = io.BytesIO()
sf.write(buf, wav.cpu().numpy().squeeze(), samplerate=24000, format="WAV")

View File

@@ -5,8 +5,8 @@ def test_load_models_config_returns_physical_and_virtual():
physical, virtual = load_models_config()
assert isinstance(physical, dict)
assert isinstance(virtual, dict)
assert len(physical) == 9
assert len(virtual) == 16
assert len(physical) == 8
assert len(virtual) == 15
def test_physical_model_has_required_fields():

View File

@@ -10,7 +10,7 @@ def registry():
def test_list_virtual_models(registry):
models = registry.list_virtual_models()
assert len(models) == 16
assert len(models) == 15
names = [m["id"] for m in models]
assert "Qwen3.5-9B-FP8-Thinking" in names
assert "GPT-OSS-20B-High" in names

View File

@@ -45,7 +45,7 @@ def test_list_models_returns_16(client, auth_headers):
assert resp.status_code == 200
body = resp.json()
assert body["object"] == "list"
assert len(body["data"]) == 16
assert len(body["data"]) == 15
def test_list_models_contains_expected_names(client, auth_headers):