Spaces:
Sleeping
Sleeping
import gradio as gr | |
import os | |
from PIL import Image | |
import numpy as np | |
import torch | |
import pickle | |
from transformers import AutoProcessor | |
from src.model import MMEBModel | |
from src.arguments import ModelArguments | |
QUERY_DIR = "imgs/queries" | |
IMAGE_DIR = "imgs/candidates" | |
image_paths = [os.path.join(IMAGE_DIR, f) for f in os.listdir(IMAGE_DIR) if f.endswith((".jpg", ".png"))] | |
global IMAGE_TOKEN, TOP_N | |
IMAGE_TOKEN = "<|image_1|>" | |
TOP_N = 5 | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
print(f"device: {device}") | |
def load_model(): | |
global IMAGE_TOKEN | |
model_args = ModelArguments( | |
# model_name="/fs-computility/ai-shen/kilab-shared/liubangwei/ckpt/my_hf/IDMR-2B", | |
model_name="lbw18601752667/IDMR-2B", | |
model_backbone="internvl_2_5", | |
) | |
if model_args.model_backbone == "phi35v": | |
processor = AutoProcessor.from_pretrained( | |
model_args.model_name, | |
trust_remote_code=True, | |
num_crops=model_args.num_crops, | |
) | |
processor.tokenizer.padding_side = "right" | |
elif model_args.model_backbone == "internvl_2_5": | |
from src.vlm_backbone.intern_vl import InternVLProcessor | |
from transformers import AutoTokenizer, AutoImageProcessor | |
tokenizer = AutoTokenizer.from_pretrained( | |
model_args.model_name, | |
trust_remote_code=True | |
) | |
image_processor = AutoImageProcessor.from_pretrained( | |
model_args.model_name, | |
trust_remote_code=True, | |
use_fast=False | |
) | |
processor = InternVLProcessor( | |
image_processor=image_processor, | |
tokenizer=tokenizer | |
) | |
IMAGE_TOKEN = "<image>" | |
model = MMEBModel.load(model_args) | |
model = model.to(device, dtype=torch.bfloat16) | |
model.eval() | |
return model, processor | |
model, processor = load_model() | |
def get_inputs(processor, text, image_path=None, image=None): | |
if image_path: | |
image = Image.open(image_path) | |
if image is None: | |
text = text.replace(IMAGE_TOKEN, "") | |
inputs = processor( | |
text=text, | |
images=[image] if image else None, | |
return_tensors="pt", | |
max_length=1024, | |
truncation=True | |
) | |
inputs = {key: value.to(device) for key, value in inputs.items()} | |
inputs["image_flags"] = torch.tensor([1 if image else 0], dtype=torch.long).to(device) | |
if image is None: | |
del inputs['pixel_values'] | |
return inputs | |
def encode_image_library(image_paths): | |
embeddings_dict = {} | |
for img_path in image_paths: | |
text = f"{IMAGE_TOKEN}\n Represent the given image." | |
print(f"text: {text}") | |
inputs = get_inputs(processor, text, image_path=img_path) | |
with torch.no_grad(), torch.autocast(device_type=device, dtype=torch.bfloat16): | |
output = model(tgt=inputs) | |
img_name = os.path.basename(img_path) | |
embeddings_dict[img_name] = output["tgt_reps"].float().cpu().numpy() | |
return embeddings_dict | |
def save_embeddings(embeddings, file_path="image_embeddings.pkl"): | |
with open(file_path, "wb") as f: | |
pickle.dump(embeddings, f) | |
def load_embeddings(file_path="image_embeddings.pkl"): | |
with open(file_path, "rb") as f: | |
return pickle.load(f) | |
def cosine_similarity(query_embedding, embeddings): | |
similarity = np.sum(query_embedding * embeddings, axis=-1) | |
return similarity | |
def retrieve_images(query_text, query_image, top_n=TOP_N): | |
if query_text: | |
query_text = f"{IMAGE_TOKEN}\n {query_text}" | |
else: | |
query_text = f"{IMAGE_TOKEN}\n Represent the given image." | |
if query_image is not None: | |
image = Image.fromarray(query_image) | |
else: | |
image = None | |
inputs = get_inputs(processor, query_text, image=image) | |
print(f"inputs: {inputs}") | |
with torch.no_grad(), torch.autocast(device_type=device, dtype=torch.bfloat16): | |
query_embedding = model(qry=inputs)["qry_reps"].float().cpu().numpy() | |
embeddings_dict = load_embeddings() | |
img_names = [] | |
embeddings = [] | |
for img_name in os.listdir(IMAGE_DIR): | |
if img_name in embeddings_dict: | |
img_names.append(img_name) | |
embeddings.append(embeddings_dict[img_name]) | |
embeddings = np.stack(embeddings) | |
similarity = cosine_similarity(query_embedding, embeddings) | |
similarity = similarity.T | |
print(f"cosine_similarity: {similarity}") | |
top_indices = np.argsort(-similarity).squeeze(0)[:top_n] | |
print(f"top_indices: {top_indices}") | |
return [os.path.join(IMAGE_DIR, img_names[i]) for i in top_indices] | |
def demo(query_text, query_image): | |
# print(f"query_text: {query_text}, query_image: {query_image}, type(query_image): {type(query_image)}, image shape: {query_image.shape if query_image is not None else 'None'}") | |
retrieved_images = retrieve_images(query_text, query_image) | |
return [Image.open(img) for img in retrieved_images] | |
def load_examples(): | |
examples = [] | |
image_files = [f for f in os.listdir(QUERY_DIR) if f.endswith((".jpg", ".png"))] | |
for img_file in image_files: | |
img_path = os.path.join(QUERY_DIR, img_file) | |
txt_file = os.path.splitext(img_file)[0] + ".txt" | |
txt_path = os.path.join(QUERY_DIR, txt_file) | |
if os.path.exists(txt_path): | |
with open(txt_path, 'r', encoding='utf-8') as f: | |
query_text = f.read().strip().replace("<|image_1|>\n", "") | |
examples.append([query_text, img_path]) | |
return examples | |
iface = gr.Interface( | |
fn=demo, | |
inputs=[ | |
gr.Textbox(placeholder="Enter your query text here...", label="Query Text"), | |
gr.Image(label="Query Image", type="numpy") | |
], | |
outputs=gr.Gallery(label=f"Retrieved Images (Top {TOP_N})", columns=3), | |
examples=load_examples(), | |
title="Instance-Driven Multi-modal Retrieval (IDMR) Demo", | |
description="Enter a query text or upload an image to retrieve relevant images from the library. You can click on the examples below to try them out." | |
) | |
if not os.path.exists("image_embeddings.pkl"): | |
embeddings = encode_image_library(image_paths) | |
save_embeddings(embeddings) | |
iface.launch() |