Spaces:
Build error
Build error
File size: 4,200 Bytes
a446b0b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 |
import copy
import torch
import torch.nn as nn
import torch.nn.functional as F
import comet.src.data.config as cfg
import comet.src.train.utils as train_utils
import comet.src.models.utils as model_utils
import comet.src.evaluate.utils as eval_utils
import comet.utils.utils as utils
from IPython import embed
##############################################################################
# BATCH
##############################################################################
def batch_atomic_generate(opt, nums, losses, batch_variables, eval_mode=False):
data_loader = batch_variables["data"]
model = batch_variables["model"]
split = batch_variables["split"]
batch, reset = data_loader.sample_batch(split, bs=opt.train.dynamic.bs)
input_ = model_utils.prepare_position_embeddings(
opt, data_loader.vocab_encoder, batch["sequences"].unsqueeze(-1))
attention_mask = batch["attention_mask"]
loss_mask = batch["loss_mask"]
targets = input_.squeeze(0)[:, 1:, 0].contiguous().view(-1)
loss, dist = mle_steps(
opt.net.model, model, input_[:, :-1, :], targets,
attention_mask[:, :-1], loss_reduction="none")
# Set loss name
micro_name = "total_micro"
macro_name = "total_macro"
length = loss_mask.sum(1)
bs = input_.size(0)
final_loss = (loss * loss_mask).sum(1)
update_generation_losses(losses, nums, micro_name, macro_name, bs,
length, (loss * loss_mask).sum(1), split)
final_loss = final_loss / length
outputs = {"loss": final_loss.sum(), "nums": nums, "reset": reset}
return outputs
def batch_conceptnet_generate(opt, nums, losses, batch_variables,
eval_mode=False, tracking_mode=False):
data_loader = batch_variables["data"]
model = batch_variables["model"]
split = batch_variables["split"]
category = batch_variables["category"]
batch, reset = data_loader.sample_batch(
split, bs=opt.train.dynamic.bs, cat=category)
input_ = model_utils.prepare_position_embeddings(
opt, data_loader.vocab_encoder, batch["sequences"].unsqueeze(-1))
attention_mask = batch["attention_mask"]
loss_mask = batch["loss_mask"]
targets = input_.squeeze(0)[:, 1:, 0].contiguous().view(-1)
loss, dist = mle_steps(
opt.net.model, model, input_[:, :-1, :], targets,
attention_mask[:, :-1], loss_reduction="none")
# Set loss name
if not eval_mode or batch_variables["category"] == "positive":
micro_name = "total_micro"
macro_name = "total_macro"
else:
micro_name = "negative_micro"
macro_name = "negative_macro"
length = loss_mask.sum(1)
bs = input_.size(0)
final_loss = (loss * loss_mask).sum(1)
update_generation_losses(losses, nums, micro_name, macro_name, bs,
length, (loss * loss_mask).sum(1), split)
final_loss = final_loss / length
outputs = {"loss": final_loss.sum(), "nums": nums, "reset": reset}
if tracking_mode:
outputs["tracking"] = final_loss.squeeze().tolist()
return outputs
def mle_steps(key, model, input_, targets, attention_mask,
loss_reduction="mean", i=None):
word_acts = decode(model, input_.unsqueeze(1),
attention_mask, i)
word_dist = train_utils.modify_output_for_loss_fn(
"nll", word_acts, dim=-1)
# Compute losses
loss = F.nll_loss(
word_dist.view(-1, word_dist.size(-1)),
targets, reduction=loss_reduction)
if loss_reduction != "mean":
return loss.view(word_dist.size(0), -1), word_dist
else:
return loss, word_dist
def decode(model, input_, attention_mask, i=None):
return model(input_, sequence_mask=attention_mask)
def update_generation_losses(losses, nums, micro, macro, bs,
length, loss, split):
if split == "train":
train_utils.update_generation_losses(
losses, nums, micro, macro, bs, length, loss)
else:
eval_utils.update_generation_losses(
losses, nums, micro, macro, bs, length, loss)
|