File size: 8,624 Bytes
f68c4f8
 
 
e3f321e
 
 
f68c4f8
e3f321e
 
 
f68c4f8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e3f321e
f68c4f8
 
e3f321e
f68c4f8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e3f321e
f68c4f8
 
e3f321e
f68c4f8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e3f321e
f68c4f8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e3f321e
f68c4f8
 
e3f321e
f68c4f8
 
e3f321e
f68c4f8
 
 
e3f321e
f68c4f8
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
# app.py
import streamlit as st
from ner_module import NERModel, TextProcessor
import time
import logging

# Configure logging (optional, but helpful for debugging Streamlit apps)
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

# --- Configuration ---
DEFAULT_MODEL = "Davlan/bert-base-multilingual-cased-ner-hrl"
# Alternative models (ensure they are compatible TokenClassification models)
# DEFAULT_MODEL = "dslim/bert-base-NER" # English NER
# DEFAULT_MODEL = "xlm-roberta-large-finetuned-conll03-english" # Another English option

DEFAULT_TEXT = """
Angela Merkel met Emmanuel Macron in Berlin on Tuesday to discuss the future of the European Union.
They visited the Brandenburg Gate and enjoyed some Currywurst. Later, they flew to Paris.
John Doe from New York works at Google LLC.
"""
CHUNK_SIZE_DEFAULT = 500 # Slightly less than common 512 limit to be safe
OVERLAP_DEFAULT = 50

# --- Caching ---
@st.cache_resource(show_spinner="Loading NER Model...")
def load_ner_model(model_name: str):
    """
    Loads the NERModel using the singleton pattern and caches the instance.
    Streamlit's cache_resource is ideal for heavy objects like models.
    """
    try:
        logger.info(f"Attempting to load model: {model_name}")
        model_instance = NERModel.get_instance(model_name=model_name)
        return model_instance
    except Exception as e:
        st.error(f"Failed to load model '{model_name}'. Error: {e}", icon="🚨")
        logger.error(f"Fatal error loading model {model_name}: {e}")
        return None

# --- Helper Functions ---
def get_color_for_entity(entity_type: str) -> str:
    """Assigns a color based on the entity type for visualization."""
    # Simple color mapping, can be expanded
    colors = {
        "PER": "#faa",  # Light red for Person
        "ORG": "#afa",  # Light green for Organization
        "LOC": "#aaf",  # Light blue for Location
        "MISC": "#ffc", # Light yellow for Miscellaneous
        # Add more colors as needed based on model's entity types
    }
    # Default color if type not found
    return colors.get(entity_type.upper(), "#ddd") # Light grey default

def highlight_entities(text: str, entities: list) -> str:
    """
    Generates an HTML string with entities highlighted using spans and colors.
    Sorts entities by start position descending to handle nested entities correctly.
    """
    if not entities:
        return text

    # Sort entities by start index in descending order
    # This ensures that inner entities are processed before outer ones if they overlap
    entities.sort(key=lambda x: x['start'], reverse=True)

    highlighted_text = text
    for entity in entities:
        start = entity['start']
        end = entity['end']
        entity_type = entity['entity_type']
        word = entity['word'] # Use the extracted word for the title/tooltip
        color = get_color_for_entity(entity_type)

        # Create the highlighted span
        highlight = (
            f'<span style="background-color: {color}; padding: 0.2em 0.3em; '
            f'margin: 0 0.15em; line-height: 1; border-radius: 0.3em;" '
            f'title="{entity_type}: {word} (Score: {entity.get("score", 0):.2f})">' # Tooltip
            f'{highlighted_text[start:end]}' # Get the original text slice
            f'<sup style="font-size: 0.7em; font-weight: bold; margin-left: 2px; color: #555;">{entity_type}</sup>' # Small label
            f'</span>'
        )

        # Replace the original text portion with the highlighted version
        # Working backwards prevents index issues from altering string length
        highlighted_text = highlighted_text[:start] + highlight + highlighted_text[end:]

    return highlighted_text


# --- Streamlit App UI ---
st.set_page_config(layout="wide", page_title="NER Demo")

st.title("πŸ“ Named Entity Recognition (NER) Demo")
st.markdown("Highlight Persons (PER), Organizations (ORG), Locations (LOC), and Miscellaneous (MISC) entities in text using a Hugging Face Transformer model.")

# Model selection fixed to default for simplicity
model_name = DEFAULT_MODEL

# Load the model (cached)
ner_model = load_ner_model(model_name)

if ner_model: # Proceed only if the model loaded successfully
    st.success(f"Model '{ner_model.model_name}' loaded successfully.", icon="βœ…")

    # --- Input & Controls ---
    col1, col2 = st.columns([3, 1]) # Input area takes more space

    with col1:
        st.subheader("Input Text")
        # Use session state to keep text area content persistent across reruns
        if 'text_input' not in st.session_state:
            st.session_state.text_input = DEFAULT_TEXT
        text_input = st.text_area("Enter text here:", value=st.session_state.text_input, height=250, key="text_area_input")
        st.session_state.text_input = text_input # Update session state on change

    with col2:
        st.subheader("Options")
        use_chunking = st.checkbox("Process as Large Text (Chunking)", value=True)

        chunk_size = CHUNK_SIZE_DEFAULT
        overlap = OVERLAP_DEFAULT

        if use_chunking:
            chunk_size = st.slider("Chunk Size (chars)", min_value=100, max_value=1024, value=CHUNK_SIZE_DEFAULT, step=10)
            overlap = st.slider("Overlap (chars)", min_value=10, max_value=chunk_size // 2, value=OVERLAP_DEFAULT, step=5)

        process_button = st.button("✨ Analyze Text", type="primary", use_container_width=True)

    # --- Processing and Output ---
    if process_button and text_input:
        start_process_time = time.time()
        st.markdown("---") # Separator
        st.subheader("Analysis Results")

        with st.spinner("Analyzing text... Please wait."):
            if use_chunking:
                logger.info(f"Processing with chunking: size={chunk_size}, overlap={overlap}")
                entities = TextProcessor.process_large_text(
                    text=text_input,
                    model=ner_model,
                    chunk_size=chunk_size,
                    overlap=overlap
                )
            else:
                logger.info("Processing without chunking (potential truncation for long text)")
                entities = TextProcessor.process_large_text(
                    text=text_input,
                    model=ner_model,
                    chunk_size=max(len(text_input), 512), # Use text length or a large value
                    overlap=0 # No overlap needed for single chunk
                )

        end_process_time = time.time()
        processing_duration = end_process_time - start_process_time
        st.info(f"Analysis completed in {processing_duration:.2f} seconds. Found {len(entities)} entities.", icon="⏱️")

        if entities:
            # Display highlighted text
            st.markdown("#### Highlighted Text:")
            highlighted_html = highlight_entities(text_input, entities)
            # Use st.markdown to render the HTML
            st.markdown(highlighted_html, unsafe_allow_html=True)

            # Display entities in a table-like format
            st.markdown("#### Extracted Entities:")
            # Sort entities by appearance order for the list
            entities.sort(key=lambda x: x['start'])
            
            # Use columns for a cleaner layout
            cols = st.columns(3) # Adjust number of columns as needed
            col_idx = 0
            for entity in entities:
                with cols[col_idx % len(cols)]:
                    st.markdown(
                        f"**{entity['entity_type']}** `{entity['score']:.2f}`: "
                        f"{entity['word']} ({entity['start']}-{entity['end']})"
                    )
                col_idx += 1

            # Alternative display as an expander with detailed info
            with st.expander("Show Detailed Entity List", expanded=False):
                for entity in entities:
                    st.write(f"- **{entity['entity_type']}**: {entity['word']} (Score: {entity['score']:.2f}, Position: {entity['start']}-{entity['end']})")

        else:
            st.warning("No entities found in the provided text.", icon="❓")

    elif process_button and not text_input:
        st.warning("Please enter some text to analyze.", icon="⚠️")

else:
    # This block runs if the model failed to load
    st.error("NER model could not be loaded. Please check the logs or model name. The application cannot proceed.", icon="πŸ›‘")

# Add footer or instructions
st.markdown("---")
st.caption("Powered by Hugging Face Transformers and Streamlit.")