File size: 3,667 Bytes
1c02c6e
 
47f397e
c2f577f
d4fdfec
c2f577f
d4fdfec
c2f577f
47f397e
 
 
 
d4fdfec
 
 
 
 
47f397e
 
d4fdfec
 
c2f577f
7b89327
 
 
 
 
 
 
 
 
 
 
d4fdfec
 
 
 
 
 
 
47f397e
7b89327
 
17b66cc
d4fdfec
 
 
 
 
 
cc1edab
47f397e
7b89327
 
d4fdfec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7b89327
17b66cc
d4fdfec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import gradio as gr
from transformers import pipeline
from smolagents import Tool

class SimpleSentimentTool(Tool):
    name = "sentiment_analysis"
    description = "This tool analyzes the sentiment of a given text."

    inputs = {
        "text": {
            "type": "string",
            "description": "The text to analyze for sentiment"
        },
        "model_key": {
            "type": "string",
            "description": "The model to use for sentiment analysis",
            "default": None
        }
    }
    # Use a standard authorized type
    output_type = "dict[str, float]"  
    
    # Available sentiment analysis models
    models = {
        "multilingual": "nlptown/bert-base-multilingual-uncased-sentiment",
        "deberta": "microsoft/deberta-xlarge-mnli",
        "distilbert": "distilbert-base-uncased-finetuned-sst-2-english",
        "mobilebert": "lordtt13/emo-mobilebert",
        "reviews": "juliensimon/reviews-sentiment-analysis",
        "sbc": "sbcBI/sentiment_analysis_model",
        "german": "oliverguhr/german-sentiment-bert"
    }
    
    def __init__(self, default_model="distilbert", preload=False):
        """Initialize with a default model.
        
        Args:
            default_model: The default model to use if no model is specified
            preload: Whether to preload the default model at initialization
        """
        super().__init__()
        self.default_model = default_model
        self._classifiers = {}
        
        # Optionally preload the default model
        if preload:
            try:
                self._get_classifier(self.models[default_model])
            except Exception as e:
                print(f"Warning: Failed to preload model: {str(e)}")
    
    def _get_classifier(self, model_id):
        """Get or create a classifier for the given model ID."""
        if model_id not in self._classifiers:
            try:
                print(f"Loading model: {model_id}")
                self._classifiers[model_id] = pipeline(
                    "text-classification", 
                    model=model_id, 
                    top_k=None  # Return all scores
                )
            except Exception as e:
                print(f"Error loading model {model_id}: {str(e)}")
                # Fall back to distilbert if available
                if model_id != self.models["distilbert"]:
                    print("Falling back to distilbert model...")
                    return self._get_classifier(self.models["distilbert"])
                else:
                    # Last resort - if even distilbert fails
                    print("Critical error: Could not load default model")
                    raise RuntimeError(f"Failed to load any sentiment model: {str(e)}")
        return self._classifiers[model_id]
    
    def forward(self, text: str, model_key=None):
        """Process input text and return sentiment predictions."""
        try:
            # Determine which model to use
            model_key = model_key or self.default_model
            model_id = self.models.get(model_key, self.models[self.default_model])
            
            # Get the classifier
            classifier = self._get_classifier(model_id)
            
            # Get predictions
            prediction = classifier(text)
            
            # Format as a dictionary
            result = {}
            for item in prediction[0]:
                result[item['label']] = float(item['score'])
            
            return result
        except Exception as e:
            print(f"Error in sentiment analysis: {str(e)}")
            return {"error": str(e)}