|
|
|
import streamlit as st |
|
from ner_module import NERModel, TextProcessor |
|
import time |
|
import logging |
|
|
|
|
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
DEFAULT_MODEL = "Davlan/bert-base-multilingual-cased-ner-hrl" |
|
|
|
|
|
|
|
|
|
DEFAULT_TEXT = """ |
|
Angela Merkel met Emmanuel Macron in Berlin on Tuesday to discuss the future of the European Union. |
|
They visited the Brandenburg Gate and enjoyed some Currywurst. Later, they flew to Paris. |
|
John Doe from New York works at Google LLC. |
|
""" |
|
CHUNK_SIZE_DEFAULT = 500 |
|
OVERLAP_DEFAULT = 50 |
|
|
|
|
|
@st.cache_resource(show_spinner="Loading NER Model...") |
|
def load_ner_model(model_name: str): |
|
""" |
|
Loads the NERModel using the singleton pattern and caches the instance. |
|
Streamlit's cache_resource is ideal for heavy objects like models. |
|
""" |
|
try: |
|
logger.info(f"Attempting to load model: {model_name}") |
|
model_instance = NERModel.get_instance(model_name=model_name) |
|
return model_instance |
|
except Exception as e: |
|
st.error(f"Failed to load model '{model_name}'. Error: {e}", icon="π¨") |
|
logger.error(f"Fatal error loading model {model_name}: {e}") |
|
return None |
|
|
|
|
|
def get_color_for_entity(entity_type: str) -> str: |
|
"""Assigns a color based on the entity type for visualization.""" |
|
|
|
colors = { |
|
"PER": "#faa", |
|
"ORG": "#afa", |
|
"LOC": "#aaf", |
|
"MISC": "#ffc", |
|
|
|
} |
|
|
|
return colors.get(entity_type.upper(), "#ddd") |
|
|
|
def highlight_entities(text: str, entities: list) -> str: |
|
""" |
|
Generates an HTML string with entities highlighted using spans and colors. |
|
Sorts entities by start position descending to handle nested entities correctly. |
|
""" |
|
if not entities: |
|
return text |
|
|
|
|
|
|
|
entities.sort(key=lambda x: x['start'], reverse=True) |
|
|
|
highlighted_text = text |
|
for entity in entities: |
|
start = entity['start'] |
|
end = entity['end'] |
|
entity_type = entity['entity_type'] |
|
word = entity['word'] |
|
color = get_color_for_entity(entity_type) |
|
|
|
|
|
highlight = ( |
|
f'<span style="background-color: {color}; padding: 0.2em 0.3em; ' |
|
f'margin: 0 0.15em; line-height: 1; border-radius: 0.3em;" ' |
|
f'title="{entity_type}: {word} (Score: {entity.get("score", 0):.2f})">' |
|
f'{highlighted_text[start:end]}' |
|
f'<sup style="font-size: 0.7em; font-weight: bold; margin-left: 2px; color: #555;">{entity_type}</sup>' |
|
f'</span>' |
|
) |
|
|
|
|
|
|
|
highlighted_text = highlighted_text[:start] + highlight + highlighted_text[end:] |
|
|
|
return highlighted_text |
|
|
|
|
|
|
|
st.set_page_config(layout="wide", page_title="NER Demo") |
|
|
|
st.title("π Named Entity Recognition (NER) Demo") |
|
st.markdown("Highlight Persons (PER), Organizations (ORG), Locations (LOC), and Miscellaneous (MISC) entities in text using a Hugging Face Transformer model.") |
|
|
|
|
|
model_name = DEFAULT_MODEL |
|
|
|
|
|
ner_model = load_ner_model(model_name) |
|
|
|
if ner_model: |
|
st.success(f"Model '{ner_model.model_name}' loaded successfully.", icon="β
") |
|
|
|
|
|
col1, col2 = st.columns([3, 1]) |
|
|
|
with col1: |
|
st.subheader("Input Text") |
|
|
|
if 'text_input' not in st.session_state: |
|
st.session_state.text_input = DEFAULT_TEXT |
|
text_input = st.text_area("Enter text here:", value=st.session_state.text_input, height=250, key="text_area_input") |
|
st.session_state.text_input = text_input |
|
|
|
with col2: |
|
st.subheader("Options") |
|
use_chunking = st.checkbox("Process as Large Text (Chunking)", value=True) |
|
|
|
chunk_size = CHUNK_SIZE_DEFAULT |
|
overlap = OVERLAP_DEFAULT |
|
|
|
if use_chunking: |
|
chunk_size = st.slider("Chunk Size (chars)", min_value=100, max_value=1024, value=CHUNK_SIZE_DEFAULT, step=10) |
|
overlap = st.slider("Overlap (chars)", min_value=10, max_value=chunk_size // 2, value=OVERLAP_DEFAULT, step=5) |
|
|
|
process_button = st.button("β¨ Analyze Text", type="primary", use_container_width=True) |
|
|
|
|
|
if process_button and text_input: |
|
start_process_time = time.time() |
|
st.markdown("---") |
|
st.subheader("Analysis Results") |
|
|
|
with st.spinner("Analyzing text... Please wait."): |
|
if use_chunking: |
|
logger.info(f"Processing with chunking: size={chunk_size}, overlap={overlap}") |
|
entities = TextProcessor.process_large_text( |
|
text=text_input, |
|
model=ner_model, |
|
chunk_size=chunk_size, |
|
overlap=overlap |
|
) |
|
else: |
|
logger.info("Processing without chunking (potential truncation for long text)") |
|
entities = TextProcessor.process_large_text( |
|
text=text_input, |
|
model=ner_model, |
|
chunk_size=max(len(text_input), 512), |
|
overlap=0 |
|
) |
|
|
|
end_process_time = time.time() |
|
processing_duration = end_process_time - start_process_time |
|
st.info(f"Analysis completed in {processing_duration:.2f} seconds. Found {len(entities)} entities.", icon="β±οΈ") |
|
|
|
if entities: |
|
|
|
st.markdown("#### Highlighted Text:") |
|
highlighted_html = highlight_entities(text_input, entities) |
|
|
|
st.markdown(highlighted_html, unsafe_allow_html=True) |
|
|
|
|
|
st.markdown("#### Extracted Entities:") |
|
|
|
entities.sort(key=lambda x: x['start']) |
|
|
|
|
|
cols = st.columns(3) |
|
col_idx = 0 |
|
for entity in entities: |
|
with cols[col_idx % len(cols)]: |
|
st.markdown( |
|
f"**{entity['entity_type']}** `{entity['score']:.2f}`: " |
|
f"{entity['word']} ({entity['start']}-{entity['end']})" |
|
) |
|
col_idx += 1 |
|
|
|
|
|
with st.expander("Show Detailed Entity List", expanded=False): |
|
for entity in entities: |
|
st.write(f"- **{entity['entity_type']}**: {entity['word']} (Score: {entity['score']:.2f}, Position: {entity['start']}-{entity['end']})") |
|
|
|
else: |
|
st.warning("No entities found in the provided text.", icon="β") |
|
|
|
elif process_button and not text_input: |
|
st.warning("Please enter some text to analyze.", icon="β οΈ") |
|
|
|
else: |
|
|
|
st.error("NER model could not be loaded. Please check the logs or model name. The application cannot proceed.", icon="π") |
|
|
|
|
|
st.markdown("---") |
|
st.caption("Powered by Hugging Face Transformers and Streamlit.") |
|
|