Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -3,142 +3,239 @@ from functools import partial
|
|
3 |
from pandas import DataFrame
|
4 |
import earthview as ev
|
5 |
import utils
|
6 |
-
import gradio as gr
|
7 |
import tqdm
|
8 |
import os
|
|
|
9 |
|
10 |
-
DEBUG
|
|
|
11 |
|
12 |
-
|
13 |
-
|
|
|
|
|
14 |
|
15 |
-
def open_dataset(dataset, subset, split, batch_size,
|
|
|
|
|
|
|
16 |
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
else:
|
22 |
-
|
|
|
|
|
|
|
23 |
|
|
|
|
|
24 |
if DEBUG == "random":
|
25 |
-
|
|
|
26 |
elif DEBUG == "samples":
|
27 |
-
|
|
|
|
|
|
|
|
|
28 |
elif not DEBUG:
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
46 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
47 |
images = []
|
48 |
metadatas = []
|
49 |
-
|
50 |
-
|
|
|
51 |
if DEBUG == "random":
|
52 |
-
|
|
|
|
|
53 |
if not only_rgb:
|
54 |
-
|
55 |
-
|
56 |
-
metadatas.append({"bounds":[[1,1,4,4]], })
|
57 |
else:
|
58 |
try:
|
59 |
-
item
|
|
|
60 |
except StopIteration:
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
if subset == "satellogic":
|
66 |
-
images.extend(item["rgb"])
|
67 |
-
if not only_rgb:
|
68 |
-
images.extend(item["1m"])
|
69 |
-
elif subset == "sentinel_1":
|
70 |
-
images.extend(item["10m"])
|
71 |
-
elif subset == "sentinel_2":
|
72 |
-
images.extend(item["rgb"])
|
73 |
-
if not only_rgb:
|
74 |
-
images.extend(item["10m"])
|
75 |
-
images.extend(item["20m"])
|
76 |
-
images.extend(item["scl"])
|
77 |
-
elif subset == "neon":
|
78 |
-
images.extend(item["rgb"])
|
79 |
-
if not only_rgb:
|
80 |
-
images.extend(item["chm"])
|
81 |
-
images.extend(item["1m"])
|
82 |
-
|
83 |
-
metadata["map"] = f'<a href="{utils.get_google_map_link(item, subset)}" target="about:_blank">🧭</a>'
|
84 |
-
metadatas.append(metadata)
|
85 |
-
|
86 |
-
return images, DataFrame(metadatas)
|
87 |
-
|
88 |
-
def update_shape(columns):
|
89 |
-
return gr.update(columns=columns)
|
90 |
-
|
91 |
-
def new_state():
|
92 |
-
return gr.State({})
|
93 |
-
|
94 |
-
if __name__ == "__main__":
|
95 |
-
with gr.Blocks(title="EarthView Viewer", fill_height = True) as demo:
|
96 |
-
state = new_state()
|
97 |
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
126 |
|
127 |
-
|
|
|
128 |
|
129 |
-
|
130 |
|
131 |
with gr.Row():
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
143 |
|
144 |
demo.launch(show_api=False)
|
|
|
3 |
from pandas import DataFrame
|
4 |
import earthview as ev
|
5 |
import utils
|
6 |
+
import gradio as gr
|
7 |
import tqdm
|
8 |
import os
|
9 |
+
import numpy as np
|
10 |
|
11 |
+
# Set DEBUG to False for normal operation, "random" for random data, "samples" for local parquet samples
|
12 |
+
DEBUG = False
|
13 |
|
14 |
+
app_state = {
|
15 |
+
"dsi": None, # Dataset iterator
|
16 |
+
"subset": None, # Currently loaded subset
|
17 |
+
}
|
18 |
|
19 |
+
def open_dataset(dataset, subset, split, batch_size, shard_value, only_rgb):
|
20 |
+
"""
|
21 |
+
Loads the specified dataset subset and shard, initializes the iterator,
|
22 |
+
and returns initial images and metadata.
|
23 |
|
24 |
+
Args:
|
25 |
+
dataset (str): Name of the main dataset.
|
26 |
+
subset (str): Name of the subset to load.
|
27 |
+
split (str): Data split (e.g., 'train', 'test').
|
28 |
+
batch_size (int): Number of items to fetch per batch.
|
29 |
+
shard_value (int): The specific shard index to load (-1 for all).
|
30 |
+
only_rgb (bool): Whether to load only RGB images.
|
31 |
+
|
32 |
+
Returns:
|
33 |
+
tuple: Updated components/values for the Gradio interface:
|
34 |
+
(updated_shard_slider, initial_gallery_images, initial_metadata_table).
|
35 |
+
"""
|
36 |
+
global app_state
|
37 |
+
|
38 |
+
print(f"Loading dataset: {dataset}, subset: {subset}, split: {split}, shard: {shard_value}")
|
39 |
+
|
40 |
+
try:
|
41 |
+
nshards = ev.get_nshards(subset) # Get total number of shards for the subset
|
42 |
+
except Exception as e:
|
43 |
+
raise gr.Error(f"Failed to get shard count for subset '{subset}': {e}")
|
44 |
+
|
45 |
+
# Determine which shards to load
|
46 |
+
if shard_value == -1:
|
47 |
+
shards_to_load = None # Load all shards
|
48 |
+
print("Loading all shards.")
|
49 |
else:
|
50 |
+
# Ensure the selected shard is within the valid range
|
51 |
+
shard_value = max(0, min(shard_value, nshards - 1))
|
52 |
+
shards_to_load = [shard_value]
|
53 |
+
print(f"Loading shard {shard_value} out of {nshards}.")
|
54 |
|
55 |
+
# Load the dataset based on DEBUG configuration
|
56 |
+
ds = None
|
57 |
if DEBUG == "random":
|
58 |
+
print("DEBUG MODE: Using random data.")
|
59 |
+
ds = range(batch_size * 2) # Generate enough for a couple of batches
|
60 |
elif DEBUG == "samples":
|
61 |
+
print("DEBUG MODE: Using local Parquet samples.")
|
62 |
+
try:
|
63 |
+
ds = ev.load_parquet(subset, batch_size=batch_size * 2)
|
64 |
+
except Exception as e:
|
65 |
+
raise gr.Error(f"Failed to load Parquet samples for '{subset}': {e}")
|
66 |
elif not DEBUG:
|
67 |
+
print("Loading dataset from source...")
|
68 |
+
try:
|
69 |
+
ds = ev.load_dataset(subset, dataset=dataset, split=split, shards=shards_to_load, cache_dir="dataset")
|
70 |
+
except Exception as e:
|
71 |
+
raise gr.Error(f"Failed to load dataset '{dataset}/{subset}': {e}")
|
72 |
+
else:
|
73 |
+
raise ValueError("Invalid DEBUG setting.")
|
74 |
+
|
75 |
+
# Create an iterator and store it in the state
|
76 |
+
app_state["dsi"] = iter(ds)
|
77 |
+
app_state["subset"] = subset
|
78 |
+
|
79 |
+
print("Dataset loaded, fetching initial batch...")
|
80 |
+
images, metadata_df = get_images(batch_size, only_rgb)
|
81 |
+
|
82 |
+
updated_shard_slider = gr.Slider(label=f"Shard (0 to {nshards-1})", value=shard_value, maximum=nshards -1 if nshards > 0 else 0)
|
83 |
+
|
84 |
+
return updated_shard_slider, images, metadata_df
|
85 |
+
|
86 |
+
def get_images(batch_size, only_rgb):
|
87 |
+
"""
|
88 |
+
Fetches the next batch of images and metadata from the current dataset iterator.
|
89 |
+
|
90 |
+
Args:
|
91 |
+
batch_size (int): Number of items to fetch.
|
92 |
+
only_rgb (bool): Whether to load only RGB images.
|
93 |
|
94 |
+
Returns:
|
95 |
+
tuple: (list_of_images, pandas_dataframe_of_metadata)
|
96 |
+
"""
|
97 |
+
global app_state
|
98 |
+
|
99 |
+
if app_state.get("dsi") is None or app_state.get("subset") is None:
|
100 |
+
raise gr.Error("You need to load a Dataset first using the 'Load' button.")
|
101 |
+
|
102 |
+
subset = app_state["subset"]
|
103 |
+
dsi = app_state["dsi"]
|
104 |
images = []
|
105 |
metadatas = []
|
106 |
+
|
107 |
+
print(f"Fetching next {batch_size} images...")
|
108 |
+
for i in tqdm.trange(batch_size, desc=f"Getting images for {subset}"):
|
109 |
if DEBUG == "random":
|
110 |
+
# Generate random image and basic metadata for debugging
|
111 |
+
img_rgb = np.random.randint(0, 255, (384, 384, 3), dtype=np.uint8)
|
112 |
+
images.append(img_rgb)
|
113 |
if not only_rgb:
|
114 |
+
img_other = np.random.randint(0, 255, (100, 100, 3), dtype=np.uint8)
|
115 |
+
images.append(img_other)
|
116 |
+
metadatas.append({"id": f"random_{i}", "bounds": [[1, 1, 4, 4]], "map": "N/A"})
|
117 |
else:
|
118 |
try:
|
119 |
+
# Get the next item from the iterator
|
120 |
+
item = next(dsi)
|
121 |
except StopIteration:
|
122 |
+
print("End of dataset iterator reached.")
|
123 |
+
gr.Warning("End of dataset shard reached.") # Inform user
|
124 |
+
break # Stop fetching if iterator is exhausted
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
125 |
|
126 |
+
try:
|
127 |
+
# Process the item to extract images and metadata
|
128 |
+
item_data = ev.item_to_images(subset, item)
|
129 |
+
metadata = item_data["metadata"]
|
130 |
+
|
131 |
+
# Append images based on subset type and only_rgb flag
|
132 |
+
if subset == "satellogic":
|
133 |
+
images.extend(item_data.get("rgb", []))
|
134 |
+
if not only_rgb:
|
135 |
+
images.extend(item_data.get("1m", []))
|
136 |
+
elif subset == "sentinel_1":
|
137 |
+
images.extend(item_data.get("10m", []))
|
138 |
+
elif subset == "sentinel_2":
|
139 |
+
images.extend(item_data.get("rgb", []))
|
140 |
+
if not only_rgb:
|
141 |
+
images.extend(item_data.get("10m", []))
|
142 |
+
images.extend(item_data.get("20m", []))
|
143 |
+
images.extend(item_data.get("scl", []))
|
144 |
+
elif subset == "neon":
|
145 |
+
images.extend(item_data.get("rgb", []))
|
146 |
+
if not only_rgb:
|
147 |
+
images.extend(item_data.get("chm", []))
|
148 |
+
images.extend(item_data.get("1m", []))
|
149 |
+
else:
|
150 |
+
# Handle potential unknown subsets gracefully
|
151 |
+
print(f"Warning: Image extraction logic not defined for subset '{subset}'. Trying 'rgb'.")
|
152 |
+
images.extend(item_data.get("rgb", []))
|
153 |
+
|
154 |
+
|
155 |
+
map_link = utils.get_google_map_link(item_data, subset)
|
156 |
+
metadata["map"] = f'<a href="{map_link}" target="_blank">🧭 View Map</a>' if map_link else "N/A"
|
157 |
+
metadatas.append(metadata)
|
158 |
+
|
159 |
+
except Exception as e:
|
160 |
+
print(f"Error processing item: {item}. Error: {e}")
|
161 |
+
metadatas.append({"id": item.get("id", "Error"), "error": str(e), "map": "Error"})
|
162 |
+
|
163 |
+
|
164 |
+
print(f"Fetched {len(metadatas)} items for the batch.")
|
165 |
+
# Convert metadata list to a Pandas DataFrame
|
166 |
+
metadata_df = DataFrame(metadatas)
|
167 |
+
return images, metadata_df
|
168 |
+
|
169 |
+
def update_gallery_columns(columns):
|
170 |
+
"""
|
171 |
+
Updates the number of columns in the image gallery.
|
172 |
+
|
173 |
+
Args:
|
174 |
+
columns (int): The desired number of columns.
|
175 |
+
|
176 |
+
Returns:
|
177 |
+
dict: A dictionary mapping the gallery component to its updated state.
|
178 |
+
In Gradio 5, we return the component constructor with new args.
|
179 |
+
"""
|
180 |
+
print(f"Updating gallery columns to: {columns}")
|
181 |
+
# Ensure columns is at least 1
|
182 |
+
columns = max(1, int(columns))
|
183 |
+
# Return the updated component configuration
|
184 |
+
return gr.Gallery(columns=columns)
|
185 |
|
186 |
+
if __name__ == "__main__":
|
187 |
+
with gr.Blocks(title="EarthView Viewer v5 fork", fill_height=True, theme=gr.themes.Default()) as demo:
|
188 |
|
189 |
+
gr.Markdown(f"# Viewer for [{ev.DATASET}](https://huggingface.co/datasets/satellogic/EarthView) Dataset (Gradio 5)")
|
190 |
|
191 |
with gr.Row():
|
192 |
+
with gr.Column(scale=1):
|
193 |
+
dataset_name = gr.Textbox(label="Dataset", value=ev.DATASET, interactive=False)
|
194 |
+
subset_select = gr.Dropdown(choices=ev.get_subsets(), label="Subset", value="satellogic")
|
195 |
+
split_name = gr.Textbox(label="Split", value="train")
|
196 |
+
initial_shard_input = gr.Number(label="Load Shard", value=10, minimum=-1, step=1, info="Enter shard index (0-based) or -1 for all shards")
|
197 |
+
only_rgb_checkbox = gr.Checkbox(label="Only RGB Images", value=True)
|
198 |
+
batch_size_input = gr.Number(value=10, label="Batch Size", minimum=1, step=1)
|
199 |
+
load_button = gr.Button("Load Dataset / Shard", variant="primary")
|
200 |
+
|
201 |
+
shard_slider = gr.Slider(label="Shard", minimum=0, maximum=1, step=1, value=0)
|
202 |
+
gallery_columns_input = gr.Number(value=5, label="Gallery Columns", minimum=1, step=1)
|
203 |
+
|
204 |
+
next_batch_button = gr.Button("Next Batch (from current shard)", scale=0)
|
205 |
+
|
206 |
+
with gr.Column(scale=4):
|
207 |
+
image_gallery = gr.Gallery(
|
208 |
+
label="Dataset Images",
|
209 |
+
interactive=False,
|
210 |
+
object_fit="scale-down",
|
211 |
+
columns=5,
|
212 |
+
height="600px",
|
213 |
+
show_label=False
|
214 |
+
)
|
215 |
+
metadata_table = gr.DataFrame(datatype="html", wrap=True)
|
216 |
+
|
217 |
+
load_button.click(
|
218 |
+
fn=open_dataset,
|
219 |
+
inputs=[dataset_name, subset_select, split_name, batch_size_input, initial_shard_input, only_rgb_checkbox],
|
220 |
+
outputs=[shard_slider, image_gallery, metadata_table]
|
221 |
+
)
|
222 |
+
|
223 |
+
shard_slider.release(
|
224 |
+
fn=open_dataset,
|
225 |
+
inputs=[dataset_name, subset_select, split_name, batch_size_input, shard_slider, only_rgb_checkbox],
|
226 |
+
outputs=[shard_slider, image_gallery, metadata_table]
|
227 |
+
)
|
228 |
+
|
229 |
+
gallery_columns_input.change(
|
230 |
+
fn=update_gallery_columns,
|
231 |
+
inputs=[gallery_columns_input],
|
232 |
+
outputs=[image_gallery]
|
233 |
+
)
|
234 |
+
|
235 |
+
next_batch_button.click(
|
236 |
+
fn=get_images,
|
237 |
+
inputs=[batch_size_input, only_rgb_checkbox],
|
238 |
+
outputs=[image_gallery, metadata_table]
|
239 |
+
)
|
240 |
|
241 |
demo.launch(show_api=False)
|