Spaces:
Running
Running
File size: 2,001 Bytes
201a582 a74afc2 201a582 a74afc2 201a582 a74afc2 201a582 a74afc2 201a582 a74afc2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 |
import gradio as gr
import torch
from transformers import T5ForConditionalGeneration, T5TokenizerFast
tokenizer = T5TokenizerFast.from_pretrained("t5-base")
# Define the quantized model architecture
quantized_model = T5ForConditionalGeneration.from_pretrained("t5-base")
# Load the state dictionary
state_dict = torch.load("quantized_model.pt")
# Filter out keys that are not present in the quantized model
filtered_state_dict = {k: v for k, v in state_dict.items() if k in quantized_model.state_dict()}
# Load the filtered state dictionary into the quantized model
quantized_model.load_state_dict(filtered_state_dict, strict=False)
def encode_text(text):
encoding = tokenizer.encode_plus(
text,
max_length=512,
padding="max_length",
truncation=True,
return_attention_mask=True,
return_tensors='pt'
)
return encoding["input_ids"], encoding["attention_mask"]
def generate_summary(input_ids, attention_mask, model):
model = model.to(input_ids.device)
generated_ids = model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
max_length=150,
num_beams=2,
repetition_penalty=2.5,
length_penalty=1.0,
early_stopping=True
)
return generated_ids
def decode_summary(generated_ids):
summary = [tokenizer.decode(gen_id, skip_special_tokens=True, clean_up_tokenization_spaces=True)
for gen_id in generated_ids]
return "".join(summary)
def summarize(text):
input_ids, attention_mask = encode_text(text)
generated_ids = generate_summary(input_ids, attention_mask, quantized_model)
summary = decode_summary(generated_ids)
return summary
# Create Gradio interface
input_text = gr.Textbox(lines=10, label="Input Text")
output_text = gr.Textbox(label="Summary")
gr.Interface(
fn=summarize,
inputs=input_text,
outputs=output_text,
title="Poem Pulse",
description="Enter a Poem and get its Jist."
).launch()
|