import os
os.environ['HF_HOME'] = '/tmp/.cache/huggingface' # Use /tmp in Spaces
os.makedirs(os.environ['HF_HOME'], exist_ok=True) # Ensure directory exists
from fastapi import FastAPI
from fastapi.responses import HTMLResponse
import torch
import numpy
from transformers import AutoTokenizer
from huggingface_hub import login
from pydantic import BaseModel
import warnings
from transformers import logging as hf_logging
from qwen_classifier.predict import predict_single # Your existing function
from qwen_classifier.evaluate import evaluate_batch # Your existing function
from qwen_classifier.globals import global_model, global_tokenizer
from qwen_classifier.model import QwenClassifier
from qwen_classifier.config import HF_REPO, DEVICE
print(numpy.__version__)
app = FastAPI(title="Qwen Classifier")
hf_repo = os.getenv("HF_REPO")
if not hf_repo:
hf_repo = HF_REPO
debug = False
if not debug:
warnings.filterwarnings("ignore", message="Some weights of the model checkpoint")
hf_logging.set_verbosity_error()
else:
hf_logging.set_verbosity_info()
warnings.simplefilter("default")
# Add this endpoint
@app.get("/", response_class=HTMLResponse)
def home():
return """
Qwen Classifier
Qwen Classifier API
Available endpoints:
- POST /predict - Classify text
- POST /evaluate - Evaluate batch text prediction from zip file
- GET /health - Check API status
Try it: curl -X POST https://keivanr-qwen-classifier-demo.hf.space/predict -H "Content-Type: application/json" -d '{"text":"your text"}'
"""
@app.on_event("startup")
async def load_model():
global global_model, global_tokenizer
# Warm up GPU
torch.zeros(1).cuda()
# Read HF_TOKEN from Hugging Face Space secrets
hf_token = os.getenv("HF_TOKEN")
if not hf_token:
raise ValueError("HF_TOKEN not found in environment variables")
# Authenticate
login(token=hf_token)
# Load model (will cache in /home/user/.cache/huggingface)
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
model = QwenClassifier.from_pretrained(
hf_repo,
).to(DEVICE)
global_tokenizer = AutoTokenizer.from_pretrained(hf_repo)
print("Model loaded successfully!")
class PredictionRequest(BaseModel):
text: str # ← Enforces that 'text' must be a non-empty string
class EvaluationRequest(BaseModel):
file_path: str # ← Enforces that 'text' must be a non-empty string
@app.post("/predict")
async def predict(request: PredictionRequest): # ← Validates input automatically
return predict_single(request.text, hf_repo, backend="local")
@app.post("/evaluate")
async def evaluate(request: EvaluationRequest): # ← Validates input automatically
return str(evaluate_batch(request.file_path, hf_repo, backend="local"))
@app.get("/health")
def health_check():
return {"status": "healthy", "model": "loaded"}