import torch from torch.utils.data import DataLoader from torch.nn.utils import clip_grad_norm_ from torch.utils.tensorboard import SummaryWriter from os.path import join as pjoin import torch.nn.functional as F import torch.optim as optim import time import numpy as np from collections import OrderedDict, defaultdict from utils.eval_t2m import evaluation_vqvae, evaluation_res_conv from utils.utils import print_current_loss import os import sys def def_value(): return 0.0 class RVQTokenizerTrainer: def __init__(self, args, vq_model): self.opt = args self.vq_model = vq_model self.device = args.device if args.is_train: self.logger = SummaryWriter(args.log_dir) if args.recons_loss == 'l1': self.l1_criterion = torch.nn.L1Loss() elif args.recons_loss == 'l1_smooth': self.l1_criterion = torch.nn.SmoothL1Loss() # self.critic = CriticWrapper(self.opt.dataset_name, self.opt.device) def forward(self, batch_data): motions = batch_data.detach().to(self.device).float() pred_motion, loss_commit, perplexity = self.vq_model(motions) self.motions = motions self.pred_motion = pred_motion loss_rec = self.l1_criterion(pred_motion, motions) pred_local_pos = pred_motion[..., 4 : (self.opt.joints_num - 1) * 3 + 4] local_pos = motions[..., 4 : (self.opt.joints_num - 1) * 3 + 4] loss_explicit = self.l1_criterion(pred_local_pos, local_pos) loss = loss_rec + self.opt.loss_vel * loss_explicit + self.opt.commit * loss_commit # return loss, loss_rec, loss_vel, loss_commit, perplexity # return loss, loss_rec, loss_percept, loss_commit, perplexity return loss, loss_rec, loss_explicit, loss_commit, perplexity # @staticmethod def update_lr_warm_up(self, nb_iter, warm_up_iter, lr): current_lr = lr * (nb_iter + 1) / (warm_up_iter + 1) for param_group in self.opt_vq_model.param_groups: param_group["lr"] = current_lr return current_lr def save(self, file_name, ep, total_it): state = { "vq_model": self.vq_model.state_dict(), "opt_vq_model": self.opt_vq_model.state_dict(), "scheduler": self.scheduler.state_dict(), 'ep': ep, 'total_it': total_it, } torch.save(state, file_name) def resume(self, model_dir): checkpoint = torch.load(model_dir, map_location=self.device) self.vq_model.load_state_dict(checkpoint['vq_model']) self.opt_vq_model.load_state_dict(checkpoint['opt_vq_model']) self.scheduler.load_state_dict(checkpoint['scheduler']) return checkpoint['ep'], checkpoint['total_it'] def train(self, train_loader, val_loader, eval_val_loader, eval_wrapper, plot_eval=None): self.vq_model.to(self.device) self.opt_vq_model = optim.AdamW(self.vq_model.parameters(), lr=self.opt.lr, betas=(0.9, 0.99), weight_decay=self.opt.weight_decay) self.scheduler = torch.optim.lr_scheduler.MultiStepLR(self.opt_vq_model, milestones=self.opt.milestones, gamma=self.opt.gamma) epoch = 0 it = 0 if self.opt.is_continue: model_dir = pjoin(self.opt.model_dir, 'latest.tar') epoch, it = self.resume(model_dir) print("Load model epoch:%d iterations:%d"%(epoch, it)) start_time = time.time() total_iters = self.opt.max_epoch * len(train_loader) print(f'Total Epochs: {self.opt.max_epoch}, Total Iters: {total_iters}') print('Iters Per Epoch, Training: %04d, Validation: %03d' % (len(train_loader), len(eval_val_loader))) # val_loss = 0 # min_val_loss = np.inf # min_val_epoch = epoch current_lr = self.opt.lr logs = defaultdict(def_value, OrderedDict()) # sys.exit() best_fid, best_div, best_top1, best_top2, best_top3, best_matching, writer = evaluation_vqvae( self.opt.model_dir, eval_val_loader, self.vq_model, self.logger, epoch, best_fid=1000, best_div=100, best_top1=0, best_top2=0, best_top3=0, best_matching=100, eval_wrapper=eval_wrapper, save=False) while epoch < self.opt.max_epoch: self.vq_model.train() for i, batch_data in enumerate(train_loader): it += 1 if it < self.opt.warm_up_iter: current_lr = self.update_lr_warm_up(it, self.opt.warm_up_iter, self.opt.lr) loss, loss_rec, loss_vel, loss_commit, perplexity = self.forward(batch_data) self.opt_vq_model.zero_grad() loss.backward() self.opt_vq_model.step() if it >= self.opt.warm_up_iter: self.scheduler.step() logs['loss'] += loss.item() logs['loss_rec'] += loss_rec.item() # Note it not necessarily velocity, too lazy to change the name now logs['loss_vel'] += loss_vel.item() logs['loss_commit'] += loss_commit.item() logs['perplexity'] += perplexity.item() logs['lr'] += self.opt_vq_model.param_groups[0]['lr'] if it % self.opt.log_every == 0: mean_loss = OrderedDict() # self.logger.add_scalar('val_loss', val_loss, it) # self.l for tag, value in logs.items(): self.logger.add_scalar('Train/%s'%tag, value / self.opt.log_every, it) mean_loss[tag] = value / self.opt.log_every logs = defaultdict(def_value, OrderedDict()) print_current_loss(start_time, it, total_iters, mean_loss, epoch=epoch, inner_iter=i) if it % self.opt.save_latest == 0: self.save(pjoin(self.opt.model_dir, 'latest.tar'), epoch, it) self.save(pjoin(self.opt.model_dir, 'latest.tar'), epoch, it) epoch += 1 # if epoch % self.opt.save_every_e == 0: # self.save(pjoin(self.opt.model_dir, 'E%04d.tar' % (epoch)), epoch, total_it=it) print('Validation time:') self.vq_model.eval() val_loss_rec = [] val_loss_vel = [] val_loss_commit = [] val_loss = [] val_perpexity = [] with torch.no_grad(): for i, batch_data in enumerate(val_loader): loss, loss_rec, loss_vel, loss_commit, perplexity = self.forward(batch_data) # val_loss_rec += self.l1_criterion(self.recon_motions, self.motions).item() # val_loss_emb += self.embedding_loss.item() val_loss.append(loss.item()) val_loss_rec.append(loss_rec.item()) val_loss_vel.append(loss_vel.item()) val_loss_commit.append(loss_commit.item()) val_perpexity.append(perplexity.item()) # val_loss = val_loss_rec / (len(val_dataloader) + 1) # val_loss = val_loss / (len(val_dataloader) + 1) # val_loss_rec = val_loss_rec / (len(val_dataloader) + 1) # val_loss_emb = val_loss_emb / (len(val_dataloader) + 1) self.logger.add_scalar('Val/loss', sum(val_loss) / len(val_loss), epoch) self.logger.add_scalar('Val/loss_rec', sum(val_loss_rec) / len(val_loss_rec), epoch) self.logger.add_scalar('Val/loss_vel', sum(val_loss_vel) / len(val_loss_vel), epoch) self.logger.add_scalar('Val/loss_commit', sum(val_loss_commit) / len(val_loss), epoch) self.logger.add_scalar('Val/loss_perplexity', sum(val_perpexity) / len(val_loss_rec), epoch) print('Validation Loss: %.5f Reconstruction: %.5f, Velocity: %.5f, Commit: %.5f' % (sum(val_loss)/len(val_loss), sum(val_loss_rec)/len(val_loss), sum(val_loss_vel)/len(val_loss), sum(val_loss_commit)/len(val_loss))) # if sum(val_loss) / len(val_loss) < min_val_loss: # min_val_loss = sum(val_loss) / len(val_loss) # # if sum(val_loss_vel) / len(val_loss_vel) < min_val_loss: # # min_val_loss = sum(val_loss_vel) / len(val_loss_vel) # min_val_epoch = epoch # self.save(pjoin(self.opt.model_dir, 'finest.tar'), epoch, it) # print('Best Validation Model So Far!~') best_fid, best_div, best_top1, best_top2, best_top3, best_matching, writer = evaluation_vqvae( self.opt.model_dir, eval_val_loader, self.vq_model, self.logger, epoch, best_fid=best_fid, best_div=best_div, best_top1=best_top1, best_top2=best_top2, best_top3=best_top3, best_matching=best_matching, eval_wrapper=eval_wrapper) if epoch % self.opt.eval_every_e == 0: data = torch.cat([self.motions[:4], self.pred_motion[:4]], dim=0).detach().cpu().numpy() # np.save(pjoin(self.opt.eval_dir, 'E%04d.npy' % (epoch)), data) save_dir = pjoin(self.opt.eval_dir, 'E%04d' % (epoch)) os.makedirs(save_dir, exist_ok=True) plot_eval(data, save_dir) # if plot_eval is not None: # save_dir = pjoin(self.opt.eval_dir, 'E%04d' % (epoch)) # os.makedirs(save_dir, exist_ok=True) # plot_eval(data, save_dir) # if epoch - min_val_epoch >= self.opt.early_stop_e: # print('Early Stopping!~') class LengthEstTrainer(object): def __init__(self, args, estimator, text_encoder, encode_fnc): self.opt = args self.estimator = estimator self.text_encoder = text_encoder self.encode_fnc = encode_fnc self.device = args.device if args.is_train: # self.motion_dis self.logger = SummaryWriter(args.log_dir) self.mul_cls_criterion = torch.nn.CrossEntropyLoss() def resume(self, model_dir): checkpoints = torch.load(model_dir, map_location=self.device) self.estimator.load_state_dict(checkpoints['estimator']) # self.opt_estimator.load_state_dict(checkpoints['opt_estimator']) return checkpoints['epoch'], checkpoints['iter'] def save(self, model_dir, epoch, niter): state = { 'estimator': self.estimator.state_dict(), # 'opt_estimator': self.opt_estimator.state_dict(), 'epoch': epoch, 'niter': niter, } torch.save(state, model_dir) @staticmethod def zero_grad(opt_list): for opt in opt_list: opt.zero_grad() @staticmethod def clip_norm(network_list): for network in network_list: clip_grad_norm_(network.parameters(), 0.5) @staticmethod def step(opt_list): for opt in opt_list: opt.step() def train(self, train_dataloader, val_dataloader): self.estimator.to(self.device) self.text_encoder.to(self.device) self.opt_estimator = optim.Adam(self.estimator.parameters(), lr=self.opt.lr) epoch = 0 it = 0 if self.opt.is_continue: model_dir = pjoin(self.opt.model_dir, 'latest.tar') epoch, it = self.resume(model_dir) start_time = time.time() total_iters = self.opt.max_epoch * len(train_dataloader) print('Iters Per Epoch, Training: %04d, Validation: %03d' % (len(train_dataloader), len(val_dataloader))) val_loss = 0 min_val_loss = np.inf logs = defaultdict(float) while epoch < self.opt.max_epoch: # time0 = time.time() for i, batch_data in enumerate(train_dataloader): self.estimator.train() conds, _, m_lens = batch_data # word_emb = word_emb.detach().to(self.device).float() # pos_ohot = pos_ohot.detach().to(self.device).float() # m_lens = m_lens.to(self.device).long() text_embs = self.encode_fnc(self.text_encoder, conds, self.opt.device).detach() # print(text_embs.shape, text_embs.device) pred_dis = self.estimator(text_embs) self.zero_grad([self.opt_estimator]) gt_labels = m_lens // self.opt.unit_length gt_labels = gt_labels.long().to(self.device) # print(gt_labels.shape, pred_dis.shape) # print(gt_labels.max(), gt_labels.min()) # print(pred_dis) acc = (gt_labels == pred_dis.argmax(dim=-1)).sum() / len(gt_labels) loss = self.mul_cls_criterion(pred_dis, gt_labels) loss.backward() self.clip_norm([self.estimator]) self.step([self.opt_estimator]) logs['loss'] += loss.item() logs['acc'] += acc.item() it += 1 if it % self.opt.log_every == 0: mean_loss = OrderedDict({'val_loss': val_loss}) # self.logger.add_scalar('Val/loss', val_loss, it) for tag, value in logs.items(): self.logger.add_scalar("Train/%s"%tag, value / self.opt.log_every, it) mean_loss[tag] = value / self.opt.log_every logs = defaultdict(float) print_current_loss(start_time, it, total_iters, mean_loss, epoch=epoch, inner_iter=i) if it % self.opt.save_latest == 0: self.save(pjoin(self.opt.model_dir, 'latest.tar'), epoch, it) self.save(pjoin(self.opt.model_dir, 'latest.tar'), epoch, it) epoch += 1 print('Validation time:') val_loss = 0 val_acc = 0 # self.estimator.eval() with torch.no_grad(): for i, batch_data in enumerate(val_dataloader): self.estimator.eval() conds, _, m_lens = batch_data # word_emb = word_emb.detach().to(self.device).float() # pos_ohot = pos_ohot.detach().to(self.device).float() # m_lens = m_lens.to(self.device).long() text_embs = self.encode_fnc(self.text_encoder, conds, self.opt.device) pred_dis = self.estimator(text_embs) gt_labels = m_lens // self.opt.unit_length gt_labels = gt_labels.long().to(self.device) loss = self.mul_cls_criterion(pred_dis, gt_labels) acc = (gt_labels == pred_dis.argmax(dim=-1)).sum() / len(gt_labels) val_loss += loss.item() val_acc += acc.item() val_loss = val_loss / len(val_dataloader) val_acc = val_acc / len(val_dataloader) print('Validation Loss: %.5f Validation Acc: %.5f' % (val_loss, val_acc)) if val_loss < min_val_loss: self.save(pjoin(self.opt.model_dir, 'finest.tar'), epoch, it) min_val_loss = val_loss