abhigyan / ner_module.py
Abhigyan
Refactor
f68c4f8
# ner_module.py
import torch
import time
from typing import List, Dict, Any, Tuple
from transformers import AutoTokenizer, AutoModelForTokenClassification, pipeline
import logging
# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
class NERModel:
"""
A singleton class to manage the NER model loading and prediction.
Ensures the potentially large model is loaded only once.
"""
_instance = None
_model = None
_tokenizer = None
_pipeline = None
_model_name = None # Store model name used for initialization
@classmethod
def get_instance(cls, model_name: str = "Davlan/bert-base-multilingual-cased-ner-hrl"):
"""
Singleton pattern: Get the existing instance or create a new one.
Uses the specified model_name only during the first initialization.
"""
if cls._instance is None:
logger.info(f"Creating new NERModel instance with model: {model_name}")
cls._instance = cls(model_name)
elif cls._model_name != model_name:
logger.warning(f"NERModel already initialized with {cls._model_name}. Ignoring new model name {model_name}.")
return cls._instance
def __init__(self, model_name: str):
"""
Initialize the model, tokenizer, and pipeline.
Private constructor - use get_instance() instead.
"""
if NERModel._instance is not None:
raise Exception("This class is a singleton! Use get_instance() to get the object.")
else:
self.model_name = model_name
NERModel._model_name = model_name # Store the model name
self._load_model()
NERModel._instance = self # Assign the instance here
def _load_model(self):
"""Load the NER model and tokenizer from Hugging Face."""
logger.info(f"Loading model: {self.model_name}")
start_time = time.time()
try:
# Load tokenizer and model
self._tokenizer = AutoTokenizer.from_pretrained(self.model_name)
self._model = AutoModelForTokenClassification.from_pretrained(self.model_name)
# Check if the model is a PyTorch model for potential optimizations
if isinstance(self._model, torch.nn.Module):
self._model.eval() # Set model to evaluation mode (important for inference)
# Create the NER pipeline
self._pipeline = pipeline(
"ner",
model=self._model,
tokenizer=self._tokenizer,
# grouped_entities=True # Uncomment if you want to use pipeline's built-in grouping
)
load_time = time.time() - start_time
logger.info(f"Model '{self.model_name}' loaded successfully in {load_time:.2f} seconds.")
except Exception as e:
logger.error(f"Error loading model {self.model_name}: {e}")
# Clean up partial loads if necessary
self._tokenizer = None
self._model = None
self._pipeline = None
# Re-raise the exception to signal failure
raise
def predict(self, text: str) -> List[Dict[str, Any]]:
"""
Run NER prediction on the input text using the loaded pipeline.
Args:
text: The input string to perform NER on.
Returns:
A list of dictionaries, where each dictionary represents an entity
identified by the pipeline.
"""
if self._pipeline is None:
logger.error("NER pipeline is not initialized. Cannot predict.")
return [] # Return empty list or raise an error
if not text or not isinstance(text, str):
logger.warning("Prediction called with empty or invalid text.")
return []
logger.debug(f"Running prediction on text: '{text[:100]}...'") # Log snippet
try:
# The pipeline handles tokenization and prediction
results = self._pipeline(text)
logger.debug(f"Prediction results: {results}")
return results
except Exception as e:
logger.error(f"Error during NER prediction: {e}")
return [] # Return empty list on error
class TextProcessor:
"""
Provides static methods for processing text, specifically for NER tasks,
including combining subword entities and handling large texts via chunking.
"""
@staticmethod
def combine_entities(ner_results: List[Dict[str, Any]], original_text: str) -> List[Dict[str, Any]]:
"""
Combine entities that might be split into subword tokens (B-TAG, I-TAG).
This method assumes the pipeline did *not* use grouped_entities=True.
Args:
ner_results: The raw output from the NER pipeline (list of token dictionaries).
original_text: The original text input to extract entity words accurately.
Returns:
A list of dictionaries, each representing a combined entity with
'entity_type', 'start', 'end', 'score', and 'word'.
"""
if not ner_results:
return []
combined_entities = []
current_entity = None
for token in ner_results:
# Basic validation of token structure
if not all(k in token for k in ['entity', 'start', 'end', 'score']):
logger.warning(f"Skipping malformed token: {token}")
continue
# Skip 'O' tags (Outside any entity)
if token['entity'] == 'O':
# If we were tracking an entity, finalize it before moving on
if current_entity:
combined_entities.append(current_entity)
current_entity = None
continue
# Extract entity type (e.g., 'PER', 'LOC') removing 'B-' or 'I-'
entity_tag = token['entity']
if entity_tag.startswith('B-') or entity_tag.startswith('I-'):
entity_type = entity_tag[2:]
else:
# Handle cases where the tag might not have B-/I- prefix (less common)
entity_type = entity_tag
# Start of a new entity ('B-') or continuation of a different entity type
if entity_tag.startswith('B-') or (entity_tag.startswith('I-') and (not current_entity or current_entity['entity_type'] != entity_type)):
# Finalize the previous entity if it exists
if current_entity:
combined_entities.append(current_entity)
# Start the new entity
current_entity = {
'entity_type': entity_type,
'start': token['start'],
'end': token['end'],
'score': float(token['score']),
'token_count': 1 # Keep track of tokens for averaging score
}
# Continuation of the current entity ('I-' and matching type)
elif entity_tag.startswith('I-') and current_entity and current_entity['entity_type'] == entity_type:
# Extend the end position
current_entity['end'] = token['end']
# Update the score (e.g., average)
current_entity['score'] = (current_entity['score'] * current_entity['token_count'] + float(token['score'])) / (current_entity['token_count'] + 1)
current_entity['token_count'] += 1
# Handle unexpected cases (e.g., I- tag without preceding B- or matching I-)
else:
logger.warning(f"Encountered unexpected token sequence at: {token}. Starting new entity.")
if current_entity:
combined_entities.append(current_entity)
# Try to create a new entity from this token
current_entity = {
'entity_type': entity_type,
'start': token['start'],
'end': token['end'],
'score': float(token['score']),
'token_count': 1
}
# Add the last tracked entity if it exists
if current_entity:
combined_entities.append(current_entity)
# Extract the actual text 'word' for each combined entity
for entity in combined_entities:
try:
# Ensure indices are valid
start = max(0, min(entity['start'], len(original_text)))
end = max(start, min(entity['end'], len(original_text)))
entity['word'] = original_text[start:end].strip()
# Remove internal helper key
if 'token_count' in entity:
del entity['token_count']
except Exception as e:
logger.error(f"Error extracting word for entity: {entity}, error: {e}")
entity['word'] = "[Error extracting word]"
# Sort entities by start position
combined_entities.sort(key=lambda x: x['start'])
logger.info(f"Combined {len(ner_results)} raw tokens into {len(combined_entities)} entities.")
return combined_entities
@staticmethod
def process_large_text(text: str, model: NERModel, chunk_size: int = 512, overlap: int = 50) -> List[Dict[str, Any]]:
"""
Process large text by splitting it into overlapping chunks, running NER
on each chunk, and then combining the results intelligently.
Args:
text: The large input text string.
model: The initialized NERModel instance.
chunk_size: The maximum size of each text chunk.
overlap: The number of characters to overlap between consecutive chunks.
Returns:
A list of combined entity dictionaries for the entire text.
"""
if not text:
return []
# Use tokenizer max length if available and smaller than chunk_size
if model._tokenizer and hasattr(model._tokenizer, 'model_max_length'):
tokenizer_max_len = model._tokenizer.model_max_length
if chunk_size > tokenizer_max_len:
logger.warning(f"Requested chunk_size {chunk_size} exceeds model max length {tokenizer_max_len}. Using {tokenizer_max_len}.")
chunk_size = tokenizer_max_len
# Ensure overlap is reasonable compared to chunk size
if overlap >= chunk_size // 2:
logger.warning(f"Overlap {overlap} seems large for chunk_size {chunk_size}. Reducing overlap to {chunk_size // 4}.")
overlap = chunk_size // 4
logger.info(f"Processing large text (length {len(text)}) with chunk_size={chunk_size}, overlap={overlap}")
chunks = TextProcessor._create_chunks(text, chunk_size, overlap)
logger.info(f"Split text into {len(chunks)} chunks.")
all_raw_results = []
total_processing_time = 0
for i, (chunk_text, start_pos) in enumerate(chunks):
logger.debug(f"Processing chunk {i+1}/{len(chunks)} (start_pos: {start_pos}, length: {len(chunk_text)})")
start_time = time.time()
# Get raw predictions for the current chunk
raw_results_chunk = model.predict(chunk_text)
chunk_processing_time = time.time() - start_time
total_processing_time += chunk_processing_time
logger.debug(f"Chunk {i+1} processed in {chunk_processing_time:.2f}s. Found {len(raw_results_chunk)} raw entities.")
# Adjust entity positions relative to the original text
for result in raw_results_chunk:
# Check if 'start' and 'end' exist before adjusting
if 'start' in result and 'end' in result:
result['start'] += start_pos
result['end'] += start_pos
else:
logger.warning(f"Skipping position adjustment for malformed result in chunk {i+1}: {result}")
all_raw_results.extend(raw_results_chunk)
logger.info(f"Finished processing all chunks in {total_processing_time:.2f} seconds.")
logger.info(f"Total raw entities found across all chunks: {len(all_raw_results)}")
# Combine entities from all chunks
combined_entities = TextProcessor.combine_entities(all_raw_results, text)
# Deduplicate entities based on overlapping positions
# Two entities are considered duplicates if they have the same type and
# overlap by more than 50% of the shorter entity's length
unique_entities = []
for entity in combined_entities:
is_duplicate = False
# Calculate entity length for overlap comparison
entity_length = entity['end'] - entity['start']
for existing in unique_entities:
if existing['entity_type'] == entity['entity_type']:
# Check for significant overlap
overlap_start = max(entity['start'], existing['start'])
overlap_end = min(entity['end'], existing['end'])
if overlap_start < overlap_end: # They overlap
overlap_length = overlap_end - overlap_start
shorter_length = min(entity_length, existing['end'] - existing['start'])
# If overlap is significant (>50% of shorter entity)
if overlap_length > 0.5 * shorter_length:
is_duplicate = True
# Keep the one with higher score
if entity['score'] > existing['score']:
# Replace the existing entity with this one
unique_entities.remove(existing)
is_duplicate = False
break
if not is_duplicate:
unique_entities.append(entity)
logger.info(f"Final number of unique combined entities: {len(unique_entities)}")
return unique_entities
@staticmethod
def _create_chunks(text: str, chunk_size: int = 512, overlap: int = 50) -> List[Tuple[str, int]]:
"""
Split text into potentially overlapping chunks, trying to respect word boundaries.
Args:
text: The input text string.
chunk_size: The target maximum size of each chunk.
overlap: The desired overlap between consecutive chunks.
Returns:
A list of tuples, where each tuple contains (chunk_text, start_position_in_original_text).
"""
if not text:
return []
if chunk_size <= overlap:
raise ValueError("chunk_size must be greater than overlap")
if chunk_size <= 0:
raise ValueError("chunk_size must be positive")
chunks = []
start = 0
text_len = len(text)
while start < text_len:
# Determine the ideal end position
end = min(start + chunk_size, text_len)
# If we're at the end of the text, just use what's left
if end >= text_len:
chunks.append((text[start:], start))
break
# Try to find a suitable split point (whitespace) to ensure we don't cut words
split_pos = -1
# Search backwards from end to find a whitespace
for i in range(end, max(start, end - overlap) - 1, -1):
if i < text_len and text[i].isspace():
split_pos = i + 1 # Position after the space
break
# If no good split found, just use the calculated end
if split_pos == -1 or split_pos <= start:
actual_end = end
else:
actual_end = split_pos
# Add the chunk
chunks.append((text[start:actual_end], start))
# Calculate next start position, ensuring we make progress
next_start = start + (actual_end - start - overlap)
if next_start <= start:
next_start = start + 1
start = next_start
return chunks