hma /
LeroyWaa's picture
import json
import math
import os
import random
from pathlib import Path
import numpy as np
import torch
from einops import rearrange
from import Dataset as TorchDataset
from datasets.encode_openx_dataset import DATA_FREQ_TABLE
from genie.config import GenieConfig
from genie.st_mask_git import cosine_schedule
SVD_SCALE = 0.18215
def normalize_actions(actions):
compute mean and std of actions. Normalize actions is done inside the network.
mean = np.mean(actions, axis=0).tolist()
std = np.std(actions, axis=0).tolist()
return actions, [mean, std]
class RawFeatureDataset(TorchDataset):
""" Loads raw float32 tokens as memmap-backed array """
def __init__(
data_dir: directory with the same format as `data/train_v0` and `data/val_v0`.
Notably, has `video.bin` and `metadata.json`
window_size: number of frames per "video" sequence
stride: frame skip
filter_interrupts: Under 3% of training frame sequences are the concatenation of two different clips.
If filter_interrupts is True, will filter out these sequences using the segment ids.
filter_overlaps: If False (default), one frame will appear in multiple examples;
e.g. frame 0 might appear as the first frame in example 0 and also the second frame in example 15.
If True, will filter out examples so that each frame appears at most once in the dataset.
use_actions: If True, will load the actions from the `actions` folder for the models
data_dir = Path(data_dir)
with open(data_dir / "metadata.json") as f:
self.metadata = json.load(f)
# TODO: assert not quantized in metadata
shape = (self.metadata["num_images"], self.metadata.get("latent_channels", 4), self.metadata["h"], self.metadata["w"]) #
print("token shape:", shape)
self.use_raw_image_as_latent = use_raw_image_as_latent
if use_raw_image_as_latent:
shape = (shape[0], 3, shape[2], shape[3])
# resize to 32x32
video_tokens_path, segment_ids_path, action_tokens_path = [data_dir / f"{name}.bin"
for name in ["video", "segment_ids", "actions"]]
token_dtype = np.dtype(self.metadata.get("token_dtype", "float16")) = np.memmap(video_tokens_path, mode="r", shape=shape, dtype=token_dtype)
print("data nan:", torch.isnan(torch.from_numpy([:100].copy())).sum())
# import IPython; IPython.embed()
if use_raw_image_as_latent:
# debug for robomimic dataset
# 256->64x64
self.metadata["h"] = 32
self.metadata["w"] = 32
self.metadata["latent_channels"] = 3
self.window_size, self.stride = window_size, stride
self.datio_noise_ratio = datio_noise_ratio
if domain is not None: # TODO: remove = domain
else: = self.metadata["name"] ="_noquant", "")
self.stride = stride
if compute_stride_from_freq_table:
self.stride = max(DATA_FREQ_TABLE.get(, 1) // natural_hz, 1)
self.n_action = self.metadata.get("action_dim", 1) * (self.stride)
if use_actions:
actions = []
# hack here for the separations in the 1x datasets
for action_file in sorted((data_dir / "actions").iterdir()):
actions.append(np.memmap(action_file, dtype=np.float32, mode="r").reshape(len(, -1))
self.actions = np.concatenate(actions, axis=-1)
self.actions, self.action_stat = normalize_actions(self.actions)
if os.path.isfile(segment_ids_path):
self.segment_ids = np.memmap(
self.segment_ids = None
if filter_interrupts:
raise NotImplementedError("Cannot filter interrupted sequences without segment ids.")
# Number of frames between the first and last frames of a video sequence (excluding one endpoint frame)
self.video_len = (self.window_size - 1) * self.stride
self.valid_start_inds = []
for start_ind in range(len( - self.video_len - self.stride):
# Assuming `segment_ids` is monotonically increasing, a sequence is interrupted (or too short)
# if the first and last frames have different segment ids.
if not (filter_interrupts and self.segment_ids[start_ind] != self.segment_ids[start_ind + self.video_len]):
if len(self.valid_start_inds) >= max_traj_num:
if filter_overlaps:
# Instead of using a sliding window, use each frame at most once
filtered_start_inds = []
for start_ind in self.valid_start_inds:
overlapping_start_inds = {start_ind - i * self.stride for i in range(1, self.window_size)}
# all sequences from `overlapping_start_inds` will also contain `start_ind`,
# so exclude sequence starting from `start_ind` if any of `overlapping_start_inds` is already being used
for existing_start_ind in filtered_start_inds[-self.window_size * self.stride:]:
# Bound could be improved
if existing_start_ind in overlapping_start_inds:
self.valid_start_inds = filtered_start_inds
num_videos = len(np.unique(self.segment_ids))
print(f"Loaded {len(self)} sequences from {data_dir} {self.stride=} {self.window_size=} {self.n_action=} {num_videos=}")
def __len__(self):
return len(self.valid_start_inds)
def __getitem__(self, idx):
Returns a flattened sequence of tokens representing `self.window_size` frames,
spaced `self.stride` apart.
start_ind = self.valid_start_inds[idx]
x =[start_ind : start_ind + self.video_len + 1 : self.stride].copy()
x = torch.FloatTensor(x).float()
if self.use_raw_image_as_latent:
x = torch.nn.functional.interpolate(x, size=(self.metadata["h"], self.metadata["w"]))
# normalize
x = x / 255 - 0.5
x = x * SVD_SCALE
x = rearrange(x, "t c h w -> (t h w) c")
# divide it when decoding
# reconstructions since the input ids and the labels are the same
attention_mask = torch.ones_like(x)
data_dict = {
"input_ids": x,
"labels": x,
"attention_mask": attention_mask,
"h": self.metadata["h"],
"w": self.metadata["w"],
"c": self.metadata["latent_channels"],
if hasattr(self, "actions"):
# we want to have all actions within the stride to predict the next frame at the end of the stride
# we will concatenate the actions from [window_size, d_action] to [window_size, d_action * stride]
data_dict['action_ids'] = self.actions[start_ind:start_ind + self.video_len + self.stride].reshape(self.window_size, -1)
data_dict['action_ids'] = torch.from_numpy(data_dict['action_ids'].astype(np.float32))
data_dict["domain"] ="_noquant", "")
return data_dict
def get_maskgit_collator_feature(config: GenieConfig):
# mask_token_id = config.image_vocab_size
def collate_fn(features) -> dict[str, torch.Tensor]:
# during training, map (z_0, z_1', z_2') -> (null, z_1, z_2)
# (z_0, z_1') -> (null, z_1) is the diffusion operator on z_1' -> z_1
h = features[0]["h"]
w = features[0]["w"]
input_ids = torch.stack([ex["input_ids"] for ex in features])
device = input_ids.device
x_THWC = rearrange(input_ids, "b (t h w) c -> b t h w c", b=len(features), t=config.T, h=h, w=w)
labels = x_THWC.clone()
first_masked_frame = config.T
mask = torch.zeros(1).long()
mask_token_indicator = torch.zeros((len(features), config.T, h, w)).long()
if config.dataloader_apply_mask:
if random.random() < config.non_mlm_ratio: # Closer to autoregressive inference
# Leave frames [0, first_masked_frame) unmasked.
first_masked_frame = random.randint(config.num_prompt_frames, config.T - 1)
else: # Typical MLM masking
first_masked_frame = 1
c = 0
while mask.max() == 0: # We could get unlucky and mask no tokens?
# per-minibatch, per-frame masking probability (could try variable masking rate from MUSE)
rand = torch.rand(len(features), config.T - first_masked_frame, 1, 1)
# add a minimum mask ratio
rand_mask = rand * (1 - config.dataloader_mask_ratio_min) + config.dataloader_mask_ratio_min
mask_prob_T = cosine_schedule(rand_mask)
r = torch.rand_like(x_THWC[:, first_masked_frame:, ..., 0], dtype=torch.float)
mask = r < mask_prob_T
c += 1
if c > 1:
print(f"Generated mask {c} > 1 times.")
mask_token_indicator =[
torch.zeros((len(features), first_masked_frame, h, w), dtype=mask.dtype), mask], dim=1)
data_dict = {
"input_ids": rearrange(x_THWC, "b t h w c -> b (t h w) c"),
"labels": rearrange(labels, "b t h w c-> b (t h w) c"),
"masked_tokens_indicator": mask_token_indicator,
if "action_ids" in features[0]:
data_dict['action_ids'] = torch.stack([ex["action_ids"] for ex in features])
data_dict['domain'] = [ex["domain"] for ex in features]
data_dict['h'] = [ex["h"] for ex in features]
data_dict['w'] = [ex["w"] for ex in features]
return data_dict
return collate_fn