Spaces:
Running
Running
import json | |
import pandas as pd | |
import numpy as np | |
import torch | |
from pathlib import Path | |
import lightning as pl | |
from lightning.pytorch.callbacks import ModelCheckpoint | |
from lightning.pytorch.loggers import TensorBoardLogger | |
from torch.utils.data import Dataset, DataLoader | |
import textwrap | |
from transformers import ( | |
AdamW, | |
T5ForConditionalGeneration, | |
T5TokenizerFast as T5Tokenizer | |
) | |
from tqdm.auto import tqdm | |
class PoemSummaryModel(pl.LightningModule): | |
def __init__(self): | |
super().__init__() | |
self.model= T5ForConditionalGeneration.from_pretrained('t5-base', return_dict=True) | |
def forward(self,input_ids, attention_mask, decoder_attention_mask, labels=None): | |
output = self.model( | |
input_ids, | |
attention_mask=attention_mask, | |
labels=labels, | |
decoder_attention_mask=decoder_attention_mask | |
) | |
return output.loss, output.logits | |
def training_step(self, batch, batch_idx): | |
input_ids=batch["text_input_ids"] | |
attention_mask=batch["text_attention_mask"] | |
labels=batch["labels"] | |
labels_attention_mask=batch["labels_attention_mask"] | |
loss, outputs = self( | |
input_ids=input_ids, | |
attention_mask=attention_mask, | |
decoder_attention_mask=labels_attention_mask, | |
labels=labels | |
) | |
self.log("train_loss", loss, prog_bar=True, logger=True) | |
return loss | |
def validation_step(self, batch, batch_idx): | |
input_ids=batch["text_input_ids"] | |
attention_mask=batch["text_attention_mask"] | |
labels=batch["labels"] | |
labels_attention_mask=batch["labels_attention_mask"] | |
loss, outputs = self( | |
input_ids=input_ids, | |
attention_mask=attention_mask, | |
decoder_attention_mask=labels_attention_mask, | |
labels=labels | |
) | |
self.log("val_loss", loss, prog_bar=True, logger=True) | |
return loss | |
def test_step(self, batch, batch_idx): | |
input_ids=batch["text_input_ids"] | |
attention_mask=batch["text_attention_mask"] | |
labels=batch["labels"] | |
labels_attention_mask=batch["labels_attention_mask"] | |
loss, outputs = self( | |
input_ids=input_ids, | |
attention_mask=attention_mask, | |
decoder_attention_mask=labels_attention_mask, | |
labels=labels | |
) | |
self.log("test_loss", loss, prog_bar=True, logger=True) | |
return loss | |
def configure_optimizers(self): | |
return AdamW(self.parameters(), lr=0.0001) | |