|
import os, json |
|
from datetime import datetime |
|
|
|
import numpy as np |
|
from PIL import Image, ImageFilter, ImageDraw |
|
import torch |
|
from transformers import CLIPProcessor, CLIPModel |
|
|
|
import gradio as gr |
|
|
|
|
|
MODEL_NAME = "openai/clip-vit-base-patch16" |
|
HISTORY_FILE = "waste_history.json" |
|
FEEDBACK_FILE = "user_feedback.json" |
|
WASTE_CATEGORIES = ["General Waste", "Recyclable Waste"] |
|
|
|
|
|
processor = CLIPProcessor.from_pretrained(MODEL_NAME) |
|
model = CLIPModel.from_pretrained(MODEL_NAME) |
|
|
|
|
|
def load_json(path, default): |
|
try: |
|
with open(path, "r") as f: return json.load(f) |
|
except: return default |
|
|
|
def save_json(path, data): |
|
with open(path, "w") as f: json.dump(data, f, indent=2) |
|
|
|
|
|
def detect_and_focus(img_array): |
|
img = Image.fromarray(img_array) |
|
w, h = img.size |
|
bw, bh = int(w*0.6), int(h*0.6) |
|
left, top = (w-bw)//2, (h-bh)//2 |
|
right, bottom = left + bw, top + bh |
|
boxed = img.copy() |
|
ImageDraw.Draw(boxed).rectangle((left, top, right, bottom), outline="green", width=3) |
|
blurred = img.filter(ImageFilter.GaussianBlur(radius=10)) |
|
blurred.paste(img.crop((left, top, right, bottom)), (left, top)) |
|
return np.array(blurred), np.array(boxed), (left, top, right, bottom) |
|
|
|
|
|
def classify_waste(image): |
|
if image is None: |
|
return "No image", [], None, "No data yet.", None |
|
focused, boxed, bbox = detect_and_focus(image) |
|
inputs = processor(text=WASTE_CATEGORIES, |
|
images=Image.fromarray(focused), |
|
return_tensors="pt", padding=True) |
|
with torch.no_grad(): |
|
outputs = model(**inputs) |
|
probs = outputs.logits_per_image.softmax(dim=1).cpu().numpy()[0] |
|
idx = int(probs.argmax()) |
|
category = WASTE_CATEGORIES[idx] |
|
result = f"**{category}** — {probs[idx]*100:.1f}%" |
|
bar_data = [{"category": c, "confidence": float(probs[i])} |
|
for i, c in enumerate(WASTE_CATEGORIES)] |
|
history = load_json(HISTORY_FILE, {"classifications":[]}) |
|
history["classifications"].append({ |
|
"timestamp": datetime.now().isoformat(), |
|
"category": category, |
|
"confidence": float(probs[idx]) |
|
}) |
|
save_json(HISTORY_FILE, history) |
|
|
|
counts = {c:0 for c in WASTE_CATEGORIES} |
|
for e in history["classifications"]: |
|
counts[e["category"]] += 1 |
|
total = sum(counts.values()) |
|
if total: |
|
summary = "### History\n" + "\n".join( |
|
f"- **{c}**: {cnt} ({cnt/total*100:.1f}%)" for c, cnt in counts.items() |
|
) |
|
else: |
|
summary = "No history yet." |
|
return result, bar_data, boxed, summary, bbox |
|
|
|
|
|
def handle_feedback(sketch, correct_cat, bbox): |
|
if sketch is None or correct_cat is None: |
|
return "Draw on the canvas and select a category." |
|
fb = load_json(FEEDBACK_FILE, {"feedback":[]}) |
|
fb["feedback"].append({ |
|
"timestamp": datetime.now().isoformat(), |
|
"correct_category": correct_cat, |
|
"bbox": bbox |
|
}) |
|
save_json(FEEDBACK_FILE, fb) |
|
return f"Saved as **{correct_cat}**!" |
|
|
|
|
|
with gr.Blocks(css=".gradio-container{background:#F9FAFB}") as demo: |
|
gr.Markdown("# 🌱 EcoCan Waste Classifier") |
|
with gr.Tabs(): |
|
with gr.Tab("Classify"): |
|
inp = gr.Image(type="numpy", label="Upload Waste Image") |
|
btn = gr.Button("🔍 Classify", variant="primary") |
|
out_t = gr.Textbox(label="Result") |
|
out_b = gr.BarPlot(x="category", y="confidence", y_lim=[0,1], |
|
title="Confidence Scores") |
|
out_i = gr.Image(label="Detected Object", interactive=False) |
|
btn.click(classify_waste, inputs=inp, |
|
outputs=[out_t, out_b, out_i, gr.State(), gr.State()]) |
|
with gr.Tab("Feedback"): |
|
sketchpad = gr.Sketchpad(label="Circle Object") |
|
radio = gr.Radio(WASTE_CATEGORIES, label="Correct Category") |
|
fb_btn = gr.Button("Submit Feedback") |
|
fb_out = gr.Textbox(label="Status") |
|
fb_btn.click(handle_feedback, |
|
inputs=[sketchpad, radio, gr.State()], |
|
outputs=fb_out) |
|
with gr.Tab("History"): |
|
gr.Markdown("No history yet.", elem_id="history-markdown") |
|
demo.launch() |
|
|