Spaces:
Running
Running
import gradio as gr | |
from ultralyticsplus import YOLO, render_result | |
import numpy as np | |
import time | |
import torch | |
# System Configuration | |
print("\n" + "="*40) | |
print(f"PyTorch: {torch.__version__}") | |
print(f"CUDA Available: {torch.cuda.is_available()}") | |
print("="*40 + "\n") | |
# Load model with optimized parameters for leaf counting | |
model = YOLO('foduucom/plant-leaf-detection-and-classification') | |
# Custom configuration for leaf counting | |
model.overrides.update({ | |
'conf': 0.15, # Lower confidence threshold for better recall | |
'iou': 0.25, # Lower IoU threshold for overlapping leaves | |
'imgsz': 1280, # Higher resolution for small leaves | |
'agnostic_nms': False, | |
'max_det': 300, # Higher maximum detections | |
'device': 'cuda' if torch.cuda.is_available() else 'cpu', | |
'classes': None, # Detect all classes (leaves only in this model) | |
'half': torch.cuda.is_available() | |
}) | |
def count_leaves(image): | |
try: | |
start_time = time.time() | |
# Preprocessing - enhance contrast | |
image = np.array(image) | |
lab = cv2.cvtColor(image, cv2.COLOR_RGB2LAB) | |
l, a, b = cv2.split(lab) | |
clahe = cv2.createCLAHE(clipLimit=3.0, tileGridSize=(8,8)) | |
cl = clahe.apply(l) | |
limg = cv2.merge((cl,a,b)) | |
enhanced_img = cv2.cvtColor(limg, cv2.COLOR_LAB2RGB) | |
# Prediction with overlap handling | |
results = model.predict( | |
source=enhanced_img, | |
augment=True, # Test time augmentation | |
verbose=False, | |
agnostic_nms=False, | |
overlap_mask=False | |
) | |
# Post-processing for overlapping leaves | |
boxes = results[0].boxes | |
valid_boxes = [] | |
# Filter small detections and merge overlapping | |
for box in boxes: | |
x1, y1, x2, y2 = box.xyxy[0].cpu().numpy() | |
w = x2 - x1 | |
h = y2 - y1 | |
# Filter too small boxes (adjust based on your leaf sizes) | |
if w > 20 and h > 20: | |
valid_boxes.append(box) | |
# Improved NMS for overlapping leaves | |
from utils.nms import non_max_suppression | |
final_boxes = non_max_suppression( | |
torch.stack([b.xywh[0] for b in valid_boxes]), | |
conf_thres=0.1, | |
iou_thres=0.15, | |
multi_label=False | |
) | |
num_leaves = len(final_boxes) | |
# Visual validation | |
debug_img = enhanced_img.copy() | |
for box in final_boxes: | |
x1, y1, x2, y2 = map(int, box[:4]) | |
cv2.rectangle(debug_img, (x1, y1), (x2, y2), (0,255,0), 2) | |
print(f"Processing time: {time.time()-start_time:.2f}s") | |
return debug_img, num_leaves | |
except Exception as e: | |
print(f"Error: {str(e)}") | |
return image, 0 | |
# Gradio interface with visualization | |
interface = gr.Interface( | |
fn=count_leaves, | |
inputs=gr.Image(label="Input Image"), | |
outputs=[ | |
gr.Image(label="Detection Visualization"), | |
gr.Number(label="Estimated Leaf Count") | |
], | |
title="π Advanced Leaf Counter", | |
description="Specialized for overlapping leaves and dense foliage", | |
examples=[ | |
["sample_leaf1.jpg"], | |
["sample_leaf2.jpg"] | |
] | |
) | |
if __name__ == "__main__": | |
interface.launch( | |
server_port=7860, | |
share=False | |
) |