File size: 3,217 Bytes
318285f
 
6e9d355
 
d3eff8a
f655296
5d27647
65afda8
22054c0
748a976
d394f04
d45f589
 
 
 
 
 
 
d394f04
1c3c329
f655296
22054c0
748a976
65afda8
c4ad33b
 
 
5d27647
d45f589
 
 
 
 
 
 
 
5d27647
 
 
6fe0026
5d27647
 
 
 
 
 
 
 
 
 
 
 
c4ad33b
5d27647
 
 
65afda8
 
 
d45f589
65afda8
 
d394f04
 
 
 
 
 
 
 
 
a2fd1a6
748a976
c4ad33b
1c3c329
d45f589
d394f04
f655296
b0cd906
 
 
 
 
c4ad33b
 
 
f655296
b0cd906
c4ad33b
b820b0a
 
c4ad33b
f2f01ac
5d27647
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import os

os.environ['HF_HOME'] = '/tmp/.cache/huggingface'  # Use /tmp in Spaces
os.makedirs(os.environ['HF_HOME'], exist_ok=True)  # Ensure directory exists

from fastapi import FastAPI
from fastapi.responses import HTMLResponse
import torch
import numpy
from transformers import AutoTokenizer
from huggingface_hub import login
from pydantic import BaseModel
import warnings
from transformers import logging as hf_logging

from qwen_classifier.predict import predict_single  # Your existing function
from qwen_classifier.evaluate import evaluate_batch  # Your existing function
from qwen_classifier.globals import global_model, global_tokenizer
from qwen_classifier.model import QwenClassifier
from qwen_classifier.config import HF_REPO, DEVICE

print(numpy.__version__)

app = FastAPI(title="Qwen Classifier")
hf_repo = os.getenv("HF_REPO")
if not hf_repo:
    hf_repo = HF_REPO

debug = False
if not debug:
    warnings.filterwarnings("ignore", message="Some weights of the model checkpoint")
    hf_logging.set_verbosity_error()
else:
    hf_logging.set_verbosity_info()
    warnings.simplefilter("default")

# Add this endpoint
@app.get("/", response_class=HTMLResponse)
def home():
    return """
    <html>
        <head>
            <title>Qwen Classifier</title>
        </head>
        <body>
            <h1>Qwen Classifier API</h1>
            <p>Available endpoints:</p>
            <ul>
                <li><strong>POST /predict</strong> - Classify text</li>
                <li><strong>POST /evaluate</strong> - Evaluate batch text prediction from zip file</li>
                <li><strong>GET /health</strong> - Check API status</li>
            </ul>
            <p>Try it: <code>curl -X POST https://keivanr-qwen-classifier-demo.hf.space/predict -H "Content-Type: application/json" -d '{"text":"your text"}'</code></p>
        </body>
    </html>
    """

@app.on_event("startup")
async def load_model():
    global global_model, global_tokenizer
    # Warm up GPU
    torch.zeros(1).cuda() 
    # Read HF_TOKEN from Hugging Face Space secrets
    hf_token = os.getenv("HF_TOKEN")
    if not hf_token:
        raise ValueError("HF_TOKEN not found in environment variables")

    # Authenticate
    login(token=hf_token)
    
    # Load model (will cache in /home/user/.cache/huggingface)
    DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
    model = QwenClassifier.from_pretrained(
        hf_repo,
    ).to(DEVICE)
    global_tokenizer = AutoTokenizer.from_pretrained(hf_repo)
    print("Model loaded successfully!")



class PredictionRequest(BaseModel):
    text: str  # ← Enforces that 'text' must be a non-empty string

class EvaluationRequest(BaseModel):
    file_path: str  # ← Enforces that 'text' must be a non-empty string

@app.post("/predict")
async def predict(request: PredictionRequest):  # ← Validates input automatically
    return predict_single(request.text, hf_repo, backend="local")

@app.post("/evaluate")
async def evaluate(request: EvaluationRequest):  # ← Validates input automatically
    return str(evaluate_batch(request.file_path, hf_repo, backend="local"))

@app.get("/health")
def health_check():
    return {"status": "healthy", "model": "loaded"}