File size: 3,138 Bytes
2c97dd8
d6be5f7
 
 
2bfae3d
3ada3ad
2bfae3d
 
d6be5f7
 
 
 
 
2c97dd8
d6be5f7
 
 
 
 
9a3022c
560244c
3ada3ad
 
 
 
 
 
 
d6be5f7
 
 
2bfae3d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3ada3ad
 
 
 
 
d6be5f7
3ada3ad
 
 
 
 
 
 
9a3022c
2bfae3d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9a3022c
3ada3ad
9a3022c
d6be5f7
 
 
 
 
 
9a3022c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
from fastapi import FastAPI
from pydantic import BaseModel
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from queue import Queue
from threading import Thread

app = FastAPI()

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# Load model and tokenizer
model_name = "Qwen/Qwen2.5-7B-Instruct-1M"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.float16,
    device_map="auto"
)

class Question(BaseModel):
    question: str

class CustomTextStreamer:
    def __init__(self, tokenizer):
        self.tokenizer = tokenizer
        self.queue = Queue()
        self.skip_prompt = True
        self.skip_special_tokens = True

    def put(self, value):
        # Handle token IDs (value is a tensor of token IDs)
        if isinstance(value, torch.Tensor):
            if value.dim() > 1:
                value = value.squeeze(0)  # Remove batch dimension if present
            text = self.tokenizer.decode(value, skip_special_tokens=self.skip_special_tokens)
            if text and not (self.skip_prompt and self.is_prompt(value)):
                self.queue.put(text)

    def end(self):
        self.queue.put(None)  # Signal end of generation

    def is_prompt(self, value):
        # Simple heuristic to skip prompt tokens (optional, adjust as needed)
        return False  # For simplicity, assume all tokens are response tokens

    def __iter__(self):
        while True:
            item = self.queue.get()
            if item is None:
                break
            yield item

def generate_response_chunks(prompt: str):
    try:
        # Prepare input
        messages = [
            {"role": "system", "content": "You are Orion AI assistant..."},
            {"role": "user", "content": prompt}
        ]
        inputs = tokenizer.apply_chat_template(
            messages,
            tokenize=True,
            add_generation_prompt=True,
            return_tensors="pt"
        ).to(model.device)

        # Set up custom streamer
        streamer = CustomTextStreamer(tokenizer)

        # Run generation in a separate thread to avoid blocking
        def generate():
            with torch.no_grad():
                model.generate(
                    inputs,
                    max_new_tokens=512,
                    do_sample=True,
                    temperature=0.7,
                    top_p=0.9,
                    streamer=streamer
                )

        # Start generation in a thread
        thread = Thread(target=generate)
        thread.start()

        # Yield chunks from the streamer
        for text in streamer:
            yield text

    except Exception as e:
        yield f"Error occurred: {str(e)}"

@app.post("/ask")
async def ask(question: Question):
    return StreamingResponse(
        generate_response_chunks(question.question),
        media_type="text/plain"
    )