chatbot-space / app.py
mynuddin's picture
Update app.py
f31fa73 verified
raw
history blame
1.66 kB
import os
from fastapi import FastAPI
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
import torch
from pydantic import BaseModel
# Set writable cache directory inside the container
os.environ['SENTENCE_TRANSFORMERS_HOME'] = '/app/hf_home'
os.environ['TRANSFORMERS_CACHE'] = '/app/hf_home'
# Ensure the directory exists
os.makedirs(os.environ['TRANSFORMERS_CACHE'], exist_ok=True)
# Define base model and adapter model
base_model_name = "facebook/opt-2.7b"
adapter_name = "mynuddin/chatbot" # Adapter model path or name
# Load base model and tokenizer
tokenizer = AutoTokenizer.from_pretrained(base_model_name)
base_model = AutoModelForCausalLM.from_pretrained(base_model_name, torch_dtype=torch.float16)
# Load PEFT adapter
model = PeftModel.from_pretrained(base_model, adapter_name)
model = model.to("cuda" if torch.cuda.is_available() else "cpu") # Use GPU if available
model.eval()
app = FastAPI()
# Define Pydantic model for input
class PromptInput(BaseModel):
prompt: str
@app.post("/generate")
def generate_text(input: PromptInput):
prompt = input.prompt # Access prompt from the request body
# Format the prompt with specific style for your fine-tuned model
inputs = tokenizer(f"### Prompt: {prompt}\n### Completion:", return_tensors="pt").to("cuda" if torch.cuda.is_available() else "cpu")
# Generate the output
with torch.no_grad():
output = model.generate(**inputs, max_length=128)
# Decode the output and remove special tokens
generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
return {"generated_query": generated_text}