File size: 3,911 Bytes
18869bb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
from typing import Any, Dict, Union, Tuple
import gradio as gr
from transformers import pipeline, AutoModelForSequenceClassification, AutoTokenizer
import logging
from .base import BaseModel

logger = logging.getLogger(__name__)

class TextClassificationModel(BaseModel):
    """Lightweight text classification model using tiny BERT."""
    
    def __init__(self):
        super().__init__(
            name="Lightweight Text Classifier",
            description="Fast text classification using a tiny BERT model (4.4MB)"
        )
        self.model_name = "prajjwal1/bert-tiny"
        self._model = None
    
    def load_model(self) -> None:
        """Load the classification model."""
        try:
            logger.info(f"Loading model: {self.model_name}")
            
            # Initialize model with binary classification
            model = AutoModelForSequenceClassification.from_pretrained(
                self.model_name,
                num_labels=2
            )
            tokenizer = AutoTokenizer.from_pretrained(self.model_name)
            
            self._model = pipeline(
                "text-classification",
                model=model,
                tokenizer=tokenizer,
                device=-1  # CPU, use device=0 for GPU
            )
            
            # Log model size
            model_size_mb = sum(p.numel() * p.element_size() for p in model.parameters()) / (1024 * 1024)
            logger.info(f"Model loaded successfully. Size: {model_size_mb:.2f} MB")
            
        except Exception as e:
            logger.error(f"Error loading model: {str(e)}")
            raise
    
    async def predict(self, text: str) -> Dict[str, Union[str, float]]:
        """Make a prediction using the model."""
        try:
            if self._model is None:
                self.load_model()
                
            logger.info(f"Processing text: {text[:50]}...")
            result = self._model(text)[0]
            
            # Map raw labels to sentiment
            label_map = {
                "LABEL_0": "NEGATIVE",
                "LABEL_1": "POSITIVE"
            }
            
            prediction = {
                "label": label_map.get(result["label"], result["label"]),
                "confidence": float(result["score"])
            }
            logger.info(f"Prediction result: {prediction}")
            return prediction
            
        except Exception as e:
            logger.error(f"Prediction error: {str(e)}")
            raise

    async def predict_for_interface(self, text: str) -> Tuple[str, float]:
        """Make a prediction and return it in a format suitable for the Gradio interface."""
        result = await self.predict(text)
        return result["label"], result["confidence"]
    
    def create_interface(self) -> gr.Interface:
        """Create a Gradio interface for text classification."""
        if self._model is None:
            self.load_model()
            
        examples = [
            ["This movie was fantastic! I really enjoyed it."],
            ["The service was terrible and the food was cold."],
            ["It was an okay experience, nothing special."],
            ["The weather is nice today!"],
            ["I'm feeling sick and tired."]
        ]
        
        return gr.Interface(
            fn=self.predict_for_interface,  # Use the interface-specific prediction function
            inputs=gr.Textbox(
                lines=3,
                placeholder="Enter text to classify...",
                label="Input Text"
            ),
            outputs=[
                gr.Label(label="Sentiment"),
                gr.Number(label="Confidence", precision=4)
            ],
            title=self.name,
            description=self.description + "\n\nThis model is also available via API!",
            examples=examples,
            api_name="predict"
        )