Yuxiang Wang commited on
Commit
c5343e6
·
1 Parent(s): af9c1e6

explanations,closest sample

Browse files
Files changed (4) hide show
  1. app.py +64 -22
  2. closest_sample.py +16 -9
  3. explanations.py +8 -3
  4. inference_beit.py +203 -0
app.py CHANGED
@@ -46,9 +46,19 @@ def get_model(model_name):
46
  nb_classes = n_classes,load_weights=False,finer_model=True,backbone_name ='Resnet50v2')
47
  model.load_weights('model_classification/rock-170.h5')
48
  else:
49
- return 'Error'
50
  return model,n_classes
51
 
 
 
 
 
 
 
 
 
 
 
52
  def segment_image(input_image):
53
  img = segmentation_sam(input_image)
54
  return img
@@ -67,7 +77,8 @@ def classify_image(input_image, model_name):
67
  if 'Fossils 19' ==model_name:
68
  from inference_beit import inference_dino
69
  model,n_classes = get_model(model_name)
70
- return inference_dino(input_image,model_name)
 
71
  return None
72
 
73
  def get_embeddings(input_image,model_name):
@@ -84,21 +95,26 @@ def get_embeddings(input_image,model_name):
84
  if 'Fossils 19' ==model_name:
85
  from inference_beit import inference_dino
86
  model,n_classes = get_model(model_name)
87
- return inference_dino(input_image,model_name)
 
 
 
88
  return None
89
 
90
 
91
  def find_closest(input_image,model_name):
92
  embedding = get_embeddings(input_image,model_name)
93
- paths = get_images(embedding)
94
- return paths
 
95
 
96
  def explain_image(input_image,model_name):
97
  model,n_classes= get_model(model_name)
98
- saliency, integrated, smoothgrad = explain(model,input_image,n_classes=n_classes)
 
99
  #original = saliency + integrated + smoothgrad
100
  print('done')
101
- return saliency, integrated, smoothgrad,
102
 
103
  #minimalist theme
104
  with gr.Blocks(theme='sudeepshouche/minimalist') as demo:
@@ -118,7 +134,7 @@ with gr.Blocks(theme='sudeepshouche/minimalist') as demo:
118
 
119
  with gr.Column():
120
  model_name = gr.Dropdown(
121
- ["Mummified 170", "Rock 170"],
122
  multiselect=False,
123
  value="Rock 170", # default option
124
  label="Model",
@@ -142,32 +158,61 @@ with gr.Blocks(theme='sudeepshouche/minimalist') as demo:
142
  # with gr.Column():
143
  # class_predicted2 = gr.Label(label='Class Predicted from diffuser')
144
  # classify_button = gr.Button("Classify Image")
145
-
146
 
147
  with gr.Accordion("Explanations "):
148
  gr.Markdown("Computing Explanations from the model")
149
  with gr.Row():
150
  #original_input = gr.Image(label="Original Frame")
151
- saliency = gr.Image(label="saliency")
152
- gradcam = gr.Image(label='integraged gradients')
153
- guided_gradcam = gr.Image(label='gradcam')
154
  #guided_backprop = gr.Image(label='guided backprop')
 
155
  generate_explanations = gr.Button("Generate Explanations")
156
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
157
  with gr.Accordion('Closest Images'):
158
  gr.Markdown("Finding the closest images in the dataset")
 
159
  with gr.Row():
160
- closest_image_0 = gr.Image(label='Closest Image')
161
- closest_image_1 = gr.Image(label='Second Closest Image')
162
- closest_image_2 = gr.Image(label='Third Closest Image')
163
- closest_image_3 = gr.Image(label='Forth Closest Image')
164
- closest_image_4 = gr.Image(label='Fifth Closest Image')
165
  find_closest_btn = gr.Button("Find Closest Images")
166
 
167
  segment_button.click(segment_image, inputs=input_image, outputs=segmented_image)
168
  classify_image_button.click(classify_image, inputs=[input_image,model_name], outputs=class_predicted)
169
- generate_explanations.click(explain_image, inputs=[input_image,model_name], outputs=[saliency,gradcam,guided_gradcam])
170
- find_closest_btn.click(find_closest, inputs=[input_image,model_name], outputs=[closest_image_0,closest_image_1,closest_image_2,closest_image_3,closest_image_4])
 
 
 
 
 
 
 
 
 
 
171
  #classify_segmented_button.click(classify_image, inputs=[segmented_image,model_name], outputs=class_predicted)
172
 
173
  demo.queue() # manage multiple incoming requests
@@ -176,6 +221,3 @@ if os.getenv('SYSTEM') == 'spaces':
176
  demo.launch(width='40%',auth=(os.environ.get('USERNAME'), os.environ.get('PASSWORD')))
177
  else:
178
  demo.launch()
179
-
180
-
181
-
 
46
  nb_classes = n_classes,load_weights=False,finer_model=True,backbone_name ='Resnet50v2')
47
  model.load_weights('model_classification/rock-170.h5')
48
  else:
49
+ raise ValueError(f"Model name '{model_name}' is not recognized")
50
  return model,n_classes
51
 
52
+ '''
53
+ elif model_name == 'Fossils 19':
54
+ n_classes = 19 or 23?
55
+ model = get_beit_model(input_shape=(600, 600, 3),
56
+ num_labels=n_classes,
57
+ load_weights=False,
58
+ )
59
+ model.load_weights('model_classification/beit-fossils-19.h5')
60
+ '''
61
+
62
  def segment_image(input_image):
63
  img = segmentation_sam(input_image)
64
  return img
 
77
  if 'Fossils 19' ==model_name:
78
  from inference_beit import inference_dino
79
  model,n_classes = get_model(model_name)
80
+ result = inference_dino(input_image,model_name)
81
+ return result
82
  return None
83
 
84
  def get_embeddings(input_image,model_name):
 
95
  if 'Fossils 19' ==model_name:
96
  from inference_beit import inference_dino
97
  model,n_classes = get_model(model_name)
98
+ result = inference_dino(input_image,model_name)
99
+ #TODO
100
+ #result = inference_beit_embedding
101
+ return result
102
  return None
103
 
104
 
105
  def find_closest(input_image,model_name):
106
  embedding = get_embeddings(input_image,model_name)
107
+ classes, paths = get_images(embedding)
108
+ #outputs = classes+paths
109
+ return classes,paths
110
 
111
  def explain_image(input_image,model_name):
112
  model,n_classes= get_model(model_name)
113
+ #saliency, integrated, smoothgrad,
114
+ rise = explain(model,input_image,n_classes=n_classes)
115
  #original = saliency + integrated + smoothgrad
116
  print('done')
117
+ return rise
118
 
119
  #minimalist theme
120
  with gr.Blocks(theme='sudeepshouche/minimalist') as demo:
 
134
 
135
  with gr.Column():
136
  model_name = gr.Dropdown(
137
+ ["Mummified 170", "Rock 170","Fossils 19"],
138
  multiselect=False,
139
  value="Rock 170", # default option
140
  label="Model",
 
158
  # with gr.Column():
159
  # class_predicted2 = gr.Label(label='Class Predicted from diffuser')
160
  # classify_button = gr.Button("Classify Image")
161
+
162
 
163
  with gr.Accordion("Explanations "):
164
  gr.Markdown("Computing Explanations from the model")
165
  with gr.Row():
166
  #original_input = gr.Image(label="Original Frame")
167
+ #saliency = gr.Image(label="saliency")
168
+ #gradcam = gr.Image(label='integraged gradients')
169
+ #guided_gradcam = gr.Image(label='gradcam')
170
  #guided_backprop = gr.Image(label='guided backprop')
171
+ rise = gr.Image(label = 'Rise')
172
  generate_explanations = gr.Button("Generate Explanations")
173
 
174
+ # with gr.Accordion('Closest Images'):
175
+ # gr.Markdown("Finding the closest images in the dataset")
176
+ # with gr.Row():
177
+ # with gr.Column():
178
+ # label_closest_image_0 = gr.Markdown('')
179
+ # closest_image_0 = gr.Image(label='Closest Image',image_mode='contain',width=200, height=200)
180
+ # with gr.Column():
181
+ # label_closest_image_1 = gr.Markdown('')
182
+ # closest_image_1 = gr.Image(label='Second Closest Image',image_mode='contain',width=200, height=200)
183
+ # with gr.Column():
184
+ # label_closest_image_2 = gr.Markdown('')
185
+ # closest_image_2 = gr.Image(label='Third Closest Image',image_mode='contain',width=200, height=200)
186
+ # with gr.Column():
187
+ # label_closest_image_3 = gr.Markdown('')
188
+ # closest_image_3 = gr.Image(label='Forth Closest Image',image_mode='contain', width=200, height=200)
189
+ # with gr.Column():
190
+ # label_closest_image_4 = gr.Markdown('')
191
+ # closest_image_4 = gr.Image(label='Fifth Closest Image',image_mode='contain',width=200, height=200)
192
+ # find_closest_btn = gr.Button("Find Closest Images")
193
  with gr.Accordion('Closest Images'):
194
  gr.Markdown("Finding the closest images in the dataset")
195
+
196
  with gr.Row():
197
+ gallery = gr.Gallery(label="Closest Images", show_label=False,elem_id="gallery",columns=[5], rows=[1],height='auto', allow_preview=True, preview=None)
198
+ #.style(grid=[1, 5], height=200, width=200)
199
+
 
 
200
  find_closest_btn = gr.Button("Find Closest Images")
201
 
202
  segment_button.click(segment_image, inputs=input_image, outputs=segmented_image)
203
  classify_image_button.click(classify_image, inputs=[input_image,model_name], outputs=class_predicted)
204
+ generate_explanations.click(explain_image, inputs=[input_image,model_name], outputs=[rise]) #saliency,gradcam,guided_gradcam
205
+ #find_closest_btn.click(find_closest, inputs=[input_image,model_name], outputs=[label_closest_image_0,label_closest_image_1,label_closest_image_2,label_closest_image_3,label_closest_image_4,closest_image_0,closest_image_1,closest_image_2,closest_image_3,closest_image_4])
206
+ def update_outputs(input_image,model_name):
207
+ labels, images = find_closest(input_image,model_name)
208
+ #labels_html = "".join([f'<div style="display: inline-block; text-align: center; width: 18%;">{label}</div>' for label in labels])
209
+ #labels_markdown = f"<div style='width: 100%; text-align: center;'>{labels_html}</div>"
210
+ image_caption=[]
211
+ for i in range(5):
212
+ image_caption.append((images[i],labels[i]))
213
+ return image_caption
214
+
215
+ find_closest_btn.click(fn=update_outputs, inputs=[input_image,model_name], outputs=[gallery])
216
  #classify_segmented_button.click(classify_image, inputs=[segmented_image,model_name], outputs=class_predicted)
217
 
218
  demo.queue() # manage multiple incoming requests
 
221
  demo.launch(width='40%',auth=(os.environ.get('USERNAME'), os.environ.get('PASSWORD')))
222
  else:
223
  demo.launch()
 
 
 
closest_sample.py CHANGED
@@ -50,10 +50,8 @@ def download_public_image(url, destination_path):
50
  with open(destination_path, 'wb') as f:
51
  f.write(response.content)
52
  print(f"Downloaded image to {destination_path}")
53
- return True
54
  else:
55
  print(f"Failed to download image from bucket. Status code: {response.status_code}")
56
- return False
57
 
58
  def get_images(embedding):
59
 
@@ -69,14 +67,23 @@ def get_images(embedding):
69
  folder_florissant = 'https://storage.googleapis.com/serrelab/prj_fossils/2024/Florissant_Fossil_v2.0/'
70
  folder_general = 'https://storage.googleapis.com/serrelab/prj_fossils/2024/General_Fossil_v2.0/'
71
 
 
 
72
  for i, path in enumerate(paths):
73
  local_file_path = f'image_{i}.jpg'
74
- public_path_florissant = path.replace('/gpfs/data/tserre/irodri15/Fossils/new_data/leavesdb-v1_1/images/Fossil/Florissant_Fossil/512/full/jpg/', folder_florissant)
75
- success = download_public_image(public_path_florissant, local_file_path)
76
-
77
- if not success:
78
- public_path_general = path.replace('/gpfs/data/tserre/irodri15/Fossils/new_data/leavesdb-v1_1/images/Fossil/Florissant_Fossil/512/full/jpg/', folder_general)
79
- download_public_image(public_path_general, local_file_path)
 
 
 
 
 
 
80
  #paths= [path.replace('/gpfs/data/tserre/irodri15/Fossils/new_data/leavesdb-v1_1/images/Fossil/Florissant_Fossil/512/full/jpg/',
81
  # '/media/data_cifs/projects/prj_fossils/data/processed_data/leavesdb-v1_1/images/Fossil/Florissant_Fossil/original/full/jpg/') for path in paths]
82
- return paths
 
 
50
  with open(destination_path, 'wb') as f:
51
  f.write(response.content)
52
  print(f"Downloaded image to {destination_path}")
 
53
  else:
54
  print(f"Failed to download image from bucket. Status code: {response.status_code}")
 
55
 
56
  def get_images(embedding):
57
 
 
67
  folder_florissant = 'https://storage.googleapis.com/serrelab/prj_fossils/2024/Florissant_Fossil_v2.0/'
68
  folder_general = 'https://storage.googleapis.com/serrelab/prj_fossils/2024/General_Fossil_v2.0/'
69
 
70
+ local_paths = []
71
+ classes = []
72
  for i, path in enumerate(paths):
73
  local_file_path = f'image_{i}.jpg'
74
+ if 'Florissant_Fossil/512/full/jpg/' in path:
75
+ public_path = path.replace('/gpfs/data/tserre/irodri15/Fossils/new_data/leavesdb-v1_1/images/Fossil/Florissant_Fossil/512/full/jpg/', folder_florissant)
76
+ elif 'General_Fossil/512/full/jpg/' in path:
77
+ public_path = path.replace('/gpfs/data/tserre/irodri15/Fossils/new_data/leavesdb-v1_1/images/Fossil/General_Fossil/512/full/jpg/', folder_general)
78
+ else:
79
+ print("no match found")
80
+ download_public_image(public_path, local_file_path)
81
+ names = []
82
+ parts = [part for part in public_path.split('/') if part]
83
+ part = parts[-2]
84
+ classes.append(part)
85
+ local_paths.append(local_file_path)
86
  #paths= [path.replace('/gpfs/data/tserre/irodri15/Fossils/new_data/leavesdb-v1_1/images/Fossil/Florissant_Fossil/512/full/jpg/',
87
  # '/media/data_cifs/projects/prj_fossils/data/processed_data/leavesdb-v1_1/images/Fossil/Florissant_Fossil/original/full/jpg/') for path in paths]
88
+
89
+ return classes, local_paths
explanations.py CHANGED
@@ -50,10 +50,13 @@ def explain(model, input_image,size=600, n_classes=171) :
50
  class_model = tf.keras.Model(model.input, model.output[1])
51
 
52
  explainers = [
53
- Saliency(class_model),
54
- IntegratedGradients(class_model, steps=50, batch_size=BATCH_SIZE),
55
- SmoothGrad(class_model, nb_samples=50, batch_size=BATCH_SIZE),
56
  #GradCAM(class_model),
 
 
 
57
  ]
58
  cropped,repetitions = _clever_crop(input_image,(size,size))
59
  size_repetitions = int(size//(repetitions.numpy()+1))
@@ -74,6 +77,8 @@ def explain(model, input_image,size=600, n_classes=171) :
74
 
75
  plt.savefig(f'phi_{e}.png')
76
  explanations.append(f'phi_{e}.png')
 
 
77
 
78
  print('Done')
79
 
 
50
  class_model = tf.keras.Model(model.input, model.output[1])
51
 
52
  explainers = [
53
+ #Saliency(class_model),
54
+ #IntegratedGradients(class_model, steps=50, batch_size=BATCH_SIZE),
55
+ #SmoothGrad(class_model, nb_samples=50, batch_size=BATCH_SIZE),
56
  #GradCAM(class_model),
57
+ Rise(class_model,nb_samples = 50, batch_size = BATCH_SIZE,grid_size=7,
58
+ preservation_probability=0.5)
59
+ #
60
  ]
61
  cropped,repetitions = _clever_crop(input_image,(size,size))
62
  size_repetitions = int(size//(repetitions.numpy()+1))
 
77
 
78
  plt.savefig(f'phi_{e}.png')
79
  explanations.append(f'phi_{e}.png')
80
+ print(type(explanations))
81
+ print(len(explanations))
82
 
83
  print('Done')
84
 
inference_beit.py CHANGED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+ gpu_devices = tf.config.experimental.list_physical_devices('GPU')
3
+ if gpu_devices:
4
+ tf.config.experimental.set_memory_growth(gpu_devices[0], True)
5
+ else:
6
+ print(f"TensorFlow device: {gpu_devices}")
7
+
8
+ import os
9
+ import numpy as np
10
+ import keras
11
+ from PIL import Image
12
+ import keras_cv
13
+ from keras_cv_attention_models import beit
14
+ import matplotlib.pyplot as plt
15
+
16
+
17
+ #preprocessing
18
+ #TODO
19
+ num_classes = len(class_names)
20
+ AUTO = tf.data.AUTOTUNE
21
+ rand_augment = keras_cv.layers.RandAugment(value_range = (-1, 1), augmentations_per_image = 3, magnitude=0.5)
22
+
23
+ SIZE = 384
24
+ debug = None
25
+
26
+ def augmentations(x, crop_size=22, brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2):
27
+ x = tf.cast(x, tf.float32)
28
+ x = tf.image.random_crop(x, (tf.shape(x)[0], 100, 100, 3))
29
+ x = tf.image.random_brightness(x, max_delta=brightness)
30
+ x = tf.image.random_contrast(x, lower=1.0-contrast, upper=1+contrast)
31
+ x = tf.image.random_saturation(x, lower=1.0-saturation, upper=1.0+saturation)
32
+ x = tf.image.random_hue(x, max_delta=hue)
33
+ x = tf.image.resize(x, (128, 128))
34
+ x = tf.clip_by_value(x, 0.0, 255.0)
35
+ x = tf.keras.applications.resnet_v2.preprocess_input(x)
36
+ return x
37
+
38
+
39
+ def pad_gt(x):
40
+ h, w = x.shape[-2:]
41
+ padh = sam.image_encoder.img_size - h
42
+ padw = sam.image_encoder.img_size - w
43
+ x = F.pad(x, (0, padw, 0, padh))
44
+ return x
45
+
46
+ def preprocess(img):
47
+
48
+ img = np.array(img).astype(np.uint8)
49
+
50
+ #assert img.max() > 127.0
51
+
52
+ img_preprocess = predictor.transform.apply_image(img)
53
+ intermediate_shape = img_preprocess.shape
54
+
55
+ img_preprocess = torch.as_tensor(img_preprocess).cuda()
56
+ img_preprocess = img_preprocess.permute(2, 0, 1).contiguous()[None, :, :, :]
57
+
58
+ img_preprocess = sam.preprocess(img_preprocess)
59
+ if len(intermediate_shape) == 3:
60
+ intermediate_shape = intermediate_shape[:2]
61
+ elif len(intermediate_shape) == 4:
62
+ intermediate_shape = intermediate_shape[1:3]
63
+
64
+ return img_preprocess, intermediate_shape
65
+
66
+
67
+
68
+ def normalize(img):
69
+ img = img - tf.math.reduce_min(img)
70
+ img = img / tf.math.reduce_max(img)
71
+ img = img * 2.0 - 1.0
72
+ return img
73
+
74
+ def smooth_mask(mask, ds=20):
75
+ shape = tf.shape(mask)
76
+ w, h = shape[0], shape[1]
77
+ return tf.image.resize(tf.image.resize(mask, (ds, ds), method="bicubic"), (w, h), method="bicubic")
78
+
79
+ def resize(img):
80
+ # default resize function for all pi outputs
81
+ return tf.image.resize(img, (SIZE, SIZE), method="bicubic")
82
+
83
+ def pi(img, mask):
84
+ img = tf.cast(img, tf.float32)
85
+
86
+ shape = tf.shape(img)
87
+ w, h = tf.cast(shape[0], tf.int64), tf.cast(shape[1], tf.int64)
88
+
89
+ mask = smooth_mask(mask)
90
+ mask = tf.reduce_mean(mask, -1)
91
+
92
+ img = img * tf.cast(mask > 0.1, tf.float32)[:, :, None]
93
+
94
+ img_resize = tf.image.resize(img, (SIZE, SIZE), method="bicubic", antialias=True)
95
+ img_pad = tf.image.resize_with_pad(img, SIZE, SIZE, method="bicubic", antialias=True)
96
+
97
+ # building 2 anchors
98
+ anchors = tf.where(mask > 0.15)
99
+ anchor_xmin = tf.math.reduce_min(anchors[:, 0])
100
+ anchor_xmax = tf.math.reduce_max(anchors[:, 0])
101
+ anchor_ymin = tf.math.reduce_min(anchors[:, 1])
102
+ anchor_ymax = tf.math.reduce_max(anchors[:, 1])
103
+
104
+ if anchor_xmax - anchor_xmin > 50 and anchor_ymax - anchor_ymin > 50:
105
+
106
+ img_anchor_1 = resize(img[anchor_xmin:anchor_xmax, anchor_ymin:anchor_ymax])
107
+
108
+ delta_x = (anchor_xmax - anchor_xmin) // 4
109
+ delta_y = (anchor_ymax - anchor_ymin) // 4
110
+ img_anchor_2 = img[anchor_xmin+delta_x:anchor_xmax-delta_x,
111
+ anchor_ymin+delta_y:anchor_ymax-delta_y]
112
+ img_anchor_2 = resize(img_anchor_2)
113
+ else:
114
+ img_anchor_1 = img_resize
115
+ img_anchor_2 = img_pad
116
+
117
+ # building the anchors max
118
+ anchor_max = tf.where(mask == tf.math.reduce_max(mask))[0]
119
+ anchor_max_x, anchor_max_y = anchor_max[0], anchor_max[1]
120
+
121
+ img_max_zoom1 = img[tf.math.maximum(anchor_max_x-SIZE, 0): tf.math.minimum(anchor_max_x+SIZE, w),
122
+ tf.math.maximum(anchor_max_y-SIZE, 0): tf.math.minimum(anchor_max_y+SIZE, h)]
123
+
124
+ img_max_zoom1 = resize(img_max_zoom1)
125
+ img_max_zoom2 = img[anchor_max_x-SIZE//2:anchor_max_x+SIZE//2,
126
+ anchor_max_y-SIZE//2:anchor_max_y+SIZE//2]
127
+ img_max_zoom2 = img[tf.math.maximum(anchor_max_x-SIZE//2, 0): tf.math.minimum(anchor_max_x+SIZE//2, w),
128
+ tf.math.maximum(anchor_max_y-SIZE//2, 0): tf.math.minimum(anchor_max_y+SIZE//2, h)]
129
+ #tf.print(img_max_zoom2.shape)
130
+ #img_max_zoom2 = resize(img_max_zoom2)
131
+
132
+ return tf.cast(img_resize, tf.float32)
133
+
134
+ def parse_img(element, split, randaugment,maskaugment=True):
135
+ #global debug
136
+ path, class_id = element[0], element[1]
137
+
138
+ data = tf.io.read_file(path)
139
+ img = tf.io.decode_jpeg(data)
140
+ img = tf.cast(img, tf.uint8)
141
+ img = normalize(img)
142
+ shape = tf.shape(img)
143
+
144
+ # data_mask = tf.io.read_file(path_mask)
145
+ # mask = tf.io.decode_jpeg(data_mask)
146
+
147
+ class_id = tf.strings.to_number(class_id)
148
+ class_id = tf.cast(class_id, tf.int32)
149
+
150
+ label = tf.one_hot(class_id, num_classes)
151
+
152
+ # img = pi(img, mask)
153
+ img = tf.image.resize_with_pad(img, SIZE, SIZE, method="bicubic", antialias=True)
154
+
155
+ return tf.cast(img, tf.float32), tf.cast(label, tf.int32)
156
+
157
+ SIZE = 384
158
+ wsize=hsize=SIZE
159
+ def resize_images(batch_x, width=224, height=224):
160
+ return tf.image.resize(batch_x, (width, height))
161
+
162
+ def load_img(image_path,gray=False):
163
+ img = tf.io.read_file(image_path)
164
+ img = tf.image.decode_jpeg(img, channels=3)
165
+ img = tf.image.convert_image_dtype(img, tf.float32)
166
+ if gray:
167
+ img = tf.image.rgb_to_grayscale(img)
168
+ img = tf.image.grayscale_to_rgb(img)
169
+ img = tf.image.resize(img,(wsize,hsize))
170
+ return img
171
+
172
+ LR = 1e-3
173
+
174
+ optimizer = tf.keras.optimizers.Adam(LR)
175
+ cce = tf.keras.losses.categorical_crossentropy
176
+
177
+ model_path = '/content/drive/MyDrive/Gg_Fossils_data_shared_copy/Fossils/models/model-13.h5'
178
+ model = keras.models.load_model(model_path, custom_objects = {'cce': cce})
179
+
180
+ outputs = model.predict(images)
181
+
182
+ predictions = tf.math.top_k(outputs[1], k = 5)
183
+ cid = 1
184
+ dataset = np.array(dataset)
185
+ final_predictions = []
186
+ for ele in predictions[1]:
187
+ if cid in ele:
188
+ final_predictions.append(cid)
189
+ else:
190
+ final_predictions.append(cid+10)
191
+ final_predictions = np.array(final_predictions)
192
+ images2 = images[final_predictions == cid]
193
+ image2_paths = dataset[final_predictions == cid][:,0]
194
+ print(images2.shape)
195
+
196
+ def get_beit_model(input_shape, num_labels, load_weights=False, ...):
197
+ pass
198
+
199
+ def inference_dino(input_image, model_name):
200
+ pass
201
+
202
+ def inference_beit_embedding(input_image, model, size=600):
203
+ pass