Spaces:
Runtime error
Runtime error
File size: 2,087 Bytes
7166b9d 9a6b3b4 7166b9d a90d622 e65e766 9a6b3b4 cc25ca0 e65e766 b8408d1 a90d622 b8408d1 a90d622 bff2feb e65e766 a90d622 e65e766 bff2feb e65e766 cc25ca0 9a6b3b4 7166b9d bff2feb 9a6b3b4 b8408d1 7166b9d 9a6b3b4 7166b9d 9a6b3b4 7166b9d 9a6b3b4 e65e766 b8408d1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 |
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import Optional
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
import logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
app = FastAPI()
try:
model_name = "scb10x/llama-3-typhoon-v1.5-8b-instruct"
tokenizer = AutoTokenizer.from_pretrained(model_name)
device = "cuda" if torch.cuda.is_available() else "cpu"
logger.info(f"Using device: {device}")
# 4-bit quantization configuration
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.float16
)
model = AutoModelForCausalLM.from_pretrained(
model_name,
quantization_config=quantization_config,
device_map="auto",
low_cpu_mem_usage=True,
)
logger.info(f"Model loaded successfully on {device}")
except Exception as e:
logger.error(f"Error loading model: {str(e)}")
raise
class Query(BaseModel):
queryResult: Optional[dict] = None
queryText: Optional[str] = None
@app.post("/webhook")
async def webhook(query: Query):
try:
user_query = query.queryResult.get('queryText') if query.queryResult else query.queryText
if not user_query:
raise HTTPException(status_code=400, detail="No query text provided")
prompt = f"Human: {user_query}\nAI:"
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
with torch.no_grad():
output = model.generate(input_ids, max_new_tokens=100, temperature=0.7)
response = tokenizer.decode(output[0], skip_special_tokens=True)
ai_response = response.split("AI:")[-1].strip()
return {"fulfillmentText": ai_response}
except Exception as e:
logger.error(f"Error in webhook: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860) |