feat: API routes for models, chat, transcription, speech, and admin

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
tlg
2026-04-05 10:04:45 +02:00
parent ef44bc09b9
commit d55c80ae35
6 changed files with 313 additions and 0 deletions

View File

@@ -0,0 +1,108 @@
import logging
import time
from fastapi import APIRouter, Depends, HTTPException, Request
from llmux.model_registry import ModelRegistry
from llmux.vram_manager import VRAMManager
logger = logging.getLogger(__name__)
TEST_PROMPT = [{"role": "user", "content": "Say hello in one sentence."}]
def create_admin_router(registry, vram_manager, backends, require_api_key):
router = APIRouter()
@router.post("/admin/test/performance")
async def test_performance(request: Request, api_key: str = Depends(require_api_key)):
body = await request.json()
physical_id = body.get("physical_model_id")
if not physical_id:
raise HTTPException(status_code=400, detail="Missing 'physical_model_id'")
physical = registry.get_physical(physical_id)
backend_name = physical.backend
if backend_name == "transformers" and physical.type == "llm":
return await _test_transformers_llm(physical_id, backends)
elif backend_name == "transformers" and physical.type == "asr":
return await _test_transformers_asr(physical_id, backends)
elif backend_name == "llamacpp":
return await _test_llamacpp(physical_id, backends)
elif backend_name == "chatterbox":
return await _test_chatterbox(physical_id, backends)
else:
raise HTTPException(status_code=400, detail=f"Unknown backend: {backend_name}")
return router
async def _test_transformers_llm(physical_id, backends):
from llmux.backends.transformers_llm import TransformersLLMBackend
results = {}
for device_label, device in [("gpu", "cuda"), ("cpu", "cpu")]:
backend = TransformersLLMBackend(models_dir=backends["transformers"]._models_dir)
await backend.load(physical_id, device=device)
start = time.monotonic()
await backend.generate(physical_id, TEST_PROMPT, params={}, stream=False)
elapsed = time.monotonic() - start
await backend.unload(physical_id)
results[device_label] = round(elapsed, 2)
ratio = results["cpu"] / results["gpu"] if results["gpu"] > 0 else 0
return {"model": physical_id, "gpu_seconds": results["gpu"], "cpu_seconds": results["cpu"], "speedup": round(ratio, 1), "pass": ratio >= 5.0}
async def _test_transformers_asr(physical_id, backends):
from llmux.backends.transformers_asr import TransformersASRBackend
silent_wav = _make_silent_wav(duration_seconds=2)
results = {}
for device_label, device in [("gpu", "cuda"), ("cpu", "cpu")]:
backend = TransformersASRBackend(models_dir=backends["transformers_asr"]._models_dir)
await backend.load(physical_id, device=device)
start = time.monotonic()
await backend.transcribe(physical_id, silent_wav, language="en")
elapsed = time.monotonic() - start
await backend.unload(physical_id)
results[device_label] = round(elapsed, 2)
ratio = results["cpu"] / results["gpu"] if results["gpu"] > 0 else 0
return {"model": physical_id, "gpu_seconds": results["gpu"], "cpu_seconds": results["cpu"], "speedup": round(ratio, 1), "pass": ratio >= 5.0}
async def _test_llamacpp(physical_id, backends):
from llmux.backends.llamacpp import LlamaCppBackend
results = {}
for label, n_gpu_layers in [("gpu", -1), ("cpu", 0)]:
backend = LlamaCppBackend(models_dir=backends["llamacpp"]._models_dir)
await backend.load(physical_id, n_gpu_layers=n_gpu_layers)
start = time.monotonic()
await backend.generate(physical_id, TEST_PROMPT, params={}, stream=False)
elapsed = time.monotonic() - start
await backend.unload(physical_id)
results[label] = round(elapsed, 2)
ratio = results["cpu"] / results["gpu"] if results["gpu"] > 0 else 0
return {"model": physical_id, "gpu_seconds": results["gpu"], "cpu_seconds": results["cpu"], "speedup": round(ratio, 1), "pass": ratio >= 5.0}
async def _test_chatterbox(physical_id, backends):
from llmux.backends.chatterbox_tts import ChatterboxTTSBackend
backend = ChatterboxTTSBackend(models_dir=backends["chatterbox"]._models_dir)
await backend.load(physical_id, device="cuda")
test_text = "Hello, this is a performance test."
start = time.monotonic()
audio_bytes = await backend.synthesize(physical_id, test_text)
elapsed = time.monotonic() - start
await backend.unload(physical_id)
audio_samples = (len(audio_bytes) - 44) / 2
audio_duration = audio_samples / 24000
return {"model": physical_id, "synthesis_seconds": round(elapsed, 2), "audio_duration_seconds": round(audio_duration, 2), "realtime_factor": round(audio_duration / elapsed, 1) if elapsed > 0 else 0}
def _make_silent_wav(duration_seconds=2, sample_rate=16000):
import struct
num_samples = int(sample_rate * duration_seconds)
data = b"\x00\x00" * num_samples
header = struct.pack("<4sI4s4sIHHIIHH4sI", b"RIFF", 36 + len(data), b"WAVE", b"fmt ", 16, 1, 1, sample_rate, sample_rate * 2, 2, 16, b"data", len(data))
return header + data

View File

@@ -0,0 +1,46 @@
import logging
from fastapi import APIRouter, Depends, HTTPException, Request
from fastapi.responses import StreamingResponse
from llmux.model_registry import ModelRegistry
from llmux.vram_manager import VRAMManager
logger = logging.getLogger(__name__)
def create_chat_router(registry, vram_manager, backends, require_api_key):
router = APIRouter()
@router.post("/v1/chat/completions")
async def chat_completions(request: Request, api_key: str = Depends(require_api_key)):
body = await request.json()
virtual_name = body.get("model")
if not virtual_name:
raise HTTPException(status_code=400, detail="Missing 'model' field")
try:
physical_id, physical, params = registry.resolve(virtual_name)
except KeyError:
raise HTTPException(status_code=404, detail=f"Model '{virtual_name}' not found")
backend_key = physical.backend
if backend_key == "transformers" and physical.type == "asr":
backend_key = "transformers_asr"
backend = backends.get(backend_key)
if backend is None:
raise HTTPException(status_code=500, detail=f"No backend for '{physical.backend}'")
await vram_manager.load_model(
model_id=physical_id, model_type=physical.type,
vram_gb=physical.estimated_vram_gb, backend=backend,
)
messages = body.get("messages", [])
stream = body.get("stream", False)
tools = body.get("tools")
result = await backend.generate(model_id=physical_id, messages=messages, params=params, stream=stream, tools=tools)
if stream:
return StreamingResponse(result, media_type="text/event-stream")
return result
return router

View File

@@ -0,0 +1,12 @@
from fastapi import APIRouter, Depends
from llmux.model_registry import ModelRegistry
def create_models_router(registry: ModelRegistry, require_api_key) -> APIRouter:
router = APIRouter()
@router.get("/v1/models")
async def list_models(api_key: str = Depends(require_api_key)):
return {"object": "list", "data": registry.list_virtual_models()}
return router

View File

@@ -0,0 +1,42 @@
import logging
from fastapi import APIRouter, Depends, HTTPException, Request
from fastapi.responses import Response
from llmux.model_registry import ModelRegistry
from llmux.vram_manager import VRAMManager
logger = logging.getLogger(__name__)
def create_speech_router(registry, vram_manager, backends, require_api_key):
router = APIRouter()
@router.post("/v1/audio/speech")
async def create_speech(request: Request, api_key: str = Depends(require_api_key)):
body = await request.json()
model_name = body.get("model")
if not model_name:
raise HTTPException(status_code=400, detail="Missing 'model' field")
try:
physical_id, physical, params = registry.resolve(model_name)
except KeyError:
raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found")
if physical.type != "tts":
raise HTTPException(status_code=400, detail=f"Model '{model_name}' is not a TTS model")
backend = backends.get(physical.backend)
if backend is None:
raise HTTPException(status_code=500, detail=f"No backend for '{physical.backend}'")
await vram_manager.load_model(
model_id=physical_id, model_type=physical.type,
vram_gb=physical.estimated_vram_gb, backend=backend,
)
text = body.get("input", "")
voice = body.get("voice", "default")
audio_bytes = await backend.synthesize(model_id=physical_id, text=text, voice=voice)
return Response(content=audio_bytes, media_type="audio/wav")
return router

View File

@@ -0,0 +1,43 @@
import logging
from fastapi import APIRouter, Depends, File, Form, HTTPException, UploadFile
from llmux.model_registry import ModelRegistry
from llmux.vram_manager import VRAMManager
logger = logging.getLogger(__name__)
def create_transcription_router(registry, vram_manager, backends, require_api_key):
router = APIRouter()
@router.post("/v1/audio/transcriptions")
async def create_transcription(
file: UploadFile = File(...),
model: str = Form(...),
language: str = Form("en"),
api_key: str = Depends(require_api_key),
):
try:
physical_id, physical, params = registry.resolve(model)
except KeyError:
raise HTTPException(status_code=404, detail=f"Model '{model}' not found")
if physical.type != "asr":
raise HTTPException(status_code=400, detail=f"Model '{model}' is not an ASR model")
backend_key = physical.backend
if backend_key == "transformers" and physical.type == "asr":
backend_key = "transformers_asr"
backend = backends.get(backend_key)
if backend is None:
raise HTTPException(status_code=500, detail=f"No backend for '{physical.backend}'")
await vram_manager.load_model(
model_id=physical_id, model_type=physical.type,
vram_gb=physical.estimated_vram_gb, backend=backend,
)
audio_data = await file.read()
result = await backend.transcribe(model_id=physical_id, audio_data=audio_data, language=language)
return result
return router