For AI embedding model timing checks

This commit is contained in:
llm
2025-12-13 22:28:55 +01:00
parent 39ee9c3b92
commit 4d25d9c679

View File

@@ -0,0 +1,95 @@
import torch
from transformers import AutoModel, AutoProcessor
from PIL import Image, UnidentifiedImageError
import requests
from io import BytesIO
import time
# Configuration
MODEL_ID = "TomoroAI/tomoro-colqwen3-embed-4b"
DTYPE = torch.bfloat16
# DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# DEVICE = "cuda"
DEVICE = "cpu"
# Load Model & Processor
processor = AutoProcessor.from_pretrained(
MODEL_ID,
trust_remote_code=True,
max_num_visual_tokens=1280,
)
model = AutoModel.from_pretrained(
MODEL_ID,
dtype=DTYPE,
attn_implementation="flash_attention_2",
trust_remote_code=True,
device_map=DEVICE,
).eval()
# Sample Data
queries = [
"Retrieve the city of Singapore",
"Retrieve the city of Beijing",
"Retrieve the city of London",
]
docs = [
"https://upload.wikimedia.org/wikipedia/commons/2/27/Singapore_skyline_2022.jpg",
"https://upload.wikimedia.org/wikipedia/commons/6/61/Beijing_skyline_at_night.JPG",
"https://upload.wikimedia.org/wikipedia/commons/4/49/London_skyline.jpg",
]
def load_image(url: str) -> Image.Image:
# Some CDNs (e.g., Wikimedia) expect a browser-like UA to avoid 403s.
for headers in ({}, {"User-Agent": "Mozilla/5.0 (compatible; ColQwen3-demo/1.0)"}):
resp = requests.get(url, headers=headers, timeout=10)
if resp.status_code == 403:
continue
resp.raise_for_status()
try:
return Image.open(BytesIO(resp.content)).convert("RGB")
except UnidentifiedImageError as e:
raise RuntimeError(f"Failed to decode image from {url}") from e
raise RuntimeError(f"Could not fetch image (HTTP 403) from {url}; try downloading locally and loading from file path.")
# Helper Functions
def encode_queries(texts, batch_size=8):
outputs = []
for start in range(0, len(texts), batch_size):
batch = processor.process_texts(texts=texts[start : start + batch_size])
batch = {k: v.to(DEVICE) for k, v in batch.items()}
with torch.inference_mode():
out = model(**batch)
vecs = out.embeddings.to(torch.bfloat16).cpu()
outputs.extend(vecs)
return outputs
def encode_docs(urls):
outputs = []
for idx, url in enumerate(urls):
img = load_image(url)
features = processor.process_images(images=[img])
features = {k: v.to(DEVICE) if isinstance(v, torch.Tensor) else v for k, v in features.items()}
# Warm up on the first image, measure only 2nd and 3rd embeddings generation
if idx in (1, 2):
start_ns = time.perf_counter_ns()
with torch.inference_mode():
out = model(**features)
vecs = out.embeddings.to(torch.bfloat16).cpu()
end_ns = time.perf_counter_ns()
duration_ns = end_ns - start_ns
print(f"Duration encode_docs image {idx + 1}: {duration_ns:,} ns")
else:
with torch.inference_mode():
out = model(**features)
vecs = out.embeddings.to(torch.bfloat16).cpu()
outputs.extend(vecs)
return outputs
# Execution
query_embeddings = encode_queries(queries)
doc_embeddings = encode_docs(docs)
# MaxSim Scoring
scores = processor.score_multi_vector(query_embeddings, doc_embeddings)
print(scores)