File size: 9,897 Bytes
5129aaa
6d787c4
 
196b164
f453698
a81a404
6d787c4
c670c44
a81a404
6d787c4
a81a404
 
6d787c4
a81a404
 
 
 
6d787c4
a81a404
 
 
 
6d787c4
a81a404
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6d787c4
a81a404
 
 
 
5129aaa
a81a404
 
5129aaa
a81a404
 
5129aaa
a81a404
 
 
 
 
5129aaa
a81a404
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6d787c4
a81a404
 
 
 
 
 
 
 
 
 
2ef57a2
 
a81a404
 
 
5129aaa
a81a404
 
 
5129aaa
a81a404
 
 
6d787c4
 
a81a404
 
6d787c4
a81a404
 
 
2ef57a2
a81a404
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2ef57a2
a81a404
 
2ef57a2
a81a404
2ef57a2
 
a81a404
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2ef57a2
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
from datasets import load_dataset
from functools import partial
from pandas import DataFrame
import earthview as ev
import utils
import gradio as gr
import tqdm
import os
import numpy as np

# Set DEBUG to False for normal operation, "random" for random data, "samples" for local parquet samples
DEBUG = False

app_state = {
    "dsi": None,    # Dataset iterator
    "subset": None, # Currently loaded subset
}

def open_dataset(dataset, subset, split, batch_size, shard_value, only_rgb):
    """
    Loads the specified dataset subset and shard, initializes the iterator,
    and returns initial images and metadata.

    Args:
        dataset (str): Name of the main dataset.
        subset (str): Name of the subset to load.
        split (str): Data split (e.g., 'train', 'test').
        batch_size (int): Number of items to fetch per batch.
        shard_value (int): The specific shard index to load (-1 for all).
        only_rgb (bool): Whether to load only RGB images.

    Returns:
        tuple: Updated components/values for the Gradio interface:
               (updated_shard_slider, initial_gallery_images, initial_metadata_table).
    """
    global app_state

    print(f"Loading dataset: {dataset}, subset: {subset}, split: {split}, shard: {shard_value}")

    try:
        nshards = ev.get_nshards(subset) # Get total number of shards for the subset
    except Exception as e:
        raise gr.Error(f"Failed to get shard count for subset '{subset}': {e}")

    # Determine which shards to load
    if shard_value == -1:
        shards_to_load = None # Load all shards
        print("Loading all shards.")
    else:
        # Ensure the selected shard is within the valid range
        shard_value = max(0, min(shard_value, nshards - 1))
        shards_to_load = [shard_value]
        print(f"Loading shard {shard_value} out of {nshards}.")

    # Load the dataset based on DEBUG configuration
    ds = None
    if DEBUG == "random":
        print("DEBUG MODE: Using random data.")
        ds = range(batch_size * 2) # Generate enough for a couple of batches
    elif DEBUG == "samples":
        print("DEBUG MODE: Using local Parquet samples.")
        try:
            ds = ev.load_parquet(subset, batch_size=batch_size * 2)
        except Exception as e:
             raise gr.Error(f"Failed to load Parquet samples for '{subset}': {e}")
    elif not DEBUG:
        print("Loading dataset from source...")
        try:
            ds = ev.load_dataset(subset, dataset=dataset, split=split, shards=shards_to_load, cache_dir="dataset")
        except Exception as e:
             raise gr.Error(f"Failed to load dataset '{dataset}/{subset}': {e}")
    else:
        raise ValueError("Invalid DEBUG setting.")

    # Create an iterator and store it in the state
    app_state["dsi"] = iter(ds)
    app_state["subset"] = subset

    print("Dataset loaded, fetching initial batch...")
    images, metadata_df = get_images(batch_size, only_rgb)

    updated_shard_slider = gr.Slider(label=f"Shard (0 to {nshards-1})", value=shard_value, maximum=nshards -1 if nshards > 0 else 0)

    return updated_shard_slider, images, metadata_df

def get_images(batch_size, only_rgb):
    """
    Fetches the next batch of images and metadata from the current dataset iterator.

    Args:
        batch_size (int): Number of items to fetch.
        only_rgb (bool): Whether to load only RGB images.

    Returns:
        tuple: (list_of_images, pandas_dataframe_of_metadata)
    """
    global app_state

    if app_state.get("dsi") is None or app_state.get("subset") is None:
        raise gr.Error("You need to load a Dataset first using the 'Load' button.")

    subset = app_state["subset"]
    dsi = app_state["dsi"]
    images = []
    metadatas = []

    print(f"Fetching next {batch_size} images...")
    for i in tqdm.trange(batch_size, desc=f"Getting images for {subset}"):
        if DEBUG == "random":
            # Generate random image and basic metadata for debugging
            img_rgb = np.random.randint(0, 255, (384, 384, 3), dtype=np.uint8)
            images.append(img_rgb)
            if not only_rgb:
                img_other = np.random.randint(0, 255, (100, 100, 3), dtype=np.uint8)
                images.append(img_other)
            metadatas.append({"id": f"random_{i}", "bounds": [[1, 1, 4, 4]], "map": "N/A"})
        else:
            try:
                # Get the next item from the iterator
                item = next(dsi)
            except StopIteration:
                print("End of dataset iterator reached.")
                gr.Warning("End of dataset shard reached.") # Inform user
                break # Stop fetching if iterator is exhausted

            try:
                # Process the item to extract images and metadata
                item_data = ev.item_to_images(subset, item)
                metadata = item_data["metadata"]

                # Append images based on subset type and only_rgb flag
                if subset == "satellogic":
                    images.extend(item_data.get("rgb", []))
                    if not only_rgb:
                        images.extend(item_data.get("1m", []))
                elif subset == "sentinel_1":
                    images.extend(item_data.get("10m", []))
                elif subset == "sentinel_2":
                    images.extend(item_data.get("rgb", []))
                    if not only_rgb:
                        images.extend(item_data.get("10m", []))
                        images.extend(item_data.get("20m", []))
                        images.extend(item_data.get("scl", []))
                elif subset == "neon":
                     images.extend(item_data.get("rgb", []))
                     if not only_rgb:
                         images.extend(item_data.get("chm", []))
                         images.extend(item_data.get("1m", []))
                else:
                     # Handle potential unknown subsets gracefully
                     print(f"Warning: Image extraction logic not defined for subset '{subset}'. Trying 'rgb'.")
                     images.extend(item_data.get("rgb", []))


                map_link = utils.get_google_map_link(item_data, subset)
                metadata["map"] = f'<a href="{map_link}" target="_blank">🧭 View Map</a>' if map_link else "N/A"
                metadatas.append(metadata)

            except Exception as e:
                print(f"Error processing item: {item}. Error: {e}")
                metadatas.append({"id": item.get("id", "Error"), "error": str(e), "map": "Error"})


    print(f"Fetched {len(metadatas)} items for the batch.")
    # Convert metadata list to a Pandas DataFrame
    metadata_df = DataFrame(metadatas)
    return images, metadata_df

def update_gallery_columns(columns):
    """
    Updates the number of columns in the image gallery.

    Args:
        columns (int): The desired number of columns.

    Returns:
        dict: A dictionary mapping the gallery component to its updated state.
              In Gradio 5, we return the component constructor with new args.
    """
    print(f"Updating gallery columns to: {columns}")
    # Ensure columns is at least 1
    columns = max(1, int(columns))
    # Return the updated component configuration
    return gr.Gallery(columns=columns)

if __name__ == "__main__":
    with gr.Blocks(title="EarthView Viewer v5 fork", fill_height=True, theme=gr.themes.Default()) as demo:

        gr.Markdown(f"# Viewer for [{ev.DATASET}](https://huggingface.co./datasets/satellogic/EarthView) Dataset (Gradio 5)")

        with gr.Row():
            with gr.Column(scale=1):
                dataset_name = gr.Textbox(label="Dataset", value=ev.DATASET, interactive=False)
                subset_select = gr.Dropdown(choices=ev.get_subsets(), label="Subset", value="satellogic")
                split_name = gr.Textbox(label="Split", value="train")
                initial_shard_input = gr.Number(label="Load Shard", value=10, minimum=-1, step=1, info="Enter shard index (0-based) or -1 for all shards")
                only_rgb_checkbox = gr.Checkbox(label="Only RGB Images", value=True)
                batch_size_input = gr.Number(value=10, label="Batch Size", minimum=1, step=1)
                load_button = gr.Button("Load Dataset / Shard", variant="primary")

                shard_slider = gr.Slider(label="Shard", minimum=0, maximum=1, step=1, value=0)
                gallery_columns_input = gr.Number(value=5, label="Gallery Columns", minimum=1, step=1)

                next_batch_button = gr.Button("Next Batch (from current shard)", scale=0)

            with gr.Column(scale=4):
                image_gallery = gr.Gallery(
                    label="Dataset Images",
                    interactive=False,
                    object_fit="scale-down",
                    columns=5,
                    height="600px",
                    show_label=False
                )
                metadata_table = gr.DataFrame(datatype="html", wrap=True)

        load_button.click(
            fn=open_dataset,
            inputs=[dataset_name, subset_select, split_name, batch_size_input, initial_shard_input, only_rgb_checkbox],
            outputs=[shard_slider, image_gallery, metadata_table]
        )

        shard_slider.release(
            fn=open_dataset,
            inputs=[dataset_name, subset_select, split_name, batch_size_input, shard_slider, only_rgb_checkbox],
             outputs=[shard_slider, image_gallery, metadata_table]
        )

        gallery_columns_input.change(
            fn=update_gallery_columns,
            inputs=[gallery_columns_input],
            outputs=[image_gallery]
        )

        next_batch_button.click(
            fn=get_images,
            inputs=[batch_size_input, only_rgb_checkbox],
            outputs=[image_gallery, metadata_table]
        )

    demo.launch(show_api=False)