Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -1,127 +1,173 @@
|
|
|
|
|
|
|
|
1 |
import numpy as np
|
|
|
2 |
import torch
|
3 |
-
import
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
import
|
|
|
|
|
11 |
|
12 |
-
|
13 |
-
device =
|
14 |
-
|
15 |
-
# βββββββββ 2) ADR classifier βββββββββ
|
16 |
-
model_name = "paragon-analytics/ADRv1"
|
17 |
-
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
18 |
-
model = AutoModelForSequenceClassification.from_pretrained(model_name).to(device)
|
19 |
-
|
20 |
-
pred_pipeline = pipeline(
|
21 |
-
"text-classification",
|
22 |
-
model=model,
|
23 |
-
tokenizer=tokenizer,
|
24 |
-
return_all_scores=True,
|
25 |
-
device=0 if device.type == "cuda" else -1
|
26 |
-
)
|
27 |
-
|
28 |
-
def predict_proba(texts):
|
29 |
-
if isinstance(texts, str):
|
30 |
-
texts = [texts]
|
31 |
-
results = pred_pipeline(texts)
|
32 |
-
return np.array([[d["score"] for d in sample] for sample in results])
|
33 |
-
|
34 |
-
def predict_proba_shap(inputs):
|
35 |
-
texts = [" ".join(x) if isinstance(x, list) else x for x in inputs]
|
36 |
-
return predict_proba(texts)
|
37 |
-
|
38 |
-
# βββββββββ 3) SHAP explainer βββββββββ
|
39 |
-
masker = shap.maskers.Text(tokenizer)
|
40 |
-
_example = pred_pipeline(["test"])[0]
|
41 |
-
class_labels = [d["label"] for d in _example]
|
42 |
-
explainer = shap.Explainer(
|
43 |
-
predict_proba_shap,
|
44 |
-
masker=masker,
|
45 |
-
output_names=class_labels
|
46 |
-
)
|
47 |
-
|
48 |
-
# βββββββββ 4) Biomedical NER βββββββββ
|
49 |
-
ner_name = "d4data/biomedical-ner-all"
|
50 |
-
ner_tokenizer = AutoTokenizer.from_pretrained(ner_name)
|
51 |
-
ner_model = AutoModelForTokenClassification.from_pretrained(ner_name).to(device)
|
52 |
-
ner_pipe = pipeline(
|
53 |
-
"ner",
|
54 |
-
model=ner_model,
|
55 |
-
tokenizer=ner_tokenizer,
|
56 |
-
aggregation_strategy="simple",
|
57 |
-
device=0 if device.type == "cuda" else -1
|
58 |
-
)
|
59 |
-
|
60 |
-
ENTITY_COLORS = {
|
61 |
-
"Severity": "red",
|
62 |
-
"Sign_symptom": "green",
|
63 |
-
"Medication": "lightblue",
|
64 |
-
"Age": "yellow",
|
65 |
-
"Sex": "yellow",
|
66 |
-
"Diagnostic_procedure": "gray",
|
67 |
-
"Biological_structure": "silver"
|
68 |
-
}
|
69 |
-
|
70 |
-
# βββββββββ 5) Prediction + SHAP + NER βββββββββ
|
71 |
-
def adr_predict(text: str):
|
72 |
-
# Probabilities
|
73 |
-
probs = predict_proba([text])[0]
|
74 |
-
prob_dict = {cls: float(probs[i]) for i, cls in enumerate(class_labels)}
|
75 |
-
# SHAP
|
76 |
-
shap_vals = explainer([text])
|
77 |
-
fig = shap.plots.text(shap_vals[0], display=False)
|
78 |
-
# NER highlight
|
79 |
-
ents = ner_pipe(text)
|
80 |
-
highlighted, last = "", 0
|
81 |
-
for ent in ents:
|
82 |
-
s, e = ent["start"], ent["end"]
|
83 |
-
w = ent["word"].replace("##", "")
|
84 |
-
color = ENTITY_COLORS.get(ent["entity_group"], "lightgray")
|
85 |
-
highlighted += text[last:s] + f"<mark style='background-color:{color};'>{w}</mark>"
|
86 |
-
last = e
|
87 |
-
highlighted += text[last:]
|
88 |
-
return prob_dict, fig, highlighted
|
89 |
-
|
90 |
-
# βββββββββ 6) Gradio UI βββββββββ
|
91 |
-
with gr.Blocks() as demo:
|
92 |
-
gr.Markdown("## Welcome to **ADR Detector** πͺ")
|
93 |
-
gr.Markdown(
|
94 |
-
"Predicts how likely your text describes a **severe** vs. **non-severe** adverse reaction. \n"
|
95 |
-
"_(Not for medical or diagnostic use.)_"
|
96 |
-
)
|
97 |
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
103 |
|
104 |
-
|
105 |
-
out_prob = gr.Label(label="Predicted Probabilities")
|
106 |
-
out_shap = gr.Plot(label="SHAP Explanation")
|
107 |
-
out_ner = gr.HTML(label="Biomedical Entities Highlighted")
|
108 |
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
|
|
|
|
113 |
)
|
114 |
|
|
|
|
|
|
|
|
|
115 |
gr.Examples(
|
116 |
examples=[
|
117 |
-
"A 35
|
118 |
-
"A 35
|
119 |
],
|
120 |
-
inputs=
|
121 |
-
outputs=[
|
122 |
-
fn=
|
123 |
-
cache_examples=
|
|
|
124 |
)
|
125 |
|
126 |
-
|
127 |
-
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import gradio as gr
|
3 |
+
import shap
|
4 |
import numpy as np
|
5 |
+
import scipy as sp
|
6 |
import torch
|
7 |
+
import tensorflow as tf
|
8 |
+
import transformers
|
9 |
+
from transformers import pipeline
|
10 |
+
from transformers import RobertaTokenizer, RobertaModel
|
11 |
+
from transformers import AutoModelForSequenceClassification
|
12 |
+
from transformers import TFAutoModelForSequenceClassification # Although imported, this is not used in the provided code
|
13 |
+
from transformers import AutoTokenizer, AutoModelForTokenClassification
|
14 |
+
import matplotlib.pyplot as plt
|
15 |
+
import sys
|
16 |
+
import csv
|
17 |
|
18 |
+
csv.field_size_limit(sys.maxsize)
|
19 |
+
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
|
21 |
+
# Load models and tokenizer
|
22 |
+
tokenizer = AutoTokenizer.from_pretrained("paragon-analytics/ADRv1")
|
23 |
+
model = AutoModelForSequenceClassification.from_pretrained("paragon-analytics/ADRv1").to(device)
|
24 |
+
|
25 |
+
# Build a pipeline object for predictions
|
26 |
+
# Note: return_all_scores is deprecated, use top_k=None instead
|
27 |
+
pred = transformers.pipeline("text-classification", model=model,
|
28 |
+
tokenizer=tokenizer, return_all_scores=True) # This will likely issue a warning or error in newer versions.
|
29 |
+
|
30 |
+
# SHAP explainer
|
31 |
+
# Check SHAP documentation for potential changes in Explainer initialization if issues arise.
|
32 |
+
explainer = shap.Explainer(pred)
|
33 |
+
|
34 |
+
# NER pipeline
|
35 |
+
ner_tokenizer = AutoTokenizer.from_pretrained("d4data/biomedical-ner-all")
|
36 |
+
ner_model = AutoModelForTokenClassification.from_pretrained("d4data/biomedical-ner-all")
|
37 |
+
ner_pipe = pipeline("ner", model=ner_model, tokenizer=ner_tokenizer, aggregation_strategy="simple") # pass device=0 if using gpu
|
38 |
+
|
39 |
+
# def adr_predict(x):
|
40 |
+
def adr_predict(x):
|
41 |
+
# Ensure input is treated as a string
|
42 |
+
text_input = str(x).lower()
|
43 |
+
|
44 |
+
encoded_input = tokenizer(text_input, return_tensors='pt').to(device) # Move input to device
|
45 |
+
output = model(**encoded_input)
|
46 |
+
|
47 |
+
# Use torch.softmax instead of tf.nn.softmax for consistency with torch model
|
48 |
+
scores = torch.softmax(output.logits, dim=-1)[0].detach().cpu().numpy() # Apply softmax on logits, move to cpu and convert to numpy
|
49 |
+
|
50 |
+
# SHAP values - The explainer might need the raw prediction output or a wrapped function
|
51 |
+
# that returns the model output in a format SHAP expects.
|
52 |
+
# The current usage `explainer([text_input])` assumes the explainer can directly
|
53 |
+
# handle the text input and the pipeline output structure. This might need adjustment
|
54 |
+
# based on the SHAP version and how it interfaces with Hugging Face pipelines.
|
55 |
+
# If explainer([text_input]) doesn't work, you might need to create a wrapper function
|
56 |
+
# like `def f(text): return pred(text)` and pass `f` to shap.Explainer.
|
57 |
+
# Also, shap.plots.text might have changes in how it's called or its return value.
|
58 |
+
try:
|
59 |
+
shap_values = explainer([text_input])
|
60 |
+
# Assuming shap_values structure is compatible and shap.plots.text works as before
|
61 |
+
# You might need to explicitly handle the expected output format from the pipeline
|
62 |
+
# which, with top_k=None, is a list of dictionaries. SHAP expects a consistent
|
63 |
+
# output format from the prediction function.
|
64 |
+
# If the pipeline output with top_k=None is different, the explainer might fail.
|
65 |
+
# Let's assume for now the explainer can handle the output format.
|
66 |
+
|
67 |
+
# shap.plots.text often returns a matplotlib figure or renders directly.
|
68 |
+
# To display in Gradio HTML, you might need to save the plot to a string (e.g., SVG or HTML).
|
69 |
+
# The display=False might prevent direct rendering, but check if it returns a string representation.
|
70 |
+
# If not, you'll need to generate an HTML/SVG string from the plot.
|
71 |
+
# As a fallback, we'll assume display=False makes it suitable for embedding or returns a string representation.
|
72 |
+
local_plot = shap.plots.text(shap_values[0], display=False) # This might need adjustment
|
73 |
+
|
74 |
+
except Exception as e:
|
75 |
+
print(f"SHAP explanation failed: {e}")
|
76 |
+
local_plot = "<p>SHAP explanation not available.</p>" # Provide a fallback
|
77 |
+
|
78 |
+
# NER processing
|
79 |
+
try:
|
80 |
+
res = ner_pipe(text_input)
|
81 |
+
|
82 |
+
entity_colors = {
|
83 |
+
'Severity': 'red',
|
84 |
+
'Sign_symptom': 'green',
|
85 |
+
'Medication': 'lightblue',
|
86 |
+
'Age': 'yellow',
|
87 |
+
'Sex':'yellow',
|
88 |
+
'Diagnostic_procedure':'gray',
|
89 |
+
'Biological_structure':'silver'
|
90 |
+
}
|
91 |
+
htext = ""
|
92 |
+
prev_end = 0
|
93 |
+
# Sort entities by start position to build the highlighted text correctly
|
94 |
+
res = sorted(res, key=lambda x: x['start'])
|
95 |
+
|
96 |
+
for entity in res:
|
97 |
+
start = entity['start']
|
98 |
+
end = entity['end']
|
99 |
+
word = text_input[start:end] # Extract original text segment
|
100 |
+
entity_type = entity['entity_group']
|
101 |
+
color = entity_colors.get(entity_type, 'lightgray') # Use get with a default color
|
102 |
+
|
103 |
+
# Append text before the entity
|
104 |
+
htext += f"{text_input[prev_end:start]}"
|
105 |
+
# Append the highlighted entity
|
106 |
+
htext += f"<mark style='background-color:{color};'>{word}</mark>"
|
107 |
+
prev_end = end
|
108 |
+
# Append any remaining text after the last entity
|
109 |
+
htext += text_input[prev_end:]
|
110 |
+
except Exception as e:
|
111 |
+
print(f"NER processing failed: {e}")
|
112 |
+
htext = "<p>NER processing not available.</p>" # Provide a fallback
|
113 |
+
|
114 |
+
|
115 |
+
# The original code returns a tuple of results. Gradio's click function expects
|
116 |
+
# the number of returned values to match the number of output components.
|
117 |
+
# The original return was: {"Severe Reaction": float(scores.numpy()[1]), "Non-severe Reaction": float(scores.numpy()[0])}, local_plot, htext
|
118 |
+
# The output components are: label, local_plot, htext
|
119 |
+
# The score output for the label component should be a dictionary as expected by gr.Label
|
120 |
+
label_output = {"Severe Reaction": float(scores[1]), "Non-severe Reaction": float(scores[0])}
|
121 |
+
|
122 |
+
return label_output, local_plot, htext
|
123 |
+
|
124 |
+
def main(prob1):
|
125 |
+
# The main function now directly calls adr_predict and returns its results.
|
126 |
+
# This matches the expected signature for a Gradio interface function.
|
127 |
+
return adr_predict(prob1)
|
128 |
+
|
129 |
+
title = "Welcome to **ADR Detector** πͺ"
|
130 |
+
description1 = """This app takes text (up to a few sentences) and predicts to what extent the text describes severe (or non-severe) adverse reaction to medicaitons. Please do NOT use for medical diagnosis."""
|
131 |
+
|
132 |
+
# Use the 'with' syntax for Blocks
|
133 |
+
with gr.Blocks(title=title) as demo:
|
134 |
+
gr.Markdown(f"## {title}")
|
135 |
+
gr.Markdown(description1)
|
136 |
+
gr.Markdown("""---""")
|
137 |
+
|
138 |
+
# Define input and output components
|
139 |
+
prob1 = gr.Textbox(label="Enter Your Text Here:",lines=2, placeholder="Type it here ...")
|
140 |
+
|
141 |
+
# Output components matching the return values of the main function
|
142 |
+
label = gr.Label(label = "Predicted Label")
|
143 |
+
local_plot = gr.HTML(label = 'Shap Explanation') # Changed label for clarity
|
144 |
+
htext = gr.HTML(label="Named Entity Recognition") # Changed label for clarity
|
145 |
|
146 |
+
submit_btn = gr.Button("Analyze")
|
|
|
|
|
|
|
147 |
|
148 |
+
# Use .click() on the button to define the interaction
|
149 |
+
submit_btn.click(
|
150 |
+
fn=main, # The function to call
|
151 |
+
inputs=[prob1], # The input components
|
152 |
+
outputs=[label, local_plot, htext], # The output components
|
153 |
+
api_name="adr" # Keep the api_name if you intend to use the API
|
154 |
)
|
155 |
|
156 |
+
# Examples section
|
157 |
+
gr.Markdown("### Click on any of the examples below to see how it works:")
|
158 |
+
# Gradio 4.0+ Examples usage. Pass inputs and outputs components directly.
|
159 |
+
# cache_examples is recommended for faster loading of examples.
|
160 |
gr.Examples(
|
161 |
examples=[
|
162 |
+
["A 35 year-old male had severe headache after taking Aspirin. The lab results were normal."],
|
163 |
+
["A 35 year-old female had minor pain in upper abdomen after taking Acetaminophen."]
|
164 |
],
|
165 |
+
inputs=[prob1],
|
166 |
+
outputs=[label, local_plot, htext],
|
167 |
+
fn=main, # Provide the function to run for caching examples
|
168 |
+
cache_examples=True,
|
169 |
+
# run_on_click=True # This might be needed depending on exact Gradio version and desired behavior, but cache_examples=True implies running the function.
|
170 |
)
|
171 |
|
172 |
+
# Launch the demo
|
173 |
+
demo.launch()
|