File size: 2,087 Bytes
7166b9d
9a6b3b4
 
7166b9d
a90d622
e65e766
 
 
 
9a6b3b4
 
cc25ca0
e65e766
 
 
b8408d1
a90d622
 
b8408d1
a90d622
 
 
 
 
bff2feb
e65e766
 
a90d622
 
 
e65e766
bff2feb
e65e766
 
 
cc25ca0
9a6b3b4
 
 
 
 
 
 
 
 
 
 
 
7166b9d
bff2feb
9a6b3b4
b8408d1
 
7166b9d
9a6b3b4
7166b9d
9a6b3b4
7166b9d
9a6b3b4
e65e766
b8408d1
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import Optional
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
import logging

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

app = FastAPI()

try:
    model_name = "scb10x/llama-3-typhoon-v1.5-8b-instruct"
    tokenizer = AutoTokenizer.from_pretrained(model_name)

    device = "cuda" if torch.cuda.is_available() else "cpu"
    logger.info(f"Using device: {device}")

    # 4-bit quantization configuration
    quantization_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_compute_dtype=torch.float16
    )

    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        quantization_config=quantization_config,
        device_map="auto",
        low_cpu_mem_usage=True,
    )
    logger.info(f"Model loaded successfully on {device}")
except Exception as e:
    logger.error(f"Error loading model: {str(e)}")
    raise

class Query(BaseModel):
    queryResult: Optional[dict] = None
    queryText: Optional[str] = None

@app.post("/webhook")
async def webhook(query: Query):
    try:
        user_query = query.queryResult.get('queryText') if query.queryResult else query.queryText
        
        if not user_query:
            raise HTTPException(status_code=400, detail="No query text provided")
        
        prompt = f"Human: {user_query}\nAI:"
        input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
        
        with torch.no_grad():
            output = model.generate(input_ids, max_new_tokens=100, temperature=0.7)
        response = tokenizer.decode(output[0], skip_special_tokens=True)
        
        ai_response = response.split("AI:")[-1].strip()
        
        return {"fulfillmentText": ai_response}
    except Exception as e:
        logger.error(f"Error in webhook: {str(e)}")
        raise HTTPException(status_code=500, detail=str(e))

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=7860)