import torch import numpy as np class NME: def __init__(self, nme_left_index, nme_right_index): self.nme_left_index = nme_left_index self.nme_right_index = nme_right_index def __repr__(self): return "NME()" def get_norm_distance(self, landmarks): assert isinstance(self.nme_right_index, list), 'the nme_right_index is not list.' assert isinstance(self.nme_left_index, list), 'the nme_left, index is not list.' right_pupil = landmarks[self.nme_right_index, :].mean(0) left_pupil = landmarks[self.nme_left_index, :].mean(0) norm_distance = np.linalg.norm(right_pupil - left_pupil) return norm_distance def test(self, label_pd, label_gt): nme_list = [] label_pd = label_pd.data.cpu().numpy() label_gt = label_gt.data.cpu().numpy() for i in range(label_gt.shape[0]): landmarks_gt = label_gt[i] landmarks_pv = label_pd[i] if isinstance(self.nme_right_index, list): norm_distance = self.get_norm_distance(landmarks_gt) elif isinstance(self.nme_right_index, int): norm_distance = np.linalg.norm(landmarks_gt[self.nme_left_index] - landmarks_gt[self.nme_right_index]) else: raise NotImplementedError landmarks_delta = landmarks_pv - landmarks_gt nme = (np.linalg.norm(landmarks_delta, axis=1) / norm_distance).mean() nme_list.append(nme) # sum_nme += nme # total_cnt += 1 return nme_list