diff --git a/kischdle/llmux/llmux/backends/base.py b/kischdle/llmux/llmux/backends/base.py new file mode 100644 index 0000000..d93ef74 --- /dev/null +++ b/kischdle/llmux/llmux/backends/base.py @@ -0,0 +1,48 @@ +from abc import ABC, abstractmethod +from typing import AsyncIterator + + +class BaseBackend(ABC): + """Abstract base for all model backends.""" + + @abstractmethod + async def load(self, model_id: str, **kwargs) -> None: + """Load model weights into GPU VRAM. + + Backends accept optional kwargs: + - device: "cuda" or "cpu" (transformers backends, chatterbox) + - n_gpu_layers: int (llamacpp backend, -1=all GPU, 0=CPU only) + """ + + @abstractmethod + async def unload(self, model_id: str) -> None: + """Unload model weights from GPU VRAM.""" + + @abstractmethod + async def generate( + self, + model_id: str, + messages: list[dict], + params: dict, + stream: bool = False, + tools: list[dict] | None = None, + ) -> AsyncIterator[str] | dict: + """Run chat inference. Returns full response dict or async iterator of SSE chunks.""" + + async def transcribe( + self, + model_id: str, + audio_data: bytes, + language: str = "en", + ) -> dict: + """Transcribe audio. Only implemented by ASR backends.""" + raise NotImplementedError(f"{self.__class__.__name__} does not support transcription") + + async def synthesize( + self, + model_id: str, + text: str, + voice: str = "default", + ) -> bytes: + """Synthesize speech. Only implemented by TTS backends.""" + raise NotImplementedError(f"{self.__class__.__name__} does not support speech synthesis")