File size: 5,210 Bytes
4224840
 
 
 
 
dae28cf
4224840
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import os, json
from datetime import datetime

import numpy as np
from PIL import Image, ImageFilter, ImageDraw
import torch
from transformers import CLIPProcessor, CLIPModel

import gradio as gr

# ─── Config ─────────────────────────────
MODEL_NAME       = "openai/clip-vit-base-patch16"
HISTORY_FILE     = "waste_history.json"
FEEDBACK_FILE    = "user_feedback.json"
WASTE_CATEGORIES = ["General Waste", "Recyclable Waste"]

# ─── Load CLIP ────────────────────────────
processor = CLIPProcessor.from_pretrained(MODEL_NAME)  # tokenizer + vision featurizer :contentReference[oaicite:3]{index=3}
model     = CLIPModel.from_pretrained(MODEL_NAME)      # ViT-B/16 + text encoder :contentReference[oaicite:4]{index=4}

# ─── JSON Helpers ─────────────────────────
def load_json(path, default):
    try:
        with open(path, "r") as f: return json.load(f)
    except: return default

def save_json(path, data):
    with open(path, "w") as f: json.dump(data, f, indent=2)

# ─── Image Focus + Blur ───────────────────
def detect_and_focus(img_array):
    img = Image.fromarray(img_array)
    w, h = img.size
    bw, bh = int(w*0.6), int(h*0.6)
    left, top = (w-bw)//2, (h-bh)//2
    right, bottom = left + bw, top + bh
    boxed = img.copy()
    ImageDraw.Draw(boxed).rectangle((left, top, right, bottom), outline="green", width=3)
    blurred = img.filter(ImageFilter.GaussianBlur(radius=10))
    blurred.paste(img.crop((left, top, right, bottom)), (left, top))
    return np.array(blurred), np.array(boxed), (left, top, right, bottom)

# ─── Classification Logic ─────────────────
def classify_waste(image):
    if image is None:
        return "No image", [], None, "No data yet.", None
    focused, boxed, bbox = detect_and_focus(image)
    inputs = processor(text=WASTE_CATEGORIES,
                       images=Image.fromarray(focused),
                       return_tensors="pt", padding=True)  # :contentReference[oaicite:5]{index=5}
    with torch.no_grad():
        outputs = model(**inputs)
    probs = outputs.logits_per_image.softmax(dim=1).cpu().numpy()[0]  # :contentReference[oaicite:6]{index=6}
    idx      = int(probs.argmax())
    category = WASTE_CATEGORIES[idx]
    result   = f"**{category}** — {probs[idx]*100:.1f}%"
    bar_data = [{"category": c, "confidence": float(probs[i])}
                for i, c in enumerate(WASTE_CATEGORIES)]
    history  = load_json(HISTORY_FILE, {"classifications":[]})
    history["classifications"].append({
        "timestamp": datetime.now().isoformat(),
        "category": category,
        "confidence": float(probs[idx])
    })
    save_json(HISTORY_FILE, history)
    # Build summary
    counts = {c:0 for c in WASTE_CATEGORIES}
    for e in history["classifications"]:
        counts[e["category"]] += 1
    total = sum(counts.values())
    if total:
        summary = "### History\n" + "\n".join(
            f"- **{c}**: {cnt} ({cnt/total*100:.1f}%)" for c, cnt in counts.items()
        )
    else:
        summary = "No history yet."
    return result, bar_data, boxed, summary, bbox

# ─── Feedback Logic ───────────────────────
def handle_feedback(sketch, correct_cat, bbox):
    if sketch is None or correct_cat is None:
        return "Draw on the canvas and select a category."
    fb = load_json(FEEDBACK_FILE, {"feedback":[]})
    fb["feedback"].append({
        "timestamp": datetime.now().isoformat(),
        "correct_category": correct_cat,
        "bbox": bbox
    })
    save_json(FEEDBACK_FILE, fb)
    return f"Saved as **{correct_cat}**!"

# ─── Gradio Interface ─────────────────────
with gr.Blocks(css=".gradio-container{background:#F9FAFB}") as demo:
    gr.Markdown("# 🌱 EcoCan Waste Classifier")
    with gr.Tabs():
        with gr.Tab("Classify"):
            inp   = gr.Image(type="numpy", label="Upload Waste Image")
            btn   = gr.Button("🔍 Classify", variant="primary")
            out_t = gr.Textbox(label="Result")
            out_b = gr.BarPlot(x="category", y="confidence", y_lim=[0,1],
                               title="Confidence Scores")
            out_i = gr.Image(label="Detected Object", interactive=False)
            btn.click(classify_waste, inputs=inp,
                      outputs=[out_t, out_b, out_i, gr.State(), gr.State()])
        with gr.Tab("Feedback"):
            sketchpad = gr.Sketchpad(label="Circle Object")  # :contentReference[oaicite:7]{index=7}
            radio     = gr.Radio(WASTE_CATEGORIES, label="Correct Category")
            fb_btn    = gr.Button("Submit Feedback")
            fb_out    = gr.Textbox(label="Status")
            fb_btn.click(handle_feedback,
                         inputs=[sketchpad, radio, gr.State()],
                         outputs=fb_out)
        with gr.Tab("History"):
            gr.Markdown("No history yet.", elem_id="history-markdown")
    demo.launch()