KeivanR commited on
Commit
0e5f0cf
·
1 Parent(s): a6bcf7e

numpy issue eval

Browse files
Files changed (1) hide show
  1. qwen_classifier/evaluate.py +2 -1
qwen_classifier/evaluate.py CHANGED
@@ -141,7 +141,8 @@ def _evaluate_local(test_data_path, hf_repo):
141
 
142
  logits = global_model(batch["input_ids"], batch["attention_mask"])
143
 
144
- preds = torch.sigmoid(logits).cpu().numpy() > 0.5
 
145
  labels = labels.cpu().numpy()
146
 
147
  all_preds.extend(preds)
 
141
 
142
  logits = global_model(batch["input_ids"], batch["attention_mask"])
143
 
144
+ preds = torch.sigmoid(logits).cpu() > 0.5 # Keeps as PyTorch tensor
145
+ preds = preds.float() # Convert to 0.0/1.0 if needed
146
  labels = labels.cpu().numpy()
147
 
148
  all_preds.extend(preds)