# ner_module.py import torch import time from typing import List, Dict, Any, Tuple from transformers import AutoTokenizer, AutoModelForTokenClassification, pipeline import logging # Configure logging logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') logger = logging.getLogger(__name__) class NERModel: """ A singleton class to manage the NER model loading and prediction. Ensures the potentially large model is loaded only once. """ _instance = None _model = None _tokenizer = None _pipeline = None _model_name = None # Store model name used for initialization @classmethod def get_instance(cls, model_name: str = "Davlan/bert-base-multilingual-cased-ner-hrl"): """ Singleton pattern: Get the existing instance or create a new one. Uses the specified model_name only during the first initialization. """ if cls._instance is None: logger.info(f"Creating new NERModel instance with model: {model_name}") cls._instance = cls(model_name) elif cls._model_name != model_name: logger.warning(f"NERModel already initialized with {cls._model_name}. Ignoring new model name {model_name}.") return cls._instance def __init__(self, model_name: str): """ Initialize the model, tokenizer, and pipeline. Private constructor - use get_instance() instead. """ if NERModel._instance is not None: raise Exception("This class is a singleton! Use get_instance() to get the object.") else: self.model_name = model_name NERModel._model_name = model_name # Store the model name self._load_model() NERModel._instance = self # Assign the instance here def _load_model(self): """Load the NER model and tokenizer from Hugging Face.""" logger.info(f"Loading model: {self.model_name}") start_time = time.time() try: # Load tokenizer and model self._tokenizer = AutoTokenizer.from_pretrained(self.model_name) self._model = AutoModelForTokenClassification.from_pretrained(self.model_name) # Check if the model is a PyTorch model for potential optimizations if isinstance(self._model, torch.nn.Module): self._model.eval() # Set model to evaluation mode (important for inference) # Create the NER pipeline self._pipeline = pipeline( "ner", model=self._model, tokenizer=self._tokenizer, # grouped_entities=True # Uncomment if you want to use pipeline's built-in grouping ) load_time = time.time() - start_time logger.info(f"Model '{self.model_name}' loaded successfully in {load_time:.2f} seconds.") except Exception as e: logger.error(f"Error loading model {self.model_name}: {e}") # Clean up partial loads if necessary self._tokenizer = None self._model = None self._pipeline = None # Re-raise the exception to signal failure raise def predict(self, text: str) -> List[Dict[str, Any]]: """ Run NER prediction on the input text using the loaded pipeline. Args: text: The input string to perform NER on. Returns: A list of dictionaries, where each dictionary represents an entity identified by the pipeline. """ if self._pipeline is None: logger.error("NER pipeline is not initialized. Cannot predict.") return [] # Return empty list or raise an error if not text or not isinstance(text, str): logger.warning("Prediction called with empty or invalid text.") return [] logger.debug(f"Running prediction on text: '{text[:100]}...'") # Log snippet try: # The pipeline handles tokenization and prediction results = self._pipeline(text) logger.debug(f"Prediction results: {results}") return results except Exception as e: logger.error(f"Error during NER prediction: {e}") return [] # Return empty list on error class TextProcessor: """ Provides static methods for processing text, specifically for NER tasks, including combining subword entities and handling large texts via chunking. """ @staticmethod def combine_entities(ner_results: List[Dict[str, Any]], original_text: str) -> List[Dict[str, Any]]: """ Combine entities that might be split into subword tokens (B-TAG, I-TAG). This method assumes the pipeline did *not* use grouped_entities=True. Args: ner_results: The raw output from the NER pipeline (list of token dictionaries). original_text: The original text input to extract entity words accurately. Returns: A list of dictionaries, each representing a combined entity with 'entity_type', 'start', 'end', 'score', and 'word'. """ if not ner_results: return [] combined_entities = [] current_entity = None for token in ner_results: # Basic validation of token structure if not all(k in token for k in ['entity', 'start', 'end', 'score']): logger.warning(f"Skipping malformed token: {token}") continue # Skip 'O' tags (Outside any entity) if token['entity'] == 'O': # If we were tracking an entity, finalize it before moving on if current_entity: combined_entities.append(current_entity) current_entity = None continue # Extract entity type (e.g., 'PER', 'LOC') removing 'B-' or 'I-' entity_tag = token['entity'] if entity_tag.startswith('B-') or entity_tag.startswith('I-'): entity_type = entity_tag[2:] else: # Handle cases where the tag might not have B-/I- prefix (less common) entity_type = entity_tag # Start of a new entity ('B-') or continuation of a different entity type if entity_tag.startswith('B-') or (entity_tag.startswith('I-') and (not current_entity or current_entity['entity_type'] != entity_type)): # Finalize the previous entity if it exists if current_entity: combined_entities.append(current_entity) # Start the new entity current_entity = { 'entity_type': entity_type, 'start': token['start'], 'end': token['end'], 'score': float(token['score']), 'token_count': 1 # Keep track of tokens for averaging score } # Continuation of the current entity ('I-' and matching type) elif entity_tag.startswith('I-') and current_entity and current_entity['entity_type'] == entity_type: # Extend the end position current_entity['end'] = token['end'] # Update the score (e.g., average) current_entity['score'] = (current_entity['score'] * current_entity['token_count'] + float(token['score'])) / (current_entity['token_count'] + 1) current_entity['token_count'] += 1 # Handle unexpected cases (e.g., I- tag without preceding B- or matching I-) else: logger.warning(f"Encountered unexpected token sequence at: {token}. Starting new entity.") if current_entity: combined_entities.append(current_entity) # Try to create a new entity from this token current_entity = { 'entity_type': entity_type, 'start': token['start'], 'end': token['end'], 'score': float(token['score']), 'token_count': 1 } # Add the last tracked entity if it exists if current_entity: combined_entities.append(current_entity) # Extract the actual text 'word' for each combined entity for entity in combined_entities: try: # Ensure indices are valid start = max(0, min(entity['start'], len(original_text))) end = max(start, min(entity['end'], len(original_text))) entity['word'] = original_text[start:end].strip() # Remove internal helper key if 'token_count' in entity: del entity['token_count'] except Exception as e: logger.error(f"Error extracting word for entity: {entity}, error: {e}") entity['word'] = "[Error extracting word]" # Sort entities by start position combined_entities.sort(key=lambda x: x['start']) logger.info(f"Combined {len(ner_results)} raw tokens into {len(combined_entities)} entities.") return combined_entities @staticmethod def process_large_text(text: str, model: NERModel, chunk_size: int = 512, overlap: int = 50) -> List[Dict[str, Any]]: """ Process large text by splitting it into overlapping chunks, running NER on each chunk, and then combining the results intelligently. Args: text: The large input text string. model: The initialized NERModel instance. chunk_size: The maximum size of each text chunk. overlap: The number of characters to overlap between consecutive chunks. Returns: A list of combined entity dictionaries for the entire text. """ if not text: return [] # Use tokenizer max length if available and smaller than chunk_size if model._tokenizer and hasattr(model._tokenizer, 'model_max_length'): tokenizer_max_len = model._tokenizer.model_max_length if chunk_size > tokenizer_max_len: logger.warning(f"Requested chunk_size {chunk_size} exceeds model max length {tokenizer_max_len}. Using {tokenizer_max_len}.") chunk_size = tokenizer_max_len # Ensure overlap is reasonable compared to chunk size if overlap >= chunk_size // 2: logger.warning(f"Overlap {overlap} seems large for chunk_size {chunk_size}. Reducing overlap to {chunk_size // 4}.") overlap = chunk_size // 4 logger.info(f"Processing large text (length {len(text)}) with chunk_size={chunk_size}, overlap={overlap}") chunks = TextProcessor._create_chunks(text, chunk_size, overlap) logger.info(f"Split text into {len(chunks)} chunks.") all_raw_results = [] total_processing_time = 0 for i, (chunk_text, start_pos) in enumerate(chunks): logger.debug(f"Processing chunk {i+1}/{len(chunks)} (start_pos: {start_pos}, length: {len(chunk_text)})") start_time = time.time() # Get raw predictions for the current chunk raw_results_chunk = model.predict(chunk_text) chunk_processing_time = time.time() - start_time total_processing_time += chunk_processing_time logger.debug(f"Chunk {i+1} processed in {chunk_processing_time:.2f}s. Found {len(raw_results_chunk)} raw entities.") # Adjust entity positions relative to the original text for result in raw_results_chunk: # Check if 'start' and 'end' exist before adjusting if 'start' in result and 'end' in result: result['start'] += start_pos result['end'] += start_pos else: logger.warning(f"Skipping position adjustment for malformed result in chunk {i+1}: {result}") all_raw_results.extend(raw_results_chunk) logger.info(f"Finished processing all chunks in {total_processing_time:.2f} seconds.") logger.info(f"Total raw entities found across all chunks: {len(all_raw_results)}") # Combine entities from all chunks combined_entities = TextProcessor.combine_entities(all_raw_results, text) # Deduplicate entities based on overlapping positions # Two entities are considered duplicates if they have the same type and # overlap by more than 50% of the shorter entity's length unique_entities = [] for entity in combined_entities: is_duplicate = False # Calculate entity length for overlap comparison entity_length = entity['end'] - entity['start'] for existing in unique_entities: if existing['entity_type'] == entity['entity_type']: # Check for significant overlap overlap_start = max(entity['start'], existing['start']) overlap_end = min(entity['end'], existing['end']) if overlap_start < overlap_end: # They overlap overlap_length = overlap_end - overlap_start shorter_length = min(entity_length, existing['end'] - existing['start']) # If overlap is significant (>50% of shorter entity) if overlap_length > 0.5 * shorter_length: is_duplicate = True # Keep the one with higher score if entity['score'] > existing['score']: # Replace the existing entity with this one unique_entities.remove(existing) is_duplicate = False break if not is_duplicate: unique_entities.append(entity) logger.info(f"Final number of unique combined entities: {len(unique_entities)}") return unique_entities @staticmethod def _create_chunks(text: str, chunk_size: int = 512, overlap: int = 50) -> List[Tuple[str, int]]: """ Split text into potentially overlapping chunks, trying to respect word boundaries. Args: text: The input text string. chunk_size: The target maximum size of each chunk. overlap: The desired overlap between consecutive chunks. Returns: A list of tuples, where each tuple contains (chunk_text, start_position_in_original_text). """ if not text: return [] if chunk_size <= overlap: raise ValueError("chunk_size must be greater than overlap") if chunk_size <= 0: raise ValueError("chunk_size must be positive") chunks = [] start = 0 text_len = len(text) while start < text_len: # Determine the ideal end position end = min(start + chunk_size, text_len) # If we're at the end of the text, just use what's left if end >= text_len: chunks.append((text[start:], start)) break # Try to find a suitable split point (whitespace) to ensure we don't cut words split_pos = -1 # Search backwards from end to find a whitespace for i in range(end, max(start, end - overlap) - 1, -1): if i < text_len and text[i].isspace(): split_pos = i + 1 # Position after the space break # If no good split found, just use the calculated end if split_pos == -1 or split_pos <= start: actual_end = end else: actual_end = split_pos # Add the chunk chunks.append((text[start:actual_end], start)) # Calculate next start position, ensuring we make progress next_start = start + (actual_end - start - overlap) if next_start <= start: next_start = start + 1 start = next_start return chunks