Spaces:
Sleeping
Sleeping
File size: 3,217 Bytes
318285f 6e9d355 d3eff8a f655296 5d27647 65afda8 22054c0 748a976 d394f04 d45f589 d394f04 1c3c329 f655296 22054c0 748a976 65afda8 c4ad33b 5d27647 d45f589 5d27647 6fe0026 5d27647 c4ad33b 5d27647 65afda8 d45f589 65afda8 d394f04 a2fd1a6 748a976 c4ad33b 1c3c329 d45f589 d394f04 f655296 b0cd906 c4ad33b f655296 b0cd906 c4ad33b b820b0a c4ad33b f2f01ac 5d27647 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 |
import os
os.environ['HF_HOME'] = '/tmp/.cache/huggingface' # Use /tmp in Spaces
os.makedirs(os.environ['HF_HOME'], exist_ok=True) # Ensure directory exists
from fastapi import FastAPI
from fastapi.responses import HTMLResponse
import torch
import numpy
from transformers import AutoTokenizer
from huggingface_hub import login
from pydantic import BaseModel
import warnings
from transformers import logging as hf_logging
from qwen_classifier.predict import predict_single # Your existing function
from qwen_classifier.evaluate import evaluate_batch # Your existing function
from qwen_classifier.globals import global_model, global_tokenizer
from qwen_classifier.model import QwenClassifier
from qwen_classifier.config import HF_REPO, DEVICE
print(numpy.__version__)
app = FastAPI(title="Qwen Classifier")
hf_repo = os.getenv("HF_REPO")
if not hf_repo:
hf_repo = HF_REPO
debug = False
if not debug:
warnings.filterwarnings("ignore", message="Some weights of the model checkpoint")
hf_logging.set_verbosity_error()
else:
hf_logging.set_verbosity_info()
warnings.simplefilter("default")
# Add this endpoint
@app.get("/", response_class=HTMLResponse)
def home():
return """
<html>
<head>
<title>Qwen Classifier</title>
</head>
<body>
<h1>Qwen Classifier API</h1>
<p>Available endpoints:</p>
<ul>
<li><strong>POST /predict</strong> - Classify text</li>
<li><strong>POST /evaluate</strong> - Evaluate batch text prediction from zip file</li>
<li><strong>GET /health</strong> - Check API status</li>
</ul>
<p>Try it: <code>curl -X POST https://keivanr-qwen-classifier-demo.hf.space/predict -H "Content-Type: application/json" -d '{"text":"your text"}'</code></p>
</body>
</html>
"""
@app.on_event("startup")
async def load_model():
global global_model, global_tokenizer
# Warm up GPU
torch.zeros(1).cuda()
# Read HF_TOKEN from Hugging Face Space secrets
hf_token = os.getenv("HF_TOKEN")
if not hf_token:
raise ValueError("HF_TOKEN not found in environment variables")
# Authenticate
login(token=hf_token)
# Load model (will cache in /home/user/.cache/huggingface)
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
model = QwenClassifier.from_pretrained(
hf_repo,
).to(DEVICE)
global_tokenizer = AutoTokenizer.from_pretrained(hf_repo)
print("Model loaded successfully!")
class PredictionRequest(BaseModel):
text: str # β Enforces that 'text' must be a non-empty string
class EvaluationRequest(BaseModel):
file_path: str # β Enforces that 'text' must be a non-empty string
@app.post("/predict")
async def predict(request: PredictionRequest): # β Validates input automatically
return predict_single(request.text, hf_repo, backend="local")
@app.post("/evaluate")
async def evaluate(request: EvaluationRequest): # β Validates input automatically
return str(evaluate_batch(request.file_path, hf_repo, backend="local"))
@app.get("/health")
def health_check():
return {"status": "healthy", "model": "loaded"} |