andy-wyx commited on
Commit
0aa9379
·
1 Parent(s): 0e1410c

add new beit model

Browse files
Files changed (1) hide show
  1. app.py +19 -2
app.py CHANGED
@@ -120,6 +120,13 @@ def get_model(model_name):
120
  embedding_depth = 2,
121
  n_classes = n_classes)
122
  model.load_weights('model_classification/fossil-new.h5')
 
 
 
 
 
 
 
123
  else:
124
  raise ValueError(f"Model name '{model_name}' is not recognized")
125
  return model,n_classes
@@ -151,6 +158,11 @@ def classify_image(input_image, model_name):
151
  model,n_classes = get_model(model_name)
152
  result = inference_resnet_finer_beit(input_image,model,size=384,n_classes=n_classes)
153
  return result
 
 
 
 
 
154
  return None
155
 
156
  def get_embeddings(input_image,model_name):
@@ -174,6 +186,11 @@ def get_embeddings(input_image,model_name):
174
  model,n_classes = get_model(model_name)
175
  result = inference_resnet_embedding_beit(input_image,model,size=384,n_classes=n_classes)
176
  return result
 
 
 
 
 
177
  return None
178
 
179
 
@@ -301,9 +318,9 @@ with gr.Blocks(theme='sudeepshouche/minimalist') as demo:
301
 
302
  with gr.Column():
303
  model_name = gr.Dropdown(
304
- ["Mummified 170", "Rock 170","Fossils 142","Fossils new"],
305
  multiselect=False,
306
- value="Fossils new", # default option
307
  label="Model",
308
  interactive=True,
309
  info="Choose the model you'd like to use"
 
120
  embedding_depth = 2,
121
  n_classes = n_classes)
122
  model.load_weights('model_classification/fossil-new.h5')
123
+ elif model_name == 'Fossils':
124
+ n_classes = 142
125
+ model = get_triplet_model_beit(input_shape = (384, 384, 3),
126
+ embedding_units = 256,
127
+ embedding_depth = 2,
128
+ n_classes = n_classes)
129
+ model.load_weights('model_classification/fossil-model.h5')
130
  else:
131
  raise ValueError(f"Model name '{model_name}' is not recognized")
132
  return model,n_classes
 
158
  model,n_classes = get_model(model_name)
159
  result = inference_resnet_finer_beit(input_image,model,size=384,n_classes=n_classes)
160
  return result
161
+ elif 'Fossils' ==model_name:
162
+ from inference_beit import inference_resnet_finer_beit
163
+ model,n_classes = get_model(model_name)
164
+ result = inference_resnet_finer_beit(input_image,model,size=384,n_classes=n_classes)
165
+ return result
166
  return None
167
 
168
  def get_embeddings(input_image,model_name):
 
186
  model,n_classes = get_model(model_name)
187
  result = inference_resnet_embedding_beit(input_image,model,size=384,n_classes=n_classes)
188
  return result
189
+ elif 'Fossils' ==model_name:
190
+ from inference_beit import inference_resnet_embedding_beit
191
+ model,n_classes = get_model(model_name)
192
+ result = inference_resnet_embedding_beit(input_image,model,size=384,n_classes=n_classes)
193
+ return result
194
  return None
195
 
196
 
 
318
 
319
  with gr.Column():
320
  model_name = gr.Dropdown(
321
+ ["Mummified 170", "Rock 170","Fossils 142","Fossils new","Fossils"],
322
  multiselect=False,
323
+ value="Fossils", # default option
324
  label="Model",
325
  interactive=True,
326
  info="Choose the model you'd like to use"