Spaces:
Running
Running
import gradio as gr | |
import time | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
import torch | |
import os | |
# ---- LOAD LLM ---- | |
model_name = "Qwen/Qwen1.5-0.5B" | |
# No need for token usually; Qwen is public, but keeping it flexible | |
hf_token = os.getenv("HF_TOKEN") | |
tokenizer = AutoTokenizer.from_pretrained( | |
model_name, | |
token=hf_token, # can be None if not set | |
trust_remote_code=True # required for Qwen | |
) | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
model = AutoModelForCausalLM.from_pretrained( | |
model_name, | |
token=hf_token, | |
trust_remote_code=True, | |
torch_dtype=torch.float16 if device == "cuda" else torch.float32, | |
device_map="auto" if device == "cuda" else None | |
).to(device) | |
# --- Define llm generation function --- | |
def llm(prompt, max_new_tokens=400, temperature=0.3, do_sample=True): | |
inputs = tokenizer(prompt, return_tensors="pt").to(model.device) | |
generation_kwargs = { | |
"max_new_tokens": max_new_tokens, | |
"do_sample": do_sample, | |
"pad_token_id": tokenizer.eos_token_id, | |
} | |
# Only add temperature/top_p if sampling is enabled | |
if do_sample: | |
generation_kwargs.update({ | |
"temperature": temperature, | |
"top_p": 0.95, | |
"top_k": 50 | |
}) | |
output = model.generate( | |
**inputs, | |
**generation_kwargs | |
) | |
generated_text = tokenizer.decode(output[0], skip_special_tokens=True) | |
return [{"generated_text": generated_text}] | |
# Define all the screening questions | |
questions = [ | |
# Generalized Anxiety & Somatic Concerns | |
# ========================= | |
("Have you been a chronic worrier for 6 months or more?", "क्या आप 6 महीने या उससे अधिक समय से लगातार चिंता कर रहे हैं?"), | |
("Have you been preoccupied with worries about work, family, or health for the past 6 months?", "क्या आप पिछले 6 महीनों से काम, परिवार या स्वास्थ्य को लेकर चिंतित रहे हैं?"), | |
("Have you had frequent headaches, body pain, or fatigue for several weeks without a clear physical cause?", "क्या आपको पिछले कई हफ्तों से बिना स्पष्ट शारीरिक कारण के सिरदर्द, बदन दर्द या थकान रही है?"), | |
("Have you experienced anxiety symptoms like palpitations, choking, or dry mouth for the past 6 months?", "क्या आपको पिछले 6 महीनों से घबराहट जैसे लक्षण (जैसे धड़कन तेज होना, गला बंद लगना, मुंह सूखना) महसूस हुए हैं?"), | |
("Have these anxiety symptoms occurred in most situations over the past 6 months?", "क्या ये घबराहट के लक्षण पिछले 6 महीनों में ज्यादातर स्थितियों में हुए हैं?"), | |
("Have you had difficulty concentrating or thinking clearly for at least 2 weeks?", "क्या आपको पिछले 2 हफ्तों से ध्यान केंद्रित करने या स्पष्ट रूप से सोचने में कठिनाई हुई है?"), | |
("Do you often have difficulty staying focused during tasks or conversations?", "क्या आपको कार्य करते समय या बातचीत के दौरान ध्यान केंद्रित करने में कठिनाई होती है?"), | |
("Do you frequently lose items or forget daily responsibilities (e.g., appointments, bills)?", "क्या आप अक्सर चीजें खो देते हैं या दैनिक ज़िम्मेदारियाँ भूल जाते हैं (जैसे अपॉइंटमेंट, बिल)?"), | |
("Do you struggle to organize tasks or manage your time effectively?", "क्या आपको कार्यों को व्यवस्थित करने या समय का सही उपयोग करने में कठिनाई होती है?"), | |
("Do you feel restless or find it hard to sit still for extended periods?", "क्या आप बेचैनी महसूस करते हैं या लंबे समय तक शांत बैठना मुश्किल लगता है?"), | |
("Do you often interrupt others or speak out without waiting your turn?", "क्या आप अक्सर दूसरों की बात काटते हैं या बिना रुके बोल पड़ते हैं?"), | |
("Do you procrastinate or avoid tasks that require sustained mental effort?", "क्या आप ऐसे कार्यों को टालते हैं जिनमें लंबे समय तक ध्यान केंद्रित करना होता है?"), | |
("Have these difficulties been present since childhood and continue to affect your work or relationships?", "क्या ये कठिनाइयाँ बचपन से रही हैं और अब भी आपके काम या रिश्तों को प्रभावित कर रही हैं?"), | |
] | |
# ---- STATE ---- | |
state = { | |
"index": 0, | |
"responses": [""] * len(questions) | |
} | |
# ---- FUNCTIONS ---- | |
def render_question(): | |
idx = state["index"] | |
total = len(questions) | |
en, hi = questions[idx] | |
question_html = f"**Q{idx+1}. {en}**<br><span style='color:#666'>{hi}</span>" | |
progress = f"Progress: Question {idx+1} of {total}" | |
return ( | |
gr.update(value=question_html), | |
gr.update(value=progress), | |
gr.update(visible=True), gr.update(visible=True), | |
gr.update(visible=(idx > 0)), | |
gr.update(visible=False), gr.update(visible=False), gr.update(visible=False) | |
) | |
def next_step(response): | |
state["responses"][state["index"]] = response | |
state["index"] += 1 | |
if state["index"] < len(questions): | |
return render_question() | |
else: | |
return ( | |
gr.update(value="✅ All questions completed. Click 'Submit for AI Analysis'."), # (question_display) | |
gr.update(value=""), # (progress_bar) | |
gr.update(visible=False), # yes_btn | |
gr.update(visible=False), # no_btn | |
gr.update(visible=False), # back_btn | |
gr.update(visible=True), # result_btn (Submit Button becomes visible here!) | |
gr.update(visible=False), # result_box (Textbox hidden) | |
gr.update(visible=False) # restart_btn hidden | |
) | |
def format_yes_responses(): | |
yes_topics = [] | |
for (en, _), ans in zip(questions, state["responses"]): | |
if ans.lower() == "yes": | |
yes_topics.append(en) | |
if not yes_topics: | |
return "No significant symptoms reported." | |
return "\n".join(yes_topics) | |
def run_final_analysis(): | |
yield ( | |
gr.update(value="⏳ Please wait... analyzing your responses 🧠"), # question_display | |
gr.update(value=""), # progress_bar | |
*[gr.update(visible=False) for _ in range(4)], # yes_btn, no_btn, back_btn, result_btn | |
gr.update(visible=False), # result_box | |
gr.update(visible=False) # restart_btn | |
) | |
time.sleep(1) | |
yes_summary = format_yes_responses() | |
prompt = ( | |
f"""The user has reported the following symptoms: {yes_summary} | |
Based on these symptoms, please write a short clinical impression summarizing the likely psychiatric condition and its further management. | |
On the basis of symptomes mentioned | |
""" | |
) | |
output = llm(prompt, max_new_tokens=300, temperature=0.3, do_sample=False) | |
ai_result = output[0]["generated_text"] | |
yield ( | |
gr.update(value="✅ AI Analysis Completed."), # question_display | |
gr.update(value=""), # progress_bar | |
*[gr.update(visible=False) for _ in range(4)], # yes_btn, no_btn, back_btn, result_btn | |
gr.update(value=ai_result, visible=True), # result_box (important: both value+visible) | |
gr.update(visible=True) # restart_btn | |
) | |
def go_back(): | |
if state["index"] > 0: | |
state["index"] -= 1 | |
return render_question() | |
def start_app(): | |
state["index"] = 0 | |
state["responses"] = [""] * len(questions) | |
return render_question() | |
def restart_screening(): | |
state["index"] = 0 | |
state["responses"] = [""] * len(questions) | |
return render_question() | |
# ---- GRADIO APP ---- | |
with gr.Blocks(theme=gr.themes.Soft()) as app: | |
gr.Markdown("## 🧠 MindScreen: Mental Health Self-Screening") | |
gr.Markdown("### मानसिक स्वास्थ्य आत्म-स्क्रीनिंग प्रश्नावली") | |
gr.Markdown( | |
"**Note: Please choose 'Yes' only if the symptoms cause distress to you or your family, or if they interfere with your day-to-day functioning.**\n\n" | |
"**नोट: कृपया 'हाँ' तभी चुनें जब ये लक्षण आपके या आपके परिवार के लिए परेशानी का कारण बन रहे हों, या आपकी दिनचर्या को प्रभावित करते हों।**" | |
) | |
progress_bar = gr.Markdown("") | |
question_display = gr.Markdown("", elem_id="question-box") | |
with gr.Row(): | |
yes_btn = gr.Button("Yes / हाँ", visible=False) | |
no_btn = gr.Button("No / नहीं", visible=False) | |
with gr.Row(): | |
back_btn = gr.Button("⬅️ Back / पीछे जाएं", visible=False) | |
result_btn = gr.Button("🎯 Submit for AI Analysis", visible=False) | |
with gr.Row(): | |
restart_btn = gr.Button("🔄 Start New Screening", visible=False) # <-- Restart button | |
with gr.Row(): | |
result_box = gr.Textbox( | |
label="🧠 AI Interpretation (English + Hindi)", | |
visible=False, | |
lines=20, | |
max_lines=25, | |
show_copy_button=True, | |
interactive=False | |
) | |
app.load(start_app, outputs=[question_display, progress_bar, yes_btn, no_btn, back_btn, result_btn, result_box, restart_btn]) | |
yes_btn.click(lambda: next_step("Yes"), outputs=[question_display, progress_bar, yes_btn, no_btn, back_btn, result_btn, result_box, restart_btn]) | |
no_btn.click(lambda: next_step("No"), outputs=[question_display, progress_bar, yes_btn, no_btn, back_btn, result_btn, result_box, restart_btn]) | |
back_btn.click(go_back, outputs=[question_display, progress_bar, yes_btn, no_btn, back_btn, result_btn, result_box, restart_btn]) | |
result_btn.click(run_final_analysis, outputs=[question_display, progress_bar, yes_btn, no_btn, back_btn, result_btn, result_box, restart_btn], concurrency_limit=1) | |
restart_btn.click(restart_screening, outputs=[question_display, progress_bar, yes_btn, no_btn, back_btn, result_btn, result_box, restart_btn]) | |
app.launch() | |