Spaces:
Running
on
Zero
Running
on
Zero
draft
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- README.md +40 -13
- app.py +69 -0
- build.sh +12 -0
- common/__init__.py +0 -0
- common/calculate_fvd.py +80 -0
- common/data_sampler.py +334 -0
- common/eval_utils.py +105 -0
- common/fid_score.py +382 -0
- common/fvd/styleganv/fvd.py +90 -0
- common/fvd/styleganv/i3d_torchscript.pt +3 -0
- common/fvd/videogpt/fvd.py +137 -0
- common/fvd/videogpt/i3d_pretrained_400.pt +3 -0
- common/fvd/videogpt/pytorch_i3d.py +322 -0
- common/inception.py +344 -0
- common/plot/__init__.py +0 -0
- common/plot/aggregated_output.csv +18 -0
- common/plot/plot_arch_ablation.py +60 -0
- common/plot/plot_arch_ablation_deltapsnr.py +49 -0
- common/plot/plot_dataset_scale.py +69 -0
- common/plot/plot_dataset_traj_scale.py +48 -0
- common/plot/plot_dynamics_ablation.py +56 -0
- common/plot/plot_dynamics_ablation_deltapsnr.py +51 -0
- common/plot/plot_from_wandb.py +185 -0
- common/plot/plot_from_wandb_singledataset.py +144 -0
- common/plot/plot_model_scale.py +64 -0
- common/plot/plot_pretrain_ablation.py +44 -0
- common/plot/plot_pretrain_ablation_mar.py +45 -0
- cont_data.py +245 -0
- data.py +240 -0
- datasets/.DS_Store +0 -0
- datasets/__init__.py +0 -0
- datasets/encode_extern_dataset.py +291 -0
- datasets/encode_openx_dataset.py +459 -0
- datasets/extern/__init__.py +0 -0
- datasets/extern/ego4d.py +193 -0
- datasets/extern/egoexo4d.py +186 -0
- datasets/extern/epic_kitchen.py +115 -0
- datasets/extern/frodobot.py +128 -0
- datasets/extern/robomimic.py +108 -0
- datasets/merge_shards.py +113 -0
- datasets/utils.py +244 -0
- experiments/.DS_Store +0 -0
- experiments/datasplit/.DS_Store +0 -0
- experiments/datasplit/dataset1.yaml +2 -0
- experiments/datasplit/dataset10.yaml +11 -0
- experiments/datasplit/dataset15.yaml +16 -0
- experiments/datasplit/dataset15_vae.yaml +16 -0
- experiments/datasplit/dataset20.yaml +21 -0
- experiments/datasplit/dataset20_vae.yaml +21 -0
- experiments/datasplit/dataset25.yaml +26 -0
README.md
CHANGED
@@ -1,13 +1,40 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Heterogeneous World Modeling with Actions
|
2 |
+
|
3 |
+
Progress in video generation may soon make it possible to evaluate robot policies in a completely learned world model.
|
4 |
+
Modified from [here](https://github.com/1x-technologies/1xgpt)
|
5 |
+
|
6 |
+
## Getting Started
|
7 |
+
We require `Python 3.10` or later. This code was tested with `Python 3.10.12`.
|
8 |
+
|
9 |
+
```
|
10 |
+
# Install dependencies and download data
|
11 |
+
./build.sh
|
12 |
+
|
13 |
+
# Source the Python environment
|
14 |
+
source venv/bin/activate
|
15 |
+
```
|
16 |
+
|
17 |
+
## File Structures
|
18 |
+
```angular2html
|
19 |
+
├── ...
|
20 |
+
├── HPT-Video
|
21 |
+
| |── data # cached token datasets and model checkpoints
|
22 |
+
| |── genie # main modeling code
|
23 |
+
| | |── diffusion # diffusion loss related
|
24 |
+
| | |── evaluate.py # evaluate a trained model
|
25 |
+
| | |── st_mar.py # spatial time MAR
|
26 |
+
| | |── generate.py # generate tokens from trained model
|
27 |
+
| | |── st_maskgit.py # spatial-time maskgit
|
28 |
+
| |── magvit # magvit code
|
29 |
+
| |── sim # simulation related codebase
|
30 |
+
| |── experiments
|
31 |
+
| | |── cmd # handy commands
|
32 |
+
| | |── datasplit # dataset split
|
33 |
+
| | |── scripts # ablation and training scripts.
|
34 |
+
| |── common # common utility and plot scripts
|
35 |
+
| |── train.py # train using magvit
|
36 |
+
| |── train_diffusion.py # train using mar
|
37 |
+
| |── train_multi.py # train on multiple datasets jointly
|
38 |
+
| |── visualize.py # visualize generated tokens
|
39 |
+
└── ...
|
40 |
+
```
|
app.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import numpy as np
|
3 |
+
from PIL import Image
|
4 |
+
import cv2
|
5 |
+
from sim.simulator import GenieSimulator
|
6 |
+
|
7 |
+
RES = 512
|
8 |
+
image = Image.open("sim/assets/langtable_prompt/frame_06.png")
|
9 |
+
genie = GenieSimulator(
|
10 |
+
image_encoder_type='temporalvae',
|
11 |
+
image_encoder_ckpt='stabilityai/stable-video-diffusion-img2vid',
|
12 |
+
quantize=False,
|
13 |
+
backbone_type='stmar',
|
14 |
+
backbone_ckpt='data/mar_ckpt/langtable',
|
15 |
+
prompt_horizon=11,
|
16 |
+
action_stride=1,
|
17 |
+
domain='language_table',
|
18 |
+
)
|
19 |
+
prompt_image = np.tile(
|
20 |
+
np.array(image), (genie.prompt_horizon, 1, 1, 1)
|
21 |
+
).astype(np.uint8)
|
22 |
+
prompt_action = np.zeros(
|
23 |
+
(genie.prompt_horizon, genie.action_stride, 2)
|
24 |
+
).astype(np.float32)
|
25 |
+
genie.set_initial_state((prompt_image, prompt_action))
|
26 |
+
image = genie.reset()
|
27 |
+
image = cv2.resize(image, (RES, RES))
|
28 |
+
image = Image.fromarray(image)
|
29 |
+
|
30 |
+
# Example model: takes a direction and returns a random image
|
31 |
+
def model(direction: str, genie=genie):
|
32 |
+
if direction == 'right':
|
33 |
+
action = np.array([0, 0.05])
|
34 |
+
elif direction == 'left':
|
35 |
+
action = np.array([0, -0.05])
|
36 |
+
elif direction == 'down':
|
37 |
+
action = np.array([0.05, 0])
|
38 |
+
elif direction == 'up':
|
39 |
+
action = np.array([-0.05, 0])
|
40 |
+
else:
|
41 |
+
raise ValueError(f"Invalid direction: {direction}")
|
42 |
+
next_image = genie.step(action)['pred_next_frame']
|
43 |
+
next_image = cv2.resize(next_image, (RES, RES))
|
44 |
+
return Image.fromarray(next_image)
|
45 |
+
|
46 |
+
# Gradio function to handle user input
|
47 |
+
def handle_input(direction):
|
48 |
+
print(f"User clicked: {direction}")
|
49 |
+
new_image = model(direction) # Get a new image from the model
|
50 |
+
return new_image
|
51 |
+
|
52 |
+
if __name__ == '__main__':
|
53 |
+
with gr.Blocks() as demo:
|
54 |
+
with gr.Row():
|
55 |
+
image_display = gr.Image(value=image, type="pil", label="Generated Image")
|
56 |
+
with gr.Row():
|
57 |
+
up = gr.Button("↑ Up")
|
58 |
+
with gr.Row():
|
59 |
+
left = gr.Button("← Left")
|
60 |
+
down = gr.Button("↓ Down")
|
61 |
+
right = gr.Button("→ Right")
|
62 |
+
|
63 |
+
# Define button interactions
|
64 |
+
up.click(fn=lambda: handle_input("up"), outputs=image_display)
|
65 |
+
down.click(fn=lambda: handle_input("down"), outputs=image_display)
|
66 |
+
left.click(fn=lambda: handle_input("left"), outputs=image_display)
|
67 |
+
right.click(fn=lambda: handle_input("right"), outputs=image_display)
|
68 |
+
|
69 |
+
demo.launch()
|
build.sh
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/bash
|
2 |
+
|
3 |
+
python3 -m venv venv
|
4 |
+
source venv/bin/activate
|
5 |
+
python -m pip install -r requirements.txt
|
6 |
+
FLASH_ATTENTION_SKIP_CUDA_BUILD=TRUE python -m pip install flash-attn==2.5.8 --no-build-isolation
|
7 |
+
|
8 |
+
# Download datasets to data/train_v1.0, data/val_v1.0
|
9 |
+
huggingface-cli download 1x-technologies/worldmodel --repo-type dataset --local-dir data
|
10 |
+
|
11 |
+
mv data/val_v1.1 data/1x_humanoid_magvit_traj1000_val
|
12 |
+
mv data/train_v1.1 data/1x_humanoid_magvit_traj1000_train
|
common/__init__.py
ADDED
File without changes
|
common/calculate_fvd.py
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Code adapted from https://github.com/JunyaoHu/common_metrics_on_video_quality
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
from tqdm import tqdm
|
5 |
+
|
6 |
+
def trans(x):
|
7 |
+
# if greyscale images add channel
|
8 |
+
if x.shape[-3] == 1:
|
9 |
+
x = x.repeat(1, 1, 3, 1, 1)
|
10 |
+
|
11 |
+
# permute BTCHW -> BCTHW
|
12 |
+
x = x.permute(0, 2, 1, 3, 4)
|
13 |
+
|
14 |
+
return x
|
15 |
+
|
16 |
+
def calculate_fvd(videos1, videos2, device="cuda", method='styleganv'):
|
17 |
+
|
18 |
+
if method == 'styleganv':
|
19 |
+
from .fvd.styleganv.fvd import get_fvd_feats, frechet_distance, load_i3d_pretrained
|
20 |
+
elif method == 'videogpt':
|
21 |
+
from .fvd.videogpt.fvd import load_i3d_pretrained
|
22 |
+
from .fvd.videogpt.fvd import get_fvd_logits as get_fvd_feats
|
23 |
+
from .fvd.videogpt.fvd import frechet_distance
|
24 |
+
|
25 |
+
|
26 |
+
# videos [batch_size, timestamps, channel, h, w]
|
27 |
+
|
28 |
+
assert videos1.shape == videos2.shape
|
29 |
+
|
30 |
+
i3d = load_i3d_pretrained(device=device)
|
31 |
+
fvd_results = []
|
32 |
+
|
33 |
+
# support grayscale input, if grayscale -> channel*3
|
34 |
+
# BTCHW -> BCTHW
|
35 |
+
# videos -> [batch_size, channel, timestamps, h, w]
|
36 |
+
|
37 |
+
videos1 = trans(videos1)
|
38 |
+
videos2 = trans(videos2)
|
39 |
+
|
40 |
+
# fvd_results = {}
|
41 |
+
|
42 |
+
# for calculate FVD, each clip_timestamp must >= 10
|
43 |
+
for clip_timestamp in tqdm(range(10, videos1.shape[-3]+1)):
|
44 |
+
# print("clip_timestamp", clip_timestamp)
|
45 |
+
# get a video clip
|
46 |
+
# videos_clip [batch_size, channel, timestamps[:clip], h, w]
|
47 |
+
videos_clip1 = videos1[:, :, : clip_timestamp]
|
48 |
+
videos_clip2 = videos2[:, :, : clip_timestamp]
|
49 |
+
|
50 |
+
# get FVD features
|
51 |
+
feats1 = get_fvd_feats(videos_clip1, i3d=i3d, device=device)
|
52 |
+
feats2 = get_fvd_feats(videos_clip2, i3d=i3d, device=device)
|
53 |
+
|
54 |
+
# calculate FVD when timestamps[:clip]
|
55 |
+
fvd_results.append(frechet_distance(feats1, feats2))
|
56 |
+
|
57 |
+
|
58 |
+
return fvd_results[-1] # only the last step
|
59 |
+
|
60 |
+
# test code / using example
|
61 |
+
|
62 |
+
def main():
|
63 |
+
NUMBER_OF_VIDEOS = 8
|
64 |
+
VIDEO_LENGTH = 50
|
65 |
+
CHANNEL = 3
|
66 |
+
SIZE = 64
|
67 |
+
videos1 = torch.zeros(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False)
|
68 |
+
videos2 = torch.ones(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False)
|
69 |
+
device = torch.device("cuda")
|
70 |
+
# device = torch.device("cpu")
|
71 |
+
|
72 |
+
import json
|
73 |
+
result = calculate_fvd(videos1, videos2, device, method='videogpt')
|
74 |
+
print(json.dumps(result, indent=4))
|
75 |
+
|
76 |
+
result = calculate_fvd(videos1, videos2, device, method='styleganv')
|
77 |
+
print(json.dumps(result, indent=4))
|
78 |
+
|
79 |
+
if __name__ == "__main__":
|
80 |
+
main()
|
common/data_sampler.py
ADDED
@@ -0,0 +1,334 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
import os
|
3 |
+
import random
|
4 |
+
from operator import itemgetter
|
5 |
+
from typing import Optional, List
|
6 |
+
|
7 |
+
import matplotlib.pyplot as plt
|
8 |
+
import numpy as np
|
9 |
+
import torch
|
10 |
+
import torch.distributed as dist
|
11 |
+
from PIL import Image
|
12 |
+
from torch.utils.data import Dataset, Sampler
|
13 |
+
from torch.utils.data import Sampler, DistributedSampler
|
14 |
+
|
15 |
+
|
16 |
+
def chunk_indices(indices: list[int], size: int) -> tuple[torch.Tensor, ...]:
|
17 |
+
return torch.split(torch.tensor(indices), size)
|
18 |
+
|
19 |
+
|
20 |
+
class CombinedDataLoader:
|
21 |
+
def __init__(self, dataloaders, reinit=True):
|
22 |
+
"""
|
23 |
+
:param dataloaders: list of pytorch dataloaders
|
24 |
+
"""
|
25 |
+
self.dataloaders = dataloaders
|
26 |
+
self.reinit = reinit
|
27 |
+
self.dataloader_idx = 0
|
28 |
+
self.loader_iters = [iter(dataloader) for dataloader in self.dataloaders]
|
29 |
+
|
30 |
+
def __iter__(self):
|
31 |
+
return self
|
32 |
+
|
33 |
+
def __next__(self):
|
34 |
+
# Choose a dataloader based on weights
|
35 |
+
chosen_loader_iter = self.loader_iters[self.dataloader_idx]
|
36 |
+
|
37 |
+
try:
|
38 |
+
data = next(chosen_loader_iter)
|
39 |
+
return data
|
40 |
+
except StopIteration:
|
41 |
+
# Handle case where a dataloader is exhausted. Reinitialize the iterator.
|
42 |
+
self.dataloader_idx = self.dataloader_idx + 1
|
43 |
+
if self.dataloader_idx == len(self.loader_iters):
|
44 |
+
self.dataloader_idx = 0 # reset
|
45 |
+
raise StopIteration
|
46 |
+
return self.__next__()
|
47 |
+
|
48 |
+
def __len__(self):
|
49 |
+
return sum([len(dataloader) for dataloader in self.dataloaders])
|
50 |
+
|
51 |
+
|
52 |
+
class CombinedBatchSampler(torch.utils.data.Sampler):
|
53 |
+
# For validation dataloaders.
|
54 |
+
def __init__(self, datasets, batch_size, num_processes=1, shuffle=False):
|
55 |
+
super().__init__() # no-op
|
56 |
+
prev_idx = 0
|
57 |
+
all_batches = []
|
58 |
+
|
59 |
+
for dataset in datasets:
|
60 |
+
indices = list(range(prev_idx, prev_idx + len(dataset)))
|
61 |
+
if shuffle:
|
62 |
+
random.shuffle(indices)
|
63 |
+
|
64 |
+
# exclude remainer, if necessary
|
65 |
+
remainder = len(indices) % (batch_size * num_processes)
|
66 |
+
if remainder > 0:
|
67 |
+
indices = indices[:-remainder] # exclude last
|
68 |
+
|
69 |
+
chunk_i = chunk_indices(indices, batch_size) # equally sized
|
70 |
+
all_batches += chunk_i
|
71 |
+
|
72 |
+
# add the new indices without the last batch
|
73 |
+
prev_idx += len(chunk_i) * batch_size # len(dataset)
|
74 |
+
|
75 |
+
if shuffle:
|
76 |
+
random.shuffle(all_batches)
|
77 |
+
|
78 |
+
self.all_batches = all_batches
|
79 |
+
|
80 |
+
def __iter__(self):
|
81 |
+
return iter(self.all_batches)
|
82 |
+
|
83 |
+
def __len__(self):
|
84 |
+
return len(self.all_batches)
|
85 |
+
|
86 |
+
|
87 |
+
# https://github.com/catalyst-team/catalyst/blob/master/catalyst/data/sampler.py
|
88 |
+
class DatasetFromSampler(Dataset):
|
89 |
+
"""Dataset to create indexes from `Sampler`.
|
90 |
+
|
91 |
+
Args:
|
92 |
+
sampler: PyTorch sampler
|
93 |
+
"""
|
94 |
+
|
95 |
+
def __init__(self, sampler: Sampler):
|
96 |
+
"""Initialisation for DatasetFromSampler."""
|
97 |
+
self.sampler = sampler
|
98 |
+
self.sampler_list = None
|
99 |
+
|
100 |
+
def __getitem__(self, index: int):
|
101 |
+
"""Gets element of the dataset.
|
102 |
+
|
103 |
+
Args:
|
104 |
+
index: index of the element in the dataset
|
105 |
+
|
106 |
+
Returns:
|
107 |
+
Single element by index
|
108 |
+
"""
|
109 |
+
if self.sampler_list is None:
|
110 |
+
self.sampler_list = list(self.sampler)
|
111 |
+
return self.sampler_list[index]
|
112 |
+
|
113 |
+
def __len__(self) -> int:
|
114 |
+
"""
|
115 |
+
Returns:
|
116 |
+
int: length of the dataset
|
117 |
+
"""
|
118 |
+
return len(self.sampler)
|
119 |
+
|
120 |
+
|
121 |
+
class DistributedSamplerWrapper(DistributedSampler):
|
122 |
+
"""
|
123 |
+
Wrapper over `Sampler` for distributed training.
|
124 |
+
Allows you to use any sampler in distributed mode.
|
125 |
+
|
126 |
+
It is especially useful in conjunction with
|
127 |
+
`torch.nn.parallel.DistributedDataParallel`. In such case, each
|
128 |
+
process can pass a DistributedSamplerWrapper instance as a DataLoader
|
129 |
+
sampler, and load a subset of subsampled data of the original dataset
|
130 |
+
that is exclusive to it.
|
131 |
+
|
132 |
+
.. note::
|
133 |
+
Sampler is assumed to be of constant size.
|
134 |
+
"""
|
135 |
+
|
136 |
+
def __init__(
|
137 |
+
self,
|
138 |
+
sampler,
|
139 |
+
num_replicas: Optional[int] = None,
|
140 |
+
rank: Optional[int] = None,
|
141 |
+
shuffle: bool = True,
|
142 |
+
):
|
143 |
+
"""
|
144 |
+
|
145 |
+
Args:
|
146 |
+
sampler: Sampler used for subsampling
|
147 |
+
num_replicas (int, optional): Number of processes participating in
|
148 |
+
distributed training
|
149 |
+
rank (int, optional): Rank of the current process
|
150 |
+
within ``num_replicas``
|
151 |
+
shuffle (bool, optional): If true (default),
|
152 |
+
sampler will shuffle the indices
|
153 |
+
"""
|
154 |
+
super(DistributedSamplerWrapper, self).__init__(
|
155 |
+
DatasetFromSampler(sampler),
|
156 |
+
num_replicas=num_replicas,
|
157 |
+
rank=rank,
|
158 |
+
shuffle=shuffle,
|
159 |
+
)
|
160 |
+
self.sampler = sampler
|
161 |
+
|
162 |
+
def __iter__(self):
|
163 |
+
"""Iterate over sampler.
|
164 |
+
|
165 |
+
Returns:
|
166 |
+
python iterator
|
167 |
+
"""
|
168 |
+
self.dataset = DatasetFromSampler(self.sampler)
|
169 |
+
indexes_of_indexes = super().__iter__()
|
170 |
+
subsampler_indexes = self.dataset
|
171 |
+
return iter(itemgetter(*indexes_of_indexes)(subsampler_indexes))
|
172 |
+
|
173 |
+
|
174 |
+
# https://github.com/rabeehk/hyperformer/blob/main/hyperformer/data/multitask_sampler.py
|
175 |
+
class MultiTaskBatchSampler(Sampler):
|
176 |
+
"""Defines a sampler to sample multiple datasets with temperature sampling
|
177 |
+
in a distributed fashion."""
|
178 |
+
|
179 |
+
def __init__(
|
180 |
+
self,
|
181 |
+
dataset_sizes: List[int],
|
182 |
+
batch_size: int,
|
183 |
+
temperature: float,
|
184 |
+
dataset_groups=[],
|
185 |
+
num_replicas: Optional[int] = 1,
|
186 |
+
rank: Optional[int] = 0,
|
187 |
+
seed: int = 0,
|
188 |
+
shuffle: bool = True,
|
189 |
+
shuffle_task: bool = True,
|
190 |
+
) -> None:
|
191 |
+
"""Constructor for MultiTaskBatchSampler.
|
192 |
+
Args:
|
193 |
+
dataset_sizes: a list of integers, specifies the number of samples in
|
194 |
+
each dataset.
|
195 |
+
batch_size: integer, specifies the batch size.
|
196 |
+
temperature: float, temperature used for temperature sampling. The larger
|
197 |
+
the value, the datasets are sampled equally, and for value of 0, the datasets
|
198 |
+
will be sampled according to their number of samples.
|
199 |
+
num_replicas: integer, specifies the number of processes.
|
200 |
+
rank: integer, specifies the rank of the current process/
|
201 |
+
seed: integer, random seed.
|
202 |
+
shuffle: bool, if set to true, the datasets will be shuffled in each epoch.
|
203 |
+
"""
|
204 |
+
|
205 |
+
if num_replicas is None:
|
206 |
+
if not dist.is_available():
|
207 |
+
raise RuntimeError("Requires distributed package to be available")
|
208 |
+
num_replicas = dist.get_world_size()
|
209 |
+
if rank is None:
|
210 |
+
if not dist.is_available():
|
211 |
+
raise RuntimeError("Requires distributed package to be available")
|
212 |
+
rank = dist.get_rank()
|
213 |
+
print("data sampler rank:", rank)
|
214 |
+
|
215 |
+
if rank >= num_replicas or rank < 0:
|
216 |
+
raise ValueError(
|
217 |
+
"Invalid rank {}, rank should be in the interval" " [0, {}]".format(rank, num_replicas - 1)
|
218 |
+
)
|
219 |
+
|
220 |
+
self.dataset_groups = dataset_groups
|
221 |
+
print("dataset groups:", self.dataset_groups)
|
222 |
+
|
223 |
+
self.num_replicas = num_replicas
|
224 |
+
self.shuffle_task = shuffle_task
|
225 |
+
self.rank = rank
|
226 |
+
self.batch_size = batch_size
|
227 |
+
self.dataset_sizes = dataset_sizes
|
228 |
+
|
229 |
+
# By default we drop the last elements if dataset is not divisible by the number of ranks.
|
230 |
+
self.rank_dataset_sizes = [dataset_size // self.num_replicas for dataset_size in self.dataset_sizes]
|
231 |
+
self.dataset_offsets = torch.cumsum(torch.LongTensor([0] + dataset_sizes), 0)
|
232 |
+
self.total_sizes = [
|
233 |
+
(dataset_size // self.num_replicas) * self.num_replicas for dataset_size in self.dataset_sizes
|
234 |
+
]
|
235 |
+
self.temperature = temperature
|
236 |
+
self.seed = seed
|
237 |
+
self.epoch = 0
|
238 |
+
self.num_batches_per_epoch = (
|
239 |
+
(np.sum(dataset_sizes) + self.batch_size - 1) // self.batch_size // self.num_replicas
|
240 |
+
)
|
241 |
+
self.shuffle = shuffle
|
242 |
+
print(f"{num_replicas=} {rank=} {self.num_batches_per_epoch=} {self.total_sizes=} self.weights={self.generate_tasks_distribution()}")
|
243 |
+
|
244 |
+
def generate_tasks_distribution(self):
|
245 |
+
"""Given the dataset sizes computes the weights to sample each dataset
|
246 |
+
according to the temperature sampling."""
|
247 |
+
if len(self.dataset_groups) > 0:
|
248 |
+
# normalize across groups first
|
249 |
+
weights = []
|
250 |
+
num_groups = len(self.dataset_groups)
|
251 |
+
for group in self.dataset_groups:
|
252 |
+
lo, hi = group
|
253 |
+
dataset_sizes = [self.dataset_sizes[idx] for idx in range(lo, hi)]
|
254 |
+
total_size = sum(dataset_sizes)
|
255 |
+
group_weights = np.array([(size / total_size) ** (1.0 / self.temperature) for size in dataset_sizes])
|
256 |
+
group_weights = group_weights / np.sum(group_weights) / num_groups
|
257 |
+
weights = np.concatenate((weights, group_weights))
|
258 |
+
|
259 |
+
else:
|
260 |
+
total_size = sum(self.dataset_sizes)
|
261 |
+
weights = np.array([(size / total_size) ** (1.0 / self.temperature) for size in self.dataset_sizes])
|
262 |
+
weights = weights / np.sum(weights)
|
263 |
+
return torch.as_tensor(weights, dtype=torch.double)
|
264 |
+
|
265 |
+
def __iter__(self):
|
266 |
+
# Defines torch generator, to make random choices consistent across cores in
|
267 |
+
# different epochs, the seed needs to be set based on seed and epoch.
|
268 |
+
generator = torch.Generator()
|
269 |
+
generator.manual_seed(self.seed + self.epoch)
|
270 |
+
|
271 |
+
# Shuffles the datasets if shuffle is set to true.
|
272 |
+
indices = []
|
273 |
+
for dataset_size in self.dataset_sizes:
|
274 |
+
if self.shuffle:
|
275 |
+
indices.append(torch.randperm(dataset_size, generator=generator).tolist())
|
276 |
+
else:
|
277 |
+
indices.append(list(range(dataset_size)))
|
278 |
+
|
279 |
+
# Shards the datasets across the all processes.
|
280 |
+
self.rank_indices = []
|
281 |
+
for i in range(len(self.dataset_sizes)):
|
282 |
+
self.rank_indices.append(indices[i][self.rank : self.total_sizes[i] : self.num_replicas])
|
283 |
+
|
284 |
+
# To make the model consistent across different processes, since the
|
285 |
+
# model is based on tasks, we need to make sure the same task is selected
|
286 |
+
# across different processes.
|
287 |
+
tasks_distribution: torch.Tensor = self.generate_tasks_distribution()
|
288 |
+
|
289 |
+
# Chooses the tasks which will be used in each batch in one epoch.
|
290 |
+
# With passing generator, we make sure this choice is consistent across
|
291 |
+
# different processes.
|
292 |
+
|
293 |
+
# want them to be different.
|
294 |
+
if self.shuffle_task:
|
295 |
+
generator.manual_seed(self.seed + self.epoch + self.rank)
|
296 |
+
batch_task_assignments = torch.multinomial(
|
297 |
+
tasks_distribution, self.num_batches_per_epoch, replacement=True, generator=generator
|
298 |
+
)
|
299 |
+
|
300 |
+
for batch_task in batch_task_assignments:
|
301 |
+
# Gets the number of samples of the selected datasets available for the current rank.
|
302 |
+
num_task_samples = self.rank_dataset_sizes[batch_task]
|
303 |
+
# Computes the random samples from the chosen dataset.
|
304 |
+
indices = torch.randint(low=0, high=num_task_samples, size=(self.batch_size,), generator=generator).tolist()
|
305 |
+
# Converts the selected indices to the global indices on the given dataset.
|
306 |
+
results = (self.dataset_offsets[batch_task] + torch.tensor(self.rank_indices[batch_task])[indices]).tolist()
|
307 |
+
yield results
|
308 |
+
|
309 |
+
def __len__(self):
|
310 |
+
return self.num_batches_per_epoch
|
311 |
+
|
312 |
+
def set_epoch(self, epoch):
|
313 |
+
self.epoch = epoch
|
314 |
+
|
315 |
+
def make_dataset_pie_plot(domains, traj_nums):
|
316 |
+
"""draw the dataset mixture as a pie plot"""
|
317 |
+
new_domains = []
|
318 |
+
for idx, domain in enumerate(domains):
|
319 |
+
new_domains.append(domain)
|
320 |
+
plt.cla()
|
321 |
+
fig1, ax1 = plt.subplots(figsize=(40, 40))
|
322 |
+
traj_prob = np.array(traj_nums) / np.sum(traj_nums)
|
323 |
+
tab20 = plt.get_cmap("tab20").colors
|
324 |
+
tab20b = plt.get_cmap("tab20b").colors
|
325 |
+
tab20c = plt.get_cmap("tab20c").colors
|
326 |
+
|
327 |
+
# Combine them to get 60 distinct colors
|
328 |
+
colors = tab20 + tab20b + tab20c
|
329 |
+
patches, _ = ax1.pie(traj_prob, startangle=90, colors=colors[: len(traj_prob)])
|
330 |
+
ax1.axis("equal")
|
331 |
+
ax1.legend(patches, new_domains, loc="center left", bbox_to_anchor=(0.8, 0.5), prop={"size": 32})
|
332 |
+
fig1.canvas.draw()
|
333 |
+
|
334 |
+
return Image.frombytes("RGB", fig1.canvas.get_width_height(), fig1.canvas.tostring_rgb())
|
common/eval_utils.py
ADDED
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Callable
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torchvision.transforms.functional as transforms_f
|
5 |
+
from einops import rearrange
|
6 |
+
|
7 |
+
from genie.factorization_utils import factorize_labels
|
8 |
+
|
9 |
+
|
10 |
+
class AvgMetric:
|
11 |
+
""" Records a running sum and count to compute the mean. """
|
12 |
+
def __init__(self):
|
13 |
+
self.total = 0
|
14 |
+
self.count = 0
|
15 |
+
|
16 |
+
def update(self, val, batch_size=1):
|
17 |
+
self.total += val * batch_size
|
18 |
+
self.count += batch_size
|
19 |
+
|
20 |
+
def update_list(self, flat_vals):
|
21 |
+
self.total += sum(flat_vals)
|
22 |
+
self.count += len(flat_vals)
|
23 |
+
|
24 |
+
def mean(self):
|
25 |
+
if self.count == 0:
|
26 |
+
return 0
|
27 |
+
return self.total / self.count
|
28 |
+
|
29 |
+
|
30 |
+
def decode_tokens(reshaped_token_ids: torch.LongTensor, decode_latents: Callable) -> torch.ByteTensor:
|
31 |
+
"""
|
32 |
+
Converts quantized latent space tokens to images.
|
33 |
+
|
34 |
+
Args:
|
35 |
+
reshaped_token_ids: shape (B, T, H, W).
|
36 |
+
decode_latents: instance of `decode_latents_wrapper()`
|
37 |
+
|
38 |
+
Returns:
|
39 |
+
(B, T, 3, 256, 256)
|
40 |
+
"""
|
41 |
+
decoded_imgs = decode_latents(rearrange(reshaped_token_ids, "b t h w -> (b t) h w").cpu().numpy())
|
42 |
+
decoded_tensor = torch.stack([transforms_f.pil_to_tensor(pred_img) for pred_img in decoded_imgs])
|
43 |
+
return rearrange(decoded_tensor, "(b t) c H W -> b t c H W", b=reshaped_token_ids.size(0))
|
44 |
+
|
45 |
+
def decode_features(reshaped_token_ids: torch.LongTensor, decode_latents: Callable) -> torch.ByteTensor:
|
46 |
+
"""
|
47 |
+
Converts quantized latent space tokens to images.
|
48 |
+
|
49 |
+
Args:
|
50 |
+
reshaped_token_ids: shape (B, T, H, W).
|
51 |
+
decode_latents: instance of `decode_latents_wrapper()`
|
52 |
+
|
53 |
+
Returns:
|
54 |
+
(B, T, 3, 256, 256)
|
55 |
+
"""
|
56 |
+
decoded_imgs = decode_latents(rearrange(reshaped_token_ids, "b t h w c -> (b t) c h w").cpu().numpy())
|
57 |
+
decoded_tensor = torch.stack([transforms_f.pil_to_tensor(pred_img) for pred_img in decoded_imgs])
|
58 |
+
return rearrange(decoded_tensor, "(b t) c H W -> b t c H W", b=reshaped_token_ids.size(0))
|
59 |
+
|
60 |
+
|
61 |
+
def compute_loss(
|
62 |
+
labels_flat: torch.LongTensor,
|
63 |
+
factored_logits: torch.FloatTensor,
|
64 |
+
num_factored_vocabs: int = 2,
|
65 |
+
factored_vocab_size: int = 512,
|
66 |
+
) -> float:
|
67 |
+
"""
|
68 |
+
If applicable (model returns logits), compute the cross entropy loss.
|
69 |
+
In the case of a factorized vocabulary, sums the cross entropy losses for each vocabulary.
|
70 |
+
|
71 |
+
Assuming all submissions use the parametrization of num_factored_vocabs = 2, factored_vocab_size = 512
|
72 |
+
|
73 |
+
Args:
|
74 |
+
labels_flat: size (B, T*H*W) corresponding to flattened, tokenized images.
|
75 |
+
factored_logits: size (B, factored_vocab_size, num_factored_vocabs, T-1, H, W).
|
76 |
+
E.g. output of `genie.evaluate.GenieEvaluator.predict_zframe_logits()`
|
77 |
+
num_factored_vocabs: Should be 2 for v1.0 of the challenge.
|
78 |
+
factored_vocab_size: Should be 512 for v1.0 of the challenge.
|
79 |
+
Returns:
|
80 |
+
Cross entropy loss
|
81 |
+
"""
|
82 |
+
assert factored_logits.dim() == 6 \
|
83 |
+
and factored_logits.size()[:3] == (labels_flat.size(0), factored_vocab_size, num_factored_vocabs), \
|
84 |
+
f"Shape of `logits` should be (B, {factored_vocab_size}, {num_factored_vocabs}, T-1, H, W)"
|
85 |
+
t = factored_logits.size(3) + 1
|
86 |
+
h, w = factored_logits.size()[-2:]
|
87 |
+
assert t * h * w == labels_flat.size(1), "Shape of `factored_logits` does not match flattened latent image size."
|
88 |
+
|
89 |
+
labels_THW = rearrange(labels_flat, "b (t h w) -> b t h w", t=t, h=h, w=w)
|
90 |
+
labels_THW = labels_THW[:, 1:].to(factored_logits.device)
|
91 |
+
|
92 |
+
factored_labels = factorize_labels(labels_THW, num_factored_vocabs, factored_vocab_size)
|
93 |
+
return torch.nn.functional.cross_entropy(factored_logits, factored_labels, reduction="none")\
|
94 |
+
.sum(dim=1).mean().item() # Final loss is the sum of the two losses across the size-512 vocabularies
|
95 |
+
|
96 |
+
|
97 |
+
def compute_lpips(frames_a: torch.ByteTensor, frames_b: torch.ByteTensor, lpips_func: Callable) -> list:
|
98 |
+
"""
|
99 |
+
Given two batches of video data, of shape (B, T, 3, 256, 256), computes the LPIPS score on frame-by-frame level.
|
100 |
+
Cannot use `lpips_func` directly because it expects at most 4D input.
|
101 |
+
"""
|
102 |
+
# LPIPS expects pixel values between [-1, 1]
|
103 |
+
flattened_a, flattened_b = [rearrange(frames / 127.5 - 1, "b t c H W -> (b t) c H W")
|
104 |
+
for frames in (frames_a, frames_b)]
|
105 |
+
return lpips_func(flattened_a, flattened_b).flatten().tolist()
|
common/fid_score.py
ADDED
@@ -0,0 +1,382 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Calculates the Frechet Inception Distance (FID) to evalulate GANs
|
2 |
+
|
3 |
+
The FID metric calculates the distance between two distributions of images.
|
4 |
+
Typically, we have summary statistics (mean & covariance matrix) of one
|
5 |
+
of these distributions, while the 2nd distribution is given by a GAN.
|
6 |
+
|
7 |
+
When run as a stand-alone program, it compares the distribution of
|
8 |
+
images that are stored as PNG/JPEG at a specified location with a
|
9 |
+
distribution given by summary statistics (in pickle format).
|
10 |
+
|
11 |
+
The FID is calculated by assuming that X_1 and X_2 are the activations of
|
12 |
+
the pool_3 layer of the inception net for generated samples and real world
|
13 |
+
samples respectively.
|
14 |
+
|
15 |
+
See --help to see further details.
|
16 |
+
|
17 |
+
Code apapted from https://github.com/bioinf-jku/TTUR to use PyTorch instead
|
18 |
+
of Tensorflow
|
19 |
+
|
20 |
+
Copyright 2018 Institute of Bioinformatics, JKU Linz
|
21 |
+
|
22 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
23 |
+
you may not use this file except in compliance with the License.
|
24 |
+
You may obtain a copy of the License at
|
25 |
+
|
26 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
27 |
+
|
28 |
+
Unless required by applicable law or agreed to in writing, software
|
29 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
30 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
31 |
+
See the License for the specific language governing permissions and
|
32 |
+
limitations under the License.
|
33 |
+
"""
|
34 |
+
# code adapted from https://github.com/mseitzer/pytorch-fid/tree/master
|
35 |
+
|
36 |
+
import os
|
37 |
+
import pathlib
|
38 |
+
from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
|
39 |
+
|
40 |
+
import numpy as np
|
41 |
+
import torch
|
42 |
+
import torchvision.transforms as TF
|
43 |
+
from PIL import Image
|
44 |
+
from scipy import linalg
|
45 |
+
from torch.nn.functional import adaptive_avg_pool2d
|
46 |
+
|
47 |
+
try:
|
48 |
+
from tqdm import tqdm
|
49 |
+
except ImportError:
|
50 |
+
# If tqdm is not available, provide a mock version of it
|
51 |
+
def tqdm(x):
|
52 |
+
return x
|
53 |
+
|
54 |
+
|
55 |
+
from .inception import InceptionV3
|
56 |
+
|
57 |
+
IMAGE_EXTENSIONS = {"bmp", "jpg", "jpeg", "pgm", "png", "ppm", "tif", "tiff", "webp"}
|
58 |
+
|
59 |
+
|
60 |
+
class ImagePathDataset(torch.utils.data.Dataset):
|
61 |
+
def __init__(self, files, transforms=None):
|
62 |
+
self.files = files
|
63 |
+
self.transforms = transforms
|
64 |
+
|
65 |
+
def __len__(self):
|
66 |
+
return len(self.files)
|
67 |
+
|
68 |
+
def __getitem__(self, i):
|
69 |
+
path = self.files[i]
|
70 |
+
img = Image.open(path).convert("RGB")
|
71 |
+
if self.transforms is not None:
|
72 |
+
img = self.transforms(img)
|
73 |
+
return img
|
74 |
+
|
75 |
+
|
76 |
+
def get_activations(
|
77 |
+
files, model, batch_size=50, dims=2048, device="cpu", num_workers=1
|
78 |
+
):
|
79 |
+
"""Calculates the activations of the pool_3 layer for all images.
|
80 |
+
|
81 |
+
Params:
|
82 |
+
-- files : List of image files paths
|
83 |
+
-- model : Instance of inception model
|
84 |
+
-- batch_size : Batch size of images for the model to process at once.
|
85 |
+
Make sure that the number of samples is a multiple of
|
86 |
+
the batch size, otherwise some samples are ignored. This
|
87 |
+
behavior is retained to match the original FID score
|
88 |
+
implementation.
|
89 |
+
-- dims : Dimensionality of features returned by Inception
|
90 |
+
-- device : Device to run calculations
|
91 |
+
-- num_workers : Number of parallel dataloader workers
|
92 |
+
|
93 |
+
Returns:
|
94 |
+
-- A numpy array of dimension (num images, dims) that contains the
|
95 |
+
activations of the given tensor when feeding inception with the
|
96 |
+
query tensor.
|
97 |
+
"""
|
98 |
+
model.eval()
|
99 |
+
|
100 |
+
if batch_size > len(files):
|
101 |
+
print(
|
102 |
+
(
|
103 |
+
"Warning: batch size is bigger than the data size. "
|
104 |
+
"Setting batch size to data size"
|
105 |
+
)
|
106 |
+
)
|
107 |
+
batch_size = len(files)
|
108 |
+
|
109 |
+
dataset = ImagePathDataset(files, transforms=TF.ToTensor())
|
110 |
+
dataloader = torch.utils.data.DataLoader(
|
111 |
+
dataset,
|
112 |
+
batch_size=batch_size,
|
113 |
+
shuffle=False,
|
114 |
+
drop_last=False,
|
115 |
+
num_workers=num_workers,
|
116 |
+
)
|
117 |
+
|
118 |
+
pred_arr = np.empty((len(files), dims))
|
119 |
+
|
120 |
+
start_idx = 0
|
121 |
+
|
122 |
+
for batch in tqdm(dataloader):
|
123 |
+
batch = batch.to(device)
|
124 |
+
|
125 |
+
with torch.no_grad():
|
126 |
+
pred = model(batch)[0]
|
127 |
+
|
128 |
+
# If model output is not scalar, apply global spatial average pooling.
|
129 |
+
# This happens if you choose a dimensionality not equal 2048.
|
130 |
+
if pred.size(2) != 1 or pred.size(3) != 1:
|
131 |
+
pred = adaptive_avg_pool2d(pred, output_size=(1, 1))
|
132 |
+
|
133 |
+
pred = pred.squeeze(3).squeeze(2).cpu().numpy()
|
134 |
+
|
135 |
+
pred_arr[start_idx : start_idx + pred.shape[0]] = pred
|
136 |
+
|
137 |
+
start_idx = start_idx + pred.shape[0]
|
138 |
+
|
139 |
+
return pred_arr
|
140 |
+
|
141 |
+
def get_activations_images(
|
142 |
+
dataset, model, batch_size=50, dims=2048, device="cpu", num_workers=0
|
143 |
+
):
|
144 |
+
"""Calculates the activations of the pool_3 layer for all images.
|
145 |
+
|
146 |
+
Params:
|
147 |
+
-- files : List of image files paths
|
148 |
+
-- model : Instance of inception model
|
149 |
+
-- batch_size : Batch size of images for the model to process at once.
|
150 |
+
Make sure that the number of samples is a multiple of
|
151 |
+
the batch size, otherwise some samples are ignored. This
|
152 |
+
behavior is retained to match the original FID score
|
153 |
+
implementation.
|
154 |
+
-- dims : Dimensionality of features returned by Inception
|
155 |
+
-- device : Device to run calculations
|
156 |
+
-- num_workers : Number of parallel dataloader workers
|
157 |
+
|
158 |
+
Returns:
|
159 |
+
-- A numpy array of dimension (num images, dims) that contains the
|
160 |
+
activations of the given tensor when feeding inception with the
|
161 |
+
query tensor.
|
162 |
+
"""
|
163 |
+
model.eval()
|
164 |
+
# import IPython; IPython.embed()
|
165 |
+
# combine batch and temporal
|
166 |
+
dataset = torch.cat([dataset[:, i] for i in range(dataset.shape[1])], dim=0).to("cpu")
|
167 |
+
dataloader = torch.utils.data.DataLoader(
|
168 |
+
dataset,
|
169 |
+
batch_size=batch_size,
|
170 |
+
shuffle=False,
|
171 |
+
drop_last=True,
|
172 |
+
num_workers=num_workers,
|
173 |
+
)
|
174 |
+
|
175 |
+
pred_arr = np.empty((len(dataset), dims))
|
176 |
+
|
177 |
+
start_idx = 0
|
178 |
+
|
179 |
+
for batch in tqdm(dataloader):
|
180 |
+
batch = batch.to(device)
|
181 |
+
|
182 |
+
with torch.no_grad():
|
183 |
+
pred = model(batch)[0]
|
184 |
+
|
185 |
+
# If model output is not scalar, apply global spatial average pooling.
|
186 |
+
# This happens if you choose a dimensionality not equal 2048.
|
187 |
+
if pred.size(2) != 1 or pred.size(3) != 1:
|
188 |
+
pred = adaptive_avg_pool2d(pred, output_size=(1, 1))
|
189 |
+
|
190 |
+
pred = pred.squeeze(3).squeeze(2).cpu().numpy()
|
191 |
+
|
192 |
+
pred_arr[start_idx : start_idx + pred.shape[0]] = pred
|
193 |
+
|
194 |
+
start_idx = start_idx + pred.shape[0]
|
195 |
+
|
196 |
+
return pred_arr
|
197 |
+
|
198 |
+
|
199 |
+
def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
|
200 |
+
"""Numpy implementation of the Frechet Distance.
|
201 |
+
The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
|
202 |
+
and X_2 ~ N(mu_2, C_2) is
|
203 |
+
d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).
|
204 |
+
|
205 |
+
Stable version by Dougal J. Sutherland.
|
206 |
+
|
207 |
+
Params:
|
208 |
+
-- mu1 : Numpy array containing the activations of a layer of the
|
209 |
+
inception net (like returned by the function 'get_predictions')
|
210 |
+
for generated samples.
|
211 |
+
-- mu2 : The sample mean over activations, precalculated on an
|
212 |
+
representative data set.
|
213 |
+
-- sigma1: The covariance matrix over activations for generated samples.
|
214 |
+
-- sigma2: The covariance matrix over activations, precalculated on an
|
215 |
+
representative data set.
|
216 |
+
|
217 |
+
Returns:
|
218 |
+
-- : The Frechet Distance.
|
219 |
+
"""
|
220 |
+
|
221 |
+
mu1 = np.atleast_1d(mu1)
|
222 |
+
mu2 = np.atleast_1d(mu2)
|
223 |
+
|
224 |
+
sigma1 = np.atleast_2d(sigma1)
|
225 |
+
sigma2 = np.atleast_2d(sigma2)
|
226 |
+
|
227 |
+
assert (
|
228 |
+
mu1.shape == mu2.shape
|
229 |
+
), "Training and test mean vectors have different lengths"
|
230 |
+
assert (
|
231 |
+
sigma1.shape == sigma2.shape
|
232 |
+
), "Training and test covariances have different dimensions"
|
233 |
+
|
234 |
+
diff = mu1 - mu2
|
235 |
+
|
236 |
+
# Product might be almost singular
|
237 |
+
covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
|
238 |
+
if not np.isfinite(covmean).all():
|
239 |
+
msg = (
|
240 |
+
"fid calculation produces singular product; "
|
241 |
+
"adding %s to diagonal of cov estimates"
|
242 |
+
) % eps
|
243 |
+
print(msg)
|
244 |
+
offset = np.eye(sigma1.shape[0]) * eps
|
245 |
+
covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
|
246 |
+
|
247 |
+
# Numerical error might give slight imaginary component
|
248 |
+
if np.iscomplexobj(covmean):
|
249 |
+
if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
|
250 |
+
m = np.max(np.abs(covmean.imag))
|
251 |
+
raise ValueError("Imaginary component {}".format(m))
|
252 |
+
covmean = covmean.real
|
253 |
+
|
254 |
+
tr_covmean = np.trace(covmean)
|
255 |
+
|
256 |
+
return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean
|
257 |
+
|
258 |
+
|
259 |
+
def calculate_activation_statistics(
|
260 |
+
images, model, batch_size=50, dims=2048, device="cpu", num_workers=1
|
261 |
+
):
|
262 |
+
"""Calculation of the statistics used by the FID.
|
263 |
+
Params:
|
264 |
+
-- files : List of image files paths
|
265 |
+
-- model : Instance of inception model
|
266 |
+
-- batch_size : The images numpy array is split into batches with
|
267 |
+
batch size batch_size. A reasonable batch size
|
268 |
+
depends on the hardware.
|
269 |
+
-- dims : Dimensionality of features returned by Inception
|
270 |
+
-- device : Device to run calculations
|
271 |
+
-- num_workers : Number of parallel dataloader workers
|
272 |
+
|
273 |
+
Returns:
|
274 |
+
-- mu : The mean over samples of the activations of the pool_3 layer of
|
275 |
+
the inception model.
|
276 |
+
-- sigma : The covariance matrix of the activations of the pool_3 layer of
|
277 |
+
the inception model.
|
278 |
+
"""
|
279 |
+
act = get_activations_images(images, model, batch_size, dims, device, num_workers)
|
280 |
+
mu = np.mean(act, axis=0)
|
281 |
+
sigma = np.cov(act, rowvar=False)
|
282 |
+
return mu, sigma
|
283 |
+
|
284 |
+
|
285 |
+
def compute_statistics(images, model, batch_size, dims, device, num_workers=1):
|
286 |
+
m, s = calculate_activation_statistics(
|
287 |
+
images, model, batch_size, dims, device, num_workers
|
288 |
+
)
|
289 |
+
return m, s
|
290 |
+
|
291 |
+
def calculate_fid(pred_images, gt_images, batch_size=16, device="cuda", dims=2048, num_workers=1):
|
292 |
+
"""Calculates the FID of two paths"""
|
293 |
+
|
294 |
+
block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims]
|
295 |
+
|
296 |
+
model = InceptionV3([block_idx]).to(device)
|
297 |
+
|
298 |
+
m1, s1 = compute_statistics(
|
299 |
+
pred_images, model, batch_size, dims, device, num_workers
|
300 |
+
)
|
301 |
+
m2, s2 = compute_statistics(
|
302 |
+
gt_images, model, batch_size, dims, device, num_workers
|
303 |
+
)
|
304 |
+
fid_value = calculate_frechet_distance(m1, s1, m2, s2)
|
305 |
+
|
306 |
+
return fid_value
|
307 |
+
|
308 |
+
def calculate_fid_given_paths(paths, batch_size, device, dims, num_workers=1):
|
309 |
+
"""Calculates the FID of two paths"""
|
310 |
+
for p in paths:
|
311 |
+
if not os.path.exists(p):
|
312 |
+
raise RuntimeError("Invalid path: %s" % p)
|
313 |
+
|
314 |
+
block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims]
|
315 |
+
|
316 |
+
model = InceptionV3([block_idx]).to(device)
|
317 |
+
|
318 |
+
m1, s1 = compute_statistics_of_path(
|
319 |
+
paths[0], model, batch_size, dims, device, num_workers
|
320 |
+
)
|
321 |
+
m2, s2 = compute_statistics_of_path(
|
322 |
+
paths[1], model, batch_size, dims, device, num_workers
|
323 |
+
)
|
324 |
+
fid_value = calculate_frechet_distance(m1, s1, m2, s2)
|
325 |
+
|
326 |
+
return fid_value
|
327 |
+
|
328 |
+
|
329 |
+
def save_fid_stats(paths, batch_size, device, dims, num_workers=1):
|
330 |
+
"""Saves FID statistics of one path"""
|
331 |
+
if not os.path.exists(paths[0]):
|
332 |
+
raise RuntimeError("Invalid path: %s" % paths[0])
|
333 |
+
|
334 |
+
if os.path.exists(paths[1]):
|
335 |
+
raise RuntimeError("Existing output file: %s" % paths[1])
|
336 |
+
|
337 |
+
block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims]
|
338 |
+
|
339 |
+
model = InceptionV3([block_idx]).to(device)
|
340 |
+
|
341 |
+
print(f"Saving statistics for {paths[0]}")
|
342 |
+
|
343 |
+
m1, s1 = compute_statistics_of_path(
|
344 |
+
paths[0], model, batch_size, dims, device, num_workers
|
345 |
+
)
|
346 |
+
|
347 |
+
np.savez_compressed(paths[1], mu=m1, sigma=s1)
|
348 |
+
|
349 |
+
|
350 |
+
def main():
|
351 |
+
args = parser.parse_args()
|
352 |
+
|
353 |
+
if args.device is None:
|
354 |
+
device = torch.device("cuda" if (torch.cuda.is_available()) else "cpu")
|
355 |
+
else:
|
356 |
+
device = torch.device(args.device)
|
357 |
+
|
358 |
+
if args.num_workers is None:
|
359 |
+
try:
|
360 |
+
num_cpus = len(os.sched_getaffinity(0))
|
361 |
+
except AttributeError:
|
362 |
+
# os.sched_getaffinity is not available under Windows, use
|
363 |
+
# os.cpu_count instead (which may not return the *available* number
|
364 |
+
# of CPUs).
|
365 |
+
num_cpus = os.cpu_count()
|
366 |
+
|
367 |
+
num_workers = min(num_cpus, 8) if num_cpus is not None else 0
|
368 |
+
else:
|
369 |
+
num_workers = args.num_workers
|
370 |
+
|
371 |
+
if args.save_stats:
|
372 |
+
save_fid_stats(args.path, args.batch_size, device, args.dims, num_workers)
|
373 |
+
return
|
374 |
+
|
375 |
+
fid_value = calculate_fid_given_paths(
|
376 |
+
args.path, args.batch_size, device, args.dims, num_workers
|
377 |
+
)
|
378 |
+
print("FID: ", fid_value)
|
379 |
+
|
380 |
+
|
381 |
+
if __name__ == "__main__":
|
382 |
+
main()
|
common/fvd/styleganv/fvd.py
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import os
|
3 |
+
import math
|
4 |
+
import torch.nn.functional as F
|
5 |
+
|
6 |
+
# https://github.com/universome/fvd-comparison
|
7 |
+
|
8 |
+
|
9 |
+
def load_i3d_pretrained(device=torch.device('cpu')):
|
10 |
+
i3D_WEIGHTS_URL = "https://www.dropbox.com/s/ge9e5ujwgetktms/i3d_torchscript.pt"
|
11 |
+
filepath = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'i3d_torchscript.pt')
|
12 |
+
print(filepath)
|
13 |
+
if not os.path.exists(filepath):
|
14 |
+
print(f"preparing for download {i3D_WEIGHTS_URL}, you can download it by yourself.")
|
15 |
+
os.system(f"wget {i3D_WEIGHTS_URL} -O {filepath}")
|
16 |
+
i3d = torch.jit.load(filepath).eval().to(device)
|
17 |
+
i3d = torch.nn.DataParallel(i3d)
|
18 |
+
return i3d
|
19 |
+
|
20 |
+
|
21 |
+
def get_feats(videos, detector, device, bs=10):
|
22 |
+
# videos : torch.tensor BCTHW [0, 1]
|
23 |
+
detector_kwargs = dict(rescale=False, resize=False, return_features=True) # Return raw features before the softmax layer.
|
24 |
+
feats = np.empty((0, 400))
|
25 |
+
with torch.no_grad():
|
26 |
+
for i in range((len(videos)-1)//bs + 1):
|
27 |
+
feats = np.vstack([feats, detector(torch.stack([preprocess_single(video) for video in videos[i*bs:(i+1)*bs]]).to(device), **detector_kwargs).detach().cpu().numpy()])
|
28 |
+
return feats
|
29 |
+
|
30 |
+
|
31 |
+
def get_fvd_feats(videos, i3d, device, bs=10):
|
32 |
+
# videos in [0, 1] as torch tensor BCTHW
|
33 |
+
# videos = [preprocess_single(video) for video in videos]
|
34 |
+
embeddings = get_feats(videos, i3d, device, bs)
|
35 |
+
return embeddings
|
36 |
+
|
37 |
+
|
38 |
+
def preprocess_single(video, resolution=224, sequence_length=None):
|
39 |
+
# video: CTHW, [0, 1]
|
40 |
+
c, t, h, w = video.shape
|
41 |
+
|
42 |
+
# temporal crop
|
43 |
+
if sequence_length is not None:
|
44 |
+
assert sequence_length <= t
|
45 |
+
video = video[:, :sequence_length]
|
46 |
+
|
47 |
+
# scale shorter side to resolution
|
48 |
+
scale = resolution / min(h, w)
|
49 |
+
if h < w:
|
50 |
+
target_size = (resolution, math.ceil(w * scale))
|
51 |
+
else:
|
52 |
+
target_size = (math.ceil(h * scale), resolution)
|
53 |
+
video = F.interpolate(video, size=target_size, mode='bilinear', align_corners=False)
|
54 |
+
|
55 |
+
# center crop
|
56 |
+
c, t, h, w = video.shape
|
57 |
+
w_start = (w - resolution) // 2
|
58 |
+
h_start = (h - resolution) // 2
|
59 |
+
video = video[:, :, h_start:h_start + resolution, w_start:w_start + resolution]
|
60 |
+
|
61 |
+
# [0, 1] -> [-1, 1]
|
62 |
+
video = (video - 0.5) * 2
|
63 |
+
|
64 |
+
return video.contiguous()
|
65 |
+
|
66 |
+
|
67 |
+
"""
|
68 |
+
Copy-pasted from https://github.com/cvpr2022-stylegan-v/stylegan-v/blob/main/src/metrics/frechet_video_distance.py
|
69 |
+
"""
|
70 |
+
from typing import Tuple
|
71 |
+
from scipy.linalg import sqrtm
|
72 |
+
import numpy as np
|
73 |
+
|
74 |
+
|
75 |
+
def compute_stats(feats: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
|
76 |
+
mu = feats.mean(axis=0) # [d]
|
77 |
+
sigma = np.cov(feats, rowvar=False) # [d, d]
|
78 |
+
return mu, sigma
|
79 |
+
|
80 |
+
|
81 |
+
def frechet_distance(feats_fake: np.ndarray, feats_real: np.ndarray) -> float:
|
82 |
+
mu_gen, sigma_gen = compute_stats(feats_fake)
|
83 |
+
mu_real, sigma_real = compute_stats(feats_real)
|
84 |
+
m = np.square(mu_gen - mu_real).sum()
|
85 |
+
if feats_fake.shape[0]>1:
|
86 |
+
s, _ = sqrtm(np.dot(sigma_gen, sigma_real), disp=False) # pylint: disable=no-member
|
87 |
+
fid = np.real(m + np.trace(sigma_gen + sigma_real - s * 2))
|
88 |
+
else:
|
89 |
+
fid = np.real(m)
|
90 |
+
return float(fid)
|
common/fvd/styleganv/i3d_torchscript.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:bec6519f66ea534e953026b4ae2c65553c17bf105611c746d904657e5860a5e2
|
3 |
+
size 51235320
|
common/fvd/videogpt/fvd.py
ADDED
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import os
|
3 |
+
import math
|
4 |
+
import torch.nn.functional as F
|
5 |
+
import numpy as np
|
6 |
+
import einops
|
7 |
+
|
8 |
+
def load_i3d_pretrained(device=torch.device('cpu')):
|
9 |
+
i3D_WEIGHTS_URL = "https://onedrive.live.com/download?cid=78EEF3EB6AE7DBCB&resid=78EEF3EB6AE7DBCB%21199&authkey=AApKdFHPXzWLNyI"
|
10 |
+
filepath = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'i3d_pretrained_400.pt')
|
11 |
+
print(filepath)
|
12 |
+
if not os.path.exists(filepath):
|
13 |
+
print(f"preparing for download {i3D_WEIGHTS_URL}, you can download it by yourself.")
|
14 |
+
os.system(f"wget {i3D_WEIGHTS_URL} -O {filepath}")
|
15 |
+
from .pytorch_i3d import InceptionI3d
|
16 |
+
i3d = InceptionI3d(400, in_channels=3).eval().to(device)
|
17 |
+
i3d.load_state_dict(torch.load(filepath, map_location=device))
|
18 |
+
i3d = torch.nn.DataParallel(i3d)
|
19 |
+
return i3d
|
20 |
+
|
21 |
+
def preprocess_single(video, resolution, sequence_length=None):
|
22 |
+
# video: THWC, {0, ..., 255}
|
23 |
+
video = video.permute(0, 3, 1, 2).float() / 255. # TCHW
|
24 |
+
t, c, h, w = video.shape
|
25 |
+
|
26 |
+
# temporal crop
|
27 |
+
if sequence_length is not None:
|
28 |
+
assert sequence_length <= t
|
29 |
+
video = video[:sequence_length]
|
30 |
+
|
31 |
+
# scale shorter side to resolution
|
32 |
+
scale = resolution / min(h, w)
|
33 |
+
if h < w:
|
34 |
+
target_size = (resolution, math.ceil(w * scale))
|
35 |
+
else:
|
36 |
+
target_size = (math.ceil(h * scale), resolution)
|
37 |
+
video = F.interpolate(video, size=target_size, mode='bilinear',
|
38 |
+
align_corners=False)
|
39 |
+
|
40 |
+
# center crop
|
41 |
+
t, c, h, w = video.shape
|
42 |
+
w_start = (w - resolution) // 2
|
43 |
+
h_start = (h - resolution) // 2
|
44 |
+
video = video[:, :, h_start:h_start + resolution, w_start:w_start + resolution]
|
45 |
+
video = video.permute(1, 0, 2, 3).contiguous() # CTHW
|
46 |
+
|
47 |
+
video -= 0.5
|
48 |
+
|
49 |
+
return video
|
50 |
+
|
51 |
+
def preprocess(videos, target_resolution=224):
|
52 |
+
# we should tras videos in [0-1] [b c t h w] as th.float
|
53 |
+
# -> videos in {0, ..., 255} [b t h w c] as np.uint8 array
|
54 |
+
videos = einops.rearrange(videos, 'b c t h w -> b t h w c')
|
55 |
+
videos = (videos*255).numpy().astype(np.uint8)
|
56 |
+
|
57 |
+
b, t, h, w, c = videos.shape
|
58 |
+
videos = torch.from_numpy(videos)
|
59 |
+
videos = torch.stack([preprocess_single(video, target_resolution) for video in videos])
|
60 |
+
return videos * 2 # [-0.5, 0.5] -> [-1, 1]
|
61 |
+
|
62 |
+
def get_fvd_logits(videos, i3d, device, bs=10):
|
63 |
+
videos = preprocess(videos)
|
64 |
+
embeddings = get_logits(i3d, videos, device, bs=10)
|
65 |
+
return embeddings
|
66 |
+
|
67 |
+
# https://github.com/tensorflow/gan/blob/de4b8da3853058ea380a6152bd3bd454013bf619/tensorflow_gan/python/eval/classifier_metrics.py#L161
|
68 |
+
def _symmetric_matrix_square_root(mat, eps=1e-10):
|
69 |
+
u, s, v = torch.svd(mat)
|
70 |
+
si = torch.where(s < eps, s, torch.sqrt(s))
|
71 |
+
return torch.matmul(torch.matmul(u, torch.diag(si)), v.t())
|
72 |
+
|
73 |
+
# https://github.com/tensorflow/gan/blob/de4b8da3853058ea380a6152bd3bd454013bf619/tensorflow_gan/python/eval/classifier_metrics.py#L400
|
74 |
+
def trace_sqrt_product(sigma, sigma_v):
|
75 |
+
sqrt_sigma = _symmetric_matrix_square_root(sigma)
|
76 |
+
sqrt_a_sigmav_a = torch.matmul(sqrt_sigma, torch.matmul(sigma_v, sqrt_sigma))
|
77 |
+
return torch.trace(_symmetric_matrix_square_root(sqrt_a_sigmav_a))
|
78 |
+
|
79 |
+
# https://discuss.pytorch.org/t/covariance-and-gradient-support/16217/2
|
80 |
+
def cov(m, rowvar=False):
|
81 |
+
'''Estimate a covariance matrix given data.
|
82 |
+
|
83 |
+
Covariance indicates the level to which two variables vary together.
|
84 |
+
If we examine N-dimensional samples, `X = [x_1, x_2, ... x_N]^T`,
|
85 |
+
then the covariance matrix element `C_{ij}` is the covariance of
|
86 |
+
`x_i` and `x_j`. The element `C_{ii}` is the variance of `x_i`.
|
87 |
+
|
88 |
+
Args:
|
89 |
+
m: A 1-D or 2-D array containing multiple variables and observations.
|
90 |
+
Each row of `m` represents a variable, and each column a single
|
91 |
+
observation of all those variables.
|
92 |
+
rowvar: If `rowvar` is True, then each row represents a
|
93 |
+
variable, with observations in the columns. Otherwise, the
|
94 |
+
relationship is transposed: each column represents a variable,
|
95 |
+
while the rows contain observations.
|
96 |
+
|
97 |
+
Returns:
|
98 |
+
The covariance matrix of the variables.
|
99 |
+
'''
|
100 |
+
if m.dim() > 2:
|
101 |
+
raise ValueError('m has more than 2 dimensions')
|
102 |
+
if m.dim() < 2:
|
103 |
+
m = m.view(1, -1)
|
104 |
+
if not rowvar and m.size(0) != 1:
|
105 |
+
m = m.t()
|
106 |
+
|
107 |
+
fact = 1.0 / (m.size(1) - 1) # unbiased estimate
|
108 |
+
m -= torch.mean(m, dim=1, keepdim=True)
|
109 |
+
mt = m.t() # if complex: mt = m.t().conj()
|
110 |
+
return fact * m.matmul(mt).squeeze()
|
111 |
+
|
112 |
+
|
113 |
+
def frechet_distance(x1, x2):
|
114 |
+
x1 = x1.flatten(start_dim=1)
|
115 |
+
x2 = x2.flatten(start_dim=1)
|
116 |
+
m, m_w = x1.mean(dim=0), x2.mean(dim=0)
|
117 |
+
sigma, sigma_w = cov(x1, rowvar=False), cov(x2, rowvar=False)
|
118 |
+
mean = torch.sum((m - m_w) ** 2)
|
119 |
+
if x1.shape[0]>1:
|
120 |
+
sqrt_trace_component = trace_sqrt_product(sigma, sigma_w)
|
121 |
+
trace = torch.trace(sigma + sigma_w) - 2.0 * sqrt_trace_component
|
122 |
+
fd = trace + mean
|
123 |
+
else:
|
124 |
+
fd = np.real(mean)
|
125 |
+
return float(fd)
|
126 |
+
|
127 |
+
|
128 |
+
def get_logits(i3d, videos, device, bs=10):
|
129 |
+
# assert videos.shape[0] % 16 == 0
|
130 |
+
with torch.no_grad():
|
131 |
+
logits = []
|
132 |
+
for i in range(0, videos.shape[0], bs):
|
133 |
+
batch = videos[i:i + bs].to(device)
|
134 |
+
# logits.append(i3d.module.extract_features(batch)) # wrong
|
135 |
+
logits.append(i3d(batch)) # right
|
136 |
+
logits = torch.cat(logits, dim=0)
|
137 |
+
return logits
|
common/fvd/videogpt/i3d_pretrained_400.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:55095f049e706479d48e221adcdb145b2b9dc930ba28b081ed72367ffaa32343
|
3 |
+
size 50939526
|
common/fvd/videogpt/pytorch_i3d.py
ADDED
@@ -0,0 +1,322 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Original code from https://github.com/piergiaj/pytorch-i3d
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
class MaxPool3dSamePadding(nn.MaxPool3d):
|
8 |
+
|
9 |
+
def compute_pad(self, dim, s):
|
10 |
+
if s % self.stride[dim] == 0:
|
11 |
+
return max(self.kernel_size[dim] - self.stride[dim], 0)
|
12 |
+
else:
|
13 |
+
return max(self.kernel_size[dim] - (s % self.stride[dim]), 0)
|
14 |
+
|
15 |
+
def forward(self, x):
|
16 |
+
# compute 'same' padding
|
17 |
+
(batch, channel, t, h, w) = x.size()
|
18 |
+
out_t = np.ceil(float(t) / float(self.stride[0]))
|
19 |
+
out_h = np.ceil(float(h) / float(self.stride[1]))
|
20 |
+
out_w = np.ceil(float(w) / float(self.stride[2]))
|
21 |
+
pad_t = self.compute_pad(0, t)
|
22 |
+
pad_h = self.compute_pad(1, h)
|
23 |
+
pad_w = self.compute_pad(2, w)
|
24 |
+
|
25 |
+
pad_t_f = pad_t // 2
|
26 |
+
pad_t_b = pad_t - pad_t_f
|
27 |
+
pad_h_f = pad_h // 2
|
28 |
+
pad_h_b = pad_h - pad_h_f
|
29 |
+
pad_w_f = pad_w // 2
|
30 |
+
pad_w_b = pad_w - pad_w_f
|
31 |
+
|
32 |
+
pad = (pad_w_f, pad_w_b, pad_h_f, pad_h_b, pad_t_f, pad_t_b)
|
33 |
+
x = F.pad(x, pad)
|
34 |
+
return super(MaxPool3dSamePadding, self).forward(x)
|
35 |
+
|
36 |
+
|
37 |
+
class Unit3D(nn.Module):
|
38 |
+
|
39 |
+
def __init__(self, in_channels,
|
40 |
+
output_channels,
|
41 |
+
kernel_shape=(1, 1, 1),
|
42 |
+
stride=(1, 1, 1),
|
43 |
+
padding=0,
|
44 |
+
activation_fn=F.relu,
|
45 |
+
use_batch_norm=True,
|
46 |
+
use_bias=False,
|
47 |
+
name='unit_3d'):
|
48 |
+
|
49 |
+
"""Initializes Unit3D module."""
|
50 |
+
super(Unit3D, self).__init__()
|
51 |
+
|
52 |
+
self._output_channels = output_channels
|
53 |
+
self._kernel_shape = kernel_shape
|
54 |
+
self._stride = stride
|
55 |
+
self._use_batch_norm = use_batch_norm
|
56 |
+
self._activation_fn = activation_fn
|
57 |
+
self._use_bias = use_bias
|
58 |
+
self.name = name
|
59 |
+
self.padding = padding
|
60 |
+
|
61 |
+
self.conv3d = nn.Conv3d(in_channels=in_channels,
|
62 |
+
out_channels=self._output_channels,
|
63 |
+
kernel_size=self._kernel_shape,
|
64 |
+
stride=self._stride,
|
65 |
+
padding=0, # we always want padding to be 0 here. We will dynamically pad based on input size in forward function
|
66 |
+
bias=self._use_bias)
|
67 |
+
|
68 |
+
if self._use_batch_norm:
|
69 |
+
self.bn = nn.BatchNorm3d(self._output_channels, eps=1e-5, momentum=0.001)
|
70 |
+
|
71 |
+
def compute_pad(self, dim, s):
|
72 |
+
if s % self._stride[dim] == 0:
|
73 |
+
return max(self._kernel_shape[dim] - self._stride[dim], 0)
|
74 |
+
else:
|
75 |
+
return max(self._kernel_shape[dim] - (s % self._stride[dim]), 0)
|
76 |
+
|
77 |
+
|
78 |
+
def forward(self, x):
|
79 |
+
# compute 'same' padding
|
80 |
+
(batch, channel, t, h, w) = x.size()
|
81 |
+
out_t = np.ceil(float(t) / float(self._stride[0]))
|
82 |
+
out_h = np.ceil(float(h) / float(self._stride[1]))
|
83 |
+
out_w = np.ceil(float(w) / float(self._stride[2]))
|
84 |
+
pad_t = self.compute_pad(0, t)
|
85 |
+
pad_h = self.compute_pad(1, h)
|
86 |
+
pad_w = self.compute_pad(2, w)
|
87 |
+
|
88 |
+
pad_t_f = pad_t // 2
|
89 |
+
pad_t_b = pad_t - pad_t_f
|
90 |
+
pad_h_f = pad_h // 2
|
91 |
+
pad_h_b = pad_h - pad_h_f
|
92 |
+
pad_w_f = pad_w // 2
|
93 |
+
pad_w_b = pad_w - pad_w_f
|
94 |
+
|
95 |
+
pad = (pad_w_f, pad_w_b, pad_h_f, pad_h_b, pad_t_f, pad_t_b)
|
96 |
+
x = F.pad(x, pad)
|
97 |
+
|
98 |
+
x = self.conv3d(x)
|
99 |
+
if self._use_batch_norm:
|
100 |
+
x = self.bn(x)
|
101 |
+
if self._activation_fn is not None:
|
102 |
+
x = self._activation_fn(x)
|
103 |
+
return x
|
104 |
+
|
105 |
+
|
106 |
+
|
107 |
+
class InceptionModule(nn.Module):
|
108 |
+
def __init__(self, in_channels, out_channels, name):
|
109 |
+
super(InceptionModule, self).__init__()
|
110 |
+
|
111 |
+
self.b0 = Unit3D(in_channels=in_channels, output_channels=out_channels[0], kernel_shape=[1, 1, 1], padding=0,
|
112 |
+
name=name+'/Branch_0/Conv3d_0a_1x1')
|
113 |
+
self.b1a = Unit3D(in_channels=in_channels, output_channels=out_channels[1], kernel_shape=[1, 1, 1], padding=0,
|
114 |
+
name=name+'/Branch_1/Conv3d_0a_1x1')
|
115 |
+
self.b1b = Unit3D(in_channels=out_channels[1], output_channels=out_channels[2], kernel_shape=[3, 3, 3],
|
116 |
+
name=name+'/Branch_1/Conv3d_0b_3x3')
|
117 |
+
self.b2a = Unit3D(in_channels=in_channels, output_channels=out_channels[3], kernel_shape=[1, 1, 1], padding=0,
|
118 |
+
name=name+'/Branch_2/Conv3d_0a_1x1')
|
119 |
+
self.b2b = Unit3D(in_channels=out_channels[3], output_channels=out_channels[4], kernel_shape=[3, 3, 3],
|
120 |
+
name=name+'/Branch_2/Conv3d_0b_3x3')
|
121 |
+
self.b3a = MaxPool3dSamePadding(kernel_size=[3, 3, 3],
|
122 |
+
stride=(1, 1, 1), padding=0)
|
123 |
+
self.b3b = Unit3D(in_channels=in_channels, output_channels=out_channels[5], kernel_shape=[1, 1, 1], padding=0,
|
124 |
+
name=name+'/Branch_3/Conv3d_0b_1x1')
|
125 |
+
self.name = name
|
126 |
+
|
127 |
+
def forward(self, x):
|
128 |
+
b0 = self.b0(x)
|
129 |
+
b1 = self.b1b(self.b1a(x))
|
130 |
+
b2 = self.b2b(self.b2a(x))
|
131 |
+
b3 = self.b3b(self.b3a(x))
|
132 |
+
return torch.cat([b0,b1,b2,b3], dim=1)
|
133 |
+
|
134 |
+
|
135 |
+
class InceptionI3d(nn.Module):
|
136 |
+
"""Inception-v1 I3D architecture.
|
137 |
+
The model is introduced in:
|
138 |
+
Quo Vadis, Action Recognition? A New Model and the Kinetics Dataset
|
139 |
+
Joao Carreira, Andrew Zisserman
|
140 |
+
https://arxiv.org/pdf/1705.07750v1.pdf.
|
141 |
+
See also the Inception architecture, introduced in:
|
142 |
+
Going deeper with convolutions
|
143 |
+
Christian Szegedy, Wei Liu, Yangqing Jia, Pierre Sermanet, Scott Reed,
|
144 |
+
Dragomir Anguelov, Dumitru Erhan, Vincent Vanhoucke, Andrew Rabinovich.
|
145 |
+
http://arxiv.org/pdf/1409.4842v1.pdf.
|
146 |
+
"""
|
147 |
+
|
148 |
+
# Endpoints of the model in order. During construction, all the endpoints up
|
149 |
+
# to a designated `final_endpoint` are returned in a dictionary as the
|
150 |
+
# second return value.
|
151 |
+
VALID_ENDPOINTS = (
|
152 |
+
'Conv3d_1a_7x7',
|
153 |
+
'MaxPool3d_2a_3x3',
|
154 |
+
'Conv3d_2b_1x1',
|
155 |
+
'Conv3d_2c_3x3',
|
156 |
+
'MaxPool3d_3a_3x3',
|
157 |
+
'Mixed_3b',
|
158 |
+
'Mixed_3c',
|
159 |
+
'MaxPool3d_4a_3x3',
|
160 |
+
'Mixed_4b',
|
161 |
+
'Mixed_4c',
|
162 |
+
'Mixed_4d',
|
163 |
+
'Mixed_4e',
|
164 |
+
'Mixed_4f',
|
165 |
+
'MaxPool3d_5a_2x2',
|
166 |
+
'Mixed_5b',
|
167 |
+
'Mixed_5c',
|
168 |
+
'Logits',
|
169 |
+
'Predictions',
|
170 |
+
)
|
171 |
+
|
172 |
+
def __init__(self, num_classes=400, spatial_squeeze=True,
|
173 |
+
final_endpoint='Logits', name='inception_i3d', in_channels=3, dropout_keep_prob=0.5):
|
174 |
+
"""Initializes I3D model instance.
|
175 |
+
Args:
|
176 |
+
num_classes: The number of outputs in the logit layer (default 400, which
|
177 |
+
matches the Kinetics dataset).
|
178 |
+
spatial_squeeze: Whether to squeeze the spatial dimensions for the logits
|
179 |
+
before returning (default True).
|
180 |
+
final_endpoint: The model contains many possible endpoints.
|
181 |
+
`final_endpoint` specifies the last endpoint for the model to be built
|
182 |
+
up to. In addition to the output at `final_endpoint`, all the outputs
|
183 |
+
at endpoints up to `final_endpoint` will also be returned, in a
|
184 |
+
dictionary. `final_endpoint` must be one of
|
185 |
+
InceptionI3d.VALID_ENDPOINTS (default 'Logits').
|
186 |
+
name: A string (optional). The name of this module.
|
187 |
+
Raises:
|
188 |
+
ValueError: if `final_endpoint` is not recognized.
|
189 |
+
"""
|
190 |
+
|
191 |
+
if final_endpoint not in self.VALID_ENDPOINTS:
|
192 |
+
raise ValueError('Unknown final endpoint %s' % final_endpoint)
|
193 |
+
|
194 |
+
super(InceptionI3d, self).__init__()
|
195 |
+
self._num_classes = num_classes
|
196 |
+
self._spatial_squeeze = spatial_squeeze
|
197 |
+
self._final_endpoint = final_endpoint
|
198 |
+
self.logits = None
|
199 |
+
|
200 |
+
if self._final_endpoint not in self.VALID_ENDPOINTS:
|
201 |
+
raise ValueError('Unknown final endpoint %s' % self._final_endpoint)
|
202 |
+
|
203 |
+
self.end_points = {}
|
204 |
+
end_point = 'Conv3d_1a_7x7'
|
205 |
+
self.end_points[end_point] = Unit3D(in_channels=in_channels, output_channels=64, kernel_shape=[7, 7, 7],
|
206 |
+
stride=(2, 2, 2), padding=(3,3,3), name=name+end_point)
|
207 |
+
if self._final_endpoint == end_point: return
|
208 |
+
|
209 |
+
end_point = 'MaxPool3d_2a_3x3'
|
210 |
+
self.end_points[end_point] = MaxPool3dSamePadding(kernel_size=[1, 3, 3], stride=(1, 2, 2),
|
211 |
+
padding=0)
|
212 |
+
if self._final_endpoint == end_point: return
|
213 |
+
|
214 |
+
end_point = 'Conv3d_2b_1x1'
|
215 |
+
self.end_points[end_point] = Unit3D(in_channels=64, output_channels=64, kernel_shape=[1, 1, 1], padding=0,
|
216 |
+
name=name+end_point)
|
217 |
+
if self._final_endpoint == end_point: return
|
218 |
+
|
219 |
+
end_point = 'Conv3d_2c_3x3'
|
220 |
+
self.end_points[end_point] = Unit3D(in_channels=64, output_channels=192, kernel_shape=[3, 3, 3], padding=1,
|
221 |
+
name=name+end_point)
|
222 |
+
if self._final_endpoint == end_point: return
|
223 |
+
|
224 |
+
end_point = 'MaxPool3d_3a_3x3'
|
225 |
+
self.end_points[end_point] = MaxPool3dSamePadding(kernel_size=[1, 3, 3], stride=(1, 2, 2),
|
226 |
+
padding=0)
|
227 |
+
if self._final_endpoint == end_point: return
|
228 |
+
|
229 |
+
end_point = 'Mixed_3b'
|
230 |
+
self.end_points[end_point] = InceptionModule(192, [64,96,128,16,32,32], name+end_point)
|
231 |
+
if self._final_endpoint == end_point: return
|
232 |
+
|
233 |
+
end_point = 'Mixed_3c'
|
234 |
+
self.end_points[end_point] = InceptionModule(256, [128,128,192,32,96,64], name+end_point)
|
235 |
+
if self._final_endpoint == end_point: return
|
236 |
+
|
237 |
+
end_point = 'MaxPool3d_4a_3x3'
|
238 |
+
self.end_points[end_point] = MaxPool3dSamePadding(kernel_size=[3, 3, 3], stride=(2, 2, 2),
|
239 |
+
padding=0)
|
240 |
+
if self._final_endpoint == end_point: return
|
241 |
+
|
242 |
+
end_point = 'Mixed_4b'
|
243 |
+
self.end_points[end_point] = InceptionModule(128+192+96+64, [192,96,208,16,48,64], name+end_point)
|
244 |
+
if self._final_endpoint == end_point: return
|
245 |
+
|
246 |
+
end_point = 'Mixed_4c'
|
247 |
+
self.end_points[end_point] = InceptionModule(192+208+48+64, [160,112,224,24,64,64], name+end_point)
|
248 |
+
if self._final_endpoint == end_point: return
|
249 |
+
|
250 |
+
end_point = 'Mixed_4d'
|
251 |
+
self.end_points[end_point] = InceptionModule(160+224+64+64, [128,128,256,24,64,64], name+end_point)
|
252 |
+
if self._final_endpoint == end_point: return
|
253 |
+
|
254 |
+
end_point = 'Mixed_4e'
|
255 |
+
self.end_points[end_point] = InceptionModule(128+256+64+64, [112,144,288,32,64,64], name+end_point)
|
256 |
+
if self._final_endpoint == end_point: return
|
257 |
+
|
258 |
+
end_point = 'Mixed_4f'
|
259 |
+
self.end_points[end_point] = InceptionModule(112+288+64+64, [256,160,320,32,128,128], name+end_point)
|
260 |
+
if self._final_endpoint == end_point: return
|
261 |
+
|
262 |
+
end_point = 'MaxPool3d_5a_2x2'
|
263 |
+
self.end_points[end_point] = MaxPool3dSamePadding(kernel_size=[2, 2, 2], stride=(2, 2, 2),
|
264 |
+
padding=0)
|
265 |
+
if self._final_endpoint == end_point: return
|
266 |
+
|
267 |
+
end_point = 'Mixed_5b'
|
268 |
+
self.end_points[end_point] = InceptionModule(256+320+128+128, [256,160,320,32,128,128], name+end_point)
|
269 |
+
if self._final_endpoint == end_point: return
|
270 |
+
|
271 |
+
end_point = 'Mixed_5c'
|
272 |
+
self.end_points[end_point] = InceptionModule(256+320+128+128, [384,192,384,48,128,128], name+end_point)
|
273 |
+
if self._final_endpoint == end_point: return
|
274 |
+
|
275 |
+
end_point = 'Logits'
|
276 |
+
self.avg_pool = nn.AvgPool3d(kernel_size=[2, 7, 7],
|
277 |
+
stride=(1, 1, 1))
|
278 |
+
self.dropout = nn.Dropout(dropout_keep_prob)
|
279 |
+
self.logits = Unit3D(in_channels=384+384+128+128, output_channels=self._num_classes,
|
280 |
+
kernel_shape=[1, 1, 1],
|
281 |
+
padding=0,
|
282 |
+
activation_fn=None,
|
283 |
+
use_batch_norm=False,
|
284 |
+
use_bias=True,
|
285 |
+
name='logits')
|
286 |
+
|
287 |
+
self.build()
|
288 |
+
|
289 |
+
|
290 |
+
def replace_logits(self, num_classes):
|
291 |
+
self._num_classes = num_classes
|
292 |
+
self.logits = Unit3D(in_channels=384+384+128+128, output_channels=self._num_classes,
|
293 |
+
kernel_shape=[1, 1, 1],
|
294 |
+
padding=0,
|
295 |
+
activation_fn=None,
|
296 |
+
use_batch_norm=False,
|
297 |
+
use_bias=True,
|
298 |
+
name='logits')
|
299 |
+
|
300 |
+
|
301 |
+
def build(self):
|
302 |
+
for k in self.end_points.keys():
|
303 |
+
self.add_module(k, self.end_points[k])
|
304 |
+
|
305 |
+
def forward(self, x):
|
306 |
+
for end_point in self.VALID_ENDPOINTS:
|
307 |
+
if end_point in self.end_points:
|
308 |
+
x = self._modules[end_point](x) # use _modules to work with dataparallel
|
309 |
+
|
310 |
+
x = self.logits(self.dropout(self.avg_pool(x)))
|
311 |
+
if self._spatial_squeeze:
|
312 |
+
logits = x.squeeze(3).squeeze(3)
|
313 |
+
logits = logits.mean(dim=2)
|
314 |
+
# logits is batch X time X classes, which is what we want to work with
|
315 |
+
return logits
|
316 |
+
|
317 |
+
|
318 |
+
def extract_features(self, x):
|
319 |
+
for end_point in self.VALID_ENDPOINTS:
|
320 |
+
if end_point in self.end_points:
|
321 |
+
x = self._modules[end_point](x)
|
322 |
+
return self.avg_pool(x)
|
common/inception.py
ADDED
@@ -0,0 +1,344 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import torchvision
|
5 |
+
|
6 |
+
try:
|
7 |
+
from torchvision.models.utils import load_state_dict_from_url
|
8 |
+
except ImportError:
|
9 |
+
from torch.utils.model_zoo import load_url as load_state_dict_from_url
|
10 |
+
|
11 |
+
# Inception weights ported to Pytorch from
|
12 |
+
# http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
|
13 |
+
FID_WEIGHTS_URL = "https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth" # noqa: E501
|
14 |
+
|
15 |
+
|
16 |
+
class InceptionV3(nn.Module):
|
17 |
+
"""Pretrained InceptionV3 network returning feature maps"""
|
18 |
+
|
19 |
+
# Index of default block of inception to return,
|
20 |
+
# corresponds to output of final average pooling
|
21 |
+
DEFAULT_BLOCK_INDEX = 3
|
22 |
+
|
23 |
+
# Maps feature dimensionality to their output blocks indices
|
24 |
+
BLOCK_INDEX_BY_DIM = {
|
25 |
+
64: 0, # First max pooling features
|
26 |
+
192: 1, # Second max pooling featurs
|
27 |
+
768: 2, # Pre-aux classifier features
|
28 |
+
2048: 3, # Final average pooling features
|
29 |
+
}
|
30 |
+
|
31 |
+
def __init__(
|
32 |
+
self,
|
33 |
+
output_blocks=(DEFAULT_BLOCK_INDEX,),
|
34 |
+
resize_input=True,
|
35 |
+
normalize_input=True,
|
36 |
+
requires_grad=False,
|
37 |
+
use_fid_inception=True,
|
38 |
+
):
|
39 |
+
"""Build pretrained InceptionV3
|
40 |
+
|
41 |
+
Parameters
|
42 |
+
----------
|
43 |
+
output_blocks : list of int
|
44 |
+
Indices of blocks to return features of. Possible values are:
|
45 |
+
- 0: corresponds to output of first max pooling
|
46 |
+
- 1: corresponds to output of second max pooling
|
47 |
+
- 2: corresponds to output which is fed to aux classifier
|
48 |
+
- 3: corresponds to output of final average pooling
|
49 |
+
resize_input : bool
|
50 |
+
If true, bilinearly resizes input to width and height 299 before
|
51 |
+
feeding input to model. As the network without fully connected
|
52 |
+
layers is fully convolutional, it should be able to handle inputs
|
53 |
+
of arbitrary size, so resizing might not be strictly needed
|
54 |
+
normalize_input : bool
|
55 |
+
If true, scales the input from range (0, 1) to the range the
|
56 |
+
pretrained Inception network expects, namely (-1, 1)
|
57 |
+
requires_grad : bool
|
58 |
+
If true, parameters of the model require gradients. Possibly useful
|
59 |
+
for finetuning the network
|
60 |
+
use_fid_inception : bool
|
61 |
+
If true, uses the pretrained Inception model used in Tensorflow's
|
62 |
+
FID implementation. If false, uses the pretrained Inception model
|
63 |
+
available in torchvision. The FID Inception model has different
|
64 |
+
weights and a slightly different structure from torchvision's
|
65 |
+
Inception model. If you want to compute FID scores, you are
|
66 |
+
strongly advised to set this parameter to true to get comparable
|
67 |
+
results.
|
68 |
+
"""
|
69 |
+
super(InceptionV3, self).__init__()
|
70 |
+
|
71 |
+
self.resize_input = resize_input
|
72 |
+
self.normalize_input = normalize_input
|
73 |
+
self.output_blocks = sorted(output_blocks)
|
74 |
+
self.last_needed_block = max(output_blocks)
|
75 |
+
|
76 |
+
assert self.last_needed_block <= 3, "Last possible output block index is 3"
|
77 |
+
|
78 |
+
self.blocks = nn.ModuleList()
|
79 |
+
|
80 |
+
if use_fid_inception:
|
81 |
+
inception = fid_inception_v3()
|
82 |
+
else:
|
83 |
+
inception = _inception_v3(weights="DEFAULT")
|
84 |
+
|
85 |
+
# Block 0: input to maxpool1
|
86 |
+
block0 = [
|
87 |
+
inception.Conv2d_1a_3x3,
|
88 |
+
inception.Conv2d_2a_3x3,
|
89 |
+
inception.Conv2d_2b_3x3,
|
90 |
+
nn.MaxPool2d(kernel_size=3, stride=2),
|
91 |
+
]
|
92 |
+
self.blocks.append(nn.Sequential(*block0))
|
93 |
+
|
94 |
+
# Block 1: maxpool1 to maxpool2
|
95 |
+
if self.last_needed_block >= 1:
|
96 |
+
block1 = [
|
97 |
+
inception.Conv2d_3b_1x1,
|
98 |
+
inception.Conv2d_4a_3x3,
|
99 |
+
nn.MaxPool2d(kernel_size=3, stride=2),
|
100 |
+
]
|
101 |
+
self.blocks.append(nn.Sequential(*block1))
|
102 |
+
|
103 |
+
# Block 2: maxpool2 to aux classifier
|
104 |
+
if self.last_needed_block >= 2:
|
105 |
+
block2 = [
|
106 |
+
inception.Mixed_5b,
|
107 |
+
inception.Mixed_5c,
|
108 |
+
inception.Mixed_5d,
|
109 |
+
inception.Mixed_6a,
|
110 |
+
inception.Mixed_6b,
|
111 |
+
inception.Mixed_6c,
|
112 |
+
inception.Mixed_6d,
|
113 |
+
inception.Mixed_6e,
|
114 |
+
]
|
115 |
+
self.blocks.append(nn.Sequential(*block2))
|
116 |
+
|
117 |
+
# Block 3: aux classifier to final avgpool
|
118 |
+
if self.last_needed_block >= 3:
|
119 |
+
block3 = [
|
120 |
+
inception.Mixed_7a,
|
121 |
+
inception.Mixed_7b,
|
122 |
+
inception.Mixed_7c,
|
123 |
+
nn.AdaptiveAvgPool2d(output_size=(1, 1)),
|
124 |
+
]
|
125 |
+
self.blocks.append(nn.Sequential(*block3))
|
126 |
+
|
127 |
+
for param in self.parameters():
|
128 |
+
param.requires_grad = requires_grad
|
129 |
+
|
130 |
+
def forward(self, inp):
|
131 |
+
"""Get Inception feature maps
|
132 |
+
|
133 |
+
Parameters
|
134 |
+
----------
|
135 |
+
inp : torch.autograd.Variable
|
136 |
+
Input tensor of shape Bx3xHxW. Values are expected to be in
|
137 |
+
range (0, 1)
|
138 |
+
|
139 |
+
Returns
|
140 |
+
-------
|
141 |
+
List of torch.autograd.Variable, corresponding to the selected output
|
142 |
+
block, sorted ascending by index
|
143 |
+
"""
|
144 |
+
outp = []
|
145 |
+
x = inp
|
146 |
+
|
147 |
+
if self.resize_input:
|
148 |
+
x = F.interpolate(x, size=(299, 299), mode="bilinear", align_corners=False)
|
149 |
+
|
150 |
+
if self.normalize_input:
|
151 |
+
x = 2 * x - 1 # Scale from range (0, 1) to range (-1, 1)
|
152 |
+
|
153 |
+
for idx, block in enumerate(self.blocks):
|
154 |
+
x = block(x)
|
155 |
+
if idx in self.output_blocks:
|
156 |
+
outp.append(x)
|
157 |
+
|
158 |
+
if idx == self.last_needed_block:
|
159 |
+
break
|
160 |
+
|
161 |
+
return outp
|
162 |
+
|
163 |
+
|
164 |
+
def _inception_v3(*args, **kwargs):
|
165 |
+
"""Wraps `torchvision.models.inception_v3`"""
|
166 |
+
try:
|
167 |
+
version = tuple(map(int, torchvision.__version__.split(".")[:2]))
|
168 |
+
except ValueError:
|
169 |
+
# Just a caution against weird version strings
|
170 |
+
version = (0,)
|
171 |
+
|
172 |
+
# Skips default weight inititialization if supported by torchvision
|
173 |
+
# version. See https://github.com/mseitzer/pytorch-fid/issues/28.
|
174 |
+
if version >= (0, 6):
|
175 |
+
kwargs["init_weights"] = False
|
176 |
+
|
177 |
+
# Backwards compatibility: `weights` argument was handled by `pretrained`
|
178 |
+
# argument prior to version 0.13.
|
179 |
+
if version < (0, 13) and "weights" in kwargs:
|
180 |
+
if kwargs["weights"] == "DEFAULT":
|
181 |
+
kwargs["pretrained"] = True
|
182 |
+
elif kwargs["weights"] is None:
|
183 |
+
kwargs["pretrained"] = False
|
184 |
+
else:
|
185 |
+
raise ValueError(
|
186 |
+
"weights=={} not supported in torchvision {}".format(
|
187 |
+
kwargs["weights"], torchvision.__version__
|
188 |
+
)
|
189 |
+
)
|
190 |
+
del kwargs["weights"]
|
191 |
+
|
192 |
+
return torchvision.models.inception_v3(*args, **kwargs)
|
193 |
+
|
194 |
+
|
195 |
+
def fid_inception_v3():
|
196 |
+
"""Build pretrained Inception model for FID computation
|
197 |
+
|
198 |
+
The Inception model for FID computation uses a different set of weights
|
199 |
+
and has a slightly different structure than torchvision's Inception.
|
200 |
+
|
201 |
+
This method first constructs torchvision's Inception and then patches the
|
202 |
+
necessary parts that are different in the FID Inception model.
|
203 |
+
"""
|
204 |
+
inception = _inception_v3(num_classes=1008, aux_logits=False, weights=None)
|
205 |
+
inception.Mixed_5b = FIDInceptionA(192, pool_features=32)
|
206 |
+
inception.Mixed_5c = FIDInceptionA(256, pool_features=64)
|
207 |
+
inception.Mixed_5d = FIDInceptionA(288, pool_features=64)
|
208 |
+
inception.Mixed_6b = FIDInceptionC(768, channels_7x7=128)
|
209 |
+
inception.Mixed_6c = FIDInceptionC(768, channels_7x7=160)
|
210 |
+
inception.Mixed_6d = FIDInceptionC(768, channels_7x7=160)
|
211 |
+
inception.Mixed_6e = FIDInceptionC(768, channels_7x7=192)
|
212 |
+
inception.Mixed_7b = FIDInceptionE_1(1280)
|
213 |
+
inception.Mixed_7c = FIDInceptionE_2(2048)
|
214 |
+
|
215 |
+
state_dict = load_state_dict_from_url(FID_WEIGHTS_URL, progress=True)
|
216 |
+
inception.load_state_dict(state_dict)
|
217 |
+
return inception
|
218 |
+
|
219 |
+
|
220 |
+
class FIDInceptionA(torchvision.models.inception.InceptionA):
|
221 |
+
"""InceptionA block patched for FID computation"""
|
222 |
+
|
223 |
+
def __init__(self, in_channels, pool_features):
|
224 |
+
super(FIDInceptionA, self).__init__(in_channels, pool_features)
|
225 |
+
|
226 |
+
def forward(self, x):
|
227 |
+
branch1x1 = self.branch1x1(x)
|
228 |
+
|
229 |
+
branch5x5 = self.branch5x5_1(x)
|
230 |
+
branch5x5 = self.branch5x5_2(branch5x5)
|
231 |
+
|
232 |
+
branch3x3dbl = self.branch3x3dbl_1(x)
|
233 |
+
branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
|
234 |
+
branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)
|
235 |
+
|
236 |
+
# Patch: Tensorflow's average pool does not use the padded zero's in
|
237 |
+
# its average calculation
|
238 |
+
branch_pool = F.avg_pool2d(
|
239 |
+
x, kernel_size=3, stride=1, padding=1, count_include_pad=False
|
240 |
+
)
|
241 |
+
branch_pool = self.branch_pool(branch_pool)
|
242 |
+
|
243 |
+
outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool]
|
244 |
+
return torch.cat(outputs, 1)
|
245 |
+
|
246 |
+
|
247 |
+
class FIDInceptionC(torchvision.models.inception.InceptionC):
|
248 |
+
"""InceptionC block patched for FID computation"""
|
249 |
+
|
250 |
+
def __init__(self, in_channels, channels_7x7):
|
251 |
+
super(FIDInceptionC, self).__init__(in_channels, channels_7x7)
|
252 |
+
|
253 |
+
def forward(self, x):
|
254 |
+
branch1x1 = self.branch1x1(x)
|
255 |
+
|
256 |
+
branch7x7 = self.branch7x7_1(x)
|
257 |
+
branch7x7 = self.branch7x7_2(branch7x7)
|
258 |
+
branch7x7 = self.branch7x7_3(branch7x7)
|
259 |
+
|
260 |
+
branch7x7dbl = self.branch7x7dbl_1(x)
|
261 |
+
branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl)
|
262 |
+
branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl)
|
263 |
+
branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl)
|
264 |
+
branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl)
|
265 |
+
|
266 |
+
# Patch: Tensorflow's average pool does not use the padded zero's in
|
267 |
+
# its average calculation
|
268 |
+
branch_pool = F.avg_pool2d(
|
269 |
+
x, kernel_size=3, stride=1, padding=1, count_include_pad=False
|
270 |
+
)
|
271 |
+
branch_pool = self.branch_pool(branch_pool)
|
272 |
+
|
273 |
+
outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool]
|
274 |
+
return torch.cat(outputs, 1)
|
275 |
+
|
276 |
+
|
277 |
+
class FIDInceptionE_1(torchvision.models.inception.InceptionE):
|
278 |
+
"""First InceptionE block patched for FID computation"""
|
279 |
+
|
280 |
+
def __init__(self, in_channels):
|
281 |
+
super(FIDInceptionE_1, self).__init__(in_channels)
|
282 |
+
|
283 |
+
def forward(self, x):
|
284 |
+
branch1x1 = self.branch1x1(x)
|
285 |
+
|
286 |
+
branch3x3 = self.branch3x3_1(x)
|
287 |
+
branch3x3 = [
|
288 |
+
self.branch3x3_2a(branch3x3),
|
289 |
+
self.branch3x3_2b(branch3x3),
|
290 |
+
]
|
291 |
+
branch3x3 = torch.cat(branch3x3, 1)
|
292 |
+
|
293 |
+
branch3x3dbl = self.branch3x3dbl_1(x)
|
294 |
+
branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
|
295 |
+
branch3x3dbl = [
|
296 |
+
self.branch3x3dbl_3a(branch3x3dbl),
|
297 |
+
self.branch3x3dbl_3b(branch3x3dbl),
|
298 |
+
]
|
299 |
+
branch3x3dbl = torch.cat(branch3x3dbl, 1)
|
300 |
+
|
301 |
+
# Patch: Tensorflow's average pool does not use the padded zero's in
|
302 |
+
# its average calculation
|
303 |
+
branch_pool = F.avg_pool2d(
|
304 |
+
x, kernel_size=3, stride=1, padding=1, count_include_pad=False
|
305 |
+
)
|
306 |
+
branch_pool = self.branch_pool(branch_pool)
|
307 |
+
|
308 |
+
outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
|
309 |
+
return torch.cat(outputs, 1)
|
310 |
+
|
311 |
+
|
312 |
+
class FIDInceptionE_2(torchvision.models.inception.InceptionE):
|
313 |
+
"""Second InceptionE block patched for FID computation"""
|
314 |
+
|
315 |
+
def __init__(self, in_channels):
|
316 |
+
super(FIDInceptionE_2, self).__init__(in_channels)
|
317 |
+
|
318 |
+
def forward(self, x):
|
319 |
+
branch1x1 = self.branch1x1(x)
|
320 |
+
|
321 |
+
branch3x3 = self.branch3x3_1(x)
|
322 |
+
branch3x3 = [
|
323 |
+
self.branch3x3_2a(branch3x3),
|
324 |
+
self.branch3x3_2b(branch3x3),
|
325 |
+
]
|
326 |
+
branch3x3 = torch.cat(branch3x3, 1)
|
327 |
+
|
328 |
+
branch3x3dbl = self.branch3x3dbl_1(x)
|
329 |
+
branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
|
330 |
+
branch3x3dbl = [
|
331 |
+
self.branch3x3dbl_3a(branch3x3dbl),
|
332 |
+
self.branch3x3dbl_3b(branch3x3dbl),
|
333 |
+
]
|
334 |
+
branch3x3dbl = torch.cat(branch3x3dbl, 1)
|
335 |
+
|
336 |
+
# Patch: The FID Inception model uses max pooling instead of average
|
337 |
+
# pooling. This is likely an error in this specific Inception
|
338 |
+
# implementation, as other Inception models use average pooling here
|
339 |
+
# (which matches the description in the paper).
|
340 |
+
branch_pool = F.max_pool2d(x, kernel_size=3, stride=1, padding=1)
|
341 |
+
branch_pool = self.branch_pool(branch_pool)
|
342 |
+
|
343 |
+
outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
|
344 |
+
return torch.cat(outputs, 1)
|
common/plot/__init__.py
ADDED
File without changes
|
common/plot/aggregated_output.csv
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name,bridge_data_v2/teacher_force_psnr,bridge_data_v2/teacher_force_psnr_delta,bridge_data_v2/teacher_force_ssim,bridge_data_v2/teacher_force_pred_lpips,bridge_data_v2/teacher_force_loss,bridge_data_v2/num_examples,fractal20220817_data/teacher_force_psnr,fractal20220817_data/teacher_force_psnr_delta,fractal20220817_data/teacher_force_ssim,fractal20220817_data/teacher_force_pred_lpips,fractal20220817_data/teacher_force_loss,fractal20220817_data/num_examples,language_table/teacher_force_psnr,language_table/teacher_force_psnr_delta,language_table/teacher_force_ssim,language_table/teacher_force_pred_lpips,language_table/teacher_force_loss,language_table/num_examples,ucsd_pick_and_place_dataset_converted_externally_to_rlds/teacher_force_psnr,ucsd_pick_and_place_dataset_converted_externally_to_rlds/teacher_force_psnr_delta,ucsd_pick_and_place_dataset_converted_externally_to_rlds/teacher_force_ssim,ucsd_pick_and_place_dataset_converted_externally_to_rlds/teacher_force_pred_lpips,ucsd_pick_and_place_dataset_converted_externally_to_rlds/teacher_force_loss,ucsd_pick_and_place_dataset_converted_externally_to_rlds/num_examples,kaist_nonprehensile_converted_externally_to_rlds/teacher_force_psnr,kaist_nonprehensile_converted_externally_to_rlds/teacher_force_psnr_delta,kaist_nonprehensile_converted_externally_to_rlds/teacher_force_ssim,kaist_nonprehensile_converted_externally_to_rlds/teacher_force_pred_lpips,kaist_nonprehensile_converted_externally_to_rlds/teacher_force_loss,kaist_nonprehensile_converted_externally_to_rlds/num_examples,ucsd_kitchen_dataset_converted_externally_to_rlds/teacher_force_psnr,ucsd_kitchen_dataset_converted_externally_to_rlds/teacher_force_psnr_delta,ucsd_kitchen_dataset_converted_externally_to_rlds/teacher_force_ssim,ucsd_kitchen_dataset_converted_externally_to_rlds/teacher_force_pred_lpips,ucsd_kitchen_dataset_converted_externally_to_rlds/teacher_force_loss,ucsd_kitchen_dataset_converted_externally_to_rlds/num_examples,utokyo_xarm_bimanual_converted_externally_to_rlds/teacher_force_psnr,utokyo_xarm_bimanual_converted_externally_to_rlds/teacher_force_psnr_delta,utokyo_xarm_bimanual_converted_externally_to_rlds/teacher_force_ssim,utokyo_xarm_bimanual_converted_externally_to_rlds/teacher_force_pred_lpips,utokyo_xarm_bimanual_converted_externally_to_rlds/teacher_force_loss,utokyo_xarm_bimanual_converted_externally_to_rlds/num_examples,stanford_hydra_dataset_converted_externally_to_rlds/teacher_force_psnr,stanford_hydra_dataset_converted_externally_to_rlds/teacher_force_psnr_delta,stanford_hydra_dataset_converted_externally_to_rlds/teacher_force_ssim,stanford_hydra_dataset_converted_externally_to_rlds/teacher_force_pred_lpips,stanford_hydra_dataset_converted_externally_to_rlds/teacher_force_loss,stanford_hydra_dataset_converted_externally_to_rlds/num_examples,austin_sirius_dataset_converted_externally_to_rlds/teacher_force_psnr,austin_sirius_dataset_converted_externally_to_rlds/teacher_force_psnr_delta,austin_sirius_dataset_converted_externally_to_rlds/teacher_force_ssim,austin_sirius_dataset_converted_externally_to_rlds/teacher_force_pred_lpips,austin_sirius_dataset_converted_externally_to_rlds/teacher_force_loss,austin_sirius_dataset_converted_externally_to_rlds/num_examples,berkeley_fanuc_manipulation/teacher_force_psnr,berkeley_fanuc_manipulation/teacher_force_psnr_delta,berkeley_fanuc_manipulation/teacher_force_ssim,berkeley_fanuc_manipulation/teacher_force_pred_lpips,berkeley_fanuc_manipulation/teacher_force_loss,berkeley_fanuc_manipulation/num_examples,berkeley_mvp_converted_externally_to_rlds/teacher_force_psnr,berkeley_mvp_converted_externally_to_rlds/teacher_force_psnr_delta,berkeley_mvp_converted_externally_to_rlds/teacher_force_ssim,berkeley_mvp_converted_externally_to_rlds/teacher_force_pred_lpips,berkeley_mvp_converted_externally_to_rlds/teacher_force_loss,berkeley_mvp_converted_externally_to_rlds/num_examples,berkeley_rpt_converted_externally_to_rlds/teacher_force_psnr,berkeley_rpt_converted_externally_to_rlds/teacher_force_psnr_delta,berkeley_rpt_converted_externally_to_rlds/teacher_force_ssim,berkeley_rpt_converted_externally_to_rlds/teacher_force_pred_lpips,berkeley_rpt_converted_externally_to_rlds/teacher_force_loss,berkeley_rpt_converted_externally_to_rlds/num_examples,cmu_play_fusion/teacher_force_psnr,cmu_play_fusion/teacher_force_psnr_delta,cmu_play_fusion/teacher_force_ssim,cmu_play_fusion/teacher_force_pred_lpips,cmu_play_fusion/teacher_force_loss,cmu_play_fusion/num_examples,iamlab_cmu_pickup_insert_converted_externally_to_rlds/teacher_force_psnr,iamlab_cmu_pickup_insert_converted_externally_to_rlds/teacher_force_psnr_delta,iamlab_cmu_pickup_insert_converted_externally_to_rlds/teacher_force_ssim,iamlab_cmu_pickup_insert_converted_externally_to_rlds/teacher_force_pred_lpips,iamlab_cmu_pickup_insert_converted_externally_to_rlds/teacher_force_loss,iamlab_cmu_pickup_insert_converted_externally_to_rlds/num_examples,qut_dexterous_manpulation/teacher_force_psnr,qut_dexterous_manpulation/teacher_force_psnr_delta,qut_dexterous_manpulation/teacher_force_ssim,qut_dexterous_manpulation/teacher_force_pred_lpips,qut_dexterous_manpulation/teacher_force_loss,qut_dexterous_manpulation/num_examples,robo_net/teacher_force_psnr,robo_net/teacher_force_psnr_delta,robo_net/teacher_force_ssim,robo_net/teacher_force_pred_lpips,robo_net/teacher_force_loss,robo_net/num_examples,furniture_bench_dataset_converted_externally_to_rlds/teacher_force_psnr,furniture_bench_dataset_converted_externally_to_rlds/teacher_force_psnr_delta,furniture_bench_dataset_converted_externally_to_rlds/teacher_force_ssim,furniture_bench_dataset_converted_externally_to_rlds/teacher_force_pred_lpips,furniture_bench_dataset_converted_externally_to_rlds/teacher_force_loss,furniture_bench_dataset_converted_externally_to_rlds/num_examples,dlr_sara_grid_clamp_converted_externally_to_rlds/teacher_force_psnr,dlr_sara_grid_clamp_converted_externally_to_rlds/teacher_force_psnr_delta,dlr_sara_grid_clamp_converted_externally_to_rlds/teacher_force_ssim,dlr_sara_grid_clamp_converted_externally_to_rlds/teacher_force_pred_lpips,dlr_sara_grid_clamp_converted_externally_to_rlds/teacher_force_loss,dlr_sara_grid_clamp_converted_externally_to_rlds/num_examples,cmu_stretch/teacher_force_psnr,cmu_stretch/teacher_force_psnr_delta,cmu_stretch/teacher_force_ssim,cmu_stretch/teacher_force_pred_lpips,cmu_stretch/teacher_force_loss,cmu_stretch/num_examples,spoc/teacher_force_psnr,spoc/teacher_force_psnr_delta,spoc/teacher_force_ssim,spoc/teacher_force_pred_lpips,spoc/teacher_force_loss,spoc/num_examples,columbia_cairlab_pusht_real/teacher_force_psnr,columbia_cairlab_pusht_real/teacher_force_psnr_delta,columbia_cairlab_pusht_real/teacher_force_ssim,columbia_cairlab_pusht_real/teacher_force_pred_lpips,columbia_cairlab_pusht_real/teacher_force_loss,columbia_cairlab_pusht_real/num_examples,droid/teacher_force_psnr,droid/teacher_force_psnr_delta,droid/teacher_force_ssim,droid/teacher_force_pred_lpips,droid/teacher_force_loss,droid/num_examples,toto/teacher_force_psnr,toto/teacher_force_psnr_delta,toto/teacher_force_ssim,toto/teacher_force_pred_lpips,toto/teacher_force_loss,toto/num_examples,io_ai_tech/teacher_force_psnr,io_ai_tech/teacher_force_psnr_delta,io_ai_tech/teacher_force_ssim,io_ai_tech/teacher_force_pred_lpips,io_ai_tech/teacher_force_loss,io_ai_tech/num_examples,conq_hose_manipulation/teacher_force_psnr,conq_hose_manipulation/teacher_force_psnr_delta,conq_hose_manipulation/teacher_force_ssim,conq_hose_manipulation/teacher_force_pred_lpips,conq_hose_manipulation/teacher_force_loss,conq_hose_manipulation/num_examples,dobbe/teacher_force_psnr,dobbe/teacher_force_psnr_delta,dobbe/teacher_force_ssim,dobbe/teacher_force_pred_lpips,dobbe/teacher_force_loss,dobbe/num_examples,berkeley_gnm_cory_hall/teacher_force_psnr,berkeley_gnm_cory_hall/teacher_force_psnr_delta,berkeley_gnm_cory_hall/teacher_force_ssim,berkeley_gnm_cory_hall/teacher_force_pred_lpips,berkeley_gnm_cory_hall/teacher_force_loss,berkeley_gnm_cory_hall/num_examples,plex_robosuite/teacher_force_psnr,plex_robosuite/teacher_force_psnr_delta,plex_robosuite/teacher_force_ssim,plex_robosuite/teacher_force_pred_lpips,plex_robosuite/teacher_force_loss,plex_robosuite/num_examples,usc_cloth_sim_converted_externally_to_rlds/teacher_force_psnr,usc_cloth_sim_converted_externally_to_rlds/teacher_force_psnr_delta,usc_cloth_sim_converted_externally_to_rlds/teacher_force_ssim,usc_cloth_sim_converted_externally_to_rlds/teacher_force_pred_lpips,usc_cloth_sim_converted_externally_to_rlds/teacher_force_loss,usc_cloth_sim_converted_externally_to_rlds/num_examples,berkeley_cable_routing/teacher_force_psnr,berkeley_cable_routing/teacher_force_psnr_delta,berkeley_cable_routing/teacher_force_ssim,berkeley_cable_routing/teacher_force_pred_lpips,berkeley_cable_routing/teacher_force_loss,berkeley_cable_routing/num_examples,imperial_wrist_dataset/teacher_force_psnr,imperial_wrist_dataset/teacher_force_psnr_delta,imperial_wrist_dataset/teacher_force_ssim,imperial_wrist_dataset/teacher_force_pred_lpips,imperial_wrist_dataset/teacher_force_loss,imperial_wrist_dataset/num_examples,bc_z/teacher_force_psnr,bc_z/teacher_force_psnr_delta,bc_z/teacher_force_ssim,bc_z/teacher_force_pred_lpips,bc_z/teacher_force_loss,bc_z/num_examples,kuka/teacher_force_psnr,kuka/teacher_force_psnr_delta,kuka/teacher_force_ssim,kuka/teacher_force_pred_lpips,kuka/teacher_force_loss,kuka/num_examples,roboturk/teacher_force_psnr,roboturk/teacher_force_psnr_delta,roboturk/teacher_force_ssim,roboturk/teacher_force_pred_lpips,roboturk/teacher_force_loss,roboturk/num_examples,metaworld/teacher_force_psnr,metaworld/teacher_force_psnr_delta,metaworld/teacher_force_ssim,metaworld/teacher_force_pred_lpips,metaworld/teacher_force_loss,metaworld/num_examples,robomimic/teacher_force_psnr,robomimic/teacher_force_psnr_delta,robomimic/teacher_force_ssim,robomimic/teacher_force_pred_lpips,robomimic/teacher_force_loss,robomimic/num_examples,epic_kitchen/teacher_force_psnr,epic_kitchen/teacher_force_psnr_delta,epic_kitchen/teacher_force_ssim,epic_kitchen/teacher_force_pred_lpips,epic_kitchen/teacher_force_loss,epic_kitchen/num_examples,ego4d/teacher_force_psnr,ego4d/teacher_force_psnr_delta,ego4d/teacher_force_ssim,ego4d/teacher_force_pred_lpips,ego4d/teacher_force_loss,ego4d/num_examples,nyu_door_opening_surprising_effectiveness/teacher_force_psnr,nyu_door_opening_surprising_effectiveness/teacher_force_psnr_delta,nyu_door_opening_surprising_effectiveness/teacher_force_ssim,nyu_door_opening_surprising_effectiveness/teacher_force_pred_lpips,nyu_door_opening_surprising_effectiveness/teacher_force_loss,nyu_door_opening_surprising_effectiveness/num_examples
|
2 |
+
20.11654281616211,0.2406988888978958,0.6522064805030823,0.16225013136863708,4.879761219024658,500,22.18416404724121,0.35506966710090637,0.6869667768478394,0.15024960041046143,4.770835876464844,500,21.715396881103516,0.3029274344444275,0.6992139220237732,0.16471771895885468,5.294665813446045,198,21.16069793701172,0.8068287372589111,0.856053352355957,0.13014769554138184,6.175273895263672,268,20.990825653076172,-0.10340484231710434,0.7044205069541931,0.18493501842021942,7.725522518157959,215,15.698708534240723,0.0012377961538732052,0.5898451209068298,0.17795829474925995,5.001875400543213,44,19.282543182373047,0.25010883808135986,0.6505519151687622,0.17733542621135712,5.174380779266357,26,17.12480926513672,-0.2587044835090637,0.705141007900238,0.19221702218055725,5.820067882537842,500,19.765209197998047,0.520576000213623,0.7926568984985352,0.1714717447757721,6.499162673950195,500,24.512998580932617,0.05698919668793678,0.7914198637008667,0.10842812061309814,3.0430972576141357,197,25.59234619140625,0.10779287666082382,0.8336195349693298,0.19609792530536652,9.673583984375,145,19.04483985900879,-1.549579381942749,0.8432617783546448,0.23590296506881714,9.34917163848877,500,25.48477554321289,-0.2415604293346405,0.808574378490448,0.14057636260986328,3.712907552719116,500,20.40241241455078,-1.7759240865707397,0.695418119430542,0.18546271324157715,5.312252044677734,500,18.150922775268555,0.07440949231386185,0.623611569404602,0.19956742227077484,2.496238946914673,500,19.270723342895508,0.13051848113536835,0.6409730315208435,0.21255843341350555,5.373872756958008,400,14.527210235595703,-1.1576420068740845,0.6855177283287048,0.2406100034713745,6.9484028816223145,500,22.118375778198242,0.12427309900522232,0.760040819644928,0.13304449617862701,2.0709831714630127,142,27.101343154907227,0.002104964340105653,0.8569799661636353,0.11864025890827179,4.056473731994629,304,14.499595642089844,-2.287384033203125,0.6487097144126892,0.5801433324813843,11.471040725708008,500,19.862470626831055,0.004642155021429062,0.8116016387939453,0.15320290625095367,7.407879829406738,290,18.82098960876465,1.036180019378662,0.6922352910041809,0.18524041771888733,5.3734941482543945,500,19.008501052856445,0.08843769133090973,0.654740035533905,0.19600719213485718,5.878746032714844,500,25.3173885345459,-1.4838885068893433,0.8376902341842651,0.091035857796669,6.641345977783203,500,0,0,0,0,0,0,23.15846061706543,0.19177615642547607,0.7468024492263794,0.14149916172027588,7.459285259246826,500,20.05571746826172,-0.6787086129188538,0.8033839464187622,0.23725976049900055,9.829147338867188,185,25.470836639404297,-0.04034167900681496,0.8336919546127319,0.13884975016117096,2.244248867034912,307,27.915611267089844,2.5503063201904297,0.9527876377105713,0.09793127328157425,1.2040963172912598,200,20.912508010864258,0.591270387172699,0.8214857578277588,0.2098020762205124,4.666227340698242,3,23.71303939819336,-0.19690543413162231,0.6411390900611877,0.09051118791103363,8.156048774719238,58,23.800819396972656,0.357938289642334,0.7015560269355774,0.16095396876335144,5.172791004180908,500,19.602933883666992,0.08206112682819366,0.6492921710014343,0.20972613990306854,5.722141265869141,136,15.88235092163086,0.034370679408311844,0.6604549288749695,0.26275190711021423,7.8473687171936035,475,0,0,0,0,0,0,20.04568099975586,-0.6712338328361511,0.8586370944976807,0.14845611155033112,4.173150062561035,60,0,0,0,0,0,0,0,0,0,0,0,0,16.780973434448242,0.36480912566185,0.5418305993080139,0.23359735310077667,9.743389129638672,63,0
|
3 |
+
20.11654281616211,0.2406988888978958,0.6522064805030823,0.16225013136863708,4.879761219024658,500,22.18416404724121,0.35506966710090637,0.6869667768478394,0.15024960041046143,4.770835876464844,500,21.715396881103516,0.3029274344444275,0.6992139220237732,0.16471771895885468,5.294665813446045,198,21.16069793701172,0.8068287372589111,0.856053352355957,0.13014769554138184,6.175273895263672,268,20.990825653076172,-0.10340484231710434,0.7044205069541931,0.18493501842021942,7.725522518157959,215,15.698708534240723,0.0012377961538732052,0.5898451209068298,0.17795829474925995,5.001875400543213,44,19.282543182373047,0.25010883808135986,0.6505519151687622,0.17733542621135712,5.174380779266357,26,17.12480926513672,-0.2587044835090637,0.705141007900238,0.19221702218055725,5.820067882537842,500,19.765209197998047,0.520576000213623,0.7926568984985352,0.1714717447757721,6.499162673950195,500,24.512998580932617,0.05698919668793678,0.7914198637008667,0.10842812061309814,3.0430972576141357,197,25.59234619140625,0.10779287666082382,0.8336195349693298,0.19609792530536652,9.673583984375,145,19.04483985900879,-1.549579381942749,0.8432617783546448,0.23590296506881714,9.34917163848877,500,25.48477554321289,-0.2415604293346405,0.808574378490448,0.14057636260986328,3.712907552719116,500,20.40241241455078,-1.7759240865707397,0.695418119430542,0.18546271324157715,5.312252044677734,500,18.150922775268555,0.07440949231386185,0.623611569404602,0.19956742227077484,2.496238946914673,500,19.270723342895508,0.13051848113536835,0.6409730315208435,0.21255843341350555,5.373872756958008,400,14.527210235595703,-1.1576420068740845,0.6855177283287048,0.2406100034713745,6.9484028816223145,500,22.118375778198242,0.12427309900522232,0.760040819644928,0.13304449617862701,2.0709831714630127,142,27.101343154907227,0.002104964340105653,0.8569799661636353,0.11864025890827179,4.056473731994629,304,14.499595642089844,-2.287384033203125,0.6487097144126892,0.5801433324813843,11.471040725708008,500,19.862470626831055,0.004642155021429062,0.8116016387939453,0.15320290625095367,7.407879829406738,290,18.82098960876465,1.036180019378662,0.6922352910041809,0.18524041771888733,5.3734941482543945,500,19.008501052856445,0.08843769133090973,0.654740035533905,0.19600719213485718,5.878746032714844,500,25.3173885345459,-1.4838885068893433,0.8376902341842651,0.091035857796669,6.641345977783203,500,0,0,0,0,0,0,23.15846061706543,0.19177615642547607,0.7468024492263794,0.14149916172027588,7.459285259246826,500,20.05571746826172,-0.6787086129188538,0.8033839464187622,0.23725976049900055,9.829147338867188,185,25.470836639404297,-0.04034167900681496,0.8336919546127319,0.13884975016117096,2.244248867034912,307,27.915611267089844,2.5503063201904297,0.9527876377105713,0.09793127328157425,1.2040963172912598,200,20.912508010864258,0.591270387172699,0.8214857578277588,0.2098020762205124,4.666227340698242,3,23.71303939819336,-0.19690543413162231,0.6411390900611877,0.09051118791103363,8.156048774719238,58,23.800819396972656,0.357938289642334,0.7015560269355774,0.16095396876335144,5.172791004180908,500,19.602933883666992,0.08206112682819366,0.6492921710014343,0.20972613990306854,5.722141265869141,136,15.88235092163086,0.034370679408311844,0.6604549288749695,0.26275190711021423,7.8473687171936035,475,0,0,0,0,0,0,20.04568099975586,-0.6712338328361511,0.8586370944976807,0.14845611155033112,4.173150062561035,60,0,0,0,0,0,0,0,0,0,0,0,0,16.780973434448242,0.36480912566185,0.5418305993080139,0.23359735310077667,9.743389129638672,63,0
|
4 |
+
20.22026252746582,0.38262513279914856,0.6533553600311279,0.1599833220243454,4.625724792480469,500,22.20106315612793,1.1525933742523193,0.6865711212158203,0.1498977541923523,4.709808826446533,500,21.663299560546875,0.315200537443161,0.6973788142204285,0.16729794442653656,5.099008083343506,198,21.049013137817383,1.241530179977417,0.8553103804588318,0.1382942497730255,6.023053169250488,268,20.187990188598633,0.2288614809513092,0.6624469757080078,0.21102775633335114,8.721923828125,215,14.019247055053711,0.1637289673089981,0.5697340369224548,0.2109413594007492,5.538760662078857,44,19.34811782836914,1.7672139406204224,0.65287184715271,0.17289860546588898,4.801314830780029,26,17.78873634338379,0.7138077020645142,0.7086768746376038,0.17418105900287628,5.320345878601074,500,19.624094009399414,0.8938031792640686,0.7965993285179138,0.19896067678928375,6.5370965003967285,500,24.533777236938477,0.036842938512563705,0.7917919754981995,0.10872980952262878,2.8953208923339844,197,25.017925262451172,0.39651936292648315,0.8296465873718262,0.19340361654758453,9.546339988708496,145,19.407617568969727,0.20671948790550232,0.8511928915977478,0.21369731426239014,8.630515098571777,500,25.74650001525879,0.004801941104233265,0.8135946989059448,0.13683146238327026,3.3414487838745117,500,18.16292953491211,-1.5063064098358154,0.6798076033592224,0.2016068398952484,4.924650192260742,500,17.171728134155273,0.12000428140163422,0.6030166149139404,0.21893559396266937,2.7387208938598633,200,19.084640502929688,0.17159269750118256,0.6400240659713745,0.21725738048553467,5.156170845031738,200,14.959861755371094,-0.08570398390293121,0.6782578825950623,0.24334610998630524,6.023822784423828,200,20.60854148864746,1.239460825920105,0.7399578094482422,0.15486611425876617,2.6809659004211426,142,27.038379669189453,-0.04835420474410057,0.8582215905189514,0.11914262175559998,3.703599214553833,200,16.528011322021484,0.07733757048845291,0.6863109469413757,0.43281883001327515,10.088181495666504,200,20.27916717529297,0.39283424615859985,0.8197209239006042,0.1518518626689911,7.038107872009277,200,19.77338218688965,-0.10181392729282379,0.7324086427688599,0.1659238040447235,5.009759426116943,200,16.711793899536133,-0.7416815161705017,0.5757627487182617,0.25752225518226624,7.026858329772949,200,26.67308807373047,-0.19710086286067963,0.849867582321167,0.08147656917572021,6.09824275970459,200,0,0,0,0,0,0,22.918180465698242,0.21646614372730255,0.7544977068901062,0.13648433983325958,6.675195693969727,200,20.61359977722168,-0.2607249617576599,0.8171404600143433,0.2431134730577469,9.399118423461914,185,25.659948348999023,-0.007021207828074694,0.8343022465705872,0.13590268790721893,2.109799385070801,200,28.551301956176758,0.999920666217804,0.9539777040481567,0.09078202396631241,1.1626181602478027,200,20.705917358398438,0.19353322684764862,0.822070300579071,0.2162921279668808,4.443334579467773,3,23.890663146972656,0.042181383818387985,0.6438125371932983,0.09075239300727844,7.78416109085083,58,23.974327087402344,0.4367068409919739,0.6935603022575378,0.15629488229751587,4.970770835876465,200,19.62123680114746,0.2064054161310196,0.650097131729126,0.21027418971061707,5.509530067443848,136,15.70105266571045,-0.008274257183074951,0.6658218502998352,0.26436904072761536,7.878448009490967,200,0,0,0,0,0,0,21.392261505126953,0.8078638315200806,0.8658612966537476,0.1169213280081749,4.069273471832275,60,0,0,0,0,0,0,0,0,0,0,0,0,16.521770477294922,0.4119645357131958,0.5416176319122314,0.2529039680957794,9.572311401367188,63,0
|
5 |
+
20.192195892333984,0.0013111135922372341,0.6530546545982361,0.16100579500198364,4.674829006195068,500,22.106380462646484,0.002372157759964466,0.6869567632675171,0.14720509946346283,4.874122619628906,500,21.668710708618164,0.010731762275099754,0.6994706392288208,0.1702597439289093,5.343525409698486,198,20.726055145263672,0.004278174135833979,0.8526590466499329,0.13251639902591705,6.267956256866455,268,20.887666702270508,0.3664732277393341,0.6933053135871887,0.18257129192352295,7.749646186828613,200,15.50479507446289,0.011554110795259476,0.5916057229042053,0.18086765706539154,5.101010799407959,44,19.26924705505371,-0.0402817502617836,0.6491515636444092,0.17109069228172302,5.088503360748291,26,18.38506317138672,-0.01206012349575758,0.7032961249351501,0.17353220283985138,4.688672065734863,200,19.471759796142578,-0.0025741406716406345,0.790705680847168,0.18272539973258972,6.433710098266602,200,24.408714294433594,-0.006203812547028065,0.7900201678276062,0.1093222051858902,3.0680618286132812,197,25.618268966674805,-0.012119447812438011,0.8320527672767639,0.189222514629364,9.541532516479492,145,20.02849769592285,0.0040948037058115005,0.8293033838272095,0.22746650874614716,8.622618675231934,200,25.799362182617188,0.002677352400496602,0.8147953748703003,0.1341477930545807,3.4496517181396484,200,22.345197677612305,0.27116137742996216,0.7258548736572266,0.1409902721643448,3.6172609329223633,200,16.546030044555664,-0.10674090683460236,0.5869951844215393,0.23679934442043304,3.374926805496216,200,19.099502563476562,-0.007533765863627195,0.6395227909088135,0.21552789211273193,5.256075382232666,200,16.655963897705078,0.0018004697049036622,0.6986208558082581,0.1904933899641037,5.870639801025391,200,22.06020736694336,0.564301609992981,0.7581813931465149,0.13439540565013885,3.1246254444122314,142,27.067359924316406,0.00568796694278717,0.857913076877594,0.11922232806682587,4.030396461486816,200,16.69654655456543,0.013287489302456379,0.6759082078933716,0.41035643219947815,10.485578536987305,200,20.1484317779541,-0.0037386345211416483,0.8214008808135986,0.1595878303050995,7.100029945373535,200,19.99980926513672,0.002995690330862999,0.7338225841522217,0.16079355776309967,5.090161323547363,200,17.28361701965332,-0.0767323225736618,0.5945534706115723,0.23762698471546173,6.790170192718506,200,26.707612991333008,-0.010888484306633472,0.8517170548439026,0.08222655206918716,6.144393444061279,200,0,0,0,0,0,0,22.81102752685547,0.004911952186375856,0.753410816192627,0.13840247690677643,6.590641021728516,200,20.56325340270996,-0.0036200578324496746,0.8122907876968384,0.2328769862651825,9.606425285339355,185,25.634384155273438,0.014794806018471718,0.8339281678199768,0.13365812599658966,2.2014026641845703,200,28.079368591308594,-0.27046048641204834,0.9531135559082031,0.09380464255809784,1.130263090133667,200,20.635210037231445,0.014526949264109135,0.8181746006011963,0.21357515454292297,4.678038597106934,3,23.946765899658203,0.026584567502141,0.6427099704742432,0.09646405279636383,8.022929191589355,58,23.835311889648438,-0.00212214607745409,0.6980120539665222,0.15597480535507202,5.109379768371582,200,19.57789421081543,-0.008618982508778572,0.6491334438323975,0.21285711228847504,5.622872829437256,136,15.693717956542969,-0.0042601898312568665,0.6714515686035156,0.2660149037837982,7.933395862579346,200,0,0,0,0,0,0,20.719491958618164,0.012509683147072792,0.8595790266990662,0.11407399922609329,3.9665277004241943,60,0,0,0,0,0,0,0,0,0,0,0,0,16.573131561279297,-0.02216508239507675,0.5317869782447815,0.25623247027397156,9.647202491760254,63,0
|
6 |
+
20.11272621154785,0.2109360694885254,0.6520361304283142,0.1624097228050232,4.9084296226501465,500,22.162460327148438,0.5982567667961121,0.6869907379150391,0.14949150383472443,4.766613483428955,500,21.70660972595215,0.24707387387752533,0.6991060376167297,0.16490808129310608,5.31504487991333,198,20.89645767211914,0.6825409531593323,0.8557612299919128,0.13019102811813354,6.398017406463623,200,21.12166976928711,0.08367721736431122,0.7037967443466187,0.1830134242773056,7.72568416595459,200,15.728557586669922,0.023094290867447853,0.5895631909370422,0.17670421302318573,4.983802318572998,44,19.269149780273438,0.21575377881526947,0.6504108905792236,0.17488287389278412,5.12293004989624,26,17.66543197631836,-0.3417667746543884,0.6988275647163391,0.1885838657617569,5.298708438873291,200,19.65871810913086,0.5588710904121399,0.7912994027137756,0.17227208614349365,6.4804582595825195,200,24.516550064086914,0.020750248804688454,0.7902164459228516,0.10835054516792297,3.0472919940948486,197,25.76390266418457,0.2111617475748062,0.8347102999687195,0.1926291286945343,9.65595817565918,145,18.151044845581055,-1.411789894104004,0.7994647026062012,0.26881223917007446,9.587594032287598,200,25.545425415039062,-0.2518838047981262,0.8101628422737122,0.13834546506404877,3.7476205825805664,200,20.21263313293457,-2.1581101417541504,0.6755650043487549,0.20787520706653595,5.349665641784668,200,17.979320526123047,0.029456928372383118,0.6195611953735352,0.2056853473186493,2.5060770511627197,200,19.21215057373047,0.09863721579313278,0.6396337747573853,0.21282075345516205,5.47512149810791,200,14.731433868408203,-1.0437582731246948,0.6715344190597534,0.24017825722694397,6.835475444793701,200,22.121034622192383,0.13931676745414734,0.7599001526832581,0.13316860795021057,2.0559332370758057,142,26.967485427856445,-0.04352593049407005,0.8562372326850891,0.11917727440595627,4.126804828643799,200,14.093514442443848,-2.401867628097534,0.6462867259979248,0.5737171769142151,11.239895820617676,200,20.1369686126709,0.03518101945519447,0.8176986575126648,0.15042880177497864,7.3873419761657715,200,19.936338424682617,1.2820953130722046,0.7327397465705872,0.16244173049926758,5.061605930328369,200,18.608089447021484,0.03232511132955551,0.6288440823554993,0.19996103644371033,6.023675918579102,200,25.265308380126953,-1.5252397060394287,0.8366735577583313,0.09203425794839859,6.713019847869873,200,0,0,0,0,0,0,22.606639862060547,0.23754052817821503,0.7469227313995361,0.14175154268741608,7.264929294586182,200,20.01392364501953,-0.7243794202804565,0.8044877052307129,0.24030713737010956,9.7893705368042,185,25.513294219970703,0.022836336866021156,0.8328461647033691,0.13896547257900238,2.2256522178649902,200,27.747732162475586,2.1492269039154053,0.9519210457801819,0.09740724414587021,1.2134103775024414,200,20.910673141479492,0.47492536902427673,0.8211309909820557,0.21091899275779724,4.660539150238037,3,23.682945251464844,-0.09257392585277557,0.6411334872245789,0.09143876284360886,8.146902084350586,58,23.916791915893555,0.2757120430469513,0.6961292028427124,0.15634751319885254,5.073945999145508,200,19.578807830810547,0.06890798360109329,0.648792028427124,0.2106357365846634,5.744231224060059,136,15.669443130493164,0.014397966675460339,0.662750780582428,0.26728084683418274,8.170258522033691,200,0,0,0,0,0,0,20.507707595825195,-0.1476416438817978,0.8596673011779785,0.13412821292877197,4.167920112609863,60,0,0,0,0,0,0,0,0,0,0,0,0,16.79804229736328,0.39872896671295166,0.5435536503791809,0.23285678029060364,9.751370429992676,63,0
|
7 |
+
20.09503936767578,0.19324186444282532,0.6532765030860901,0.16777116060256958,4.9215545654296875,200,21.670238494873047,0.5962230563163757,0.6772172451019287,0.15711373090744019,4.973745822906494,200,21.32579231262207,0.23807521164417267,0.6979929804801941,0.17777009308338165,5.471674919128418,198,19.964054107666016,0.6898609399795532,0.8480191826820374,0.14881540834903717,6.650026798248291,200,19.73797035217285,0.028328483924269676,0.6416213512420654,0.2283506691455841,9.225159645080566,200,14.12078857421875,-0.07745029032230377,0.5660436749458313,0.21174412965774536,5.737281322479248,44,17.832124710083008,-0.6108675003051758,0.625728189945221,0.2204168289899826,5.574310302734375,26,17.75128746032715,0.5052798986434937,0.6958627700805664,0.1850184202194214,5.1611528396606445,200,18.28720474243164,3.523144483566284,0.77419114112854,0.22147363424301147,6.997479438781738,200,24.396211624145508,0.042355746030807495,0.7892429828643799,0.11021917313337326,3.1949589252471924,197,24.727394104003906,0.08382178097963333,0.831762433052063,0.21840521693229675,9.826996803283691,145,15.711186408996582,-1.1402037143707275,0.7454285621643066,0.4373578727245331,11.309600830078125,200,25.52533721923828,-0.00760778971016407,0.8099644184112549,0.13890613615512848,3.6684834957122803,200,19.276363372802734,-2.6745126247406006,0.6672981381416321,0.21036404371261597,4.968836784362793,200,15.376065254211426,0.2176371067762375,0.5345624089241028,0.28238123655319214,4.76574182510376,200,18.881689071655273,0.1820032000541687,0.6362051367759705,0.22268341481685638,5.450001239776611,200,15.536110877990723,0.5788084268569946,0.6868603229522705,0.2221454679965973,6.200928211212158,200,18.588090896606445,-0.1351645439863205,0.6845595240592957,0.20317442715168,4.309319496154785,142,26.693002700805664,-0.017797470092773438,0.85491943359375,0.12290634214878082,4.107079982757568,200,15.977954864501953,-0.2343152016401291,0.6691921949386597,0.4571930468082428,10.901474952697754,200,19.851110458374023,0.4773397743701935,0.8219828605651855,0.17668022215366364,7.3162055015563965,200,19.206226348876953,0.6738173961639404,0.7266687750816345,0.17411881685256958,5.285086631774902,200,15.480487823486328,-0.05200222134590149,0.5209740400314331,0.30362191796302795,8.292901992797852,200,25.806957244873047,-0.5103187561035156,0.8443582653999329,0.08975253254175186,6.479008197784424,200,0,0,0,0,0,0,21.907487869262695,-0.3219781517982483,0.7322761416435242,0.15371987223625183,7.148579120635986,200,19.59992027282715,-0.06399036198854446,0.8096731305122375,0.2715202569961548,9.681106567382812,185,25.1120662689209,-0.05029723793268204,0.829463541507721,0.14083868265151978,2.2757794857025146,200,28.31503677368164,1.267438292503357,0.9527142643928528,0.0927368700504303,1.3062008619308472,200,20.59347152709961,0.03865854814648628,0.8206701874732971,0.2205837219953537,4.7697625160217285,3,23.66497230529785,0.0852617546916008,0.6408542990684509,0.10270028561353683,8.262017250061035,58,23.514087677001953,0.18420732021331787,0.6972275972366333,0.1637846827507019,5.285017490386963,200,19.399612426757812,0.14159218966960907,0.6437479257583618,0.2207319140434265,6.007476806640625,136,15.523843765258789,0.04549547657370567,0.6679399609565735,0.2731015682220459,8.212705612182617,200,0,0,0,0,0,0,18.983938217163086,-0.7280352711677551,0.8515005111694336,0.16164684295654297,4.466729640960693,60,0,0,0,0,0,0,0,0,0,0,0,0,16.20790672302246,0.4906918406486511,0.5336827039718628,0.28068792819976807,10.018815994262695,63,0
|
8 |
+
18.422204971313477,-0.12648740410804749,0.6362218260765076,0.19450049102306366,5.315305233001709,500,20.882783889770508,-0.07433691620826721,0.6690763831138611,0.1761208325624466,5.387838363647461,500,21.112810134887695,-0.03544004634022713,0.699590802192688,0.18499258160591125,5.969436168670654,198,18.350582122802734,0.12320471554994583,0.8303707838058472,0.1668887436389923,6.814785480499268,268,20.747133255004883,-0.008221889846026897,0.6795586943626404,0.18530716001987457,8.013651847839355,215,13.65688705444336,-0.014612981118261814,0.5605658888816833,0.21683458983898163,5.646177291870117,44,17.740320205688477,-0.01435495913028717,0.6188032031059265,0.20206505060195923,5.870429039001465,26,16.441152572631836,-0.09573374688625336,0.6951659917831421,0.20139959454536438,5.628950119018555,500,18.033527374267578,-0.049468427896499634,0.7758861184120178,0.20358926057815552,7.416662693023682,500,23.77448272705078,-0.004841374699026346,0.7814860939979553,0.11732590198516846,3.6382157802581787,197,23.89436149597168,-0.07056951522827148,0.8185535073280334,0.22474637627601624,10.225882530212402,145,19.539384841918945,-0.2427220195531845,0.8561981320381165,0.23518529534339905,9.096972465515137,500,24.909347534179688,-0.04073095694184303,0.8004875779151917,0.146201491355896,4.055650234222412,500,21.47697639465332,-0.03675007075071335,0.7250679135322571,0.1535453498363495,4.3143134117126465,500,18.370075225830078,-0.002068187575787306,0.6328396201133728,0.19349025189876556,2.4523301124572754,500,18.51810073852539,-0.029367070645093918,0.6309530138969421,0.23161163926124573,5.59820556640625,400,15.171514511108398,-0.08596291393041611,0.6846583485603333,0.2208338975906372,6.640722751617432,500,22.067739486694336,-0.010470133274793625,0.759013831615448,0.13397353887557983,2.6300439834594727,142,26.65119743347168,-0.02217714861035347,0.8541759252548218,0.12250284105539322,4.489205837249756,304,16.662843704223633,0.030067602172493935,0.6663406491279602,0.45978274941444397,10.977742195129395,500,18.961740493774414,-0.03515905141830444,0.8050862550735474,0.18823504447937012,8.212032318115234,290,17.795196533203125,-0.022556299343705177,0.6786172986030579,0.21285778284072876,5.880309581756592,500,18.647926330566406,-0.020784933120012283,0.6541447043418884,0.21047918498516083,6.204220771789551,500,25.266008377075195,-0.13009877502918243,0.8358787298202515,0.09196046739816666,6.571922779083252,500,0,0,0,0,0,0,21.6362247467041,0.043379079550504684,0.7133622169494629,0.16810670495033264,7.582168102264404,500,19.400468826293945,-0.06924106180667877,0.7891562581062317,0.2562311589717865,10.40406322479248,185,24.272430419921875,0.002408565254881978,0.822277307510376,0.14490512013435364,3.1252973079681396,307,21.627735137939453,0.28700390458106995,0.9307183623313904,0.15895913541316986,3.7723772525787354,200,20.32921028137207,0.07073872536420822,0.8144434094429016,0.21742278337478638,5.295686721801758,3,23.026090621948242,-0.008077435195446014,0.63054358959198,0.1302529126405716,8.783021926879883,58,22.135250091552734,-0.020904889330267906,0.6749143004417419,0.18395783007144928,5.7703423500061035,500,18.66996192932129,-0.03733401745557785,0.6296820640563965,0.2478480488061905,6.311908721923828,136,15.479337692260742,-0.004108104854822159,0.6553525328636169,0.28199443221092224,8.236729621887207,475,0,0,0,0,0,0,18.864093780517578,0.1739964336156845,0.8528752326965332,0.1548851728439331,5.3060221672058105,60,0,0,0,0,0,0,0,0,0,0,0,0,15.678476333618164,0.038419850170612335,0.5233398675918579,0.3035353720188141,10.510383605957031,63,0
|
9 |
+
18.889339447021484,-0.1648501455783844,0.6437658667564392,0.1826857179403305,5.060956954956055,500,21.38709831237793,0.1416708528995514,0.6777622103691101,0.16441290080547333,5.186842441558838,500,21.354217529296875,0.0018506277119740844,0.7000029683113098,0.17723077535629272,5.897092342376709,198,19.61304473876953,0.037835195660591125,0.8433040380477905,0.14655180275440216,6.7373223304748535,268,20.715747833251953,0.01130291074514389,0.678013265132904,0.17803651094436646,8.016233444213867,215,15.14285659790039,0.26080322265625,0.5890591144561768,0.18768545985221863,5.359984874725342,44,18.543058395385742,-0.08016975224018097,0.6347747445106506,0.18409821391105652,5.6670966148376465,26,16.97516441345215,-0.1479617804288864,0.702846884727478,0.18982268869876862,5.4662394523620605,500,18.89791488647461,0.24670402705669403,0.7856683135032654,0.18493321537971497,7.272751808166504,500,24.04220199584961,0.019914034754037857,0.7841129302978516,0.11466450989246368,3.5969300270080566,197,25.081993103027344,0.19039775431156158,0.8243004679679871,0.19027777016162872,10.178629875183105,145,20.056028366088867,-0.3057011365890503,0.8603973984718323,0.2097621113061905,9.087891578674316,500,25.227895736694336,-0.03144318237900734,0.8052465319633484,0.1430753916501999,3.845712900161743,500,21.70534896850586,-0.05035046488046646,0.7276560068130493,0.148365318775177,4.039795398712158,500,18.364177703857422,-2.2224783151614247e-06,0.6325255632400513,0.19385698437690735,2.432619333267212,500,18.738195419311523,-0.033635493367910385,0.6359359622001648,0.22341181337833405,5.412669658660889,400,16.01719856262207,0.0007864768267609179,0.7023988366127014,0.2010006606578827,6.480032920837402,500,22.102428436279297,-0.0012201054487377405,0.759409487247467,0.13381606340408325,2.5473294258117676,142,26.880590438842773,-0.02358187735080719,0.8565123081207275,0.12121907621622086,4.459400177001953,304,16.68733787536621,0.07558043301105499,0.6599492430686951,0.4475219249725342,11.352530479431152,500,19.459217071533203,-0.03683672845363617,0.809934139251709,0.16714370250701904,8.145930290222168,290,18.282310485839844,-0.0003248922876082361,0.6865655779838562,0.20002803206443787,5.662166595458984,500,18.83856964111328,-0.03184810280799866,0.6561446785926819,0.20359492301940918,6.031191349029541,500,25.960927963256836,-0.08454426378011703,0.8430694341659546,0.08546590059995651,6.4937214851379395,500,0,0,0,0,0,0,21.94908332824707,-0.3419593274593353,0.7211637496948242,0.16153308749198914,7.5465989112854,500,20.40264892578125,0.017866995185613632,0.8031715750694275,0.22785907983779907,10.408190727233887,185,24.751502990722656,-0.008820587769150734,0.8268305659294128,0.14042264223098755,2.8674750328063965,307,23.79220199584961,0.023781709372997284,0.9390408396720886,0.13440102338790894,3.189241647720337,200,20.65520668029785,-0.02165459282696247,0.8165206909179688,0.21889762580394745,5.296838283538818,3,23.0324764251709,0.130624458193779,0.6325167417526245,0.1328558623790741,8.84994888305664,58,23.088464736938477,0.06231540068984032,0.6918238997459412,0.16935458779335022,5.561310291290283,500,19.111549377441406,0.09548354148864746,0.6398555040359497,0.23092451691627502,5.997190952301025,136,15.769152641296387,0.028782520443201065,0.6602852940559387,0.26686036586761475,8.060165405273438,475,0,0,0,0,0,0,19.613920211791992,-0.03737274929881096,0.8565442562103271,0.13404110074043274,4.84287166595459,60,0,0,0,0,0,0,0,0,0,0,0,0,16.006196975708008,-0.16971516609191895,0.5292651653289795,0.2796056866645813,10.46243953704834,63,0
|
10 |
+
19.012269973754883,-0.04718773066997528,0.6433225870132446,0.18231874704360962,5.107872009277344,500,21.322338104248047,0.0140613978728652,0.6770919561386108,0.16535094380378723,5.265127182006836,500,21.2813777923584,-0.021526703611016273,0.6984917521476746,0.1797674298286438,5.906203269958496,198,19.346487045288086,0.10074374079704285,0.8397706151008606,0.1516176462173462,6.717349529266357,268,20.726825714111328,0.025485752150416374,0.6783157587051392,0.18283355236053467,7.897275447845459,215,14.350356101989746,-0.07720185816287994,0.5780261754989624,0.20371182262897491,5.403542995452881,44,18.631084442138672,0.07769659161567688,0.6344185471534729,0.18611928820610046,5.661814212799072,26,17.056427001953125,-0.1821698099374771,0.7035198211669922,0.18846037983894348,5.462797164916992,500,18.501523971557617,0.08108856528997421,0.7818804383277893,0.19710351526737213,7.370274543762207,500,24.048419952392578,0.013744712807238102,0.7855278849601746,0.11609163880348206,3.6646697521209717,197,24.893909454345703,-0.17591805756092072,0.8268886208534241,0.20372600853443146,10.059407234191895,145,19.865501403808594,-0.22679699957370758,0.8597549796104431,0.2231135070323944,9.073450088500977,500,25.364770889282227,0.010617231950163841,0.8068088889122009,0.14136189222335815,4.029416084289551,500,21.7684383392334,-0.011133561842143536,0.7283687591552734,0.14749296009540558,4.200531482696533,500,18.374786376953125,-0.0007336796843446791,0.6325269341468811,0.1934446394443512,2.4958865642547607,500,18.839426040649414,-0.03957090154290199,0.6344597339630127,0.2224958837032318,5.42381477355957,400,15.805286407470703,0.08770310878753662,0.6977684497833252,0.20588622987270355,6.46218729019165,500,22.08108139038086,-0.01628425344824791,0.7589410543441772,0.13417819142341614,2.693875789642334,142,26.95757484436035,0.0022796352859586477,0.8572258353233337,0.12103652209043503,4.541953086853027,304,16.950056076049805,0.05445929989218712,0.6725068092346191,0.42972901463508606,10.8408203125,500,19.32314682006836,0.034813400357961655,0.810303807258606,0.1743694394826889,8.154923439025879,290,18.345930099487305,0.09873536974191666,0.6867643594741821,0.20015820860862732,5.688977241516113,500,18.867534637451172,0.008000208996236324,0.6561379432678223,0.20259518921375275,6.0598602294921875,500,25.886808395385742,-0.1851750612258911,0.843407154083252,0.08705346286296844,6.513862609863281,500,0,0,0,0,0,0,22.645614624023438,0.03813670575618744,0.7365572452545166,0.14904607832431793,7.366730213165283,500,20.234968185424805,-0.028652122244238853,0.8059276938438416,0.23735074698925018,10.289590835571289,185,24.92990493774414,-0.05894114822149277,0.8283792734146118,0.13925190269947052,3.199906349182129,307,23.42925453186035,-0.41295871138572693,0.9373413920402527,0.14144201576709747,3.7806153297424316,200,20.85039710998535,-0.004692706745117903,0.8196279406547546,0.2168300598859787,5.288863182067871,3,23.50503158569336,-0.028557421639561653,0.6380681991577148,0.1192559152841568,8.550318717956543,58,22.69325828552246,0.08483685553073883,0.6885713338851929,0.17399273812770844,5.622918605804443,500,18.995891571044922,-0.025361914187669754,0.6374244689941406,0.23452074825763702,6.074242115020752,136,15.669060707092285,0.025412963703274727,0.6598301529884338,0.2714788615703583,8.040945053100586,475,0,0,0,0,0,0,19.33572006225586,-0.150547593832016,0.8552142381668091,0.1460384875535965,5.214033126831055,60,0,0,0,0,0,0,0,0,0,0,0,0,16.531322479248047,0.06065846234560013,0.5375595688819885,0.25438711047172546,10.283487319946289,63,0
|
11 |
+
18.73299217224121,-0.13161632418632507,0.6400277018547058,0.18755526840686798,5.157333850860596,500,21.10808753967285,-0.16676977276802063,0.672235906124115,0.1692410707473755,5.296654224395752,500,21.203834533691406,-0.024857260286808014,0.6980849504470825,0.18170055747032166,5.922959327697754,198,19.301345825195312,0.2718823254108429,0.8395335078239441,0.15078316628932953,6.726401329040527,268,20.7464542388916,0.0026871892623603344,0.67714923620224,0.18162284791469574,7.944480895996094,215,14.387572288513184,-0.0007205808651633561,0.5734782814979553,0.20236076414585114,5.430235385894775,44,18.500873565673828,0.08837074786424637,0.6331009864807129,0.18800735473632812,5.786571979522705,26,16.948963165283203,-0.07204114645719528,0.7015774846076965,0.19183480739593506,5.483725547790527,500,18.69580078125,0.08595505356788635,0.7832303643226624,0.18746890127658844,7.323951244354248,500,24.063138961791992,0.040227603167295456,0.7842744588851929,0.11548798531293869,3.7416088581085205,197,24.77498435974121,-0.10164796561002731,0.8261294364929199,0.19882343709468842,10.059757232666016,145,20.417898178100586,0.12687283754348755,0.8632937669754028,0.20207682251930237,9.024741172790527,500,25.293771743774414,-0.013352192007005215,0.8060333728790283,0.1428879052400589,3.982311248779297,500,21.808292388916016,0.05234861373901367,0.7285376191139221,0.14711451530456543,4.239038467407227,500,18.38561248779297,0.0016012159176170826,0.6329339742660522,0.19295348227024078,2.502042531967163,500,18.688907623291016,-0.04777387157082558,0.6325020790100098,0.22649399936199188,5.479030609130859,400,16.001020431518555,0.1790591925382614,0.7002289891242981,0.20206555724143982,6.487364292144775,500,22.093610763549805,-0.01092259306460619,0.7591899633407593,0.13419727981090546,2.800780773162842,142,26.88434410095215,0.008505810052156448,0.8564297556877136,0.1211276724934578,4.6540913581848145,304,16.507360458374023,-0.14780215919017792,0.6714658737182617,0.4597805440425873,10.8370361328125,500,19.03635597229004,-0.03866236284375191,0.8065902590751648,0.179178386926651,8.137178421020508,290,18.207599639892578,0.0390840582549572,0.6847269535064697,0.20280848443508148,5.730045318603516,500,18.824954986572266,0.007314752321690321,0.656466543674469,0.20462752878665924,6.068840980529785,500,25.785615921020508,-0.16119396686553955,0.8409033417701721,0.0872565507888794,6.494687557220459,500,0,0,0,0,0,0,22.554359436035156,0.17164531350135803,0.7343575954437256,0.15075889229774475,7.297224998474121,500,20.00868797302246,-0.20269323885440826,0.8015546202659607,0.2370498925447464,10.263086318969727,185,24.849260330200195,0.014324101619422436,0.8274909257888794,0.13913463056087494,3.273496627807617,307,24.143491744995117,-0.0864129289984703,0.9391130208969116,0.1349482536315918,3.8407604694366455,200,20.61045265197754,-0.044369861483573914,0.817896842956543,0.22000306844711304,5.338271617889404,3,23.173229217529297,0.06823521107435226,0.6345419883728027,0.11324688047170639,8.597124099731445,58,22.52924156188965,-0.005226884037256241,0.6790374517440796,0.17646169662475586,5.648611068725586,500,18.822338104248047,-0.034071650356054306,0.6334968209266663,0.23989541828632355,6.15177059173584,136,15.601283073425293,-0.00634151604026556,0.6596918702125549,0.2742881774902344,8.023792266845703,475,0,0,0,0,0,0,19.87816619873047,-0.21906106173992157,0.8562472462654114,0.12401121854782104,5.19991397857666,60,0,0,0,0,0,0,0,0,0,0,0,0,15.993595123291016,-0.15662315487861633,0.5254011154174805,0.27742305397987366,10.258377075195312,63,0
|
12 |
+
19.110883712768555,0.03614996746182442,0.6461726427078247,0.1795150190591812,5.265549182891846,500,21.108360290527344,-0.14474032819271088,0.6736648678779602,0.16756995022296906,5.4420576095581055,500,21.234786987304688,-0.0323927141726017,0.6997060179710388,0.18202681839466095,6.203959941864014,198,19.807920455932617,0.17171825468540192,0.842271089553833,0.1452997773885727,6.870168685913086,268,20.6026668548584,-0.011249484494328499,0.6755874752998352,0.18391308188438416,8.184722900390625,215,14.93524169921875,-0.0469074510037899,0.5757077932357788,0.19150333106517792,5.594109058380127,44,18.379505157470703,-0.03754724934697151,0.6315571665763855,0.18700724840164185,5.797399520874023,26,17.05131721496582,-0.10010848939418793,0.7025959491729736,0.18903829157352448,5.659502983093262,500,19.137954711914062,0.14187031984329224,0.785306453704834,0.18198907375335693,7.440811634063721,500,23.754358291625977,0.0686846598982811,0.7810045480728149,0.1172272339463234,3.8019583225250244,197,24.170063018798828,-0.11422554403543472,0.8158525824546814,0.21402356028556824,10.47574234008789,145,20.21315574645996,-0.09578195214271545,0.859663188457489,0.20872971415519714,9.281356811523438,500,25.30156707763672,-0.018742039799690247,0.8054926991462708,0.14185506105422974,4.103004455566406,500,21.711626052856445,0.006813677493482828,0.7276286482810974,0.14892178773880005,4.297058582305908,500,18.365007400512695,-0.007851418107748032,0.6329214572906494,0.19421109557151794,2.7666542530059814,500,18.607521057128906,-0.061320219188928604,0.6336644291877747,0.2250998169183731,5.655640125274658,400,15.967201232910156,-0.06718071550130844,0.6980611681938171,0.20418159663677216,6.667073726654053,500,22.090274810791016,-0.0017378028715029359,0.7593238949775696,0.13377660512924194,2.6701622009277344,142,26.750892639160156,-0.07790957391262054,0.8558348417282104,0.12187864631414413,4.571890354156494,304,16.300031661987305,0.010518108494579792,0.6394075155258179,0.46379354596138,11.451972961425781,500,19.428918838500977,0.1081659346818924,0.8055146336555481,0.16705267131328583,8.28512191772461,290,18.22205924987793,-0.02810261771082878,0.6851233839988708,0.20087064802646637,5.899336814880371,500,18.846923828125,-0.021323775872588158,0.6569703221321106,0.20638112723827362,6.310214519500732,500,25.63957977294922,-0.21977517008781433,0.8407760858535767,0.08965358138084412,6.699525833129883,500,0,0,0,0,0,0,22.031572341918945,-0.16287170350551605,0.7234898805618286,0.15870793163776398,7.465205192565918,500,19.877155303955078,0.0820217877626419,0.7918612360954285,0.24016961455345154,10.721766471862793,185,24.740259170532227,0.05753420665860176,0.8264721035957336,0.14053989946842194,3.2130744457244873,307,23.240985870361328,-0.9365432858467102,0.9365918040275574,0.14060433208942413,3.1881415843963623,200,20.191814422607422,-0.015667753294110298,0.8101682662963867,0.22013281285762787,5.518276214599609,3,22.47603416442871,0.04551320895552635,0.6252691149711609,0.11578691005706787,9.172477722167969,58,22.722797393798828,0.1353231817483902,0.6838358640670776,0.17324240505695343,5.799938678741455,500,18.984954833984375,0.008981794118881226,0.6379337906837463,0.23654848337173462,6.328289031982422,136,15.746525764465332,0.024799227714538574,0.6614393591880798,0.27319034934043884,8.200506210327148,475,0,0,0,0,0,0,19.888586044311523,-0.006714705843478441,0.8541179299354553,0.12033425271511078,5.116232872009277,60,0,0,0,0,0,0,0,0,0,0,0,0,15.9192476272583,-0.10522013157606125,0.5274612903594971,0.29392895102500916,10.56518840789795,63,0
|
13 |
+
20.46278190612793,-0.060763970017433167,0.6579670906066895,0.161322221159935,4.680773735046387,200,22.052663803100586,0.5723397135734558,0.6804402470588684,0.1476287543773651,4.742996692657471,200,21.605833053588867,0.23160338401794434,0.6969097852706909,0.171411395072937,5.295865535736084,198,20.8725643157959,0.6837916374206543,0.8545865416526794,0.12993884086608887,6.236001968383789,200,19.573144912719727,0.29120177030563354,0.6459593772888184,0.22614720463752747,9.27774429321289,200,14.133930206298828,0.22245623171329498,0.564282238483429,0.2065575122833252,5.674943447113037,0,18.949764251708984,0.17191997170448303,0.6465899348258972,0.19255967438220978,5.078993797302246,0,17.931673049926758,-0.11049146950244904,0.6982738971710205,0.18147899210453033,5.114457130432129,200,19.872821807861328,3.0028798580169678,0.7939698696136475,0.17656543850898743,6.240903377532959,200,24.46942138671875,0.02408456988632679,0.7884182333946228,0.10927800089120865,2.891601324081421,197,25.840389251708984,0.19366736710071564,0.8314366936683655,0.1776755452156067,9.571022033691406,145,16.929052352905273,-1.5327644348144531,0.7828646302223206,0.3492167592048645,10.210173606872559,200,25.710983276367188,-0.017311187461018562,0.8133223652839661,0.1370222568511963,3.3832216262817383,200,18.50092124938965,-2.022146224975586,0.6764024496078491,0.19759753346443176,4.851006031036377,200,17.02499008178711,0.16653648018836975,0.5973562002182007,0.22383597493171692,3.0246450901031494,200,19.098665237426758,0.004295928869396448,0.638796329498291,0.21742229163646698,5.272642612457275,200,16.54633140563965,0.17737846076488495,0.6975648403167725,0.19392725825309753,5.778793811798096,200,20.060346603393555,0.9734287261962891,0.7293187975883484,0.16429120302200317,3.052821397781372,0,27.030488967895508,-0.08356904238462448,0.8584595322608948,0.12005755305290222,3.672435998916626,0,16.670869827270508,0.08606883883476257,0.6815004348754883,0.4169979691505432,10.184615135192871,0,20.275964736938477,0.1098201647400856,0.8205265402793884,0.15371590852737427,7.1788201332092285,200,19.987890243530273,0.2744450271129608,0.7336406707763672,0.16191206872463226,5.04759407043457,200,17.112150192260742,-0.015052754431962967,0.5950587391853333,0.23803958296775818,6.651583671569824,200,26.6037654876709,-0.26065897941589355,0.8506165146827698,0.08368074893951416,6.122287750244141,200,0,0,0,0,0,0,22.674583435058594,-0.11895836144685745,0.749716579914093,0.14290745556354523,6.852635383605957,200,20.810087203979492,-0.07519318163394928,0.8198818564414978,0.23864984512329102,9.383285522460938,185,25.744356155395508,0.10595372319221497,0.8350962400436401,0.13477018475532532,2.0739593505859375,200,28.547643661499023,0.7507640719413757,0.9549750089645386,0.08899791538715363,1.1312161684036255,200,21.079015731811523,0.029884053394198418,0.8219752907752991,0.21175755560398102,4.344815731048584,3,23.93878173828125,0.05293554812669754,0.6428359746932983,0.08624686300754547,7.878372669219971,58,23.858797073364258,0.06702835857868195,0.6956433653831482,0.15657857060432434,5.089795112609863,200,19.62722396850586,0.05636562407016754,0.650608479976654,0.2117949277162552,5.699061870574951,136,15.717203140258789,0.02669934555888176,0.672042965888977,0.26545801758766174,7.867374420166016,200,0,0,0,0,0,0,21.55535316467285,1.0019360780715942,0.8686189651489258,0.10703979432582855,4.072519302368164,60,0,0,0,0,0,0,0,0,0,0,0,0,16.39645767211914,0.24261629581451416,0.538216233253479,0.2558179199695587,9.674732208251953,63,0
|
14 |
+
20.47854995727539,1.2024067640304565,0.6582081317901611,0.15989229083061218,4.665087699890137,200,22.276348114013672,0.9851991534233093,0.6830069422721863,0.14797556400299072,4.454257488250732,200,21.658119201660156,0.8381527066230774,0.6987597942352295,0.16515423357486725,5.077767848968506,198,21.830154418945312,3.4031131267547607,0.8640546202659607,0.12628214061260223,5.637884140014648,200,19.192798614501953,0.3516661822795868,0.6184120774269104,0.2570144832134247,10.193868637084961,200,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
|
15 |
+
20.590303421020508,0.4962620437145233,0.6601793169975281,0.15753625333309174,4.5785813331604,200,22.237953186035156,0.4489673674106598,0.6821855902671814,0.14713634550571442,4.530069351196289,200,21.67139434814453,0.5253610610961914,0.6990557909011841,0.16455212235450745,5.083188533782959,198,21.511552810668945,2.3355321884155273,0.8608049750328064,0.12826672196388245,5.940978527069092,200,21.198034286499023,0.4582754373550415,0.696334183216095,0.17922863364219666,7.798947811126709,200,14.655498504638672,0.11240727454423904,0.571294903755188,0.19869443774223328,5.37660551071167,44,18.72746467590332,-0.0767064094543457,0.6407261490821838,0.19800549745559692,5.168641090393066,26,18.191055297851562,0.7278627157211304,0.6975538730621338,0.17866922914981842,5.047889232635498,200,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
|
16 |
+
20.47854995727539,1.2024067640304565,0.6582081317901611,0.15989229083061218,4.665087699890137,200,22.276348114013672,0.9851991534233093,0.6830069422721863,0.14797556400299072,4.454257488250732,200,21.658119201660156,0.8381527066230774,0.6987597942352295,0.16515423357486725,5.077767848968506,198,21.830154418945312,3.4031131267547607,0.8640546202659607,0.12628214061260223,5.637884140014648,200,19.192798614501953,0.3516661822795868,0.6184120774269104,0.2570144832134247,10.193868637084961,200,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
|
17 |
+
20.5020809173584,0.5488608479499817,0.6597924828529358,0.16010788083076477,4.721373081207275,200,22.316869735717773,1.2297122478485107,0.6829593181610107,0.1500682681798935,4.489542007446289,200,21.685548782348633,0.7039672136306763,0.6994300484657288,0.1671903133392334,5.135034561157227,198,21.7237548828125,2.7444534301757812,0.8628481030464172,0.12802647054195404,5.7939772605896,200,20.3883113861084,0.08283903449773788,0.6582912802696228,0.216960147023201,8.84621810913086,200,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
|
18 |
+
21.46029465948103,0.10821035103956703,0.7790344667110257,0.1146072332425551,4.8639076042175295,200,22.858641416969302,0.2752841716840785,0.8049575514436705,0.11705265556546775,4.970765066146851,200,22.984534566644847,0.28235093636870023,0.8254322399853716,0.11391306117594352,5.201311178881713,198,21.269003861197834,0.48774929318372584,0.8845344550020229,0.10633078400722959,5.946489062309265,200,21.91150318956801,-0.1983188362236588,0.8238611486475108,0.14899968794744575,6.809058917893304,144,15.47896710908777,-0.21478780557785693,0.766425387992816,0.16411230564964088,5.260037183761597,32,20.303700892418085,0.11696818304556747,0.7842201829523877,0.14423518538136373,5.439062690734863,20,17.87179956351499,-0.3067176983676024,0.8296448574724792,0.12977059845050629,5.341754336357116,200,19.857372164518466,0.05305542383233112,0.8275884621862206,0.1549614530158314,6.752370271682739,200,27.3605180375065,0.04183501128873296,0.8984342113196049,0.06288346686315807,3.3972074127197267,200,26.833560528931578,0.8061492074839005,0.8918025717551408,0.1506120693911968,9.562130060024604,167,15.137075498804533,-5.621109492690501,0.6937904475067103,0.3866491428085349,10.7611030960083,200,27.60886437063087,-0.5479598479597061,0.9037209833626163,0.08412209230221131,4.320116324424744,200,17.608785845649514,-6.357329803241343,0.6815140742522903,0.24886811768551442,5.626099944114685,200,24.33990679792855,-0.10166818044734177,0.8951419864328773,0.0715118810958864,2.0131524705886843,200,21.411267073256052,-0.021985724226232916,0.7967679150871174,0.13873621311716058,4.87985538482666,200,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
|
common/plot/plot_arch_ablation.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import seaborn as sns
|
3 |
+
import matplotlib
|
4 |
+
import matplotlib.pyplot as plt
|
5 |
+
import numpy as np
|
6 |
+
import pandas as pd
|
7 |
+
|
8 |
+
|
9 |
+
|
10 |
+
|
11 |
+
|
12 |
+
import matplotlib.pyplot as plt
|
13 |
+
import numpy as np
|
14 |
+
|
15 |
+
# Sample data based on the provided image structure
|
16 |
+
tasks = [ "Add", "Concat", "Cross Attention", "Modulation"]
|
17 |
+
values = np.array([
|
18 |
+
[6.35],
|
19 |
+
[5.68],
|
20 |
+
[5.26],
|
21 |
+
[5.02],
|
22 |
+
|
23 |
+
# [0.87, 0.55, 0.25, 0.03, 0.01, 0.0]
|
24 |
+
])
|
25 |
+
values = np.exp(values)
|
26 |
+
# Bar colors matching the provided image
|
27 |
+
bar_colors = ['#1f78b4', '#ffffff', '#a6cee3', '#cab2d6', '#b3b3cc', '#33a02c']
|
28 |
+
|
29 |
+
# Plotting the data
|
30 |
+
fig, ax = plt.subplots(figsize=(5, 3))
|
31 |
+
|
32 |
+
# Set bar width and x positions for each group
|
33 |
+
bar_width = 0.4
|
34 |
+
x = np.arange(len(tasks))
|
35 |
+
|
36 |
+
# Plot each group's bars with the specified colors
|
37 |
+
for i in range(values.shape[1]):
|
38 |
+
bars = ax.bar(x + i * bar_width, values[:, i], width=bar_width, color=bar_colors[i], edgecolor='black')
|
39 |
+
|
40 |
+
for container in ax.containers:
|
41 |
+
ax.bar_label(container, label_type="edge", fontsize="x-large", fmt="%.1f")
|
42 |
+
bars[-1].set_color('#cab2d6')
|
43 |
+
bars[-1].set_edgecolor('black')
|
44 |
+
|
45 |
+
# Set titles, labels, and ticks
|
46 |
+
# ax.set_title("Zero-Shot Performance Comparison Across Tasks")
|
47 |
+
ax.set_xlabel("Model", fontsize=14)
|
48 |
+
ax.set_ylabel("Perplexity", fontsize=14)
|
49 |
+
ax.set_xticks(x )
|
50 |
+
ax.tick_params(axis='x', rotation=15)
|
51 |
+
ax.set_xticklabels(tasks, fontsize=12)
|
52 |
+
ax.set_ylim(values.min() - 10, values.max() + 50)
|
53 |
+
|
54 |
+
# Adding the legend outside the plot area
|
55 |
+
# ax.legend(loc='upper center', bbox_to_anchor=(0.5, 1.15), ncol=3)
|
56 |
+
|
57 |
+
# Display the plot
|
58 |
+
plt.tight_layout()
|
59 |
+
# plt.show()
|
60 |
+
plt.savefig("output/arch_ablation.png", dpi=300)
|
common/plot/plot_arch_ablation_deltapsnr.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import seaborn as sns
|
2 |
+
import matplotlib
|
3 |
+
import matplotlib.pyplot as plt
|
4 |
+
import numpy as np
|
5 |
+
import pandas as pd
|
6 |
+
|
7 |
+
# Sample data based on the provided image structure
|
8 |
+
tasks = [ "Add", "Concat", "Cross Attention", "Modulation"]
|
9 |
+
values = np.array([
|
10 |
+
[0.46],
|
11 |
+
[0.18],
|
12 |
+
[0.02],
|
13 |
+
[1.87],
|
14 |
+
])
|
15 |
+
# Bar colors matching the provided image
|
16 |
+
bar_colors = ['#1f78b4', '#a6cee3', '#1f78b4', '#ffffff', '#cab2d6', '#b3b3cc', '#33a02c']
|
17 |
+
|
18 |
+
# Plotting the data
|
19 |
+
fig, ax = plt.subplots(figsize=(5, 3))
|
20 |
+
|
21 |
+
# Set bar width and x positions for each group
|
22 |
+
bar_width = 0.4
|
23 |
+
x = np.arange(len(tasks))
|
24 |
+
|
25 |
+
# Plot each group's bars with the specified colors
|
26 |
+
for i in range(values.shape[1]):
|
27 |
+
bars = ax.bar(x + i * bar_width, values[:, i], width=bar_width, color=bar_colors[i], edgecolor='black')
|
28 |
+
|
29 |
+
for container in ax.containers:
|
30 |
+
ax.bar_label(container, label_type="edge", fontsize="x-large", fmt="%.2f")
|
31 |
+
bars[-1].set_color('#cab2d6')
|
32 |
+
bars[-1].set_edgecolor('black')
|
33 |
+
|
34 |
+
# Set titles, labels, and ticks
|
35 |
+
# ax.set_title("Zero-Shot Performance Comparison Across Tasks")
|
36 |
+
ax.set_xlabel("Model", fontsize=14)
|
37 |
+
ax.set_ylabel("Delta PSNR", fontsize=14)
|
38 |
+
ax.set_xticks(x )
|
39 |
+
ax.tick_params(axis='x', rotation=15)
|
40 |
+
ax.set_xticklabels(tasks, fontsize=12)
|
41 |
+
ax.set_ylim(0, values.max() + 0.2)
|
42 |
+
|
43 |
+
# Adding the legend outside the plot area
|
44 |
+
# ax.legend(loc='upper center', bbox_to_anchor=(0.5, 1.15), ncol=3)
|
45 |
+
|
46 |
+
# Display the plot
|
47 |
+
plt.tight_layout()
|
48 |
+
# plt.show()
|
49 |
+
plt.savefig("output/arch_ablation_controllability.png", dpi=300)
|
common/plot/plot_dataset_scale.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import seaborn as sns
|
2 |
+
import matplotlib
|
3 |
+
import matplotlib.pyplot as plt
|
4 |
+
import numpy as np
|
5 |
+
import pandas as pd
|
6 |
+
# Adjusting the line thickness to better match the provided example
|
7 |
+
|
8 |
+
x = [1, 5, 10]
|
9 |
+
tasks = ["Paper Towel Replacement\n(Bi-UR5e)", "Items in Drawer\n(Franka)",
|
10 |
+
"Stack Bowls\n(UR5e)", "Tupperware in Microwave\n(Bi-ARX)"]
|
11 |
+
|
12 |
+
# Define y-values for each line type
|
13 |
+
y_values = {
|
14 |
+
"π₀": [0.9, 0.85, 0.8],
|
15 |
+
"π₀ (scratch)": [0.7, 0.75, 0.72],
|
16 |
+
"DP": [0.2, 0.3, 0.4],
|
17 |
+
"Octo": [0.5, 0.6, 0.55],
|
18 |
+
"OpenVLA": [0.1, 0.15, 0.2],
|
19 |
+
"ACT": [0.4, 0.5, 0.6]
|
20 |
+
}
|
21 |
+
|
22 |
+
# Define markers, line styles, colors for each line type
|
23 |
+
markers = {"π₀": 'o', "π₀ (scratch)": 'o', "DP": 'o', "Octo": 'D', "OpenVLA": '*', "ACT": 'o'}
|
24 |
+
styles = {"π₀": '-', "π₀ (scratch)": '--', "DP": '-', "Octo": '-', "OpenVLA": '', "ACT": '-'}
|
25 |
+
colors = {"π₀": '#1f78b4', "π₀ (scratch)": '#1f78b4', "DP": '#e31a1c', "Octo": '#33a02c', "OpenVLA": '#6a3d9a', "ACT": '#ff7f00'}
|
26 |
+
|
27 |
+
# Set line width for enhanced visibility
|
28 |
+
|
29 |
+
# Create subplots
|
30 |
+
|
31 |
+
|
32 |
+
fig, ax = plt.subplots( figsize=(5, 4))
|
33 |
+
|
34 |
+
x_values = [5, 10, 20, 30, 40]
|
35 |
+
y_values = [5.94,5.72, 5.21,5.15,5.02]
|
36 |
+
y_values = np.exp(y_values)
|
37 |
+
|
38 |
+
# Set line width for each line plot
|
39 |
+
line_width = 1.5
|
40 |
+
x = []
|
41 |
+
# Iterate over each subplot (task) and plot the lines with specified styles, markers, and adjusted line width
|
42 |
+
|
43 |
+
|
44 |
+
fig, ax1 = plt.subplots(figsize=(5, 4))
|
45 |
+
|
46 |
+
# Plot Perplexity (left y-axis)
|
47 |
+
ax1.plot(x_values, y_values, marker='o', linestyle='-', color='#1f78b4', linewidth=line_width)
|
48 |
+
ax1.annotate(f"{y_values[-1]:.1f}", (x_values[-1], y_values[-1]), textcoords="offset points", xytext=(0, 10), ha='center')
|
49 |
+
ax1.set_xscale('log')
|
50 |
+
ax1.set_xlabel("# Dataset", fontsize=14)
|
51 |
+
ax1.set_ylabel("Perplexity", fontsize=14, color='#1f78b4')
|
52 |
+
ax1.tick_params(axis='y', labelcolor='#1f78b4')
|
53 |
+
|
54 |
+
# Create a twin y-axis for controllability (right y-axis)
|
55 |
+
ax2 = ax1.twinx()
|
56 |
+
controllability_values = [ 0.46, 0.55, 1.69, 1.5, 1.87] # Example values for controllability
|
57 |
+
ax2.plot(x_values, controllability_values, marker='s', linestyle='--', color='#006400', linewidth=line_width)
|
58 |
+
ax2.set_ylabel("Delta PSNR", fontsize=14, color='#006400')
|
59 |
+
ax2.set_ylim(0, 2.1)
|
60 |
+
ax2.tick_params(axis='y', labelcolor='#006400')
|
61 |
+
ax2.annotate(f"{controllability_values[-1]:.1f}", (x_values[-1], controllability_values[-1]), textcoords="offset points", xytext=(0, 10), ha='center')
|
62 |
+
|
63 |
+
# Save the figure in high resolution
|
64 |
+
plt.tight_layout()
|
65 |
+
# plt.show()
|
66 |
+
|
67 |
+
|
68 |
+
plt.savefig(f"output/dataset_sizes.png", dpi=300) # Save the figure in high resolution
|
69 |
+
|
common/plot/plot_dataset_traj_scale.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import seaborn as sns
|
2 |
+
import matplotlib
|
3 |
+
import matplotlib.pyplot as plt
|
4 |
+
import numpy as np
|
5 |
+
import pandas as pd
|
6 |
+
# Adjusting the line thickness to better match the provided example
|
7 |
+
|
8 |
+
fig, ax = plt.subplots( figsize=(5, 4))
|
9 |
+
|
10 |
+
x_values = [8287, 77664, 532150,1126876,2070965,3163485]
|
11 |
+
y_values = [9.46, 6.94, 5.81, 5.70, 5.09, 5.02]
|
12 |
+
y_values = np.exp(y_values)
|
13 |
+
|
14 |
+
# Set line width for each line plot
|
15 |
+
line_width = 1.5
|
16 |
+
x = []
|
17 |
+
# Iterate over each subplot (task) and plot the lines with specified styles, markers, and adjusted line width
|
18 |
+
# for i, task in enumerate(tasks):
|
19 |
+
|
20 |
+
# Adding a centralized legend that appears above the plot
|
21 |
+
# fig.legend(y_values, loc='upper center', bbox_to_anchor=(0.5, 1.05), ncol=3, frameon=False, markerscale=1.5)
|
22 |
+
|
23 |
+
fig, ax1 = plt.subplots(figsize=(5, 4))
|
24 |
+
|
25 |
+
# Plot Perplexity (left y-axis)
|
26 |
+
ax1.plot(x_values, y_values, marker='o', linestyle='-', color='#1f78b4', linewidth=line_width)
|
27 |
+
ax1.annotate(f"{y_values[-1]:.1f}", (x_values[-1], y_values[-1]), textcoords="offset points", xytext=(0, 10), ha='center')
|
28 |
+
ax1.set_xscale('log')
|
29 |
+
ax1.set_xlabel("# Trajectory", fontsize=14)
|
30 |
+
ax1.set_ylabel("Perplexity", fontsize=14, color='#1f78b4')
|
31 |
+
ax1.tick_params(axis='y', labelcolor='#1f78b4')
|
32 |
+
|
33 |
+
# Create a twin y-axis for controllability (right y-axis)
|
34 |
+
ax2 = ax1.twinx()
|
35 |
+
controllability_values = [0.,0.10,1.20,1.41,1.56, 1.87] # Example values for controllability
|
36 |
+
ax2.plot(x_values, controllability_values, marker='s', linestyle='--', color='#006400', linewidth=line_width)
|
37 |
+
ax2.set_ylabel("Delta PSNR", fontsize=14, color='#006400')
|
38 |
+
ax2.annotate(f"{controllability_values[-1]:.1f}", (x_values[-1], controllability_values[-1]), textcoords="offset points", xytext=(0, 10), ha='center')
|
39 |
+
|
40 |
+
ax2.set_ylim(0, 2.1)
|
41 |
+
ax2.tick_params(axis='y', labelcolor='#006400')
|
42 |
+
|
43 |
+
# Save the figure in high resolution
|
44 |
+
plt.tight_layout()
|
45 |
+
#plt.show()
|
46 |
+
|
47 |
+
plt.savefig(f"output/traj_sizes.png", dpi=300)
|
48 |
+
|
common/plot/plot_dynamics_ablation.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import seaborn as sns
|
2 |
+
import matplotlib
|
3 |
+
import matplotlib.pyplot as plt
|
4 |
+
import numpy as np
|
5 |
+
import pandas as pd
|
6 |
+
|
7 |
+
|
8 |
+
|
9 |
+
|
10 |
+
|
11 |
+
|
12 |
+
import matplotlib.pyplot as plt
|
13 |
+
import numpy as np
|
14 |
+
|
15 |
+
# Sample data based on the provided image structure
|
16 |
+
tasks = ["Passive Dynamics", "Full Dynamics", "Forward Dynamics"]
|
17 |
+
bar_labels = ["Passive Dynamics", "Full Dynamics", "Forward Dynamics"]
|
18 |
+
values = np.array([
|
19 |
+
[6.29],
|
20 |
+
[5.21],
|
21 |
+
[5.02],
|
22 |
+
])
|
23 |
+
values = np.exp(values)
|
24 |
+
# Bar colors matching the provided image
|
25 |
+
bar_colors = ['#a6cee3', '#ffffff', '#a6cee3', '#cab2d6', '#b3b3cc', '#33a02c']
|
26 |
+
|
27 |
+
# Plotting the data
|
28 |
+
fig, ax = plt.subplots(figsize=(5, 3))
|
29 |
+
|
30 |
+
# Set bar width and x positions for each group
|
31 |
+
bar_width = 0.4
|
32 |
+
x = np.arange(len(tasks))
|
33 |
+
|
34 |
+
# Plot each group's bars with the specified colors
|
35 |
+
for i in range(values.shape[1]):
|
36 |
+
bars = ax.bar(x + i * bar_width, values[:, i], width=bar_width, color=bar_colors[i], edgecolor='black')
|
37 |
+
bars[-1].set_color('#cab2d6')
|
38 |
+
bars[-1].set_edgecolor('black')
|
39 |
+
for container in ax.containers:
|
40 |
+
ax.bar_label(container, label_type="edge", fontsize="x-large", fmt="%.1f")
|
41 |
+
|
42 |
+
# Set titles, labels, and ticks
|
43 |
+
# ax.set_title("Zero-Shot Performance Comparison Across Tasks")
|
44 |
+
ax.set_xlabel("Model", fontsize=14)
|
45 |
+
ax.set_ylabel("Perplexity", fontsize=14)
|
46 |
+
ax.set_xticks(x )
|
47 |
+
ax.set_xticklabels(tasks, fontsize=12)
|
48 |
+
ax.set_ylim(values.min() - 10, values.max() + 50)
|
49 |
+
ax.tick_params(axis='x', rotation=15)
|
50 |
+
# Adding the legend outside the plot area
|
51 |
+
# ax.legend(loc='upper center', bbox_to_anchor=(0.5, 1.15), ncol=3)
|
52 |
+
|
53 |
+
# Display the plot
|
54 |
+
plt.tight_layout()
|
55 |
+
# plt.show()
|
56 |
+
plt.savefig("output/dynamics_ablation.png", dpi=300)
|
common/plot/plot_dynamics_ablation_deltapsnr.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import seaborn as sns
|
2 |
+
import matplotlib
|
3 |
+
import matplotlib.pyplot as plt
|
4 |
+
import numpy as np
|
5 |
+
import pandas as pd
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
+
# Sample data based on the provided image structure
|
9 |
+
tasks = ["Passive Dynamics", "Full Dynamics", "Forward Dynamics"]
|
10 |
+
bar_labels = ["Passive Dynamics", "Full Dynamics", "Forward Dynamics"]
|
11 |
+
values = np.array([
|
12 |
+
[0.33],
|
13 |
+
[1.23],
|
14 |
+
[1.87],
|
15 |
+
# [0.87, 0.55, 0.25, 0.03, 0.01, 0.0]
|
16 |
+
])
|
17 |
+
# Bar colors matching the provided image
|
18 |
+
bar_colors = ['#a6cee3', '#1f78b4', '#ffffff', '#cab2d6', '#b3b3cc', '#33a02c']
|
19 |
+
|
20 |
+
# Plotting the data
|
21 |
+
fig, ax = plt.subplots(figsize=(5, 3))
|
22 |
+
|
23 |
+
# Set bar width and x positions for each group
|
24 |
+
bar_width = 0.4
|
25 |
+
x = np.arange(len(tasks))
|
26 |
+
|
27 |
+
# Plot each group's bars with the specified colors
|
28 |
+
for i in range(values.shape[1]):
|
29 |
+
bars = ax.bar(x + i * bar_width, values[:, i], width=bar_width, color=bar_colors[i], edgecolor='black')
|
30 |
+
|
31 |
+
bars[-1].set_color('#cab2d6')
|
32 |
+
bars[-1].set_edgecolor('black')
|
33 |
+
|
34 |
+
for container in ax.containers:
|
35 |
+
ax.bar_label(container, label_type="edge", fontsize="x-large", fmt="%.2f")
|
36 |
+
|
37 |
+
# Set titles, labels, and ticks
|
38 |
+
# ax.set_title("Zero-Shot Performance Comparison Across Tasks")
|
39 |
+
ax.set_xlabel("Model", fontsize=14)
|
40 |
+
ax.set_ylabel("Delta PSNR", fontsize=14)
|
41 |
+
ax.set_xticks(x )
|
42 |
+
ax.set_xticklabels(tasks, fontsize=12)
|
43 |
+
ax.set_ylim(values.min() - 0.1, values.max() + 0.2)
|
44 |
+
ax.tick_params(axis='x', rotation=15)
|
45 |
+
# Adding the legend outside the plot area
|
46 |
+
# ax.legend(loc='upper center', bbox_to_anchor=(0.5, 1.15), ncol=3)
|
47 |
+
|
48 |
+
# Display the plot
|
49 |
+
plt.tight_layout()
|
50 |
+
# plt.show()
|
51 |
+
plt.savefig("output/dynamics_ablation_controllability.png", dpi=300)
|
common/plot/plot_from_wandb.py
ADDED
@@ -0,0 +1,185 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import wandb
|
2 |
+
import pandas as pd
|
3 |
+
import matplotlib.pyplot as plt
|
4 |
+
import sys
|
5 |
+
import argparse
|
6 |
+
import os
|
7 |
+
|
8 |
+
"""
|
9 |
+
Running plotting scripts over key metrics and key runs
|
10 |
+
export MODEL=final2_40dataset_waction_concat_gpu_8_nodes_1
|
11 |
+
python common/plot/plot_from_wandb.py --run_id $MODEL
|
12 |
+
python common/plot/plot_from_wandb.py --run_id final2_40dataset_noaction_gpu_8_nodes_1_step15k_v5
|
13 |
+
python common/plot/plot_from_wandb.py --run_id final2_40dataset_waction_modulate_gpu_8_nodes_1_step15k_v5
|
14 |
+
python common/plot/plot_from_wandb.py --run_id final2_40dataset_waction_attn_gpu_8_nodes_1_step15k_v5
|
15 |
+
python common/plot/plot_from_wandb.py --run_id final2_40dataset_waction_add_gpu_8_nodes_1_step15k_v5
|
16 |
+
python common/plot/plot_from_wandb.py --run_id final2_40dataset_waction_d64_gpu_8_nodes_1_step15k_v5
|
17 |
+
python common/plot/plot_from_wandb.py --run_id final2_40dataset_forward_dynamics_gpu_8_nodes_1_step15k_v5
|
18 |
+
python common/plot/plot_from_wandb.py --run_id final2_40dataset_full_dynamics_gpu_8_nodes_1_step15k_v5
|
19 |
+
python common/plot/plot_from_wandb.py --run_id final2_40dataset_waction_traj100000_gpu_8_nodes_1_68536steps_step15k_v5
|
20 |
+
python common/plot/plot_from_wandb.py --run_id final2_40dataset_waction_traj10000_gpu_8_nodes_1_68536steps_step15k_v5
|
21 |
+
python common/plot/plot_from_wandb.py --run_id final2_40dataset_waction_traj100_gpu_8_nodes_1_68536steps_step15k_v5
|
22 |
+
python common/plot/plot_from_wandb.py --run_id final2_40dataset_waction_traj1000_gpu_8_nodes_1_68536steps_step15k_v5
|
23 |
+
python common/plot/plot_from_wandb.py --run_id final2_5dataset_waction_gpu_8_nodes_1_step24k_v5
|
24 |
+
python common/plot/plot_from_wandb.py --run_id final2_30dataset_waction_gpu_8_nodes_1_step24k_v5
|
25 |
+
python common/plot/plot_from_wandb.py --run_id final2_5dataset_waction_gpu_8_nodes_1_step24k_v5
|
26 |
+
python common/plot/plot_from_wandb.py --run_id final2_10dataset_waction_gpu_8_nodes_1_step24k_v5
|
27 |
+
|
28 |
+
"""
|
29 |
+
# Initialize the wandb API client
|
30 |
+
api = wandb.Api()
|
31 |
+
pwd = os.path.dirname(os.path.abspath(__file__))
|
32 |
+
|
33 |
+
# Replace with your specific project and entity
|
34 |
+
entity = "latent-mage"
|
35 |
+
project = "video_val"
|
36 |
+
|
37 |
+
# List of datasets to process
|
38 |
+
datasets = [
|
39 |
+
"bridge_data_v2",
|
40 |
+
"fractal20220817_data",
|
41 |
+
"language_table",
|
42 |
+
"ucsd_pick_and_place_dataset_converted_externally_to_rlds",
|
43 |
+
"kaist_nonprehensile_converted_externally_to_rlds",
|
44 |
+
"ucsd_kitchen_dataset_converted_externally_to_rlds",
|
45 |
+
"utokyo_xarm_bimanual_converted_externally_to_rlds",
|
46 |
+
"stanford_hydra_dataset_converted_externally_to_rlds",
|
47 |
+
"austin_sirius_dataset_converted_externally_to_rlds",
|
48 |
+
"berkeley_fanuc_manipulation",
|
49 |
+
"berkeley_mvp_converted_externally_to_rlds",
|
50 |
+
"berkeley_rpt_converted_externally_to_rlds",
|
51 |
+
"cmu_play_fusion",
|
52 |
+
"iamlab_cmu_pickup_insert_converted_externally_to_rlds",
|
53 |
+
"qut_dexterous_manpulation",
|
54 |
+
"robo_net",
|
55 |
+
"furniture_bench_dataset_converted_externally_to_rlds",
|
56 |
+
"dlr_sara_grid_clamp_converted_externally_to_rlds",
|
57 |
+
"cmu_stretch",
|
58 |
+
"spoc",
|
59 |
+
"columbia_cairlab_pusht_real",
|
60 |
+
"droid",
|
61 |
+
"toto",
|
62 |
+
"io_ai_tech",
|
63 |
+
"conq_hose_manipulation",
|
64 |
+
"dobbe",
|
65 |
+
"berkeley_gnm_cory_hall",
|
66 |
+
"plex_robosuite",
|
67 |
+
"usc_cloth_sim_converted_externally_to_rlds",
|
68 |
+
"berkeley_cable_routing",
|
69 |
+
"imperial_wrist_dataset",
|
70 |
+
"bc_z",
|
71 |
+
"kuka",
|
72 |
+
"roboturk",
|
73 |
+
"metaworld",
|
74 |
+
"robomimic",
|
75 |
+
"epic_kitchen",
|
76 |
+
"ego4d",
|
77 |
+
"nyu_door_opening_surprising_effectiveness"
|
78 |
+
]
|
79 |
+
|
80 |
+
def normalize_dataset(metric, runs):
|
81 |
+
"""
|
82 |
+
Figure out best and worst values for a metric across all runs
|
83 |
+
and use it for normalization
|
84 |
+
"""
|
85 |
+
pass
|
86 |
+
|
87 |
+
# List to store dataframes of PSNR metrics for each dataset
|
88 |
+
metrics_data = []
|
89 |
+
# Get runs based on a path
|
90 |
+
# Set up argument parser
|
91 |
+
parser = argparse.ArgumentParser(description='Process some integers.')
|
92 |
+
parser.add_argument('--run_id', type=str, default='40dataset_waction_add_gpu_8_nodes_1', help='The run ID to process')
|
93 |
+
|
94 |
+
# Parse arguments
|
95 |
+
args = parser.parse_args()
|
96 |
+
|
97 |
+
fields = ['num_examples', 'teacher_force_psnr', 'teacher_force_psnr_delta', 'teacher_force_ssim', 'teacher_force_pred_lpips', 'teacher_force_loss']
|
98 |
+
num_fields = len(fields)
|
99 |
+
run_id = args.run_id
|
100 |
+
|
101 |
+
runs_path = f"{entity}/{project}/runs"
|
102 |
+
run = api.run(f"{entity}/{project}/runs/{run_id}")
|
103 |
+
|
104 |
+
# Get the history dataframe of a run
|
105 |
+
history = run.history(pandas=True)
|
106 |
+
model_step = 0
|
107 |
+
summary_metrics = run.summary
|
108 |
+
num_datasets = 0
|
109 |
+
|
110 |
+
# output the field into csv
|
111 |
+
# csv_output = f"{pwd}/aggregated_output.csv"
|
112 |
+
csv_output = f"aggregated_output.csv"
|
113 |
+
|
114 |
+
# initialize the csv file
|
115 |
+
if not os.path.exists(csv_output):
|
116 |
+
with open(csv_output, 'w') as f:
|
117 |
+
field_str = f"name,"
|
118 |
+
for dataset in datasets:
|
119 |
+
for field in fields:
|
120 |
+
field_str += f"{dataset}/{field},"
|
121 |
+
f.write(field_str.rstrip(",") + "\n")
|
122 |
+
|
123 |
+
results = [run_id] + [None] * len(datasets) * num_fields
|
124 |
+
for field_idx, field in enumerate(fields):
|
125 |
+
if not history.empty:
|
126 |
+
# Filter the history to only include PSNR metrics for the specified datasets
|
127 |
+
for dataset_idx, dataset in enumerate(datasets):
|
128 |
+
field_col = f"{dataset}/{field}"
|
129 |
+
col_idx = dataset_idx * num_fields + field_idx + 1
|
130 |
+
if field == "num_examples":
|
131 |
+
if f"{dataset}/num_examples" in summary_metrics:
|
132 |
+
results[col_idx] = summary_metrics[f"{dataset}/num_examples"]
|
133 |
+
|
134 |
+
continue
|
135 |
+
if field_col in history.columns:
|
136 |
+
# Calculate PSNR divided by the number of examples (uncomment if needed)
|
137 |
+
# history[field_col] = history[field_col] / history.shape[0]
|
138 |
+
valid_field = history[field_col].dropna()
|
139 |
+
if not valid_field.empty:
|
140 |
+
last_valid_value = valid_field.iloc[-1] # Get the last non-NaN value
|
141 |
+
num_datasets += 1
|
142 |
+
metrics = pd.DataFrame({field_col: [last_valid_value]})
|
143 |
+
metrics['dataset'] = dataset
|
144 |
+
results[col_idx] = last_valid_value
|
145 |
+
metrics_data.append(metrics)
|
146 |
+
else:
|
147 |
+
pass
|
148 |
+
# print("missing dataset:", dataset)
|
149 |
+
|
150 |
+
if f"{dataset}/model_step" in summary_metrics:
|
151 |
+
model_step = summary_metrics[f"{dataset}/model_step"]
|
152 |
+
|
153 |
+
# Combine all the metric dataframes into one
|
154 |
+
if metrics_data:
|
155 |
+
all_metrics_df = pd.concat(metrics_data, ignore_index=True)
|
156 |
+
|
157 |
+
# # Compute aggregated statistics (mean, median, std, etc.) for PSNR
|
158 |
+
# aggregated_stats = all_metrics_df.groupby('dataset').mean()
|
159 |
+
#
|
160 |
+
# # Plot the mean PSNR for each dataset
|
161 |
+
# plt.figure(figsize=(12, 8))
|
162 |
+
# aggregated_stats[f'{field}'] = aggregated_stats.mean(axis=1)
|
163 |
+
# aggregated_stats[f'{field}'].plot(kind='bar')
|
164 |
+
# # print number of steps in the wandb run
|
165 |
+
# print(f"run: {run_id} field: {field} steps: {model_step} num of dataset: {len(metrics_data)}")
|
166 |
+
# print(f"{field}: {aggregated_stats[field].mean():.2f}+-{aggregated_stats[field].std():.2f}", )
|
167 |
+
#
|
168 |
+
# plt.title(f"Mean {field} for Each Dataset")
|
169 |
+
# plt.xlabel("Dataset")
|
170 |
+
# plt.ylabel(f"Mean {field} ")
|
171 |
+
# plt.xticks(rotation=90)
|
172 |
+
# plt.tight_layout()
|
173 |
+
#
|
174 |
+
# # Save the plot
|
175 |
+
# plt.savefig(f"{pwd}/output/{run.id}_{field}_plot.png")
|
176 |
+
|
177 |
+
# write the results into csv
|
178 |
+
with open(csv_output, 'a+') as f:
|
179 |
+
f.write(",".join([str(x) for x in results]) + "\n")
|
180 |
+
|
181 |
+
# Display aggregated statistics
|
182 |
+
# print(aggregated_stats)
|
183 |
+
|
184 |
+
# Save the aggregated statistics as a CSV if needed
|
185 |
+
# aggregated_stats.to_csv(f"{run_id}_{field}_stat.csv", index=True)
|
common/plot/plot_from_wandb_singledataset.py
ADDED
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import wandb
|
2 |
+
import pandas as pd
|
3 |
+
import matplotlib.pyplot as plt
|
4 |
+
import sys
|
5 |
+
import argparse
|
6 |
+
|
7 |
+
"""
|
8 |
+
Running plotting scripts over key metrics and key runs
|
9 |
+
export MODEL=40dataset_waction_add_gpu_8_nodes_1
|
10 |
+
python common/plot/plot_from_wandb.py --field teacher_force_psnr --run_id $MODEL
|
11 |
+
python common/plot/plot_from_wandb.py --field teacher_force_psnr_delta --run_id $MODEL
|
12 |
+
python common/plot/plot_from_wandb.py --field teacher_force_ssim --run_id $MODEL
|
13 |
+
python common/plot/plot_from_wandb.py --field teacher_force_pred_lpips --run_id $MODEL
|
14 |
+
python common/plot/plot_from_wandb.py --field teacher_force_loss --run_id $MODEL
|
15 |
+
|
16 |
+
"""
|
17 |
+
# Initialize the wandb API client
|
18 |
+
api = wandb.Api()
|
19 |
+
|
20 |
+
# Replace with your specific project and entity
|
21 |
+
entity = "latent-mage"
|
22 |
+
project = "video_val"
|
23 |
+
|
24 |
+
# List of datasets to process
|
25 |
+
datasets = [
|
26 |
+
"bridge_data_v2",
|
27 |
+
"fractal20220817_data",
|
28 |
+
"language_table",
|
29 |
+
"ucsd_pick_and_place_dataset_converted_externally_to_rlds",
|
30 |
+
"kaist_nonprehensile_converted_externally_to_rlds",
|
31 |
+
"ucsd_kitchen_dataset_converted_externally_to_rlds",
|
32 |
+
"utokyo_xarm_bimanual_converted_externally_to_rlds",
|
33 |
+
"stanford_hydra_dataset_converted_externally_to_rlds",
|
34 |
+
"austin_sirius_dataset_converted_externally_to_rlds",
|
35 |
+
"berkeley_fanuc_manipulation",
|
36 |
+
"berkeley_mvp_converted_externally_to_rlds",
|
37 |
+
"berkeley_rpt_converted_externally_to_rlds",
|
38 |
+
"cmu_play_fusion",
|
39 |
+
"iamlab_cmu_pickup_insert_converted_externally_to_rlds",
|
40 |
+
"qut_dexterous_manpulation",
|
41 |
+
"robo_net",
|
42 |
+
"furniture_bench_dataset_converted_externally_to_rlds",
|
43 |
+
"dlr_sara_grid_clamp_converted_externally_to_rlds",
|
44 |
+
"cmu_stretch",
|
45 |
+
"spoc",
|
46 |
+
"columbia_cairlab_pusht_real",
|
47 |
+
"droid",
|
48 |
+
"toto",
|
49 |
+
"io_ai_tech",
|
50 |
+
"conq_hose_manipulation",
|
51 |
+
"dobbe",
|
52 |
+
"berkeley_gnm_cory_hall",
|
53 |
+
"plex_robosuite",
|
54 |
+
"usc_cloth_sim_converted_externally_to_rlds",
|
55 |
+
"berkeley_cable_routing",
|
56 |
+
"imperial_wrist_dataset",
|
57 |
+
"bc_z",
|
58 |
+
"kuka",
|
59 |
+
"roboturk",
|
60 |
+
"metaworld",
|
61 |
+
"robomimic",
|
62 |
+
"epic_kitchen",
|
63 |
+
"ego4d",
|
64 |
+
"nyu_door_opening_surprising_effectiveness"
|
65 |
+
]
|
66 |
+
|
67 |
+
# List to store dataframes of PSNR metrics for each dataset
|
68 |
+
|
69 |
+
# Get runs based on a path
|
70 |
+
# Set up argument parser
|
71 |
+
parser = argparse.ArgumentParser(description='Process some integers.')
|
72 |
+
parser.add_argument('--field', type=str, default='teacher_force_psnr', help='The field to process')
|
73 |
+
parser.add_argument('--run_id', type=str, default='40dataset_waction_add_gpu_8_nodes_1', help='The run ID to process')
|
74 |
+
|
75 |
+
# Parse arguments
|
76 |
+
args = parser.parse_args()
|
77 |
+
|
78 |
+
field = args.field
|
79 |
+
run_id = args.run_id
|
80 |
+
|
81 |
+
runs_path = f"{entity}/{project}/runs"
|
82 |
+
run = api.run(f"{entity}/{project}/runs/{run_id}")
|
83 |
+
|
84 |
+
# Get the history dataframe of a run
|
85 |
+
history = run.history(pandas=True)
|
86 |
+
model_step = 0
|
87 |
+
summary_metrics = run.summary
|
88 |
+
num_datasets = 0
|
89 |
+
fields = ['num_examples', 'teacher_force_psnr', 'teacher_force_psnr_delta', 'teacher_force_ssim', 'teacher_force_pred_lpips', 'teacher_force_loss']
|
90 |
+
|
91 |
+
for field in fields:
|
92 |
+
metrics_data = []
|
93 |
+
if not history.empty:
|
94 |
+
# Filter the history to only include PSNR metrics for the specified datasets
|
95 |
+
for dataset in datasets:
|
96 |
+
field_col = f"{dataset}/{field}"
|
97 |
+
step_col = f"{dataset}/model_step"
|
98 |
+
if field_col in history.columns:
|
99 |
+
# Calculate PSNR divided by the number of examples (uncomment if needed)
|
100 |
+
# history[field_col] = history[field_col] / history.shape[0]
|
101 |
+
valid_field = history[field_col].dropna()
|
102 |
+
if not valid_field.empty:
|
103 |
+
last_valid_value = valid_field.iloc[-1] # Get the last non-NaN value
|
104 |
+
num_datasets += 1
|
105 |
+
metrics = pd.DataFrame({field_col: [last_valid_value]})
|
106 |
+
metrics['dataset'] = dataset
|
107 |
+
metrics_data.append(metrics)
|
108 |
+
|
109 |
+
if step_col in summary_metrics:
|
110 |
+
model_step = summary_metrics[step_col]
|
111 |
+
|
112 |
+
# Combine all the metric dataframes into one
|
113 |
+
if metrics_data:
|
114 |
+
all_metrics_df = pd.concat(metrics_data, ignore_index=True)
|
115 |
+
|
116 |
+
# Print columns for debugging
|
117 |
+
|
118 |
+
# Compute aggregated statistics (mean, median, std, etc.) for PSNR
|
119 |
+
aggregated_stats = all_metrics_df.groupby('dataset').mean()
|
120 |
+
|
121 |
+
# Plot the mean PSNR for each dataset
|
122 |
+
plt.figure(figsize=(12, 8))
|
123 |
+
aggregated_stats[f'{field}'] = aggregated_stats.mean(axis=1)
|
124 |
+
aggregated_stats[f'{field}'].plot(kind='bar')
|
125 |
+
# print number of steps in the wandb run
|
126 |
+
print(f"run: {run_id} field: {field} steps: {model_step} num of dataset: {len(metrics_data)}")
|
127 |
+
print(f"{field}: {aggregated_stats[field].mean():.2f}+-{aggregated_stats[field].std():.2f}", )
|
128 |
+
|
129 |
+
# plt.title(f"Mean {field} for Each Dataset")
|
130 |
+
# plt.xlabel("Dataset")
|
131 |
+
# plt.ylabel(f"Mean {field} ")
|
132 |
+
# plt.xticks(rotation=90)
|
133 |
+
# plt.tight_layout()
|
134 |
+
|
135 |
+
# # Save the plot
|
136 |
+
# import os
|
137 |
+
# pwd = os.path.dirname(os.path.abspath(__file__))
|
138 |
+
# plt.savefig(f"{pwd}/output/{run.id}_{field}_plot.png")
|
139 |
+
|
140 |
+
# Display aggregated statistics
|
141 |
+
# print(aggregated_stats)
|
142 |
+
|
143 |
+
# Save the aggregated statistics as a CSV if needed
|
144 |
+
# aggregated_stats.to_csv(f"{run_id}_{field}_stat.csv", index=True)
|
common/plot/plot_model_scale.py
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import seaborn as sns
|
2 |
+
import matplotlib
|
3 |
+
import matplotlib.pyplot as plt
|
4 |
+
import numpy as np
|
5 |
+
import pandas as pd
|
6 |
+
# Adjusting the line thickness to better match the provided example
|
7 |
+
|
8 |
+
|
9 |
+
|
10 |
+
fig, ax = plt.subplots( figsize=(5, 4))
|
11 |
+
# 64, 64sqrt2, 128, 128sqrt2, 256, 256sqrt2, 512
|
12 |
+
x_values = [2.8, 10.2, 36.9, 174.22, 366.9, 755.9] #
|
13 |
+
y_values = [6.33, 5.52, 5.25, 5.19, 5.02, 5.02] # , 5.09
|
14 |
+
y_values = np.exp(y_values)
|
15 |
+
# 256sqrt2->700m
|
16 |
+
# 512->1.3billion
|
17 |
+
|
18 |
+
# Set line width for each line plot
|
19 |
+
line_width = 1.5
|
20 |
+
x = []
|
21 |
+
# Iterate over each subplot (task) and plot the lines with specified styles, markers, and adjusted line width
|
22 |
+
# for i, task in enumerate(tasks):
|
23 |
+
|
24 |
+
# ax.plot(x_values, y_values, marker='o', linestyle='--', color='#1f78b4', linewidth=line_width)
|
25 |
+
# # for i, txt in enumerate(y_values):
|
26 |
+
# # ax.annotate(f"{txt:.1f}", (x_values[i], y_values[i]), textcoords="offset points", xytext=(0,10), ha='center')
|
27 |
+
# ax.annotate(f"{y_values[-1]:.1f}", (x_values[-1], y_values[-1]), textcoords="offset points", xytext=(0,10), ha='center')
|
28 |
+
|
29 |
+
|
30 |
+
# # Set individual titles and axis labels for each subplot
|
31 |
+
|
32 |
+
# ax.set_xlabel("Model Parameters(M)", fontsize=14)
|
33 |
+
# ax.set_ylabel("Perplexity", fontsize=14)
|
34 |
+
# ax.set_ylim(0, 1)
|
35 |
+
fig, ax1 = plt.subplots(figsize=(5, 4))
|
36 |
+
|
37 |
+
INDEX = -2
|
38 |
+
# Plot Perplexity (left y-axis)
|
39 |
+
ax1.plot(x_values, y_values, marker='o', linestyle='-', color='#1f78b4', linewidth=line_width)
|
40 |
+
ax1.annotate(f"{y_values[INDEX]:.1f}", (x_values[INDEX], y_values[INDEX]), textcoords="offset points", xytext=(0, 10), ha='center')
|
41 |
+
ax1.set_xscale('log')
|
42 |
+
ax1.set_xlabel("Model Parameters(M)", fontsize=14)
|
43 |
+
ax1.set_ylabel("Perplexity", fontsize=14, color='#1f78b4')
|
44 |
+
ax1.tick_params(axis='y', labelcolor='#1f78b4')
|
45 |
+
# , 1.18
|
46 |
+
|
47 |
+
# Create a twin y-axis for controllability (right y-axis)
|
48 |
+
ax2 = ax1.twinx()
|
49 |
+
controllability_values = [0.11, 1.02, 1.07, 1.12, 1.87, 1.34] # Example values for controllability
|
50 |
+
ax2.plot(x_values, controllability_values, marker='s', linestyle='--', color='#006400', linewidth=line_width)
|
51 |
+
ax2.set_ylabel("Delta PSNR", fontsize=14, color='#006400')
|
52 |
+
ax2.set_ylim(0, np.max(controllability_values) + 0.2)
|
53 |
+
ax2.tick_params(axis='y', labelcolor='#006400')
|
54 |
+
ax2.annotate(f"{controllability_values[INDEX]:.1f}", (x_values[INDEX], controllability_values[INDEX]), textcoords="offset points", xytext=(0, 10), ha='center')
|
55 |
+
|
56 |
+
# Save the figure in high resolution
|
57 |
+
plt.tight_layout()
|
58 |
+
# plt.show()
|
59 |
+
|
60 |
+
plt.savefig(f"output/model_sizes.png", dpi=300)
|
61 |
+
|
62 |
+
# Adding a centralized legend that appears above the plot
|
63 |
+
# fig.legend(y_values, loc='upper center', bbox_to_anchor=(0.5, 1.05), ncol=3, frameon=False, markerscale=1.5)
|
64 |
+
|
common/plot/plot_pretrain_ablation.py
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import seaborn as sns
|
2 |
+
import matplotlib
|
3 |
+
import matplotlib.pyplot as plt
|
4 |
+
import numpy as np
|
5 |
+
import pandas as pd
|
6 |
+
|
7 |
+
font = {
|
8 |
+
"family": "normal",
|
9 |
+
"size": 22,
|
10 |
+
}
|
11 |
+
|
12 |
+
matplotlib.rc("font", **font)
|
13 |
+
sns.set(rc={"font.family": "Times New Roman"})
|
14 |
+
sns.set(style="whitegrid")
|
15 |
+
sns.set(font_scale=3, style="whitegrid")
|
16 |
+
|
17 |
+
# Sample data for plotting
|
18 |
+
categories = ["Scratch", "Passive Pre-Train", "Pre-Train", "Pre-Train (Large)"]
|
19 |
+
values = [1.0, 1.0, 1.0, 1.0]
|
20 |
+
|
21 |
+
# Define custom colors for the bars
|
22 |
+
colors = ["#4c72b0", "#55a868", "#c44e52", "#8172b2"] # Adjust as needed
|
23 |
+
|
24 |
+
plt.figure(figsize=(14, 12))
|
25 |
+
ax = sns.barplot(
|
26 |
+
x=categories, y=values, alpha=0.9, palette=colors, edgecolor="black"
|
27 |
+
)
|
28 |
+
for container in ax.containers:
|
29 |
+
ax.bar_label(container, label_type="edge", fontsize="x-large", fmt="%.2f")
|
30 |
+
|
31 |
+
# Adding title and labels
|
32 |
+
plt.xlabel("Setting", fontsize=40)
|
33 |
+
plt.ylabel("Validation Perplexity", fontsize=40)
|
34 |
+
plt.xticks(fontsize=30)
|
35 |
+
plt.yticks(fontsize=30)
|
36 |
+
plt.legend(fontsize="small", title_fontsize="small", loc="lower left")
|
37 |
+
|
38 |
+
# Remove the borders
|
39 |
+
sns.despine(left=True, bottom=True)
|
40 |
+
|
41 |
+
# Display the plot
|
42 |
+
plt.tight_layout()
|
43 |
+
plt.savefig(f"output/model_ablation.png", dpi=300) # Save the figure in high resolution
|
44 |
+
plt.show()
|
common/plot/plot_pretrain_ablation_mar.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import seaborn as sns
|
2 |
+
import matplotlib
|
3 |
+
import matplotlib.pyplot as plt
|
4 |
+
import numpy as np
|
5 |
+
import pandas as pd
|
6 |
+
|
7 |
+
font = {
|
8 |
+
"family": "normal",
|
9 |
+
"size": 22,
|
10 |
+
}
|
11 |
+
|
12 |
+
matplotlib.rc("font", **font)
|
13 |
+
sns.set(rc={"font.family": "Times New Roman"})
|
14 |
+
sns.set(style="whitegrid")
|
15 |
+
sns.set(font_scale=3, style="whitegrid")
|
16 |
+
|
17 |
+
# Sample data for plotting
|
18 |
+
categories = ["Scratch", "Passive Pre-Train", "Pre-Train", "Pre-Train (Large)"]
|
19 |
+
values = [1.0, 1.0, 1.0, 1.0]
|
20 |
+
|
21 |
+
# Define custom colors for the bars
|
22 |
+
colors = ["#4c72b0", "#55a868", "#c44e52", "#8172b2"] # Adjust as needed
|
23 |
+
|
24 |
+
plt.figure(figsize=(14, 12))
|
25 |
+
ax = sns.barplot(
|
26 |
+
x=categories, y=values, alpha=0.9, palette=colors, edgecolor="black"
|
27 |
+
)
|
28 |
+
for container in ax.containers:
|
29 |
+
ax.bar_label(container, label_type="edge", fontsize="x-large", fmt="%.2f")
|
30 |
+
|
31 |
+
# Adding title and labels
|
32 |
+
plt.xlabel("Setting", fontsize=40)
|
33 |
+
plt.ylabel("Validation Perplexity", fontsize=40)
|
34 |
+
plt.xticks(fontsize=30)
|
35 |
+
ax.tick_params(axis='x', rotation=15)
|
36 |
+
plt.yticks(fontsize=30)
|
37 |
+
plt.legend(fontsize="small", title_fontsize="small", loc="lower left")
|
38 |
+
|
39 |
+
# Remove the borders
|
40 |
+
sns.despine(left=True, bottom=True)
|
41 |
+
|
42 |
+
# Display the plot
|
43 |
+
plt.tight_layout()
|
44 |
+
plt.savefig(f"output/model_ablation.png", dpi=300) # Save the figure in high resolution
|
45 |
+
plt.show()
|
cont_data.py
ADDED
@@ -0,0 +1,245 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import math
|
3 |
+
import os
|
4 |
+
import random
|
5 |
+
from pathlib import Path
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
from einops import rearrange
|
10 |
+
from torch.utils.data import Dataset as TorchDataset
|
11 |
+
|
12 |
+
from datasets.encode_openx_dataset import DATA_FREQ_TABLE
|
13 |
+
from genie.config import GenieConfig
|
14 |
+
from genie.st_mask_git import cosine_schedule
|
15 |
+
|
16 |
+
SVD_SCALE = 0.18215
|
17 |
+
|
18 |
+
def normalize_actions(actions):
|
19 |
+
"""
|
20 |
+
compute mean and std of actions. Normalize actions is done inside the network.
|
21 |
+
"""
|
22 |
+
mean = np.mean(actions, axis=0).tolist()
|
23 |
+
std = np.std(actions, axis=0).tolist()
|
24 |
+
return actions, [mean, std]
|
25 |
+
|
26 |
+
|
27 |
+
class RawFeatureDataset(TorchDataset):
|
28 |
+
""" Loads raw float32 tokens as memmap-backed array """
|
29 |
+
def __init__(
|
30 |
+
self,
|
31 |
+
data_dir,
|
32 |
+
window_size,
|
33 |
+
stride=1,
|
34 |
+
filter_interrupts=True,
|
35 |
+
filter_overlaps=False,
|
36 |
+
use_actions=False,
|
37 |
+
max_traj_num=1000000,
|
38 |
+
compute_stride_from_freq_table=True,
|
39 |
+
natural_hz=2,
|
40 |
+
datio_noise_ratio=0.0,
|
41 |
+
use_raw_image_as_latent=False,
|
42 |
+
domain=None,
|
43 |
+
):
|
44 |
+
"""
|
45 |
+
Args:
|
46 |
+
data_dir: directory with the same format as `data/train_v0` and `data/val_v0`.
|
47 |
+
Notably, has `video.bin` and `metadata.json`
|
48 |
+
window_size: number of frames per "video" sequence
|
49 |
+
stride: frame skip
|
50 |
+
filter_interrupts: Under 3% of training frame sequences are the concatenation of two different clips.
|
51 |
+
If filter_interrupts is True, will filter out these sequences using the segment ids.
|
52 |
+
filter_overlaps: If False (default), one frame will appear in multiple examples;
|
53 |
+
e.g. frame 0 might appear as the first frame in example 0 and also the second frame in example 15.
|
54 |
+
If True, will filter out examples so that each frame appears at most once in the dataset.
|
55 |
+
use_actions: If True, will load the actions from the `actions` folder for the models
|
56 |
+
"""
|
57 |
+
data_dir = Path(data_dir)
|
58 |
+
with open(data_dir / "metadata.json") as f:
|
59 |
+
self.metadata = json.load(f)
|
60 |
+
|
61 |
+
# TODO: assert not quantized in metadata
|
62 |
+
shape = (self.metadata["num_images"], self.metadata.get("latent_channels", 4), self.metadata["h"], self.metadata["w"]) #
|
63 |
+
print("token shape:", shape)
|
64 |
+
self.use_raw_image_as_latent = use_raw_image_as_latent
|
65 |
+
if use_raw_image_as_latent:
|
66 |
+
shape = (shape[0], 3, shape[2], shape[3])
|
67 |
+
# resize to 32x32
|
68 |
+
|
69 |
+
video_tokens_path, segment_ids_path, action_tokens_path = [data_dir / f"{name}.bin"
|
70 |
+
for name in ["video", "segment_ids", "actions"]]
|
71 |
+
|
72 |
+
token_dtype = np.dtype(self.metadata.get("token_dtype", "float16"))
|
73 |
+
self.data = np.memmap(video_tokens_path, mode="r", shape=shape, dtype=token_dtype)
|
74 |
+
print("data nan:", torch.isnan(torch.from_numpy(self.data[:100].copy())).sum())
|
75 |
+
# import IPython; IPython.embed()
|
76 |
+
if use_raw_image_as_latent:
|
77 |
+
# debug for robomimic dataset
|
78 |
+
# 256->64x64
|
79 |
+
self.metadata["h"] = 32
|
80 |
+
self.metadata["w"] = 32
|
81 |
+
self.metadata["latent_channels"] = 3
|
82 |
+
|
83 |
+
self.window_size, self.stride = window_size, stride
|
84 |
+
self.datio_noise_ratio = datio_noise_ratio
|
85 |
+
|
86 |
+
if domain is not None: # TODO: remove
|
87 |
+
self.name = domain
|
88 |
+
else:
|
89 |
+
self.name = self.metadata["name"]
|
90 |
+
|
91 |
+
self.name = self.name.replace("_noquant", "")
|
92 |
+
self.stride = stride
|
93 |
+
if compute_stride_from_freq_table:
|
94 |
+
self.stride = max(DATA_FREQ_TABLE.get(self.name, 1) // natural_hz, 1)
|
95 |
+
self.n_action = self.metadata.get("action_dim", 1) * (self.stride)
|
96 |
+
|
97 |
+
if use_actions:
|
98 |
+
actions = []
|
99 |
+
|
100 |
+
# hack here for the separations in the 1x datasets
|
101 |
+
for action_file in sorted((data_dir / "actions").iterdir()):
|
102 |
+
actions.append(np.memmap(action_file, dtype=np.float32, mode="r").reshape(len(self.data), -1))
|
103 |
+
|
104 |
+
self.actions = np.concatenate(actions, axis=-1)
|
105 |
+
self.actions, self.action_stat = normalize_actions(self.actions)
|
106 |
+
|
107 |
+
if os.path.isfile(segment_ids_path):
|
108 |
+
self.segment_ids = np.memmap(
|
109 |
+
segment_ids_path,
|
110 |
+
dtype=np.int32,
|
111 |
+
mode="r",
|
112 |
+
shape=(self.metadata["num_images"],)
|
113 |
+
)
|
114 |
+
else:
|
115 |
+
self.segment_ids = None
|
116 |
+
if filter_interrupts:
|
117 |
+
raise NotImplementedError("Cannot filter interrupted sequences without segment ids.")
|
118 |
+
|
119 |
+
# Number of frames between the first and last frames of a video sequence (excluding one endpoint frame)
|
120 |
+
self.video_len = (self.window_size - 1) * self.stride
|
121 |
+
self.valid_start_inds = []
|
122 |
+
|
123 |
+
for start_ind in range(len(self.data) - self.video_len - self.stride):
|
124 |
+
# Assuming `segment_ids` is monotonically increasing, a sequence is interrupted (or too short)
|
125 |
+
# if the first and last frames have different segment ids.
|
126 |
+
if not (filter_interrupts and self.segment_ids[start_ind] != self.segment_ids[start_ind + self.video_len]):
|
127 |
+
self.valid_start_inds.append(start_ind)
|
128 |
+
|
129 |
+
if len(self.valid_start_inds) >= max_traj_num:
|
130 |
+
break
|
131 |
+
|
132 |
+
if filter_overlaps:
|
133 |
+
# Instead of using a sliding window, use each frame at most once
|
134 |
+
filtered_start_inds = []
|
135 |
+
for start_ind in self.valid_start_inds:
|
136 |
+
overlapping_start_inds = {start_ind - i * self.stride for i in range(1, self.window_size)}
|
137 |
+
# all sequences from `overlapping_start_inds` will also contain `start_ind`,
|
138 |
+
# so exclude sequence starting from `start_ind` if any of `overlapping_start_inds` is already being used
|
139 |
+
for existing_start_ind in filtered_start_inds[-self.window_size * self.stride:]:
|
140 |
+
# Bound could be improved
|
141 |
+
if existing_start_ind in overlapping_start_inds:
|
142 |
+
break
|
143 |
+
else:
|
144 |
+
filtered_start_inds.append(start_ind)
|
145 |
+
|
146 |
+
self.valid_start_inds = filtered_start_inds
|
147 |
+
|
148 |
+
num_videos = len(np.unique(self.segment_ids))
|
149 |
+
print(f"Loaded {len(self)} sequences from {data_dir} {self.stride=} {self.window_size=} {self.n_action=} {num_videos=}")
|
150 |
+
|
151 |
+
def __len__(self):
|
152 |
+
return len(self.valid_start_inds)
|
153 |
+
|
154 |
+
def __getitem__(self, idx):
|
155 |
+
"""
|
156 |
+
Returns a flattened sequence of tokens representing `self.window_size` frames,
|
157 |
+
spaced `self.stride` apart.
|
158 |
+
"""
|
159 |
+
start_ind = self.valid_start_inds[idx]
|
160 |
+
x = self.data[start_ind : start_ind + self.video_len + 1 : self.stride].copy()
|
161 |
+
x = torch.FloatTensor(x).float()
|
162 |
+
if self.use_raw_image_as_latent:
|
163 |
+
x = torch.nn.functional.interpolate(x, size=(self.metadata["h"], self.metadata["w"]))
|
164 |
+
# normalize
|
165 |
+
x = x / 255 - 0.5
|
166 |
+
else:
|
167 |
+
x = x * SVD_SCALE
|
168 |
+
|
169 |
+
x = rearrange(x, "t c h w -> (t h w) c")
|
170 |
+
# divide it when decoding
|
171 |
+
# reconstructions since the input ids and the labels are the same
|
172 |
+
attention_mask = torch.ones_like(x)
|
173 |
+
data_dict = {
|
174 |
+
"input_ids": x,
|
175 |
+
"labels": x,
|
176 |
+
"attention_mask": attention_mask,
|
177 |
+
"h": self.metadata["h"],
|
178 |
+
"w": self.metadata["w"],
|
179 |
+
"c": self.metadata["latent_channels"],
|
180 |
+
}
|
181 |
+
if hasattr(self, "actions"):
|
182 |
+
# we want to have all actions within the stride to predict the next frame at the end of the stride
|
183 |
+
# we will concatenate the actions from [window_size, d_action] to [window_size, d_action * stride]
|
184 |
+
data_dict['action_ids'] = self.actions[start_ind:start_ind + self.video_len + self.stride].reshape(self.window_size, -1)
|
185 |
+
data_dict['action_ids'] = torch.from_numpy(data_dict['action_ids'].astype(np.float32))
|
186 |
+
|
187 |
+
data_dict["domain"] = self.name.replace("_noquant", "")
|
188 |
+
return data_dict
|
189 |
+
|
190 |
+
|
191 |
+
def get_maskgit_collator_feature(config: GenieConfig):
|
192 |
+
# mask_token_id = config.image_vocab_size
|
193 |
+
|
194 |
+
def collate_fn(features) -> dict[str, torch.Tensor]:
|
195 |
+
# during training, map (z_0, z_1', z_2') -> (null, z_1, z_2)
|
196 |
+
# (z_0, z_1') -> (null, z_1) is the diffusion operator on z_1' -> z_1
|
197 |
+
|
198 |
+
h = features[0]["h"]
|
199 |
+
w = features[0]["w"]
|
200 |
+
input_ids = torch.stack([ex["input_ids"] for ex in features])
|
201 |
+
device = input_ids.device
|
202 |
+
x_THWC = rearrange(input_ids, "b (t h w) c -> b t h w c", b=len(features), t=config.T, h=h, w=w)
|
203 |
+
labels = x_THWC.clone()
|
204 |
+
first_masked_frame = config.T
|
205 |
+
|
206 |
+
mask = torch.zeros(1).long()
|
207 |
+
mask_token_indicator = torch.zeros((len(features), config.T, h, w)).long()
|
208 |
+
|
209 |
+
if config.dataloader_apply_mask:
|
210 |
+
if random.random() < config.non_mlm_ratio: # Closer to autoregressive inference
|
211 |
+
# Leave frames [0, first_masked_frame) unmasked.
|
212 |
+
first_masked_frame = random.randint(config.num_prompt_frames, config.T - 1)
|
213 |
+
else: # Typical MLM masking
|
214 |
+
first_masked_frame = 1
|
215 |
+
|
216 |
+
c = 0
|
217 |
+
while mask.max() == 0: # We could get unlucky and mask no tokens?
|
218 |
+
# per-minibatch, per-frame masking probability (could try variable masking rate from MUSE)
|
219 |
+
rand = torch.rand(len(features), config.T - first_masked_frame, 1, 1)
|
220 |
+
# add a minimum mask ratio
|
221 |
+
rand_mask = rand * (1 - config.dataloader_mask_ratio_min) + config.dataloader_mask_ratio_min
|
222 |
+
mask_prob_T = cosine_schedule(rand_mask)
|
223 |
+
r = torch.rand_like(x_THWC[:, first_masked_frame:, ..., 0], dtype=torch.float)
|
224 |
+
mask = r < mask_prob_T
|
225 |
+
c += 1
|
226 |
+
|
227 |
+
if c > 1:
|
228 |
+
print(f"Generated mask {c} > 1 times.")
|
229 |
+
|
230 |
+
mask_token_indicator = torch.cat([
|
231 |
+
torch.zeros((len(features), first_masked_frame, h, w), dtype=mask.dtype), mask], dim=1)
|
232 |
+
|
233 |
+
data_dict = {
|
234 |
+
"input_ids": rearrange(x_THWC, "b t h w c -> b (t h w) c"),
|
235 |
+
"labels": rearrange(labels, "b t h w c-> b (t h w) c"),
|
236 |
+
"masked_tokens_indicator": mask_token_indicator,
|
237 |
+
}
|
238 |
+
|
239 |
+
if "action_ids" in features[0]:
|
240 |
+
data_dict['action_ids'] = torch.stack([ex["action_ids"] for ex in features])
|
241 |
+
data_dict['domain'] = [ex["domain"] for ex in features]
|
242 |
+
data_dict['h'] = [ex["h"] for ex in features]
|
243 |
+
data_dict['w'] = [ex["w"] for ex in features]
|
244 |
+
return data_dict
|
245 |
+
return collate_fn
|
data.py
ADDED
@@ -0,0 +1,240 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import math
|
3 |
+
import os
|
4 |
+
import random
|
5 |
+
from pathlib import Path
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
from einops import rearrange
|
10 |
+
from torch.utils.data import Dataset as TorchDataset
|
11 |
+
|
12 |
+
from datasets.encode_openx_dataset import DATA_FREQ_TABLE
|
13 |
+
from genie.factorization_utils import factorize_token_ids, unfactorize_token_ids
|
14 |
+
from genie.config import GenieConfig
|
15 |
+
from genie.st_mask_git import cosine_schedule
|
16 |
+
|
17 |
+
|
18 |
+
def normalize_actions(actions: np.ndarray) -> tuple[np.ndarray, list[list[float]]]:
|
19 |
+
"""
|
20 |
+
compute mean and std of actions. Normalize actions is done inside the network.
|
21 |
+
"""
|
22 |
+
mean = np.mean(actions, axis=0).tolist()
|
23 |
+
std = np.std(actions, axis=0).tolist()
|
24 |
+
return actions, [mean, std]
|
25 |
+
|
26 |
+
|
27 |
+
class RawTokenDataset(TorchDataset):
|
28 |
+
""" Loads raw uint32 tokens as memmap-backed array """
|
29 |
+
def __init__(
|
30 |
+
self,
|
31 |
+
data_dir,
|
32 |
+
window_size,
|
33 |
+
stride=1,
|
34 |
+
filter_interrupts=True,
|
35 |
+
filter_overlaps=False,
|
36 |
+
use_actions=False,
|
37 |
+
name='',
|
38 |
+
max_traj_num=1000000,
|
39 |
+
compute_stride_from_freq_table=True,
|
40 |
+
natural_hz=2,
|
41 |
+
drop_action_ratio=0.0
|
42 |
+
):
|
43 |
+
"""
|
44 |
+
Args:
|
45 |
+
data_dir: directory with the same format as `data/train_v0` and `data/val_v0`.
|
46 |
+
Notably, has `video.bin` and `metadata.json`
|
47 |
+
window_size: number of frames per "video" sequence
|
48 |
+
stride: frame skip
|
49 |
+
filter_interrupts: Under 3% of training frame sequences are the concatenation of two different clips.
|
50 |
+
If filter_interrupts is True, will filter out these sequences using the segment ids.
|
51 |
+
filter_overlaps: If False (default), one frame will appear in multiple examples;
|
52 |
+
e.g. frame 0 might appear as the first frame in example 0 and also the second frame in example 15.
|
53 |
+
If True, will filter out examples so that each frame appears at most once in the dataset.
|
54 |
+
use_actions: If True, will load the actions from the `actions` folder for the models
|
55 |
+
name: the name of the dataset
|
56 |
+
|
57 |
+
"""
|
58 |
+
data_dir = Path(data_dir)
|
59 |
+
with open(data_dir / "metadata.json") as f:
|
60 |
+
self.metadata = json.load(f)
|
61 |
+
|
62 |
+
shape = (self.metadata["num_images"], self.metadata["h"], self.metadata["w"]) # self.metadata["s"], self.metadata["s"]
|
63 |
+
video_tokens_path, segment_ids_path, action_tokens_path = [data_dir / f"{name}.bin"
|
64 |
+
for name in ["video", "segment_ids", "actions"]]
|
65 |
+
token_dtype = np.dtype(self.metadata.get("token_dtype", "uint32"))
|
66 |
+
self.data = np.memmap(video_tokens_path, dtype=token_dtype, mode="r", shape=shape)
|
67 |
+
self.window_size, self.stride = window_size, stride
|
68 |
+
|
69 |
+
if len(name) == 0:
|
70 |
+
self.name = self.metadata["name"]
|
71 |
+
else: # remove later
|
72 |
+
self.name = name
|
73 |
+
|
74 |
+
if compute_stride_from_freq_table:
|
75 |
+
self.stride = max(DATA_FREQ_TABLE.get(self.name, 1) // natural_hz, 1)
|
76 |
+
print(f"RawTokenDataset: {self.name=} {self.stride=}")
|
77 |
+
|
78 |
+
self.n_action = self.metadata.get("action_dim", 1) * (self.stride)
|
79 |
+
self.drop_action_ratio = drop_action_ratio
|
80 |
+
|
81 |
+
if use_actions:
|
82 |
+
actions = []
|
83 |
+
|
84 |
+
# hack here for the separations in the 1x datasets
|
85 |
+
for action_file in sorted((data_dir / "actions").iterdir()):
|
86 |
+
actions.append(np.memmap(action_file, dtype=np.float32, mode="r").reshape(len(self.data), -1))
|
87 |
+
|
88 |
+
self.actions = np.concatenate(actions, axis=-1)
|
89 |
+
self.actions, self.action_stat = normalize_actions(self.actions)
|
90 |
+
|
91 |
+
if os.path.isfile(segment_ids_path):
|
92 |
+
self.segment_ids = np.memmap(
|
93 |
+
segment_ids_path,
|
94 |
+
dtype=np.int32,
|
95 |
+
mode="r",
|
96 |
+
shape=(self.metadata["num_images"],)
|
97 |
+
)
|
98 |
+
else:
|
99 |
+
self.segment_ids = None
|
100 |
+
if filter_interrupts:
|
101 |
+
raise NotImplementedError("Cannot filter interrupted sequences without segment ids.")
|
102 |
+
|
103 |
+
# Number of frames between the first and last frames of a video sequence (excluding one endpoint frame)
|
104 |
+
self.video_len = (self.window_size - 1) * self.stride
|
105 |
+
|
106 |
+
self.valid_start_inds = []
|
107 |
+
for start_ind in range(len(self.data) - self.video_len - self.stride):
|
108 |
+
# Assuming `segment_ids` is monotonically increasing, a sequence is interrupted (or too short)
|
109 |
+
# if the first and last frames have different segment ids.
|
110 |
+
if not (filter_interrupts and self.segment_ids[start_ind] != self.segment_ids[start_ind + self.video_len]):
|
111 |
+
self.valid_start_inds.append(start_ind)
|
112 |
+
|
113 |
+
if self.segment_ids is not None and self.segment_ids[start_ind] >= max_traj_num: # because we will filter based on window size later
|
114 |
+
# len(self.valid_start_inds) >= max_traj_num
|
115 |
+
break
|
116 |
+
|
117 |
+
if filter_overlaps:
|
118 |
+
# Instead of using a sliding window, use each frame at most once
|
119 |
+
filtered_start_inds = []
|
120 |
+
for start_ind in self.valid_start_inds:
|
121 |
+
overlapping_start_inds = {start_ind - i * self.stride for i in range(1, self.window_size)}
|
122 |
+
# all sequences from `overlapping_start_inds` will also contain `start_ind`,
|
123 |
+
# so exclude sequence starting from `start_ind` if any of `overlapping_start_inds` is already being used
|
124 |
+
for existing_start_ind in filtered_start_inds[-self.window_size * self.stride:]:
|
125 |
+
# Bound could be improved
|
126 |
+
if existing_start_ind in overlapping_start_inds:
|
127 |
+
break
|
128 |
+
else:
|
129 |
+
filtered_start_inds.append(start_ind)
|
130 |
+
|
131 |
+
self.valid_start_inds = filtered_start_inds
|
132 |
+
|
133 |
+
self.num_videos = len(np.unique(self.valid_start_inds))
|
134 |
+
print(f"Loaded {len(self)} sequences from {data_dir} {self.stride=} {self.window_size=} {self.n_action=} {self.num_videos=}")
|
135 |
+
|
136 |
+
def __len__(self):
|
137 |
+
return len(self.valid_start_inds)
|
138 |
+
|
139 |
+
def __getitem__(self, idx):
|
140 |
+
"""
|
141 |
+
Returns a flattened sequence of tokens representing `self.window_size` frames,
|
142 |
+
spaced `self.stride` apart.
|
143 |
+
"""
|
144 |
+
start_ind = self.valid_start_inds[idx]
|
145 |
+
x = torch.from_numpy((self.data[start_ind : start_ind + self.video_len + 1 : self.stride]).astype(np.int64))
|
146 |
+
x = x.flatten() # 16 x 16 x 16
|
147 |
+
|
148 |
+
# reconstructions since the input ids and the labels are the same
|
149 |
+
attention_mask = torch.ones_like(x)
|
150 |
+
data_dict = {
|
151 |
+
"input_ids": x,
|
152 |
+
"labels": x,
|
153 |
+
"attention_mask": attention_mask,
|
154 |
+
"h": self.metadata["h"],
|
155 |
+
"w": self.metadata["w"],
|
156 |
+
}
|
157 |
+
if hasattr(self, "actions") and np.random.uniform() > self.drop_action_ratio:
|
158 |
+
# we want to have all actions within the stride to predict the next frame at the end of the stride
|
159 |
+
# we will concatenate the actions from [window_size, d_action] to [window_size, d_action * stride]
|
160 |
+
# S x T x d_action
|
161 |
+
data_dict['action_ids'] = self.actions[start_ind:start_ind + self.video_len + self.stride].reshape(self.window_size, -1)
|
162 |
+
data_dict['action_ids'] = torch.from_numpy(data_dict['action_ids'].astype(np.float32))
|
163 |
+
|
164 |
+
data_dict["domain"] = self.name
|
165 |
+
return data_dict
|
166 |
+
|
167 |
+
|
168 |
+
def get_maskgit_collator(config: GenieConfig):
|
169 |
+
mask_token_id = config.image_vocab_size
|
170 |
+
# h = w = math.isqrt(config.S)
|
171 |
+
|
172 |
+
def collate_fn(features) -> dict[str, torch.Tensor]:
|
173 |
+
# during training, map (z_0, z_1', z_2') -> (null, z_1, z_2)
|
174 |
+
# (z_0, z_1') -> (null, z_1) is the diffusion operator on z_1' -> z_1
|
175 |
+
h = features[0]["h"]
|
176 |
+
w = features[0]["w"]
|
177 |
+
input_ids = torch.stack([ex["input_ids"] for ex in features])
|
178 |
+
device = input_ids.device
|
179 |
+
x_THW = rearrange(input_ids, "b (t h w) -> b t h w", b=len(features), t=config.T,
|
180 |
+
h=h, w=w)
|
181 |
+
x_THWC = factorize_token_ids(x_THW, config.num_factored_vocabs, config.factored_vocab_size)
|
182 |
+
labels = x_THW.clone()
|
183 |
+
|
184 |
+
if config.dataloader_apply_corruption:
|
185 |
+
# As done in Copilot-4D paper, add random noise sampled with a random rate between 0% and `config.max_corrupt_rate`
|
186 |
+
r = torch.rand(x_THWC.size(), device=device)
|
187 |
+
u01 = torch.rand((), device=device)
|
188 |
+
random_patches_mask = r < config.max_corrupt_rate * u01
|
189 |
+
random_values = torch.randint(low=0, high=config.factored_vocab_size, size=x_THWC.size(),
|
190 |
+
dtype=torch.long, device=device)
|
191 |
+
x_THWC[random_patches_mask] = random_values[random_patches_mask]
|
192 |
+
|
193 |
+
if random.random() < config.non_mlm_ratio: # Closer to autoregressive inference
|
194 |
+
# Leave frames [0, first_masked_frame) unmasked.
|
195 |
+
# first_masked_frame = random.randint(config.num_prompt_frames, config.T - 1)
|
196 |
+
first_masked_frame = random.randint(config.num_prompt_frames, config.T - 1)
|
197 |
+
x_THWC_view = x_THWC[:, first_masked_frame:]
|
198 |
+
|
199 |
+
# Arbitrary numbers here, but corrupting later frames more
|
200 |
+
# since we likely have compounding errors.
|
201 |
+
correct_rate = random.uniform(config.dataloader_mask_ratio_min, 1.0)
|
202 |
+
for i in range(x_THWC_view.size(1)):
|
203 |
+
correct_rate *= random.uniform(0.9, 1.0)
|
204 |
+
r = torch.rand((len(features), h, w, config.num_factored_vocabs), device=device)
|
205 |
+
random_patches_mask = r > correct_rate
|
206 |
+
x_THWC_view[:, i][random_patches_mask] = random_values[:, first_masked_frame + i][random_patches_mask]
|
207 |
+
else: # Typical MLM masking
|
208 |
+
first_masked_frame = 1
|
209 |
+
|
210 |
+
mask = torch.zeros(1)
|
211 |
+
if config.dataloader_apply_mask:
|
212 |
+
c = 0
|
213 |
+
|
214 |
+
while mask.max() == 0: # We could get unlucky and mask no tokens?
|
215 |
+
# per-minibatch, per-frame masking probability (could try variable masking rate from MUSE)
|
216 |
+
mask_prob_T = cosine_schedule(torch.rand(len(features), config.T - first_masked_frame, 1, 1))
|
217 |
+
r = torch.rand_like(x_THW[:, first_masked_frame:], dtype=torch.float)
|
218 |
+
mask = r < mask_prob_T
|
219 |
+
c += 1
|
220 |
+
|
221 |
+
if c > 1:
|
222 |
+
print(f"Generated mask {c} > 1 times.")
|
223 |
+
|
224 |
+
x_THW = unfactorize_token_ids(x_THWC, config.num_factored_vocabs, config.factored_vocab_size)
|
225 |
+
x_THW[:, first_masked_frame:][mask] = mask_token_id
|
226 |
+
|
227 |
+
data_dict = {
|
228 |
+
"input_ids": rearrange(x_THW, "b t h w -> b (t h w)"),
|
229 |
+
"labels": rearrange(labels, "b t h w -> b (t h w)"),
|
230 |
+
}
|
231 |
+
|
232 |
+
if "action_ids" in features[0]:
|
233 |
+
data_dict['action_ids'] = torch.stack([ex["action_ids"] for ex in features])
|
234 |
+
data_dict['domain'] = [ex["domain"] for ex in features]
|
235 |
+
data_dict['h'] = [ex["h"] for ex in features]
|
236 |
+
data_dict['w'] = [ex["w"] for ex in features]
|
237 |
+
return data_dict
|
238 |
+
|
239 |
+
|
240 |
+
return collate_fn
|
datasets/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
datasets/__init__.py
ADDED
File without changes
|
datasets/encode_extern_dataset.py
ADDED
@@ -0,0 +1,291 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# Licensed under The MIT License [see LICENSE for details]
|
3 |
+
# --------------------------------------------------------
|
4 |
+
import argparse
|
5 |
+
import json
|
6 |
+
import os
|
7 |
+
import time
|
8 |
+
import traceback
|
9 |
+
from typing import Optional
|
10 |
+
|
11 |
+
import numpy as np
|
12 |
+
from tqdm import tqdm
|
13 |
+
|
14 |
+
from datasets.encode_openx_dataset import MIN_VAL_EXAMPLES, MAX_VAL_EXAMPLES, get_shard_inds, VAL_RATIO, \
|
15 |
+
process_dataset_step, DATA_FREQ_TABLE
|
16 |
+
from datasets.extern.ego4d import ego4d_dataset_size, ego4d_dataset_generator
|
17 |
+
from datasets.extern.egoexo4d import egoexo4d_dataset_size, egoexo4d_dataset_generator
|
18 |
+
from datasets.extern.robomimic import robomimic_dataset_generator, robomimic_dataset_size
|
19 |
+
from . import utils
|
20 |
+
|
21 |
+
|
22 |
+
SCRIPT_DESCRIPTION="""
|
23 |
+
Similar to encode_openx_dataset.py except for non-OpenX datasets.
|
24 |
+
Again, each split can be partitioned into multiple shards,
|
25 |
+
which is useful for parallelized encoding across GPUs.
|
26 |
+
|
27 |
+
Example usage:
|
28 |
+
CUDA_VISIBLE_DEVICES=0 python -m datasets.encode_extern_dataset --dataset_name egoexo4d --data_split train --num_shards 1000 --curr_shard_rank 400
|
29 |
+
|
30 |
+
Untested usage (SVD tokenizer):
|
31 |
+
CUDA_VISIBLE_DEVICES=0 python -m datasets.encode_extern_dataset --dataset_name robomimic --data_split val --no_quantization --encoder_type temporalvae --encoder_name_or_path 'stabilityai/stable-video-diffusion-img2vid'
|
32 |
+
""".strip()
|
33 |
+
|
34 |
+
DATASET_TO_GEN_AND_SIZE = {
|
35 |
+
"ego4d": (ego4d_dataset_generator, ego4d_dataset_size),
|
36 |
+
"egoexo4d": (egoexo4d_dataset_generator, egoexo4d_dataset_size),
|
37 |
+
"robomimic": (robomimic_dataset_generator, robomimic_dataset_size),
|
38 |
+
}
|
39 |
+
|
40 |
+
|
41 |
+
def encode_dataset_split(
|
42 |
+
extern_dataset_name: str,
|
43 |
+
split: str,
|
44 |
+
max_episodes: Optional[int],
|
45 |
+
original_res: bool,
|
46 |
+
no_quantization: bool,
|
47 |
+
curr_shard_rank: int,
|
48 |
+
num_shards: int,
|
49 |
+
root_dir: str,
|
50 |
+
encoder_type: str,
|
51 |
+
encoder_name_or_path: str,
|
52 |
+
dataset_postfix: str = "",
|
53 |
+
no_encoding: bool = False,
|
54 |
+
):
|
55 |
+
"""
|
56 |
+
Encodes (e.g. tokenizes) dataset.
|
57 |
+
The data written to disk can be used to load a `RawTokenDataset` (or the continuous version.)
|
58 |
+
|
59 |
+
Args:
|
60 |
+
extern_dataset_name: TODO
|
61 |
+
split: expected to be either "train" or "val". TODO: decide how to split
|
62 |
+
max_episodes: the maximum number of trajectories to include in the dataset.
|
63 |
+
dataset_postfix: will be a suffix of the output dirname.
|
64 |
+
image_encoder: string specifying the type of image encoder/tokenizer to use.
|
65 |
+
original_res: if True, will maintain original resolution of the video rather than resizing it to 256x256.
|
66 |
+
no_quantization: if True, will not perform quantization step in image encoder.
|
67 |
+
"""
|
68 |
+
extern_dataset_name = extern_dataset_name.strip() # never modified
|
69 |
+
suffixed_dataset_name = extern_dataset_name # will modify later
|
70 |
+
|
71 |
+
if original_res:
|
72 |
+
suffixed_dataset_name = f"{suffixed_dataset_name}_originalres"
|
73 |
+
if no_quantization:
|
74 |
+
suffixed_dataset_name = f"{suffixed_dataset_name}_noquant"
|
75 |
+
if no_encoding:
|
76 |
+
suffixed_dataset_name = f"{suffixed_dataset_name}_noencoding"
|
77 |
+
save_dirname = "_".join([suffixed_dataset_name, encoder_type, dataset_postfix, split])
|
78 |
+
dataset_path = os.path.join(root_dir, save_dirname)
|
79 |
+
print("=" * 25)
|
80 |
+
print(f"{dataset_path=}")
|
81 |
+
utils.mkdir_if_missing(dataset_path)
|
82 |
+
|
83 |
+
# Load data
|
84 |
+
generator, size_func = DATASET_TO_GEN_AND_SIZE[extern_dataset_name]
|
85 |
+
num_examples = size_func()
|
86 |
+
if max_episodes is not None:
|
87 |
+
num_examples = min(num_examples, max_episodes) # clip num_examples
|
88 |
+
|
89 |
+
# We will only operate on a subset of the training examples, depending on:
|
90 |
+
# 1) The split (train/val). Some examples are reserved for the other split.
|
91 |
+
# 2) Sharding
|
92 |
+
assert num_examples > MIN_VAL_EXAMPLES # non-positive number of train examples otherwise
|
93 |
+
num_val_examples = np.clip(int(VAL_RATIO * num_examples), MIN_VAL_EXAMPLES, MAX_VAL_EXAMPLES)
|
94 |
+
|
95 |
+
if split == "train": # first_ind inclusive, last_ind exclusive
|
96 |
+
first_split_ind, last_split_ind = num_val_examples, num_examples
|
97 |
+
elif split == "val":
|
98 |
+
first_split_ind, last_split_ind = 0, num_val_examples
|
99 |
+
else:
|
100 |
+
raise NotImplementedError(f"{split=}")
|
101 |
+
|
102 |
+
first_shard_ind, last_shard_ind = get_shard_inds(first_split_ind, last_split_ind, curr_shard_rank, num_shards)
|
103 |
+
print(f"Total number of examples in {suffixed_dataset_name}: {num_examples}")
|
104 |
+
print(f"Number of examples for {split=}, shard {curr_shard_rank} of {num_shards}: "
|
105 |
+
f"{last_shard_ind - first_shard_ind}. {first_shard_ind=} {last_shard_ind=}")
|
106 |
+
|
107 |
+
##### Encode data #####
|
108 |
+
traj_lens = [] # only used to print statistics
|
109 |
+
videos = [] # NOTE: videos/actions for the entire shard are stored in RAM until the end
|
110 |
+
actions = []
|
111 |
+
segment_ids = []
|
112 |
+
|
113 |
+
# split based on some fixed batch sizes to reset RAM.
|
114 |
+
max_batch_per_loading = 10
|
115 |
+
pbar = tqdm(range(first_shard_ind, last_shard_ind, max_batch_per_loading), position=0, leave=True)
|
116 |
+
start_time = time.time()
|
117 |
+
|
118 |
+
for start_idx in pbar:
|
119 |
+
end_idx = min(start_idx + max_batch_per_loading, last_shard_ind)
|
120 |
+
pbar.set_description(f"{suffixed_dataset_name} caching episodes: {start_idx}:{end_idx}")
|
121 |
+
ds = generator(range(start_idx, end_idx))
|
122 |
+
|
123 |
+
for chunk_idx, episode in enumerate(tqdm(ds, position=1, leave=False)):
|
124 |
+
segment_id = start_idx + chunk_idx
|
125 |
+
try:
|
126 |
+
# batchify the data and then process
|
127 |
+
for step_ind, step_data in enumerate(episode["steps"]):
|
128 |
+
dataset_step = process_dataset_step(
|
129 |
+
step_data,
|
130 |
+
encoder_type=encoder_type,
|
131 |
+
encoder_name_or_path=encoder_name_or_path,
|
132 |
+
keep_res=original_res,
|
133 |
+
quantize=not no_quantization,
|
134 |
+
no_encoding=no_encoding
|
135 |
+
)
|
136 |
+
|
137 |
+
segment_ids.append(segment_id)
|
138 |
+
videos.append(dataset_step["image"])
|
139 |
+
actions.append(dataset_step["action"])
|
140 |
+
|
141 |
+
traj_lens.append(step_ind + 1) # number of steps in this trajectory
|
142 |
+
except:
|
143 |
+
print("-" * 25)
|
144 |
+
print(f"Add episode failed: {segment_id=}", traceback.format_exc(), suffixed_dataset_name)
|
145 |
+
|
146 |
+
# 2 day timeout
|
147 |
+
if time.time() - start_time > 86400 * 2:
|
148 |
+
print(f"Writing dataset {suffixed_dataset_name} timed out")
|
149 |
+
break
|
150 |
+
|
151 |
+
if len(videos) == 0:
|
152 |
+
print("Empty shard!")
|
153 |
+
with open(f"{dataset_path}/error.json", "w") as f:
|
154 |
+
json.dump({"status": "empty_shard"}, f)
|
155 |
+
|
156 |
+
return
|
157 |
+
|
158 |
+
if no_quantization:
|
159 |
+
num_channels, height, width = videos[-1].shape[:3] # num_channels is not actually stored in metadata
|
160 |
+
else:
|
161 |
+
height, width = videos[-1].shape[:2]
|
162 |
+
num_channels = None
|
163 |
+
|
164 |
+
##### Write videos, actions, segment_ids, and metadata #####
|
165 |
+
# align format to save segment_ids.bin, video.bin, actions/action.bin, metadata.json
|
166 |
+
# save videos
|
167 |
+
videos = np.stack(videos, axis=0)
|
168 |
+
# fp = np.memmap(f'{dataset_path}/video.bin', dtype=video_dtype, mode='w+', shape=videos.shape)
|
169 |
+
# fp[:] = videos[:]
|
170 |
+
videos.tofile(f'{dataset_path}/video.bin')
|
171 |
+
|
172 |
+
# save action
|
173 |
+
utils.mkdir_if_missing(f'{dataset_path}/actions')
|
174 |
+
actions = np.stack(actions, axis=0)
|
175 |
+
# fp = np.memmap(f'{dataset_path}/actions/actions.bin', dtype=np.float32, mode='w+', shape=actions.shape)
|
176 |
+
# fp[:] = actions[:]
|
177 |
+
actions = actions.astype(np.float32)
|
178 |
+
actions.tofile(f'{dataset_path}/actions/actions.bin')
|
179 |
+
|
180 |
+
# save segment_ids
|
181 |
+
segment_ids = np.array(segment_ids)
|
182 |
+
# fp = np.memmap(f'{dataset_path}/segment_ids.bin', dtype=np.int32, mode='w+', shape=segment_ids.shape)
|
183 |
+
# fp[:] = segment_ids[:] # map to trajectory index
|
184 |
+
segment_ids = segment_ids.astype(np.int32)
|
185 |
+
segment_ids.tofile(f'{dataset_path}/segment_ids.bin')
|
186 |
+
|
187 |
+
# feature_mean = np.mean(videos)
|
188 |
+
# feature_std = np.std((videos - feature_mean) / 1e9) * 1e9
|
189 |
+
|
190 |
+
# save metadata
|
191 |
+
if encoder_type == "magvit":
|
192 |
+
vocab_size = int(2 ** 18)
|
193 |
+
elif encoder_type == "temporalvae":
|
194 |
+
vocab_size = None
|
195 |
+
else:
|
196 |
+
raise NotImplementedError(f"{encoder_type=}")
|
197 |
+
|
198 |
+
with open(f'{dataset_path}/metadata.json', 'w') as f: # Technically only need to save most of this data for shard 0
|
199 |
+
json.dump({
|
200 |
+
"token_dtype": str(np.dtype(videos.dtype)),
|
201 |
+
"action_dim": actions[0].shape[-1],
|
202 |
+
"s": 16,
|
203 |
+
"h": height,
|
204 |
+
"w": width,
|
205 |
+
"vocab_size": vocab_size,
|
206 |
+
"hz": DATA_FREQ_TABLE.get(extern_dataset_name, 1), # to be loaded from the data code
|
207 |
+
"encoder_name_or_path": encoder_name_or_path,
|
208 |
+
"encoder_type": encoder_type,
|
209 |
+
"num_images": len(videos),
|
210 |
+
"latent_channels": num_channels,
|
211 |
+
"name": extern_dataset_name,
|
212 |
+
# "feature_mean": feature_mean,
|
213 |
+
# "feature_std": feature_std,
|
214 |
+
}, f)
|
215 |
+
|
216 |
+
print(f"{len(traj_lens)=} {np.mean(traj_lens)=} {np.sum(traj_lens)=}")
|
217 |
+
print(f"Dataset creation time: {time.time() - start_time:.3f}")
|
218 |
+
|
219 |
+
|
220 |
+
def parse_args():
|
221 |
+
parser = argparse.ArgumentParser(description=SCRIPT_DESCRIPTION)
|
222 |
+
|
223 |
+
parser.add_argument(
|
224 |
+
"--dataset_name", type=str, required=True, choices=DATASET_TO_GEN_AND_SIZE.keys(),
|
225 |
+
help="TODO"
|
226 |
+
)
|
227 |
+
parser.add_argument(
|
228 |
+
"--data_split", type=str, choices=["train", "val"], required=True,
|
229 |
+
help="The split of the dataset to create."
|
230 |
+
)
|
231 |
+
parser.add_argument(
|
232 |
+
"--episode_cnt", type=int,
|
233 |
+
help="If specified, will limit the maximum number of trajectories to encode."
|
234 |
+
)
|
235 |
+
parser.add_argument(
|
236 |
+
"--original_res", action='store_true',
|
237 |
+
help="Maintain original resolution of the video rather than resizing it to 256x256."
|
238 |
+
)
|
239 |
+
parser.add_argument(
|
240 |
+
"--no_quantization", action='store_true',
|
241 |
+
help="Skip quantization step in visual encoder."
|
242 |
+
)
|
243 |
+
parser.add_argument(
|
244 |
+
"--num_shards", type=int, default=1,
|
245 |
+
help="The number of shards to partition the train/val dataset into."
|
246 |
+
)
|
247 |
+
parser.add_argument(
|
248 |
+
"--curr_shard_rank", type=int, default=0,
|
249 |
+
help="The (0-indexed) shard number to encode."
|
250 |
+
)
|
251 |
+
parser.add_argument(
|
252 |
+
"--root_dir", type=str, default="data",
|
253 |
+
help="The root directory to write all datasets to."
|
254 |
+
)
|
255 |
+
parser.add_argument(
|
256 |
+
"--encoder_type", type=str, default="magvit", choices=["magvit", "temporalvae"],
|
257 |
+
help="Type of the image tokenizer."
|
258 |
+
)
|
259 |
+
parser.add_argument(
|
260 |
+
"--encoder_name_or_path", type=str, default="data/magvit2.ckpt",
|
261 |
+
help="The path or name of the image encoder."
|
262 |
+
)
|
263 |
+
parser.add_argument(
|
264 |
+
"--no_encoding", action='store_true',
|
265 |
+
help="Preserve the groundtruth raw images to compute metrics in validation."
|
266 |
+
)
|
267 |
+
return parser.parse_args()
|
268 |
+
|
269 |
+
|
270 |
+
if __name__ == "__main__":
|
271 |
+
args = parse_args()
|
272 |
+
utils.set_seed(233)
|
273 |
+
|
274 |
+
dataset_postfix = f"shard{args.curr_shard_rank}_of_{args.num_shards}"
|
275 |
+
if args.episode_cnt is not None:
|
276 |
+
dataset_postfix = f"max{args.episode_cnt}_{dataset_postfix}"
|
277 |
+
|
278 |
+
encode_dataset_split(
|
279 |
+
extern_dataset_name=args.dataset_name,
|
280 |
+
split=args.data_split,
|
281 |
+
max_episodes=args.episode_cnt,
|
282 |
+
dataset_postfix=dataset_postfix,
|
283 |
+
original_res=args.original_res,
|
284 |
+
no_quantization=args.no_quantization,
|
285 |
+
num_shards=args.num_shards,
|
286 |
+
curr_shard_rank=args.curr_shard_rank,
|
287 |
+
root_dir=args.root_dir,
|
288 |
+
encoder_type=args.encoder_type,
|
289 |
+
encoder_name_or_path=args.encoder_name_or_path,
|
290 |
+
no_encoding=args.no_encoding,
|
291 |
+
)
|
datasets/encode_openx_dataset.py
ADDED
@@ -0,0 +1,459 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# Licensed under The MIT License [see LICENSE for details]
|
3 |
+
# --------------------------------------------------------
|
4 |
+
import argparse
|
5 |
+
import json
|
6 |
+
import os
|
7 |
+
import time
|
8 |
+
import traceback
|
9 |
+
from typing import Optional
|
10 |
+
|
11 |
+
import math
|
12 |
+
import numpy as np
|
13 |
+
import tensorflow_datasets as tfds
|
14 |
+
from tensorflow_datasets.core import DatasetBuilder
|
15 |
+
from tqdm import tqdm
|
16 |
+
|
17 |
+
from . import utils
|
18 |
+
|
19 |
+
|
20 |
+
SCRIPT_DESCRIPTION="""
|
21 |
+
Converts an Open X-Embodiment dataset from GS to encoded/tokenized data on disk.
|
22 |
+
This script only encodes one split (specified by `--data_split`)
|
23 |
+
of a one OpenX dataset (specified by `--dataset_name`) at a time.
|
24 |
+
|
25 |
+
Optionally, each split can be partitioned into multiple shards,
|
26 |
+
which is useful for parallelized encoding across GPUs.
|
27 |
+
|
28 |
+
Example usage:
|
29 |
+
CUDA_VISIBLE_DEVICES=0 python -m datasets.encode_openx_dataset --dataset_name bc_z --data_split train --episode_cnt 500 --num_shards 16 --curr_shard_rank 0
|
30 |
+
CUDA_VISIBLE_DEVICES=1 python -m datasets.encode_openx_dataset --dataset_name bc_z --data_split train --episode_cnt 500 --num_shards 16 --curr_shard_rank 1
|
31 |
+
|
32 |
+
set -e
|
33 |
+
for ((i = 0; i < 64; i += 2)); do
|
34 |
+
CUDA_VISIBLE_DEVICES=0 python -m datasets.encode_openx_dataset --dataset_name bridge --data_split train --num_shards 64 --curr_shard_rank $i --root_dir sharded_data
|
35 |
+
done
|
36 |
+
|
37 |
+
set -e
|
38 |
+
for ((i = 1; i < 64; i += 2)); do
|
39 |
+
CUDA_VISIBLE_DEVICES=1 python -m datasets.encode_openx_dataset --dataset_name bridge --data_split train --num_shards 64 --curr_shard_rank $i --root_dir sharded_data
|
40 |
+
done
|
41 |
+
|
42 |
+
Example usage (SVD tokenizer):
|
43 |
+
CUDA_VISIBLE_DEVICES=0 python -m datasets.encode_openx_dataset --dataset_name language_table --data_split val --no_quantization --encoder_type temporalvae --encoder_name_or_path 'stabilityai/stable-video-diffusion-img2vid'
|
44 |
+
""".strip()
|
45 |
+
|
46 |
+
# The validation set is the first VAL_RATIO examples in the dataset, and clipped to [MIN_VAL_EXAMPLES, MAX_VAL_EXAMPLES]
|
47 |
+
VAL_RATIO = 0.05
|
48 |
+
MIN_VAL_EXAMPLES, MAX_VAL_EXAMPLES = 20, 200
|
49 |
+
|
50 |
+
|
51 |
+
DATA_FREQ_TABLE = {
|
52 |
+
"austin_sailor_dataset_converted_externally_to_rlds": 20,
|
53 |
+
"stanford_hydra_dataset_converted_externally_to_rlds": 10,
|
54 |
+
"austin_buds_dataset_converted_externally_to_rlds": 20,
|
55 |
+
"austin_sirius_dataset_converted_externally_to_rlds": 20,
|
56 |
+
"berkeley_mvp_converted_externally_to_rlds": 5,
|
57 |
+
"berkeley_rpt_converted_externally_to_rlds": 30,
|
58 |
+
"ucsd_kitchen_dataset_converted_externally_to_rlds": 2,
|
59 |
+
"iamlab_cmu_pickup_insert_converted_externally_to_rlds": 20,
|
60 |
+
"utaustin_mutex": 20,
|
61 |
+
"imperialcollege_sawyer_wrist_cam": 10,
|
62 |
+
"language_table": 2, # changed to match frequency
|
63 |
+
"kuka": 2, # changed to match frequency
|
64 |
+
"bc_z": 10,
|
65 |
+
"robo_net": 1,
|
66 |
+
"dlr_sara_pour_converted_externally_to_rlds": 10,
|
67 |
+
"stanford_robocook_converted_externally_to_rlds": 5,
|
68 |
+
"cmu_play_fusion": 5,
|
69 |
+
"bridge": 5,
|
70 |
+
"furniture_bench_dataset_converted_externally_to_rlds": 10,
|
71 |
+
"ucsd_pick_and_place_dataset_converted_externally_to_rlds": 3,
|
72 |
+
"usc_cloth_sim_converted_externally_to_rlds": 10,
|
73 |
+
"stanford_kuka_multimodal_dataset_converted_externally_to_rlds": 20,
|
74 |
+
"roboturk": 10,
|
75 |
+
"kaist_nonprehensile_converted_externally_to_rlds": 10,
|
76 |
+
"asu_table_top_converted_externally_to_rlds": 12,
|
77 |
+
"utokyo_xarm_pick_and_place_converted_externally_to_rlds": 10,
|
78 |
+
"berkeley_cable_routing": 10,
|
79 |
+
"droid": 15,
|
80 |
+
"uiuc_d3field": 1,
|
81 |
+
"robo_set": 5,
|
82 |
+
"toto": 30,
|
83 |
+
"nyu_door_opening_surprising_effectiveness": 3,
|
84 |
+
"nyu_franka_play_dataset_converted_externally_to_rlds": 3,
|
85 |
+
"mimic_play": 15,
|
86 |
+
"maniskill_dataset_converted_externally_to_rlds": 20,
|
87 |
+
"columbia_cairlab_pusht_real": 10,
|
88 |
+
"conq_hose_manipulation": 30,
|
89 |
+
"dlr_edan_shared_control_converted_externally_to_rlds": 5,
|
90 |
+
"berkeley_gnm_sac_son": 10,
|
91 |
+
"berkeley_autolab_ur5": 5,
|
92 |
+
"aloha_mobile": 30,
|
93 |
+
"1x_humanoid": 30,
|
94 |
+
"epic_kitchen_originalres": 30,
|
95 |
+
"epic_kitchen": 30,
|
96 |
+
"exoego4d": 30,
|
97 |
+
"ego4d": 1, # less than this.
|
98 |
+
"robomimic": 6, # average length around 50
|
99 |
+
"metaworld": 6,
|
100 |
+
"frodobot": 30,
|
101 |
+
"fractal20220817_data": 3,
|
102 |
+
# robomimic
|
103 |
+
"robomimic": 6, # average length around 50
|
104 |
+
"robomimic_new": 6, # average length around 50
|
105 |
+
"robomimic_multitask_new": 6, # average length around 50
|
106 |
+
"robomimic_new_perturb": 6, # average length around 50
|
107 |
+
"robomimic_multitask_new_perturb": 6, # average length around 50
|
108 |
+
}
|
109 |
+
|
110 |
+
|
111 |
+
|
112 |
+
def select_image(observation, verbose=False):
|
113 |
+
"""
|
114 |
+
Select a canonical frame as image observation.
|
115 |
+
"""
|
116 |
+
imgs = []
|
117 |
+
# does not need to prefer wrist camera
|
118 |
+
for key in ["rgb", "image"]:
|
119 |
+
for obs_key in observation:
|
120 |
+
if key in obs_key and "depth" not in obs_key:
|
121 |
+
image = observation[obs_key]
|
122 |
+
if type(observation[obs_key]) is not np.ndarray:
|
123 |
+
image = image.numpy()
|
124 |
+
if verbose:
|
125 |
+
print("selected image key:", obs_key)
|
126 |
+
imgs.append(image)
|
127 |
+
|
128 |
+
return imgs
|
129 |
+
|
130 |
+
|
131 |
+
def process_dataset_step(step, encoder_type: str, encoder_name_or_path: str,
|
132 |
+
keep_res=False, quantize=True, no_encoding=False):
|
133 |
+
"""
|
134 |
+
Map dataset-specific keys and values to a unified format.
|
135 |
+
|
136 |
+
Args:
|
137 |
+
step (dict): The step dictionary containing the dataset-specific information.
|
138 |
+
encoder_type (str, optional): The image encoder to use.
|
139 |
+
Returns:
|
140 |
+
dict: The processed step dictionary with the mapped keys and values.
|
141 |
+
"""
|
142 |
+
step_dict = {}
|
143 |
+
try:
|
144 |
+
if "action" in step:
|
145 |
+
step_dict["action"] = np.array(step["action"])
|
146 |
+
|
147 |
+
# handle action
|
148 |
+
if type(step["action"]) is dict:
|
149 |
+
step_dict["action"] = step_dict["action"].item()
|
150 |
+
|
151 |
+
# outlier cases
|
152 |
+
action = []
|
153 |
+
for k, v in sorted(step_dict["action"].items()):
|
154 |
+
action.append(v.numpy().reshape(-1))
|
155 |
+
step_dict["action"] = np.concatenate(action)
|
156 |
+
|
157 |
+
# handle image
|
158 |
+
images = select_image(step["observation"])
|
159 |
+
|
160 |
+
# compute the embeddings.
|
161 |
+
if no_encoding:
|
162 |
+
step_dict["image"] = utils.resize_image(images[0])
|
163 |
+
elif quantize:
|
164 |
+
step_dict["image"] = utils.get_quantized_image_embeddings(
|
165 |
+
images[0],
|
166 |
+
encoder_type=encoder_type,
|
167 |
+
encoder_name_or_path=encoder_name_or_path,
|
168 |
+
keep_res=keep_res,
|
169 |
+
)
|
170 |
+
else:
|
171 |
+
step_dict["image"] = utils.get_vae_image_embeddings(
|
172 |
+
images[0],
|
173 |
+
encoder_type=encoder_type,
|
174 |
+
encoder_name_or_path=encoder_name_or_path,
|
175 |
+
keep_res=keep_res,
|
176 |
+
)
|
177 |
+
except Exception as e:
|
178 |
+
print("--------------------------")
|
179 |
+
print("process_dataset_step exception:", traceback.format_exc())
|
180 |
+
|
181 |
+
return step_dict
|
182 |
+
|
183 |
+
|
184 |
+
def get_dataset_builder(gs_dataset_name) -> tuple[DatasetBuilder, int]:
|
185 |
+
"""
|
186 |
+
Returns the dataset builder and the total number of examples (for the train split).
|
187 |
+
"""
|
188 |
+
try:
|
189 |
+
builder = tfds.builder_from_directory(builder_dir=f"gs://gresearch/robotics/{gs_dataset_name}/0.1.0/")
|
190 |
+
except:
|
191 |
+
try:
|
192 |
+
builder = tfds.builder_from_directory(builder_dir=f"gs://gresearch/robotics/{gs_dataset_name}/1.0.0/")
|
193 |
+
except:
|
194 |
+
builder = tfds.builder_from_directory(builder_dir=f"gs://gresearch/robotics/{gs_dataset_name}/0.0.1/")
|
195 |
+
|
196 |
+
info = builder.info
|
197 |
+
num_examples = info.splits["train"].num_examples
|
198 |
+
|
199 |
+
return builder, num_examples
|
200 |
+
|
201 |
+
|
202 |
+
def get_shard_inds(first_split_ind: int, last_split_ind: int, curr_shard_rank: int, num_shards: int) -> tuple[int, int]:
|
203 |
+
"""
|
204 |
+
Given the indices of the first (inclusive) and last (exclusive) examples in the data split (i.e. entire train dataset or val dataset),
|
205 |
+
returns the indices of the first (inclusive) and last (exclusive) examples for the current shard in this data split.
|
206 |
+
"""
|
207 |
+
split_num_examples = last_split_ind - first_split_ind
|
208 |
+
shard_size_float = split_num_examples / num_shards # average number of examples per shard
|
209 |
+
return (
|
210 |
+
first_split_ind + math.ceil(curr_shard_rank * shard_size_float),
|
211 |
+
min(first_split_ind + math.ceil((curr_shard_rank + 1) * shard_size_float), last_split_ind)
|
212 |
+
)
|
213 |
+
|
214 |
+
|
215 |
+
def encode_dataset_split(
|
216 |
+
gs_dataset_name: str,
|
217 |
+
split: str,
|
218 |
+
max_episodes: Optional[int],
|
219 |
+
original_res: bool,
|
220 |
+
no_quantization: bool,
|
221 |
+
curr_shard_rank: int,
|
222 |
+
num_shards: int,
|
223 |
+
root_dir: str,
|
224 |
+
encoder_type: str,
|
225 |
+
encoder_name_or_path: str,
|
226 |
+
dataset_postfix: str = "",
|
227 |
+
no_encoding: bool = False,
|
228 |
+
):
|
229 |
+
"""
|
230 |
+
Converts an Open X-Embodiment dataset from GS to encoded/tokenized data on disk.
|
231 |
+
The data written to disk can be used to load a `RawTokenDataset` (or the continuous version.)
|
232 |
+
|
233 |
+
Args:
|
234 |
+
gs_dataset_name: the name of the dataset in Google Storage.
|
235 |
+
Can be checked with gsutil ls -d gs://gresearch/robotics/*/
|
236 |
+
split: expected to be either "train" or "val". TODO: decide how to split
|
237 |
+
max_episodes: the maximum number of trajectories to include in the dataset.
|
238 |
+
dataset_postfix: will be a suffix of the output dirname.
|
239 |
+
image_encoder: string specifying the type of image encoder/tokenizer to use.
|
240 |
+
original_res: if True, will maintain original resolution of the video rather than resizing it to 256x256.
|
241 |
+
no_quantization: if True, will not perform quantization step in image encoder.
|
242 |
+
"""
|
243 |
+
gs_dataset_name = gs_dataset_name.strip() # never modified
|
244 |
+
suffixed_dataset_name = gs_dataset_name # will modify later
|
245 |
+
if no_quantization:
|
246 |
+
video_dtype = np.float16
|
247 |
+
elif no_encoding:
|
248 |
+
video_dtype = np.uint8
|
249 |
+
else:
|
250 |
+
video_dtype = np.uint32
|
251 |
+
if original_res:
|
252 |
+
suffixed_dataset_name = f"{suffixed_dataset_name}_originalres"
|
253 |
+
if no_quantization:
|
254 |
+
suffixed_dataset_name = f"{suffixed_dataset_name}_noquant"
|
255 |
+
if no_encoding:
|
256 |
+
suffixed_dataset_name = f"{suffixed_dataset_name}_noencoding"
|
257 |
+
save_dirname = "_".join([suffixed_dataset_name, encoder_type, dataset_postfix, split])
|
258 |
+
dataset_path = os.path.join(root_dir, save_dirname)
|
259 |
+
print("=" * 25)
|
260 |
+
print(f"{dataset_path=}")
|
261 |
+
utils.mkdir_if_missing(dataset_path)
|
262 |
+
|
263 |
+
# Load data
|
264 |
+
builder, num_examples = get_dataset_builder(gs_dataset_name)
|
265 |
+
if max_episodes is not None:
|
266 |
+
num_examples = min(num_examples, max_episodes) # clip num_examples
|
267 |
+
|
268 |
+
# We will only operate on a subset of the training examples, depending on:
|
269 |
+
# 1) The split (train/val). Some examples are reserved for the other split.
|
270 |
+
# 2) Sharding
|
271 |
+
assert num_examples > MIN_VAL_EXAMPLES, f"{num_examples=} {MIN_VAL_EXAMPLES=}" # non-positive number of train examples otherwise
|
272 |
+
num_val_examples = np.clip(int(VAL_RATIO * num_examples), MIN_VAL_EXAMPLES, MAX_VAL_EXAMPLES)
|
273 |
+
|
274 |
+
if split == "train": # first_ind inclusive, last_ind exclusive
|
275 |
+
first_split_ind, last_split_ind = num_val_examples, num_examples
|
276 |
+
elif split == "val":
|
277 |
+
first_split_ind, last_split_ind = 0, num_val_examples
|
278 |
+
else:
|
279 |
+
raise NotImplementedError(f"{split=}")
|
280 |
+
|
281 |
+
first_shard_ind, last_shard_ind = get_shard_inds(first_split_ind, last_split_ind, curr_shard_rank, num_shards)
|
282 |
+
print(f"Total number of examples in {suffixed_dataset_name}: {num_examples}")
|
283 |
+
print(f"Number of examples for {split=}, shard {curr_shard_rank} of {num_shards}: "
|
284 |
+
f"{last_shard_ind - first_shard_ind}. {first_shard_ind=} {last_shard_ind=}")
|
285 |
+
|
286 |
+
##### Encode data #####
|
287 |
+
traj_lens = [] # only used to print statistics
|
288 |
+
videos = [] # NOTE: videos/actions for the entire shard are stored in RAM until the end
|
289 |
+
actions = []
|
290 |
+
segment_ids = []
|
291 |
+
|
292 |
+
# split based on some fixed batch sizes to reset RAM.
|
293 |
+
max_batch_per_loading = 10
|
294 |
+
pbar = tqdm(range(first_shard_ind, last_shard_ind, max_batch_per_loading), position=0, leave=True)
|
295 |
+
start_time = time.time()
|
296 |
+
|
297 |
+
for start_idx in pbar:
|
298 |
+
end_idx = min(start_idx + max_batch_per_loading, last_shard_ind)
|
299 |
+
pbar.set_description(f"{suffixed_dataset_name} caching episodes: {start_idx}:{end_idx}")
|
300 |
+
ds = builder.as_dataset(split=f"train[{start_idx}:{end_idx}]")
|
301 |
+
|
302 |
+
for chunk_idx, episode in enumerate(tqdm(ds, position=1, leave=False)):
|
303 |
+
segment_id = start_idx + chunk_idx
|
304 |
+
try:
|
305 |
+
# batchify the data and then process
|
306 |
+
for step_ind, step_data in enumerate(episode["steps"]):
|
307 |
+
dataset_step = process_dataset_step(
|
308 |
+
step_data,
|
309 |
+
encoder_type=encoder_type,
|
310 |
+
encoder_name_or_path=encoder_name_or_path,
|
311 |
+
keep_res=original_res,
|
312 |
+
quantize=not no_quantization,
|
313 |
+
no_encoding=no_encoding
|
314 |
+
)
|
315 |
+
|
316 |
+
segment_ids.append(segment_id)
|
317 |
+
videos.append(dataset_step["image"])
|
318 |
+
actions.append(dataset_step["action"])
|
319 |
+
|
320 |
+
traj_lens.append(step_ind + 1) # number of steps in this trajectory
|
321 |
+
except:
|
322 |
+
print("-" * 25)
|
323 |
+
print(f"Add episode failed: {segment_id=}", traceback.format_exc(), suffixed_dataset_name)
|
324 |
+
|
325 |
+
# 2 day timeout
|
326 |
+
if time.time() - start_time > 86400 * 2:
|
327 |
+
print(f"Writing dataset {suffixed_dataset_name} timed out")
|
328 |
+
break
|
329 |
+
|
330 |
+
if no_quantization:
|
331 |
+
num_channels, height, width = videos[-1].shape[:3]
|
332 |
+
else:
|
333 |
+
height, width = videos[-1].shape[:2]
|
334 |
+
num_channels = None
|
335 |
+
|
336 |
+
##### Write videos, actions, segment_ids, and metadata #####
|
337 |
+
# align format to save segment_ids.bin, video.bin, actions/action.bin, metadata.json
|
338 |
+
# save videos
|
339 |
+
videos = np.stack(videos, axis=0)
|
340 |
+
fp = np.memmap(f'{dataset_path}/video.bin', dtype=video_dtype, mode='w+', shape=videos.shape)
|
341 |
+
fp[:] = videos[:]
|
342 |
+
|
343 |
+
# save action
|
344 |
+
utils.mkdir_if_missing(f'{dataset_path}/actions')
|
345 |
+
actions = np.stack(actions, axis=0)
|
346 |
+
fp = np.memmap(f'{dataset_path}/actions/actions.bin', dtype=np.float32, mode='w+', shape=actions.shape)
|
347 |
+
fp[:] = actions[:]
|
348 |
+
|
349 |
+
# save segment_ids
|
350 |
+
segment_ids = np.array(segment_ids)
|
351 |
+
fp = np.memmap(f'{dataset_path}/segment_ids.bin', dtype=np.int32, mode='w+', shape=segment_ids.shape)
|
352 |
+
fp[:] = segment_ids[:] # map to trajectory index
|
353 |
+
|
354 |
+
# feature_mean = float(np.mean(videos))
|
355 |
+
# feature_std = float(np.std((videos - feature_mean) / 1e9)) * 1e9
|
356 |
+
# save metadata
|
357 |
+
if encoder_type == "magvit":
|
358 |
+
vocab_size = int(2 ** 18)
|
359 |
+
elif encoder_type == "temporalvae":
|
360 |
+
vocab_size = None
|
361 |
+
else:
|
362 |
+
raise NotImplementedError(f"{encoder_type=}")
|
363 |
+
|
364 |
+
with open(f'{dataset_path}/metadata.json', 'w') as f: # Technically only need to save most of this data for shard 0
|
365 |
+
json.dump({
|
366 |
+
"token_dtype": str(np.dtype(videos.dtype)),
|
367 |
+
"action_dim": actions[0].shape[-1],
|
368 |
+
"s": 16,
|
369 |
+
"h": height,
|
370 |
+
"w": width,
|
371 |
+
"vocab_size": vocab_size,
|
372 |
+
"hz": DATA_FREQ_TABLE.get(gs_dataset_name, 1), # to be loaded from the data code TODO: remove default?
|
373 |
+
"encoder_name_or_path": encoder_name_or_path,
|
374 |
+
"encoder_type": encoder_type,
|
375 |
+
"num_images": len(videos),
|
376 |
+
"name": gs_dataset_name,
|
377 |
+
"latent_channels": num_channels,
|
378 |
+
"quantized": not args.no_quantization,
|
379 |
+
# "feature_mean": feature_mean,
|
380 |
+
# "feature_std": feature_std,
|
381 |
+
}, f)
|
382 |
+
|
383 |
+
print(f"{len(traj_lens)=} {np.mean(traj_lens)=} {np.sum(traj_lens)=}")
|
384 |
+
print(f"Dataset creation time: {time.time() - start_time:.3f}")
|
385 |
+
|
386 |
+
|
387 |
+
def parse_args():
|
388 |
+
parser = argparse.ArgumentParser(description=SCRIPT_DESCRIPTION)
|
389 |
+
|
390 |
+
parser.add_argument(
|
391 |
+
"--dataset_name", type=str, required=True,
|
392 |
+
help="The name of the Open X-Embodiment dataset on Google Storage. "
|
393 |
+
"Can be checked with gsutil ls -d gs://gresearch/robotics/*/. "
|
394 |
+
)
|
395 |
+
parser.add_argument(
|
396 |
+
"--data_split", type=str, choices=["train", "val"], required=True,
|
397 |
+
help="The split of the dataset to create."
|
398 |
+
)
|
399 |
+
parser.add_argument(
|
400 |
+
"--episode_cnt", type=int,
|
401 |
+
help="If specified, will limit the maximum number of trajectories to encode."
|
402 |
+
)
|
403 |
+
parser.add_argument(
|
404 |
+
"--original_res", action='store_true',
|
405 |
+
help="Maintain original resolution of the video rather than resizing it to 256x256."
|
406 |
+
)
|
407 |
+
parser.add_argument(
|
408 |
+
"--no_quantization", action='store_true',
|
409 |
+
help="Skip quantization step in visual encoder."
|
410 |
+
)
|
411 |
+
parser.add_argument(
|
412 |
+
"--num_shards", type=int, default=1,
|
413 |
+
help="The number of shards to partition the train/val dataset into."
|
414 |
+
)
|
415 |
+
parser.add_argument(
|
416 |
+
"--curr_shard_rank", type=int, default=0,
|
417 |
+
help="The (0-indexed) shard number to encode."
|
418 |
+
)
|
419 |
+
parser.add_argument(
|
420 |
+
"--root_dir", type=str, default="data",
|
421 |
+
help="The root directory to write all datasets to."
|
422 |
+
)
|
423 |
+
parser.add_argument(
|
424 |
+
"--encoder_type", type=str, default="magvit", choices=["magvit", "temporalvae"],
|
425 |
+
help="Type of the image tokenizer."
|
426 |
+
)
|
427 |
+
parser.add_argument(
|
428 |
+
"--encoder_name_or_path", type=str, default="data/magvit2.ckpt",
|
429 |
+
help="The path or name of the image encoder."
|
430 |
+
)
|
431 |
+
parser.add_argument(
|
432 |
+
"--no_encoding", action='store_true',
|
433 |
+
help="Preserve the groundtruth raw images to compute metrics in validation."
|
434 |
+
)
|
435 |
+
return parser.parse_args()
|
436 |
+
|
437 |
+
|
438 |
+
if __name__ == "__main__":
|
439 |
+
args = parse_args()
|
440 |
+
utils.set_seed(233)
|
441 |
+
|
442 |
+
dataset_postfix = f"shard{args.curr_shard_rank}_of_{args.num_shards}" if args.num_shards > 1 else ""
|
443 |
+
if args.episode_cnt is not None:
|
444 |
+
dataset_postfix = f"max{args.episode_cnt}_{dataset_postfix}" if dataset_postfix else f"max{args.episode_cnt}"
|
445 |
+
|
446 |
+
encode_dataset_split(
|
447 |
+
gs_dataset_name=args.dataset_name,
|
448 |
+
split=args.data_split,
|
449 |
+
max_episodes=args.episode_cnt,
|
450 |
+
dataset_postfix=dataset_postfix,
|
451 |
+
original_res=args.original_res,
|
452 |
+
no_quantization=args.no_quantization,
|
453 |
+
num_shards=args.num_shards,
|
454 |
+
curr_shard_rank=args.curr_shard_rank,
|
455 |
+
root_dir=args.root_dir,
|
456 |
+
encoder_type=args.encoder_type,
|
457 |
+
encoder_name_or_path=args.encoder_name_or_path,
|
458 |
+
no_encoding=args.no_encoding,
|
459 |
+
)
|
datasets/extern/__init__.py
ADDED
File without changes
|
datasets/extern/ego4d.py
ADDED
@@ -0,0 +1,193 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# Licensed under The MIT License [see LICENSE for details]
|
3 |
+
# --------------------------------------------------------
|
4 |
+
import os
|
5 |
+
from typing import Iterable
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
from tqdm import tqdm
|
9 |
+
from collections import OrderedDict
|
10 |
+
import os
|
11 |
+
import numpy as np
|
12 |
+
from pathlib import Path
|
13 |
+
|
14 |
+
CURRENT_DIR = os.path.dirname(__file__)
|
15 |
+
import cv2
|
16 |
+
from os.path import expanduser
|
17 |
+
import json
|
18 |
+
import matplotlib.pyplot as plt
|
19 |
+
|
20 |
+
RESOLUTION = (480, 480)
|
21 |
+
home = expanduser("~")
|
22 |
+
|
23 |
+
# Adjust these to the where-ever your detections and frames are stored.
|
24 |
+
ROOT = "/datasets01/ego4d_track2/"
|
25 |
+
LABEL_ROOT = ROOT + "v2_1/annotations/fho_main.json"
|
26 |
+
VIDEO_PATH = ROOT + "v2_1/full_scale/"
|
27 |
+
# from epic_kitchens.hoa import load_detections
|
28 |
+
|
29 |
+
# labels = json.load(open("/datasets01/ego4d_track2/v2_1/annotations/fho_main.json"))
|
30 |
+
# videos = /datasets01/ego4d_track2/v2_1/clips
|
31 |
+
def parse_video_frame(video_path, frame_id):
|
32 |
+
cap = cv2.VideoCapture(video_path)
|
33 |
+
cap.set(cv2.CAP_PROP_POS_FRAMES, frame_id-1)
|
34 |
+
ret, frame = cap.read()
|
35 |
+
return frame
|
36 |
+
|
37 |
+
def parse_raw_video(video_path):
|
38 |
+
cap = cv2.VideoCapture(video_path)
|
39 |
+
frames = []
|
40 |
+
while cap.isOpened():
|
41 |
+
ret, frame = cap.read()
|
42 |
+
if not ret:
|
43 |
+
break
|
44 |
+
frames.append(frame)
|
45 |
+
return frames
|
46 |
+
|
47 |
+
def compute_state_and_actions(image, curr_frame, next_frame, frame_idx, save=False):
|
48 |
+
# curr_frame is a list of bounding box labels
|
49 |
+
img_width, img_height = image.shape[1], image.shape[0]
|
50 |
+
for box in curr_frame:
|
51 |
+
if box['object_type'] == 'left_hand':
|
52 |
+
curr_hand1_center = [box['bbox']['x'] + box['bbox']['width'] / 2, box['bbox']['y'] + box['bbox']['height'] / 2]
|
53 |
+
|
54 |
+
if box['object_type'] == 'right_hand':
|
55 |
+
curr_hand2_center = [box['bbox']['x'] + box['bbox']['width'] / 2, box['bbox']['y'] + box['bbox']['height'] / 2]
|
56 |
+
|
57 |
+
for box in next_frame:
|
58 |
+
if box['object_type'] == 'left_hand':
|
59 |
+
next_hand1_center = [box['bbox']['x'] + box['bbox']['width'] / 2, box['bbox']['y'] + box['bbox']['height'] / 2]
|
60 |
+
|
61 |
+
if box['object_type'] == 'right_hand':
|
62 |
+
next_hand2_center = [box['bbox']['x'] + box['bbox']['width'] / 2, box['bbox']['y'] + box['bbox']['height'] / 2]
|
63 |
+
|
64 |
+
# normalized them
|
65 |
+
curr_hand1_center = np.array([curr_hand1_center[0] / img_width, curr_hand1_center[1] / img_height])
|
66 |
+
curr_hand2_center = np.array([curr_hand2_center[0] / img_width, curr_hand2_center[1] / img_height])
|
67 |
+
|
68 |
+
# normalize them
|
69 |
+
next_hand1_center = np.array([next_hand1_center[0] / img_width, next_hand1_center[1] / img_height])
|
70 |
+
next_hand2_center = np.array([next_hand2_center[0] / img_width, next_hand2_center[1] / img_height])
|
71 |
+
|
72 |
+
state = np.concatenate((curr_hand1_center, curr_hand2_center)) # - np.array(curr_hand1_center) - np.array(curr_hand2_center)
|
73 |
+
action = np.concatenate(
|
74 |
+
(
|
75 |
+
np.array(next_hand1_center),
|
76 |
+
np.array(next_hand2_center),
|
77 |
+
)
|
78 |
+
)
|
79 |
+
if save:
|
80 |
+
# draw the bounding boxes
|
81 |
+
cv2.circle(image, (int(curr_hand1_center[0] * img_width), int(curr_hand1_center[1] * img_height)), 10, (0, 255, 0), -1)
|
82 |
+
cv2.circle(image, (int(curr_hand2_center[0] * img_width), int(curr_hand2_center[1] * img_height)), 10, (0, 255, 0), -1)
|
83 |
+
cv2.circle(image, (int(next_hand1_center[0] * img_width), int(next_hand1_center[1] * img_height)), 10, (0, 0, 255), -1)
|
84 |
+
cv2.circle(image, (int(next_hand2_center[0] * img_width), int(next_hand2_center[1] * img_height)), 10, (0, 0, 255), -1)
|
85 |
+
# save the image
|
86 |
+
cv2.imwrite(f"/private/home/xinleic/LR/hpt_video/data/ego4d_video_label_check/img_{frame_idx}.png", image)
|
87 |
+
return state, action
|
88 |
+
|
89 |
+
|
90 |
+
def parse_raw_video(video_path):
|
91 |
+
import cv2
|
92 |
+
cap = cv2.VideoCapture(video_path)
|
93 |
+
frames = []
|
94 |
+
while cap.isOpened():
|
95 |
+
ret, frame = cap.read()
|
96 |
+
if not ret:
|
97 |
+
break
|
98 |
+
frames.append(frame)
|
99 |
+
return frames
|
100 |
+
|
101 |
+
def chunk_actions_and_concatenate(actions):
|
102 |
+
chunk_size = 4
|
103 |
+
chunked_actions = [actions[i:i + chunk_size] for i in range(0, len(actions), chunk_size)][:-1]
|
104 |
+
concatenated_frames = []
|
105 |
+
|
106 |
+
for chunk in chunked_actions:
|
107 |
+
frames_to_concat = []
|
108 |
+
for action in chunk:
|
109 |
+
frames = action['frames'] # Assuming 'frames' is a list or iterable
|
110 |
+
if frames is not None:
|
111 |
+
frames_to_concat.extend(frames) # Collect frames from each action
|
112 |
+
concatenated_frames.append(frames_to_concat) # Store the concatenated frames for this chunk
|
113 |
+
|
114 |
+
return concatenated_frames
|
115 |
+
|
116 |
+
|
117 |
+
def ego4d_dataset_size() -> int:
|
118 |
+
""" Returns the number of trajectories in the dataset. ~1725 for Ego4D. """
|
119 |
+
labels = json.load(open(LABEL_ROOT))
|
120 |
+
return len(labels['videos'])
|
121 |
+
|
122 |
+
|
123 |
+
# define your own dataset conversion
|
124 |
+
def ego4d_dataset_generator(example_inds: Iterable[int] = None):
|
125 |
+
"""
|
126 |
+
Generator yielding data from Ego4D.
|
127 |
+
Args:
|
128 |
+
example_inds: if specified, will only yield data from these indices.
|
129 |
+
Otherwise, will default to yielding the entire dataset.
|
130 |
+
"""
|
131 |
+
# convert to a list of episodes that can be added to replay buffer
|
132 |
+
labels = json.load(open(LABEL_ROOT))
|
133 |
+
|
134 |
+
if example_inds is None:
|
135 |
+
example_inds = range(len(labels['videos']))
|
136 |
+
|
137 |
+
for example_ind in example_inds:
|
138 |
+
label = labels['videos'][example_ind]
|
139 |
+
# ['annotated_intervals'][2]['narrated_actions']
|
140 |
+
video_path = VIDEO_PATH + label['video_uid'] + ".mp4"
|
141 |
+
if not os.path.exists(video_path):
|
142 |
+
print("skip", video_path)
|
143 |
+
continue
|
144 |
+
|
145 |
+
label_detections = labels
|
146 |
+
print("video_path:", video_path)
|
147 |
+
print("len label detections", len(label_detections))
|
148 |
+
|
149 |
+
# action extractions over bounding boxes subtractions of both hands.
|
150 |
+
for interval in label['annotated_intervals']:
|
151 |
+
# print(video_detections[frame_idx].hands)
|
152 |
+
|
153 |
+
lang = "use human hands to do some tasks" # dummies
|
154 |
+
# import IPython; IPython.embed()
|
155 |
+
print(f"Interval [{interval['start_sec']} - {interval['end_sec']}]")
|
156 |
+
actions = list(filter(lambda x: not (x['is_invalid_annotation'] or x['is_rejected']) and x['stage'] is not None, interval['narrated_actions']))
|
157 |
+
print(f"Actions: {len(actions)}")
|
158 |
+
|
159 |
+
# because we need to concatenate
|
160 |
+
if len(actions) < 3:
|
161 |
+
continue
|
162 |
+
|
163 |
+
# the number of frames is usually 7 and it also does not follow strict 2hz
|
164 |
+
chunk_actions = chunk_actions_and_concatenate(actions)
|
165 |
+
for frame_idx, frames in enumerate(chunk_actions):
|
166 |
+
# lang = frame['narration_text']
|
167 |
+
steps = []
|
168 |
+
# need to use dummy actions to expand from 6 frames to 16 frames
|
169 |
+
for idx, frame in enumerate(frames[:-1]):
|
170 |
+
frame_id = frame['frame_number']
|
171 |
+
next_frame = frames[idx + 1]
|
172 |
+
image = parse_video_frame(video_path, frame_id)
|
173 |
+
|
174 |
+
if len(frame['boxes']) > 2 and len(next_frame['boxes']) > 2:
|
175 |
+
try:
|
176 |
+
s, a = compute_state_and_actions(image, frame['boxes'], next_frame['boxes'], idx, save=False)
|
177 |
+
except:
|
178 |
+
print(f'compute action failed idx {idx} frame idx {frame_idx}')
|
179 |
+
continue
|
180 |
+
# break into step dict
|
181 |
+
step = {
|
182 |
+
"observation": {"image": image, "state": s},
|
183 |
+
"action": a,
|
184 |
+
"language_instruction": lang,
|
185 |
+
}
|
186 |
+
steps.append(OrderedDict(step))
|
187 |
+
|
188 |
+
if len(steps) < 16:
|
189 |
+
print("skip this traj because frame window length < 16")
|
190 |
+
continue
|
191 |
+
data_dict = {"steps": steps}
|
192 |
+
yield data_dict
|
193 |
+
|
datasets/extern/egoexo4d.py
ADDED
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# Licensed under The MIT License [see LICENSE for details]
|
3 |
+
# --------------------------------------------------------
|
4 |
+
# https://github.com/epic-kitchens/epic-kitchens-100-hand-object-bboxes/blob/master/notebooks/demo.ipynb
|
5 |
+
import os
|
6 |
+
from typing import Iterable
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
from tqdm import tqdm
|
10 |
+
from collections import OrderedDict
|
11 |
+
import os
|
12 |
+
import numpy as np
|
13 |
+
from pathlib import Path
|
14 |
+
|
15 |
+
|
16 |
+
CURRENT_DIR = os.path.dirname(__file__)
|
17 |
+
import cv2
|
18 |
+
from os.path import expanduser
|
19 |
+
import json
|
20 |
+
|
21 |
+
|
22 |
+
# Adjust these to the where-ever your detections and frames are stored.
|
23 |
+
CAM = "cam01" # cam01
|
24 |
+
ROOT = "/datasets01/egoexo4d/v2/"
|
25 |
+
LABEL_ROOT = ROOT + "annotations/ego_pose/train/hand/automatic/{}.json"
|
26 |
+
VIDEO_PATH = ROOT + "takes/{}/frame_aligned_videos/{}.mp4"
|
27 |
+
# from epic_kitchens.hoa import load_detections
|
28 |
+
TAKE_ROOT = ROOT + "takes.json"
|
29 |
+
|
30 |
+
|
31 |
+
|
32 |
+
def compute_state_and_actions(image, curr_frame, next_frame, idx, save=False):
|
33 |
+
img_width, img_height = image.shape[1], image.shape[0]
|
34 |
+
|
35 |
+
# already normalized
|
36 |
+
curr_hand1_center = curr_frame[0]['annotation2D'][CAM]['left_wrist']
|
37 |
+
curr_hand2_center = curr_frame[0]['annotation2D'][CAM]['right_wrist']
|
38 |
+
|
39 |
+
# normalized them
|
40 |
+
curr_hand1_center = np.array([curr_hand1_center['x'] / img_width, curr_hand1_center['y'] / img_height])
|
41 |
+
curr_hand2_center = np.array([curr_hand2_center['x'] / img_width, curr_hand2_center['y'] / img_height])
|
42 |
+
|
43 |
+
next_hand1_center = next_frame[0]['annotation2D'][CAM]['left_wrist']
|
44 |
+
next_hand2_center = next_frame[0]['annotation2D'][CAM]['right_wrist']
|
45 |
+
|
46 |
+
# normalize them
|
47 |
+
next_hand1_center = np.array([next_hand1_center['x'] / img_width, next_hand1_center['y'] / img_height])
|
48 |
+
next_hand2_center = np.array([next_hand2_center['x'] / img_width, next_hand2_center['y'] / img_height])
|
49 |
+
|
50 |
+
|
51 |
+
state = np.concatenate((curr_hand1_center, curr_hand2_center)) # - np.array(curr_hand1_center) - np.array(curr_hand2_center)
|
52 |
+
action = np.concatenate(
|
53 |
+
(
|
54 |
+
np.array(next_hand1_center),
|
55 |
+
np.array(next_hand2_center),
|
56 |
+
)
|
57 |
+
)
|
58 |
+
if save:
|
59 |
+
# draw the bounding boxes
|
60 |
+
cv2.circle(image, (int(curr_hand1_center[0] * img_width), int(curr_hand1_center[1] * img_height)), 10, (0, 255, 0), -1)
|
61 |
+
cv2.circle(image, (int(curr_hand2_center[0] * img_width), int(curr_hand2_center[1] * img_height)), 10, (0, 255, 0), -1)
|
62 |
+
cv2.circle(image, (int(next_hand1_center[0] * img_width), int(next_hand1_center[1] * img_height)), 10, (0, 0, 255), -1)
|
63 |
+
cv2.circle(image, (int(next_hand2_center[0] * img_width), int(next_hand2_center[1] * img_height)), 10, (0, 0, 255), -1)
|
64 |
+
# save the image
|
65 |
+
cv2.imwrite(f"output/inspect/test_{idx}.png", image)
|
66 |
+
return state, action
|
67 |
+
|
68 |
+
|
69 |
+
def parse_raw_video(video_path):
|
70 |
+
import cv2
|
71 |
+
cap = cv2.VideoCapture(video_path)
|
72 |
+
frames = []
|
73 |
+
while cap.isOpened():
|
74 |
+
ret, frame = cap.read()
|
75 |
+
if not ret:
|
76 |
+
break
|
77 |
+
frames.append(frame)
|
78 |
+
return frames
|
79 |
+
|
80 |
+
def egoexo4d_dataset_size() -> int:
|
81 |
+
""" Returns the number of takes in the dataset. ~5k for v2. """
|
82 |
+
takes = json.load(open(TAKE_ROOT))
|
83 |
+
return len(takes)
|
84 |
+
|
85 |
+
|
86 |
+
# define your own dataset conversion
|
87 |
+
def egoexo4d_dataset_generator(example_inds: Iterable[int] = None):
|
88 |
+
"""
|
89 |
+
Generator yielding data from Ego-Exo4D.
|
90 |
+
Args:
|
91 |
+
example_inds: if specified, will only yield data from these indices.
|
92 |
+
Otherwise, will default to yielding the entire dataset.
|
93 |
+
"""
|
94 |
+
# convert to a list of episodes that can be added to replay buffer
|
95 |
+
MAX_EPISODE_LENGTH = 5000
|
96 |
+
TAKE_FILE = json.load(open(TAKE_ROOT))
|
97 |
+
print("total takes", len(TAKE_FILE))
|
98 |
+
# find the first camera with aria
|
99 |
+
global CAM
|
100 |
+
|
101 |
+
def find_aria_name(take):
|
102 |
+
for cam in take['cameras']:
|
103 |
+
if 'aria' in cam['name']:
|
104 |
+
return cam['name']
|
105 |
+
return None
|
106 |
+
|
107 |
+
if example_inds is None:
|
108 |
+
example_inds = range(len(TAKE_FILE))
|
109 |
+
|
110 |
+
for example_ind in example_inds:
|
111 |
+
take = TAKE_FILE[example_ind]
|
112 |
+
take_name = take['take_name']
|
113 |
+
take_uid = take['take_uid']
|
114 |
+
# CAM = find_aria_name(take)
|
115 |
+
# if CAM is None:
|
116 |
+
# continue
|
117 |
+
|
118 |
+
video_path = VIDEO_PATH.format(take_name, CAM)
|
119 |
+
label_path = LABEL_ROOT.format(take_uid)
|
120 |
+
|
121 |
+
if not os.path.exists(video_path) or not os.path.exists(label_path):
|
122 |
+
continue
|
123 |
+
|
124 |
+
video_frames = parse_raw_video(video_path)
|
125 |
+
label_detections = json.load(open(label_path))
|
126 |
+
print("video_path:", video_path)
|
127 |
+
print("len video frames", len(video_frames))
|
128 |
+
print("len label detections", len(label_detections))
|
129 |
+
|
130 |
+
# action extractions over bounding boxes subtractions of both hands.
|
131 |
+
max_frame_idx = len(video_frames) - 1
|
132 |
+
DS_FACTOR = 1
|
133 |
+
frame_idx = 0
|
134 |
+
start_frame_idx = 0
|
135 |
+
MIN_CLIP_LENGTH = 300
|
136 |
+
|
137 |
+
def get_continuous_chunk(start_idx, label_detections):
|
138 |
+
end_idx = start_idx + 1
|
139 |
+
while str(start_idx) in label_detections and len(label_detections[str(start_idx)]) > 0 and str(end_idx) in label_detections and len(label_detections[str(end_idx)]) > 0:
|
140 |
+
end_idx += 1
|
141 |
+
return end_idx
|
142 |
+
|
143 |
+
print("TAKE", take_name)
|
144 |
+
|
145 |
+
# some frames might not have label. if there is a gap, skip
|
146 |
+
while start_frame_idx < max_frame_idx - DS_FACTOR:
|
147 |
+
# print(video_detections[frame_idx].hands)
|
148 |
+
lang = "use human hands to do some tasks" # dummies
|
149 |
+
if str(start_frame_idx) not in label_detections or str(start_frame_idx + DS_FACTOR) not in label_detections:
|
150 |
+
start_frame_idx += DS_FACTOR
|
151 |
+
continue
|
152 |
+
|
153 |
+
end_frame_idx = get_continuous_chunk(start_frame_idx, label_detections)
|
154 |
+
# print("start_frame_idx", start_frame_idx, end_frame_idx)
|
155 |
+
|
156 |
+
if end_frame_idx - start_frame_idx < MIN_CLIP_LENGTH:
|
157 |
+
start_frame_idx = end_frame_idx
|
158 |
+
continue
|
159 |
+
|
160 |
+
print("start clipping from", start_frame_idx, "to", end_frame_idx)
|
161 |
+
steps = []
|
162 |
+
for frame_idx in range(start_frame_idx, end_frame_idx - DS_FACTOR, DS_FACTOR):
|
163 |
+
image = video_frames[frame_idx][...,[2,1,0]] # RGB
|
164 |
+
try:
|
165 |
+
s, a = compute_state_and_actions(
|
166 |
+
image,
|
167 |
+
label_detections[str(frame_idx)], label_detections[str(frame_idx + DS_FACTOR)],
|
168 |
+
frame_idx, save=False
|
169 |
+
)
|
170 |
+
except:
|
171 |
+
break
|
172 |
+
# break into step dict
|
173 |
+
step = {
|
174 |
+
"observation": {"image": image, "state": s},
|
175 |
+
"action": a,
|
176 |
+
"language_instruction": lang,
|
177 |
+
}
|
178 |
+
steps.append(OrderedDict(step))
|
179 |
+
if len(steps) > MAX_EPISODE_LENGTH:
|
180 |
+
break
|
181 |
+
|
182 |
+
start_frame_idx = end_frame_idx
|
183 |
+
if len(steps) < MIN_CLIP_LENGTH:
|
184 |
+
data_dict = {"steps": steps}
|
185 |
+
print(f"max_frame_idx: {max_frame_idx} ds factor: {DS_FACTOR} {len(steps)}")
|
186 |
+
yield data_dict
|
datasets/extern/epic_kitchen.py
ADDED
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# Licensed under The MIT License [see LICENSE for details]
|
3 |
+
# --------------------------------------------------------
|
4 |
+
# https://github.com/epic-kitchens/epic-kitchens-100-hand-object-bboxes/blob/master/notebooks/demo.ipynb
|
5 |
+
import os
|
6 |
+
import numpy as np
|
7 |
+
from tqdm import tqdm
|
8 |
+
from collections import OrderedDict
|
9 |
+
import os
|
10 |
+
import numpy as np
|
11 |
+
from pathlib import Path
|
12 |
+
|
13 |
+
|
14 |
+
CURRENT_DIR = os.path.dirname(__file__)
|
15 |
+
import cv2
|
16 |
+
from os.path import expanduser
|
17 |
+
from epic_kitchens.hoa.types import BBox, FloatVector, HandSide
|
18 |
+
from epic_kitchens.hoa import load_detections
|
19 |
+
|
20 |
+
RESOLUTION = (480, 480)
|
21 |
+
home = expanduser("~")
|
22 |
+
|
23 |
+
# Adjust these to the where-ever your detections and frames are stored.
|
24 |
+
DETECTION_ROOT = f"/checkpoint/xinleic/LR/epic-kitchens-100-hand-object-bboxes/labels/hand-objects"
|
25 |
+
FRAMES_ROOT = f"/datasets01/EPIC-KITCHENS-100"
|
26 |
+
|
27 |
+
# DETECTION_ROOT = f'{home}/Projects/epic_kitchen_labels/hand-objects'
|
28 |
+
# FRAMES_ROOT = f'{home}/EPIC-KITCHENS'
|
29 |
+
detections_root = Path(DETECTION_ROOT)
|
30 |
+
frames_root = Path(FRAMES_ROOT)
|
31 |
+
|
32 |
+
|
33 |
+
def compute_state_and_actions(curr_frame, next_frame):
|
34 |
+
curr_hand1, curr_hand2 = curr_frame.hands[0], curr_frame.hands[1]
|
35 |
+
if curr_hand1.side != HandSide.LEFT: # flip
|
36 |
+
curr_hand1, curr_hand2 = curr_hand2, curr_hand1
|
37 |
+
|
38 |
+
# already normalized
|
39 |
+
curr_hand1_center = curr_hand1.bbox.center
|
40 |
+
curr_hand2_center = curr_hand2.bbox.center
|
41 |
+
|
42 |
+
next_hand1, next_hand2 = next_frame.hands[0], next_frame.hands[1]
|
43 |
+
if next_hand1.side != HandSide.LEFT: # flip
|
44 |
+
next_hand1, next_hand2 = next_hand2, next_hand1
|
45 |
+
|
46 |
+
# already normalized even
|
47 |
+
next_hand1_center = next_hand1.bbox.center
|
48 |
+
next_hand2_center = next_hand2.bbox.center
|
49 |
+
state = np.concatenate((curr_hand1_center, curr_hand2_center))
|
50 |
+
action = np.concatenate(
|
51 |
+
(
|
52 |
+
np.array(next_hand1_center) - np.array(curr_hand1_center),
|
53 |
+
np.array(next_hand2_center) - np.array(curr_hand2_center),
|
54 |
+
)
|
55 |
+
)
|
56 |
+
return state, action
|
57 |
+
|
58 |
+
|
59 |
+
# define your own dataset conversion
|
60 |
+
def convert_dataset_image():
|
61 |
+
# convert to a list of episodes that can be added to replay buffer
|
62 |
+
ALL_EPISODES = os.listdir(FRAMES_ROOT)
|
63 |
+
MAX_EPISODE_LENGTH = 5000
|
64 |
+
|
65 |
+
for EPS in ALL_EPISODES:
|
66 |
+
rgb_path = os.path.join(FRAMES_ROOT, EPS, "rgb_frames")
|
67 |
+
if not os.path.exists(rgb_path):
|
68 |
+
continue
|
69 |
+
for video_id in os.listdir(rgb_path):
|
70 |
+
full_path = os.path.join(rgb_path, video_id)
|
71 |
+
if (
|
72 |
+
not full_path.endswith(".tar") and not full_path.endswith(".jpg") and not full_path.endswith("home")
|
73 |
+
): # folder
|
74 |
+
|
75 |
+
# action extractions over bounding boxes subtractions of both hands.
|
76 |
+
participant_id = video_id[:3]
|
77 |
+
video_detections = load_detections(detections_root / participant_id / (video_id + ".pkl"))
|
78 |
+
max_frame_idx = len(video_detections) - 1
|
79 |
+
DS_FACTOR = 1
|
80 |
+
print(full_path)
|
81 |
+
steps = []
|
82 |
+
|
83 |
+
for frame_idx in range(0, max_frame_idx - DS_FACTOR, DS_FACTOR):
|
84 |
+
# print(video_detections[frame_idx].hands)
|
85 |
+
if (
|
86 |
+
len(video_detections[frame_idx].hands) != 2
|
87 |
+
or len(video_detections[frame_idx + DS_FACTOR].hands) != 2
|
88 |
+
):
|
89 |
+
continue
|
90 |
+
|
91 |
+
s, a = compute_state_and_actions(
|
92 |
+
video_detections[frame_idx], video_detections[frame_idx + DS_FACTOR]
|
93 |
+
)
|
94 |
+
lang = "use human hands to do some tasks" # dummies
|
95 |
+
# print("state actions:", s, a)
|
96 |
+
image_path = frames_root / participant_id / "rgb_frames" / video_id / f"frame_{frame_idx:010d}.jpg"
|
97 |
+
# print(image_path)
|
98 |
+
image = cv2.imread(str(image_path))
|
99 |
+
if image is None:
|
100 |
+
continue
|
101 |
+
image = image[..., [2, 1, 0]] # RGB
|
102 |
+
|
103 |
+
# break into step dict
|
104 |
+
step = {
|
105 |
+
"observation": {"image": image, "state": s},
|
106 |
+
"action": a,
|
107 |
+
"language_instruction": lang,
|
108 |
+
}
|
109 |
+
steps.append(OrderedDict(step))
|
110 |
+
if len(steps) > MAX_EPISODE_LENGTH:
|
111 |
+
break
|
112 |
+
data_dict = {"steps": steps}
|
113 |
+
print(f"max_frame_idx: {max_frame_idx} ds factor: {DS_FACTOR} {len(steps)}")
|
114 |
+
yield data_dict
|
115 |
+
|
datasets/extern/frodobot.py
ADDED
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# Licensed under The MIT License [see LICENSE for details]
|
3 |
+
# --------------------------------------------------------
|
4 |
+
import random
|
5 |
+
import os
|
6 |
+
import time
|
7 |
+
import sys
|
8 |
+
import numpy as np
|
9 |
+
import IPython
|
10 |
+
import torch
|
11 |
+
from tqdm import tqdm
|
12 |
+
from collections import OrderedDict
|
13 |
+
import os
|
14 |
+
import PIL.Image
|
15 |
+
import numpy as np
|
16 |
+
from typing import Union, List
|
17 |
+
from pathlib import Path
|
18 |
+
import re
|
19 |
+
|
20 |
+
CURRENT_DIR = os.path.dirname(__file__)
|
21 |
+
import cv2
|
22 |
+
from os.path import expanduser
|
23 |
+
import pickle
|
24 |
+
import cv2
|
25 |
+
from matplotlib import pyplot as plt
|
26 |
+
|
27 |
+
import pandas as pd
|
28 |
+
import json
|
29 |
+
|
30 |
+
RESOLUTION = (480, 480)
|
31 |
+
DATA = "/home/liruiw/Projects/frodobot/"
|
32 |
+
# https://colab.research.google.com/#scrollTo=50ce529a-a20a-4852-9a5a-114b52b98f2e&fileId=https%3A//huggingface.co/datasets/frodobots/FrodoBots-2K/blob/main/helpercode.ipynb
|
33 |
+
|
34 |
+
|
35 |
+
# #### control data
|
36 |
+
import pandas as pd
|
37 |
+
|
38 |
+
# print(f"{dataset_dir}/control_data_{ride_id}.csv")
|
39 |
+
|
40 |
+
|
41 |
+
def convert_img_dataset(
|
42 |
+
dataset_dir="/home/liruiw/Projects/frodobot/output_rides_22",
|
43 |
+
env_names=None,
|
44 |
+
gui=False,
|
45 |
+
episode_num_pertask=2000,
|
46 |
+
**kwargs,
|
47 |
+
):
|
48 |
+
# convert to a list of episodes that can be added to replay buffer
|
49 |
+
for eps_file in os.listdir(dataset_dir)[:50]: # 50 trajectories
|
50 |
+
dataset_dir_ = os.path.join(dataset_dir, eps_file)
|
51 |
+
if os.path.isdir(dataset_dir_):
|
52 |
+
ride_id = dataset_dir_.split("_")[-2]
|
53 |
+
print(dataset_dir_)
|
54 |
+
|
55 |
+
##### control data
|
56 |
+
control = pd.read_csv(f"{dataset_dir_}/control_data_{ride_id}.csv")
|
57 |
+
control_data_dict = control.set_index("timestamp").T.to_dict("list")
|
58 |
+
control_sorted_keys = sorted(list(control_data_dict.keys()))
|
59 |
+
|
60 |
+
##### IMU data
|
61 |
+
gyro_data = pd.read_csv(f"{dataset_dir_}/imu_data_{ride_id}.csv")[["timestamp", "gyroscope"]]
|
62 |
+
gyro_data_dict = gyro_data.set_index("timestamp").T.to_dict("list")
|
63 |
+
gyro_sorted_keys = sorted(list(gyro_data_dict.keys()))
|
64 |
+
|
65 |
+
compass_data = pd.read_csv(f"{dataset_dir_}/imu_data_{ride_id}.csv")[["timestamp", "compass"]]
|
66 |
+
compass_data_dict = compass_data.set_index("timestamp").T.to_dict("list")
|
67 |
+
compass_sorted_keys = sorted(list(compass_data_dict.keys()))
|
68 |
+
|
69 |
+
accel_data = pd.read_csv(f"{dataset_dir_}/imu_data_{ride_id}.csv")[["timestamp", "accelerometer"]]
|
70 |
+
accel_data_dict = accel_data.set_index("timestamp").T.to_dict("list")
|
71 |
+
accel_sorted_keys = sorted(list(accel_data_dict.keys()))
|
72 |
+
|
73 |
+
##### Camera data
|
74 |
+
camera_data = pd.read_csv(f"{dataset_dir_}/front_camera_timestamps_{ride_id}.csv")
|
75 |
+
camera_data_dict = camera_data.set_index("timestamp").T.to_dict("list")
|
76 |
+
camera_sorted_keys = sorted(list(camera_data_dict.keys()))
|
77 |
+
|
78 |
+
images = sorted(os.listdir(f"{dataset_dir_}/front_camera/"))
|
79 |
+
|
80 |
+
# #### front camera video
|
81 |
+
# front_camera = f"{dataset_dir}/recordings/0f0e8539d249f38e3ae7b18660f5af8c_ride_39572__uid_s_1000__uid_e_video_20240502221408754.ts"
|
82 |
+
languages = "drive around to play" # dummy
|
83 |
+
steps = []
|
84 |
+
SUBSAMPLE_IDX = 5
|
85 |
+
|
86 |
+
for idx, control_t in enumerate(control_sorted_keys):
|
87 |
+
|
88 |
+
# enumerate along actions and only pick matched timesteps
|
89 |
+
action = control_data_dict[control_t]
|
90 |
+
camera_t = camera_sorted_keys[np.argmin(np.array(camera_sorted_keys) - control_t)]
|
91 |
+
camera_path = images[camera_data_dict[camera_t][0]]
|
92 |
+
img = cv2.resize(cv2.imread(f"{dataset_dir_}/front_camera/{camera_path}"), None, fx=0.5, fy=0.5)
|
93 |
+
gyro = gyro_data_dict[gyro_sorted_keys[np.argmin(np.array(gyro_sorted_keys) - control_t)]]
|
94 |
+
first_three_strings = eval(gyro[0])[0][:3]
|
95 |
+
gyro_array = np.array(first_three_strings, dtype=float)
|
96 |
+
|
97 |
+
compass = compass_data_dict[compass_sorted_keys[np.argmin(np.array(compass_sorted_keys) - control_t)]]
|
98 |
+
first_three_strings = eval(compass[0])[0][:3]
|
99 |
+
compass_array = np.array(first_three_strings, dtype=float)
|
100 |
+
|
101 |
+
accel = accel_data_dict[accel_sorted_keys[np.argmin(np.array(accel_sorted_keys) - control_t)]]
|
102 |
+
first_three_strings = eval(accel[0])[0][:3]
|
103 |
+
accel_array = np.array(first_three_strings, dtype=float)
|
104 |
+
|
105 |
+
prop = np.concatenate((gyro_array, compass_array, accel_array))
|
106 |
+
step = {
|
107 |
+
"observation": {"state": prop, "image": img},
|
108 |
+
"action": action,
|
109 |
+
"language_instruction": languages,
|
110 |
+
}
|
111 |
+
steps.append(OrderedDict(step))
|
112 |
+
data_dict = {"steps": steps}
|
113 |
+
yield data_dict
|
114 |
+
|
115 |
+
|
116 |
+
class RolloutRunner:
|
117 |
+
"""evaluate policy rollouts"""
|
118 |
+
|
119 |
+
def __init__(self, env_names, episode_num, save_video=False):
|
120 |
+
self.env_names = env_names
|
121 |
+
self.episode_num = episode_num
|
122 |
+
self.envs = []
|
123 |
+
self.scene_files = []
|
124 |
+
self.save_video = save_video
|
125 |
+
|
126 |
+
@torch.no_grad()
|
127 |
+
def run(self, policy, save_video=False, gui=False, video_postfix="", seed=233, env_name=None, **kwargs):
|
128 |
+
pass
|
datasets/extern/robomimic.py
ADDED
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# Licensed under The MIT License [see LICENSE for details]
|
3 |
+
# --------------------------------------------------------
|
4 |
+
|
5 |
+
"""
|
6 |
+
TODO: explain
|
7 |
+
"""
|
8 |
+
import h5py
|
9 |
+
import numpy as np
|
10 |
+
import cv2
|
11 |
+
import time
|
12 |
+
from collections import OrderedDict
|
13 |
+
import robomimic.utils.file_utils as FileUtils
|
14 |
+
|
15 |
+
from sim.robomimic.robomimic_runner import (
|
16 |
+
create_env, OBS_KEYS, RESOLUTION
|
17 |
+
)
|
18 |
+
from sim.robomimic.robomimic_wrapper import RobomimicLowdimWrapper
|
19 |
+
|
20 |
+
from typing import Optional, Iterable
|
21 |
+
|
22 |
+
DATASET_DIR = 'data/robomimic/datasets'
|
23 |
+
SUPPORTED_ENVS = ['lift', 'square', 'can']
|
24 |
+
NUM_EPISODES_PER_TASK = 200
|
25 |
+
|
26 |
+
|
27 |
+
def render_step(env, state):
|
28 |
+
env.env.env.sim.set_state_from_flattened(state)
|
29 |
+
env.env.env.sim.forward()
|
30 |
+
img = env.render()
|
31 |
+
img = cv2.resize(img, RESOLUTION)
|
32 |
+
return img
|
33 |
+
|
34 |
+
|
35 |
+
def robomimic_dataset_size() -> int:
|
36 |
+
return len(SUPPORTED_ENVS) * NUM_EPISODES_PER_TASK
|
37 |
+
|
38 |
+
|
39 |
+
def robomimic_dataset_generator(example_inds: Optional[Iterable[int]] = None):
|
40 |
+
if example_inds is None:
|
41 |
+
example_inds = range(robomimic_dataset_size())
|
42 |
+
|
43 |
+
curr_env_name = None
|
44 |
+
for idx in example_inds:
|
45 |
+
# get env_name corresponding to idx
|
46 |
+
env_name = SUPPORTED_ENVS[idx // NUM_EPISODES_PER_TASK]
|
47 |
+
if curr_env_name is None or curr_env_name != env_name:
|
48 |
+
# need to load new env
|
49 |
+
dataset = f"{DATASET_DIR}/{env_name}/ph/image.hdf5"
|
50 |
+
env_meta = FileUtils.get_env_metadata_from_dataset(dataset)
|
51 |
+
env_meta["use_image_obs"] = True
|
52 |
+
env = create_env(env_meta=env_meta, obs_keys=OBS_KEYS)
|
53 |
+
env = RobomimicLowdimWrapper(env=env)
|
54 |
+
env.reset() # NOTE: this is necessary to remove green laser bug
|
55 |
+
curr_env_name = env_name
|
56 |
+
|
57 |
+
with h5py.File(dataset) as file:
|
58 |
+
demos = file["data"]
|
59 |
+
local_episode_idx = idx % NUM_EPISODES_PER_TASK
|
60 |
+
if f"demo_{local_episode_idx}" not in demos:
|
61 |
+
continue
|
62 |
+
|
63 |
+
demo = demos[f"demo_{local_episode_idx}"]
|
64 |
+
obs = demo["obs"]
|
65 |
+
states = demo["states"]
|
66 |
+
action = demo["actions"][:].astype(np.float32)
|
67 |
+
step_obs = np.concatenate([obs[key] for key in OBS_KEYS], axis=-1).astype(np.float32)
|
68 |
+
steps = []
|
69 |
+
for a, o, s in zip(action, step_obs, states):
|
70 |
+
# break into step dict
|
71 |
+
image = render_step(env, s)
|
72 |
+
step = {
|
73 |
+
"observation": {"state": o, "image": image},
|
74 |
+
"action": a,
|
75 |
+
"language_instruction": f"{env_name}",
|
76 |
+
}
|
77 |
+
steps.append(OrderedDict(step))
|
78 |
+
data_dict = {"steps": steps}
|
79 |
+
yield data_dict
|
80 |
+
|
81 |
+
# # import imageio
|
82 |
+
# for _ in range(3):
|
83 |
+
# steps = []
|
84 |
+
# perturbed_action = action + np.random.normal(0, 0.2, action.shape)
|
85 |
+
# current_state = states[0]
|
86 |
+
# _ = render_step(env, current_state)
|
87 |
+
# for someindex in range(len(action)):
|
88 |
+
# image = env.render()
|
89 |
+
# step = {
|
90 |
+
# "observation": {"image": image},
|
91 |
+
# "action": action[someindex],
|
92 |
+
# "language_instruction": f"{env_name}",
|
93 |
+
# }
|
94 |
+
# steps.append(OrderedDict(step))
|
95 |
+
|
96 |
+
# # simulate action
|
97 |
+
# env.step(perturbed_action[someindex])
|
98 |
+
|
99 |
+
# # # save video
|
100 |
+
# # frames = [step["observation"]["image"] for step in steps]
|
101 |
+
# # imageio.mimsave(f"test.mp4", frames, fps=10)
|
102 |
+
# # while not (user_input := input("Continue? (y/n)")) in ["y", "n"]:
|
103 |
+
# # print("Invalid input")
|
104 |
+
|
105 |
+
# data_dict = {"steps": steps}
|
106 |
+
# yield data_dict
|
107 |
+
|
108 |
+
env.close()
|
datasets/merge_shards.py
ADDED
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Merge data shards generated from `encode_{extern,openx}_dataset.py`
|
3 |
+
In addition to CLI args, `SHARD_DATA_FORMAT` must be changed depending on the dataset.
|
4 |
+
"""
|
5 |
+
|
6 |
+
import argparse
|
7 |
+
import json
|
8 |
+
import os
|
9 |
+
|
10 |
+
import numpy as np
|
11 |
+
from tqdm.auto import tqdm
|
12 |
+
|
13 |
+
SHARD_DATA_FORMAT = "/private/home/xinleic/LR/HPT-Video-KZ/sharded_data/droid_magvit_shard{}_of_{}_train"
|
14 |
+
|
15 |
+
|
16 |
+
if __name__ == "__main__":
|
17 |
+
parser = argparse.ArgumentParser()
|
18 |
+
parser.add_argument("--out_data_dir", type=str, required=True,
|
19 |
+
help="Directory to save merged data, must not exist.")
|
20 |
+
parser.add_argument("--num_shards", type=int, required=True, help="Number of shards the dataset was split into.")
|
21 |
+
|
22 |
+
args = parser.parse_args()
|
23 |
+
assert not os.path.exists(args.out_data_dir), "Will not overwrite existing directory."
|
24 |
+
os.makedirs(os.path.join(args.out_data_dir, "actions"), exist_ok=True)
|
25 |
+
|
26 |
+
num_frames = 0
|
27 |
+
valid_inds = []
|
28 |
+
|
29 |
+
for shard_ind in range(args.num_shards):
|
30 |
+
shard_path = SHARD_DATA_FORMAT.format(shard_ind, args.num_shards)
|
31 |
+
if os.path.isfile(os.path.join(shard_path, "metadata.json")):
|
32 |
+
valid_inds.append(shard_ind)
|
33 |
+
with open(os.path.join(shard_path, "metadata.json"), "r") as f:
|
34 |
+
shard_metadata = json.load(f)
|
35 |
+
|
36 |
+
num_frames += shard_metadata["num_images"]
|
37 |
+
else:
|
38 |
+
print(f"{shard_ind=} is invalid.")
|
39 |
+
|
40 |
+
if num_frames == 0:
|
41 |
+
print("No valid shards")
|
42 |
+
exit(0)
|
43 |
+
|
44 |
+
token_dtype = np.dtype(shard_metadata["token_dtype"])
|
45 |
+
if shard_metadata["quantized"]:
|
46 |
+
frame_dims = (shard_metadata["h"], shard_metadata["w"])
|
47 |
+
else:
|
48 |
+
frame_dims = (shard_metadata["latent_channels"], shard_metadata["h"], shard_metadata["w"])
|
49 |
+
|
50 |
+
action_dim = shard_metadata["action_dim"]
|
51 |
+
videos = np.memmap(
|
52 |
+
os.path.join(args.out_data_dir, "video.bin"),
|
53 |
+
dtype=token_dtype,
|
54 |
+
mode="write",
|
55 |
+
shape=(num_frames, *frame_dims)
|
56 |
+
)
|
57 |
+
|
58 |
+
actions = np.memmap(
|
59 |
+
os.path.join(args.out_data_dir, "actions", "actions.bin"),
|
60 |
+
dtype=np.float32,
|
61 |
+
mode="write",
|
62 |
+
shape=(num_frames, action_dim)
|
63 |
+
)
|
64 |
+
|
65 |
+
segment_ids = np.memmap(
|
66 |
+
os.path.join(args.out_data_dir, "segment_ids.bin"),
|
67 |
+
dtype=np.int32,
|
68 |
+
mode="write",
|
69 |
+
shape=(num_frames,)
|
70 |
+
)
|
71 |
+
|
72 |
+
prev_frame_ind = 0
|
73 |
+
prev_segment_id = 0
|
74 |
+
|
75 |
+
for shard_ind in tqdm(valid_inds):
|
76 |
+
shard_path = SHARD_DATA_FORMAT.format(shard_ind, args.num_shards)
|
77 |
+
with open(os.path.join(shard_path, "metadata.json"), "r") as f:
|
78 |
+
shard_metadata = json.load(f)
|
79 |
+
|
80 |
+
shard_num_frames = shard_metadata["num_images"]
|
81 |
+
videos[prev_frame_ind: prev_frame_ind + shard_num_frames] = np.memmap(
|
82 |
+
os.path.join(shard_path, "video.bin"),
|
83 |
+
dtype=np.dtype(shard_metadata["token_dtype"]),
|
84 |
+
mode="r",
|
85 |
+
shape=(shard_num_frames, *frame_dims),
|
86 |
+
)
|
87 |
+
|
88 |
+
actions[prev_frame_ind: prev_frame_ind + shard_num_frames] = np.memmap(
|
89 |
+
os.path.join(shard_path, "actions", "actions.bin"),
|
90 |
+
dtype=np.float32,
|
91 |
+
mode="r",
|
92 |
+
shape=(shard_num_frames, action_dim),
|
93 |
+
)
|
94 |
+
|
95 |
+
segment_ids[prev_frame_ind: prev_frame_ind + shard_num_frames] = np.memmap(
|
96 |
+
os.path.join(shard_path, "segment_ids.bin"),
|
97 |
+
dtype=np.int32,
|
98 |
+
mode="r",
|
99 |
+
shape=(shard_num_frames,),
|
100 |
+
) + prev_segment_id
|
101 |
+
|
102 |
+
prev_segment_id = segment_ids[prev_frame_ind + shard_num_frames - 1] + 1
|
103 |
+
prev_frame_ind += shard_num_frames
|
104 |
+
|
105 |
+
assert prev_frame_ind == num_frames
|
106 |
+
print("Finished")
|
107 |
+
|
108 |
+
with (open(os.path.join(args.out_data_dir, "metadata.json"), "w") as f):
|
109 |
+
merged_metadata = shard_metadata \
|
110 |
+
| vars(args) \
|
111 |
+
| {"num_images": num_frames, "input_path": SHARD_DATA_FORMAT.format(0, args.num_shards)}
|
112 |
+
|
113 |
+
json.dump(merged_metadata, f)
|
datasets/utils.py
ADDED
@@ -0,0 +1,244 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
import cv2
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
import torchvision.transforms.v2.functional as transforms_f
|
7 |
+
from diffusers import AutoencoderKLTemporalDecoder
|
8 |
+
from einops import rearrange
|
9 |
+
from transformers import T5Tokenizer, T5Model
|
10 |
+
|
11 |
+
from magvit2.config import VQConfig
|
12 |
+
from magvit2.models.lfqgan import VQModel
|
13 |
+
|
14 |
+
vision_model = None
|
15 |
+
|
16 |
+
|
17 |
+
def get_image_encoder(encoder_type: str, encoder_name_or_path: str):
|
18 |
+
encoder_type = encoder_type.lower()
|
19 |
+
if encoder_type == "magvit":
|
20 |
+
return VQModel(VQConfig(), ckpt_path=encoder_name_or_path)
|
21 |
+
elif encoder_type == "temporalvae":
|
22 |
+
return AutoencoderKLTemporalDecoder.from_pretrained(encoder_name_or_path, subfolder="vae")
|
23 |
+
else:
|
24 |
+
raise NotImplementedError(f"{encoder_type=}")
|
25 |
+
|
26 |
+
|
27 |
+
def set_seed(seed):
|
28 |
+
# set seed for reproducibility
|
29 |
+
torch.manual_seed(seed)
|
30 |
+
np.random.seed(seed)
|
31 |
+
|
32 |
+
|
33 |
+
def mkdir_if_missing(dst_dir):
|
34 |
+
"""make destination folder if it's missing"""
|
35 |
+
if not os.path.exists(dst_dir):
|
36 |
+
os.makedirs(dst_dir)
|
37 |
+
|
38 |
+
def resize_image(image, resize=True):
|
39 |
+
MAX_RES = 1024
|
40 |
+
|
41 |
+
# convert to array
|
42 |
+
image = np.asarray(image)
|
43 |
+
h, w = image.shape[:2]
|
44 |
+
if h > MAX_RES or w > MAX_RES:
|
45 |
+
if h < w:
|
46 |
+
new_h, new_w = int(MAX_RES * w / h), MAX_RES
|
47 |
+
else:
|
48 |
+
new_h, new_w = MAX_RES, int(MAX_RES * h / w)
|
49 |
+
image = cv2.resize(image, (new_w, new_h))
|
50 |
+
|
51 |
+
if resize:
|
52 |
+
# resize the shorter side to 256 and then do a center crop
|
53 |
+
h, w = image.shape[:2]
|
54 |
+
if h < w:
|
55 |
+
new_h, new_w = 256, int(256 * w / h)
|
56 |
+
else:
|
57 |
+
new_h, new_w = int(256 * h / w), 256
|
58 |
+
image = cv2.resize(image, (new_w, new_h))
|
59 |
+
|
60 |
+
h, w = image.shape[:2]
|
61 |
+
crop_h, crop_w = 256, 256
|
62 |
+
start_h = (h - crop_h) // 2
|
63 |
+
start_w = (w - crop_w) // 2
|
64 |
+
image = image[start_h:start_h + crop_h, start_w:start_w + crop_w]
|
65 |
+
return image
|
66 |
+
|
67 |
+
def normalize_image(image, resize=True):
|
68 |
+
"""
|
69 |
+
H x W x 3(uint8) -> imagenet normalized 3 x H x W
|
70 |
+
|
71 |
+
Normalizes image to [-1, 1].
|
72 |
+
Resizes the image if resize=True or if the image resolution > MAX_RES
|
73 |
+
"""
|
74 |
+
image = resize_image(image, resize=resize)
|
75 |
+
# normalize between -1 and 1
|
76 |
+
image = image / 255.0
|
77 |
+
image = (image * 2 - 1.)
|
78 |
+
return torch.from_numpy(image.transpose(2, 0, 1))
|
79 |
+
|
80 |
+
|
81 |
+
def unnormalize_image(magvit_output):
|
82 |
+
"""
|
83 |
+
[-1, 1] -> [0, 255]
|
84 |
+
|
85 |
+
Important: clip to [0, 255]
|
86 |
+
"""
|
87 |
+
rescaled_output = ((magvit_output.detach().cpu() + 1) * 127.5)
|
88 |
+
clipped_output = torch.clamp(rescaled_output, 0, 255).to(dtype=torch.uint8)
|
89 |
+
return clipped_output
|
90 |
+
|
91 |
+
@torch.inference_mode()
|
92 |
+
@torch.no_grad()
|
93 |
+
def get_quantized_image_embeddings(
|
94 |
+
image,
|
95 |
+
encoder_type,
|
96 |
+
encoder_name_or_path,
|
97 |
+
keep_res=False,
|
98 |
+
device="cuda",
|
99 |
+
):
|
100 |
+
"""
|
101 |
+
image: (h, w, 3)
|
102 |
+
"""
|
103 |
+
global vision_model
|
104 |
+
DEBUG = False
|
105 |
+
dtype = torch.bfloat16
|
106 |
+
|
107 |
+
if vision_model is None:
|
108 |
+
vision_model = get_image_encoder(encoder_type=encoder_type, encoder_name_or_path=encoder_name_or_path)
|
109 |
+
vision_model = vision_model.to(device=device, dtype=dtype)
|
110 |
+
vision_model.eval()
|
111 |
+
|
112 |
+
batch = normalize_image(image, resize=not keep_res)[None]
|
113 |
+
if not keep_res:
|
114 |
+
img_h, img_w = 256, 256
|
115 |
+
else:
|
116 |
+
img_h, img_w = batch.shape[2:]
|
117 |
+
|
118 |
+
h, w = img_h // 16, img_w // 16
|
119 |
+
|
120 |
+
with vision_model.ema_scope():
|
121 |
+
quant_, _, indices, _ = vision_model.encode(batch.to(device=device, dtype=dtype), flip=True)
|
122 |
+
indices = rearrange(indices, "(h w) -> h w", h=h, w=w)
|
123 |
+
|
124 |
+
# alternative way to get indices
|
125 |
+
# indices_ = vision_model.quantize.bits_to_indices(quant_.permute(0, 2, 3, 1) > 0).cpu().numpy()
|
126 |
+
# indices_ = rearrange(indices_, "(h w) -> h w", h=h, w=w)
|
127 |
+
|
128 |
+
if DEBUG:
|
129 |
+
# sanity check: decode and then visualize
|
130 |
+
with vision_model.ema_scope():
|
131 |
+
indices = indices[None]
|
132 |
+
# bit representations
|
133 |
+
quant = vision_model.quantize.get_codebook_entry(rearrange(indices, "b h w -> b (h w)"),
|
134 |
+
bhwc=indices.shape + (vision_model.quantize.codebook_dim,)).flip(1)
|
135 |
+
## why is there a flip(1) needed for the codebook bits?
|
136 |
+
decoded_img = unnormalize_image(vision_model.decode(quant.to(device=device, dtype=dtype)))
|
137 |
+
transforms_f.to_pil_image(decoded_img[0]).save("decoded.png")
|
138 |
+
transforms_f.to_pil_image(image).save("original.png") # show()
|
139 |
+
|
140 |
+
# 18 x 16 x 16 of [-1., 1.] - > 16 x 16 uint32
|
141 |
+
indices = indices.type(torch.int32)
|
142 |
+
indices = indices.detach().cpu().numpy().astype(np.uint32)
|
143 |
+
return indices
|
144 |
+
|
145 |
+
|
146 |
+
@torch.inference_mode()
|
147 |
+
@torch.no_grad()
|
148 |
+
def get_vae_image_embeddings(
|
149 |
+
image,
|
150 |
+
encoder_type,
|
151 |
+
encoder_name_or_path,
|
152 |
+
keep_res: bool = False,
|
153 |
+
device="cuda",
|
154 |
+
):
|
155 |
+
"""
|
156 |
+
image: (h, w, 3), in [-1, 1]
|
157 |
+
use SD VAE to encode and decode the images.
|
158 |
+
"""
|
159 |
+
global vision_model
|
160 |
+
DEBUG = False
|
161 |
+
dtype = torch.bfloat16
|
162 |
+
|
163 |
+
if vision_model is None:
|
164 |
+
vision_model = get_image_encoder(encoder_type, encoder_name_or_path)
|
165 |
+
vision_model = vision_model.to(device=device, dtype=dtype)
|
166 |
+
vision_model.eval()
|
167 |
+
|
168 |
+
# https://github.com/bytedance/IRASim/blob/main/sample/sample_autoregressive.py#L151
|
169 |
+
# if args.use_temporal_decoder:
|
170 |
+
# vae = AutoencoderKLTemporalDecoder.from_pretrained(args.vae_model_path, subfolder="t2v_required_models/vae_temporal_decoder").to(device)
|
171 |
+
# else:
|
172 |
+
# vae = AutoencoderKL.from_pretrained(args.vae_model_path, subfolder="vae").to(device)
|
173 |
+
# x = vae.encode(x).latent_dist.sample().mul_(vae.config.scaling_factor) ?
|
174 |
+
|
175 |
+
batch = normalize_image(image, resize=not keep_res)[None]
|
176 |
+
|
177 |
+
if isinstance(vision_model, AutoencoderKLTemporalDecoder):
|
178 |
+
# Think SVD expects images in [-1, 1] so we don't have to change anything?
|
179 |
+
# https://github.com/Stability-AI/generative-models/blob/1659a1c09b0953ad9cc0d480f42e4526c5575b37/scripts/demo/video_sampling.py#L182
|
180 |
+
# https://github.com/Stability-AI/generative-models/blob/1659a1c09b0953ad9cc0d480f42e4526c5575b37/scripts/demo/streamlit_helpers.py#L894
|
181 |
+
z = vision_model.encode(batch.to(device=device, dtype=dtype)).latent_dist.mean
|
182 |
+
elif isinstance(vision_model, VQModel): # vision_model should be VQModel
|
183 |
+
# with vision_model.ema_scope(): # doesn't matter due to bugged VQModel ckpt_path arg
|
184 |
+
z = vision_model.encode_without_quantize(batch.to(device=device, dtype=dtype))
|
185 |
+
else:
|
186 |
+
raise NotImplementedError(f"{vision_model=}")
|
187 |
+
|
188 |
+
if DEBUG:
|
189 |
+
decoded_img = unnormalize_image(vision_model.decode(z.to(device=device, dtype=dtype)))
|
190 |
+
transforms_f.to_pil_image(decoded_img[0]).save("decoded_unquant.png")
|
191 |
+
transforms_f.to_pil_image(image).save("original.png")
|
192 |
+
|
193 |
+
return z[0].detach().cpu().float().numpy().astype(np.float16)
|
194 |
+
|
195 |
+
# switch to VAE in SD
|
196 |
+
# https://huggingface.co/stabilityai/stable-diffusion-3.5-large/tree/main/vae
|
197 |
+
# https://github.com/bytedance/IRASim/blob/main/sample/sample_autoregressive.py#L151
|
198 |
+
# from diffusers.models import AutoencoderKL,AutoencoderKLTemporalDecoder
|
199 |
+
# vae_model_path = 'pretrained_models/stabilityai/stable-diffusion-xl-base-1.0'
|
200 |
+
# if args.use_temporal_decoder:
|
201 |
+
# vae = AutoencoderKLTemporalDecoder.from_pretrained(vae_model_path, subfolder="t2v_required_models/vae_temporal_decoder").to(device)
|
202 |
+
# else:
|
203 |
+
# vae = AutoencoderKL.from_pretrained(vae_model_path, subfolder="vae").to(device)
|
204 |
+
# z = vae.encode(x).latent_dist.sample().mul_(vae.config.scaling_factor)
|
205 |
+
# if DEBUG:
|
206 |
+
# decoded_img = unnormalize_image(vae.decode(z.to(device=device, dtype=dtype) / vae.config.scaling_factor))
|
207 |
+
# transforms_f.to_pil_image(decoded_img[0]).save("decoded_unquant.png")
|
208 |
+
# transforms_f.to_pil_image(image).save("original.png")
|
209 |
+
|
210 |
+
|
211 |
+
@torch.no_grad()
|
212 |
+
def get_t5_embeddings(language, per_token=True, max_length=16, device="cpu"):
|
213 |
+
"""Get T5 embedding"""
|
214 |
+
global global_language_model, t5_tok
|
215 |
+
if global_language_model is None:
|
216 |
+
try:
|
217 |
+
t5_model = T5Model.from_pretrained("t5-base")
|
218 |
+
t5_tok = T5Tokenizer.from_pretrained("t5-base")
|
219 |
+
except:
|
220 |
+
t5_model = T5Model.from_pretrained("t5-base", local_files_only=True)
|
221 |
+
t5_tok = T5Tokenizer.from_pretrained("t5-base", local_files_only=True)
|
222 |
+
t5_model = t5_model.to(device)
|
223 |
+
global_language_model = t5_model
|
224 |
+
global_language_model.eval()
|
225 |
+
|
226 |
+
# forward pass through encoder only
|
227 |
+
enc = t5_tok(
|
228 |
+
[language],
|
229 |
+
return_tensors="pt",
|
230 |
+
padding="max_length",
|
231 |
+
truncation=True,
|
232 |
+
max_length=max_length,
|
233 |
+
).to(device)
|
234 |
+
|
235 |
+
output = global_language_model.encoder(
|
236 |
+
input_ids=enc["input_ids"], attention_mask=enc["attention_mask"], return_dict=True
|
237 |
+
)
|
238 |
+
torch.cuda.empty_cache()
|
239 |
+
if per_token:
|
240 |
+
return output.last_hidden_state[0].detach().cpu().numpy()
|
241 |
+
else:
|
242 |
+
# get the final hidden states. average across tokens.
|
243 |
+
emb = output.last_hidden_state[0].mean(dim=0).detach().cpu().numpy()
|
244 |
+
return emb
|
experiments/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
experiments/datasplit/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
experiments/datasplit/dataset1.yaml
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
domains: >
|
2 |
+
language_table
|
experiments/datasplit/dataset10.yaml
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
domains: >
|
2 |
+
bridge_data_v2,
|
3 |
+
fractal20220817_data,
|
4 |
+
language_table,
|
5 |
+
ucsd_pick_and_place_dataset_converted_externally_to_rlds,
|
6 |
+
kaist_nonprehensile_converted_externally_to_rlds,
|
7 |
+
ucsd_kitchen_dataset_converted_externally_to_rlds,
|
8 |
+
berkeley_fanuc_manipulation,
|
9 |
+
bc_z,
|
10 |
+
cmu_play_fusion,
|
11 |
+
kuka
|
experiments/datasplit/dataset15.yaml
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
domains: >
|
2 |
+
bridge_data_v2,
|
3 |
+
fractal20220817_data,
|
4 |
+
language_table,
|
5 |
+
ucsd_pick_and_place_dataset_converted_externally_to_rlds,
|
6 |
+
kaist_nonprehensile_converted_externally_to_rlds,
|
7 |
+
ucsd_kitchen_dataset_converted_externally_to_rlds,
|
8 |
+
utokyo_xarm_bimanual_converted_externally_to_rlds,
|
9 |
+
stanford_hydra_dataset_converted_externally_to_rlds,
|
10 |
+
austin_sirius_dataset_converted_externally_to_rlds,
|
11 |
+
berkeley_fanuc_manipulation,
|
12 |
+
berkeley_mvp_converted_externally_to_rlds,
|
13 |
+
berkeley_rpt_converted_externally_to_rlds,
|
14 |
+
cmu_play_fusion,
|
15 |
+
iamlab_cmu_pickup_insert_converted_externally_to_rlds,
|
16 |
+
qut_dexterous_manpulation
|
experiments/datasplit/dataset15_vae.yaml
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
domains: >
|
2 |
+
bridge_data_v2,
|
3 |
+
fractal20220817_data,
|
4 |
+
language_table,
|
5 |
+
ucsd_pick_and_place_dataset_converted_externally_to_rlds,
|
6 |
+
kaist_nonprehensile_converted_externally_to_rlds,
|
7 |
+
ucsd_kitchen_dataset_converted_externally_to_rlds,
|
8 |
+
utokyo_xarm_bimanual_converted_externally_to_rlds,
|
9 |
+
stanford_hydra_dataset_converted_externally_to_rlds,
|
10 |
+
austin_sirius_dataset_converted_externally_to_rlds,
|
11 |
+
berkeley_fanuc_manipulation,
|
12 |
+
berkeley_mvp_converted_externally_to_rlds,
|
13 |
+
berkeley_rpt_converted_externally_to_rlds,
|
14 |
+
cmu_play_fusion,
|
15 |
+
iamlab_cmu_pickup_insert_converted_externally_to_rlds,
|
16 |
+
qut_dexterous_manpulation
|
experiments/datasplit/dataset20.yaml
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
domains: >
|
2 |
+
bridge_data_v2,
|
3 |
+
fractal20220817_data,
|
4 |
+
language_table,
|
5 |
+
ucsd_pick_and_place_dataset_converted_externally_to_rlds,
|
6 |
+
kaist_nonprehensile_converted_externally_to_rlds,
|
7 |
+
ucsd_kitchen_dataset_converted_externally_to_rlds,
|
8 |
+
utokyo_xarm_bimanual_converted_externally_to_rlds,
|
9 |
+
stanford_hydra_dataset_converted_externally_to_rlds,
|
10 |
+
austin_sirius_dataset_converted_externally_to_rlds,
|
11 |
+
berkeley_fanuc_manipulation,
|
12 |
+
berkeley_mvp_converted_externally_to_rlds,
|
13 |
+
cmu_play_fusion,
|
14 |
+
robo_net,
|
15 |
+
furniture_bench_dataset_converted_externally_to_rlds,
|
16 |
+
dlr_sara_grid_clamp_converted_externally_to_rlds,
|
17 |
+
cmu_stretch,
|
18 |
+
droid,
|
19 |
+
toto,
|
20 |
+
bc_z,
|
21 |
+
kuka
|
experiments/datasplit/dataset20_vae.yaml
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
domains: >
|
2 |
+
language_table,
|
3 |
+
ucsd_pick_and_place_dataset_converted_externally_to_rlds,
|
4 |
+
kaist_nonprehensile_converted_externally_to_rlds,
|
5 |
+
ucsd_kitchen_dataset_converted_externally_to_rlds,
|
6 |
+
utokyo_xarm_bimanual_converted_externally_to_rlds,
|
7 |
+
stanford_hydra_dataset_converted_externally_to_rlds,
|
8 |
+
austin_sirius_dataset_converted_externally_to_rlds,
|
9 |
+
berkeley_fanuc_manipulation,
|
10 |
+
berkeley_mvp_converted_externally_to_rlds,
|
11 |
+
berkeley_rpt_converted_externally_to_rlds,
|
12 |
+
cmu_play_fusion,
|
13 |
+
iamlab_cmu_pickup_insert_converted_externally_to_rlds,
|
14 |
+
qut_dexterous_manpulation,
|
15 |
+
robo_net,
|
16 |
+
dlr_sara_grid_clamp_converted_externally_to_rlds,
|
17 |
+
cmu_stretch,
|
18 |
+
columbia_cairlab_pusht_real,
|
19 |
+
droid,
|
20 |
+
toto,
|
21 |
+
io_ai_tech
|
experiments/datasplit/dataset25.yaml
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
domains: >
|
2 |
+
bridge_data_v2,
|
3 |
+
fractal20220817_data,
|
4 |
+
language_table,
|
5 |
+
ucsd_pick_and_place_dataset_converted_externally_to_rlds,
|
6 |
+
kaist_nonprehensile_converted_externally_to_rlds,
|
7 |
+
ucsd_kitchen_dataset_converted_externally_to_rlds,
|
8 |
+
utokyo_xarm_bimanual_converted_externally_to_rlds,
|
9 |
+
stanford_hydra_dataset_converted_externally_to_rlds,
|
10 |
+
austin_sirius_dataset_converted_externally_to_rlds,
|
11 |
+
berkeley_fanuc_manipulation,
|
12 |
+
berkeley_mvp_converted_externally_to_rlds,
|
13 |
+
cmu_play_fusion,
|
14 |
+
iamlab_cmu_pickup_insert_converted_externally_to_rlds,
|
15 |
+
robo_net,
|
16 |
+
furniture_bench_dataset_converted_externally_to_rlds,
|
17 |
+
dlr_sara_grid_clamp_converted_externally_to_rlds,
|
18 |
+
cmu_stretch,
|
19 |
+
droid,
|
20 |
+
toto,
|
21 |
+
io_ai_tech,
|
22 |
+
bc_z,
|
23 |
+
roboturk,
|
24 |
+
cmu_franka_exploration_dataset_converted_externally_to_rlds,
|
25 |
+
nyu_door_opening_surprising_effectiveness,
|
26 |
+
kuka
|