|
import streamlit as st |
|
import requests |
|
from PIL import Image |
|
from io import BytesIO |
|
|
|
CLASS_LABELS = { |
|
0: "airplane", |
|
1: "bird", |
|
2: "car", |
|
3: "cat", |
|
4: "deer", |
|
5: "dog", |
|
6: "horse", |
|
7: "monkey", |
|
8: "ship", |
|
9: "truck", |
|
} |
|
|
|
def get_classification(image_bytes): |
|
response = requests.post("http://localhost:5000/classify", files={"file": image_bytes}) |
|
class_id = response.json()["classification"] |
|
return CLASS_LABELS[class_id] |
|
|
|
st.title("Image Classification") |
|
st.write("Upload an image to classify") |
|
|
|
uploaded_file = st.file_uploader("Choose an image", type=["jpg", "jpeg", "png"]) |
|
|
|
if uploaded_file is not None: |
|
image = Image.open(uploaded_file) |
|
st.image(image, caption="Uploaded Image", use_column_width=True) |
|
|
|
if st.button("Classify"): |
|
img_bytes = uploaded_file.read() |
|
label = get_classification(img_bytes) |
|
st.write("Prediction:", label) |
|
|
|
|