Spaces:
Running
Running
File size: 3,667 Bytes
1c02c6e 47f397e c2f577f d4fdfec c2f577f d4fdfec c2f577f 47f397e d4fdfec 47f397e d4fdfec c2f577f 7b89327 d4fdfec 47f397e 7b89327 17b66cc d4fdfec cc1edab 47f397e 7b89327 d4fdfec 7b89327 17b66cc d4fdfec |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 |
import gradio as gr
from transformers import pipeline
from smolagents import Tool
class SimpleSentimentTool(Tool):
name = "sentiment_analysis"
description = "This tool analyzes the sentiment of a given text."
inputs = {
"text": {
"type": "string",
"description": "The text to analyze for sentiment"
},
"model_key": {
"type": "string",
"description": "The model to use for sentiment analysis",
"default": None
}
}
# Use a standard authorized type
output_type = "dict[str, float]"
# Available sentiment analysis models
models = {
"multilingual": "nlptown/bert-base-multilingual-uncased-sentiment",
"deberta": "microsoft/deberta-xlarge-mnli",
"distilbert": "distilbert-base-uncased-finetuned-sst-2-english",
"mobilebert": "lordtt13/emo-mobilebert",
"reviews": "juliensimon/reviews-sentiment-analysis",
"sbc": "sbcBI/sentiment_analysis_model",
"german": "oliverguhr/german-sentiment-bert"
}
def __init__(self, default_model="distilbert", preload=False):
"""Initialize with a default model.
Args:
default_model: The default model to use if no model is specified
preload: Whether to preload the default model at initialization
"""
super().__init__()
self.default_model = default_model
self._classifiers = {}
# Optionally preload the default model
if preload:
try:
self._get_classifier(self.models[default_model])
except Exception as e:
print(f"Warning: Failed to preload model: {str(e)}")
def _get_classifier(self, model_id):
"""Get or create a classifier for the given model ID."""
if model_id not in self._classifiers:
try:
print(f"Loading model: {model_id}")
self._classifiers[model_id] = pipeline(
"text-classification",
model=model_id,
top_k=None # Return all scores
)
except Exception as e:
print(f"Error loading model {model_id}: {str(e)}")
# Fall back to distilbert if available
if model_id != self.models["distilbert"]:
print("Falling back to distilbert model...")
return self._get_classifier(self.models["distilbert"])
else:
# Last resort - if even distilbert fails
print("Critical error: Could not load default model")
raise RuntimeError(f"Failed to load any sentiment model: {str(e)}")
return self._classifiers[model_id]
def forward(self, text: str, model_key=None):
"""Process input text and return sentiment predictions."""
try:
# Determine which model to use
model_key = model_key or self.default_model
model_id = self.models.get(model_key, self.models[self.default_model])
# Get the classifier
classifier = self._get_classifier(model_id)
# Get predictions
prediction = classifier(text)
# Format as a dictionary
result = {}
for item in prediction[0]:
result[item['label']] = float(item['score'])
return result
except Exception as e:
print(f"Error in sentiment analysis: {str(e)}")
return {"error": str(e)} |