feat: VRAM manager with priority-based model eviction
Tracks GPU VRAM usage (16GB) and handles model loading/unloading with priority-based eviction: LLM (lowest) -> TTS -> ASR (highest, protected). Uses asyncio Lock for concurrency safety. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
89
kischdle/llmux/llmux/vram_manager.py
Normal file
89
kischdle/llmux/llmux/vram_manager.py
Normal file
@@ -0,0 +1,89 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_PRIORITY = {"llm": 0, "tts": 1, "asr": 2}
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelSlot:
|
||||
model_id: str
|
||||
model_type: str
|
||||
vram_gb: float
|
||||
backend: object
|
||||
|
||||
@staticmethod
|
||||
def priority_rank(model_type: str) -> int:
|
||||
return _PRIORITY[model_type]
|
||||
|
||||
@property
|
||||
def priority(self) -> int:
|
||||
return _PRIORITY[self.model_type]
|
||||
|
||||
|
||||
class VRAMManager:
|
||||
def __init__(self, total_vram_gb: float = 16.0):
|
||||
self._total_vram_gb = total_vram_gb
|
||||
self._loaded: dict[str, ModelSlot] = {}
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
@property
|
||||
def available_vram_gb(self) -> float:
|
||||
used = sum(slot.vram_gb for slot in self._loaded.values())
|
||||
return self._total_vram_gb - used
|
||||
|
||||
def is_loaded(self, model_id: str) -> bool:
|
||||
return model_id in self._loaded
|
||||
|
||||
def get_loaded_models(self) -> dict[str, ModelSlot]:
|
||||
return dict(self._loaded)
|
||||
|
||||
async def load_model(self, model_id, model_type, vram_gb, backend):
|
||||
async with self._lock:
|
||||
await self._load_model_locked(model_id, model_type, vram_gb, backend)
|
||||
|
||||
async def _load_model_locked(self, model_id, model_type, vram_gb, backend):
|
||||
if model_id in self._loaded:
|
||||
return
|
||||
|
||||
if self.available_vram_gb < vram_gb:
|
||||
await self._evict_for(vram_gb, model_type)
|
||||
|
||||
if self.available_vram_gb < vram_gb:
|
||||
raise RuntimeError(
|
||||
f"Cannot free enough VRAM for {model_id} "
|
||||
f"(need {vram_gb}GB, available {self.available_vram_gb}GB)"
|
||||
)
|
||||
|
||||
logger.info(f"Loading {model_id} ({vram_gb}GB VRAM)")
|
||||
await backend.load(model_id)
|
||||
self._loaded[model_id] = ModelSlot(
|
||||
model_id=model_id,
|
||||
model_type=model_type,
|
||||
vram_gb=vram_gb,
|
||||
backend=backend,
|
||||
)
|
||||
|
||||
async def _evict_for(self, needed_gb, requesting_type):
|
||||
requesting_priority = _PRIORITY[requesting_type]
|
||||
|
||||
# Evict in priority order: lowest first (LLM=0, TTS=1, ASR=2).
|
||||
# Rule: never evict highest-priority tier (ASR) for a lower-priority
|
||||
# request. Same-priority replacement is always allowed (e.g., old LLM
|
||||
# evicted for new LLM). Lower-priority models are fair game for any
|
||||
# requester — cascade through them until enough VRAM is freed.
|
||||
candidates = sorted(self._loaded.values(), key=lambda s: s.priority)
|
||||
for slot in candidates:
|
||||
if self.available_vram_gb >= needed_gb:
|
||||
break
|
||||
# Skip if this slot is the highest-priority tier and the requester
|
||||
# is lower priority. (Protects ASR from eviction by TTS/LLM.)
|
||||
if slot.priority > requesting_priority and slot.model_type == "asr":
|
||||
continue
|
||||
logger.info(
|
||||
f"Evicting {slot.model_id} ({slot.model_type}, {slot.vram_gb}GB)"
|
||||
)
|
||||
await slot.backend.unload(slot.model_id)
|
||||
del self._loaded[slot.model_id]
|
||||
Reference in New Issue
Block a user