iv0's picture
Fix typo in app.py
8b5bf2b verified
raw
history blame
4.54 kB
from datasets import load_dataset
from functools import partial
from pandas import DataFrame
import earthview as ev
import gradio as gr
import tqdm
import os
DEBUG = False # False, "random", "samples"
if DEBUG == "random":
import numpy as np
def open_dataset(dataset, subset, split, batch_size, shard, only_rgb, state):
nshards = ev.get_nshards(subset)
if shard == -1:
shards = None
else:
shards = [shard]
if DEBUG == "random":
ds = range(batch_size)
elif DEBUG == "samples":
ds = ev.load_parquet(subset, batch_size=batch_size)
elif not DEBUG:
ds = ev.load_dataset(subset, dataset=dataset, split=split, shards=shards, cache_dir="dataset")
dsi = iter(ds)
state["subset"] = subset
state["dsi"] = dsi
return (
gr.update(label=f"Shard (max {nshards})", value=shard, maximum=nshards),
*get_images(batch_size, only_rgb, state),
state
)
def get_images(batch_size, only_rgb, state):
try:
subset = state["subset"]
except KeyError:
raise gr.Error("You need to load a Dataset first")
images = []
metadatas = []
for i in tqdm.trange(batch_size, desc=f"Getting images"):
if DEBUG == "random":
images.append(np.random.randint(0,255,(384,384,3)))
if not only_rgb:
images.append(np.random.randint(0,255,(100,100,3)))
metadatas.append({"bounds":[[1,1,4,4]], })
else:
try:
item = next(state["dsi"])
except StopIteration:
break
metadata = item["metadata"]
item = ev.item_to_images(subset, item)
if subset == "satellogic":
images.extend(item["rgb"])
if not only_rgb:
images.extend(item["1m"])
if subset == "sentinel_1":
images.extend(item["10m"])
if subset == "neon":
images.extend(item["rgb"])
if not only_rgb:
images.extend(item["chm"])
images.extend(item["1m"])
metadatas.append(item["metadata"])
return images, DataFrame(metadatas)
def update_shape(rows, columns):
return gr.update(rows=rows, columns=columns)
def new_state():
return gr.State({})
if __name__ == "__main__":
with gr.Blocks(title="EarthView Viewer", fill_height = True) as demo:
state = new_state()
gr.Markdown(f"# Viewer for [{ev.DATASET}](https://huggingface.co./datasets/satellogic/EarthView) Dataset")
batch_size = gr.Number(10, label = "Batch Size", render=False)
shard = gr.Slider(label="Shard", minimum=0, maximum=10000, step=1, render=False)
table = gr.DataFrame(render = False)
# headers=["Index","TimeStamp","Bounds","CRS"],
gallery = gr.Gallery(
label=ev.DATASET,
interactive=False,
object_fit="scale-down",
columns=5, rows=2, render=False)
with gr.Row():
dataset = gr.Textbox(label="Dataset", value=ev.DATASET, interactive=False)
subset = gr.Dropdown(choices=ev.get_subsets(), label="Subset", value="satellogic", )
split = gr.Textbox(label="Split", value="train")
initial_shard = gr.Number(label = "Initial shard", value=10, info="-1 for whole dataset")
only_rgb = gr.Checkbox(label="Only RGB", value=True)
gr.Button("Load (minutes)").click(
open_dataset,
inputs=[dataset, subset, split, batch_size, initial_shard, only_rgb, state],
outputs=[shard, gallery, table, state])
gallery.render()
with gr.Row():
batch_size.render()
rows = gr.Number(2, label="Rows")
columns = gr.Number(5, label="Columns")
rows.change(update_shape, [rows, columns], [gallery])
columns.change(update_shape, [rows, columns], [gallery])
with gr.Row():
shard.render()
shard.release(
open_dataset,
inputs=[dataset, subset, split, batch_size, shard, only_rgb, state],
outputs=[shard, gallery, table, state])
btn = gr.Button("Next Batch (same shard)", scale=0)
btn.click(get_images, [batch_size, only_rgb, state], [gallery, table])
btn.click()
table.render()
demo.launch(show_api=False)