File size: 3,863 Bytes
1c02c6e
 
47f397e
c2f577f
d4fdfec
c2f577f
d4fdfec
c2f577f
47f397e
 
 
 
d4fdfec
 
 
 
1dd56ad
 
47f397e
 
d4fdfec
207a623
c2f577f
7b89327
 
 
 
 
 
 
 
 
 
 
d4fdfec
 
 
 
 
 
 
47f397e
7b89327
 
17b66cc
d4fdfec
 
 
 
 
 
cc1edab
47f397e
7b89327
 
d4fdfec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7b89327
17b66cc
3a53d6f
d4fdfec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
207a623
 
 
d4fdfec
 
207a623
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import gradio as gr
from transformers import pipeline
from smolagents import Tool

class SimpleSentimentTool(Tool):
    name = "sentiment_analysis"
    description = "This tool analyzes the sentiment of a given text."

    inputs = {
        "text": {
            "type": "string",
            "description": "The text to analyze for sentiment"
        },
        "model_key": {
            "type": "string",
            "description": "The model to use for sentiment analysis",
            "default": "oliverguhr/german-sentiment-bert",
            "nullable": True
        }
    }
    # Use a standard authorized type
    output_type = "string"  
    
    # Available sentiment analysis models
    models = {
        "multilingual": "nlptown/bert-base-multilingual-uncased-sentiment",
        "deberta": "microsoft/deberta-xlarge-mnli",
        "distilbert": "distilbert-base-uncased-finetuned-sst-2-english",
        "mobilebert": "lordtt13/emo-mobilebert",
        "reviews": "juliensimon/reviews-sentiment-analysis",
        "sbc": "sbcBI/sentiment_analysis_model",
        "german": "oliverguhr/german-sentiment-bert"
    }
    
    def __init__(self, default_model="distilbert", preload=False):
        """Initialize with a default model.
        
        Args:
            default_model: The default model to use if no model is specified
            preload: Whether to preload the default model at initialization
        """
        super().__init__()
        self.default_model = default_model
        self._classifiers = {}
        
        # Optionally preload the default model
        if preload:
            try:
                self._get_classifier(self.models[default_model])
            except Exception as e:
                print(f"Warning: Failed to preload model: {str(e)}")
    
    def _get_classifier(self, model_id):
        """Get or create a classifier for the given model ID."""
        if model_id not in self._classifiers:
            try:
                print(f"Loading model: {model_id}")
                self._classifiers[model_id] = pipeline(
                    "text-classification", 
                    model=model_id, 
                    top_k=None  # Return all scores
                )
            except Exception as e:
                print(f"Error loading model {model_id}: {str(e)}")
                # Fall back to distilbert if available
                if model_id != self.models["distilbert"]:
                    print("Falling back to distilbert model...")
                    return self._get_classifier(self.models["distilbert"])
                else:
                    # Last resort - if even distilbert fails
                    print("Critical error: Could not load default model")
                    raise RuntimeError(f"Failed to load any sentiment model: {str(e)}")
        return self._classifiers[model_id]
    
    def forward(self, text: str, model_key="oliverguhr/german-sentiment-bert"):
        """Process input text and return sentiment predictions."""
        try:
            # Determine which model to use
            model_key = model_key or self.default_model
            model_id = self.models.get(model_key, self.models[self.default_model])
            
            # Get the classifier
            classifier = self._get_classifier(model_id)
            
            # Get predictions
            prediction = classifier(text)
            
            # Format as a dictionary
            result = {}
            for item in prediction[0]:
                result[item['label']] = float(item['score'])
            
            # Convert to JSON string for output
            import json
            return json.dumps(result, indent=2)
        except Exception as e:
            print(f"Error in sentiment analysis: {str(e)}")
            return json.dumps({"error": str(e)}, indent=2)