File size: 8,139 Bytes
74d7743
 
261ea5b
74d7743
e669abf
74d7743
 
 
 
 
 
 
 
 
 
55ecc4c
74d7743
 
7c8cdc0
74d7743
 
 
 
 
 
97ace60
74d7743
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e669abf
74d7743
b6e3578
74d7743
 
 
 
 
 
41b32e2
b6e3578
74d7743
 
 
 
b6e3578
 
74d7743
 
b6e3578
74d7743
 
 
9c8aa75
9c00224
e669abf
b6e3578
74d7743
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
import gradio as gr
import shap
import numpy as np
import scipy as sp
import torch
import tensorflow as tf
import transformers
from transformers import pipeline
from transformers import RobertaTokenizer, RobertaModel
from transformers import AutoModelForSequenceClassification
from transformers import TFAutoModelForSequenceClassification # Although imported, this is not used in the provided code
from transformers import AutoTokenizer, AutoModelForTokenClassification
import matplotlib.pyplot as plt
import sys
import csv

csv.field_size_limit(sys.maxsize)
device = "cuda:0" if torch.cuda.is_available() else "cpu"

# Load models and tokenizer
tokenizer = AutoTokenizer.from_pretrained("paragon-analytics/ADRv1")
model = AutoModelForSequenceClassification.from_pretrained("paragon-analytics/ADRv1").to(device)

# Build a pipeline object for predictions
# Note: return_all_scores is deprecated, use top_k=None instead
pred = transformers.pipeline("text-classification", model=model, tokenizer=tokenizer, top_k=None, device=device) # Added device=device for consistency

# SHAP explainer
# Check SHAP documentation for potential changes in Explainer initialization if issues arise.
explainer = shap.Explainer(pred)

# NER pipeline
ner_tokenizer = AutoTokenizer.from_pretrained("d4data/biomedical-ner-all")
ner_model = AutoModelForTokenClassification.from_pretrained("d4data/biomedical-ner-all")
ner_pipe = pipeline("ner", model=ner_model, tokenizer=ner_tokenizer, aggregation_strategy="simple") # pass device=0 if using gpu

# def adr_predict(x):
def adr_predict(x):
    # Ensure input is treated as a string
    text_input = str(x).lower()

    encoded_input = tokenizer(text_input, return_tensors='pt').to(device) # Move input to device
    output = model(**encoded_input)

    # Use torch.softmax instead of tf.nn.softmax for consistency with torch model
    scores = torch.softmax(output.logits, dim=-1)[0].detach().cpu().numpy() # Apply softmax on logits, move to cpu and convert to numpy

    # SHAP values - The explainer might need the raw prediction output or a wrapped function
    # that returns the model output in a format SHAP expects.
    # The current usage `explainer([text_input])` assumes the explainer can directly
    # handle the text input and the pipeline output structure. This might need adjustment
    # based on the SHAP version and how it interfaces with Hugging Face pipelines.
    # If explainer([text_input]) doesn't work, you might need to create a wrapper function
    # like `def f(text): return pred(text)` and pass `f` to shap.Explainer.
    # Also, shap.plots.text might have changes in how it's called or its return value.
    try:
        shap_values = explainer([text_input])
        # Assuming shap_values structure is compatible and shap.plots.text works as before
        # You might need to explicitly handle the expected output format from the pipeline
        # which, with top_k=None, is a list of dictionaries. SHAP expects a consistent
        # output format from the prediction function.
        # If the pipeline output with top_k=None is different, the explainer might fail.
        # Let's assume for now the explainer can handle the output format.

        # shap.plots.text often returns a matplotlib figure or renders directly.
        # To display in Gradio HTML, you might need to save the plot to a string (e.g., SVG or HTML).
        # The display=False might prevent direct rendering, but check if it returns a string representation.
        # If not, you'll need to generate an HTML/SVG string from the plot.
        # As a fallback, we'll assume display=False makes it suitable for embedding or returns a string representation.
        local_plot = shap.plots.text(shap_values[0], display=False) # This might need adjustment

    except Exception as e:
        print(f"SHAP explanation failed: {e}")
        local_plot = "<p>SHAP explanation not available.</p>" # Provide a fallback

    # NER processing
    try:
        res = ner_pipe(text_input)

        entity_colors = {
            'Severity': 'red',
            'Sign_symptom': 'green',
            'Medication': 'lightblue',
            'Age': 'yellow',
            'Sex':'yellow',
            'Diagnostic_procedure':'gray',
            'Biological_structure':'silver'
        }
        htext = ""
        prev_end = 0
        # Sort entities by start position to build the highlighted text correctly
        res = sorted(res, key=lambda x: x['start'])

        for entity in res:
            start = entity['start']
            end = entity['end']
            word = text_input[start:end] # Extract original text segment
            entity_type = entity['entity_group']
            color = entity_colors.get(entity_type, 'lightgray') # Use get with a default color

            # Append text before the entity
            htext += f"{text_input[prev_end:start]}"
            # Append the highlighted entity
            htext += f"<mark style='background-color:{color};'>{word}</mark>"
            prev_end = end
        # Append any remaining text after the last entity
        htext += text_input[prev_end:]
    except Exception as e:
        print(f"NER processing failed: {e}")
        htext = "<p>NER processing not available.</p>" # Provide a fallback


    # The original code returns a tuple of results. Gradio's click function expects
    # the number of returned values to match the number of output components.
    # The original return was: {"Severe Reaction": float(scores.numpy()[1]), "Non-severe Reaction": float(scores.numpy()[0])}, local_plot, htext
    # The output components are: label, local_plot, htext
    # The score output for the label component should be a dictionary as expected by gr.Label
    label_output = {"Severe Reaction": float(scores[1]), "Non-severe Reaction": float(scores[0])}

    return label_output, local_plot, htext

def main(prob1):
    # The main function now directly calls adr_predict and returns its results.
    # This matches the expected signature for a Gradio interface function.
    return adr_predict(prob1)

title = "Welcome to **ADR Detector** 🪐"
description1 = """This app takes text (up to a few sentences) and predicts to what extent the text describes severe (or non-severe) adverse reaction to medicaitons. Please do NOT use for medical diagnosis."""

# Use the 'with' syntax for Blocks
with gr.Blocks(title=title) as demo:
    gr.Markdown(f"## {title}")
    gr.Markdown(description1)
    gr.Markdown("""---""")

    # Define input and output components
    prob1 = gr.Textbox(label="Enter Your Text Here:",lines=2, placeholder="Type it here ...")

    # Output components matching the return values of the main function
    label = gr.Label(label = "Predicted Label")
    local_plot = gr.HTML(label = 'Shap Explanation') # Changed label for clarity
    htext = gr.HTML(label="Named Entity Recognition") # Changed label for clarity

    submit_btn = gr.Button("Analyze")

    # Use .click() on the button to define the interaction
    submit_btn.click(
        fn=main, # The function to call
        inputs=[prob1], # The input components
        outputs=[label, local_plot, htext], # The output components
        api_name="adr" # Keep the api_name if you intend to use the API
    )

    # Examples section
    gr.Markdown("### Click on any of the examples below to see how it works:")
    # Gradio 4.0+ Examples usage. Pass inputs and outputs components directly.
    # cache_examples is recommended for faster loading of examples.
    gr.Examples(
        examples=[
            ["A 35 year-old male had severe headache after taking Aspirin. The lab results were normal."],
            ["A 35 year-old female had minor pain in upper abdomen after taking Acetaminophen."]
        ],
        inputs=[prob1],
        outputs=[label, local_plot, htext],
        fn=main, # Provide the function to run for caching examples
        cache_examples=False,
        run_on_click=True # This might be needed depending on exact Gradio version and desired behavior, but cache_examples=True implies running the function.
    )

# Launch the demo
demo.launch()