import os os.environ['HF_HOME'] = '/tmp/.cache/huggingface' # Use /tmp in Spaces os.makedirs(os.environ['HF_HOME'], exist_ok=True) # Ensure directory exists from fastapi import FastAPI from fastapi.responses import HTMLResponse import torch import numpy from transformers import AutoTokenizer from huggingface_hub import login from pydantic import BaseModel import warnings from transformers import logging as hf_logging from qwen_classifier.predict import predict_single # Your existing function from qwen_classifier.evaluate import evaluate_batch # Your existing function from qwen_classifier.globals import global_model, global_tokenizer from qwen_classifier.model import QwenClassifier from qwen_classifier.config import HF_REPO, DEVICE print(numpy.__version__) app = FastAPI(title="Qwen Classifier") hf_repo = os.getenv("HF_REPO") if not hf_repo: hf_repo = HF_REPO debug = False if not debug: warnings.filterwarnings("ignore", message="Some weights of the model checkpoint") hf_logging.set_verbosity_error() else: hf_logging.set_verbosity_info() warnings.simplefilter("default") # Add this endpoint @app.get("/", response_class=HTMLResponse) def home(): return """ Qwen Classifier

Qwen Classifier API

Available endpoints:

Try it: curl -X POST https://keivanr-qwen-classifier-demo.hf.space/predict -H "Content-Type: application/json" -d '{"text":"your text"}'

""" @app.on_event("startup") async def load_model(): global global_model, global_tokenizer # Warm up GPU torch.zeros(1).cuda() # Read HF_TOKEN from Hugging Face Space secrets hf_token = os.getenv("HF_TOKEN") if not hf_token: raise ValueError("HF_TOKEN not found in environment variables") # Authenticate login(token=hf_token) # Load model (will cache in /home/user/.cache/huggingface) DEVICE = "cuda" if torch.cuda.is_available() else "cpu" model = QwenClassifier.from_pretrained( hf_repo, ).to(DEVICE) global_tokenizer = AutoTokenizer.from_pretrained(hf_repo) print("Model loaded successfully!") class PredictionRequest(BaseModel): text: str # ← Enforces that 'text' must be a non-empty string class EvaluationRequest(BaseModel): file_path: str # ← Enforces that 'text' must be a non-empty string @app.post("/predict") async def predict(request: PredictionRequest): # ← Validates input automatically return predict_single(request.text, hf_repo, backend="local") @app.post("/evaluate") async def evaluate(request: EvaluationRequest): # ← Validates input automatically return str(evaluate_batch(request.file_path, hf_repo, backend="local")) @app.get("/health") def health_check(): return {"status": "healthy", "model": "loaded"}