Spaces:
Sleeping
Sleeping
from fastapi import FastAPI, HTTPException | |
from pydantic import BaseModel | |
from typing import List, Dict, Any | |
from pymongo import MongoClient | |
from transformers import BlenderbotTokenizer, BlenderbotForConditionalGeneration | |
import spacy | |
import os | |
import logging | |
# Set up logging with detailed output | |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') | |
logger = logging.getLogger(__name__) | |
app = FastAPI() | |
# MongoDB Setup | |
connection_string = os.getenv("MONGO_URI", "mongodb+srv://clician:[email protected]/?retryWrites=true&w=majority&appName=Hutterdev") | |
client = MongoClient(connection_string) | |
db = client["test"] | |
products_collection = db["products"] | |
# BlenderBot Setup | |
model_repo = "SyedHutter/blenderbot_model" | |
model_subfolder = "blenderbot_model" | |
model_dir = "/home/user/app/blenderbot_model" | |
if not os.path.exists(model_dir): | |
logger.info(f"Downloading {model_repo}/{model_subfolder} to {model_dir}...") | |
tokenizer = BlenderbotTokenizer.from_pretrained(model_repo, subfolder=model_subfolder) | |
model = BlenderbotForConditionalGeneration.from_pretrained(model_repo, subfolder=model_subfolder) | |
os.makedirs(model_dir, exist_ok=True) | |
tokenizer.save_pretrained(model_dir) | |
model.save_pretrained(model_dir) | |
logger.info("Model download complete.") | |
else: | |
logger.info(f"Loading pre-existing model from {model_dir}.") | |
tokenizer = BlenderbotTokenizer.from_pretrained(model_dir) | |
model = BlenderbotForConditionalGeneration.from_pretrained(model_dir) | |
# Static Context | |
context_msg = "Hutter Products GmbH provides sustainable products like shirts and shorts..." | |
# spaCy Setup | |
spacy_model_path = "/home/user/app/en_core_web_sm-3.8.0" | |
nlp = spacy.load(spacy_model_path) | |
# Pydantic Models | |
class PromptRequest(BaseModel): | |
input_text: str | |
conversation_history: List[str] = [] | |
class CombinedResponse(BaseModel): | |
ner: Dict[str, Any] | |
qa: Dict[str, Any] | |
products_matched: List[Dict[str, Any]] | |
# Helper Functions | |
def extract_keywords(text: str) -> List[str]: | |
doc = nlp(text) | |
keywords = [token.text for token in doc if token.pos_ in ["NOUN", "PROPN"]] | |
return list(set(keywords)) | |
def detect_intent(text: str) -> str: | |
doc = nlp(text.lower()) | |
# Stricter matching for "shirt" vs "short" | |
if "shirt" in [token.text for token in doc]: # Exact match for "shirt" | |
return "recommend_shirt" | |
elif "short" in [token.text for token in doc]: # Exact match for "short" | |
return "recommend_shorts" | |
elif any(token.text in ["what", "who", "company", "do", "products"] for token in doc): | |
return "company_info" | |
return "unknown" | |
def search_products_by_keywords(keywords: List[str]) -> List[Dict[str, Any]]: | |
if not keywords: | |
logger.info("No keywords provided, returning empty product list.") | |
return [] | |
# Use stricter matching: only return products with exact keyword in name | |
query = {"$or": [{"name": {"$regex": f"\\b{keyword}\\b", "$options": "i"}} for keyword in keywords]} | |
matched_products = [ | |
{ | |
"id": str(p["_id"]), | |
"name": p.get("name", "Unknown"), | |
"skuNumber": p.get("skuNumber", "N/A"), | |
"description": p.get("description", "No description available") | |
} | |
for p in products_collection.find(query) | |
] | |
return matched_products | |
def get_product_context(products: List[Dict]) -> str: | |
if not products: | |
return "" | |
product_str = "Here are some products: " | |
product_str += ", ".join([f"'{p['name']}' - {p['description']}" for p in products[:2]]) | |
return product_str | |
def format_response(response: str, products: List[Dict], intent: str) -> str: | |
if intent in ["recommend_shirt", "recommend_shorts"] and products: | |
product = products[0] | |
return f"{response} For example, check out our '{product['name']}'—it’s {product['description'].lower()}!" | |
elif intent == "company_info": | |
return f"{response} At Hutter Products GmbH, we specialize in sustainable product design and production!" | |
return response | |
# Endpoints | |
async def root(): | |
return {"message": "Welcome to the NER and Chat API!"} | |
async def process_prompt(request: PromptRequest): | |
try: | |
logger.info(f"Processing request: {request.input_text}") | |
input_text = request.input_text | |
history = request.conversation_history[-3:] if request.conversation_history else [] | |
intent = detect_intent(input_text) | |
keywords = extract_keywords(input_text) | |
logger.info(f"Intent: {intent}, Keywords: {keywords}") | |
products = search_products_by_keywords(keywords) | |
product_context = get_product_context(products) | |
logger.info(f"Products matched: {len(products)}") | |
history_str = " || ".join(history) | |
full_input = f"{history_str} || {product_context} {context_msg} || {input_text}" if history else f"{product_context} {context_msg} || {input_text}" | |
logger.info(f"Full input to model: {full_input}") | |
logger.info("Tokenizing input...") | |
inputs = tokenizer(full_input, return_tensors="pt", truncation=True, max_length=512) | |
logger.info("Input tokenized successfully.") | |
logger.info("Generating model response...") | |
outputs = model.generate(**inputs, max_length=50, num_beams=1, no_repeat_ngram_size=2) | |
logger.info("Model generation complete.") | |
logger.info("Decoding model output...") | |
response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
logger.info(f"Model response: {response}") | |
enhanced_response = format_response(response, products, intent) | |
qa_response = { | |
"question": input_text, | |
"answer": enhanced_response, | |
"score": 1.0 | |
} | |
logger.info("Returning response...") | |
return { | |
"ner": {"extracted_keywords": keywords}, | |
"qa": qa_response, | |
"products_matched": products | |
} | |
except Exception as e: | |
logger.error(f"Error processing request: {str(e)}", exc_info=True) | |
raise HTTPException(status_code=500, detail=f"Oops, something went wrong: {str(e)}. Try again!") | |
async def startup_event(): | |
logger.info("API is running with BlenderBot-400M-distill, connected to MongoDB.") | |
def shutdown_event(): | |
client.close() |