radub23 commited on
Commit
059897f
·
1 Parent(s): 7ca2a93

Fix tensor index conversion in predict function

Browse files
Files changed (1) hide show
  1. app.py +3 -8
app.py CHANGED
@@ -48,19 +48,14 @@ def detect_warning_lamp(image, history: list[tuple[str, str]], system_message):
48
  # Get model prediction
49
  pred_class, pred_idx, probs = learn_inf.predict(img)
50
 
51
- # Convert tensor outputs to Python types
52
- pred_class = str(pred_class) # Convert class name to string
53
- pred_idx = int(pred_idx) # Convert index to integer
54
- probs = [float(p) for p in probs] # Convert probabilities to float list
55
-
56
  # Format the prediction results
57
- confidence = probs[pred_idx] # Get confidence for predicted class
58
  response = f"Detected Warning Lamp: {pred_class}\nConfidence: {confidence:.2%}"
59
 
60
  # Add probabilities for all classes
61
  response += "\n\nProbabilities for all classes:"
62
- for cls, prob in zip(learn_inf.dls.vocab, probs):
63
- response += f"\n- {cls}: {prob:.2%}"
64
 
65
  # Update chat history
66
  history.append((None, response))
 
48
  # Get model prediction
49
  pred_class, pred_idx, probs = learn_inf.predict(img)
50
 
 
 
 
 
 
51
  # Format the prediction results
52
+ confidence = float(probs[int(pred_idx)]) # Convert to float for better formatting
53
  response = f"Detected Warning Lamp: {pred_class}\nConfidence: {confidence:.2%}"
54
 
55
  # Add probabilities for all classes
56
  response += "\n\nProbabilities for all classes:"
57
+ for i, (cls, prob) in enumerate(zip(learn_inf.dls.vocab, probs)):
58
+ response += f"\n- {cls}: {float(prob):.2%}"
59
 
60
  # Update chat history
61
  history.append((None, response))