import gradio as gr import shap import numpy as np import scipy as sp import torch import tensorflow as tf import transformers from transformers import pipeline from transformers import RobertaTokenizer, RobertaModel from transformers import AutoModelForSequenceClassification from transformers import TFAutoModelForSequenceClassification # Although imported, this is not used in the provided code from transformers import AutoTokenizer, AutoModelForTokenClassification import matplotlib.pyplot as plt import sys import csv csv.field_size_limit(sys.maxsize) device = "cuda:0" if torch.cuda.is_available() else "cpu" # Load models and tokenizer tokenizer = AutoTokenizer.from_pretrained("paragon-analytics/ADRv1") model = AutoModelForSequenceClassification.from_pretrained("paragon-analytics/ADRv1").to(device) # Build a pipeline object for predictions # Note: return_all_scores is deprecated, use top_k=None instead pred = transformers.pipeline("text-classification", model=model, tokenizer=tokenizer, top_k=None, device=device) # Added device=device for consistency # SHAP explainer # Check SHAP documentation for potential changes in Explainer initialization if issues arise. explainer = shap.Explainer(pred) # NER pipeline ner_tokenizer = AutoTokenizer.from_pretrained("d4data/biomedical-ner-all") ner_model = AutoModelForTokenClassification.from_pretrained("d4data/biomedical-ner-all") ner_pipe = pipeline("ner", model=ner_model, tokenizer=ner_tokenizer, aggregation_strategy="simple") # pass device=0 if using gpu # def adr_predict(x): def adr_predict(x): # Ensure input is treated as a string text_input = str(x).lower() encoded_input = tokenizer(text_input, return_tensors='pt').to(device) # Move input to device output = model(**encoded_input) # Use torch.softmax instead of tf.nn.softmax for consistency with torch model scores = torch.softmax(output.logits, dim=-1)[0].detach().cpu().numpy() # Apply softmax on logits, move to cpu and convert to numpy # SHAP values - The explainer might need the raw prediction output or a wrapped function # that returns the model output in a format SHAP expects. # The current usage `explainer([text_input])` assumes the explainer can directly # handle the text input and the pipeline output structure. This might need adjustment # based on the SHAP version and how it interfaces with Hugging Face pipelines. # If explainer([text_input]) doesn't work, you might need to create a wrapper function # like `def f(text): return pred(text)` and pass `f` to shap.Explainer. # Also, shap.plots.text might have changes in how it's called or its return value. try: shap_values = explainer([text_input]) # Assuming shap_values structure is compatible and shap.plots.text works as before # You might need to explicitly handle the expected output format from the pipeline # which, with top_k=None, is a list of dictionaries. SHAP expects a consistent # output format from the prediction function. # If the pipeline output with top_k=None is different, the explainer might fail. # Let's assume for now the explainer can handle the output format. # shap.plots.text often returns a matplotlib figure or renders directly. # To display in Gradio HTML, you might need to save the plot to a string (e.g., SVG or HTML). # The display=False might prevent direct rendering, but check if it returns a string representation. # If not, you'll need to generate an HTML/SVG string from the plot. # As a fallback, we'll assume display=False makes it suitable for embedding or returns a string representation. local_plot = shap.plots.text(shap_values[0], display=False) # This might need adjustment except Exception as e: print(f"SHAP explanation failed: {e}") local_plot = "
SHAP explanation not available.
" # Provide a fallback # NER processing try: res = ner_pipe(text_input) entity_colors = { 'Severity': 'red', 'Sign_symptom': 'green', 'Medication': 'lightblue', 'Age': 'yellow', 'Sex':'yellow', 'Diagnostic_procedure':'gray', 'Biological_structure':'silver' } htext = "" prev_end = 0 # Sort entities by start position to build the highlighted text correctly res = sorted(res, key=lambda x: x['start']) for entity in res: start = entity['start'] end = entity['end'] word = text_input[start:end] # Extract original text segment entity_type = entity['entity_group'] color = entity_colors.get(entity_type, 'lightgray') # Use get with a default color # Append text before the entity htext += f"{text_input[prev_end:start]}" # Append the highlighted entity htext += f"{word}" prev_end = end # Append any remaining text after the last entity htext += text_input[prev_end:] except Exception as e: print(f"NER processing failed: {e}") htext = "NER processing not available.
" # Provide a fallback # The original code returns a tuple of results. Gradio's click function expects # the number of returned values to match the number of output components. # The original return was: {"Severe Reaction": float(scores.numpy()[1]), "Non-severe Reaction": float(scores.numpy()[0])}, local_plot, htext # The output components are: label, local_plot, htext # The score output for the label component should be a dictionary as expected by gr.Label label_output = {"Severe Reaction": float(scores[1]), "Non-severe Reaction": float(scores[0])} return label_output, local_plot, htext def main(prob1): # The main function now directly calls adr_predict and returns its results. # This matches the expected signature for a Gradio interface function. return adr_predict(prob1) title = "Welcome to **ADR Detector** 🪐" description1 = """This app takes text (up to a few sentences) and predicts to what extent the text describes severe (or non-severe) adverse reaction to medicaitons. Please do NOT use for medical diagnosis.""" # Use the 'with' syntax for Blocks with gr.Blocks(title=title) as demo: gr.Markdown(f"## {title}") gr.Markdown(description1) gr.Markdown("""---""") # Define input and output components prob1 = gr.Textbox(label="Enter Your Text Here:",lines=2, placeholder="Type it here ...") # Output components matching the return values of the main function label = gr.Label(label = "Predicted Label") local_plot = gr.HTML(label = 'Shap Explanation') # Changed label for clarity htext = gr.HTML(label="Named Entity Recognition") # Changed label for clarity submit_btn = gr.Button("Analyze") # Use .click() on the button to define the interaction submit_btn.click( fn=main, # The function to call inputs=[prob1], # The input components outputs=[label, local_plot, htext], # The output components api_name="adr" # Keep the api_name if you intend to use the API ) # Examples section gr.Markdown("### Click on any of the examples below to see how it works:") # Gradio 4.0+ Examples usage. Pass inputs and outputs components directly. # cache_examples is recommended for faster loading of examples. gr.Examples( examples=[ ["A 35 year-old male had severe headache after taking Aspirin. The lab results were normal."], ["A 35 year-old female had minor pain in upper abdomen after taking Acetaminophen."] ], inputs=[prob1], outputs=[label, local_plot, htext], fn=main, # Provide the function to run for caching examples cache_examples=False, run_on_click=True # This might be needed depending on exact Gradio version and desired behavior, but cache_examples=True implies running the function. ) # Launch the demo demo.launch()