ecocan / app.py
Aarookie's picture
Update app.py
4224840 verified
import os, json
from datetime import datetime
import numpy as np
from PIL import Image, ImageFilter, ImageDraw
import torch
from transformers import CLIPProcessor, CLIPModel
import gradio as gr
# ─── Config ─────────────────────────────
MODEL_NAME = "openai/clip-vit-base-patch16"
HISTORY_FILE = "waste_history.json"
FEEDBACK_FILE = "user_feedback.json"
WASTE_CATEGORIES = ["General Waste", "Recyclable Waste"]
# ─── Load CLIP ────────────────────────────
processor = CLIPProcessor.from_pretrained(MODEL_NAME) # tokenizer + vision featurizer :contentReference[oaicite:3]{index=3}
model = CLIPModel.from_pretrained(MODEL_NAME) # ViT-B/16 + text encoder :contentReference[oaicite:4]{index=4}
# ─── JSON Helpers ─────────────────────────
def load_json(path, default):
try:
with open(path, "r") as f: return json.load(f)
except: return default
def save_json(path, data):
with open(path, "w") as f: json.dump(data, f, indent=2)
# ─── Image Focus + Blur ───────────────────
def detect_and_focus(img_array):
img = Image.fromarray(img_array)
w, h = img.size
bw, bh = int(w*0.6), int(h*0.6)
left, top = (w-bw)//2, (h-bh)//2
right, bottom = left + bw, top + bh
boxed = img.copy()
ImageDraw.Draw(boxed).rectangle((left, top, right, bottom), outline="green", width=3)
blurred = img.filter(ImageFilter.GaussianBlur(radius=10))
blurred.paste(img.crop((left, top, right, bottom)), (left, top))
return np.array(blurred), np.array(boxed), (left, top, right, bottom)
# ─── Classification Logic ─────────────────
def classify_waste(image):
if image is None:
return "No image", [], None, "No data yet.", None
focused, boxed, bbox = detect_and_focus(image)
inputs = processor(text=WASTE_CATEGORIES,
images=Image.fromarray(focused),
return_tensors="pt", padding=True) # :contentReference[oaicite:5]{index=5}
with torch.no_grad():
outputs = model(**inputs)
probs = outputs.logits_per_image.softmax(dim=1).cpu().numpy()[0] # :contentReference[oaicite:6]{index=6}
idx = int(probs.argmax())
category = WASTE_CATEGORIES[idx]
result = f"**{category}** — {probs[idx]*100:.1f}%"
bar_data = [{"category": c, "confidence": float(probs[i])}
for i, c in enumerate(WASTE_CATEGORIES)]
history = load_json(HISTORY_FILE, {"classifications":[]})
history["classifications"].append({
"timestamp": datetime.now().isoformat(),
"category": category,
"confidence": float(probs[idx])
})
save_json(HISTORY_FILE, history)
# Build summary
counts = {c:0 for c in WASTE_CATEGORIES}
for e in history["classifications"]:
counts[e["category"]] += 1
total = sum(counts.values())
if total:
summary = "### History\n" + "\n".join(
f"- **{c}**: {cnt} ({cnt/total*100:.1f}%)" for c, cnt in counts.items()
)
else:
summary = "No history yet."
return result, bar_data, boxed, summary, bbox
# ─── Feedback Logic ───────────────────────
def handle_feedback(sketch, correct_cat, bbox):
if sketch is None or correct_cat is None:
return "Draw on the canvas and select a category."
fb = load_json(FEEDBACK_FILE, {"feedback":[]})
fb["feedback"].append({
"timestamp": datetime.now().isoformat(),
"correct_category": correct_cat,
"bbox": bbox
})
save_json(FEEDBACK_FILE, fb)
return f"Saved as **{correct_cat}**!"
# ─── Gradio Interface ─────────────────────
with gr.Blocks(css=".gradio-container{background:#F9FAFB}") as demo:
gr.Markdown("# 🌱 EcoCan Waste Classifier")
with gr.Tabs():
with gr.Tab("Classify"):
inp = gr.Image(type="numpy", label="Upload Waste Image")
btn = gr.Button("🔍 Classify", variant="primary")
out_t = gr.Textbox(label="Result")
out_b = gr.BarPlot(x="category", y="confidence", y_lim=[0,1],
title="Confidence Scores")
out_i = gr.Image(label="Detected Object", interactive=False)
btn.click(classify_waste, inputs=inp,
outputs=[out_t, out_b, out_i, gr.State(), gr.State()])
with gr.Tab("Feedback"):
sketchpad = gr.Sketchpad(label="Circle Object") # :contentReference[oaicite:7]{index=7}
radio = gr.Radio(WASTE_CATEGORIES, label="Correct Category")
fb_btn = gr.Button("Submit Feedback")
fb_out = gr.Textbox(label="Status")
fb_btn.click(handle_feedback,
inputs=[sketchpad, radio, gr.State()],
outputs=fb_out)
with gr.Tab("History"):
gr.Markdown("No history yet.", elem_id="history-markdown")
demo.launch()