abdullahalioo commited on
Commit
2bfae3d
·
verified ·
1 Parent(s): 9a3022c

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +55 -18
main.py CHANGED
@@ -2,8 +2,10 @@ from fastapi import FastAPI
2
  from pydantic import BaseModel
3
  from fastapi.middleware.cors import CORSMiddleware
4
  from fastapi.responses import StreamingResponse
5
- from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer
6
  import torch
 
 
7
 
8
  app = FastAPI()
9
 
@@ -27,6 +29,36 @@ model = AutoModelForCausalLM.from_pretrained(
27
  class Question(BaseModel):
28
  question: str
29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  def generate_response_chunks(prompt: str):
31
  try:
32
  # Prepare input
@@ -41,23 +73,28 @@ def generate_response_chunks(prompt: str):
41
  return_tensors="pt"
42
  ).to(model.device)
43
 
44
- # Set up streamer
45
- streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
46
-
47
- # Generate response with streaming
48
- with torch.no_grad():
49
- model.generate(
50
- inputs,
51
- max_new_tokens=512,
52
- do_sample=True,
53
- temperature=0.7,
54
- top_p=0.9,
55
- streamer=streamer
56
- )
57
- # Since TextStreamer handles printing, we yield chunks from the streamer
58
- for text in streamer:
59
- if text:
60
- yield text
 
 
 
 
 
61
 
62
  except Exception as e:
63
  yield f"Error occurred: {str(e)}"
 
2
  from pydantic import BaseModel
3
  from fastapi.middleware.cors import CORSMiddleware
4
  from fastapi.responses import StreamingResponse
5
+ from transformers import AutoModelForCausalLM, AutoTokenizer
6
  import torch
7
+ from queue import Queue
8
+ from threading import Thread
9
 
10
  app = FastAPI()
11
 
 
29
  class Question(BaseModel):
30
  question: str
31
 
32
+ class CustomTextStreamer:
33
+ def __init__(self, tokenizer):
34
+ self.tokenizer = tokenizer
35
+ self.queue = Queue()
36
+ self.skip_prompt = True
37
+ self.skip_special_tokens = True
38
+
39
+ def put(self, value):
40
+ # Handle token IDs (value is a tensor of token IDs)
41
+ if isinstance(value, torch.Tensor):
42
+ if value.dim() > 1:
43
+ value = value.squeeze(0) # Remove batch dimension if present
44
+ text = self.tokenizer.decode(value, skip_special_tokens=self.skip_special_tokens)
45
+ if text and not (self.skip_prompt and self.is_prompt(value)):
46
+ self.queue.put(text)
47
+
48
+ def end(self):
49
+ self.queue.put(None) # Signal end of generation
50
+
51
+ def is_prompt(self, value):
52
+ # Simple heuristic to skip prompt tokens (optional, adjust as needed)
53
+ return False # For simplicity, assume all tokens are response tokens
54
+
55
+ def __iter__(self):
56
+ while True:
57
+ item = self.queue.get()
58
+ if item is None:
59
+ break
60
+ yield item
61
+
62
  def generate_response_chunks(prompt: str):
63
  try:
64
  # Prepare input
 
73
  return_tensors="pt"
74
  ).to(model.device)
75
 
76
+ # Set up custom streamer
77
+ streamer = CustomTextStreamer(tokenizer)
78
+
79
+ # Run generation in a separate thread to avoid blocking
80
+ def generate():
81
+ with torch.no_grad():
82
+ model.generate(
83
+ inputs,
84
+ max_new_tokens=512,
85
+ do_sample=True,
86
+ temperature=0.7,
87
+ top_p=0.9,
88
+ streamer=streamer
89
+ )
90
+
91
+ # Start generation in a thread
92
+ thread = Thread(target=generate)
93
+ thread.start()
94
+
95
+ # Yield chunks from the streamer
96
+ for text in streamer:
97
+ yield text
98
 
99
  except Exception as e:
100
  yield f"Error occurred: {str(e)}"