|
import xgboost as xgb |
|
import pickle |
|
import numpy as np |
|
import pandas as pd |
|
import torch |
|
import streamlit as st |
|
from transformers import AutoTokenizer, AutoModelForQuestionAnswering |
|
import nltk |
|
from nltk.tokenize import word_tokenize |
|
from nltk.corpus import stopwords |
|
import re |
|
|
|
|
|
nltk.download("stopwords") |
|
nltk.download("punkt") |
|
nltk.download('punkt_tab') |
|
|
|
|
|
stop_words = set(stopwords.words("english")) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model_name = "dmis-lab/biobert-large-cased-v1.1-squad" |
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
qa_model = AutoModelForQuestionAnswering.from_pretrained(model_name) |
|
|
|
|
|
|
|
model = xgb.XGBClassifier() |
|
model.load_model("symptom_disease_model.json") |
|
label_encoder = pickle.load(open("label_encoder.pkl", "rb")) |
|
X_train = pd.read_csv("X_train.csv") |
|
symptom_list = X_train.columns.tolist() |
|
|
|
|
|
|
|
|
|
precaution_df = pd.read_csv("Disease precaution.csv") |
|
precaution_dict = { |
|
row["Disease"].strip().lower(): [row[f"Precaution_{i}"] for i in range(1, 5) if pd.notna(row[f"Precaution_{i}"])] |
|
for _, row in precaution_df.iterrows() |
|
} |
|
|
|
|
|
|
|
|
|
def load_medical_context(): |
|
with open("medical_context.txt", "r", encoding="utf-8") as file: |
|
return file.read() |
|
|
|
medical_context = load_medical_context() |
|
|
|
|
|
|
|
|
|
doctor_database = { |
|
"malaria": [{"name": "Dr. Rajesh Kumar", "specialty": "Infectious Diseases", "location": "Apollo Hospital", "contact": "9876543210"}], |
|
"diabetes": [{"name": "Dr. Anil Mehta", "specialty": "Endocrinologist", "location": "AIIMS Delhi", "contact": "9876543233"}], |
|
"heart attack": [{"name": "Dr. Vikram Singh", "specialty": "Cardiologist", "location": "Medanta Hospital", "contact": "9876543255"}], |
|
} |
|
|
|
|
|
|
|
|
|
def predict_disease(user_symptoms): |
|
"""Predicts disease based on user symptoms using the trained XGBoost model.""" |
|
input_vector = np.zeros(len(symptom_list)) |
|
|
|
for symptom in user_symptoms: |
|
if symptom in symptom_list: |
|
input_vector[symptom_list.index(symptom)] = 1 |
|
|
|
input_vector = input_vector.reshape(1, -1) |
|
predicted_class = model.predict(input_vector)[0] |
|
predicted_disease = label_encoder.inverse_transform([predicted_class])[0] |
|
|
|
return predicted_disease |
|
|
|
|
|
|
|
|
|
def get_precautions(disease): |
|
"""Returns the precautions for a given disease.""" |
|
return precaution_dict.get(disease.lower(), ["No precautions available"]) |
|
|
|
|
|
|
|
|
|
def get_medical_answer(question): |
|
"""Uses the pre-trained Q&A model to answer general medical questions.""" |
|
inputs = tokenizer(question, medical_context, return_tensors="pt", truncation=True, max_length=512) |
|
with torch.no_grad(): |
|
outputs = qa_model(**inputs) |
|
|
|
answer_start = torch.argmax(outputs.start_logits) |
|
answer_end = torch.argmax(outputs.end_logits) + 1 |
|
|
|
answer = tokenizer.convert_tokens_to_string( |
|
tokenizer.convert_ids_to_tokens(inputs["input_ids"][0][answer_start:answer_end]) |
|
) |
|
|
|
if answer.strip() in ["", "[CLS]", "<s>"]: |
|
return "I'm not sure. Please consult a medical professional." |
|
|
|
return answer |
|
|
|
|
|
|
|
def book_appointment(disease): |
|
"""Finds a doctor for the given disease and returns appointment details.""" |
|
disease = disease.lower().strip() |
|
doctors = doctor_database.get(disease, []) |
|
if not doctors: |
|
return f"Sorry, no available doctors found for {disease}." |
|
|
|
doctor = doctors[0] |
|
return f"Appointment booked with **{doctor['name']}** ({doctor['specialty']}) at **{doctor['location']}**.\nContact: {doctor['contact']}" |
|
|
|
|
|
|
|
|
|
def handle_user_query(user_query): |
|
"""Handles user queries related to symptoms, diseases, and doctor appointments.""" |
|
user_query = user_query.lower().strip() |
|
|
|
|
|
if "symptoms" in user_query or "signs" in user_query: |
|
disease = user_query.replace("symptoms", "").replace("signs", "").strip() |
|
return get_medical_answer(f"What are the symptoms of {disease}?") |
|
|
|
|
|
elif "treatment" in user_query or "treat" in user_query: |
|
disease = user_query.replace("treatment", "").replace("treat", "").strip() |
|
return get_medical_answer(f"What is the treatment for {disease}?") |
|
|
|
|
|
elif "who should i see" in user_query: |
|
disease = user_query.replace("who should i see for", "").strip() |
|
return book_appointment(disease) |
|
|
|
|
|
elif "book appointment" in user_query: |
|
disease = user_query.replace("book appointment for", "").strip() |
|
return book_appointment(disease) |
|
|
|
|
|
else: |
|
return get_medical_answer(user_query) |
|
|