KeivanR commited on
Commit
b0cd906
·
1 Parent(s): 2f3df87

text input as json

Browse files
Files changed (4) hide show
  1. app.py +8 -3
  2. qwen_classifier/evaluate.py +2 -3
  3. qwen_classifier/predict.py +11 -7
  4. setup.py +2 -1
app.py CHANGED
@@ -8,7 +8,7 @@ from qwen_classifier.predict import predict_single # Your existing function
8
  import torch
9
  from huggingface_hub import login
10
  from qwen_classifier.model import QwenClassifier
11
- import os
12
 
13
  app = FastAPI(title="Qwen Classifier")
14
 
@@ -30,6 +30,11 @@ async def load_model():
30
  )
31
  print("Model loaded successfully!")
32
 
 
 
 
 
 
33
  @app.post("/predict")
34
- async def predict(text: str):
35
- return predict_single(text, backend="local")
 
8
  import torch
9
  from huggingface_hub import login
10
  from qwen_classifier.model import QwenClassifier
11
+ from pydantic import BaseModel
12
 
13
  app = FastAPI(title="Qwen Classifier")
14
 
 
30
  )
31
  print("Model loaded successfully!")
32
 
33
+
34
+
35
+ class PredictionRequest(BaseModel):
36
+ text: str # ← Enforces that 'text' must be a non-empty string
37
+
38
  @app.post("/predict")
39
+ async def predict(request: PredictionRequest): # ← Validates input automatically
40
+ return predict_single(request.text, backend="local")
qwen_classifier/evaluate.py CHANGED
@@ -43,7 +43,6 @@ def preprocessing(df):
43
 
44
 
45
  def evaluate_model(test_data_path):
46
- # Load your test data
47
- # Implement evaluation logic
48
- # Return metrics like precision, recall, f1-score
49
  return metrics
 
43
 
44
 
45
  def evaluate_model(test_data_path):
46
+ df = load_data(test_data_path)
47
+ df = preprocessing(df)
 
48
  return metrics
qwen_classifier/predict.py CHANGED
@@ -1,6 +1,6 @@
1
  import torch
2
  import requests
3
- from .config import TAG_NAMES
4
 
5
  # Local model setup (only load if needed)
6
  local_model = None
@@ -31,18 +31,22 @@ def _predict_local(text, hf_repo):
31
  return _process_output(logits)
32
 
33
  def _predict_hf_api(text, hf_token=None):
34
- # Use your Space endpoint instead of direct model API
35
- SPACE_URL = "https://keivanr-qwen-classifier-demo.hf.space"
36
 
37
  try:
38
  response = requests.post(
39
  f"{SPACE_URL}/predict",
40
- json={"text": text},
41
- headers={"Authorization": f"Bearer {hf_token}"} if hf_token else {}
 
 
 
 
42
  )
 
43
  return response.json()
44
- except Exception as e:
45
- raise ValueError(f"Space API Error: {str(e)}")
46
 
47
  def _process_output(logits):
48
  probs = torch.sigmoid(logits)
 
1
  import torch
2
  import requests
3
+ from .config import TAG_NAMES, SPACE_URL
4
 
5
  # Local model setup (only load if needed)
6
  local_model = None
 
31
  return _process_output(logits)
32
 
33
  def _predict_hf_api(text, hf_token=None):
34
+
 
35
 
36
  try:
37
  response = requests.post(
38
  f"{SPACE_URL}/predict",
39
+ json={"text": text}, # This matches the Pydantic model
40
+ headers={
41
+ "Authorization": f"Bearer {hf_token}",
42
+ "Content-Type": "application/json"
43
+ } if hf_token else {"Content-Type": "application/json"},
44
+ timeout=10
45
  )
46
+ response.raise_for_status() # Raise HTTP errors
47
  return response.json()
48
+ except requests.exceptions.RequestException as e:
49
+ raise ValueError(f"API Error: {str(e)}\nResponse: {e.response.text if hasattr(e, 'response') else ''}")
50
 
51
  def _process_output(logits):
52
  probs = torch.sigmoid(logits)
setup.py CHANGED
@@ -11,7 +11,8 @@ setup(
11
  'scikit-learn',
12
  'huggingface_hub',
13
  'requests',
14
- 'pandas'
 
15
  ],
16
  entry_points={
17
  'console_scripts': [
 
11
  'scikit-learn',
12
  'huggingface_hub',
13
  'requests',
14
+ 'pandas',
15
+ 'pydantic'
16
  ],
17
  entry_points={
18
  'console_scripts': [