gera-richarte commited on
Commit
5129aaa
·
1 Parent(s): 3a8d69b

moved code to earthview.py (and other things)

Browse files
Files changed (2) hide show
  1. app.py +42 -44
  2. earthview.py +75 -9
app.py CHANGED
@@ -1,4 +1,4 @@
1
- from datasets import load_dataset, get_dataset_config_names
2
  from functools import partial
3
  from pandas import DataFrame
4
  import earthview as ev
@@ -6,73 +6,70 @@ import gradio as gr
6
  import tqdm
7
  import os
8
 
9
- DEBUG = False
10
 
11
- if DEBUG:
12
  import numpy as np
13
 
14
- def open_dataset(dataset, set_name, split, batch_size, state, shard = -1):
15
- if shard == -1:
16
- # Trick to open the whole dataset
17
- data_files = None
18
- shards = 100
19
- else:
20
- config = ev.sets[set_name].get("config", set_name)
21
- shards = ev.sets[set_name]["shards"]
22
- path = ev.sets[set_name].get("path", set_name)
23
- data_files = {"train":[f"{path}/{split}-{shard:05d}-of-{shards:05d}.parquet"]}
24
 
25
- if DEBUG:
26
- dsi = range(100)
 
 
27
  else:
28
- ds = load_dataset(
29
- dataset,
30
- config,
31
- split=split,
32
- cache_dir="dataset",
33
- data_files=data_files,
34
- streaming=True,
35
- token=os.environ.get("HF_TOKEN", None))
36
 
37
- dsi = iter(ds)
38
 
39
- state["config"] = config
40
  state["dsi"] = dsi
41
  return (
42
- gr.update(label=f"Shards (max {shards})", value=shard, maximum=shards),
43
- *get_images(batch_size, state),
44
  state
45
  )
46
 
47
- def get_images(batch_size, state):
48
- config = state["config"]
49
 
50
  images = []
51
  metadatas = []
52
 
53
  for i in tqdm.trange(batch_size, desc=f"Getting images"):
54
- if DEBUG:
55
  images.append(np.random.randint(0,255,(384,384,3)))
56
- images.append(np.random.randint(0,255,(100,100,3)))
 
57
 
58
- metadata = {"bounds":[[1,1,4,4]], }
59
  else:
60
  try:
61
  item = next(state["dsi"])
62
  except StopIteration:
63
  break
64
  metadata = item["metadata"]
65
- item = ev.item_to_images(config, item)
66
 
67
- if config == "satellogic":
68
  images.extend(item["rgb"])
69
- # images.extend(item["1m"])
70
- if config == "sentinel_1":
 
71
  images.extend(item["10m"])
72
- if config == "default":
73
  images.extend(item["rgb"])
74
- images.extend(item["chm"])
75
- images.extend(item["1m"])
 
 
76
  metadatas.append(item["metadata"])
77
 
78
  return images, DataFrame(metadatas)
@@ -84,7 +81,7 @@ def new_state():
84
  return gr.State({})
85
 
86
  if __name__ == "__main__":
87
- with gr.Blocks(title="Dataset Explorer", fill_height = True) as demo:
88
  state = new_state()
89
 
90
  gr.Markdown(f"# Viewer for [{ev.DATASET}](https://huggingface.co/datasets/satellogic/EarthView) Dataset")
@@ -101,13 +98,14 @@ if __name__ == "__main__":
101
 
102
  with gr.Row():
103
  dataset = gr.Textbox(label="Dataset", value=ev.DATASET, interactive=False)
104
- config = gr.Dropdown(choices=ev.get_sets(), label="Config", value="satellogic", )
105
  split = gr.Textbox(label="Split", value="train")
106
  initial_shard = gr.Number(label = "Initial shard", value=10, info="-1 for whole dataset")
 
107
 
108
  gr.Button("Load (minutes)").click(
109
  open_dataset,
110
- inputs=[dataset, config, split, batch_size, state, initial_shard],
111
  outputs=[shard, gallery, table, state])
112
 
113
  gallery.render()
@@ -125,11 +123,11 @@ if __name__ == "__main__":
125
  shard.render()
126
  shard.release(
127
  open_dataset,
128
- inputs=[dataset, config, split, batch_size, state, shard],
129
  outputs=[shard, gallery, table, state])
130
 
131
  btn = gr.Button("Next Batch (same shard)", scale=0)
132
- btn.click(get_images, [batch_size, state], [gallery, table])
133
  btn.click()
134
 
135
  table.render()
 
1
+ from datasets import load_dataset
2
  from functools import partial
3
  from pandas import DataFrame
4
  import earthview as ev
 
6
  import tqdm
7
  import os
8
 
9
+ DEBUG = "samples" # False, "random", "samples"
10
 
11
+ if DEBUG == "random":
12
  import numpy as np
13
 
14
+ def open_dataset(dataset, subset, split, batch_size, shard, only_rgb, state):
 
 
 
 
 
 
 
 
 
15
 
16
+ nshards = ev.get_nshards(subset)
17
+
18
+ if shard == -1:
19
+ shards = None
20
  else:
21
+ shards = [shard]
22
+
23
+ if DEBUG == "random":
24
+ ds = range(batch_size)
25
+ elif DEBUG == "samples":
26
+ ds = ev.load_parquet(subset, batch_size=batch_size)
27
+ elif not DEBUG:
28
+ ds = ev.load_dataset(subset, dataset=dataset, split=split, shards=shards, cache_dir="dataset")
29
 
30
+ dsi = iter(ds)
31
 
32
+ state["subset"] = subset
33
  state["dsi"] = dsi
34
  return (
35
+ gr.update(label=f"Shard (max {nshards})", value=shard, maximum=nshards),
36
+ *get_images(batch_size, only_rgb, state),
37
  state
38
  )
39
 
40
+ def get_images(batch_size, only_rgb, state):
41
+ subset = state["subset"]
42
 
43
  images = []
44
  metadatas = []
45
 
46
  for i in tqdm.trange(batch_size, desc=f"Getting images"):
47
+ if DEBUG == "random":
48
  images.append(np.random.randint(0,255,(384,384,3)))
49
+ if not only_rgb:
50
+ images.append(np.random.randint(0,255,(100,100,3)))
51
 
52
+ metadatas.append({"bounds":[[1,1,4,4]], })
53
  else:
54
  try:
55
  item = next(state["dsi"])
56
  except StopIteration:
57
  break
58
  metadata = item["metadata"]
59
+ item = ev.item_to_images(subset, item)
60
 
61
+ if subset == "satellogic":
62
  images.extend(item["rgb"])
63
+ if not only_rgb:
64
+ images.extend(item["1m"])
65
+ if subset == "sentinel_1":
66
  images.extend(item["10m"])
67
+ if subset == "neon":
68
  images.extend(item["rgb"])
69
+ if not only_rgb:
70
+ images.extend(item["chm"])
71
+ images.extend(item["1m"])
72
+
73
  metadatas.append(item["metadata"])
74
 
75
  return images, DataFrame(metadatas)
 
81
  return gr.State({})
82
 
83
  if __name__ == "__main__":
84
+ with gr.Blocks(title="EarthView Viewer", fill_height = True) as demo:
85
  state = new_state()
86
 
87
  gr.Markdown(f"# Viewer for [{ev.DATASET}](https://huggingface.co/datasets/satellogic/EarthView) Dataset")
 
98
 
99
  with gr.Row():
100
  dataset = gr.Textbox(label="Dataset", value=ev.DATASET, interactive=False)
101
+ subset = gr.Dropdown(choices=ev.get_subsets(), label="Subset", value="satellogic", )
102
  split = gr.Textbox(label="Split", value="train")
103
  initial_shard = gr.Number(label = "Initial shard", value=10, info="-1 for whole dataset")
104
+ only_rgb = gr.Checkbox(label="Only RGB", value=True)
105
 
106
  gr.Button("Load (minutes)").click(
107
  open_dataset,
108
+ inputs=[dataset, subset, split, batch_size, initial_shard, only_rgb, state],
109
  outputs=[shard, gallery, table, state])
110
 
111
  gallery.render()
 
123
  shard.render()
124
  shard.release(
125
  open_dataset,
126
+ inputs=[dataset, subset, split, batch_size, shard, only_rgb, state],
127
  outputs=[shard, gallery, table, state])
128
 
129
  btn = gr.Button("Next Batch (same shard)", scale=0)
130
+ btn.click(get_images, [batch_size, only_rgb, state], [gallery, table])
131
  btn.click()
132
 
133
  table.render()
earthview.py CHANGED
@@ -1,7 +1,13 @@
 
 
1
  from PIL import Image
2
  import numpy as np
3
  import json
4
 
 
 
 
 
5
  DATASET = "satellogic/EarthView"
6
 
7
  sets = {
@@ -18,10 +24,58 @@ sets = {
18
  }
19
  }
20
 
21
- def get_sets():
22
  return sets.keys()
23
 
24
- def item_to_images(config, item):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  metadata = item["metadata"]
26
  if type(metadata) == str:
27
  metadata = json.loads(metadata)
@@ -33,16 +87,24 @@ def item_to_images(config, item):
33
  }
34
  item["metadata"] = metadata
35
 
36
- if config == "satellogic":
37
- item["rgb"] = [
38
- Image.fromarray(image.transpose(1,2,0))
39
- for image in item["rgb"]
40
- ]
 
 
 
 
 
 
 
41
  item["1m"] = [
42
  Image.fromarray(image[0,:,:])
43
  for image in item["1m"]
44
  ]
45
- elif config == "sentinel_1":
 
46
  # Mapping of V and H to RGB. May not be correct
47
  # https://gis.stackexchange.com/questions/400726/creating-composite-rgb-images-from-sentinel-1-channels
48
  i10m = item["10m"]
@@ -59,7 +121,8 @@ def item_to_images(config, item):
59
  Image.fromarray(image.transpose(1,2,0))
60
  for image in i10m
61
  ]
62
- elif config == "default":
 
63
  item["rgb"] = [
64
  Image.fromarray(image.transpose(1,2,0))
65
  for image in item["rgb"]
@@ -80,5 +143,8 @@ def item_to_images(config, item):
80
  ,2).astype("uint8"))
81
  for image in item["1m"]
82
  ]
 
 
 
83
  return item
84
 
 
1
+ from datasets import load_dataset as _load_dataset
2
+ from os import environ
3
  from PIL import Image
4
  import numpy as np
5
  import json
6
 
7
+ from pyarrow.parquet import ParquetFile
8
+ from pyarrow import Table as pa_Table
9
+ from datasets import Dataset
10
+
11
  DATASET = "satellogic/EarthView"
12
 
13
  sets = {
 
24
  }
25
  }
26
 
27
+ def get_subsets():
28
  return sets.keys()
29
 
30
+ def get_nshards(subset):
31
+ return sets[subset]["shards"]
32
+
33
+ def get_path(subset):
34
+ return sets[subset].get("path", subset)
35
+
36
+ def get_config(subset):
37
+ return sets[subset].get("config", subset)
38
+
39
+ def load_dataset(subset, dataset="satellogic/EarthView", split="train", shards = None, streaming=True, **kwargs):
40
+ config = get_config(subset)
41
+ nshards = get_nshards(subset)
42
+ path = get_path(subset)
43
+ if shards is None:
44
+ data_files = None
45
+ else:
46
+ data_files = [f"{path}/{split}-{shard:05d}-of-{nshards:05d}.parquet" for shard in shards]
47
+ data_files = {split: data_files}
48
+
49
+ ds = _load_dataset(
50
+ path=dataset,
51
+ name=config,
52
+ save_infos=True,
53
+ split=split,
54
+ data_files=data_files,
55
+ streaming=streaming,
56
+ token=environ.get("HF_TOKEN", None),
57
+ **kwargs)
58
+
59
+ return ds
60
+
61
+ def load_parquet(subset_or_filename, batch_size=100):
62
+ if subset_or_filename in get_subsets():
63
+ pqfile = ParquetFile(f"dataset/{subset_or_filename}/sample.parquet")
64
+ else:
65
+ pqfile = subset_or_filename
66
+
67
+ batch = pqfile.iter_batches(batch_size=batch_size)
68
+ return Dataset(pa_Table.from_batches(batch))
69
+
70
+ def item_to_images(subset, item):
71
+ """
72
+ Converts the images within an item (arrays), as retrieved from the dataset to proper PIL.Image
73
+
74
+ subset: The name of the Subset, one of "satellogic", "default", "sentinel-1"
75
+ item: The item as retrieved from the subset
76
+
77
+ returns the item, with arrays converted to PIL.Image
78
+ """
79
  metadata = item["metadata"]
80
  if type(metadata) == str:
81
  metadata = json.loads(metadata)
 
87
  }
88
  item["metadata"] = metadata
89
 
90
+ if subset == "satellogic":
91
+ # item["rgb"] = [
92
+ # Image.fromarray(np.average(image.transpose(1,2,0), 2).astype("uint8"))
93
+ # for image in item["rgb"]
94
+ # ]
95
+ rgbs = []
96
+ for rgb in item["rgb"]:
97
+ rgbs.append(Image.fromarray(rgb.transpose(1,2,0)))
98
+ # rgbs.append(Image.fromarray(rgb[0,:,:])) # Red
99
+ # rgbs.append(Image.fromarray(rgb[1,:,:])) # Green
100
+ # rgbs.append(Image.fromarray(rgb[2,:,:])) # Blue
101
+ item["rgb"] = rgbs
102
  item["1m"] = [
103
  Image.fromarray(image[0,:,:])
104
  for image in item["1m"]
105
  ]
106
+ count = len(item["1m"])
107
+ elif subset == "sentinel_1":
108
  # Mapping of V and H to RGB. May not be correct
109
  # https://gis.stackexchange.com/questions/400726/creating-composite-rgb-images-from-sentinel-1-channels
110
  i10m = item["10m"]
 
121
  Image.fromarray(image.transpose(1,2,0))
122
  for image in i10m
123
  ]
124
+ count = len(item["10m"])
125
+ elif subset == "neon":
126
  item["rgb"] = [
127
  Image.fromarray(image.transpose(1,2,0))
128
  for image in item["rgb"]
 
143
  ,2).astype("uint8"))
144
  for image in item["1m"]
145
  ]
146
+ count = len(item["rgb"])
147
+
148
+ item["metadata"]["count"] = count
149
  return item
150