File size: 2,770 Bytes
318285f
 
6e9d355
 
d3eff8a
f655296
5d27647
65afda8
b820b0a
748a976
65afda8
748a976
d394f04
 
6fe0026
b0cd906
f655296
748a976
65afda8
c4ad33b
 
 
5d27647
 
 
 
6fe0026
5d27647
 
 
 
 
 
 
 
 
 
 
 
c4ad33b
5d27647
 
 
65afda8
 
 
748a976
65afda8
 
d394f04
 
 
 
 
 
 
 
 
748a976
 
c4ad33b
d394f04
748a976
d394f04
f655296
b0cd906
 
 
 
 
c4ad33b
 
 
f655296
b0cd906
c4ad33b
b820b0a
 
c4ad33b
 
5d27647
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import os

os.environ['HF_HOME'] = '/tmp/.cache/huggingface'  # Use /tmp in Spaces
os.makedirs(os.environ['HF_HOME'], exist_ok=True)  # Ensure directory exists

from fastapi import FastAPI
from fastapi.responses import HTMLResponse
from qwen_classifier.predict import predict_single  # Your existing function
from qwen_classifier.evaluate import evaluate_batch  # Your existing function
from qwen_classifier.globals import model, tokenizer
import torch
from transformers import AutoTokenizer
from huggingface_hub import login
from qwen_classifier.model import QwenClassifier
from qwen_classifier.config import HF_REPO
from pydantic import BaseModel


app = FastAPI(title="Qwen Classifier")
hf_repo = os.getenv("HF_REPO")
if not hf_repo:
    hf_repo = HF_REPO

# Add this endpoint
@app.get("/", response_class=HTMLResponse)
def home():
    return """
    <html>
        <head>
            <title>Qwen Classifier</title>
        </head>
        <body>
            <h1>Qwen Classifier API</h1>
            <p>Available endpoints:</p>
            <ul>
                <li><strong>POST /predict</strong> - Classify text</li>
                <li><strong>POST /evaluate</strong> - Evaluate batch text prediction from zip file</li>
                <li><strong>GET /health</strong> - Check API status</li>
            </ul>
            <p>Try it: <code>curl -X POST https://keivanr-qwen-classifier-demo.hf.space/predict -H "Content-Type: application/json" -d '{"text":"your text"}'</code></p>
        </body>
    </html>
    """

@app.on_event("startup")
async def load_model():
    global model, tokenizer
    # Warm up GPU
    torch.zeros(1).cuda() 
    # Read HF_TOKEN from Hugging Face Space secrets
    hf_token = os.getenv("HF_TOKEN")
    if not hf_token:
        raise ValueError("HF_TOKEN not found in environment variables")

    # Authenticate
    login(token=hf_token)
    
    # Load model (will cache in /home/user/.cache/huggingface)
    
    model = QwenClassifier.from_pretrained(
        hf_repo,
    )
    tokenizer = AutoTokenizer.from_pretrained(hf_repo)
    print("Model loaded successfully!")



class PredictionRequest(BaseModel):
    text: str  # ← Enforces that 'text' must be a non-empty string

class EvaluationRequest(BaseModel):
    file_path: str  # ← Enforces that 'text' must be a non-empty string

@app.post("/predict")
async def predict(request: PredictionRequest):  # ← Validates input automatically
    return predict_single(request.text, hf_repo, backend="local")

@app.post("/evaluate")
async def evaluate(request: EvaluationRequest):  # ← Validates input automatically
    return evaluate_batch(request.file_path, hf_repo, backend="local")

@app.get("/health")
def health_check():
    return {"status": "healthy", "model": "loaded"}