from typing import Any, Dict, Union, Tuple import gradio as gr from transformers import pipeline, AutoModelForSequenceClassification, AutoTokenizer import logging from .base import BaseModel logger = logging.getLogger(__name__) class TextClassificationModel(BaseModel): """Lightweight text classification model using tiny BERT.""" def __init__(self): super().__init__( name="Lightweight Text Classifier", description="Fast text classification using a tiny BERT model (4.4MB)" ) self.model_name = "prajjwal1/bert-tiny" self._model = None def load_model(self) -> None: """Load the classification model.""" try: logger.info(f"Loading model: {self.model_name}") # Initialize model with binary classification model = AutoModelForSequenceClassification.from_pretrained( self.model_name, num_labels=2 ) tokenizer = AutoTokenizer.from_pretrained(self.model_name) self._model = pipeline( "text-classification", model=model, tokenizer=tokenizer, device=-1 # CPU, use device=0 for GPU ) # Log model size model_size_mb = sum(p.numel() * p.element_size() for p in model.parameters()) / (1024 * 1024) logger.info(f"Model loaded successfully. Size: {model_size_mb:.2f} MB") except Exception as e: logger.error(f"Error loading model: {str(e)}") raise async def predict(self, text: str) -> Dict[str, Union[str, float]]: """Make a prediction using the model.""" try: if self._model is None: self.load_model() logger.info(f"Processing text: {text[:50]}...") result = self._model(text)[0] # Map raw labels to sentiment label_map = { "LABEL_0": "NEGATIVE", "LABEL_1": "POSITIVE" } prediction = { "label": label_map.get(result["label"], result["label"]), "confidence": float(result["score"]) } logger.info(f"Prediction result: {prediction}") return prediction except Exception as e: logger.error(f"Prediction error: {str(e)}") raise async def predict_for_interface(self, text: str) -> Tuple[str, float]: """Make a prediction and return it in a format suitable for the Gradio interface.""" result = await self.predict(text) return result["label"], result["confidence"] def create_interface(self) -> gr.Interface: """Create a Gradio interface for text classification.""" if self._model is None: self.load_model() examples = [ ["This movie was fantastic! I really enjoyed it."], ["The service was terrible and the food was cold."], ["It was an okay experience, nothing special."], ["The weather is nice today!"], ["I'm feeling sick and tired."] ] return gr.Interface( fn=self.predict_for_interface, # Use the interface-specific prediction function inputs=gr.Textbox( lines=3, placeholder="Enter text to classify...", label="Input Text" ), outputs=[ gr.Label(label="Sentiment"), gr.Number(label="Confidence", precision=4) ], title=self.name, description=self.description + "\n\nThis model is also available via API!", examples=examples, api_name="predict" )