prowseed commited on
Commit
a81a404
·
verified ·
1 Parent(s): 7dbf341

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +210 -113
app.py CHANGED
@@ -3,142 +3,239 @@ from functools import partial
3
  from pandas import DataFrame
4
  import earthview as ev
5
  import utils
6
- import gradio as gr
7
  import tqdm
8
  import os
 
9
 
10
- DEBUG = False # False, "random", "samples"
 
11
 
12
- if DEBUG == "random":
13
- import numpy as np
 
 
14
 
15
- def open_dataset(dataset, subset, split, batch_size, shard, only_rgb, state):
 
 
 
16
 
17
- nshards = ev.get_nshards(subset)
18
-
19
- if shard == -1:
20
- shards = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  else:
22
- shards = [shard]
 
 
 
23
 
 
 
24
  if DEBUG == "random":
25
- ds = range(batch_size)
 
26
  elif DEBUG == "samples":
27
- ds = ev.load_parquet(subset, batch_size=batch_size)
 
 
 
 
28
  elif not DEBUG:
29
- ds = ev.load_dataset(subset, dataset=dataset, split=split, shards=shards, cache_dir="dataset")
30
-
31
- dsi = iter(ds)
32
-
33
- state["subset"] = subset
34
- state["dsi"] = dsi
35
- return (
36
- gr.update(label=f"Shard (max {nshards})", value=shard, maximum=nshards),
37
- *get_images(batch_size, only_rgb, state),
38
- state
39
- )
40
-
41
- def get_images(batch_size, only_rgb, state):
42
- try:
43
- subset = state["subset"]
44
- except KeyError:
45
- raise gr.Error("You need to load a Dataset first")
 
 
 
 
 
 
 
 
 
46
 
 
 
 
 
 
 
 
 
 
 
47
  images = []
48
  metadatas = []
49
-
50
- for i in tqdm.trange(batch_size, desc=f"Getting images"):
 
51
  if DEBUG == "random":
52
- images.append(np.random.randint(0,255,(384,384,3)))
 
 
53
  if not only_rgb:
54
- images.append(np.random.randint(0,255,(100,100,3)))
55
-
56
- metadatas.append({"bounds":[[1,1,4,4]], })
57
  else:
58
  try:
59
- item = next(state["dsi"])
 
60
  except StopIteration:
61
- break
62
- item = ev.item_to_images(subset, item)
63
- metadata = item["metadata"]
64
-
65
- if subset == "satellogic":
66
- images.extend(item["rgb"])
67
- if not only_rgb:
68
- images.extend(item["1m"])
69
- elif subset == "sentinel_1":
70
- images.extend(item["10m"])
71
- elif subset == "sentinel_2":
72
- images.extend(item["rgb"])
73
- if not only_rgb:
74
- images.extend(item["10m"])
75
- images.extend(item["20m"])
76
- images.extend(item["scl"])
77
- elif subset == "neon":
78
- images.extend(item["rgb"])
79
- if not only_rgb:
80
- images.extend(item["chm"])
81
- images.extend(item["1m"])
82
-
83
- metadata["map"] = f'<a href="{utils.get_google_map_link(item, subset)}" target="about:_blank">🧭</a>'
84
- metadatas.append(metadata)
85
-
86
- return images, DataFrame(metadatas)
87
-
88
- def update_shape(columns):
89
- return gr.update(columns=columns)
90
-
91
- def new_state():
92
- return gr.State({})
93
-
94
- if __name__ == "__main__":
95
- with gr.Blocks(title="EarthView Viewer", fill_height = True) as demo:
96
- state = new_state()
97
 
98
- gr.Markdown(f"# Viewer for [{ev.DATASET}](https://huggingface.co/datasets/satellogic/EarthView) Dataset")
99
- batch_size = gr.Number(10, label = "Batch Size", render=False)
100
- shard = gr.Slider(label="Shard", minimum=0, maximum=10000, step=1, render=False)
101
- table = gr.DataFrame(render = False, datatype="html")
102
- # headers=["Index","TimeStamp","Bounds","CRS"],
103
-
104
- gallery = gr.Gallery(
105
- label=ev.DATASET,
106
- interactive=False,
107
- object_fit="scale-down",
108
- columns=5, render=False)
109
-
110
- with gr.Row():
111
- dataset = gr.Textbox(label="Dataset", value=ev.DATASET, interactive=False)
112
- subset = gr.Dropdown(choices=ev.get_subsets(), label="Subset", value="satellogic", )
113
- split = gr.Textbox(label="Split", value="train")
114
- initial_shard = gr.Number(label = "Initial shard", value=10, info="-1 for whole dataset")
115
- only_rgb = gr.Checkbox(label="Only RGB", value=True)
116
-
117
- gr.Button("Load (minutes)").click(
118
- open_dataset,
119
- inputs=[dataset, subset, split, batch_size, initial_shard, only_rgb, state],
120
- outputs=[shard, gallery, table, state])
121
-
122
- gallery.render()
123
-
124
- with gr.Row():
125
- batch_size.render()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
 
127
- columns = gr.Number(5, label="Columns")
 
128
 
129
- columns.change(update_shape, [columns], [gallery])
130
 
131
  with gr.Row():
132
- shard.render()
133
- shard.release(
134
- open_dataset,
135
- inputs=[dataset, subset, split, batch_size, shard, only_rgb, state],
136
- outputs=[shard, gallery, table, state])
137
-
138
- btn = gr.Button("Next Batch (same shard)", scale=0)
139
- btn.click(get_images, [batch_size, only_rgb, state], [gallery, table])
140
- btn.click()
141
-
142
- table.render()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
143
 
144
  demo.launch(show_api=False)
 
3
  from pandas import DataFrame
4
  import earthview as ev
5
  import utils
6
+ import gradio as gr
7
  import tqdm
8
  import os
9
+ import numpy as np
10
 
11
+ # Set DEBUG to False for normal operation, "random" for random data, "samples" for local parquet samples
12
+ DEBUG = False
13
 
14
+ app_state = {
15
+ "dsi": None, # Dataset iterator
16
+ "subset": None, # Currently loaded subset
17
+ }
18
 
19
+ def open_dataset(dataset, subset, split, batch_size, shard_value, only_rgb):
20
+ """
21
+ Loads the specified dataset subset and shard, initializes the iterator,
22
+ and returns initial images and metadata.
23
 
24
+ Args:
25
+ dataset (str): Name of the main dataset.
26
+ subset (str): Name of the subset to load.
27
+ split (str): Data split (e.g., 'train', 'test').
28
+ batch_size (int): Number of items to fetch per batch.
29
+ shard_value (int): The specific shard index to load (-1 for all).
30
+ only_rgb (bool): Whether to load only RGB images.
31
+
32
+ Returns:
33
+ tuple: Updated components/values for the Gradio interface:
34
+ (updated_shard_slider, initial_gallery_images, initial_metadata_table).
35
+ """
36
+ global app_state
37
+
38
+ print(f"Loading dataset: {dataset}, subset: {subset}, split: {split}, shard: {shard_value}")
39
+
40
+ try:
41
+ nshards = ev.get_nshards(subset) # Get total number of shards for the subset
42
+ except Exception as e:
43
+ raise gr.Error(f"Failed to get shard count for subset '{subset}': {e}")
44
+
45
+ # Determine which shards to load
46
+ if shard_value == -1:
47
+ shards_to_load = None # Load all shards
48
+ print("Loading all shards.")
49
  else:
50
+ # Ensure the selected shard is within the valid range
51
+ shard_value = max(0, min(shard_value, nshards - 1))
52
+ shards_to_load = [shard_value]
53
+ print(f"Loading shard {shard_value} out of {nshards}.")
54
 
55
+ # Load the dataset based on DEBUG configuration
56
+ ds = None
57
  if DEBUG == "random":
58
+ print("DEBUG MODE: Using random data.")
59
+ ds = range(batch_size * 2) # Generate enough for a couple of batches
60
  elif DEBUG == "samples":
61
+ print("DEBUG MODE: Using local Parquet samples.")
62
+ try:
63
+ ds = ev.load_parquet(subset, batch_size=batch_size * 2)
64
+ except Exception as e:
65
+ raise gr.Error(f"Failed to load Parquet samples for '{subset}': {e}")
66
  elif not DEBUG:
67
+ print("Loading dataset from source...")
68
+ try:
69
+ ds = ev.load_dataset(subset, dataset=dataset, split=split, shards=shards_to_load, cache_dir="dataset")
70
+ except Exception as e:
71
+ raise gr.Error(f"Failed to load dataset '{dataset}/{subset}': {e}")
72
+ else:
73
+ raise ValueError("Invalid DEBUG setting.")
74
+
75
+ # Create an iterator and store it in the state
76
+ app_state["dsi"] = iter(ds)
77
+ app_state["subset"] = subset
78
+
79
+ print("Dataset loaded, fetching initial batch...")
80
+ images, metadata_df = get_images(batch_size, only_rgb)
81
+
82
+ updated_shard_slider = gr.Slider(label=f"Shard (0 to {nshards-1})", value=shard_value, maximum=nshards -1 if nshards > 0 else 0)
83
+
84
+ return updated_shard_slider, images, metadata_df
85
+
86
+ def get_images(batch_size, only_rgb):
87
+ """
88
+ Fetches the next batch of images and metadata from the current dataset iterator.
89
+
90
+ Args:
91
+ batch_size (int): Number of items to fetch.
92
+ only_rgb (bool): Whether to load only RGB images.
93
 
94
+ Returns:
95
+ tuple: (list_of_images, pandas_dataframe_of_metadata)
96
+ """
97
+ global app_state
98
+
99
+ if app_state.get("dsi") is None or app_state.get("subset") is None:
100
+ raise gr.Error("You need to load a Dataset first using the 'Load' button.")
101
+
102
+ subset = app_state["subset"]
103
+ dsi = app_state["dsi"]
104
  images = []
105
  metadatas = []
106
+
107
+ print(f"Fetching next {batch_size} images...")
108
+ for i in tqdm.trange(batch_size, desc=f"Getting images for {subset}"):
109
  if DEBUG == "random":
110
+ # Generate random image and basic metadata for debugging
111
+ img_rgb = np.random.randint(0, 255, (384, 384, 3), dtype=np.uint8)
112
+ images.append(img_rgb)
113
  if not only_rgb:
114
+ img_other = np.random.randint(0, 255, (100, 100, 3), dtype=np.uint8)
115
+ images.append(img_other)
116
+ metadatas.append({"id": f"random_{i}", "bounds": [[1, 1, 4, 4]], "map": "N/A"})
117
  else:
118
  try:
119
+ # Get the next item from the iterator
120
+ item = next(dsi)
121
  except StopIteration:
122
+ print("End of dataset iterator reached.")
123
+ gr.Warning("End of dataset shard reached.") # Inform user
124
+ break # Stop fetching if iterator is exhausted
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
 
126
+ try:
127
+ # Process the item to extract images and metadata
128
+ item_data = ev.item_to_images(subset, item)
129
+ metadata = item_data["metadata"]
130
+
131
+ # Append images based on subset type and only_rgb flag
132
+ if subset == "satellogic":
133
+ images.extend(item_data.get("rgb", []))
134
+ if not only_rgb:
135
+ images.extend(item_data.get("1m", []))
136
+ elif subset == "sentinel_1":
137
+ images.extend(item_data.get("10m", []))
138
+ elif subset == "sentinel_2":
139
+ images.extend(item_data.get("rgb", []))
140
+ if not only_rgb:
141
+ images.extend(item_data.get("10m", []))
142
+ images.extend(item_data.get("20m", []))
143
+ images.extend(item_data.get("scl", []))
144
+ elif subset == "neon":
145
+ images.extend(item_data.get("rgb", []))
146
+ if not only_rgb:
147
+ images.extend(item_data.get("chm", []))
148
+ images.extend(item_data.get("1m", []))
149
+ else:
150
+ # Handle potential unknown subsets gracefully
151
+ print(f"Warning: Image extraction logic not defined for subset '{subset}'. Trying 'rgb'.")
152
+ images.extend(item_data.get("rgb", []))
153
+
154
+
155
+ map_link = utils.get_google_map_link(item_data, subset)
156
+ metadata["map"] = f'<a href="{map_link}" target="_blank">🧭 View Map</a>' if map_link else "N/A"
157
+ metadatas.append(metadata)
158
+
159
+ except Exception as e:
160
+ print(f"Error processing item: {item}. Error: {e}")
161
+ metadatas.append({"id": item.get("id", "Error"), "error": str(e), "map": "Error"})
162
+
163
+
164
+ print(f"Fetched {len(metadatas)} items for the batch.")
165
+ # Convert metadata list to a Pandas DataFrame
166
+ metadata_df = DataFrame(metadatas)
167
+ return images, metadata_df
168
+
169
+ def update_gallery_columns(columns):
170
+ """
171
+ Updates the number of columns in the image gallery.
172
+
173
+ Args:
174
+ columns (int): The desired number of columns.
175
+
176
+ Returns:
177
+ dict: A dictionary mapping the gallery component to its updated state.
178
+ In Gradio 5, we return the component constructor with new args.
179
+ """
180
+ print(f"Updating gallery columns to: {columns}")
181
+ # Ensure columns is at least 1
182
+ columns = max(1, int(columns))
183
+ # Return the updated component configuration
184
+ return gr.Gallery(columns=columns)
185
 
186
+ if __name__ == "__main__":
187
+ with gr.Blocks(title="EarthView Viewer v5 fork", fill_height=True, theme=gr.themes.Default()) as demo:
188
 
189
+ gr.Markdown(f"# Viewer for [{ev.DATASET}](https://huggingface.co/datasets/satellogic/EarthView) Dataset (Gradio 5)")
190
 
191
  with gr.Row():
192
+ with gr.Column(scale=1):
193
+ dataset_name = gr.Textbox(label="Dataset", value=ev.DATASET, interactive=False)
194
+ subset_select = gr.Dropdown(choices=ev.get_subsets(), label="Subset", value="satellogic")
195
+ split_name = gr.Textbox(label="Split", value="train")
196
+ initial_shard_input = gr.Number(label="Load Shard", value=10, minimum=-1, step=1, info="Enter shard index (0-based) or -1 for all shards")
197
+ only_rgb_checkbox = gr.Checkbox(label="Only RGB Images", value=True)
198
+ batch_size_input = gr.Number(value=10, label="Batch Size", minimum=1, step=1)
199
+ load_button = gr.Button("Load Dataset / Shard", variant="primary")
200
+
201
+ shard_slider = gr.Slider(label="Shard", minimum=0, maximum=1, step=1, value=0)
202
+ gallery_columns_input = gr.Number(value=5, label="Gallery Columns", minimum=1, step=1)
203
+
204
+ next_batch_button = gr.Button("Next Batch (from current shard)", scale=0)
205
+
206
+ with gr.Column(scale=4):
207
+ image_gallery = gr.Gallery(
208
+ label="Dataset Images",
209
+ interactive=False,
210
+ object_fit="scale-down",
211
+ columns=5,
212
+ height="600px",
213
+ show_label=False
214
+ )
215
+ metadata_table = gr.DataFrame(datatype="html", wrap=True)
216
+
217
+ load_button.click(
218
+ fn=open_dataset,
219
+ inputs=[dataset_name, subset_select, split_name, batch_size_input, initial_shard_input, only_rgb_checkbox],
220
+ outputs=[shard_slider, image_gallery, metadata_table]
221
+ )
222
+
223
+ shard_slider.release(
224
+ fn=open_dataset,
225
+ inputs=[dataset_name, subset_select, split_name, batch_size_input, shard_slider, only_rgb_checkbox],
226
+ outputs=[shard_slider, image_gallery, metadata_table]
227
+ )
228
+
229
+ gallery_columns_input.change(
230
+ fn=update_gallery_columns,
231
+ inputs=[gallery_columns_input],
232
+ outputs=[image_gallery]
233
+ )
234
+
235
+ next_batch_button.click(
236
+ fn=get_images,
237
+ inputs=[batch_size_input, only_rgb_checkbox],
238
+ outputs=[image_gallery, metadata_table]
239
+ )
240
 
241
  demo.launch(show_api=False)