ADR_Detector / app.py
paragon-analytics's picture
Update app.py
9c00224 verified
import gradio as gr
import shap
import numpy as np
import scipy as sp
import torch
import tensorflow as tf
import transformers
from transformers import pipeline
from transformers import RobertaTokenizer, RobertaModel
from transformers import AutoModelForSequenceClassification
from transformers import TFAutoModelForSequenceClassification # Although imported, this is not used in the provided code
from transformers import AutoTokenizer, AutoModelForTokenClassification
import matplotlib.pyplot as plt
import sys
import csv
csv.field_size_limit(sys.maxsize)
device = "cuda:0" if torch.cuda.is_available() else "cpu"
# Load models and tokenizer
tokenizer = AutoTokenizer.from_pretrained("paragon-analytics/ADRv1")
model = AutoModelForSequenceClassification.from_pretrained("paragon-analytics/ADRv1").to(device)
# Build a pipeline object for predictions
# Note: return_all_scores is deprecated, use top_k=None instead
pred = transformers.pipeline("text-classification", model=model, tokenizer=tokenizer, top_k=None, device=device) # Added device=device for consistency
# SHAP explainer
# Check SHAP documentation for potential changes in Explainer initialization if issues arise.
explainer = shap.Explainer(pred)
# NER pipeline
ner_tokenizer = AutoTokenizer.from_pretrained("d4data/biomedical-ner-all")
ner_model = AutoModelForTokenClassification.from_pretrained("d4data/biomedical-ner-all")
ner_pipe = pipeline("ner", model=ner_model, tokenizer=ner_tokenizer, aggregation_strategy="simple") # pass device=0 if using gpu
# def adr_predict(x):
def adr_predict(x):
# Ensure input is treated as a string
text_input = str(x).lower()
encoded_input = tokenizer(text_input, return_tensors='pt').to(device) # Move input to device
output = model(**encoded_input)
# Use torch.softmax instead of tf.nn.softmax for consistency with torch model
scores = torch.softmax(output.logits, dim=-1)[0].detach().cpu().numpy() # Apply softmax on logits, move to cpu and convert to numpy
# SHAP values - The explainer might need the raw prediction output or a wrapped function
# that returns the model output in a format SHAP expects.
# The current usage `explainer([text_input])` assumes the explainer can directly
# handle the text input and the pipeline output structure. This might need adjustment
# based on the SHAP version and how it interfaces with Hugging Face pipelines.
# If explainer([text_input]) doesn't work, you might need to create a wrapper function
# like `def f(text): return pred(text)` and pass `f` to shap.Explainer.
# Also, shap.plots.text might have changes in how it's called or its return value.
try:
shap_values = explainer([text_input])
# Assuming shap_values structure is compatible and shap.plots.text works as before
# You might need to explicitly handle the expected output format from the pipeline
# which, with top_k=None, is a list of dictionaries. SHAP expects a consistent
# output format from the prediction function.
# If the pipeline output with top_k=None is different, the explainer might fail.
# Let's assume for now the explainer can handle the output format.
# shap.plots.text often returns a matplotlib figure or renders directly.
# To display in Gradio HTML, you might need to save the plot to a string (e.g., SVG or HTML).
# The display=False might prevent direct rendering, but check if it returns a string representation.
# If not, you'll need to generate an HTML/SVG string from the plot.
# As a fallback, we'll assume display=False makes it suitable for embedding or returns a string representation.
local_plot = shap.plots.text(shap_values[0], display=False) # This might need adjustment
except Exception as e:
print(f"SHAP explanation failed: {e}")
local_plot = "<p>SHAP explanation not available.</p>" # Provide a fallback
# NER processing
try:
res = ner_pipe(text_input)
entity_colors = {
'Severity': 'red',
'Sign_symptom': 'green',
'Medication': 'lightblue',
'Age': 'yellow',
'Sex':'yellow',
'Diagnostic_procedure':'gray',
'Biological_structure':'silver'
}
htext = ""
prev_end = 0
# Sort entities by start position to build the highlighted text correctly
res = sorted(res, key=lambda x: x['start'])
for entity in res:
start = entity['start']
end = entity['end']
word = text_input[start:end] # Extract original text segment
entity_type = entity['entity_group']
color = entity_colors.get(entity_type, 'lightgray') # Use get with a default color
# Append text before the entity
htext += f"{text_input[prev_end:start]}"
# Append the highlighted entity
htext += f"<mark style='background-color:{color};'>{word}</mark>"
prev_end = end
# Append any remaining text after the last entity
htext += text_input[prev_end:]
except Exception as e:
print(f"NER processing failed: {e}")
htext = "<p>NER processing not available.</p>" # Provide a fallback
# The original code returns a tuple of results. Gradio's click function expects
# the number of returned values to match the number of output components.
# The original return was: {"Severe Reaction": float(scores.numpy()[1]), "Non-severe Reaction": float(scores.numpy()[0])}, local_plot, htext
# The output components are: label, local_plot, htext
# The score output for the label component should be a dictionary as expected by gr.Label
label_output = {"Severe Reaction": float(scores[1]), "Non-severe Reaction": float(scores[0])}
return label_output, local_plot, htext
def main(prob1):
# The main function now directly calls adr_predict and returns its results.
# This matches the expected signature for a Gradio interface function.
return adr_predict(prob1)
title = "Welcome to **ADR Detector** πŸͺ"
description1 = """This app takes text (up to a few sentences) and predicts to what extent the text describes severe (or non-severe) adverse reaction to medicaitons. Please do NOT use for medical diagnosis."""
# Use the 'with' syntax for Blocks
with gr.Blocks(title=title) as demo:
gr.Markdown(f"## {title}")
gr.Markdown(description1)
gr.Markdown("""---""")
# Define input and output components
prob1 = gr.Textbox(label="Enter Your Text Here:",lines=2, placeholder="Type it here ...")
# Output components matching the return values of the main function
label = gr.Label(label = "Predicted Label")
local_plot = gr.HTML(label = 'Shap Explanation') # Changed label for clarity
htext = gr.HTML(label="Named Entity Recognition") # Changed label for clarity
submit_btn = gr.Button("Analyze")
# Use .click() on the button to define the interaction
submit_btn.click(
fn=main, # The function to call
inputs=[prob1], # The input components
outputs=[label, local_plot, htext], # The output components
api_name="adr" # Keep the api_name if you intend to use the API
)
# Examples section
gr.Markdown("### Click on any of the examples below to see how it works:")
# Gradio 4.0+ Examples usage. Pass inputs and outputs components directly.
# cache_examples is recommended for faster loading of examples.
gr.Examples(
examples=[
["A 35 year-old male had severe headache after taking Aspirin. The lab results were normal."],
["A 35 year-old female had minor pain in upper abdomen after taking Acetaminophen."]
],
inputs=[prob1],
outputs=[label, local_plot, htext],
fn=main, # Provide the function to run for caching examples
cache_examples=False,
run_on_click=True # This might be needed depending on exact Gradio version and desired behavior, but cache_examples=True implies running the function.
)
# Launch the demo
demo.launch()