radub23
commited on
Commit
·
059897f
1
Parent(s):
7ca2a93
Fix tensor index conversion in predict function
Browse files
app.py
CHANGED
@@ -48,19 +48,14 @@ def detect_warning_lamp(image, history: list[tuple[str, str]], system_message):
|
|
48 |
# Get model prediction
|
49 |
pred_class, pred_idx, probs = learn_inf.predict(img)
|
50 |
|
51 |
-
# Convert tensor outputs to Python types
|
52 |
-
pred_class = str(pred_class) # Convert class name to string
|
53 |
-
pred_idx = int(pred_idx) # Convert index to integer
|
54 |
-
probs = [float(p) for p in probs] # Convert probabilities to float list
|
55 |
-
|
56 |
# Format the prediction results
|
57 |
-
confidence = probs[pred_idx] #
|
58 |
response = f"Detected Warning Lamp: {pred_class}\nConfidence: {confidence:.2%}"
|
59 |
|
60 |
# Add probabilities for all classes
|
61 |
response += "\n\nProbabilities for all classes:"
|
62 |
-
for cls, prob in zip(learn_inf.dls.vocab, probs):
|
63 |
-
response += f"\n- {cls}: {prob:.2%}"
|
64 |
|
65 |
# Update chat history
|
66 |
history.append((None, response))
|
|
|
48 |
# Get model prediction
|
49 |
pred_class, pred_idx, probs = learn_inf.predict(img)
|
50 |
|
|
|
|
|
|
|
|
|
|
|
51 |
# Format the prediction results
|
52 |
+
confidence = float(probs[int(pred_idx)]) # Convert to float for better formatting
|
53 |
response = f"Detected Warning Lamp: {pred_class}\nConfidence: {confidence:.2%}"
|
54 |
|
55 |
# Add probabilities for all classes
|
56 |
response += "\n\nProbabilities for all classes:"
|
57 |
+
for i, (cls, prob) in enumerate(zip(learn_inf.dls.vocab, probs)):
|
58 |
+
response += f"\n- {cls}: {float(prob):.2%}"
|
59 |
|
60 |
# Update chat history
|
61 |
history.append((None, response))
|