Spaces:
Running
Running
File size: 3,911 Bytes
18869bb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 |
from typing import Any, Dict, Union, Tuple
import gradio as gr
from transformers import pipeline, AutoModelForSequenceClassification, AutoTokenizer
import logging
from .base import BaseModel
logger = logging.getLogger(__name__)
class TextClassificationModel(BaseModel):
"""Lightweight text classification model using tiny BERT."""
def __init__(self):
super().__init__(
name="Lightweight Text Classifier",
description="Fast text classification using a tiny BERT model (4.4MB)"
)
self.model_name = "prajjwal1/bert-tiny"
self._model = None
def load_model(self) -> None:
"""Load the classification model."""
try:
logger.info(f"Loading model: {self.model_name}")
# Initialize model with binary classification
model = AutoModelForSequenceClassification.from_pretrained(
self.model_name,
num_labels=2
)
tokenizer = AutoTokenizer.from_pretrained(self.model_name)
self._model = pipeline(
"text-classification",
model=model,
tokenizer=tokenizer,
device=-1 # CPU, use device=0 for GPU
)
# Log model size
model_size_mb = sum(p.numel() * p.element_size() for p in model.parameters()) / (1024 * 1024)
logger.info(f"Model loaded successfully. Size: {model_size_mb:.2f} MB")
except Exception as e:
logger.error(f"Error loading model: {str(e)}")
raise
async def predict(self, text: str) -> Dict[str, Union[str, float]]:
"""Make a prediction using the model."""
try:
if self._model is None:
self.load_model()
logger.info(f"Processing text: {text[:50]}...")
result = self._model(text)[0]
# Map raw labels to sentiment
label_map = {
"LABEL_0": "NEGATIVE",
"LABEL_1": "POSITIVE"
}
prediction = {
"label": label_map.get(result["label"], result["label"]),
"confidence": float(result["score"])
}
logger.info(f"Prediction result: {prediction}")
return prediction
except Exception as e:
logger.error(f"Prediction error: {str(e)}")
raise
async def predict_for_interface(self, text: str) -> Tuple[str, float]:
"""Make a prediction and return it in a format suitable for the Gradio interface."""
result = await self.predict(text)
return result["label"], result["confidence"]
def create_interface(self) -> gr.Interface:
"""Create a Gradio interface for text classification."""
if self._model is None:
self.load_model()
examples = [
["This movie was fantastic! I really enjoyed it."],
["The service was terrible and the food was cold."],
["It was an okay experience, nothing special."],
["The weather is nice today!"],
["I'm feeling sick and tired."]
]
return gr.Interface(
fn=self.predict_for_interface, # Use the interface-specific prediction function
inputs=gr.Textbox(
lines=3,
placeholder="Enter text to classify...",
label="Input Text"
),
outputs=[
gr.Label(label="Sentiment"),
gr.Number(label="Confidence", precision=4)
],
title=self.name,
description=self.description + "\n\nThis model is also available via API!",
examples=examples,
api_name="predict"
) |