81 lines
3.1 KiB
Python
81 lines
3.1 KiB
Python
import logging
|
|
import os
|
|
|
|
from fastapi import FastAPI
|
|
|
|
from llmux.config import load_models_config, load_api_keys
|
|
from llmux.auth import create_api_key_dependency
|
|
from llmux.model_registry import ModelRegistry
|
|
from llmux.vram_manager import VRAMManager
|
|
from llmux.backends.transformers_llm import TransformersLLMBackend
|
|
from llmux.backends.transformers_llm import set_physical_models as set_transformers_llm_models
|
|
from llmux.backends.transformers_asr import TransformersASRBackend
|
|
from llmux.backends.transformers_asr import set_physical_models as set_transformers_asr_models
|
|
from llmux.backends.llamacpp import LlamaCppBackend
|
|
from llmux.backends.llamacpp import set_physical_models as set_llamacpp_models
|
|
from llmux.backends.chatterbox_tts import ChatterboxTTSBackend
|
|
from llmux.backends.chatterbox_tts import set_physical_models as set_chatterbox_models
|
|
from llmux.routes.models import create_models_router
|
|
from llmux.routes.chat import create_chat_router
|
|
from llmux.routes.transcription import create_transcription_router
|
|
from llmux.routes.speech import create_speech_router
|
|
from llmux.routes.admin import create_admin_router
|
|
|
|
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s: %(message)s")
|
|
logger = logging.getLogger(__name__)
|
|
|
|
MODELS_DIR = os.environ.get("LLMUX_MODELS_DIR", "/models")
|
|
|
|
app = FastAPI(title="llmux", version="0.1.0")
|
|
|
|
|
|
@app.on_event("startup")
|
|
async def startup():
|
|
logger.info("Starting llmux...")
|
|
|
|
physical, virtual = load_models_config()
|
|
api_keys = load_api_keys()
|
|
|
|
set_transformers_llm_models(physical)
|
|
set_transformers_asr_models(physical)
|
|
set_llamacpp_models(physical)
|
|
set_chatterbox_models(physical)
|
|
|
|
registry = ModelRegistry(physical, virtual)
|
|
vram_manager = VRAMManager(total_vram_gb=16.0)
|
|
require_api_key = create_api_key_dependency(api_keys)
|
|
|
|
transformers_llm = TransformersLLMBackend(models_dir=MODELS_DIR)
|
|
transformers_asr = TransformersASRBackend(models_dir=MODELS_DIR)
|
|
llamacpp = LlamaCppBackend(models_dir=MODELS_DIR)
|
|
chatterbox = ChatterboxTTSBackend(models_dir=MODELS_DIR)
|
|
|
|
backends = {
|
|
"transformers": transformers_llm,
|
|
"transformers_asr": transformers_asr,
|
|
"llamacpp": llamacpp,
|
|
"chatterbox": chatterbox,
|
|
}
|
|
|
|
app.state.vram_manager = vram_manager
|
|
app.state.registry = registry
|
|
|
|
app.include_router(create_models_router(registry, require_api_key))
|
|
app.include_router(create_chat_router(registry, vram_manager, backends, require_api_key))
|
|
app.include_router(create_transcription_router(registry, vram_manager, backends, require_api_key))
|
|
app.include_router(create_speech_router(registry, vram_manager, backends, require_api_key))
|
|
app.include_router(create_admin_router(registry, vram_manager, backends, require_api_key))
|
|
|
|
logger.info("llmux started successfully")
|
|
|
|
|
|
@app.get("/health")
|
|
async def health():
|
|
vram_manager = app.state.vram_manager
|
|
loaded = vram_manager.get_loaded_models()
|
|
return {
|
|
"status": "ok",
|
|
"loaded_models": {mid: {"type": slot.model_type, "vram_gb": slot.vram_gb} for mid, slot in loaded.items()},
|
|
"available_vram_gb": round(vram_manager.available_vram_gb, 1),
|
|
}
|