From d55c80ae35d231085b12d681d7b75ae7abb77ac4a1839a791544c2ea567415c7 Mon Sep 17 00:00:00 2001 From: tlg Date: Sun, 5 Apr 2026 10:04:45 +0200 Subject: [PATCH] feat: API routes for models, chat, transcription, speech, and admin Co-Authored-By: Claude Sonnet 4.6 --- kischdle/llmux/llmux/routes/admin.py | 108 +++++++++++++++++++ kischdle/llmux/llmux/routes/chat.py | 46 ++++++++ kischdle/llmux/llmux/routes/models.py | 12 +++ kischdle/llmux/llmux/routes/speech.py | 42 ++++++++ kischdle/llmux/llmux/routes/transcription.py | 43 ++++++++ kischdle/llmux/tests/test_routes.py | 62 +++++++++++ 6 files changed, 313 insertions(+) create mode 100644 kischdle/llmux/llmux/routes/admin.py create mode 100644 kischdle/llmux/llmux/routes/chat.py create mode 100644 kischdle/llmux/llmux/routes/models.py create mode 100644 kischdle/llmux/llmux/routes/speech.py create mode 100644 kischdle/llmux/llmux/routes/transcription.py create mode 100644 kischdle/llmux/tests/test_routes.py diff --git a/kischdle/llmux/llmux/routes/admin.py b/kischdle/llmux/llmux/routes/admin.py new file mode 100644 index 0000000..01b5e46 --- /dev/null +++ b/kischdle/llmux/llmux/routes/admin.py @@ -0,0 +1,108 @@ +import logging +import time +from fastapi import APIRouter, Depends, HTTPException, Request +from llmux.model_registry import ModelRegistry +from llmux.vram_manager import VRAMManager + +logger = logging.getLogger(__name__) + +TEST_PROMPT = [{"role": "user", "content": "Say hello in one sentence."}] + + +def create_admin_router(registry, vram_manager, backends, require_api_key): + router = APIRouter() + + @router.post("/admin/test/performance") + async def test_performance(request: Request, api_key: str = Depends(require_api_key)): + body = await request.json() + physical_id = body.get("physical_model_id") + if not physical_id: + raise HTTPException(status_code=400, detail="Missing 'physical_model_id'") + + physical = registry.get_physical(physical_id) + backend_name = physical.backend + + if backend_name == "transformers" and physical.type == "llm": + return await _test_transformers_llm(physical_id, backends) + elif backend_name == "transformers" and physical.type == "asr": + return await _test_transformers_asr(physical_id, backends) + elif backend_name == "llamacpp": + return await _test_llamacpp(physical_id, backends) + elif backend_name == "chatterbox": + return await _test_chatterbox(physical_id, backends) + else: + raise HTTPException(status_code=400, detail=f"Unknown backend: {backend_name}") + + return router + + +async def _test_transformers_llm(physical_id, backends): + from llmux.backends.transformers_llm import TransformersLLMBackend + results = {} + for device_label, device in [("gpu", "cuda"), ("cpu", "cpu")]: + backend = TransformersLLMBackend(models_dir=backends["transformers"]._models_dir) + await backend.load(physical_id, device=device) + start = time.monotonic() + await backend.generate(physical_id, TEST_PROMPT, params={}, stream=False) + elapsed = time.monotonic() - start + await backend.unload(physical_id) + results[device_label] = round(elapsed, 2) + + ratio = results["cpu"] / results["gpu"] if results["gpu"] > 0 else 0 + return {"model": physical_id, "gpu_seconds": results["gpu"], "cpu_seconds": results["cpu"], "speedup": round(ratio, 1), "pass": ratio >= 5.0} + + +async def _test_transformers_asr(physical_id, backends): + from llmux.backends.transformers_asr import TransformersASRBackend + silent_wav = _make_silent_wav(duration_seconds=2) + results = {} + for device_label, device in [("gpu", "cuda"), ("cpu", "cpu")]: + backend = TransformersASRBackend(models_dir=backends["transformers_asr"]._models_dir) + await backend.load(physical_id, device=device) + start = time.monotonic() + await backend.transcribe(physical_id, silent_wav, language="en") + elapsed = time.monotonic() - start + await backend.unload(physical_id) + results[device_label] = round(elapsed, 2) + + ratio = results["cpu"] / results["gpu"] if results["gpu"] > 0 else 0 + return {"model": physical_id, "gpu_seconds": results["gpu"], "cpu_seconds": results["cpu"], "speedup": round(ratio, 1), "pass": ratio >= 5.0} + + +async def _test_llamacpp(physical_id, backends): + from llmux.backends.llamacpp import LlamaCppBackend + results = {} + for label, n_gpu_layers in [("gpu", -1), ("cpu", 0)]: + backend = LlamaCppBackend(models_dir=backends["llamacpp"]._models_dir) + await backend.load(physical_id, n_gpu_layers=n_gpu_layers) + start = time.monotonic() + await backend.generate(physical_id, TEST_PROMPT, params={}, stream=False) + elapsed = time.monotonic() - start + await backend.unload(physical_id) + results[label] = round(elapsed, 2) + + ratio = results["cpu"] / results["gpu"] if results["gpu"] > 0 else 0 + return {"model": physical_id, "gpu_seconds": results["gpu"], "cpu_seconds": results["cpu"], "speedup": round(ratio, 1), "pass": ratio >= 5.0} + + +async def _test_chatterbox(physical_id, backends): + from llmux.backends.chatterbox_tts import ChatterboxTTSBackend + backend = ChatterboxTTSBackend(models_dir=backends["chatterbox"]._models_dir) + await backend.load(physical_id, device="cuda") + test_text = "Hello, this is a performance test." + start = time.monotonic() + audio_bytes = await backend.synthesize(physical_id, test_text) + elapsed = time.monotonic() - start + await backend.unload(physical_id) + + audio_samples = (len(audio_bytes) - 44) / 2 + audio_duration = audio_samples / 24000 + return {"model": physical_id, "synthesis_seconds": round(elapsed, 2), "audio_duration_seconds": round(audio_duration, 2), "realtime_factor": round(audio_duration / elapsed, 1) if elapsed > 0 else 0} + + +def _make_silent_wav(duration_seconds=2, sample_rate=16000): + import struct + num_samples = int(sample_rate * duration_seconds) + data = b"\x00\x00" * num_samples + header = struct.pack("<4sI4s4sIHHIIHH4sI", b"RIFF", 36 + len(data), b"WAVE", b"fmt ", 16, 1, 1, sample_rate, sample_rate * 2, 2, 16, b"data", len(data)) + return header + data diff --git a/kischdle/llmux/llmux/routes/chat.py b/kischdle/llmux/llmux/routes/chat.py new file mode 100644 index 0000000..2dbbe0f --- /dev/null +++ b/kischdle/llmux/llmux/routes/chat.py @@ -0,0 +1,46 @@ +import logging +from fastapi import APIRouter, Depends, HTTPException, Request +from fastapi.responses import StreamingResponse +from llmux.model_registry import ModelRegistry +from llmux.vram_manager import VRAMManager + +logger = logging.getLogger(__name__) + + +def create_chat_router(registry, vram_manager, backends, require_api_key): + router = APIRouter() + + @router.post("/v1/chat/completions") + async def chat_completions(request: Request, api_key: str = Depends(require_api_key)): + body = await request.json() + virtual_name = body.get("model") + if not virtual_name: + raise HTTPException(status_code=400, detail="Missing 'model' field") + + try: + physical_id, physical, params = registry.resolve(virtual_name) + except KeyError: + raise HTTPException(status_code=404, detail=f"Model '{virtual_name}' not found") + + backend_key = physical.backend + if backend_key == "transformers" and physical.type == "asr": + backend_key = "transformers_asr" + backend = backends.get(backend_key) + if backend is None: + raise HTTPException(status_code=500, detail=f"No backend for '{physical.backend}'") + + await vram_manager.load_model( + model_id=physical_id, model_type=physical.type, + vram_gb=physical.estimated_vram_gb, backend=backend, + ) + + messages = body.get("messages", []) + stream = body.get("stream", False) + tools = body.get("tools") + result = await backend.generate(model_id=physical_id, messages=messages, params=params, stream=stream, tools=tools) + + if stream: + return StreamingResponse(result, media_type="text/event-stream") + return result + + return router diff --git a/kischdle/llmux/llmux/routes/models.py b/kischdle/llmux/llmux/routes/models.py new file mode 100644 index 0000000..fb10ac5 --- /dev/null +++ b/kischdle/llmux/llmux/routes/models.py @@ -0,0 +1,12 @@ +from fastapi import APIRouter, Depends +from llmux.model_registry import ModelRegistry + + +def create_models_router(registry: ModelRegistry, require_api_key) -> APIRouter: + router = APIRouter() + + @router.get("/v1/models") + async def list_models(api_key: str = Depends(require_api_key)): + return {"object": "list", "data": registry.list_virtual_models()} + + return router diff --git a/kischdle/llmux/llmux/routes/speech.py b/kischdle/llmux/llmux/routes/speech.py new file mode 100644 index 0000000..91dc4ff --- /dev/null +++ b/kischdle/llmux/llmux/routes/speech.py @@ -0,0 +1,42 @@ +import logging +from fastapi import APIRouter, Depends, HTTPException, Request +from fastapi.responses import Response +from llmux.model_registry import ModelRegistry +from llmux.vram_manager import VRAMManager + +logger = logging.getLogger(__name__) + + +def create_speech_router(registry, vram_manager, backends, require_api_key): + router = APIRouter() + + @router.post("/v1/audio/speech") + async def create_speech(request: Request, api_key: str = Depends(require_api_key)): + body = await request.json() + model_name = body.get("model") + if not model_name: + raise HTTPException(status_code=400, detail="Missing 'model' field") + + try: + physical_id, physical, params = registry.resolve(model_name) + except KeyError: + raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found") + + if physical.type != "tts": + raise HTTPException(status_code=400, detail=f"Model '{model_name}' is not a TTS model") + + backend = backends.get(physical.backend) + if backend is None: + raise HTTPException(status_code=500, detail=f"No backend for '{physical.backend}'") + + await vram_manager.load_model( + model_id=physical_id, model_type=physical.type, + vram_gb=physical.estimated_vram_gb, backend=backend, + ) + + text = body.get("input", "") + voice = body.get("voice", "default") + audio_bytes = await backend.synthesize(model_id=physical_id, text=text, voice=voice) + return Response(content=audio_bytes, media_type="audio/wav") + + return router diff --git a/kischdle/llmux/llmux/routes/transcription.py b/kischdle/llmux/llmux/routes/transcription.py new file mode 100644 index 0000000..6ce3a4b --- /dev/null +++ b/kischdle/llmux/llmux/routes/transcription.py @@ -0,0 +1,43 @@ +import logging +from fastapi import APIRouter, Depends, File, Form, HTTPException, UploadFile +from llmux.model_registry import ModelRegistry +from llmux.vram_manager import VRAMManager + +logger = logging.getLogger(__name__) + + +def create_transcription_router(registry, vram_manager, backends, require_api_key): + router = APIRouter() + + @router.post("/v1/audio/transcriptions") + async def create_transcription( + file: UploadFile = File(...), + model: str = Form(...), + language: str = Form("en"), + api_key: str = Depends(require_api_key), + ): + try: + physical_id, physical, params = registry.resolve(model) + except KeyError: + raise HTTPException(status_code=404, detail=f"Model '{model}' not found") + + if physical.type != "asr": + raise HTTPException(status_code=400, detail=f"Model '{model}' is not an ASR model") + + backend_key = physical.backend + if backend_key == "transformers" and physical.type == "asr": + backend_key = "transformers_asr" + backend = backends.get(backend_key) + if backend is None: + raise HTTPException(status_code=500, detail=f"No backend for '{physical.backend}'") + + await vram_manager.load_model( + model_id=physical_id, model_type=physical.type, + vram_gb=physical.estimated_vram_gb, backend=backend, + ) + + audio_data = await file.read() + result = await backend.transcribe(model_id=physical_id, audio_data=audio_data, language=language) + return result + + return router diff --git a/kischdle/llmux/tests/test_routes.py b/kischdle/llmux/tests/test_routes.py new file mode 100644 index 0000000..01e7c9c --- /dev/null +++ b/kischdle/llmux/tests/test_routes.py @@ -0,0 +1,62 @@ +import pytest +from fastapi import FastAPI +from fastapi.testclient import TestClient + +from llmux.config import ApiKey +from llmux.auth import create_api_key_dependency +from llmux.model_registry import ModelRegistry +from llmux.vram_manager import VRAMManager +from llmux.routes.models import create_models_router + +API_KEY = "sk-test-key" + + +@pytest.fixture +def registry(): + return ModelRegistry.from_config() + + +@pytest.fixture +def vram_manager(): + return VRAMManager(total_vram_gb=16.0) + + +@pytest.fixture +def app(registry, vram_manager): + keys = [ApiKey(key=API_KEY, name="Test")] + require_api_key = create_api_key_dependency(keys) + app = FastAPI() + app.include_router(create_models_router(registry, require_api_key)) + return app + + +@pytest.fixture +def client(app): + return TestClient(app) + + +@pytest.fixture +def auth_headers(): + return {"Authorization": f"Bearer {API_KEY}"} + + +def test_list_models_returns_16(client, auth_headers): + resp = client.get("/v1/models", headers=auth_headers) + assert resp.status_code == 200 + body = resp.json() + assert body["object"] == "list" + assert len(body["data"]) == 16 + + +def test_list_models_contains_expected_names(client, auth_headers): + resp = client.get("/v1/models", headers=auth_headers) + names = [m["id"] for m in resp.json()["data"]] + assert "Qwen3.5-9B-FP8-Thinking" in names + assert "GPT-OSS-20B-High" in names + assert "cohere-transcribe" in names + assert "Chatterbox-Multilingual" in names + + +def test_list_models_requires_auth(client): + resp = client.get("/v1/models") + assert resp.status_code == 401