Spaces:
Sleeping
Sleeping
add
Browse files
app.py
CHANGED
@@ -10,7 +10,6 @@ import io
|
|
10 |
import gdown
|
11 |
from transformers import TFSegformerForSemanticSegmentation
|
12 |
|
13 |
-
# Set page configuration
|
14 |
st.set_page_config(
|
15 |
page_title="Pet Segmentation with SegFormer",
|
16 |
page_icon="🐶",
|
@@ -39,24 +38,48 @@ def download_model_from_drive():
|
|
39 |
# Define paths
|
40 |
model_dir = os.path.join("models", "saved_models")
|
41 |
os.makedirs(model_dir, exist_ok=True)
|
42 |
-
model_path = os.path.join(model_dir, "
|
43 |
|
44 |
# Check if model already exists
|
45 |
if not os.path.exists(model_path):
|
46 |
with st.spinner("Downloading model from Google Drive..."):
|
47 |
try:
|
48 |
-
#
|
49 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
50 |
|
51 |
-
# Download the model file
|
52 |
-
url = f"https://drive.google.com/uc?id={file_id}"
|
53 |
-
gdown.download(url, model_path, quiet=False)
|
54 |
-
st.success("Model downloaded successfully!")
|
55 |
except Exception as e:
|
56 |
st.error(f"Error downloading model: {str(e)}")
|
|
|
57 |
return None
|
58 |
else:
|
59 |
-
st.info("Model already exists
|
60 |
|
61 |
return model_path
|
62 |
|
@@ -83,8 +106,28 @@ def load_model():
|
|
83 |
ignore_mismatched_sizes=True
|
84 |
)
|
85 |
else:
|
86 |
-
#
|
87 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
88 |
|
89 |
return model
|
90 |
except Exception as e:
|
|
|
10 |
import gdown
|
11 |
from transformers import TFSegformerForSemanticSegmentation
|
12 |
|
|
|
13 |
st.set_page_config(
|
14 |
page_title="Pet Segmentation with SegFormer",
|
15 |
page_icon="🐶",
|
|
|
38 |
# Define paths
|
39 |
model_dir = os.path.join("models", "saved_models")
|
40 |
os.makedirs(model_dir, exist_ok=True)
|
41 |
+
model_path = os.path.join(model_dir, "t5_model.h5")
|
42 |
|
43 |
# Check if model already exists
|
44 |
if not os.path.exists(model_path):
|
45 |
with st.spinner("Downloading model from Google Drive..."):
|
46 |
try:
|
47 |
+
# Direct download URL for t5_model.h5
|
48 |
+
url = "https://drive.google.com/uc?id=1XObpqG8qZ7YUyiRKbpVvxX11yQSK8Y_3"
|
49 |
+
|
50 |
+
st.info(f"Downloading model from {url}")
|
51 |
+
|
52 |
+
# Try direct download with gdown using direct URL
|
53 |
+
try:
|
54 |
+
output = gdown.download(url, model_path, quiet=False)
|
55 |
+
if output is None:
|
56 |
+
raise Exception("gdown failed with direct URL")
|
57 |
+
st.success(f"Model downloaded to {output}")
|
58 |
+
except Exception as e1:
|
59 |
+
st.warning(f"First download attempt failed: {str(e1)}")
|
60 |
+
|
61 |
+
# Try with just the file ID
|
62 |
+
file_id = "1XObpqG8qZ7YUyiRKbpVvxX11yQSK8Y_3"
|
63 |
+
st.info(f"Trying alternative download with file ID: {file_id}")
|
64 |
+
output = gdown.download(f"https://drive.google.com/uc?id={file_id}",
|
65 |
+
model_path, quiet=False)
|
66 |
+
if output is None:
|
67 |
+
raise Exception("gdown failed with file ID")
|
68 |
+
st.success("Model downloaded successfully!")
|
69 |
+
|
70 |
+
# Verify the file exists and has content
|
71 |
+
if os.path.exists(model_path) and os.path.getsize(model_path) > 0:
|
72 |
+
st.success(f"Verified file at {model_path} ({os.path.getsize(model_path)/1024/1024:.2f} MB)")
|
73 |
+
else:
|
74 |
+
st.error(f"Downloaded file is missing or empty: {model_path}")
|
75 |
+
return None
|
76 |
|
|
|
|
|
|
|
|
|
77 |
except Exception as e:
|
78 |
st.error(f"Error downloading model: {str(e)}")
|
79 |
+
st.info("Trying to use the HuggingFace model instead...")
|
80 |
return None
|
81 |
else:
|
82 |
+
st.info(f"Model already exists at {model_path}")
|
83 |
|
84 |
return model_path
|
85 |
|
|
|
106 |
ignore_mismatched_sizes=True
|
107 |
)
|
108 |
else:
|
109 |
+
# Check if this is a Keras .h5 model or a HuggingFace model directory
|
110 |
+
if model_path.endswith('.h5'):
|
111 |
+
st.info("Loading Keras H5 model...")
|
112 |
+
# For a Keras .h5 file, use tf.keras.models.load_model
|
113 |
+
try:
|
114 |
+
model = tf.keras.models.load_model(model_path)
|
115 |
+
st.success("Keras model loaded successfully")
|
116 |
+
except Exception as ke:
|
117 |
+
st.error(f"Error loading Keras model: {str(ke)}")
|
118 |
+
st.warning("Falling back to pretrained model")
|
119 |
+
model = TFSegformerForSemanticSegmentation.from_pretrained(
|
120 |
+
"nvidia/mit-b0",
|
121 |
+
num_labels=NUM_CLASSES,
|
122 |
+
id2label=ID2LABEL,
|
123 |
+
label2id={label: id for id, label in ID2LABEL.items()},
|
124 |
+
ignore_mismatched_sizes=True
|
125 |
+
)
|
126 |
+
else:
|
127 |
+
# For a HuggingFace model directory
|
128 |
+
st.info("Loading HuggingFace model...")
|
129 |
+
model = TFSegformerForSemanticSegmentation.from_pretrained(model_path)
|
130 |
+
st.success("HuggingFace model loaded successfully")
|
131 |
|
132 |
return model
|
133 |
except Exception as e:
|