Spaces:
Sleeping
Sleeping
Update main.py
Browse files
main.py
CHANGED
@@ -2,7 +2,7 @@ from fastapi import FastAPI
|
|
2 |
from pydantic import BaseModel
|
3 |
from fastapi.middleware.cors import CORSMiddleware
|
4 |
from fastapi.responses import StreamingResponse
|
5 |
-
from transformers import AutoModelForCausalLM, AutoTokenizer
|
6 |
import torch
|
7 |
|
8 |
app = FastAPI()
|
@@ -15,7 +15,7 @@ app.add_middleware(
|
|
15 |
allow_headers=["*"],
|
16 |
)
|
17 |
|
18 |
-
# Load model and tokenizer
|
19 |
model_name = "Qwen/Qwen2.5-7B-Instruct-1M"
|
20 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
21 |
model = AutoModelForCausalLM.from_pretrained(
|
@@ -40,30 +40,31 @@ def generate_response_chunks(prompt: str):
|
|
40 |
add_generation_prompt=True,
|
41 |
return_tensors="pt"
|
42 |
).to(model.device)
|
43 |
-
|
44 |
-
#
|
|
|
|
|
|
|
45 |
with torch.no_grad():
|
46 |
-
|
47 |
inputs,
|
48 |
max_new_tokens=512,
|
49 |
do_sample=True,
|
50 |
temperature=0.7,
|
51 |
top_p=0.9,
|
52 |
-
streamer=
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
text = tokenizer.decode(chunk, skip_special_tokens=True)
|
57 |
if text:
|
58 |
yield text
|
59 |
-
|
60 |
except Exception as e:
|
61 |
-
yield f"Error occurred: {e}"
|
62 |
|
63 |
@app.post("/ask")
|
64 |
async def ask(question: Question):
|
65 |
return StreamingResponse(
|
66 |
generate_response_chunks(question.question),
|
67 |
media_type="text/plain"
|
68 |
-
)
|
69 |
-
|
|
|
2 |
from pydantic import BaseModel
|
3 |
from fastapi.middleware.cors import CORSMiddleware
|
4 |
from fastapi.responses import StreamingResponse
|
5 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer
|
6 |
import torch
|
7 |
|
8 |
app = FastAPI()
|
|
|
15 |
allow_headers=["*"],
|
16 |
)
|
17 |
|
18 |
+
# Load model and tokenizer
|
19 |
model_name = "Qwen/Qwen2.5-7B-Instruct-1M"
|
20 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
21 |
model = AutoModelForCausalLM.from_pretrained(
|
|
|
40 |
add_generation_prompt=True,
|
41 |
return_tensors="pt"
|
42 |
).to(model.device)
|
43 |
+
|
44 |
+
# Set up streamer
|
45 |
+
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
46 |
+
|
47 |
+
# Generate response with streaming
|
48 |
with torch.no_grad():
|
49 |
+
model.generate(
|
50 |
inputs,
|
51 |
max_new_tokens=512,
|
52 |
do_sample=True,
|
53 |
temperature=0.7,
|
54 |
top_p=0.9,
|
55 |
+
streamer=streamer
|
56 |
+
)
|
57 |
+
# Since TextStreamer handles printing, we yield chunks from the streamer
|
58 |
+
for text in streamer:
|
|
|
59 |
if text:
|
60 |
yield text
|
61 |
+
|
62 |
except Exception as e:
|
63 |
+
yield f"Error occurred: {str(e)}"
|
64 |
|
65 |
@app.post("/ask")
|
66 |
async def ask(question: Question):
|
67 |
return StreamingResponse(
|
68 |
generate_response_chunks(question.question),
|
69 |
media_type="text/plain"
|
70 |
+
)
|
|