mmchowdhury commited on
Commit
a74afc2
·
verified ·
1 Parent(s): dbd54db

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -8
app.py CHANGED
@@ -1,12 +1,20 @@
1
  import gradio as gr
2
- from transformers import T5TokenizerFast
3
- from model import PoemSummaryModel
4
-
5
 
6
  tokenizer = T5TokenizerFast.from_pretrained("t5-base")
7
- best_model = PoemSummaryModel.load_from_checkpoint("best-checkpoint.ckpt")
8
- best_model.freeze()
9
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
  def encode_text(text):
12
  encoding = tokenizer.encode_plus(
@@ -21,7 +29,7 @@ def encode_text(text):
21
 
22
  def generate_summary(input_ids, attention_mask, model):
23
  model = model.to(input_ids.device)
24
- generated_ids = model.model.generate(
25
  input_ids=input_ids,
26
  attention_mask=attention_mask,
27
  max_length=150,
@@ -39,7 +47,7 @@ def decode_summary(generated_ids):
39
 
40
  def summarize(text):
41
  input_ids, attention_mask = encode_text(text)
42
- generated_ids = generate_summary(input_ids, attention_mask, best_model)
43
  summary = decode_summary(generated_ids)
44
  return summary
45
 
@@ -53,4 +61,4 @@ gr.Interface(
53
  outputs=output_text,
54
  title="Poem Pulse",
55
  description="Enter a Poem and get its Jist."
56
- ).launch()
 
1
  import gradio as gr
2
+ import torch
3
+ from transformers import T5ForConditionalGeneration, T5TokenizerFast
 
4
 
5
  tokenizer = T5TokenizerFast.from_pretrained("t5-base")
 
 
6
 
7
+ # Define the quantized model architecture
8
+ quantized_model = T5ForConditionalGeneration.from_pretrained("t5-base")
9
+
10
+ # Load the state dictionary
11
+ state_dict = torch.load("quantized_model.pt")
12
+
13
+ # Filter out keys that are not present in the quantized model
14
+ filtered_state_dict = {k: v for k, v in state_dict.items() if k in quantized_model.state_dict()}
15
+
16
+ # Load the filtered state dictionary into the quantized model
17
+ quantized_model.load_state_dict(filtered_state_dict, strict=False)
18
 
19
  def encode_text(text):
20
  encoding = tokenizer.encode_plus(
 
29
 
30
  def generate_summary(input_ids, attention_mask, model):
31
  model = model.to(input_ids.device)
32
+ generated_ids = model.generate(
33
  input_ids=input_ids,
34
  attention_mask=attention_mask,
35
  max_length=150,
 
47
 
48
  def summarize(text):
49
  input_ids, attention_mask = encode_text(text)
50
+ generated_ids = generate_summary(input_ids, attention_mask, quantized_model)
51
  summary = decode_summary(generated_ids)
52
  return summary
53
 
 
61
  outputs=output_text,
62
  title="Poem Pulse",
63
  description="Enter a Poem and get its Jist."
64
+ ).launch()