SyedHutter commited on
Commit
32d7156
·
verified ·
1 Parent(s): 5e3fd94

app.py Beta 2

Browse files
Files changed (1) hide show
  1. app.py +138 -148
app.py CHANGED
@@ -1,148 +1,138 @@
1
- from fastapi import FastAPI
2
- from pydantic import BaseModel
3
- from typing import List, Dict, Any
4
- from pymongo import MongoClient
5
- from transformers import pipeline
6
- import spacy
7
- import subprocess
8
- import sys
9
-
10
- # FastAPI app setup
11
- app = FastAPI()
12
-
13
- # ==========================
14
- # MongoDB Connection Setup
15
- # ==========================
16
- connection_string = "mongodb+srv://clician:[email protected]/?retryWrites=true&w=majority&appName=Hutterdev"
17
- client = MongoClient(connection_string)
18
- db = client["test"] # Replace with your database name
19
- products_collection = db["products"] # Replace with your collection name
20
-
21
- # ==========================
22
- # Transformers Pipeline Setup
23
- # ==========================
24
- # Load the Question-Answering pipeline
25
- qa_pipeline = pipeline("question-answering", model="bert-large-uncased-whole-word-masking-finetuned-squad")
26
-
27
- # ==========================
28
- # Static Context Message
29
- # ==========================
30
- context_msg = (
31
- "Hutter Products GmbH provides a wide array of services to help businesses create high-quality, sustainable products. "
32
- "Their offerings include comprehensive product design, ensuring items are both visually appealing and functional, and product consulting, "
33
- "which provides expert advice on features, materials, and design elements. They also offer sustainability consulting to integrate eco-friendly practices, "
34
- "such as using recycled materials and Ocean Bound Plastic. Additionally, they manage customized production to ensure products meet the highest standards "
35
- "and offer product animation services, creating realistic rendered images and animations to enhance online engagement. These services collectively enable "
36
- "businesses to develop products that are sustainable, market-responsive, and aligned with their brand identity."
37
- )
38
-
39
- # ==========================
40
- # spaCy NER Setup
41
- # ==========================
42
- # ==========================
43
- # spaCy NER Setup
44
- # ==========================
45
- from spacy.util import is_package
46
-
47
- # Ensure 'en_core_web_sm' is available; otherwise, download it
48
- try:
49
- spacy_model_path = "/home/user/app/en_core_web_sm-3.8.0"
50
- nlp = spacy.load(spacy_model_path)
51
- except OSError:
52
- # print("Downloading 'en_core_web_sm' model...")
53
- # subprocess.run([sys.executable, "-m", "spacy", "download", "en_core_web_sm"], check=True)
54
- nlp = spacy.load(spacy_model_path)
55
-
56
- # ==========================
57
- # Pydantic Models
58
- # ==========================
59
- class PromptRequest(BaseModel):
60
- input_text: str
61
-
62
- class CombinedResponse(BaseModel):
63
- ner: Dict[str, Any]
64
- qa: Dict[str, Any]
65
- products_matched: List[Dict[str, Any]]
66
-
67
- # ==========================
68
- # Helper Functions
69
- # ==========================
70
- def extract_keywords(text: str) -> List[str]:
71
- """
72
- Extract keywords (nouns and proper nouns) using spaCy.
73
- """
74
- doc = nlp(text)
75
- keywords = [token.text for token in doc if token.pos_ in ["NOUN", "PROPN"]]
76
- return list(set(keywords))
77
-
78
- def search_products_by_keywords(keywords: List[str]) -> List[Dict[str, Any]]:
79
- """
80
- Search MongoDB for products that match any of the extracted keywords.
81
- """
82
- regex_patterns = [{"name": {"$regex": keyword, "$options": "i"}} for keyword in keywords]
83
- query = {"$or": regex_patterns}
84
-
85
- matched_products = []
86
- cursor = products_collection.find(query)
87
- for product in cursor:
88
- matched_products.append({
89
- "id": str(product.get("_id", "")),
90
- "name": product.get("name", ""),
91
- "description": product.get("description", ""),
92
- "skuNumber": product.get("skuNumber", ""),
93
- "baseModel": product.get("baseModel", ""),
94
- })
95
-
96
- return matched_products
97
-
98
- def get_combined_context(products: List[Dict]) -> str:
99
- """
100
- Combine the static context with product descriptions fetched from MongoDB.
101
- """
102
- product_descriptions = " ".join([p["description"] for p in products if "description" in p and p["description"]])
103
- combined_context = f"{product_descriptions} {context_msg}"
104
- return combined_context
105
-
106
- # ==========================
107
- # FastAPI Endpoints
108
- # ==========================
109
- @app.get("/")
110
- async def root():
111
- return {"message": "Welcome to the NER and QA API!"}
112
-
113
- @app.post("/process/", response_model=CombinedResponse)
114
- async def process_prompt(request: PromptRequest):
115
- input_text = request.input_text
116
-
117
- # Step 1: Extract keywords using spaCy NER
118
- keywords = extract_keywords(input_text)
119
- ner_response = {"extracted_keywords": keywords}
120
-
121
- # Step 2: Search MongoDB for matching products
122
- products = search_products_by_keywords(keywords)
123
-
124
- # Step 3: Generate Combined Context
125
- combined_context = get_combined_context(products)
126
-
127
- # Step 4: Use Q&A Model
128
- if combined_context.strip(): # Ensure the combined context is not empty
129
- qa_input = {"question": input_text, "context": combined_context}
130
- qa_output = qa_pipeline(qa_input)
131
- qa_response = {
132
- "question": input_text,
133
- "answer": qa_output["answer"],
134
- "score": qa_output["score"]
135
- }
136
- else:
137
- qa_response = {
138
- "question": input_text,
139
- "answer": "No relevant context available.",
140
- "score": 0.0
141
- }
142
-
143
- # Step 5: Return Combined Response
144
- return {
145
- "ner": ner_response,
146
- "qa": qa_response,
147
- "products_matched": products
148
- }
 
1
+ from fastapi import FastAPI, HTTPException
2
+ from pydantic import BaseModel
3
+ from typing import List, Dict, Any
4
+ from pymongo import MongoClient
5
+ from transformers import BlenderbotTokenizer, BlenderbotForConditionalGeneration
6
+ import spacy
7
+ import os
8
+ import logging
9
+
10
+ # Set up logging
11
+ logging.basicConfig(level=logging.INFO)
12
+ logger = logging.getLogger(__name__)
13
+
14
+ app = FastAPI()
15
+
16
+ # MongoDB Setup
17
+ connection_string = os.getenv("MONGO_URI", "mongodb+srv://clician:[email protected]/?retryWrites=true&w=majority&appName=Hutterdev")
18
+ client = MongoClient(connection_string)
19
+ db = client["test"]
20
+ products_collection = db["products"]
21
+
22
+ # BlenderBot Setup
23
+ model_name = "SyedHutter/blenderbot_model/blenderbot_model" # Points to subdirectory
24
+ model_dir = "/home/user/app/blenderbot_model"
25
+
26
+ if not os.path.exists(model_dir):
27
+ logger.info(f"Downloading {model_name} to {model_dir}...")
28
+ tokenizer = BlenderbotTokenizer.from_pretrained(model_name)
29
+ model = BlenderbotForConditionalGeneration.from_pretrained(model_name)
30
+ os.makedirs(model_dir, exist_ok=True)
31
+ tokenizer.save_pretrained(model_dir)
32
+ model.save_pretrained(model_dir)
33
+ logger.info("Model download complete.")
34
+ else:
35
+ logger.info(f"Loading pre-existing model from {model_dir}.")
36
+
37
+ tokenizer = BlenderbotTokenizer.from_pretrained(model_dir)
38
+ model = BlenderbotForConditionalGeneration.from_pretrained(model_dir)
39
+
40
+ # Static Context
41
+ context_msg = "Hutter Products GmbH provides sustainable products like shirts and shorts..."
42
+
43
+ # spaCy Setup
44
+ spacy_model_path = "/home/user/app/en_core_web_sm-3.8.0"
45
+ nlp = spacy.load(spacy_model_path)
46
+
47
+ # Pydantic Models
48
+ class PromptRequest(BaseModel):
49
+ input_text: str
50
+ conversation_history: List[str] = []
51
+
52
+ class CombinedResponse(BaseModel):
53
+ ner: Dict[str, Any]
54
+ qa: Dict[str, Any]
55
+ products_matched: List[Dict[str, Any]]
56
+
57
+ # Helper Functions
58
+ def extract_keywords(text: str) -> List[str]:
59
+ doc = nlp(text)
60
+ keywords = [token.text for token in doc if token.pos_ in ["NOUN", "PROPN"]]
61
+ return list(set(keywords))
62
+
63
+ def detect_intent(text: str) -> str:
64
+ doc = nlp(text.lower())
65
+ if any(token.text in ["shirt", "shirts"] for token in doc):
66
+ return "recommend_shirt"
67
+ elif any(token.text in ["short", "shorts"] for token in doc):
68
+ return "recommend_shorts"
69
+ elif any(token.text in ["what", "who", "company", "do", "products"] for token in doc):
70
+ return "company_info"
71
+ return "unknown"
72
+
73
+ def search_products_by_keywords(keywords: List[str]) -> List[Dict[str, Any]]:
74
+ query = {"$or": [{"name": {"$regex": keyword, "$options": "i"}} for keyword in keywords]}
75
+ matched_products = [dict(p, id=str(p["_id"])) for p in products_collection.find(query)]
76
+ return matched_products
77
+
78
+ def get_product_context(products: List[Dict]) -> str:
79
+ if not products:
80
+ return ""
81
+ product_str = "Here are some products: "
82
+ product_str += ", ".join([f"'{p['name']}' (SKU: {p['skuNumber']}) - {p['description']}" for p in products[:2]])
83
+ return product_str
84
+
85
+ def format_response(response: str, products: List[Dict], intent: str) -> str:
86
+ if intent in ["recommend_shirt", "recommend_shorts"] and products:
87
+ product = products[0]
88
+ return f"{response} For example, check out our '{product['name']}' (SKU: {product['skuNumber']})—it’s {product['description'].lower()}!"
89
+ elif intent == "company_info":
90
+ return f"{response} At Hutter Products GmbH, we specialize in sustainable product design and production!"
91
+ return response
92
+
93
+ # Endpoints
94
+ @app.get("/")
95
+ async def root():
96
+ return {"message": "Welcome to the NER and Chat API!"}
97
+
98
+ @app.post("/process/", response_model=CombinedResponse)
99
+ async def process_prompt(request: PromptRequest):
100
+ try:
101
+ input_text = request.input_text
102
+ history = request.conversation_history[-3:] if request.conversation_history else []
103
+
104
+ intent = detect_intent(input_text)
105
+ keywords = extract_keywords(input_text)
106
+ ner_response = {"extracted_keywords": keywords}
107
+
108
+ products = search_products_by_keywords(keywords)
109
+ product_context = get_product_context(products)
110
+
111
+ history_str = " || ".join(history)
112
+ full_input = f"{history_str} || {product_context} {context_msg} || {input_text}" if history else f"{product_context} {context_msg} || {input_text}"
113
+ inputs = tokenizer(full_input, return_tensors="pt", truncation=True, max_length=512)
114
+ outputs = model.generate(**inputs, max_length=150, num_beams=5, no_repeat_ngram_size=2)
115
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
116
+
117
+ enhanced_response = format_response(response, products, intent)
118
+ qa_response = {
119
+ "question": input_text,
120
+ "answer": enhanced_response,
121
+ "score": 1.0
122
+ }
123
+
124
+ return {
125
+ "ner": ner_response,
126
+ "qa": qa_response,
127
+ "products_matched": products
128
+ }
129
+ except Exception as e:
130
+ raise HTTPException(status_code=500, detail=f"Oops, something went wrong: {str(e)}. Try again!")
131
+
132
+ @app.on_event("startup")
133
+ async def startup_event():
134
+ logger.info("API is running with BlenderBot-400M-distill, connected to MongoDB.")
135
+
136
+ @app.on_event("shutdown")
137
+ def shutdown_event():
138
+ client.close()