167 lines
4.9 KiB
Python
Executable File
167 lines
4.9 KiB
Python
Executable File
#!/usr/bin/env python
|
|
|
|
import os
|
|
from typing import List
|
|
|
|
import torch
|
|
from fastapi import FastAPI, HTTPException
|
|
from pydantic import BaseModel
|
|
from transformers.utils.import_utils import is_flash_attn_2_available
|
|
from colpali_engine.models import ColQwen2_5, ColQwen2_5_Processor
|
|
|
|
HF_MODEL_ID = os.environ.get("HF_MODEL_ID", "nomic-ai/colnomic-embed-multimodal-7b")
|
|
HF_MODEL_URL = os.environ.get("HF_MODEL_URL")
|
|
API_PORT = int(os.environ.get("PYTORCH_CONTAINER_PORT", os.environ.get("PORT", "8000")))
|
|
|
|
app = FastAPI(title="Colnomic Embed Multimodal 7B API")
|
|
|
|
_model = None
|
|
_processor = None
|
|
_device = None
|
|
|
|
|
|
def _ensure_model_loaded():
|
|
"""
|
|
Lazy-load the ColNomic model and processor on first request.
|
|
|
|
Hard requirements for this deployment:
|
|
- CUDA must be available.
|
|
- FlashAttention-2 must be available (flash-attn successfully installed).
|
|
|
|
If either is missing, an exception is raised and /health returns 500.
|
|
"""
|
|
global _model, _processor, _device
|
|
|
|
if _model is not None and _processor is not None:
|
|
return _model, _processor, _device
|
|
|
|
if not torch.cuda.is_available():
|
|
raise RuntimeError("CUDA is not available; a CUDA-capable GPU is required.")
|
|
|
|
if not is_flash_attn_2_available():
|
|
raise RuntimeError("flash_attn_2 is not available; please install compatible libraries.")
|
|
|
|
# Choose dtype: BF16 if supported, otherwise FP16
|
|
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
|
|
|
|
# Use a single GPU (cuda:0) for now.
|
|
device_map = "cuda:0"
|
|
|
|
# Force FlashAttention-2 (we already checked availability above).
|
|
attn_impl = "flash_attention_2"
|
|
|
|
model = ColQwen2_5.from_pretrained(
|
|
HF_MODEL_ID,
|
|
torch_dtype=dtype,
|
|
device_map=device_map,
|
|
attn_implementation=attn_impl,
|
|
).eval()
|
|
|
|
processor = ColQwen2_5_Processor.from_pretrained(HF_MODEL_ID)
|
|
|
|
_model = model
|
|
_processor = processor
|
|
_device = device_map
|
|
|
|
return _model, _processor, _device
|
|
|
|
|
|
class EmbedRequest(BaseModel):
|
|
texts: List[str]
|
|
|
|
|
|
class EmbedResponse(BaseModel):
|
|
model_id: str
|
|
# results[batch][tokens][dim]
|
|
results: List[List[List[float]]]
|
|
|
|
|
|
@app.get("/health")
|
|
def health():
|
|
"""
|
|
Health check:
|
|
- Reports CUDA and FlashAttention-2 availability.
|
|
- Tries to load the model once (lazy).
|
|
- Returns 200 only if CUDA, FlashAttention-2 and model loading are OK.
|
|
"""
|
|
cuda_ok = bool(torch.cuda.is_available())
|
|
flash_ok = bool(is_flash_attn_2_available())
|
|
|
|
info = {
|
|
"status": "ok",
|
|
"model_id": HF_MODEL_ID,
|
|
"model_url": HF_MODEL_URL,
|
|
"cuda_available": cuda_ok,
|
|
"flash_attn_2_available": flash_ok,
|
|
}
|
|
|
|
# CUDA or FlashAttention missing -> hard failure
|
|
if not cuda_ok:
|
|
info["status"] = "error"
|
|
info["error"] = "CUDA is not available inside the container."
|
|
raise HTTPException(status_code=500, detail=info)
|
|
|
|
if not flash_ok:
|
|
info["status"] = "error"
|
|
info["error"] = "flash_attn_2 is not available; this deployment requires FlashAttention-2."
|
|
raise HTTPException(status_code=500, detail=info)
|
|
|
|
try:
|
|
_ensure_model_loaded()
|
|
except Exception as exc: # noqa: BLE001
|
|
info["status"] = "error"
|
|
info["error"] = str(exc)
|
|
raise HTTPException(status_code=500, detail=info) from exc
|
|
|
|
return info
|
|
|
|
|
|
@app.post("/embed", response_model=EmbedResponse)
|
|
def embed(request: EmbedRequest):
|
|
"""
|
|
Compute multi-vector embeddings for a list of texts.
|
|
|
|
Result shape: results[batch][tokens][dim] (multi-vector per text).
|
|
"""
|
|
if not request.texts:
|
|
raise HTTPException(status_code=400, detail="texts must not be empty")
|
|
|
|
model, processor, device = _ensure_model_loaded() # noqa: F841 - device kept for future use
|
|
|
|
# For queries, use process_queries (as in ColQwen2.5 docs)
|
|
with torch.inference_mode():
|
|
batch = processor.process_queries(request.texts).to(model.device)
|
|
outputs = model(**batch)
|
|
|
|
# ColQwen2.5 returns either:
|
|
# - a tensor shaped (batch, tokens, dim), or
|
|
# - an object with .last_hidden_state
|
|
if isinstance(outputs, torch.Tensor):
|
|
embeddings = outputs
|
|
elif hasattr(outputs, "last_hidden_state"):
|
|
embeddings = outputs.last_hidden_state
|
|
else:
|
|
raise HTTPException(
|
|
status_code=500,
|
|
detail=f"Unexpected model output type from ColQwen/ColPali: {type(outputs)}",
|
|
)
|
|
|
|
if embeddings.dim() == 2: # (tokens, dim) -> single text
|
|
embeddings = embeddings.unsqueeze(0)
|
|
elif embeddings.dim() != 3:
|
|
raise HTTPException(
|
|
status_code=500,
|
|
detail=f"Unexpected embedding shape: {tuple(embeddings.shape)}",
|
|
)
|
|
|
|
embeddings = embeddings.detach().cpu().float()
|
|
results = embeddings.tolist()
|
|
|
|
return EmbedResponse(model_id=HF_MODEL_ID, results=results)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import uvicorn
|
|
|
|
uvicorn.run(app, host="0.0.0.0", port=API_PORT)
|