Qwen3-VL mode working; /unload; normal model loading times
This commit is contained in:
@@ -18,6 +18,7 @@ from transformers import (
|
||||
AutoTokenizer,
|
||||
AutoProcessor,
|
||||
AutoModel,
|
||||
AutoModelForVision2Seq,
|
||||
)
|
||||
from transformers.utils.import_utils import is_flash_attn_2_available
|
||||
|
||||
@@ -177,6 +178,10 @@ class EmbeddingResponse(BaseModel):
|
||||
usage: Usage
|
||||
|
||||
|
||||
class PreloadRequest(BaseModel):
|
||||
model: str
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -248,28 +253,29 @@ def _load_model_locked(model_id: str):
|
||||
# Load Generation Model
|
||||
# Check if it is a VL model
|
||||
if "VL" in model_id:
|
||||
# Attempt to load as VL
|
||||
# Using AutoModelForVision2Seq or AutoModelForCausalLM
|
||||
# depending on the specific model support in transformers
|
||||
# Use AutoModelForVision2Seq for VL models
|
||||
# The configuration class Qwen3VLConfig requires Vision2Seq or AutoModel
|
||||
try:
|
||||
from transformers import Qwen2VLForConditionalGeneration
|
||||
|
||||
model_class = Qwen2VLForConditionalGeneration
|
||||
except ImportError:
|
||||
# Fallback to AutoModel if specific class not available
|
||||
model_class = AutoModelForCausalLM
|
||||
|
||||
# Note: We use AutoModelForCausalLM for broad compatibility.
|
||||
# Qwen2-VL requires Qwen2VLForConditionalGeneration for vision.
|
||||
# We will try AutoModelForCausalLM first.
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_id,
|
||||
torch_dtype=dtype,
|
||||
device_map=device_map,
|
||||
attn_implementation=attn_impl,
|
||||
trust_remote_code=True, # Often needed for new architectures
|
||||
).eval()
|
||||
print(f"Loading {model_id} with AutoModelForVision2Seq...")
|
||||
model = AutoModelForVision2Seq.from_pretrained(
|
||||
model_id,
|
||||
torch_dtype=dtype,
|
||||
device_map=device_map,
|
||||
attn_implementation=attn_impl,
|
||||
trust_remote_code=True,
|
||||
low_cpu_mem_usage=True,
|
||||
).eval()
|
||||
except Exception as e:
|
||||
print(f"Vision2Seq failed: {e}. Fallback to AutoModel...")
|
||||
# Fallback to generic AutoModel if Vision2Seq fails
|
||||
model = AutoModel.from_pretrained(
|
||||
model_id,
|
||||
torch_dtype=dtype,
|
||||
device_map=device_map,
|
||||
attn_implementation=attn_impl,
|
||||
trust_remote_code=True,
|
||||
low_cpu_mem_usage=True,
|
||||
).eval()
|
||||
|
||||
# Processor/Tokenizer
|
||||
try:
|
||||
@@ -284,12 +290,14 @@ def _load_model_locked(model_id: str):
|
||||
_loaded_model_type = "generation"
|
||||
else:
|
||||
# Standard Text Model (GPT-OSS)
|
||||
print(f"Loading {model_id} with AutoModelForCausalLM...")
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_id,
|
||||
torch_dtype=dtype,
|
||||
device_map=device_map,
|
||||
attn_implementation=attn_impl,
|
||||
trust_remote_code=True,
|
||||
low_cpu_mem_usage=True,
|
||||
).eval()
|
||||
processor = AutoTokenizer.from_pretrained(
|
||||
model_id, trust_remote_code=True
|
||||
@@ -367,6 +375,41 @@ def _extract_embeddings(outputs) -> torch.Tensor:
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
@app.post("/preload")
|
||||
def preload_model(request: PreloadRequest):
|
||||
model_id = request.model.strip()
|
||||
if model_id not in ALLOWED_MODEL_IDS:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Model {model_id} not in allowed models.",
|
||||
)
|
||||
|
||||
with _model_lock:
|
||||
try:
|
||||
_ensure_model_loaded(model_id)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to load model: {e}"
|
||||
)
|
||||
|
||||
return {
|
||||
"status": "ok",
|
||||
"loaded_model_id": _loaded_model_id,
|
||||
"vram_bytes": _current_vram_info(),
|
||||
}
|
||||
|
||||
|
||||
@app.post("/unload")
|
||||
def unload_model():
|
||||
with _model_lock:
|
||||
stats = _unload_model_locked()
|
||||
return {
|
||||
"status": "ok",
|
||||
"vram_bytes": _current_vram_info(),
|
||||
"stats": stats,
|
||||
}
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
def health():
|
||||
cuda_ok = bool(torch.cuda.is_available())
|
||||
|
||||
Reference in New Issue
Block a user