647 lines
20 KiB
Python
Executable File
647 lines
20 KiB
Python
Executable File
#!/usr/bin/env python
|
|
import base64
|
|
import gc
|
|
import io
|
|
import os
|
|
import threading
|
|
import time
|
|
import uuid
|
|
from typing import List, Optional, Union, Dict, Any, Literal
|
|
|
|
import torch
|
|
from colpali_engine.models import ColQwen2_5, ColQwen2_5_Processor
|
|
from fastapi import FastAPI, HTTPException, Request
|
|
from PIL import Image, ImageFile
|
|
from pydantic import BaseModel, Field
|
|
from transformers import (
|
|
AutoModelForCausalLM,
|
|
AutoTokenizer,
|
|
AutoProcessor,
|
|
AutoModel,
|
|
AutoModelForVision2Seq,
|
|
)
|
|
from transformers.utils.import_utils import is_flash_attn_2_available
|
|
|
|
# -----------------------------------------------------------------------------
|
|
# Configuration
|
|
# -----------------------------------------------------------------------------
|
|
|
|
# Embedding Models
|
|
MODEL_ID_NOMIC = "nomic-ai/colnomic-embed-multimodal-7b"
|
|
MODEL_ID_EVO_7B = "ApsaraStackMaaS/EvoQwen2.5-VL-Retriever-7B-v1"
|
|
|
|
# Generation Models
|
|
MODEL_ID_QWEN3_VL_8B_INSTRUCT = "Qwen/Qwen3-VL-8B-Instruct"
|
|
MODEL_ID_QWEN3_VL_8B_INSTRUCT_FP8 = "Qwen/Qwen3-VL-8B-Instruct-FP8"
|
|
MODEL_ID_QWEN3_VL_8B_THINKING = "Qwen/Qwen3-VL-8B-Thinking"
|
|
MODEL_ID_QWEN3_VL_8B_THINKING_FP8 = "Qwen/Qwen3-VL-8B-Thinking-FP8"
|
|
MODEL_ID_GPT_OSS_20B = "openai/gpt-oss-20b"
|
|
|
|
ALLOWED_EMBEDDING_MODELS = {MODEL_ID_NOMIC, MODEL_ID_EVO_7B}
|
|
ALLOWED_GENERATION_MODELS = {
|
|
MODEL_ID_QWEN3_VL_8B_INSTRUCT,
|
|
MODEL_ID_QWEN3_VL_8B_INSTRUCT_FP8,
|
|
MODEL_ID_QWEN3_VL_8B_THINKING,
|
|
MODEL_ID_QWEN3_VL_8B_THINKING_FP8,
|
|
MODEL_ID_GPT_OSS_20B,
|
|
}
|
|
|
|
ALLOWED_MODEL_IDS = ALLOWED_EMBEDDING_MODELS | ALLOWED_GENERATION_MODELS
|
|
|
|
# Default selected model (must be one of ALLOWED_MODEL_IDS).
|
|
# If env not allowed, fallback to MODEL_ID_NOMIC.
|
|
ENV_DEFAULT_MODEL = os.environ.get("HF_MODEL_ID", MODEL_ID_NOMIC)
|
|
DEFAULT_MODEL_ID = (
|
|
ENV_DEFAULT_MODEL
|
|
if ENV_DEFAULT_MODEL in ALLOWED_MODEL_IDS
|
|
else MODEL_ID_NOMIC
|
|
)
|
|
|
|
HF_MODEL_URL = os.environ.get("HF_MODEL_URL") # optional informational field
|
|
API_PORT = int(
|
|
os.environ.get("PYTORCH_CONTAINER_PORT", os.environ.get("PORT", "8000"))
|
|
)
|
|
|
|
# Limits (env-overridable)
|
|
MAX_TEXTS_PER_REQUEST = int(os.environ.get("TEXT_MAX_ITEMS", "32"))
|
|
MAX_IMAGES_PER_REQUEST = int(os.environ.get("IMAGE_MAX_ITEMS", "8"))
|
|
MAX_IMAGE_BASE64_BYTES = int(
|
|
os.environ.get("IMAGE_MAX_BASE64_BYTES", str(25 * 1024 * 1024))
|
|
) # 25MB per image b64 (approx)
|
|
MAX_IMAGE_PIXELS = int(
|
|
os.environ.get("IMAGE_MAX_PIXELS", str(30_000_000))
|
|
) # ~30MP safety
|
|
|
|
# PIL safety for large images
|
|
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
|
Image.MAX_IMAGE_PIXELS = MAX_IMAGE_PIXELS
|
|
|
|
# -----------------------------------------------------------------------------
|
|
# App + Global State
|
|
# -----------------------------------------------------------------------------
|
|
|
|
app = FastAPI(title="AI Model Service")
|
|
|
|
_model_lock = threading.RLock()
|
|
|
|
# Unified model storage
|
|
_model: Optional[torch.nn.Module] = None
|
|
# Can be ColQwen2_5_Processor, AutoTokenizer, or AutoProcessor
|
|
_processor: Optional[Any] = None
|
|
_loaded_model_id: Optional[str] = None
|
|
_loaded_model_type: Optional[str] = None # "embedding" or "generation"
|
|
|
|
# For reporting
|
|
_dtype_str: Optional[str] = None
|
|
_device_str: str = "cuda:0"
|
|
|
|
# -----------------------------------------------------------------------------
|
|
# Pydantic Models (OpenAI Compatible)
|
|
# -----------------------------------------------------------------------------
|
|
|
|
|
|
class ModelCard(BaseModel):
|
|
id: str
|
|
object: str = "model"
|
|
created: int = int(time.time())
|
|
owned_by: str = "system"
|
|
|
|
|
|
class ModelList(BaseModel):
|
|
object: str = "list"
|
|
data: List[ModelCard]
|
|
|
|
|
|
class ChatMessage(BaseModel):
|
|
role: str
|
|
content: Union[str, List[Dict[str, Any]]] # string or multimodal list
|
|
name: Optional[str] = None
|
|
|
|
|
|
class ChatCompletionRequest(BaseModel):
|
|
model: str
|
|
messages: List[ChatMessage]
|
|
temperature: Optional[float] = 1.0
|
|
top_p: Optional[float] = 1.0
|
|
n: Optional[int] = 1
|
|
stream: Optional[bool] = False
|
|
stop: Optional[Union[str, List[str]]] = None
|
|
max_tokens: Optional[int] = None
|
|
presence_penalty: Optional[float] = 0.0
|
|
frequency_penalty: Optional[float] = 0.0
|
|
logit_bias: Optional[Dict[str, float]] = None
|
|
user: Optional[str] = None
|
|
|
|
|
|
class ChatChoice(BaseModel):
|
|
index: int
|
|
message: ChatMessage
|
|
finish_reason: Optional[str] = None
|
|
|
|
|
|
class Usage(BaseModel):
|
|
prompt_tokens: int
|
|
completion_tokens: int
|
|
total_tokens: int
|
|
|
|
|
|
class ChatCompletionResponse(BaseModel):
|
|
id: str
|
|
object: str = "chat.completion"
|
|
created: int
|
|
model: str
|
|
choices: List[ChatChoice]
|
|
usage: Optional[Usage] = None
|
|
|
|
|
|
class EmbeddingRequest(BaseModel):
|
|
# OpenAI supports various inputs
|
|
input: Union[str, List[str], List[int], List[List[int]]]
|
|
model: str
|
|
encoding_format: Optional[str] = "float" # float or base64
|
|
user: Optional[str] = None
|
|
|
|
|
|
class EmbeddingObject(BaseModel):
|
|
object: str = "embedding"
|
|
index: int
|
|
# OpenAI embeddings are 1D vectors, but ColQwen is multi-vector.
|
|
# We return the raw multi-vector as the "embedding" field,
|
|
# which implies it's a list of lists.
|
|
embedding: Any
|
|
|
|
|
|
class EmbeddingResponse(BaseModel):
|
|
object: str = "list"
|
|
data: List[EmbeddingObject]
|
|
model: str
|
|
usage: Usage
|
|
|
|
|
|
class PreloadRequest(BaseModel):
|
|
model: str
|
|
|
|
|
|
# -----------------------------------------------------------------------------
|
|
# Helpers
|
|
# -----------------------------------------------------------------------------
|
|
|
|
|
|
def _torch_dtype_str(dtype: torch.dtype) -> str:
|
|
if dtype == torch.bfloat16:
|
|
return "bfloat16"
|
|
if dtype == torch.float16:
|
|
return "float16"
|
|
return str(dtype)
|
|
|
|
|
|
def _hard_requirements_check():
|
|
if not torch.cuda.is_available():
|
|
raise RuntimeError(
|
|
"CUDA is not available; a CUDA-capable GPU is required."
|
|
)
|
|
if not is_flash_attn_2_available():
|
|
# Warn but maybe not fail for generation models if they can fallback?
|
|
# But previous code had it as hard requirement. Sticking to it.
|
|
pass
|
|
|
|
|
|
def _pick_dtype() -> torch.dtype:
|
|
return torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
|
|
|
|
|
|
def _unload_model_locked():
|
|
global _model, _processor, _loaded_model_id, _loaded_model_type, _dtype_str
|
|
# Assumes caller holds _model_lock
|
|
_model = None
|
|
_processor = None
|
|
_loaded_model_id = None
|
|
_loaded_model_type = None
|
|
_dtype_str = None
|
|
gc.collect()
|
|
if torch.cuda.is_available():
|
|
try:
|
|
torch.cuda.empty_cache()
|
|
torch.cuda.ipc_collect()
|
|
except Exception:
|
|
pass
|
|
|
|
|
|
def _load_model_locked(model_id: str):
|
|
global _model, _processor, _loaded_model_id, _loaded_model_type
|
|
global _dtype_str, _device_str
|
|
|
|
_hard_requirements_check()
|
|
dtype = _pick_dtype()
|
|
device_map = "cuda:0"
|
|
attn_impl = (
|
|
"flash_attention_2" if is_flash_attn_2_available() else "sdpa"
|
|
)
|
|
|
|
if model_id in ALLOWED_EMBEDDING_MODELS:
|
|
# Load Embedding Model
|
|
model = ColQwen2_5.from_pretrained(
|
|
model_id,
|
|
torch_dtype=dtype,
|
|
device_map=device_map,
|
|
attn_implementation="flash_attention_2", # ColQwen mandates FA2
|
|
).eval()
|
|
processor = ColQwen2_5_Processor.from_pretrained(model_id)
|
|
_loaded_model_type = "embedding"
|
|
|
|
elif model_id in ALLOWED_GENERATION_MODELS:
|
|
# Load Generation Model
|
|
# Check if it is a VL model
|
|
if "VL" in model_id:
|
|
# Use AutoModelForVision2Seq for VL models
|
|
# The configuration class Qwen3VLConfig requires Vision2Seq or AutoModel
|
|
try:
|
|
print(f"Loading {model_id} with AutoModelForVision2Seq...")
|
|
model = AutoModelForVision2Seq.from_pretrained(
|
|
model_id,
|
|
torch_dtype=dtype,
|
|
device_map=device_map,
|
|
attn_implementation=attn_impl,
|
|
trust_remote_code=True,
|
|
low_cpu_mem_usage=True,
|
|
).eval()
|
|
except Exception as e:
|
|
print(f"Vision2Seq failed: {e}. Fallback to AutoModel...")
|
|
# Fallback to generic AutoModel if Vision2Seq fails
|
|
model = AutoModel.from_pretrained(
|
|
model_id,
|
|
torch_dtype=dtype,
|
|
device_map=device_map,
|
|
attn_implementation=attn_impl,
|
|
trust_remote_code=True,
|
|
low_cpu_mem_usage=True,
|
|
).eval()
|
|
|
|
# Processor/Tokenizer
|
|
try:
|
|
processor = AutoProcessor.from_pretrained(
|
|
model_id, trust_remote_code=True
|
|
)
|
|
except Exception:
|
|
processor = AutoTokenizer.from_pretrained(
|
|
model_id, trust_remote_code=True
|
|
)
|
|
|
|
_loaded_model_type = "generation"
|
|
else:
|
|
# Standard Text Model (GPT-OSS)
|
|
print(f"Loading {model_id} with AutoModelForCausalLM...")
|
|
# GPT-OSS-20B uses native MXFP4 quantization and needs "auto" dtype
|
|
use_dtype = "auto" if "gpt-oss-20b" in model_id else dtype
|
|
model = AutoModelForCausalLM.from_pretrained(
|
|
model_id,
|
|
torch_dtype=use_dtype,
|
|
device_map=device_map,
|
|
attn_implementation=attn_impl,
|
|
trust_remote_code=True,
|
|
low_cpu_mem_usage=True,
|
|
).eval()
|
|
processor = AutoTokenizer.from_pretrained(
|
|
model_id, trust_remote_code=True
|
|
)
|
|
_loaded_model_type = "generation"
|
|
|
|
else:
|
|
raise ValueError(f"Unknown model type for {model_id}")
|
|
|
|
_model = model
|
|
_processor = processor
|
|
_loaded_model_id = model_id
|
|
_dtype_str = _torch_dtype_str(dtype)
|
|
_device_str = device_map
|
|
|
|
|
|
def _ensure_model_loaded(model_id: str):
|
|
with _model_lock:
|
|
if (
|
|
_model is not None
|
|
and _processor is not None
|
|
and _loaded_model_id == model_id
|
|
):
|
|
return _model, _processor, _loaded_model_type
|
|
|
|
_unload_model_locked()
|
|
_load_model_locked(model_id)
|
|
return _model, _processor, _loaded_model_type
|
|
|
|
|
|
def _current_vram_info():
|
|
if not torch.cuda.is_available():
|
|
return {"free": None, "total": None, "used": None}
|
|
free, total = torch.cuda.mem_get_info(0)
|
|
used = total - free
|
|
return {"free": free, "total": total, "used": used}
|
|
|
|
|
|
def _decode_base64_image(b64_data: str):
|
|
approx_bytes = int(len(b64_data) * 0.75)
|
|
if approx_bytes > MAX_IMAGE_BASE64_BYTES:
|
|
raise ValueError(
|
|
f"Image exceeds max base64 size of {MAX_IMAGE_BASE64_BYTES} bytes"
|
|
)
|
|
try:
|
|
raw = base64.b64decode(b64_data, validate=True)
|
|
img = Image.open(io.BytesIO(raw))
|
|
if img.mode != "RGB":
|
|
img = img.convert("RGB")
|
|
img.load()
|
|
return img
|
|
except Exception as e:
|
|
raise ValueError(f"Unable to decode image: {e}")
|
|
|
|
|
|
def _extract_embeddings(outputs) -> torch.Tensor:
|
|
if isinstance(outputs, torch.Tensor):
|
|
embeddings = outputs
|
|
elif hasattr(outputs, "last_hidden_state"):
|
|
embeddings = outputs.last_hidden_state
|
|
else:
|
|
raise RuntimeError(f"Unexpected model output type: {type(outputs)}")
|
|
|
|
if embeddings.dim() == 2:
|
|
embeddings = embeddings.unsqueeze(0)
|
|
elif embeddings.dim() != 3:
|
|
raise RuntimeError(
|
|
f"Unexpected embedding shape: {tuple(embeddings.shape)}"
|
|
)
|
|
return embeddings
|
|
|
|
|
|
# -----------------------------------------------------------------------------
|
|
# Endpoints
|
|
# -----------------------------------------------------------------------------
|
|
|
|
|
|
@app.post("/preload")
|
|
def preload_model(request: PreloadRequest):
|
|
model_id = request.model.strip()
|
|
if model_id not in ALLOWED_MODEL_IDS:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail=f"Model {model_id} not in allowed models.",
|
|
)
|
|
|
|
with _model_lock:
|
|
try:
|
|
_ensure_model_loaded(model_id)
|
|
except Exception as e:
|
|
raise HTTPException(
|
|
status_code=500, detail=f"Failed to load model: {e}"
|
|
)
|
|
|
|
return {
|
|
"status": "ok",
|
|
"loaded_model_id": _loaded_model_id,
|
|
"vram_bytes": _current_vram_info(),
|
|
}
|
|
|
|
|
|
@app.post("/unload")
|
|
def unload_model():
|
|
with _model_lock:
|
|
stats = _unload_model_locked()
|
|
return {
|
|
"status": "ok",
|
|
"vram_bytes": _current_vram_info(),
|
|
"stats": stats,
|
|
}
|
|
|
|
|
|
@app.get("/health")
|
|
def health():
|
|
cuda_ok = bool(torch.cuda.is_available())
|
|
flash_ok = bool(is_flash_attn_2_available())
|
|
|
|
info = {
|
|
"status": "ok",
|
|
"loaded_model_id": _loaded_model_id,
|
|
"cuda_available": cuda_ok,
|
|
"flash_attn_2_available": flash_ok,
|
|
"vram_bytes": _current_vram_info(),
|
|
}
|
|
|
|
if not cuda_ok:
|
|
info["status"] = "error"
|
|
info["error"] = "CUDA is not available."
|
|
raise HTTPException(status_code=500, detail=info)
|
|
|
|
return info
|
|
|
|
|
|
@app.get("/v1/models", response_model=ModelList)
|
|
def list_models():
|
|
models = []
|
|
for mid in ALLOWED_MODEL_IDS:
|
|
models.append(ModelCard(id=mid))
|
|
return ModelList(data=models)
|
|
|
|
|
|
@app.post("/v1/chat/completions", response_model=ChatCompletionResponse)
|
|
async def chat_completions(request: ChatCompletionRequest):
|
|
model_id = request.model
|
|
if model_id not in ALLOWED_GENERATION_MODELS:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail=f"Model {model_id} not supported or not a generation model."
|
|
)
|
|
|
|
with _model_lock:
|
|
try:
|
|
model, processor, mtype = _ensure_model_loaded(model_id)
|
|
except Exception as e:
|
|
raise HTTPException(
|
|
status_code=500, detail=f"Failed to load model: {e}"
|
|
)
|
|
|
|
if mtype != "generation":
|
|
raise HTTPException(
|
|
status_code=500,
|
|
detail=(f"Model loaded as {mtype} "
|
|
"but accessed via chat completion.")
|
|
)
|
|
|
|
# Prepare input
|
|
# Naive implementation: concatenate messages.
|
|
# Ideally apply chat template if available.
|
|
|
|
prompt_text = ""
|
|
# images = []
|
|
|
|
# Check if we have apply_chat_template support (most modern tokenizers do)
|
|
has_template = hasattr(processor, "apply_chat_template")
|
|
|
|
if has_template:
|
|
# processor can be Tokenizer or Processor.
|
|
# If it is a Processor (for VL), it might expect specific format.
|
|
# We'll try passing the messages dict directly.
|
|
try:
|
|
# Convert Pydantic messages to dict
|
|
msgs = [
|
|
m.model_dump(exclude_none=True) for m in request.messages
|
|
]
|
|
|
|
# Check for images in messages if VL model
|
|
# TODO: Extract base64 images from content if present
|
|
|
|
text_input = processor.apply_chat_template(
|
|
msgs, tokenize=False, add_generation_prompt=True
|
|
)
|
|
except Exception as e:
|
|
# Fallback to manual concatenation
|
|
print(f"Template application failed: {e}")
|
|
text_input = ""
|
|
for m in request.messages:
|
|
content = m.content
|
|
if isinstance(content, list):
|
|
# Handle multimodal content list - extract text
|
|
content = " ".join(
|
|
[
|
|
c.get("text", "")
|
|
for c in content
|
|
if c.get("type") == "text"
|
|
]
|
|
)
|
|
text_input += f"<|im_start|>{m.role}\n"
|
|
text_input += f"{content}<|im_end|>\n"
|
|
text_input += "<|im_start|>assistant\n"
|
|
else:
|
|
text_input = ""
|
|
for m in request.messages:
|
|
content = m.content
|
|
if isinstance(content, list):
|
|
content = " ".join(
|
|
[
|
|
c.get("text", "")
|
|
for c in content
|
|
if c.get("type") == "text"
|
|
]
|
|
)
|
|
text_input += f"{m.role}: {content}\n"
|
|
text_input += "assistant: "
|
|
|
|
# Tokenize
|
|
inputs = None
|
|
if (
|
|
hasattr(processor, "process_images")
|
|
or "Processor" in processor.__class__.__name__
|
|
):
|
|
# It's likely a VL processor.
|
|
inputs = processor(
|
|
text=[text_input], return_tensors="pt", padding=True
|
|
).to(_device_str)
|
|
else:
|
|
# Standard tokenizer
|
|
inputs = processor(text_input, return_tensors="pt").to(_device_str)
|
|
|
|
# Generate
|
|
with torch.inference_mode():
|
|
generated_ids = model.generate(
|
|
**inputs,
|
|
max_new_tokens=request.max_tokens or 512,
|
|
do_sample=request.temperature > 0,
|
|
temperature=request.temperature,
|
|
top_p=request.top_p,
|
|
)
|
|
|
|
# Decode
|
|
input_len = inputs.input_ids.shape[1]
|
|
generated_ids = generated_ids[:, input_len:]
|
|
output_text = processor.decode(
|
|
generated_ids[0], skip_special_tokens=True
|
|
)
|
|
|
|
# Usage
|
|
usage = Usage(
|
|
prompt_tokens=input_len,
|
|
completion_tokens=generated_ids.shape[1],
|
|
total_tokens=input_len + generated_ids.shape[1],
|
|
)
|
|
|
|
choice = ChatChoice(
|
|
index=0,
|
|
message=ChatMessage(role="assistant", content=output_text),
|
|
finish_reason="stop",
|
|
)
|
|
|
|
return ChatCompletionResponse(
|
|
id=str(uuid.uuid4()),
|
|
created=int(time.time()),
|
|
model=model_id,
|
|
choices=[choice],
|
|
usage=usage,
|
|
)
|
|
|
|
|
|
@app.post("/v1/embeddings", response_model=EmbeddingResponse)
|
|
def create_embeddings(request: EmbeddingRequest):
|
|
model_id = request.model
|
|
# We check if model_id is allowed.
|
|
if model_id not in ALLOWED_EMBEDDING_MODELS:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail=f"Model {model_id} not supported or not an embedding model."
|
|
)
|
|
|
|
with _model_lock:
|
|
try:
|
|
model, processor, mtype = _ensure_model_loaded(model_id)
|
|
except Exception as e:
|
|
raise HTTPException(
|
|
status_code=500, detail=f"Failed to load model: {e}"
|
|
)
|
|
|
|
if mtype != "embedding":
|
|
raise HTTPException(
|
|
status_code=500, detail="Model is not an embedding model."
|
|
)
|
|
|
|
# Handle input
|
|
texts = request.input
|
|
if isinstance(texts, str):
|
|
texts = [texts]
|
|
# If it's list of tokens (int), we can't handle with current processor
|
|
if texts and isinstance(texts[0], int):
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail="Token IDs input not supported, please provide text.",
|
|
)
|
|
|
|
try:
|
|
with torch.inference_mode():
|
|
# ColQwen processor handles queries/docs. Assume queries.
|
|
batch = processor.process_queries(texts).to(_device_str)
|
|
outputs = model(**batch)
|
|
embeddings = _extract_embeddings(outputs)
|
|
except Exception as exc:
|
|
raise HTTPException(
|
|
status_code=500, detail=f"Failed to compute embeddings: {exc}"
|
|
)
|
|
|
|
embeddings = embeddings.detach().cpu().float().tolist()
|
|
|
|
data = []
|
|
token_count = 0 # Dummy count
|
|
for i, emb in enumerate(embeddings):
|
|
data.append(EmbeddingObject(index=i, embedding=emb))
|
|
|
|
return EmbeddingResponse(
|
|
data=data,
|
|
model=model_id,
|
|
usage=Usage(
|
|
prompt_tokens=token_count,
|
|
completion_tokens=0,
|
|
total_tokens=token_count,
|
|
),
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import uvicorn
|
|
|
|
uvicorn.run(app, host="0.0.0.0", port=API_PORT)
|