|
import os |
|
import json |
|
import streamlit as st |
|
from groq import Groq |
|
from PIL import Image, UnidentifiedImageError, ExifTags |
|
import requests |
|
from io import BytesIO |
|
from transformers import pipeline |
|
from final_captioner import generate_final_caption |
|
import hashlib |
|
|
|
|
|
st.title("PicSamvaad : Image Conversational Chatbot") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
os.environ["GROQ_API_KEY"] = GROQ_API_KEY |
|
|
|
client = Groq() |
|
|
|
|
|
with st.sidebar: |
|
st.header("Upload Image or Enter URL") |
|
|
|
uploaded_file = st.file_uploader( |
|
"Upload an image to chat...", type=["jpg", "jpeg", "png"] |
|
) |
|
url = st.text_input("Or enter a valid image URL...") |
|
|
|
image = None |
|
error_message = None |
|
|
|
|
|
def correct_image_orientation(img): |
|
try: |
|
for orientation in ExifTags.TAGS.keys(): |
|
if ExifTags.TAGS[orientation] == "Orientation": |
|
break |
|
exif = img._getexif() |
|
if exif is not None: |
|
orientation = exif[orientation] |
|
if orientation == 3: |
|
img = img.rotate(180, expand=True) |
|
elif orientation == 6: |
|
img = img.rotate(270, expand=True) |
|
elif orientation == 8: |
|
img = img.rotate(90, expand=True) |
|
except (AttributeError, KeyError, IndexError): |
|
pass |
|
return img |
|
|
|
|
|
def get_image_hash(image): |
|
|
|
img_bytes = image.tobytes() |
|
return hashlib.md5(img_bytes).hexdigest() |
|
|
|
|
|
|
|
if "last_uploaded_hash" not in st.session_state: |
|
st.session_state.last_uploaded_hash = None |
|
|
|
if uploaded_file is not None: |
|
image = Image.open(uploaded_file) |
|
image_hash = get_image_hash(image) |
|
|
|
if st.session_state.last_uploaded_hash != image_hash: |
|
st.session_state.chat_history = [] |
|
st.session_state.last_uploaded_hash = image_hash |
|
|
|
image = correct_image_orientation(image) |
|
st.image(image, caption="Uploaded Image.", use_column_width=True) |
|
|
|
elif url: |
|
try: |
|
response = requests.get(url) |
|
response.raise_for_status() |
|
image = Image.open(BytesIO(response.content)) |
|
image_hash = get_image_hash(image) |
|
|
|
if st.session_state.last_uploaded_hash != image_hash: |
|
st.session_state.chat_history = [] |
|
st.session_state.last_uploaded_hash = ( |
|
image_hash |
|
) |
|
|
|
image = correct_image_orientation(image) |
|
st.image(image, caption="Image from URL.", use_column_width=True) |
|
except (requests.exceptions.RequestException, UnidentifiedImageError) as e: |
|
image = None |
|
error_message = "Error: The provided URL is invalid or the image could not be loaded. Sometimes some image URLs don't work. We suggest you upload the downloaded image instead ;)" |
|
|
|
caption = "" |
|
if image is not None: |
|
caption += generate_final_caption(image) |
|
st.write("ChatBot : " + caption) |
|
|
|
|
|
if error_message: |
|
st.error(error_message) |
|
|
|
|
|
if "chat_history" not in st.session_state: |
|
st.session_state.chat_history = [] |
|
|
|
|
|
for message in st.session_state.chat_history: |
|
with st.chat_message(message["role"]): |
|
st.markdown(message["content"]) |
|
|
|
|
|
user_prompt = st.chat_input("Ask the Chatbot about the image...") |
|
|
|
if user_prompt: |
|
st.chat_message("user").markdown(user_prompt) |
|
st.session_state.chat_history.append({"role": "user", "content": user_prompt}) |
|
|
|
|
|
messages = [ |
|
{ |
|
"role": "system", |
|
"content": "You are a helpful, accurate image conversational assistant. You don't hallucinate, and your answers are very precise and have a positive approach.The caption of the image is: " |
|
+ caption, |
|
}, |
|
*st.session_state.chat_history, |
|
] |
|
|
|
response = client.chat.completions.create( |
|
model="llama-3.1-8b-instant", messages=messages |
|
) |
|
|
|
assistant_response = response.choices[0].message.content |
|
st.session_state.chat_history.append( |
|
{"role": "assistant", "content": assistant_response} |
|
) |
|
|
|
|
|
with st.chat_message("assistant"): |
|
st.markdown(assistant_response) |
|
|