ruslanmv commited on
Commit
c1f9bb7
·
verified ·
1 Parent(s): 31d70aa

Update src/app.py

Browse files
Files changed (1) hide show
  1. src/app.py +182 -252
src/app.py CHANGED
@@ -1,131 +1,105 @@
1
  """Template Demo for IBM Granite Hugging Face spaces."""
2
 
3
- from collections.abc import Iterator
4
- from datetime import datetime
5
- from pathlib import Path
6
- from threading import Thread
7
-
8
- import gradio as gr
9
- import spaces
10
- import torch
11
- from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
12
-
13
- # Vision model imports
14
- from transformers import LlavaNextProcessor, LlavaNextForConditionalGeneration
15
- import random
16
-
17
- from themes.research_monochrome import theme
18
-
19
- today_date = datetime.today().strftime("%B %-d, %Y") # noqa: DTZ002
20
-
21
- SYS_PROMPT = f"""Knowledge Cutoff Date: April 2024.
22
- Today's Date: {today_date}.
23
- You are Granite, developed by IBM. You are a helpful AI assistant"""
24
- TITLE = "IBM Granite 3.1 8b Instruct & Vision Preview"
25
- DESCRIPTION = """
26
- <p>Granite 3.1 8b instruct is an open-source LLM supporting a 128k context window and Granite Vision 3.1 2B Preview for vision-language capabilities. Start with one of the sample prompts
27
- or enter your own. Upload an image to use the vision model. Keep in mind that AI can occasionally make mistakes.
28
- <span class="gr_docs_link">
29
- <a href="https://www.ibm.com/granite/docs/">View Granite Instruct Documentation <i class="fa fa-external-link"></i></a>
30
- </span>
31
- <span class="gr_docs_link">
32
- <a href="https://www.ibm.com/granite/vision/docs/">View Granite Vision Documentation <i class="fa fa-external-link"></i></a>
33
- </span>
34
- </p>
35
- """
36
- MAX_INPUT_TOKEN_LENGTH = 128_000
37
- MAX_NEW_TOKENS = 1024
38
- TEMPERATURE = 0.7
39
- TOP_P = 0.85
40
- TOP_K = 50
41
- REPETITION_PENALTY = 1.05
42
-
43
- VISION_TEMPERATURE = 0.2
44
- VISION_TOP_P = 0.95
45
- VISION_TOP_K = 50
46
- VISION_MAX_TOKENS = 128
47
-
48
-
49
- if not torch.cuda.is_available():
50
- print("This demo may not work on CPU.")
51
-
52
- # Text model loading
53
- text_model = AutoModelForCausalLM.from_pretrained(
54
- "ibm-granite/granite-3.1-8b-instruct", torch_dtype=torch.float16, device_map="auto"
55
- )
56
- text_tokenizer = AutoTokenizer.from_pretrained("ibm-granite/granite-3.1-8b-instruct")
57
- text_tokenizer.use_default_system_prompt = False
58
-
59
- # Vision model loading
60
- vision_model_path = "ibm-granite/granite-vision-3.1-2b-preview"
61
- vision_processor = LlavaNextProcessor.from_pretrained(vision_model_path, use_fast=True)
62
-
63
- # Option 1: Use the default settings (like the original demo)
64
- vision_model = LlavaNextForConditionalGeneration.from_pretrained(
65
- vision_model_path,
66
- torch_dtype="auto",
67
- device_map="auto"
68
- )
69
-
70
- # --- OR ---
71
-
72
- # Option 2: Use torch.float16 but ensure the custom model code is trusted
73
- # vision_model = LlavaNextForConditionalGeneration.from_pretrained(
74
- # vision_model_path,
75
- # torch_dtype=torch.float16,
76
- # device_map="auto",
77
- # trust_remote_code=True
78
- # )
79
-
80
-
81
- @spaces.GPU
82
- def generate(
83
- message: str,
84
- chat_history: list[dict],
85
- temperature: float = TEMPERATURE,
86
- repetition_penalty: float = REPETITION_PENALTY,
87
- top_p: float = TOP_P,
88
- top_k: float = TOP_K,
89
- max_new_tokens: int = MAX_NEW_TOKENS,
90
- ) -> Iterator[str]:
91
- """Generate function for text chat demo."""
92
- # Build messages
93
- conversation = []
94
- conversation.append({"role": "system", "content": SYS_PROMPT})
95
- conversation += chat_history
96
- conversation.append({"role": "user", "content": message})
97
-
98
- # Convert messages to prompt format
99
- input_ids = text_tokenizer.apply_chat_template(
100
- conversation,
101
- return_tensors="pt",
102
- add_generation_prompt=True,
103
- truncation=True,
104
- max_length=MAX_INPUT_TOKEN_LENGTH - max_new_tokens,
105
- )
106
-
107
- input_ids = input_ids.to(text_model.device)
108
- streamer = TextIteratorStreamer(text_tokenizer, skip_prompt=True, skip_special_tokens=True)
109
- generate_kwargs = dict(
110
- {"input_ids": input_ids},
111
- streamer=streamer,
112
- max_new_tokens=max_new_tokens,
113
- do_sample=True,
114
- top_p=top_p,
115
- top_k=top_k,
116
- temperature=temperature,
117
- num_beams=1,
118
- repetition_penalty=repetition_penalty,
119
- )
120
-
121
- t = Thread(target=text_model.generate, kwargs=generate_kwargs)
122
- t.start()
123
-
124
- outputs = []
125
- for text in streamer:
126
- outputs.append(text)
127
- yield "".join(outputs)
128
-
129
 
130
  def get_text_from_content(content):
131
  texts = []
@@ -137,7 +111,7 @@ def get_text_from_content(content):
137
  return " ".join(texts)
138
 
139
  @spaces.GPU
140
- def chat_inference(image, text, conversation, temperature=VISION_TEMPERATURE, top_p=VISION_TOP_P, top_k=VISION_TOP_K, max_tokens=VISION_MAX_TOKENS):
141
  if conversation is None:
142
  conversation = []
143
 
@@ -194,156 +168,126 @@ def conversation_display(conversation):
194
  return chat_history
195
 
196
  def clear_chat():
197
- return [], [], "", None, [] # Cleared state for both text and vision
198
 
199
  css_file_path = Path(Path(__file__).parent / "app.css")
200
  head_file_path = Path(Path(__file__).parent / "app_head.html")
201
 
202
- # Advanced settings (displayed in Accordion) - Text Model
203
- text_temperature_slider = gr.Slider(
204
- minimum=0, maximum=1.0, value=TEMPERATURE, step=0.1, label="Text Temperature", elem_classes=["gr_accordion_element"]
205
  )
206
- text_top_p_slider = gr.Slider(
207
- minimum=0, maximum=1.0, value=TOP_P, step=0.05, label="Text Top P", elem_classes=["gr_accordion_element"]
208
  )
209
- text_top_k_slider = gr.Slider(
210
- minimum=0, maximum=100, value=TOP_K, step=1, label="Text Top K", elem_classes=["gr_accordion_element"]
211
  )
212
- text_repetition_penalty_slider = gr.Slider(
 
 
213
  minimum=0,
214
  maximum=2.0,
215
  value=REPETITION_PENALTY,
216
  step=0.05,
217
- label="Text Repetition Penalty",
218
  elem_classes=["gr_accordion_element"],
219
  )
220
- text_max_new_tokens_slider = gr.Slider(
221
  minimum=1,
222
  maximum=2000,
223
  value=MAX_NEW_TOKENS,
224
  step=1,
225
- label="Text Max New Tokens",
226
  elem_classes=["gr_accordion_element"],
227
  )
228
- text_chat_interface_accordion = gr.Accordion(label="Text Model Advanced Settings", open=False)
229
 
230
- # Advanced settings (displayed in Accordion) - Vision Model
231
- vision_temperature_slider = gr.Slider(
232
- minimum=0.0, maximum=2.0, value=VISION_TEMPERATURE, step=0.01, label="Vision Temperature", elem_classes=["gr_accordion_element"]
233
- )
234
- vision_top_p_slider = gr.Slider(
235
- minimum=0.0, maximum=1.0, value=VISION_TOP_P, step=0.01, label="Vision Top p", elem_classes=["gr_accordion_element"]
236
- )
237
- vision_top_k_slider = gr.Slider(
238
- minimum=0, maximum=100, value=VISION_TOP_K, step=1, label="Vision Top k", elem_classes=["gr_accordion_element"]
239
- )
240
- vision_max_tokens_slider = gr.Slider(
241
- minimum=10, maximum=300, value=VISION_MAX_TOKENS, step=1, label="Vision Max Tokens", elem_classes=["gr_accordion_element"]
242
  )
243
- vision_chat_interface_accordion = gr.Accordion(label="Vision Model Advanced Settings", open=False)
244
 
 
245
 
246
  with gr.Blocks(fill_height=True, css_paths=css_file_path, head_paths=head_file_path, theme=theme, title=TITLE) as demo:
247
  gr.HTML(f"<h1>{TITLE}</h1>", elem_classes=["gr_title"])
248
  gr.HTML(DESCRIPTION)
249
 
250
- chatbot = gr.Chatbot(label="Chat History", elem_id="chatbot", height=500, type='messages')
251
- text_input = gr.Textbox(lines=2, placeholder="Enter your message here", label="Message")
252
- image_input = gr.Image(type="pil", label="Upload Image (optional)")
253
-
254
- with text_chat_interface_accordion:
255
- text_temperature_slider
256
- text_top_p_slider
257
- text_top_k_slider
258
- text_repetition_penalty_slider
259
- text_max_new_tokens_slider
260
-
261
- with vision_chat_interface_accordion:
262
- vision_temperature_slider
263
- vision_top_p_slider
264
- vision_top_k_slider
265
- vision_max_tokens_slider
266
-
267
-
268
- clear_button = gr.Button("Clear Chat")
269
- send_button = gr.Button("Send Message") # Changed from "Chat" to "Send Message" for clarity
270
-
271
- text_state = gr.State([]) # State for text chatbot history
272
- vision_state = gr.State([]) # State for vision chatbot history
273
- chatbot_type_state = gr.State("text") # State to track which chatbot is in use
274
-
275
- def send_message(image_input, text_input, chatbot_type_state, text_state, vision_state,
276
- text_temperature, text_repetition_penalty, text_top_p, text_top_k, text_max_new_tokens,
277
- vision_temperature, vision_top_p, vision_top_k, vision_max_tokens):
278
  if image_input:
279
- chatbot_type_state = "vision"
280
- history = vision_state
281
- gen_kwargs_vision = {
282
- "temperature": vision_temperature,
283
- "top_p": vision_top_p,
284
- "top_k": vision_top_k,
285
- "max_tokens": vision_max_tokens,
286
- "conversation": history
287
- }
288
- chat_output, updated_vision_state = chat_inference(image=image_input, text=text_input, **gen_kwargs_vision)
289
- return chat_output, updated_vision_state, chatbot_type_state, gr.ChatInterface.update(visible=False), gr.Chatbot.update(visible=True) # Hide text interface, show vision chatbot
290
-
291
  else:
292
- chatbot_type_state = "text"
293
- history = text_state
294
- gen_kwargs_text = {
295
- "temperature": text_temperature,
296
- "repetition_penalty": text_repetition_penalty,
297
- "top_p": text_top_p,
298
- "top_k": text_top_k,
299
- "max_new_tokens": text_max_new_tokens,
300
- "message": text_input,
301
- "chat_history": history
302
- }
303
 
304
- chat_output_iterator = generate(**gen_kwargs_text)
305
- output_text = ""
306
- for text_chunk in chat_output_iterator:
307
- output_text = text_chunk
308
-
309
- updated_text_state = history + [{"role": "user", "content": text_input}, {"role": "assistant", "content": output_text}]
310
- text_chatbot_history = updated_text_state # format for chatbot display
311
- formatted_history = []
312
- for message in text_chatbot_history:
313
- formatted_history.append((message["content"] if message["role"] == "user" else None, message["content"] if message["role"] == "assistant" else None))
 
 
 
 
314
 
 
 
315
 
316
- return formatted_history, updated_text_state, chatbot_type_state, gr.ChatInterface.update(visible=True), gr.Chatbot.update(visible=False) # Show text interface, hide vision chatbot
317
 
318
 
319
  send_button.click(
320
- send_message,
321
- inputs=[image_input, text_input, chatbot_type_state, text_state, vision_state,
322
- text_temperature_slider, text_repetition_penalty_slider, text_top_p_slider, text_top_k_slider, text_max_new_tokens_slider,
323
- vision_temperature_slider, vision_top_p_slider, vision_top_k_slider, vision_max_tokens_slider],
324
- outputs=[chatbot, vision_state, chatbot_type_state, gr.ChatInterface(), gr.Chatbot()] # Dummy ChatInterface output, real Chatbot output
325
  )
326
 
327
  clear_button.click(
328
  clear_chat,
329
  inputs=None,
330
- outputs=[chatbot, vision_state, text_input, image_input, text_state] # Added text_state to clear
331
  )
332
 
333
-
334
- image_examples_dir = Path(__file__).parent / "themes" # Create a folder 'themes' in the same directory as app.py and put images there
335
- image_examples = [
336
- str(image_examples_dir / "cat.jpg"), # Replace "cat.jpg" with actual image file names in 'image_examples'
337
- str(image_examples_dir / "dog.jpg"), # Replace "dog.jpg" with actual image file names in 'image_examples'
338
- str(image_examples_dir / "horse.jpg") # Replace "horse.jpg" with actual image file names in 'image_examples'
339
- ]
340
-
341
  gr.Examples(
342
  examples=[
343
- ["Explain the concept of quantum computing to someone with no background in physics or computer science.", None],
344
- ["What is OpenShift?", None],
345
- ["What's the importance of low latency inference?", None],
346
- ["Help me boost productivity habits.", None],
347
  [
348
  """Explain the following code in a concise manner:
349
 
@@ -384,7 +328,7 @@ class Pair {
384
  this.y = y;
385
  }
386
  }
387
- ```""", None
388
  ],
389
  [
390
  """Generate a Java code block from the following explanation:
@@ -395,27 +339,13 @@ The findPairs method takes two arguments: an array of integers and a difference
395
 
396
  The Pair class is a simple data structure that stores two integers.
397
 
398
- The main method creates an array of integers, initializes the difference value, and calls the findPairs method to find all pairs in the array. Finally, the code iterates over the list of pairs and prints each pair to the console.""" , None # noqa: E501
399
  ],
400
- ["What is in this image?", image_examples[0]], # Vision example using local image
401
- ["Describe this image in detail", image_examples[1]], # Vision example using local image
402
- ["Identify the object in the image", image_examples[2]], # Vision example using local image
403
  ],
404
- inputs=[text_input, image_input],
405
- example_labels=[
406
- "Explain quantum computing",
407
- "What is OpenShift?",
408
- "Importance of low latency inference",
409
- "Boosting productivity habits",
410
- "Explain and document your code",
411
- "Generate Java Code",
412
- "Vision Example 1: What is in this image?",
413
- "Vision Example 2: Describe this image",
414
- "Vision Example 3: Identify object",
415
- ],
416
- cache_examples=False,
417
  )
418
 
419
-
420
  if __name__ == "__main__":
421
  demo.queue().launch()
 
1
  """Template Demo for IBM Granite Hugging Face spaces."""
2
 
3
+ from collections.abc import Iterator
4
+ from datetime import datetime
5
+ from pathlib import Path
6
+ from threading import Thread
7
+
8
+ import gradio as gr
9
+ import spaces
10
+ import torch
11
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
12
+
13
+ from themes.research_monochrome import theme
14
+
15
+ # Vision imports
16
+ import random
17
+ from transformers import LlavaNextProcessor, LlavaNextForConditionalGeneration
18
+
19
+ today_date = datetime.today().strftime("%B %-d, %Y") # noqa: DTZ002
20
+
21
+ SYS_PROMPT = f"""Knowledge Cutoff Date: April 2024.
22
+ Today's Date: {today_date}.
23
+ You are Granite, developed by IBM. You are a helpful AI assistant"""
24
+ TITLE = "IBM Granite 3.1 8b Instruct & Vision Preview"
25
+ DESCRIPTION = """
26
+ <p>Granite 3.1 8b instruct is an open-source LLM supporting a 128k context window. Start with one of the sample prompts
27
+ or upload an image and ask a question. Keep in mind that AI can occasionally make mistakes.
28
+ <span class="gr_docs_link">
29
+ <a href="https://www.ibm.com/granite/docs/">View Documentation <i class="fa fa-external-link"></i></a>
30
+ </span>
31
+ </p>
32
+ """
33
+ MAX_INPUT_TOKEN_LENGTH = 128_000
34
+ MAX_NEW_TOKENS = 1024
35
+ TEMPERATURE = 0.7
36
+ TOP_P = 0.85
37
+ TOP_K = 50
38
+ REPETITION_PENALTY = 1.05
39
+
40
+ if not torch.cuda.is_available():
41
+ print("This demo may not work on CPU.")
42
+
43
+ # Text Model and Tokenizer
44
+ text_model = AutoModelForCausalLM.from_pretrained(
45
+ "ibm-granite/granite-3.1-8b-instruct", torch_dtype=torch.float16, device_map="auto"
46
+ )
47
+ text_tokenizer = AutoTokenizer.from_pretrained("ibm-granite/granite-3.1-8b-instruct")
48
+ text_tokenizer.use_default_system_prompt = False
49
+
50
+ # Vision Model and Processor
51
+ vision_model_path = "ibm-granite/granite-vision-3.1-2b-preview"
52
+ vision_processor = LlavaNextProcessor.from_pretrained(vision_model_path, use_fast=True)
53
+ vision_model = LlavaNextForConditionalGeneration.from_pretrained(vision_model_path, torch_dtype="auto", device_map="auto")
54
+
55
+
56
+ @spaces.GPU
57
+ def generate(
58
+ message: str,
59
+ chat_history: list[dict],
60
+ temperature: float = TEMPERATURE,
61
+ repetition_penalty: float = REPETITION_PENALTY,
62
+ top_p: float = TOP_P,
63
+ top_k: float = TOP_K,
64
+ max_new_tokens: int = MAX_NEW_TOKENS,
65
+ ) -> Iterator[str]:
66
+ """Generate function for text chat demo."""
67
+ # Build messages
68
+ conversation = []
69
+ conversation.append({"role": "system", "content": SYS_PROMPT})
70
+ conversation += chat_history
71
+ conversation.append({"role": "user", "content": message})
72
+
73
+ # Convert messages to prompt format
74
+ input_ids = text_tokenizer.apply_chat_template(
75
+ conversation,
76
+ return_tensors="pt",
77
+ add_generation_prompt=True,
78
+ truncation=True,
79
+ max_length=MAX_INPUT_TOKEN_LENGTH - max_new_tokens,
80
+ )
81
+
82
+ input_ids = input_ids.to(text_model.device)
83
+ streamer = TextIteratorStreamer(text_tokenizer, skip_prompt=True, skip_special_tokens=True)
84
+ generate_kwargs = dict(
85
+ {"input_ids": input_ids},
86
+ streamer=streamer,
87
+ max_new_tokens=max_new_tokens,
88
+ do_sample=True,
89
+ top_p=top_p,
90
+ top_k=top_k,
91
+ temperature=temperature,
92
+ num_beams=1,
93
+ repetition_penalty=repetition_penalty,
94
+ )
95
+
96
+ t = Thread(target=text_model.generate, kwargs=generate_kwargs)
97
+ t.start()
98
+
99
+ outputs = []
100
+ for text in streamer:
101
+ outputs.append(text)
102
+ yield "".join(outputs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
 
104
  def get_text_from_content(content):
105
  texts = []
 
111
  return " ".join(texts)
112
 
113
  @spaces.GPU
114
+ def chat_inference(image, text, temperature, top_p, top_k, max_tokens, conversation):
115
  if conversation is None:
116
  conversation = []
117
 
 
168
  return chat_history
169
 
170
  def clear_chat():
171
+ return [], [], "", None
172
 
173
  css_file_path = Path(Path(__file__).parent / "app.css")
174
  head_file_path = Path(Path(__file__).parent / "app_head.html")
175
 
176
+ # Advanced settings (displayed in Accordion) - Common settings for both models
177
+ temperature_slider = gr.Slider(
178
+ minimum=0, maximum=1.0, value=TEMPERATURE, step=0.1, label="Temperature", elem_classes=["gr_accordion_element"]
179
  )
180
+ top_p_slider = gr.Slider(
181
+ minimum=0, maximum=1.0, value=TOP_P, step=0.05, label="Top P", elem_classes=["gr_accordion_element"]
182
  )
183
+ top_k_slider = gr.Slider(
184
+ minimum=0, maximum=100, value=TOP_K, step=1, label="Top K", elem_classes=["gr_accordion_element"]
185
  )
186
+
187
+ # Advanced settings specific to Text model
188
+ repetition_penalty_slider = gr.Slider(
189
  minimum=0,
190
  maximum=2.0,
191
  value=REPETITION_PENALTY,
192
  step=0.05,
193
+ label="Repetition Penalty (Text Model)",
194
  elem_classes=["gr_accordion_element"],
195
  )
196
+ max_new_tokens_slider = gr.Slider(
197
  minimum=1,
198
  maximum=2000,
199
  value=MAX_NEW_TOKENS,
200
  step=1,
201
+ label="Max New Tokens (Text Model)",
202
  elem_classes=["gr_accordion_element"],
203
  )
 
204
 
205
+ # Advanced settings specific to Vision model
206
+ max_tokens_slider_vision = gr.Slider(
207
+ minimum=10,
208
+ maximum=300,
209
+ value=128,
210
+ step=1,
211
+ label="Max Tokens (Vision Model)",
212
+ elem_classes=["gr_accordion_element"],
 
 
 
 
213
  )
 
214
 
215
+ chat_interface_accordion = gr.Accordion(label="Advanced Settings", open=False)
216
 
217
  with gr.Blocks(fill_height=True, css_paths=css_file_path, head_paths=head_file_path, theme=theme, title=TITLE) as demo:
218
  gr.HTML(f"<h1>{TITLE}</h1>", elem_classes=["gr_title"])
219
  gr.HTML(DESCRIPTION)
220
 
221
+ state = gr.State([]) # State for vision chat history
222
+ chat_history_state = gr.State([]) # State for text chat history
223
+
224
+ with gr.Row():
225
+ with gr.Column(scale=2):
226
+ image_input = gr.Image(type="pil", label="Upload Image (optional)")
227
+ with gr.Accordion(label="Vision Model Settings", open=False):
228
+ max_tokens_input_vision = max_tokens_slider_vision
229
+ with gr.Accordion(label="Text Model Settings", open=False):
230
+ repetition_penalty_input = repetition_penalty_slider
231
+ max_new_tokens_input = max_new_tokens_slider
232
+ with chat_interface_accordion: # Common Settings
233
+ temperature_input = temperature_slider
234
+ top_p_input = top_p_slider
235
+ top_k_input = top_k_slider
236
+
237
+ with gr.Column(scale=3):
238
+ chatbot = gr.Chatbot(label="Chat History", elem_id="chatbot", type='messages')
239
+ text_input = gr.Textbox(lines=2, placeholder="Enter your message here", label="Message")
240
+ with gr.Row():
241
+ send_button = gr.Button("Chat")
242
+ clear_button = gr.Button("Clear Chat")
243
+
244
+ def process_chat(image_input, text_input, temperature_input, top_p_input, top_k_input, repetition_penalty_input, max_new_tokens_input, max_tokens_input_vision, state, chat_history_state):
 
 
 
 
245
  if image_input:
246
+ # Use Vision model
247
+ return chat_inference(image_input, text_input, temperature_input, top_p_input, top_k_input, max_tokens_input_vision, state)
 
 
 
 
 
 
 
 
 
 
248
  else:
249
+ # Use Text model
250
+ return generate(text_input, chat_history_state, temperature_input, repetition_penalty_input, top_p_input, top_k_input, max_new_tokens_input), None # Return None for state as text model doesn't use it
 
 
 
 
 
 
 
 
 
251
 
252
+ def process_chat_wrapper(image_input_val, text_input_val, temperature_input_val, top_p_input_val, top_k_input_val, repetition_penalty_input_val, max_new_tokens_input_val, max_tokens_input_vision_val, state_val, chat_history_state_val):
253
+ if image_input_val:
254
+ chatbot_output, updated_state = process_chat(image_input_val, text_input_val, temperature_input_val, top_p_input_val, top_k_input_val, repetition_penalty_input_val, max_new_tokens_input_val, max_tokens_input_vision_val, state_val, chat_history_state_val)
255
+ return chatbot_output, updated_state, chat_history_state_val # Return vision state and keep text state unchanged
256
+ else:
257
+ chatbot_output_generator, _ = process_chat(image_input_val, text_input_val, temperature_input_val, top_p_input_val, top_k_input_val, repetition_penalty_input_val, max_new_tokens_input_val, max_tokens_input_vision_val, state_val, chat_history_state_val)
258
+ updated_chat_history = []
259
+ full_response = ""
260
+ for response_chunk in chatbot_output_generator:
261
+ full_response = response_chunk
262
+ if chat_history_state_val is None:
263
+ updated_chat_history = []
264
+ else:
265
+ updated_chat_history = chat_history_state_val
266
 
267
+ updated_chat_history.append({"role": "user", "content": text_input_val})
268
+ updated_chat_history.append({"role": "assistant", "content": full_response})
269
 
270
+ return updated_chat_history, state_val, updated_chat_history # Return text chat history, keep vision state unchanged, return updated text history for chatbot display
271
 
272
 
273
  send_button.click(
274
+ process_chat_wrapper,
275
+ inputs=[image_input, text_input, temperature_input, top_p_input, top_k_input, repetition_penalty_input, max_new_tokens_input, max_tokens_input_vision, state, chat_history_state],
276
+ outputs=[chatbot, state, chat_history_state] # Keep both states as output
 
 
277
  )
278
 
279
  clear_button.click(
280
  clear_chat,
281
  inputs=None,
282
+ outputs=[chatbot, state, text_input, image_input] # clear_chat clears vision state and input. Need to clear text state also.
283
  )
284
 
 
 
 
 
 
 
 
 
285
  gr.Examples(
286
  examples=[
287
+ ["Explain the concept of quantum computing to someone with no background in physics or computer science."],
288
+ ["What is OpenShift?"],
289
+ ["What's the importance of low latency inference?"],
290
+ ["Help me boost productivity habits."],
291
  [
292
  """Explain the following code in a concise manner:
293
 
 
328
  this.y = y;
329
  }
330
  }
331
+ ```"""
332
  ],
333
  [
334
  """Generate a Java code block from the following explanation:
 
339
 
340
  The Pair class is a simple data structure that stores two integers.
341
 
342
+ The main method creates an array of integers, initializes the difference value, and calls the findPairs method to find all pairs in the array. Finally, the code iterates over the list of pairs and prints each pair to the console.""" # noqa: E501
343
  ],
344
+ ["https://raw.githubusercontent.com/gradio-app/gradio/main/test/test_files/bus.png", "What is this?"] # Vision example
 
 
345
  ],
346
+ inputs=[text_input, text_input, text_input, text_input, text_input, text_input, image_input, image_input] , # Duplicated text_input to match example count, last two are image_input for vision example
347
+ examples_per_page=7
 
 
 
 
 
 
 
 
 
 
 
348
  )
349
 
 
350
  if __name__ == "__main__":
351
  demo.queue().launch()