PyTorch experiments
This commit is contained in:
120
python/experiments_tomoro-colqwen3-embed-4b.py
Normal file
120
python/experiments_tomoro-colqwen3-embed-4b.py
Normal file
@@ -0,0 +1,120 @@
|
||||
import torch
|
||||
from transformers import AutoModel, AutoProcessor
|
||||
from PIL import Image, UnidentifiedImageError
|
||||
import requests
|
||||
from io import BytesIO
|
||||
import time
|
||||
|
||||
# Configuration
|
||||
MODEL_ID = "TomoroAI/tomoro-colqwen3-embed-4b"
|
||||
DTYPE = torch.bfloat16
|
||||
# DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
DEVICE = "cpu"
|
||||
# DEVICE = "cuda"
|
||||
|
||||
start_ts = time.perf_counter_ns()
|
||||
|
||||
# Load Model & Processor
|
||||
processor = AutoProcessor.from_pretrained(
|
||||
MODEL_ID,
|
||||
trust_remote_code=True,
|
||||
max_num_visual_tokens=1280,
|
||||
)
|
||||
model = AutoModel.from_pretrained(
|
||||
MODEL_ID,
|
||||
dtype=DTYPE,
|
||||
# attn_implementation="flash_attention_2",
|
||||
attn_implementation="sdpa",
|
||||
trust_remote_code=True,
|
||||
device_map=DEVICE,
|
||||
).eval()
|
||||
|
||||
duration_ns = time.perf_counter_ns() - start_ts
|
||||
print(f"Duration Load Model & Processor: {duration_ns:,} ns")
|
||||
total_params = sum(p.numel() for p in model.parameters())
|
||||
print(f"total_params: {total_params:,}")
|
||||
|
||||
# Sample Data
|
||||
queries = [
|
||||
"Retrieve the city of Singapore",
|
||||
# "Retrieve the city of Beijing",
|
||||
# "Retrieve the city of London",
|
||||
]
|
||||
docs = [
|
||||
"https://upload.wikimedia.org/wikipedia/commons/2/27/Singapore_skyline_2022.jpg",
|
||||
# "https://upload.wikimedia.org/wikipedia/commons/6/61/Beijing_skyline_at_night.JPG",
|
||||
# "https://upload.wikimedia.org/wikipedia/commons/4/49/London_skyline.jpg",
|
||||
]
|
||||
|
||||
def load_image(url: str) -> Image.Image:
|
||||
# Some CDNs (e.g., Wikimedia) expect a browser-like UA to avoid 403s.
|
||||
for headers in ({}, {"User-Agent": "Mozilla/5.0 (compatible; ColQwen3-demo/1.0)"}):
|
||||
resp = requests.get(url, headers=headers, timeout=10)
|
||||
if resp.status_code == 403:
|
||||
continue
|
||||
resp.raise_for_status()
|
||||
try:
|
||||
return Image.open(BytesIO(resp.content)).convert("RGB")
|
||||
except UnidentifiedImageError as e:
|
||||
raise RuntimeError(f"Failed to decode image from {url}") from e
|
||||
raise RuntimeError(f"Could not fetch image (HTTP 403) from {url}; try downloading locally and loading from file path.")
|
||||
|
||||
# Helper Functions
|
||||
def encode_queries(texts, batch_size=8):
|
||||
outputs = []
|
||||
for start in range(0, len(texts), batch_size):
|
||||
batch = processor.process_texts(texts=texts[start : start + batch_size])
|
||||
batch = {k: v.to(DEVICE) for k, v in batch.items()}
|
||||
with torch.inference_mode():
|
||||
out = model(**batch)
|
||||
vecs = out.embeddings.to(torch.bfloat16).cpu()
|
||||
outputs.extend(vecs)
|
||||
return outputs
|
||||
|
||||
def encode_docs(urls, batch_size=4):
|
||||
pil_images = [load_image(url) for url in urls]
|
||||
outputs = []
|
||||
for start in range(0, len(pil_images), batch_size):
|
||||
batch_imgs = pil_images[start : start + batch_size]
|
||||
features = processor.process_images(images=batch_imgs)
|
||||
features = {k: v.to(DEVICE) if isinstance(v, torch.Tensor) else v for k, v in features.items()}
|
||||
with torch.inference_mode():
|
||||
out = model(**features)
|
||||
print(f"type(out.embeddings) = {type(out.embeddings)}")
|
||||
print(f"out.embeddings.shape = {out.embeddings.shape}")
|
||||
print(f"out.embeddings.ndim = {out.embeddings.ndim}")
|
||||
print(f"out.embeddings.device = {out.embeddings.device}")
|
||||
print(f"out.embeddings.numel() = {out.embeddings.numel()}")
|
||||
print("out.embeddings.element_size() = "
|
||||
f"{out.embeddings.element_size()}")
|
||||
print("out.embeddings.numel() * out.embeddings.element_size() = "
|
||||
f"{out.embeddings.numel() * out.embeddings.element_size()}")
|
||||
vecs = out.embeddings.to(torch.bfloat16).cpu()
|
||||
outputs.extend(vecs)
|
||||
return outputs
|
||||
|
||||
# Execution
|
||||
|
||||
start_ts = time.perf_counter_ns()
|
||||
|
||||
query_embeddings = encode_queries(queries)
|
||||
|
||||
duration_ns = time.perf_counter_ns() - start_ts
|
||||
print(f"Duration encode_queries: {duration_ns:,} ns")
|
||||
start_ts = time.perf_counter_ns()
|
||||
|
||||
doc_embeddings = encode_docs(docs)
|
||||
|
||||
duration_ns = time.perf_counter_ns() - start_ts
|
||||
print(f"Duration encode_docs: {duration_ns:,} ns")
|
||||
|
||||
# MaxSim Scoring
|
||||
|
||||
start_ts = time.perf_counter_ns()
|
||||
|
||||
scores = processor.score_multi_vector(query_embeddings, doc_embeddings)
|
||||
|
||||
duration_ns = time.perf_counter_ns() - start_ts
|
||||
print(f"Duration score_multi_vector: {duration_ns:,} ns")
|
||||
|
||||
print(scores)
|
||||
Reference in New Issue
Block a user