File size: 17,169 Bytes
b04682c
 
 
c5922b9
 
 
b04682c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7d26c0c
 
 
 
b04682c
 
 
c5922b9
b04682c
 
 
 
7d26c0c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b04682c
 
 
 
 
 
7d26c0c
 
 
 
 
 
 
 
b04682c
 
 
 
c5922b9
b04682c
 
 
 
7d26c0c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b04682c
fb510e6
b04682c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7d26c0c
 
 
 
 
 
 
 
 
 
b04682c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7d26c0c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b04682c
 
7d26c0c
b04682c
d25649c
b04682c
 
 
 
 
 
 
 
 
7d26c0c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b04682c
 
 
7d26c0c
d25649c
b04682c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7d26c0c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b04682c
 
 
 
7d26c0c
b04682c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28094fc
b04682c
d25649c
b04682c
 
 
 
 
 
 
 
 
 
 
 
7d26c0c
b04682c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c5922b9
b04682c
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
import os
from typing import Dict, List, Any, Optional, Union
from smolagents import Tool

class NamedEntityRecognitionTool(Tool):
    name = "ner_tool"
    description = """
    Identifies and labels named entities in text using customizable NER models.
    Can recognize entities such as persons, organizations, locations, dates, etc.
    Returns a structured analysis of all entities found in the input text.
    """
    inputs = {
        "text": {
            "type": "string",
            "description": "The text to analyze for named entities",
        },
        "model": {
            "type": "string",
            "description": "The NER model to use (default: 'dslim/bert-base-NER')",
            "nullable": True
        },
        "aggregation": {
            "type": "string",
            "description": "How to aggregate entities: 'simple' (just list), 'grouped' (by label), or 'detailed' (with confidence scores)",
            "nullable": True
        },
        "min_score": {
            "type": "number",
            "description": "Minimum confidence score threshold (0.0-1.0) for including entities",
            "nullable": True
        }
    }
    output_type = "string"
    
    def __init__(self):
        """Initialize the NER Tool with default settings."""
        super().__init__()
        self.default_model = "dslim/bert-base-NER"
        self.available_models = {
            "dslim/bert-base-NER": "Standard NER (English)",
            "jean-baptiste/camembert-ner": "French NER",
            "Davlan/bert-base-multilingual-cased-ner-hrl": "Multilingual NER",
            "Babelscape/wikineural-multilingual-ner": "WikiNeural Multilingual NER",
            "flair/ner-english-ontonotes-large": "OntoNotes English (fine-grained)",
            "elastic/distilbert-base-cased-finetuned-conll03-english": "CoNLL (fast)"
        }
        self.entity_colors = {
            "PER": "πŸŸ₯ Person",
            "PERSON": "πŸŸ₯ Person",
            "LOC": "🟨 Location",
            "LOCATION": "🟨 Location",
            "GPE": "🟨 Location",
            "ORG": "🟦 Organization",
            "ORGANIZATION": "🟦 Organization",
            "MISC": "🟩 Miscellaneous",
            "DATE": "πŸŸͺ Date",
            "TIME": "πŸŸͺ Time",
            "MONEY": "πŸ’° Money",
            "PERCENT": "πŸ“Š Percentage",
            "PRODUCT": "πŸ›’ Product",
            "EVENT": "🎫 Event",
            "WORK_OF_ART": "🎨 Work of Art",
            "LAW": "βš–οΈ Law",
            "LANGUAGE": "πŸ—£οΈ Language",
            "FAC": "🏒 Facility",
            # Fix for models that don't properly tag entities
            "O": "Not an entity",
            "UNKNOWN": "πŸ”· Entity"
        }
        # Pipeline will be lazily loaded
        self._pipeline = None

    def _load_pipeline(self, model_name: str):
        """Load the NER pipeline with the specified model."""
        try:
            from transformers import pipeline
            import torch
            
            # Try to detect if GPU is available
            device = 0 if torch.cuda.is_available() else -1
            
            # For some models, we need special handling
            if "dslim/bert-base-NER" in model_name:
                # This model works better with a specific aggregation strategy
                self._pipeline = pipeline(
                    "ner", 
                    model=model_name, 
                    aggregation_strategy="first",
                    device=device
                )
            else:
                self._pipeline = pipeline(
                    "ner", 
                    model=model_name, 
                    aggregation_strategy="simple",
                    device=device
                )
            return True
        except Exception as e:
            print(f"Error loading model {model_name}: {str(e)}")
            try:
                # Fall back to default model
                from transformers import pipeline
                import torch
                device = 0 if torch.cuda.is_available() else -1
                self._pipeline = pipeline(
                    "ner", 
                    model=self.default_model, 
                    aggregation_strategy="first",
                    device=device
                )
                return True
            except Exception as fallback_error:
                print(f"Error loading fallback model: {str(fallback_error)}")
                return False

    def _get_friendly_label(self, label: str) -> str:
        """Convert technical entity labels to friendly descriptions with color indicators."""
        # Strip B- or I- prefixes that indicate beginning or inside of entity
        clean_label = label.replace("B-", "").replace("I-", "")
        
        # Handle common name and location patterns with heuristics
        if clean_label == "UNKNOWN" or clean_label == "O":
            # Apply some basic heuristics to detect entity types
            # This is a fallback when the model fails to properly tag
            text = self._current_entity_text.lower() if hasattr(self, '_current_entity_text') else ""
            
            # Check for capitalized words which might be names or places
            if text and text[0].isupper():
                # Countries and major cities
                countries_and_cities = ["germany", "france", "spain", "italy", "london", 
                                        "paris", "berlin", "rome", "new york", "tokyo", 
                                        "beijing", "moscow", "canada", "australia", "india",
                                        "china", "japan", "russia", "brazil", "mexico"]
                
                if text.lower() in countries_and_cities:
                    return self.entity_colors.get("LOC", "🟨 Location")
                
                # Common first names (add more as needed)
                common_names = ["john", "mike", "sarah", "david", "michael", "james",
                               "robert", "mary", "jennifer", "linda", "michael", "william",
                               "kristof", "chris", "thomas", "daniel", "matthew", "joseph",
                               "donald", "richard", "charles", "paul", "mark", "kevin"]
                
                name_parts = text.lower().split()
                if name_parts and name_parts[0] in common_names:
                    return self.entity_colors.get("PER", "πŸŸ₯ Person")
            
        return self.entity_colors.get(clean_label, f"πŸ”· {clean_label}")

    def forward(self, text: str, model: str = None, aggregation: str = None, min_score: float = None) -> str:
        """
        Perform Named Entity Recognition on the input text.
        
        Args:
            text: The text to analyze
            model: NER model to use (default: dslim/bert-base-NER)
            aggregation: How to aggregate results (simple, grouped, detailed)
            min_score: Minimum confidence threshold (0.0-1.0)
            
        Returns:
            Formatted string with NER analysis results
        """
        # Set default values if parameters are None
        if model is None:
            model = self.default_model
        if aggregation is None:
            aggregation = "grouped"
        if min_score is None:
            min_score = 0.8
            
        # Validate model choice
        if model not in self.available_models and not model.startswith("dslim/"):
            return f"Model '{model}' not recognized. Available models: {', '.join(self.available_models.keys())}"
            
        # Load the model if not already loaded or if different from current
        if self._pipeline is None or self._pipeline.model.name_or_path != model:
            if not self._load_pipeline(model):
                return "Failed to load NER model. Please try a different model."
                
        # Perform NER analysis
        try:
            entities = self._pipeline(text)
            
            # Filter by confidence score
            entities = [e for e in entities if e.get('score', 0) >= min_score]
            
            # Store the text for better heuristics
            for entity in entities:
                word = entity.get("word", "")
                start = entity.get("start", 0)
                end = entity.get("end", 0)
                # Store the actual text from the input for better entity type detection
                entity['actual_text'] = text[start:end]
                # Set this for _get_friendly_label to use
                self._current_entity_text = text[start:end]
            
            if not entities:
                return "No entities were detected in the text with the current settings."
                
            # Process results based on aggregation method
            if aggregation == "simple":
                return self._format_simple(text, entities)
            elif aggregation == "detailed":
                return self._format_detailed(text, entities)
            else:  # default to grouped
                return self._format_grouped(text, entities)
                
        except Exception as e:
            return f"Error analyzing text: {str(e)}"
            
    def _format_simple(self, text: str, entities: List[Dict[str, Any]]) -> str:
        """Format entities as a simple list."""
        # Process word pieces and handle subtoken merging
        merged_entities = []
        current_entity = None
        
        for entity in sorted(entities, key=lambda e: e.get("start", 0)):
            word = entity.get("word", "")
            start = entity.get("start", 0)
            end = entity.get("end", 0)
            label = entity.get("entity", "UNKNOWN")
            score = entity.get("score", 0)
            
            # Check if this is a continuation (subtoken)
            if word.startswith("##"):
                if current_entity:
                    # Extend the current entity
                    current_entity["word"] += word.replace("##", "")
                    current_entity["end"] = end
                    # Keep the average score
                    current_entity["score"] = (current_entity["score"] + score) / 2
                continue
            
            # Start a new entity
            current_entity = {
                "word": word,
                "start": start,
                "end": end,
                "entity": label,
                "score": score
            }
            merged_entities.append(current_entity)
        
        result = "Named Entities Found:\n\n"
        
        for entity in merged_entities:
            word = entity.get("word", "")
            label = entity.get("entity", "UNKNOWN")
            score = entity.get("score", 0)
            friendly_label = self._get_friendly_label(label)
            
            result += f"β€’ {word} - {friendly_label} (confidence: {score:.2f})\n"
            
        return result
            
    def _format_grouped(self, text: str, entities: List[Dict[str, Any]]) -> str:
        """Format entities grouped by their category."""
        # Process word pieces and handle subtoken merging
        merged_entities = []
        current_entity = None
        
        for entity in sorted(entities, key=lambda e: e.get("start", 0)):
            word = entity.get("word", "")
            start = entity.get("start", 0)
            end = entity.get("end", 0)
            label = entity.get("entity", "UNKNOWN")
            score = entity.get("score", 0)
            
            # Check if this is a continuation (subtoken)
            if word.startswith("##"):
                if current_entity:
                    # Extend the current entity
                    current_entity["word"] += word.replace("##", "")
                    current_entity["end"] = end
                    # Keep the average score
                    current_entity["score"] = (current_entity["score"] + score) / 2
                continue
            
            # Start a new entity
            current_entity = {
                "word": word,
                "start": start,
                "end": end,
                "entity": label,
                "score": score
            }
            merged_entities.append(current_entity)
        
        # Group entities by their label
        grouped = {}
        
        for entity in merged_entities:
            word = entity.get("word", "")
            label = entity.get("entity", "UNKNOWN").replace("B-", "").replace("I-", "")
            
            if label not in grouped:
                grouped[label] = []
                
            grouped[label].append(word)
            
        # Build the result string
        result = "Named Entities by Category:\n\n"
        
        for label, words in grouped.items():
            friendly_label = self._get_friendly_label(label)
            unique_words = list(set(words))
            result += f"{friendly_label}: {', '.join(unique_words)}\n"
            
        return result
            
    def _format_detailed(self, text: str, entities: List[Dict[str, Any]]) -> str:
        """Format entities with detailed information including position in text."""
        # Process word pieces and handle subtoken merging
        merged_entities = []
        current_entity = None
        
        for entity in sorted(entities, key=lambda e: e.get("start", 0)):
            word = entity.get("word", "")
            start = entity.get("start", 0)
            end = entity.get("end", 0)
            label = entity.get("entity", "UNKNOWN")
            score = entity.get("score", 0)
            
            # Check if this is a continuation (subtoken)
            if word.startswith("##"):
                if current_entity:
                    # Extend the current entity
                    current_entity["word"] += word.replace("##", "")
                    current_entity["end"] = end
                    # Keep the average score
                    current_entity["score"] = (current_entity["score"] + score) / 2
                continue
            
            # Start a new entity
            current_entity = {
                "word": word,
                "start": start,
                "end": end,
                "entity": label,
                "score": score
            }
            merged_entities.append(current_entity)
            
        # First, build an entity map to highlight the entire text
        character_labels = [None] * len(text)
        
        # Mark each character with its entity
        for entity in merged_entities:
            start = entity.get("start", 0)
            end = entity.get("end", 0)
            label = entity.get("entity", "UNKNOWN")
            
            for i in range(start, min(end, len(text))):
                character_labels[i] = label
                
        # Build highlighted text sections
        highlighted_text = ""
        current_label = None
        current_segment = ""
        
        for i, char in enumerate(text):
            label = character_labels[i]
            
            if label != current_label:
                # End the previous segment if any
                if current_segment:
                    if current_label:
                        clean_label = current_label.replace("B-", "").replace("I-", "")
                        highlighted_text += f"[{current_segment}]({clean_label}) "
                    else:
                        highlighted_text += current_segment + " "
                        
                # Start a new segment
                current_label = label
                current_segment = char
            else:
                current_segment += char
                
        # Add the final segment
        if current_segment:
            if current_label:
                clean_label = current_label.replace("B-", "").replace("I-", "")
                highlighted_text += f"[{current_segment}]({clean_label})"
            else:
                highlighted_text += current_segment
                
        # Get entity details
        entity_details = []
        for entity in merged_entities:
            word = entity.get("word", "")
            label = entity.get("entity", "UNKNOWN")
            score = entity.get("score", 0)
            friendly_label = self._get_friendly_label(label)
            
            entity_details.append(f"β€’ {word} - {friendly_label} (confidence: {score:.2f})")
            
        # Combine into final result
        result = "Entity Analysis:\n\n"
        result += "Text with Entities Marked:\n"
        result += highlighted_text + "\n\n"
        result += "Entity Details:\n"
        result += "\n".join(entity_details)
        
        return result
        
    def get_available_models(self) -> Dict[str, str]:
        """Return the dictionary of available models with descriptions."""
        return self.available_models

# Example usage:
# ner_tool = NamedEntityRecognitionTool()
# result = ner_tool("Apple Inc. is planning to open a new store in Paris, France next year.", model="dslim/bert-base-NER")
# print(result)