import nest_asyncio nest_asyncio.apply() import streamlit as st import torchvision.transforms as transforms import torch import torch.nn as nn import numpy as np from PIL import Image import torch.nn.functional as F from evo_vit import EvoViTModel import io import os import cohere from fpdf import FPDF from torchvision.models import resnet50 from huggingface_hub import hf_hub_download from langchain_openai import OpenAIEmbeddings, ChatOpenAI from SkinCancerDiagnosis import initialize_classifier from rag_pipeline import ( available_models, initialize_llm, load_rag_chain, get_reranked_response, initialize_rag_components ) from langchain_core.messages import HumanMessage, AIMessage from groq import Groq import google.generativeai as genai device='cuda' if torch.cuda.is_available() else 'cpu' st.set_page_config(page_title="DermBOT", page_icon="๐Ÿงฌ", layout="centered") @st.cache_resource(show_spinner=False) def load_models(): """Cache all models to load only once""" with st.spinner("Loading all AI models (one-time operation)..."): models = { 'classifier': initialize_classifier(), 'rag_components': initialize_rag_components(), 'llm': initialize_llm(st.session_state["selected_model"]) } models['rag_chain'] = load_rag_chain(models['llm']) return models if "selected_model" not in st.session_state: st.session_state["selected_model"] = available_models[0] previous_model = st.session_state.get("selected_model", available_models[0]) st.session_state["selected_model"] = st.sidebar.selectbox( "Select LLM Model", available_models, index=available_models.index(st.session_state["selected_model"]) ) if 'app_models' not in st.session_state: st.session_state.app_models = load_models() classifier = st.session_state.app_models['classifier'] llm = st.session_state.app_models['llm'] if "model_change_confirmed" not in st.session_state: st.session_state.model_change_confirmed = False if st.session_state["selected_model"] != previous_model: if st.session_state.messages: st.session_state.model_change_confirmed = False # Reset confirmation state with st.sidebar: st.warning("Changing models will clear current conversation.") col1, col2 = st.columns(2) with col1: if st.button("Confirm Change", key="confirm_model_change"): st.session_state.messages = [] st.session_state.current_image = None st.session_state.model_change_confirmed = True st.rerun() with col2: if st.button("Cancel", key="cancel_model_change"): st.session_state["selected_model"] = previous_model st.rerun() else: st.session_state.model_change_confirmed = True if "model_change_confirmed" not in st.session_state or st.session_state.model_change_confirmed: st.session_state.app_models['llm'] = initialize_llm(st.session_state["selected_model"]) st.session_state.app_models['rag_chain'] = load_rag_chain(st.session_state.app_models['llm']) llm = st.session_state.app_models['llm'] else: pass # === Session Init === if "messages" not in st.session_state: st.session_state.messages = [] if "current_image" not in st.session_state: st.session_state.current_image = None # === Image Processing Function === def run_inference(image): result = classifier.predict(image, top_k=1) predicted_label = result["top_predictions"][0][0] predicted_label_multi = classifier.predict_skincon(image, top_k=1) return predicted_label, predicted_label_multi # === PDF Export === def export_chat_to_pdf(messages): pdf = FPDF() pdf.add_page() pdf.set_font("Arial", size=12) for msg in messages: role = "You" if msg["role"] == "user" else "AI" pdf.multi_cell(0, 10, f"{role}: {msg['content']}\n") buf = io.BytesIO() pdf.output(buf) buf.seek(0) return buf # === App UI === st.title("๐Ÿงฌ DermBOT โ€” Skin AI Assistant") st.caption(f"๐Ÿง  Using model: {st.session_state['selected_model']}") uploaded_file = st.file_uploader( "Upload a skin image", type=["jpg", "jpeg", "png"], key="file_uploader" ) if uploaded_file is not None and uploaded_file != st.session_state.current_image: st.session_state.messages = [] st.session_state.current_image = uploaded_file image = Image.open(uploaded_file).convert("RGB") st.image(image, caption="Uploaded image", use_column_width=True) with st.spinner("Analyzing the image..."): predicted_label, predicted_label_multi = run_inference(image) # st.markdown(f"๐Ÿงพ **Skin Issues**: {', '.join(predicted_label_multi)}") st.markdown(f" Most Likely Diagnosis : {predicted_label}") # initial_query = f"What are my treatment options for {predicted_label} & {predicted_label_multi}?" initial_query = f"What are my treatment options for {predicted_label}?" st.session_state.messages.append({"role": "user", "content": initial_query}) with st.spinner("Retrieving medical information..."): response = get_reranked_response(initial_query, st.session_state.app_models['llm'], st.session_state.app_models['rag_components']) st.session_state.messages.append({"role": "assistant", "content": response}) for message in st.session_state.messages: with st.chat_message(message["role"]): st.markdown(message["content"]) # === Chat Interface === if prompt := st.chat_input("Ask a follow-up question..."): st.session_state.messages.append({"role": "user", "content": prompt}) with st.chat_message("user"): st.markdown(prompt) with st.chat_message("assistant"): with st.spinner("Thinking..."): if len(st.session_state.messages) > 1: conversation_context = "\n".join( f"{m['role']}: {m['content']}" for m in st.session_state.messages[:-1] # Exclude current prompt ) augmented_prompt = ( f"Conversation history:\n{conversation_context}\n\n" f"Current question: {prompt}" ) response = get_reranked_response(augmented_prompt, st.session_state.app_models['llm'], st.session_state.app_models['rag_components']) else: response = get_reranked_response(prompt, st.session_state.app_models['llm'], st.session_state.app_models['rag_components']) st.markdown(response) st.session_state.messages.append({"role": "assistant", "content": response}) if st.session_state.messages and st.button("๐Ÿ“„ Download Chat as PDF"): pdf_file = export_chat_to_pdf(st.session_state.messages) st.download_button( "Download PDF", data=pdf_file, file_name="dermbot_chat_history.pdf", mime="application/pdf" )