kishoreb4 commited on
Commit
822fcd2
·
verified ·
1 Parent(s): 7b71af3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +82 -20
app.py CHANGED
@@ -10,6 +10,19 @@ import io
10
  import gdown
11
  from transformers import TFSegformerForSemanticSegmentation
12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  st.set_page_config(
14
  page_title="Pet Segmentation with SegFormer",
15
  page_icon="🐶",
@@ -32,11 +45,12 @@ NUM_CLASSES = len(ID2LABEL)
32
  def download_model_from_drive():
33
  # Create a models directory
34
  os.makedirs("models", exist_ok=True)
35
- model_path = "models/best_model"
36
 
37
  if not os.path.exists(model_path):
38
- # Fixed Google Drive URL format for gdown
39
- url = "https://drive.google.com/file/d/1XObpqG8qZ7YUyiRKbpVvxX11yQSK8Y_3/view?usp=sharing"
 
40
  try:
41
  gdown.download(url, model_path, quiet=False)
42
  st.success("Model downloaded successfully from Google Drive.")
@@ -126,8 +140,18 @@ def preprocess_image(image):
126
  # Store original image for display
127
  original_img = img_array.copy()
128
 
129
- # Resize to target size
130
- img_resized = tf.image.resize(img_array, (IMAGE_SIZE, IMAGE_SIZE))
 
 
 
 
 
 
 
 
 
 
131
 
132
  # Normalize
133
  img_normalized = normalize_image(img_resized)
@@ -150,8 +174,21 @@ def create_mask(pred_mask):
150
  Returns:
151
  Processed mask (2D array)
152
  """
 
153
  pred_mask = tf.math.argmax(pred_mask, axis=1)
 
 
154
  pred_mask = tf.squeeze(pred_mask)
 
 
 
 
 
 
 
 
 
 
155
  return pred_mask.numpy()
156
 
157
  def colorize_mask(mask):
@@ -308,21 +345,46 @@ def main():
308
 
309
  # Preprocess and predict
310
  with st.spinner("Generating segmentation mask..."):
311
- # Preprocess the image
312
- img_tensor, original_img = preprocess_image(image)
313
-
314
- # Make prediction
315
- outputs = model(pixel_values=img_tensor, training=False)
316
- logits = outputs.logits
317
-
318
- # Create visualization mask
319
- mask = create_mask(logits)
320
-
321
- # Colorize the mask
322
- colorized_mask = colorize_mask(mask)
323
-
324
- # Create overlay
325
- overlay = create_overlay(original_img, colorized_mask, alpha=overlay_opacity)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
326
 
327
  # Display results
328
  with col2:
 
10
  import gdown
11
  from transformers import TFSegformerForSemanticSegmentation
12
 
13
+
14
+ try:
15
+ # Limit GPU memory growth
16
+ gpus = tf.config.experimental.list_physical_devices('GPU')
17
+ if gpus:
18
+ for gpu in gpus:
19
+ tf.config.experimental.set_memory_growth(gpu, True)
20
+ st.sidebar.success(f"GPU available: {len(gpus)} device(s)")
21
+ else:
22
+ st.sidebar.warning("No GPU detected, using CPU")
23
+ except Exception as e:
24
+ st.sidebar.error(f"GPU config error: {e}")
25
+
26
  st.set_page_config(
27
  page_title="Pet Segmentation with SegFormer",
28
  page_icon="🐶",
 
45
  def download_model_from_drive():
46
  # Create a models directory
47
  os.makedirs("models", exist_ok=True)
48
+ model_path = "models/tf_model.h5"
49
 
50
  if not os.path.exists(model_path):
51
+ # Extract the file ID from the sharing URL
52
+ file_id = "1XObpqG8qZ7YUyiRKbpVvxX11yQSK8Y_3"
53
+ url = f"https://drive.google.com/uc?id={file_id}"
54
  try:
55
  gdown.download(url, model_path, quiet=False)
56
  st.success("Model downloaded successfully from Google Drive.")
 
140
  # Store original image for display
141
  original_img = img_array.copy()
142
 
143
+ # Resize to target size with preserve_aspect_ratio=False
144
+ img_resized = tf.image.resize(
145
+ img_array,
146
+ (IMAGE_SIZE, IMAGE_SIZE),
147
+ method='bilinear',
148
+ preserve_aspect_ratio=False, # Ensure exact dimensions
149
+ antialias=True
150
+ )
151
+
152
+ # Verify dimensions with assertion
153
+ tf.debugging.assert_equal(tf.shape(img_resized)[0:2], [IMAGE_SIZE, IMAGE_SIZE],
154
+ message="Image dimensions don't match expected size")
155
 
156
  # Normalize
157
  img_normalized = normalize_image(img_resized)
 
174
  Returns:
175
  Processed mask (2D array)
176
  """
177
+ # Take argmax along the class dimension (axis=1 for batch data)
178
  pred_mask = tf.math.argmax(pred_mask, axis=1)
179
+
180
+ # Remove batch dimension and convert to numpy
181
  pred_mask = tf.squeeze(pred_mask)
182
+
183
+ # Resize to match original image size if needed
184
+ if pred_mask.shape[0] != IMAGE_SIZE or pred_mask.shape[1] != IMAGE_SIZE:
185
+ pred_mask = tf.image.resize(
186
+ tf.expand_dims(pred_mask, axis=-1),
187
+ (IMAGE_SIZE, IMAGE_SIZE),
188
+ method='nearest'
189
+ )
190
+ pred_mask = tf.squeeze(pred_mask)
191
+
192
  return pred_mask.numpy()
193
 
194
  def colorize_mask(mask):
 
345
 
346
  # Preprocess and predict
347
  with st.spinner("Generating segmentation mask..."):
348
+ try:
349
+ # Preprocess the image
350
+ img_tensor, original_img = preprocess_image(image)
351
+
352
+ # Print shape to debug
353
+ st.write(f"DEBUG - Input tensor shape: {img_tensor.shape}")
354
+
355
+ # Make prediction with error handling
356
+ try:
357
+ outputs = model(pixel_values=img_tensor, training=False)
358
+ logits = outputs.logits
359
+
360
+ # Create visualization mask
361
+ mask = create_mask(logits)
362
+
363
+ # Colorize the mask
364
+ colorized_mask = colorize_mask(mask)
365
+
366
+ # Create overlay
367
+ overlay = create_overlay(original_img, colorized_mask, alpha=overlay_opacity)
368
+ except Exception as inference_error:
369
+ st.error(f"Inference error: {inference_error}")
370
+ st.write("Trying alternative approach...")
371
+
372
+ # Alternative: resize to exactly 512x512 with crop_or_pad
373
+ img_resized = tf.image.resize_with_crop_or_pad(
374
+ original_img, IMAGE_SIZE, IMAGE_SIZE
375
+ )
376
+ img_normalized = normalize_image(img_resized)
377
+ img_transposed = tf.transpose(img_normalized, (2, 0, 1))
378
+ img_tensor = tf.expand_dims(img_transposed, axis=0)
379
+
380
+ outputs = model(pixel_values=img_tensor, training=False)
381
+ logits = outputs.logits
382
+ mask = create_mask(logits)
383
+ colorized_mask = colorize_mask(mask)
384
+ overlay = create_overlay(original_img, colorized_mask, alpha=overlay_opacity)
385
+ except Exception as e:
386
+ st.error(f"Failed to process image: {e}")
387
+ st.stop()
388
 
389
  # Display results
390
  with col2: