|
|
|
import torch |
|
import time |
|
from typing import List, Dict, Any, Tuple |
|
from transformers import AutoTokenizer, AutoModelForTokenClassification, pipeline |
|
import logging |
|
|
|
|
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') |
|
logger = logging.getLogger(__name__) |
|
|
|
class NERModel: |
|
""" |
|
A singleton class to manage the NER model loading and prediction. |
|
Ensures the potentially large model is loaded only once. |
|
""" |
|
_instance = None |
|
_model = None |
|
_tokenizer = None |
|
_pipeline = None |
|
_model_name = None |
|
|
|
@classmethod |
|
def get_instance(cls, model_name: str = "Davlan/bert-base-multilingual-cased-ner-hrl"): |
|
""" |
|
Singleton pattern: Get the existing instance or create a new one. |
|
Uses the specified model_name only during the first initialization. |
|
""" |
|
if cls._instance is None: |
|
logger.info(f"Creating new NERModel instance with model: {model_name}") |
|
cls._instance = cls(model_name) |
|
elif cls._model_name != model_name: |
|
logger.warning(f"NERModel already initialized with {cls._model_name}. Ignoring new model name {model_name}.") |
|
return cls._instance |
|
|
|
def __init__(self, model_name: str): |
|
""" |
|
Initialize the model, tokenizer, and pipeline. |
|
Private constructor - use get_instance() instead. |
|
""" |
|
if NERModel._instance is not None: |
|
raise Exception("This class is a singleton! Use get_instance() to get the object.") |
|
else: |
|
self.model_name = model_name |
|
NERModel._model_name = model_name |
|
self._load_model() |
|
NERModel._instance = self |
|
|
|
def _load_model(self): |
|
"""Load the NER model and tokenizer from Hugging Face.""" |
|
logger.info(f"Loading model: {self.model_name}") |
|
start_time = time.time() |
|
|
|
try: |
|
|
|
self._tokenizer = AutoTokenizer.from_pretrained(self.model_name) |
|
self._model = AutoModelForTokenClassification.from_pretrained(self.model_name) |
|
|
|
|
|
if isinstance(self._model, torch.nn.Module): |
|
self._model.eval() |
|
|
|
|
|
self._pipeline = pipeline( |
|
"ner", |
|
model=self._model, |
|
tokenizer=self._tokenizer, |
|
|
|
) |
|
|
|
load_time = time.time() - start_time |
|
logger.info(f"Model '{self.model_name}' loaded successfully in {load_time:.2f} seconds.") |
|
|
|
except Exception as e: |
|
logger.error(f"Error loading model {self.model_name}: {e}") |
|
|
|
self._tokenizer = None |
|
self._model = None |
|
self._pipeline = None |
|
|
|
raise |
|
|
|
def predict(self, text: str) -> List[Dict[str, Any]]: |
|
""" |
|
Run NER prediction on the input text using the loaded pipeline. |
|
|
|
Args: |
|
text: The input string to perform NER on. |
|
|
|
Returns: |
|
A list of dictionaries, where each dictionary represents an entity |
|
identified by the pipeline. |
|
""" |
|
if self._pipeline is None: |
|
logger.error("NER pipeline is not initialized. Cannot predict.") |
|
return [] |
|
|
|
if not text or not isinstance(text, str): |
|
logger.warning("Prediction called with empty or invalid text.") |
|
return [] |
|
|
|
logger.debug(f"Running prediction on text: '{text[:100]}...'") |
|
try: |
|
|
|
results = self._pipeline(text) |
|
logger.debug(f"Prediction results: {results}") |
|
return results |
|
except Exception as e: |
|
logger.error(f"Error during NER prediction: {e}") |
|
return [] |
|
|
|
|
|
class TextProcessor: |
|
""" |
|
Provides static methods for processing text, specifically for NER tasks, |
|
including combining subword entities and handling large texts via chunking. |
|
""" |
|
|
|
@staticmethod |
|
def combine_entities(ner_results: List[Dict[str, Any]], original_text: str) -> List[Dict[str, Any]]: |
|
""" |
|
Combine entities that might be split into subword tokens (B-TAG, I-TAG). |
|
This method assumes the pipeline did *not* use grouped_entities=True. |
|
|
|
Args: |
|
ner_results: The raw output from the NER pipeline (list of token dictionaries). |
|
original_text: The original text input to extract entity words accurately. |
|
|
|
Returns: |
|
A list of dictionaries, each representing a combined entity with |
|
'entity_type', 'start', 'end', 'score', and 'word'. |
|
""" |
|
if not ner_results: |
|
return [] |
|
|
|
combined_entities = [] |
|
current_entity = None |
|
|
|
for token in ner_results: |
|
|
|
if not all(k in token for k in ['entity', 'start', 'end', 'score']): |
|
logger.warning(f"Skipping malformed token: {token}") |
|
continue |
|
|
|
|
|
if token['entity'] == 'O': |
|
|
|
if current_entity: |
|
combined_entities.append(current_entity) |
|
current_entity = None |
|
continue |
|
|
|
|
|
entity_tag = token['entity'] |
|
if entity_tag.startswith('B-') or entity_tag.startswith('I-'): |
|
entity_type = entity_tag[2:] |
|
else: |
|
|
|
entity_type = entity_tag |
|
|
|
|
|
if entity_tag.startswith('B-') or (entity_tag.startswith('I-') and (not current_entity or current_entity['entity_type'] != entity_type)): |
|
|
|
if current_entity: |
|
combined_entities.append(current_entity) |
|
|
|
|
|
current_entity = { |
|
'entity_type': entity_type, |
|
'start': token['start'], |
|
'end': token['end'], |
|
'score': float(token['score']), |
|
'token_count': 1 |
|
} |
|
|
|
|
|
elif entity_tag.startswith('I-') and current_entity and current_entity['entity_type'] == entity_type: |
|
|
|
current_entity['end'] = token['end'] |
|
|
|
current_entity['score'] = (current_entity['score'] * current_entity['token_count'] + float(token['score'])) / (current_entity['token_count'] + 1) |
|
current_entity['token_count'] += 1 |
|
|
|
|
|
else: |
|
logger.warning(f"Encountered unexpected token sequence at: {token}. Starting new entity.") |
|
if current_entity: |
|
combined_entities.append(current_entity) |
|
|
|
current_entity = { |
|
'entity_type': entity_type, |
|
'start': token['start'], |
|
'end': token['end'], |
|
'score': float(token['score']), |
|
'token_count': 1 |
|
} |
|
|
|
|
|
if current_entity: |
|
combined_entities.append(current_entity) |
|
|
|
|
|
for entity in combined_entities: |
|
try: |
|
|
|
start = max(0, min(entity['start'], len(original_text))) |
|
end = max(start, min(entity['end'], len(original_text))) |
|
entity['word'] = original_text[start:end].strip() |
|
|
|
if 'token_count' in entity: |
|
del entity['token_count'] |
|
except Exception as e: |
|
logger.error(f"Error extracting word for entity: {entity}, error: {e}") |
|
entity['word'] = "[Error extracting word]" |
|
|
|
|
|
combined_entities.sort(key=lambda x: x['start']) |
|
|
|
logger.info(f"Combined {len(ner_results)} raw tokens into {len(combined_entities)} entities.") |
|
return combined_entities |
|
|
|
@staticmethod |
|
def process_large_text(text: str, model: NERModel, chunk_size: int = 512, overlap: int = 50) -> List[Dict[str, Any]]: |
|
""" |
|
Process large text by splitting it into overlapping chunks, running NER |
|
on each chunk, and then combining the results intelligently. |
|
|
|
Args: |
|
text: The large input text string. |
|
model: The initialized NERModel instance. |
|
chunk_size: The maximum size of each text chunk. |
|
overlap: The number of characters to overlap between consecutive chunks. |
|
|
|
Returns: |
|
A list of combined entity dictionaries for the entire text. |
|
""" |
|
if not text: |
|
return [] |
|
|
|
|
|
if model._tokenizer and hasattr(model._tokenizer, 'model_max_length'): |
|
tokenizer_max_len = model._tokenizer.model_max_length |
|
if chunk_size > tokenizer_max_len: |
|
logger.warning(f"Requested chunk_size {chunk_size} exceeds model max length {tokenizer_max_len}. Using {tokenizer_max_len}.") |
|
chunk_size = tokenizer_max_len |
|
|
|
if overlap >= chunk_size // 2: |
|
logger.warning(f"Overlap {overlap} seems large for chunk_size {chunk_size}. Reducing overlap to {chunk_size // 4}.") |
|
overlap = chunk_size // 4 |
|
|
|
logger.info(f"Processing large text (length {len(text)}) with chunk_size={chunk_size}, overlap={overlap}") |
|
chunks = TextProcessor._create_chunks(text, chunk_size, overlap) |
|
logger.info(f"Split text into {len(chunks)} chunks.") |
|
|
|
all_raw_results = [] |
|
total_processing_time = 0 |
|
|
|
for i, (chunk_text, start_pos) in enumerate(chunks): |
|
logger.debug(f"Processing chunk {i+1}/{len(chunks)} (start_pos: {start_pos}, length: {len(chunk_text)})") |
|
start_time = time.time() |
|
|
|
|
|
raw_results_chunk = model.predict(chunk_text) |
|
|
|
chunk_processing_time = time.time() - start_time |
|
total_processing_time += chunk_processing_time |
|
logger.debug(f"Chunk {i+1} processed in {chunk_processing_time:.2f}s. Found {len(raw_results_chunk)} raw entities.") |
|
|
|
|
|
for result in raw_results_chunk: |
|
|
|
if 'start' in result and 'end' in result: |
|
result['start'] += start_pos |
|
result['end'] += start_pos |
|
else: |
|
logger.warning(f"Skipping position adjustment for malformed result in chunk {i+1}: {result}") |
|
|
|
all_raw_results.extend(raw_results_chunk) |
|
|
|
logger.info(f"Finished processing all chunks in {total_processing_time:.2f} seconds.") |
|
logger.info(f"Total raw entities found across all chunks: {len(all_raw_results)}") |
|
|
|
|
|
combined_entities = TextProcessor.combine_entities(all_raw_results, text) |
|
|
|
|
|
|
|
|
|
unique_entities = [] |
|
for entity in combined_entities: |
|
is_duplicate = False |
|
|
|
entity_length = entity['end'] - entity['start'] |
|
|
|
for existing in unique_entities: |
|
if existing['entity_type'] == entity['entity_type']: |
|
|
|
overlap_start = max(entity['start'], existing['start']) |
|
overlap_end = min(entity['end'], existing['end']) |
|
if overlap_start < overlap_end: |
|
overlap_length = overlap_end - overlap_start |
|
shorter_length = min(entity_length, existing['end'] - existing['start']) |
|
|
|
|
|
if overlap_length > 0.5 * shorter_length: |
|
is_duplicate = True |
|
|
|
if entity['score'] > existing['score']: |
|
|
|
unique_entities.remove(existing) |
|
is_duplicate = False |
|
break |
|
|
|
if not is_duplicate: |
|
unique_entities.append(entity) |
|
|
|
logger.info(f"Final number of unique combined entities: {len(unique_entities)}") |
|
return unique_entities |
|
|
|
@staticmethod |
|
def _create_chunks(text: str, chunk_size: int = 512, overlap: int = 50) -> List[Tuple[str, int]]: |
|
""" |
|
Split text into potentially overlapping chunks, trying to respect word boundaries. |
|
|
|
Args: |
|
text: The input text string. |
|
chunk_size: The target maximum size of each chunk. |
|
overlap: The desired overlap between consecutive chunks. |
|
|
|
Returns: |
|
A list of tuples, where each tuple contains (chunk_text, start_position_in_original_text). |
|
""" |
|
if not text: |
|
return [] |
|
if chunk_size <= overlap: |
|
raise ValueError("chunk_size must be greater than overlap") |
|
if chunk_size <= 0: |
|
raise ValueError("chunk_size must be positive") |
|
|
|
chunks = [] |
|
start = 0 |
|
text_len = len(text) |
|
|
|
while start < text_len: |
|
|
|
end = min(start + chunk_size, text_len) |
|
|
|
|
|
if end >= text_len: |
|
chunks.append((text[start:], start)) |
|
break |
|
|
|
|
|
split_pos = -1 |
|
|
|
for i in range(end, max(start, end - overlap) - 1, -1): |
|
if i < text_len and text[i].isspace(): |
|
split_pos = i + 1 |
|
break |
|
|
|
|
|
if split_pos == -1 or split_pos <= start: |
|
actual_end = end |
|
else: |
|
actual_end = split_pos |
|
|
|
|
|
chunks.append((text[start:actual_end], start)) |
|
|
|
|
|
next_start = start + (actual_end - start - overlap) |
|
if next_start <= start: |
|
next_start = start + 1 |
|
|
|
start = next_start |
|
|
|
return chunks |
|
|