SparshSG commited on
Commit
d1e0895
·
verified ·
1 Parent(s): 986ea8d

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +78 -0
app.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import numpy as np
3
+ import tensorflow as tf
4
+ import tensorflow_hub as hub
5
+ import cv2
6
+ from PIL import Image
7
+ import time
8
+
9
+ # Set page title and favicon
10
+ st.set_page_config(page_title="Cat and Dog Classifier", page_icon="🐱🐶")
11
+
12
+ # Load the pre-trained model
13
+ mobilenet_model = 'https://tfhub.dev/google/tf2-preview/mobilenet_v2/feature_vector/4'
14
+ pretrained_model = hub.KerasLayer(mobilenet_model, input_shape=(224, 224, 3), trainable=False)
15
+ num_of_classes = 2
16
+ model = tf.keras.Sequential([
17
+ pretrained_model,
18
+ tf.keras.layers.Dense(num_of_classes)
19
+ ])
20
+
21
+ model.compile(
22
+ optimizer='adam',
23
+ loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
24
+ metrics=['acc']
25
+ )
26
+
27
+ model.load_weights("cat_dog_classifier.h5")
28
+
29
+ # Define functions for image resizing and classification
30
+ def preprocess_image(image):
31
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
32
+ image = cv2.resize(image, (224, 224))
33
+ image = image / 255.0
34
+ image = np.expand_dims(image, axis=0)
35
+ return image
36
+
37
+ def classify_image(image):
38
+ image = preprocess_image(image)
39
+ prediction = model.predict(image)
40
+ return np.argmax(prediction)
41
+
42
+ # Sidebar
43
+ st.sidebar.header("Cat and Dog Classifier")
44
+ uploaded_image = st.sidebar.file_uploader("Upload an image", type=["jpg", "png", "jpeg"])
45
+
46
+ # Main content
47
+ st.title("Cat/Dog Image Classification")
48
+
49
+ if uploaded_image:
50
+ with st.spinner("Uploading image..."):
51
+ time.sleep(2) # Simulate image upload process, replace with actual image upload logic
52
+
53
+ st.success("Image upload complete!")
54
+
55
+ image = Image.open(uploaded_image)
56
+ st.image(image, caption="Uploaded Image", use_column_width=True)
57
+
58
+ if st.button("Classify"):
59
+ image = np.array(image)
60
+ pred_label = classify_image(image)
61
+
62
+ if pred_label == 0:
63
+ st.write('<div style="font-size: 24px; color: white;">Prediction: It\'s a Cat 😺</div>', unsafe_allow_html=True)
64
+ else:
65
+ st.write('<div style="font-size: 24px; color: white;">Prediction: It\'s a Dog 🐶</div>', unsafe_allow_html=True)
66
+
67
+ # Add a footer with CSS for positioning
68
+ st.markdown(
69
+ """
70
+ <div style="position: fixed; bottom: 0; right: 10px; padding: 10px; color: white;">
71
+ <a href="https://github.com/sg-sparsh-goyal" target="_blank" style="color: white; text-decoration: none;">
72
+ ✨ Github
73
+ </a><br>
74
+ By Sparsh
75
+ </div>
76
+ """,
77
+ unsafe_allow_html=True
78
+ )