File size: 1,069 Bytes
318285f
 
6e9d355
 
d3eff8a
f655296
65afda8
 
d394f04
 
 
f655296
65afda8
 
 
 
 
 
d394f04
 
 
 
 
 
 
 
 
 
8bb21ff
d394f04
 
f655296
 
65afda8
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import os

os.environ['HF_HOME'] = '/tmp/.cache/huggingface'  # Use /tmp in Spaces
os.makedirs(os.environ['HF_HOME'], exist_ok=True)  # Ensure directory exists

from fastapi import FastAPI
from qwen_classifier.predict import predict_single  # Your existing function
import torch
from huggingface_hub import login
from qwen_classifier.model import QwenClassifier
import os

app = FastAPI(title="Qwen Classifier")

@app.on_event("startup")
async def load_model():
    # Warm up GPU
    torch.zeros(1).cuda() 
    # Read HF_TOKEN from Hugging Face Space secrets
    hf_token = os.getenv("HF_TOKEN")
    if not hf_token:
        raise ValueError("HF_TOKEN not found in environment variables")

    # Authenticate
    login(token=hf_token)
    
    # Load model (will cache in /home/user/.cache/huggingface)
    app.state.model = QwenClassifier.from_pretrained(
        'KeivanR/Qwen2.5-1.5B-Instruct-MLB-clf_lora-1743189446',
    )
    print("Model loaded successfully!")

@app.post("/predict")
async def predict(text: str):
    return predict_single(text, backend="local")