noobmaster1246 commited on
Commit
49a6a82
·
verified ·
1 Parent(s): 40d2442

Update ai.py

Browse files
Files changed (1) hide show
  1. ai.py +144 -120
ai.py CHANGED
@@ -1,120 +1,144 @@
1
- import torch
2
- import torch.nn as nn
3
- import numpy as np
4
- import pandas as pd
5
- from sklearn.preprocessing import LabelEncoder, StandardScaler
6
- from sentence_transformers import SentenceTransformer, util
7
- import json
8
- from transformers import AutoTokenizer, AutoModelForTokenClassification, pipeline
9
- from symspellpy import SymSpell, Verbosity
10
-
11
- device = torch.device("cpu")
12
-
13
- class DiseaseClassifier(nn.Module):
14
- def __init__(self, input_size, num_classes, dropout_rate=0.35665610394511454):
15
- super(DiseaseClassifier, self).__init__()
16
- self.fc1 = nn.Linear(input_size, 382)
17
- self.fc2 = nn.Linear(382, 389)
18
- self.fc3 = nn.Linear(389, 433)
19
- self.fc4 = nn.Linear(433, num_classes)
20
- self.activation = nn.LeakyReLU()
21
- self.dropout = nn.Dropout(dropout_rate)
22
-
23
- def forward(self, x):
24
- x = self.activation(self.fc1(x))
25
- x = self.dropout(x)
26
- x = self.activation(self.fc2(x))
27
- x = self.dropout(x)
28
- x = self.activation(self.fc3(x))
29
- x = self.dropout(x)
30
- x = self.fc4(x) # Logits
31
- return x
32
-
33
-
34
- class DiseasePredictionModel:
35
- def __init__(self, ai_model_name="model.pth", data_file="data.csv", symptom_json="symptoms.json", dictionary_file="frequency_dictionary_en_82_765.txt"):
36
- self.df = pd.read_csv(data_file)
37
- self.symptom_columns = self.load_symptoms(symptom_json)
38
- self.label_encoder = LabelEncoder()
39
- self.label_encoder.fit(self.df.iloc[:, 0])
40
- self.scaler = StandardScaler()
41
- self.scaler.fit(self.df.iloc[:, 1:].values)
42
- self.input_size = len(self.symptom_columns)
43
- self.num_classes = len(self.label_encoder.classes_)
44
- self.model = self._load_model(ai_model_name)
45
- self.SYMPTOM_LIST = self.load_symptoms(symptom_json)
46
- self.sym_spell = SymSpell(max_dictionary_edit_distance=2, prefix_length=7)
47
- self.sym_spell.load_dictionary(dictionary_file, term_index=0, count_index=1)
48
- self.tokenizer = AutoTokenizer.from_pretrained("./biobert_diseases_ner")
49
- # self.tokenizer.save_pretrained("./biobert_diseases_ner")
50
- self.nlp_model = AutoModelForTokenClassification.from_pretrained("./biobert_diseases_ner")
51
- # self.nlp_model.save_pretrained("./biobert_diseases_ner")
52
- self.ner_pipeline = pipeline("ner", model=self.nlp_model, tokenizer=self.tokenizer, aggregation_strategy="simple")
53
- self.semantic_model = SentenceTransformer('all-MiniLM-L6-v2')
54
-
55
- def _load_model(self, ai_model_name):
56
- model = DiseaseClassifier(self.input_size, self.num_classes).to(device)
57
- model.load_state_dict(torch.load(ai_model_name, map_location=device, weights_only=True))
58
- model.eval()
59
- return model
60
-
61
- def predict_disease(self, symptoms):
62
- input_vector = np.zeros(len(self.symptom_columns))
63
- for symptom in symptoms:
64
- if symptom in self.symptom_columns:
65
- input_vector[list(self.symptom_columns).index(symptom)] = 1
66
-
67
- input_vector = self.scaler.transform([input_vector])
68
-
69
- input_tensor = torch.tensor(input_vector, dtype=torch.float32).to(device)
70
-
71
- with torch.no_grad():
72
- outputs = self.model(input_tensor)
73
- _, predicted_class = torch.max(outputs, 1)
74
-
75
- predicted_disease = self.label_encoder.inverse_transform([predicted_class.cpu().numpy()[0]])[0]
76
- return predicted_disease
77
-
78
- def load_symptoms(self, json_file):
79
- with open(json_file, "r", encoding="utf-8") as f:
80
- return json.load(f)
81
-
82
- def correct_text(self, text):
83
- words = text.split()
84
- corrected_words = []
85
-
86
- for word in words:
87
- if word.lower() in [symptom.lower() for symptom in self.SYMPTOM_LIST]:
88
- corrected_words.append(word)
89
- else:
90
- suggestions = self.sym_spell.lookup(word, Verbosity.CLOSEST, max_edit_distance=2)
91
- if suggestions:
92
- corrected_words.append(suggestions[0].term)
93
- else:
94
- corrected_words.append(word)
95
- return ' '.join(corrected_words)
96
-
97
- def extract_symptoms(self, text):
98
- ner_results = self.ner_pipeline(text)
99
- symptoms = set()
100
- for entity in ner_results:
101
- if entity["entity_group"] == "DISEASE":
102
- symptoms.add(entity["word"].lower())
103
- return list(symptoms)
104
-
105
- def match_symptoms(self, extracted_symptoms):
106
- matched = {}
107
-
108
- symptom_embeddings = self.semantic_model.encode(self.SYMPTOM_LIST, convert_to_tensor=True)
109
-
110
- for symptom in extracted_symptoms:
111
- symptom_embedding = self.semantic_model.encode(symptom, convert_to_tensor=True)
112
-
113
- similarities = util.pytorch_cos_sim(symptom_embedding, symptom_embeddings)[0]
114
-
115
- most_similar_idx = similarities.argmax()
116
- best_match = self.SYMPTOM_LIST[most_similar_idx]
117
- matched[symptom] = best_match
118
-
119
- return matched.values()
120
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torch.nn as nn
4
+ import numpy as np
5
+ import pandas as pd
6
+ import json
7
+ from sklearn.preprocessing import LabelEncoder, StandardScaler
8
+ from sentence_transformers import SentenceTransformer, util
9
+ from transformers import AutoTokenizer, AutoModelForTokenClassification, pipeline
10
+ from symspellpy import SymSpell, Verbosity
11
+ import gradio as gr
12
+
13
+ # Ensure Hugging Face cache directory is writable
14
+ os.environ["TRANSFORMERS_CACHE"] = "/home/user/.cache/huggingface"
15
+
16
+ # Set device
17
+ device = torch.device("cpu")
18
+
19
+ # Define DiseaseClassifier Model
20
+ class DiseaseClassifier(nn.Module):
21
+ def __init__(self, input_size, num_classes, dropout_rate=0.35665610394511454):
22
+ super(DiseaseClassifier, self).__init__()
23
+ self.fc1 = nn.Linear(input_size, 382)
24
+ self.fc2 = nn.Linear(382, 389)
25
+ self.fc3 = nn.Linear(389, 433)
26
+ self.fc4 = nn.Linear(433, num_classes)
27
+ self.activation = nn.LeakyReLU()
28
+ self.dropout = nn.Dropout(dropout_rate)
29
+
30
+ def forward(self, x):
31
+ x = self.activation(self.fc1(x))
32
+ x = self.dropout(x)
33
+ x = self.activation(self.fc2(x))
34
+ x = self.dropout(x)
35
+ x = self.activation(self.fc3(x))
36
+ x = self.dropout(x)
37
+ x = self.fc4(x) # Logits
38
+ return x
39
+
40
+ # Define DiseasePredictionModel
41
+ class DiseasePredictionModel:
42
+ def __init__(self, ai_model_name="model.pth", data_file="data.csv", symptom_json="symptoms.json", dictionary_file="frequency_dictionary_en_82_765.txt"):
43
+ # Load dataset
44
+ self.df = pd.read_csv(data_file)
45
+ self.symptom_columns = self.load_symptoms(symptom_json)
46
+ self.label_encoder = LabelEncoder()
47
+ self.label_encoder.fit(self.df.iloc[:, 0])
48
+ self.scaler = StandardScaler()
49
+ self.scaler.fit(self.df.iloc[:, 1:].values)
50
+
51
+ self.input_size = len(self.symptom_columns)
52
+ self.num_classes = len(self.label_encoder.classes_)
53
+ self.model = self._load_model(ai_model_name)
54
+
55
+ self.SYMPTOM_LIST = self.load_symptoms(symptom_json)
56
+
57
+ # Load SymSpell dictionary
58
+ self.sym_spell = SymSpell(max_dictionary_edit_distance=2, prefix_length=7)
59
+ self.sym_spell.load_dictionary(dictionary_file, term_index=0, count_index=1)
60
+
61
+ # Load BioBERT tokenizer and model
62
+ self.tokenizer = AutoTokenizer.from_pretrained("dmis-lab/biobert-v1.1")
63
+ self.nlp_model = AutoModelForTokenClassification.from_pretrained("dmis-lab/biobert-v1.1")
64
+ self.ner_pipeline = pipeline("ner", model=self.nlp_model, tokenizer=self.tokenizer, aggregation_strategy="simple")
65
+
66
+ # Load Sentence Transformer
67
+ self.semantic_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
68
+
69
+ def _load_model(self, ai_model_name):
70
+ model = DiseaseClassifier(self.input_size, self.num_classes).to(device)
71
+ model.load_state_dict(torch.load(ai_model_name, map_location=device))
72
+ model.eval()
73
+ return model
74
+
75
+ def predict_disease(self, symptoms):
76
+ input_vector = np.zeros(len(self.symptom_columns))
77
+ for symptom in symptoms:
78
+ if symptom in self.symptom_columns:
79
+ input_vector[list(self.symptom_columns).index(symptom)] = 1
80
+
81
+ input_vector = self.scaler.transform([input_vector])
82
+ input_tensor = torch.tensor(input_vector, dtype=torch.float32).to(device)
83
+
84
+ with torch.no_grad():
85
+ outputs = self.model(input_tensor)
86
+ _, predicted_class = torch.max(outputs, 1)
87
+
88
+ predicted_disease = self.label_encoder.inverse_transform([predicted_class.cpu().numpy()[0]])[0]
89
+ return predicted_disease
90
+
91
+ def load_symptoms(self, json_file):
92
+ with open(json_file, "r", encoding="utf-8") as f:
93
+ return json.load(f)
94
+
95
+ def correct_text(self, text):
96
+ words = text.split()
97
+ corrected_words = []
98
+
99
+ for word in words:
100
+ if word.lower() in [symptom.lower() for symptom in self.SYMPTOM_LIST]:
101
+ corrected_words.append(word)
102
+ else:
103
+ suggestions = self.sym_spell.lookup(word, Verbosity.CLOSEST, max_edit_distance=2)
104
+ if suggestions:
105
+ corrected_words.append(suggestions[0].term)
106
+ else:
107
+ corrected_words.append(word)
108
+ return ' '.join(corrected_words)
109
+
110
+ def extract_symptoms(self, text):
111
+ ner_results = self.ner_pipeline(text)
112
+ symptoms = set()
113
+ for entity in ner_results:
114
+ if entity["entity_group"] == "DISEASE":
115
+ symptoms.add(entity["word"].lower())
116
+ return list(symptoms)
117
+
118
+ def match_symptoms(self, extracted_symptoms):
119
+ matched = {}
120
+ symptom_embeddings = self.semantic_model.encode(self.SYMPTOM_LIST, convert_to_tensor=True)
121
+
122
+ for symptom in extracted_symptoms:
123
+ symptom_embedding = self.semantic_model.encode(symptom, convert_to_tensor=True)
124
+ similarities = util.pytorch_cos_sim(symptom_embedding, symptom_embeddings)[0]
125
+ most_similar_idx = similarities.argmax()
126
+ best_match = self.SYMPTOM_LIST[most_similar_idx]
127
+ matched[symptom] = best_match
128
+
129
+ return matched.values()
130
+
131
+ # Initialize Model
132
+ model = DiseasePredictionModel()
133
+
134
+ # Define Prediction Function
135
+ def predict(symptoms):
136
+ corrected = model.correct_text(symptoms)
137
+ extracted = model.extract_symptoms(corrected)
138
+ matched = model.match_symptoms(extracted)
139
+ prediction = model.predict_disease(matched)
140
+ return prediction
141
+
142
+ # Define Gradio Interface
143
+ iface = gr.Interface(fn=predict, inputs="text", outputs="text", title="Disease Prediction AI")
144
+ iface.launch()