from fastapi import FastAPI from pydantic import BaseModel from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import StreamingResponse from transformers import AutoModelForCausalLM, AutoTokenizer import torch from queue import Queue from threading import Thread app = FastAPI() app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Load model and tokenizer model_name = "Qwen/Qwen2.5-7B-Instruct-1M" tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype=torch.float16, device_map="auto" ) class Question(BaseModel): question: str class CustomTextStreamer: def __init__(self, tokenizer): self.tokenizer = tokenizer self.queue = Queue() self.skip_prompt = True self.skip_special_tokens = True def put(self, value): # Handle token IDs (value is a tensor of token IDs) if isinstance(value, torch.Tensor): if value.dim() > 1: value = value.squeeze(0) # Remove batch dimension if present text = self.tokenizer.decode(value, skip_special_tokens=self.skip_special_tokens) if text and not (self.skip_prompt and self.is_prompt(value)): self.queue.put(text) def end(self): self.queue.put(None) # Signal end of generation def is_prompt(self, value): # Simple heuristic to skip prompt tokens (optional, adjust as needed) return False # For simplicity, assume all tokens are response tokens def __iter__(self): while True: item = self.queue.get() if item is None: break yield item def generate_response_chunks(prompt: str): try: # Prepare input messages = [ {"role": "system", "content": "You are Orion AI assistant..."}, {"role": "user", "content": prompt} ] inputs = tokenizer.apply_chat_template( messages, tokenize=True, add_generation_prompt=True, return_tensors="pt" ).to(model.device) # Set up custom streamer streamer = CustomTextStreamer(tokenizer) # Run generation in a separate thread to avoid blocking def generate(): with torch.no_grad(): model.generate( inputs, max_new_tokens=512, do_sample=True, temperature=0.7, top_p=0.9, streamer=streamer ) # Start generation in a thread thread = Thread(target=generate) thread.start() # Yield chunks from the streamer for text in streamer: yield text except Exception as e: yield f"Error occurred: {str(e)}" @app.post("/ask") async def ask(question: Question): return StreamingResponse( generate_response_chunks(question.question), media_type="text/plain" )