oversai-models / src /models /text_classification.py
ogirald0's picture
Initial commit for Hugging Face deployment
18869bb
from typing import Any, Dict, Union, Tuple
import gradio as gr
from transformers import pipeline, AutoModelForSequenceClassification, AutoTokenizer
import logging
from .base import BaseModel
logger = logging.getLogger(__name__)
class TextClassificationModel(BaseModel):
"""Lightweight text classification model using tiny BERT."""
def __init__(self):
super().__init__(
name="Lightweight Text Classifier",
description="Fast text classification using a tiny BERT model (4.4MB)"
)
self.model_name = "prajjwal1/bert-tiny"
self._model = None
def load_model(self) -> None:
"""Load the classification model."""
try:
logger.info(f"Loading model: {self.model_name}")
# Initialize model with binary classification
model = AutoModelForSequenceClassification.from_pretrained(
self.model_name,
num_labels=2
)
tokenizer = AutoTokenizer.from_pretrained(self.model_name)
self._model = pipeline(
"text-classification",
model=model,
tokenizer=tokenizer,
device=-1 # CPU, use device=0 for GPU
)
# Log model size
model_size_mb = sum(p.numel() * p.element_size() for p in model.parameters()) / (1024 * 1024)
logger.info(f"Model loaded successfully. Size: {model_size_mb:.2f} MB")
except Exception as e:
logger.error(f"Error loading model: {str(e)}")
raise
async def predict(self, text: str) -> Dict[str, Union[str, float]]:
"""Make a prediction using the model."""
try:
if self._model is None:
self.load_model()
logger.info(f"Processing text: {text[:50]}...")
result = self._model(text)[0]
# Map raw labels to sentiment
label_map = {
"LABEL_0": "NEGATIVE",
"LABEL_1": "POSITIVE"
}
prediction = {
"label": label_map.get(result["label"], result["label"]),
"confidence": float(result["score"])
}
logger.info(f"Prediction result: {prediction}")
return prediction
except Exception as e:
logger.error(f"Prediction error: {str(e)}")
raise
async def predict_for_interface(self, text: str) -> Tuple[str, float]:
"""Make a prediction and return it in a format suitable for the Gradio interface."""
result = await self.predict(text)
return result["label"], result["confidence"]
def create_interface(self) -> gr.Interface:
"""Create a Gradio interface for text classification."""
if self._model is None:
self.load_model()
examples = [
["This movie was fantastic! I really enjoyed it."],
["The service was terrible and the food was cold."],
["It was an okay experience, nothing special."],
["The weather is nice today!"],
["I'm feeling sick and tired."]
]
return gr.Interface(
fn=self.predict_for_interface, # Use the interface-specific prediction function
inputs=gr.Textbox(
lines=3,
placeholder="Enter text to classify...",
label="Input Text"
),
outputs=[
gr.Label(label="Sentiment"),
gr.Number(label="Confidence", precision=4)
],
title=self.name,
description=self.description + "\n\nThis model is also available via API!",
examples=examples,
api_name="predict"
)