Spaces:
Running
on
Zero
Running
on
Zero
File size: 910 Bytes
6303c5d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 |
import torch
from torchmetrics import Metric
class PRFMetric(Metric):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.add_state("t_precision", default=torch.tensor(0.0), dist_reduce_fx="sum")
self.add_state("t_recall", default=torch.tensor(0.0), dist_reduce_fx="sum")
self.add_state("t_f1", default=torch.tensor(0.0), dist_reduce_fx="sum")
self.add_state("n", default=torch.tensor(0.0), dist_reduce_fx="sum")
def update(self, vid_p, vid_r) -> None:
self.t_precision += vid_p
self.t_recall += vid_r
self.t_f1 += 2 * (vid_p * vid_r) / (vid_p + vid_r) if vid_p + vid_r else 0.0
self.n += 1
def compute(self):
avg_p = self.t_precision * 100 / self.n
avg_r = self.t_recall * 100 / self.n
avg_f1 = self.t_f1 * 100 / self.n
return {"precision": avg_p, "recall": avg_r, "f1": avg_f1}
|