Ruurd commited on
Commit
4f67864
·
2 Parent(s): 0040338 0e4362c

Merge branch 'main' of https://huggingface.co./spaces/Ruurd/radiolm

Browse files
Files changed (1) hide show
  1. app.py +50 -0
app.py CHANGED
@@ -1,6 +1,8 @@
1
  import os
2
  import torch
3
  import time
 
 
4
  import gradio as gr
5
  import spaces
6
  from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
@@ -12,6 +14,51 @@ import threading
12
  from transformers import TextIteratorStreamer
13
  import queue
14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  class RichTextStreamer(TextIteratorStreamer):
16
  def __init__(self, tokenizer, prompt_len=0, **kwargs):
17
  super().__init__(tokenizer, **kwargs)
@@ -194,10 +241,13 @@ def add_user_message(user_input, history):
194
 
195
  # Curated models
196
  model_choices = [
 
197
  "meta-llama/Llama-3.2-3B-Instruct",
198
  "deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
199
  "google/gemma-7b",
200
  "mistralai/Mistral-Small-3.1-24B-Instruct-2503"
 
 
201
  ]
202
 
203
  with gr.Blocks() as demo:
 
1
  import os
2
  import torch
3
  import time
4
+ import torch
5
+ import time
6
  import gradio as gr
7
  import spaces
8
  from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
 
14
  from transformers import TextIteratorStreamer
15
  import queue
16
 
17
+ class RichTextStreamer(TextIteratorStreamer):
18
+ def __init__(self, tokenizer, prompt_len=0, **kwargs):
19
+ super().__init__(tokenizer, **kwargs)
20
+ self.token_queue = queue.Queue()
21
+ self.prompt_len = prompt_len
22
+ self.count = 0
23
+
24
+ def put(self, value):
25
+ if isinstance(value, torch.Tensor):
26
+ token_ids = value.view(-1).tolist()
27
+ elif isinstance(value, list):
28
+ token_ids = value
29
+ else:
30
+ token_ids = [value]
31
+
32
+ for token_id in token_ids:
33
+ self.count += 1
34
+ if self.count <= self.prompt_len:
35
+ continue # skip prompt tokens
36
+ token_str = self.tokenizer.decode([token_id], **self.decode_kwargs)
37
+ is_special = token_id in self.tokenizer.all_special_ids
38
+ self.token_queue.put({
39
+ "token_id": token_id,
40
+ "token": token_str,
41
+ "is_special": is_special
42
+ })
43
+
44
+ def __iter__(self):
45
+ while True:
46
+ try:
47
+ token_info = self.token_queue.get(timeout=self.timeout)
48
+ yield token_info
49
+ except queue.Empty:
50
+ if self.end_of_generation.is_set():
51
+ break
52
+
53
+ from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
54
+ import threading
55
+
56
+ from transformers import TextIteratorStreamer
57
+ import threading
58
+
59
+ from transformers import TextIteratorStreamer
60
+ import queue
61
+
62
  class RichTextStreamer(TextIteratorStreamer):
63
  def __init__(self, tokenizer, prompt_len=0, **kwargs):
64
  super().__init__(tokenizer, **kwargs)
 
241
 
242
  # Curated models
243
  model_choices = [
244
+ "meta-llama/Llama-3.2-3B-Instruct",
245
  "meta-llama/Llama-3.2-3B-Instruct",
246
  "deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
247
  "google/gemma-7b",
248
  "mistralai/Mistral-Small-3.1-24B-Instruct-2503"
249
+ "google/gemma-7b",
250
+ "mistralai/Mistral-Small-3.1-24B-Instruct-2503"
251
  ]
252
 
253
  with gr.Blocks() as demo: