Spaces:
Running
Running
from fastapi import FastAPI | |
from pydantic import BaseModel | |
from fastapi.middleware.cors import CORSMiddleware | |
from fastapi.responses import StreamingResponse | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
import torch | |
from queue import Queue | |
from threading import Thread | |
app = FastAPI() | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
# Load model and tokenizer | |
model_name = "Qwen/Qwen2.5-7B-Instruct-1M" | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModelForCausalLM.from_pretrained( | |
model_name, | |
torch_dtype=torch.float16, | |
device_map="auto" | |
) | |
class Question(BaseModel): | |
question: str | |
class CustomTextStreamer: | |
def __init__(self, tokenizer): | |
self.tokenizer = tokenizer | |
self.queue = Queue() | |
self.skip_prompt = True | |
self.skip_special_tokens = True | |
def put(self, value): | |
# Handle token IDs (value is a tensor of token IDs) | |
if isinstance(value, torch.Tensor): | |
if value.dim() > 1: | |
value = value.squeeze(0) # Remove batch dimension if present | |
text = self.tokenizer.decode(value, skip_special_tokens=self.skip_special_tokens) | |
if text and not (self.skip_prompt and self.is_prompt(value)): | |
self.queue.put(text) | |
def end(self): | |
self.queue.put(None) # Signal end of generation | |
def is_prompt(self, value): | |
# Simple heuristic to skip prompt tokens (optional, adjust as needed) | |
return False # For simplicity, assume all tokens are response tokens | |
def __iter__(self): | |
while True: | |
item = self.queue.get() | |
if item is None: | |
break | |
yield item | |
def generate_response_chunks(prompt: str): | |
try: | |
# Prepare input | |
messages = [ | |
{"role": "system", "content": "You are Orion AI assistant..."}, | |
{"role": "user", "content": prompt} | |
] | |
inputs = tokenizer.apply_chat_template( | |
messages, | |
tokenize=True, | |
add_generation_prompt=True, | |
return_tensors="pt" | |
).to(model.device) | |
# Set up custom streamer | |
streamer = CustomTextStreamer(tokenizer) | |
# Run generation in a separate thread to avoid blocking | |
def generate(): | |
with torch.no_grad(): | |
model.generate( | |
inputs, | |
max_new_tokens=512, | |
do_sample=True, | |
temperature=0.7, | |
top_p=0.9, | |
streamer=streamer | |
) | |
# Start generation in a thread | |
thread = Thread(target=generate) | |
thread.start() | |
# Yield chunks from the streamer | |
for text in streamer: | |
yield text | |
except Exception as e: | |
yield f"Error occurred: {str(e)}" | |
async def ask(question: Question): | |
return StreamingResponse( | |
generate_response_chunks(question.question), | |
media_type="text/plain" | |
) |