|
import datetime |
|
import time |
|
from collections import OrderedDict |
|
from pathlib import Path |
|
|
|
import einops |
|
import numpy as np |
|
import torch |
|
import torch.nn.functional as F |
|
|
|
from src.tools.files import json_dump |
|
|
|
|
|
class TestCirr: |
|
def __init__(self): |
|
pass |
|
|
|
@staticmethod |
|
@torch.no_grad() |
|
def __call__(model, data_loader, fabric): |
|
model.eval() |
|
|
|
fabric.print("Computing features for test...") |
|
start_time = time.time() |
|
|
|
tar_img_feats = [] |
|
query_feats = [] |
|
pair_ids = [] |
|
for ref_img, tar_feat, caption, pair_id, *_ in data_loader: |
|
pair_ids.extend(pair_id.cpu().numpy().tolist()) |
|
|
|
device = ref_img.device |
|
|
|
ref_img_embs = model.visual_encoder(ref_img) |
|
ref_img_atts = torch.ones(ref_img_embs.size()[:-1], dtype=torch.long).to( |
|
device |
|
) |
|
|
|
text = model.tokenizer( |
|
caption, |
|
padding="longest", |
|
truncation=True, |
|
max_length=64, |
|
return_tensors="pt", |
|
).to(device) |
|
|
|
|
|
encoder_input_ids = text.input_ids.clone() |
|
encoder_input_ids[:, 0] = model.tokenizer.enc_token_id |
|
query_embs = model.text_encoder( |
|
encoder_input_ids, |
|
attention_mask=text.attention_mask, |
|
encoder_hidden_states=ref_img_embs, |
|
encoder_attention_mask=ref_img_atts, |
|
return_dict=True, |
|
) |
|
query_feat = query_embs.last_hidden_state[:, 0, :] |
|
query_feat = F.normalize(model.text_proj(query_feat), dim=-1) |
|
query_feats.append(query_feat.cpu()) |
|
|
|
|
|
tar_img_feats.append(tar_feat.cpu()) |
|
|
|
pair_ids = torch.tensor(pair_ids, dtype=torch.long) |
|
query_feats = torch.cat(query_feats, dim=0) |
|
tar_img_feats = torch.cat(tar_img_feats, dim=0) |
|
|
|
if fabric.world_size > 1: |
|
|
|
query_feats = fabric.all_gather(query_feats) |
|
tar_img_feats = fabric.all_gather(tar_img_feats) |
|
pair_ids = fabric.all_gather(pair_ids) |
|
|
|
query_feats = einops.rearrange(query_feats, "d b e -> (d b) e") |
|
tar_img_feats = einops.rearrange(tar_img_feats, "d b e -> (d b) e") |
|
pair_ids = einops.rearrange(pair_ids, "d b -> (d b)") |
|
|
|
if fabric.global_rank == 0: |
|
pair_ids = pair_ids.cpu().numpy().tolist() |
|
|
|
assert len(query_feats) == len(pair_ids) |
|
img_ids = [data_loader.dataset.pairid2ref[pair_id] for pair_id in pair_ids] |
|
assert len(img_ids) == len(pair_ids) |
|
|
|
id2emb = OrderedDict() |
|
for img_id, tar_img_feat in zip(img_ids, tar_img_feats): |
|
if img_id not in id2emb: |
|
id2emb[img_id] = tar_img_feat |
|
|
|
tar_feats = torch.stack(list(id2emb.values()), dim=0) |
|
sims_q2t = query_feats @ tar_feats.T |
|
|
|
|
|
pairid2index = {pair_id: i for i, pair_id in enumerate(pair_ids)} |
|
|
|
|
|
tarid2index = {tar_id: j for j, tar_id in enumerate(id2emb.keys())} |
|
|
|
|
|
for pair_id, query_feat in zip(pair_ids, query_feats): |
|
que_id = data_loader.dataset.pairid2ref[pair_id] |
|
if que_id in tarid2index: |
|
sims_q2t[pairid2index[pair_id], tarid2index[que_id]] = -100 |
|
sims_q2t = sims_q2t.cpu().numpy() |
|
|
|
total_time = time.time() - start_time |
|
total_time_str = str(datetime.timedelta(seconds=int(total_time))) |
|
print("Evaluation time {}".format(total_time_str)) |
|
|
|
recalls = {} |
|
recalls["version"] = "rc2" |
|
recalls["metric"] = "recall" |
|
|
|
recalls_subset = {} |
|
recalls_subset["version"] = "rc2" |
|
recalls_subset["metric"] = "recall_subset" |
|
|
|
target_imgs = np.array(list(id2emb.keys())) |
|
|
|
assert len(sims_q2t) == len(pair_ids) |
|
for pair_id, query_sims in zip(pair_ids, sims_q2t): |
|
sorted_indices = np.argsort(query_sims)[::-1] |
|
|
|
query_id_recalls = list(target_imgs[sorted_indices][:50]) |
|
query_id_recalls = [ |
|
str(data_loader.dataset.int2id[x]) for x in query_id_recalls |
|
] |
|
recalls[str(pair_id)] = query_id_recalls |
|
|
|
members = data_loader.dataset.pairid2members[pair_id] |
|
query_id_recalls_subset = [ |
|
target |
|
for target in target_imgs[sorted_indices] |
|
if target in members |
|
] |
|
query_id_recalls_subset = [ |
|
data_loader.dataset.int2id[x] for x in query_id_recalls_subset |
|
][:3] |
|
recalls_subset[str(pair_id)] = query_id_recalls_subset |
|
|
|
json_dump(recalls, "recalls_cirr.json") |
|
json_dump(recalls_subset, "recalls_cirr_subset.json") |
|
|
|
print(f"Recalls saved in {Path.cwd()} as recalls_cirr.json") |
|
|
|
fabric.barrier() |
|
|