Spaces:
Sleeping
Sleeping
import os | |
os.environ['HF_HOME'] = '/tmp/.cache/huggingface' # Use /tmp in Spaces | |
os.makedirs(os.environ['HF_HOME'], exist_ok=True) # Ensure directory exists | |
from fastapi import FastAPI | |
from fastapi.responses import HTMLResponse | |
import torch | |
import numpy | |
from transformers import AutoTokenizer | |
from huggingface_hub import login | |
from pydantic import BaseModel | |
import warnings | |
from transformers import logging as hf_logging | |
from qwen_classifier.predict import predict_single # Your existing function | |
from qwen_classifier.evaluate import evaluate_batch # Your existing function | |
from qwen_classifier.globals import global_model, global_tokenizer | |
from qwen_classifier.model import QwenClassifier | |
from qwen_classifier.config import HF_REPO, DEVICE | |
print(numpy.__version__) | |
app = FastAPI(title="Qwen Classifier") | |
hf_repo = os.getenv("HF_REPO") | |
if not hf_repo: | |
hf_repo = HF_REPO | |
debug = False | |
if not debug: | |
warnings.filterwarnings("ignore", message="Some weights of the model checkpoint") | |
hf_logging.set_verbosity_error() | |
else: | |
hf_logging.set_verbosity_info() | |
warnings.simplefilter("default") | |
# Add this endpoint | |
def home(): | |
return """ | |
<html> | |
<head> | |
<title>Qwen Classifier</title> | |
</head> | |
<body> | |
<h1>Qwen Classifier API</h1> | |
<p>Available endpoints:</p> | |
<ul> | |
<li><strong>POST /predict</strong> - Classify text</li> | |
<li><strong>POST /evaluate</strong> - Evaluate batch text prediction from zip file</li> | |
<li><strong>GET /health</strong> - Check API status</li> | |
</ul> | |
<p>Try it: <code>curl -X POST https://keivanr-qwen-classifier-demo.hf.space/predict -H "Content-Type: application/json" -d '{"text":"your text"}'</code></p> | |
</body> | |
</html> | |
""" | |
async def load_model(): | |
global global_model, global_tokenizer | |
# Warm up GPU | |
torch.zeros(1).cuda() | |
# Read HF_TOKEN from Hugging Face Space secrets | |
hf_token = os.getenv("HF_TOKEN") | |
if not hf_token: | |
raise ValueError("HF_TOKEN not found in environment variables") | |
# Authenticate | |
login(token=hf_token) | |
# Load model (will cache in /home/user/.cache/huggingface) | |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
model = QwenClassifier.from_pretrained( | |
hf_repo, | |
).to(DEVICE) | |
global_tokenizer = AutoTokenizer.from_pretrained(hf_repo) | |
print("Model loaded successfully!") | |
class PredictionRequest(BaseModel): | |
text: str # β Enforces that 'text' must be a non-empty string | |
class EvaluationRequest(BaseModel): | |
file_path: str # β Enforces that 'text' must be a non-empty string | |
async def predict(request: PredictionRequest): # β Validates input automatically | |
return predict_single(request.text, hf_repo, backend="local") | |
async def evaluate(request: EvaluationRequest): # β Validates input automatically | |
return str(evaluate_batch(request.file_path, hf_repo, backend="local")) | |
def health_check(): | |
return {"status": "healthy", "model": "loaded"} |