prowseed's picture
Update app.py
a81a404 verified
from datasets import load_dataset
from functools import partial
from pandas import DataFrame
import earthview as ev
import utils
import gradio as gr
import tqdm
import os
import numpy as np
# Set DEBUG to False for normal operation, "random" for random data, "samples" for local parquet samples
DEBUG = False
app_state = {
"dsi": None, # Dataset iterator
"subset": None, # Currently loaded subset
}
def open_dataset(dataset, subset, split, batch_size, shard_value, only_rgb):
"""
Loads the specified dataset subset and shard, initializes the iterator,
and returns initial images and metadata.
Args:
dataset (str): Name of the main dataset.
subset (str): Name of the subset to load.
split (str): Data split (e.g., 'train', 'test').
batch_size (int): Number of items to fetch per batch.
shard_value (int): The specific shard index to load (-1 for all).
only_rgb (bool): Whether to load only RGB images.
Returns:
tuple: Updated components/values for the Gradio interface:
(updated_shard_slider, initial_gallery_images, initial_metadata_table).
"""
global app_state
print(f"Loading dataset: {dataset}, subset: {subset}, split: {split}, shard: {shard_value}")
try:
nshards = ev.get_nshards(subset) # Get total number of shards for the subset
except Exception as e:
raise gr.Error(f"Failed to get shard count for subset '{subset}': {e}")
# Determine which shards to load
if shard_value == -1:
shards_to_load = None # Load all shards
print("Loading all shards.")
else:
# Ensure the selected shard is within the valid range
shard_value = max(0, min(shard_value, nshards - 1))
shards_to_load = [shard_value]
print(f"Loading shard {shard_value} out of {nshards}.")
# Load the dataset based on DEBUG configuration
ds = None
if DEBUG == "random":
print("DEBUG MODE: Using random data.")
ds = range(batch_size * 2) # Generate enough for a couple of batches
elif DEBUG == "samples":
print("DEBUG MODE: Using local Parquet samples.")
try:
ds = ev.load_parquet(subset, batch_size=batch_size * 2)
except Exception as e:
raise gr.Error(f"Failed to load Parquet samples for '{subset}': {e}")
elif not DEBUG:
print("Loading dataset from source...")
try:
ds = ev.load_dataset(subset, dataset=dataset, split=split, shards=shards_to_load, cache_dir="dataset")
except Exception as e:
raise gr.Error(f"Failed to load dataset '{dataset}/{subset}': {e}")
else:
raise ValueError("Invalid DEBUG setting.")
# Create an iterator and store it in the state
app_state["dsi"] = iter(ds)
app_state["subset"] = subset
print("Dataset loaded, fetching initial batch...")
images, metadata_df = get_images(batch_size, only_rgb)
updated_shard_slider = gr.Slider(label=f"Shard (0 to {nshards-1})", value=shard_value, maximum=nshards -1 if nshards > 0 else 0)
return updated_shard_slider, images, metadata_df
def get_images(batch_size, only_rgb):
"""
Fetches the next batch of images and metadata from the current dataset iterator.
Args:
batch_size (int): Number of items to fetch.
only_rgb (bool): Whether to load only RGB images.
Returns:
tuple: (list_of_images, pandas_dataframe_of_metadata)
"""
global app_state
if app_state.get("dsi") is None or app_state.get("subset") is None:
raise gr.Error("You need to load a Dataset first using the 'Load' button.")
subset = app_state["subset"]
dsi = app_state["dsi"]
images = []
metadatas = []
print(f"Fetching next {batch_size} images...")
for i in tqdm.trange(batch_size, desc=f"Getting images for {subset}"):
if DEBUG == "random":
# Generate random image and basic metadata for debugging
img_rgb = np.random.randint(0, 255, (384, 384, 3), dtype=np.uint8)
images.append(img_rgb)
if not only_rgb:
img_other = np.random.randint(0, 255, (100, 100, 3), dtype=np.uint8)
images.append(img_other)
metadatas.append({"id": f"random_{i}", "bounds": [[1, 1, 4, 4]], "map": "N/A"})
else:
try:
# Get the next item from the iterator
item = next(dsi)
except StopIteration:
print("End of dataset iterator reached.")
gr.Warning("End of dataset shard reached.") # Inform user
break # Stop fetching if iterator is exhausted
try:
# Process the item to extract images and metadata
item_data = ev.item_to_images(subset, item)
metadata = item_data["metadata"]
# Append images based on subset type and only_rgb flag
if subset == "satellogic":
images.extend(item_data.get("rgb", []))
if not only_rgb:
images.extend(item_data.get("1m", []))
elif subset == "sentinel_1":
images.extend(item_data.get("10m", []))
elif subset == "sentinel_2":
images.extend(item_data.get("rgb", []))
if not only_rgb:
images.extend(item_data.get("10m", []))
images.extend(item_data.get("20m", []))
images.extend(item_data.get("scl", []))
elif subset == "neon":
images.extend(item_data.get("rgb", []))
if not only_rgb:
images.extend(item_data.get("chm", []))
images.extend(item_data.get("1m", []))
else:
# Handle potential unknown subsets gracefully
print(f"Warning: Image extraction logic not defined for subset '{subset}'. Trying 'rgb'.")
images.extend(item_data.get("rgb", []))
map_link = utils.get_google_map_link(item_data, subset)
metadata["map"] = f'<a href="{map_link}" target="_blank">🧭 View Map</a>' if map_link else "N/A"
metadatas.append(metadata)
except Exception as e:
print(f"Error processing item: {item}. Error: {e}")
metadatas.append({"id": item.get("id", "Error"), "error": str(e), "map": "Error"})
print(f"Fetched {len(metadatas)} items for the batch.")
# Convert metadata list to a Pandas DataFrame
metadata_df = DataFrame(metadatas)
return images, metadata_df
def update_gallery_columns(columns):
"""
Updates the number of columns in the image gallery.
Args:
columns (int): The desired number of columns.
Returns:
dict: A dictionary mapping the gallery component to its updated state.
In Gradio 5, we return the component constructor with new args.
"""
print(f"Updating gallery columns to: {columns}")
# Ensure columns is at least 1
columns = max(1, int(columns))
# Return the updated component configuration
return gr.Gallery(columns=columns)
if __name__ == "__main__":
with gr.Blocks(title="EarthView Viewer v5 fork", fill_height=True, theme=gr.themes.Default()) as demo:
gr.Markdown(f"# Viewer for [{ev.DATASET}](https://huggingface.co./datasets/satellogic/EarthView) Dataset (Gradio 5)")
with gr.Row():
with gr.Column(scale=1):
dataset_name = gr.Textbox(label="Dataset", value=ev.DATASET, interactive=False)
subset_select = gr.Dropdown(choices=ev.get_subsets(), label="Subset", value="satellogic")
split_name = gr.Textbox(label="Split", value="train")
initial_shard_input = gr.Number(label="Load Shard", value=10, minimum=-1, step=1, info="Enter shard index (0-based) or -1 for all shards")
only_rgb_checkbox = gr.Checkbox(label="Only RGB Images", value=True)
batch_size_input = gr.Number(value=10, label="Batch Size", minimum=1, step=1)
load_button = gr.Button("Load Dataset / Shard", variant="primary")
shard_slider = gr.Slider(label="Shard", minimum=0, maximum=1, step=1, value=0)
gallery_columns_input = gr.Number(value=5, label="Gallery Columns", minimum=1, step=1)
next_batch_button = gr.Button("Next Batch (from current shard)", scale=0)
with gr.Column(scale=4):
image_gallery = gr.Gallery(
label="Dataset Images",
interactive=False,
object_fit="scale-down",
columns=5,
height="600px",
show_label=False
)
metadata_table = gr.DataFrame(datatype="html", wrap=True)
load_button.click(
fn=open_dataset,
inputs=[dataset_name, subset_select, split_name, batch_size_input, initial_shard_input, only_rgb_checkbox],
outputs=[shard_slider, image_gallery, metadata_table]
)
shard_slider.release(
fn=open_dataset,
inputs=[dataset_name, subset_select, split_name, batch_size_input, shard_slider, only_rgb_checkbox],
outputs=[shard_slider, image_gallery, metadata_table]
)
gallery_columns_input.change(
fn=update_gallery_columns,
inputs=[gallery_columns_input],
outputs=[image_gallery]
)
next_batch_button.click(
fn=get_images,
inputs=[batch_size_input, only_rgb_checkbox],
outputs=[image_gallery, metadata_table]
)
demo.launch(show_api=False)