File size: 6,914 Bytes
4310148
 
 
50e8a0a
 
 
 
 
 
 
 
 
 
dc9062b
50e8a0a
 
75b0792
79cab30
 
dc9062b
 
 
 
 
 
 
ae2bc75
dc9062b
 
50e8a0a
4310148
50e8a0a
 
 
c808cd0
 
 
 
dc9062b
 
 
 
 
 
 
 
c808cd0
 
 
 
8bad604
 
c808cd0
 
 
 
 
50e8a0a
dc9062b
 
 
 
8bad604
2860bc8
 
 
8bad604
 
2860bc8
 
8bad604
2860bc8
 
 
 
 
 
 
 
 
 
 
8bad604
2860bc8
 
 
dc9062b
 
 
2860bc8
dc9062b
c808cd0
50e8a0a
 
 
 
 
c808cd0
 
50e8a0a
 
 
 
 
87c2216
 
50e8a0a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c808cd0
 
 
 
 
 
 
 
 
 
50e8a0a
 
c808cd0
 
87c2216
50e8a0a
87c2216
50e8a0a
 
87c2216
c808cd0
 
dc9062b
 
 
50e8a0a
c808cd0
 
 
50e8a0a
 
c808cd0
50e8a0a
 
 
 
 
ef73894
 
9905e70
 
 
 
 
 
 
 
dc9062b
ef73894
dc9062b
ef73894
dc9062b
 
50e8a0a
c808cd0
50e8a0a
c808cd0
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
import nest_asyncio
nest_asyncio.apply()

import streamlit as st
import torchvision.transforms as transforms
import torch
import torch.nn as nn
import numpy as np
from PIL import Image
import torch.nn.functional as F
from evo_vit import EvoViTModel
import io
import os
import cohere
from fpdf import FPDF
from torchvision.models import resnet50
from huggingface_hub import hf_hub_download
from langchain_openai import OpenAIEmbeddings, ChatOpenAI
from SkinCancerDiagnosis import initialize_classifier
from rag_pipeline import (
    available_models,
    initialize_llm,
    load_rag_chain,
    get_reranked_response,
initialize_rag_components
)
from langchain_core.messages import HumanMessage, AIMessage
from groq import Groq
import google.generativeai as genai


device='cuda' if torch.cuda.is_available() else 'cpu'
st.set_page_config(page_title="DermBOT", page_icon="🧬", layout="centered")


@st.cache_resource(show_spinner=False)
def load_models():
    """Cache all models to load only once"""
    with st.spinner("Loading all AI models (one-time operation)..."):
        models = {
            'classifier': initialize_classifier(),
            'rag_components': initialize_rag_components(),
            'llm': initialize_llm(st.session_state["selected_model"])
        }
        models['rag_chain'] = load_rag_chain(models['llm'])
        return models

if "selected_model" not in st.session_state:
    st.session_state["selected_model"] = available_models[0]

previous_model = st.session_state.get("selected_model", available_models[0])

st.session_state["selected_model"] = st.sidebar.selectbox(
    "Select LLM Model",
    available_models,
    index=available_models.index(st.session_state["selected_model"])
)

if 'app_models' not in st.session_state:
    st.session_state.app_models = load_models()
classifier = st.session_state.app_models['classifier']
llm = st.session_state.app_models['llm']

if "model_change_confirmed" not in st.session_state:
    st.session_state.model_change_confirmed = False

if st.session_state["selected_model"] != previous_model:
    if st.session_state.messages:
        st.session_state.model_change_confirmed = False  # Reset confirmation state
        with st.sidebar:
            st.warning("Changing models will clear current conversation.")
            col1, col2 = st.columns(2)
            with col1:
                if st.button("Confirm Change", key="confirm_model_change"):
                    st.session_state.messages = []
                    st.session_state.current_image = None
                    st.session_state.model_change_confirmed = True
                    st.rerun()
            with col2:
                if st.button("Cancel", key="cancel_model_change"):
                    st.session_state["selected_model"] = previous_model
                    st.rerun()
    else:
        st.session_state.model_change_confirmed = True

if "model_change_confirmed" not in st.session_state or st.session_state.model_change_confirmed:
    st.session_state.app_models['llm'] = initialize_llm(st.session_state["selected_model"])
    st.session_state.app_models['rag_chain'] = load_rag_chain(st.session_state.app_models['llm'])
    llm = st.session_state.app_models['llm']
else:
    pass


# === Session Init ===
if "messages" not in st.session_state:
    st.session_state.messages = []

if "current_image" not in st.session_state:
    st.session_state.current_image = None

# === Image Processing Function ===
def run_inference(image):
    result = classifier.predict(image, top_k=1)
    predicted_label = result["top_predictions"][0][0]
    predicted_label_multi = classifier.predict_skincon(image, top_k=1)
    return predicted_label, predicted_label_multi


# === PDF Export ===
def export_chat_to_pdf(messages):
    pdf = FPDF()
    pdf.add_page()
    pdf.set_font("Arial", size=12)
    for msg in messages:
        role = "You" if msg["role"] == "user" else "AI"
        pdf.multi_cell(0, 10, f"{role}: {msg['content']}\n")
    buf = io.BytesIO()
    pdf.output(buf)
    buf.seek(0)
    return buf


# === App UI ===

st.title("🧬 DermBOT β€” Skin AI Assistant")
st.caption(f"🧠 Using model: {st.session_state['selected_model']}")
uploaded_file = st.file_uploader(
    "Upload a skin image",
    type=["jpg", "jpeg", "png"],
    key="file_uploader"
)

if uploaded_file is not None and uploaded_file != st.session_state.current_image:
    st.session_state.messages = []
    st.session_state.current_image = uploaded_file

    image = Image.open(uploaded_file).convert("RGB")
    st.image(image, caption="Uploaded image", use_column_width=True)
    with st.spinner("Analyzing the image..."):
        predicted_label, predicted_label_multi = run_inference(image)

    st.markdown(f"🧾 **Skin Issues**: {', '.join(predicted_label_multi)}")
    st.markdown(f" Most Likely Diagnosis : {predicted_label}")

    initial_query = f"What are my treatment options for {predicted_label} & {predicted_label_multi}?"
    st.session_state.messages.append({"role": "user", "content": initial_query})
    with st.spinner("Retrieving medical information..."):
        response = get_reranked_response(initial_query, st.session_state.app_models['llm'], st.session_state.app_models['rag_components'])
        st.session_state.messages.append({"role": "assistant", "content": response})


for message in st.session_state.messages:
    with st.chat_message(message["role"]):
        st.markdown(message["content"])

# === Chat Interface ===
if prompt := st.chat_input("Ask a follow-up question..."):
    st.session_state.messages.append({"role": "user", "content": prompt})
    with st.chat_message("user"):
        st.markdown(prompt)

    with st.chat_message("assistant"):
        with st.spinner("Thinking..."):
            if len(st.session_state.messages) > 1:
                conversation_context = "\n".join(
                    f"{m['role']}: {m['content']}"
                    for m in st.session_state.messages[:-1]  # Exclude current prompt
                )
                augmented_prompt = (
                    f"Conversation history:\n{conversation_context}\n\n"
                    f"Current question: {prompt}"
                )
                response = get_reranked_response(augmented_prompt, st.session_state.app_models['llm'], st.session_state.app_models['rag_components'])
            else:
                response = get_reranked_response(prompt, st.session_state.app_models['llm'], st.session_state.app_models['rag_components'])

            st.markdown(response)
            st.session_state.messages.append({"role": "assistant", "content": response})

if st.session_state.messages and st.button("πŸ“„ Download Chat as PDF"):
    pdf_file = export_chat_to_pdf(st.session_state.messages)
    st.download_button(
        "Download PDF",
        data=pdf_file,
        file_name="dermbot_chat_history.pdf",
        mime="application/pdf"
    )