Spaces:
Running
Running
import nest_asyncio | |
nest_asyncio.apply() | |
import streamlit as st | |
import torchvision.transforms as transforms | |
import torch | |
import torch.nn as nn | |
import numpy as np | |
from PIL import Image | |
import torch.nn.functional as F | |
from evo_vit import EvoViTModel | |
import io | |
import os | |
import cohere | |
from fpdf import FPDF | |
from torchvision.models import resnet50 | |
from huggingface_hub import hf_hub_download | |
from langchain_openai import OpenAIEmbeddings, ChatOpenAI | |
from SkinCancerDiagnosis import initialize_classifier | |
from rag_pipeline import ( | |
available_models, | |
initialize_llm, | |
load_rag_chain, | |
get_reranked_response, | |
initialize_rag_components | |
) | |
from langchain_core.messages import HumanMessage, AIMessage | |
from groq import Groq | |
import google.generativeai as genai | |
device='cuda' if torch.cuda.is_available() else 'cpu' | |
st.set_page_config(page_title="DermBOT", page_icon="π§¬", layout="centered") | |
def load_models(): | |
"""Cache all models to load only once""" | |
with st.spinner("Loading all AI models (one-time operation)..."): | |
models = { | |
'classifier': initialize_classifier(), | |
'rag_components': initialize_rag_components(), | |
'llm': initialize_llm(st.session_state["selected_model"]) | |
} | |
models['rag_chain'] = load_rag_chain(models['llm']) | |
return models | |
if "selected_model" not in st.session_state: | |
st.session_state["selected_model"] = available_models[0] | |
previous_model = st.session_state.get("selected_model", available_models[0]) | |
st.session_state["selected_model"] = st.sidebar.selectbox( | |
"Select LLM Model", | |
available_models, | |
index=available_models.index(st.session_state["selected_model"]) | |
) | |
if 'app_models' not in st.session_state: | |
st.session_state.app_models = load_models() | |
classifier = st.session_state.app_models['classifier'] | |
llm = st.session_state.app_models['llm'] | |
if "model_change_confirmed" not in st.session_state: | |
st.session_state.model_change_confirmed = False | |
if st.session_state["selected_model"] != previous_model: | |
if st.session_state.messages: | |
st.session_state.model_change_confirmed = False # Reset confirmation state | |
with st.sidebar: | |
st.warning("Changing models will clear current conversation.") | |
col1, col2 = st.columns(2) | |
with col1: | |
if st.button("Confirm Change", key="confirm_model_change"): | |
st.session_state.messages = [] | |
st.session_state.current_image = None | |
st.session_state.model_change_confirmed = True | |
st.rerun() | |
with col2: | |
if st.button("Cancel", key="cancel_model_change"): | |
st.session_state["selected_model"] = previous_model | |
st.rerun() | |
else: | |
st.session_state.model_change_confirmed = True | |
if "model_change_confirmed" not in st.session_state or st.session_state.model_change_confirmed: | |
st.session_state.app_models['llm'] = initialize_llm(st.session_state["selected_model"]) | |
st.session_state.app_models['rag_chain'] = load_rag_chain(st.session_state.app_models['llm']) | |
llm = st.session_state.app_models['llm'] | |
else: | |
pass | |
# === Session Init === | |
if "messages" not in st.session_state: | |
st.session_state.messages = [] | |
if "current_image" not in st.session_state: | |
st.session_state.current_image = None | |
# === Image Processing Function === | |
def run_inference(image): | |
result = classifier.predict(image, top_k=1) | |
predicted_label = result["top_predictions"][0][0] | |
predicted_label_multi = classifier.predict_skincon(image, top_k=1) | |
return predicted_label, predicted_label_multi | |
# === PDF Export === | |
def export_chat_to_pdf(messages): | |
pdf = FPDF() | |
pdf.add_page() | |
pdf.set_font("Arial", size=12) | |
for msg in messages: | |
role = "You" if msg["role"] == "user" else "AI" | |
pdf.multi_cell(0, 10, f"{role}: {msg['content']}\n") | |
buf = io.BytesIO() | |
pdf.output(buf) | |
buf.seek(0) | |
return buf | |
# === App UI === | |
st.title("𧬠DermBOT β Skin AI Assistant") | |
st.caption(f"π§ Using model: {st.session_state['selected_model']}") | |
uploaded_file = st.file_uploader( | |
"Upload a skin image", | |
type=["jpg", "jpeg", "png"], | |
key="file_uploader" | |
) | |
if uploaded_file is not None and uploaded_file != st.session_state.current_image: | |
st.session_state.messages = [] | |
st.session_state.current_image = uploaded_file | |
image = Image.open(uploaded_file).convert("RGB") | |
st.image(image, caption="Uploaded image", use_column_width=True) | |
with st.spinner("Analyzing the image..."): | |
predicted_label, predicted_label_multi = run_inference(image) | |
st.markdown(f"π§Ύ **Skin Issues**: {', '.join(predicted_label_multi)}") | |
st.markdown(f" Most Likely Diagnosis : {predicted_label}") | |
initial_query = f"What are my treatment options for {predicted_label} & {predicted_label_multi}?" | |
st.session_state.messages.append({"role": "user", "content": initial_query}) | |
with st.spinner("Retrieving medical information..."): | |
response = get_reranked_response(initial_query, st.session_state.app_models['llm'], st.session_state.app_models['rag_components']) | |
st.session_state.messages.append({"role": "assistant", "content": response}) | |
for message in st.session_state.messages: | |
with st.chat_message(message["role"]): | |
st.markdown(message["content"]) | |
# === Chat Interface === | |
if prompt := st.chat_input("Ask a follow-up question..."): | |
st.session_state.messages.append({"role": "user", "content": prompt}) | |
with st.chat_message("user"): | |
st.markdown(prompt) | |
with st.chat_message("assistant"): | |
with st.spinner("Thinking..."): | |
if len(st.session_state.messages) > 1: | |
conversation_context = "\n".join( | |
f"{m['role']}: {m['content']}" | |
for m in st.session_state.messages[:-1] # Exclude current prompt | |
) | |
augmented_prompt = ( | |
f"Conversation history:\n{conversation_context}\n\n" | |
f"Current question: {prompt}" | |
) | |
response = get_reranked_response(augmented_prompt, st.session_state.app_models['llm'], st.session_state.app_models['rag_components']) | |
else: | |
response = get_reranked_response(prompt, st.session_state.app_models['llm'], st.session_state.app_models['rag_components']) | |
st.markdown(response) | |
st.session_state.messages.append({"role": "assistant", "content": response}) | |
if st.session_state.messages and st.button("π Download Chat as PDF"): | |
pdf_file = export_chat_to_pdf(st.session_state.messages) | |
st.download_button( | |
"Download PDF", | |
data=pdf_file, | |
file_name="dermbot_chat_history.pdf", | |
mime="application/pdf" | |
) |