Qwen3-VL mode working; /unload; normal model loading times

This commit is contained in:
llm
2025-11-28 21:29:07 +01:00
parent 21d7ab4a2c
commit 9b3d4e40e2

View File

@@ -18,6 +18,7 @@ from transformers import (
AutoTokenizer, AutoTokenizer,
AutoProcessor, AutoProcessor,
AutoModel, AutoModel,
AutoModelForVision2Seq,
) )
from transformers.utils.import_utils import is_flash_attn_2_available from transformers.utils.import_utils import is_flash_attn_2_available
@@ -177,6 +178,10 @@ class EmbeddingResponse(BaseModel):
usage: Usage usage: Usage
class PreloadRequest(BaseModel):
model: str
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Helpers # Helpers
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -248,27 +253,28 @@ def _load_model_locked(model_id: str):
# Load Generation Model # Load Generation Model
# Check if it is a VL model # Check if it is a VL model
if "VL" in model_id: if "VL" in model_id:
# Attempt to load as VL # Use AutoModelForVision2Seq for VL models
# Using AutoModelForVision2Seq or AutoModelForCausalLM # The configuration class Qwen3VLConfig requires Vision2Seq or AutoModel
# depending on the specific model support in transformers
try: try:
from transformers import Qwen2VLForConditionalGeneration print(f"Loading {model_id} with AutoModelForVision2Seq...")
model = AutoModelForVision2Seq.from_pretrained(
model_class = Qwen2VLForConditionalGeneration
except ImportError:
# Fallback to AutoModel if specific class not available
model_class = AutoModelForCausalLM
# Note: We use AutoModelForCausalLM for broad compatibility.
# Qwen2-VL requires Qwen2VLForConditionalGeneration for vision.
# We will try AutoModelForCausalLM first.
model = AutoModelForCausalLM.from_pretrained(
model_id, model_id,
torch_dtype=dtype, torch_dtype=dtype,
device_map=device_map, device_map=device_map,
attn_implementation=attn_impl, attn_implementation=attn_impl,
trust_remote_code=True, # Often needed for new architectures trust_remote_code=True,
low_cpu_mem_usage=True,
).eval()
except Exception as e:
print(f"Vision2Seq failed: {e}. Fallback to AutoModel...")
# Fallback to generic AutoModel if Vision2Seq fails
model = AutoModel.from_pretrained(
model_id,
torch_dtype=dtype,
device_map=device_map,
attn_implementation=attn_impl,
trust_remote_code=True,
low_cpu_mem_usage=True,
).eval() ).eval()
# Processor/Tokenizer # Processor/Tokenizer
@@ -284,12 +290,14 @@ def _load_model_locked(model_id: str):
_loaded_model_type = "generation" _loaded_model_type = "generation"
else: else:
# Standard Text Model (GPT-OSS) # Standard Text Model (GPT-OSS)
print(f"Loading {model_id} with AutoModelForCausalLM...")
model = AutoModelForCausalLM.from_pretrained( model = AutoModelForCausalLM.from_pretrained(
model_id, model_id,
torch_dtype=dtype, torch_dtype=dtype,
device_map=device_map, device_map=device_map,
attn_implementation=attn_impl, attn_implementation=attn_impl,
trust_remote_code=True, trust_remote_code=True,
low_cpu_mem_usage=True,
).eval() ).eval()
processor = AutoTokenizer.from_pretrained( processor = AutoTokenizer.from_pretrained(
model_id, trust_remote_code=True model_id, trust_remote_code=True
@@ -367,6 +375,41 @@ def _extract_embeddings(outputs) -> torch.Tensor:
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@app.post("/preload")
def preload_model(request: PreloadRequest):
model_id = request.model.strip()
if model_id not in ALLOWED_MODEL_IDS:
raise HTTPException(
status_code=400,
detail=f"Model {model_id} not in allowed models.",
)
with _model_lock:
try:
_ensure_model_loaded(model_id)
except Exception as e:
raise HTTPException(
status_code=500, detail=f"Failed to load model: {e}"
)
return {
"status": "ok",
"loaded_model_id": _loaded_model_id,
"vram_bytes": _current_vram_info(),
}
@app.post("/unload")
def unload_model():
with _model_lock:
stats = _unload_model_locked()
return {
"status": "ok",
"vram_bytes": _current_vram_info(),
"stats": stats,
}
@app.get("/health") @app.get("/health")
def health(): def health():
cuda_ok = bool(torch.cuda.is_available()) cuda_ok = bool(torch.cuda.is_available())