paragon-analytics commited on
Commit
74d7743
Β·
verified Β·
1 Parent(s): d2b9b3e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +160 -114
app.py CHANGED
@@ -1,127 +1,173 @@
 
 
 
1
  import numpy as np
 
2
  import torch
3
- import shap
4
- from transformers import (
5
- pipeline,
6
- AutoTokenizer,
7
- AutoModelForSequenceClassification,
8
- AutoModelForTokenClassification
9
- )
10
- import gradio as gr
 
 
11
 
12
- # β€”β€”β€”β€”β€”β€”β€”β€”β€” 1) Device setup β€”β€”β€”β€”β€”β€”β€”β€”β€”
13
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
-
15
- # β€”β€”β€”β€”β€”β€”β€”β€”β€” 2) ADR classifier β€”β€”β€”β€”β€”β€”β€”β€”β€”
16
- model_name = "paragon-analytics/ADRv1"
17
- tokenizer = AutoTokenizer.from_pretrained(model_name)
18
- model = AutoModelForSequenceClassification.from_pretrained(model_name).to(device)
19
-
20
- pred_pipeline = pipeline(
21
- "text-classification",
22
- model=model,
23
- tokenizer=tokenizer,
24
- return_all_scores=True,
25
- device=0 if device.type == "cuda" else -1
26
- )
27
-
28
- def predict_proba(texts):
29
- if isinstance(texts, str):
30
- texts = [texts]
31
- results = pred_pipeline(texts)
32
- return np.array([[d["score"] for d in sample] for sample in results])
33
-
34
- def predict_proba_shap(inputs):
35
- texts = [" ".join(x) if isinstance(x, list) else x for x in inputs]
36
- return predict_proba(texts)
37
-
38
- # β€”β€”β€”β€”β€”β€”β€”β€”β€” 3) SHAP explainer β€”β€”β€”β€”β€”β€”β€”β€”β€”
39
- masker = shap.maskers.Text(tokenizer)
40
- _example = pred_pipeline(["test"])[0]
41
- class_labels = [d["label"] for d in _example]
42
- explainer = shap.Explainer(
43
- predict_proba_shap,
44
- masker=masker,
45
- output_names=class_labels
46
- )
47
-
48
- # β€”β€”β€”β€”β€”β€”β€”β€”β€” 4) Biomedical NER β€”β€”β€”β€”β€”β€”β€”β€”β€”
49
- ner_name = "d4data/biomedical-ner-all"
50
- ner_tokenizer = AutoTokenizer.from_pretrained(ner_name)
51
- ner_model = AutoModelForTokenClassification.from_pretrained(ner_name).to(device)
52
- ner_pipe = pipeline(
53
- "ner",
54
- model=ner_model,
55
- tokenizer=ner_tokenizer,
56
- aggregation_strategy="simple",
57
- device=0 if device.type == "cuda" else -1
58
- )
59
-
60
- ENTITY_COLORS = {
61
- "Severity": "red",
62
- "Sign_symptom": "green",
63
- "Medication": "lightblue",
64
- "Age": "yellow",
65
- "Sex": "yellow",
66
- "Diagnostic_procedure": "gray",
67
- "Biological_structure": "silver"
68
- }
69
-
70
- # β€”β€”β€”β€”β€”β€”β€”β€”β€” 5) Prediction + SHAP + NER β€”β€”β€”β€”β€”β€”β€”β€”β€”
71
- def adr_predict(text: str):
72
- # Probabilities
73
- probs = predict_proba([text])[0]
74
- prob_dict = {cls: float(probs[i]) for i, cls in enumerate(class_labels)}
75
- # SHAP
76
- shap_vals = explainer([text])
77
- fig = shap.plots.text(shap_vals[0], display=False)
78
- # NER highlight
79
- ents = ner_pipe(text)
80
- highlighted, last = "", 0
81
- for ent in ents:
82
- s, e = ent["start"], ent["end"]
83
- w = ent["word"].replace("##", "")
84
- color = ENTITY_COLORS.get(ent["entity_group"], "lightgray")
85
- highlighted += text[last:s] + f"<mark style='background-color:{color};'>{w}</mark>"
86
- last = e
87
- highlighted += text[last:]
88
- return prob_dict, fig, highlighted
89
-
90
- # β€”β€”β€”β€”β€”β€”β€”β€”β€” 6) Gradio UI β€”β€”β€”β€”β€”β€”β€”β€”β€”
91
- with gr.Blocks() as demo:
92
- gr.Markdown("## Welcome to **ADR Detector** πŸͺ")
93
- gr.Markdown(
94
- "Predicts how likely your text describes a **severe** vs. **non-severe** adverse reaction. \n"
95
- "_(Not for medical or diagnostic use.)_"
96
- )
97
 
98
- txt = gr.Textbox(
99
- label="Enter Your Text Here:", lines=3,
100
- placeholder="Type a sentence about an adverse reaction…"
101
- )
102
- btn = gr.Button("Analyze")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
 
104
- with gr.Row():
105
- out_prob = gr.Label(label="Predicted Probabilities")
106
- out_shap = gr.Plot(label="SHAP Explanation")
107
- out_ner = gr.HTML(label="Biomedical Entities Highlighted")
108
 
109
- btn.click(
110
- fn=adr_predict,
111
- inputs=txt,
112
- outputs=[out_prob, out_shap, out_ner]
 
 
113
  )
114
 
 
 
 
 
115
  gr.Examples(
116
  examples=[
117
- "A 35-year-old male experienced severe headache after taking Aspirin.",
118
- "A 35-year-old female had minor abdominal pain after Acetaminophen."
119
  ],
120
- inputs=txt,
121
- outputs=[out_prob, out_shap, out_ner],
122
- fn=adr_predict,
123
- cache_examples=False # ← disable startup caching here
 
124
  )
125
 
126
- if __name__ == "__main__":
127
- demo.launch()
 
1
+ import streamlit as st
2
+ import gradio as gr
3
+ import shap
4
  import numpy as np
5
+ import scipy as sp
6
  import torch
7
+ import tensorflow as tf
8
+ import transformers
9
+ from transformers import pipeline
10
+ from transformers import RobertaTokenizer, RobertaModel
11
+ from transformers import AutoModelForSequenceClassification
12
+ from transformers import TFAutoModelForSequenceClassification # Although imported, this is not used in the provided code
13
+ from transformers import AutoTokenizer, AutoModelForTokenClassification
14
+ import matplotlib.pyplot as plt
15
+ import sys
16
+ import csv
17
 
18
+ csv.field_size_limit(sys.maxsize)
19
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
+ # Load models and tokenizer
22
+ tokenizer = AutoTokenizer.from_pretrained("paragon-analytics/ADRv1")
23
+ model = AutoModelForSequenceClassification.from_pretrained("paragon-analytics/ADRv1").to(device)
24
+
25
+ # Build a pipeline object for predictions
26
+ # Note: return_all_scores is deprecated, use top_k=None instead
27
+ pred = transformers.pipeline("text-classification", model=model,
28
+ tokenizer=tokenizer, return_all_scores=True) # This will likely issue a warning or error in newer versions.
29
+
30
+ # SHAP explainer
31
+ # Check SHAP documentation for potential changes in Explainer initialization if issues arise.
32
+ explainer = shap.Explainer(pred)
33
+
34
+ # NER pipeline
35
+ ner_tokenizer = AutoTokenizer.from_pretrained("d4data/biomedical-ner-all")
36
+ ner_model = AutoModelForTokenClassification.from_pretrained("d4data/biomedical-ner-all")
37
+ ner_pipe = pipeline("ner", model=ner_model, tokenizer=ner_tokenizer, aggregation_strategy="simple") # pass device=0 if using gpu
38
+
39
+ # def adr_predict(x):
40
+ def adr_predict(x):
41
+ # Ensure input is treated as a string
42
+ text_input = str(x).lower()
43
+
44
+ encoded_input = tokenizer(text_input, return_tensors='pt').to(device) # Move input to device
45
+ output = model(**encoded_input)
46
+
47
+ # Use torch.softmax instead of tf.nn.softmax for consistency with torch model
48
+ scores = torch.softmax(output.logits, dim=-1)[0].detach().cpu().numpy() # Apply softmax on logits, move to cpu and convert to numpy
49
+
50
+ # SHAP values - The explainer might need the raw prediction output or a wrapped function
51
+ # that returns the model output in a format SHAP expects.
52
+ # The current usage `explainer([text_input])` assumes the explainer can directly
53
+ # handle the text input and the pipeline output structure. This might need adjustment
54
+ # based on the SHAP version and how it interfaces with Hugging Face pipelines.
55
+ # If explainer([text_input]) doesn't work, you might need to create a wrapper function
56
+ # like `def f(text): return pred(text)` and pass `f` to shap.Explainer.
57
+ # Also, shap.plots.text might have changes in how it's called or its return value.
58
+ try:
59
+ shap_values = explainer([text_input])
60
+ # Assuming shap_values structure is compatible and shap.plots.text works as before
61
+ # You might need to explicitly handle the expected output format from the pipeline
62
+ # which, with top_k=None, is a list of dictionaries. SHAP expects a consistent
63
+ # output format from the prediction function.
64
+ # If the pipeline output with top_k=None is different, the explainer might fail.
65
+ # Let's assume for now the explainer can handle the output format.
66
+
67
+ # shap.plots.text often returns a matplotlib figure or renders directly.
68
+ # To display in Gradio HTML, you might need to save the plot to a string (e.g., SVG or HTML).
69
+ # The display=False might prevent direct rendering, but check if it returns a string representation.
70
+ # If not, you'll need to generate an HTML/SVG string from the plot.
71
+ # As a fallback, we'll assume display=False makes it suitable for embedding or returns a string representation.
72
+ local_plot = shap.plots.text(shap_values[0], display=False) # This might need adjustment
73
+
74
+ except Exception as e:
75
+ print(f"SHAP explanation failed: {e}")
76
+ local_plot = "<p>SHAP explanation not available.</p>" # Provide a fallback
77
+
78
+ # NER processing
79
+ try:
80
+ res = ner_pipe(text_input)
81
+
82
+ entity_colors = {
83
+ 'Severity': 'red',
84
+ 'Sign_symptom': 'green',
85
+ 'Medication': 'lightblue',
86
+ 'Age': 'yellow',
87
+ 'Sex':'yellow',
88
+ 'Diagnostic_procedure':'gray',
89
+ 'Biological_structure':'silver'
90
+ }
91
+ htext = ""
92
+ prev_end = 0
93
+ # Sort entities by start position to build the highlighted text correctly
94
+ res = sorted(res, key=lambda x: x['start'])
95
+
96
+ for entity in res:
97
+ start = entity['start']
98
+ end = entity['end']
99
+ word = text_input[start:end] # Extract original text segment
100
+ entity_type = entity['entity_group']
101
+ color = entity_colors.get(entity_type, 'lightgray') # Use get with a default color
102
+
103
+ # Append text before the entity
104
+ htext += f"{text_input[prev_end:start]}"
105
+ # Append the highlighted entity
106
+ htext += f"<mark style='background-color:{color};'>{word}</mark>"
107
+ prev_end = end
108
+ # Append any remaining text after the last entity
109
+ htext += text_input[prev_end:]
110
+ except Exception as e:
111
+ print(f"NER processing failed: {e}")
112
+ htext = "<p>NER processing not available.</p>" # Provide a fallback
113
+
114
+
115
+ # The original code returns a tuple of results. Gradio's click function expects
116
+ # the number of returned values to match the number of output components.
117
+ # The original return was: {"Severe Reaction": float(scores.numpy()[1]), "Non-severe Reaction": float(scores.numpy()[0])}, local_plot, htext
118
+ # The output components are: label, local_plot, htext
119
+ # The score output for the label component should be a dictionary as expected by gr.Label
120
+ label_output = {"Severe Reaction": float(scores[1]), "Non-severe Reaction": float(scores[0])}
121
+
122
+ return label_output, local_plot, htext
123
+
124
+ def main(prob1):
125
+ # The main function now directly calls adr_predict and returns its results.
126
+ # This matches the expected signature for a Gradio interface function.
127
+ return adr_predict(prob1)
128
+
129
+ title = "Welcome to **ADR Detector** πŸͺ"
130
+ description1 = """This app takes text (up to a few sentences) and predicts to what extent the text describes severe (or non-severe) adverse reaction to medicaitons. Please do NOT use for medical diagnosis."""
131
+
132
+ # Use the 'with' syntax for Blocks
133
+ with gr.Blocks(title=title) as demo:
134
+ gr.Markdown(f"## {title}")
135
+ gr.Markdown(description1)
136
+ gr.Markdown("""---""")
137
+
138
+ # Define input and output components
139
+ prob1 = gr.Textbox(label="Enter Your Text Here:",lines=2, placeholder="Type it here ...")
140
+
141
+ # Output components matching the return values of the main function
142
+ label = gr.Label(label = "Predicted Label")
143
+ local_plot = gr.HTML(label = 'Shap Explanation') # Changed label for clarity
144
+ htext = gr.HTML(label="Named Entity Recognition") # Changed label for clarity
145
 
146
+ submit_btn = gr.Button("Analyze")
 
 
 
147
 
148
+ # Use .click() on the button to define the interaction
149
+ submit_btn.click(
150
+ fn=main, # The function to call
151
+ inputs=[prob1], # The input components
152
+ outputs=[label, local_plot, htext], # The output components
153
+ api_name="adr" # Keep the api_name if you intend to use the API
154
  )
155
 
156
+ # Examples section
157
+ gr.Markdown("### Click on any of the examples below to see how it works:")
158
+ # Gradio 4.0+ Examples usage. Pass inputs and outputs components directly.
159
+ # cache_examples is recommended for faster loading of examples.
160
  gr.Examples(
161
  examples=[
162
+ ["A 35 year-old male had severe headache after taking Aspirin. The lab results were normal."],
163
+ ["A 35 year-old female had minor pain in upper abdomen after taking Acetaminophen."]
164
  ],
165
+ inputs=[prob1],
166
+ outputs=[label, local_plot, htext],
167
+ fn=main, # Provide the function to run for caching examples
168
+ cache_examples=True,
169
+ # run_on_click=True # This might be needed depending on exact Gradio version and desired behavior, but cache_examples=True implies running the function.
170
  )
171
 
172
+ # Launch the demo
173
+ demo.launch()