taesiri commited on
Commit
4edc505
·
1 Parent(s): e859cf6
Files changed (1) hide show
  1. app.py +68 -20
app.py CHANGED
@@ -19,8 +19,8 @@ import torchvision
19
  from huggingface_hub import HfApi, login, snapshot_download
20
  from PIL import Image
21
 
22
- session_token = os.environ.get("SessionToken")
23
- login(token=session_token)
24
 
25
  csv.field_size_limit(sys.maxsize)
26
 
@@ -100,22 +100,24 @@ def generate_dataset(username):
100
 
101
  NUMBER_OF_IMAGES = len(bad_items)
102
 
 
 
 
103
  if NUMBER_OF_IMAGES == 0:
104
  return []
105
 
106
- random_indices = remaining
107
- random_images = [imagenet_hard[int(i)]["image"] for i in random_indices]
108
- random_gt_ids = [imagenet_hard[int(i)]["label"] for i in random_indices]
109
- random_gt_labels = [imagenet_hard[int(x)]["english_label"] for x in random_indices]
110
 
111
  data = []
112
- for i, image in enumerate(random_images):
113
  data.append(
114
  {
115
- "id": random_indices[i],
116
- "image": image,
117
- "correct_label": random_gt_labels[i],
118
- "original_id": int(random_indices[i]),
119
  }
120
  )
121
  return data
@@ -153,16 +155,22 @@ def get_training_samples(qid):
153
 
154
  def load_sample(data, current_index):
155
  image_id = data[current_index]["id"]
156
- qimage = data[current_index]["image"]
 
 
 
 
157
 
158
- labels = data[current_index]["correct_label"]
159
  return qimage, labels
160
 
161
 
162
  def preprocessing(data, current_index, history, username):
163
  data = generate_dataset(username)
164
 
165
- if len(data) == 0:
 
 
 
166
  fake_plot = string_to_image("No more images to review")
167
  empty_image = Image.new("RGB", (224, 224))
168
  return (
@@ -172,6 +180,7 @@ def preprocessing(data, current_index, history, username):
172
  history,
173
  data,
174
  None,
 
175
  )
176
 
177
  current_index = 0
@@ -186,7 +195,15 @@ def preprocessing(data, current_index, history, username):
186
  labels = ", ".join(labels)
187
  label_plot = string_to_image(labels)
188
 
189
- return qimage, label_plot, current_index, history, data, training_samples_image
 
 
 
 
 
 
 
 
190
 
191
 
192
  def update_app(decision, data, current_index, history, username):
@@ -194,7 +211,7 @@ def update_app(decision, data, current_index, history, username):
194
  if current_index == -1:
195
  fake_plot = string_to_image("Please Enter your username and load samples")
196
  empty_image = Image.new("RGB", (224, 224))
197
- return empty_image, fake_plot, current_index, history, data, None
198
 
199
  if current_index == NUMBER_OF_IMAGES - 1:
200
  time_stamp = int(time.time())
@@ -226,7 +243,19 @@ def update_app(decision, data, current_index, history, username):
226
 
227
  fake_plot = string_to_image("Thank you for your time!")
228
  empty_image = Image.new("RGB", (224, 224))
229
- return empty_image, fake_plot, current_index, history, data, None
 
 
 
 
 
 
 
 
 
 
 
 
230
 
231
  if current_index >= 0 and current_index < NUMBER_OF_IMAGES - 1:
232
  time_stamp = int(time.time())
@@ -270,7 +299,18 @@ def update_app(decision, data, current_index, history, username):
270
  labels = ", ".join(labels)
271
  label_plot = string_to_image(labels)
272
 
273
- return qimage, label_plot, current_index, history, data, training_samples_image
 
 
 
 
 
 
 
 
 
 
 
274
 
275
 
276
  newcss = """
@@ -313,7 +353,11 @@ with gr.Blocks(css=newcss, theme=gr.themes.Soft()) as demo:
313
  )
314
 
315
  with gr.Column():
316
- username = gr.Textbox(label="Username", value=f"user-{random_str}")
 
 
 
 
317
  prepare_btn = gr.Button(value="Load Samples")
318
 
319
  with gr.Column():
@@ -341,6 +385,7 @@ with gr.Blocks(css=newcss, theme=gr.themes.Soft()) as demo:
341
  history,
342
  data_gr,
343
  training_samples,
 
344
  ],
345
  )
346
  myabe_btn.click(
@@ -353,6 +398,7 @@ with gr.Blocks(css=newcss, theme=gr.themes.Soft()) as demo:
353
  history,
354
  data_gr,
355
  training_samples,
 
356
  ],
357
  )
358
 
@@ -366,6 +412,7 @@ with gr.Blocks(css=newcss, theme=gr.themes.Soft()) as demo:
366
  history,
367
  data_gr,
368
  training_samples,
 
369
  ],
370
  )
371
 
@@ -379,7 +426,8 @@ with gr.Blocks(css=newcss, theme=gr.themes.Soft()) as demo:
379
  history,
380
  data_gr,
381
  training_samples,
 
382
  ],
383
  )
384
 
385
- demo.launch()
 
19
  from huggingface_hub import HfApi, login, snapshot_download
20
  from PIL import Image
21
 
22
+ # session_token = os.environ.get("SessionToken")
23
+ # login(token=session_token)
24
 
25
  csv.field_size_limit(sys.maxsize)
26
 
 
100
 
101
  NUMBER_OF_IMAGES = len(bad_items)
102
 
103
+ print(f"NUMBER_OF_IMAGES: {NUMBER_OF_IMAGES}")
104
+ print(f"Remaining: {len(remaining)}")
105
+
106
  if NUMBER_OF_IMAGES == 0:
107
  return []
108
 
109
+ # random_indices = remaining
110
+ # random_images = [imagenet_hard[int(i)]["image"] for i in random_indices]
111
+ # random_gt_ids = [imagenet_hard[int(i)]["label"] for i in random_indices]
112
+ # random_gt_labels = [imagenet_hard[int(x)]["english_label"] for x in random_indices]
113
 
114
  data = []
115
+ for i, image in enumerate(remaining):
116
  data.append(
117
  {
118
+ "id": remaining[i],
119
+ # "correct_label": random_gt_labels[i],
120
+ # "original_id": int(random_indices[i]),
 
121
  }
122
  )
123
  return data
 
155
 
156
  def load_sample(data, current_index):
157
  image_id = data[current_index]["id"]
158
+ qimage = imagenet_hard[int(image_id)]["image"]
159
+ # labels = data[current_index]["correct_label"]
160
+ labels = imagenet_hard[int(image_id)]["english_label"]
161
+ # print(f"Image ID: {image_id}")
162
+ # print(f"Labels: {labels}")
163
 
 
164
  return qimage, labels
165
 
166
 
167
  def preprocessing(data, current_index, history, username):
168
  data = generate_dataset(username)
169
 
170
+ remaining_images = len(data)
171
+ labeled_images = len(bad_items) - remaining_images
172
+
173
+ if remaining_images == 0:
174
  fake_plot = string_to_image("No more images to review")
175
  empty_image = Image.new("RGB", (224, 224))
176
  return (
 
180
  history,
181
  data,
182
  None,
183
+ labeled_images,
184
  )
185
 
186
  current_index = 0
 
195
  labels = ", ".join(labels)
196
  label_plot = string_to_image(labels)
197
 
198
+ return (
199
+ qimage,
200
+ label_plot,
201
+ current_index,
202
+ history,
203
+ data,
204
+ training_samples_image,
205
+ labeled_images,
206
+ )
207
 
208
 
209
  def update_app(decision, data, current_index, history, username):
 
211
  if current_index == -1:
212
  fake_plot = string_to_image("Please Enter your username and load samples")
213
  empty_image = Image.new("RGB", (224, 224))
214
+ return empty_image, fake_plot, current_index, history, data, None, 0
215
 
216
  if current_index == NUMBER_OF_IMAGES - 1:
217
  time_stamp = int(time.time())
 
243
 
244
  fake_plot = string_to_image("Thank you for your time!")
245
  empty_image = Image.new("RGB", (224, 224))
246
+
247
+ remaining_images = len(data)
248
+ labeled_images = (len(bad_items) - remaining_images) + current_index
249
+
250
+ return (
251
+ empty_image,
252
+ fake_plot,
253
+ current_index,
254
+ history,
255
+ data,
256
+ None,
257
+ labeled_images + 1,
258
+ )
259
 
260
  if current_index >= 0 and current_index < NUMBER_OF_IMAGES - 1:
261
  time_stamp = int(time.time())
 
299
  labels = ", ".join(labels)
300
  label_plot = string_to_image(labels)
301
 
302
+ remaining_images = len(data)
303
+ labeled_images = (len(bad_items) - remaining_images) + current_index
304
+
305
+ return (
306
+ qimage,
307
+ label_plot,
308
+ current_index,
309
+ history,
310
+ data,
311
+ training_samples_image,
312
+ labeled_images,
313
+ )
314
 
315
 
316
  newcss = """
 
353
  )
354
 
355
  with gr.Column():
356
+ with gr.Row():
357
+ username = gr.Textbox(label="Username", value=f"user-{random_str}")
358
+ labeled_images = gr.Textbox(label="Labeled Images", value="0")
359
+ total_images = gr.Textbox(label="Total Images", value=len(bad_items))
360
+
361
  prepare_btn = gr.Button(value="Load Samples")
362
 
363
  with gr.Column():
 
385
  history,
386
  data_gr,
387
  training_samples,
388
+ labeled_images,
389
  ],
390
  )
391
  myabe_btn.click(
 
398
  history,
399
  data_gr,
400
  training_samples,
401
+ labeled_images,
402
  ],
403
  )
404
 
 
412
  history,
413
  data_gr,
414
  training_samples,
415
+ labeled_images,
416
  ],
417
  )
418
 
 
426
  history,
427
  data_gr,
428
  training_samples,
429
+ labeled_images,
430
  ],
431
  )
432
 
433
+ demo.launch(debug=True)