Iris_class / app.py
Rausda6's picture
Update app.py
5b7678f verified
raw
history blame contribute delete
2.02 kB
import gradio as gr
from transformers import AutoImageProcessor, AutoModelForImageClassification
from PIL import Image
import torch
import numpy as np
import spaces
import logging
# Set up verbose logging
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__)
# Load model and processor from the Hugging Face Hub
MODEL_REPO = "Rausda6/autotrain-uo2t1-gvgzu" # Replace with your actual model repo name
logger.debug(f"Loading model from: {MODEL_REPO}")
model = AutoModelForImageClassification.from_pretrained(MODEL_REPO)
processor = AutoImageProcessor.from_pretrained(MODEL_REPO)
labels = model.config.id2label
@spaces.GPU
def classify_image(img: Image.Image):
logger.debug("Received image for classification.")
try:
inputs = processor(images=img, return_tensors="pt")
logger.debug(f"Processed inputs: {inputs}")
with torch.no_grad():
outputs = model(**inputs)
logger.debug(f"Model outputs: {outputs}")
logits = outputs.logits
probs = torch.nn.functional.softmax(logits, dim=-1)[0]
logger.debug(f"Probabilities: {probs}")
# Build result dictionary with confidence values
probs_dict = {labels[i]: float(probs[i]) for i in range(len(probs))}
# Sort and format nicely
sorted_probs = sorted(probs_dict.items(), key=lambda x: x[1], reverse=True)
top_label, top_score = sorted_probs[0]
logger.debug(f"Top prediction: {top_label} with confidence {top_score:.2%}")
return top_label, dict(sorted_probs)
except Exception as e:
logger.exception("Error during classification")
raise e
# Gradio interface
demo = gr.Interface(
fn=classify_image,
inputs=gr.Image(type="pil"),
outputs=[gr.Label(label="Top Prediction"), gr.Label(num_top_classes=6, label="Class Probabilities")],
title="Image Classification with AutoTrain Model",
description="Upload a JPG image to classify it using the fine-tuned model."
)
demo.launch()