abdullahalioo commited on
Commit
9a3022c
·
verified ·
1 Parent(s): 560244c

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +15 -14
main.py CHANGED
@@ -2,7 +2,7 @@ from fastapi import FastAPI
2
  from pydantic import BaseModel
3
  from fastapi.middleware.cors import CORSMiddleware
4
  from fastapi.responses import StreamingResponse
5
- from transformers import AutoModelForCausalLM, AutoTokenizer
6
  import torch
7
 
8
  app = FastAPI()
@@ -15,7 +15,7 @@ app.add_middleware(
15
  allow_headers=["*"],
16
  )
17
 
18
- # Load model and tokenizer (do this once at startup)
19
  model_name = "Qwen/Qwen2.5-7B-Instruct-1M"
20
  tokenizer = AutoTokenizer.from_pretrained(model_name)
21
  model = AutoModelForCausalLM.from_pretrained(
@@ -40,30 +40,31 @@ def generate_response_chunks(prompt: str):
40
  add_generation_prompt=True,
41
  return_tensors="pt"
42
  ).to(model.device)
43
-
44
- # Generate streamingly
 
 
 
45
  with torch.no_grad():
46
- for outputs in model.generate(
47
  inputs,
48
  max_new_tokens=512,
49
  do_sample=True,
50
  temperature=0.7,
51
  top_p=0.9,
52
- streamer=None, # We'll implement manual streaming
53
- stopping_criteria=None
54
- ):
55
- chunk = outputs[0, inputs.shape[1]:]
56
- text = tokenizer.decode(chunk, skip_special_tokens=True)
57
  if text:
58
  yield text
59
-
60
  except Exception as e:
61
- yield f"Error occurred: {e}"
62
 
63
  @app.post("/ask")
64
  async def ask(question: Question):
65
  return StreamingResponse(
66
  generate_response_chunks(question.question),
67
  media_type="text/plain"
68
- )
69
-
 
2
  from pydantic import BaseModel
3
  from fastapi.middleware.cors import CORSMiddleware
4
  from fastapi.responses import StreamingResponse
5
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer
6
  import torch
7
 
8
  app = FastAPI()
 
15
  allow_headers=["*"],
16
  )
17
 
18
+ # Load model and tokenizer
19
  model_name = "Qwen/Qwen2.5-7B-Instruct-1M"
20
  tokenizer = AutoTokenizer.from_pretrained(model_name)
21
  model = AutoModelForCausalLM.from_pretrained(
 
40
  add_generation_prompt=True,
41
  return_tensors="pt"
42
  ).to(model.device)
43
+
44
+ # Set up streamer
45
+ streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
46
+
47
+ # Generate response with streaming
48
  with torch.no_grad():
49
+ model.generate(
50
  inputs,
51
  max_new_tokens=512,
52
  do_sample=True,
53
  temperature=0.7,
54
  top_p=0.9,
55
+ streamer=streamer
56
+ )
57
+ # Since TextStreamer handles printing, we yield chunks from the streamer
58
+ for text in streamer:
 
59
  if text:
60
  yield text
61
+
62
  except Exception as e:
63
+ yield f"Error occurred: {str(e)}"
64
 
65
  @app.post("/ask")
66
  async def ask(question: Question):
67
  return StreamingResponse(
68
  generate_response_chunks(question.question),
69
  media_type="text/plain"
70
+ )