|
import eventlet |
|
eventlet.monkey_patch() |
|
|
|
import pandas as pd |
|
import json |
|
from PIL import Image |
|
import numpy as np |
|
|
|
import os |
|
from pathlib import Path |
|
|
|
import torch |
|
import torch.nn.functional as F |
|
|
|
from src.model.blip_embs import blip_embs |
|
from src.data.transforms import transform_test |
|
|
|
from transformers import StoppingCriteria, StoppingCriteriaList |
|
import gradio as gr |
|
|
|
from langchain.chains import ConversationChain |
|
from langchain_community.chat_message_histories import ChatMessageHistory |
|
from langchain_core.runnables import RunnableWithMessageHistory |
|
from langchain_core.output_parsers import StrOutputParser |
|
from langchain_core.prompts import ChatPromptTemplate |
|
from langchain_groq import ChatGroq |
|
|
|
from dotenv import load_dotenv |
|
from flask import Flask, request, render_template |
|
from flask_cors import CORS |
|
from flask_socketio import SocketIO, emit |
|
|
|
import json |
|
from openai import OpenAI |
|
|
|
load_dotenv(".env") |
|
USER_AGENT = os.getenv("USER_AGENT") |
|
GROQ_API_KEY = os.getenv("GROQ_API_KEY") |
|
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") |
|
SECRET_KEY = os.getenv("SECRET_KEY") |
|
|
|
|
|
os.environ['USER_AGENT'] = USER_AGENT |
|
os.environ["GROQ_API_KEY"] = GROQ_API_KEY |
|
os.environ['OPENAI_API_KEY'] = OPENAI_API_KEY |
|
os.environ["TOKENIZERS_PARALLELISM"] = 'true' |
|
|
|
|
|
app = Flask(__name__) |
|
CORS(app) |
|
app.config['MAX_CONTENT_LENGTH'] = 1024 * 1024 * 1024 |
|
socketio = SocketIO(app, cors_allowed_origins="*", logger=True, max_http_buffer_size=1024 * 1024 * 1024) |
|
app.config['SECRET_KEY'] = SECRET_KEY |
|
|
|
|
|
llm = ChatGroq(model="llama-3.1-8b-instant", temperature=0, max_tokens=1024, max_retries=2) |
|
|
|
|
|
json_llm = ChatGroq(model="llama-3.1-70b-versatile", temperature=0, max_tokens=1024, max_retries=2, model_kwargs={"response_format": {"type": "json_object"}}) |
|
|
|
|
|
router = ChatGroq(model="llama-3.2-3b-preview", temperature=0, max_tokens=1024, max_retries=2, model_kwargs={"response_format": {"type": "json_object"}}) |
|
|
|
|
|
answer_formatter = ChatGroq(model="llama-3.1-8b-instant", temperature=0, max_tokens=1024, max_retries=2) |
|
|
|
|
|
client = OpenAI() |
|
|
|
class StoppingCriteriaSub(StoppingCriteria): |
|
|
|
def __init__(self, stops=[], encounters=1): |
|
super().__init__() |
|
self.stops = stops |
|
|
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor): |
|
for stop in self.stops: |
|
if torch.all(input_ids[:, -len(stop):] == stop).item(): |
|
return True |
|
|
|
return False |
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
def get_blip_config(model="base"): |
|
config = dict() |
|
if model == "base": |
|
config[ |
|
"pretrained" |
|
] = "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_capfilt_large.pth " |
|
config["vit"] = "base" |
|
config["batch_size_train"] = 32 |
|
config["batch_size_test"] = 16 |
|
config["vit_grad_ckpt"] = True |
|
config["vit_ckpt_layer"] = 4 |
|
config["init_lr"] = 1e-5 |
|
elif model == "large": |
|
config[ |
|
"pretrained" |
|
] = "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_retrieval_coco.pth" |
|
config["vit"] = "large" |
|
config["batch_size_train"] = 16 |
|
config["batch_size_test"] = 32 |
|
config["vit_grad_ckpt"] = True |
|
config["vit_ckpt_layer"] = 12 |
|
config["init_lr"] = 5e-6 |
|
|
|
config["image_size"] = 384 |
|
config["queue_size"] = 57600 |
|
config["alpha"] = 0.4 |
|
config["k_test"] = 256 |
|
config["negative_all_rank"] = True |
|
|
|
return config |
|
|
|
print("Creating model") |
|
config = get_blip_config("large") |
|
|
|
model = blip_embs( |
|
pretrained=config["pretrained"], |
|
image_size=config["image_size"], |
|
vit=config["vit"], |
|
vit_grad_ckpt=config["vit_grad_ckpt"], |
|
vit_ckpt_layer=config["vit_ckpt_layer"], |
|
queue_size=config["queue_size"], |
|
negative_all_rank=config["negative_all_rank"], |
|
) |
|
|
|
model = model.to(device) |
|
model.eval() |
|
|
|
transform = transform_test(384) |
|
|
|
df = pd.read_json("my_recipes.json") |
|
|
|
tar_img_feats = [] |
|
for _id in df["id_"].tolist(): |
|
tar_img_feats.append(torch.load("./datasets/sidechef/blip-embs-large/{:07d}.pth".format(_id)).unsqueeze(0)) |
|
|
|
tar_img_feats = torch.cat(tar_img_feats, dim=0) |
|
|
|
class Chat: |
|
|
|
def __init__(self, model, transform, dataframe, tar_img_feats, device='cuda:0', stopping_criteria=None): |
|
self.device = device |
|
self.model = model |
|
self.transform = transform |
|
self.df = dataframe |
|
self.tar_img_feats = tar_img_feats |
|
self.img_feats = None |
|
self.target_recipe = None |
|
self.messages = [] |
|
|
|
if stopping_criteria is not None: |
|
self.stopping_criteria = stopping_criteria |
|
else: |
|
stop_words_ids = [torch.tensor([2]).to(self.device)] |
|
self.stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)]) |
|
|
|
def encode_image(self, image_path): |
|
img = Image.fromarray(image_path).convert("RGB") |
|
img = self.transform(img).unsqueeze(0) |
|
img = img.to(self.device) |
|
img_embs = model.visual_encoder(img) |
|
img_feats = F.normalize(model.vision_proj(img_embs[:, 0, :]), dim=-1).cpu() |
|
|
|
self.img_feats = img_feats |
|
|
|
self.get_target(self.img_feats, self.tar_img_feats) |
|
|
|
def get_target(self, img_feats, tar_img_feats) : |
|
score = (img_feats @ tar_img_feats.t()).squeeze(0).cpu().detach().numpy() |
|
index = np.argsort(score)[::-1][0] |
|
self.target_recipe = df.iloc[index] |
|
|
|
def ask(self): |
|
return json.dumps(self.target_recipe.to_json()) |
|
|
|
chat = Chat(model,transform,df,tar_img_feats, device) |
|
|
|
def answer_generator(formated_input, session_id): |
|
|
|
qa_system_prompt = """ |
|
You are an AI assistant developed by Nutrigenics AI. Your purpose is to help users by providing accurate and relevant answers to their questions. |
|
Operational Guidelines: |
|
|
|
1. Input Structure: |
|
- Context: You may receive contextual information related to recipes or other topics. |
|
- User Query: Users will pose questions or requests on various topics. |
|
|
|
2. Response Strategy: |
|
- Utilize Provided Context: If the context contains relevant information that addresses the user's query, base your response on this provided data. |
|
- Respond to User Query Directly: If the context does not contain the necessary information, answer the question to the best of your ability. |
|
|
|
Output Format: |
|
- Provide clear and concise answers. |
|
- Format your response in JSON with a key 'content' containing your answer. |
|
|
|
Additional Instructions: |
|
- Precision and Personalization: Always aim to provide precise, personalized, and relevant information. |
|
- Clarity and Coherence: Ensure all responses are clear, well-structured, and easy to understand. |
|
- Do not mention about the context in the response, format the answer in a natural and friendly way. |
|
""" |
|
qa_prompt = ChatPromptTemplate.from_messages( |
|
[ |
|
("system", qa_system_prompt), |
|
("human", "{input}") |
|
] |
|
) |
|
|
|
|
|
base_chain = qa_prompt | llm | StrOutputParser() |
|
|
|
|
|
question_answer_chain = RunnableWithMessageHistory( |
|
base_chain, |
|
lambda session_id: ChatMessageHistory(), |
|
input_messages_key="input", |
|
history_messages_key="chat_history" |
|
) |
|
|
|
response = question_answer_chain.invoke(formated_input, config={"configurable": {"session_id": session_id}}) |
|
|
|
return response |
|
|
|
def json_answer_generator(user_query, context): |
|
system_prompt = """ |
|
Given a context in JSON format, respond to user queries by extracting and returning the requested information in JSON format with an additional `"header"` key containing a response starter. Use the following rules: |
|
|
|
1. **Information Extraction**: |
|
- If the user query explicitly requests specific data (e.g., ingredients, nutrients, or instructions), return only those JSON objects from the provided context. |
|
- Include `"header": "Here is the information you requested:"` at the start of each response. |
|
|
|
2. **General Responses**: |
|
- If the query is not directly related to the context, provide a helpful and accurate answer. |
|
- Include `"header": "Here is your answer:"` at the start of the response. |
|
- Return a JSON object with a single key `"content"` and your response as its value. |
|
|
|
Try to format the output as a JSON object with key-value pairs. |
|
""" |
|
|
|
formatted_input = f""" |
|
User Query: {user_query} |
|
Context: |
|
{context} |
|
""" |
|
response = json_llm.invoke( |
|
[SystemMessage(content=system_prompt)] |
|
+ [ |
|
HumanMessage( |
|
content=formatted_input |
|
) |
|
] |
|
) |
|
res = json.loads(response.content) |
|
return res |
|
|
|
def router_node(query): |
|
|
|
router_instructions = """You are an expert at determining the appropriate task for a user’s question based on chat history and the current query context. You have three available tasks: |
|
1. Retrieval: Fetch information based on user's chat history and current query. |
|
2. Recommendation/Suggestion: Recommend recipes to users based on the query. |
|
3. General: Answer general questions not related to recipes or the current context. |
|
Return a JSON response with a single key named “task” indicating either “retrieval”, “recommendation”, or “general” based on your decision. |
|
""" |
|
response = router.invoke( |
|
[SystemMessage(content=router_instructions)] |
|
+ [ |
|
HumanMessage( |
|
content=query |
|
) |
|
] |
|
) |
|
res = json.loads(response.content) |
|
return res['task'] |
|
|
|
def recommendation_node(query): |
|
prompt = """ |
|
You are a helpful assistant that writes Python code to filter recipes from a JSON file based on the user query. |
|
|
|
JSON file path = 'recipes.json' |
|
|
|
The JSON file is a list of recipes with the following structure: |
|
{ |
|
"recipe_name": string, |
|
"recipe_time": integer, |
|
"recipe_yields": string, |
|
"recipe_ingredients": list of ingredients, |
|
"recipe_instructions": list of instructions, |
|
"recipe_image": string, |
|
"blogger": string, |
|
"recipe_nutrients": JSON object with key-value pairs such as "protein: 10g", |
|
"tags": list of tags related to a recipe |
|
} |
|
|
|
Based on the user query, provide a Python function to filter the JSON data. The output of the function should be a list of JSON objects. |
|
|
|
Recipe filtering instructions: |
|
- If a user asked for the highest nutrient recipe such as "high protein or high calories" then filtered recipes should be the top highest recipes from all the recipes with high nutrient. |
|
- Sort or rearrange recipes based on which recipes are more appropriate for the user. |
|
|
|
Your output instructions: |
|
- The function name should be filter_recipes. The input to the function should be file name. |
|
- The length of output recipes should not be more than 6. |
|
- Only give me the output function. Do not call the function. |
|
- Give the python function as a key named "code" in a JSON format. |
|
- Do not include any other text with the output, only give python code. |
|
- If you do not follow the above given instructions, the chat may be terminated. |
|
""" |
|
max_tries = 3 |
|
while True: |
|
try: |
|
response = client.chat.completions.create( |
|
model="gpt-4o-mini", |
|
messages=[ |
|
{"role": "system", "content": prompt}, |
|
{ |
|
"role": "user", |
|
"content": query |
|
} |
|
] |
|
) |
|
|
|
content = response.choices[0].message.content |
|
|
|
res = json.loads(content) |
|
script = res['code'] |
|
exec(script, globals()) |
|
filtered_recipes = filter_recipes('my_recipes.json') |
|
if len(filtered_recipes) > 0: |
|
return filtered_recipes |
|
except Exception as e: |
|
print(e) |
|
if max_tries <= 0: |
|
return [{"content": "max-retries reach"}] |
|
else: |
|
max_tries -= 1 |
|
return filtered_recipes |
|
|
|
def answer_formatter_node(question, context): |
|
prompt = f"""You are a highly clever question-answering assistant trained to provide clear and concise answers based on the user query and provided context. |
|
Your task is to generate answers for the user query based on the context provided. |
|
Instructions for your response: |
|
1. Directly answer the user query using only the information provided in the context. |
|
2. Ensure your response is clear and concise. |
|
3. Mention only details related to the recipe, including the recipe name, instructions, nutrients, yield, ingredients, and image. |
|
4. Do not include any information that is not related to the recipe context. |
|
Please format an answer based on the following user question and context provided: |
|
User Question: |
|
{question} |
|
Context: |
|
{context} |
|
""" |
|
response = answer_formatter.invoke( |
|
[SystemMessage(content=prompt)] |
|
) |
|
res = response.content |
|
return res |
|
|
|
def reguar_answer_node(question, context): |
|
prompt = f"""You are a highly clever question-answering assistant trained to provide clear and concise answers based on the user query and provided context. |
|
Your task is to generate answers for the user query based on the context provided. |
|
Instructions for your response: |
|
1. Directly answer the user query. Make use of provided context if necessary. |
|
2. Ensure your response is clear and concise. |
|
3. Give the answer in JSON format with a single key named 'content' with value as your response. |
|
4. It is important to give response in JSON format, otherwise the chat may terminate. |
|
Please format an answer based on the following user question and context provided: |
|
User Question: |
|
{question} |
|
Context: |
|
{context} |
|
""" |
|
response = answer_formatter.invoke( |
|
[SystemMessage(content=prompt)] |
|
) |
|
res = response.content |
|
return res |
|
|
|
def general_answer_node(question): |
|
prompt = f"""You are an assistant that provides helpful and accurate answers to any question. Please answer the following question in a JSON format with a single key 'content' containing your answer. |
|
|
|
Question: {question} |
|
""" |
|
response = llm.invoke( |
|
[SystemMessage(content=prompt)] |
|
) |
|
try: |
|
res = json.loads(response.content) |
|
return res |
|
except json.JSONDecodeError: |
|
|
|
return {'content': response.content} |
|
|
|
CURR_CONTEXT = '' |
|
|
|
def get_answer(image=[], message='', sessionID='abc123'): |
|
global CURR_CONTEXT |
|
if len(image) > 0: |
|
try: |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
chat = Chat(model,transform,df,tar_img_feats, device) |
|
chat.encode_image(image) |
|
data = chat.ask() |
|
CURR_CONTEXT = data |
|
formated_input = { |
|
'input': message, |
|
'context': data |
|
} |
|
response = json_answer_generator(message, data) |
|
except Exception as e: |
|
print(e) |
|
response = {'content':"An error occurred while processing your request."} |
|
elif len(image) == 0 and message is not None: |
|
task = router_node(message) |
|
if task == 'retrieval': |
|
formated_input = { |
|
'input': message, |
|
'context': CURR_CONTEXT |
|
} |
|
response = json_answer_generator(message, CURR_CONTEXT) |
|
elif task == "recommendation": |
|
recipes = recommendation_node(message) |
|
if not recipes: |
|
response = {'content':"An error occurred while processing your request."} |
|
response = answer_formatter_node(message, recipes) |
|
elif task == "general": |
|
response = general_answer_node(message) |
|
if response is None: |
|
response = {'content':"An error occurred while processing your request."} |
|
else: |
|
response = {'content':"Sorry, I didn't understand your request."} |
|
else: |
|
response = {'content':"Please provide a message to process."} |
|
|
|
return response |
|
|
|
|
|
@socketio.on('ping') |
|
def handle_ping(): |
|
emit('Ping-return', {'message': 'Connected'}, room=request.sid) |
|
|
|
|
|
@socketio.on('connect') |
|
def handle_connect(): |
|
print(f"Client connected: {request.sid}") |
|
|
|
|
|
@socketio.on('disconnect') |
|
def handle_disconnect(): |
|
print(f"Client disconnected: {request.sid}") |
|
|
|
import base64 |
|
from io import BytesIO |
|
import torchvision.transforms as transforms |
|
|
|
|
|
session_store = {} |
|
|
|
@socketio.on('message') |
|
def handle_message(data): |
|
global session_store |
|
global CURR_CONTEXT |
|
context = "No data available" |
|
session_id = request.sid |
|
if session_id not in session_store: |
|
session_store[session_id] = {'image_data': b"", 'message': None, 'image_received': False} |
|
|
|
if 'message' in data: |
|
session_store[session_id]['message'] = data['message'] |
|
|
|
|
|
if 'image' in data: |
|
try: |
|
|
|
session_store[session_id]['image_data'] += data['image'] |
|
|
|
except Exception as e: |
|
print(f"Error processing image chunk: {str(e)}") |
|
emit('response', "An error occurred while receiving the image chunk.", room=session_id) |
|
return |
|
|
|
if session_store[session_id]['image_data'] or session_store[session_id]['message']: |
|
try: |
|
image_bytes = session_store[session_id]['image_data'] |
|
if isinstance(image_bytes, str): |
|
image_bytes = base64.b64decode(image_bytes) |
|
image = Image.open(BytesIO(image_bytes)) |
|
image_array = np.array(image) |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
chat = Chat(model, transform, df, tar_img_feats, device) |
|
chat.encode_image(image_array) |
|
context = chat.ask() |
|
CURR_CONTEXT = context |
|
message = data['message'] |
|
formated_input = { |
|
'input': message, |
|
'context': json.dumps(context) |
|
} |
|
response = json_answer_generator(message, context) |
|
emit('response', response, room=session_id) |
|
|
|
except Exception as e: |
|
print(f"Error processing image or message: {str(e)}") |
|
emit('response', "An error occurred while processing your request.", room=session_id) |
|
return |
|
finally: |
|
|
|
session_store.pop(session_id, None) |
|
else: |
|
message = data['message'] |
|
task = router_node(message) |
|
if task == 'retrieval': |
|
formated_input = { |
|
'input': message, |
|
'context': CURR_CONTEXT |
|
} |
|
response = json_answer_generator(message, CURR_CONTEXT) |
|
emit('response', response, room=session_id) |
|
elif task == "recommendation": |
|
recipes = recommendation_node(message) |
|
if not recipes: |
|
response = {'content':"An error occurred while processing your request."} |
|
response = answer_formatter_node(message, recipes) |
|
emit('json_response', response, room=session_id) |
|
elif task == "general": |
|
response = general_answer_node(message) |
|
if response is None: |
|
response = {'content':"An error occurred while processing your request."} |
|
emit('json_response', response, room=session_id) |
|
else: |
|
response = {'content':"Sorry, I didn't understand your request."} |
|
emit('json_response', response, room=session_id) |
|
session_store.pop(session_id, None) |
|
|
|
import base64 |
|
import numpy as np |
|
from io import BytesIO |
|
from PIL import Image |
|
|
|
def base64_to_numpy(base64_string): |
|
|
|
image_data = base64.b64decode(base64_string) |
|
|
|
|
|
image = Image.open(BytesIO(image_data)) |
|
|
|
|
|
image_np = np.array(image) |
|
|
|
return image_np |
|
|
|
@socketio.on('example') |
|
def handle_message(data): |
|
img_url = data['img_url'] |
|
message = data['message'] |
|
session_id = request.sid |
|
image_array = base64_to_numpy(img_url) |
|
response = get_answer(image=image_array, message=message, sessionID=request.sid) |
|
emit('response', response, room=session_id) |
|
return response |
|
|
|
|
|
@app.route("/") |
|
def index_view(): |
|
return render_template('chat.html') |
|
|
|
|
|
if __name__ == '__main__': |
|
socketio.run(app, debug=True) |
|
|