kishoreb4 commited on
Commit
5095c72
·
1 Parent(s): a9ad3bf
Files changed (1) hide show
  1. app.py +54 -11
app.py CHANGED
@@ -10,7 +10,6 @@ import io
10
  import gdown
11
  from transformers import TFSegformerForSemanticSegmentation
12
 
13
- # Set page configuration
14
  st.set_page_config(
15
  page_title="Pet Segmentation with SegFormer",
16
  page_icon="🐶",
@@ -39,24 +38,48 @@ def download_model_from_drive():
39
  # Define paths
40
  model_dir = os.path.join("models", "saved_models")
41
  os.makedirs(model_dir, exist_ok=True)
42
- model_path = os.path.join(model_dir, "segformer_model")
43
 
44
  # Check if model already exists
45
  if not os.path.exists(model_path):
46
  with st.spinner("Downloading model from Google Drive..."):
47
  try:
48
- # Google Drive file ID from the shared link
49
- file_id = "1XObpqG8qZ7YUyiRKbpVvxX11yQSK8Y_3"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
- # Download the model file
52
- url = f"https://drive.google.com/uc?id={file_id}"
53
- gdown.download(url, model_path, quiet=False)
54
- st.success("Model downloaded successfully!")
55
  except Exception as e:
56
  st.error(f"Error downloading model: {str(e)}")
 
57
  return None
58
  else:
59
- st.info("Model already exists locally.")
60
 
61
  return model_path
62
 
@@ -83,8 +106,28 @@ def load_model():
83
  ignore_mismatched_sizes=True
84
  )
85
  else:
86
- # Load downloaded model
87
- model = TFSegformerForSemanticSegmentation.from_pretrained(model_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
 
89
  return model
90
  except Exception as e:
 
10
  import gdown
11
  from transformers import TFSegformerForSemanticSegmentation
12
 
 
13
  st.set_page_config(
14
  page_title="Pet Segmentation with SegFormer",
15
  page_icon="🐶",
 
38
  # Define paths
39
  model_dir = os.path.join("models", "saved_models")
40
  os.makedirs(model_dir, exist_ok=True)
41
+ model_path = os.path.join(model_dir, "t5_model.h5")
42
 
43
  # Check if model already exists
44
  if not os.path.exists(model_path):
45
  with st.spinner("Downloading model from Google Drive..."):
46
  try:
47
+ # Direct download URL for t5_model.h5
48
+ url = "https://drive.google.com/uc?id=1XObpqG8qZ7YUyiRKbpVvxX11yQSK8Y_3"
49
+
50
+ st.info(f"Downloading model from {url}")
51
+
52
+ # Try direct download with gdown using direct URL
53
+ try:
54
+ output = gdown.download(url, model_path, quiet=False)
55
+ if output is None:
56
+ raise Exception("gdown failed with direct URL")
57
+ st.success(f"Model downloaded to {output}")
58
+ except Exception as e1:
59
+ st.warning(f"First download attempt failed: {str(e1)}")
60
+
61
+ # Try with just the file ID
62
+ file_id = "1XObpqG8qZ7YUyiRKbpVvxX11yQSK8Y_3"
63
+ st.info(f"Trying alternative download with file ID: {file_id}")
64
+ output = gdown.download(f"https://drive.google.com/uc?id={file_id}",
65
+ model_path, quiet=False)
66
+ if output is None:
67
+ raise Exception("gdown failed with file ID")
68
+ st.success("Model downloaded successfully!")
69
+
70
+ # Verify the file exists and has content
71
+ if os.path.exists(model_path) and os.path.getsize(model_path) > 0:
72
+ st.success(f"Verified file at {model_path} ({os.path.getsize(model_path)/1024/1024:.2f} MB)")
73
+ else:
74
+ st.error(f"Downloaded file is missing or empty: {model_path}")
75
+ return None
76
 
 
 
 
 
77
  except Exception as e:
78
  st.error(f"Error downloading model: {str(e)}")
79
+ st.info("Trying to use the HuggingFace model instead...")
80
  return None
81
  else:
82
+ st.info(f"Model already exists at {model_path}")
83
 
84
  return model_path
85
 
 
106
  ignore_mismatched_sizes=True
107
  )
108
  else:
109
+ # Check if this is a Keras .h5 model or a HuggingFace model directory
110
+ if model_path.endswith('.h5'):
111
+ st.info("Loading Keras H5 model...")
112
+ # For a Keras .h5 file, use tf.keras.models.load_model
113
+ try:
114
+ model = tf.keras.models.load_model(model_path)
115
+ st.success("Keras model loaded successfully")
116
+ except Exception as ke:
117
+ st.error(f"Error loading Keras model: {str(ke)}")
118
+ st.warning("Falling back to pretrained model")
119
+ model = TFSegformerForSemanticSegmentation.from_pretrained(
120
+ "nvidia/mit-b0",
121
+ num_labels=NUM_CLASSES,
122
+ id2label=ID2LABEL,
123
+ label2id={label: id for id, label in ID2LABEL.items()},
124
+ ignore_mismatched_sizes=True
125
+ )
126
+ else:
127
+ # For a HuggingFace model directory
128
+ st.info("Loading HuggingFace model...")
129
+ model = TFSegformerForSemanticSegmentation.from_pretrained(model_path)
130
+ st.success("HuggingFace model loaded successfully")
131
 
132
  return model
133
  except Exception as e: