feat: API routes for models, chat, transcription, speech, and admin
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
108
kischdle/llmux/llmux/routes/admin.py
Normal file
108
kischdle/llmux/llmux/routes/admin.py
Normal file
@@ -0,0 +1,108 @@
|
||||
import logging
|
||||
import time
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
from llmux.model_registry import ModelRegistry
|
||||
from llmux.vram_manager import VRAMManager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
TEST_PROMPT = [{"role": "user", "content": "Say hello in one sentence."}]
|
||||
|
||||
|
||||
def create_admin_router(registry, vram_manager, backends, require_api_key):
|
||||
router = APIRouter()
|
||||
|
||||
@router.post("/admin/test/performance")
|
||||
async def test_performance(request: Request, api_key: str = Depends(require_api_key)):
|
||||
body = await request.json()
|
||||
physical_id = body.get("physical_model_id")
|
||||
if not physical_id:
|
||||
raise HTTPException(status_code=400, detail="Missing 'physical_model_id'")
|
||||
|
||||
physical = registry.get_physical(physical_id)
|
||||
backend_name = physical.backend
|
||||
|
||||
if backend_name == "transformers" and physical.type == "llm":
|
||||
return await _test_transformers_llm(physical_id, backends)
|
||||
elif backend_name == "transformers" and physical.type == "asr":
|
||||
return await _test_transformers_asr(physical_id, backends)
|
||||
elif backend_name == "llamacpp":
|
||||
return await _test_llamacpp(physical_id, backends)
|
||||
elif backend_name == "chatterbox":
|
||||
return await _test_chatterbox(physical_id, backends)
|
||||
else:
|
||||
raise HTTPException(status_code=400, detail=f"Unknown backend: {backend_name}")
|
||||
|
||||
return router
|
||||
|
||||
|
||||
async def _test_transformers_llm(physical_id, backends):
|
||||
from llmux.backends.transformers_llm import TransformersLLMBackend
|
||||
results = {}
|
||||
for device_label, device in [("gpu", "cuda"), ("cpu", "cpu")]:
|
||||
backend = TransformersLLMBackend(models_dir=backends["transformers"]._models_dir)
|
||||
await backend.load(physical_id, device=device)
|
||||
start = time.monotonic()
|
||||
await backend.generate(physical_id, TEST_PROMPT, params={}, stream=False)
|
||||
elapsed = time.monotonic() - start
|
||||
await backend.unload(physical_id)
|
||||
results[device_label] = round(elapsed, 2)
|
||||
|
||||
ratio = results["cpu"] / results["gpu"] if results["gpu"] > 0 else 0
|
||||
return {"model": physical_id, "gpu_seconds": results["gpu"], "cpu_seconds": results["cpu"], "speedup": round(ratio, 1), "pass": ratio >= 5.0}
|
||||
|
||||
|
||||
async def _test_transformers_asr(physical_id, backends):
|
||||
from llmux.backends.transformers_asr import TransformersASRBackend
|
||||
silent_wav = _make_silent_wav(duration_seconds=2)
|
||||
results = {}
|
||||
for device_label, device in [("gpu", "cuda"), ("cpu", "cpu")]:
|
||||
backend = TransformersASRBackend(models_dir=backends["transformers_asr"]._models_dir)
|
||||
await backend.load(physical_id, device=device)
|
||||
start = time.monotonic()
|
||||
await backend.transcribe(physical_id, silent_wav, language="en")
|
||||
elapsed = time.monotonic() - start
|
||||
await backend.unload(physical_id)
|
||||
results[device_label] = round(elapsed, 2)
|
||||
|
||||
ratio = results["cpu"] / results["gpu"] if results["gpu"] > 0 else 0
|
||||
return {"model": physical_id, "gpu_seconds": results["gpu"], "cpu_seconds": results["cpu"], "speedup": round(ratio, 1), "pass": ratio >= 5.0}
|
||||
|
||||
|
||||
async def _test_llamacpp(physical_id, backends):
|
||||
from llmux.backends.llamacpp import LlamaCppBackend
|
||||
results = {}
|
||||
for label, n_gpu_layers in [("gpu", -1), ("cpu", 0)]:
|
||||
backend = LlamaCppBackend(models_dir=backends["llamacpp"]._models_dir)
|
||||
await backend.load(physical_id, n_gpu_layers=n_gpu_layers)
|
||||
start = time.monotonic()
|
||||
await backend.generate(physical_id, TEST_PROMPT, params={}, stream=False)
|
||||
elapsed = time.monotonic() - start
|
||||
await backend.unload(physical_id)
|
||||
results[label] = round(elapsed, 2)
|
||||
|
||||
ratio = results["cpu"] / results["gpu"] if results["gpu"] > 0 else 0
|
||||
return {"model": physical_id, "gpu_seconds": results["gpu"], "cpu_seconds": results["cpu"], "speedup": round(ratio, 1), "pass": ratio >= 5.0}
|
||||
|
||||
|
||||
async def _test_chatterbox(physical_id, backends):
|
||||
from llmux.backends.chatterbox_tts import ChatterboxTTSBackend
|
||||
backend = ChatterboxTTSBackend(models_dir=backends["chatterbox"]._models_dir)
|
||||
await backend.load(physical_id, device="cuda")
|
||||
test_text = "Hello, this is a performance test."
|
||||
start = time.monotonic()
|
||||
audio_bytes = await backend.synthesize(physical_id, test_text)
|
||||
elapsed = time.monotonic() - start
|
||||
await backend.unload(physical_id)
|
||||
|
||||
audio_samples = (len(audio_bytes) - 44) / 2
|
||||
audio_duration = audio_samples / 24000
|
||||
return {"model": physical_id, "synthesis_seconds": round(elapsed, 2), "audio_duration_seconds": round(audio_duration, 2), "realtime_factor": round(audio_duration / elapsed, 1) if elapsed > 0 else 0}
|
||||
|
||||
|
||||
def _make_silent_wav(duration_seconds=2, sample_rate=16000):
|
||||
import struct
|
||||
num_samples = int(sample_rate * duration_seconds)
|
||||
data = b"\x00\x00" * num_samples
|
||||
header = struct.pack("<4sI4s4sIHHIIHH4sI", b"RIFF", 36 + len(data), b"WAVE", b"fmt ", 16, 1, 1, sample_rate, sample_rate * 2, 2, 16, b"data", len(data))
|
||||
return header + data
|
||||
46
kischdle/llmux/llmux/routes/chat.py
Normal file
46
kischdle/llmux/llmux/routes/chat.py
Normal file
@@ -0,0 +1,46 @@
|
||||
import logging
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
from fastapi.responses import StreamingResponse
|
||||
from llmux.model_registry import ModelRegistry
|
||||
from llmux.vram_manager import VRAMManager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_chat_router(registry, vram_manager, backends, require_api_key):
|
||||
router = APIRouter()
|
||||
|
||||
@router.post("/v1/chat/completions")
|
||||
async def chat_completions(request: Request, api_key: str = Depends(require_api_key)):
|
||||
body = await request.json()
|
||||
virtual_name = body.get("model")
|
||||
if not virtual_name:
|
||||
raise HTTPException(status_code=400, detail="Missing 'model' field")
|
||||
|
||||
try:
|
||||
physical_id, physical, params = registry.resolve(virtual_name)
|
||||
except KeyError:
|
||||
raise HTTPException(status_code=404, detail=f"Model '{virtual_name}' not found")
|
||||
|
||||
backend_key = physical.backend
|
||||
if backend_key == "transformers" and physical.type == "asr":
|
||||
backend_key = "transformers_asr"
|
||||
backend = backends.get(backend_key)
|
||||
if backend is None:
|
||||
raise HTTPException(status_code=500, detail=f"No backend for '{physical.backend}'")
|
||||
|
||||
await vram_manager.load_model(
|
||||
model_id=physical_id, model_type=physical.type,
|
||||
vram_gb=physical.estimated_vram_gb, backend=backend,
|
||||
)
|
||||
|
||||
messages = body.get("messages", [])
|
||||
stream = body.get("stream", False)
|
||||
tools = body.get("tools")
|
||||
result = await backend.generate(model_id=physical_id, messages=messages, params=params, stream=stream, tools=tools)
|
||||
|
||||
if stream:
|
||||
return StreamingResponse(result, media_type="text/event-stream")
|
||||
return result
|
||||
|
||||
return router
|
||||
12
kischdle/llmux/llmux/routes/models.py
Normal file
12
kischdle/llmux/llmux/routes/models.py
Normal file
@@ -0,0 +1,12 @@
|
||||
from fastapi import APIRouter, Depends
|
||||
from llmux.model_registry import ModelRegistry
|
||||
|
||||
|
||||
def create_models_router(registry: ModelRegistry, require_api_key) -> APIRouter:
|
||||
router = APIRouter()
|
||||
|
||||
@router.get("/v1/models")
|
||||
async def list_models(api_key: str = Depends(require_api_key)):
|
||||
return {"object": "list", "data": registry.list_virtual_models()}
|
||||
|
||||
return router
|
||||
42
kischdle/llmux/llmux/routes/speech.py
Normal file
42
kischdle/llmux/llmux/routes/speech.py
Normal file
@@ -0,0 +1,42 @@
|
||||
import logging
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
from fastapi.responses import Response
|
||||
from llmux.model_registry import ModelRegistry
|
||||
from llmux.vram_manager import VRAMManager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_speech_router(registry, vram_manager, backends, require_api_key):
|
||||
router = APIRouter()
|
||||
|
||||
@router.post("/v1/audio/speech")
|
||||
async def create_speech(request: Request, api_key: str = Depends(require_api_key)):
|
||||
body = await request.json()
|
||||
model_name = body.get("model")
|
||||
if not model_name:
|
||||
raise HTTPException(status_code=400, detail="Missing 'model' field")
|
||||
|
||||
try:
|
||||
physical_id, physical, params = registry.resolve(model_name)
|
||||
except KeyError:
|
||||
raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found")
|
||||
|
||||
if physical.type != "tts":
|
||||
raise HTTPException(status_code=400, detail=f"Model '{model_name}' is not a TTS model")
|
||||
|
||||
backend = backends.get(physical.backend)
|
||||
if backend is None:
|
||||
raise HTTPException(status_code=500, detail=f"No backend for '{physical.backend}'")
|
||||
|
||||
await vram_manager.load_model(
|
||||
model_id=physical_id, model_type=physical.type,
|
||||
vram_gb=physical.estimated_vram_gb, backend=backend,
|
||||
)
|
||||
|
||||
text = body.get("input", "")
|
||||
voice = body.get("voice", "default")
|
||||
audio_bytes = await backend.synthesize(model_id=physical_id, text=text, voice=voice)
|
||||
return Response(content=audio_bytes, media_type="audio/wav")
|
||||
|
||||
return router
|
||||
43
kischdle/llmux/llmux/routes/transcription.py
Normal file
43
kischdle/llmux/llmux/routes/transcription.py
Normal file
@@ -0,0 +1,43 @@
|
||||
import logging
|
||||
from fastapi import APIRouter, Depends, File, Form, HTTPException, UploadFile
|
||||
from llmux.model_registry import ModelRegistry
|
||||
from llmux.vram_manager import VRAMManager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_transcription_router(registry, vram_manager, backends, require_api_key):
|
||||
router = APIRouter()
|
||||
|
||||
@router.post("/v1/audio/transcriptions")
|
||||
async def create_transcription(
|
||||
file: UploadFile = File(...),
|
||||
model: str = Form(...),
|
||||
language: str = Form("en"),
|
||||
api_key: str = Depends(require_api_key),
|
||||
):
|
||||
try:
|
||||
physical_id, physical, params = registry.resolve(model)
|
||||
except KeyError:
|
||||
raise HTTPException(status_code=404, detail=f"Model '{model}' not found")
|
||||
|
||||
if physical.type != "asr":
|
||||
raise HTTPException(status_code=400, detail=f"Model '{model}' is not an ASR model")
|
||||
|
||||
backend_key = physical.backend
|
||||
if backend_key == "transformers" and physical.type == "asr":
|
||||
backend_key = "transformers_asr"
|
||||
backend = backends.get(backend_key)
|
||||
if backend is None:
|
||||
raise HTTPException(status_code=500, detail=f"No backend for '{physical.backend}'")
|
||||
|
||||
await vram_manager.load_model(
|
||||
model_id=physical_id, model_type=physical.type,
|
||||
vram_gb=physical.estimated_vram_gb, backend=backend,
|
||||
)
|
||||
|
||||
audio_data = await file.read()
|
||||
result = await backend.transcribe(model_id=physical_id, audio_data=audio_data, language=language)
|
||||
return result
|
||||
|
||||
return router
|
||||
Reference in New Issue
Block a user