Spaces:
Sleeping
Sleeping
add
Browse files
app.py
CHANGED
@@ -30,56 +30,29 @@ NUM_CLASSES = len(ID2LABEL)
|
|
30 |
@st.cache_resource
|
31 |
def download_model_from_drive():
|
32 |
"""
|
33 |
-
Download model from Google Drive
|
34 |
-
|
35 |
-
Returns:
|
36 |
-
Path to downloaded model
|
37 |
"""
|
38 |
-
# Define paths
|
39 |
model_dir = os.path.join("models", "saved_models")
|
40 |
os.makedirs(model_dir, exist_ok=True)
|
41 |
-
model_path = os.path.join(model_dir, "
|
42 |
-
|
43 |
-
#
|
|
|
|
|
|
|
44 |
if not os.path.exists(model_path):
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
url
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
if output is None:
|
56 |
-
raise Exception("gdown failed with direct URL")
|
57 |
-
st.success(f"Model downloaded to {output}")
|
58 |
-
except Exception as e1:
|
59 |
-
st.warning(f"First download attempt failed: {str(e1)}")
|
60 |
-
|
61 |
-
# Try with just the file ID
|
62 |
-
file_id = "1XObpqG8qZ7YUyiRKbpVvxX11yQSK8Y_3"
|
63 |
-
st.info(f"Trying alternative download with file ID: {file_id}")
|
64 |
-
output = gdown.download(f"https://drive.google.com/uc?id={file_id}",
|
65 |
-
model_path, quiet=False)
|
66 |
-
if output is None:
|
67 |
-
raise Exception("gdown failed with file ID")
|
68 |
-
st.success("Model downloaded successfully!")
|
69 |
-
|
70 |
-
# Verify the file exists and has content
|
71 |
-
if os.path.exists(model_path) and os.path.getsize(model_path) > 0:
|
72 |
-
st.success(f"Verified file at {model_path} ({os.path.getsize(model_path)/1024/1024:.2f} MB)")
|
73 |
-
else:
|
74 |
-
st.error(f"Downloaded file is missing or empty: {model_path}")
|
75 |
-
return None
|
76 |
-
|
77 |
-
except Exception as e:
|
78 |
-
st.error(f"Error downloading model: {str(e)}")
|
79 |
-
st.info("Trying to use the HuggingFace model instead...")
|
80 |
-
return None
|
81 |
else:
|
82 |
-
st.info(
|
83 |
|
84 |
return model_path
|
85 |
|
@@ -192,36 +165,29 @@ def create_mask(pred_mask):
|
|
192 |
Convert model prediction to displayable mask
|
193 |
|
194 |
Args:
|
195 |
-
pred_mask: Prediction from model
|
196 |
|
197 |
Returns:
|
198 |
-
Processed mask
|
199 |
"""
|
200 |
-
# Get the class with highest probability (argmax along class dimension)
|
201 |
pred_mask = tf.math.argmax(pred_mask, axis=1)
|
202 |
-
|
203 |
-
|
204 |
-
pred_mask = tf.expand_dims(pred_mask, -1)
|
205 |
-
|
206 |
-
# Resize to original image size
|
207 |
-
pred_mask = tf.image.resize(
|
208 |
-
pred_mask,
|
209 |
-
(IMAGE_SIZE, IMAGE_SIZE),
|
210 |
-
method="nearest"
|
211 |
-
)
|
212 |
-
|
213 |
-
return pred_mask[0]
|
214 |
|
215 |
def colorize_mask(mask):
|
216 |
"""
|
217 |
Apply colors to segmentation mask
|
218 |
|
219 |
Args:
|
220 |
-
mask: Segmentation mask
|
221 |
|
222 |
Returns:
|
223 |
-
Colorized mask
|
224 |
"""
|
|
|
|
|
|
|
|
|
225 |
# Define colors for each class (RGB)
|
226 |
colors = [
|
227 |
[0, 0, 0], # Background (black)
|
@@ -233,8 +199,7 @@ def colorize_mask(mask):
|
|
233 |
rgb_mask = np.zeros((mask.shape[0], mask.shape[1], 3), dtype=np.uint8)
|
234 |
|
235 |
for i, color in enumerate(colors):
|
236 |
-
|
237 |
-
class_mask = np.where(mask == i, 1, 0).astype(np.uint8)
|
238 |
for c in range(3):
|
239 |
rgb_mask[:, :, c] += class_mask * color[c]
|
240 |
|
|
|
30 |
@st.cache_resource
|
31 |
def download_model_from_drive():
|
32 |
"""
|
33 |
+
Download model from Google Drive and return the local path.
|
|
|
|
|
|
|
34 |
"""
|
|
|
35 |
model_dir = os.path.join("models", "saved_models")
|
36 |
os.makedirs(model_dir, exist_ok=True)
|
37 |
+
model_path = os.path.join(model_dir, "segformer_model")
|
38 |
+
|
39 |
+
# Google Drive file ID
|
40 |
+
file_id = "1XObpqG8qZ7YUyiRKbpVvxX11yQSK8Y_3"
|
41 |
+
url = f"https://drive.google.com/uc?id={file_id}"
|
42 |
+
|
43 |
if not os.path.exists(model_path):
|
44 |
+
try:
|
45 |
+
# Use gdown to download the file
|
46 |
+
with st.spinner("Downloading model from Google Drive..."):
|
47 |
+
gdown.download(url, model_path, quiet=False)
|
48 |
+
if not os.path.exists(model_path) or os.path.getsize(model_path) == 0:
|
49 |
+
raise Exception("Downloaded file is empty or missing.")
|
50 |
+
st.success("Model downloaded successfully!")
|
51 |
+
except Exception as e:
|
52 |
+
st.error(f"Error downloading model: {e}")
|
53 |
+
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
54 |
else:
|
55 |
+
st.info("Model already exists locally.")
|
56 |
|
57 |
return model_path
|
58 |
|
|
|
165 |
Convert model prediction to displayable mask
|
166 |
|
167 |
Args:
|
168 |
+
pred_mask: Prediction logits from the model
|
169 |
|
170 |
Returns:
|
171 |
+
Processed mask (2D array)
|
172 |
"""
|
|
|
173 |
pred_mask = tf.math.argmax(pred_mask, axis=1)
|
174 |
+
pred_mask = tf.squeeze(pred_mask, axis=0) # Remove batch dimension
|
175 |
+
return pred_mask.numpy()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
176 |
|
177 |
def colorize_mask(mask):
|
178 |
"""
|
179 |
Apply colors to segmentation mask
|
180 |
|
181 |
Args:
|
182 |
+
mask: Segmentation mask (2D array)
|
183 |
|
184 |
Returns:
|
185 |
+
Colorized mask (3D RGB array)
|
186 |
"""
|
187 |
+
# Ensure the mask is 2D
|
188 |
+
if len(mask.shape) > 2:
|
189 |
+
mask = np.squeeze(mask, axis=-1)
|
190 |
+
|
191 |
# Define colors for each class (RGB)
|
192 |
colors = [
|
193 |
[0, 0, 0], # Background (black)
|
|
|
199 |
rgb_mask = np.zeros((mask.shape[0], mask.shape[1], 3), dtype=np.uint8)
|
200 |
|
201 |
for i, color in enumerate(colors):
|
202 |
+
class_mask = (mask == i).astype(np.uint8)
|
|
|
203 |
for c in range(3):
|
204 |
rgb_mask[:, :, c] += class_mask * color[c]
|
205 |
|