File size: 1,531 Bytes
318285f
 
6e9d355
 
d3eff8a
f655296
65afda8
b820b0a
65afda8
d394f04
 
b0cd906
f655296
65afda8
b820b0a
65afda8
 
 
 
 
d394f04
 
 
 
 
 
 
 
 
 
b820b0a
d394f04
 
f655296
b0cd906
 
 
 
 
f655296
b0cd906
b820b0a
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import os

os.environ['HF_HOME'] = '/tmp/.cache/huggingface'  # Use /tmp in Spaces
os.makedirs(os.environ['HF_HOME'], exist_ok=True)  # Ensure directory exists

from fastapi import FastAPI
from qwen_classifier.predict import predict_single  # Your existing function
from qwen_classifier.evaluate import evaluate_batch  # Your existing function
import torch
from huggingface_hub import login
from qwen_classifier.model import QwenClassifier
from pydantic import BaseModel

app = FastAPI(title="Qwen Classifier")
hf_repo = 'KeivanR/Qwen2.5-1.5B-Instruct-MLB-clf_lora-1743189446'

@app.on_event("startup")
async def load_model():
    # Warm up GPU
    torch.zeros(1).cuda() 
    # Read HF_TOKEN from Hugging Face Space secrets
    hf_token = os.getenv("HF_TOKEN")
    if not hf_token:
        raise ValueError("HF_TOKEN not found in environment variables")

    # Authenticate
    login(token=hf_token)
    
    # Load model (will cache in /home/user/.cache/huggingface)
    app.state.model = QwenClassifier.from_pretrained(
        hf_repo,
    )
    print("Model loaded successfully!")



class PredictionRequest(BaseModel):
    text: str  # ← Enforces that 'text' must be a non-empty string

@app.post("/predict")
async def predict(request: PredictionRequest):  # ← Validates input automatically
    return predict_single(request.text, hf_repo, backend="local")

@app.post("/evaluate")
async def evaluate(request: PredictionRequest):  # ← Validates input automatically
    return evaluate_batch(request.text, backend="local")