Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
from torchmetrics import Metric | |
class PRFMetric(Metric): | |
def __init__(self, **kwargs): | |
super().__init__(**kwargs) | |
self.add_state("t_precision", default=torch.tensor(0.0), dist_reduce_fx="sum") | |
self.add_state("t_recall", default=torch.tensor(0.0), dist_reduce_fx="sum") | |
self.add_state("t_f1", default=torch.tensor(0.0), dist_reduce_fx="sum") | |
self.add_state("n", default=torch.tensor(0.0), dist_reduce_fx="sum") | |
def update(self, vid_p, vid_r) -> None: | |
self.t_precision += vid_p | |
self.t_recall += vid_r | |
self.t_f1 += 2 * (vid_p * vid_r) / (vid_p + vid_r) if vid_p + vid_r else 0.0 | |
self.n += 1 | |
def compute(self): | |
avg_p = self.t_precision * 100 / self.n | |
avg_r = self.t_recall * 100 / self.n | |
avg_f1 = self.t_f1 * 100 / self.n | |
return {"precision": avg_p, "recall": avg_r, "f1": avg_f1} | |