diff --git a/kischdle/llmux/llmux/vram_manager.py b/kischdle/llmux/llmux/vram_manager.py index b2c5dca..e3af2a6 100644 --- a/kischdle/llmux/llmux/vram_manager.py +++ b/kischdle/llmux/llmux/vram_manager.py @@ -70,20 +70,39 @@ class VRAMManager: requesting_priority = _PRIORITY[requesting_type] # Evict in priority order: lowest first (LLM=0, TTS=1, ASR=2). - # Rule: never evict highest-priority tier (ASR) for a lower-priority - # request. Same-priority replacement is always allowed (e.g., old LLM - # evicted for new LLM). Lower-priority models are fair game for any - # requester — cascade through them until enough VRAM is freed. + # + # Rule: never evict a higher-priority model to make room for a + # lower-priority one. E.g., a TTS request must not evict ASR — + # it should evict the LLM instead. But an LLM request CAN cascade + # through TTS and ASR as a last resort, because there is nothing + # lower to evict. Same-priority replacement is always allowed. + # + # Pass 1: evict models with priority <= requesting priority + # (lower or same tier). + # Pass 2: if still not enough, evict higher-priority models + # in ascending order (only when the requester has no + # lower-priority alternatives left). candidates = sorted(self._loaded.values(), key=lambda s: s.priority) - for slot in candidates: + + # Pass 1: evict lower and same priority + for slot in list(candidates): if self.available_vram_gb >= needed_gb: break - # Skip if this slot is the highest-priority tier and the requester - # is lower priority. (Protects ASR from eviction by TTS/LLM.) - if slot.priority > requesting_priority and slot.model_type == "asr": - continue - logger.info( - f"Evicting {slot.model_id} ({slot.model_type}, {slot.vram_gb}GB)" - ) - await slot.backend.unload(slot.model_id) - del self._loaded[slot.model_id] + if slot.priority <= requesting_priority: + logger.info( + f"Evicting {slot.model_id} ({slot.model_type}, {slot.vram_gb}GB)" + ) + await slot.backend.unload(slot.model_id) + del self._loaded[slot.model_id] + + # Pass 2: evict higher priority as last resort + if self.available_vram_gb < needed_gb: + candidates = sorted(self._loaded.values(), key=lambda s: s.priority) + for slot in list(candidates): + if self.available_vram_gb >= needed_gb: + break + logger.info( + f"Evicting {slot.model_id} ({slot.model_type}, {slot.vram_gb}GB) [last resort]" + ) + await slot.backend.unload(slot.model_id) + del self._loaded[slot.model_id] diff --git a/kischdle/llmux/tests/test_vram_manager.py b/kischdle/llmux/tests/test_vram_manager.py index 489530d..49d0c8b 100644 --- a/kischdle/llmux/tests/test_vram_manager.py +++ b/kischdle/llmux/tests/test_vram_manager.py @@ -65,19 +65,53 @@ async def test_evict_llm_first(manager): @pytest.mark.asyncio -async def test_evict_cascade_for_large_llm(manager): +async def test_evict_cascade_asr_survives(manager): + """When LLM fits alongside ASR after evicting LLM+TTS, ASR survives.""" backend = FakeBackend() await manager.load_model("cohere-transcribe", model_type="asr", vram_gb=4.0, backend=backend) await manager.load_model("chatterbox-multilingual", model_type="tts", vram_gb=2.0, backend=backend) await manager.load_model("qwen3.5-4b", model_type="llm", vram_gb=4.0, backend=backend) - # 10 GB used. gpt-oss-20b needs 12GB. Evict LLM(4)->free=10. Evict TTS(2)->free=12. Load. - await manager.load_model("gpt-oss-20b", model_type="llm", vram_gb=12.0, backend=backend) + # 10 GB used. Need 12GB. Evict LLM(4)->free=10. Evict TTS(2)->free=12. ASR+12=16, fits. + await manager.load_model("large-llm", model_type="llm", vram_gb=12.0, backend=backend) assert not manager.is_loaded("qwen3.5-4b") assert not manager.is_loaded("chatterbox-multilingual") - assert manager.is_loaded("cohere-transcribe") # ASR survives if possible + assert manager.is_loaded("cohere-transcribe") # ASR survives + assert manager.is_loaded("large-llm") + + +@pytest.mark.asyncio +async def test_evict_cascade_full_for_huge_llm(manager): + """When LLM is too large to fit alongside ASR, everything gets evicted.""" + backend = FakeBackend() + await manager.load_model("cohere-transcribe", model_type="asr", vram_gb=4.0, backend=backend) + await manager.load_model("chatterbox-multilingual", model_type="tts", vram_gb=2.0, backend=backend) + await manager.load_model("qwen3.5-4b", model_type="llm", vram_gb=4.0, backend=backend) + # 10 GB used. gpt-oss-20b needs 13GB. Evict LLM(4)->free=10. TTS(2)->free=12. ASR(4)->free=16. Load alone. + await manager.load_model("gpt-oss-20b", model_type="llm", vram_gb=13.0, backend=backend) + assert not manager.is_loaded("qwen3.5-4b") + assert not manager.is_loaded("chatterbox-multilingual") + assert not manager.is_loaded("cohere-transcribe") # ASR evicted as last resort assert manager.is_loaded("gpt-oss-20b") +@pytest.mark.asyncio +async def test_tts_cannot_evict_asr(manager): + """TTS request must not evict ASR — it evicts LLM instead.""" + backend = FakeBackend() + await manager.load_model("cohere-transcribe", model_type="asr", vram_gb=4.0, backend=backend) + await manager.load_model("qwen3.5-9b-fp8", model_type="llm", vram_gb=9.0, backend=backend) + # 13GB used, 3GB free. TTS needs 2GB — fits! Load alongside. + await manager.load_model("chatterbox", model_type="tts", vram_gb=2.0, backend=backend) + assert manager.is_loaded("cohere-transcribe") + assert manager.is_loaded("qwen3.5-9b-fp8") + assert manager.is_loaded("chatterbox") + # Now replace TTS with a bigger one that needs eviction + # 15GB used, 1GB free. New TTS needs 2GB. Evict old TTS(2)->free=3. Load. + await manager.load_model("chatterbox-ml", model_type="tts", vram_gb=2.0, backend=backend) + assert manager.is_loaded("cohere-transcribe") # ASR must survive + assert manager.is_loaded("chatterbox-ml") + + @pytest.mark.asyncio async def test_asr_evicts_llm_not_reversed(manager): """When ASR request arrives and LLM is loaded, evict LLM (lower priority)."""