Spaces:
Running
Running
File size: 6,914 Bytes
4310148 50e8a0a dc9062b 50e8a0a 75b0792 79cab30 dc9062b ae2bc75 dc9062b 50e8a0a 4310148 50e8a0a c808cd0 dc9062b c808cd0 8bad604 c808cd0 50e8a0a dc9062b 8bad604 2860bc8 8bad604 2860bc8 8bad604 2860bc8 8bad604 2860bc8 dc9062b 2860bc8 dc9062b c808cd0 50e8a0a c808cd0 50e8a0a 87c2216 50e8a0a c808cd0 50e8a0a c808cd0 87c2216 50e8a0a 87c2216 50e8a0a 87c2216 c808cd0 dc9062b 50e8a0a c808cd0 50e8a0a c808cd0 50e8a0a ef73894 9905e70 dc9062b ef73894 dc9062b ef73894 dc9062b 50e8a0a c808cd0 50e8a0a c808cd0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 |
import nest_asyncio
nest_asyncio.apply()
import streamlit as st
import torchvision.transforms as transforms
import torch
import torch.nn as nn
import numpy as np
from PIL import Image
import torch.nn.functional as F
from evo_vit import EvoViTModel
import io
import os
import cohere
from fpdf import FPDF
from torchvision.models import resnet50
from huggingface_hub import hf_hub_download
from langchain_openai import OpenAIEmbeddings, ChatOpenAI
from SkinCancerDiagnosis import initialize_classifier
from rag_pipeline import (
available_models,
initialize_llm,
load_rag_chain,
get_reranked_response,
initialize_rag_components
)
from langchain_core.messages import HumanMessage, AIMessage
from groq import Groq
import google.generativeai as genai
device='cuda' if torch.cuda.is_available() else 'cpu'
st.set_page_config(page_title="DermBOT", page_icon="π§¬", layout="centered")
@st.cache_resource(show_spinner=False)
def load_models():
"""Cache all models to load only once"""
with st.spinner("Loading all AI models (one-time operation)..."):
models = {
'classifier': initialize_classifier(),
'rag_components': initialize_rag_components(),
'llm': initialize_llm(st.session_state["selected_model"])
}
models['rag_chain'] = load_rag_chain(models['llm'])
return models
if "selected_model" not in st.session_state:
st.session_state["selected_model"] = available_models[0]
previous_model = st.session_state.get("selected_model", available_models[0])
st.session_state["selected_model"] = st.sidebar.selectbox(
"Select LLM Model",
available_models,
index=available_models.index(st.session_state["selected_model"])
)
if 'app_models' not in st.session_state:
st.session_state.app_models = load_models()
classifier = st.session_state.app_models['classifier']
llm = st.session_state.app_models['llm']
if "model_change_confirmed" not in st.session_state:
st.session_state.model_change_confirmed = False
if st.session_state["selected_model"] != previous_model:
if st.session_state.messages:
st.session_state.model_change_confirmed = False # Reset confirmation state
with st.sidebar:
st.warning("Changing models will clear current conversation.")
col1, col2 = st.columns(2)
with col1:
if st.button("Confirm Change", key="confirm_model_change"):
st.session_state.messages = []
st.session_state.current_image = None
st.session_state.model_change_confirmed = True
st.rerun()
with col2:
if st.button("Cancel", key="cancel_model_change"):
st.session_state["selected_model"] = previous_model
st.rerun()
else:
st.session_state.model_change_confirmed = True
if "model_change_confirmed" not in st.session_state or st.session_state.model_change_confirmed:
st.session_state.app_models['llm'] = initialize_llm(st.session_state["selected_model"])
st.session_state.app_models['rag_chain'] = load_rag_chain(st.session_state.app_models['llm'])
llm = st.session_state.app_models['llm']
else:
pass
# === Session Init ===
if "messages" not in st.session_state:
st.session_state.messages = []
if "current_image" not in st.session_state:
st.session_state.current_image = None
# === Image Processing Function ===
def run_inference(image):
result = classifier.predict(image, top_k=1)
predicted_label = result["top_predictions"][0][0]
predicted_label_multi = classifier.predict_skincon(image, top_k=1)
return predicted_label, predicted_label_multi
# === PDF Export ===
def export_chat_to_pdf(messages):
pdf = FPDF()
pdf.add_page()
pdf.set_font("Arial", size=12)
for msg in messages:
role = "You" if msg["role"] == "user" else "AI"
pdf.multi_cell(0, 10, f"{role}: {msg['content']}\n")
buf = io.BytesIO()
pdf.output(buf)
buf.seek(0)
return buf
# === App UI ===
st.title("𧬠DermBOT β Skin AI Assistant")
st.caption(f"π§ Using model: {st.session_state['selected_model']}")
uploaded_file = st.file_uploader(
"Upload a skin image",
type=["jpg", "jpeg", "png"],
key="file_uploader"
)
if uploaded_file is not None and uploaded_file != st.session_state.current_image:
st.session_state.messages = []
st.session_state.current_image = uploaded_file
image = Image.open(uploaded_file).convert("RGB")
st.image(image, caption="Uploaded image", use_column_width=True)
with st.spinner("Analyzing the image..."):
predicted_label, predicted_label_multi = run_inference(image)
st.markdown(f"π§Ύ **Skin Issues**: {', '.join(predicted_label_multi)}")
st.markdown(f" Most Likely Diagnosis : {predicted_label}")
initial_query = f"What are my treatment options for {predicted_label} & {predicted_label_multi}?"
st.session_state.messages.append({"role": "user", "content": initial_query})
with st.spinner("Retrieving medical information..."):
response = get_reranked_response(initial_query, st.session_state.app_models['llm'], st.session_state.app_models['rag_components'])
st.session_state.messages.append({"role": "assistant", "content": response})
for message in st.session_state.messages:
with st.chat_message(message["role"]):
st.markdown(message["content"])
# === Chat Interface ===
if prompt := st.chat_input("Ask a follow-up question..."):
st.session_state.messages.append({"role": "user", "content": prompt})
with st.chat_message("user"):
st.markdown(prompt)
with st.chat_message("assistant"):
with st.spinner("Thinking..."):
if len(st.session_state.messages) > 1:
conversation_context = "\n".join(
f"{m['role']}: {m['content']}"
for m in st.session_state.messages[:-1] # Exclude current prompt
)
augmented_prompt = (
f"Conversation history:\n{conversation_context}\n\n"
f"Current question: {prompt}"
)
response = get_reranked_response(augmented_prompt, st.session_state.app_models['llm'], st.session_state.app_models['rag_components'])
else:
response = get_reranked_response(prompt, st.session_state.app_models['llm'], st.session_state.app_models['rag_components'])
st.markdown(response)
st.session_state.messages.append({"role": "assistant", "content": response})
if st.session_state.messages and st.button("π Download Chat as PDF"):
pdf_file = export_chat_to_pdf(st.session_state.messages)
st.download_button(
"Download PDF",
data=pdf_file,
file_name="dermbot_chat_history.pdf",
mime="application/pdf"
) |