# app.py import streamlit as st from ner_module import NERModel, TextProcessor import time import logging # Configure logging (optional, but helpful for debugging Streamlit apps) logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') logger = logging.getLogger(__name__) # --- Configuration --- DEFAULT_MODEL = "Davlan/bert-base-multilingual-cased-ner-hrl" # Alternative models (ensure they are compatible TokenClassification models) # DEFAULT_MODEL = "dslim/bert-base-NER" # English NER # DEFAULT_MODEL = "xlm-roberta-large-finetuned-conll03-english" # Another English option DEFAULT_TEXT = """ Angela Merkel met Emmanuel Macron in Berlin on Tuesday to discuss the future of the European Union. They visited the Brandenburg Gate and enjoyed some Currywurst. Later, they flew to Paris. John Doe from New York works at Google LLC. """ CHUNK_SIZE_DEFAULT = 500 # Slightly less than common 512 limit to be safe OVERLAP_DEFAULT = 50 # --- Caching --- @st.cache_resource(show_spinner="Loading NER Model...") def load_ner_model(model_name: str): """ Loads the NERModel using the singleton pattern and caches the instance. Streamlit's cache_resource is ideal for heavy objects like models. """ try: logger.info(f"Attempting to load model: {model_name}") model_instance = NERModel.get_instance(model_name=model_name) return model_instance except Exception as e: st.error(f"Failed to load model '{model_name}'. Error: {e}", icon="🚨") logger.error(f"Fatal error loading model {model_name}: {e}") return None # --- Helper Functions --- def get_color_for_entity(entity_type: str) -> str: """Assigns a color based on the entity type for visualization.""" # Simple color mapping, can be expanded colors = { "PER": "#faa", # Light red for Person "ORG": "#afa", # Light green for Organization "LOC": "#aaf", # Light blue for Location "MISC": "#ffc", # Light yellow for Miscellaneous # Add more colors as needed based on model's entity types } # Default color if type not found return colors.get(entity_type.upper(), "#ddd") # Light grey default def highlight_entities(text: str, entities: list) -> str: """ Generates an HTML string with entities highlighted using spans and colors. Sorts entities by start position descending to handle nested entities correctly. """ if not entities: return text # Sort entities by start index in descending order # This ensures that inner entities are processed before outer ones if they overlap entities.sort(key=lambda x: x['start'], reverse=True) highlighted_text = text for entity in entities: start = entity['start'] end = entity['end'] entity_type = entity['entity_type'] word = entity['word'] # Use the extracted word for the title/tooltip color = get_color_for_entity(entity_type) # Create the highlighted span highlight = ( f'' # Tooltip f'{highlighted_text[start:end]}' # Get the original text slice f'{entity_type}' # Small label f'' ) # Replace the original text portion with the highlighted version # Working backwards prevents index issues from altering string length highlighted_text = highlighted_text[:start] + highlight + highlighted_text[end:] return highlighted_text # --- Streamlit App UI --- st.set_page_config(layout="wide", page_title="NER Demo") st.title("📝 Named Entity Recognition (NER) Demo") st.markdown("Highlight Persons (PER), Organizations (ORG), Locations (LOC), and Miscellaneous (MISC) entities in text using a Hugging Face Transformer model.") # Model selection fixed to default for simplicity model_name = DEFAULT_MODEL # Load the model (cached) ner_model = load_ner_model(model_name) if ner_model: # Proceed only if the model loaded successfully st.success(f"Model '{ner_model.model_name}' loaded successfully.", icon="✅") # --- Input & Controls --- col1, col2 = st.columns([3, 1]) # Input area takes more space with col1: st.subheader("Input Text") # Use session state to keep text area content persistent across reruns if 'text_input' not in st.session_state: st.session_state.text_input = DEFAULT_TEXT text_input = st.text_area("Enter text here:", value=st.session_state.text_input, height=250, key="text_area_input") st.session_state.text_input = text_input # Update session state on change with col2: st.subheader("Options") use_chunking = st.checkbox("Process as Large Text (Chunking)", value=True) chunk_size = CHUNK_SIZE_DEFAULT overlap = OVERLAP_DEFAULT if use_chunking: chunk_size = st.slider("Chunk Size (chars)", min_value=100, max_value=1024, value=CHUNK_SIZE_DEFAULT, step=10) overlap = st.slider("Overlap (chars)", min_value=10, max_value=chunk_size // 2, value=OVERLAP_DEFAULT, step=5) process_button = st.button("✨ Analyze Text", type="primary", use_container_width=True) # --- Processing and Output --- if process_button and text_input: start_process_time = time.time() st.markdown("---") # Separator st.subheader("Analysis Results") with st.spinner("Analyzing text... Please wait."): if use_chunking: logger.info(f"Processing with chunking: size={chunk_size}, overlap={overlap}") entities = TextProcessor.process_large_text( text=text_input, model=ner_model, chunk_size=chunk_size, overlap=overlap ) else: logger.info("Processing without chunking (potential truncation for long text)") entities = TextProcessor.process_large_text( text=text_input, model=ner_model, chunk_size=max(len(text_input), 512), # Use text length or a large value overlap=0 # No overlap needed for single chunk ) end_process_time = time.time() processing_duration = end_process_time - start_process_time st.info(f"Analysis completed in {processing_duration:.2f} seconds. Found {len(entities)} entities.", icon="⏱️") if entities: # Display highlighted text st.markdown("#### Highlighted Text:") highlighted_html = highlight_entities(text_input, entities) # Use st.markdown to render the HTML st.markdown(highlighted_html, unsafe_allow_html=True) # Display entities in a table-like format st.markdown("#### Extracted Entities:") # Sort entities by appearance order for the list entities.sort(key=lambda x: x['start']) # Use columns for a cleaner layout cols = st.columns(3) # Adjust number of columns as needed col_idx = 0 for entity in entities: with cols[col_idx % len(cols)]: st.markdown( f"**{entity['entity_type']}** `{entity['score']:.2f}`: " f"{entity['word']} ({entity['start']}-{entity['end']})" ) col_idx += 1 # Alternative display as an expander with detailed info with st.expander("Show Detailed Entity List", expanded=False): for entity in entities: st.write(f"- **{entity['entity_type']}**: {entity['word']} (Score: {entity['score']:.2f}, Position: {entity['start']}-{entity['end']})") else: st.warning("No entities found in the provided text.", icon="❓") elif process_button and not text_input: st.warning("Please enter some text to analyze.", icon="⚠️") else: # This block runs if the model failed to load st.error("NER model could not be loaded. Please check the logs or model name. The application cannot proceed.", icon="🛑") # Add footer or instructions st.markdown("---") st.caption("Powered by Hugging Face Transformers and Streamlit.")