Spaces:
Running
Running
from fastapi import APIRouter, HTTPException | |
from pydantic import BaseModel | |
from typing import Dict, Union, List | |
from models.text_classification import TextClassificationModel | |
router = APIRouter() | |
model = TextClassificationModel() | |
class TextInput(BaseModel): | |
text: str | |
class BatchTextInput(BaseModel): | |
texts: List[str] | |
class PredictionResponse(BaseModel): | |
label: str | |
confidence: float | |
class BatchPredictionResponse(BaseModel): | |
predictions: List[PredictionResponse] | |
async def predict(input_data: TextInput) -> Dict[str, Union[str, float]]: | |
"""Make a prediction for a single text.""" | |
try: | |
result = await model.predict(input_data.text) | |
return result | |
except Exception as e: | |
raise HTTPException( | |
status_code=500, | |
detail=f"Prediction failed: {str(e)}" | |
) | |
async def predict_batch(input_data: BatchTextInput) -> Dict[str, List[Dict[str, Union[str, float]]]]: | |
"""Make predictions for multiple texts.""" | |
try: | |
predictions = [] | |
for text in input_data.texts: | |
result = await model.predict(text) | |
predictions.append(result) | |
return {"predictions": predictions} | |
except Exception as e: | |
raise HTTPException( | |
status_code=500, | |
detail=f"Batch prediction failed: {str(e)}" | |
) | |
async def get_model_info(): | |
"""Get information about the text classification model.""" | |
return model.get_info() |