marufc36 commited on
Commit
974e3be
·
1 Parent(s): 201a582
Files changed (1) hide show
  1. model.py +77 -0
model.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import pandas as pd
3
+ import numpy as np
4
+ import torch
5
+ from pathlib import Path
6
+ import lightning as pl
7
+ from lightning.pytorch.callbacks import ModelCheckpoint
8
+ from lightning.pytorch.loggers import TensorBoardLogger
9
+ from torch.utils.data import Dataset, DataLoader
10
+
11
+ import textwrap
12
+ from transformers import (
13
+ AdamW,
14
+ T5ForConditionalGeneration,
15
+ T5TokenizerFast as T5Tokenizer
16
+ )
17
+ from tqdm.auto import tqdm
18
+
19
+ class PoemSummaryModel(pl.LightningModule):
20
+ def __init__(self):
21
+ super().__init__()
22
+ self.model= T5ForConditionalGeneration.from_pretrained('t5-base', return_dict=True)
23
+ def forward(self,input_ids, attention_mask, decoder_attention_mask, labels=None):
24
+ output = self.model(
25
+ input_ids,
26
+ attention_mask=attention_mask,
27
+ labels=labels,
28
+ decoder_attention_mask=decoder_attention_mask
29
+ )
30
+ return output.loss, output.logits
31
+ def training_step(self, batch, batch_idx):
32
+ input_ids=batch["text_input_ids"]
33
+ attention_mask=batch["text_attention_mask"]
34
+ labels=batch["labels"]
35
+ labels_attention_mask=batch["labels_attention_mask"]
36
+
37
+
38
+ loss, outputs = self(
39
+ input_ids=input_ids,
40
+ attention_mask=attention_mask,
41
+ decoder_attention_mask=labels_attention_mask,
42
+ labels=labels
43
+ )
44
+ self.log("train_loss", loss, prog_bar=True, logger=True)
45
+ return loss
46
+ def validation_step(self, batch, batch_idx):
47
+ input_ids=batch["text_input_ids"]
48
+ attention_mask=batch["text_attention_mask"]
49
+ labels=batch["labels"]
50
+ labels_attention_mask=batch["labels_attention_mask"]
51
+
52
+
53
+ loss, outputs = self(
54
+ input_ids=input_ids,
55
+ attention_mask=attention_mask,
56
+ decoder_attention_mask=labels_attention_mask,
57
+ labels=labels
58
+ )
59
+ self.log("val_loss", loss, prog_bar=True, logger=True)
60
+ return loss
61
+ def test_step(self, batch, batch_idx):
62
+ input_ids=batch["text_input_ids"]
63
+ attention_mask=batch["text_attention_mask"]
64
+ labels=batch["labels"]
65
+ labels_attention_mask=batch["labels_attention_mask"]
66
+
67
+
68
+ loss, outputs = self(
69
+ input_ids=input_ids,
70
+ attention_mask=attention_mask,
71
+ decoder_attention_mask=labels_attention_mask,
72
+ labels=labels
73
+ )
74
+ self.log("test_loss", loss, prog_bar=True, logger=True)
75
+ return loss
76
+ def configure_optimizers(self):
77
+ return AdamW(self.parameters(), lr=0.0001)