kishoreb4 commited on
Commit
e8dad3c
·
1 Parent(s): 5095c72
Files changed (1) hide show
  1. app.py +29 -64
app.py CHANGED
@@ -30,56 +30,29 @@ NUM_CLASSES = len(ID2LABEL)
30
  @st.cache_resource
31
  def download_model_from_drive():
32
  """
33
- Download model from Google Drive
34
-
35
- Returns:
36
- Path to downloaded model
37
  """
38
- # Define paths
39
  model_dir = os.path.join("models", "saved_models")
40
  os.makedirs(model_dir, exist_ok=True)
41
- model_path = os.path.join(model_dir, "t5_model.h5")
42
-
43
- # Check if model already exists
 
 
 
44
  if not os.path.exists(model_path):
45
- with st.spinner("Downloading model from Google Drive..."):
46
- try:
47
- # Direct download URL for t5_model.h5
48
- url = "https://drive.google.com/uc?id=1XObpqG8qZ7YUyiRKbpVvxX11yQSK8Y_3"
49
-
50
- st.info(f"Downloading model from {url}")
51
-
52
- # Try direct download with gdown using direct URL
53
- try:
54
- output = gdown.download(url, model_path, quiet=False)
55
- if output is None:
56
- raise Exception("gdown failed with direct URL")
57
- st.success(f"Model downloaded to {output}")
58
- except Exception as e1:
59
- st.warning(f"First download attempt failed: {str(e1)}")
60
-
61
- # Try with just the file ID
62
- file_id = "1XObpqG8qZ7YUyiRKbpVvxX11yQSK8Y_3"
63
- st.info(f"Trying alternative download with file ID: {file_id}")
64
- output = gdown.download(f"https://drive.google.com/uc?id={file_id}",
65
- model_path, quiet=False)
66
- if output is None:
67
- raise Exception("gdown failed with file ID")
68
- st.success("Model downloaded successfully!")
69
-
70
- # Verify the file exists and has content
71
- if os.path.exists(model_path) and os.path.getsize(model_path) > 0:
72
- st.success(f"Verified file at {model_path} ({os.path.getsize(model_path)/1024/1024:.2f} MB)")
73
- else:
74
- st.error(f"Downloaded file is missing or empty: {model_path}")
75
- return None
76
-
77
- except Exception as e:
78
- st.error(f"Error downloading model: {str(e)}")
79
- st.info("Trying to use the HuggingFace model instead...")
80
- return None
81
  else:
82
- st.info(f"Model already exists at {model_path}")
83
 
84
  return model_path
85
 
@@ -192,36 +165,29 @@ def create_mask(pred_mask):
192
  Convert model prediction to displayable mask
193
 
194
  Args:
195
- pred_mask: Prediction from model
196
 
197
  Returns:
198
- Processed mask for visualization
199
  """
200
- # Get the class with highest probability (argmax along class dimension)
201
  pred_mask = tf.math.argmax(pred_mask, axis=1)
202
-
203
- # Add channel dimension
204
- pred_mask = tf.expand_dims(pred_mask, -1)
205
-
206
- # Resize to original image size
207
- pred_mask = tf.image.resize(
208
- pred_mask,
209
- (IMAGE_SIZE, IMAGE_SIZE),
210
- method="nearest"
211
- )
212
-
213
- return pred_mask[0]
214
 
215
  def colorize_mask(mask):
216
  """
217
  Apply colors to segmentation mask
218
 
219
  Args:
220
- mask: Segmentation mask
221
 
222
  Returns:
223
- Colorized mask
224
  """
 
 
 
 
225
  # Define colors for each class (RGB)
226
  colors = [
227
  [0, 0, 0], # Background (black)
@@ -233,8 +199,7 @@ def colorize_mask(mask):
233
  rgb_mask = np.zeros((mask.shape[0], mask.shape[1], 3), dtype=np.uint8)
234
 
235
  for i, color in enumerate(colors):
236
- # Find pixels of this class and assign color
237
- class_mask = np.where(mask == i, 1, 0).astype(np.uint8)
238
  for c in range(3):
239
  rgb_mask[:, :, c] += class_mask * color[c]
240
 
 
30
  @st.cache_resource
31
  def download_model_from_drive():
32
  """
33
+ Download model from Google Drive and return the local path.
 
 
 
34
  """
 
35
  model_dir = os.path.join("models", "saved_models")
36
  os.makedirs(model_dir, exist_ok=True)
37
+ model_path = os.path.join(model_dir, "segformer_model")
38
+
39
+ # Google Drive file ID
40
+ file_id = "1XObpqG8qZ7YUyiRKbpVvxX11yQSK8Y_3"
41
+ url = f"https://drive.google.com/uc?id={file_id}"
42
+
43
  if not os.path.exists(model_path):
44
+ try:
45
+ # Use gdown to download the file
46
+ with st.spinner("Downloading model from Google Drive..."):
47
+ gdown.download(url, model_path, quiet=False)
48
+ if not os.path.exists(model_path) or os.path.getsize(model_path) == 0:
49
+ raise Exception("Downloaded file is empty or missing.")
50
+ st.success("Model downloaded successfully!")
51
+ except Exception as e:
52
+ st.error(f"Error downloading model: {e}")
53
+ return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  else:
55
+ st.info("Model already exists locally.")
56
 
57
  return model_path
58
 
 
165
  Convert model prediction to displayable mask
166
 
167
  Args:
168
+ pred_mask: Prediction logits from the model
169
 
170
  Returns:
171
+ Processed mask (2D array)
172
  """
 
173
  pred_mask = tf.math.argmax(pred_mask, axis=1)
174
+ pred_mask = tf.squeeze(pred_mask, axis=0) # Remove batch dimension
175
+ return pred_mask.numpy()
 
 
 
 
 
 
 
 
 
 
176
 
177
  def colorize_mask(mask):
178
  """
179
  Apply colors to segmentation mask
180
 
181
  Args:
182
+ mask: Segmentation mask (2D array)
183
 
184
  Returns:
185
+ Colorized mask (3D RGB array)
186
  """
187
+ # Ensure the mask is 2D
188
+ if len(mask.shape) > 2:
189
+ mask = np.squeeze(mask, axis=-1)
190
+
191
  # Define colors for each class (RGB)
192
  colors = [
193
  [0, 0, 0], # Background (black)
 
199
  rgb_mask = np.zeros((mask.shape[0], mask.shape[1], 3), dtype=np.uint8)
200
 
201
  for i, color in enumerate(colors):
202
+ class_mask = (mask == i).astype(np.uint8)
 
203
  for c in range(3):
204
  rgb_mask[:, :, c] += class_mask * color[c]
205