KeivanR commited on
Commit
b820b0a
·
verified ·
1 Parent(s): c18aa66

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -2
app.py CHANGED
@@ -5,12 +5,14 @@ os.makedirs(os.environ['HF_HOME'], exist_ok=True) # Ensure directory exists
5
 
6
  from fastapi import FastAPI
7
  from qwen_classifier.predict import predict_single # Your existing function
 
8
  import torch
9
  from huggingface_hub import login
10
  from qwen_classifier.model import QwenClassifier
11
  from pydantic import BaseModel
12
 
13
  app = FastAPI(title="Qwen Classifier")
 
14
 
15
  @app.on_event("startup")
16
  async def load_model():
@@ -26,7 +28,7 @@ async def load_model():
26
 
27
  # Load model (will cache in /home/user/.cache/huggingface)
28
  app.state.model = QwenClassifier.from_pretrained(
29
- 'KeivanR/Qwen2.5-1.5B-Instruct-MLB-clf_lora-1743189446',
30
  )
31
  print("Model loaded successfully!")
32
 
@@ -37,4 +39,8 @@ class PredictionRequest(BaseModel):
37
 
38
  @app.post("/predict")
39
  async def predict(request: PredictionRequest): # ← Validates input automatically
40
- return predict_single(request.text, backend="local")
 
 
 
 
 
5
 
6
  from fastapi import FastAPI
7
  from qwen_classifier.predict import predict_single # Your existing function
8
+ from qwen_classifier.evaluate import evaluate_batch # Your existing function
9
  import torch
10
  from huggingface_hub import login
11
  from qwen_classifier.model import QwenClassifier
12
  from pydantic import BaseModel
13
 
14
  app = FastAPI(title="Qwen Classifier")
15
+ hf_repo = 'KeivanR/Qwen2.5-1.5B-Instruct-MLB-clf_lora-1743189446'
16
 
17
  @app.on_event("startup")
18
  async def load_model():
 
28
 
29
  # Load model (will cache in /home/user/.cache/huggingface)
30
  app.state.model = QwenClassifier.from_pretrained(
31
+ hf_repo,
32
  )
33
  print("Model loaded successfully!")
34
 
 
39
 
40
  @app.post("/predict")
41
  async def predict(request: PredictionRequest): # ← Validates input automatically
42
+ return predict_single(request.text, hf_repo, backend="local")
43
+
44
+ @app.post("/evaluate")
45
+ async def evaluate(request: PredictionRequest): # ← Validates input automatically
46
+ return evaluate_batch(request.text, backend="local")