Spaces:
Running
Running
File size: 2,376 Bytes
974e3be |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 |
import json
import pandas as pd
import numpy as np
import torch
from pathlib import Path
import lightning as pl
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.loggers import TensorBoardLogger
from torch.utils.data import Dataset, DataLoader
import textwrap
from transformers import (
AdamW,
T5ForConditionalGeneration,
T5TokenizerFast as T5Tokenizer
)
from tqdm.auto import tqdm
class PoemSummaryModel(pl.LightningModule):
def __init__(self):
super().__init__()
self.model= T5ForConditionalGeneration.from_pretrained('t5-base', return_dict=True)
def forward(self,input_ids, attention_mask, decoder_attention_mask, labels=None):
output = self.model(
input_ids,
attention_mask=attention_mask,
labels=labels,
decoder_attention_mask=decoder_attention_mask
)
return output.loss, output.logits
def training_step(self, batch, batch_idx):
input_ids=batch["text_input_ids"]
attention_mask=batch["text_attention_mask"]
labels=batch["labels"]
labels_attention_mask=batch["labels_attention_mask"]
loss, outputs = self(
input_ids=input_ids,
attention_mask=attention_mask,
decoder_attention_mask=labels_attention_mask,
labels=labels
)
self.log("train_loss", loss, prog_bar=True, logger=True)
return loss
def validation_step(self, batch, batch_idx):
input_ids=batch["text_input_ids"]
attention_mask=batch["text_attention_mask"]
labels=batch["labels"]
labels_attention_mask=batch["labels_attention_mask"]
loss, outputs = self(
input_ids=input_ids,
attention_mask=attention_mask,
decoder_attention_mask=labels_attention_mask,
labels=labels
)
self.log("val_loss", loss, prog_bar=True, logger=True)
return loss
def test_step(self, batch, batch_idx):
input_ids=batch["text_input_ids"]
attention_mask=batch["text_attention_mask"]
labels=batch["labels"]
labels_attention_mask=batch["labels_attention_mask"]
loss, outputs = self(
input_ids=input_ids,
attention_mask=attention_mask,
decoder_attention_mask=labels_attention_mask,
labels=labels
)
self.log("test_loss", loss, prog_bar=True, logger=True)
return loss
def configure_optimizers(self):
return AdamW(self.parameters(), lr=0.0001)
|