Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -10,6 +10,19 @@ import io
|
|
10 |
import gdown
|
11 |
from transformers import TFSegformerForSemanticSegmentation
|
12 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
st.set_page_config(
|
14 |
page_title="Pet Segmentation with SegFormer",
|
15 |
page_icon="🐶",
|
@@ -32,11 +45,12 @@ NUM_CLASSES = len(ID2LABEL)
|
|
32 |
def download_model_from_drive():
|
33 |
# Create a models directory
|
34 |
os.makedirs("models", exist_ok=True)
|
35 |
-
model_path = "models/
|
36 |
|
37 |
if not os.path.exists(model_path):
|
38 |
-
#
|
39 |
-
|
|
|
40 |
try:
|
41 |
gdown.download(url, model_path, quiet=False)
|
42 |
st.success("Model downloaded successfully from Google Drive.")
|
@@ -126,8 +140,18 @@ def preprocess_image(image):
|
|
126 |
# Store original image for display
|
127 |
original_img = img_array.copy()
|
128 |
|
129 |
-
# Resize to target size
|
130 |
-
img_resized = tf.image.resize(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
131 |
|
132 |
# Normalize
|
133 |
img_normalized = normalize_image(img_resized)
|
@@ -150,8 +174,21 @@ def create_mask(pred_mask):
|
|
150 |
Returns:
|
151 |
Processed mask (2D array)
|
152 |
"""
|
|
|
153 |
pred_mask = tf.math.argmax(pred_mask, axis=1)
|
|
|
|
|
154 |
pred_mask = tf.squeeze(pred_mask)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
155 |
return pred_mask.numpy()
|
156 |
|
157 |
def colorize_mask(mask):
|
@@ -308,21 +345,46 @@ def main():
|
|
308 |
|
309 |
# Preprocess and predict
|
310 |
with st.spinner("Generating segmentation mask..."):
|
311 |
-
|
312 |
-
|
313 |
-
|
314 |
-
|
315 |
-
|
316 |
-
|
317 |
-
|
318 |
-
|
319 |
-
|
320 |
-
|
321 |
-
|
322 |
-
|
323 |
-
|
324 |
-
|
325 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
326 |
|
327 |
# Display results
|
328 |
with col2:
|
|
|
10 |
import gdown
|
11 |
from transformers import TFSegformerForSemanticSegmentation
|
12 |
|
13 |
+
|
14 |
+
try:
|
15 |
+
# Limit GPU memory growth
|
16 |
+
gpus = tf.config.experimental.list_physical_devices('GPU')
|
17 |
+
if gpus:
|
18 |
+
for gpu in gpus:
|
19 |
+
tf.config.experimental.set_memory_growth(gpu, True)
|
20 |
+
st.sidebar.success(f"GPU available: {len(gpus)} device(s)")
|
21 |
+
else:
|
22 |
+
st.sidebar.warning("No GPU detected, using CPU")
|
23 |
+
except Exception as e:
|
24 |
+
st.sidebar.error(f"GPU config error: {e}")
|
25 |
+
|
26 |
st.set_page_config(
|
27 |
page_title="Pet Segmentation with SegFormer",
|
28 |
page_icon="🐶",
|
|
|
45 |
def download_model_from_drive():
|
46 |
# Create a models directory
|
47 |
os.makedirs("models", exist_ok=True)
|
48 |
+
model_path = "models/tf_model.h5"
|
49 |
|
50 |
if not os.path.exists(model_path):
|
51 |
+
# Extract the file ID from the sharing URL
|
52 |
+
file_id = "1XObpqG8qZ7YUyiRKbpVvxX11yQSK8Y_3"
|
53 |
+
url = f"https://drive.google.com/uc?id={file_id}"
|
54 |
try:
|
55 |
gdown.download(url, model_path, quiet=False)
|
56 |
st.success("Model downloaded successfully from Google Drive.")
|
|
|
140 |
# Store original image for display
|
141 |
original_img = img_array.copy()
|
142 |
|
143 |
+
# Resize to target size with preserve_aspect_ratio=False
|
144 |
+
img_resized = tf.image.resize(
|
145 |
+
img_array,
|
146 |
+
(IMAGE_SIZE, IMAGE_SIZE),
|
147 |
+
method='bilinear',
|
148 |
+
preserve_aspect_ratio=False, # Ensure exact dimensions
|
149 |
+
antialias=True
|
150 |
+
)
|
151 |
+
|
152 |
+
# Verify dimensions with assertion
|
153 |
+
tf.debugging.assert_equal(tf.shape(img_resized)[0:2], [IMAGE_SIZE, IMAGE_SIZE],
|
154 |
+
message="Image dimensions don't match expected size")
|
155 |
|
156 |
# Normalize
|
157 |
img_normalized = normalize_image(img_resized)
|
|
|
174 |
Returns:
|
175 |
Processed mask (2D array)
|
176 |
"""
|
177 |
+
# Take argmax along the class dimension (axis=1 for batch data)
|
178 |
pred_mask = tf.math.argmax(pred_mask, axis=1)
|
179 |
+
|
180 |
+
# Remove batch dimension and convert to numpy
|
181 |
pred_mask = tf.squeeze(pred_mask)
|
182 |
+
|
183 |
+
# Resize to match original image size if needed
|
184 |
+
if pred_mask.shape[0] != IMAGE_SIZE or pred_mask.shape[1] != IMAGE_SIZE:
|
185 |
+
pred_mask = tf.image.resize(
|
186 |
+
tf.expand_dims(pred_mask, axis=-1),
|
187 |
+
(IMAGE_SIZE, IMAGE_SIZE),
|
188 |
+
method='nearest'
|
189 |
+
)
|
190 |
+
pred_mask = tf.squeeze(pred_mask)
|
191 |
+
|
192 |
return pred_mask.numpy()
|
193 |
|
194 |
def colorize_mask(mask):
|
|
|
345 |
|
346 |
# Preprocess and predict
|
347 |
with st.spinner("Generating segmentation mask..."):
|
348 |
+
try:
|
349 |
+
# Preprocess the image
|
350 |
+
img_tensor, original_img = preprocess_image(image)
|
351 |
+
|
352 |
+
# Print shape to debug
|
353 |
+
st.write(f"DEBUG - Input tensor shape: {img_tensor.shape}")
|
354 |
+
|
355 |
+
# Make prediction with error handling
|
356 |
+
try:
|
357 |
+
outputs = model(pixel_values=img_tensor, training=False)
|
358 |
+
logits = outputs.logits
|
359 |
+
|
360 |
+
# Create visualization mask
|
361 |
+
mask = create_mask(logits)
|
362 |
+
|
363 |
+
# Colorize the mask
|
364 |
+
colorized_mask = colorize_mask(mask)
|
365 |
+
|
366 |
+
# Create overlay
|
367 |
+
overlay = create_overlay(original_img, colorized_mask, alpha=overlay_opacity)
|
368 |
+
except Exception as inference_error:
|
369 |
+
st.error(f"Inference error: {inference_error}")
|
370 |
+
st.write("Trying alternative approach...")
|
371 |
+
|
372 |
+
# Alternative: resize to exactly 512x512 with crop_or_pad
|
373 |
+
img_resized = tf.image.resize_with_crop_or_pad(
|
374 |
+
original_img, IMAGE_SIZE, IMAGE_SIZE
|
375 |
+
)
|
376 |
+
img_normalized = normalize_image(img_resized)
|
377 |
+
img_transposed = tf.transpose(img_normalized, (2, 0, 1))
|
378 |
+
img_tensor = tf.expand_dims(img_transposed, axis=0)
|
379 |
+
|
380 |
+
outputs = model(pixel_values=img_tensor, training=False)
|
381 |
+
logits = outputs.logits
|
382 |
+
mask = create_mask(logits)
|
383 |
+
colorized_mask = colorize_mask(mask)
|
384 |
+
overlay = create_overlay(original_img, colorized_mask, alpha=overlay_opacity)
|
385 |
+
except Exception as e:
|
386 |
+
st.error(f"Failed to process image: {e}")
|
387 |
+
st.stop()
|
388 |
|
389 |
# Display results
|
390 |
with col2:
|