|
import os |
|
from typing import Dict, List, Any, Optional, Union |
|
from smolagents import Tool |
|
|
|
class NamedEntityRecognitionTool(Tool): |
|
name = "ner_tool" |
|
description = """ |
|
Identifies and labels named entities in text using customizable NER models. |
|
Can recognize entities such as persons, organizations, locations, dates, etc. |
|
Returns a structured analysis of all entities found in the input text. |
|
""" |
|
inputs = { |
|
"text": { |
|
"type": "string", |
|
"description": "The text to analyze for named entities", |
|
}, |
|
"model": { |
|
"type": "string", |
|
"description": "The NER model to use (default: 'dslim/bert-base-NER')", |
|
"nullable": True |
|
}, |
|
"aggregation": { |
|
"type": "string", |
|
"description": "How to aggregate entities: 'simple' (just list), 'grouped' (by label), or 'detailed' (with confidence scores)", |
|
"nullable": True |
|
}, |
|
"min_score": { |
|
"type": "number", |
|
"description": "Minimum confidence score threshold (0.0-1.0) for including entities", |
|
"nullable": True |
|
} |
|
} |
|
output_type = "string" |
|
|
|
def __init__(self): |
|
"""Initialize the NER Tool with default settings.""" |
|
super().__init__() |
|
self.default_model = "dslim/bert-base-NER" |
|
self.available_models = { |
|
"dslim/bert-base-NER": "Standard NER (English)", |
|
"jean-baptiste/camembert-ner": "French NER", |
|
"Davlan/bert-base-multilingual-cased-ner-hrl": "Multilingual NER", |
|
"Babelscape/wikineural-multilingual-ner": "WikiNeural Multilingual NER", |
|
"flair/ner-english-ontonotes-large": "OntoNotes English (fine-grained)", |
|
"elastic/distilbert-base-cased-finetuned-conll03-english": "CoNLL (fast)" |
|
} |
|
self.entity_colors = { |
|
"PER": "π₯ Person", |
|
"PERSON": "π₯ Person", |
|
"LOC": "π¨ Location", |
|
"LOCATION": "π¨ Location", |
|
"GPE": "π¨ Location", |
|
"ORG": "π¦ Organization", |
|
"ORGANIZATION": "π¦ Organization", |
|
"MISC": "π© Miscellaneous", |
|
"DATE": "πͺ Date", |
|
"TIME": "πͺ Time", |
|
"MONEY": "π° Money", |
|
"PERCENT": "π Percentage", |
|
"PRODUCT": "π Product", |
|
"EVENT": "π« Event", |
|
"WORK_OF_ART": "π¨ Work of Art", |
|
"LAW": "βοΈ Law", |
|
"LANGUAGE": "π£οΈ Language", |
|
"FAC": "π’ Facility", |
|
|
|
"O": "Not an entity", |
|
"UNKNOWN": "π· Entity" |
|
} |
|
|
|
self._pipeline = None |
|
|
|
def _load_pipeline(self, model_name: str): |
|
"""Load the NER pipeline with the specified model.""" |
|
try: |
|
from transformers import pipeline |
|
import torch |
|
|
|
|
|
device = 0 if torch.cuda.is_available() else -1 |
|
|
|
|
|
if "dslim/bert-base-NER" in model_name: |
|
|
|
self._pipeline = pipeline( |
|
"ner", |
|
model=model_name, |
|
aggregation_strategy="first", |
|
device=device |
|
) |
|
else: |
|
self._pipeline = pipeline( |
|
"ner", |
|
model=model_name, |
|
aggregation_strategy="simple", |
|
device=device |
|
) |
|
return True |
|
except Exception as e: |
|
print(f"Error loading model {model_name}: {str(e)}") |
|
try: |
|
|
|
from transformers import pipeline |
|
import torch |
|
device = 0 if torch.cuda.is_available() else -1 |
|
self._pipeline = pipeline( |
|
"ner", |
|
model=self.default_model, |
|
aggregation_strategy="first", |
|
device=device |
|
) |
|
return True |
|
except Exception as fallback_error: |
|
print(f"Error loading fallback model: {str(fallback_error)}") |
|
return False |
|
|
|
def _get_friendly_label(self, label: str) -> str: |
|
"""Convert technical entity labels to friendly descriptions with color indicators.""" |
|
|
|
clean_label = label.replace("B-", "").replace("I-", "") |
|
|
|
|
|
if clean_label == "UNKNOWN" or clean_label == "O": |
|
|
|
|
|
text = self._current_entity_text.lower() if hasattr(self, '_current_entity_text') else "" |
|
|
|
|
|
if text and text[0].isupper(): |
|
|
|
countries_and_cities = ["germany", "france", "spain", "italy", "london", |
|
"paris", "berlin", "rome", "new york", "tokyo", |
|
"beijing", "moscow", "canada", "australia", "india", |
|
"china", "japan", "russia", "brazil", "mexico"] |
|
|
|
if text.lower() in countries_and_cities: |
|
return self.entity_colors.get("LOC", "π¨ Location") |
|
|
|
|
|
common_names = ["john", "mike", "sarah", "david", "michael", "james", |
|
"robert", "mary", "jennifer", "linda", "michael", "william", |
|
"kristof", "chris", "thomas", "daniel", "matthew", "joseph", |
|
"donald", "richard", "charles", "paul", "mark", "kevin"] |
|
|
|
name_parts = text.lower().split() |
|
if name_parts and name_parts[0] in common_names: |
|
return self.entity_colors.get("PER", "π₯ Person") |
|
|
|
return self.entity_colors.get(clean_label, f"π· {clean_label}") |
|
|
|
def forward(self, text: str, model: str = None, aggregation: str = None, min_score: float = None) -> str: |
|
""" |
|
Perform Named Entity Recognition on the input text. |
|
|
|
Args: |
|
text: The text to analyze |
|
model: NER model to use (default: dslim/bert-base-NER) |
|
aggregation: How to aggregate results (simple, grouped, detailed) |
|
min_score: Minimum confidence threshold (0.0-1.0) |
|
|
|
Returns: |
|
Formatted string with NER analysis results |
|
""" |
|
|
|
if model is None: |
|
model = self.default_model |
|
if aggregation is None: |
|
aggregation = "grouped" |
|
if min_score is None: |
|
min_score = 0.8 |
|
|
|
|
|
if model not in self.available_models and not model.startswith("dslim/"): |
|
return f"Model '{model}' not recognized. Available models: {', '.join(self.available_models.keys())}" |
|
|
|
|
|
if self._pipeline is None or self._pipeline.model.name_or_path != model: |
|
if not self._load_pipeline(model): |
|
return "Failed to load NER model. Please try a different model." |
|
|
|
|
|
try: |
|
entities = self._pipeline(text) |
|
|
|
|
|
entities = [e for e in entities if e.get('score', 0) >= min_score] |
|
|
|
|
|
for entity in entities: |
|
word = entity.get("word", "") |
|
start = entity.get("start", 0) |
|
end = entity.get("end", 0) |
|
|
|
entity['actual_text'] = text[start:end] |
|
|
|
self._current_entity_text = text[start:end] |
|
|
|
if not entities: |
|
return "No entities were detected in the text with the current settings." |
|
|
|
|
|
if aggregation == "simple": |
|
return self._format_simple(text, entities) |
|
elif aggregation == "detailed": |
|
return self._format_detailed(text, entities) |
|
else: |
|
return self._format_grouped(text, entities) |
|
|
|
except Exception as e: |
|
return f"Error analyzing text: {str(e)}" |
|
|
|
def _format_simple(self, text: str, entities: List[Dict[str, Any]]) -> str: |
|
"""Format entities as a simple list.""" |
|
|
|
merged_entities = [] |
|
current_entity = None |
|
|
|
for entity in sorted(entities, key=lambda e: e.get("start", 0)): |
|
word = entity.get("word", "") |
|
start = entity.get("start", 0) |
|
end = entity.get("end", 0) |
|
label = entity.get("entity", "UNKNOWN") |
|
score = entity.get("score", 0) |
|
|
|
|
|
if word.startswith("##"): |
|
if current_entity: |
|
|
|
current_entity["word"] += word.replace("##", "") |
|
current_entity["end"] = end |
|
|
|
current_entity["score"] = (current_entity["score"] + score) / 2 |
|
continue |
|
|
|
|
|
current_entity = { |
|
"word": word, |
|
"start": start, |
|
"end": end, |
|
"entity": label, |
|
"score": score |
|
} |
|
merged_entities.append(current_entity) |
|
|
|
result = "Named Entities Found:\n\n" |
|
|
|
for entity in merged_entities: |
|
word = entity.get("word", "") |
|
label = entity.get("entity", "UNKNOWN") |
|
score = entity.get("score", 0) |
|
friendly_label = self._get_friendly_label(label) |
|
|
|
result += f"β’ {word} - {friendly_label} (confidence: {score:.2f})\n" |
|
|
|
return result |
|
|
|
def _format_grouped(self, text: str, entities: List[Dict[str, Any]]) -> str: |
|
"""Format entities grouped by their category.""" |
|
|
|
merged_entities = [] |
|
current_entity = None |
|
|
|
for entity in sorted(entities, key=lambda e: e.get("start", 0)): |
|
word = entity.get("word", "") |
|
start = entity.get("start", 0) |
|
end = entity.get("end", 0) |
|
label = entity.get("entity", "UNKNOWN") |
|
score = entity.get("score", 0) |
|
|
|
|
|
if word.startswith("##"): |
|
if current_entity: |
|
|
|
current_entity["word"] += word.replace("##", "") |
|
current_entity["end"] = end |
|
|
|
current_entity["score"] = (current_entity["score"] + score) / 2 |
|
continue |
|
|
|
|
|
current_entity = { |
|
"word": word, |
|
"start": start, |
|
"end": end, |
|
"entity": label, |
|
"score": score |
|
} |
|
merged_entities.append(current_entity) |
|
|
|
|
|
grouped = {} |
|
|
|
for entity in merged_entities: |
|
word = entity.get("word", "") |
|
label = entity.get("entity", "UNKNOWN").replace("B-", "").replace("I-", "") |
|
|
|
if label not in grouped: |
|
grouped[label] = [] |
|
|
|
grouped[label].append(word) |
|
|
|
|
|
result = "Named Entities by Category:\n\n" |
|
|
|
for label, words in grouped.items(): |
|
friendly_label = self._get_friendly_label(label) |
|
unique_words = list(set(words)) |
|
result += f"{friendly_label}: {', '.join(unique_words)}\n" |
|
|
|
return result |
|
|
|
def _format_detailed(self, text: str, entities: List[Dict[str, Any]]) -> str: |
|
"""Format entities with detailed information including position in text.""" |
|
|
|
merged_entities = [] |
|
current_entity = None |
|
|
|
for entity in sorted(entities, key=lambda e: e.get("start", 0)): |
|
word = entity.get("word", "") |
|
start = entity.get("start", 0) |
|
end = entity.get("end", 0) |
|
label = entity.get("entity", "UNKNOWN") |
|
score = entity.get("score", 0) |
|
|
|
|
|
if word.startswith("##"): |
|
if current_entity: |
|
|
|
current_entity["word"] += word.replace("##", "") |
|
current_entity["end"] = end |
|
|
|
current_entity["score"] = (current_entity["score"] + score) / 2 |
|
continue |
|
|
|
|
|
current_entity = { |
|
"word": word, |
|
"start": start, |
|
"end": end, |
|
"entity": label, |
|
"score": score |
|
} |
|
merged_entities.append(current_entity) |
|
|
|
|
|
character_labels = [None] * len(text) |
|
|
|
|
|
for entity in merged_entities: |
|
start = entity.get("start", 0) |
|
end = entity.get("end", 0) |
|
label = entity.get("entity", "UNKNOWN") |
|
|
|
for i in range(start, min(end, len(text))): |
|
character_labels[i] = label |
|
|
|
|
|
highlighted_text = "" |
|
current_label = None |
|
current_segment = "" |
|
|
|
for i, char in enumerate(text): |
|
label = character_labels[i] |
|
|
|
if label != current_label: |
|
|
|
if current_segment: |
|
if current_label: |
|
clean_label = current_label.replace("B-", "").replace("I-", "") |
|
highlighted_text += f"[{current_segment}]({clean_label}) " |
|
else: |
|
highlighted_text += current_segment + " " |
|
|
|
|
|
current_label = label |
|
current_segment = char |
|
else: |
|
current_segment += char |
|
|
|
|
|
if current_segment: |
|
if current_label: |
|
clean_label = current_label.replace("B-", "").replace("I-", "") |
|
highlighted_text += f"[{current_segment}]({clean_label})" |
|
else: |
|
highlighted_text += current_segment |
|
|
|
|
|
entity_details = [] |
|
for entity in merged_entities: |
|
word = entity.get("word", "") |
|
label = entity.get("entity", "UNKNOWN") |
|
score = entity.get("score", 0) |
|
friendly_label = self._get_friendly_label(label) |
|
|
|
entity_details.append(f"β’ {word} - {friendly_label} (confidence: {score:.2f})") |
|
|
|
|
|
result = "Entity Analysis:\n\n" |
|
result += "Text with Entities Marked:\n" |
|
result += highlighted_text + "\n\n" |
|
result += "Entity Details:\n" |
|
result += "\n".join(entity_details) |
|
|
|
return result |
|
|
|
def get_available_models(self) -> Dict[str, str]: |
|
"""Return the dictionary of available models with descriptions.""" |
|
return self.available_models |
|
|
|
|
|
|
|
|
|
|