sumanthd commited on
Commit
c5ecbf5
·
1 Parent(s): b063b6f

add model inference

Browse files
Files changed (1) hide show
  1. app.py +65 -14
app.py CHANGED
@@ -1,9 +1,24 @@
 
 
 
 
1
  import gradio as gr
 
2
 
3
- model = None
4
- tokenizer = None
 
 
 
 
 
 
 
 
 
 
 
5
 
6
- # device = 0 if torch.cuda.is_available() else -1
7
 
8
  LANGUAGES = {
9
  "Hindi": "hin_Deva",
@@ -29,9 +44,51 @@ LANGUAGES = {
29
  "Bodo": "brx_Deva"
30
  }
31
 
32
- def translate(src_lang, text, tgt_lang):
 
33
 
34
- return "Translation output will appear here..."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
  def store_feedback(rating, feedback_text):
37
  if not rating:
@@ -59,12 +116,6 @@ with gr.Blocks(theme=gr.themes.Default(), css=css) as demo:
59
  with gr.Column(elem_id="col-container"):
60
  with gr.Row():
61
  with gr.Column():
62
- src_lang = gr.Dropdown(
63
- ["English"],
64
- value="English",
65
- label="Translate From",
66
- elem_id="translate-from"
67
- )
68
 
69
  text_input = gr.Textbox(
70
  placeholder="Enter text to translate...",
@@ -90,7 +141,7 @@ with gr.Blocks(theme=gr.themes.Default(), css=css) as demo:
90
  )
91
 
92
  btn_submit = gr.Button("Translate")
93
- btn_submit.click(fn=translate, inputs=[src_lang, text_input, tgt_lang], outputs=text_output)
94
 
95
  gr.Examples(
96
  examples=[
@@ -100,9 +151,9 @@ with gr.Blocks(theme=gr.themes.Default(), css=css) as demo:
100
  ["English", "Hello, how are you today? I hope you're doing well.", "Marathi"],
101
  ["English", "Hello, how are you today? I hope you're doing well.", "Malayalam"]
102
  ],
103
- inputs=[src_lang, text_input, tgt_lang],
104
  outputs=text_output,
105
- fn=translate,
106
  cache_examples=True,
107
  examples_per_page=5
108
  )
 
1
+ import torch
2
+ import spaces
3
+ from collections.abc import Iterator
4
+ from threading import Thread
5
  import gradio as gr
6
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
7
 
8
+ MAX_MAX_NEW_TOKENS = 4096
9
+ DEFAULT_MAX_NEW_TOKENS = 2048
10
+ MAX_INPUT_TOKEN_LENGTH = 4096
11
+
12
+ if not torch.cuda.is_available():
13
+ DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"
14
+
15
+
16
+ if torch.cuda.is_available():
17
+ model_id = "ai4bharat/IndicTrans3-beta"
18
+ model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, device_map="auto")
19
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
20
+ tokenizer.use_default_system_prompt = False
21
 
 
22
 
23
  LANGUAGES = {
24
  "Hindi": "hin_Deva",
 
44
  "Bodo": "brx_Deva"
45
  }
46
 
47
+
48
+ # def translate(src_lang, text, tgt_lang):
49
 
50
+ # return "Translation output will appear here..."
51
+
52
+ @spaces.GPU
53
+ def generate(
54
+ tgt_lang: str,
55
+ message: str,
56
+ max_new_tokens: int = 1024,
57
+ temperature: float = 0.6,
58
+ top_p: float = 0.9,
59
+ top_k: int = 50,
60
+ repetition_penalty: float = 1.2,
61
+ ) -> Iterator[str]:
62
+
63
+ conversation = []
64
+ conversation.append({"role": "user", "content": f"Translate the following text to {tgt_lang}: {message}"})
65
+
66
+ input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt")
67
+ if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
68
+ input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
69
+ gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
70
+ input_ids = input_ids.to(model.device)
71
+
72
+ streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
73
+ generate_kwargs = dict(
74
+ {"input_ids": input_ids},
75
+ streamer=streamer,
76
+ max_new_tokens=max_new_tokens,
77
+ do_sample=True,
78
+ top_p=top_p,
79
+ top_k=top_k,
80
+ temperature=temperature,
81
+ num_beams=1,
82
+ repetition_penalty=repetition_penalty,
83
+ )
84
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
85
+ t.start()
86
+
87
+ outputs = []
88
+ for text in streamer:
89
+ outputs.append(text)
90
+ yield "".join(outputs)
91
+
92
 
93
  def store_feedback(rating, feedback_text):
94
  if not rating:
 
116
  with gr.Column(elem_id="col-container"):
117
  with gr.Row():
118
  with gr.Column():
 
 
 
 
 
 
119
 
120
  text_input = gr.Textbox(
121
  placeholder="Enter text to translate...",
 
141
  )
142
 
143
  btn_submit = gr.Button("Translate")
144
+ btn_submit.click(fn=generate, inputs=[tgt_lang, text_input, 4096, 0, 50, 0], outputs=text_output)
145
 
146
  gr.Examples(
147
  examples=[
 
151
  ["English", "Hello, how are you today? I hope you're doing well.", "Marathi"],
152
  ["English", "Hello, how are you today? I hope you're doing well.", "Malayalam"]
153
  ],
154
+ inputs=[tgt_lang, text_input, 4096, 0, 50, 0],
155
  outputs=text_output,
156
+ fn=generate,
157
  cache_examples=True,
158
  examples_per_page=5
159
  )