obichimav commited on
Commit
2bd484d
·
verified ·
1 Parent(s): 7f884bb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -37
app.py CHANGED
@@ -352,7 +352,7 @@ PLANTNET_API_KEY = os.getenv('PLANTNET_API_KEY', 'your-plantnet-key-here')
352
  MODEL = "gpt-4o"
353
  openai = OpenAI()
354
 
355
- # Initialize VisionAgent
356
  agent = VisionAgentCoderV2(verbose=False)
357
 
358
  system_message = """You are an expert in object detection. When users mention:
@@ -360,10 +360,8 @@ system_message = """You are an expert in object detection. When users mention:
360
  2. "detect [object(s)]" - Same as count
361
  3. "show [object(s)]" - Same as count
362
 
363
- Always use object detection tool when counting/detecting is mentioned."""
364
-
365
- system_message += "Always be accurate. If you don't know the answer, say so."
366
-
367
 
368
  class State:
369
  def __init__(self):
@@ -398,38 +396,28 @@ def detect_objects(query_text):
398
  # Clean query text to get the object name
399
  object_name = query_text[0].replace("a photo of ", "").strip()
400
 
401
- # Let VisionAgent handle the detection with its agent-based approach
402
- # Create agent message for object detection
403
- agent_message = [
404
- AgentMessage(
405
- role="user",
406
- content=f"Count the number of {object_name} in this image. Only show detections with high confidence (>0.75).",
407
- media=[image_path]
408
- )
409
- ]
410
-
411
- # Generate code using VisionAgent
412
- code_context = agent.generate_code(agent_message)
413
-
414
- # Load the image for visualization
415
  image = T.load_image(image_path)
416
 
417
- # Use multiple models for detection and get high confidence results
418
- # First try the specialized detector
419
  detections = T.countgd_object_detection(object_name, image, conf_threshold=0.55)
 
 
420
 
421
- # If no high-confidence detections, try the more general object detector
422
  if not detections:
423
- # Try a different model with the same high threshold
424
  try:
425
  detections = T.grounding_dino_detection(object_name, image, box_threshold=0.55)
426
- except:
427
- pass
 
 
 
428
 
429
- # Only keep high confidence detections
430
- high_conf_detections = [det for det in detections if det.get("score", 0) > 0.55]
431
 
432
- # Visualize only high confidence results with clear labeling
433
  result_image = T.overlay_bounding_boxes(
434
  image,
435
  high_conf_detections,
@@ -442,7 +430,7 @@ def detect_objects(query_text):
442
  return {
443
  "count": len(high_conf_detections),
444
  "confidence": [det["score"] for det in high_conf_detections],
445
- "message": f"Detected {len(high_conf_detections)} {object_name}(s) with high confidence (>0.75)"
446
  }
447
  except Exception as e:
448
  print(f"Error in detect_objects: {str(e)}")
@@ -539,8 +527,6 @@ def chat(message, image, history):
539
 
540
  # Extract objects to detect from user message
541
  objects_to_detect = message.lower()
542
-
543
- # Format query for object detection - keep it simple and direct
544
  cleaned_query = objects_to_detect.replace("count", "").replace("detect", "").replace("show", "").strip()
545
  query = ["a photo of " + cleaned_query]
546
 
@@ -559,11 +545,13 @@ def chat(message, image, history):
559
  max_tokens=300
560
  )
561
 
 
562
  if response.choices[0].finish_reason == "tool_calls":
563
- message = response.choices[0].message
564
- messages.append(message)
565
 
566
- for tool_call in message.tool_calls:
 
567
  if tool_call.function.name == "detect_objects":
568
  results = detect_objects(query)
569
  else:
@@ -604,8 +592,8 @@ with gr.Blocks() as demo:
604
  output_image = gr.Image(type="numpy", label="Detection Results")
605
 
606
  def process_interaction(message, image, history):
607
- response, pred_image = chat(message, image, history)
608
- history.append((message, response))
609
  return "", pred_image, history
610
 
611
  def reset_interface():
@@ -636,4 +624,4 @@ Examples:
636
  - "What species is this plant?"
637
  """)
638
 
639
- demo.launch(share=True)
 
352
  MODEL = "gpt-4o"
353
  openai = OpenAI()
354
 
355
+ # Initialize VisionAgent (kept for potential future use, though not used directly in detection below)
356
  agent = VisionAgentCoderV2(verbose=False)
357
 
358
  system_message = """You are an expert in object detection. When users mention:
 
360
  2. "detect [object(s)]" - Same as count
361
  3. "show [object(s)]" - Same as count
362
 
363
+ Always use object detection tool when counting/detecting is mentioned.
364
+ Always be accurate. If you don't know the answer, say so."""
 
 
365
 
366
  class State:
367
  def __init__(self):
 
396
  # Clean query text to get the object name
397
  object_name = query_text[0].replace("a photo of ", "").strip()
398
 
399
+ # Load the image for detection and visualization
 
 
 
 
 
 
 
 
 
 
 
 
 
400
  image = T.load_image(image_path)
401
 
402
+ # Use the specialized detector first with a threshold of 0.55
 
403
  detections = T.countgd_object_detection(object_name, image, conf_threshold=0.55)
404
+ if detections is None:
405
+ detections = []
406
 
407
+ # If no detections, try the more general grounding_dino detector
408
  if not detections:
 
409
  try:
410
  detections = T.grounding_dino_detection(object_name, image, box_threshold=0.55)
411
+ if detections is None:
412
+ detections = []
413
+ except Exception as e:
414
+ print(f"Error in grounding_dino_detection: {str(e)}")
415
+ detections = []
416
 
417
+ # Only keep detections with confidence higher than 0.55
418
+ high_conf_detections = [det for det in detections if det.get("score", 0) >= 0.55]
419
 
420
+ # Visualize the high confidence detections with clear labeling
421
  result_image = T.overlay_bounding_boxes(
422
  image,
423
  high_conf_detections,
 
430
  return {
431
  "count": len(high_conf_detections),
432
  "confidence": [det["score"] for det in high_conf_detections],
433
+ "message": f"Detected {len(high_conf_detections)} {object_name}(s) with high confidence (>=0.55)"
434
  }
435
  except Exception as e:
436
  print(f"Error in detect_objects: {str(e)}")
 
527
 
528
  # Extract objects to detect from user message
529
  objects_to_detect = message.lower()
 
 
530
  cleaned_query = objects_to_detect.replace("count", "").replace("detect", "").replace("show", "").strip()
531
  query = ["a photo of " + cleaned_query]
532
 
 
545
  max_tokens=300
546
  )
547
 
548
+ # Check if a tool call is required based on the response
549
  if response.choices[0].finish_reason == "tool_calls":
550
+ message_obj = response.choices[0].message
551
+ messages.append(message_obj)
552
 
553
+ # Process each tool call from the message
554
+ for tool_call in message_obj.tool_calls:
555
  if tool_call.function.name == "detect_objects":
556
  results = detect_objects(query)
557
  else:
 
592
  output_image = gr.Image(type="numpy", label="Detection Results")
593
 
594
  def process_interaction(message, image, history):
595
+ response_text, pred_image = chat(message, image, history)
596
+ history.append((message, response_text))
597
  return "", pred_image, history
598
 
599
  def reset_interface():
 
624
  - "What species is this plant?"
625
  """)
626
 
627
+ demo.launch(share=True)