fix: Open WebUI integration — Harmony stripping, VRAM eviction, concurrency lock
- Add harmony.py: strip GPT-OSS-20B analysis/thinking channel from both streaming and non-streaming responses (HarmonyStreamFilter + extract_final_text) - Add per-model asyncio.Lock in llamacpp backend to prevent concurrent C++ access that caused container segfaults (exit 139) - Fix chat handler swap for streaming: move inside _stream_generate within lock scope (was broken by try/finally running before stream was consumed) - Filter /v1/models to return only LLM models (hide ASR/TTS from chat dropdown) - Correct Qwen3.5-4B estimated_vram_gb: 4 → 9 (actual allocation ~8GB) - Add GPU memory verification after eviction with retry loop in vram_manager - Add HF_TOKEN_PATH support in main.py for gated model access - Add /v1/audio/models and /v1/audio/voices discovery endpoints (no auth) - Add OOM error handling in both backends and chat route - Add AUDIO_STT_SUPPORTED_CONTENT_TYPES for webm/wav/mp3/ogg - Add performance test script (scripts/perf_test.py) - Update tests to match current config (42 tests pass) Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -22,7 +22,7 @@ physical_models:
|
|||||||
type: llm
|
type: llm
|
||||||
backend: transformers
|
backend: transformers
|
||||||
model_id: "Qwen/Qwen3.5-4B"
|
model_id: "Qwen/Qwen3.5-4B"
|
||||||
estimated_vram_gb: 4
|
estimated_vram_gb: 9
|
||||||
supports_vision: true
|
supports_vision: true
|
||||||
supports_tools: true
|
supports_tools: true
|
||||||
|
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ from llama_cpp import Llama
|
|||||||
|
|
||||||
from llmux.backends.base import BaseBackend
|
from llmux.backends.base import BaseBackend
|
||||||
from llmux.config import PhysicalModel
|
from llmux.config import PhysicalModel
|
||||||
|
from llmux.harmony import HarmonyStreamFilter, extract_final_text
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -19,6 +20,7 @@ class LlamaCppBackend(BaseBackend):
|
|||||||
def __init__(self, models_dir: str = "/models"):
|
def __init__(self, models_dir: str = "/models"):
|
||||||
self._models_dir = Path(models_dir)
|
self._models_dir = Path(models_dir)
|
||||||
self._loaded: dict[str, dict] = {}
|
self._loaded: dict[str, dict] = {}
|
||||||
|
self._locks: dict[str, asyncio.Lock] = {} # per-model lock to prevent concurrent C++ access
|
||||||
|
|
||||||
def _resolve_gguf_path(self, physical: PhysicalModel, filename: str) -> str:
|
def _resolve_gguf_path(self, physical: PhysicalModel, filename: str) -> str:
|
||||||
"""Resolve a GGUF filename — check flat gguf/ dir first, then HF cache."""
|
"""Resolve a GGUF filename — check flat gguf/ dir first, then HF cache."""
|
||||||
@@ -69,29 +71,41 @@ class LlamaCppBackend(BaseBackend):
|
|||||||
"think_handler": think_handler,
|
"think_handler": think_handler,
|
||||||
"no_think_handler": no_think_handler,
|
"no_think_handler": no_think_handler,
|
||||||
}
|
}
|
||||||
|
self._locks[model_id] = asyncio.Lock()
|
||||||
|
|
||||||
async def unload(self, model_id: str) -> None:
|
async def unload(self, model_id: str) -> None:
|
||||||
if model_id not in self._loaded:
|
if model_id not in self._loaded:
|
||||||
return
|
return
|
||||||
entry = self._loaded.pop(model_id)
|
entry = self._loaded.pop(model_id)
|
||||||
del entry["llm"]
|
self._locks.pop(model_id, None)
|
||||||
|
# Delete chat handlers first (they hold references to Llama internals)
|
||||||
|
entry.pop("think_handler", None)
|
||||||
|
entry.pop("no_think_handler", None)
|
||||||
|
llm = entry.pop("llm")
|
||||||
|
# Close the Llama model to release GGML CUDA memory
|
||||||
|
if hasattr(llm, "close"):
|
||||||
|
llm.close()
|
||||||
|
del llm
|
||||||
del entry
|
del entry
|
||||||
gc.collect()
|
gc.collect()
|
||||||
|
# Also clear PyTorch cache in case of mixed allocations
|
||||||
|
import torch
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.empty_cache()
|
||||||
logger.info(f"Unloaded GGUF model {model_id}")
|
logger.info(f"Unloaded GGUF model {model_id}")
|
||||||
|
|
||||||
|
def _select_handler(self, entry, params):
|
||||||
|
"""Select the correct chat handler based on params."""
|
||||||
|
if "enable_thinking" in params:
|
||||||
|
if params["enable_thinking"]:
|
||||||
|
return entry.get("think_handler")
|
||||||
|
else:
|
||||||
|
return entry.get("no_think_handler")
|
||||||
|
return None
|
||||||
|
|
||||||
async def generate(self, model_id, messages, params, stream=False, tools=None):
|
async def generate(self, model_id, messages, params, stream=False, tools=None):
|
||||||
entry = self._loaded[model_id]
|
entry = self._loaded[model_id]
|
||||||
llm = entry["llm"]
|
handler = self._select_handler(entry, params)
|
||||||
|
|
||||||
# Swap chat handler based on thinking mode
|
|
||||||
original_handler = llm.chat_handler
|
|
||||||
if "enable_thinking" in params:
|
|
||||||
if params["enable_thinking"]:
|
|
||||||
handler = entry.get("think_handler")
|
|
||||||
else:
|
|
||||||
handler = entry.get("no_think_handler")
|
|
||||||
if handler:
|
|
||||||
llm.chat_handler = handler
|
|
||||||
|
|
||||||
effective_messages = list(messages)
|
effective_messages = list(messages)
|
||||||
if "system_prompt_prefix" in params:
|
if "system_prompt_prefix" in params:
|
||||||
@@ -102,28 +116,51 @@ class LlamaCppBackend(BaseBackend):
|
|||||||
else:
|
else:
|
||||||
effective_messages.insert(0, {"role": "system", "content": prefix})
|
effective_messages.insert(0, {"role": "system", "content": prefix})
|
||||||
|
|
||||||
try:
|
|
||||||
if stream:
|
if stream:
|
||||||
return self._stream_generate(llm, effective_messages, model_id, tools)
|
return self._stream_generate(entry, effective_messages, model_id, tools, handler)
|
||||||
else:
|
else:
|
||||||
return await self._full_generate(llm, effective_messages, model_id, tools)
|
return await self._full_generate(entry, effective_messages, model_id, tools, handler)
|
||||||
finally:
|
|
||||||
# Restore original handler
|
async def _full_generate(self, entry, messages, model_id, tools, handler):
|
||||||
llm.chat_handler = original_handler
|
llm = entry["llm"]
|
||||||
|
lock = self._locks[model_id]
|
||||||
|
|
||||||
async def _full_generate(self, llm, messages, model_id, tools):
|
|
||||||
def _run():
|
def _run():
|
||||||
kwargs = {"messages": messages, "max_tokens": 4096}
|
kwargs = {"messages": messages, "max_tokens": 4096}
|
||||||
if tools:
|
if tools:
|
||||||
kwargs["tools"] = tools
|
kwargs["tools"] = tools
|
||||||
return llm.create_chat_completion(**kwargs)
|
return llm.create_chat_completion(**kwargs)
|
||||||
|
|
||||||
|
async with lock:
|
||||||
|
original = llm.chat_handler
|
||||||
|
if handler:
|
||||||
|
llm.chat_handler = handler
|
||||||
|
try:
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
result = await loop.run_in_executor(None, _run)
|
result = await loop.run_in_executor(None, _run)
|
||||||
|
finally:
|
||||||
|
llm.chat_handler = original
|
||||||
|
|
||||||
result["model"] = model_id
|
result["model"] = model_id
|
||||||
|
for choice in result.get("choices", []):
|
||||||
|
msg = choice.get("message", {})
|
||||||
|
if msg.get("content"):
|
||||||
|
msg["content"] = extract_final_text(msg["content"])
|
||||||
return result
|
return result
|
||||||
|
|
||||||
async def _stream_generate(self, llm, messages, model_id, tools):
|
async def _stream_generate(self, entry, messages, model_id, tools, handler):
|
||||||
|
llm = entry["llm"]
|
||||||
|
lock = self._locks[model_id]
|
||||||
|
|
||||||
|
# Acquire lock for the entire duration of streaming.
|
||||||
|
# This prevents concurrent C++ access which causes segfaults.
|
||||||
|
await lock.acquire()
|
||||||
|
|
||||||
|
original = llm.chat_handler
|
||||||
|
if handler:
|
||||||
|
llm.chat_handler = handler
|
||||||
|
|
||||||
|
try:
|
||||||
def _run():
|
def _run():
|
||||||
kwargs = {"messages": messages, "max_tokens": 4096, "stream": True}
|
kwargs = {"messages": messages, "max_tokens": 4096, "stream": True}
|
||||||
if tools:
|
if tools:
|
||||||
@@ -133,10 +170,51 @@ class LlamaCppBackend(BaseBackend):
|
|||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
stream = await loop.run_in_executor(None, _run)
|
stream = await loop.run_in_executor(None, _run)
|
||||||
|
|
||||||
|
harmony_filter = HarmonyStreamFilter()
|
||||||
|
error_msg = None
|
||||||
|
try:
|
||||||
for chunk in stream:
|
for chunk in stream:
|
||||||
chunk["model"] = model_id
|
chunk["model"] = model_id
|
||||||
|
skip = False
|
||||||
|
for choice in chunk.get("choices", []):
|
||||||
|
delta = choice.get("delta", {})
|
||||||
|
content = delta.get("content")
|
||||||
|
if content is not None:
|
||||||
|
filtered = harmony_filter.feed(content)
|
||||||
|
if not filtered:
|
||||||
|
skip = True
|
||||||
|
else:
|
||||||
|
delta["content"] = filtered
|
||||||
|
if skip:
|
||||||
|
continue
|
||||||
yield f"data: {json.dumps(chunk)}\n\n"
|
yield f"data: {json.dumps(chunk)}\n\n"
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Stream error for {model_id}: {e}")
|
||||||
|
error_msg = str(e)
|
||||||
|
|
||||||
|
flushed = harmony_filter.flush()
|
||||||
|
if flushed:
|
||||||
|
flush_chunk = {
|
||||||
|
"id": f"chatcmpl-{uuid.uuid4().hex[:12]}",
|
||||||
|
"model": model_id,
|
||||||
|
"object": "chat.completion.chunk",
|
||||||
|
"choices": [{"index": 0, "delta": {"content": flushed}, "finish_reason": None}],
|
||||||
|
}
|
||||||
|
yield f"data: {json.dumps(flush_chunk)}\n\n"
|
||||||
|
|
||||||
|
if error_msg:
|
||||||
|
err_chunk = {
|
||||||
|
"id": f"chatcmpl-{uuid.uuid4().hex[:12]}",
|
||||||
|
"model": model_id,
|
||||||
|
"object": "chat.completion.chunk",
|
||||||
|
"choices": [{"index": 0, "delta": {"content": f"\n\n[Error: {error_msg}]"}, "finish_reason": None}],
|
||||||
|
}
|
||||||
|
yield f"data: {json.dumps(err_chunk)}\n\n"
|
||||||
|
|
||||||
yield "data: [DONE]\n\n"
|
yield "data: [DONE]\n\n"
|
||||||
|
finally:
|
||||||
|
llm.chat_handler = original
|
||||||
|
lock.release()
|
||||||
|
|
||||||
|
|
||||||
def _create_think_handler(llm, enable_thinking: bool):
|
def _create_think_handler(llm, enable_thinking: bool):
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ from threading import Thread
|
|||||||
|
|
||||||
from llmux.backends.base import BaseBackend
|
from llmux.backends.base import BaseBackend
|
||||||
from llmux.config import PhysicalModel
|
from llmux.config import PhysicalModel
|
||||||
|
from llmux.harmony import HarmonyStreamFilter, extract_final_text
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -95,6 +96,7 @@ class TransformersLLMBackend(BaseBackend):
|
|||||||
|
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
text = await loop.run_in_executor(None, _run)
|
text = await loop.run_in_executor(None, _run)
|
||||||
|
text = extract_final_text(text)
|
||||||
return {
|
return {
|
||||||
"id": f"chatcmpl-{uuid.uuid4().hex[:12]}",
|
"id": f"chatcmpl-{uuid.uuid4().hex[:12]}",
|
||||||
"object": "chat.completion",
|
"object": "chat.completion",
|
||||||
@@ -107,23 +109,56 @@ class TransformersLLMBackend(BaseBackend):
|
|||||||
async def _stream_generate(self, model, tokenizer, inputs, model_id):
|
async def _stream_generate(self, model, tokenizer, inputs, model_id):
|
||||||
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
||||||
gen_kwargs = {**inputs, "max_new_tokens": 4096, "streamer": streamer}
|
gen_kwargs = {**inputs, "max_new_tokens": 4096, "streamer": streamer}
|
||||||
thread = Thread(target=lambda: model.generate(**gen_kwargs))
|
gen_error = [None]
|
||||||
|
|
||||||
|
def _run():
|
||||||
|
try:
|
||||||
|
model.generate(**gen_kwargs)
|
||||||
|
except Exception as e:
|
||||||
|
gen_error[0] = e
|
||||||
|
logger.error(f"Generation error for {model_id}: {e}")
|
||||||
|
|
||||||
|
thread = Thread(target=_run)
|
||||||
thread.start()
|
thread.start()
|
||||||
|
|
||||||
chat_id = f"chatcmpl-{uuid.uuid4().hex[:12]}"
|
chat_id = f"chatcmpl-{uuid.uuid4().hex[:12]}"
|
||||||
created = int(time.time())
|
created = int(time.time())
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
|
harmony_filter = HarmonyStreamFilter()
|
||||||
|
|
||||||
|
try:
|
||||||
while True:
|
while True:
|
||||||
token = await loop.run_in_executor(None, lambda: next(streamer, None))
|
token = await loop.run_in_executor(None, lambda: next(streamer, None))
|
||||||
if token is None:
|
if token is None:
|
||||||
|
break
|
||||||
|
filtered = harmony_filter.feed(token)
|
||||||
|
if not filtered:
|
||||||
|
continue
|
||||||
|
chunk = {"id": chat_id, "object": "chat.completion.chunk", "created": created, "model": model_id, "choices": [{"index": 0, "delta": {"content": filtered}, "finish_reason": None}]}
|
||||||
|
yield f"data: {json.dumps(chunk)}\n\n"
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Stream iteration error for {model_id}: {e}")
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
thread.join()
|
||||||
|
|
||||||
|
if gen_error[0]:
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
error_msg = str(gen_error[0])
|
||||||
|
if "out of memory" in error_msg.lower():
|
||||||
|
error_msg = "GPU out of memory. Try a shorter message or clear VRAM."
|
||||||
|
chunk = {"id": chat_id, "object": "chat.completion.chunk", "created": created, "model": model_id, "choices": [{"index": 0, "delta": {"content": f"\n\n[Error: {error_msg}]"}, "finish_reason": None}]}
|
||||||
|
yield f"data: {json.dumps(chunk)}\n\n"
|
||||||
|
|
||||||
|
# Flush any remaining buffered content
|
||||||
|
flushed = harmony_filter.flush()
|
||||||
|
if flushed:
|
||||||
|
chunk = {"id": chat_id, "object": "chat.completion.chunk", "created": created, "model": model_id, "choices": [{"index": 0, "delta": {"content": flushed}, "finish_reason": None}]}
|
||||||
|
yield f"data: {json.dumps(chunk)}\n\n"
|
||||||
|
|
||||||
chunk = {"id": chat_id, "object": "chat.completion.chunk", "created": created, "model": model_id, "choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}]}
|
chunk = {"id": chat_id, "object": "chat.completion.chunk", "created": created, "model": model_id, "choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}]}
|
||||||
yield f"data: {json.dumps(chunk)}\n\n"
|
yield f"data: {json.dumps(chunk)}\n\n"
|
||||||
yield "data: [DONE]\n\n"
|
yield "data: [DONE]\n\n"
|
||||||
break
|
|
||||||
chunk = {"id": chat_id, "object": "chat.completion.chunk", "created": created, "model": model_id, "choices": [{"index": 0, "delta": {"content": token}, "finish_reason": None}]}
|
|
||||||
yield f"data: {json.dumps(chunk)}\n\n"
|
|
||||||
thread.join()
|
|
||||||
|
|
||||||
|
|
||||||
# Physical model config injection
|
# Physical model config injection
|
||||||
|
|||||||
90
kischdle/llmux/llmux/harmony.py
Normal file
90
kischdle/llmux/llmux/harmony.py
Normal file
@@ -0,0 +1,90 @@
|
|||||||
|
"""Post-processing for GPT-OSS Harmony format responses.
|
||||||
|
|
||||||
|
GPT-OSS models output multi-channel responses with analysis (thinking) and
|
||||||
|
final (user-facing) channels. This module extracts only the final channel.
|
||||||
|
|
||||||
|
Formats seen:
|
||||||
|
llamacpp: <|channel|>analysis<|message|>...<|end|><|start|>assistant<|channel|>final<|message|>Hello!
|
||||||
|
transformers: analysisUser greeting...assistantfinalHello! (special tokens stripped)
|
||||||
|
"""
|
||||||
|
|
||||||
|
import re
|
||||||
|
|
||||||
|
# Pattern for llamacpp output (special tokens preserved)
|
||||||
|
_LLAMACPP_FINAL_RE = re.compile(
|
||||||
|
r"<\|channel\|>final<\|message\|>(.*?)(?:<\|end\|>|$)",
|
||||||
|
re.DOTALL,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Pattern for transformers output (special tokens stripped, leaving text markers)
|
||||||
|
_TRANSFORMERS_FINAL_RE = re.compile(
|
||||||
|
r"assistantfinal(.*?)$",
|
||||||
|
re.DOTALL,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def extract_final_text(text: str) -> str:
|
||||||
|
"""Extract the final channel content from a Harmony format response."""
|
||||||
|
# Try llamacpp format first
|
||||||
|
m = _LLAMACPP_FINAL_RE.search(text)
|
||||||
|
if m:
|
||||||
|
return m.group(1).strip()
|
||||||
|
|
||||||
|
# Try transformers format
|
||||||
|
m = _TRANSFORMERS_FINAL_RE.search(text)
|
||||||
|
if m:
|
||||||
|
return m.group(1).strip()
|
||||||
|
|
||||||
|
# Not Harmony format — return as-is
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
class HarmonyStreamFilter:
|
||||||
|
"""Buffers streaming chunks and emits only the final channel content.
|
||||||
|
|
||||||
|
For streaming, we accumulate text until we detect the final channel marker,
|
||||||
|
then start emitting from that point forward. Any content before the marker
|
||||||
|
(analysis channel) is silently dropped.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Markers that indicate the start of the final channel
|
||||||
|
_LLAMACPP_MARKER = "<|channel|>final<|message|>"
|
||||||
|
_TRANSFORMERS_MARKER = "assistantfinal"
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._buffer = ""
|
||||||
|
self._emitting = False
|
||||||
|
self._marker_found = False
|
||||||
|
|
||||||
|
def feed(self, chunk: str) -> str:
|
||||||
|
"""Feed a chunk of streamed text. Returns text to emit (may be empty)."""
|
||||||
|
if self._emitting:
|
||||||
|
return chunk
|
||||||
|
|
||||||
|
self._buffer += chunk
|
||||||
|
|
||||||
|
# Check for llamacpp marker
|
||||||
|
idx = self._buffer.find(self._LLAMACPP_MARKER)
|
||||||
|
if idx >= 0:
|
||||||
|
self._emitting = True
|
||||||
|
after = self._buffer[idx + len(self._LLAMACPP_MARKER):]
|
||||||
|
self._buffer = ""
|
||||||
|
return after
|
||||||
|
|
||||||
|
# Check for transformers marker
|
||||||
|
idx = self._buffer.find(self._TRANSFORMERS_MARKER)
|
||||||
|
if idx >= 0:
|
||||||
|
self._emitting = True
|
||||||
|
after = self._buffer[idx + len(self._TRANSFORMERS_MARKER):]
|
||||||
|
self._buffer = ""
|
||||||
|
return after
|
||||||
|
|
||||||
|
# Not found yet — keep buffering, emit nothing
|
||||||
|
return ""
|
||||||
|
|
||||||
|
def flush(self) -> str:
|
||||||
|
"""Call at end of stream. If no marker was found, return full buffer."""
|
||||||
|
if not self._emitting and self._buffer:
|
||||||
|
# No Harmony markers found — return unmodified content
|
||||||
|
return self._buffer
|
||||||
|
return ""
|
||||||
@@ -26,6 +26,16 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
MODELS_DIR = os.environ.get("LLMUX_MODELS_DIR", "/models")
|
MODELS_DIR = os.environ.get("LLMUX_MODELS_DIR", "/models")
|
||||||
|
|
||||||
|
# Load HF token from file if HF_TOKEN_PATH is set and HF_TOKEN is not already set
|
||||||
|
_hf_token_path = os.environ.get("HF_TOKEN_PATH")
|
||||||
|
if _hf_token_path and not os.environ.get("HF_TOKEN"):
|
||||||
|
try:
|
||||||
|
with open(_hf_token_path) as f:
|
||||||
|
os.environ["HF_TOKEN"] = f.read().strip()
|
||||||
|
logger.info(f"Loaded HF token from {_hf_token_path}")
|
||||||
|
except FileNotFoundError:
|
||||||
|
logger.warning(f"HF_TOKEN_PATH set but file not found: {_hf_token_path}")
|
||||||
|
|
||||||
app = FastAPI(title="llmux", version="0.1.0")
|
app = FastAPI(title="llmux", version="0.1.0")
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -23,7 +23,8 @@ class ModelRegistry:
|
|||||||
"created": 0,
|
"created": 0,
|
||||||
"owned_by": "llmux",
|
"owned_by": "llmux",
|
||||||
}
|
}
|
||||||
for name in self._virtual
|
for name, vm in self._virtual.items()
|
||||||
|
if self._physical[vm.physical].type == "llm"
|
||||||
]
|
]
|
||||||
|
|
||||||
def resolve(self, virtual_name: str) -> tuple[str, PhysicalModel, dict]:
|
def resolve(self, virtual_name: str) -> tuple[str, PhysicalModel, dict]:
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
import logging
|
import logging
|
||||||
|
import torch
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
from llmux.model_registry import ModelRegistry
|
from llmux.model_registry import ModelRegistry
|
||||||
@@ -37,12 +38,17 @@ def create_chat_router(registry, vram_manager, backends, require_api_key):
|
|||||||
messages = body.get("messages", [])
|
messages = body.get("messages", [])
|
||||||
stream = body.get("stream", False)
|
stream = body.get("stream", False)
|
||||||
tools = body.get("tools")
|
tools = body.get("tools")
|
||||||
|
|
||||||
|
try:
|
||||||
if stream:
|
if stream:
|
||||||
# generate() is async def that returns an async generator
|
|
||||||
stream_iter = await backend.generate(model_id=physical_id, messages=messages, params=params, stream=True, tools=tools)
|
stream_iter = await backend.generate(model_id=physical_id, messages=messages, params=params, stream=True, tools=tools)
|
||||||
return StreamingResponse(stream_iter, media_type="text/event-stream")
|
return StreamingResponse(stream_iter, media_type="text/event-stream")
|
||||||
|
|
||||||
result = await backend.generate(model_id=physical_id, messages=messages, params=params, stream=False, tools=tools)
|
result = await backend.generate(model_id=physical_id, messages=messages, params=params, stream=False, tools=tools)
|
||||||
return result
|
return result
|
||||||
|
except torch.cuda.OutOfMemoryError:
|
||||||
|
logger.error(f"CUDA OOM during generation with {virtual_name}")
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
raise HTTPException(status_code=503, detail="GPU out of memory. Try a shorter message or switch to a smaller model.")
|
||||||
|
|
||||||
return router
|
return router
|
||||||
|
|||||||
@@ -10,6 +10,21 @@ logger = logging.getLogger(__name__)
|
|||||||
def create_speech_router(registry, vram_manager, backends, require_api_key):
|
def create_speech_router(registry, vram_manager, backends, require_api_key):
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
|
@router.get("/v1/audio/models")
|
||||||
|
async def list_audio_models():
|
||||||
|
"""Discovery endpoint for Open WebUI — lists available TTS models."""
|
||||||
|
tts_models = [
|
||||||
|
{"id": name}
|
||||||
|
for name, vm in registry._virtual.items()
|
||||||
|
if registry._physical[vm.physical].type == "tts"
|
||||||
|
]
|
||||||
|
return {"models": tts_models}
|
||||||
|
|
||||||
|
@router.get("/v1/audio/voices")
|
||||||
|
async def list_audio_voices():
|
||||||
|
"""Discovery endpoint for Open WebUI — lists available voices."""
|
||||||
|
return {"voices": [{"id": "default", "name": "Default"}]}
|
||||||
|
|
||||||
@router.post("/v1/audio/speech")
|
@router.post("/v1/audio/speech")
|
||||||
async def create_speech(request: Request, api_key: str = Depends(require_api_key)):
|
async def create_speech(request: Request, api_key: str = Depends(require_api_key)):
|
||||||
body = await request.json()
|
body = await request.json()
|
||||||
|
|||||||
@@ -1,7 +1,11 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
|
import gc
|
||||||
import logging
|
import logging
|
||||||
|
import time
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
_PRIORITY = {"llm": 0, "tts": 1, "asr": 2}
|
_PRIORITY = {"llm": 0, "tts": 1, "asr": 2}
|
||||||
@@ -24,10 +28,11 @@ class ModelSlot:
|
|||||||
|
|
||||||
|
|
||||||
class VRAMManager:
|
class VRAMManager:
|
||||||
def __init__(self, total_vram_gb: float = 16.0):
|
def __init__(self, total_vram_gb: float = 16.0, verify_gpu: bool = True):
|
||||||
self._total_vram_gb = total_vram_gb
|
self._total_vram_gb = total_vram_gb
|
||||||
self._loaded: dict[str, ModelSlot] = {}
|
self._loaded: dict[str, ModelSlot] = {}
|
||||||
self._lock = asyncio.Lock()
|
self._lock = asyncio.Lock()
|
||||||
|
self._verify_gpu = verify_gpu
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def available_vram_gb(self) -> float:
|
def available_vram_gb(self) -> float:
|
||||||
@@ -42,9 +47,6 @@ class VRAMManager:
|
|||||||
|
|
||||||
async def clear_all(self) -> dict:
|
async def clear_all(self) -> dict:
|
||||||
"""Unload all models and clear CUDA cache. Returns what was unloaded."""
|
"""Unload all models and clear CUDA cache. Returns what was unloaded."""
|
||||||
import gc
|
|
||||||
import torch
|
|
||||||
|
|
||||||
async with self._lock:
|
async with self._lock:
|
||||||
unloaded = []
|
unloaded = []
|
||||||
for slot in list(self._loaded.values()):
|
for slot in list(self._loaded.values()):
|
||||||
@@ -52,10 +54,7 @@ class VRAMManager:
|
|||||||
await slot.backend.unload(slot.model_id)
|
await slot.backend.unload(slot.model_id)
|
||||||
unloaded.append(slot.model_id)
|
unloaded.append(slot.model_id)
|
||||||
self._loaded.clear()
|
self._loaded.clear()
|
||||||
gc.collect()
|
self._force_gpu_cleanup()
|
||||||
if torch.cuda.is_available():
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
return {
|
return {
|
||||||
"unloaded": unloaded,
|
"unloaded": unloaded,
|
||||||
"available_vram_gb": round(self.available_vram_gb, 1),
|
"available_vram_gb": round(self.available_vram_gb, 1),
|
||||||
@@ -65,12 +64,30 @@ class VRAMManager:
|
|||||||
async with self._lock:
|
async with self._lock:
|
||||||
await self._load_model_locked(model_id, model_type, vram_gb, backend)
|
await self._load_model_locked(model_id, model_type, vram_gb, backend)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _gpu_free_gb() -> float:
|
||||||
|
"""Get actual free GPU memory in GB."""
|
||||||
|
if not torch.cuda.is_available():
|
||||||
|
return 16.0
|
||||||
|
free, _ = torch.cuda.mem_get_info()
|
||||||
|
return free / (1024 ** 3)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _force_gpu_cleanup():
|
||||||
|
"""Force garbage collection and GPU memory release."""
|
||||||
|
gc.collect()
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
async def _load_model_locked(self, model_id, model_type, vram_gb, backend):
|
async def _load_model_locked(self, model_id, model_type, vram_gb, backend):
|
||||||
if model_id in self._loaded:
|
if model_id in self._loaded:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
evicted = False
|
||||||
if self.available_vram_gb < vram_gb:
|
if self.available_vram_gb < vram_gb:
|
||||||
await self._evict_for(vram_gb, model_type)
|
await self._evict_for(vram_gb, model_type)
|
||||||
|
evicted = True
|
||||||
|
|
||||||
if self.available_vram_gb < vram_gb:
|
if self.available_vram_gb < vram_gb:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
@@ -78,6 +95,35 @@ class VRAMManager:
|
|||||||
f"(need {vram_gb}GB, available {self.available_vram_gb}GB)"
|
f"(need {vram_gb}GB, available {self.available_vram_gb}GB)"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# After eviction, verify GPU memory is actually freed.
|
||||||
|
# GGML (llama-cpp) CUDA allocations may take time to release.
|
||||||
|
# Only check when we evicted AND have real GPU AND the model needs >4GB
|
||||||
|
# (small models fit even with overhead; large models are the OOM risk).
|
||||||
|
if evicted and self._verify_gpu and torch.cuda.is_available():
|
||||||
|
self._force_gpu_cleanup()
|
||||||
|
actual_free = self._gpu_free_gb()
|
||||||
|
if actual_free < vram_gb:
|
||||||
|
logger.warning(
|
||||||
|
f"GPU has only {actual_free:.1f}GB free after eviction "
|
||||||
|
f"(need {vram_gb}GB). Waiting for memory release..."
|
||||||
|
)
|
||||||
|
for _ in range(10):
|
||||||
|
await asyncio.sleep(1)
|
||||||
|
self._force_gpu_cleanup()
|
||||||
|
actual_free = self._gpu_free_gb()
|
||||||
|
if actual_free >= vram_gb:
|
||||||
|
break
|
||||||
|
if actual_free < vram_gb:
|
||||||
|
logger.error(
|
||||||
|
f"GPU memory not freed: {actual_free:.1f}GB free, "
|
||||||
|
f"need {vram_gb}GB for {model_id}"
|
||||||
|
)
|
||||||
|
raise RuntimeError(
|
||||||
|
f"GPU memory not freed after eviction: "
|
||||||
|
f"{actual_free:.1f}GB free, need {vram_gb}GB"
|
||||||
|
)
|
||||||
|
logger.info(f"GPU verified: {actual_free:.1f}GB free after eviction")
|
||||||
|
|
||||||
logger.info(f"Loading {model_id} ({vram_gb}GB VRAM)")
|
logger.info(f"Loading {model_id} ({vram_gb}GB VRAM)")
|
||||||
await backend.load(model_id)
|
await backend.load(model_id)
|
||||||
self._loaded[model_id] = ModelSlot(
|
self._loaded[model_id] = ModelSlot(
|
||||||
|
|||||||
224
kischdle/llmux/scripts/perf_test.py
Normal file
224
kischdle/llmux/scripts/perf_test.py
Normal file
@@ -0,0 +1,224 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""Performance test for llmux — measures TTFT, tok/s, and total latency for each LLM model."""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
import sys
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
BASE_URL = "http://127.0.0.1:8081"
|
||||||
|
API_KEY = "sk-llmux-openwebui-hMD6pAka1czM53MtTkmmlFP8tF5zuiiDRgt-PCBnj-c"
|
||||||
|
HEADERS = {"Authorization": f"Bearer {API_KEY}", "Content-Type": "application/json"}
|
||||||
|
|
||||||
|
# Test prompts — short and long to measure different characteristics
|
||||||
|
PROMPTS = {
|
||||||
|
"short": "What is 2+2? Answer in one sentence.",
|
||||||
|
"medium": "Explain how a CPU works in 3-4 paragraphs.",
|
||||||
|
}
|
||||||
|
|
||||||
|
# Models to test — one virtual model per physical model (avoid duplicating physical loads)
|
||||||
|
TEST_MODELS = [
|
||||||
|
# llama-cpp backend (GGUF)
|
||||||
|
("Qwen3.5-9B-FP8-Instruct", "llamacpp", "~10GB"),
|
||||||
|
("GPT-OSS-20B-Uncensored-Low", "llamacpp", "~13GB"),
|
||||||
|
# transformers backend
|
||||||
|
("Qwen3.5-4B-Instruct", "transformers", "~4GB"),
|
||||||
|
# GPT-OSS-20B-Low disabled: needs libc6-dev sys/ headers for triton MXFP4 kernels
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def clear_vram():
|
||||||
|
"""Unload all models to start fresh."""
|
||||||
|
r = httpx.post(f"{BASE_URL}/admin/clear-vram", headers=HEADERS, timeout=60)
|
||||||
|
if r.status_code == 200:
|
||||||
|
print(" VRAM cleared")
|
||||||
|
else:
|
||||||
|
print(f" WARN: clear-vram returned {r.status_code}")
|
||||||
|
|
||||||
|
|
||||||
|
def test_streaming(model: str, prompt: str, prompt_label: str) -> dict:
|
||||||
|
"""Test a model with streaming, measuring TTFT and tok/s."""
|
||||||
|
body = {
|
||||||
|
"model": model,
|
||||||
|
"messages": [{"role": "user", "content": prompt}],
|
||||||
|
"stream": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
start = time.perf_counter()
|
||||||
|
first_token_time = None
|
||||||
|
token_count = 0
|
||||||
|
full_text = []
|
||||||
|
|
||||||
|
try:
|
||||||
|
with httpx.stream("POST", f"{BASE_URL}/v1/chat/completions",
|
||||||
|
json=body, headers=HEADERS, timeout=300) as resp:
|
||||||
|
if resp.status_code != 200:
|
||||||
|
return {"model": model, "prompt": prompt_label, "error": f"HTTP {resp.status_code}"}
|
||||||
|
|
||||||
|
for line in resp.iter_lines():
|
||||||
|
if not line.startswith("data: "):
|
||||||
|
continue
|
||||||
|
data = line[6:]
|
||||||
|
if data == "[DONE]":
|
||||||
|
break
|
||||||
|
try:
|
||||||
|
chunk = json.loads(data)
|
||||||
|
delta = chunk.get("choices", [{}])[0].get("delta", {})
|
||||||
|
content = delta.get("content", "")
|
||||||
|
if content:
|
||||||
|
if first_token_time is None:
|
||||||
|
first_token_time = time.perf_counter()
|
||||||
|
token_count += 1
|
||||||
|
full_text.append(content)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
continue
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return {"model": model, "prompt": prompt_label, "error": str(e)}
|
||||||
|
|
||||||
|
end = time.perf_counter()
|
||||||
|
total_time = end - start
|
||||||
|
ttft = (first_token_time - start) if first_token_time else total_time
|
||||||
|
|
||||||
|
# Token generation time (after first token)
|
||||||
|
gen_time = (end - first_token_time) if first_token_time and token_count > 1 else 0
|
||||||
|
tok_per_sec = (token_count - 1) / gen_time if gen_time > 0 else 0
|
||||||
|
|
||||||
|
output_text = "".join(full_text)
|
||||||
|
output_chars = len(output_text)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"model": model,
|
||||||
|
"prompt": prompt_label,
|
||||||
|
"ttft_s": round(ttft, 2),
|
||||||
|
"total_s": round(total_time, 2),
|
||||||
|
"tokens": token_count,
|
||||||
|
"tok_per_s": round(tok_per_sec, 1),
|
||||||
|
"output_chars": output_chars,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def test_non_streaming(model: str, prompt: str, prompt_label: str) -> dict:
|
||||||
|
"""Test a model without streaming — measures total latency."""
|
||||||
|
body = {
|
||||||
|
"model": model,
|
||||||
|
"messages": [{"role": "user", "content": prompt}],
|
||||||
|
"stream": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
start = time.perf_counter()
|
||||||
|
try:
|
||||||
|
r = httpx.post(f"{BASE_URL}/v1/chat/completions",
|
||||||
|
json=body, headers=HEADERS, timeout=300)
|
||||||
|
if r.status_code != 200:
|
||||||
|
return {"model": model, "prompt": prompt_label, "mode": "non-stream", "error": f"HTTP {r.status_code}"}
|
||||||
|
result = r.json()
|
||||||
|
except Exception as e:
|
||||||
|
return {"model": model, "prompt": prompt_label, "mode": "non-stream", "error": str(e)}
|
||||||
|
|
||||||
|
end = time.perf_counter()
|
||||||
|
content = result.get("choices", [{}])[0].get("message", {}).get("content", "")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"model": model,
|
||||||
|
"prompt": prompt_label,
|
||||||
|
"mode": "non-stream",
|
||||||
|
"total_s": round(end - start, 2),
|
||||||
|
"output_chars": len(content),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def run_tests():
|
||||||
|
print("=" * 80)
|
||||||
|
print("llmux Performance Test")
|
||||||
|
print("=" * 80)
|
||||||
|
|
||||||
|
# Check health
|
||||||
|
try:
|
||||||
|
r = httpx.get(f"{BASE_URL}/health", timeout=5)
|
||||||
|
health = r.json()
|
||||||
|
print(f"Server healthy — available VRAM: {health['available_vram_gb']} GB")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"ERROR: Server not reachable: {e}")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
results = []
|
||||||
|
|
||||||
|
for model, backend, vram_est in TEST_MODELS:
|
||||||
|
print(f"\n{'─' * 60}")
|
||||||
|
print(f"Model: {model} ({backend}, {vram_est})")
|
||||||
|
print(f"{'─' * 60}")
|
||||||
|
|
||||||
|
# Clear VRAM before each model to measure cold-start load time
|
||||||
|
clear_vram()
|
||||||
|
|
||||||
|
for prompt_label, prompt_text in PROMPTS.items():
|
||||||
|
# First run = cold start (includes model loading)
|
||||||
|
print(f" [{prompt_label}] streaming (cold)...", end=" ", flush=True)
|
||||||
|
r = test_streaming(model, prompt_text, prompt_label)
|
||||||
|
r["cold_start"] = True
|
||||||
|
results.append(r)
|
||||||
|
if "error" in r:
|
||||||
|
print(f"ERROR: {r['error']}")
|
||||||
|
else:
|
||||||
|
print(f"TTFT={r['ttft_s']}s total={r['total_s']}s {r['tok_per_s']} tok/s ({r['tokens']} tokens)")
|
||||||
|
|
||||||
|
# Second run = warm (model already loaded)
|
||||||
|
print(f" [{prompt_label}] streaming (warm)...", end=" ", flush=True)
|
||||||
|
r = test_streaming(model, prompt_text, prompt_label)
|
||||||
|
r["cold_start"] = False
|
||||||
|
results.append(r)
|
||||||
|
if "error" in r:
|
||||||
|
print(f"ERROR: {r['error']}")
|
||||||
|
else:
|
||||||
|
print(f"TTFT={r['ttft_s']}s total={r['total_s']}s {r['tok_per_s']} tok/s ({r['tokens']} tokens)")
|
||||||
|
|
||||||
|
# Non-streaming tests (warm)
|
||||||
|
for plabel in ["short", "medium"]:
|
||||||
|
print(f" [{plabel}] non-streaming (warm)...", end=" ", flush=True)
|
||||||
|
r = test_non_streaming(model, PROMPTS[plabel], plabel)
|
||||||
|
results.append(r)
|
||||||
|
if "error" in r:
|
||||||
|
print(f"ERROR: {r['error']}")
|
||||||
|
else:
|
||||||
|
chars_per_s = round(r['output_chars'] / r['total_s'], 1) if r['total_s'] > 0 else 0
|
||||||
|
print(f"total={r['total_s']}s ({r['output_chars']} chars, {chars_per_s} chars/s)")
|
||||||
|
|
||||||
|
# Clear to free VRAM for next model
|
||||||
|
clear_vram()
|
||||||
|
|
||||||
|
# Summary table
|
||||||
|
print(f"\n{'=' * 90}")
|
||||||
|
print("Summary — Streaming")
|
||||||
|
print(f"{'=' * 90}")
|
||||||
|
print(f"{'Model':<40} {'Prompt':<8} {'Cold':>5} {'TTFT':>7} {'Total':>7} {'Chunks':>7} {'Char/s':>7}")
|
||||||
|
print(f"{'-' * 40} {'-' * 8} {'-' * 5} {'-' * 7} {'-' * 7} {'-' * 7} {'-' * 7}")
|
||||||
|
for r in results:
|
||||||
|
if r.get("mode") == "non-stream":
|
||||||
|
continue
|
||||||
|
if "error" in r:
|
||||||
|
print(f"{r['model']:<40} {r['prompt']:<8} {'':>5} {'ERROR':>7}")
|
||||||
|
continue
|
||||||
|
cold = "yes" if r.get("cold_start") else "no"
|
||||||
|
chars_per_s = round(r['output_chars'] / r['total_s'], 1) if r['total_s'] > 0 else 0
|
||||||
|
print(f"{r['model']:<40} {r['prompt']:<8} {cold:>5} {r['ttft_s']:>6.2f}s {r['total_s']:>6.2f}s {r['tokens']:>7} {chars_per_s:>6.1f}")
|
||||||
|
|
||||||
|
print(f"\n{'=' * 90}")
|
||||||
|
print("Summary — Non-streaming")
|
||||||
|
print(f"{'=' * 90}")
|
||||||
|
print(f"{'Model':<40} {'Prompt':<8} {'Total':>7} {'Chars':>7} {'Char/s':>7}")
|
||||||
|
print(f"{'-' * 40} {'-' * 8} {'-' * 7} {'-' * 7} {'-' * 7}")
|
||||||
|
for r in results:
|
||||||
|
if r.get("mode") != "non-stream":
|
||||||
|
continue
|
||||||
|
if "error" in r:
|
||||||
|
print(f"{r['model']:<40} {r['prompt']:<8} {'ERROR':>7}")
|
||||||
|
continue
|
||||||
|
chars_per_s = round(r['output_chars'] / r['total_s'], 1) if r['total_s'] > 0 else 0
|
||||||
|
print(f"{r['model']:<40} {r['prompt']:<8} {r['total_s']:>6.2f}s {r['output_chars']:>7} {chars_per_s:>6.1f}")
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
run_tests()
|
||||||
@@ -13,10 +13,10 @@ def test_physical_model_has_required_fields():
|
|||||||
physical, _ = load_models_config()
|
physical, _ = load_models_config()
|
||||||
qwen = physical["qwen3.5-9b-fp8"]
|
qwen = physical["qwen3.5-9b-fp8"]
|
||||||
assert qwen.type == "llm"
|
assert qwen.type == "llm"
|
||||||
assert qwen.backend == "transformers"
|
assert qwen.backend == "llamacpp"
|
||||||
assert qwen.model_id == "lovedheart/Qwen3.5-9B-FP8"
|
assert qwen.model_id == "unsloth/Qwen3.5-9B-GGUF"
|
||||||
assert qwen.estimated_vram_gb == 9
|
assert qwen.estimated_vram_gb == 10
|
||||||
assert qwen.supports_vision is True
|
assert qwen.supports_vision is False
|
||||||
assert qwen.supports_tools is True
|
assert qwen.supports_tools is True
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
55
kischdle/llmux/tests/test_harmony.py
Normal file
55
kischdle/llmux/tests/test_harmony.py
Normal file
@@ -0,0 +1,55 @@
|
|||||||
|
from llmux.harmony import extract_final_text, HarmonyStreamFilter
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_llamacpp_format():
|
||||||
|
text = '<|channel|>analysis<|message|>User greeting. Simple.<|end|><|start|>assistant<|channel|>final<|message|>Hello! How can I help you today?'
|
||||||
|
assert extract_final_text(text) == "Hello! How can I help you today?"
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_llamacpp_with_end_tag():
|
||||||
|
text = '<|channel|>analysis<|message|>thinking...<|end|><|start|>assistant<|channel|>final<|message|>The answer is 42.<|end|>'
|
||||||
|
assert extract_final_text(text) == "The answer is 42."
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_transformers_format():
|
||||||
|
text = 'analysisUser greeting. Just respond friendly.assistantfinalHello! I am doing great.'
|
||||||
|
assert extract_final_text(text) == "Hello! I am doing great."
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_non_harmony_passthrough():
|
||||||
|
text = "Hello! I'm doing well, thanks for asking."
|
||||||
|
assert extract_final_text(text) == text
|
||||||
|
|
||||||
|
|
||||||
|
def test_stream_filter_llamacpp():
|
||||||
|
f = HarmonyStreamFilter()
|
||||||
|
chunks = [
|
||||||
|
"<|channel|>", "analysis", "<|message|>", "User ", "greeting.",
|
||||||
|
"<|end|>", "<|start|>", "assistant", "<|channel|>", "final",
|
||||||
|
"<|message|>", "Hello!", " How ", "are you?"
|
||||||
|
]
|
||||||
|
output = ""
|
||||||
|
for c in chunks:
|
||||||
|
output += f.feed(c)
|
||||||
|
output += f.flush()
|
||||||
|
assert output == "Hello! How are you?"
|
||||||
|
|
||||||
|
|
||||||
|
def test_stream_filter_transformers():
|
||||||
|
f = HarmonyStreamFilter()
|
||||||
|
chunks = ["analysis", "User ", "greeting.", "assistant", "final", "Hello!", " Great day!"]
|
||||||
|
output = ""
|
||||||
|
for c in chunks:
|
||||||
|
output += f.feed(c)
|
||||||
|
output += f.flush()
|
||||||
|
assert output == "Hello! Great day!"
|
||||||
|
|
||||||
|
|
||||||
|
def test_stream_filter_non_harmony():
|
||||||
|
f = HarmonyStreamFilter()
|
||||||
|
chunks = ["Hello", " world", "!"]
|
||||||
|
output = ""
|
||||||
|
for c in chunks:
|
||||||
|
output += f.feed(c)
|
||||||
|
output += f.flush()
|
||||||
|
assert output == "Hello world!"
|
||||||
@@ -10,12 +10,12 @@ def registry():
|
|||||||
|
|
||||||
def test_list_virtual_models(registry):
|
def test_list_virtual_models(registry):
|
||||||
models = registry.list_virtual_models()
|
models = registry.list_virtual_models()
|
||||||
assert len(models) == 15
|
assert len(models) == 12 # only LLM models, not ASR/TTS
|
||||||
names = [m["id"] for m in models]
|
names = [m["id"] for m in models]
|
||||||
assert "Qwen3.5-9B-FP8-Thinking" in names
|
assert "Qwen3.5-9B-FP8-Thinking" in names
|
||||||
assert "GPT-OSS-20B-High" in names
|
assert "GPT-OSS-20B-High" in names
|
||||||
assert "cohere-transcribe" in names
|
assert "cohere-transcribe" not in names
|
||||||
assert "Chatterbox-Multilingual" in names
|
assert "Chatterbox-Multilingual" not in names
|
||||||
|
|
||||||
|
|
||||||
def test_virtual_model_openai_format(registry):
|
def test_virtual_model_openai_format(registry):
|
||||||
@@ -28,7 +28,7 @@ def test_virtual_model_openai_format(registry):
|
|||||||
def test_resolve_virtual_to_physical(registry):
|
def test_resolve_virtual_to_physical(registry):
|
||||||
physical_id, physical, params = registry.resolve("Qwen3.5-9B-FP8-Thinking")
|
physical_id, physical, params = registry.resolve("Qwen3.5-9B-FP8-Thinking")
|
||||||
assert physical_id == "qwen3.5-9b-fp8"
|
assert physical_id == "qwen3.5-9b-fp8"
|
||||||
assert physical.backend == "transformers"
|
assert physical.backend == "llamacpp"
|
||||||
assert params == {"enable_thinking": True}
|
assert params == {"enable_thinking": True}
|
||||||
|
|
||||||
|
|
||||||
@@ -58,7 +58,7 @@ def test_resolve_unknown_model_raises(registry):
|
|||||||
def test_get_physical(registry):
|
def test_get_physical(registry):
|
||||||
physical = registry.get_physical("qwen3.5-9b-fp8")
|
physical = registry.get_physical("qwen3.5-9b-fp8")
|
||||||
assert physical.type == "llm"
|
assert physical.type == "llm"
|
||||||
assert physical.estimated_vram_gb == 9
|
assert physical.estimated_vram_gb == 10
|
||||||
|
|
||||||
|
|
||||||
def test_get_physical_unknown_raises(registry):
|
def test_get_physical_unknown_raises(registry):
|
||||||
|
|||||||
@@ -40,12 +40,12 @@ def auth_headers():
|
|||||||
return {"Authorization": f"Bearer {API_KEY}"}
|
return {"Authorization": f"Bearer {API_KEY}"}
|
||||||
|
|
||||||
|
|
||||||
def test_list_models_returns_16(client, auth_headers):
|
def test_list_models_returns_only_llm(client, auth_headers):
|
||||||
resp = client.get("/v1/models", headers=auth_headers)
|
resp = client.get("/v1/models", headers=auth_headers)
|
||||||
assert resp.status_code == 200
|
assert resp.status_code == 200
|
||||||
body = resp.json()
|
body = resp.json()
|
||||||
assert body["object"] == "list"
|
assert body["object"] == "list"
|
||||||
assert len(body["data"]) == 15
|
assert len(body["data"]) == 12 # only LLM models
|
||||||
|
|
||||||
|
|
||||||
def test_list_models_contains_expected_names(client, auth_headers):
|
def test_list_models_contains_expected_names(client, auth_headers):
|
||||||
@@ -53,8 +53,8 @@ def test_list_models_contains_expected_names(client, auth_headers):
|
|||||||
names = [m["id"] for m in resp.json()["data"]]
|
names = [m["id"] for m in resp.json()["data"]]
|
||||||
assert "Qwen3.5-9B-FP8-Thinking" in names
|
assert "Qwen3.5-9B-FP8-Thinking" in names
|
||||||
assert "GPT-OSS-20B-High" in names
|
assert "GPT-OSS-20B-High" in names
|
||||||
assert "cohere-transcribe" in names
|
assert "cohere-transcribe" not in names
|
||||||
assert "Chatterbox-Multilingual" in names
|
assert "Chatterbox-Multilingual" not in names
|
||||||
|
|
||||||
|
|
||||||
def test_list_models_requires_auth(client):
|
def test_list_models_requires_auth(client):
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ class FakeBackend:
|
|||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def manager():
|
def manager():
|
||||||
return VRAMManager(total_vram_gb=16.0)
|
return VRAMManager(total_vram_gb=16.0, verify_gpu=False)
|
||||||
|
|
||||||
|
|
||||||
def test_priority_ordering():
|
def test_priority_ordering():
|
||||||
|
|||||||
Reference in New Issue
Block a user