Update app.py
Browse files
app.py
CHANGED
@@ -352,7 +352,7 @@ PLANTNET_API_KEY = os.getenv('PLANTNET_API_KEY', 'your-plantnet-key-here')
|
|
352 |
MODEL = "gpt-4o"
|
353 |
openai = OpenAI()
|
354 |
|
355 |
-
# Initialize VisionAgent
|
356 |
agent = VisionAgentCoderV2(verbose=False)
|
357 |
|
358 |
system_message = """You are an expert in object detection. When users mention:
|
@@ -360,10 +360,8 @@ system_message = """You are an expert in object detection. When users mention:
|
|
360 |
2. "detect [object(s)]" - Same as count
|
361 |
3. "show [object(s)]" - Same as count
|
362 |
|
363 |
-
Always use object detection tool when counting/detecting is mentioned.
|
364 |
-
|
365 |
-
system_message += "Always be accurate. If you don't know the answer, say so."
|
366 |
-
|
367 |
|
368 |
class State:
|
369 |
def __init__(self):
|
@@ -398,38 +396,28 @@ def detect_objects(query_text):
|
|
398 |
# Clean query text to get the object name
|
399 |
object_name = query_text[0].replace("a photo of ", "").strip()
|
400 |
|
401 |
-
#
|
402 |
-
# Create agent message for object detection
|
403 |
-
agent_message = [
|
404 |
-
AgentMessage(
|
405 |
-
role="user",
|
406 |
-
content=f"Count the number of {object_name} in this image. Only show detections with high confidence (>0.75).",
|
407 |
-
media=[image_path]
|
408 |
-
)
|
409 |
-
]
|
410 |
-
|
411 |
-
# Generate code using VisionAgent
|
412 |
-
code_context = agent.generate_code(agent_message)
|
413 |
-
|
414 |
-
# Load the image for visualization
|
415 |
image = T.load_image(image_path)
|
416 |
|
417 |
-
# Use
|
418 |
-
# First try the specialized detector
|
419 |
detections = T.countgd_object_detection(object_name, image, conf_threshold=0.55)
|
|
|
|
|
420 |
|
421 |
-
# If no
|
422 |
if not detections:
|
423 |
-
# Try a different model with the same high threshold
|
424 |
try:
|
425 |
detections = T.grounding_dino_detection(object_name, image, box_threshold=0.55)
|
426 |
-
|
427 |
-
|
|
|
|
|
|
|
428 |
|
429 |
-
# Only keep
|
430 |
-
high_conf_detections = [det for det in detections if det.get("score", 0)
|
431 |
|
432 |
-
# Visualize
|
433 |
result_image = T.overlay_bounding_boxes(
|
434 |
image,
|
435 |
high_conf_detections,
|
@@ -442,7 +430,7 @@ def detect_objects(query_text):
|
|
442 |
return {
|
443 |
"count": len(high_conf_detections),
|
444 |
"confidence": [det["score"] for det in high_conf_detections],
|
445 |
-
"message": f"Detected {len(high_conf_detections)} {object_name}(s) with high confidence (
|
446 |
}
|
447 |
except Exception as e:
|
448 |
print(f"Error in detect_objects: {str(e)}")
|
@@ -539,8 +527,6 @@ def chat(message, image, history):
|
|
539 |
|
540 |
# Extract objects to detect from user message
|
541 |
objects_to_detect = message.lower()
|
542 |
-
|
543 |
-
# Format query for object detection - keep it simple and direct
|
544 |
cleaned_query = objects_to_detect.replace("count", "").replace("detect", "").replace("show", "").strip()
|
545 |
query = ["a photo of " + cleaned_query]
|
546 |
|
@@ -559,11 +545,13 @@ def chat(message, image, history):
|
|
559 |
max_tokens=300
|
560 |
)
|
561 |
|
|
|
562 |
if response.choices[0].finish_reason == "tool_calls":
|
563 |
-
|
564 |
-
messages.append(
|
565 |
|
566 |
-
|
|
|
567 |
if tool_call.function.name == "detect_objects":
|
568 |
results = detect_objects(query)
|
569 |
else:
|
@@ -604,8 +592,8 @@ with gr.Blocks() as demo:
|
|
604 |
output_image = gr.Image(type="numpy", label="Detection Results")
|
605 |
|
606 |
def process_interaction(message, image, history):
|
607 |
-
|
608 |
-
history.append((message,
|
609 |
return "", pred_image, history
|
610 |
|
611 |
def reset_interface():
|
@@ -636,4 +624,4 @@ Examples:
|
|
636 |
- "What species is this plant?"
|
637 |
""")
|
638 |
|
639 |
-
demo.launch(share=True)
|
|
|
352 |
MODEL = "gpt-4o"
|
353 |
openai = OpenAI()
|
354 |
|
355 |
+
# Initialize VisionAgent (kept for potential future use, though not used directly in detection below)
|
356 |
agent = VisionAgentCoderV2(verbose=False)
|
357 |
|
358 |
system_message = """You are an expert in object detection. When users mention:
|
|
|
360 |
2. "detect [object(s)]" - Same as count
|
361 |
3. "show [object(s)]" - Same as count
|
362 |
|
363 |
+
Always use object detection tool when counting/detecting is mentioned.
|
364 |
+
Always be accurate. If you don't know the answer, say so."""
|
|
|
|
|
365 |
|
366 |
class State:
|
367 |
def __init__(self):
|
|
|
396 |
# Clean query text to get the object name
|
397 |
object_name = query_text[0].replace("a photo of ", "").strip()
|
398 |
|
399 |
+
# Load the image for detection and visualization
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
400 |
image = T.load_image(image_path)
|
401 |
|
402 |
+
# Use the specialized detector first with a threshold of 0.55
|
|
|
403 |
detections = T.countgd_object_detection(object_name, image, conf_threshold=0.55)
|
404 |
+
if detections is None:
|
405 |
+
detections = []
|
406 |
|
407 |
+
# If no detections, try the more general grounding_dino detector
|
408 |
if not detections:
|
|
|
409 |
try:
|
410 |
detections = T.grounding_dino_detection(object_name, image, box_threshold=0.55)
|
411 |
+
if detections is None:
|
412 |
+
detections = []
|
413 |
+
except Exception as e:
|
414 |
+
print(f"Error in grounding_dino_detection: {str(e)}")
|
415 |
+
detections = []
|
416 |
|
417 |
+
# Only keep detections with confidence higher than 0.55
|
418 |
+
high_conf_detections = [det for det in detections if det.get("score", 0) >= 0.55]
|
419 |
|
420 |
+
# Visualize the high confidence detections with clear labeling
|
421 |
result_image = T.overlay_bounding_boxes(
|
422 |
image,
|
423 |
high_conf_detections,
|
|
|
430 |
return {
|
431 |
"count": len(high_conf_detections),
|
432 |
"confidence": [det["score"] for det in high_conf_detections],
|
433 |
+
"message": f"Detected {len(high_conf_detections)} {object_name}(s) with high confidence (>=0.55)"
|
434 |
}
|
435 |
except Exception as e:
|
436 |
print(f"Error in detect_objects: {str(e)}")
|
|
|
527 |
|
528 |
# Extract objects to detect from user message
|
529 |
objects_to_detect = message.lower()
|
|
|
|
|
530 |
cleaned_query = objects_to_detect.replace("count", "").replace("detect", "").replace("show", "").strip()
|
531 |
query = ["a photo of " + cleaned_query]
|
532 |
|
|
|
545 |
max_tokens=300
|
546 |
)
|
547 |
|
548 |
+
# Check if a tool call is required based on the response
|
549 |
if response.choices[0].finish_reason == "tool_calls":
|
550 |
+
message_obj = response.choices[0].message
|
551 |
+
messages.append(message_obj)
|
552 |
|
553 |
+
# Process each tool call from the message
|
554 |
+
for tool_call in message_obj.tool_calls:
|
555 |
if tool_call.function.name == "detect_objects":
|
556 |
results = detect_objects(query)
|
557 |
else:
|
|
|
592 |
output_image = gr.Image(type="numpy", label="Detection Results")
|
593 |
|
594 |
def process_interaction(message, image, history):
|
595 |
+
response_text, pred_image = chat(message, image, history)
|
596 |
+
history.append((message, response_text))
|
597 |
return "", pred_image, history
|
598 |
|
599 |
def reset_interface():
|
|
|
624 |
- "What species is this plant?"
|
625 |
""")
|
626 |
|
627 |
+
demo.launch(share=True)
|