File size: 910 Bytes
6303c5d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
import torch
from torchmetrics import Metric


class PRFMetric(Metric):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.add_state("t_precision", default=torch.tensor(0.0), dist_reduce_fx="sum")
        self.add_state("t_recall", default=torch.tensor(0.0), dist_reduce_fx="sum")
        self.add_state("t_f1", default=torch.tensor(0.0), dist_reduce_fx="sum")
        self.add_state("n", default=torch.tensor(0.0), dist_reduce_fx="sum")

    def update(self, vid_p, vid_r) -> None:
        self.t_precision += vid_p
        self.t_recall += vid_r
        self.t_f1 += 2 * (vid_p * vid_r) / (vid_p + vid_r) if vid_p + vid_r else 0.0
        self.n += 1

    def compute(self):
        avg_p = self.t_precision * 100 / self.n
        avg_r = self.t_recall * 100 / self.n
        avg_f1 = self.t_f1 * 100 / self.n
        return {"precision": avg_p, "recall": avg_r, "f1": avg_f1}