|
import datetime |
|
import time |
|
from pathlib import Path |
|
from typing import Dict, List |
|
|
|
import einops |
|
import numpy as np |
|
import pandas as pd |
|
import torch |
|
import torch.nn.functional as F |
|
from tabulate import tabulate |
|
|
|
from src.tools.files import json_dump, json_load |
|
|
|
|
|
class TestFashionIQ: |
|
def __init__(self, category: str): |
|
self.category = category |
|
pass |
|
|
|
@torch.no_grad() |
|
def __call__(self, model, data_loader, fabric): |
|
model.eval() |
|
|
|
fabric.print("Computing features for evaluation...") |
|
start_time = time.time() |
|
|
|
query_feats = [] |
|
captions = [] |
|
idxs = [] |
|
for ref_img, _, caption, idx in data_loader: |
|
idxs.extend(idx.cpu().numpy().tolist()) |
|
captions.extend(caption) |
|
|
|
device = ref_img.device |
|
|
|
ref_img_embs = model.visual_encoder(ref_img) |
|
ref_img_atts = torch.ones(ref_img_embs.size()[:-1], dtype=torch.long).to( |
|
device |
|
) |
|
|
|
text = model.tokenizer( |
|
caption, |
|
padding="longest", |
|
truncation=True, |
|
max_length=64, |
|
return_tensors="pt", |
|
).to(device) |
|
|
|
|
|
encoder_input_ids = text.input_ids.clone() |
|
encoder_input_ids[:, 0] = model.tokenizer.enc_token_id |
|
query_embs = model.text_encoder( |
|
encoder_input_ids, |
|
attention_mask=text.attention_mask, |
|
encoder_hidden_states=ref_img_embs, |
|
encoder_attention_mask=ref_img_atts, |
|
return_dict=True, |
|
) |
|
query_feat = query_embs.last_hidden_state[:, 0, :] |
|
query_feat = F.normalize(model.text_proj(query_feat), dim=-1) |
|
query_feats.append(query_feat.cpu()) |
|
|
|
query_feats = torch.cat(query_feats, dim=0) |
|
query_feats = F.normalize(query_feats, dim=-1) |
|
idxs = torch.tensor(idxs, dtype=torch.long) |
|
|
|
if fabric.world_size > 1: |
|
|
|
query_feats = fabric.all_gather(query_feats) |
|
idxs = fabric.all_gather(idxs) |
|
|
|
query_feats = einops.rearrange(query_feats, "d b e -> (d b) e") |
|
idxs = einops.rearrange(idxs, "d b -> (d b)") |
|
|
|
if fabric.global_rank == 0: |
|
idxs = idxs.cpu().numpy() |
|
ref_img_ids = [data_loader.dataset.pairid2ref[idx] for idx in idxs] |
|
ref_img_ids = [data_loader.dataset.int2id[id] for id in ref_img_ids] |
|
|
|
tar_img_feats = [] |
|
tar_img_ids = [] |
|
for target_id in data_loader.dataset.target_ids: |
|
tar_img_ids.append(target_id) |
|
target_emb_pth = data_loader.dataset.id2embpth[target_id] |
|
target_feat = torch.load(target_emb_pth).cpu() |
|
tar_img_feats.append(target_feat.cpu()) |
|
tar_img_feats = torch.stack(tar_img_feats) |
|
tar_img_feats = F.normalize(tar_img_feats, dim=-1) |
|
tar_img_feats = tar_img_feats.to(query_feats.device) |
|
|
|
sim_q2t = (query_feats @ tar_img_feats.t()).cpu() |
|
|
|
|
|
for i in range(len(ref_img_ids)): |
|
for j in range(len(tar_img_ids)): |
|
if ref_img_ids[i] == tar_img_ids[j]: |
|
sim_q2t[i][j] = -10 |
|
|
|
total_time = time.time() - start_time |
|
total_time_str = str(datetime.timedelta(seconds=int(total_time))) |
|
print("Evaluation time {}".format(total_time_str)) |
|
|
|
ref_img_ids = np.array(ref_img_ids) |
|
tar_img_ids = np.array(tar_img_ids) |
|
|
|
cor_img_ids = [data_loader.dataset.pairid2tar[idx] for idx in idxs] |
|
cor_img_ids = [data_loader.dataset.int2id[id] for id in cor_img_ids] |
|
|
|
recalls = get_recalls_labels(sim_q2t, cor_img_ids, tar_img_ids) |
|
fabric.print(recalls) |
|
|
|
|
|
json_dump(recalls, f"recalls_fiq-{self.category}.json") |
|
|
|
print(f"Recalls saved in {Path.cwd()} as recalls_fiq-{self.category}.json") |
|
|
|
mean_results(fabric=fabric) |
|
|
|
fabric.barrier() |
|
|
|
|
|
|
|
def recall_at_k_labels(sim, query_lbls, target_lbls, k=10): |
|
distances = 1 - sim |
|
sorted_indices = torch.argsort(distances, dim=-1).cpu() |
|
sorted_index_names = np.array(target_lbls)[sorted_indices] |
|
labels = torch.tensor( |
|
sorted_index_names |
|
== np.repeat(np.array(query_lbls), len(target_lbls)).reshape( |
|
len(query_lbls), -1 |
|
) |
|
) |
|
assert torch.equal( |
|
torch.sum(labels, dim=-1).int(), torch.ones(len(query_lbls)).int() |
|
) |
|
return round((torch.sum(labels[:, :k]) / len(labels)).item() * 100, 2) |
|
|
|
|
|
def get_recalls_labels( |
|
sims, query_lbls, target_lbls, ks: List[int] = [1, 5, 10, 50] |
|
) -> Dict[str, float]: |
|
return {f"R{k}": recall_at_k_labels(sims, query_lbls, target_lbls, k) for k in ks} |
|
|
|
|
|
def mean_results(dir=".", fabric=None, save=True): |
|
dir = Path(dir) |
|
recall_pths = list(dir.glob("recalls_fiq-*.json")) |
|
recall_pths.sort() |
|
if len(recall_pths) != 3: |
|
return |
|
|
|
df = {} |
|
for pth in recall_pths: |
|
name = pth.name.split("_")[1].split(".")[0] |
|
data = json_load(pth) |
|
df[name] = data |
|
|
|
df = pd.DataFrame(df) |
|
|
|
|
|
df_fiq = df[df.columns[df.columns.str.contains("fiq")]] |
|
assert len(df_fiq.columns) == 3 |
|
df_fiq["Average"] = df_fiq.mean(axis=1) |
|
df_fiq["Average"] = df_fiq["Average"].apply(lambda x: round(x, 2)) |
|
|
|
headers = [ |
|
"dress\nR10", |
|
"dress\nR50", |
|
"shirt\nR10", |
|
"shirt\nR50", |
|
"toptee\nR10", |
|
"toptee\nR50", |
|
"Average\nR10", |
|
"Average\nR50", |
|
] |
|
fiq = [] |
|
for category in ["fiq-dress", "fiq-shirt", "fiq-toptee", "Average"]: |
|
for recall in ["R10", "R50"]: |
|
value = df_fiq.loc[recall, category] |
|
value = str(value).zfill(2) |
|
fiq.extend([value]) |
|
if fabric is None: |
|
print(tabulate([fiq], headers=headers, tablefmt="latex_raw")) |
|
print(" & ".join(fiq)) |
|
else: |
|
fabric.print(tabulate([fiq], headers=headers)) |
|
fabric.print(" & ".join(fiq)) |
|
|
|
if save: |
|
df_mean = df_fiq["Average"].to_dict() |
|
df_mean = {k + "_mean": round(v, 2) for k, v in df_mean.items()} |
|
json_dump(df_mean, "recalls_fiq-mean.json") |
|
|