Spaces:
Running
on
Zero
Running
on
Zero
import argparse | |
import contextlib | |
import logging | |
import math | |
import os | |
import time | |
import mup | |
import numpy as np | |
import torch | |
import torchvision.transforms.functional as transforms_f | |
from accelerate import Accelerator | |
from accelerate.logging import get_logger | |
from accelerate.utils import set_seed | |
from einops import rearrange | |
from lpips import lpips | |
from torch.utils.data import DataLoader | |
from tqdm.auto import tqdm | |
import transformers | |
import traceback | |
from transformers import ( | |
default_data_collator, | |
get_scheduler, | |
) | |
from collections import defaultdict | |
from cont_data import RawFeatureDataset, get_maskgit_collator_feature | |
from common.eval_utils import decode_tokens, compute_lpips, decode_features | |
from genie.config import DiffusionGenieConfig | |
from genie.st_mar import STMAR | |
from visualize import decode_latents_wrapper | |
from skimage import metrics as image_metrics | |
from matplotlib import pyplot as plt | |
from datetime import datetime | |
from accelerate import DistributedDataParallelKwargs | |
torch.autograd.set_detect_anomaly(True) | |
# Get current date and time | |
now = datetime.now() | |
# Format the datetime object as a string | |
formatted_date = now.strftime("%Y-%m-%d %H:%M:%S") | |
torch.set_float32_matmul_precision("medium") | |
logger = get_logger(__name__) | |
SVD_SCALE = 0.18215 | |
def parse_args(): | |
# parser = argparse.ArgumentParser(description="Train a MaskGIT or Llama-style LLM on video generation.") | |
parser = argparse.ArgumentParser(description="Train a spatial-temporal MaskGIT-style model on video generation.") | |
# Data | |
parser.add_argument( | |
"--train_data_dir", type=str, default="data/1x_humanoid_magvit_traj1000_train", | |
help="Directory containing tokenized data, should have a `video.bin`, `metadata.json` and `segment_ids.json`." | |
) | |
parser.add_argument( | |
"--val_data_dir", type=str, default="data/1x_humanoid_magvit_traj1000_val", | |
help="Directory containing tokenized data, should have a `video.bin`, `metadata.json` and `segment_ids.json`." | |
) | |
parser.add_argument( | |
"--domain", type=str, default="1x_humanoid", | |
help="The domain name for the dataset" | |
) | |
parser.add_argument( | |
"--window_size", | |
type=int, | |
default=12, | |
help="Number of frames to in a sequence.", | |
) | |
parser.add_argument( | |
"--stride", | |
type=int, | |
default=None, | |
help="Difference in frame count between consecutive frames in a sequence.", | |
) | |
parser.add_argument( | |
"--filter_overlaps", | |
action="store_true", | |
help=( | |
"Whether to filter repeated frames in the train dataset (`filter_overlaps` always true for the val set). " | |
"Filtering essentially makes the training dataset less correlated but ~16x smaller, " | |
"see the `filter_overlaps` argument in `RawTokenDataset` for details."), | |
default=True | |
) | |
# Model | |
parser.add_argument( | |
"--llama_config", | |
type=str, | |
help="`transformers.LlamaConfig` json. " | |
"E.g. https://huggingface.co./1x-technologies/Llama_1B_v0/blob/main/config.json", | |
) | |
parser.add_argument( | |
"--genie_config", | |
type=str, | |
help="GenieConfig json." | |
), | |
parser.add_argument( | |
"--warmstart_path", | |
type=str, | |
default=None, | |
help="A path to a checkpoint to warmstart a model from, possibly not trained on the same dataset, " | |
"will resize embeddings if needed.", | |
) | |
parser.add_argument( | |
"--resume_from_checkpoint", | |
type=str, | |
default=None, | |
help="If the training should continue from a checkpoint folder.", | |
) | |
# Training | |
parser.add_argument( | |
"--per_device_train_batch_size", | |
type=int, | |
default=4, | |
help="Batch size (per device) for the training dataloader.", | |
) | |
parser.add_argument( | |
"--per_device_eval_batch_size", | |
type=int, | |
default=2, | |
help="Batch size (per device) for the evaluation dataloader.", | |
) | |
parser.add_argument( | |
"--gradient_accumulation_steps", | |
type=int, | |
default=1, | |
help="Number of updates steps to accumulate before performing a backward/update pass.", | |
) | |
parser.add_argument( | |
"--gradient_checkpointing", | |
default=False, | |
action="store_true", | |
) | |
parser.add_argument( | |
"--learning_rate", | |
type=float, | |
default=2e-4, | |
help="Initial learning rate (after the potential warmup period) to use.", | |
) | |
parser.add_argument("--weight_decay", type=float, default=0.01, help="Weight decay to use.") | |
parser.add_argument("--num_train_epochs", type=int, default=2, help="Total number of training epochs to perform.") | |
parser.add_argument( | |
"--max_train_steps", | |
type=int, | |
default=None, | |
help="Total number of training steps to perform. If provided, overrides num_train_epochs.", | |
) | |
parser.add_argument( | |
"--max_eval_steps", | |
type=int, | |
default=int(1e10), | |
help="Only evaluate on `max_eval_steps` batches of validation data per process, faster.", | |
) | |
parser.add_argument( | |
"--eval_every_n_steps", | |
type=int, | |
default=1000, | |
help="Eval every N training steps.", | |
) | |
parser.add_argument( | |
"--vis_every_n_steps", | |
type=int, | |
default=1000, | |
help="Visualize every N training steps.", | |
) | |
parser.add_argument( | |
"--lr_scheduler_type", | |
type=str, | |
default="constant_with_warmup", | |
help="The scheduler type to use.", | |
choices=["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup", "custom_cosine"], | |
) | |
parser.add_argument( | |
"--num_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." | |
) | |
parser.add_argument( | |
"--max_grad_norm", | |
type=float, | |
default=10, | |
help="Threshold to clip gradients.", | |
) | |
parser.add_argument( | |
"--attention_dropout", | |
type=float, | |
default=0.01, | |
help="Attention dropout prob.", | |
) | |
parser.add_argument( | |
"--adam_beta_1", | |
type=float, | |
default=0.9, | |
) | |
parser.add_argument( | |
"--adam_beta_2", | |
type=float, | |
default=0.95, | |
) | |
parser.add_argument( | |
"--adam_eps", | |
type=float, | |
default=1e-8, | |
) | |
# Misc | |
parser.add_argument("--output_dir", type=str, required=True, help="Where to store the model checkpoints.") | |
parser.add_argument( | |
"--checkpointing_steps", | |
type=str, | |
default="10000", | |
help="Whether the various states should be saved at the end of every n steps, or 'epoch' for each epoch.", | |
) | |
parser.add_argument("--seed", type=int, default=42, help="A seed for reproducible training.") | |
parser.add_argument( | |
"--overfit_first_batch", | |
action="store_true", | |
help=( | |
"Debug option that trains and validates on only the first batch of the training dataset." | |
), | |
) | |
parser.add_argument( | |
"--report_to", | |
type=str, | |
default="wandb", | |
help="The integration to report the results and logs to.", | |
) | |
parser.add_argument( | |
"--mu_transfer", | |
action="store_true", | |
help="If specified, will train with mu transfer reparametrizations. Only supports Llama models.", | |
default=False | |
) | |
parser.add_argument( | |
"--no_compile", | |
action="store_true", | |
help="If specified, will not compile the model.", | |
default=True | |
) | |
parser.add_argument( | |
"--add_action_input", | |
action="store_true", | |
help=( | |
"Whether to add action as input to the dynamics model. ") | |
) | |
parser.add_argument( | |
"--run_name", | |
type=str, | |
default="video_prediction", | |
help="", | |
) | |
parser.add_argument( | |
"--cleanup_checkpoints", | |
action="store_true", | |
help=( | |
"Whether to clean up checkpoints (to keep only the last 3) along the training. "), | |
default=True | |
) | |
parser.add_argument( | |
"--use_raw_image_as_latent", | |
action="store_true", | |
help="If specified, will train with mu transfer reparametrizations. Only supports Llama models.", | |
) | |
return parser | |
def save_checkpoint(model, accelerator, args, filename): | |
""" | |
filename: `save_path = os.path.join(args.output_dir, filename)` | |
""" | |
unwrapped_model = accelerator.unwrap_model(model) | |
save_path = os.path.join(args.output_dir, filename) | |
if accelerator.is_main_process: | |
unwrapped_model.save_pretrained( | |
save_path, is_main_process=accelerator.is_main_process, save_function=accelerator.save | |
) | |
accelerator.save_state(save_path) | |
def visualize(accelerator, model, dataloader, window_size, encoder_type, | |
encoder_name_or_path, metrics_prefix="eval", max_steps=1, use_raw_image_as_latent=False): | |
""" | |
Visualizes model's autoregressive generation outputs, logged to wandb. | |
It uses teacher-forcing (causal in time axis) | |
""" | |
# breakpoint() | |
accelerator.wait_for_everyone() | |
unwrapped_model = accelerator.unwrap_model(model) | |
if not unwrapped_model.config.jointly_predict_states: | |
return | |
metrics = defaultdict(list) | |
if hasattr(dataloader.dataset, "metadata"): | |
if accelerator.is_main_process: | |
lpips_alex = lpips.LPIPS(net="alex") # Calculate LPIPS w/ AlexNet, the fastest option | |
else: | |
if accelerator.is_main_process: | |
lpips_alex = lpips.LPIPS(net="alex") # Calculate LPIPS w/ AlexNet, the fastest option | |
decode_latents = decode_latents_wrapper(encoder_type=encoder_type, encoder_name_or_path=encoder_name_or_path) # re-initializing every time to save memory | |
unwrapped_model.eval() | |
for step, batch in enumerate(dataloader): | |
# Note: hardcoding 4 image cap for faster inference on small models | |
TEST_NUM = 4 | |
reshaped_labels = rearrange(batch["labels"][:TEST_NUM], "b (t s) c -> b t s c", t=window_size).to(accelerator.device) # `s` is really `(h, w)` | |
domains = batch["domain"][:TEST_NUM] | |
if 'action_ids' in batch: | |
action_ids = batch["action_ids"][:TEST_NUM].to(accelerator.device) | |
else: | |
action_ids = None | |
# hardcoding half of frames for context | |
num_prompt_frames = unwrapped_model.config.num_prompt_frames | |
num_new_tokens = batch["w"][0] * batch["h"][0] * (window_size - num_prompt_frames) | |
prompt_input_ids = rearrange(reshaped_labels[:, :num_prompt_frames], "b t s c -> b (t s) c") | |
outputs = unwrapped_model.generate(input_ids=prompt_input_ids, attention_mask=torch.ones_like(prompt_input_ids), | |
max_new_tokens=num_new_tokens, min_new_tokens=num_new_tokens, | |
action_ids=action_ids, | |
domain=batch["domain"][:TEST_NUM], | |
w=batch["w"][:TEST_NUM], | |
h=batch["h"][:TEST_NUM]) | |
output_tokens = rearrange(outputs, "b (t h w) c -> b t h w c", t=window_size, | |
h=batch["h"][0], w=batch["w"][0]) | |
gtruth_tokens = rearrange(reshaped_labels[:, num_prompt_frames:], "b t (h w) c -> b t h w c", | |
h=batch["h"][0], w=batch["w"][0]) | |
if not use_raw_image_as_latent: | |
output_tokens = output_tokens / SVD_SCALE | |
gtruth_tokens = gtruth_tokens / SVD_SCALE | |
decoded_output = decode_features(output_tokens.cpu(), decode_latents) | |
decoded_gtruth = decode_features(gtruth_tokens.cpu(), decode_latents) | |
else: | |
decoded_output = ((output_tokens + 0.5) * 255).long() | |
decoded_gtruth = ((gtruth_tokens + 0.5) * 255).long() | |
decoded_output = accelerator.gather(decoded_output.to(accelerator.device)).cpu() | |
decoded_gtruth = accelerator.gather(decoded_gtruth.to(accelerator.device)).cpu() | |
if accelerator.is_main_process and step < 2: | |
exs_per_fig = 4 | |
for j in range(0, len(decoded_output), exs_per_fig): | |
# with 10 percent change we log some output. to save spaces | |
fig, axs = plt.subplots(nrows=2 * exs_per_fig, ncols=window_size, figsize=(3 * window_size, 3 * 2 * exs_per_fig)) | |
# If len(decoded_output) is not a multiple of 4, make sure to truncate properly | |
for k in range(min(exs_per_fig, len(decoded_output) - j)): | |
for i in range(num_prompt_frames): | |
for ax in (axs[k * 2, i], axs[k * 2 + 1, i]): | |
ax.imshow(transforms_f.to_pil_image(decoded_output[j + k, i])) | |
ax.set_title("Context") | |
ax.axis("off") | |
for i in range(num_prompt_frames, window_size): | |
axs[k * 2, i].imshow(transforms_f.to_pil_image(decoded_gtruth[j + k, i - num_prompt_frames])) | |
axs[k * 2, i].set_title("Ground truth") | |
axs[k * 2 + 1, i].imshow(transforms_f.to_pil_image(decoded_output[j + k, i])) | |
axs[k * 2 + 1, i].set_title("Prediction") | |
for ax in axs[:, i]: | |
ax.axis("off") | |
wandb_tracker = accelerator.get_tracker("wandb") | |
wandb_tracker.log({f"vis_{metrics_prefix}_{j}": fig}, commit=False) | |
wandb_tracker.log({f"{domains[0]}/vis_{metrics_prefix}_{j}": fig}, commit=False) | |
plt.close(fig) | |
metrics["ar_lpips"].extend(compute_lpips(decoded_gtruth, # Note: not parallelizing right now | |
decoded_output[:, num_prompt_frames:], lpips_alex)) | |
gt_frames_numpy = decoded_gtruth.detach().cpu().numpy() | |
pred_frames_numpy = decoded_output[:, num_prompt_frames:].detach().cpu().numpy() | |
psnr = [image_metrics.peak_signal_noise_ratio( | |
gt_frames_numpy[i] / 255., pred_frames_numpy[i] / 255., data_range=1.0) for i in range(gt_frames_numpy.shape[0])] | |
ssim = [np.mean([image_metrics.structural_similarity( | |
gt_frames_numpy[i][j] / 255., pred_frames_numpy[i][j] / 255., data_range=1.0, channel_axis=0) \ | |
for i in range(gt_frames_numpy.shape[0])]) for j in range(gt_frames_numpy.shape[1])] | |
# compute some other metrics | |
metrics[f"{metrics_prefix}/ar_psnr"].extend(psnr) | |
metrics[f"{metrics_prefix}/ar_ssim"].extend(ssim) | |
metrics[f"{batch['domain'][0]}/ar_lpips"].extend(compute_lpips(decoded_gtruth, # Note: not parallelizing right now | |
decoded_output[:, num_prompt_frames:], lpips_alex)) | |
if step + 1 >= max_steps: | |
break | |
unwrapped_model.train() | |
if accelerator.is_main_process: | |
metrics = {f"{metrics_prefix}_{key}": np.mean(val) for key, val in metrics.items() if len(val) > 0} | |
print(f"{metrics=}") | |
wandb_tracker = accelerator.get_tracker("wandb") | |
wandb_tracker.log(metrics, commit=False) | |
def train(accelerator, model, optimizer, lr_scheduler, train_dataloader, eval_dataloader, experiment_config, config, args): | |
total_batch_size = args.per_device_train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps | |
logger.info("***** Running training *****") | |
logger.info(f" Num examples = {len(train_dataloader)}") | |
logger.info(f" Num Epochs = {args.num_train_epochs}") | |
logger.info(f" Instantaneous batch size per device = {args.per_device_train_batch_size}") | |
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") | |
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") | |
logger.info(f" Total optimization steps = {args.max_train_steps}") | |
# Only show the progress bar once on each machine. | |
progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process) | |
completed_steps = 0 | |
starting_epoch = 0 | |
resume_step = None | |
checkpoint_path = None | |
# Potentially load in the weights and states from a previous save | |
if args.resume_from_checkpoint: | |
if os.path.exists(args.resume_from_checkpoint + "/pytorch_model.bin") or os.path.exists(args.resume_from_checkpoint + "/model.safetensors"): | |
checkpoint_path = args.resume_from_checkpoint | |
path = os.path.basename(args.resume_from_checkpoint.rstrip("/")) | |
else: | |
# Get the most recent checkpoint | |
base_path = os.path.dirname(args.resume_from_checkpoint) | |
dirs = [os.path.join(base_path, f.name) for f in os.scandir(base_path) if f.is_dir()] | |
dirs.sort(key=os.path.getctime) | |
# Sorts folders by date modified, most recent checkpoint is the last | |
if len(dirs) > 0: | |
path = dirs[-1] | |
checkpoint_path = path | |
path = os.path.basename(checkpoint_path) | |
try: | |
accelerator.print(f"Resumed from checkpoint: {checkpoint_path}") | |
if os.path.exists(checkpoint_path): | |
# for finetuning with a different structures | |
print(f"loading checkpoint from {checkpoint_path}") | |
accelerator.load_state(checkpoint_path, strict=False) | |
# tied weights not saved so can't load strict, but also no need to tie again | |
# Extract `epoch_{i}` or `step_{i}` | |
training_difference = os.path.splitext(path)[0] | |
else: | |
print("No checkpoint found, training from scratch.") | |
training_difference = "step_0" | |
if "epoch" in training_difference: | |
starting_epoch = int(training_difference.replace("epoch_", "")) + 1 | |
resume_step = None | |
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) | |
completed_steps = starting_epoch * num_update_steps_per_epoch | |
else: | |
# need to multiply `gradient_accumulation_steps` to reflect real steps | |
resume_step = int(training_difference.replace("step_", "")) * args.gradient_accumulation_steps | |
starting_epoch = resume_step // len(train_dataloader) | |
completed_steps = resume_step // args.gradient_accumulation_steps | |
resume_step -= starting_epoch * len(train_dataloader) | |
except Exception as e: | |
print("load checkpoint incomplete", traceback.format_exc()) | |
# update the progress_bar if load from checkpoint | |
progress_bar.update(completed_steps) | |
loss_info = torch.zeros(2, device=accelerator.device) # sum, count | |
for epoch in range(starting_epoch, args.num_train_epochs): | |
model.train() | |
train_dataloader.set_epoch(epoch) | |
# potenally cleanup the previous checkpoints | |
if args.cleanup_checkpoints: | |
if os.path.exists(args.output_dir): | |
dirs = [os.path.join(args.output_dir, f.name) for f in os.scandir(args.output_dir) if f.is_dir()] | |
if len(dirs) > 3: | |
dirs.sort(key=os.path.getctime) | |
paths = dirs[:-3] | |
# only keep the last 3 | |
for path in paths: | |
print(f"remove rm -rf {path}") | |
# os.system(f"rm -rf {path}") | |
if args.resume_from_checkpoint and epoch == starting_epoch and resume_step is not None: | |
# We skip the first `n` batches in the dataloader when resuming from a checkpoint | |
active_dataloader = accelerator.skip_first_batches(train_dataloader, resume_step) | |
else: | |
active_dataloader = train_dataloader | |
_time = time.time() | |
# accelerator.wait_for_everyone() | |
dataloader_iter = iter(active_dataloader) | |
for step in range(len(active_dataloader)): | |
batch = next(dataloader_iter) | |
batch_size = batch["input_ids"].size(0) | |
# Manual gradient accumulation because accelerator somehow taking a lot of memory | |
is_update_step = (step + 1) % args.gradient_accumulation_steps == 0 | |
ctx_manager = contextlib.nullcontext() if is_update_step else accelerator.no_sync(model) | |
train_action_loss = 0 | |
with ctx_manager: | |
outputs = model(**batch) | |
loss = outputs.loss | |
if not torch.isnan(loss).any(): | |
loss_info[0] += loss.detach().mean() * batch_size | |
if "action_loss" in outputs: | |
train_action_loss = outputs.action_loss.item() | |
loss += config.action_loss_weight * outputs.action_loss | |
accelerator.backward(loss / args.gradient_accumulation_steps) | |
else: | |
print(f"Warning: NaN or Inf detected in loss for domain: {batch['domain']}. Skipping backward pass.") | |
dummy_loss = torch.zeros_like(loss, requires_grad=True) | |
accelerator.backward(dummy_loss) | |
loss_info[1] += batch_size | |
if not is_update_step: | |
continue | |
# Everything below only happens on update step | |
if args.max_grad_norm is not None: | |
accelerator.clip_grad_norm_(model.parameters(), args.max_grad_norm) | |
optimizer.step() | |
lr_scheduler.step() | |
optimizer.zero_grad() | |
loss_info = accelerator.reduce(loss_info) | |
train_action_loss = train_action_loss / loss_info[1] | |
avg_train_loss = (loss_info[0] / loss_info[1]).item() # sum / count | |
loss_info *= 0 # reset sum and count | |
batch_time = time.time() - _time # accumulated batch | |
_time = time.time() | |
accelerator.log( | |
{ | |
"train_loss": avg_train_loss, | |
"train_action_loss": train_action_loss, | |
f"stat/{batch['domain'][0]}_action_loss": train_action_loss, | |
f"stat/{batch['domain'][0]}_train_loss": avg_train_loss, | |
"epoch": epoch, | |
"update_step": completed_steps, | |
"examples_processed": completed_steps * args.per_device_train_batch_size | |
* args.gradient_accumulation_steps * accelerator.num_processes, | |
"learning_rate": lr_scheduler.get_last_lr()[0], | |
"flops": (completed_steps + 1) * experiment_config["FLOPs_per_update_step"], | |
"throughput_examples": experiment_config["effective_batch_size"] / batch_time, | |
}, step=completed_steps) | |
progress_bar.update(1) | |
completed_steps += 1 | |
# print(f"{completed_steps % args.checkpointing_steps=} {completed_steps=} {args.checkpointing_steps=}") | |
if completed_steps % int(args.checkpointing_steps) == 0: | |
print(f"Saving checkpoint at step {completed_steps}!") | |
save_checkpoint(model, accelerator, args, f"step_{completed_steps}") | |
if completed_steps % args.eval_every_n_steps == 0: | |
time.sleep(1) # manual adding time sleep | |
model.eval() | |
eval_losses = [] | |
# Compute token-level accuracy (w/ teacher forcing) | |
num_correct = 0 | |
num_total = 0 | |
# to resolve the data collating issues | |
eval_dataloader_iter = iter(eval_dataloader) | |
for step in range(min(args.max_eval_steps, len(eval_dataloader))): | |
batch = next(eval_dataloader_iter) | |
batch_size = len(batch["input_ids"]) # Last batch might not be full | |
with torch.no_grad(): | |
outputs = model(**batch) | |
loss = outputs.loss | |
eval_losses.append(accelerator.gather_for_metrics(loss.repeat(batch_size))) | |
if "acc" in outputs: | |
# `num_correct` and `num_total` actually track mean accuracy in this case. | |
num_correct_batch = accelerator.reduce(outputs.acc, reduction="mean").item() * batch_size | |
num_total_batch = batch_size | |
num_correct += num_correct_batch | |
num_total += num_total_batch | |
else: | |
shifted_preds = torch.argmax(outputs.logits[:, :-1, :], dim=-1) | |
shifted_labels = batch["labels"][:, 1:] | |
num_correct_batch = accelerator.gather_for_metrics((shifted_preds == shifted_labels).sum()).sum().item() | |
num_total_batch = accelerator.gather_for_metrics(torch.tensor(torch.numel(shifted_labels), | |
device=accelerator.device)).sum().item() | |
num_correct += num_correct_batch | |
num_total += num_total_batch | |
if step >= args.max_eval_steps * args.num_datasets: | |
break | |
try: | |
accelerator.log( | |
{ | |
f"stat/{batch['domain'][0]}_eval_loss": eval_losses[-1], | |
# f"{batch['domain'][0]}_stat/eval_teacher_acc": num_correct_batch / num_total_batch | |
}, | |
step=completed_steps, | |
) | |
except Exception as e: | |
print("log failed", e) | |
eval_losses = torch.cat(eval_losses) | |
eval_loss = torch.mean(eval_losses).item() | |
eval_teacher_acc = num_correct / num_total | |
logger.info(f"{completed_steps=} {eval_loss=} {eval_teacher_acc=}") | |
accelerator.log( | |
{ | |
"eval_loss": eval_loss, | |
"epoch": epoch, | |
"update_step": completed_steps, | |
"examples_processed": completed_steps * args.per_device_train_batch_size | |
* args.gradient_accumulation_steps * accelerator.num_processes, | |
"flops": completed_steps * experiment_config["FLOPs_per_update_step"], | |
}, | |
step=completed_steps, | |
) | |
# Switch back to train mode | |
model.train() | |
if completed_steps % args.vis_every_n_steps == 0: | |
if "encoder_type" not in experiment_config: | |
experiment_config["encoder_name_or_path"] = "data/magvit2.ckpt" | |
experiment_config["encoder_type"] = "magvit" | |
if not args.overfit_first_batch: # val is same as train otherwise | |
visualize(accelerator, model, eval_dataloader, args.window_size, | |
experiment_config["encoder_type"], experiment_config["encoder_name_or_path"], "val", | |
use_raw_image_as_latent=args.use_raw_image_as_latent) | |
visualize(accelerator, model, train_dataloader, args.window_size, | |
experiment_config["encoder_type"], experiment_config["encoder_name_or_path"], "train", | |
use_raw_image_as_latent=args.use_raw_image_as_latent) | |
if completed_steps >= args.max_train_steps: | |
break | |
if args.checkpointing_steps == "epoch": | |
save_checkpoint(model, accelerator, args, f"epoch_{epoch}") | |
accelerator.end_training() | |
save_checkpoint(model, accelerator, args, f"final_checkpt") | |
def main(): | |
parser = parse_args() | |
args = parser.parse_args() | |
# Manual gradient accumulation | |
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) | |
accelerator = Accelerator(gradient_accumulation_steps=1, log_with=args.report_to, project_dir=args.output_dir, kwargs_handlers=[ddp_kwargs]) | |
accelerator.init_trackers("video") | |
if accelerator.is_main_process: | |
accelerator.trackers[0].run.name = formatted_date + "_" + args.run_name | |
# Make one log on every process with the configuration for debugging. | |
logging.basicConfig( | |
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", | |
datefmt="%m/%d/%Y %H:%M:%S", | |
level=logging.INFO, | |
) | |
logger.info(accelerator.state, main_process_only=False) | |
if accelerator.is_local_main_process: | |
transformers.utils.logging.set_verbosity_info() | |
print(f"Rank {accelerator.process_index} assigned to device {torch.cuda.current_device()}") | |
else: | |
transformers.utils.logging.set_verbosity_error() | |
if args.seed is not None: | |
set_seed(args.seed) | |
if accelerator.is_main_process: | |
os.makedirs(args.output_dir, exist_ok=True) | |
accelerator.wait_for_everyone() | |
config = DiffusionGenieConfig.from_pretrained(args.genie_config) | |
train_dataset = RawFeatureDataset(args.train_data_dir, window_size=args.window_size, | |
stride=args.stride, filter_overlaps=args.filter_overlaps, | |
compute_stride_from_freq_table=(args.stride is None), | |
use_actions=config.use_actions, | |
use_raw_image_as_latent=args.use_raw_image_as_latent) | |
if not args.overfit_first_batch: | |
eval_dataset = RawFeatureDataset(args.val_data_dir, window_size=args.window_size, | |
stride=args.stride, filter_overlaps=True, | |
compute_stride_from_freq_table=(args.stride is None), | |
use_actions=config.use_actions, | |
use_raw_image_as_latent=args.use_raw_image_as_latent) | |
else: | |
train_dataset.valid_start_inds = train_dataset.valid_start_inds[:args.per_device_train_batch_size | |
* args.gradient_accumulation_steps | |
* accelerator.num_processes] | |
eval_dataset = train_dataset | |
shared_keys = ("s", "h", "w", "vocab_size", "latent_channels", "encoder_type", "encoder_name_or_path", "quantized") # TODO: check train/val hz per dataset? | |
assert all(train_dataset.metadata.get(shared_key) == eval_dataset.metadata.get(shared_key) | |
for shared_key in shared_keys) | |
# if "encoder_type" not in train_dataset.metadata or "encoder_name_or_path" not in train_dataset.metadata: | |
# accelerator.print("Assuming MAGVIT image encoder.") | |
# train_dataset.metadata["encoder_type"] = "magvit" | |
# train_dataset.metadata["encoder_name_or_path"] = "data/magvit2.ckpt" | |
# Will not store key in metadata if it's missing, so that defaults can be filled by functions later? # TODO: don't think we are handling missing keys in function calls | |
shared_metadata = {shared_key: train_dataset.metadata[shared_key] | |
for shared_key in shared_keys if shared_key in train_dataset.metadata} | |
if args.llama_config is not None: | |
raise NotImplementedError("Have not factorized Llama vocabulary.") | |
else: | |
config.use_mup = args.mu_transfer # Note: changing this may affect pre-trained model due to attn scaling | |
config.image_vocab_size = shared_metadata["vocab_size"] | |
config.T = args.window_size | |
config.S = shared_metadata["h"] * shared_metadata["w"] # TODO: make STMaskGIT use h and w instead of S | |
config.vae_embed_dim = shared_metadata["latent_channels"] | |
model = STMAR(config) | |
if config.use_actions: | |
print(f"Initializing action projectors with {train_dataset.n_action}d action") | |
model.init_action_projectors([args.domain], [train_dataset.n_action], [train_dataset.action_stat], config.action_network) | |
if args.mu_transfer: | |
model.set_mup_shapes(rescale_params=True) | |
model.init_weights() # might be unnecessary if `rescale_params` is True | |
# Optimizer. Split weights in two groups, one with weight decay and the other not. | |
no_decay = ["bias", "layer_norm.weight"] | |
optimizer_grouped_parameters = [ | |
{ | |
"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], | |
"weight_decay": args.weight_decay, | |
}, | |
{ | |
"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], | |
"weight_decay": 0.0, | |
}, | |
] | |
opt_class = mup.MuAdamW if args.mu_transfer else torch.optim.AdamW | |
# scale base learning rate | |
effective_batch_size = args.per_device_train_batch_size * args.gradient_accumulation_steps \ | |
* accelerator.num_processes | |
args.learning_rate = args.learning_rate * min(max(1, effective_batch_size / 64), 8) | |
optimizer = opt_class(optimizer_grouped_parameters, lr=args.learning_rate, | |
betas=(args.adam_beta_1, args.adam_beta_2), eps=args.adam_eps) | |
# DataLoaders creation: | |
collate_fn = default_data_collator if args.llama_config is not None else get_maskgit_collator_feature(config) | |
train_dataloader = DataLoader( | |
train_dataset, shuffle=True, collate_fn=collate_fn, | |
batch_size=args.per_device_train_batch_size, num_workers=4, pin_memory=True, | |
) | |
# Shuffle eval dataset and then set shuffle=False on the dataloader. | |
# Shuffling in the dataloader results in reshuffling with each iteration. | |
eval_dataset.valid_start_inds = torch.tensor(eval_dataset.valid_start_inds)[ | |
torch.randperm(len(eval_dataset), generator=torch.Generator().manual_seed(0)) | |
].tolist() | |
eval_dataloader = DataLoader( | |
eval_dataset, shuffle=False, collate_fn=collate_fn, | |
batch_size=args.per_device_eval_batch_size, pin_memory=True, num_workers=4, | |
) | |
# Scheduler and math around the number of training steps. | |
overrode_max_train_steps = False | |
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) | |
if args.max_train_steps is None: | |
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch | |
overrode_max_train_steps = True | |
if args.lr_scheduler_type == "custom_cosine": # decay to `end_ratio` of the peak learning rate | |
def get_lr_wrapper(warmup_steps, max_steps, end_ratio=0.1): | |
def get_lr(step): | |
if step < warmup_steps: | |
return (step + 1) / warmup_steps | |
remaining_steps = max_steps - warmup_steps | |
return ((1 + math.cos(math.pi * (step - warmup_steps) / remaining_steps)) / 2) \ | |
* (1 - end_ratio) + end_ratio | |
return get_lr | |
lr_scheduler = torch.optim.lr_scheduler.LambdaLR( | |
optimizer, get_lr_wrapper(args.num_warmup_steps * accelerator.num_processes, | |
args.max_train_steps if overrode_max_train_steps | |
else args.max_train_steps * accelerator.num_processes) | |
) | |
else: | |
lr_scheduler = get_scheduler( | |
name=args.lr_scheduler_type, | |
optimizer=optimizer, | |
num_warmup_steps=args.num_warmup_steps * accelerator.num_processes, | |
num_training_steps=args.max_train_steps | |
if overrode_max_train_steps | |
else args.max_train_steps * accelerator.num_processes, | |
) | |
# Enable gradient checkpointing to save memory | |
if args.gradient_checkpointing: | |
logger.info("Enabling gradient checkpointing") | |
model.gradient_checkpointing_enable() | |
model.config.use_cache = False # incompatible with grad checkpointing | |
# Prepare everything with our `accelerator`. | |
model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare( | |
model, optimizer, train_dataloader, eval_dataloader, lr_scheduler | |
) | |
if not args.no_compile: | |
torch._dynamo.config.cache_size_limit = 256 | |
torch._dynamo.config.optimize_ddp = False # https://github.com/pytorch/pytorch/issues/104674 | |
# TODO: https://github.com/pytorch/pytorch/issues/109774#issuecomment-2046633776 | |
model = torch.compile(model) | |
# We need to recalculate our total training steps as the size of the training dataloader may have changed. | |
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) | |
if overrode_max_train_steps: | |
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch | |
# Afterwards we recalculate our number of training epochs | |
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) | |
# Figure out how many steps we should save the Accelerator states | |
checkpointing_steps = args.checkpointing_steps | |
if checkpointing_steps is not None and checkpointing_steps.isdigit(): | |
checkpointing_steps = int(checkpointing_steps) | |
# We need to initialize the trackers we use, and also store our configuration. | |
# The trackers initialize automatically on the main process. | |
experiment_config = vars(args) | vars(config) | |
seq_len = shared_metadata["h"] * shared_metadata["w"] * args.window_size | |
effective_batch_size = args.per_device_train_batch_size * args.gradient_accumulation_steps \ | |
* accelerator.num_processes | |
args.num_datasets = 1 | |
model_module = model.module if hasattr(model, "module") else model | |
experiment_config.update(shared_metadata | { | |
"model_parameters": sum(p.numel() for p in model.parameters()), | |
"model_parameters_M": round(sum(p.numel() for p in model.parameters()) / 1e6), | |
"trunk_parameters": sum(p.numel() for p in model_module.decoder.parameters()), | |
"trunk_parameters_M": round(sum(p.numel() for p in model_module.decoder.parameters()) / 1e6), | |
"seq_len": seq_len, | |
# "hz": "HARDCODED_2", | |
"train_data_tokens": len(train_dataset) * seq_len, # only one epoch | |
"effective_batch_size": effective_batch_size, | |
"effective_batch_size_tokens": effective_batch_size * seq_len, | |
"mixed_precision": accelerator.mixed_precision, | |
"num_datasets": 1 | |
}) | |
print("============================") | |
print(f"model parameters: {experiment_config['model_parameters_M']}M") | |
print("============================") | |
experiment_config["FLOPs_per_update_step"] = 6 * experiment_config["model_parameters"] \ | |
* experiment_config["effective_batch_size_tokens"] | |
accelerator.init_trackers(project_name="video", config=experiment_config) | |
train(accelerator, model, optimizer, lr_scheduler, train_dataloader, eval_dataloader, experiment_config, config, args) | |
if __name__ == "__main__": | |
main() | |