Spaces:
Running
Running
from typing import Any, Dict, Union, Tuple | |
import gradio as gr | |
from transformers import pipeline, AutoModelForSequenceClassification, AutoTokenizer | |
import logging | |
from .base import BaseModel | |
logger = logging.getLogger(__name__) | |
class TextClassificationModel(BaseModel): | |
"""Lightweight text classification model using tiny BERT.""" | |
def __init__(self): | |
super().__init__( | |
name="Lightweight Text Classifier", | |
description="Fast text classification using a tiny BERT model (4.4MB)" | |
) | |
self.model_name = "prajjwal1/bert-tiny" | |
self._model = None | |
def load_model(self) -> None: | |
"""Load the classification model.""" | |
try: | |
logger.info(f"Loading model: {self.model_name}") | |
# Initialize model with binary classification | |
model = AutoModelForSequenceClassification.from_pretrained( | |
self.model_name, | |
num_labels=2 | |
) | |
tokenizer = AutoTokenizer.from_pretrained(self.model_name) | |
self._model = pipeline( | |
"text-classification", | |
model=model, | |
tokenizer=tokenizer, | |
device=-1 # CPU, use device=0 for GPU | |
) | |
# Log model size | |
model_size_mb = sum(p.numel() * p.element_size() for p in model.parameters()) / (1024 * 1024) | |
logger.info(f"Model loaded successfully. Size: {model_size_mb:.2f} MB") | |
except Exception as e: | |
logger.error(f"Error loading model: {str(e)}") | |
raise | |
async def predict(self, text: str) -> Dict[str, Union[str, float]]: | |
"""Make a prediction using the model.""" | |
try: | |
if self._model is None: | |
self.load_model() | |
logger.info(f"Processing text: {text[:50]}...") | |
result = self._model(text)[0] | |
# Map raw labels to sentiment | |
label_map = { | |
"LABEL_0": "NEGATIVE", | |
"LABEL_1": "POSITIVE" | |
} | |
prediction = { | |
"label": label_map.get(result["label"], result["label"]), | |
"confidence": float(result["score"]) | |
} | |
logger.info(f"Prediction result: {prediction}") | |
return prediction | |
except Exception as e: | |
logger.error(f"Prediction error: {str(e)}") | |
raise | |
async def predict_for_interface(self, text: str) -> Tuple[str, float]: | |
"""Make a prediction and return it in a format suitable for the Gradio interface.""" | |
result = await self.predict(text) | |
return result["label"], result["confidence"] | |
def create_interface(self) -> gr.Interface: | |
"""Create a Gradio interface for text classification.""" | |
if self._model is None: | |
self.load_model() | |
examples = [ | |
["This movie was fantastic! I really enjoyed it."], | |
["The service was terrible and the food was cold."], | |
["It was an okay experience, nothing special."], | |
["The weather is nice today!"], | |
["I'm feeling sick and tired."] | |
] | |
return gr.Interface( | |
fn=self.predict_for_interface, # Use the interface-specific prediction function | |
inputs=gr.Textbox( | |
lines=3, | |
placeholder="Enter text to classify...", | |
label="Input Text" | |
), | |
outputs=[ | |
gr.Label(label="Sentiment"), | |
gr.Number(label="Confidence", precision=4) | |
], | |
title=self.name, | |
description=self.description + "\n\nThis model is also available via API!", | |
examples=examples, | |
api_name="predict" | |
) |