#!/usr/bin/env python import base64 import gc import io import os import threading from typing import List, Optional import torch from colpali_engine.models import ColQwen2_5, ColQwen2_5_Processor from fastapi import FastAPI, HTTPException from PIL import Image, ImageFile from pydantic import BaseModel, Field from transformers.utils.import_utils import is_flash_attn_2_available # -------------------------------------------------------------------------------------- # Configuration # -------------------------------------------------------------------------------------- # Allowed models (strictly limited per user request) MODEL_ID_NOMIC = "nomic-ai/colnomic-embed-multimodal-7b" MODEL_ID_EVO_7B = "ApsaraStackMaaS/EvoQwen2.5-VL-Retriever-7B-v1" ALLOWED_MODEL_IDS = {MODEL_ID_NOMIC, MODEL_ID_EVO_7B} # Default selected model (must be one of ALLOWED_MODEL_IDS). If env not allowed, fallback to MODEL_ID_NOMIC. ENV_DEFAULT_MODEL = os.environ.get("HF_MODEL_ID", MODEL_ID_NOMIC) DEFAULT_MODEL_ID = ( ENV_DEFAULT_MODEL if ENV_DEFAULT_MODEL in ALLOWED_MODEL_IDS else MODEL_ID_NOMIC ) HF_MODEL_URL = os.environ.get("HF_MODEL_URL") # optional informational field API_PORT = int(os.environ.get("PYTORCH_CONTAINER_PORT", os.environ.get("PORT", "8000"))) # Limits (env-overridable) MAX_TEXTS_PER_REQUEST = int(os.environ.get("TEXT_MAX_ITEMS", "32")) MAX_IMAGES_PER_REQUEST = int(os.environ.get("IMAGE_MAX_ITEMS", "8")) MAX_IMAGE_BASE64_BYTES = int( os.environ.get("IMAGE_MAX_BASE64_BYTES", str(25 * 1024 * 1024)) ) # 25MB per image b64 (approx) MAX_IMAGE_PIXELS = int( os.environ.get("IMAGE_MAX_PIXELS", str(30_000_000)) ) # ~30MP safety # PIL safety for large images ImageFile.LOAD_TRUNCATED_IMAGES = True Image.MAX_IMAGE_PIXELS = MAX_IMAGE_PIXELS # -------------------------------------------------------------------------------------- # App + Global State # -------------------------------------------------------------------------------------- app = FastAPI(title="Colnomic Embed Multimodal API") _model_lock = threading.RLock() _model: Optional[torch.nn.Module] = None _processor: Optional[ColQwen2_5_Processor] = None _loaded_model_id: Optional[str] = None # For reporting _dtype_str: Optional[str] = None _device_str: str = "cuda:0" _active_model_id: str = DEFAULT_MODEL_ID # -------------------------------------------------------------------------------------- # Pydantic Models # -------------------------------------------------------------------------------------- class SelectModelRequest(BaseModel): model_id: str = Field(..., description="One of the allowed model IDs") class EmbedTextsRequest(BaseModel): # Validated in handler for min length texts: List[str] class EmbedResponse(BaseModel): model_id: str # results[batch][tokens][dim] results: List[List[List[float]]] class EmbedImagesRequest(BaseModel): # Validated in handler for min length images_b64: List[str] # base64 encoded images only class ImageMetadata(BaseModel): index: int status: str # "ok" | "error" width: Optional[int] = None height: Optional[int] = None mode: Optional[str] = None error: Optional[str] = None class EmbedImagesResponse(BaseModel): model_id: str results: List[List[List[float]]] metadata: List[ImageMetadata] # -------------------------------------------------------------------------------------- # Helpers # -------------------------------------------------------------------------------------- def _torch_dtype_str(dtype: torch.dtype) -> str: if dtype == torch.bfloat16: return "bfloat16" if dtype == torch.float16: return "float16" return str(dtype) def _hard_requirements_check(): # CUDA hard requirement if not torch.cuda.is_available(): raise RuntimeError("CUDA is not available; a CUDA-capable GPU is required.") # FlashAttention-2 hard requirement if not is_flash_attn_2_available(): raise RuntimeError( "flash_attn_2 is not available; this deployment requires FlashAttention-2." ) def _pick_dtype() -> torch.dtype: return torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 def _unload_model_locked(): global _model, _processor, _loaded_model_id, _dtype_str # Assumes caller holds _model_lock before_free, before_total = (0, 0) if torch.cuda.is_available(): free, total = torch.cuda.mem_get_info(0) before_free, before_total = free, total _model = None _processor = None _loaded_model_id = None _dtype_str = None gc.collect() if torch.cuda.is_available(): try: torch.cuda.empty_cache() torch.cuda.ipc_collect() except Exception: # Best-effort cleanup; continue pass after_free, after_total = (0, 0) if torch.cuda.is_available(): free, total = torch.cuda.mem_get_info(0) after_free, after_total = free, total return { "before": {"free": before_free, "total": before_total}, "after": {"free": after_free, "total": after_total}, "freed": max(0, (after_free - before_free)), } def _load_model_locked(model_id: str): # Assumes caller holds _model_lock global _model, _processor, _loaded_model_id, _dtype_str, _device_str _hard_requirements_check() dtype = _pick_dtype() device_map = "cuda:0" attn_impl = "flash_attention_2" # we ensured availability above model = ColQwen2_5.from_pretrained( model_id, torch_dtype=dtype, device_map=device_map, attn_implementation=attn_impl, ).eval() processor = ColQwen2_5_Processor.from_pretrained(model_id) _model = model _processor = processor _loaded_model_id = model_id _dtype_str = _torch_dtype_str(dtype) _device_str = device_map def _ensure_model_loaded(): with _model_lock: if ( _model is not None and _processor is not None and _loaded_model_id == _active_model_id ): model, processor = _model, _processor assert model is not None and processor is not None return model, processor # Different or missing model: (re)load _unload_model_locked() _load_model_locked(_active_model_id) model, processor = _model, _processor assert model is not None and processor is not None return model, processor def _current_vram_info(): if not torch.cuda.is_available(): return {"free": None, "total": None, "used": None} free, total = torch.cuda.mem_get_info(0) used = total - free return {"free": free, "total": total, "used": used} def _decode_base64_image(b64_data: str): # Size guard approx_bytes = int(len(b64_data) * 0.75) # rough, base64 overhead ~33% if approx_bytes > MAX_IMAGE_BASE64_BYTES: raise ValueError( f"Image exceeds max base64 size of {MAX_IMAGE_BASE64_BYTES} bytes" ) try: raw = base64.b64decode(b64_data, validate=True) except Exception as e: raise ValueError(f"Invalid base64: {e}") try: img = Image.open(io.BytesIO(raw)) # Convert to RGB for model compatibility if img.mode != "RGB": img = img.convert("RGB") img.load() # ensure data is read return img except Exception as e: raise ValueError(f"Unable to decode image: {e}") def _extract_embeddings(outputs) -> torch.Tensor: # ColQwen2.5 returns either: # - a tensor shaped (batch, tokens, dim), or # - an object with .last_hidden_state if isinstance(outputs, torch.Tensor): embeddings = outputs elif hasattr(outputs, "last_hidden_state"): embeddings = outputs.last_hidden_state else: raise RuntimeError(f"Unexpected model output type: {type(outputs)}") if embeddings.dim() == 2: # (tokens, dim) -> single item embeddings = embeddings.unsqueeze(0) elif embeddings.dim() != 3: raise RuntimeError(f"Unexpected embedding shape: {tuple(embeddings.shape)}") return embeddings # -------------------------------------------------------------------------------------- # Endpoints # -------------------------------------------------------------------------------------- @app.get("/health") def health(): """ Health check with hard requirements: - CUDA available - FlashAttention-2 available - Lazy-loads the active model once - Includes dtype, device, and VRAM info """ cuda_ok = bool(torch.cuda.is_available()) flash_ok = bool(is_flash_attn_2_available()) info = { "status": "ok", "model_id": _active_model_id, "model_url": HF_MODEL_URL, "cuda_available": cuda_ok, "flash_attn_2_available": flash_ok, "dtype": _dtype_str, "device": _device_str, "vram_bytes": _current_vram_info(), } # Hard failures if not cuda_ok: info["status"] = "error" info["error"] = "CUDA is not available inside the container." raise HTTPException(status_code=500, detail=info) if not flash_ok: info["status"] = "error" info["error"] = ( "flash_attn_2 is not available; this deployment requires FlashAttention-2." ) raise HTTPException(status_code=500, detail=info) try: _ensure_model_loaded() except Exception as exc: info["status"] = "error" info["error"] = str(exc) raise HTTPException(status_code=500, detail=info) from exc # Ensure final dtype/device populated info["dtype"] = _dtype_str info["device"] = _device_str info["vram_bytes"] = _current_vram_info() return info @app.post("/select-model") def select_model(req: SelectModelRequest): """ Switch the active model between the allowed set. Fully unloads the current model (free VRAM) then loads the new one. Blocks concurrent requests briefly via a lock. """ global _active_model_id target = req.model_id.strip() if target not in ALLOWED_MODEL_IDS: raise HTTPException( status_code=400, detail={ "error": "Unsupported model_id", "allowed": sorted(list(ALLOWED_MODEL_IDS)), "received": target, }, ) with _model_lock: if ( target == _active_model_id and _model is not None and _loaded_model_id == target ): # No-op return { "status": "ok", "model_id": _active_model_id, "message": "Model unchanged; already active.", } # Switch _active_model_id = target _unload_model_locked() try: _load_model_locked(_active_model_id) except Exception as exc: # Attempt to revert to a safe state: no model loaded _unload_model_locked() raise HTTPException( status_code=500, detail={"error": f"Failed to load model '{target}': {exc}"}, ) from exc return {"status": "ok", "model_id": _active_model_id} @app.post("/embed-texts", response_model=EmbedResponse) def embed_texts(request: EmbedTextsRequest): """ Compute multi-vector embeddings for a list of texts. Result shape: results[batch][tokens][dim] (multi-vector per text). Limits: - Max texts per request: TEXT_MAX_ITEMS (default 32) """ texts = request.texts if not texts: raise HTTPException(status_code=400, detail="texts must not be empty") if len(texts) > MAX_TEXTS_PER_REQUEST: raise HTTPException( status_code=400, detail=f"Too many texts; max is {MAX_TEXTS_PER_REQUEST}" ) with _model_lock: model, processor = _ensure_model_loaded() try: with torch.inference_mode(): batch = processor.process_queries(texts).to(_device_str) outputs = model(**batch) embeddings = _extract_embeddings(outputs) except Exception as exc: raise HTTPException( status_code=500, detail=f"Failed to compute text embeddings: {exc}" ) from exc embeddings = embeddings.detach().cpu().float() results = embeddings.tolist() return EmbedResponse(model_id=_active_model_id, results=results) @app.post("/embed-images", response_model=EmbedImagesResponse) def embed_images(request: EmbedImagesRequest): """ Compute multi-vector embeddings for a list of base64-encoded images. Returns results aligned with the input order: results[i] is [] if the i-th image failed to decode. Limits: - Max images per request: IMAGE_MAX_ITEMS (default 8) - Max base64 bytes per image: IMAGE_MAX_BASE64_BYTES (default ~25MB) - Max image pixels (safety): IMAGE_MAX_PIXELS (default ~30MP) """ b64_list = request.images_b64 if not b64_list: raise HTTPException(status_code=400, detail="images_b64 must not be empty") if len(b64_list) > MAX_IMAGES_PER_REQUEST: raise HTTPException( status_code=400, detail=f"Too many images; max is {MAX_IMAGES_PER_REQUEST}" ) # Decode individually and track metadata decoded_images: List[Optional[Image.Image]] = [None] * len(b64_list) metadata: List[ImageMetadata] = [] ok_indices: List[int] = [] for idx, b64_img in enumerate(b64_list): try: img = _decode_base64_image(b64_img) decoded_images[idx] = img w, h = img.size metadata.append( ImageMetadata(index=idx, status="ok", width=w, height=h, mode=img.mode) ) ok_indices.append(idx) except Exception as exc: metadata.append(ImageMetadata(index=idx, status="error", error=str(exc))) if not ok_indices: raise HTTPException( status_code=400, detail="All provided images failed to decode or were rejected by limits", ) # Prepare only successful images for batching, but preserve order in output images_ok: List[Image.Image] = [] for i in ok_indices: img_i = decoded_images[i] assert img_i is not None images_ok.append(img_i) with _model_lock: model, processor = _ensure_model_loaded() try: with torch.inference_mode(): batch_images = processor.process_images(images_ok).to(_device_str) outputs = model(**batch_images) embeddings = _extract_embeddings(outputs) except Exception as exc: raise HTTPException( status_code=500, detail=f"Failed to compute image embeddings: {exc}" ) from exc embeddings = embeddings.detach().cpu().float().tolist() # Distribute embeddings back into results aligned with original indices # For failed entries, place an empty list []. results: List[List[List[float]]] = [[] for _ in range(len(b64_list))] for pos, idx in enumerate(ok_indices): results[idx] = embeddings[pos] return EmbedImagesResponse( model_id=_active_model_id, results=results, metadata=metadata ) @app.post("/free-vram") def free_vram(): """ Frees GPU VRAM by unloading the model/processor and emptying CUDA caches. The active model selection is preserved, but the next request will re-load the model. """ with _model_lock: before = _current_vram_info() stats = _unload_model_locked() after = _current_vram_info() return { "status": "ok", "active_model_id": _active_model_id, "vram_bytes_before": before, "vram_bytes_after": after, "free_stats": stats, } # -------------------------------------------------------------------------------------- # Entrypoint # -------------------------------------------------------------------------------------- if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=API_PORT)