leaf-counter / app.py
muskangoyal06's picture
Update app.py
4f80d37 verified
import gradio as gr
from ultralyticsplus import YOLO, render_result
import numpy as np
import time
import torch
# System Configuration
print("\n" + "="*40)
print(f"PyTorch: {torch.__version__}")
print(f"CUDA Available: {torch.cuda.is_available()}")
print("="*40 + "\n")
# Load model with optimized parameters for leaf counting
model = YOLO('foduucom/plant-leaf-detection-and-classification')
# Custom configuration for leaf counting
model.overrides.update({
'conf': 0.15, # Lower confidence threshold for better recall
'iou': 0.25, # Lower IoU threshold for overlapping leaves
'imgsz': 1280, # Higher resolution for small leaves
'agnostic_nms': False,
'max_det': 300, # Higher maximum detections
'device': 'cuda' if torch.cuda.is_available() else 'cpu',
'classes': None, # Detect all classes (leaves only in this model)
'half': torch.cuda.is_available()
})
def count_leaves(image):
try:
start_time = time.time()
# Preprocessing - enhance contrast
image = np.array(image)
lab = cv2.cvtColor(image, cv2.COLOR_RGB2LAB)
l, a, b = cv2.split(lab)
clahe = cv2.createCLAHE(clipLimit=3.0, tileGridSize=(8,8))
cl = clahe.apply(l)
limg = cv2.merge((cl,a,b))
enhanced_img = cv2.cvtColor(limg, cv2.COLOR_LAB2RGB)
# Prediction with overlap handling
results = model.predict(
source=enhanced_img,
augment=True, # Test time augmentation
verbose=False,
agnostic_nms=False,
overlap_mask=False
)
# Post-processing for overlapping leaves
boxes = results[0].boxes
valid_boxes = []
# Filter small detections and merge overlapping
for box in boxes:
x1, y1, x2, y2 = box.xyxy[0].cpu().numpy()
w = x2 - x1
h = y2 - y1
# Filter too small boxes (adjust based on your leaf sizes)
if w > 20 and h > 20:
valid_boxes.append(box)
# Improved NMS for overlapping leaves
from utils.nms import non_max_suppression
final_boxes = non_max_suppression(
torch.stack([b.xywh[0] for b in valid_boxes]),
conf_thres=0.1,
iou_thres=0.15,
multi_label=False
)
num_leaves = len(final_boxes)
# Visual validation
debug_img = enhanced_img.copy()
for box in final_boxes:
x1, y1, x2, y2 = map(int, box[:4])
cv2.rectangle(debug_img, (x1, y1), (x2, y2), (0,255,0), 2)
print(f"Processing time: {time.time()-start_time:.2f}s")
return debug_img, num_leaves
except Exception as e:
print(f"Error: {str(e)}")
return image, 0
# Gradio interface with visualization
interface = gr.Interface(
fn=count_leaves,
inputs=gr.Image(label="Input Image"),
outputs=[
gr.Image(label="Detection Visualization"),
gr.Number(label="Estimated Leaf Count")
],
title="πŸƒ Advanced Leaf Counter",
description="Specialized for overlapping leaves and dense foliage",
examples=[
["sample_leaf1.jpg"],
["sample_leaf2.jpg"]
]
)
if __name__ == "__main__":
interface.launch(
server_port=7860,
share=False
)