diff --git a/kischdle/llmux/llmux/main.py b/kischdle/llmux/llmux/main.py new file mode 100644 index 0000000..a6c57ce --- /dev/null +++ b/kischdle/llmux/llmux/main.py @@ -0,0 +1,80 @@ +import logging +import os + +from fastapi import FastAPI + +from llmux.config import load_models_config, load_api_keys +from llmux.auth import create_api_key_dependency +from llmux.model_registry import ModelRegistry +from llmux.vram_manager import VRAMManager +from llmux.backends.transformers_llm import TransformersLLMBackend +from llmux.backends.transformers_llm import set_physical_models as set_transformers_llm_models +from llmux.backends.transformers_asr import TransformersASRBackend +from llmux.backends.transformers_asr import set_physical_models as set_transformers_asr_models +from llmux.backends.llamacpp import LlamaCppBackend +from llmux.backends.llamacpp import set_physical_models as set_llamacpp_models +from llmux.backends.chatterbox_tts import ChatterboxTTSBackend +from llmux.backends.chatterbox_tts import set_physical_models as set_chatterbox_models +from llmux.routes.models import create_models_router +from llmux.routes.chat import create_chat_router +from llmux.routes.transcription import create_transcription_router +from llmux.routes.speech import create_speech_router +from llmux.routes.admin import create_admin_router + +logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s: %(message)s") +logger = logging.getLogger(__name__) + +MODELS_DIR = os.environ.get("LLMUX_MODELS_DIR", "/models") + +app = FastAPI(title="llmux", version="0.1.0") + + +@app.on_event("startup") +async def startup(): + logger.info("Starting llmux...") + + physical, virtual = load_models_config() + api_keys = load_api_keys() + + set_transformers_llm_models(physical) + set_transformers_asr_models(physical) + set_llamacpp_models(physical) + set_chatterbox_models(physical) + + registry = ModelRegistry(physical, virtual) + vram_manager = VRAMManager(total_vram_gb=16.0) + require_api_key = create_api_key_dependency(api_keys) + + transformers_llm = TransformersLLMBackend(models_dir=MODELS_DIR) + transformers_asr = TransformersASRBackend(models_dir=MODELS_DIR) + llamacpp = LlamaCppBackend(models_dir=MODELS_DIR) + chatterbox = ChatterboxTTSBackend(models_dir=MODELS_DIR) + + backends = { + "transformers": transformers_llm, + "transformers_asr": transformers_asr, + "llamacpp": llamacpp, + "chatterbox": chatterbox, + } + + app.state.vram_manager = vram_manager + app.state.registry = registry + + app.include_router(create_models_router(registry, require_api_key)) + app.include_router(create_chat_router(registry, vram_manager, backends, require_api_key)) + app.include_router(create_transcription_router(registry, vram_manager, backends, require_api_key)) + app.include_router(create_speech_router(registry, vram_manager, backends, require_api_key)) + app.include_router(create_admin_router(registry, vram_manager, backends, require_api_key)) + + logger.info("llmux started successfully") + + +@app.get("/health") +async def health(): + vram_manager = app.state.vram_manager + loaded = vram_manager.get_loaded_models() + return { + "status": "ok", + "loaded_models": {mid: {"type": slot.model_type, "vram_gb": slot.vram_gb} for mid, slot in loaded.items()}, + "available_vram_gb": round(vram_manager.available_vram_gb, 1), + }