File size: 2,263 Bytes
a446b0b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import time
import torch

import comet.utils.utils as utils
import comet.src.data.config as cfg


class Evaluator(object):
    def __init__(self, opt, model, data_loader):
        super(Evaluator, self).__init__()

        self.data_loader = data_loader
        self.model = model

        self.batch_variables = {
            "model": model,
            "data": data_loader
        }

        self.opt = opt

    def validate(self, l, split="dev", losses={}, keyset=None):
        self.batch_variables["split"] = split
        print("Evaluating {}".format(split))

        epoch_losses = self.epoch(
            self.opt, self.model, self.data_loader, split, keyset)

        self.print_result(split, epoch_losses)

        for loss_name, loss_val in epoch_losses.items():
            losses.setdefault(loss_name, {})
            losses[loss_name][l] = loss_val

    def epoch(self, opt, model, data_loader, split, keyset=None):
        average_loss, nums = self.initialize_losses()

        data_loader.reset_offsets(splits=split, shuffle=False)

        # Set evaluation mode
        model.eval()

        start = time.time()

        # Initialize progress bar
        bar = utils.set_progress_bar(
            data_loader.total_size[split])

        reset = False

        with torch.no_grad():
            while not reset:

                start = data_loader.offset_summary(split)

                outputs = self.batch(
                    opt, nums, average_loss,
                    self.batch_variables, eval_mode=True)

                end = data_loader.offset_summary(split)

                reset = outputs["reset"]

                if not reset:
                    bar.update(end - start)
                else:
                    print(end)

                if cfg.toy and self.counter(nums) > 100:
                    break
                if (opt.eval.es != "full" and
                        (self.counter(nums) > opt.eval.es)):
                    break

        nums = outputs["nums"]

        torch.cuda.synchronize()

        print("{} evaluation completed in: {} s".format(
            split.capitalize(), time.time() - start))

        average_loss = self.compute_final_scores(
            average_loss, nums)

        return average_loss