|
import pandas as pd |
|
import json |
|
from PIL import Image |
|
import numpy as np |
|
|
|
import os |
|
from pathlib import Path |
|
|
|
import torch |
|
import torch.nn.functional as F |
|
|
|
|
|
from src.model.blip_embs import blip_embs |
|
from src.data.transforms import transform_test |
|
|
|
from transformers import StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer |
|
import gradio as gr |
|
import spaces |
|
|
|
from langchain_core.output_parsers import StrOutputParser |
|
from langchain_core.prompts import ChatPromptTemplate |
|
from langchain_groq import ChatGroq |
|
|
|
from dotenv import load_dotenv |
|
import asyncio |
|
from flask import Flask, request, render_template |
|
from flask_cors import CORS |
|
from flask_socketio import SocketIO, emit, join_room, leave_room |
|
|
|
|
|
|
|
GROQ_API_KEY = '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' |
|
load_dotenv(".env") |
|
USER_AGENT = os.getenv("USER_AGENT") |
|
GROQ_API_KEY = os.getenv("GROQ_API_KEY") |
|
SECRET_KEY = os.getenv("SECRET_KEY") |
|
|
|
|
|
|
|
os.environ['USER_AGENT'] = USER_AGENT |
|
os.environ["GROQ_API_KEY"] = GROQ_API_KEY |
|
os.environ["TOKENIZERS_PARALLELISM"] = 'true' |
|
|
|
|
|
app = Flask(__name__) |
|
CORS(app) |
|
socketio = SocketIO(app, cors_allowed_origins="*") |
|
app.config['SESSION_COOKIE_SECURE'] = True |
|
app.config['SESSION_COOKIE_HTTPONLY'] = True |
|
app.config['SESSION_COOKIE_SAMESITE'] = 'Lax' |
|
app.config['SECRET_KEY'] = SECRET_KEY |
|
|
|
|
|
llm = ChatGroq(model="llama-3.1-8b-instant", temperature=0, max_tokens=1024, max_retries=2) |
|
|
|
|
|
qa_system_prompt = """ |
|
Prompt: |
|
You are a highly intelligent assistant. Use the following context to answer user questions. Analyze the data carefully and generate a clear, concise, and informative response to the user's question based on this data. |
|
|
|
Response Guidelines: |
|
- Use only the information provided in the data to answer the question. |
|
- Ensure the answer is accurate and directly related to the question. |
|
- If the data is insufficient to answer the question, politey apologise and tell the user that there is insufficient data available to answer their question. |
|
- Provide the response in a conversational yet professional tone. |
|
|
|
Context: |
|
{context} |
|
""" |
|
qa_prompt = ChatPromptTemplate.from_messages( |
|
[ |
|
("system", qa_system_prompt), |
|
("human", "{input}") |
|
] |
|
) |
|
|
|
question_answer_chain = qa_prompt | llm | StrOutputParser() |
|
|
|
|
|
class StoppingCriteriaSub(StoppingCriteria): |
|
|
|
def __init__(self, stops=[], encounters=1): |
|
super().__init__() |
|
self.stops = stops |
|
|
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor): |
|
for stop in self.stops: |
|
if torch.all(input_ids[:, -len(stop):] == stop).item(): |
|
return True |
|
|
|
return False |
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
def get_blip_config(model="base"): |
|
config = dict() |
|
if model == "base": |
|
config[ |
|
"pretrained" |
|
] = "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_capfilt_large.pth " |
|
config["vit"] = "base" |
|
config["batch_size_train"] = 32 |
|
config["batch_size_test"] = 16 |
|
config["vit_grad_ckpt"] = True |
|
config["vit_ckpt_layer"] = 4 |
|
config["init_lr"] = 1e-5 |
|
elif model == "large": |
|
config[ |
|
"pretrained" |
|
] = "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_retrieval_coco.pth" |
|
config["vit"] = "large" |
|
config["batch_size_train"] = 16 |
|
config["batch_size_test"] = 32 |
|
config["vit_grad_ckpt"] = True |
|
config["vit_ckpt_layer"] = 12 |
|
config["init_lr"] = 5e-6 |
|
|
|
config["image_size"] = 384 |
|
config["queue_size"] = 57600 |
|
config["alpha"] = 0.4 |
|
config["k_test"] = 256 |
|
config["negative_all_rank"] = True |
|
|
|
return config |
|
|
|
|
|
print("Creating model") |
|
config = get_blip_config("large") |
|
|
|
model = blip_embs( |
|
pretrained=config["pretrained"], |
|
image_size=config["image_size"], |
|
vit=config["vit"], |
|
vit_grad_ckpt=config["vit_grad_ckpt"], |
|
vit_ckpt_layer=config["vit_ckpt_layer"], |
|
queue_size=config["queue_size"], |
|
negative_all_rank=config["negative_all_rank"], |
|
) |
|
|
|
model = model.to(device) |
|
model.eval() |
|
print("Model Loaded !") |
|
print("="*50) |
|
|
|
transform = transform_test(384) |
|
|
|
print("Loading Data") |
|
df = pd.read_json("datasets/sidechef/my_recipes.json") |
|
|
|
print("Loading Target Embedding") |
|
tar_img_feats = [] |
|
for _id in df["id_"].tolist(): |
|
tar_img_feats.append(torch.load("datasets/sidechef/blip-embs-large/{:07d}.pth".format(_id)).unsqueeze(0)) |
|
|
|
tar_img_feats = torch.cat(tar_img_feats, dim=0) |
|
|
|
|
|
class Chat: |
|
|
|
def __init__(self, model, transform, dataframe, tar_img_feats, device='cuda:0', stopping_criteria=None): |
|
self.device = device |
|
self.model = model |
|
self.transform = transform |
|
self.df = dataframe |
|
self.tar_img_feats = tar_img_feats |
|
self.img_feats = None |
|
self.target_recipe = None |
|
self.messages = [] |
|
|
|
if stopping_criteria is not None: |
|
self.stopping_criteria = stopping_criteria |
|
else: |
|
stop_words_ids = [torch.tensor([2]).to(self.device)] |
|
self.stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)]) |
|
|
|
def encode_image(self, image_path): |
|
img = Image.fromarray(image_path).convert("RGB") |
|
img = self.transform(img).unsqueeze(0) |
|
img = img.to(self.device) |
|
img_embs = model.visual_encoder(img) |
|
img_feats = F.normalize(model.vision_proj(img_embs[:, 0, :]), dim=-1).cpu() |
|
|
|
self.img_feats = img_feats |
|
|
|
self.get_target(self.img_feats, self.tar_img_feats) |
|
|
|
def get_target(self, img_feats, tar_img_feats) : |
|
score = (img_feats @ tar_img_feats.t()).squeeze(0).cpu().detach().numpy() |
|
index = np.argsort(score)[::-1][0] + 1 |
|
self.target_recipe = df.iloc[index] |
|
|
|
def ask(self): |
|
return json.dumps(self.target_recipe.to_json()) |
|
|
|
|
|
|
|
chat = Chat(model,transform,df,tar_img_feats, device) |
|
print("Chat Initialized !") |
|
|
|
|
|
custom_css = """ |
|
.primary{ |
|
background-color: #4CAF50; /* Green */ |
|
} |
|
""" |
|
|
|
@spaces.GPU |
|
def respond_to_user(image, message): |
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
chat = Chat(model,transform,df,tar_img_feats, device) |
|
chat.encode_image(image) |
|
data = chat.ask() |
|
formated_input = { |
|
'input': message, |
|
'context': data |
|
} |
|
try: |
|
response = question_answer_chain.invoke(formated_input) |
|
except Exception as e: |
|
response = {'content':"An error occurred while processing your request."} |
|
return response |
|
|
|
iface = gr.Interface( |
|
fn=respond_to_user, |
|
inputs=[gr.Image(), gr.Textbox(label="Ask Query")], |
|
outputs=gr.Textbox(label="Nutrition-GPT"), |
|
title="Nutrition-GPT Demo", |
|
description="Upload an food image and ask queries!", |
|
css=".component-12 {background-color: red}", |
|
) |
|
|
|
iface.launch() |