Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -5,12 +5,14 @@ os.makedirs(os.environ['HF_HOME'], exist_ok=True) # Ensure directory exists
|
|
5 |
|
6 |
from fastapi import FastAPI
|
7 |
from qwen_classifier.predict import predict_single # Your existing function
|
|
|
8 |
import torch
|
9 |
from huggingface_hub import login
|
10 |
from qwen_classifier.model import QwenClassifier
|
11 |
from pydantic import BaseModel
|
12 |
|
13 |
app = FastAPI(title="Qwen Classifier")
|
|
|
14 |
|
15 |
@app.on_event("startup")
|
16 |
async def load_model():
|
@@ -26,7 +28,7 @@ async def load_model():
|
|
26 |
|
27 |
# Load model (will cache in /home/user/.cache/huggingface)
|
28 |
app.state.model = QwenClassifier.from_pretrained(
|
29 |
-
|
30 |
)
|
31 |
print("Model loaded successfully!")
|
32 |
|
@@ -37,4 +39,8 @@ class PredictionRequest(BaseModel):
|
|
37 |
|
38 |
@app.post("/predict")
|
39 |
async def predict(request: PredictionRequest): # ← Validates input automatically
|
40 |
-
return predict_single(request.text, backend="local")
|
|
|
|
|
|
|
|
|
|
5 |
|
6 |
from fastapi import FastAPI
|
7 |
from qwen_classifier.predict import predict_single # Your existing function
|
8 |
+
from qwen_classifier.evaluate import evaluate_batch # Your existing function
|
9 |
import torch
|
10 |
from huggingface_hub import login
|
11 |
from qwen_classifier.model import QwenClassifier
|
12 |
from pydantic import BaseModel
|
13 |
|
14 |
app = FastAPI(title="Qwen Classifier")
|
15 |
+
hf_repo = 'KeivanR/Qwen2.5-1.5B-Instruct-MLB-clf_lora-1743189446'
|
16 |
|
17 |
@app.on_event("startup")
|
18 |
async def load_model():
|
|
|
28 |
|
29 |
# Load model (will cache in /home/user/.cache/huggingface)
|
30 |
app.state.model = QwenClassifier.from_pretrained(
|
31 |
+
hf_repo,
|
32 |
)
|
33 |
print("Model loaded successfully!")
|
34 |
|
|
|
39 |
|
40 |
@app.post("/predict")
|
41 |
async def predict(request: PredictionRequest): # ← Validates input automatically
|
42 |
+
return predict_single(request.text, hf_repo, backend="local")
|
43 |
+
|
44 |
+
@app.post("/evaluate")
|
45 |
+
async def evaluate(request: PredictionRequest): # ← Validates input automatically
|
46 |
+
return evaluate_batch(request.text, backend="local")
|