import gradio as gr
import time
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import os
# ---- LOAD LLM ----
model_name = "Qwen/Qwen1.5-0.5B"
# No need for token usually; Qwen is public, but keeping it flexible
hf_token = os.getenv("HF_TOKEN")
tokenizer = AutoTokenizer.from_pretrained(
model_name,
token=hf_token, # can be None if not set
trust_remote_code=True # required for Qwen
)
device = "cuda" if torch.cuda.is_available() else "cpu"
model = AutoModelForCausalLM.from_pretrained(
model_name,
token=hf_token,
trust_remote_code=True,
torch_dtype=torch.float16 if device == "cuda" else torch.float32,
device_map="auto" if device == "cuda" else None
).to(device)
# --- Define llm generation function ---
def llm(prompt, max_new_tokens=400, temperature=0.3, do_sample=True):
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
generation_kwargs = {
"max_new_tokens": max_new_tokens,
"do_sample": do_sample,
"pad_token_id": tokenizer.eos_token_id,
}
# Only add temperature/top_p if sampling is enabled
if do_sample:
generation_kwargs.update({
"temperature": temperature,
"top_p": 0.95,
"top_k": 50
})
output = model.generate(
**inputs,
**generation_kwargs
)
generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
return [{"generated_text": generated_text}]
# Define all the screening questions
questions = [
# Generalized Anxiety & Somatic Concerns
# =========================
("Have you been a chronic worrier for 6 months or more?", "क्या आप 6 महीने या उससे अधिक समय से लगातार चिंता कर रहे हैं?"),
("Have you been preoccupied with worries about work, family, or health for the past 6 months?", "क्या आप पिछले 6 महीनों से काम, परिवार या स्वास्थ्य को लेकर चिंतित रहे हैं?"),
("Have you had frequent headaches, body pain, or fatigue for several weeks without a clear physical cause?", "क्या आपको पिछले कई हफ्तों से बिना स्पष्ट शारीरिक कारण के सिरदर्द, बदन दर्द या थकान रही है?"),
("Have you experienced anxiety symptoms like palpitations, choking, or dry mouth for the past 6 months?", "क्या आपको पिछले 6 महीनों से घबराहट जैसे लक्षण (जैसे धड़कन तेज होना, गला बंद लगना, मुंह सूखना) महसूस हुए हैं?"),
("Have these anxiety symptoms occurred in most situations over the past 6 months?", "क्या ये घबराहट के लक्षण पिछले 6 महीनों में ज्यादातर स्थितियों में हुए हैं?"),
("Have you had difficulty concentrating or thinking clearly for at least 2 weeks?", "क्या आपको पिछले 2 हफ्तों से ध्यान केंद्रित करने या स्पष्ट रूप से सोचने में कठिनाई हुई है?"),
("Do you often have difficulty staying focused during tasks or conversations?", "क्या आपको कार्य करते समय या बातचीत के दौरान ध्यान केंद्रित करने में कठिनाई होती है?"),
("Do you frequently lose items or forget daily responsibilities (e.g., appointments, bills)?", "क्या आप अक्सर चीजें खो देते हैं या दैनिक ज़िम्मेदारियाँ भूल जाते हैं (जैसे अपॉइंटमेंट, बिल)?"),
("Do you struggle to organize tasks or manage your time effectively?", "क्या आपको कार्यों को व्यवस्थित करने या समय का सही उपयोग करने में कठिनाई होती है?"),
("Do you feel restless or find it hard to sit still for extended periods?", "क्या आप बेचैनी महसूस करते हैं या लंबे समय तक शांत बैठना मुश्किल लगता है?"),
("Do you often interrupt others or speak out without waiting your turn?", "क्या आप अक्सर दूसरों की बात काटते हैं या बिना रुके बोल पड़ते हैं?"),
("Do you procrastinate or avoid tasks that require sustained mental effort?", "क्या आप ऐसे कार्यों को टालते हैं जिनमें लंबे समय तक ध्यान केंद्रित करना होता है?"),
("Have these difficulties been present since childhood and continue to affect your work or relationships?", "क्या ये कठिनाइयाँ बचपन से रही हैं और अब भी आपके काम या रिश्तों को प्रभावित कर रही हैं?"),
]
# ---- STATE ----
state = {
"index": 0,
"responses": [""] * len(questions)
}
# ---- FUNCTIONS ----
def render_question():
idx = state["index"]
total = len(questions)
en, hi = questions[idx]
question_html = f"**Q{idx+1}. {en}**
{hi}"
progress = f"Progress: Question {idx+1} of {total}"
return (
gr.update(value=question_html),
gr.update(value=progress),
gr.update(visible=True), gr.update(visible=True),
gr.update(visible=(idx > 0)),
gr.update(visible=False), gr.update(visible=False), gr.update(visible=False)
)
def next_step(response):
state["responses"][state["index"]] = response
state["index"] += 1
if state["index"] < len(questions):
return render_question()
else:
return (
gr.update(value="✅ All questions completed. Click 'Submit for AI Analysis'."), # (question_display)
gr.update(value=""), # (progress_bar)
gr.update(visible=False), # yes_btn
gr.update(visible=False), # no_btn
gr.update(visible=False), # back_btn
gr.update(visible=True), # result_btn (Submit Button becomes visible here!)
gr.update(visible=False), # result_box (Textbox hidden)
gr.update(visible=False) # restart_btn hidden
)
def format_yes_responses():
yes_topics = []
for (en, _), ans in zip(questions, state["responses"]):
if ans.lower() == "yes":
yes_topics.append(en)
if not yes_topics:
return "No significant symptoms reported."
return "\n".join(yes_topics)
def run_final_analysis():
yield (
gr.update(value="⏳ Please wait... analyzing your responses 🧠"), # question_display
gr.update(value=""), # progress_bar
*[gr.update(visible=False) for _ in range(4)], # yes_btn, no_btn, back_btn, result_btn
gr.update(visible=False), # result_box
gr.update(visible=False) # restart_btn
)
time.sleep(1)
yes_summary = format_yes_responses()
prompt = (
f"""The user has reported the following symptoms: {yes_summary}
Based on these symptoms, please write a short clinical impression summarizing the likely psychiatric condition and its further management.
On the basis of symptomes mentioned
"""
)
output = llm(prompt, max_new_tokens=300, temperature=0.3, do_sample=False)
ai_result = output[0]["generated_text"]
yield (
gr.update(value="✅ AI Analysis Completed."), # question_display
gr.update(value=""), # progress_bar
*[gr.update(visible=False) for _ in range(4)], # yes_btn, no_btn, back_btn, result_btn
gr.update(value=ai_result, visible=True), # result_box (important: both value+visible)
gr.update(visible=True) # restart_btn
)
def go_back():
if state["index"] > 0:
state["index"] -= 1
return render_question()
def start_app():
state["index"] = 0
state["responses"] = [""] * len(questions)
return render_question()
def restart_screening():
state["index"] = 0
state["responses"] = [""] * len(questions)
return render_question()
# ---- GRADIO APP ----
with gr.Blocks(theme=gr.themes.Soft()) as app:
gr.Markdown("## 🧠 MindScreen: Mental Health Self-Screening")
gr.Markdown("### मानसिक स्वास्थ्य आत्म-स्क्रीनिंग प्रश्नावली")
gr.Markdown(
"**Note: Please choose 'Yes' only if the symptoms cause distress to you or your family, or if they interfere with your day-to-day functioning.**\n\n"
"**नोट: कृपया 'हाँ' तभी चुनें जब ये लक्षण आपके या आपके परिवार के लिए परेशानी का कारण बन रहे हों, या आपकी दिनचर्या को प्रभावित करते हों।**"
)
progress_bar = gr.Markdown("")
question_display = gr.Markdown("", elem_id="question-box")
with gr.Row():
yes_btn = gr.Button("Yes / हाँ", visible=False)
no_btn = gr.Button("No / नहीं", visible=False)
with gr.Row():
back_btn = gr.Button("⬅️ Back / पीछे जाएं", visible=False)
result_btn = gr.Button("🎯 Submit for AI Analysis", visible=False)
with gr.Row():
restart_btn = gr.Button("🔄 Start New Screening", visible=False) # <-- Restart button
with gr.Row():
result_box = gr.Textbox(
label="🧠 AI Interpretation (English + Hindi)",
visible=False,
lines=20,
max_lines=25,
show_copy_button=True,
interactive=False
)
app.load(start_app, outputs=[question_display, progress_bar, yes_btn, no_btn, back_btn, result_btn, result_box, restart_btn])
yes_btn.click(lambda: next_step("Yes"), outputs=[question_display, progress_bar, yes_btn, no_btn, back_btn, result_btn, result_box, restart_btn])
no_btn.click(lambda: next_step("No"), outputs=[question_display, progress_bar, yes_btn, no_btn, back_btn, result_btn, result_box, restart_btn])
back_btn.click(go_back, outputs=[question_display, progress_bar, yes_btn, no_btn, back_btn, result_btn, result_box, restart_btn])
result_btn.click(run_final_analysis, outputs=[question_display, progress_bar, yes_btn, no_btn, back_btn, result_btn, result_box, restart_btn], concurrency_limit=1)
restart_btn.click(restart_screening, outputs=[question_display, progress_bar, yes_btn, no_btn, back_btn, result_btn, result_box, restart_btn])
app.launch()