Spaces:
Running
Running
import gradio as gr | |
import shap | |
import numpy as np | |
import scipy as sp | |
import torch | |
import tensorflow as tf | |
import transformers | |
from transformers import pipeline | |
from transformers import RobertaTokenizer, RobertaModel | |
from transformers import AutoModelForSequenceClassification | |
from transformers import TFAutoModelForSequenceClassification # Although imported, this is not used in the provided code | |
from transformers import AutoTokenizer, AutoModelForTokenClassification | |
import matplotlib.pyplot as plt | |
import sys | |
import csv | |
csv.field_size_limit(sys.maxsize) | |
device = "cuda:0" if torch.cuda.is_available() else "cpu" | |
# Load models and tokenizer | |
tokenizer = AutoTokenizer.from_pretrained("paragon-analytics/ADRv1") | |
model = AutoModelForSequenceClassification.from_pretrained("paragon-analytics/ADRv1").to(device) | |
# Build a pipeline object for predictions | |
# Note: return_all_scores is deprecated, use top_k=None instead | |
pred = transformers.pipeline("text-classification", model=model, tokenizer=tokenizer, top_k=None, device=device) # Added device=device for consistency | |
# SHAP explainer | |
# Check SHAP documentation for potential changes in Explainer initialization if issues arise. | |
explainer = shap.Explainer(pred) | |
# NER pipeline | |
ner_tokenizer = AutoTokenizer.from_pretrained("d4data/biomedical-ner-all") | |
ner_model = AutoModelForTokenClassification.from_pretrained("d4data/biomedical-ner-all") | |
ner_pipe = pipeline("ner", model=ner_model, tokenizer=ner_tokenizer, aggregation_strategy="simple") # pass device=0 if using gpu | |
# def adr_predict(x): | |
def adr_predict(x): | |
# Ensure input is treated as a string | |
text_input = str(x).lower() | |
encoded_input = tokenizer(text_input, return_tensors='pt').to(device) # Move input to device | |
output = model(**encoded_input) | |
# Use torch.softmax instead of tf.nn.softmax for consistency with torch model | |
scores = torch.softmax(output.logits, dim=-1)[0].detach().cpu().numpy() # Apply softmax on logits, move to cpu and convert to numpy | |
# SHAP values - The explainer might need the raw prediction output or a wrapped function | |
# that returns the model output in a format SHAP expects. | |
# The current usage `explainer([text_input])` assumes the explainer can directly | |
# handle the text input and the pipeline output structure. This might need adjustment | |
# based on the SHAP version and how it interfaces with Hugging Face pipelines. | |
# If explainer([text_input]) doesn't work, you might need to create a wrapper function | |
# like `def f(text): return pred(text)` and pass `f` to shap.Explainer. | |
# Also, shap.plots.text might have changes in how it's called or its return value. | |
try: | |
shap_values = explainer([text_input]) | |
# Assuming shap_values structure is compatible and shap.plots.text works as before | |
# You might need to explicitly handle the expected output format from the pipeline | |
# which, with top_k=None, is a list of dictionaries. SHAP expects a consistent | |
# output format from the prediction function. | |
# If the pipeline output with top_k=None is different, the explainer might fail. | |
# Let's assume for now the explainer can handle the output format. | |
# shap.plots.text often returns a matplotlib figure or renders directly. | |
# To display in Gradio HTML, you might need to save the plot to a string (e.g., SVG or HTML). | |
# The display=False might prevent direct rendering, but check if it returns a string representation. | |
# If not, you'll need to generate an HTML/SVG string from the plot. | |
# As a fallback, we'll assume display=False makes it suitable for embedding or returns a string representation. | |
local_plot = shap.plots.text(shap_values[0], display=False) # This might need adjustment | |
except Exception as e: | |
print(f"SHAP explanation failed: {e}") | |
local_plot = "<p>SHAP explanation not available.</p>" # Provide a fallback | |
# NER processing | |
try: | |
res = ner_pipe(text_input) | |
entity_colors = { | |
'Severity': 'red', | |
'Sign_symptom': 'green', | |
'Medication': 'lightblue', | |
'Age': 'yellow', | |
'Sex':'yellow', | |
'Diagnostic_procedure':'gray', | |
'Biological_structure':'silver' | |
} | |
htext = "" | |
prev_end = 0 | |
# Sort entities by start position to build the highlighted text correctly | |
res = sorted(res, key=lambda x: x['start']) | |
for entity in res: | |
start = entity['start'] | |
end = entity['end'] | |
word = text_input[start:end] # Extract original text segment | |
entity_type = entity['entity_group'] | |
color = entity_colors.get(entity_type, 'lightgray') # Use get with a default color | |
# Append text before the entity | |
htext += f"{text_input[prev_end:start]}" | |
# Append the highlighted entity | |
htext += f"<mark style='background-color:{color};'>{word}</mark>" | |
prev_end = end | |
# Append any remaining text after the last entity | |
htext += text_input[prev_end:] | |
except Exception as e: | |
print(f"NER processing failed: {e}") | |
htext = "<p>NER processing not available.</p>" # Provide a fallback | |
# The original code returns a tuple of results. Gradio's click function expects | |
# the number of returned values to match the number of output components. | |
# The original return was: {"Severe Reaction": float(scores.numpy()[1]), "Non-severe Reaction": float(scores.numpy()[0])}, local_plot, htext | |
# The output components are: label, local_plot, htext | |
# The score output for the label component should be a dictionary as expected by gr.Label | |
label_output = {"Severe Reaction": float(scores[1]), "Non-severe Reaction": float(scores[0])} | |
return label_output, local_plot, htext | |
def main(prob1): | |
# The main function now directly calls adr_predict and returns its results. | |
# This matches the expected signature for a Gradio interface function. | |
return adr_predict(prob1) | |
title = "Welcome to **ADR Detector** πͺ" | |
description1 = """This app takes text (up to a few sentences) and predicts to what extent the text describes severe (or non-severe) adverse reaction to medicaitons. Please do NOT use for medical diagnosis.""" | |
# Use the 'with' syntax for Blocks | |
with gr.Blocks(title=title) as demo: | |
gr.Markdown(f"## {title}") | |
gr.Markdown(description1) | |
gr.Markdown("""---""") | |
# Define input and output components | |
prob1 = gr.Textbox(label="Enter Your Text Here:",lines=2, placeholder="Type it here ...") | |
# Output components matching the return values of the main function | |
label = gr.Label(label = "Predicted Label") | |
local_plot = gr.HTML(label = 'Shap Explanation') # Changed label for clarity | |
htext = gr.HTML(label="Named Entity Recognition") # Changed label for clarity | |
submit_btn = gr.Button("Analyze") | |
# Use .click() on the button to define the interaction | |
submit_btn.click( | |
fn=main, # The function to call | |
inputs=[prob1], # The input components | |
outputs=[label, local_plot, htext], # The output components | |
api_name="adr" # Keep the api_name if you intend to use the API | |
) | |
# Examples section | |
gr.Markdown("### Click on any of the examples below to see how it works:") | |
# Gradio 4.0+ Examples usage. Pass inputs and outputs components directly. | |
# cache_examples is recommended for faster loading of examples. | |
gr.Examples( | |
examples=[ | |
["A 35 year-old male had severe headache after taking Aspirin. The lab results were normal."], | |
["A 35 year-old female had minor pain in upper abdomen after taking Acetaminophen."] | |
], | |
inputs=[prob1], | |
outputs=[label, local_plot, htext], | |
fn=main, # Provide the function to run for caching examples | |
cache_examples=False, | |
run_on_click=True # This might be needed depending on exact Gradio version and desired behavior, but cache_examples=True implies running the function. | |
) | |
# Launch the demo | |
demo.launch() |