KeerthiVM's picture
Fix added
87c2216
import nest_asyncio
nest_asyncio.apply()
import streamlit as st
import torchvision.transforms as transforms
import torch
import torch.nn as nn
import numpy as np
from PIL import Image
import torch.nn.functional as F
from evo_vit import EvoViTModel
import io
import os
import cohere
from fpdf import FPDF
from torchvision.models import resnet50
from huggingface_hub import hf_hub_download
from langchain_openai import OpenAIEmbeddings, ChatOpenAI
from SkinCancerDiagnosis import initialize_classifier
from rag_pipeline import (
available_models,
initialize_llm,
load_rag_chain,
get_reranked_response,
initialize_rag_components
)
from langchain_core.messages import HumanMessage, AIMessage
from groq import Groq
import google.generativeai as genai
device='cuda' if torch.cuda.is_available() else 'cpu'
st.set_page_config(page_title="DermBOT", page_icon="🧬", layout="centered")
@st.cache_resource(show_spinner=False)
def load_models():
"""Cache all models to load only once"""
with st.spinner("Loading all AI models (one-time operation)..."):
models = {
'classifier': initialize_classifier(),
'rag_components': initialize_rag_components(),
'llm': initialize_llm(st.session_state["selected_model"])
}
models['rag_chain'] = load_rag_chain(models['llm'])
return models
if "selected_model" not in st.session_state:
st.session_state["selected_model"] = available_models[0]
previous_model = st.session_state.get("selected_model", available_models[0])
st.session_state["selected_model"] = st.sidebar.selectbox(
"Select LLM Model",
available_models,
index=available_models.index(st.session_state["selected_model"])
)
if 'app_models' not in st.session_state:
st.session_state.app_models = load_models()
classifier = st.session_state.app_models['classifier']
llm = st.session_state.app_models['llm']
if "model_change_confirmed" not in st.session_state:
st.session_state.model_change_confirmed = False
if st.session_state["selected_model"] != previous_model:
if st.session_state.messages:
st.session_state.model_change_confirmed = False # Reset confirmation state
with st.sidebar:
st.warning("Changing models will clear current conversation.")
col1, col2 = st.columns(2)
with col1:
if st.button("Confirm Change", key="confirm_model_change"):
st.session_state.messages = []
st.session_state.current_image = None
st.session_state.model_change_confirmed = True
st.rerun()
with col2:
if st.button("Cancel", key="cancel_model_change"):
st.session_state["selected_model"] = previous_model
st.rerun()
else:
st.session_state.model_change_confirmed = True
if "model_change_confirmed" not in st.session_state or st.session_state.model_change_confirmed:
st.session_state.app_models['llm'] = initialize_llm(st.session_state["selected_model"])
st.session_state.app_models['rag_chain'] = load_rag_chain(st.session_state.app_models['llm'])
llm = st.session_state.app_models['llm']
else:
pass
# === Session Init ===
if "messages" not in st.session_state:
st.session_state.messages = []
if "current_image" not in st.session_state:
st.session_state.current_image = None
# === Image Processing Function ===
def run_inference(image):
result = classifier.predict(image, top_k=1)
predicted_label = result["top_predictions"][0][0]
predicted_label_multi = classifier.predict_skincon(image, top_k=1)
return predicted_label, predicted_label_multi
# === PDF Export ===
def export_chat_to_pdf(messages):
pdf = FPDF()
pdf.add_page()
pdf.set_font("Arial", size=12)
for msg in messages:
role = "You" if msg["role"] == "user" else "AI"
pdf.multi_cell(0, 10, f"{role}: {msg['content']}\n")
buf = io.BytesIO()
pdf.output(buf)
buf.seek(0)
return buf
# === App UI ===
st.title("🧬 DermBOT β€” Skin AI Assistant")
st.caption(f"🧠 Using model: {st.session_state['selected_model']}")
uploaded_file = st.file_uploader(
"Upload a skin image",
type=["jpg", "jpeg", "png"],
key="file_uploader"
)
if uploaded_file is not None and uploaded_file != st.session_state.current_image:
st.session_state.messages = []
st.session_state.current_image = uploaded_file
image = Image.open(uploaded_file).convert("RGB")
st.image(image, caption="Uploaded image", use_column_width=True)
with st.spinner("Analyzing the image..."):
predicted_label, predicted_label_multi = run_inference(image)
st.markdown(f"🧾 **Skin Issues**: {', '.join(predicted_label_multi)}")
st.markdown(f" Most Likely Diagnosis : {predicted_label}")
initial_query = f"What are my treatment options for {predicted_label} & {predicted_label_multi}?"
st.session_state.messages.append({"role": "user", "content": initial_query})
with st.spinner("Retrieving medical information..."):
response = get_reranked_response(initial_query, st.session_state.app_models['llm'], st.session_state.app_models['rag_components'])
st.session_state.messages.append({"role": "assistant", "content": response})
for message in st.session_state.messages:
with st.chat_message(message["role"]):
st.markdown(message["content"])
# === Chat Interface ===
if prompt := st.chat_input("Ask a follow-up question..."):
st.session_state.messages.append({"role": "user", "content": prompt})
with st.chat_message("user"):
st.markdown(prompt)
with st.chat_message("assistant"):
with st.spinner("Thinking..."):
if len(st.session_state.messages) > 1:
conversation_context = "\n".join(
f"{m['role']}: {m['content']}"
for m in st.session_state.messages[:-1] # Exclude current prompt
)
augmented_prompt = (
f"Conversation history:\n{conversation_context}\n\n"
f"Current question: {prompt}"
)
response = get_reranked_response(augmented_prompt, st.session_state.app_models['llm'], st.session_state.app_models['rag_components'])
else:
response = get_reranked_response(prompt, st.session_state.app_models['llm'], st.session_state.app_models['rag_components'])
st.markdown(response)
st.session_state.messages.append({"role": "assistant", "content": response})
if st.session_state.messages and st.button("πŸ“„ Download Chat as PDF"):
pdf_file = export_chat_to_pdf(st.session_state.messages)
st.download_button(
"Download PDF",
data=pdf_file,
file_name="dermbot_chat_history.pdf",
mime="application/pdf"
)