LeroyWaa commited on
Commit
246c106
·
1 Parent(s): 1d541d9
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. README.md +40 -13
  2. app.py +69 -0
  3. build.sh +12 -0
  4. common/__init__.py +0 -0
  5. common/calculate_fvd.py +80 -0
  6. common/data_sampler.py +334 -0
  7. common/eval_utils.py +105 -0
  8. common/fid_score.py +382 -0
  9. common/fvd/styleganv/fvd.py +90 -0
  10. common/fvd/styleganv/i3d_torchscript.pt +3 -0
  11. common/fvd/videogpt/fvd.py +137 -0
  12. common/fvd/videogpt/i3d_pretrained_400.pt +3 -0
  13. common/fvd/videogpt/pytorch_i3d.py +322 -0
  14. common/inception.py +344 -0
  15. common/plot/__init__.py +0 -0
  16. common/plot/aggregated_output.csv +18 -0
  17. common/plot/plot_arch_ablation.py +60 -0
  18. common/plot/plot_arch_ablation_deltapsnr.py +49 -0
  19. common/plot/plot_dataset_scale.py +69 -0
  20. common/plot/plot_dataset_traj_scale.py +48 -0
  21. common/plot/plot_dynamics_ablation.py +56 -0
  22. common/plot/plot_dynamics_ablation_deltapsnr.py +51 -0
  23. common/plot/plot_from_wandb.py +185 -0
  24. common/plot/plot_from_wandb_singledataset.py +144 -0
  25. common/plot/plot_model_scale.py +64 -0
  26. common/plot/plot_pretrain_ablation.py +44 -0
  27. common/plot/plot_pretrain_ablation_mar.py +45 -0
  28. cont_data.py +245 -0
  29. data.py +240 -0
  30. datasets/.DS_Store +0 -0
  31. datasets/__init__.py +0 -0
  32. datasets/encode_extern_dataset.py +291 -0
  33. datasets/encode_openx_dataset.py +459 -0
  34. datasets/extern/__init__.py +0 -0
  35. datasets/extern/ego4d.py +193 -0
  36. datasets/extern/egoexo4d.py +186 -0
  37. datasets/extern/epic_kitchen.py +115 -0
  38. datasets/extern/frodobot.py +128 -0
  39. datasets/extern/robomimic.py +108 -0
  40. datasets/merge_shards.py +113 -0
  41. datasets/utils.py +244 -0
  42. experiments/.DS_Store +0 -0
  43. experiments/datasplit/.DS_Store +0 -0
  44. experiments/datasplit/dataset1.yaml +2 -0
  45. experiments/datasplit/dataset10.yaml +11 -0
  46. experiments/datasplit/dataset15.yaml +16 -0
  47. experiments/datasplit/dataset15_vae.yaml +16 -0
  48. experiments/datasplit/dataset20.yaml +21 -0
  49. experiments/datasplit/dataset20_vae.yaml +21 -0
  50. experiments/datasplit/dataset25.yaml +26 -0
README.md CHANGED
@@ -1,13 +1,40 @@
1
- ---
2
- title: Hma
3
- emoji: 📉
4
- colorFrom: gray
5
- colorTo: indigo
6
- sdk: gradio
7
- sdk_version: 5.6.0
8
- app_file: app.py
9
- pinned: false
10
- license: mit
11
- ---
12
-
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Heterogeneous World Modeling with Actions
2
+
3
+ Progress in video generation may soon make it possible to evaluate robot policies in a completely learned world model.
4
+ Modified from [here](https://github.com/1x-technologies/1xgpt)
5
+
6
+ ## Getting Started
7
+ We require `Python 3.10` or later. This code was tested with `Python 3.10.12`.
8
+
9
+ ```
10
+ # Install dependencies and download data
11
+ ./build.sh
12
+
13
+ # Source the Python environment
14
+ source venv/bin/activate
15
+ ```
16
+
17
+ ## File Structures
18
+ ```angular2html
19
+ ├── ...
20
+ ├── HPT-Video
21
+ | |── data # cached token datasets and model checkpoints
22
+ | |── genie # main modeling code
23
+ | | |── diffusion # diffusion loss related
24
+ | | |── evaluate.py # evaluate a trained model
25
+ | | |── st_mar.py # spatial time MAR
26
+ | | |── generate.py # generate tokens from trained model
27
+ | | |── st_maskgit.py # spatial-time maskgit
28
+ | |── magvit # magvit code
29
+ | |── sim # simulation related codebase
30
+ | |── experiments
31
+ | | |── cmd # handy commands
32
+ | | |── datasplit # dataset split
33
+ | | |── scripts # ablation and training scripts.
34
+ | |── common # common utility and plot scripts
35
+ | |── train.py # train using magvit
36
+ | |── train_diffusion.py # train using mar
37
+ | |── train_multi.py # train on multiple datasets jointly
38
+ | |── visualize.py # visualize generated tokens
39
+ └── ...
40
+ ```
app.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ from PIL import Image
4
+ import cv2
5
+ from sim.simulator import GenieSimulator
6
+
7
+ RES = 512
8
+ image = Image.open("sim/assets/langtable_prompt/frame_06.png")
9
+ genie = GenieSimulator(
10
+ image_encoder_type='temporalvae',
11
+ image_encoder_ckpt='stabilityai/stable-video-diffusion-img2vid',
12
+ quantize=False,
13
+ backbone_type='stmar',
14
+ backbone_ckpt='data/mar_ckpt/langtable',
15
+ prompt_horizon=11,
16
+ action_stride=1,
17
+ domain='language_table',
18
+ )
19
+ prompt_image = np.tile(
20
+ np.array(image), (genie.prompt_horizon, 1, 1, 1)
21
+ ).astype(np.uint8)
22
+ prompt_action = np.zeros(
23
+ (genie.prompt_horizon, genie.action_stride, 2)
24
+ ).astype(np.float32)
25
+ genie.set_initial_state((prompt_image, prompt_action))
26
+ image = genie.reset()
27
+ image = cv2.resize(image, (RES, RES))
28
+ image = Image.fromarray(image)
29
+
30
+ # Example model: takes a direction and returns a random image
31
+ def model(direction: str, genie=genie):
32
+ if direction == 'right':
33
+ action = np.array([0, 0.05])
34
+ elif direction == 'left':
35
+ action = np.array([0, -0.05])
36
+ elif direction == 'down':
37
+ action = np.array([0.05, 0])
38
+ elif direction == 'up':
39
+ action = np.array([-0.05, 0])
40
+ else:
41
+ raise ValueError(f"Invalid direction: {direction}")
42
+ next_image = genie.step(action)['pred_next_frame']
43
+ next_image = cv2.resize(next_image, (RES, RES))
44
+ return Image.fromarray(next_image)
45
+
46
+ # Gradio function to handle user input
47
+ def handle_input(direction):
48
+ print(f"User clicked: {direction}")
49
+ new_image = model(direction) # Get a new image from the model
50
+ return new_image
51
+
52
+ if __name__ == '__main__':
53
+ with gr.Blocks() as demo:
54
+ with gr.Row():
55
+ image_display = gr.Image(value=image, type="pil", label="Generated Image")
56
+ with gr.Row():
57
+ up = gr.Button("↑ Up")
58
+ with gr.Row():
59
+ left = gr.Button("← Left")
60
+ down = gr.Button("↓ Down")
61
+ right = gr.Button("→ Right")
62
+
63
+ # Define button interactions
64
+ up.click(fn=lambda: handle_input("up"), outputs=image_display)
65
+ down.click(fn=lambda: handle_input("down"), outputs=image_display)
66
+ left.click(fn=lambda: handle_input("left"), outputs=image_display)
67
+ right.click(fn=lambda: handle_input("right"), outputs=image_display)
68
+
69
+ demo.launch()
build.sh ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/bash
2
+
3
+ python3 -m venv venv
4
+ source venv/bin/activate
5
+ python -m pip install -r requirements.txt
6
+ FLASH_ATTENTION_SKIP_CUDA_BUILD=TRUE python -m pip install flash-attn==2.5.8 --no-build-isolation
7
+
8
+ # Download datasets to data/train_v1.0, data/val_v1.0
9
+ huggingface-cli download 1x-technologies/worldmodel --repo-type dataset --local-dir data
10
+
11
+ mv data/val_v1.1 data/1x_humanoid_magvit_traj1000_val
12
+ mv data/train_v1.1 data/1x_humanoid_magvit_traj1000_train
common/__init__.py ADDED
File without changes
common/calculate_fvd.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Code adapted from https://github.com/JunyaoHu/common_metrics_on_video_quality
2
+ import numpy as np
3
+ import torch
4
+ from tqdm import tqdm
5
+
6
+ def trans(x):
7
+ # if greyscale images add channel
8
+ if x.shape[-3] == 1:
9
+ x = x.repeat(1, 1, 3, 1, 1)
10
+
11
+ # permute BTCHW -> BCTHW
12
+ x = x.permute(0, 2, 1, 3, 4)
13
+
14
+ return x
15
+
16
+ def calculate_fvd(videos1, videos2, device="cuda", method='styleganv'):
17
+
18
+ if method == 'styleganv':
19
+ from .fvd.styleganv.fvd import get_fvd_feats, frechet_distance, load_i3d_pretrained
20
+ elif method == 'videogpt':
21
+ from .fvd.videogpt.fvd import load_i3d_pretrained
22
+ from .fvd.videogpt.fvd import get_fvd_logits as get_fvd_feats
23
+ from .fvd.videogpt.fvd import frechet_distance
24
+
25
+
26
+ # videos [batch_size, timestamps, channel, h, w]
27
+
28
+ assert videos1.shape == videos2.shape
29
+
30
+ i3d = load_i3d_pretrained(device=device)
31
+ fvd_results = []
32
+
33
+ # support grayscale input, if grayscale -> channel*3
34
+ # BTCHW -> BCTHW
35
+ # videos -> [batch_size, channel, timestamps, h, w]
36
+
37
+ videos1 = trans(videos1)
38
+ videos2 = trans(videos2)
39
+
40
+ # fvd_results = {}
41
+
42
+ # for calculate FVD, each clip_timestamp must >= 10
43
+ for clip_timestamp in tqdm(range(10, videos1.shape[-3]+1)):
44
+ # print("clip_timestamp", clip_timestamp)
45
+ # get a video clip
46
+ # videos_clip [batch_size, channel, timestamps[:clip], h, w]
47
+ videos_clip1 = videos1[:, :, : clip_timestamp]
48
+ videos_clip2 = videos2[:, :, : clip_timestamp]
49
+
50
+ # get FVD features
51
+ feats1 = get_fvd_feats(videos_clip1, i3d=i3d, device=device)
52
+ feats2 = get_fvd_feats(videos_clip2, i3d=i3d, device=device)
53
+
54
+ # calculate FVD when timestamps[:clip]
55
+ fvd_results.append(frechet_distance(feats1, feats2))
56
+
57
+
58
+ return fvd_results[-1] # only the last step
59
+
60
+ # test code / using example
61
+
62
+ def main():
63
+ NUMBER_OF_VIDEOS = 8
64
+ VIDEO_LENGTH = 50
65
+ CHANNEL = 3
66
+ SIZE = 64
67
+ videos1 = torch.zeros(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False)
68
+ videos2 = torch.ones(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False)
69
+ device = torch.device("cuda")
70
+ # device = torch.device("cpu")
71
+
72
+ import json
73
+ result = calculate_fvd(videos1, videos2, device, method='videogpt')
74
+ print(json.dumps(result, indent=4))
75
+
76
+ result = calculate_fvd(videos1, videos2, device, method='styleganv')
77
+ print(json.dumps(result, indent=4))
78
+
79
+ if __name__ == "__main__":
80
+ main()
common/data_sampler.py ADDED
@@ -0,0 +1,334 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import os
3
+ import random
4
+ from operator import itemgetter
5
+ from typing import Optional, List
6
+
7
+ import matplotlib.pyplot as plt
8
+ import numpy as np
9
+ import torch
10
+ import torch.distributed as dist
11
+ from PIL import Image
12
+ from torch.utils.data import Dataset, Sampler
13
+ from torch.utils.data import Sampler, DistributedSampler
14
+
15
+
16
+ def chunk_indices(indices: list[int], size: int) -> tuple[torch.Tensor, ...]:
17
+ return torch.split(torch.tensor(indices), size)
18
+
19
+
20
+ class CombinedDataLoader:
21
+ def __init__(self, dataloaders, reinit=True):
22
+ """
23
+ :param dataloaders: list of pytorch dataloaders
24
+ """
25
+ self.dataloaders = dataloaders
26
+ self.reinit = reinit
27
+ self.dataloader_idx = 0
28
+ self.loader_iters = [iter(dataloader) for dataloader in self.dataloaders]
29
+
30
+ def __iter__(self):
31
+ return self
32
+
33
+ def __next__(self):
34
+ # Choose a dataloader based on weights
35
+ chosen_loader_iter = self.loader_iters[self.dataloader_idx]
36
+
37
+ try:
38
+ data = next(chosen_loader_iter)
39
+ return data
40
+ except StopIteration:
41
+ # Handle case where a dataloader is exhausted. Reinitialize the iterator.
42
+ self.dataloader_idx = self.dataloader_idx + 1
43
+ if self.dataloader_idx == len(self.loader_iters):
44
+ self.dataloader_idx = 0 # reset
45
+ raise StopIteration
46
+ return self.__next__()
47
+
48
+ def __len__(self):
49
+ return sum([len(dataloader) for dataloader in self.dataloaders])
50
+
51
+
52
+ class CombinedBatchSampler(torch.utils.data.Sampler):
53
+ # For validation dataloaders.
54
+ def __init__(self, datasets, batch_size, num_processes=1, shuffle=False):
55
+ super().__init__() # no-op
56
+ prev_idx = 0
57
+ all_batches = []
58
+
59
+ for dataset in datasets:
60
+ indices = list(range(prev_idx, prev_idx + len(dataset)))
61
+ if shuffle:
62
+ random.shuffle(indices)
63
+
64
+ # exclude remainer, if necessary
65
+ remainder = len(indices) % (batch_size * num_processes)
66
+ if remainder > 0:
67
+ indices = indices[:-remainder] # exclude last
68
+
69
+ chunk_i = chunk_indices(indices, batch_size) # equally sized
70
+ all_batches += chunk_i
71
+
72
+ # add the new indices without the last batch
73
+ prev_idx += len(chunk_i) * batch_size # len(dataset)
74
+
75
+ if shuffle:
76
+ random.shuffle(all_batches)
77
+
78
+ self.all_batches = all_batches
79
+
80
+ def __iter__(self):
81
+ return iter(self.all_batches)
82
+
83
+ def __len__(self):
84
+ return len(self.all_batches)
85
+
86
+
87
+ # https://github.com/catalyst-team/catalyst/blob/master/catalyst/data/sampler.py
88
+ class DatasetFromSampler(Dataset):
89
+ """Dataset to create indexes from `Sampler`.
90
+
91
+ Args:
92
+ sampler: PyTorch sampler
93
+ """
94
+
95
+ def __init__(self, sampler: Sampler):
96
+ """Initialisation for DatasetFromSampler."""
97
+ self.sampler = sampler
98
+ self.sampler_list = None
99
+
100
+ def __getitem__(self, index: int):
101
+ """Gets element of the dataset.
102
+
103
+ Args:
104
+ index: index of the element in the dataset
105
+
106
+ Returns:
107
+ Single element by index
108
+ """
109
+ if self.sampler_list is None:
110
+ self.sampler_list = list(self.sampler)
111
+ return self.sampler_list[index]
112
+
113
+ def __len__(self) -> int:
114
+ """
115
+ Returns:
116
+ int: length of the dataset
117
+ """
118
+ return len(self.sampler)
119
+
120
+
121
+ class DistributedSamplerWrapper(DistributedSampler):
122
+ """
123
+ Wrapper over `Sampler` for distributed training.
124
+ Allows you to use any sampler in distributed mode.
125
+
126
+ It is especially useful in conjunction with
127
+ `torch.nn.parallel.DistributedDataParallel`. In such case, each
128
+ process can pass a DistributedSamplerWrapper instance as a DataLoader
129
+ sampler, and load a subset of subsampled data of the original dataset
130
+ that is exclusive to it.
131
+
132
+ .. note::
133
+ Sampler is assumed to be of constant size.
134
+ """
135
+
136
+ def __init__(
137
+ self,
138
+ sampler,
139
+ num_replicas: Optional[int] = None,
140
+ rank: Optional[int] = None,
141
+ shuffle: bool = True,
142
+ ):
143
+ """
144
+
145
+ Args:
146
+ sampler: Sampler used for subsampling
147
+ num_replicas (int, optional): Number of processes participating in
148
+ distributed training
149
+ rank (int, optional): Rank of the current process
150
+ within ``num_replicas``
151
+ shuffle (bool, optional): If true (default),
152
+ sampler will shuffle the indices
153
+ """
154
+ super(DistributedSamplerWrapper, self).__init__(
155
+ DatasetFromSampler(sampler),
156
+ num_replicas=num_replicas,
157
+ rank=rank,
158
+ shuffle=shuffle,
159
+ )
160
+ self.sampler = sampler
161
+
162
+ def __iter__(self):
163
+ """Iterate over sampler.
164
+
165
+ Returns:
166
+ python iterator
167
+ """
168
+ self.dataset = DatasetFromSampler(self.sampler)
169
+ indexes_of_indexes = super().__iter__()
170
+ subsampler_indexes = self.dataset
171
+ return iter(itemgetter(*indexes_of_indexes)(subsampler_indexes))
172
+
173
+
174
+ # https://github.com/rabeehk/hyperformer/blob/main/hyperformer/data/multitask_sampler.py
175
+ class MultiTaskBatchSampler(Sampler):
176
+ """Defines a sampler to sample multiple datasets with temperature sampling
177
+ in a distributed fashion."""
178
+
179
+ def __init__(
180
+ self,
181
+ dataset_sizes: List[int],
182
+ batch_size: int,
183
+ temperature: float,
184
+ dataset_groups=[],
185
+ num_replicas: Optional[int] = 1,
186
+ rank: Optional[int] = 0,
187
+ seed: int = 0,
188
+ shuffle: bool = True,
189
+ shuffle_task: bool = True,
190
+ ) -> None:
191
+ """Constructor for MultiTaskBatchSampler.
192
+ Args:
193
+ dataset_sizes: a list of integers, specifies the number of samples in
194
+ each dataset.
195
+ batch_size: integer, specifies the batch size.
196
+ temperature: float, temperature used for temperature sampling. The larger
197
+ the value, the datasets are sampled equally, and for value of 0, the datasets
198
+ will be sampled according to their number of samples.
199
+ num_replicas: integer, specifies the number of processes.
200
+ rank: integer, specifies the rank of the current process/
201
+ seed: integer, random seed.
202
+ shuffle: bool, if set to true, the datasets will be shuffled in each epoch.
203
+ """
204
+
205
+ if num_replicas is None:
206
+ if not dist.is_available():
207
+ raise RuntimeError("Requires distributed package to be available")
208
+ num_replicas = dist.get_world_size()
209
+ if rank is None:
210
+ if not dist.is_available():
211
+ raise RuntimeError("Requires distributed package to be available")
212
+ rank = dist.get_rank()
213
+ print("data sampler rank:", rank)
214
+
215
+ if rank >= num_replicas or rank < 0:
216
+ raise ValueError(
217
+ "Invalid rank {}, rank should be in the interval" " [0, {}]".format(rank, num_replicas - 1)
218
+ )
219
+
220
+ self.dataset_groups = dataset_groups
221
+ print("dataset groups:", self.dataset_groups)
222
+
223
+ self.num_replicas = num_replicas
224
+ self.shuffle_task = shuffle_task
225
+ self.rank = rank
226
+ self.batch_size = batch_size
227
+ self.dataset_sizes = dataset_sizes
228
+
229
+ # By default we drop the last elements if dataset is not divisible by the number of ranks.
230
+ self.rank_dataset_sizes = [dataset_size // self.num_replicas for dataset_size in self.dataset_sizes]
231
+ self.dataset_offsets = torch.cumsum(torch.LongTensor([0] + dataset_sizes), 0)
232
+ self.total_sizes = [
233
+ (dataset_size // self.num_replicas) * self.num_replicas for dataset_size in self.dataset_sizes
234
+ ]
235
+ self.temperature = temperature
236
+ self.seed = seed
237
+ self.epoch = 0
238
+ self.num_batches_per_epoch = (
239
+ (np.sum(dataset_sizes) + self.batch_size - 1) // self.batch_size // self.num_replicas
240
+ )
241
+ self.shuffle = shuffle
242
+ print(f"{num_replicas=} {rank=} {self.num_batches_per_epoch=} {self.total_sizes=} self.weights={self.generate_tasks_distribution()}")
243
+
244
+ def generate_tasks_distribution(self):
245
+ """Given the dataset sizes computes the weights to sample each dataset
246
+ according to the temperature sampling."""
247
+ if len(self.dataset_groups) > 0:
248
+ # normalize across groups first
249
+ weights = []
250
+ num_groups = len(self.dataset_groups)
251
+ for group in self.dataset_groups:
252
+ lo, hi = group
253
+ dataset_sizes = [self.dataset_sizes[idx] for idx in range(lo, hi)]
254
+ total_size = sum(dataset_sizes)
255
+ group_weights = np.array([(size / total_size) ** (1.0 / self.temperature) for size in dataset_sizes])
256
+ group_weights = group_weights / np.sum(group_weights) / num_groups
257
+ weights = np.concatenate((weights, group_weights))
258
+
259
+ else:
260
+ total_size = sum(self.dataset_sizes)
261
+ weights = np.array([(size / total_size) ** (1.0 / self.temperature) for size in self.dataset_sizes])
262
+ weights = weights / np.sum(weights)
263
+ return torch.as_tensor(weights, dtype=torch.double)
264
+
265
+ def __iter__(self):
266
+ # Defines torch generator, to make random choices consistent across cores in
267
+ # different epochs, the seed needs to be set based on seed and epoch.
268
+ generator = torch.Generator()
269
+ generator.manual_seed(self.seed + self.epoch)
270
+
271
+ # Shuffles the datasets if shuffle is set to true.
272
+ indices = []
273
+ for dataset_size in self.dataset_sizes:
274
+ if self.shuffle:
275
+ indices.append(torch.randperm(dataset_size, generator=generator).tolist())
276
+ else:
277
+ indices.append(list(range(dataset_size)))
278
+
279
+ # Shards the datasets across the all processes.
280
+ self.rank_indices = []
281
+ for i in range(len(self.dataset_sizes)):
282
+ self.rank_indices.append(indices[i][self.rank : self.total_sizes[i] : self.num_replicas])
283
+
284
+ # To make the model consistent across different processes, since the
285
+ # model is based on tasks, we need to make sure the same task is selected
286
+ # across different processes.
287
+ tasks_distribution: torch.Tensor = self.generate_tasks_distribution()
288
+
289
+ # Chooses the tasks which will be used in each batch in one epoch.
290
+ # With passing generator, we make sure this choice is consistent across
291
+ # different processes.
292
+
293
+ # want them to be different.
294
+ if self.shuffle_task:
295
+ generator.manual_seed(self.seed + self.epoch + self.rank)
296
+ batch_task_assignments = torch.multinomial(
297
+ tasks_distribution, self.num_batches_per_epoch, replacement=True, generator=generator
298
+ )
299
+
300
+ for batch_task in batch_task_assignments:
301
+ # Gets the number of samples of the selected datasets available for the current rank.
302
+ num_task_samples = self.rank_dataset_sizes[batch_task]
303
+ # Computes the random samples from the chosen dataset.
304
+ indices = torch.randint(low=0, high=num_task_samples, size=(self.batch_size,), generator=generator).tolist()
305
+ # Converts the selected indices to the global indices on the given dataset.
306
+ results = (self.dataset_offsets[batch_task] + torch.tensor(self.rank_indices[batch_task])[indices]).tolist()
307
+ yield results
308
+
309
+ def __len__(self):
310
+ return self.num_batches_per_epoch
311
+
312
+ def set_epoch(self, epoch):
313
+ self.epoch = epoch
314
+
315
+ def make_dataset_pie_plot(domains, traj_nums):
316
+ """draw the dataset mixture as a pie plot"""
317
+ new_domains = []
318
+ for idx, domain in enumerate(domains):
319
+ new_domains.append(domain)
320
+ plt.cla()
321
+ fig1, ax1 = plt.subplots(figsize=(40, 40))
322
+ traj_prob = np.array(traj_nums) / np.sum(traj_nums)
323
+ tab20 = plt.get_cmap("tab20").colors
324
+ tab20b = plt.get_cmap("tab20b").colors
325
+ tab20c = plt.get_cmap("tab20c").colors
326
+
327
+ # Combine them to get 60 distinct colors
328
+ colors = tab20 + tab20b + tab20c
329
+ patches, _ = ax1.pie(traj_prob, startangle=90, colors=colors[: len(traj_prob)])
330
+ ax1.axis("equal")
331
+ ax1.legend(patches, new_domains, loc="center left", bbox_to_anchor=(0.8, 0.5), prop={"size": 32})
332
+ fig1.canvas.draw()
333
+
334
+ return Image.frombytes("RGB", fig1.canvas.get_width_height(), fig1.canvas.tostring_rgb())
common/eval_utils.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Callable
2
+
3
+ import torch
4
+ import torchvision.transforms.functional as transforms_f
5
+ from einops import rearrange
6
+
7
+ from genie.factorization_utils import factorize_labels
8
+
9
+
10
+ class AvgMetric:
11
+ """ Records a running sum and count to compute the mean. """
12
+ def __init__(self):
13
+ self.total = 0
14
+ self.count = 0
15
+
16
+ def update(self, val, batch_size=1):
17
+ self.total += val * batch_size
18
+ self.count += batch_size
19
+
20
+ def update_list(self, flat_vals):
21
+ self.total += sum(flat_vals)
22
+ self.count += len(flat_vals)
23
+
24
+ def mean(self):
25
+ if self.count == 0:
26
+ return 0
27
+ return self.total / self.count
28
+
29
+
30
+ def decode_tokens(reshaped_token_ids: torch.LongTensor, decode_latents: Callable) -> torch.ByteTensor:
31
+ """
32
+ Converts quantized latent space tokens to images.
33
+
34
+ Args:
35
+ reshaped_token_ids: shape (B, T, H, W).
36
+ decode_latents: instance of `decode_latents_wrapper()`
37
+
38
+ Returns:
39
+ (B, T, 3, 256, 256)
40
+ """
41
+ decoded_imgs = decode_latents(rearrange(reshaped_token_ids, "b t h w -> (b t) h w").cpu().numpy())
42
+ decoded_tensor = torch.stack([transforms_f.pil_to_tensor(pred_img) for pred_img in decoded_imgs])
43
+ return rearrange(decoded_tensor, "(b t) c H W -> b t c H W", b=reshaped_token_ids.size(0))
44
+
45
+ def decode_features(reshaped_token_ids: torch.LongTensor, decode_latents: Callable) -> torch.ByteTensor:
46
+ """
47
+ Converts quantized latent space tokens to images.
48
+
49
+ Args:
50
+ reshaped_token_ids: shape (B, T, H, W).
51
+ decode_latents: instance of `decode_latents_wrapper()`
52
+
53
+ Returns:
54
+ (B, T, 3, 256, 256)
55
+ """
56
+ decoded_imgs = decode_latents(rearrange(reshaped_token_ids, "b t h w c -> (b t) c h w").cpu().numpy())
57
+ decoded_tensor = torch.stack([transforms_f.pil_to_tensor(pred_img) for pred_img in decoded_imgs])
58
+ return rearrange(decoded_tensor, "(b t) c H W -> b t c H W", b=reshaped_token_ids.size(0))
59
+
60
+
61
+ def compute_loss(
62
+ labels_flat: torch.LongTensor,
63
+ factored_logits: torch.FloatTensor,
64
+ num_factored_vocabs: int = 2,
65
+ factored_vocab_size: int = 512,
66
+ ) -> float:
67
+ """
68
+ If applicable (model returns logits), compute the cross entropy loss.
69
+ In the case of a factorized vocabulary, sums the cross entropy losses for each vocabulary.
70
+
71
+ Assuming all submissions use the parametrization of num_factored_vocabs = 2, factored_vocab_size = 512
72
+
73
+ Args:
74
+ labels_flat: size (B, T*H*W) corresponding to flattened, tokenized images.
75
+ factored_logits: size (B, factored_vocab_size, num_factored_vocabs, T-1, H, W).
76
+ E.g. output of `genie.evaluate.GenieEvaluator.predict_zframe_logits()`
77
+ num_factored_vocabs: Should be 2 for v1.0 of the challenge.
78
+ factored_vocab_size: Should be 512 for v1.0 of the challenge.
79
+ Returns:
80
+ Cross entropy loss
81
+ """
82
+ assert factored_logits.dim() == 6 \
83
+ and factored_logits.size()[:3] == (labels_flat.size(0), factored_vocab_size, num_factored_vocabs), \
84
+ f"Shape of `logits` should be (B, {factored_vocab_size}, {num_factored_vocabs}, T-1, H, W)"
85
+ t = factored_logits.size(3) + 1
86
+ h, w = factored_logits.size()[-2:]
87
+ assert t * h * w == labels_flat.size(1), "Shape of `factored_logits` does not match flattened latent image size."
88
+
89
+ labels_THW = rearrange(labels_flat, "b (t h w) -> b t h w", t=t, h=h, w=w)
90
+ labels_THW = labels_THW[:, 1:].to(factored_logits.device)
91
+
92
+ factored_labels = factorize_labels(labels_THW, num_factored_vocabs, factored_vocab_size)
93
+ return torch.nn.functional.cross_entropy(factored_logits, factored_labels, reduction="none")\
94
+ .sum(dim=1).mean().item() # Final loss is the sum of the two losses across the size-512 vocabularies
95
+
96
+
97
+ def compute_lpips(frames_a: torch.ByteTensor, frames_b: torch.ByteTensor, lpips_func: Callable) -> list:
98
+ """
99
+ Given two batches of video data, of shape (B, T, 3, 256, 256), computes the LPIPS score on frame-by-frame level.
100
+ Cannot use `lpips_func` directly because it expects at most 4D input.
101
+ """
102
+ # LPIPS expects pixel values between [-1, 1]
103
+ flattened_a, flattened_b = [rearrange(frames / 127.5 - 1, "b t c H W -> (b t) c H W")
104
+ for frames in (frames_a, frames_b)]
105
+ return lpips_func(flattened_a, flattened_b).flatten().tolist()
common/fid_score.py ADDED
@@ -0,0 +1,382 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Calculates the Frechet Inception Distance (FID) to evalulate GANs
2
+
3
+ The FID metric calculates the distance between two distributions of images.
4
+ Typically, we have summary statistics (mean & covariance matrix) of one
5
+ of these distributions, while the 2nd distribution is given by a GAN.
6
+
7
+ When run as a stand-alone program, it compares the distribution of
8
+ images that are stored as PNG/JPEG at a specified location with a
9
+ distribution given by summary statistics (in pickle format).
10
+
11
+ The FID is calculated by assuming that X_1 and X_2 are the activations of
12
+ the pool_3 layer of the inception net for generated samples and real world
13
+ samples respectively.
14
+
15
+ See --help to see further details.
16
+
17
+ Code apapted from https://github.com/bioinf-jku/TTUR to use PyTorch instead
18
+ of Tensorflow
19
+
20
+ Copyright 2018 Institute of Bioinformatics, JKU Linz
21
+
22
+ Licensed under the Apache License, Version 2.0 (the "License");
23
+ you may not use this file except in compliance with the License.
24
+ You may obtain a copy of the License at
25
+
26
+ http://www.apache.org/licenses/LICENSE-2.0
27
+
28
+ Unless required by applicable law or agreed to in writing, software
29
+ distributed under the License is distributed on an "AS IS" BASIS,
30
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
31
+ See the License for the specific language governing permissions and
32
+ limitations under the License.
33
+ """
34
+ # code adapted from https://github.com/mseitzer/pytorch-fid/tree/master
35
+
36
+ import os
37
+ import pathlib
38
+ from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
39
+
40
+ import numpy as np
41
+ import torch
42
+ import torchvision.transforms as TF
43
+ from PIL import Image
44
+ from scipy import linalg
45
+ from torch.nn.functional import adaptive_avg_pool2d
46
+
47
+ try:
48
+ from tqdm import tqdm
49
+ except ImportError:
50
+ # If tqdm is not available, provide a mock version of it
51
+ def tqdm(x):
52
+ return x
53
+
54
+
55
+ from .inception import InceptionV3
56
+
57
+ IMAGE_EXTENSIONS = {"bmp", "jpg", "jpeg", "pgm", "png", "ppm", "tif", "tiff", "webp"}
58
+
59
+
60
+ class ImagePathDataset(torch.utils.data.Dataset):
61
+ def __init__(self, files, transforms=None):
62
+ self.files = files
63
+ self.transforms = transforms
64
+
65
+ def __len__(self):
66
+ return len(self.files)
67
+
68
+ def __getitem__(self, i):
69
+ path = self.files[i]
70
+ img = Image.open(path).convert("RGB")
71
+ if self.transforms is not None:
72
+ img = self.transforms(img)
73
+ return img
74
+
75
+
76
+ def get_activations(
77
+ files, model, batch_size=50, dims=2048, device="cpu", num_workers=1
78
+ ):
79
+ """Calculates the activations of the pool_3 layer for all images.
80
+
81
+ Params:
82
+ -- files : List of image files paths
83
+ -- model : Instance of inception model
84
+ -- batch_size : Batch size of images for the model to process at once.
85
+ Make sure that the number of samples is a multiple of
86
+ the batch size, otherwise some samples are ignored. This
87
+ behavior is retained to match the original FID score
88
+ implementation.
89
+ -- dims : Dimensionality of features returned by Inception
90
+ -- device : Device to run calculations
91
+ -- num_workers : Number of parallel dataloader workers
92
+
93
+ Returns:
94
+ -- A numpy array of dimension (num images, dims) that contains the
95
+ activations of the given tensor when feeding inception with the
96
+ query tensor.
97
+ """
98
+ model.eval()
99
+
100
+ if batch_size > len(files):
101
+ print(
102
+ (
103
+ "Warning: batch size is bigger than the data size. "
104
+ "Setting batch size to data size"
105
+ )
106
+ )
107
+ batch_size = len(files)
108
+
109
+ dataset = ImagePathDataset(files, transforms=TF.ToTensor())
110
+ dataloader = torch.utils.data.DataLoader(
111
+ dataset,
112
+ batch_size=batch_size,
113
+ shuffle=False,
114
+ drop_last=False,
115
+ num_workers=num_workers,
116
+ )
117
+
118
+ pred_arr = np.empty((len(files), dims))
119
+
120
+ start_idx = 0
121
+
122
+ for batch in tqdm(dataloader):
123
+ batch = batch.to(device)
124
+
125
+ with torch.no_grad():
126
+ pred = model(batch)[0]
127
+
128
+ # If model output is not scalar, apply global spatial average pooling.
129
+ # This happens if you choose a dimensionality not equal 2048.
130
+ if pred.size(2) != 1 or pred.size(3) != 1:
131
+ pred = adaptive_avg_pool2d(pred, output_size=(1, 1))
132
+
133
+ pred = pred.squeeze(3).squeeze(2).cpu().numpy()
134
+
135
+ pred_arr[start_idx : start_idx + pred.shape[0]] = pred
136
+
137
+ start_idx = start_idx + pred.shape[0]
138
+
139
+ return pred_arr
140
+
141
+ def get_activations_images(
142
+ dataset, model, batch_size=50, dims=2048, device="cpu", num_workers=0
143
+ ):
144
+ """Calculates the activations of the pool_3 layer for all images.
145
+
146
+ Params:
147
+ -- files : List of image files paths
148
+ -- model : Instance of inception model
149
+ -- batch_size : Batch size of images for the model to process at once.
150
+ Make sure that the number of samples is a multiple of
151
+ the batch size, otherwise some samples are ignored. This
152
+ behavior is retained to match the original FID score
153
+ implementation.
154
+ -- dims : Dimensionality of features returned by Inception
155
+ -- device : Device to run calculations
156
+ -- num_workers : Number of parallel dataloader workers
157
+
158
+ Returns:
159
+ -- A numpy array of dimension (num images, dims) that contains the
160
+ activations of the given tensor when feeding inception with the
161
+ query tensor.
162
+ """
163
+ model.eval()
164
+ # import IPython; IPython.embed()
165
+ # combine batch and temporal
166
+ dataset = torch.cat([dataset[:, i] for i in range(dataset.shape[1])], dim=0).to("cpu")
167
+ dataloader = torch.utils.data.DataLoader(
168
+ dataset,
169
+ batch_size=batch_size,
170
+ shuffle=False,
171
+ drop_last=True,
172
+ num_workers=num_workers,
173
+ )
174
+
175
+ pred_arr = np.empty((len(dataset), dims))
176
+
177
+ start_idx = 0
178
+
179
+ for batch in tqdm(dataloader):
180
+ batch = batch.to(device)
181
+
182
+ with torch.no_grad():
183
+ pred = model(batch)[0]
184
+
185
+ # If model output is not scalar, apply global spatial average pooling.
186
+ # This happens if you choose a dimensionality not equal 2048.
187
+ if pred.size(2) != 1 or pred.size(3) != 1:
188
+ pred = adaptive_avg_pool2d(pred, output_size=(1, 1))
189
+
190
+ pred = pred.squeeze(3).squeeze(2).cpu().numpy()
191
+
192
+ pred_arr[start_idx : start_idx + pred.shape[0]] = pred
193
+
194
+ start_idx = start_idx + pred.shape[0]
195
+
196
+ return pred_arr
197
+
198
+
199
+ def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
200
+ """Numpy implementation of the Frechet Distance.
201
+ The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
202
+ and X_2 ~ N(mu_2, C_2) is
203
+ d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).
204
+
205
+ Stable version by Dougal J. Sutherland.
206
+
207
+ Params:
208
+ -- mu1 : Numpy array containing the activations of a layer of the
209
+ inception net (like returned by the function 'get_predictions')
210
+ for generated samples.
211
+ -- mu2 : The sample mean over activations, precalculated on an
212
+ representative data set.
213
+ -- sigma1: The covariance matrix over activations for generated samples.
214
+ -- sigma2: The covariance matrix over activations, precalculated on an
215
+ representative data set.
216
+
217
+ Returns:
218
+ -- : The Frechet Distance.
219
+ """
220
+
221
+ mu1 = np.atleast_1d(mu1)
222
+ mu2 = np.atleast_1d(mu2)
223
+
224
+ sigma1 = np.atleast_2d(sigma1)
225
+ sigma2 = np.atleast_2d(sigma2)
226
+
227
+ assert (
228
+ mu1.shape == mu2.shape
229
+ ), "Training and test mean vectors have different lengths"
230
+ assert (
231
+ sigma1.shape == sigma2.shape
232
+ ), "Training and test covariances have different dimensions"
233
+
234
+ diff = mu1 - mu2
235
+
236
+ # Product might be almost singular
237
+ covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
238
+ if not np.isfinite(covmean).all():
239
+ msg = (
240
+ "fid calculation produces singular product; "
241
+ "adding %s to diagonal of cov estimates"
242
+ ) % eps
243
+ print(msg)
244
+ offset = np.eye(sigma1.shape[0]) * eps
245
+ covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
246
+
247
+ # Numerical error might give slight imaginary component
248
+ if np.iscomplexobj(covmean):
249
+ if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
250
+ m = np.max(np.abs(covmean.imag))
251
+ raise ValueError("Imaginary component {}".format(m))
252
+ covmean = covmean.real
253
+
254
+ tr_covmean = np.trace(covmean)
255
+
256
+ return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean
257
+
258
+
259
+ def calculate_activation_statistics(
260
+ images, model, batch_size=50, dims=2048, device="cpu", num_workers=1
261
+ ):
262
+ """Calculation of the statistics used by the FID.
263
+ Params:
264
+ -- files : List of image files paths
265
+ -- model : Instance of inception model
266
+ -- batch_size : The images numpy array is split into batches with
267
+ batch size batch_size. A reasonable batch size
268
+ depends on the hardware.
269
+ -- dims : Dimensionality of features returned by Inception
270
+ -- device : Device to run calculations
271
+ -- num_workers : Number of parallel dataloader workers
272
+
273
+ Returns:
274
+ -- mu : The mean over samples of the activations of the pool_3 layer of
275
+ the inception model.
276
+ -- sigma : The covariance matrix of the activations of the pool_3 layer of
277
+ the inception model.
278
+ """
279
+ act = get_activations_images(images, model, batch_size, dims, device, num_workers)
280
+ mu = np.mean(act, axis=0)
281
+ sigma = np.cov(act, rowvar=False)
282
+ return mu, sigma
283
+
284
+
285
+ def compute_statistics(images, model, batch_size, dims, device, num_workers=1):
286
+ m, s = calculate_activation_statistics(
287
+ images, model, batch_size, dims, device, num_workers
288
+ )
289
+ return m, s
290
+
291
+ def calculate_fid(pred_images, gt_images, batch_size=16, device="cuda", dims=2048, num_workers=1):
292
+ """Calculates the FID of two paths"""
293
+
294
+ block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims]
295
+
296
+ model = InceptionV3([block_idx]).to(device)
297
+
298
+ m1, s1 = compute_statistics(
299
+ pred_images, model, batch_size, dims, device, num_workers
300
+ )
301
+ m2, s2 = compute_statistics(
302
+ gt_images, model, batch_size, dims, device, num_workers
303
+ )
304
+ fid_value = calculate_frechet_distance(m1, s1, m2, s2)
305
+
306
+ return fid_value
307
+
308
+ def calculate_fid_given_paths(paths, batch_size, device, dims, num_workers=1):
309
+ """Calculates the FID of two paths"""
310
+ for p in paths:
311
+ if not os.path.exists(p):
312
+ raise RuntimeError("Invalid path: %s" % p)
313
+
314
+ block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims]
315
+
316
+ model = InceptionV3([block_idx]).to(device)
317
+
318
+ m1, s1 = compute_statistics_of_path(
319
+ paths[0], model, batch_size, dims, device, num_workers
320
+ )
321
+ m2, s2 = compute_statistics_of_path(
322
+ paths[1], model, batch_size, dims, device, num_workers
323
+ )
324
+ fid_value = calculate_frechet_distance(m1, s1, m2, s2)
325
+
326
+ return fid_value
327
+
328
+
329
+ def save_fid_stats(paths, batch_size, device, dims, num_workers=1):
330
+ """Saves FID statistics of one path"""
331
+ if not os.path.exists(paths[0]):
332
+ raise RuntimeError("Invalid path: %s" % paths[0])
333
+
334
+ if os.path.exists(paths[1]):
335
+ raise RuntimeError("Existing output file: %s" % paths[1])
336
+
337
+ block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims]
338
+
339
+ model = InceptionV3([block_idx]).to(device)
340
+
341
+ print(f"Saving statistics for {paths[0]}")
342
+
343
+ m1, s1 = compute_statistics_of_path(
344
+ paths[0], model, batch_size, dims, device, num_workers
345
+ )
346
+
347
+ np.savez_compressed(paths[1], mu=m1, sigma=s1)
348
+
349
+
350
+ def main():
351
+ args = parser.parse_args()
352
+
353
+ if args.device is None:
354
+ device = torch.device("cuda" if (torch.cuda.is_available()) else "cpu")
355
+ else:
356
+ device = torch.device(args.device)
357
+
358
+ if args.num_workers is None:
359
+ try:
360
+ num_cpus = len(os.sched_getaffinity(0))
361
+ except AttributeError:
362
+ # os.sched_getaffinity is not available under Windows, use
363
+ # os.cpu_count instead (which may not return the *available* number
364
+ # of CPUs).
365
+ num_cpus = os.cpu_count()
366
+
367
+ num_workers = min(num_cpus, 8) if num_cpus is not None else 0
368
+ else:
369
+ num_workers = args.num_workers
370
+
371
+ if args.save_stats:
372
+ save_fid_stats(args.path, args.batch_size, device, args.dims, num_workers)
373
+ return
374
+
375
+ fid_value = calculate_fid_given_paths(
376
+ args.path, args.batch_size, device, args.dims, num_workers
377
+ )
378
+ print("FID: ", fid_value)
379
+
380
+
381
+ if __name__ == "__main__":
382
+ main()
common/fvd/styleganv/fvd.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import os
3
+ import math
4
+ import torch.nn.functional as F
5
+
6
+ # https://github.com/universome/fvd-comparison
7
+
8
+
9
+ def load_i3d_pretrained(device=torch.device('cpu')):
10
+ i3D_WEIGHTS_URL = "https://www.dropbox.com/s/ge9e5ujwgetktms/i3d_torchscript.pt"
11
+ filepath = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'i3d_torchscript.pt')
12
+ print(filepath)
13
+ if not os.path.exists(filepath):
14
+ print(f"preparing for download {i3D_WEIGHTS_URL}, you can download it by yourself.")
15
+ os.system(f"wget {i3D_WEIGHTS_URL} -O {filepath}")
16
+ i3d = torch.jit.load(filepath).eval().to(device)
17
+ i3d = torch.nn.DataParallel(i3d)
18
+ return i3d
19
+
20
+
21
+ def get_feats(videos, detector, device, bs=10):
22
+ # videos : torch.tensor BCTHW [0, 1]
23
+ detector_kwargs = dict(rescale=False, resize=False, return_features=True) # Return raw features before the softmax layer.
24
+ feats = np.empty((0, 400))
25
+ with torch.no_grad():
26
+ for i in range((len(videos)-1)//bs + 1):
27
+ feats = np.vstack([feats, detector(torch.stack([preprocess_single(video) for video in videos[i*bs:(i+1)*bs]]).to(device), **detector_kwargs).detach().cpu().numpy()])
28
+ return feats
29
+
30
+
31
+ def get_fvd_feats(videos, i3d, device, bs=10):
32
+ # videos in [0, 1] as torch tensor BCTHW
33
+ # videos = [preprocess_single(video) for video in videos]
34
+ embeddings = get_feats(videos, i3d, device, bs)
35
+ return embeddings
36
+
37
+
38
+ def preprocess_single(video, resolution=224, sequence_length=None):
39
+ # video: CTHW, [0, 1]
40
+ c, t, h, w = video.shape
41
+
42
+ # temporal crop
43
+ if sequence_length is not None:
44
+ assert sequence_length <= t
45
+ video = video[:, :sequence_length]
46
+
47
+ # scale shorter side to resolution
48
+ scale = resolution / min(h, w)
49
+ if h < w:
50
+ target_size = (resolution, math.ceil(w * scale))
51
+ else:
52
+ target_size = (math.ceil(h * scale), resolution)
53
+ video = F.interpolate(video, size=target_size, mode='bilinear', align_corners=False)
54
+
55
+ # center crop
56
+ c, t, h, w = video.shape
57
+ w_start = (w - resolution) // 2
58
+ h_start = (h - resolution) // 2
59
+ video = video[:, :, h_start:h_start + resolution, w_start:w_start + resolution]
60
+
61
+ # [0, 1] -> [-1, 1]
62
+ video = (video - 0.5) * 2
63
+
64
+ return video.contiguous()
65
+
66
+
67
+ """
68
+ Copy-pasted from https://github.com/cvpr2022-stylegan-v/stylegan-v/blob/main/src/metrics/frechet_video_distance.py
69
+ """
70
+ from typing import Tuple
71
+ from scipy.linalg import sqrtm
72
+ import numpy as np
73
+
74
+
75
+ def compute_stats(feats: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
76
+ mu = feats.mean(axis=0) # [d]
77
+ sigma = np.cov(feats, rowvar=False) # [d, d]
78
+ return mu, sigma
79
+
80
+
81
+ def frechet_distance(feats_fake: np.ndarray, feats_real: np.ndarray) -> float:
82
+ mu_gen, sigma_gen = compute_stats(feats_fake)
83
+ mu_real, sigma_real = compute_stats(feats_real)
84
+ m = np.square(mu_gen - mu_real).sum()
85
+ if feats_fake.shape[0]>1:
86
+ s, _ = sqrtm(np.dot(sigma_gen, sigma_real), disp=False) # pylint: disable=no-member
87
+ fid = np.real(m + np.trace(sigma_gen + sigma_real - s * 2))
88
+ else:
89
+ fid = np.real(m)
90
+ return float(fid)
common/fvd/styleganv/i3d_torchscript.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bec6519f66ea534e953026b4ae2c65553c17bf105611c746d904657e5860a5e2
3
+ size 51235320
common/fvd/videogpt/fvd.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import os
3
+ import math
4
+ import torch.nn.functional as F
5
+ import numpy as np
6
+ import einops
7
+
8
+ def load_i3d_pretrained(device=torch.device('cpu')):
9
+ i3D_WEIGHTS_URL = "https://onedrive.live.com/download?cid=78EEF3EB6AE7DBCB&resid=78EEF3EB6AE7DBCB%21199&authkey=AApKdFHPXzWLNyI"
10
+ filepath = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'i3d_pretrained_400.pt')
11
+ print(filepath)
12
+ if not os.path.exists(filepath):
13
+ print(f"preparing for download {i3D_WEIGHTS_URL}, you can download it by yourself.")
14
+ os.system(f"wget {i3D_WEIGHTS_URL} -O {filepath}")
15
+ from .pytorch_i3d import InceptionI3d
16
+ i3d = InceptionI3d(400, in_channels=3).eval().to(device)
17
+ i3d.load_state_dict(torch.load(filepath, map_location=device))
18
+ i3d = torch.nn.DataParallel(i3d)
19
+ return i3d
20
+
21
+ def preprocess_single(video, resolution, sequence_length=None):
22
+ # video: THWC, {0, ..., 255}
23
+ video = video.permute(0, 3, 1, 2).float() / 255. # TCHW
24
+ t, c, h, w = video.shape
25
+
26
+ # temporal crop
27
+ if sequence_length is not None:
28
+ assert sequence_length <= t
29
+ video = video[:sequence_length]
30
+
31
+ # scale shorter side to resolution
32
+ scale = resolution / min(h, w)
33
+ if h < w:
34
+ target_size = (resolution, math.ceil(w * scale))
35
+ else:
36
+ target_size = (math.ceil(h * scale), resolution)
37
+ video = F.interpolate(video, size=target_size, mode='bilinear',
38
+ align_corners=False)
39
+
40
+ # center crop
41
+ t, c, h, w = video.shape
42
+ w_start = (w - resolution) // 2
43
+ h_start = (h - resolution) // 2
44
+ video = video[:, :, h_start:h_start + resolution, w_start:w_start + resolution]
45
+ video = video.permute(1, 0, 2, 3).contiguous() # CTHW
46
+
47
+ video -= 0.5
48
+
49
+ return video
50
+
51
+ def preprocess(videos, target_resolution=224):
52
+ # we should tras videos in [0-1] [b c t h w] as th.float
53
+ # -> videos in {0, ..., 255} [b t h w c] as np.uint8 array
54
+ videos = einops.rearrange(videos, 'b c t h w -> b t h w c')
55
+ videos = (videos*255).numpy().astype(np.uint8)
56
+
57
+ b, t, h, w, c = videos.shape
58
+ videos = torch.from_numpy(videos)
59
+ videos = torch.stack([preprocess_single(video, target_resolution) for video in videos])
60
+ return videos * 2 # [-0.5, 0.5] -> [-1, 1]
61
+
62
+ def get_fvd_logits(videos, i3d, device, bs=10):
63
+ videos = preprocess(videos)
64
+ embeddings = get_logits(i3d, videos, device, bs=10)
65
+ return embeddings
66
+
67
+ # https://github.com/tensorflow/gan/blob/de4b8da3853058ea380a6152bd3bd454013bf619/tensorflow_gan/python/eval/classifier_metrics.py#L161
68
+ def _symmetric_matrix_square_root(mat, eps=1e-10):
69
+ u, s, v = torch.svd(mat)
70
+ si = torch.where(s < eps, s, torch.sqrt(s))
71
+ return torch.matmul(torch.matmul(u, torch.diag(si)), v.t())
72
+
73
+ # https://github.com/tensorflow/gan/blob/de4b8da3853058ea380a6152bd3bd454013bf619/tensorflow_gan/python/eval/classifier_metrics.py#L400
74
+ def trace_sqrt_product(sigma, sigma_v):
75
+ sqrt_sigma = _symmetric_matrix_square_root(sigma)
76
+ sqrt_a_sigmav_a = torch.matmul(sqrt_sigma, torch.matmul(sigma_v, sqrt_sigma))
77
+ return torch.trace(_symmetric_matrix_square_root(sqrt_a_sigmav_a))
78
+
79
+ # https://discuss.pytorch.org/t/covariance-and-gradient-support/16217/2
80
+ def cov(m, rowvar=False):
81
+ '''Estimate a covariance matrix given data.
82
+
83
+ Covariance indicates the level to which two variables vary together.
84
+ If we examine N-dimensional samples, `X = [x_1, x_2, ... x_N]^T`,
85
+ then the covariance matrix element `C_{ij}` is the covariance of
86
+ `x_i` and `x_j`. The element `C_{ii}` is the variance of `x_i`.
87
+
88
+ Args:
89
+ m: A 1-D or 2-D array containing multiple variables and observations.
90
+ Each row of `m` represents a variable, and each column a single
91
+ observation of all those variables.
92
+ rowvar: If `rowvar` is True, then each row represents a
93
+ variable, with observations in the columns. Otherwise, the
94
+ relationship is transposed: each column represents a variable,
95
+ while the rows contain observations.
96
+
97
+ Returns:
98
+ The covariance matrix of the variables.
99
+ '''
100
+ if m.dim() > 2:
101
+ raise ValueError('m has more than 2 dimensions')
102
+ if m.dim() < 2:
103
+ m = m.view(1, -1)
104
+ if not rowvar and m.size(0) != 1:
105
+ m = m.t()
106
+
107
+ fact = 1.0 / (m.size(1) - 1) # unbiased estimate
108
+ m -= torch.mean(m, dim=1, keepdim=True)
109
+ mt = m.t() # if complex: mt = m.t().conj()
110
+ return fact * m.matmul(mt).squeeze()
111
+
112
+
113
+ def frechet_distance(x1, x2):
114
+ x1 = x1.flatten(start_dim=1)
115
+ x2 = x2.flatten(start_dim=1)
116
+ m, m_w = x1.mean(dim=0), x2.mean(dim=0)
117
+ sigma, sigma_w = cov(x1, rowvar=False), cov(x2, rowvar=False)
118
+ mean = torch.sum((m - m_w) ** 2)
119
+ if x1.shape[0]>1:
120
+ sqrt_trace_component = trace_sqrt_product(sigma, sigma_w)
121
+ trace = torch.trace(sigma + sigma_w) - 2.0 * sqrt_trace_component
122
+ fd = trace + mean
123
+ else:
124
+ fd = np.real(mean)
125
+ return float(fd)
126
+
127
+
128
+ def get_logits(i3d, videos, device, bs=10):
129
+ # assert videos.shape[0] % 16 == 0
130
+ with torch.no_grad():
131
+ logits = []
132
+ for i in range(0, videos.shape[0], bs):
133
+ batch = videos[i:i + bs].to(device)
134
+ # logits.append(i3d.module.extract_features(batch)) # wrong
135
+ logits.append(i3d(batch)) # right
136
+ logits = torch.cat(logits, dim=0)
137
+ return logits
common/fvd/videogpt/i3d_pretrained_400.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:55095f049e706479d48e221adcdb145b2b9dc930ba28b081ed72367ffaa32343
3
+ size 50939526
common/fvd/videogpt/pytorch_i3d.py ADDED
@@ -0,0 +1,322 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Original code from https://github.com/piergiaj/pytorch-i3d
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ import numpy as np
6
+
7
+ class MaxPool3dSamePadding(nn.MaxPool3d):
8
+
9
+ def compute_pad(self, dim, s):
10
+ if s % self.stride[dim] == 0:
11
+ return max(self.kernel_size[dim] - self.stride[dim], 0)
12
+ else:
13
+ return max(self.kernel_size[dim] - (s % self.stride[dim]), 0)
14
+
15
+ def forward(self, x):
16
+ # compute 'same' padding
17
+ (batch, channel, t, h, w) = x.size()
18
+ out_t = np.ceil(float(t) / float(self.stride[0]))
19
+ out_h = np.ceil(float(h) / float(self.stride[1]))
20
+ out_w = np.ceil(float(w) / float(self.stride[2]))
21
+ pad_t = self.compute_pad(0, t)
22
+ pad_h = self.compute_pad(1, h)
23
+ pad_w = self.compute_pad(2, w)
24
+
25
+ pad_t_f = pad_t // 2
26
+ pad_t_b = pad_t - pad_t_f
27
+ pad_h_f = pad_h // 2
28
+ pad_h_b = pad_h - pad_h_f
29
+ pad_w_f = pad_w // 2
30
+ pad_w_b = pad_w - pad_w_f
31
+
32
+ pad = (pad_w_f, pad_w_b, pad_h_f, pad_h_b, pad_t_f, pad_t_b)
33
+ x = F.pad(x, pad)
34
+ return super(MaxPool3dSamePadding, self).forward(x)
35
+
36
+
37
+ class Unit3D(nn.Module):
38
+
39
+ def __init__(self, in_channels,
40
+ output_channels,
41
+ kernel_shape=(1, 1, 1),
42
+ stride=(1, 1, 1),
43
+ padding=0,
44
+ activation_fn=F.relu,
45
+ use_batch_norm=True,
46
+ use_bias=False,
47
+ name='unit_3d'):
48
+
49
+ """Initializes Unit3D module."""
50
+ super(Unit3D, self).__init__()
51
+
52
+ self._output_channels = output_channels
53
+ self._kernel_shape = kernel_shape
54
+ self._stride = stride
55
+ self._use_batch_norm = use_batch_norm
56
+ self._activation_fn = activation_fn
57
+ self._use_bias = use_bias
58
+ self.name = name
59
+ self.padding = padding
60
+
61
+ self.conv3d = nn.Conv3d(in_channels=in_channels,
62
+ out_channels=self._output_channels,
63
+ kernel_size=self._kernel_shape,
64
+ stride=self._stride,
65
+ padding=0, # we always want padding to be 0 here. We will dynamically pad based on input size in forward function
66
+ bias=self._use_bias)
67
+
68
+ if self._use_batch_norm:
69
+ self.bn = nn.BatchNorm3d(self._output_channels, eps=1e-5, momentum=0.001)
70
+
71
+ def compute_pad(self, dim, s):
72
+ if s % self._stride[dim] == 0:
73
+ return max(self._kernel_shape[dim] - self._stride[dim], 0)
74
+ else:
75
+ return max(self._kernel_shape[dim] - (s % self._stride[dim]), 0)
76
+
77
+
78
+ def forward(self, x):
79
+ # compute 'same' padding
80
+ (batch, channel, t, h, w) = x.size()
81
+ out_t = np.ceil(float(t) / float(self._stride[0]))
82
+ out_h = np.ceil(float(h) / float(self._stride[1]))
83
+ out_w = np.ceil(float(w) / float(self._stride[2]))
84
+ pad_t = self.compute_pad(0, t)
85
+ pad_h = self.compute_pad(1, h)
86
+ pad_w = self.compute_pad(2, w)
87
+
88
+ pad_t_f = pad_t // 2
89
+ pad_t_b = pad_t - pad_t_f
90
+ pad_h_f = pad_h // 2
91
+ pad_h_b = pad_h - pad_h_f
92
+ pad_w_f = pad_w // 2
93
+ pad_w_b = pad_w - pad_w_f
94
+
95
+ pad = (pad_w_f, pad_w_b, pad_h_f, pad_h_b, pad_t_f, pad_t_b)
96
+ x = F.pad(x, pad)
97
+
98
+ x = self.conv3d(x)
99
+ if self._use_batch_norm:
100
+ x = self.bn(x)
101
+ if self._activation_fn is not None:
102
+ x = self._activation_fn(x)
103
+ return x
104
+
105
+
106
+
107
+ class InceptionModule(nn.Module):
108
+ def __init__(self, in_channels, out_channels, name):
109
+ super(InceptionModule, self).__init__()
110
+
111
+ self.b0 = Unit3D(in_channels=in_channels, output_channels=out_channels[0], kernel_shape=[1, 1, 1], padding=0,
112
+ name=name+'/Branch_0/Conv3d_0a_1x1')
113
+ self.b1a = Unit3D(in_channels=in_channels, output_channels=out_channels[1], kernel_shape=[1, 1, 1], padding=0,
114
+ name=name+'/Branch_1/Conv3d_0a_1x1')
115
+ self.b1b = Unit3D(in_channels=out_channels[1], output_channels=out_channels[2], kernel_shape=[3, 3, 3],
116
+ name=name+'/Branch_1/Conv3d_0b_3x3')
117
+ self.b2a = Unit3D(in_channels=in_channels, output_channels=out_channels[3], kernel_shape=[1, 1, 1], padding=0,
118
+ name=name+'/Branch_2/Conv3d_0a_1x1')
119
+ self.b2b = Unit3D(in_channels=out_channels[3], output_channels=out_channels[4], kernel_shape=[3, 3, 3],
120
+ name=name+'/Branch_2/Conv3d_0b_3x3')
121
+ self.b3a = MaxPool3dSamePadding(kernel_size=[3, 3, 3],
122
+ stride=(1, 1, 1), padding=0)
123
+ self.b3b = Unit3D(in_channels=in_channels, output_channels=out_channels[5], kernel_shape=[1, 1, 1], padding=0,
124
+ name=name+'/Branch_3/Conv3d_0b_1x1')
125
+ self.name = name
126
+
127
+ def forward(self, x):
128
+ b0 = self.b0(x)
129
+ b1 = self.b1b(self.b1a(x))
130
+ b2 = self.b2b(self.b2a(x))
131
+ b3 = self.b3b(self.b3a(x))
132
+ return torch.cat([b0,b1,b2,b3], dim=1)
133
+
134
+
135
+ class InceptionI3d(nn.Module):
136
+ """Inception-v1 I3D architecture.
137
+ The model is introduced in:
138
+ Quo Vadis, Action Recognition? A New Model and the Kinetics Dataset
139
+ Joao Carreira, Andrew Zisserman
140
+ https://arxiv.org/pdf/1705.07750v1.pdf.
141
+ See also the Inception architecture, introduced in:
142
+ Going deeper with convolutions
143
+ Christian Szegedy, Wei Liu, Yangqing Jia, Pierre Sermanet, Scott Reed,
144
+ Dragomir Anguelov, Dumitru Erhan, Vincent Vanhoucke, Andrew Rabinovich.
145
+ http://arxiv.org/pdf/1409.4842v1.pdf.
146
+ """
147
+
148
+ # Endpoints of the model in order. During construction, all the endpoints up
149
+ # to a designated `final_endpoint` are returned in a dictionary as the
150
+ # second return value.
151
+ VALID_ENDPOINTS = (
152
+ 'Conv3d_1a_7x7',
153
+ 'MaxPool3d_2a_3x3',
154
+ 'Conv3d_2b_1x1',
155
+ 'Conv3d_2c_3x3',
156
+ 'MaxPool3d_3a_3x3',
157
+ 'Mixed_3b',
158
+ 'Mixed_3c',
159
+ 'MaxPool3d_4a_3x3',
160
+ 'Mixed_4b',
161
+ 'Mixed_4c',
162
+ 'Mixed_4d',
163
+ 'Mixed_4e',
164
+ 'Mixed_4f',
165
+ 'MaxPool3d_5a_2x2',
166
+ 'Mixed_5b',
167
+ 'Mixed_5c',
168
+ 'Logits',
169
+ 'Predictions',
170
+ )
171
+
172
+ def __init__(self, num_classes=400, spatial_squeeze=True,
173
+ final_endpoint='Logits', name='inception_i3d', in_channels=3, dropout_keep_prob=0.5):
174
+ """Initializes I3D model instance.
175
+ Args:
176
+ num_classes: The number of outputs in the logit layer (default 400, which
177
+ matches the Kinetics dataset).
178
+ spatial_squeeze: Whether to squeeze the spatial dimensions for the logits
179
+ before returning (default True).
180
+ final_endpoint: The model contains many possible endpoints.
181
+ `final_endpoint` specifies the last endpoint for the model to be built
182
+ up to. In addition to the output at `final_endpoint`, all the outputs
183
+ at endpoints up to `final_endpoint` will also be returned, in a
184
+ dictionary. `final_endpoint` must be one of
185
+ InceptionI3d.VALID_ENDPOINTS (default 'Logits').
186
+ name: A string (optional). The name of this module.
187
+ Raises:
188
+ ValueError: if `final_endpoint` is not recognized.
189
+ """
190
+
191
+ if final_endpoint not in self.VALID_ENDPOINTS:
192
+ raise ValueError('Unknown final endpoint %s' % final_endpoint)
193
+
194
+ super(InceptionI3d, self).__init__()
195
+ self._num_classes = num_classes
196
+ self._spatial_squeeze = spatial_squeeze
197
+ self._final_endpoint = final_endpoint
198
+ self.logits = None
199
+
200
+ if self._final_endpoint not in self.VALID_ENDPOINTS:
201
+ raise ValueError('Unknown final endpoint %s' % self._final_endpoint)
202
+
203
+ self.end_points = {}
204
+ end_point = 'Conv3d_1a_7x7'
205
+ self.end_points[end_point] = Unit3D(in_channels=in_channels, output_channels=64, kernel_shape=[7, 7, 7],
206
+ stride=(2, 2, 2), padding=(3,3,3), name=name+end_point)
207
+ if self._final_endpoint == end_point: return
208
+
209
+ end_point = 'MaxPool3d_2a_3x3'
210
+ self.end_points[end_point] = MaxPool3dSamePadding(kernel_size=[1, 3, 3], stride=(1, 2, 2),
211
+ padding=0)
212
+ if self._final_endpoint == end_point: return
213
+
214
+ end_point = 'Conv3d_2b_1x1'
215
+ self.end_points[end_point] = Unit3D(in_channels=64, output_channels=64, kernel_shape=[1, 1, 1], padding=0,
216
+ name=name+end_point)
217
+ if self._final_endpoint == end_point: return
218
+
219
+ end_point = 'Conv3d_2c_3x3'
220
+ self.end_points[end_point] = Unit3D(in_channels=64, output_channels=192, kernel_shape=[3, 3, 3], padding=1,
221
+ name=name+end_point)
222
+ if self._final_endpoint == end_point: return
223
+
224
+ end_point = 'MaxPool3d_3a_3x3'
225
+ self.end_points[end_point] = MaxPool3dSamePadding(kernel_size=[1, 3, 3], stride=(1, 2, 2),
226
+ padding=0)
227
+ if self._final_endpoint == end_point: return
228
+
229
+ end_point = 'Mixed_3b'
230
+ self.end_points[end_point] = InceptionModule(192, [64,96,128,16,32,32], name+end_point)
231
+ if self._final_endpoint == end_point: return
232
+
233
+ end_point = 'Mixed_3c'
234
+ self.end_points[end_point] = InceptionModule(256, [128,128,192,32,96,64], name+end_point)
235
+ if self._final_endpoint == end_point: return
236
+
237
+ end_point = 'MaxPool3d_4a_3x3'
238
+ self.end_points[end_point] = MaxPool3dSamePadding(kernel_size=[3, 3, 3], stride=(2, 2, 2),
239
+ padding=0)
240
+ if self._final_endpoint == end_point: return
241
+
242
+ end_point = 'Mixed_4b'
243
+ self.end_points[end_point] = InceptionModule(128+192+96+64, [192,96,208,16,48,64], name+end_point)
244
+ if self._final_endpoint == end_point: return
245
+
246
+ end_point = 'Mixed_4c'
247
+ self.end_points[end_point] = InceptionModule(192+208+48+64, [160,112,224,24,64,64], name+end_point)
248
+ if self._final_endpoint == end_point: return
249
+
250
+ end_point = 'Mixed_4d'
251
+ self.end_points[end_point] = InceptionModule(160+224+64+64, [128,128,256,24,64,64], name+end_point)
252
+ if self._final_endpoint == end_point: return
253
+
254
+ end_point = 'Mixed_4e'
255
+ self.end_points[end_point] = InceptionModule(128+256+64+64, [112,144,288,32,64,64], name+end_point)
256
+ if self._final_endpoint == end_point: return
257
+
258
+ end_point = 'Mixed_4f'
259
+ self.end_points[end_point] = InceptionModule(112+288+64+64, [256,160,320,32,128,128], name+end_point)
260
+ if self._final_endpoint == end_point: return
261
+
262
+ end_point = 'MaxPool3d_5a_2x2'
263
+ self.end_points[end_point] = MaxPool3dSamePadding(kernel_size=[2, 2, 2], stride=(2, 2, 2),
264
+ padding=0)
265
+ if self._final_endpoint == end_point: return
266
+
267
+ end_point = 'Mixed_5b'
268
+ self.end_points[end_point] = InceptionModule(256+320+128+128, [256,160,320,32,128,128], name+end_point)
269
+ if self._final_endpoint == end_point: return
270
+
271
+ end_point = 'Mixed_5c'
272
+ self.end_points[end_point] = InceptionModule(256+320+128+128, [384,192,384,48,128,128], name+end_point)
273
+ if self._final_endpoint == end_point: return
274
+
275
+ end_point = 'Logits'
276
+ self.avg_pool = nn.AvgPool3d(kernel_size=[2, 7, 7],
277
+ stride=(1, 1, 1))
278
+ self.dropout = nn.Dropout(dropout_keep_prob)
279
+ self.logits = Unit3D(in_channels=384+384+128+128, output_channels=self._num_classes,
280
+ kernel_shape=[1, 1, 1],
281
+ padding=0,
282
+ activation_fn=None,
283
+ use_batch_norm=False,
284
+ use_bias=True,
285
+ name='logits')
286
+
287
+ self.build()
288
+
289
+
290
+ def replace_logits(self, num_classes):
291
+ self._num_classes = num_classes
292
+ self.logits = Unit3D(in_channels=384+384+128+128, output_channels=self._num_classes,
293
+ kernel_shape=[1, 1, 1],
294
+ padding=0,
295
+ activation_fn=None,
296
+ use_batch_norm=False,
297
+ use_bias=True,
298
+ name='logits')
299
+
300
+
301
+ def build(self):
302
+ for k in self.end_points.keys():
303
+ self.add_module(k, self.end_points[k])
304
+
305
+ def forward(self, x):
306
+ for end_point in self.VALID_ENDPOINTS:
307
+ if end_point in self.end_points:
308
+ x = self._modules[end_point](x) # use _modules to work with dataparallel
309
+
310
+ x = self.logits(self.dropout(self.avg_pool(x)))
311
+ if self._spatial_squeeze:
312
+ logits = x.squeeze(3).squeeze(3)
313
+ logits = logits.mean(dim=2)
314
+ # logits is batch X time X classes, which is what we want to work with
315
+ return logits
316
+
317
+
318
+ def extract_features(self, x):
319
+ for end_point in self.VALID_ENDPOINTS:
320
+ if end_point in self.end_points:
321
+ x = self._modules[end_point](x)
322
+ return self.avg_pool(x)
common/inception.py ADDED
@@ -0,0 +1,344 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import torchvision
5
+
6
+ try:
7
+ from torchvision.models.utils import load_state_dict_from_url
8
+ except ImportError:
9
+ from torch.utils.model_zoo import load_url as load_state_dict_from_url
10
+
11
+ # Inception weights ported to Pytorch from
12
+ # http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
13
+ FID_WEIGHTS_URL = "https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth" # noqa: E501
14
+
15
+
16
+ class InceptionV3(nn.Module):
17
+ """Pretrained InceptionV3 network returning feature maps"""
18
+
19
+ # Index of default block of inception to return,
20
+ # corresponds to output of final average pooling
21
+ DEFAULT_BLOCK_INDEX = 3
22
+
23
+ # Maps feature dimensionality to their output blocks indices
24
+ BLOCK_INDEX_BY_DIM = {
25
+ 64: 0, # First max pooling features
26
+ 192: 1, # Second max pooling featurs
27
+ 768: 2, # Pre-aux classifier features
28
+ 2048: 3, # Final average pooling features
29
+ }
30
+
31
+ def __init__(
32
+ self,
33
+ output_blocks=(DEFAULT_BLOCK_INDEX,),
34
+ resize_input=True,
35
+ normalize_input=True,
36
+ requires_grad=False,
37
+ use_fid_inception=True,
38
+ ):
39
+ """Build pretrained InceptionV3
40
+
41
+ Parameters
42
+ ----------
43
+ output_blocks : list of int
44
+ Indices of blocks to return features of. Possible values are:
45
+ - 0: corresponds to output of first max pooling
46
+ - 1: corresponds to output of second max pooling
47
+ - 2: corresponds to output which is fed to aux classifier
48
+ - 3: corresponds to output of final average pooling
49
+ resize_input : bool
50
+ If true, bilinearly resizes input to width and height 299 before
51
+ feeding input to model. As the network without fully connected
52
+ layers is fully convolutional, it should be able to handle inputs
53
+ of arbitrary size, so resizing might not be strictly needed
54
+ normalize_input : bool
55
+ If true, scales the input from range (0, 1) to the range the
56
+ pretrained Inception network expects, namely (-1, 1)
57
+ requires_grad : bool
58
+ If true, parameters of the model require gradients. Possibly useful
59
+ for finetuning the network
60
+ use_fid_inception : bool
61
+ If true, uses the pretrained Inception model used in Tensorflow's
62
+ FID implementation. If false, uses the pretrained Inception model
63
+ available in torchvision. The FID Inception model has different
64
+ weights and a slightly different structure from torchvision's
65
+ Inception model. If you want to compute FID scores, you are
66
+ strongly advised to set this parameter to true to get comparable
67
+ results.
68
+ """
69
+ super(InceptionV3, self).__init__()
70
+
71
+ self.resize_input = resize_input
72
+ self.normalize_input = normalize_input
73
+ self.output_blocks = sorted(output_blocks)
74
+ self.last_needed_block = max(output_blocks)
75
+
76
+ assert self.last_needed_block <= 3, "Last possible output block index is 3"
77
+
78
+ self.blocks = nn.ModuleList()
79
+
80
+ if use_fid_inception:
81
+ inception = fid_inception_v3()
82
+ else:
83
+ inception = _inception_v3(weights="DEFAULT")
84
+
85
+ # Block 0: input to maxpool1
86
+ block0 = [
87
+ inception.Conv2d_1a_3x3,
88
+ inception.Conv2d_2a_3x3,
89
+ inception.Conv2d_2b_3x3,
90
+ nn.MaxPool2d(kernel_size=3, stride=2),
91
+ ]
92
+ self.blocks.append(nn.Sequential(*block0))
93
+
94
+ # Block 1: maxpool1 to maxpool2
95
+ if self.last_needed_block >= 1:
96
+ block1 = [
97
+ inception.Conv2d_3b_1x1,
98
+ inception.Conv2d_4a_3x3,
99
+ nn.MaxPool2d(kernel_size=3, stride=2),
100
+ ]
101
+ self.blocks.append(nn.Sequential(*block1))
102
+
103
+ # Block 2: maxpool2 to aux classifier
104
+ if self.last_needed_block >= 2:
105
+ block2 = [
106
+ inception.Mixed_5b,
107
+ inception.Mixed_5c,
108
+ inception.Mixed_5d,
109
+ inception.Mixed_6a,
110
+ inception.Mixed_6b,
111
+ inception.Mixed_6c,
112
+ inception.Mixed_6d,
113
+ inception.Mixed_6e,
114
+ ]
115
+ self.blocks.append(nn.Sequential(*block2))
116
+
117
+ # Block 3: aux classifier to final avgpool
118
+ if self.last_needed_block >= 3:
119
+ block3 = [
120
+ inception.Mixed_7a,
121
+ inception.Mixed_7b,
122
+ inception.Mixed_7c,
123
+ nn.AdaptiveAvgPool2d(output_size=(1, 1)),
124
+ ]
125
+ self.blocks.append(nn.Sequential(*block3))
126
+
127
+ for param in self.parameters():
128
+ param.requires_grad = requires_grad
129
+
130
+ def forward(self, inp):
131
+ """Get Inception feature maps
132
+
133
+ Parameters
134
+ ----------
135
+ inp : torch.autograd.Variable
136
+ Input tensor of shape Bx3xHxW. Values are expected to be in
137
+ range (0, 1)
138
+
139
+ Returns
140
+ -------
141
+ List of torch.autograd.Variable, corresponding to the selected output
142
+ block, sorted ascending by index
143
+ """
144
+ outp = []
145
+ x = inp
146
+
147
+ if self.resize_input:
148
+ x = F.interpolate(x, size=(299, 299), mode="bilinear", align_corners=False)
149
+
150
+ if self.normalize_input:
151
+ x = 2 * x - 1 # Scale from range (0, 1) to range (-1, 1)
152
+
153
+ for idx, block in enumerate(self.blocks):
154
+ x = block(x)
155
+ if idx in self.output_blocks:
156
+ outp.append(x)
157
+
158
+ if idx == self.last_needed_block:
159
+ break
160
+
161
+ return outp
162
+
163
+
164
+ def _inception_v3(*args, **kwargs):
165
+ """Wraps `torchvision.models.inception_v3`"""
166
+ try:
167
+ version = tuple(map(int, torchvision.__version__.split(".")[:2]))
168
+ except ValueError:
169
+ # Just a caution against weird version strings
170
+ version = (0,)
171
+
172
+ # Skips default weight inititialization if supported by torchvision
173
+ # version. See https://github.com/mseitzer/pytorch-fid/issues/28.
174
+ if version >= (0, 6):
175
+ kwargs["init_weights"] = False
176
+
177
+ # Backwards compatibility: `weights` argument was handled by `pretrained`
178
+ # argument prior to version 0.13.
179
+ if version < (0, 13) and "weights" in kwargs:
180
+ if kwargs["weights"] == "DEFAULT":
181
+ kwargs["pretrained"] = True
182
+ elif kwargs["weights"] is None:
183
+ kwargs["pretrained"] = False
184
+ else:
185
+ raise ValueError(
186
+ "weights=={} not supported in torchvision {}".format(
187
+ kwargs["weights"], torchvision.__version__
188
+ )
189
+ )
190
+ del kwargs["weights"]
191
+
192
+ return torchvision.models.inception_v3(*args, **kwargs)
193
+
194
+
195
+ def fid_inception_v3():
196
+ """Build pretrained Inception model for FID computation
197
+
198
+ The Inception model for FID computation uses a different set of weights
199
+ and has a slightly different structure than torchvision's Inception.
200
+
201
+ This method first constructs torchvision's Inception and then patches the
202
+ necessary parts that are different in the FID Inception model.
203
+ """
204
+ inception = _inception_v3(num_classes=1008, aux_logits=False, weights=None)
205
+ inception.Mixed_5b = FIDInceptionA(192, pool_features=32)
206
+ inception.Mixed_5c = FIDInceptionA(256, pool_features=64)
207
+ inception.Mixed_5d = FIDInceptionA(288, pool_features=64)
208
+ inception.Mixed_6b = FIDInceptionC(768, channels_7x7=128)
209
+ inception.Mixed_6c = FIDInceptionC(768, channels_7x7=160)
210
+ inception.Mixed_6d = FIDInceptionC(768, channels_7x7=160)
211
+ inception.Mixed_6e = FIDInceptionC(768, channels_7x7=192)
212
+ inception.Mixed_7b = FIDInceptionE_1(1280)
213
+ inception.Mixed_7c = FIDInceptionE_2(2048)
214
+
215
+ state_dict = load_state_dict_from_url(FID_WEIGHTS_URL, progress=True)
216
+ inception.load_state_dict(state_dict)
217
+ return inception
218
+
219
+
220
+ class FIDInceptionA(torchvision.models.inception.InceptionA):
221
+ """InceptionA block patched for FID computation"""
222
+
223
+ def __init__(self, in_channels, pool_features):
224
+ super(FIDInceptionA, self).__init__(in_channels, pool_features)
225
+
226
+ def forward(self, x):
227
+ branch1x1 = self.branch1x1(x)
228
+
229
+ branch5x5 = self.branch5x5_1(x)
230
+ branch5x5 = self.branch5x5_2(branch5x5)
231
+
232
+ branch3x3dbl = self.branch3x3dbl_1(x)
233
+ branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
234
+ branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)
235
+
236
+ # Patch: Tensorflow's average pool does not use the padded zero's in
237
+ # its average calculation
238
+ branch_pool = F.avg_pool2d(
239
+ x, kernel_size=3, stride=1, padding=1, count_include_pad=False
240
+ )
241
+ branch_pool = self.branch_pool(branch_pool)
242
+
243
+ outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool]
244
+ return torch.cat(outputs, 1)
245
+
246
+
247
+ class FIDInceptionC(torchvision.models.inception.InceptionC):
248
+ """InceptionC block patched for FID computation"""
249
+
250
+ def __init__(self, in_channels, channels_7x7):
251
+ super(FIDInceptionC, self).__init__(in_channels, channels_7x7)
252
+
253
+ def forward(self, x):
254
+ branch1x1 = self.branch1x1(x)
255
+
256
+ branch7x7 = self.branch7x7_1(x)
257
+ branch7x7 = self.branch7x7_2(branch7x7)
258
+ branch7x7 = self.branch7x7_3(branch7x7)
259
+
260
+ branch7x7dbl = self.branch7x7dbl_1(x)
261
+ branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl)
262
+ branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl)
263
+ branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl)
264
+ branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl)
265
+
266
+ # Patch: Tensorflow's average pool does not use the padded zero's in
267
+ # its average calculation
268
+ branch_pool = F.avg_pool2d(
269
+ x, kernel_size=3, stride=1, padding=1, count_include_pad=False
270
+ )
271
+ branch_pool = self.branch_pool(branch_pool)
272
+
273
+ outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool]
274
+ return torch.cat(outputs, 1)
275
+
276
+
277
+ class FIDInceptionE_1(torchvision.models.inception.InceptionE):
278
+ """First InceptionE block patched for FID computation"""
279
+
280
+ def __init__(self, in_channels):
281
+ super(FIDInceptionE_1, self).__init__(in_channels)
282
+
283
+ def forward(self, x):
284
+ branch1x1 = self.branch1x1(x)
285
+
286
+ branch3x3 = self.branch3x3_1(x)
287
+ branch3x3 = [
288
+ self.branch3x3_2a(branch3x3),
289
+ self.branch3x3_2b(branch3x3),
290
+ ]
291
+ branch3x3 = torch.cat(branch3x3, 1)
292
+
293
+ branch3x3dbl = self.branch3x3dbl_1(x)
294
+ branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
295
+ branch3x3dbl = [
296
+ self.branch3x3dbl_3a(branch3x3dbl),
297
+ self.branch3x3dbl_3b(branch3x3dbl),
298
+ ]
299
+ branch3x3dbl = torch.cat(branch3x3dbl, 1)
300
+
301
+ # Patch: Tensorflow's average pool does not use the padded zero's in
302
+ # its average calculation
303
+ branch_pool = F.avg_pool2d(
304
+ x, kernel_size=3, stride=1, padding=1, count_include_pad=False
305
+ )
306
+ branch_pool = self.branch_pool(branch_pool)
307
+
308
+ outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
309
+ return torch.cat(outputs, 1)
310
+
311
+
312
+ class FIDInceptionE_2(torchvision.models.inception.InceptionE):
313
+ """Second InceptionE block patched for FID computation"""
314
+
315
+ def __init__(self, in_channels):
316
+ super(FIDInceptionE_2, self).__init__(in_channels)
317
+
318
+ def forward(self, x):
319
+ branch1x1 = self.branch1x1(x)
320
+
321
+ branch3x3 = self.branch3x3_1(x)
322
+ branch3x3 = [
323
+ self.branch3x3_2a(branch3x3),
324
+ self.branch3x3_2b(branch3x3),
325
+ ]
326
+ branch3x3 = torch.cat(branch3x3, 1)
327
+
328
+ branch3x3dbl = self.branch3x3dbl_1(x)
329
+ branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
330
+ branch3x3dbl = [
331
+ self.branch3x3dbl_3a(branch3x3dbl),
332
+ self.branch3x3dbl_3b(branch3x3dbl),
333
+ ]
334
+ branch3x3dbl = torch.cat(branch3x3dbl, 1)
335
+
336
+ # Patch: The FID Inception model uses max pooling instead of average
337
+ # pooling. This is likely an error in this specific Inception
338
+ # implementation, as other Inception models use average pooling here
339
+ # (which matches the description in the paper).
340
+ branch_pool = F.max_pool2d(x, kernel_size=3, stride=1, padding=1)
341
+ branch_pool = self.branch_pool(branch_pool)
342
+
343
+ outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
344
+ return torch.cat(outputs, 1)
common/plot/__init__.py ADDED
File without changes
common/plot/aggregated_output.csv ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name,bridge_data_v2/teacher_force_psnr,bridge_data_v2/teacher_force_psnr_delta,bridge_data_v2/teacher_force_ssim,bridge_data_v2/teacher_force_pred_lpips,bridge_data_v2/teacher_force_loss,bridge_data_v2/num_examples,fractal20220817_data/teacher_force_psnr,fractal20220817_data/teacher_force_psnr_delta,fractal20220817_data/teacher_force_ssim,fractal20220817_data/teacher_force_pred_lpips,fractal20220817_data/teacher_force_loss,fractal20220817_data/num_examples,language_table/teacher_force_psnr,language_table/teacher_force_psnr_delta,language_table/teacher_force_ssim,language_table/teacher_force_pred_lpips,language_table/teacher_force_loss,language_table/num_examples,ucsd_pick_and_place_dataset_converted_externally_to_rlds/teacher_force_psnr,ucsd_pick_and_place_dataset_converted_externally_to_rlds/teacher_force_psnr_delta,ucsd_pick_and_place_dataset_converted_externally_to_rlds/teacher_force_ssim,ucsd_pick_and_place_dataset_converted_externally_to_rlds/teacher_force_pred_lpips,ucsd_pick_and_place_dataset_converted_externally_to_rlds/teacher_force_loss,ucsd_pick_and_place_dataset_converted_externally_to_rlds/num_examples,kaist_nonprehensile_converted_externally_to_rlds/teacher_force_psnr,kaist_nonprehensile_converted_externally_to_rlds/teacher_force_psnr_delta,kaist_nonprehensile_converted_externally_to_rlds/teacher_force_ssim,kaist_nonprehensile_converted_externally_to_rlds/teacher_force_pred_lpips,kaist_nonprehensile_converted_externally_to_rlds/teacher_force_loss,kaist_nonprehensile_converted_externally_to_rlds/num_examples,ucsd_kitchen_dataset_converted_externally_to_rlds/teacher_force_psnr,ucsd_kitchen_dataset_converted_externally_to_rlds/teacher_force_psnr_delta,ucsd_kitchen_dataset_converted_externally_to_rlds/teacher_force_ssim,ucsd_kitchen_dataset_converted_externally_to_rlds/teacher_force_pred_lpips,ucsd_kitchen_dataset_converted_externally_to_rlds/teacher_force_loss,ucsd_kitchen_dataset_converted_externally_to_rlds/num_examples,utokyo_xarm_bimanual_converted_externally_to_rlds/teacher_force_psnr,utokyo_xarm_bimanual_converted_externally_to_rlds/teacher_force_psnr_delta,utokyo_xarm_bimanual_converted_externally_to_rlds/teacher_force_ssim,utokyo_xarm_bimanual_converted_externally_to_rlds/teacher_force_pred_lpips,utokyo_xarm_bimanual_converted_externally_to_rlds/teacher_force_loss,utokyo_xarm_bimanual_converted_externally_to_rlds/num_examples,stanford_hydra_dataset_converted_externally_to_rlds/teacher_force_psnr,stanford_hydra_dataset_converted_externally_to_rlds/teacher_force_psnr_delta,stanford_hydra_dataset_converted_externally_to_rlds/teacher_force_ssim,stanford_hydra_dataset_converted_externally_to_rlds/teacher_force_pred_lpips,stanford_hydra_dataset_converted_externally_to_rlds/teacher_force_loss,stanford_hydra_dataset_converted_externally_to_rlds/num_examples,austin_sirius_dataset_converted_externally_to_rlds/teacher_force_psnr,austin_sirius_dataset_converted_externally_to_rlds/teacher_force_psnr_delta,austin_sirius_dataset_converted_externally_to_rlds/teacher_force_ssim,austin_sirius_dataset_converted_externally_to_rlds/teacher_force_pred_lpips,austin_sirius_dataset_converted_externally_to_rlds/teacher_force_loss,austin_sirius_dataset_converted_externally_to_rlds/num_examples,berkeley_fanuc_manipulation/teacher_force_psnr,berkeley_fanuc_manipulation/teacher_force_psnr_delta,berkeley_fanuc_manipulation/teacher_force_ssim,berkeley_fanuc_manipulation/teacher_force_pred_lpips,berkeley_fanuc_manipulation/teacher_force_loss,berkeley_fanuc_manipulation/num_examples,berkeley_mvp_converted_externally_to_rlds/teacher_force_psnr,berkeley_mvp_converted_externally_to_rlds/teacher_force_psnr_delta,berkeley_mvp_converted_externally_to_rlds/teacher_force_ssim,berkeley_mvp_converted_externally_to_rlds/teacher_force_pred_lpips,berkeley_mvp_converted_externally_to_rlds/teacher_force_loss,berkeley_mvp_converted_externally_to_rlds/num_examples,berkeley_rpt_converted_externally_to_rlds/teacher_force_psnr,berkeley_rpt_converted_externally_to_rlds/teacher_force_psnr_delta,berkeley_rpt_converted_externally_to_rlds/teacher_force_ssim,berkeley_rpt_converted_externally_to_rlds/teacher_force_pred_lpips,berkeley_rpt_converted_externally_to_rlds/teacher_force_loss,berkeley_rpt_converted_externally_to_rlds/num_examples,cmu_play_fusion/teacher_force_psnr,cmu_play_fusion/teacher_force_psnr_delta,cmu_play_fusion/teacher_force_ssim,cmu_play_fusion/teacher_force_pred_lpips,cmu_play_fusion/teacher_force_loss,cmu_play_fusion/num_examples,iamlab_cmu_pickup_insert_converted_externally_to_rlds/teacher_force_psnr,iamlab_cmu_pickup_insert_converted_externally_to_rlds/teacher_force_psnr_delta,iamlab_cmu_pickup_insert_converted_externally_to_rlds/teacher_force_ssim,iamlab_cmu_pickup_insert_converted_externally_to_rlds/teacher_force_pred_lpips,iamlab_cmu_pickup_insert_converted_externally_to_rlds/teacher_force_loss,iamlab_cmu_pickup_insert_converted_externally_to_rlds/num_examples,qut_dexterous_manpulation/teacher_force_psnr,qut_dexterous_manpulation/teacher_force_psnr_delta,qut_dexterous_manpulation/teacher_force_ssim,qut_dexterous_manpulation/teacher_force_pred_lpips,qut_dexterous_manpulation/teacher_force_loss,qut_dexterous_manpulation/num_examples,robo_net/teacher_force_psnr,robo_net/teacher_force_psnr_delta,robo_net/teacher_force_ssim,robo_net/teacher_force_pred_lpips,robo_net/teacher_force_loss,robo_net/num_examples,furniture_bench_dataset_converted_externally_to_rlds/teacher_force_psnr,furniture_bench_dataset_converted_externally_to_rlds/teacher_force_psnr_delta,furniture_bench_dataset_converted_externally_to_rlds/teacher_force_ssim,furniture_bench_dataset_converted_externally_to_rlds/teacher_force_pred_lpips,furniture_bench_dataset_converted_externally_to_rlds/teacher_force_loss,furniture_bench_dataset_converted_externally_to_rlds/num_examples,dlr_sara_grid_clamp_converted_externally_to_rlds/teacher_force_psnr,dlr_sara_grid_clamp_converted_externally_to_rlds/teacher_force_psnr_delta,dlr_sara_grid_clamp_converted_externally_to_rlds/teacher_force_ssim,dlr_sara_grid_clamp_converted_externally_to_rlds/teacher_force_pred_lpips,dlr_sara_grid_clamp_converted_externally_to_rlds/teacher_force_loss,dlr_sara_grid_clamp_converted_externally_to_rlds/num_examples,cmu_stretch/teacher_force_psnr,cmu_stretch/teacher_force_psnr_delta,cmu_stretch/teacher_force_ssim,cmu_stretch/teacher_force_pred_lpips,cmu_stretch/teacher_force_loss,cmu_stretch/num_examples,spoc/teacher_force_psnr,spoc/teacher_force_psnr_delta,spoc/teacher_force_ssim,spoc/teacher_force_pred_lpips,spoc/teacher_force_loss,spoc/num_examples,columbia_cairlab_pusht_real/teacher_force_psnr,columbia_cairlab_pusht_real/teacher_force_psnr_delta,columbia_cairlab_pusht_real/teacher_force_ssim,columbia_cairlab_pusht_real/teacher_force_pred_lpips,columbia_cairlab_pusht_real/teacher_force_loss,columbia_cairlab_pusht_real/num_examples,droid/teacher_force_psnr,droid/teacher_force_psnr_delta,droid/teacher_force_ssim,droid/teacher_force_pred_lpips,droid/teacher_force_loss,droid/num_examples,toto/teacher_force_psnr,toto/teacher_force_psnr_delta,toto/teacher_force_ssim,toto/teacher_force_pred_lpips,toto/teacher_force_loss,toto/num_examples,io_ai_tech/teacher_force_psnr,io_ai_tech/teacher_force_psnr_delta,io_ai_tech/teacher_force_ssim,io_ai_tech/teacher_force_pred_lpips,io_ai_tech/teacher_force_loss,io_ai_tech/num_examples,conq_hose_manipulation/teacher_force_psnr,conq_hose_manipulation/teacher_force_psnr_delta,conq_hose_manipulation/teacher_force_ssim,conq_hose_manipulation/teacher_force_pred_lpips,conq_hose_manipulation/teacher_force_loss,conq_hose_manipulation/num_examples,dobbe/teacher_force_psnr,dobbe/teacher_force_psnr_delta,dobbe/teacher_force_ssim,dobbe/teacher_force_pred_lpips,dobbe/teacher_force_loss,dobbe/num_examples,berkeley_gnm_cory_hall/teacher_force_psnr,berkeley_gnm_cory_hall/teacher_force_psnr_delta,berkeley_gnm_cory_hall/teacher_force_ssim,berkeley_gnm_cory_hall/teacher_force_pred_lpips,berkeley_gnm_cory_hall/teacher_force_loss,berkeley_gnm_cory_hall/num_examples,plex_robosuite/teacher_force_psnr,plex_robosuite/teacher_force_psnr_delta,plex_robosuite/teacher_force_ssim,plex_robosuite/teacher_force_pred_lpips,plex_robosuite/teacher_force_loss,plex_robosuite/num_examples,usc_cloth_sim_converted_externally_to_rlds/teacher_force_psnr,usc_cloth_sim_converted_externally_to_rlds/teacher_force_psnr_delta,usc_cloth_sim_converted_externally_to_rlds/teacher_force_ssim,usc_cloth_sim_converted_externally_to_rlds/teacher_force_pred_lpips,usc_cloth_sim_converted_externally_to_rlds/teacher_force_loss,usc_cloth_sim_converted_externally_to_rlds/num_examples,berkeley_cable_routing/teacher_force_psnr,berkeley_cable_routing/teacher_force_psnr_delta,berkeley_cable_routing/teacher_force_ssim,berkeley_cable_routing/teacher_force_pred_lpips,berkeley_cable_routing/teacher_force_loss,berkeley_cable_routing/num_examples,imperial_wrist_dataset/teacher_force_psnr,imperial_wrist_dataset/teacher_force_psnr_delta,imperial_wrist_dataset/teacher_force_ssim,imperial_wrist_dataset/teacher_force_pred_lpips,imperial_wrist_dataset/teacher_force_loss,imperial_wrist_dataset/num_examples,bc_z/teacher_force_psnr,bc_z/teacher_force_psnr_delta,bc_z/teacher_force_ssim,bc_z/teacher_force_pred_lpips,bc_z/teacher_force_loss,bc_z/num_examples,kuka/teacher_force_psnr,kuka/teacher_force_psnr_delta,kuka/teacher_force_ssim,kuka/teacher_force_pred_lpips,kuka/teacher_force_loss,kuka/num_examples,roboturk/teacher_force_psnr,roboturk/teacher_force_psnr_delta,roboturk/teacher_force_ssim,roboturk/teacher_force_pred_lpips,roboturk/teacher_force_loss,roboturk/num_examples,metaworld/teacher_force_psnr,metaworld/teacher_force_psnr_delta,metaworld/teacher_force_ssim,metaworld/teacher_force_pred_lpips,metaworld/teacher_force_loss,metaworld/num_examples,robomimic/teacher_force_psnr,robomimic/teacher_force_psnr_delta,robomimic/teacher_force_ssim,robomimic/teacher_force_pred_lpips,robomimic/teacher_force_loss,robomimic/num_examples,epic_kitchen/teacher_force_psnr,epic_kitchen/teacher_force_psnr_delta,epic_kitchen/teacher_force_ssim,epic_kitchen/teacher_force_pred_lpips,epic_kitchen/teacher_force_loss,epic_kitchen/num_examples,ego4d/teacher_force_psnr,ego4d/teacher_force_psnr_delta,ego4d/teacher_force_ssim,ego4d/teacher_force_pred_lpips,ego4d/teacher_force_loss,ego4d/num_examples,nyu_door_opening_surprising_effectiveness/teacher_force_psnr,nyu_door_opening_surprising_effectiveness/teacher_force_psnr_delta,nyu_door_opening_surprising_effectiveness/teacher_force_ssim,nyu_door_opening_surprising_effectiveness/teacher_force_pred_lpips,nyu_door_opening_surprising_effectiveness/teacher_force_loss,nyu_door_opening_surprising_effectiveness/num_examples
2
+ 20.11654281616211,0.2406988888978958,0.6522064805030823,0.16225013136863708,4.879761219024658,500,22.18416404724121,0.35506966710090637,0.6869667768478394,0.15024960041046143,4.770835876464844,500,21.715396881103516,0.3029274344444275,0.6992139220237732,0.16471771895885468,5.294665813446045,198,21.16069793701172,0.8068287372589111,0.856053352355957,0.13014769554138184,6.175273895263672,268,20.990825653076172,-0.10340484231710434,0.7044205069541931,0.18493501842021942,7.725522518157959,215,15.698708534240723,0.0012377961538732052,0.5898451209068298,0.17795829474925995,5.001875400543213,44,19.282543182373047,0.25010883808135986,0.6505519151687622,0.17733542621135712,5.174380779266357,26,17.12480926513672,-0.2587044835090637,0.705141007900238,0.19221702218055725,5.820067882537842,500,19.765209197998047,0.520576000213623,0.7926568984985352,0.1714717447757721,6.499162673950195,500,24.512998580932617,0.05698919668793678,0.7914198637008667,0.10842812061309814,3.0430972576141357,197,25.59234619140625,0.10779287666082382,0.8336195349693298,0.19609792530536652,9.673583984375,145,19.04483985900879,-1.549579381942749,0.8432617783546448,0.23590296506881714,9.34917163848877,500,25.48477554321289,-0.2415604293346405,0.808574378490448,0.14057636260986328,3.712907552719116,500,20.40241241455078,-1.7759240865707397,0.695418119430542,0.18546271324157715,5.312252044677734,500,18.150922775268555,0.07440949231386185,0.623611569404602,0.19956742227077484,2.496238946914673,500,19.270723342895508,0.13051848113536835,0.6409730315208435,0.21255843341350555,5.373872756958008,400,14.527210235595703,-1.1576420068740845,0.6855177283287048,0.2406100034713745,6.9484028816223145,500,22.118375778198242,0.12427309900522232,0.760040819644928,0.13304449617862701,2.0709831714630127,142,27.101343154907227,0.002104964340105653,0.8569799661636353,0.11864025890827179,4.056473731994629,304,14.499595642089844,-2.287384033203125,0.6487097144126892,0.5801433324813843,11.471040725708008,500,19.862470626831055,0.004642155021429062,0.8116016387939453,0.15320290625095367,7.407879829406738,290,18.82098960876465,1.036180019378662,0.6922352910041809,0.18524041771888733,5.3734941482543945,500,19.008501052856445,0.08843769133090973,0.654740035533905,0.19600719213485718,5.878746032714844,500,25.3173885345459,-1.4838885068893433,0.8376902341842651,0.091035857796669,6.641345977783203,500,0,0,0,0,0,0,23.15846061706543,0.19177615642547607,0.7468024492263794,0.14149916172027588,7.459285259246826,500,20.05571746826172,-0.6787086129188538,0.8033839464187622,0.23725976049900055,9.829147338867188,185,25.470836639404297,-0.04034167900681496,0.8336919546127319,0.13884975016117096,2.244248867034912,307,27.915611267089844,2.5503063201904297,0.9527876377105713,0.09793127328157425,1.2040963172912598,200,20.912508010864258,0.591270387172699,0.8214857578277588,0.2098020762205124,4.666227340698242,3,23.71303939819336,-0.19690543413162231,0.6411390900611877,0.09051118791103363,8.156048774719238,58,23.800819396972656,0.357938289642334,0.7015560269355774,0.16095396876335144,5.172791004180908,500,19.602933883666992,0.08206112682819366,0.6492921710014343,0.20972613990306854,5.722141265869141,136,15.88235092163086,0.034370679408311844,0.6604549288749695,0.26275190711021423,7.8473687171936035,475,0,0,0,0,0,0,20.04568099975586,-0.6712338328361511,0.8586370944976807,0.14845611155033112,4.173150062561035,60,0,0,0,0,0,0,0,0,0,0,0,0,16.780973434448242,0.36480912566185,0.5418305993080139,0.23359735310077667,9.743389129638672,63,0
3
+ 20.11654281616211,0.2406988888978958,0.6522064805030823,0.16225013136863708,4.879761219024658,500,22.18416404724121,0.35506966710090637,0.6869667768478394,0.15024960041046143,4.770835876464844,500,21.715396881103516,0.3029274344444275,0.6992139220237732,0.16471771895885468,5.294665813446045,198,21.16069793701172,0.8068287372589111,0.856053352355957,0.13014769554138184,6.175273895263672,268,20.990825653076172,-0.10340484231710434,0.7044205069541931,0.18493501842021942,7.725522518157959,215,15.698708534240723,0.0012377961538732052,0.5898451209068298,0.17795829474925995,5.001875400543213,44,19.282543182373047,0.25010883808135986,0.6505519151687622,0.17733542621135712,5.174380779266357,26,17.12480926513672,-0.2587044835090637,0.705141007900238,0.19221702218055725,5.820067882537842,500,19.765209197998047,0.520576000213623,0.7926568984985352,0.1714717447757721,6.499162673950195,500,24.512998580932617,0.05698919668793678,0.7914198637008667,0.10842812061309814,3.0430972576141357,197,25.59234619140625,0.10779287666082382,0.8336195349693298,0.19609792530536652,9.673583984375,145,19.04483985900879,-1.549579381942749,0.8432617783546448,0.23590296506881714,9.34917163848877,500,25.48477554321289,-0.2415604293346405,0.808574378490448,0.14057636260986328,3.712907552719116,500,20.40241241455078,-1.7759240865707397,0.695418119430542,0.18546271324157715,5.312252044677734,500,18.150922775268555,0.07440949231386185,0.623611569404602,0.19956742227077484,2.496238946914673,500,19.270723342895508,0.13051848113536835,0.6409730315208435,0.21255843341350555,5.373872756958008,400,14.527210235595703,-1.1576420068740845,0.6855177283287048,0.2406100034713745,6.9484028816223145,500,22.118375778198242,0.12427309900522232,0.760040819644928,0.13304449617862701,2.0709831714630127,142,27.101343154907227,0.002104964340105653,0.8569799661636353,0.11864025890827179,4.056473731994629,304,14.499595642089844,-2.287384033203125,0.6487097144126892,0.5801433324813843,11.471040725708008,500,19.862470626831055,0.004642155021429062,0.8116016387939453,0.15320290625095367,7.407879829406738,290,18.82098960876465,1.036180019378662,0.6922352910041809,0.18524041771888733,5.3734941482543945,500,19.008501052856445,0.08843769133090973,0.654740035533905,0.19600719213485718,5.878746032714844,500,25.3173885345459,-1.4838885068893433,0.8376902341842651,0.091035857796669,6.641345977783203,500,0,0,0,0,0,0,23.15846061706543,0.19177615642547607,0.7468024492263794,0.14149916172027588,7.459285259246826,500,20.05571746826172,-0.6787086129188538,0.8033839464187622,0.23725976049900055,9.829147338867188,185,25.470836639404297,-0.04034167900681496,0.8336919546127319,0.13884975016117096,2.244248867034912,307,27.915611267089844,2.5503063201904297,0.9527876377105713,0.09793127328157425,1.2040963172912598,200,20.912508010864258,0.591270387172699,0.8214857578277588,0.2098020762205124,4.666227340698242,3,23.71303939819336,-0.19690543413162231,0.6411390900611877,0.09051118791103363,8.156048774719238,58,23.800819396972656,0.357938289642334,0.7015560269355774,0.16095396876335144,5.172791004180908,500,19.602933883666992,0.08206112682819366,0.6492921710014343,0.20972613990306854,5.722141265869141,136,15.88235092163086,0.034370679408311844,0.6604549288749695,0.26275190711021423,7.8473687171936035,475,0,0,0,0,0,0,20.04568099975586,-0.6712338328361511,0.8586370944976807,0.14845611155033112,4.173150062561035,60,0,0,0,0,0,0,0,0,0,0,0,0,16.780973434448242,0.36480912566185,0.5418305993080139,0.23359735310077667,9.743389129638672,63,0
4
+ 20.22026252746582,0.38262513279914856,0.6533553600311279,0.1599833220243454,4.625724792480469,500,22.20106315612793,1.1525933742523193,0.6865711212158203,0.1498977541923523,4.709808826446533,500,21.663299560546875,0.315200537443161,0.6973788142204285,0.16729794442653656,5.099008083343506,198,21.049013137817383,1.241530179977417,0.8553103804588318,0.1382942497730255,6.023053169250488,268,20.187990188598633,0.2288614809513092,0.6624469757080078,0.21102775633335114,8.721923828125,215,14.019247055053711,0.1637289673089981,0.5697340369224548,0.2109413594007492,5.538760662078857,44,19.34811782836914,1.7672139406204224,0.65287184715271,0.17289860546588898,4.801314830780029,26,17.78873634338379,0.7138077020645142,0.7086768746376038,0.17418105900287628,5.320345878601074,500,19.624094009399414,0.8938031792640686,0.7965993285179138,0.19896067678928375,6.5370965003967285,500,24.533777236938477,0.036842938512563705,0.7917919754981995,0.10872980952262878,2.8953208923339844,197,25.017925262451172,0.39651936292648315,0.8296465873718262,0.19340361654758453,9.546339988708496,145,19.407617568969727,0.20671948790550232,0.8511928915977478,0.21369731426239014,8.630515098571777,500,25.74650001525879,0.004801941104233265,0.8135946989059448,0.13683146238327026,3.3414487838745117,500,18.16292953491211,-1.5063064098358154,0.6798076033592224,0.2016068398952484,4.924650192260742,500,17.171728134155273,0.12000428140163422,0.6030166149139404,0.21893559396266937,2.7387208938598633,200,19.084640502929688,0.17159269750118256,0.6400240659713745,0.21725738048553467,5.156170845031738,200,14.959861755371094,-0.08570398390293121,0.6782578825950623,0.24334610998630524,6.023822784423828,200,20.60854148864746,1.239460825920105,0.7399578094482422,0.15486611425876617,2.6809659004211426,142,27.038379669189453,-0.04835420474410057,0.8582215905189514,0.11914262175559998,3.703599214553833,200,16.528011322021484,0.07733757048845291,0.6863109469413757,0.43281883001327515,10.088181495666504,200,20.27916717529297,0.39283424615859985,0.8197209239006042,0.1518518626689911,7.038107872009277,200,19.77338218688965,-0.10181392729282379,0.7324086427688599,0.1659238040447235,5.009759426116943,200,16.711793899536133,-0.7416815161705017,0.5757627487182617,0.25752225518226624,7.026858329772949,200,26.67308807373047,-0.19710086286067963,0.849867582321167,0.08147656917572021,6.09824275970459,200,0,0,0,0,0,0,22.918180465698242,0.21646614372730255,0.7544977068901062,0.13648433983325958,6.675195693969727,200,20.61359977722168,-0.2607249617576599,0.8171404600143433,0.2431134730577469,9.399118423461914,185,25.659948348999023,-0.007021207828074694,0.8343022465705872,0.13590268790721893,2.109799385070801,200,28.551301956176758,0.999920666217804,0.9539777040481567,0.09078202396631241,1.1626181602478027,200,20.705917358398438,0.19353322684764862,0.822070300579071,0.2162921279668808,4.443334579467773,3,23.890663146972656,0.042181383818387985,0.6438125371932983,0.09075239300727844,7.78416109085083,58,23.974327087402344,0.4367068409919739,0.6935603022575378,0.15629488229751587,4.970770835876465,200,19.62123680114746,0.2064054161310196,0.650097131729126,0.21027418971061707,5.509530067443848,136,15.70105266571045,-0.008274257183074951,0.6658218502998352,0.26436904072761536,7.878448009490967,200,0,0,0,0,0,0,21.392261505126953,0.8078638315200806,0.8658612966537476,0.1169213280081749,4.069273471832275,60,0,0,0,0,0,0,0,0,0,0,0,0,16.521770477294922,0.4119645357131958,0.5416176319122314,0.2529039680957794,9.572311401367188,63,0
5
+ 20.192195892333984,0.0013111135922372341,0.6530546545982361,0.16100579500198364,4.674829006195068,500,22.106380462646484,0.002372157759964466,0.6869567632675171,0.14720509946346283,4.874122619628906,500,21.668710708618164,0.010731762275099754,0.6994706392288208,0.1702597439289093,5.343525409698486,198,20.726055145263672,0.004278174135833979,0.8526590466499329,0.13251639902591705,6.267956256866455,268,20.887666702270508,0.3664732277393341,0.6933053135871887,0.18257129192352295,7.749646186828613,200,15.50479507446289,0.011554110795259476,0.5916057229042053,0.18086765706539154,5.101010799407959,44,19.26924705505371,-0.0402817502617836,0.6491515636444092,0.17109069228172302,5.088503360748291,26,18.38506317138672,-0.01206012349575758,0.7032961249351501,0.17353220283985138,4.688672065734863,200,19.471759796142578,-0.0025741406716406345,0.790705680847168,0.18272539973258972,6.433710098266602,200,24.408714294433594,-0.006203812547028065,0.7900201678276062,0.1093222051858902,3.0680618286132812,197,25.618268966674805,-0.012119447812438011,0.8320527672767639,0.189222514629364,9.541532516479492,145,20.02849769592285,0.0040948037058115005,0.8293033838272095,0.22746650874614716,8.622618675231934,200,25.799362182617188,0.002677352400496602,0.8147953748703003,0.1341477930545807,3.4496517181396484,200,22.345197677612305,0.27116137742996216,0.7258548736572266,0.1409902721643448,3.6172609329223633,200,16.546030044555664,-0.10674090683460236,0.5869951844215393,0.23679934442043304,3.374926805496216,200,19.099502563476562,-0.007533765863627195,0.6395227909088135,0.21552789211273193,5.256075382232666,200,16.655963897705078,0.0018004697049036622,0.6986208558082581,0.1904933899641037,5.870639801025391,200,22.06020736694336,0.564301609992981,0.7581813931465149,0.13439540565013885,3.1246254444122314,142,27.067359924316406,0.00568796694278717,0.857913076877594,0.11922232806682587,4.030396461486816,200,16.69654655456543,0.013287489302456379,0.6759082078933716,0.41035643219947815,10.485578536987305,200,20.1484317779541,-0.0037386345211416483,0.8214008808135986,0.1595878303050995,7.100029945373535,200,19.99980926513672,0.002995690330862999,0.7338225841522217,0.16079355776309967,5.090161323547363,200,17.28361701965332,-0.0767323225736618,0.5945534706115723,0.23762698471546173,6.790170192718506,200,26.707612991333008,-0.010888484306633472,0.8517170548439026,0.08222655206918716,6.144393444061279,200,0,0,0,0,0,0,22.81102752685547,0.004911952186375856,0.753410816192627,0.13840247690677643,6.590641021728516,200,20.56325340270996,-0.0036200578324496746,0.8122907876968384,0.2328769862651825,9.606425285339355,185,25.634384155273438,0.014794806018471718,0.8339281678199768,0.13365812599658966,2.2014026641845703,200,28.079368591308594,-0.27046048641204834,0.9531135559082031,0.09380464255809784,1.130263090133667,200,20.635210037231445,0.014526949264109135,0.8181746006011963,0.21357515454292297,4.678038597106934,3,23.946765899658203,0.026584567502141,0.6427099704742432,0.09646405279636383,8.022929191589355,58,23.835311889648438,-0.00212214607745409,0.6980120539665222,0.15597480535507202,5.109379768371582,200,19.57789421081543,-0.008618982508778572,0.6491334438323975,0.21285711228847504,5.622872829437256,136,15.693717956542969,-0.0042601898312568665,0.6714515686035156,0.2660149037837982,7.933395862579346,200,0,0,0,0,0,0,20.719491958618164,0.012509683147072792,0.8595790266990662,0.11407399922609329,3.9665277004241943,60,0,0,0,0,0,0,0,0,0,0,0,0,16.573131561279297,-0.02216508239507675,0.5317869782447815,0.25623247027397156,9.647202491760254,63,0
6
+ 20.11272621154785,0.2109360694885254,0.6520361304283142,0.1624097228050232,4.9084296226501465,500,22.162460327148438,0.5982567667961121,0.6869907379150391,0.14949150383472443,4.766613483428955,500,21.70660972595215,0.24707387387752533,0.6991060376167297,0.16490808129310608,5.31504487991333,198,20.89645767211914,0.6825409531593323,0.8557612299919128,0.13019102811813354,6.398017406463623,200,21.12166976928711,0.08367721736431122,0.7037967443466187,0.1830134242773056,7.72568416595459,200,15.728557586669922,0.023094290867447853,0.5895631909370422,0.17670421302318573,4.983802318572998,44,19.269149780273438,0.21575377881526947,0.6504108905792236,0.17488287389278412,5.12293004989624,26,17.66543197631836,-0.3417667746543884,0.6988275647163391,0.1885838657617569,5.298708438873291,200,19.65871810913086,0.5588710904121399,0.7912994027137756,0.17227208614349365,6.4804582595825195,200,24.516550064086914,0.020750248804688454,0.7902164459228516,0.10835054516792297,3.0472919940948486,197,25.76390266418457,0.2111617475748062,0.8347102999687195,0.1926291286945343,9.65595817565918,145,18.151044845581055,-1.411789894104004,0.7994647026062012,0.26881223917007446,9.587594032287598,200,25.545425415039062,-0.2518838047981262,0.8101628422737122,0.13834546506404877,3.7476205825805664,200,20.21263313293457,-2.1581101417541504,0.6755650043487549,0.20787520706653595,5.349665641784668,200,17.979320526123047,0.029456928372383118,0.6195611953735352,0.2056853473186493,2.5060770511627197,200,19.21215057373047,0.09863721579313278,0.6396337747573853,0.21282075345516205,5.47512149810791,200,14.731433868408203,-1.0437582731246948,0.6715344190597534,0.24017825722694397,6.835475444793701,200,22.121034622192383,0.13931676745414734,0.7599001526832581,0.13316860795021057,2.0559332370758057,142,26.967485427856445,-0.04352593049407005,0.8562372326850891,0.11917727440595627,4.126804828643799,200,14.093514442443848,-2.401867628097534,0.6462867259979248,0.5737171769142151,11.239895820617676,200,20.1369686126709,0.03518101945519447,0.8176986575126648,0.15042880177497864,7.3873419761657715,200,19.936338424682617,1.2820953130722046,0.7327397465705872,0.16244173049926758,5.061605930328369,200,18.608089447021484,0.03232511132955551,0.6288440823554993,0.19996103644371033,6.023675918579102,200,25.265308380126953,-1.5252397060394287,0.8366735577583313,0.09203425794839859,6.713019847869873,200,0,0,0,0,0,0,22.606639862060547,0.23754052817821503,0.7469227313995361,0.14175154268741608,7.264929294586182,200,20.01392364501953,-0.7243794202804565,0.8044877052307129,0.24030713737010956,9.7893705368042,185,25.513294219970703,0.022836336866021156,0.8328461647033691,0.13896547257900238,2.2256522178649902,200,27.747732162475586,2.1492269039154053,0.9519210457801819,0.09740724414587021,1.2134103775024414,200,20.910673141479492,0.47492536902427673,0.8211309909820557,0.21091899275779724,4.660539150238037,3,23.682945251464844,-0.09257392585277557,0.6411334872245789,0.09143876284360886,8.146902084350586,58,23.916791915893555,0.2757120430469513,0.6961292028427124,0.15634751319885254,5.073945999145508,200,19.578807830810547,0.06890798360109329,0.648792028427124,0.2106357365846634,5.744231224060059,136,15.669443130493164,0.014397966675460339,0.662750780582428,0.26728084683418274,8.170258522033691,200,0,0,0,0,0,0,20.507707595825195,-0.1476416438817978,0.8596673011779785,0.13412821292877197,4.167920112609863,60,0,0,0,0,0,0,0,0,0,0,0,0,16.79804229736328,0.39872896671295166,0.5435536503791809,0.23285678029060364,9.751370429992676,63,0
7
+ 20.09503936767578,0.19324186444282532,0.6532765030860901,0.16777116060256958,4.9215545654296875,200,21.670238494873047,0.5962230563163757,0.6772172451019287,0.15711373090744019,4.973745822906494,200,21.32579231262207,0.23807521164417267,0.6979929804801941,0.17777009308338165,5.471674919128418,198,19.964054107666016,0.6898609399795532,0.8480191826820374,0.14881540834903717,6.650026798248291,200,19.73797035217285,0.028328483924269676,0.6416213512420654,0.2283506691455841,9.225159645080566,200,14.12078857421875,-0.07745029032230377,0.5660436749458313,0.21174412965774536,5.737281322479248,44,17.832124710083008,-0.6108675003051758,0.625728189945221,0.2204168289899826,5.574310302734375,26,17.75128746032715,0.5052798986434937,0.6958627700805664,0.1850184202194214,5.1611528396606445,200,18.28720474243164,3.523144483566284,0.77419114112854,0.22147363424301147,6.997479438781738,200,24.396211624145508,0.042355746030807495,0.7892429828643799,0.11021917313337326,3.1949589252471924,197,24.727394104003906,0.08382178097963333,0.831762433052063,0.21840521693229675,9.826996803283691,145,15.711186408996582,-1.1402037143707275,0.7454285621643066,0.4373578727245331,11.309600830078125,200,25.52533721923828,-0.00760778971016407,0.8099644184112549,0.13890613615512848,3.6684834957122803,200,19.276363372802734,-2.6745126247406006,0.6672981381416321,0.21036404371261597,4.968836784362793,200,15.376065254211426,0.2176371067762375,0.5345624089241028,0.28238123655319214,4.76574182510376,200,18.881689071655273,0.1820032000541687,0.6362051367759705,0.22268341481685638,5.450001239776611,200,15.536110877990723,0.5788084268569946,0.6868603229522705,0.2221454679965973,6.200928211212158,200,18.588090896606445,-0.1351645439863205,0.6845595240592957,0.20317442715168,4.309319496154785,142,26.693002700805664,-0.017797470092773438,0.85491943359375,0.12290634214878082,4.107079982757568,200,15.977954864501953,-0.2343152016401291,0.6691921949386597,0.4571930468082428,10.901474952697754,200,19.851110458374023,0.4773397743701935,0.8219828605651855,0.17668022215366364,7.3162055015563965,200,19.206226348876953,0.6738173961639404,0.7266687750816345,0.17411881685256958,5.285086631774902,200,15.480487823486328,-0.05200222134590149,0.5209740400314331,0.30362191796302795,8.292901992797852,200,25.806957244873047,-0.5103187561035156,0.8443582653999329,0.08975253254175186,6.479008197784424,200,0,0,0,0,0,0,21.907487869262695,-0.3219781517982483,0.7322761416435242,0.15371987223625183,7.148579120635986,200,19.59992027282715,-0.06399036198854446,0.8096731305122375,0.2715202569961548,9.681106567382812,185,25.1120662689209,-0.05029723793268204,0.829463541507721,0.14083868265151978,2.2757794857025146,200,28.31503677368164,1.267438292503357,0.9527142643928528,0.0927368700504303,1.3062008619308472,200,20.59347152709961,0.03865854814648628,0.8206701874732971,0.2205837219953537,4.7697625160217285,3,23.66497230529785,0.0852617546916008,0.6408542990684509,0.10270028561353683,8.262017250061035,58,23.514087677001953,0.18420732021331787,0.6972275972366333,0.1637846827507019,5.285017490386963,200,19.399612426757812,0.14159218966960907,0.6437479257583618,0.2207319140434265,6.007476806640625,136,15.523843765258789,0.04549547657370567,0.6679399609565735,0.2731015682220459,8.212705612182617,200,0,0,0,0,0,0,18.983938217163086,-0.7280352711677551,0.8515005111694336,0.16164684295654297,4.466729640960693,60,0,0,0,0,0,0,0,0,0,0,0,0,16.20790672302246,0.4906918406486511,0.5336827039718628,0.28068792819976807,10.018815994262695,63,0
8
+ 18.422204971313477,-0.12648740410804749,0.6362218260765076,0.19450049102306366,5.315305233001709,500,20.882783889770508,-0.07433691620826721,0.6690763831138611,0.1761208325624466,5.387838363647461,500,21.112810134887695,-0.03544004634022713,0.699590802192688,0.18499258160591125,5.969436168670654,198,18.350582122802734,0.12320471554994583,0.8303707838058472,0.1668887436389923,6.814785480499268,268,20.747133255004883,-0.008221889846026897,0.6795586943626404,0.18530716001987457,8.013651847839355,215,13.65688705444336,-0.014612981118261814,0.5605658888816833,0.21683458983898163,5.646177291870117,44,17.740320205688477,-0.01435495913028717,0.6188032031059265,0.20206505060195923,5.870429039001465,26,16.441152572631836,-0.09573374688625336,0.6951659917831421,0.20139959454536438,5.628950119018555,500,18.033527374267578,-0.049468427896499634,0.7758861184120178,0.20358926057815552,7.416662693023682,500,23.77448272705078,-0.004841374699026346,0.7814860939979553,0.11732590198516846,3.6382157802581787,197,23.89436149597168,-0.07056951522827148,0.8185535073280334,0.22474637627601624,10.225882530212402,145,19.539384841918945,-0.2427220195531845,0.8561981320381165,0.23518529534339905,9.096972465515137,500,24.909347534179688,-0.04073095694184303,0.8004875779151917,0.146201491355896,4.055650234222412,500,21.47697639465332,-0.03675007075071335,0.7250679135322571,0.1535453498363495,4.3143134117126465,500,18.370075225830078,-0.002068187575787306,0.6328396201133728,0.19349025189876556,2.4523301124572754,500,18.51810073852539,-0.029367070645093918,0.6309530138969421,0.23161163926124573,5.59820556640625,400,15.171514511108398,-0.08596291393041611,0.6846583485603333,0.2208338975906372,6.640722751617432,500,22.067739486694336,-0.010470133274793625,0.759013831615448,0.13397353887557983,2.6300439834594727,142,26.65119743347168,-0.02217714861035347,0.8541759252548218,0.12250284105539322,4.489205837249756,304,16.662843704223633,0.030067602172493935,0.6663406491279602,0.45978274941444397,10.977742195129395,500,18.961740493774414,-0.03515905141830444,0.8050862550735474,0.18823504447937012,8.212032318115234,290,17.795196533203125,-0.022556299343705177,0.6786172986030579,0.21285778284072876,5.880309581756592,500,18.647926330566406,-0.020784933120012283,0.6541447043418884,0.21047918498516083,6.204220771789551,500,25.266008377075195,-0.13009877502918243,0.8358787298202515,0.09196046739816666,6.571922779083252,500,0,0,0,0,0,0,21.6362247467041,0.043379079550504684,0.7133622169494629,0.16810670495033264,7.582168102264404,500,19.400468826293945,-0.06924106180667877,0.7891562581062317,0.2562311589717865,10.40406322479248,185,24.272430419921875,0.002408565254881978,0.822277307510376,0.14490512013435364,3.1252973079681396,307,21.627735137939453,0.28700390458106995,0.9307183623313904,0.15895913541316986,3.7723772525787354,200,20.32921028137207,0.07073872536420822,0.8144434094429016,0.21742278337478638,5.295686721801758,3,23.026090621948242,-0.008077435195446014,0.63054358959198,0.1302529126405716,8.783021926879883,58,22.135250091552734,-0.020904889330267906,0.6749143004417419,0.18395783007144928,5.7703423500061035,500,18.66996192932129,-0.03733401745557785,0.6296820640563965,0.2478480488061905,6.311908721923828,136,15.479337692260742,-0.004108104854822159,0.6553525328636169,0.28199443221092224,8.236729621887207,475,0,0,0,0,0,0,18.864093780517578,0.1739964336156845,0.8528752326965332,0.1548851728439331,5.3060221672058105,60,0,0,0,0,0,0,0,0,0,0,0,0,15.678476333618164,0.038419850170612335,0.5233398675918579,0.3035353720188141,10.510383605957031,63,0
9
+ 18.889339447021484,-0.1648501455783844,0.6437658667564392,0.1826857179403305,5.060956954956055,500,21.38709831237793,0.1416708528995514,0.6777622103691101,0.16441290080547333,5.186842441558838,500,21.354217529296875,0.0018506277119740844,0.7000029683113098,0.17723077535629272,5.897092342376709,198,19.61304473876953,0.037835195660591125,0.8433040380477905,0.14655180275440216,6.7373223304748535,268,20.715747833251953,0.01130291074514389,0.678013265132904,0.17803651094436646,8.016233444213867,215,15.14285659790039,0.26080322265625,0.5890591144561768,0.18768545985221863,5.359984874725342,44,18.543058395385742,-0.08016975224018097,0.6347747445106506,0.18409821391105652,5.6670966148376465,26,16.97516441345215,-0.1479617804288864,0.702846884727478,0.18982268869876862,5.4662394523620605,500,18.89791488647461,0.24670402705669403,0.7856683135032654,0.18493321537971497,7.272751808166504,500,24.04220199584961,0.019914034754037857,0.7841129302978516,0.11466450989246368,3.5969300270080566,197,25.081993103027344,0.19039775431156158,0.8243004679679871,0.19027777016162872,10.178629875183105,145,20.056028366088867,-0.3057011365890503,0.8603973984718323,0.2097621113061905,9.087891578674316,500,25.227895736694336,-0.03144318237900734,0.8052465319633484,0.1430753916501999,3.845712900161743,500,21.70534896850586,-0.05035046488046646,0.7276560068130493,0.148365318775177,4.039795398712158,500,18.364177703857422,-2.2224783151614247e-06,0.6325255632400513,0.19385698437690735,2.432619333267212,500,18.738195419311523,-0.033635493367910385,0.6359359622001648,0.22341181337833405,5.412669658660889,400,16.01719856262207,0.0007864768267609179,0.7023988366127014,0.2010006606578827,6.480032920837402,500,22.102428436279297,-0.0012201054487377405,0.759409487247467,0.13381606340408325,2.5473294258117676,142,26.880590438842773,-0.02358187735080719,0.8565123081207275,0.12121907621622086,4.459400177001953,304,16.68733787536621,0.07558043301105499,0.6599492430686951,0.4475219249725342,11.352530479431152,500,19.459217071533203,-0.03683672845363617,0.809934139251709,0.16714370250701904,8.145930290222168,290,18.282310485839844,-0.0003248922876082361,0.6865655779838562,0.20002803206443787,5.662166595458984,500,18.83856964111328,-0.03184810280799866,0.6561446785926819,0.20359492301940918,6.031191349029541,500,25.960927963256836,-0.08454426378011703,0.8430694341659546,0.08546590059995651,6.4937214851379395,500,0,0,0,0,0,0,21.94908332824707,-0.3419593274593353,0.7211637496948242,0.16153308749198914,7.5465989112854,500,20.40264892578125,0.017866995185613632,0.8031715750694275,0.22785907983779907,10.408190727233887,185,24.751502990722656,-0.008820587769150734,0.8268305659294128,0.14042264223098755,2.8674750328063965,307,23.79220199584961,0.023781709372997284,0.9390408396720886,0.13440102338790894,3.189241647720337,200,20.65520668029785,-0.02165459282696247,0.8165206909179688,0.21889762580394745,5.296838283538818,3,23.0324764251709,0.130624458193779,0.6325167417526245,0.1328558623790741,8.84994888305664,58,23.088464736938477,0.06231540068984032,0.6918238997459412,0.16935458779335022,5.561310291290283,500,19.111549377441406,0.09548354148864746,0.6398555040359497,0.23092451691627502,5.997190952301025,136,15.769152641296387,0.028782520443201065,0.6602852940559387,0.26686036586761475,8.060165405273438,475,0,0,0,0,0,0,19.613920211791992,-0.03737274929881096,0.8565442562103271,0.13404110074043274,4.84287166595459,60,0,0,0,0,0,0,0,0,0,0,0,0,16.006196975708008,-0.16971516609191895,0.5292651653289795,0.2796056866645813,10.46243953704834,63,0
10
+ 19.012269973754883,-0.04718773066997528,0.6433225870132446,0.18231874704360962,5.107872009277344,500,21.322338104248047,0.0140613978728652,0.6770919561386108,0.16535094380378723,5.265127182006836,500,21.2813777923584,-0.021526703611016273,0.6984917521476746,0.1797674298286438,5.906203269958496,198,19.346487045288086,0.10074374079704285,0.8397706151008606,0.1516176462173462,6.717349529266357,268,20.726825714111328,0.025485752150416374,0.6783157587051392,0.18283355236053467,7.897275447845459,215,14.350356101989746,-0.07720185816287994,0.5780261754989624,0.20371182262897491,5.403542995452881,44,18.631084442138672,0.07769659161567688,0.6344185471534729,0.18611928820610046,5.661814212799072,26,17.056427001953125,-0.1821698099374771,0.7035198211669922,0.18846037983894348,5.462797164916992,500,18.501523971557617,0.08108856528997421,0.7818804383277893,0.19710351526737213,7.370274543762207,500,24.048419952392578,0.013744712807238102,0.7855278849601746,0.11609163880348206,3.6646697521209717,197,24.893909454345703,-0.17591805756092072,0.8268886208534241,0.20372600853443146,10.059407234191895,145,19.865501403808594,-0.22679699957370758,0.8597549796104431,0.2231135070323944,9.073450088500977,500,25.364770889282227,0.010617231950163841,0.8068088889122009,0.14136189222335815,4.029416084289551,500,21.7684383392334,-0.011133561842143536,0.7283687591552734,0.14749296009540558,4.200531482696533,500,18.374786376953125,-0.0007336796843446791,0.6325269341468811,0.1934446394443512,2.4958865642547607,500,18.839426040649414,-0.03957090154290199,0.6344597339630127,0.2224958837032318,5.42381477355957,400,15.805286407470703,0.08770310878753662,0.6977684497833252,0.20588622987270355,6.46218729019165,500,22.08108139038086,-0.01628425344824791,0.7589410543441772,0.13417819142341614,2.693875789642334,142,26.95757484436035,0.0022796352859586477,0.8572258353233337,0.12103652209043503,4.541953086853027,304,16.950056076049805,0.05445929989218712,0.6725068092346191,0.42972901463508606,10.8408203125,500,19.32314682006836,0.034813400357961655,0.810303807258606,0.1743694394826889,8.154923439025879,290,18.345930099487305,0.09873536974191666,0.6867643594741821,0.20015820860862732,5.688977241516113,500,18.867534637451172,0.008000208996236324,0.6561379432678223,0.20259518921375275,6.0598602294921875,500,25.886808395385742,-0.1851750612258911,0.843407154083252,0.08705346286296844,6.513862609863281,500,0,0,0,0,0,0,22.645614624023438,0.03813670575618744,0.7365572452545166,0.14904607832431793,7.366730213165283,500,20.234968185424805,-0.028652122244238853,0.8059276938438416,0.23735074698925018,10.289590835571289,185,24.92990493774414,-0.05894114822149277,0.8283792734146118,0.13925190269947052,3.199906349182129,307,23.42925453186035,-0.41295871138572693,0.9373413920402527,0.14144201576709747,3.7806153297424316,200,20.85039710998535,-0.004692706745117903,0.8196279406547546,0.2168300598859787,5.288863182067871,3,23.50503158569336,-0.028557421639561653,0.6380681991577148,0.1192559152841568,8.550318717956543,58,22.69325828552246,0.08483685553073883,0.6885713338851929,0.17399273812770844,5.622918605804443,500,18.995891571044922,-0.025361914187669754,0.6374244689941406,0.23452074825763702,6.074242115020752,136,15.669060707092285,0.025412963703274727,0.6598301529884338,0.2714788615703583,8.040945053100586,475,0,0,0,0,0,0,19.33572006225586,-0.150547593832016,0.8552142381668091,0.1460384875535965,5.214033126831055,60,0,0,0,0,0,0,0,0,0,0,0,0,16.531322479248047,0.06065846234560013,0.5375595688819885,0.25438711047172546,10.283487319946289,63,0
11
+ 18.73299217224121,-0.13161632418632507,0.6400277018547058,0.18755526840686798,5.157333850860596,500,21.10808753967285,-0.16676977276802063,0.672235906124115,0.1692410707473755,5.296654224395752,500,21.203834533691406,-0.024857260286808014,0.6980849504470825,0.18170055747032166,5.922959327697754,198,19.301345825195312,0.2718823254108429,0.8395335078239441,0.15078316628932953,6.726401329040527,268,20.7464542388916,0.0026871892623603344,0.67714923620224,0.18162284791469574,7.944480895996094,215,14.387572288513184,-0.0007205808651633561,0.5734782814979553,0.20236076414585114,5.430235385894775,44,18.500873565673828,0.08837074786424637,0.6331009864807129,0.18800735473632812,5.786571979522705,26,16.948963165283203,-0.07204114645719528,0.7015774846076965,0.19183480739593506,5.483725547790527,500,18.69580078125,0.08595505356788635,0.7832303643226624,0.18746890127658844,7.323951244354248,500,24.063138961791992,0.040227603167295456,0.7842744588851929,0.11548798531293869,3.7416088581085205,197,24.77498435974121,-0.10164796561002731,0.8261294364929199,0.19882343709468842,10.059757232666016,145,20.417898178100586,0.12687283754348755,0.8632937669754028,0.20207682251930237,9.024741172790527,500,25.293771743774414,-0.013352192007005215,0.8060333728790283,0.1428879052400589,3.982311248779297,500,21.808292388916016,0.05234861373901367,0.7285376191139221,0.14711451530456543,4.239038467407227,500,18.38561248779297,0.0016012159176170826,0.6329339742660522,0.19295348227024078,2.502042531967163,500,18.688907623291016,-0.04777387157082558,0.6325020790100098,0.22649399936199188,5.479030609130859,400,16.001020431518555,0.1790591925382614,0.7002289891242981,0.20206555724143982,6.487364292144775,500,22.093610763549805,-0.01092259306460619,0.7591899633407593,0.13419727981090546,2.800780773162842,142,26.88434410095215,0.008505810052156448,0.8564297556877136,0.1211276724934578,4.6540913581848145,304,16.507360458374023,-0.14780215919017792,0.6714658737182617,0.4597805440425873,10.8370361328125,500,19.03635597229004,-0.03866236284375191,0.8065902590751648,0.179178386926651,8.137178421020508,290,18.207599639892578,0.0390840582549572,0.6847269535064697,0.20280848443508148,5.730045318603516,500,18.824954986572266,0.007314752321690321,0.656466543674469,0.20462752878665924,6.068840980529785,500,25.785615921020508,-0.16119396686553955,0.8409033417701721,0.0872565507888794,6.494687557220459,500,0,0,0,0,0,0,22.554359436035156,0.17164531350135803,0.7343575954437256,0.15075889229774475,7.297224998474121,500,20.00868797302246,-0.20269323885440826,0.8015546202659607,0.2370498925447464,10.263086318969727,185,24.849260330200195,0.014324101619422436,0.8274909257888794,0.13913463056087494,3.273496627807617,307,24.143491744995117,-0.0864129289984703,0.9391130208969116,0.1349482536315918,3.8407604694366455,200,20.61045265197754,-0.044369861483573914,0.817896842956543,0.22000306844711304,5.338271617889404,3,23.173229217529297,0.06823521107435226,0.6345419883728027,0.11324688047170639,8.597124099731445,58,22.52924156188965,-0.005226884037256241,0.6790374517440796,0.17646169662475586,5.648611068725586,500,18.822338104248047,-0.034071650356054306,0.6334968209266663,0.23989541828632355,6.15177059173584,136,15.601283073425293,-0.00634151604026556,0.6596918702125549,0.2742881774902344,8.023792266845703,475,0,0,0,0,0,0,19.87816619873047,-0.21906106173992157,0.8562472462654114,0.12401121854782104,5.19991397857666,60,0,0,0,0,0,0,0,0,0,0,0,0,15.993595123291016,-0.15662315487861633,0.5254011154174805,0.27742305397987366,10.258377075195312,63,0
12
+ 19.110883712768555,0.03614996746182442,0.6461726427078247,0.1795150190591812,5.265549182891846,500,21.108360290527344,-0.14474032819271088,0.6736648678779602,0.16756995022296906,5.4420576095581055,500,21.234786987304688,-0.0323927141726017,0.6997060179710388,0.18202681839466095,6.203959941864014,198,19.807920455932617,0.17171825468540192,0.842271089553833,0.1452997773885727,6.870168685913086,268,20.6026668548584,-0.011249484494328499,0.6755874752998352,0.18391308188438416,8.184722900390625,215,14.93524169921875,-0.0469074510037899,0.5757077932357788,0.19150333106517792,5.594109058380127,44,18.379505157470703,-0.03754724934697151,0.6315571665763855,0.18700724840164185,5.797399520874023,26,17.05131721496582,-0.10010848939418793,0.7025959491729736,0.18903829157352448,5.659502983093262,500,19.137954711914062,0.14187031984329224,0.785306453704834,0.18198907375335693,7.440811634063721,500,23.754358291625977,0.0686846598982811,0.7810045480728149,0.1172272339463234,3.8019583225250244,197,24.170063018798828,-0.11422554403543472,0.8158525824546814,0.21402356028556824,10.47574234008789,145,20.21315574645996,-0.09578195214271545,0.859663188457489,0.20872971415519714,9.281356811523438,500,25.30156707763672,-0.018742039799690247,0.8054926991462708,0.14185506105422974,4.103004455566406,500,21.711626052856445,0.006813677493482828,0.7276286482810974,0.14892178773880005,4.297058582305908,500,18.365007400512695,-0.007851418107748032,0.6329214572906494,0.19421109557151794,2.7666542530059814,500,18.607521057128906,-0.061320219188928604,0.6336644291877747,0.2250998169183731,5.655640125274658,400,15.967201232910156,-0.06718071550130844,0.6980611681938171,0.20418159663677216,6.667073726654053,500,22.090274810791016,-0.0017378028715029359,0.7593238949775696,0.13377660512924194,2.6701622009277344,142,26.750892639160156,-0.07790957391262054,0.8558348417282104,0.12187864631414413,4.571890354156494,304,16.300031661987305,0.010518108494579792,0.6394075155258179,0.46379354596138,11.451972961425781,500,19.428918838500977,0.1081659346818924,0.8055146336555481,0.16705267131328583,8.28512191772461,290,18.22205924987793,-0.02810261771082878,0.6851233839988708,0.20087064802646637,5.899336814880371,500,18.846923828125,-0.021323775872588158,0.6569703221321106,0.20638112723827362,6.310214519500732,500,25.63957977294922,-0.21977517008781433,0.8407760858535767,0.08965358138084412,6.699525833129883,500,0,0,0,0,0,0,22.031572341918945,-0.16287170350551605,0.7234898805618286,0.15870793163776398,7.465205192565918,500,19.877155303955078,0.0820217877626419,0.7918612360954285,0.24016961455345154,10.721766471862793,185,24.740259170532227,0.05753420665860176,0.8264721035957336,0.14053989946842194,3.2130744457244873,307,23.240985870361328,-0.9365432858467102,0.9365918040275574,0.14060433208942413,3.1881415843963623,200,20.191814422607422,-0.015667753294110298,0.8101682662963867,0.22013281285762787,5.518276214599609,3,22.47603416442871,0.04551320895552635,0.6252691149711609,0.11578691005706787,9.172477722167969,58,22.722797393798828,0.1353231817483902,0.6838358640670776,0.17324240505695343,5.799938678741455,500,18.984954833984375,0.008981794118881226,0.6379337906837463,0.23654848337173462,6.328289031982422,136,15.746525764465332,0.024799227714538574,0.6614393591880798,0.27319034934043884,8.200506210327148,475,0,0,0,0,0,0,19.888586044311523,-0.006714705843478441,0.8541179299354553,0.12033425271511078,5.116232872009277,60,0,0,0,0,0,0,0,0,0,0,0,0,15.9192476272583,-0.10522013157606125,0.5274612903594971,0.29392895102500916,10.56518840789795,63,0
13
+ 20.46278190612793,-0.060763970017433167,0.6579670906066895,0.161322221159935,4.680773735046387,200,22.052663803100586,0.5723397135734558,0.6804402470588684,0.1476287543773651,4.742996692657471,200,21.605833053588867,0.23160338401794434,0.6969097852706909,0.171411395072937,5.295865535736084,198,20.8725643157959,0.6837916374206543,0.8545865416526794,0.12993884086608887,6.236001968383789,200,19.573144912719727,0.29120177030563354,0.6459593772888184,0.22614720463752747,9.27774429321289,200,14.133930206298828,0.22245623171329498,0.564282238483429,0.2065575122833252,5.674943447113037,0,18.949764251708984,0.17191997170448303,0.6465899348258972,0.19255967438220978,5.078993797302246,0,17.931673049926758,-0.11049146950244904,0.6982738971710205,0.18147899210453033,5.114457130432129,200,19.872821807861328,3.0028798580169678,0.7939698696136475,0.17656543850898743,6.240903377532959,200,24.46942138671875,0.02408456988632679,0.7884182333946228,0.10927800089120865,2.891601324081421,197,25.840389251708984,0.19366736710071564,0.8314366936683655,0.1776755452156067,9.571022033691406,145,16.929052352905273,-1.5327644348144531,0.7828646302223206,0.3492167592048645,10.210173606872559,200,25.710983276367188,-0.017311187461018562,0.8133223652839661,0.1370222568511963,3.3832216262817383,200,18.50092124938965,-2.022146224975586,0.6764024496078491,0.19759753346443176,4.851006031036377,200,17.02499008178711,0.16653648018836975,0.5973562002182007,0.22383597493171692,3.0246450901031494,200,19.098665237426758,0.004295928869396448,0.638796329498291,0.21742229163646698,5.272642612457275,200,16.54633140563965,0.17737846076488495,0.6975648403167725,0.19392725825309753,5.778793811798096,200,20.060346603393555,0.9734287261962891,0.7293187975883484,0.16429120302200317,3.052821397781372,0,27.030488967895508,-0.08356904238462448,0.8584595322608948,0.12005755305290222,3.672435998916626,0,16.670869827270508,0.08606883883476257,0.6815004348754883,0.4169979691505432,10.184615135192871,0,20.275964736938477,0.1098201647400856,0.8205265402793884,0.15371590852737427,7.1788201332092285,200,19.987890243530273,0.2744450271129608,0.7336406707763672,0.16191206872463226,5.04759407043457,200,17.112150192260742,-0.015052754431962967,0.5950587391853333,0.23803958296775818,6.651583671569824,200,26.6037654876709,-0.26065897941589355,0.8506165146827698,0.08368074893951416,6.122287750244141,200,0,0,0,0,0,0,22.674583435058594,-0.11895836144685745,0.749716579914093,0.14290745556354523,6.852635383605957,200,20.810087203979492,-0.07519318163394928,0.8198818564414978,0.23864984512329102,9.383285522460938,185,25.744356155395508,0.10595372319221497,0.8350962400436401,0.13477018475532532,2.0739593505859375,200,28.547643661499023,0.7507640719413757,0.9549750089645386,0.08899791538715363,1.1312161684036255,200,21.079015731811523,0.029884053394198418,0.8219752907752991,0.21175755560398102,4.344815731048584,3,23.93878173828125,0.05293554812669754,0.6428359746932983,0.08624686300754547,7.878372669219971,58,23.858797073364258,0.06702835857868195,0.6956433653831482,0.15657857060432434,5.089795112609863,200,19.62722396850586,0.05636562407016754,0.650608479976654,0.2117949277162552,5.699061870574951,136,15.717203140258789,0.02669934555888176,0.672042965888977,0.26545801758766174,7.867374420166016,200,0,0,0,0,0,0,21.55535316467285,1.0019360780715942,0.8686189651489258,0.10703979432582855,4.072519302368164,60,0,0,0,0,0,0,0,0,0,0,0,0,16.39645767211914,0.24261629581451416,0.538216233253479,0.2558179199695587,9.674732208251953,63,0
14
+ 20.47854995727539,1.2024067640304565,0.6582081317901611,0.15989229083061218,4.665087699890137,200,22.276348114013672,0.9851991534233093,0.6830069422721863,0.14797556400299072,4.454257488250732,200,21.658119201660156,0.8381527066230774,0.6987597942352295,0.16515423357486725,5.077767848968506,198,21.830154418945312,3.4031131267547607,0.8640546202659607,0.12628214061260223,5.637884140014648,200,19.192798614501953,0.3516661822795868,0.6184120774269104,0.2570144832134247,10.193868637084961,200,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
15
+ 20.590303421020508,0.4962620437145233,0.6601793169975281,0.15753625333309174,4.5785813331604,200,22.237953186035156,0.4489673674106598,0.6821855902671814,0.14713634550571442,4.530069351196289,200,21.67139434814453,0.5253610610961914,0.6990557909011841,0.16455212235450745,5.083188533782959,198,21.511552810668945,2.3355321884155273,0.8608049750328064,0.12826672196388245,5.940978527069092,200,21.198034286499023,0.4582754373550415,0.696334183216095,0.17922863364219666,7.798947811126709,200,14.655498504638672,0.11240727454423904,0.571294903755188,0.19869443774223328,5.37660551071167,44,18.72746467590332,-0.0767064094543457,0.6407261490821838,0.19800549745559692,5.168641090393066,26,18.191055297851562,0.7278627157211304,0.6975538730621338,0.17866922914981842,5.047889232635498,200,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
16
+ 20.47854995727539,1.2024067640304565,0.6582081317901611,0.15989229083061218,4.665087699890137,200,22.276348114013672,0.9851991534233093,0.6830069422721863,0.14797556400299072,4.454257488250732,200,21.658119201660156,0.8381527066230774,0.6987597942352295,0.16515423357486725,5.077767848968506,198,21.830154418945312,3.4031131267547607,0.8640546202659607,0.12628214061260223,5.637884140014648,200,19.192798614501953,0.3516661822795868,0.6184120774269104,0.2570144832134247,10.193868637084961,200,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
17
+ 20.5020809173584,0.5488608479499817,0.6597924828529358,0.16010788083076477,4.721373081207275,200,22.316869735717773,1.2297122478485107,0.6829593181610107,0.1500682681798935,4.489542007446289,200,21.685548782348633,0.7039672136306763,0.6994300484657288,0.1671903133392334,5.135034561157227,198,21.7237548828125,2.7444534301757812,0.8628481030464172,0.12802647054195404,5.7939772605896,200,20.3883113861084,0.08283903449773788,0.6582912802696228,0.216960147023201,8.84621810913086,200,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
18
+ 21.46029465948103,0.10821035103956703,0.7790344667110257,0.1146072332425551,4.8639076042175295,200,22.858641416969302,0.2752841716840785,0.8049575514436705,0.11705265556546775,4.970765066146851,200,22.984534566644847,0.28235093636870023,0.8254322399853716,0.11391306117594352,5.201311178881713,198,21.269003861197834,0.48774929318372584,0.8845344550020229,0.10633078400722959,5.946489062309265,200,21.91150318956801,-0.1983188362236588,0.8238611486475108,0.14899968794744575,6.809058917893304,144,15.47896710908777,-0.21478780557785693,0.766425387992816,0.16411230564964088,5.260037183761597,32,20.303700892418085,0.11696818304556747,0.7842201829523877,0.14423518538136373,5.439062690734863,20,17.87179956351499,-0.3067176983676024,0.8296448574724792,0.12977059845050629,5.341754336357116,200,19.857372164518466,0.05305542383233112,0.8275884621862206,0.1549614530158314,6.752370271682739,200,27.3605180375065,0.04183501128873296,0.8984342113196049,0.06288346686315807,3.3972074127197267,200,26.833560528931578,0.8061492074839005,0.8918025717551408,0.1506120693911968,9.562130060024604,167,15.137075498804533,-5.621109492690501,0.6937904475067103,0.3866491428085349,10.7611030960083,200,27.60886437063087,-0.5479598479597061,0.9037209833626163,0.08412209230221131,4.320116324424744,200,17.608785845649514,-6.357329803241343,0.6815140742522903,0.24886811768551442,5.626099944114685,200,24.33990679792855,-0.10166818044734177,0.8951419864328773,0.0715118810958864,2.0131524705886843,200,21.411267073256052,-0.021985724226232916,0.7967679150871174,0.13873621311716058,4.87985538482666,200,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
common/plot/plot_arch_ablation.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import seaborn as sns
3
+ import matplotlib
4
+ import matplotlib.pyplot as plt
5
+ import numpy as np
6
+ import pandas as pd
7
+
8
+
9
+
10
+
11
+
12
+ import matplotlib.pyplot as plt
13
+ import numpy as np
14
+
15
+ # Sample data based on the provided image structure
16
+ tasks = [ "Add", "Concat", "Cross Attention", "Modulation"]
17
+ values = np.array([
18
+ [6.35],
19
+ [5.68],
20
+ [5.26],
21
+ [5.02],
22
+
23
+ # [0.87, 0.55, 0.25, 0.03, 0.01, 0.0]
24
+ ])
25
+ values = np.exp(values)
26
+ # Bar colors matching the provided image
27
+ bar_colors = ['#1f78b4', '#ffffff', '#a6cee3', '#cab2d6', '#b3b3cc', '#33a02c']
28
+
29
+ # Plotting the data
30
+ fig, ax = plt.subplots(figsize=(5, 3))
31
+
32
+ # Set bar width and x positions for each group
33
+ bar_width = 0.4
34
+ x = np.arange(len(tasks))
35
+
36
+ # Plot each group's bars with the specified colors
37
+ for i in range(values.shape[1]):
38
+ bars = ax.bar(x + i * bar_width, values[:, i], width=bar_width, color=bar_colors[i], edgecolor='black')
39
+
40
+ for container in ax.containers:
41
+ ax.bar_label(container, label_type="edge", fontsize="x-large", fmt="%.1f")
42
+ bars[-1].set_color('#cab2d6')
43
+ bars[-1].set_edgecolor('black')
44
+
45
+ # Set titles, labels, and ticks
46
+ # ax.set_title("Zero-Shot Performance Comparison Across Tasks")
47
+ ax.set_xlabel("Model", fontsize=14)
48
+ ax.set_ylabel("Perplexity", fontsize=14)
49
+ ax.set_xticks(x )
50
+ ax.tick_params(axis='x', rotation=15)
51
+ ax.set_xticklabels(tasks, fontsize=12)
52
+ ax.set_ylim(values.min() - 10, values.max() + 50)
53
+
54
+ # Adding the legend outside the plot area
55
+ # ax.legend(loc='upper center', bbox_to_anchor=(0.5, 1.15), ncol=3)
56
+
57
+ # Display the plot
58
+ plt.tight_layout()
59
+ # plt.show()
60
+ plt.savefig("output/arch_ablation.png", dpi=300)
common/plot/plot_arch_ablation_deltapsnr.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import seaborn as sns
2
+ import matplotlib
3
+ import matplotlib.pyplot as plt
4
+ import numpy as np
5
+ import pandas as pd
6
+
7
+ # Sample data based on the provided image structure
8
+ tasks = [ "Add", "Concat", "Cross Attention", "Modulation"]
9
+ values = np.array([
10
+ [0.46],
11
+ [0.18],
12
+ [0.02],
13
+ [1.87],
14
+ ])
15
+ # Bar colors matching the provided image
16
+ bar_colors = ['#1f78b4', '#a6cee3', '#1f78b4', '#ffffff', '#cab2d6', '#b3b3cc', '#33a02c']
17
+
18
+ # Plotting the data
19
+ fig, ax = plt.subplots(figsize=(5, 3))
20
+
21
+ # Set bar width and x positions for each group
22
+ bar_width = 0.4
23
+ x = np.arange(len(tasks))
24
+
25
+ # Plot each group's bars with the specified colors
26
+ for i in range(values.shape[1]):
27
+ bars = ax.bar(x + i * bar_width, values[:, i], width=bar_width, color=bar_colors[i], edgecolor='black')
28
+
29
+ for container in ax.containers:
30
+ ax.bar_label(container, label_type="edge", fontsize="x-large", fmt="%.2f")
31
+ bars[-1].set_color('#cab2d6')
32
+ bars[-1].set_edgecolor('black')
33
+
34
+ # Set titles, labels, and ticks
35
+ # ax.set_title("Zero-Shot Performance Comparison Across Tasks")
36
+ ax.set_xlabel("Model", fontsize=14)
37
+ ax.set_ylabel("Delta PSNR", fontsize=14)
38
+ ax.set_xticks(x )
39
+ ax.tick_params(axis='x', rotation=15)
40
+ ax.set_xticklabels(tasks, fontsize=12)
41
+ ax.set_ylim(0, values.max() + 0.2)
42
+
43
+ # Adding the legend outside the plot area
44
+ # ax.legend(loc='upper center', bbox_to_anchor=(0.5, 1.15), ncol=3)
45
+
46
+ # Display the plot
47
+ plt.tight_layout()
48
+ # plt.show()
49
+ plt.savefig("output/arch_ablation_controllability.png", dpi=300)
common/plot/plot_dataset_scale.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import seaborn as sns
2
+ import matplotlib
3
+ import matplotlib.pyplot as plt
4
+ import numpy as np
5
+ import pandas as pd
6
+ # Adjusting the line thickness to better match the provided example
7
+
8
+ x = [1, 5, 10]
9
+ tasks = ["Paper Towel Replacement\n(Bi-UR5e)", "Items in Drawer\n(Franka)",
10
+ "Stack Bowls\n(UR5e)", "Tupperware in Microwave\n(Bi-ARX)"]
11
+
12
+ # Define y-values for each line type
13
+ y_values = {
14
+ "π₀": [0.9, 0.85, 0.8],
15
+ "π₀ (scratch)": [0.7, 0.75, 0.72],
16
+ "DP": [0.2, 0.3, 0.4],
17
+ "Octo": [0.5, 0.6, 0.55],
18
+ "OpenVLA": [0.1, 0.15, 0.2],
19
+ "ACT": [0.4, 0.5, 0.6]
20
+ }
21
+
22
+ # Define markers, line styles, colors for each line type
23
+ markers = {"π₀": 'o', "π₀ (scratch)": 'o', "DP": 'o', "Octo": 'D', "OpenVLA": '*', "ACT": 'o'}
24
+ styles = {"π₀": '-', "π₀ (scratch)": '--', "DP": '-', "Octo": '-', "OpenVLA": '', "ACT": '-'}
25
+ colors = {"π₀": '#1f78b4', "π₀ (scratch)": '#1f78b4', "DP": '#e31a1c', "Octo": '#33a02c', "OpenVLA": '#6a3d9a', "ACT": '#ff7f00'}
26
+
27
+ # Set line width for enhanced visibility
28
+
29
+ # Create subplots
30
+
31
+
32
+ fig, ax = plt.subplots( figsize=(5, 4))
33
+
34
+ x_values = [5, 10, 20, 30, 40]
35
+ y_values = [5.94,5.72, 5.21,5.15,5.02]
36
+ y_values = np.exp(y_values)
37
+
38
+ # Set line width for each line plot
39
+ line_width = 1.5
40
+ x = []
41
+ # Iterate over each subplot (task) and plot the lines with specified styles, markers, and adjusted line width
42
+
43
+
44
+ fig, ax1 = plt.subplots(figsize=(5, 4))
45
+
46
+ # Plot Perplexity (left y-axis)
47
+ ax1.plot(x_values, y_values, marker='o', linestyle='-', color='#1f78b4', linewidth=line_width)
48
+ ax1.annotate(f"{y_values[-1]:.1f}", (x_values[-1], y_values[-1]), textcoords="offset points", xytext=(0, 10), ha='center')
49
+ ax1.set_xscale('log')
50
+ ax1.set_xlabel("# Dataset", fontsize=14)
51
+ ax1.set_ylabel("Perplexity", fontsize=14, color='#1f78b4')
52
+ ax1.tick_params(axis='y', labelcolor='#1f78b4')
53
+
54
+ # Create a twin y-axis for controllability (right y-axis)
55
+ ax2 = ax1.twinx()
56
+ controllability_values = [ 0.46, 0.55, 1.69, 1.5, 1.87] # Example values for controllability
57
+ ax2.plot(x_values, controllability_values, marker='s', linestyle='--', color='#006400', linewidth=line_width)
58
+ ax2.set_ylabel("Delta PSNR", fontsize=14, color='#006400')
59
+ ax2.set_ylim(0, 2.1)
60
+ ax2.tick_params(axis='y', labelcolor='#006400')
61
+ ax2.annotate(f"{controllability_values[-1]:.1f}", (x_values[-1], controllability_values[-1]), textcoords="offset points", xytext=(0, 10), ha='center')
62
+
63
+ # Save the figure in high resolution
64
+ plt.tight_layout()
65
+ # plt.show()
66
+
67
+
68
+ plt.savefig(f"output/dataset_sizes.png", dpi=300) # Save the figure in high resolution
69
+
common/plot/plot_dataset_traj_scale.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import seaborn as sns
2
+ import matplotlib
3
+ import matplotlib.pyplot as plt
4
+ import numpy as np
5
+ import pandas as pd
6
+ # Adjusting the line thickness to better match the provided example
7
+
8
+ fig, ax = plt.subplots( figsize=(5, 4))
9
+
10
+ x_values = [8287, 77664, 532150,1126876,2070965,3163485]
11
+ y_values = [9.46, 6.94, 5.81, 5.70, 5.09, 5.02]
12
+ y_values = np.exp(y_values)
13
+
14
+ # Set line width for each line plot
15
+ line_width = 1.5
16
+ x = []
17
+ # Iterate over each subplot (task) and plot the lines with specified styles, markers, and adjusted line width
18
+ # for i, task in enumerate(tasks):
19
+
20
+ # Adding a centralized legend that appears above the plot
21
+ # fig.legend(y_values, loc='upper center', bbox_to_anchor=(0.5, 1.05), ncol=3, frameon=False, markerscale=1.5)
22
+
23
+ fig, ax1 = plt.subplots(figsize=(5, 4))
24
+
25
+ # Plot Perplexity (left y-axis)
26
+ ax1.plot(x_values, y_values, marker='o', linestyle='-', color='#1f78b4', linewidth=line_width)
27
+ ax1.annotate(f"{y_values[-1]:.1f}", (x_values[-1], y_values[-1]), textcoords="offset points", xytext=(0, 10), ha='center')
28
+ ax1.set_xscale('log')
29
+ ax1.set_xlabel("# Trajectory", fontsize=14)
30
+ ax1.set_ylabel("Perplexity", fontsize=14, color='#1f78b4')
31
+ ax1.tick_params(axis='y', labelcolor='#1f78b4')
32
+
33
+ # Create a twin y-axis for controllability (right y-axis)
34
+ ax2 = ax1.twinx()
35
+ controllability_values = [0.,0.10,1.20,1.41,1.56, 1.87] # Example values for controllability
36
+ ax2.plot(x_values, controllability_values, marker='s', linestyle='--', color='#006400', linewidth=line_width)
37
+ ax2.set_ylabel("Delta PSNR", fontsize=14, color='#006400')
38
+ ax2.annotate(f"{controllability_values[-1]:.1f}", (x_values[-1], controllability_values[-1]), textcoords="offset points", xytext=(0, 10), ha='center')
39
+
40
+ ax2.set_ylim(0, 2.1)
41
+ ax2.tick_params(axis='y', labelcolor='#006400')
42
+
43
+ # Save the figure in high resolution
44
+ plt.tight_layout()
45
+ #plt.show()
46
+
47
+ plt.savefig(f"output/traj_sizes.png", dpi=300)
48
+
common/plot/plot_dynamics_ablation.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import seaborn as sns
2
+ import matplotlib
3
+ import matplotlib.pyplot as plt
4
+ import numpy as np
5
+ import pandas as pd
6
+
7
+
8
+
9
+
10
+
11
+
12
+ import matplotlib.pyplot as plt
13
+ import numpy as np
14
+
15
+ # Sample data based on the provided image structure
16
+ tasks = ["Passive Dynamics", "Full Dynamics", "Forward Dynamics"]
17
+ bar_labels = ["Passive Dynamics", "Full Dynamics", "Forward Dynamics"]
18
+ values = np.array([
19
+ [6.29],
20
+ [5.21],
21
+ [5.02],
22
+ ])
23
+ values = np.exp(values)
24
+ # Bar colors matching the provided image
25
+ bar_colors = ['#a6cee3', '#ffffff', '#a6cee3', '#cab2d6', '#b3b3cc', '#33a02c']
26
+
27
+ # Plotting the data
28
+ fig, ax = plt.subplots(figsize=(5, 3))
29
+
30
+ # Set bar width and x positions for each group
31
+ bar_width = 0.4
32
+ x = np.arange(len(tasks))
33
+
34
+ # Plot each group's bars with the specified colors
35
+ for i in range(values.shape[1]):
36
+ bars = ax.bar(x + i * bar_width, values[:, i], width=bar_width, color=bar_colors[i], edgecolor='black')
37
+ bars[-1].set_color('#cab2d6')
38
+ bars[-1].set_edgecolor('black')
39
+ for container in ax.containers:
40
+ ax.bar_label(container, label_type="edge", fontsize="x-large", fmt="%.1f")
41
+
42
+ # Set titles, labels, and ticks
43
+ # ax.set_title("Zero-Shot Performance Comparison Across Tasks")
44
+ ax.set_xlabel("Model", fontsize=14)
45
+ ax.set_ylabel("Perplexity", fontsize=14)
46
+ ax.set_xticks(x )
47
+ ax.set_xticklabels(tasks, fontsize=12)
48
+ ax.set_ylim(values.min() - 10, values.max() + 50)
49
+ ax.tick_params(axis='x', rotation=15)
50
+ # Adding the legend outside the plot area
51
+ # ax.legend(loc='upper center', bbox_to_anchor=(0.5, 1.15), ncol=3)
52
+
53
+ # Display the plot
54
+ plt.tight_layout()
55
+ # plt.show()
56
+ plt.savefig("output/dynamics_ablation.png", dpi=300)
common/plot/plot_dynamics_ablation_deltapsnr.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import seaborn as sns
2
+ import matplotlib
3
+ import matplotlib.pyplot as plt
4
+ import numpy as np
5
+ import pandas as pd
6
+ import numpy as np
7
+
8
+ # Sample data based on the provided image structure
9
+ tasks = ["Passive Dynamics", "Full Dynamics", "Forward Dynamics"]
10
+ bar_labels = ["Passive Dynamics", "Full Dynamics", "Forward Dynamics"]
11
+ values = np.array([
12
+ [0.33],
13
+ [1.23],
14
+ [1.87],
15
+ # [0.87, 0.55, 0.25, 0.03, 0.01, 0.0]
16
+ ])
17
+ # Bar colors matching the provided image
18
+ bar_colors = ['#a6cee3', '#1f78b4', '#ffffff', '#cab2d6', '#b3b3cc', '#33a02c']
19
+
20
+ # Plotting the data
21
+ fig, ax = plt.subplots(figsize=(5, 3))
22
+
23
+ # Set bar width and x positions for each group
24
+ bar_width = 0.4
25
+ x = np.arange(len(tasks))
26
+
27
+ # Plot each group's bars with the specified colors
28
+ for i in range(values.shape[1]):
29
+ bars = ax.bar(x + i * bar_width, values[:, i], width=bar_width, color=bar_colors[i], edgecolor='black')
30
+
31
+ bars[-1].set_color('#cab2d6')
32
+ bars[-1].set_edgecolor('black')
33
+
34
+ for container in ax.containers:
35
+ ax.bar_label(container, label_type="edge", fontsize="x-large", fmt="%.2f")
36
+
37
+ # Set titles, labels, and ticks
38
+ # ax.set_title("Zero-Shot Performance Comparison Across Tasks")
39
+ ax.set_xlabel("Model", fontsize=14)
40
+ ax.set_ylabel("Delta PSNR", fontsize=14)
41
+ ax.set_xticks(x )
42
+ ax.set_xticklabels(tasks, fontsize=12)
43
+ ax.set_ylim(values.min() - 0.1, values.max() + 0.2)
44
+ ax.tick_params(axis='x', rotation=15)
45
+ # Adding the legend outside the plot area
46
+ # ax.legend(loc='upper center', bbox_to_anchor=(0.5, 1.15), ncol=3)
47
+
48
+ # Display the plot
49
+ plt.tight_layout()
50
+ # plt.show()
51
+ plt.savefig("output/dynamics_ablation_controllability.png", dpi=300)
common/plot/plot_from_wandb.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import wandb
2
+ import pandas as pd
3
+ import matplotlib.pyplot as plt
4
+ import sys
5
+ import argparse
6
+ import os
7
+
8
+ """
9
+ Running plotting scripts over key metrics and key runs
10
+ export MODEL=final2_40dataset_waction_concat_gpu_8_nodes_1
11
+ python common/plot/plot_from_wandb.py --run_id $MODEL
12
+ python common/plot/plot_from_wandb.py --run_id final2_40dataset_noaction_gpu_8_nodes_1_step15k_v5
13
+ python common/plot/plot_from_wandb.py --run_id final2_40dataset_waction_modulate_gpu_8_nodes_1_step15k_v5
14
+ python common/plot/plot_from_wandb.py --run_id final2_40dataset_waction_attn_gpu_8_nodes_1_step15k_v5
15
+ python common/plot/plot_from_wandb.py --run_id final2_40dataset_waction_add_gpu_8_nodes_1_step15k_v5
16
+ python common/plot/plot_from_wandb.py --run_id final2_40dataset_waction_d64_gpu_8_nodes_1_step15k_v5
17
+ python common/plot/plot_from_wandb.py --run_id final2_40dataset_forward_dynamics_gpu_8_nodes_1_step15k_v5
18
+ python common/plot/plot_from_wandb.py --run_id final2_40dataset_full_dynamics_gpu_8_nodes_1_step15k_v5
19
+ python common/plot/plot_from_wandb.py --run_id final2_40dataset_waction_traj100000_gpu_8_nodes_1_68536steps_step15k_v5
20
+ python common/plot/plot_from_wandb.py --run_id final2_40dataset_waction_traj10000_gpu_8_nodes_1_68536steps_step15k_v5
21
+ python common/plot/plot_from_wandb.py --run_id final2_40dataset_waction_traj100_gpu_8_nodes_1_68536steps_step15k_v5
22
+ python common/plot/plot_from_wandb.py --run_id final2_40dataset_waction_traj1000_gpu_8_nodes_1_68536steps_step15k_v5
23
+ python common/plot/plot_from_wandb.py --run_id final2_5dataset_waction_gpu_8_nodes_1_step24k_v5
24
+ python common/plot/plot_from_wandb.py --run_id final2_30dataset_waction_gpu_8_nodes_1_step24k_v5
25
+ python common/plot/plot_from_wandb.py --run_id final2_5dataset_waction_gpu_8_nodes_1_step24k_v5
26
+ python common/plot/plot_from_wandb.py --run_id final2_10dataset_waction_gpu_8_nodes_1_step24k_v5
27
+
28
+ """
29
+ # Initialize the wandb API client
30
+ api = wandb.Api()
31
+ pwd = os.path.dirname(os.path.abspath(__file__))
32
+
33
+ # Replace with your specific project and entity
34
+ entity = "latent-mage"
35
+ project = "video_val"
36
+
37
+ # List of datasets to process
38
+ datasets = [
39
+ "bridge_data_v2",
40
+ "fractal20220817_data",
41
+ "language_table",
42
+ "ucsd_pick_and_place_dataset_converted_externally_to_rlds",
43
+ "kaist_nonprehensile_converted_externally_to_rlds",
44
+ "ucsd_kitchen_dataset_converted_externally_to_rlds",
45
+ "utokyo_xarm_bimanual_converted_externally_to_rlds",
46
+ "stanford_hydra_dataset_converted_externally_to_rlds",
47
+ "austin_sirius_dataset_converted_externally_to_rlds",
48
+ "berkeley_fanuc_manipulation",
49
+ "berkeley_mvp_converted_externally_to_rlds",
50
+ "berkeley_rpt_converted_externally_to_rlds",
51
+ "cmu_play_fusion",
52
+ "iamlab_cmu_pickup_insert_converted_externally_to_rlds",
53
+ "qut_dexterous_manpulation",
54
+ "robo_net",
55
+ "furniture_bench_dataset_converted_externally_to_rlds",
56
+ "dlr_sara_grid_clamp_converted_externally_to_rlds",
57
+ "cmu_stretch",
58
+ "spoc",
59
+ "columbia_cairlab_pusht_real",
60
+ "droid",
61
+ "toto",
62
+ "io_ai_tech",
63
+ "conq_hose_manipulation",
64
+ "dobbe",
65
+ "berkeley_gnm_cory_hall",
66
+ "plex_robosuite",
67
+ "usc_cloth_sim_converted_externally_to_rlds",
68
+ "berkeley_cable_routing",
69
+ "imperial_wrist_dataset",
70
+ "bc_z",
71
+ "kuka",
72
+ "roboturk",
73
+ "metaworld",
74
+ "robomimic",
75
+ "epic_kitchen",
76
+ "ego4d",
77
+ "nyu_door_opening_surprising_effectiveness"
78
+ ]
79
+
80
+ def normalize_dataset(metric, runs):
81
+ """
82
+ Figure out best and worst values for a metric across all runs
83
+ and use it for normalization
84
+ """
85
+ pass
86
+
87
+ # List to store dataframes of PSNR metrics for each dataset
88
+ metrics_data = []
89
+ # Get runs based on a path
90
+ # Set up argument parser
91
+ parser = argparse.ArgumentParser(description='Process some integers.')
92
+ parser.add_argument('--run_id', type=str, default='40dataset_waction_add_gpu_8_nodes_1', help='The run ID to process')
93
+
94
+ # Parse arguments
95
+ args = parser.parse_args()
96
+
97
+ fields = ['num_examples', 'teacher_force_psnr', 'teacher_force_psnr_delta', 'teacher_force_ssim', 'teacher_force_pred_lpips', 'teacher_force_loss']
98
+ num_fields = len(fields)
99
+ run_id = args.run_id
100
+
101
+ runs_path = f"{entity}/{project}/runs"
102
+ run = api.run(f"{entity}/{project}/runs/{run_id}")
103
+
104
+ # Get the history dataframe of a run
105
+ history = run.history(pandas=True)
106
+ model_step = 0
107
+ summary_metrics = run.summary
108
+ num_datasets = 0
109
+
110
+ # output the field into csv
111
+ # csv_output = f"{pwd}/aggregated_output.csv"
112
+ csv_output = f"aggregated_output.csv"
113
+
114
+ # initialize the csv file
115
+ if not os.path.exists(csv_output):
116
+ with open(csv_output, 'w') as f:
117
+ field_str = f"name,"
118
+ for dataset in datasets:
119
+ for field in fields:
120
+ field_str += f"{dataset}/{field},"
121
+ f.write(field_str.rstrip(",") + "\n")
122
+
123
+ results = [run_id] + [None] * len(datasets) * num_fields
124
+ for field_idx, field in enumerate(fields):
125
+ if not history.empty:
126
+ # Filter the history to only include PSNR metrics for the specified datasets
127
+ for dataset_idx, dataset in enumerate(datasets):
128
+ field_col = f"{dataset}/{field}"
129
+ col_idx = dataset_idx * num_fields + field_idx + 1
130
+ if field == "num_examples":
131
+ if f"{dataset}/num_examples" in summary_metrics:
132
+ results[col_idx] = summary_metrics[f"{dataset}/num_examples"]
133
+
134
+ continue
135
+ if field_col in history.columns:
136
+ # Calculate PSNR divided by the number of examples (uncomment if needed)
137
+ # history[field_col] = history[field_col] / history.shape[0]
138
+ valid_field = history[field_col].dropna()
139
+ if not valid_field.empty:
140
+ last_valid_value = valid_field.iloc[-1] # Get the last non-NaN value
141
+ num_datasets += 1
142
+ metrics = pd.DataFrame({field_col: [last_valid_value]})
143
+ metrics['dataset'] = dataset
144
+ results[col_idx] = last_valid_value
145
+ metrics_data.append(metrics)
146
+ else:
147
+ pass
148
+ # print("missing dataset:", dataset)
149
+
150
+ if f"{dataset}/model_step" in summary_metrics:
151
+ model_step = summary_metrics[f"{dataset}/model_step"]
152
+
153
+ # Combine all the metric dataframes into one
154
+ if metrics_data:
155
+ all_metrics_df = pd.concat(metrics_data, ignore_index=True)
156
+
157
+ # # Compute aggregated statistics (mean, median, std, etc.) for PSNR
158
+ # aggregated_stats = all_metrics_df.groupby('dataset').mean()
159
+ #
160
+ # # Plot the mean PSNR for each dataset
161
+ # plt.figure(figsize=(12, 8))
162
+ # aggregated_stats[f'{field}'] = aggregated_stats.mean(axis=1)
163
+ # aggregated_stats[f'{field}'].plot(kind='bar')
164
+ # # print number of steps in the wandb run
165
+ # print(f"run: {run_id} field: {field} steps: {model_step} num of dataset: {len(metrics_data)}")
166
+ # print(f"{field}: {aggregated_stats[field].mean():.2f}+-{aggregated_stats[field].std():.2f}", )
167
+ #
168
+ # plt.title(f"Mean {field} for Each Dataset")
169
+ # plt.xlabel("Dataset")
170
+ # plt.ylabel(f"Mean {field} ")
171
+ # plt.xticks(rotation=90)
172
+ # plt.tight_layout()
173
+ #
174
+ # # Save the plot
175
+ # plt.savefig(f"{pwd}/output/{run.id}_{field}_plot.png")
176
+
177
+ # write the results into csv
178
+ with open(csv_output, 'a+') as f:
179
+ f.write(",".join([str(x) for x in results]) + "\n")
180
+
181
+ # Display aggregated statistics
182
+ # print(aggregated_stats)
183
+
184
+ # Save the aggregated statistics as a CSV if needed
185
+ # aggregated_stats.to_csv(f"{run_id}_{field}_stat.csv", index=True)
common/plot/plot_from_wandb_singledataset.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import wandb
2
+ import pandas as pd
3
+ import matplotlib.pyplot as plt
4
+ import sys
5
+ import argparse
6
+
7
+ """
8
+ Running plotting scripts over key metrics and key runs
9
+ export MODEL=40dataset_waction_add_gpu_8_nodes_1
10
+ python common/plot/plot_from_wandb.py --field teacher_force_psnr --run_id $MODEL
11
+ python common/plot/plot_from_wandb.py --field teacher_force_psnr_delta --run_id $MODEL
12
+ python common/plot/plot_from_wandb.py --field teacher_force_ssim --run_id $MODEL
13
+ python common/plot/plot_from_wandb.py --field teacher_force_pred_lpips --run_id $MODEL
14
+ python common/plot/plot_from_wandb.py --field teacher_force_loss --run_id $MODEL
15
+
16
+ """
17
+ # Initialize the wandb API client
18
+ api = wandb.Api()
19
+
20
+ # Replace with your specific project and entity
21
+ entity = "latent-mage"
22
+ project = "video_val"
23
+
24
+ # List of datasets to process
25
+ datasets = [
26
+ "bridge_data_v2",
27
+ "fractal20220817_data",
28
+ "language_table",
29
+ "ucsd_pick_and_place_dataset_converted_externally_to_rlds",
30
+ "kaist_nonprehensile_converted_externally_to_rlds",
31
+ "ucsd_kitchen_dataset_converted_externally_to_rlds",
32
+ "utokyo_xarm_bimanual_converted_externally_to_rlds",
33
+ "stanford_hydra_dataset_converted_externally_to_rlds",
34
+ "austin_sirius_dataset_converted_externally_to_rlds",
35
+ "berkeley_fanuc_manipulation",
36
+ "berkeley_mvp_converted_externally_to_rlds",
37
+ "berkeley_rpt_converted_externally_to_rlds",
38
+ "cmu_play_fusion",
39
+ "iamlab_cmu_pickup_insert_converted_externally_to_rlds",
40
+ "qut_dexterous_manpulation",
41
+ "robo_net",
42
+ "furniture_bench_dataset_converted_externally_to_rlds",
43
+ "dlr_sara_grid_clamp_converted_externally_to_rlds",
44
+ "cmu_stretch",
45
+ "spoc",
46
+ "columbia_cairlab_pusht_real",
47
+ "droid",
48
+ "toto",
49
+ "io_ai_tech",
50
+ "conq_hose_manipulation",
51
+ "dobbe",
52
+ "berkeley_gnm_cory_hall",
53
+ "plex_robosuite",
54
+ "usc_cloth_sim_converted_externally_to_rlds",
55
+ "berkeley_cable_routing",
56
+ "imperial_wrist_dataset",
57
+ "bc_z",
58
+ "kuka",
59
+ "roboturk",
60
+ "metaworld",
61
+ "robomimic",
62
+ "epic_kitchen",
63
+ "ego4d",
64
+ "nyu_door_opening_surprising_effectiveness"
65
+ ]
66
+
67
+ # List to store dataframes of PSNR metrics for each dataset
68
+
69
+ # Get runs based on a path
70
+ # Set up argument parser
71
+ parser = argparse.ArgumentParser(description='Process some integers.')
72
+ parser.add_argument('--field', type=str, default='teacher_force_psnr', help='The field to process')
73
+ parser.add_argument('--run_id', type=str, default='40dataset_waction_add_gpu_8_nodes_1', help='The run ID to process')
74
+
75
+ # Parse arguments
76
+ args = parser.parse_args()
77
+
78
+ field = args.field
79
+ run_id = args.run_id
80
+
81
+ runs_path = f"{entity}/{project}/runs"
82
+ run = api.run(f"{entity}/{project}/runs/{run_id}")
83
+
84
+ # Get the history dataframe of a run
85
+ history = run.history(pandas=True)
86
+ model_step = 0
87
+ summary_metrics = run.summary
88
+ num_datasets = 0
89
+ fields = ['num_examples', 'teacher_force_psnr', 'teacher_force_psnr_delta', 'teacher_force_ssim', 'teacher_force_pred_lpips', 'teacher_force_loss']
90
+
91
+ for field in fields:
92
+ metrics_data = []
93
+ if not history.empty:
94
+ # Filter the history to only include PSNR metrics for the specified datasets
95
+ for dataset in datasets:
96
+ field_col = f"{dataset}/{field}"
97
+ step_col = f"{dataset}/model_step"
98
+ if field_col in history.columns:
99
+ # Calculate PSNR divided by the number of examples (uncomment if needed)
100
+ # history[field_col] = history[field_col] / history.shape[0]
101
+ valid_field = history[field_col].dropna()
102
+ if not valid_field.empty:
103
+ last_valid_value = valid_field.iloc[-1] # Get the last non-NaN value
104
+ num_datasets += 1
105
+ metrics = pd.DataFrame({field_col: [last_valid_value]})
106
+ metrics['dataset'] = dataset
107
+ metrics_data.append(metrics)
108
+
109
+ if step_col in summary_metrics:
110
+ model_step = summary_metrics[step_col]
111
+
112
+ # Combine all the metric dataframes into one
113
+ if metrics_data:
114
+ all_metrics_df = pd.concat(metrics_data, ignore_index=True)
115
+
116
+ # Print columns for debugging
117
+
118
+ # Compute aggregated statistics (mean, median, std, etc.) for PSNR
119
+ aggregated_stats = all_metrics_df.groupby('dataset').mean()
120
+
121
+ # Plot the mean PSNR for each dataset
122
+ plt.figure(figsize=(12, 8))
123
+ aggregated_stats[f'{field}'] = aggregated_stats.mean(axis=1)
124
+ aggregated_stats[f'{field}'].plot(kind='bar')
125
+ # print number of steps in the wandb run
126
+ print(f"run: {run_id} field: {field} steps: {model_step} num of dataset: {len(metrics_data)}")
127
+ print(f"{field}: {aggregated_stats[field].mean():.2f}+-{aggregated_stats[field].std():.2f}", )
128
+
129
+ # plt.title(f"Mean {field} for Each Dataset")
130
+ # plt.xlabel("Dataset")
131
+ # plt.ylabel(f"Mean {field} ")
132
+ # plt.xticks(rotation=90)
133
+ # plt.tight_layout()
134
+
135
+ # # Save the plot
136
+ # import os
137
+ # pwd = os.path.dirname(os.path.abspath(__file__))
138
+ # plt.savefig(f"{pwd}/output/{run.id}_{field}_plot.png")
139
+
140
+ # Display aggregated statistics
141
+ # print(aggregated_stats)
142
+
143
+ # Save the aggregated statistics as a CSV if needed
144
+ # aggregated_stats.to_csv(f"{run_id}_{field}_stat.csv", index=True)
common/plot/plot_model_scale.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import seaborn as sns
2
+ import matplotlib
3
+ import matplotlib.pyplot as plt
4
+ import numpy as np
5
+ import pandas as pd
6
+ # Adjusting the line thickness to better match the provided example
7
+
8
+
9
+
10
+ fig, ax = plt.subplots( figsize=(5, 4))
11
+ # 64, 64sqrt2, 128, 128sqrt2, 256, 256sqrt2, 512
12
+ x_values = [2.8, 10.2, 36.9, 174.22, 366.9, 755.9] #
13
+ y_values = [6.33, 5.52, 5.25, 5.19, 5.02, 5.02] # , 5.09
14
+ y_values = np.exp(y_values)
15
+ # 256sqrt2->700m
16
+ # 512->1.3billion
17
+
18
+ # Set line width for each line plot
19
+ line_width = 1.5
20
+ x = []
21
+ # Iterate over each subplot (task) and plot the lines with specified styles, markers, and adjusted line width
22
+ # for i, task in enumerate(tasks):
23
+
24
+ # ax.plot(x_values, y_values, marker='o', linestyle='--', color='#1f78b4', linewidth=line_width)
25
+ # # for i, txt in enumerate(y_values):
26
+ # # ax.annotate(f"{txt:.1f}", (x_values[i], y_values[i]), textcoords="offset points", xytext=(0,10), ha='center')
27
+ # ax.annotate(f"{y_values[-1]:.1f}", (x_values[-1], y_values[-1]), textcoords="offset points", xytext=(0,10), ha='center')
28
+
29
+
30
+ # # Set individual titles and axis labels for each subplot
31
+
32
+ # ax.set_xlabel("Model Parameters(M)", fontsize=14)
33
+ # ax.set_ylabel("Perplexity", fontsize=14)
34
+ # ax.set_ylim(0, 1)
35
+ fig, ax1 = plt.subplots(figsize=(5, 4))
36
+
37
+ INDEX = -2
38
+ # Plot Perplexity (left y-axis)
39
+ ax1.plot(x_values, y_values, marker='o', linestyle='-', color='#1f78b4', linewidth=line_width)
40
+ ax1.annotate(f"{y_values[INDEX]:.1f}", (x_values[INDEX], y_values[INDEX]), textcoords="offset points", xytext=(0, 10), ha='center')
41
+ ax1.set_xscale('log')
42
+ ax1.set_xlabel("Model Parameters(M)", fontsize=14)
43
+ ax1.set_ylabel("Perplexity", fontsize=14, color='#1f78b4')
44
+ ax1.tick_params(axis='y', labelcolor='#1f78b4')
45
+ # , 1.18
46
+
47
+ # Create a twin y-axis for controllability (right y-axis)
48
+ ax2 = ax1.twinx()
49
+ controllability_values = [0.11, 1.02, 1.07, 1.12, 1.87, 1.34] # Example values for controllability
50
+ ax2.plot(x_values, controllability_values, marker='s', linestyle='--', color='#006400', linewidth=line_width)
51
+ ax2.set_ylabel("Delta PSNR", fontsize=14, color='#006400')
52
+ ax2.set_ylim(0, np.max(controllability_values) + 0.2)
53
+ ax2.tick_params(axis='y', labelcolor='#006400')
54
+ ax2.annotate(f"{controllability_values[INDEX]:.1f}", (x_values[INDEX], controllability_values[INDEX]), textcoords="offset points", xytext=(0, 10), ha='center')
55
+
56
+ # Save the figure in high resolution
57
+ plt.tight_layout()
58
+ # plt.show()
59
+
60
+ plt.savefig(f"output/model_sizes.png", dpi=300)
61
+
62
+ # Adding a centralized legend that appears above the plot
63
+ # fig.legend(y_values, loc='upper center', bbox_to_anchor=(0.5, 1.05), ncol=3, frameon=False, markerscale=1.5)
64
+
common/plot/plot_pretrain_ablation.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import seaborn as sns
2
+ import matplotlib
3
+ import matplotlib.pyplot as plt
4
+ import numpy as np
5
+ import pandas as pd
6
+
7
+ font = {
8
+ "family": "normal",
9
+ "size": 22,
10
+ }
11
+
12
+ matplotlib.rc("font", **font)
13
+ sns.set(rc={"font.family": "Times New Roman"})
14
+ sns.set(style="whitegrid")
15
+ sns.set(font_scale=3, style="whitegrid")
16
+
17
+ # Sample data for plotting
18
+ categories = ["Scratch", "Passive Pre-Train", "Pre-Train", "Pre-Train (Large)"]
19
+ values = [1.0, 1.0, 1.0, 1.0]
20
+
21
+ # Define custom colors for the bars
22
+ colors = ["#4c72b0", "#55a868", "#c44e52", "#8172b2"] # Adjust as needed
23
+
24
+ plt.figure(figsize=(14, 12))
25
+ ax = sns.barplot(
26
+ x=categories, y=values, alpha=0.9, palette=colors, edgecolor="black"
27
+ )
28
+ for container in ax.containers:
29
+ ax.bar_label(container, label_type="edge", fontsize="x-large", fmt="%.2f")
30
+
31
+ # Adding title and labels
32
+ plt.xlabel("Setting", fontsize=40)
33
+ plt.ylabel("Validation Perplexity", fontsize=40)
34
+ plt.xticks(fontsize=30)
35
+ plt.yticks(fontsize=30)
36
+ plt.legend(fontsize="small", title_fontsize="small", loc="lower left")
37
+
38
+ # Remove the borders
39
+ sns.despine(left=True, bottom=True)
40
+
41
+ # Display the plot
42
+ plt.tight_layout()
43
+ plt.savefig(f"output/model_ablation.png", dpi=300) # Save the figure in high resolution
44
+ plt.show()
common/plot/plot_pretrain_ablation_mar.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import seaborn as sns
2
+ import matplotlib
3
+ import matplotlib.pyplot as plt
4
+ import numpy as np
5
+ import pandas as pd
6
+
7
+ font = {
8
+ "family": "normal",
9
+ "size": 22,
10
+ }
11
+
12
+ matplotlib.rc("font", **font)
13
+ sns.set(rc={"font.family": "Times New Roman"})
14
+ sns.set(style="whitegrid")
15
+ sns.set(font_scale=3, style="whitegrid")
16
+
17
+ # Sample data for plotting
18
+ categories = ["Scratch", "Passive Pre-Train", "Pre-Train", "Pre-Train (Large)"]
19
+ values = [1.0, 1.0, 1.0, 1.0]
20
+
21
+ # Define custom colors for the bars
22
+ colors = ["#4c72b0", "#55a868", "#c44e52", "#8172b2"] # Adjust as needed
23
+
24
+ plt.figure(figsize=(14, 12))
25
+ ax = sns.barplot(
26
+ x=categories, y=values, alpha=0.9, palette=colors, edgecolor="black"
27
+ )
28
+ for container in ax.containers:
29
+ ax.bar_label(container, label_type="edge", fontsize="x-large", fmt="%.2f")
30
+
31
+ # Adding title and labels
32
+ plt.xlabel("Setting", fontsize=40)
33
+ plt.ylabel("Validation Perplexity", fontsize=40)
34
+ plt.xticks(fontsize=30)
35
+ ax.tick_params(axis='x', rotation=15)
36
+ plt.yticks(fontsize=30)
37
+ plt.legend(fontsize="small", title_fontsize="small", loc="lower left")
38
+
39
+ # Remove the borders
40
+ sns.despine(left=True, bottom=True)
41
+
42
+ # Display the plot
43
+ plt.tight_layout()
44
+ plt.savefig(f"output/model_ablation.png", dpi=300) # Save the figure in high resolution
45
+ plt.show()
cont_data.py ADDED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import math
3
+ import os
4
+ import random
5
+ from pathlib import Path
6
+
7
+ import numpy as np
8
+ import torch
9
+ from einops import rearrange
10
+ from torch.utils.data import Dataset as TorchDataset
11
+
12
+ from datasets.encode_openx_dataset import DATA_FREQ_TABLE
13
+ from genie.config import GenieConfig
14
+ from genie.st_mask_git import cosine_schedule
15
+
16
+ SVD_SCALE = 0.18215
17
+
18
+ def normalize_actions(actions):
19
+ """
20
+ compute mean and std of actions. Normalize actions is done inside the network.
21
+ """
22
+ mean = np.mean(actions, axis=0).tolist()
23
+ std = np.std(actions, axis=0).tolist()
24
+ return actions, [mean, std]
25
+
26
+
27
+ class RawFeatureDataset(TorchDataset):
28
+ """ Loads raw float32 tokens as memmap-backed array """
29
+ def __init__(
30
+ self,
31
+ data_dir,
32
+ window_size,
33
+ stride=1,
34
+ filter_interrupts=True,
35
+ filter_overlaps=False,
36
+ use_actions=False,
37
+ max_traj_num=1000000,
38
+ compute_stride_from_freq_table=True,
39
+ natural_hz=2,
40
+ datio_noise_ratio=0.0,
41
+ use_raw_image_as_latent=False,
42
+ domain=None,
43
+ ):
44
+ """
45
+ Args:
46
+ data_dir: directory with the same format as `data/train_v0` and `data/val_v0`.
47
+ Notably, has `video.bin` and `metadata.json`
48
+ window_size: number of frames per "video" sequence
49
+ stride: frame skip
50
+ filter_interrupts: Under 3% of training frame sequences are the concatenation of two different clips.
51
+ If filter_interrupts is True, will filter out these sequences using the segment ids.
52
+ filter_overlaps: If False (default), one frame will appear in multiple examples;
53
+ e.g. frame 0 might appear as the first frame in example 0 and also the second frame in example 15.
54
+ If True, will filter out examples so that each frame appears at most once in the dataset.
55
+ use_actions: If True, will load the actions from the `actions` folder for the models
56
+ """
57
+ data_dir = Path(data_dir)
58
+ with open(data_dir / "metadata.json") as f:
59
+ self.metadata = json.load(f)
60
+
61
+ # TODO: assert not quantized in metadata
62
+ shape = (self.metadata["num_images"], self.metadata.get("latent_channels", 4), self.metadata["h"], self.metadata["w"]) #
63
+ print("token shape:", shape)
64
+ self.use_raw_image_as_latent = use_raw_image_as_latent
65
+ if use_raw_image_as_latent:
66
+ shape = (shape[0], 3, shape[2], shape[3])
67
+ # resize to 32x32
68
+
69
+ video_tokens_path, segment_ids_path, action_tokens_path = [data_dir / f"{name}.bin"
70
+ for name in ["video", "segment_ids", "actions"]]
71
+
72
+ token_dtype = np.dtype(self.metadata.get("token_dtype", "float16"))
73
+ self.data = np.memmap(video_tokens_path, mode="r", shape=shape, dtype=token_dtype)
74
+ print("data nan:", torch.isnan(torch.from_numpy(self.data[:100].copy())).sum())
75
+ # import IPython; IPython.embed()
76
+ if use_raw_image_as_latent:
77
+ # debug for robomimic dataset
78
+ # 256->64x64
79
+ self.metadata["h"] = 32
80
+ self.metadata["w"] = 32
81
+ self.metadata["latent_channels"] = 3
82
+
83
+ self.window_size, self.stride = window_size, stride
84
+ self.datio_noise_ratio = datio_noise_ratio
85
+
86
+ if domain is not None: # TODO: remove
87
+ self.name = domain
88
+ else:
89
+ self.name = self.metadata["name"]
90
+
91
+ self.name = self.name.replace("_noquant", "")
92
+ self.stride = stride
93
+ if compute_stride_from_freq_table:
94
+ self.stride = max(DATA_FREQ_TABLE.get(self.name, 1) // natural_hz, 1)
95
+ self.n_action = self.metadata.get("action_dim", 1) * (self.stride)
96
+
97
+ if use_actions:
98
+ actions = []
99
+
100
+ # hack here for the separations in the 1x datasets
101
+ for action_file in sorted((data_dir / "actions").iterdir()):
102
+ actions.append(np.memmap(action_file, dtype=np.float32, mode="r").reshape(len(self.data), -1))
103
+
104
+ self.actions = np.concatenate(actions, axis=-1)
105
+ self.actions, self.action_stat = normalize_actions(self.actions)
106
+
107
+ if os.path.isfile(segment_ids_path):
108
+ self.segment_ids = np.memmap(
109
+ segment_ids_path,
110
+ dtype=np.int32,
111
+ mode="r",
112
+ shape=(self.metadata["num_images"],)
113
+ )
114
+ else:
115
+ self.segment_ids = None
116
+ if filter_interrupts:
117
+ raise NotImplementedError("Cannot filter interrupted sequences without segment ids.")
118
+
119
+ # Number of frames between the first and last frames of a video sequence (excluding one endpoint frame)
120
+ self.video_len = (self.window_size - 1) * self.stride
121
+ self.valid_start_inds = []
122
+
123
+ for start_ind in range(len(self.data) - self.video_len - self.stride):
124
+ # Assuming `segment_ids` is monotonically increasing, a sequence is interrupted (or too short)
125
+ # if the first and last frames have different segment ids.
126
+ if not (filter_interrupts and self.segment_ids[start_ind] != self.segment_ids[start_ind + self.video_len]):
127
+ self.valid_start_inds.append(start_ind)
128
+
129
+ if len(self.valid_start_inds) >= max_traj_num:
130
+ break
131
+
132
+ if filter_overlaps:
133
+ # Instead of using a sliding window, use each frame at most once
134
+ filtered_start_inds = []
135
+ for start_ind in self.valid_start_inds:
136
+ overlapping_start_inds = {start_ind - i * self.stride for i in range(1, self.window_size)}
137
+ # all sequences from `overlapping_start_inds` will also contain `start_ind`,
138
+ # so exclude sequence starting from `start_ind` if any of `overlapping_start_inds` is already being used
139
+ for existing_start_ind in filtered_start_inds[-self.window_size * self.stride:]:
140
+ # Bound could be improved
141
+ if existing_start_ind in overlapping_start_inds:
142
+ break
143
+ else:
144
+ filtered_start_inds.append(start_ind)
145
+
146
+ self.valid_start_inds = filtered_start_inds
147
+
148
+ num_videos = len(np.unique(self.segment_ids))
149
+ print(f"Loaded {len(self)} sequences from {data_dir} {self.stride=} {self.window_size=} {self.n_action=} {num_videos=}")
150
+
151
+ def __len__(self):
152
+ return len(self.valid_start_inds)
153
+
154
+ def __getitem__(self, idx):
155
+ """
156
+ Returns a flattened sequence of tokens representing `self.window_size` frames,
157
+ spaced `self.stride` apart.
158
+ """
159
+ start_ind = self.valid_start_inds[idx]
160
+ x = self.data[start_ind : start_ind + self.video_len + 1 : self.stride].copy()
161
+ x = torch.FloatTensor(x).float()
162
+ if self.use_raw_image_as_latent:
163
+ x = torch.nn.functional.interpolate(x, size=(self.metadata["h"], self.metadata["w"]))
164
+ # normalize
165
+ x = x / 255 - 0.5
166
+ else:
167
+ x = x * SVD_SCALE
168
+
169
+ x = rearrange(x, "t c h w -> (t h w) c")
170
+ # divide it when decoding
171
+ # reconstructions since the input ids and the labels are the same
172
+ attention_mask = torch.ones_like(x)
173
+ data_dict = {
174
+ "input_ids": x,
175
+ "labels": x,
176
+ "attention_mask": attention_mask,
177
+ "h": self.metadata["h"],
178
+ "w": self.metadata["w"],
179
+ "c": self.metadata["latent_channels"],
180
+ }
181
+ if hasattr(self, "actions"):
182
+ # we want to have all actions within the stride to predict the next frame at the end of the stride
183
+ # we will concatenate the actions from [window_size, d_action] to [window_size, d_action * stride]
184
+ data_dict['action_ids'] = self.actions[start_ind:start_ind + self.video_len + self.stride].reshape(self.window_size, -1)
185
+ data_dict['action_ids'] = torch.from_numpy(data_dict['action_ids'].astype(np.float32))
186
+
187
+ data_dict["domain"] = self.name.replace("_noquant", "")
188
+ return data_dict
189
+
190
+
191
+ def get_maskgit_collator_feature(config: GenieConfig):
192
+ # mask_token_id = config.image_vocab_size
193
+
194
+ def collate_fn(features) -> dict[str, torch.Tensor]:
195
+ # during training, map (z_0, z_1', z_2') -> (null, z_1, z_2)
196
+ # (z_0, z_1') -> (null, z_1) is the diffusion operator on z_1' -> z_1
197
+
198
+ h = features[0]["h"]
199
+ w = features[0]["w"]
200
+ input_ids = torch.stack([ex["input_ids"] for ex in features])
201
+ device = input_ids.device
202
+ x_THWC = rearrange(input_ids, "b (t h w) c -> b t h w c", b=len(features), t=config.T, h=h, w=w)
203
+ labels = x_THWC.clone()
204
+ first_masked_frame = config.T
205
+
206
+ mask = torch.zeros(1).long()
207
+ mask_token_indicator = torch.zeros((len(features), config.T, h, w)).long()
208
+
209
+ if config.dataloader_apply_mask:
210
+ if random.random() < config.non_mlm_ratio: # Closer to autoregressive inference
211
+ # Leave frames [0, first_masked_frame) unmasked.
212
+ first_masked_frame = random.randint(config.num_prompt_frames, config.T - 1)
213
+ else: # Typical MLM masking
214
+ first_masked_frame = 1
215
+
216
+ c = 0
217
+ while mask.max() == 0: # We could get unlucky and mask no tokens?
218
+ # per-minibatch, per-frame masking probability (could try variable masking rate from MUSE)
219
+ rand = torch.rand(len(features), config.T - first_masked_frame, 1, 1)
220
+ # add a minimum mask ratio
221
+ rand_mask = rand * (1 - config.dataloader_mask_ratio_min) + config.dataloader_mask_ratio_min
222
+ mask_prob_T = cosine_schedule(rand_mask)
223
+ r = torch.rand_like(x_THWC[:, first_masked_frame:, ..., 0], dtype=torch.float)
224
+ mask = r < mask_prob_T
225
+ c += 1
226
+
227
+ if c > 1:
228
+ print(f"Generated mask {c} > 1 times.")
229
+
230
+ mask_token_indicator = torch.cat([
231
+ torch.zeros((len(features), first_masked_frame, h, w), dtype=mask.dtype), mask], dim=1)
232
+
233
+ data_dict = {
234
+ "input_ids": rearrange(x_THWC, "b t h w c -> b (t h w) c"),
235
+ "labels": rearrange(labels, "b t h w c-> b (t h w) c"),
236
+ "masked_tokens_indicator": mask_token_indicator,
237
+ }
238
+
239
+ if "action_ids" in features[0]:
240
+ data_dict['action_ids'] = torch.stack([ex["action_ids"] for ex in features])
241
+ data_dict['domain'] = [ex["domain"] for ex in features]
242
+ data_dict['h'] = [ex["h"] for ex in features]
243
+ data_dict['w'] = [ex["w"] for ex in features]
244
+ return data_dict
245
+ return collate_fn
data.py ADDED
@@ -0,0 +1,240 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import math
3
+ import os
4
+ import random
5
+ from pathlib import Path
6
+
7
+ import numpy as np
8
+ import torch
9
+ from einops import rearrange
10
+ from torch.utils.data import Dataset as TorchDataset
11
+
12
+ from datasets.encode_openx_dataset import DATA_FREQ_TABLE
13
+ from genie.factorization_utils import factorize_token_ids, unfactorize_token_ids
14
+ from genie.config import GenieConfig
15
+ from genie.st_mask_git import cosine_schedule
16
+
17
+
18
+ def normalize_actions(actions: np.ndarray) -> tuple[np.ndarray, list[list[float]]]:
19
+ """
20
+ compute mean and std of actions. Normalize actions is done inside the network.
21
+ """
22
+ mean = np.mean(actions, axis=0).tolist()
23
+ std = np.std(actions, axis=0).tolist()
24
+ return actions, [mean, std]
25
+
26
+
27
+ class RawTokenDataset(TorchDataset):
28
+ """ Loads raw uint32 tokens as memmap-backed array """
29
+ def __init__(
30
+ self,
31
+ data_dir,
32
+ window_size,
33
+ stride=1,
34
+ filter_interrupts=True,
35
+ filter_overlaps=False,
36
+ use_actions=False,
37
+ name='',
38
+ max_traj_num=1000000,
39
+ compute_stride_from_freq_table=True,
40
+ natural_hz=2,
41
+ drop_action_ratio=0.0
42
+ ):
43
+ """
44
+ Args:
45
+ data_dir: directory with the same format as `data/train_v0` and `data/val_v0`.
46
+ Notably, has `video.bin` and `metadata.json`
47
+ window_size: number of frames per "video" sequence
48
+ stride: frame skip
49
+ filter_interrupts: Under 3% of training frame sequences are the concatenation of two different clips.
50
+ If filter_interrupts is True, will filter out these sequences using the segment ids.
51
+ filter_overlaps: If False (default), one frame will appear in multiple examples;
52
+ e.g. frame 0 might appear as the first frame in example 0 and also the second frame in example 15.
53
+ If True, will filter out examples so that each frame appears at most once in the dataset.
54
+ use_actions: If True, will load the actions from the `actions` folder for the models
55
+ name: the name of the dataset
56
+
57
+ """
58
+ data_dir = Path(data_dir)
59
+ with open(data_dir / "metadata.json") as f:
60
+ self.metadata = json.load(f)
61
+
62
+ shape = (self.metadata["num_images"], self.metadata["h"], self.metadata["w"]) # self.metadata["s"], self.metadata["s"]
63
+ video_tokens_path, segment_ids_path, action_tokens_path = [data_dir / f"{name}.bin"
64
+ for name in ["video", "segment_ids", "actions"]]
65
+ token_dtype = np.dtype(self.metadata.get("token_dtype", "uint32"))
66
+ self.data = np.memmap(video_tokens_path, dtype=token_dtype, mode="r", shape=shape)
67
+ self.window_size, self.stride = window_size, stride
68
+
69
+ if len(name) == 0:
70
+ self.name = self.metadata["name"]
71
+ else: # remove later
72
+ self.name = name
73
+
74
+ if compute_stride_from_freq_table:
75
+ self.stride = max(DATA_FREQ_TABLE.get(self.name, 1) // natural_hz, 1)
76
+ print(f"RawTokenDataset: {self.name=} {self.stride=}")
77
+
78
+ self.n_action = self.metadata.get("action_dim", 1) * (self.stride)
79
+ self.drop_action_ratio = drop_action_ratio
80
+
81
+ if use_actions:
82
+ actions = []
83
+
84
+ # hack here for the separations in the 1x datasets
85
+ for action_file in sorted((data_dir / "actions").iterdir()):
86
+ actions.append(np.memmap(action_file, dtype=np.float32, mode="r").reshape(len(self.data), -1))
87
+
88
+ self.actions = np.concatenate(actions, axis=-1)
89
+ self.actions, self.action_stat = normalize_actions(self.actions)
90
+
91
+ if os.path.isfile(segment_ids_path):
92
+ self.segment_ids = np.memmap(
93
+ segment_ids_path,
94
+ dtype=np.int32,
95
+ mode="r",
96
+ shape=(self.metadata["num_images"],)
97
+ )
98
+ else:
99
+ self.segment_ids = None
100
+ if filter_interrupts:
101
+ raise NotImplementedError("Cannot filter interrupted sequences without segment ids.")
102
+
103
+ # Number of frames between the first and last frames of a video sequence (excluding one endpoint frame)
104
+ self.video_len = (self.window_size - 1) * self.stride
105
+
106
+ self.valid_start_inds = []
107
+ for start_ind in range(len(self.data) - self.video_len - self.stride):
108
+ # Assuming `segment_ids` is monotonically increasing, a sequence is interrupted (or too short)
109
+ # if the first and last frames have different segment ids.
110
+ if not (filter_interrupts and self.segment_ids[start_ind] != self.segment_ids[start_ind + self.video_len]):
111
+ self.valid_start_inds.append(start_ind)
112
+
113
+ if self.segment_ids is not None and self.segment_ids[start_ind] >= max_traj_num: # because we will filter based on window size later
114
+ # len(self.valid_start_inds) >= max_traj_num
115
+ break
116
+
117
+ if filter_overlaps:
118
+ # Instead of using a sliding window, use each frame at most once
119
+ filtered_start_inds = []
120
+ for start_ind in self.valid_start_inds:
121
+ overlapping_start_inds = {start_ind - i * self.stride for i in range(1, self.window_size)}
122
+ # all sequences from `overlapping_start_inds` will also contain `start_ind`,
123
+ # so exclude sequence starting from `start_ind` if any of `overlapping_start_inds` is already being used
124
+ for existing_start_ind in filtered_start_inds[-self.window_size * self.stride:]:
125
+ # Bound could be improved
126
+ if existing_start_ind in overlapping_start_inds:
127
+ break
128
+ else:
129
+ filtered_start_inds.append(start_ind)
130
+
131
+ self.valid_start_inds = filtered_start_inds
132
+
133
+ self.num_videos = len(np.unique(self.valid_start_inds))
134
+ print(f"Loaded {len(self)} sequences from {data_dir} {self.stride=} {self.window_size=} {self.n_action=} {self.num_videos=}")
135
+
136
+ def __len__(self):
137
+ return len(self.valid_start_inds)
138
+
139
+ def __getitem__(self, idx):
140
+ """
141
+ Returns a flattened sequence of tokens representing `self.window_size` frames,
142
+ spaced `self.stride` apart.
143
+ """
144
+ start_ind = self.valid_start_inds[idx]
145
+ x = torch.from_numpy((self.data[start_ind : start_ind + self.video_len + 1 : self.stride]).astype(np.int64))
146
+ x = x.flatten() # 16 x 16 x 16
147
+
148
+ # reconstructions since the input ids and the labels are the same
149
+ attention_mask = torch.ones_like(x)
150
+ data_dict = {
151
+ "input_ids": x,
152
+ "labels": x,
153
+ "attention_mask": attention_mask,
154
+ "h": self.metadata["h"],
155
+ "w": self.metadata["w"],
156
+ }
157
+ if hasattr(self, "actions") and np.random.uniform() > self.drop_action_ratio:
158
+ # we want to have all actions within the stride to predict the next frame at the end of the stride
159
+ # we will concatenate the actions from [window_size, d_action] to [window_size, d_action * stride]
160
+ # S x T x d_action
161
+ data_dict['action_ids'] = self.actions[start_ind:start_ind + self.video_len + self.stride].reshape(self.window_size, -1)
162
+ data_dict['action_ids'] = torch.from_numpy(data_dict['action_ids'].astype(np.float32))
163
+
164
+ data_dict["domain"] = self.name
165
+ return data_dict
166
+
167
+
168
+ def get_maskgit_collator(config: GenieConfig):
169
+ mask_token_id = config.image_vocab_size
170
+ # h = w = math.isqrt(config.S)
171
+
172
+ def collate_fn(features) -> dict[str, torch.Tensor]:
173
+ # during training, map (z_0, z_1', z_2') -> (null, z_1, z_2)
174
+ # (z_0, z_1') -> (null, z_1) is the diffusion operator on z_1' -> z_1
175
+ h = features[0]["h"]
176
+ w = features[0]["w"]
177
+ input_ids = torch.stack([ex["input_ids"] for ex in features])
178
+ device = input_ids.device
179
+ x_THW = rearrange(input_ids, "b (t h w) -> b t h w", b=len(features), t=config.T,
180
+ h=h, w=w)
181
+ x_THWC = factorize_token_ids(x_THW, config.num_factored_vocabs, config.factored_vocab_size)
182
+ labels = x_THW.clone()
183
+
184
+ if config.dataloader_apply_corruption:
185
+ # As done in Copilot-4D paper, add random noise sampled with a random rate between 0% and `config.max_corrupt_rate`
186
+ r = torch.rand(x_THWC.size(), device=device)
187
+ u01 = torch.rand((), device=device)
188
+ random_patches_mask = r < config.max_corrupt_rate * u01
189
+ random_values = torch.randint(low=0, high=config.factored_vocab_size, size=x_THWC.size(),
190
+ dtype=torch.long, device=device)
191
+ x_THWC[random_patches_mask] = random_values[random_patches_mask]
192
+
193
+ if random.random() < config.non_mlm_ratio: # Closer to autoregressive inference
194
+ # Leave frames [0, first_masked_frame) unmasked.
195
+ # first_masked_frame = random.randint(config.num_prompt_frames, config.T - 1)
196
+ first_masked_frame = random.randint(config.num_prompt_frames, config.T - 1)
197
+ x_THWC_view = x_THWC[:, first_masked_frame:]
198
+
199
+ # Arbitrary numbers here, but corrupting later frames more
200
+ # since we likely have compounding errors.
201
+ correct_rate = random.uniform(config.dataloader_mask_ratio_min, 1.0)
202
+ for i in range(x_THWC_view.size(1)):
203
+ correct_rate *= random.uniform(0.9, 1.0)
204
+ r = torch.rand((len(features), h, w, config.num_factored_vocabs), device=device)
205
+ random_patches_mask = r > correct_rate
206
+ x_THWC_view[:, i][random_patches_mask] = random_values[:, first_masked_frame + i][random_patches_mask]
207
+ else: # Typical MLM masking
208
+ first_masked_frame = 1
209
+
210
+ mask = torch.zeros(1)
211
+ if config.dataloader_apply_mask:
212
+ c = 0
213
+
214
+ while mask.max() == 0: # We could get unlucky and mask no tokens?
215
+ # per-minibatch, per-frame masking probability (could try variable masking rate from MUSE)
216
+ mask_prob_T = cosine_schedule(torch.rand(len(features), config.T - first_masked_frame, 1, 1))
217
+ r = torch.rand_like(x_THW[:, first_masked_frame:], dtype=torch.float)
218
+ mask = r < mask_prob_T
219
+ c += 1
220
+
221
+ if c > 1:
222
+ print(f"Generated mask {c} > 1 times.")
223
+
224
+ x_THW = unfactorize_token_ids(x_THWC, config.num_factored_vocabs, config.factored_vocab_size)
225
+ x_THW[:, first_masked_frame:][mask] = mask_token_id
226
+
227
+ data_dict = {
228
+ "input_ids": rearrange(x_THW, "b t h w -> b (t h w)"),
229
+ "labels": rearrange(labels, "b t h w -> b (t h w)"),
230
+ }
231
+
232
+ if "action_ids" in features[0]:
233
+ data_dict['action_ids'] = torch.stack([ex["action_ids"] for ex in features])
234
+ data_dict['domain'] = [ex["domain"] for ex in features]
235
+ data_dict['h'] = [ex["h"] for ex in features]
236
+ data_dict['w'] = [ex["w"] for ex in features]
237
+ return data_dict
238
+
239
+
240
+ return collate_fn
datasets/.DS_Store ADDED
Binary file (6.15 kB). View file
 
datasets/__init__.py ADDED
File without changes
datasets/encode_extern_dataset.py ADDED
@@ -0,0 +1,291 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # Licensed under The MIT License [see LICENSE for details]
3
+ # --------------------------------------------------------
4
+ import argparse
5
+ import json
6
+ import os
7
+ import time
8
+ import traceback
9
+ from typing import Optional
10
+
11
+ import numpy as np
12
+ from tqdm import tqdm
13
+
14
+ from datasets.encode_openx_dataset import MIN_VAL_EXAMPLES, MAX_VAL_EXAMPLES, get_shard_inds, VAL_RATIO, \
15
+ process_dataset_step, DATA_FREQ_TABLE
16
+ from datasets.extern.ego4d import ego4d_dataset_size, ego4d_dataset_generator
17
+ from datasets.extern.egoexo4d import egoexo4d_dataset_size, egoexo4d_dataset_generator
18
+ from datasets.extern.robomimic import robomimic_dataset_generator, robomimic_dataset_size
19
+ from . import utils
20
+
21
+
22
+ SCRIPT_DESCRIPTION="""
23
+ Similar to encode_openx_dataset.py except for non-OpenX datasets.
24
+ Again, each split can be partitioned into multiple shards,
25
+ which is useful for parallelized encoding across GPUs.
26
+
27
+ Example usage:
28
+ CUDA_VISIBLE_DEVICES=0 python -m datasets.encode_extern_dataset --dataset_name egoexo4d --data_split train --num_shards 1000 --curr_shard_rank 400
29
+
30
+ Untested usage (SVD tokenizer):
31
+ CUDA_VISIBLE_DEVICES=0 python -m datasets.encode_extern_dataset --dataset_name robomimic --data_split val --no_quantization --encoder_type temporalvae --encoder_name_or_path 'stabilityai/stable-video-diffusion-img2vid'
32
+ """.strip()
33
+
34
+ DATASET_TO_GEN_AND_SIZE = {
35
+ "ego4d": (ego4d_dataset_generator, ego4d_dataset_size),
36
+ "egoexo4d": (egoexo4d_dataset_generator, egoexo4d_dataset_size),
37
+ "robomimic": (robomimic_dataset_generator, robomimic_dataset_size),
38
+ }
39
+
40
+
41
+ def encode_dataset_split(
42
+ extern_dataset_name: str,
43
+ split: str,
44
+ max_episodes: Optional[int],
45
+ original_res: bool,
46
+ no_quantization: bool,
47
+ curr_shard_rank: int,
48
+ num_shards: int,
49
+ root_dir: str,
50
+ encoder_type: str,
51
+ encoder_name_or_path: str,
52
+ dataset_postfix: str = "",
53
+ no_encoding: bool = False,
54
+ ):
55
+ """
56
+ Encodes (e.g. tokenizes) dataset.
57
+ The data written to disk can be used to load a `RawTokenDataset` (or the continuous version.)
58
+
59
+ Args:
60
+ extern_dataset_name: TODO
61
+ split: expected to be either "train" or "val". TODO: decide how to split
62
+ max_episodes: the maximum number of trajectories to include in the dataset.
63
+ dataset_postfix: will be a suffix of the output dirname.
64
+ image_encoder: string specifying the type of image encoder/tokenizer to use.
65
+ original_res: if True, will maintain original resolution of the video rather than resizing it to 256x256.
66
+ no_quantization: if True, will not perform quantization step in image encoder.
67
+ """
68
+ extern_dataset_name = extern_dataset_name.strip() # never modified
69
+ suffixed_dataset_name = extern_dataset_name # will modify later
70
+
71
+ if original_res:
72
+ suffixed_dataset_name = f"{suffixed_dataset_name}_originalres"
73
+ if no_quantization:
74
+ suffixed_dataset_name = f"{suffixed_dataset_name}_noquant"
75
+ if no_encoding:
76
+ suffixed_dataset_name = f"{suffixed_dataset_name}_noencoding"
77
+ save_dirname = "_".join([suffixed_dataset_name, encoder_type, dataset_postfix, split])
78
+ dataset_path = os.path.join(root_dir, save_dirname)
79
+ print("=" * 25)
80
+ print(f"{dataset_path=}")
81
+ utils.mkdir_if_missing(dataset_path)
82
+
83
+ # Load data
84
+ generator, size_func = DATASET_TO_GEN_AND_SIZE[extern_dataset_name]
85
+ num_examples = size_func()
86
+ if max_episodes is not None:
87
+ num_examples = min(num_examples, max_episodes) # clip num_examples
88
+
89
+ # We will only operate on a subset of the training examples, depending on:
90
+ # 1) The split (train/val). Some examples are reserved for the other split.
91
+ # 2) Sharding
92
+ assert num_examples > MIN_VAL_EXAMPLES # non-positive number of train examples otherwise
93
+ num_val_examples = np.clip(int(VAL_RATIO * num_examples), MIN_VAL_EXAMPLES, MAX_VAL_EXAMPLES)
94
+
95
+ if split == "train": # first_ind inclusive, last_ind exclusive
96
+ first_split_ind, last_split_ind = num_val_examples, num_examples
97
+ elif split == "val":
98
+ first_split_ind, last_split_ind = 0, num_val_examples
99
+ else:
100
+ raise NotImplementedError(f"{split=}")
101
+
102
+ first_shard_ind, last_shard_ind = get_shard_inds(first_split_ind, last_split_ind, curr_shard_rank, num_shards)
103
+ print(f"Total number of examples in {suffixed_dataset_name}: {num_examples}")
104
+ print(f"Number of examples for {split=}, shard {curr_shard_rank} of {num_shards}: "
105
+ f"{last_shard_ind - first_shard_ind}. {first_shard_ind=} {last_shard_ind=}")
106
+
107
+ ##### Encode data #####
108
+ traj_lens = [] # only used to print statistics
109
+ videos = [] # NOTE: videos/actions for the entire shard are stored in RAM until the end
110
+ actions = []
111
+ segment_ids = []
112
+
113
+ # split based on some fixed batch sizes to reset RAM.
114
+ max_batch_per_loading = 10
115
+ pbar = tqdm(range(first_shard_ind, last_shard_ind, max_batch_per_loading), position=0, leave=True)
116
+ start_time = time.time()
117
+
118
+ for start_idx in pbar:
119
+ end_idx = min(start_idx + max_batch_per_loading, last_shard_ind)
120
+ pbar.set_description(f"{suffixed_dataset_name} caching episodes: {start_idx}:{end_idx}")
121
+ ds = generator(range(start_idx, end_idx))
122
+
123
+ for chunk_idx, episode in enumerate(tqdm(ds, position=1, leave=False)):
124
+ segment_id = start_idx + chunk_idx
125
+ try:
126
+ # batchify the data and then process
127
+ for step_ind, step_data in enumerate(episode["steps"]):
128
+ dataset_step = process_dataset_step(
129
+ step_data,
130
+ encoder_type=encoder_type,
131
+ encoder_name_or_path=encoder_name_or_path,
132
+ keep_res=original_res,
133
+ quantize=not no_quantization,
134
+ no_encoding=no_encoding
135
+ )
136
+
137
+ segment_ids.append(segment_id)
138
+ videos.append(dataset_step["image"])
139
+ actions.append(dataset_step["action"])
140
+
141
+ traj_lens.append(step_ind + 1) # number of steps in this trajectory
142
+ except:
143
+ print("-" * 25)
144
+ print(f"Add episode failed: {segment_id=}", traceback.format_exc(), suffixed_dataset_name)
145
+
146
+ # 2 day timeout
147
+ if time.time() - start_time > 86400 * 2:
148
+ print(f"Writing dataset {suffixed_dataset_name} timed out")
149
+ break
150
+
151
+ if len(videos) == 0:
152
+ print("Empty shard!")
153
+ with open(f"{dataset_path}/error.json", "w") as f:
154
+ json.dump({"status": "empty_shard"}, f)
155
+
156
+ return
157
+
158
+ if no_quantization:
159
+ num_channels, height, width = videos[-1].shape[:3] # num_channels is not actually stored in metadata
160
+ else:
161
+ height, width = videos[-1].shape[:2]
162
+ num_channels = None
163
+
164
+ ##### Write videos, actions, segment_ids, and metadata #####
165
+ # align format to save segment_ids.bin, video.bin, actions/action.bin, metadata.json
166
+ # save videos
167
+ videos = np.stack(videos, axis=0)
168
+ # fp = np.memmap(f'{dataset_path}/video.bin', dtype=video_dtype, mode='w+', shape=videos.shape)
169
+ # fp[:] = videos[:]
170
+ videos.tofile(f'{dataset_path}/video.bin')
171
+
172
+ # save action
173
+ utils.mkdir_if_missing(f'{dataset_path}/actions')
174
+ actions = np.stack(actions, axis=0)
175
+ # fp = np.memmap(f'{dataset_path}/actions/actions.bin', dtype=np.float32, mode='w+', shape=actions.shape)
176
+ # fp[:] = actions[:]
177
+ actions = actions.astype(np.float32)
178
+ actions.tofile(f'{dataset_path}/actions/actions.bin')
179
+
180
+ # save segment_ids
181
+ segment_ids = np.array(segment_ids)
182
+ # fp = np.memmap(f'{dataset_path}/segment_ids.bin', dtype=np.int32, mode='w+', shape=segment_ids.shape)
183
+ # fp[:] = segment_ids[:] # map to trajectory index
184
+ segment_ids = segment_ids.astype(np.int32)
185
+ segment_ids.tofile(f'{dataset_path}/segment_ids.bin')
186
+
187
+ # feature_mean = np.mean(videos)
188
+ # feature_std = np.std((videos - feature_mean) / 1e9) * 1e9
189
+
190
+ # save metadata
191
+ if encoder_type == "magvit":
192
+ vocab_size = int(2 ** 18)
193
+ elif encoder_type == "temporalvae":
194
+ vocab_size = None
195
+ else:
196
+ raise NotImplementedError(f"{encoder_type=}")
197
+
198
+ with open(f'{dataset_path}/metadata.json', 'w') as f: # Technically only need to save most of this data for shard 0
199
+ json.dump({
200
+ "token_dtype": str(np.dtype(videos.dtype)),
201
+ "action_dim": actions[0].shape[-1],
202
+ "s": 16,
203
+ "h": height,
204
+ "w": width,
205
+ "vocab_size": vocab_size,
206
+ "hz": DATA_FREQ_TABLE.get(extern_dataset_name, 1), # to be loaded from the data code
207
+ "encoder_name_or_path": encoder_name_or_path,
208
+ "encoder_type": encoder_type,
209
+ "num_images": len(videos),
210
+ "latent_channels": num_channels,
211
+ "name": extern_dataset_name,
212
+ # "feature_mean": feature_mean,
213
+ # "feature_std": feature_std,
214
+ }, f)
215
+
216
+ print(f"{len(traj_lens)=} {np.mean(traj_lens)=} {np.sum(traj_lens)=}")
217
+ print(f"Dataset creation time: {time.time() - start_time:.3f}")
218
+
219
+
220
+ def parse_args():
221
+ parser = argparse.ArgumentParser(description=SCRIPT_DESCRIPTION)
222
+
223
+ parser.add_argument(
224
+ "--dataset_name", type=str, required=True, choices=DATASET_TO_GEN_AND_SIZE.keys(),
225
+ help="TODO"
226
+ )
227
+ parser.add_argument(
228
+ "--data_split", type=str, choices=["train", "val"], required=True,
229
+ help="The split of the dataset to create."
230
+ )
231
+ parser.add_argument(
232
+ "--episode_cnt", type=int,
233
+ help="If specified, will limit the maximum number of trajectories to encode."
234
+ )
235
+ parser.add_argument(
236
+ "--original_res", action='store_true',
237
+ help="Maintain original resolution of the video rather than resizing it to 256x256."
238
+ )
239
+ parser.add_argument(
240
+ "--no_quantization", action='store_true',
241
+ help="Skip quantization step in visual encoder."
242
+ )
243
+ parser.add_argument(
244
+ "--num_shards", type=int, default=1,
245
+ help="The number of shards to partition the train/val dataset into."
246
+ )
247
+ parser.add_argument(
248
+ "--curr_shard_rank", type=int, default=0,
249
+ help="The (0-indexed) shard number to encode."
250
+ )
251
+ parser.add_argument(
252
+ "--root_dir", type=str, default="data",
253
+ help="The root directory to write all datasets to."
254
+ )
255
+ parser.add_argument(
256
+ "--encoder_type", type=str, default="magvit", choices=["magvit", "temporalvae"],
257
+ help="Type of the image tokenizer."
258
+ )
259
+ parser.add_argument(
260
+ "--encoder_name_or_path", type=str, default="data/magvit2.ckpt",
261
+ help="The path or name of the image encoder."
262
+ )
263
+ parser.add_argument(
264
+ "--no_encoding", action='store_true',
265
+ help="Preserve the groundtruth raw images to compute metrics in validation."
266
+ )
267
+ return parser.parse_args()
268
+
269
+
270
+ if __name__ == "__main__":
271
+ args = parse_args()
272
+ utils.set_seed(233)
273
+
274
+ dataset_postfix = f"shard{args.curr_shard_rank}_of_{args.num_shards}"
275
+ if args.episode_cnt is not None:
276
+ dataset_postfix = f"max{args.episode_cnt}_{dataset_postfix}"
277
+
278
+ encode_dataset_split(
279
+ extern_dataset_name=args.dataset_name,
280
+ split=args.data_split,
281
+ max_episodes=args.episode_cnt,
282
+ dataset_postfix=dataset_postfix,
283
+ original_res=args.original_res,
284
+ no_quantization=args.no_quantization,
285
+ num_shards=args.num_shards,
286
+ curr_shard_rank=args.curr_shard_rank,
287
+ root_dir=args.root_dir,
288
+ encoder_type=args.encoder_type,
289
+ encoder_name_or_path=args.encoder_name_or_path,
290
+ no_encoding=args.no_encoding,
291
+ )
datasets/encode_openx_dataset.py ADDED
@@ -0,0 +1,459 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # Licensed under The MIT License [see LICENSE for details]
3
+ # --------------------------------------------------------
4
+ import argparse
5
+ import json
6
+ import os
7
+ import time
8
+ import traceback
9
+ from typing import Optional
10
+
11
+ import math
12
+ import numpy as np
13
+ import tensorflow_datasets as tfds
14
+ from tensorflow_datasets.core import DatasetBuilder
15
+ from tqdm import tqdm
16
+
17
+ from . import utils
18
+
19
+
20
+ SCRIPT_DESCRIPTION="""
21
+ Converts an Open X-Embodiment dataset from GS to encoded/tokenized data on disk.
22
+ This script only encodes one split (specified by `--data_split`)
23
+ of a one OpenX dataset (specified by `--dataset_name`) at a time.
24
+
25
+ Optionally, each split can be partitioned into multiple shards,
26
+ which is useful for parallelized encoding across GPUs.
27
+
28
+ Example usage:
29
+ CUDA_VISIBLE_DEVICES=0 python -m datasets.encode_openx_dataset --dataset_name bc_z --data_split train --episode_cnt 500 --num_shards 16 --curr_shard_rank 0
30
+ CUDA_VISIBLE_DEVICES=1 python -m datasets.encode_openx_dataset --dataset_name bc_z --data_split train --episode_cnt 500 --num_shards 16 --curr_shard_rank 1
31
+
32
+ set -e
33
+ for ((i = 0; i < 64; i += 2)); do
34
+ CUDA_VISIBLE_DEVICES=0 python -m datasets.encode_openx_dataset --dataset_name bridge --data_split train --num_shards 64 --curr_shard_rank $i --root_dir sharded_data
35
+ done
36
+
37
+ set -e
38
+ for ((i = 1; i < 64; i += 2)); do
39
+ CUDA_VISIBLE_DEVICES=1 python -m datasets.encode_openx_dataset --dataset_name bridge --data_split train --num_shards 64 --curr_shard_rank $i --root_dir sharded_data
40
+ done
41
+
42
+ Example usage (SVD tokenizer):
43
+ CUDA_VISIBLE_DEVICES=0 python -m datasets.encode_openx_dataset --dataset_name language_table --data_split val --no_quantization --encoder_type temporalvae --encoder_name_or_path 'stabilityai/stable-video-diffusion-img2vid'
44
+ """.strip()
45
+
46
+ # The validation set is the first VAL_RATIO examples in the dataset, and clipped to [MIN_VAL_EXAMPLES, MAX_VAL_EXAMPLES]
47
+ VAL_RATIO = 0.05
48
+ MIN_VAL_EXAMPLES, MAX_VAL_EXAMPLES = 20, 200
49
+
50
+
51
+ DATA_FREQ_TABLE = {
52
+ "austin_sailor_dataset_converted_externally_to_rlds": 20,
53
+ "stanford_hydra_dataset_converted_externally_to_rlds": 10,
54
+ "austin_buds_dataset_converted_externally_to_rlds": 20,
55
+ "austin_sirius_dataset_converted_externally_to_rlds": 20,
56
+ "berkeley_mvp_converted_externally_to_rlds": 5,
57
+ "berkeley_rpt_converted_externally_to_rlds": 30,
58
+ "ucsd_kitchen_dataset_converted_externally_to_rlds": 2,
59
+ "iamlab_cmu_pickup_insert_converted_externally_to_rlds": 20,
60
+ "utaustin_mutex": 20,
61
+ "imperialcollege_sawyer_wrist_cam": 10,
62
+ "language_table": 2, # changed to match frequency
63
+ "kuka": 2, # changed to match frequency
64
+ "bc_z": 10,
65
+ "robo_net": 1,
66
+ "dlr_sara_pour_converted_externally_to_rlds": 10,
67
+ "stanford_robocook_converted_externally_to_rlds": 5,
68
+ "cmu_play_fusion": 5,
69
+ "bridge": 5,
70
+ "furniture_bench_dataset_converted_externally_to_rlds": 10,
71
+ "ucsd_pick_and_place_dataset_converted_externally_to_rlds": 3,
72
+ "usc_cloth_sim_converted_externally_to_rlds": 10,
73
+ "stanford_kuka_multimodal_dataset_converted_externally_to_rlds": 20,
74
+ "roboturk": 10,
75
+ "kaist_nonprehensile_converted_externally_to_rlds": 10,
76
+ "asu_table_top_converted_externally_to_rlds": 12,
77
+ "utokyo_xarm_pick_and_place_converted_externally_to_rlds": 10,
78
+ "berkeley_cable_routing": 10,
79
+ "droid": 15,
80
+ "uiuc_d3field": 1,
81
+ "robo_set": 5,
82
+ "toto": 30,
83
+ "nyu_door_opening_surprising_effectiveness": 3,
84
+ "nyu_franka_play_dataset_converted_externally_to_rlds": 3,
85
+ "mimic_play": 15,
86
+ "maniskill_dataset_converted_externally_to_rlds": 20,
87
+ "columbia_cairlab_pusht_real": 10,
88
+ "conq_hose_manipulation": 30,
89
+ "dlr_edan_shared_control_converted_externally_to_rlds": 5,
90
+ "berkeley_gnm_sac_son": 10,
91
+ "berkeley_autolab_ur5": 5,
92
+ "aloha_mobile": 30,
93
+ "1x_humanoid": 30,
94
+ "epic_kitchen_originalres": 30,
95
+ "epic_kitchen": 30,
96
+ "exoego4d": 30,
97
+ "ego4d": 1, # less than this.
98
+ "robomimic": 6, # average length around 50
99
+ "metaworld": 6,
100
+ "frodobot": 30,
101
+ "fractal20220817_data": 3,
102
+ # robomimic
103
+ "robomimic": 6, # average length around 50
104
+ "robomimic_new": 6, # average length around 50
105
+ "robomimic_multitask_new": 6, # average length around 50
106
+ "robomimic_new_perturb": 6, # average length around 50
107
+ "robomimic_multitask_new_perturb": 6, # average length around 50
108
+ }
109
+
110
+
111
+
112
+ def select_image(observation, verbose=False):
113
+ """
114
+ Select a canonical frame as image observation.
115
+ """
116
+ imgs = []
117
+ # does not need to prefer wrist camera
118
+ for key in ["rgb", "image"]:
119
+ for obs_key in observation:
120
+ if key in obs_key and "depth" not in obs_key:
121
+ image = observation[obs_key]
122
+ if type(observation[obs_key]) is not np.ndarray:
123
+ image = image.numpy()
124
+ if verbose:
125
+ print("selected image key:", obs_key)
126
+ imgs.append(image)
127
+
128
+ return imgs
129
+
130
+
131
+ def process_dataset_step(step, encoder_type: str, encoder_name_or_path: str,
132
+ keep_res=False, quantize=True, no_encoding=False):
133
+ """
134
+ Map dataset-specific keys and values to a unified format.
135
+
136
+ Args:
137
+ step (dict): The step dictionary containing the dataset-specific information.
138
+ encoder_type (str, optional): The image encoder to use.
139
+ Returns:
140
+ dict: The processed step dictionary with the mapped keys and values.
141
+ """
142
+ step_dict = {}
143
+ try:
144
+ if "action" in step:
145
+ step_dict["action"] = np.array(step["action"])
146
+
147
+ # handle action
148
+ if type(step["action"]) is dict:
149
+ step_dict["action"] = step_dict["action"].item()
150
+
151
+ # outlier cases
152
+ action = []
153
+ for k, v in sorted(step_dict["action"].items()):
154
+ action.append(v.numpy().reshape(-1))
155
+ step_dict["action"] = np.concatenate(action)
156
+
157
+ # handle image
158
+ images = select_image(step["observation"])
159
+
160
+ # compute the embeddings.
161
+ if no_encoding:
162
+ step_dict["image"] = utils.resize_image(images[0])
163
+ elif quantize:
164
+ step_dict["image"] = utils.get_quantized_image_embeddings(
165
+ images[0],
166
+ encoder_type=encoder_type,
167
+ encoder_name_or_path=encoder_name_or_path,
168
+ keep_res=keep_res,
169
+ )
170
+ else:
171
+ step_dict["image"] = utils.get_vae_image_embeddings(
172
+ images[0],
173
+ encoder_type=encoder_type,
174
+ encoder_name_or_path=encoder_name_or_path,
175
+ keep_res=keep_res,
176
+ )
177
+ except Exception as e:
178
+ print("--------------------------")
179
+ print("process_dataset_step exception:", traceback.format_exc())
180
+
181
+ return step_dict
182
+
183
+
184
+ def get_dataset_builder(gs_dataset_name) -> tuple[DatasetBuilder, int]:
185
+ """
186
+ Returns the dataset builder and the total number of examples (for the train split).
187
+ """
188
+ try:
189
+ builder = tfds.builder_from_directory(builder_dir=f"gs://gresearch/robotics/{gs_dataset_name}/0.1.0/")
190
+ except:
191
+ try:
192
+ builder = tfds.builder_from_directory(builder_dir=f"gs://gresearch/robotics/{gs_dataset_name}/1.0.0/")
193
+ except:
194
+ builder = tfds.builder_from_directory(builder_dir=f"gs://gresearch/robotics/{gs_dataset_name}/0.0.1/")
195
+
196
+ info = builder.info
197
+ num_examples = info.splits["train"].num_examples
198
+
199
+ return builder, num_examples
200
+
201
+
202
+ def get_shard_inds(first_split_ind: int, last_split_ind: int, curr_shard_rank: int, num_shards: int) -> tuple[int, int]:
203
+ """
204
+ Given the indices of the first (inclusive) and last (exclusive) examples in the data split (i.e. entire train dataset or val dataset),
205
+ returns the indices of the first (inclusive) and last (exclusive) examples for the current shard in this data split.
206
+ """
207
+ split_num_examples = last_split_ind - first_split_ind
208
+ shard_size_float = split_num_examples / num_shards # average number of examples per shard
209
+ return (
210
+ first_split_ind + math.ceil(curr_shard_rank * shard_size_float),
211
+ min(first_split_ind + math.ceil((curr_shard_rank + 1) * shard_size_float), last_split_ind)
212
+ )
213
+
214
+
215
+ def encode_dataset_split(
216
+ gs_dataset_name: str,
217
+ split: str,
218
+ max_episodes: Optional[int],
219
+ original_res: bool,
220
+ no_quantization: bool,
221
+ curr_shard_rank: int,
222
+ num_shards: int,
223
+ root_dir: str,
224
+ encoder_type: str,
225
+ encoder_name_or_path: str,
226
+ dataset_postfix: str = "",
227
+ no_encoding: bool = False,
228
+ ):
229
+ """
230
+ Converts an Open X-Embodiment dataset from GS to encoded/tokenized data on disk.
231
+ The data written to disk can be used to load a `RawTokenDataset` (or the continuous version.)
232
+
233
+ Args:
234
+ gs_dataset_name: the name of the dataset in Google Storage.
235
+ Can be checked with gsutil ls -d gs://gresearch/robotics/*/
236
+ split: expected to be either "train" or "val". TODO: decide how to split
237
+ max_episodes: the maximum number of trajectories to include in the dataset.
238
+ dataset_postfix: will be a suffix of the output dirname.
239
+ image_encoder: string specifying the type of image encoder/tokenizer to use.
240
+ original_res: if True, will maintain original resolution of the video rather than resizing it to 256x256.
241
+ no_quantization: if True, will not perform quantization step in image encoder.
242
+ """
243
+ gs_dataset_name = gs_dataset_name.strip() # never modified
244
+ suffixed_dataset_name = gs_dataset_name # will modify later
245
+ if no_quantization:
246
+ video_dtype = np.float16
247
+ elif no_encoding:
248
+ video_dtype = np.uint8
249
+ else:
250
+ video_dtype = np.uint32
251
+ if original_res:
252
+ suffixed_dataset_name = f"{suffixed_dataset_name}_originalres"
253
+ if no_quantization:
254
+ suffixed_dataset_name = f"{suffixed_dataset_name}_noquant"
255
+ if no_encoding:
256
+ suffixed_dataset_name = f"{suffixed_dataset_name}_noencoding"
257
+ save_dirname = "_".join([suffixed_dataset_name, encoder_type, dataset_postfix, split])
258
+ dataset_path = os.path.join(root_dir, save_dirname)
259
+ print("=" * 25)
260
+ print(f"{dataset_path=}")
261
+ utils.mkdir_if_missing(dataset_path)
262
+
263
+ # Load data
264
+ builder, num_examples = get_dataset_builder(gs_dataset_name)
265
+ if max_episodes is not None:
266
+ num_examples = min(num_examples, max_episodes) # clip num_examples
267
+
268
+ # We will only operate on a subset of the training examples, depending on:
269
+ # 1) The split (train/val). Some examples are reserved for the other split.
270
+ # 2) Sharding
271
+ assert num_examples > MIN_VAL_EXAMPLES, f"{num_examples=} {MIN_VAL_EXAMPLES=}" # non-positive number of train examples otherwise
272
+ num_val_examples = np.clip(int(VAL_RATIO * num_examples), MIN_VAL_EXAMPLES, MAX_VAL_EXAMPLES)
273
+
274
+ if split == "train": # first_ind inclusive, last_ind exclusive
275
+ first_split_ind, last_split_ind = num_val_examples, num_examples
276
+ elif split == "val":
277
+ first_split_ind, last_split_ind = 0, num_val_examples
278
+ else:
279
+ raise NotImplementedError(f"{split=}")
280
+
281
+ first_shard_ind, last_shard_ind = get_shard_inds(first_split_ind, last_split_ind, curr_shard_rank, num_shards)
282
+ print(f"Total number of examples in {suffixed_dataset_name}: {num_examples}")
283
+ print(f"Number of examples for {split=}, shard {curr_shard_rank} of {num_shards}: "
284
+ f"{last_shard_ind - first_shard_ind}. {first_shard_ind=} {last_shard_ind=}")
285
+
286
+ ##### Encode data #####
287
+ traj_lens = [] # only used to print statistics
288
+ videos = [] # NOTE: videos/actions for the entire shard are stored in RAM until the end
289
+ actions = []
290
+ segment_ids = []
291
+
292
+ # split based on some fixed batch sizes to reset RAM.
293
+ max_batch_per_loading = 10
294
+ pbar = tqdm(range(first_shard_ind, last_shard_ind, max_batch_per_loading), position=0, leave=True)
295
+ start_time = time.time()
296
+
297
+ for start_idx in pbar:
298
+ end_idx = min(start_idx + max_batch_per_loading, last_shard_ind)
299
+ pbar.set_description(f"{suffixed_dataset_name} caching episodes: {start_idx}:{end_idx}")
300
+ ds = builder.as_dataset(split=f"train[{start_idx}:{end_idx}]")
301
+
302
+ for chunk_idx, episode in enumerate(tqdm(ds, position=1, leave=False)):
303
+ segment_id = start_idx + chunk_idx
304
+ try:
305
+ # batchify the data and then process
306
+ for step_ind, step_data in enumerate(episode["steps"]):
307
+ dataset_step = process_dataset_step(
308
+ step_data,
309
+ encoder_type=encoder_type,
310
+ encoder_name_or_path=encoder_name_or_path,
311
+ keep_res=original_res,
312
+ quantize=not no_quantization,
313
+ no_encoding=no_encoding
314
+ )
315
+
316
+ segment_ids.append(segment_id)
317
+ videos.append(dataset_step["image"])
318
+ actions.append(dataset_step["action"])
319
+
320
+ traj_lens.append(step_ind + 1) # number of steps in this trajectory
321
+ except:
322
+ print("-" * 25)
323
+ print(f"Add episode failed: {segment_id=}", traceback.format_exc(), suffixed_dataset_name)
324
+
325
+ # 2 day timeout
326
+ if time.time() - start_time > 86400 * 2:
327
+ print(f"Writing dataset {suffixed_dataset_name} timed out")
328
+ break
329
+
330
+ if no_quantization:
331
+ num_channels, height, width = videos[-1].shape[:3]
332
+ else:
333
+ height, width = videos[-1].shape[:2]
334
+ num_channels = None
335
+
336
+ ##### Write videos, actions, segment_ids, and metadata #####
337
+ # align format to save segment_ids.bin, video.bin, actions/action.bin, metadata.json
338
+ # save videos
339
+ videos = np.stack(videos, axis=0)
340
+ fp = np.memmap(f'{dataset_path}/video.bin', dtype=video_dtype, mode='w+', shape=videos.shape)
341
+ fp[:] = videos[:]
342
+
343
+ # save action
344
+ utils.mkdir_if_missing(f'{dataset_path}/actions')
345
+ actions = np.stack(actions, axis=0)
346
+ fp = np.memmap(f'{dataset_path}/actions/actions.bin', dtype=np.float32, mode='w+', shape=actions.shape)
347
+ fp[:] = actions[:]
348
+
349
+ # save segment_ids
350
+ segment_ids = np.array(segment_ids)
351
+ fp = np.memmap(f'{dataset_path}/segment_ids.bin', dtype=np.int32, mode='w+', shape=segment_ids.shape)
352
+ fp[:] = segment_ids[:] # map to trajectory index
353
+
354
+ # feature_mean = float(np.mean(videos))
355
+ # feature_std = float(np.std((videos - feature_mean) / 1e9)) * 1e9
356
+ # save metadata
357
+ if encoder_type == "magvit":
358
+ vocab_size = int(2 ** 18)
359
+ elif encoder_type == "temporalvae":
360
+ vocab_size = None
361
+ else:
362
+ raise NotImplementedError(f"{encoder_type=}")
363
+
364
+ with open(f'{dataset_path}/metadata.json', 'w') as f: # Technically only need to save most of this data for shard 0
365
+ json.dump({
366
+ "token_dtype": str(np.dtype(videos.dtype)),
367
+ "action_dim": actions[0].shape[-1],
368
+ "s": 16,
369
+ "h": height,
370
+ "w": width,
371
+ "vocab_size": vocab_size,
372
+ "hz": DATA_FREQ_TABLE.get(gs_dataset_name, 1), # to be loaded from the data code TODO: remove default?
373
+ "encoder_name_or_path": encoder_name_or_path,
374
+ "encoder_type": encoder_type,
375
+ "num_images": len(videos),
376
+ "name": gs_dataset_name,
377
+ "latent_channels": num_channels,
378
+ "quantized": not args.no_quantization,
379
+ # "feature_mean": feature_mean,
380
+ # "feature_std": feature_std,
381
+ }, f)
382
+
383
+ print(f"{len(traj_lens)=} {np.mean(traj_lens)=} {np.sum(traj_lens)=}")
384
+ print(f"Dataset creation time: {time.time() - start_time:.3f}")
385
+
386
+
387
+ def parse_args():
388
+ parser = argparse.ArgumentParser(description=SCRIPT_DESCRIPTION)
389
+
390
+ parser.add_argument(
391
+ "--dataset_name", type=str, required=True,
392
+ help="The name of the Open X-Embodiment dataset on Google Storage. "
393
+ "Can be checked with gsutil ls -d gs://gresearch/robotics/*/. "
394
+ )
395
+ parser.add_argument(
396
+ "--data_split", type=str, choices=["train", "val"], required=True,
397
+ help="The split of the dataset to create."
398
+ )
399
+ parser.add_argument(
400
+ "--episode_cnt", type=int,
401
+ help="If specified, will limit the maximum number of trajectories to encode."
402
+ )
403
+ parser.add_argument(
404
+ "--original_res", action='store_true',
405
+ help="Maintain original resolution of the video rather than resizing it to 256x256."
406
+ )
407
+ parser.add_argument(
408
+ "--no_quantization", action='store_true',
409
+ help="Skip quantization step in visual encoder."
410
+ )
411
+ parser.add_argument(
412
+ "--num_shards", type=int, default=1,
413
+ help="The number of shards to partition the train/val dataset into."
414
+ )
415
+ parser.add_argument(
416
+ "--curr_shard_rank", type=int, default=0,
417
+ help="The (0-indexed) shard number to encode."
418
+ )
419
+ parser.add_argument(
420
+ "--root_dir", type=str, default="data",
421
+ help="The root directory to write all datasets to."
422
+ )
423
+ parser.add_argument(
424
+ "--encoder_type", type=str, default="magvit", choices=["magvit", "temporalvae"],
425
+ help="Type of the image tokenizer."
426
+ )
427
+ parser.add_argument(
428
+ "--encoder_name_or_path", type=str, default="data/magvit2.ckpt",
429
+ help="The path or name of the image encoder."
430
+ )
431
+ parser.add_argument(
432
+ "--no_encoding", action='store_true',
433
+ help="Preserve the groundtruth raw images to compute metrics in validation."
434
+ )
435
+ return parser.parse_args()
436
+
437
+
438
+ if __name__ == "__main__":
439
+ args = parse_args()
440
+ utils.set_seed(233)
441
+
442
+ dataset_postfix = f"shard{args.curr_shard_rank}_of_{args.num_shards}" if args.num_shards > 1 else ""
443
+ if args.episode_cnt is not None:
444
+ dataset_postfix = f"max{args.episode_cnt}_{dataset_postfix}" if dataset_postfix else f"max{args.episode_cnt}"
445
+
446
+ encode_dataset_split(
447
+ gs_dataset_name=args.dataset_name,
448
+ split=args.data_split,
449
+ max_episodes=args.episode_cnt,
450
+ dataset_postfix=dataset_postfix,
451
+ original_res=args.original_res,
452
+ no_quantization=args.no_quantization,
453
+ num_shards=args.num_shards,
454
+ curr_shard_rank=args.curr_shard_rank,
455
+ root_dir=args.root_dir,
456
+ encoder_type=args.encoder_type,
457
+ encoder_name_or_path=args.encoder_name_or_path,
458
+ no_encoding=args.no_encoding,
459
+ )
datasets/extern/__init__.py ADDED
File without changes
datasets/extern/ego4d.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # Licensed under The MIT License [see LICENSE for details]
3
+ # --------------------------------------------------------
4
+ import os
5
+ from typing import Iterable
6
+
7
+ import numpy as np
8
+ from tqdm import tqdm
9
+ from collections import OrderedDict
10
+ import os
11
+ import numpy as np
12
+ from pathlib import Path
13
+
14
+ CURRENT_DIR = os.path.dirname(__file__)
15
+ import cv2
16
+ from os.path import expanduser
17
+ import json
18
+ import matplotlib.pyplot as plt
19
+
20
+ RESOLUTION = (480, 480)
21
+ home = expanduser("~")
22
+
23
+ # Adjust these to the where-ever your detections and frames are stored.
24
+ ROOT = "/datasets01/ego4d_track2/"
25
+ LABEL_ROOT = ROOT + "v2_1/annotations/fho_main.json"
26
+ VIDEO_PATH = ROOT + "v2_1/full_scale/"
27
+ # from epic_kitchens.hoa import load_detections
28
+
29
+ # labels = json.load(open("/datasets01/ego4d_track2/v2_1/annotations/fho_main.json"))
30
+ # videos = /datasets01/ego4d_track2/v2_1/clips
31
+ def parse_video_frame(video_path, frame_id):
32
+ cap = cv2.VideoCapture(video_path)
33
+ cap.set(cv2.CAP_PROP_POS_FRAMES, frame_id-1)
34
+ ret, frame = cap.read()
35
+ return frame
36
+
37
+ def parse_raw_video(video_path):
38
+ cap = cv2.VideoCapture(video_path)
39
+ frames = []
40
+ while cap.isOpened():
41
+ ret, frame = cap.read()
42
+ if not ret:
43
+ break
44
+ frames.append(frame)
45
+ return frames
46
+
47
+ def compute_state_and_actions(image, curr_frame, next_frame, frame_idx, save=False):
48
+ # curr_frame is a list of bounding box labels
49
+ img_width, img_height = image.shape[1], image.shape[0]
50
+ for box in curr_frame:
51
+ if box['object_type'] == 'left_hand':
52
+ curr_hand1_center = [box['bbox']['x'] + box['bbox']['width'] / 2, box['bbox']['y'] + box['bbox']['height'] / 2]
53
+
54
+ if box['object_type'] == 'right_hand':
55
+ curr_hand2_center = [box['bbox']['x'] + box['bbox']['width'] / 2, box['bbox']['y'] + box['bbox']['height'] / 2]
56
+
57
+ for box in next_frame:
58
+ if box['object_type'] == 'left_hand':
59
+ next_hand1_center = [box['bbox']['x'] + box['bbox']['width'] / 2, box['bbox']['y'] + box['bbox']['height'] / 2]
60
+
61
+ if box['object_type'] == 'right_hand':
62
+ next_hand2_center = [box['bbox']['x'] + box['bbox']['width'] / 2, box['bbox']['y'] + box['bbox']['height'] / 2]
63
+
64
+ # normalized them
65
+ curr_hand1_center = np.array([curr_hand1_center[0] / img_width, curr_hand1_center[1] / img_height])
66
+ curr_hand2_center = np.array([curr_hand2_center[0] / img_width, curr_hand2_center[1] / img_height])
67
+
68
+ # normalize them
69
+ next_hand1_center = np.array([next_hand1_center[0] / img_width, next_hand1_center[1] / img_height])
70
+ next_hand2_center = np.array([next_hand2_center[0] / img_width, next_hand2_center[1] / img_height])
71
+
72
+ state = np.concatenate((curr_hand1_center, curr_hand2_center)) # - np.array(curr_hand1_center) - np.array(curr_hand2_center)
73
+ action = np.concatenate(
74
+ (
75
+ np.array(next_hand1_center),
76
+ np.array(next_hand2_center),
77
+ )
78
+ )
79
+ if save:
80
+ # draw the bounding boxes
81
+ cv2.circle(image, (int(curr_hand1_center[0] * img_width), int(curr_hand1_center[1] * img_height)), 10, (0, 255, 0), -1)
82
+ cv2.circle(image, (int(curr_hand2_center[0] * img_width), int(curr_hand2_center[1] * img_height)), 10, (0, 255, 0), -1)
83
+ cv2.circle(image, (int(next_hand1_center[0] * img_width), int(next_hand1_center[1] * img_height)), 10, (0, 0, 255), -1)
84
+ cv2.circle(image, (int(next_hand2_center[0] * img_width), int(next_hand2_center[1] * img_height)), 10, (0, 0, 255), -1)
85
+ # save the image
86
+ cv2.imwrite(f"/private/home/xinleic/LR/hpt_video/data/ego4d_video_label_check/img_{frame_idx}.png", image)
87
+ return state, action
88
+
89
+
90
+ def parse_raw_video(video_path):
91
+ import cv2
92
+ cap = cv2.VideoCapture(video_path)
93
+ frames = []
94
+ while cap.isOpened():
95
+ ret, frame = cap.read()
96
+ if not ret:
97
+ break
98
+ frames.append(frame)
99
+ return frames
100
+
101
+ def chunk_actions_and_concatenate(actions):
102
+ chunk_size = 4
103
+ chunked_actions = [actions[i:i + chunk_size] for i in range(0, len(actions), chunk_size)][:-1]
104
+ concatenated_frames = []
105
+
106
+ for chunk in chunked_actions:
107
+ frames_to_concat = []
108
+ for action in chunk:
109
+ frames = action['frames'] # Assuming 'frames' is a list or iterable
110
+ if frames is not None:
111
+ frames_to_concat.extend(frames) # Collect frames from each action
112
+ concatenated_frames.append(frames_to_concat) # Store the concatenated frames for this chunk
113
+
114
+ return concatenated_frames
115
+
116
+
117
+ def ego4d_dataset_size() -> int:
118
+ """ Returns the number of trajectories in the dataset. ~1725 for Ego4D. """
119
+ labels = json.load(open(LABEL_ROOT))
120
+ return len(labels['videos'])
121
+
122
+
123
+ # define your own dataset conversion
124
+ def ego4d_dataset_generator(example_inds: Iterable[int] = None):
125
+ """
126
+ Generator yielding data from Ego4D.
127
+ Args:
128
+ example_inds: if specified, will only yield data from these indices.
129
+ Otherwise, will default to yielding the entire dataset.
130
+ """
131
+ # convert to a list of episodes that can be added to replay buffer
132
+ labels = json.load(open(LABEL_ROOT))
133
+
134
+ if example_inds is None:
135
+ example_inds = range(len(labels['videos']))
136
+
137
+ for example_ind in example_inds:
138
+ label = labels['videos'][example_ind]
139
+ # ['annotated_intervals'][2]['narrated_actions']
140
+ video_path = VIDEO_PATH + label['video_uid'] + ".mp4"
141
+ if not os.path.exists(video_path):
142
+ print("skip", video_path)
143
+ continue
144
+
145
+ label_detections = labels
146
+ print("video_path:", video_path)
147
+ print("len label detections", len(label_detections))
148
+
149
+ # action extractions over bounding boxes subtractions of both hands.
150
+ for interval in label['annotated_intervals']:
151
+ # print(video_detections[frame_idx].hands)
152
+
153
+ lang = "use human hands to do some tasks" # dummies
154
+ # import IPython; IPython.embed()
155
+ print(f"Interval [{interval['start_sec']} - {interval['end_sec']}]")
156
+ actions = list(filter(lambda x: not (x['is_invalid_annotation'] or x['is_rejected']) and x['stage'] is not None, interval['narrated_actions']))
157
+ print(f"Actions: {len(actions)}")
158
+
159
+ # because we need to concatenate
160
+ if len(actions) < 3:
161
+ continue
162
+
163
+ # the number of frames is usually 7 and it also does not follow strict 2hz
164
+ chunk_actions = chunk_actions_and_concatenate(actions)
165
+ for frame_idx, frames in enumerate(chunk_actions):
166
+ # lang = frame['narration_text']
167
+ steps = []
168
+ # need to use dummy actions to expand from 6 frames to 16 frames
169
+ for idx, frame in enumerate(frames[:-1]):
170
+ frame_id = frame['frame_number']
171
+ next_frame = frames[idx + 1]
172
+ image = parse_video_frame(video_path, frame_id)
173
+
174
+ if len(frame['boxes']) > 2 and len(next_frame['boxes']) > 2:
175
+ try:
176
+ s, a = compute_state_and_actions(image, frame['boxes'], next_frame['boxes'], idx, save=False)
177
+ except:
178
+ print(f'compute action failed idx {idx} frame idx {frame_idx}')
179
+ continue
180
+ # break into step dict
181
+ step = {
182
+ "observation": {"image": image, "state": s},
183
+ "action": a,
184
+ "language_instruction": lang,
185
+ }
186
+ steps.append(OrderedDict(step))
187
+
188
+ if len(steps) < 16:
189
+ print("skip this traj because frame window length < 16")
190
+ continue
191
+ data_dict = {"steps": steps}
192
+ yield data_dict
193
+
datasets/extern/egoexo4d.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # Licensed under The MIT License [see LICENSE for details]
3
+ # --------------------------------------------------------
4
+ # https://github.com/epic-kitchens/epic-kitchens-100-hand-object-bboxes/blob/master/notebooks/demo.ipynb
5
+ import os
6
+ from typing import Iterable
7
+
8
+ import numpy as np
9
+ from tqdm import tqdm
10
+ from collections import OrderedDict
11
+ import os
12
+ import numpy as np
13
+ from pathlib import Path
14
+
15
+
16
+ CURRENT_DIR = os.path.dirname(__file__)
17
+ import cv2
18
+ from os.path import expanduser
19
+ import json
20
+
21
+
22
+ # Adjust these to the where-ever your detections and frames are stored.
23
+ CAM = "cam01" # cam01
24
+ ROOT = "/datasets01/egoexo4d/v2/"
25
+ LABEL_ROOT = ROOT + "annotations/ego_pose/train/hand/automatic/{}.json"
26
+ VIDEO_PATH = ROOT + "takes/{}/frame_aligned_videos/{}.mp4"
27
+ # from epic_kitchens.hoa import load_detections
28
+ TAKE_ROOT = ROOT + "takes.json"
29
+
30
+
31
+
32
+ def compute_state_and_actions(image, curr_frame, next_frame, idx, save=False):
33
+ img_width, img_height = image.shape[1], image.shape[0]
34
+
35
+ # already normalized
36
+ curr_hand1_center = curr_frame[0]['annotation2D'][CAM]['left_wrist']
37
+ curr_hand2_center = curr_frame[0]['annotation2D'][CAM]['right_wrist']
38
+
39
+ # normalized them
40
+ curr_hand1_center = np.array([curr_hand1_center['x'] / img_width, curr_hand1_center['y'] / img_height])
41
+ curr_hand2_center = np.array([curr_hand2_center['x'] / img_width, curr_hand2_center['y'] / img_height])
42
+
43
+ next_hand1_center = next_frame[0]['annotation2D'][CAM]['left_wrist']
44
+ next_hand2_center = next_frame[0]['annotation2D'][CAM]['right_wrist']
45
+
46
+ # normalize them
47
+ next_hand1_center = np.array([next_hand1_center['x'] / img_width, next_hand1_center['y'] / img_height])
48
+ next_hand2_center = np.array([next_hand2_center['x'] / img_width, next_hand2_center['y'] / img_height])
49
+
50
+
51
+ state = np.concatenate((curr_hand1_center, curr_hand2_center)) # - np.array(curr_hand1_center) - np.array(curr_hand2_center)
52
+ action = np.concatenate(
53
+ (
54
+ np.array(next_hand1_center),
55
+ np.array(next_hand2_center),
56
+ )
57
+ )
58
+ if save:
59
+ # draw the bounding boxes
60
+ cv2.circle(image, (int(curr_hand1_center[0] * img_width), int(curr_hand1_center[1] * img_height)), 10, (0, 255, 0), -1)
61
+ cv2.circle(image, (int(curr_hand2_center[0] * img_width), int(curr_hand2_center[1] * img_height)), 10, (0, 255, 0), -1)
62
+ cv2.circle(image, (int(next_hand1_center[0] * img_width), int(next_hand1_center[1] * img_height)), 10, (0, 0, 255), -1)
63
+ cv2.circle(image, (int(next_hand2_center[0] * img_width), int(next_hand2_center[1] * img_height)), 10, (0, 0, 255), -1)
64
+ # save the image
65
+ cv2.imwrite(f"output/inspect/test_{idx}.png", image)
66
+ return state, action
67
+
68
+
69
+ def parse_raw_video(video_path):
70
+ import cv2
71
+ cap = cv2.VideoCapture(video_path)
72
+ frames = []
73
+ while cap.isOpened():
74
+ ret, frame = cap.read()
75
+ if not ret:
76
+ break
77
+ frames.append(frame)
78
+ return frames
79
+
80
+ def egoexo4d_dataset_size() -> int:
81
+ """ Returns the number of takes in the dataset. ~5k for v2. """
82
+ takes = json.load(open(TAKE_ROOT))
83
+ return len(takes)
84
+
85
+
86
+ # define your own dataset conversion
87
+ def egoexo4d_dataset_generator(example_inds: Iterable[int] = None):
88
+ """
89
+ Generator yielding data from Ego-Exo4D.
90
+ Args:
91
+ example_inds: if specified, will only yield data from these indices.
92
+ Otherwise, will default to yielding the entire dataset.
93
+ """
94
+ # convert to a list of episodes that can be added to replay buffer
95
+ MAX_EPISODE_LENGTH = 5000
96
+ TAKE_FILE = json.load(open(TAKE_ROOT))
97
+ print("total takes", len(TAKE_FILE))
98
+ # find the first camera with aria
99
+ global CAM
100
+
101
+ def find_aria_name(take):
102
+ for cam in take['cameras']:
103
+ if 'aria' in cam['name']:
104
+ return cam['name']
105
+ return None
106
+
107
+ if example_inds is None:
108
+ example_inds = range(len(TAKE_FILE))
109
+
110
+ for example_ind in example_inds:
111
+ take = TAKE_FILE[example_ind]
112
+ take_name = take['take_name']
113
+ take_uid = take['take_uid']
114
+ # CAM = find_aria_name(take)
115
+ # if CAM is None:
116
+ # continue
117
+
118
+ video_path = VIDEO_PATH.format(take_name, CAM)
119
+ label_path = LABEL_ROOT.format(take_uid)
120
+
121
+ if not os.path.exists(video_path) or not os.path.exists(label_path):
122
+ continue
123
+
124
+ video_frames = parse_raw_video(video_path)
125
+ label_detections = json.load(open(label_path))
126
+ print("video_path:", video_path)
127
+ print("len video frames", len(video_frames))
128
+ print("len label detections", len(label_detections))
129
+
130
+ # action extractions over bounding boxes subtractions of both hands.
131
+ max_frame_idx = len(video_frames) - 1
132
+ DS_FACTOR = 1
133
+ frame_idx = 0
134
+ start_frame_idx = 0
135
+ MIN_CLIP_LENGTH = 300
136
+
137
+ def get_continuous_chunk(start_idx, label_detections):
138
+ end_idx = start_idx + 1
139
+ while str(start_idx) in label_detections and len(label_detections[str(start_idx)]) > 0 and str(end_idx) in label_detections and len(label_detections[str(end_idx)]) > 0:
140
+ end_idx += 1
141
+ return end_idx
142
+
143
+ print("TAKE", take_name)
144
+
145
+ # some frames might not have label. if there is a gap, skip
146
+ while start_frame_idx < max_frame_idx - DS_FACTOR:
147
+ # print(video_detections[frame_idx].hands)
148
+ lang = "use human hands to do some tasks" # dummies
149
+ if str(start_frame_idx) not in label_detections or str(start_frame_idx + DS_FACTOR) not in label_detections:
150
+ start_frame_idx += DS_FACTOR
151
+ continue
152
+
153
+ end_frame_idx = get_continuous_chunk(start_frame_idx, label_detections)
154
+ # print("start_frame_idx", start_frame_idx, end_frame_idx)
155
+
156
+ if end_frame_idx - start_frame_idx < MIN_CLIP_LENGTH:
157
+ start_frame_idx = end_frame_idx
158
+ continue
159
+
160
+ print("start clipping from", start_frame_idx, "to", end_frame_idx)
161
+ steps = []
162
+ for frame_idx in range(start_frame_idx, end_frame_idx - DS_FACTOR, DS_FACTOR):
163
+ image = video_frames[frame_idx][...,[2,1,0]] # RGB
164
+ try:
165
+ s, a = compute_state_and_actions(
166
+ image,
167
+ label_detections[str(frame_idx)], label_detections[str(frame_idx + DS_FACTOR)],
168
+ frame_idx, save=False
169
+ )
170
+ except:
171
+ break
172
+ # break into step dict
173
+ step = {
174
+ "observation": {"image": image, "state": s},
175
+ "action": a,
176
+ "language_instruction": lang,
177
+ }
178
+ steps.append(OrderedDict(step))
179
+ if len(steps) > MAX_EPISODE_LENGTH:
180
+ break
181
+
182
+ start_frame_idx = end_frame_idx
183
+ if len(steps) < MIN_CLIP_LENGTH:
184
+ data_dict = {"steps": steps}
185
+ print(f"max_frame_idx: {max_frame_idx} ds factor: {DS_FACTOR} {len(steps)}")
186
+ yield data_dict
datasets/extern/epic_kitchen.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # Licensed under The MIT License [see LICENSE for details]
3
+ # --------------------------------------------------------
4
+ # https://github.com/epic-kitchens/epic-kitchens-100-hand-object-bboxes/blob/master/notebooks/demo.ipynb
5
+ import os
6
+ import numpy as np
7
+ from tqdm import tqdm
8
+ from collections import OrderedDict
9
+ import os
10
+ import numpy as np
11
+ from pathlib import Path
12
+
13
+
14
+ CURRENT_DIR = os.path.dirname(__file__)
15
+ import cv2
16
+ from os.path import expanduser
17
+ from epic_kitchens.hoa.types import BBox, FloatVector, HandSide
18
+ from epic_kitchens.hoa import load_detections
19
+
20
+ RESOLUTION = (480, 480)
21
+ home = expanduser("~")
22
+
23
+ # Adjust these to the where-ever your detections and frames are stored.
24
+ DETECTION_ROOT = f"/checkpoint/xinleic/LR/epic-kitchens-100-hand-object-bboxes/labels/hand-objects"
25
+ FRAMES_ROOT = f"/datasets01/EPIC-KITCHENS-100"
26
+
27
+ # DETECTION_ROOT = f'{home}/Projects/epic_kitchen_labels/hand-objects'
28
+ # FRAMES_ROOT = f'{home}/EPIC-KITCHENS'
29
+ detections_root = Path(DETECTION_ROOT)
30
+ frames_root = Path(FRAMES_ROOT)
31
+
32
+
33
+ def compute_state_and_actions(curr_frame, next_frame):
34
+ curr_hand1, curr_hand2 = curr_frame.hands[0], curr_frame.hands[1]
35
+ if curr_hand1.side != HandSide.LEFT: # flip
36
+ curr_hand1, curr_hand2 = curr_hand2, curr_hand1
37
+
38
+ # already normalized
39
+ curr_hand1_center = curr_hand1.bbox.center
40
+ curr_hand2_center = curr_hand2.bbox.center
41
+
42
+ next_hand1, next_hand2 = next_frame.hands[0], next_frame.hands[1]
43
+ if next_hand1.side != HandSide.LEFT: # flip
44
+ next_hand1, next_hand2 = next_hand2, next_hand1
45
+
46
+ # already normalized even
47
+ next_hand1_center = next_hand1.bbox.center
48
+ next_hand2_center = next_hand2.bbox.center
49
+ state = np.concatenate((curr_hand1_center, curr_hand2_center))
50
+ action = np.concatenate(
51
+ (
52
+ np.array(next_hand1_center) - np.array(curr_hand1_center),
53
+ np.array(next_hand2_center) - np.array(curr_hand2_center),
54
+ )
55
+ )
56
+ return state, action
57
+
58
+
59
+ # define your own dataset conversion
60
+ def convert_dataset_image():
61
+ # convert to a list of episodes that can be added to replay buffer
62
+ ALL_EPISODES = os.listdir(FRAMES_ROOT)
63
+ MAX_EPISODE_LENGTH = 5000
64
+
65
+ for EPS in ALL_EPISODES:
66
+ rgb_path = os.path.join(FRAMES_ROOT, EPS, "rgb_frames")
67
+ if not os.path.exists(rgb_path):
68
+ continue
69
+ for video_id in os.listdir(rgb_path):
70
+ full_path = os.path.join(rgb_path, video_id)
71
+ if (
72
+ not full_path.endswith(".tar") and not full_path.endswith(".jpg") and not full_path.endswith("home")
73
+ ): # folder
74
+
75
+ # action extractions over bounding boxes subtractions of both hands.
76
+ participant_id = video_id[:3]
77
+ video_detections = load_detections(detections_root / participant_id / (video_id + ".pkl"))
78
+ max_frame_idx = len(video_detections) - 1
79
+ DS_FACTOR = 1
80
+ print(full_path)
81
+ steps = []
82
+
83
+ for frame_idx in range(0, max_frame_idx - DS_FACTOR, DS_FACTOR):
84
+ # print(video_detections[frame_idx].hands)
85
+ if (
86
+ len(video_detections[frame_idx].hands) != 2
87
+ or len(video_detections[frame_idx + DS_FACTOR].hands) != 2
88
+ ):
89
+ continue
90
+
91
+ s, a = compute_state_and_actions(
92
+ video_detections[frame_idx], video_detections[frame_idx + DS_FACTOR]
93
+ )
94
+ lang = "use human hands to do some tasks" # dummies
95
+ # print("state actions:", s, a)
96
+ image_path = frames_root / participant_id / "rgb_frames" / video_id / f"frame_{frame_idx:010d}.jpg"
97
+ # print(image_path)
98
+ image = cv2.imread(str(image_path))
99
+ if image is None:
100
+ continue
101
+ image = image[..., [2, 1, 0]] # RGB
102
+
103
+ # break into step dict
104
+ step = {
105
+ "observation": {"image": image, "state": s},
106
+ "action": a,
107
+ "language_instruction": lang,
108
+ }
109
+ steps.append(OrderedDict(step))
110
+ if len(steps) > MAX_EPISODE_LENGTH:
111
+ break
112
+ data_dict = {"steps": steps}
113
+ print(f"max_frame_idx: {max_frame_idx} ds factor: {DS_FACTOR} {len(steps)}")
114
+ yield data_dict
115
+
datasets/extern/frodobot.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # Licensed under The MIT License [see LICENSE for details]
3
+ # --------------------------------------------------------
4
+ import random
5
+ import os
6
+ import time
7
+ import sys
8
+ import numpy as np
9
+ import IPython
10
+ import torch
11
+ from tqdm import tqdm
12
+ from collections import OrderedDict
13
+ import os
14
+ import PIL.Image
15
+ import numpy as np
16
+ from typing import Union, List
17
+ from pathlib import Path
18
+ import re
19
+
20
+ CURRENT_DIR = os.path.dirname(__file__)
21
+ import cv2
22
+ from os.path import expanduser
23
+ import pickle
24
+ import cv2
25
+ from matplotlib import pyplot as plt
26
+
27
+ import pandas as pd
28
+ import json
29
+
30
+ RESOLUTION = (480, 480)
31
+ DATA = "/home/liruiw/Projects/frodobot/"
32
+ # https://colab.research.google.com/#scrollTo=50ce529a-a20a-4852-9a5a-114b52b98f2e&fileId=https%3A//huggingface.co/datasets/frodobots/FrodoBots-2K/blob/main/helpercode.ipynb
33
+
34
+
35
+ # #### control data
36
+ import pandas as pd
37
+
38
+ # print(f"{dataset_dir}/control_data_{ride_id}.csv")
39
+
40
+
41
+ def convert_img_dataset(
42
+ dataset_dir="/home/liruiw/Projects/frodobot/output_rides_22",
43
+ env_names=None,
44
+ gui=False,
45
+ episode_num_pertask=2000,
46
+ **kwargs,
47
+ ):
48
+ # convert to a list of episodes that can be added to replay buffer
49
+ for eps_file in os.listdir(dataset_dir)[:50]: # 50 trajectories
50
+ dataset_dir_ = os.path.join(dataset_dir, eps_file)
51
+ if os.path.isdir(dataset_dir_):
52
+ ride_id = dataset_dir_.split("_")[-2]
53
+ print(dataset_dir_)
54
+
55
+ ##### control data
56
+ control = pd.read_csv(f"{dataset_dir_}/control_data_{ride_id}.csv")
57
+ control_data_dict = control.set_index("timestamp").T.to_dict("list")
58
+ control_sorted_keys = sorted(list(control_data_dict.keys()))
59
+
60
+ ##### IMU data
61
+ gyro_data = pd.read_csv(f"{dataset_dir_}/imu_data_{ride_id}.csv")[["timestamp", "gyroscope"]]
62
+ gyro_data_dict = gyro_data.set_index("timestamp").T.to_dict("list")
63
+ gyro_sorted_keys = sorted(list(gyro_data_dict.keys()))
64
+
65
+ compass_data = pd.read_csv(f"{dataset_dir_}/imu_data_{ride_id}.csv")[["timestamp", "compass"]]
66
+ compass_data_dict = compass_data.set_index("timestamp").T.to_dict("list")
67
+ compass_sorted_keys = sorted(list(compass_data_dict.keys()))
68
+
69
+ accel_data = pd.read_csv(f"{dataset_dir_}/imu_data_{ride_id}.csv")[["timestamp", "accelerometer"]]
70
+ accel_data_dict = accel_data.set_index("timestamp").T.to_dict("list")
71
+ accel_sorted_keys = sorted(list(accel_data_dict.keys()))
72
+
73
+ ##### Camera data
74
+ camera_data = pd.read_csv(f"{dataset_dir_}/front_camera_timestamps_{ride_id}.csv")
75
+ camera_data_dict = camera_data.set_index("timestamp").T.to_dict("list")
76
+ camera_sorted_keys = sorted(list(camera_data_dict.keys()))
77
+
78
+ images = sorted(os.listdir(f"{dataset_dir_}/front_camera/"))
79
+
80
+ # #### front camera video
81
+ # front_camera = f"{dataset_dir}/recordings/0f0e8539d249f38e3ae7b18660f5af8c_ride_39572__uid_s_1000__uid_e_video_20240502221408754.ts"
82
+ languages = "drive around to play" # dummy
83
+ steps = []
84
+ SUBSAMPLE_IDX = 5
85
+
86
+ for idx, control_t in enumerate(control_sorted_keys):
87
+
88
+ # enumerate along actions and only pick matched timesteps
89
+ action = control_data_dict[control_t]
90
+ camera_t = camera_sorted_keys[np.argmin(np.array(camera_sorted_keys) - control_t)]
91
+ camera_path = images[camera_data_dict[camera_t][0]]
92
+ img = cv2.resize(cv2.imread(f"{dataset_dir_}/front_camera/{camera_path}"), None, fx=0.5, fy=0.5)
93
+ gyro = gyro_data_dict[gyro_sorted_keys[np.argmin(np.array(gyro_sorted_keys) - control_t)]]
94
+ first_three_strings = eval(gyro[0])[0][:3]
95
+ gyro_array = np.array(first_three_strings, dtype=float)
96
+
97
+ compass = compass_data_dict[compass_sorted_keys[np.argmin(np.array(compass_sorted_keys) - control_t)]]
98
+ first_three_strings = eval(compass[0])[0][:3]
99
+ compass_array = np.array(first_three_strings, dtype=float)
100
+
101
+ accel = accel_data_dict[accel_sorted_keys[np.argmin(np.array(accel_sorted_keys) - control_t)]]
102
+ first_three_strings = eval(accel[0])[0][:3]
103
+ accel_array = np.array(first_three_strings, dtype=float)
104
+
105
+ prop = np.concatenate((gyro_array, compass_array, accel_array))
106
+ step = {
107
+ "observation": {"state": prop, "image": img},
108
+ "action": action,
109
+ "language_instruction": languages,
110
+ }
111
+ steps.append(OrderedDict(step))
112
+ data_dict = {"steps": steps}
113
+ yield data_dict
114
+
115
+
116
+ class RolloutRunner:
117
+ """evaluate policy rollouts"""
118
+
119
+ def __init__(self, env_names, episode_num, save_video=False):
120
+ self.env_names = env_names
121
+ self.episode_num = episode_num
122
+ self.envs = []
123
+ self.scene_files = []
124
+ self.save_video = save_video
125
+
126
+ @torch.no_grad()
127
+ def run(self, policy, save_video=False, gui=False, video_postfix="", seed=233, env_name=None, **kwargs):
128
+ pass
datasets/extern/robomimic.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # Licensed under The MIT License [see LICENSE for details]
3
+ # --------------------------------------------------------
4
+
5
+ """
6
+ TODO: explain
7
+ """
8
+ import h5py
9
+ import numpy as np
10
+ import cv2
11
+ import time
12
+ from collections import OrderedDict
13
+ import robomimic.utils.file_utils as FileUtils
14
+
15
+ from sim.robomimic.robomimic_runner import (
16
+ create_env, OBS_KEYS, RESOLUTION
17
+ )
18
+ from sim.robomimic.robomimic_wrapper import RobomimicLowdimWrapper
19
+
20
+ from typing import Optional, Iterable
21
+
22
+ DATASET_DIR = 'data/robomimic/datasets'
23
+ SUPPORTED_ENVS = ['lift', 'square', 'can']
24
+ NUM_EPISODES_PER_TASK = 200
25
+
26
+
27
+ def render_step(env, state):
28
+ env.env.env.sim.set_state_from_flattened(state)
29
+ env.env.env.sim.forward()
30
+ img = env.render()
31
+ img = cv2.resize(img, RESOLUTION)
32
+ return img
33
+
34
+
35
+ def robomimic_dataset_size() -> int:
36
+ return len(SUPPORTED_ENVS) * NUM_EPISODES_PER_TASK
37
+
38
+
39
+ def robomimic_dataset_generator(example_inds: Optional[Iterable[int]] = None):
40
+ if example_inds is None:
41
+ example_inds = range(robomimic_dataset_size())
42
+
43
+ curr_env_name = None
44
+ for idx in example_inds:
45
+ # get env_name corresponding to idx
46
+ env_name = SUPPORTED_ENVS[idx // NUM_EPISODES_PER_TASK]
47
+ if curr_env_name is None or curr_env_name != env_name:
48
+ # need to load new env
49
+ dataset = f"{DATASET_DIR}/{env_name}/ph/image.hdf5"
50
+ env_meta = FileUtils.get_env_metadata_from_dataset(dataset)
51
+ env_meta["use_image_obs"] = True
52
+ env = create_env(env_meta=env_meta, obs_keys=OBS_KEYS)
53
+ env = RobomimicLowdimWrapper(env=env)
54
+ env.reset() # NOTE: this is necessary to remove green laser bug
55
+ curr_env_name = env_name
56
+
57
+ with h5py.File(dataset) as file:
58
+ demos = file["data"]
59
+ local_episode_idx = idx % NUM_EPISODES_PER_TASK
60
+ if f"demo_{local_episode_idx}" not in demos:
61
+ continue
62
+
63
+ demo = demos[f"demo_{local_episode_idx}"]
64
+ obs = demo["obs"]
65
+ states = demo["states"]
66
+ action = demo["actions"][:].astype(np.float32)
67
+ step_obs = np.concatenate([obs[key] for key in OBS_KEYS], axis=-1).astype(np.float32)
68
+ steps = []
69
+ for a, o, s in zip(action, step_obs, states):
70
+ # break into step dict
71
+ image = render_step(env, s)
72
+ step = {
73
+ "observation": {"state": o, "image": image},
74
+ "action": a,
75
+ "language_instruction": f"{env_name}",
76
+ }
77
+ steps.append(OrderedDict(step))
78
+ data_dict = {"steps": steps}
79
+ yield data_dict
80
+
81
+ # # import imageio
82
+ # for _ in range(3):
83
+ # steps = []
84
+ # perturbed_action = action + np.random.normal(0, 0.2, action.shape)
85
+ # current_state = states[0]
86
+ # _ = render_step(env, current_state)
87
+ # for someindex in range(len(action)):
88
+ # image = env.render()
89
+ # step = {
90
+ # "observation": {"image": image},
91
+ # "action": action[someindex],
92
+ # "language_instruction": f"{env_name}",
93
+ # }
94
+ # steps.append(OrderedDict(step))
95
+
96
+ # # simulate action
97
+ # env.step(perturbed_action[someindex])
98
+
99
+ # # # save video
100
+ # # frames = [step["observation"]["image"] for step in steps]
101
+ # # imageio.mimsave(f"test.mp4", frames, fps=10)
102
+ # # while not (user_input := input("Continue? (y/n)")) in ["y", "n"]:
103
+ # # print("Invalid input")
104
+
105
+ # data_dict = {"steps": steps}
106
+ # yield data_dict
107
+
108
+ env.close()
datasets/merge_shards.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Merge data shards generated from `encode_{extern,openx}_dataset.py`
3
+ In addition to CLI args, `SHARD_DATA_FORMAT` must be changed depending on the dataset.
4
+ """
5
+
6
+ import argparse
7
+ import json
8
+ import os
9
+
10
+ import numpy as np
11
+ from tqdm.auto import tqdm
12
+
13
+ SHARD_DATA_FORMAT = "/private/home/xinleic/LR/HPT-Video-KZ/sharded_data/droid_magvit_shard{}_of_{}_train"
14
+
15
+
16
+ if __name__ == "__main__":
17
+ parser = argparse.ArgumentParser()
18
+ parser.add_argument("--out_data_dir", type=str, required=True,
19
+ help="Directory to save merged data, must not exist.")
20
+ parser.add_argument("--num_shards", type=int, required=True, help="Number of shards the dataset was split into.")
21
+
22
+ args = parser.parse_args()
23
+ assert not os.path.exists(args.out_data_dir), "Will not overwrite existing directory."
24
+ os.makedirs(os.path.join(args.out_data_dir, "actions"), exist_ok=True)
25
+
26
+ num_frames = 0
27
+ valid_inds = []
28
+
29
+ for shard_ind in range(args.num_shards):
30
+ shard_path = SHARD_DATA_FORMAT.format(shard_ind, args.num_shards)
31
+ if os.path.isfile(os.path.join(shard_path, "metadata.json")):
32
+ valid_inds.append(shard_ind)
33
+ with open(os.path.join(shard_path, "metadata.json"), "r") as f:
34
+ shard_metadata = json.load(f)
35
+
36
+ num_frames += shard_metadata["num_images"]
37
+ else:
38
+ print(f"{shard_ind=} is invalid.")
39
+
40
+ if num_frames == 0:
41
+ print("No valid shards")
42
+ exit(0)
43
+
44
+ token_dtype = np.dtype(shard_metadata["token_dtype"])
45
+ if shard_metadata["quantized"]:
46
+ frame_dims = (shard_metadata["h"], shard_metadata["w"])
47
+ else:
48
+ frame_dims = (shard_metadata["latent_channels"], shard_metadata["h"], shard_metadata["w"])
49
+
50
+ action_dim = shard_metadata["action_dim"]
51
+ videos = np.memmap(
52
+ os.path.join(args.out_data_dir, "video.bin"),
53
+ dtype=token_dtype,
54
+ mode="write",
55
+ shape=(num_frames, *frame_dims)
56
+ )
57
+
58
+ actions = np.memmap(
59
+ os.path.join(args.out_data_dir, "actions", "actions.bin"),
60
+ dtype=np.float32,
61
+ mode="write",
62
+ shape=(num_frames, action_dim)
63
+ )
64
+
65
+ segment_ids = np.memmap(
66
+ os.path.join(args.out_data_dir, "segment_ids.bin"),
67
+ dtype=np.int32,
68
+ mode="write",
69
+ shape=(num_frames,)
70
+ )
71
+
72
+ prev_frame_ind = 0
73
+ prev_segment_id = 0
74
+
75
+ for shard_ind in tqdm(valid_inds):
76
+ shard_path = SHARD_DATA_FORMAT.format(shard_ind, args.num_shards)
77
+ with open(os.path.join(shard_path, "metadata.json"), "r") as f:
78
+ shard_metadata = json.load(f)
79
+
80
+ shard_num_frames = shard_metadata["num_images"]
81
+ videos[prev_frame_ind: prev_frame_ind + shard_num_frames] = np.memmap(
82
+ os.path.join(shard_path, "video.bin"),
83
+ dtype=np.dtype(shard_metadata["token_dtype"]),
84
+ mode="r",
85
+ shape=(shard_num_frames, *frame_dims),
86
+ )
87
+
88
+ actions[prev_frame_ind: prev_frame_ind + shard_num_frames] = np.memmap(
89
+ os.path.join(shard_path, "actions", "actions.bin"),
90
+ dtype=np.float32,
91
+ mode="r",
92
+ shape=(shard_num_frames, action_dim),
93
+ )
94
+
95
+ segment_ids[prev_frame_ind: prev_frame_ind + shard_num_frames] = np.memmap(
96
+ os.path.join(shard_path, "segment_ids.bin"),
97
+ dtype=np.int32,
98
+ mode="r",
99
+ shape=(shard_num_frames,),
100
+ ) + prev_segment_id
101
+
102
+ prev_segment_id = segment_ids[prev_frame_ind + shard_num_frames - 1] + 1
103
+ prev_frame_ind += shard_num_frames
104
+
105
+ assert prev_frame_ind == num_frames
106
+ print("Finished")
107
+
108
+ with (open(os.path.join(args.out_data_dir, "metadata.json"), "w") as f):
109
+ merged_metadata = shard_metadata \
110
+ | vars(args) \
111
+ | {"num_images": num_frames, "input_path": SHARD_DATA_FORMAT.format(0, args.num_shards)}
112
+
113
+ json.dump(merged_metadata, f)
datasets/utils.py ADDED
@@ -0,0 +1,244 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import cv2
4
+ import numpy as np
5
+ import torch
6
+ import torchvision.transforms.v2.functional as transforms_f
7
+ from diffusers import AutoencoderKLTemporalDecoder
8
+ from einops import rearrange
9
+ from transformers import T5Tokenizer, T5Model
10
+
11
+ from magvit2.config import VQConfig
12
+ from magvit2.models.lfqgan import VQModel
13
+
14
+ vision_model = None
15
+
16
+
17
+ def get_image_encoder(encoder_type: str, encoder_name_or_path: str):
18
+ encoder_type = encoder_type.lower()
19
+ if encoder_type == "magvit":
20
+ return VQModel(VQConfig(), ckpt_path=encoder_name_or_path)
21
+ elif encoder_type == "temporalvae":
22
+ return AutoencoderKLTemporalDecoder.from_pretrained(encoder_name_or_path, subfolder="vae")
23
+ else:
24
+ raise NotImplementedError(f"{encoder_type=}")
25
+
26
+
27
+ def set_seed(seed):
28
+ # set seed for reproducibility
29
+ torch.manual_seed(seed)
30
+ np.random.seed(seed)
31
+
32
+
33
+ def mkdir_if_missing(dst_dir):
34
+ """make destination folder if it's missing"""
35
+ if not os.path.exists(dst_dir):
36
+ os.makedirs(dst_dir)
37
+
38
+ def resize_image(image, resize=True):
39
+ MAX_RES = 1024
40
+
41
+ # convert to array
42
+ image = np.asarray(image)
43
+ h, w = image.shape[:2]
44
+ if h > MAX_RES or w > MAX_RES:
45
+ if h < w:
46
+ new_h, new_w = int(MAX_RES * w / h), MAX_RES
47
+ else:
48
+ new_h, new_w = MAX_RES, int(MAX_RES * h / w)
49
+ image = cv2.resize(image, (new_w, new_h))
50
+
51
+ if resize:
52
+ # resize the shorter side to 256 and then do a center crop
53
+ h, w = image.shape[:2]
54
+ if h < w:
55
+ new_h, new_w = 256, int(256 * w / h)
56
+ else:
57
+ new_h, new_w = int(256 * h / w), 256
58
+ image = cv2.resize(image, (new_w, new_h))
59
+
60
+ h, w = image.shape[:2]
61
+ crop_h, crop_w = 256, 256
62
+ start_h = (h - crop_h) // 2
63
+ start_w = (w - crop_w) // 2
64
+ image = image[start_h:start_h + crop_h, start_w:start_w + crop_w]
65
+ return image
66
+
67
+ def normalize_image(image, resize=True):
68
+ """
69
+ H x W x 3(uint8) -> imagenet normalized 3 x H x W
70
+
71
+ Normalizes image to [-1, 1].
72
+ Resizes the image if resize=True or if the image resolution > MAX_RES
73
+ """
74
+ image = resize_image(image, resize=resize)
75
+ # normalize between -1 and 1
76
+ image = image / 255.0
77
+ image = (image * 2 - 1.)
78
+ return torch.from_numpy(image.transpose(2, 0, 1))
79
+
80
+
81
+ def unnormalize_image(magvit_output):
82
+ """
83
+ [-1, 1] -> [0, 255]
84
+
85
+ Important: clip to [0, 255]
86
+ """
87
+ rescaled_output = ((magvit_output.detach().cpu() + 1) * 127.5)
88
+ clipped_output = torch.clamp(rescaled_output, 0, 255).to(dtype=torch.uint8)
89
+ return clipped_output
90
+
91
+ @torch.inference_mode()
92
+ @torch.no_grad()
93
+ def get_quantized_image_embeddings(
94
+ image,
95
+ encoder_type,
96
+ encoder_name_or_path,
97
+ keep_res=False,
98
+ device="cuda",
99
+ ):
100
+ """
101
+ image: (h, w, 3)
102
+ """
103
+ global vision_model
104
+ DEBUG = False
105
+ dtype = torch.bfloat16
106
+
107
+ if vision_model is None:
108
+ vision_model = get_image_encoder(encoder_type=encoder_type, encoder_name_or_path=encoder_name_or_path)
109
+ vision_model = vision_model.to(device=device, dtype=dtype)
110
+ vision_model.eval()
111
+
112
+ batch = normalize_image(image, resize=not keep_res)[None]
113
+ if not keep_res:
114
+ img_h, img_w = 256, 256
115
+ else:
116
+ img_h, img_w = batch.shape[2:]
117
+
118
+ h, w = img_h // 16, img_w // 16
119
+
120
+ with vision_model.ema_scope():
121
+ quant_, _, indices, _ = vision_model.encode(batch.to(device=device, dtype=dtype), flip=True)
122
+ indices = rearrange(indices, "(h w) -> h w", h=h, w=w)
123
+
124
+ # alternative way to get indices
125
+ # indices_ = vision_model.quantize.bits_to_indices(quant_.permute(0, 2, 3, 1) > 0).cpu().numpy()
126
+ # indices_ = rearrange(indices_, "(h w) -> h w", h=h, w=w)
127
+
128
+ if DEBUG:
129
+ # sanity check: decode and then visualize
130
+ with vision_model.ema_scope():
131
+ indices = indices[None]
132
+ # bit representations
133
+ quant = vision_model.quantize.get_codebook_entry(rearrange(indices, "b h w -> b (h w)"),
134
+ bhwc=indices.shape + (vision_model.quantize.codebook_dim,)).flip(1)
135
+ ## why is there a flip(1) needed for the codebook bits?
136
+ decoded_img = unnormalize_image(vision_model.decode(quant.to(device=device, dtype=dtype)))
137
+ transforms_f.to_pil_image(decoded_img[0]).save("decoded.png")
138
+ transforms_f.to_pil_image(image).save("original.png") # show()
139
+
140
+ # 18 x 16 x 16 of [-1., 1.] - > 16 x 16 uint32
141
+ indices = indices.type(torch.int32)
142
+ indices = indices.detach().cpu().numpy().astype(np.uint32)
143
+ return indices
144
+
145
+
146
+ @torch.inference_mode()
147
+ @torch.no_grad()
148
+ def get_vae_image_embeddings(
149
+ image,
150
+ encoder_type,
151
+ encoder_name_or_path,
152
+ keep_res: bool = False,
153
+ device="cuda",
154
+ ):
155
+ """
156
+ image: (h, w, 3), in [-1, 1]
157
+ use SD VAE to encode and decode the images.
158
+ """
159
+ global vision_model
160
+ DEBUG = False
161
+ dtype = torch.bfloat16
162
+
163
+ if vision_model is None:
164
+ vision_model = get_image_encoder(encoder_type, encoder_name_or_path)
165
+ vision_model = vision_model.to(device=device, dtype=dtype)
166
+ vision_model.eval()
167
+
168
+ # https://github.com/bytedance/IRASim/blob/main/sample/sample_autoregressive.py#L151
169
+ # if args.use_temporal_decoder:
170
+ # vae = AutoencoderKLTemporalDecoder.from_pretrained(args.vae_model_path, subfolder="t2v_required_models/vae_temporal_decoder").to(device)
171
+ # else:
172
+ # vae = AutoencoderKL.from_pretrained(args.vae_model_path, subfolder="vae").to(device)
173
+ # x = vae.encode(x).latent_dist.sample().mul_(vae.config.scaling_factor) ?
174
+
175
+ batch = normalize_image(image, resize=not keep_res)[None]
176
+
177
+ if isinstance(vision_model, AutoencoderKLTemporalDecoder):
178
+ # Think SVD expects images in [-1, 1] so we don't have to change anything?
179
+ # https://github.com/Stability-AI/generative-models/blob/1659a1c09b0953ad9cc0d480f42e4526c5575b37/scripts/demo/video_sampling.py#L182
180
+ # https://github.com/Stability-AI/generative-models/blob/1659a1c09b0953ad9cc0d480f42e4526c5575b37/scripts/demo/streamlit_helpers.py#L894
181
+ z = vision_model.encode(batch.to(device=device, dtype=dtype)).latent_dist.mean
182
+ elif isinstance(vision_model, VQModel): # vision_model should be VQModel
183
+ # with vision_model.ema_scope(): # doesn't matter due to bugged VQModel ckpt_path arg
184
+ z = vision_model.encode_without_quantize(batch.to(device=device, dtype=dtype))
185
+ else:
186
+ raise NotImplementedError(f"{vision_model=}")
187
+
188
+ if DEBUG:
189
+ decoded_img = unnormalize_image(vision_model.decode(z.to(device=device, dtype=dtype)))
190
+ transforms_f.to_pil_image(decoded_img[0]).save("decoded_unquant.png")
191
+ transforms_f.to_pil_image(image).save("original.png")
192
+
193
+ return z[0].detach().cpu().float().numpy().astype(np.float16)
194
+
195
+ # switch to VAE in SD
196
+ # https://huggingface.co/stabilityai/stable-diffusion-3.5-large/tree/main/vae
197
+ # https://github.com/bytedance/IRASim/blob/main/sample/sample_autoregressive.py#L151
198
+ # from diffusers.models import AutoencoderKL,AutoencoderKLTemporalDecoder
199
+ # vae_model_path = 'pretrained_models/stabilityai/stable-diffusion-xl-base-1.0'
200
+ # if args.use_temporal_decoder:
201
+ # vae = AutoencoderKLTemporalDecoder.from_pretrained(vae_model_path, subfolder="t2v_required_models/vae_temporal_decoder").to(device)
202
+ # else:
203
+ # vae = AutoencoderKL.from_pretrained(vae_model_path, subfolder="vae").to(device)
204
+ # z = vae.encode(x).latent_dist.sample().mul_(vae.config.scaling_factor)
205
+ # if DEBUG:
206
+ # decoded_img = unnormalize_image(vae.decode(z.to(device=device, dtype=dtype) / vae.config.scaling_factor))
207
+ # transforms_f.to_pil_image(decoded_img[0]).save("decoded_unquant.png")
208
+ # transforms_f.to_pil_image(image).save("original.png")
209
+
210
+
211
+ @torch.no_grad()
212
+ def get_t5_embeddings(language, per_token=True, max_length=16, device="cpu"):
213
+ """Get T5 embedding"""
214
+ global global_language_model, t5_tok
215
+ if global_language_model is None:
216
+ try:
217
+ t5_model = T5Model.from_pretrained("t5-base")
218
+ t5_tok = T5Tokenizer.from_pretrained("t5-base")
219
+ except:
220
+ t5_model = T5Model.from_pretrained("t5-base", local_files_only=True)
221
+ t5_tok = T5Tokenizer.from_pretrained("t5-base", local_files_only=True)
222
+ t5_model = t5_model.to(device)
223
+ global_language_model = t5_model
224
+ global_language_model.eval()
225
+
226
+ # forward pass through encoder only
227
+ enc = t5_tok(
228
+ [language],
229
+ return_tensors="pt",
230
+ padding="max_length",
231
+ truncation=True,
232
+ max_length=max_length,
233
+ ).to(device)
234
+
235
+ output = global_language_model.encoder(
236
+ input_ids=enc["input_ids"], attention_mask=enc["attention_mask"], return_dict=True
237
+ )
238
+ torch.cuda.empty_cache()
239
+ if per_token:
240
+ return output.last_hidden_state[0].detach().cpu().numpy()
241
+ else:
242
+ # get the final hidden states. average across tokens.
243
+ emb = output.last_hidden_state[0].mean(dim=0).detach().cpu().numpy()
244
+ return emb
experiments/.DS_Store ADDED
Binary file (6.15 kB). View file
 
experiments/datasplit/.DS_Store ADDED
Binary file (6.15 kB). View file
 
experiments/datasplit/dataset1.yaml ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ domains: >
2
+ language_table
experiments/datasplit/dataset10.yaml ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ domains: >
2
+ bridge_data_v2,
3
+ fractal20220817_data,
4
+ language_table,
5
+ ucsd_pick_and_place_dataset_converted_externally_to_rlds,
6
+ kaist_nonprehensile_converted_externally_to_rlds,
7
+ ucsd_kitchen_dataset_converted_externally_to_rlds,
8
+ berkeley_fanuc_manipulation,
9
+ bc_z,
10
+ cmu_play_fusion,
11
+ kuka
experiments/datasplit/dataset15.yaml ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ domains: >
2
+ bridge_data_v2,
3
+ fractal20220817_data,
4
+ language_table,
5
+ ucsd_pick_and_place_dataset_converted_externally_to_rlds,
6
+ kaist_nonprehensile_converted_externally_to_rlds,
7
+ ucsd_kitchen_dataset_converted_externally_to_rlds,
8
+ utokyo_xarm_bimanual_converted_externally_to_rlds,
9
+ stanford_hydra_dataset_converted_externally_to_rlds,
10
+ austin_sirius_dataset_converted_externally_to_rlds,
11
+ berkeley_fanuc_manipulation,
12
+ berkeley_mvp_converted_externally_to_rlds,
13
+ berkeley_rpt_converted_externally_to_rlds,
14
+ cmu_play_fusion,
15
+ iamlab_cmu_pickup_insert_converted_externally_to_rlds,
16
+ qut_dexterous_manpulation
experiments/datasplit/dataset15_vae.yaml ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ domains: >
2
+ bridge_data_v2,
3
+ fractal20220817_data,
4
+ language_table,
5
+ ucsd_pick_and_place_dataset_converted_externally_to_rlds,
6
+ kaist_nonprehensile_converted_externally_to_rlds,
7
+ ucsd_kitchen_dataset_converted_externally_to_rlds,
8
+ utokyo_xarm_bimanual_converted_externally_to_rlds,
9
+ stanford_hydra_dataset_converted_externally_to_rlds,
10
+ austin_sirius_dataset_converted_externally_to_rlds,
11
+ berkeley_fanuc_manipulation,
12
+ berkeley_mvp_converted_externally_to_rlds,
13
+ berkeley_rpt_converted_externally_to_rlds,
14
+ cmu_play_fusion,
15
+ iamlab_cmu_pickup_insert_converted_externally_to_rlds,
16
+ qut_dexterous_manpulation
experiments/datasplit/dataset20.yaml ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ domains: >
2
+ bridge_data_v2,
3
+ fractal20220817_data,
4
+ language_table,
5
+ ucsd_pick_and_place_dataset_converted_externally_to_rlds,
6
+ kaist_nonprehensile_converted_externally_to_rlds,
7
+ ucsd_kitchen_dataset_converted_externally_to_rlds,
8
+ utokyo_xarm_bimanual_converted_externally_to_rlds,
9
+ stanford_hydra_dataset_converted_externally_to_rlds,
10
+ austin_sirius_dataset_converted_externally_to_rlds,
11
+ berkeley_fanuc_manipulation,
12
+ berkeley_mvp_converted_externally_to_rlds,
13
+ cmu_play_fusion,
14
+ robo_net,
15
+ furniture_bench_dataset_converted_externally_to_rlds,
16
+ dlr_sara_grid_clamp_converted_externally_to_rlds,
17
+ cmu_stretch,
18
+ droid,
19
+ toto,
20
+ bc_z,
21
+ kuka
experiments/datasplit/dataset20_vae.yaml ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ domains: >
2
+ language_table,
3
+ ucsd_pick_and_place_dataset_converted_externally_to_rlds,
4
+ kaist_nonprehensile_converted_externally_to_rlds,
5
+ ucsd_kitchen_dataset_converted_externally_to_rlds,
6
+ utokyo_xarm_bimanual_converted_externally_to_rlds,
7
+ stanford_hydra_dataset_converted_externally_to_rlds,
8
+ austin_sirius_dataset_converted_externally_to_rlds,
9
+ berkeley_fanuc_manipulation,
10
+ berkeley_mvp_converted_externally_to_rlds,
11
+ berkeley_rpt_converted_externally_to_rlds,
12
+ cmu_play_fusion,
13
+ iamlab_cmu_pickup_insert_converted_externally_to_rlds,
14
+ qut_dexterous_manpulation,
15
+ robo_net,
16
+ dlr_sara_grid_clamp_converted_externally_to_rlds,
17
+ cmu_stretch,
18
+ columbia_cairlab_pusht_real,
19
+ droid,
20
+ toto,
21
+ io_ai_tech
experiments/datasplit/dataset25.yaml ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ domains: >
2
+ bridge_data_v2,
3
+ fractal20220817_data,
4
+ language_table,
5
+ ucsd_pick_and_place_dataset_converted_externally_to_rlds,
6
+ kaist_nonprehensile_converted_externally_to_rlds,
7
+ ucsd_kitchen_dataset_converted_externally_to_rlds,
8
+ utokyo_xarm_bimanual_converted_externally_to_rlds,
9
+ stanford_hydra_dataset_converted_externally_to_rlds,
10
+ austin_sirius_dataset_converted_externally_to_rlds,
11
+ berkeley_fanuc_manipulation,
12
+ berkeley_mvp_converted_externally_to_rlds,
13
+ cmu_play_fusion,
14
+ iamlab_cmu_pickup_insert_converted_externally_to_rlds,
15
+ robo_net,
16
+ furniture_bench_dataset_converted_externally_to_rlds,
17
+ dlr_sara_grid_clamp_converted_externally_to_rlds,
18
+ cmu_stretch,
19
+ droid,
20
+ toto,
21
+ io_ai_tech,
22
+ bc_z,
23
+ roboturk,
24
+ cmu_franka_exploration_dataset_converted_externally_to_rlds,
25
+ nyu_door_opening_surprising_effectiveness,
26
+ kuka