MotionLCM / train_motionlcm.py
wxDai's picture
init
6b1e9f7
raw
history blame
No virus
18.5 kB
import os
import sys
import logging
import datetime
import os.path as osp
from typing import Generator
import numpy as np
from tqdm.auto import tqdm
from omegaconf import OmegaConf
import torch
import diffusers
import transformers
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter
from diffusers.optimization import get_scheduler
from mld.config import parse_args, instantiate_from_config
from mld.data.get_data import get_datasets
from mld.models.modeltype.mld import MLD
from mld.utils.utils import print_table, set_seed, move_batch_to_device
os.environ["TOKENIZERS_PARALLELISM"] = "false"
def guidance_scale_embedding(w: torch.Tensor, embedding_dim: int = 512,
dtype: torch.dtype = torch.float32) -> torch.Tensor:
assert len(w.shape) == 1
w = w * 1000.0
half_dim = embedding_dim // 2
emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
emb = w.to(dtype)[:, None] * emb[None, :]
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
if embedding_dim % 2 == 1: # zero pad
emb = torch.nn.functional.pad(emb, (0, 1))
assert emb.shape == (w.shape[0], embedding_dim)
return emb
def append_dims(x: torch.Tensor, target_dims: int) -> torch.Tensor:
"""Appends dimensions to the end of a tensor until it has target_dims dimensions."""
dims_to_append = target_dims - x.ndim
if dims_to_append < 0:
raise ValueError(f"input has {x.ndim} dims but target_dims is {target_dims}, which is less")
return x[(...,) + (None,) * dims_to_append]
def scalings_for_boundary_conditions(timestep: torch.Tensor, sigma_data: float = 0.5,
timestep_scaling: float = 10.0) -> tuple:
c_skip = sigma_data ** 2 / ((timestep * timestep_scaling) ** 2 + sigma_data ** 2)
c_out = (timestep * timestep_scaling) / ((timestep * timestep_scaling) ** 2 + sigma_data ** 2) ** 0.5
return c_skip, c_out
def predicted_origin(
model_output: torch.Tensor,
timesteps: torch.Tensor,
sample: torch.Tensor,
prediction_type: str,
alphas: torch.Tensor,
sigmas: torch.Tensor
) -> torch.Tensor:
if prediction_type == "epsilon":
sigmas = extract_into_tensor(sigmas, timesteps, sample.shape)
alphas = extract_into_tensor(alphas, timesteps, sample.shape)
pred_x_0 = (sample - sigmas * model_output) / alphas
elif prediction_type == "v_prediction":
sigmas = extract_into_tensor(sigmas, timesteps, sample.shape)
alphas = extract_into_tensor(alphas, timesteps, sample.shape)
pred_x_0 = alphas * sample - sigmas * model_output
else:
raise ValueError(f"Prediction type {prediction_type} currently not supported.")
return pred_x_0
def extract_into_tensor(a: torch.Tensor, t: torch.Tensor, x_shape: torch.Size) -> torch.Tensor:
b, *_ = t.shape
out = a.gather(-1, t)
return out.reshape(b, *((1,) * (len(x_shape) - 1)))
class DDIMSolver:
def __init__(self, alpha_cumprods: np.ndarray, timesteps: int = 1000, ddim_timesteps: int = 50) -> None:
# DDIM sampling parameters
step_ratio = timesteps // ddim_timesteps
self.ddim_timesteps = (np.arange(1, ddim_timesteps + 1) * step_ratio).round().astype(np.int64) - 1
self.ddim_alpha_cumprods = alpha_cumprods[self.ddim_timesteps]
self.ddim_alpha_cumprods_prev = np.asarray(
[alpha_cumprods[0]] + alpha_cumprods[self.ddim_timesteps[:-1]].tolist()
)
# convert to torch tensors
self.ddim_timesteps = torch.from_numpy(self.ddim_timesteps).long()
self.ddim_alpha_cumprods = torch.from_numpy(self.ddim_alpha_cumprods)
self.ddim_alpha_cumprods_prev = torch.from_numpy(self.ddim_alpha_cumprods_prev)
def to(self, device: torch.device) -> "DDIMSolver":
self.ddim_timesteps = self.ddim_timesteps.to(device)
self.ddim_alpha_cumprods = self.ddim_alpha_cumprods.to(device)
self.ddim_alpha_cumprods_prev = self.ddim_alpha_cumprods_prev.to(device)
return self
def ddim_step(self, pred_x0: torch.Tensor, pred_noise: torch.Tensor,
timestep_index: torch.Tensor) -> torch.Tensor:
alpha_cumprod_prev = extract_into_tensor(self.ddim_alpha_cumprods_prev, timestep_index, pred_x0.shape)
dir_xt = (1.0 - alpha_cumprod_prev).sqrt() * pred_noise
x_prev = alpha_cumprod_prev.sqrt() * pred_x0 + dir_xt
return x_prev
@torch.no_grad()
def update_ema(target_params: Generator, source_params: Generator, rate: float = 0.99) -> None:
for tgt, src in zip(target_params, source_params):
tgt.detach().mul_(rate).add_(src, alpha=1 - rate)
def main():
cfg = parse_args()
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
set_seed(cfg.TRAIN.SEED_VALUE)
name_time_str = osp.join(cfg.NAME, datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S"))
output_dir = osp.join(cfg.FOLDER, name_time_str)
os.makedirs(output_dir, exist_ok=False)
os.makedirs(f"{output_dir}/checkpoints", exist_ok=False)
writer = SummaryWriter(output_dir)
stream_handler = logging.StreamHandler(sys.stdout)
file_handler = logging.FileHandler(osp.join(output_dir, 'output.log'))
handlers = [file_handler, stream_handler]
logging.basicConfig(level=logging.INFO,
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
handlers=handlers)
logger = logging.getLogger(__name__)
OmegaConf.save(cfg, osp.join(output_dir, 'config.yaml'))
transformers.utils.logging.set_verbosity_warning()
diffusers.utils.logging.set_verbosity_info()
logger.info(f'Training guidance scale range (w): [{cfg.TRAIN.w_min}, {cfg.TRAIN.w_max}]')
logger.info(f'EMA rate (mu): {cfg.TRAIN.ema_decay}')
logger.info(f'Skipping interval (k): {1000 / cfg.TRAIN.num_ddim_timesteps}')
logger.info(f'Loss type (huber or l2): {cfg.TRAIN.loss_type}')
datasets = get_datasets(cfg)[0]
train_dataloader = datasets.train_dataloader()
val_dataloader = datasets.val_dataloader()
logger.info(f"Loading pretrained model: {cfg.TRAIN.PRETRAINED}")
state_dict = torch.load(cfg.TRAIN.PRETRAINED, map_location="cpu")["state_dict"]
base_model = MLD(cfg, datasets)
base_model.load_state_dict(state_dict)
noise_scheduler = base_model.noise_scheduler
alpha_schedule = torch.sqrt(noise_scheduler.alphas_cumprod)
sigma_schedule = torch.sqrt(1 - noise_scheduler.alphas_cumprod)
solver = DDIMSolver(
noise_scheduler.alphas_cumprod.numpy(),
timesteps=noise_scheduler.config.num_train_timesteps,
ddim_timesteps=cfg.TRAIN.num_ddim_timesteps,
)
base_model.to(device)
vae = base_model.vae
text_encoder = base_model.text_encoder
teacher_unet = base_model.denoiser
vae.requires_grad_(False)
text_encoder.requires_grad_(False)
teacher_unet.requires_grad_(False)
# Apply CFG here (Important!!!)
cfg.model.denoiser.params.time_cond_proj_dim = cfg.TRAIN.unet_time_cond_proj_dim
unet = instantiate_from_config(cfg.model.denoiser)
unet.load_state_dict(teacher_unet.state_dict(), strict=False)
target_unet = instantiate_from_config(cfg.model.denoiser)
target_unet.load_state_dict(teacher_unet.state_dict(), strict=False)
# Only evaluate the online network
base_model.denoiser = unet
unet = unet.to(device)
target_unet = target_unet.to(device)
target_unet.requires_grad_(False)
# Also move the alpha and sigma noise schedules to device
alpha_schedule = alpha_schedule.to(device)
sigma_schedule = sigma_schedule.to(device)
solver = solver.to(device)
optimizer = torch.optim.AdamW(
unet.parameters(),
lr=cfg.TRAIN.learning_rate,
betas=(cfg.TRAIN.adam_beta1, cfg.TRAIN.adam_beta2),
weight_decay=cfg.TRAIN.adam_weight_decay,
eps=cfg.TRAIN.adam_epsilon)
if cfg.TRAIN.max_train_steps == -1:
assert cfg.TRAIN.max_train_epochs != -1
cfg.TRAIN.max_train_steps = cfg.TRAIN.max_train_epochs * len(train_dataloader)
if cfg.TRAIN.checkpointing_steps == -1:
assert cfg.TRAIN.checkpointing_epochs != -1
cfg.TRAIN.checkpointing_steps = cfg.TRAIN.checkpointing_epochs * len(train_dataloader)
if cfg.TRAIN.validation_steps == -1:
assert cfg.TRAIN.validation_epochs != -1
cfg.TRAIN.validation_steps = cfg.TRAIN.validation_epochs * len(train_dataloader)
lr_scheduler = get_scheduler(
cfg.TRAIN.lr_scheduler,
optimizer=optimizer,
num_warmup_steps=cfg.TRAIN.lr_warmup_steps,
num_training_steps=cfg.TRAIN.max_train_steps)
uncond_prompt_embeds = text_encoder([""] * cfg.TRAIN.BATCH_SIZE)
# Train!
logger.info("***** Running training *****")
logging.info(f" Num examples = {len(train_dataloader.dataset)}")
logging.info(f" Num Epochs = {cfg.TRAIN.max_train_epochs}")
logging.info(f" Instantaneous batch size per device = {cfg.TRAIN.BATCH_SIZE}")
logging.info(f" Total optimization steps = {cfg.TRAIN.max_train_steps}")
global_step = 0
first_epoch = 0
progress_bar = tqdm(range(0, cfg.TRAIN.max_train_steps), desc="Steps")
@torch.no_grad()
def validation():
base_model.eval()
for val_batch in tqdm(val_dataloader):
val_batch = move_batch_to_device(val_batch, device)
base_model.allsplit_step('test', val_batch)
metrics = base_model.allsplit_epoch_end()
max_val_rp1 = metrics['Metrics/R_precision_top_1']
min_val_fid = metrics['Metrics/FID']
print_table(f'Metrics@Step-{global_step}', metrics)
for k, v in metrics.items():
writer.add_scalar(k, v, global_step=global_step)
base_model.train()
return max_val_rp1, min_val_fid
max_rp1, min_fid = validation()
for epoch in range(first_epoch, cfg.TRAIN.max_train_epochs):
for step, batch in enumerate(train_dataloader):
batch = move_batch_to_device(batch, device)
feats_ref = batch["motion"]
lengths = batch["length"]
text = batch['text']
# Encode motions to latents
with torch.no_grad():
latents, _ = vae.encode(feats_ref, lengths)
latents = latents.permute(1, 0, 2)
prompt_embeds = text_encoder(text)
# Sample noise that we'll add to the latents
noise = torch.randn_like(latents)
bsz = latents.shape[0]
# Sample a random timestep for each image t_n ~ U[0, N - k - 1] without bias.
topk = noise_scheduler.config.num_train_timesteps // cfg.TRAIN.num_ddim_timesteps
index = torch.randint(0, cfg.TRAIN.num_ddim_timesteps, (bsz,), device=latents.device).long()
start_timesteps = solver.ddim_timesteps[index]
timesteps = start_timesteps - topk
timesteps = torch.where(timesteps < 0, torch.zeros_like(timesteps), timesteps)
# Get boundary scalings for start_timesteps and (end) timesteps.
c_skip_start, c_out_start = scalings_for_boundary_conditions(start_timesteps)
c_skip_start, c_out_start = [append_dims(x, latents.ndim) for x in [c_skip_start, c_out_start]]
c_skip, c_out = scalings_for_boundary_conditions(timesteps)
c_skip, c_out = [append_dims(x, latents.ndim) for x in [c_skip, c_out]]
# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process) [z_{t_{n + k}} in Algorithm 1]
noisy_model_input = noise_scheduler.add_noise(latents, noise, start_timesteps)
# Sample a random guidance scale w from U[w_min, w_max] and embed it
w = (cfg.TRAIN.w_max - cfg.TRAIN.w_min) * torch.rand((bsz,)) + cfg.TRAIN.w_min
w_embedding = guidance_scale_embedding(w, embedding_dim=cfg.TRAIN.unet_time_cond_proj_dim)
w = append_dims(w, latents.ndim)
# Move to U-Net device and dtype
w = w.to(device=latents.device, dtype=latents.dtype)
w_embedding = w_embedding.to(device=latents.device, dtype=latents.dtype)
# Get online LCM prediction on z_{t_{n + k}}, w, c, t_{n + k}
noise_pred = unet(
noisy_model_input,
start_timesteps,
timestep_cond=w_embedding,
encoder_hidden_states=prompt_embeds)
pred_x_0 = predicted_origin(
noise_pred,
start_timesteps,
noisy_model_input,
noise_scheduler.config.prediction_type,
alpha_schedule,
sigma_schedule)
model_pred = c_skip_start * noisy_model_input + c_out_start * pred_x_0
# Use the ODE solver to predict the k-th step in the augmented PF-ODE trajectory after
# noisy_latents with both the conditioning embedding c and unconditional embedding 0
# Get teacher model prediction on noisy_latents and conditional embedding
with torch.no_grad():
cond_teacher_output = teacher_unet(
noisy_model_input,
start_timesteps,
encoder_hidden_states=prompt_embeds)
cond_pred_x0 = predicted_origin(
cond_teacher_output,
start_timesteps,
noisy_model_input,
noise_scheduler.config.prediction_type,
alpha_schedule,
sigma_schedule)
# Get teacher model prediction on noisy_latents and unconditional embedding
uncond_teacher_output = teacher_unet(
noisy_model_input,
start_timesteps,
encoder_hidden_states=uncond_prompt_embeds[:bsz])
uncond_pred_x0 = predicted_origin(
uncond_teacher_output,
start_timesteps,
noisy_model_input,
noise_scheduler.config.prediction_type,
alpha_schedule,
sigma_schedule)
# Perform "CFG" to get z_prev estimate (using the LCM paper's CFG formulation)
pred_x0 = cond_pred_x0 + w * (cond_pred_x0 - uncond_pred_x0)
pred_noise = cond_teacher_output + w * (cond_teacher_output - uncond_teacher_output)
x_prev = solver.ddim_step(pred_x0, pred_noise, index)
# Get target LCM prediction on z_prev, w, c, t_n
with torch.no_grad():
target_noise_pred = target_unet(
x_prev.float(),
timesteps,
timestep_cond=w_embedding,
encoder_hidden_states=prompt_embeds)
pred_x_0 = predicted_origin(
target_noise_pred,
timesteps,
x_prev,
noise_scheduler.config.prediction_type,
alpha_schedule,
sigma_schedule)
target = c_skip * x_prev + c_out * pred_x_0
# Calculate loss
if cfg.TRAIN.loss_type == "l2":
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
elif cfg.TRAIN.loss_type == "huber":
loss = torch.mean(
torch.sqrt(
(model_pred.float() - target.float()) ** 2 + cfg.TRAIN.huber_c ** 2) - cfg.TRAIN.huber_c
)
# Back propagate on the online student model (`unet`)
loss.backward()
torch.nn.utils.clip_grad_norm_(unet.parameters(), cfg.TRAIN.max_grad_norm)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad(set_to_none=True)
# Make EMA update to target student model parameters
update_ema(target_unet.parameters(), unet.parameters(), cfg.TRAIN.ema_decay)
progress_bar.update(1)
global_step += 1
if global_step % cfg.TRAIN.checkpointing_steps == 0:
save_path = os.path.join(output_dir, 'checkpoints', f"checkpoint-{global_step}.ckpt")
ckpt = dict(state_dict=base_model.state_dict())
base_model.on_save_checkpoint(ckpt)
torch.save(ckpt, save_path)
logger.info(f"Saved state to {save_path}")
if global_step % cfg.TRAIN.validation_steps == 0:
cur_rp1, cur_fid = validation()
if cur_rp1 > max_rp1:
max_rp1 = cur_rp1
save_path = os.path.join(output_dir, 'checkpoints',
f"checkpoint-{global_step}-rp1-{round(cur_rp1, 3)}.ckpt")
ckpt = dict(state_dict=base_model.state_dict())
base_model.on_save_checkpoint(ckpt)
torch.save(ckpt, save_path)
logger.info(f"Saved state to {save_path} with rp1:{round(cur_rp1, 3)}")
if cur_fid < min_fid:
min_fid = cur_fid
save_path = os.path.join(output_dir, 'checkpoints',
f"checkpoint-{global_step}-fid-{round(cur_fid, 3)}.ckpt")
ckpt = dict(state_dict=base_model.state_dict())
base_model.on_save_checkpoint(ckpt)
torch.save(ckpt, save_path)
logger.info(f"Saved state to {save_path} with fid:{round(cur_fid, 3)}")
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)
writer.add_scalar('loss', logs['loss'], global_step=global_step)
writer.add_scalar('lr', logs['lr'], global_step=global_step)
if global_step >= cfg.TRAIN.max_train_steps:
break
save_path = os.path.join(output_dir, 'checkpoints', f"checkpoint-last.ckpt")
ckpt = dict(state_dict=base_model.state_dict())
base_model.on_save_checkpoint(ckpt)
torch.save(ckpt, save_path)
if __name__ == "__main__":
main()