Qwen3-VL mode working; /unload; normal model loading times

This commit is contained in:
llm
2025-11-28 21:29:07 +01:00
parent 21d7ab4a2c
commit 9b3d4e40e2

View File

@@ -18,6 +18,7 @@ from transformers import (
AutoTokenizer,
AutoProcessor,
AutoModel,
AutoModelForVision2Seq,
)
from transformers.utils.import_utils import is_flash_attn_2_available
@@ -177,6 +178,10 @@ class EmbeddingResponse(BaseModel):
usage: Usage
class PreloadRequest(BaseModel):
model: str
# -----------------------------------------------------------------------------
# Helpers
# -----------------------------------------------------------------------------
@@ -248,28 +253,29 @@ def _load_model_locked(model_id: str):
# Load Generation Model
# Check if it is a VL model
if "VL" in model_id:
# Attempt to load as VL
# Using AutoModelForVision2Seq or AutoModelForCausalLM
# depending on the specific model support in transformers
# Use AutoModelForVision2Seq for VL models
# The configuration class Qwen3VLConfig requires Vision2Seq or AutoModel
try:
from transformers import Qwen2VLForConditionalGeneration
model_class = Qwen2VLForConditionalGeneration
except ImportError:
# Fallback to AutoModel if specific class not available
model_class = AutoModelForCausalLM
# Note: We use AutoModelForCausalLM for broad compatibility.
# Qwen2-VL requires Qwen2VLForConditionalGeneration for vision.
# We will try AutoModelForCausalLM first.
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=dtype,
device_map=device_map,
attn_implementation=attn_impl,
trust_remote_code=True, # Often needed for new architectures
).eval()
print(f"Loading {model_id} with AutoModelForVision2Seq...")
model = AutoModelForVision2Seq.from_pretrained(
model_id,
torch_dtype=dtype,
device_map=device_map,
attn_implementation=attn_impl,
trust_remote_code=True,
low_cpu_mem_usage=True,
).eval()
except Exception as e:
print(f"Vision2Seq failed: {e}. Fallback to AutoModel...")
# Fallback to generic AutoModel if Vision2Seq fails
model = AutoModel.from_pretrained(
model_id,
torch_dtype=dtype,
device_map=device_map,
attn_implementation=attn_impl,
trust_remote_code=True,
low_cpu_mem_usage=True,
).eval()
# Processor/Tokenizer
try:
@@ -284,12 +290,14 @@ def _load_model_locked(model_id: str):
_loaded_model_type = "generation"
else:
# Standard Text Model (GPT-OSS)
print(f"Loading {model_id} with AutoModelForCausalLM...")
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=dtype,
device_map=device_map,
attn_implementation=attn_impl,
trust_remote_code=True,
low_cpu_mem_usage=True,
).eval()
processor = AutoTokenizer.from_pretrained(
model_id, trust_remote_code=True
@@ -367,6 +375,41 @@ def _extract_embeddings(outputs) -> torch.Tensor:
# -----------------------------------------------------------------------------
@app.post("/preload")
def preload_model(request: PreloadRequest):
model_id = request.model.strip()
if model_id not in ALLOWED_MODEL_IDS:
raise HTTPException(
status_code=400,
detail=f"Model {model_id} not in allowed models.",
)
with _model_lock:
try:
_ensure_model_loaded(model_id)
except Exception as e:
raise HTTPException(
status_code=500, detail=f"Failed to load model: {e}"
)
return {
"status": "ok",
"loaded_model_id": _loaded_model_id,
"vram_bytes": _current_vram_info(),
}
@app.post("/unload")
def unload_model():
with _model_lock:
stats = _unload_model_locked()
return {
"status": "ok",
"vram_bytes": _current_vram_info(),
"stats": stats,
}
@app.get("/health")
def health():
cuda_ok = bool(torch.cuda.is_available())