Spaces:
Running
on
Zero
Running
on
Zero
# Copyright (c) ByteDance, Inc. and its affiliates. | |
# Copyright (c) Chutong Meng | |
# | |
# This source code is licensed under the CC BY-NC license found in the | |
# LICENSE file in the root directory of this source tree. | |
# Based on AudioDec (https://github.com/facebookresearch/AudioDec) | |
import argparse | |
import logging | |
import os | |
logging.basicConfig( | |
format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", | |
datefmt="%Y-%m-%d %H:%M:%S", | |
level=os.environ.get("LOGLEVEL", "INFO").upper(), | |
) | |
logger = logging.getLogger("repcodec_train") # init logger before other modules | |
import random | |
import numpy as np | |
import torch | |
import yaml | |
from torch.utils.data import DataLoader | |
from dataloader import ReprDataset, ReprCollater | |
from losses.repr_reconstruct_loss import ReprReconstructLoss | |
from repcodec.RepCodec import RepCodec | |
from trainer.autoencoder import Trainer | |
class TrainMain: | |
def __init__(self, args): | |
# Fix seed and make backends deterministic | |
random.seed(args.seed) | |
np.random.seed(args.seed) | |
torch.manual_seed(args.seed) | |
if not torch.cuda.is_available(): | |
self.device = torch.device('cpu') | |
logger.info(f"device: cpu") | |
else: | |
self.device = torch.device('cuda:0') # only supports single gpu for now | |
logger.info(f"device: gpu") | |
torch.cuda.manual_seed_all(args.seed) | |
if args.disable_cudnn == "False": | |
torch.backends.cudnn.benchmark = True | |
# initialize config | |
with open(args.config, 'r') as f: | |
self.config = yaml.load(f, Loader=yaml.FullLoader) | |
self.config.update(vars(args)) | |
# initialize model folder | |
expdir = os.path.join(args.exp_root, args.tag) | |
os.makedirs(expdir, exist_ok=True) | |
self.config["outdir"] = expdir | |
# save config | |
with open(os.path.join(expdir, "config.yml"), "w") as f: | |
yaml.dump(self.config, f, Dumper=yaml.Dumper) | |
for key, value in self.config.items(): | |
logger.info(f"{key} = {value}") | |
# initialize attribute | |
self.resume: str = args.resume | |
self.data_loader = None | |
self.model = None | |
self.optimizer = None | |
self.scheduler = None | |
self.criterion = None | |
self.trainer = None | |
# initialize batch_length | |
self.batch_length: int = self.config['batch_length'] | |
self.data_path: str = self.config['data']['path'] | |
def initialize_data_loader(self): | |
train_set = self._build_dataset("train") | |
valid_set = self._build_dataset("valid") | |
collater = ReprCollater() | |
logger.info(f"The number of training files = {len(train_set)}.") | |
logger.info(f"The number of validation files = {len(valid_set)}.") | |
dataset = {"train": train_set, "dev": valid_set} | |
self._set_data_loader(dataset, collater) | |
def define_model_optimizer_scheduler(self): | |
# model arch | |
self.model = { | |
"repcodec": RepCodec(**self.config["model_params"]).to(self.device) | |
} | |
logger.info(f"Model Arch:\n{self.model['repcodec']}") | |
# opt | |
optimizer_class = getattr( | |
torch.optim, | |
self.config["model_optimizer_type"] | |
) | |
self.optimizer = { | |
"repcodec": optimizer_class( | |
self.model["repcodec"].parameters(), | |
**self.config["model_optimizer_params"] | |
) | |
} | |
# scheduler | |
scheduler_class = getattr( | |
torch.optim.lr_scheduler, | |
self.config.get("model_scheduler_type", "StepLR"), | |
) | |
self.scheduler = { | |
"repcodec": scheduler_class( | |
optimizer=self.optimizer["repcodec"], | |
**self.config["model_scheduler_params"] | |
) | |
} | |
def define_criterion(self): | |
self.criterion = { | |
"repr_reconstruct_loss": ReprReconstructLoss( | |
**self.config.get("repr_reconstruct_loss_params", {}), | |
).to(self.device) | |
} | |
def define_trainer(self): | |
self.trainer = Trainer( | |
steps=0, | |
epochs=0, | |
data_loader=self.data_loader, | |
model=self.model, | |
criterion=self.criterion, | |
optimizer=self.optimizer, | |
scheduler=self.scheduler, | |
config=self.config, | |
device=self.device | |
) | |
def initialize_model(self): | |
initial = self.config.get("initial", "") | |
if os.path.exists(self.resume): # resume from trained model | |
self.trainer.load_checkpoint(self.resume) | |
logger.info(f"Successfully resumed from {self.resume}.") | |
elif os.path.exists(initial): # initial new model with the pre-trained model | |
self.trainer.load_checkpoint(initial, load_only_params=True) | |
logger.info(f"Successfully initialize parameters from {initial}.") | |
else: | |
logger.info("Train from scrach") | |
def run(self): | |
assert self.trainer is not None | |
self.trainer: Trainer | |
try: | |
logger.info(f"The current training step: {self.trainer.steps}") | |
self.trainer.train_max_steps = self.config["train_max_steps"] | |
if not self.trainer._check_train_finish(): | |
self.trainer.run() | |
finally: | |
self.trainer.save_checkpoint( | |
os.path.join(self.config["outdir"], f"checkpoint-{self.trainer.steps}steps.pkl") | |
) | |
logger.info(f"Successfully saved checkpoint @ {self.trainer.steps}steps.") | |
def _build_dataset( | |
self, subset: str | |
) -> ReprDataset: | |
data_dir = os.path.join( | |
self.data_path, self.config['data']['subset'][subset] | |
) | |
params = { | |
"data_dir": data_dir, | |
"batch_len": self.batch_length | |
} | |
return ReprDataset(**params) | |
def _set_data_loader(self, dataset, collater): | |
self.data_loader = { | |
"train": DataLoader( | |
dataset=dataset["train"], | |
shuffle=True, | |
collate_fn=collater, | |
batch_size=self.config["batch_size"], | |
num_workers=self.config["num_workers"], | |
pin_memory=self.config["pin_memory"], | |
), | |
"dev": DataLoader( | |
dataset=dataset["dev"], | |
shuffle=False, | |
collate_fn=collater, | |
batch_size=self.config["batch_size"], | |
num_workers=0, | |
pin_memory=False, # save some memory. set to True if you have enough memory. | |
), | |
} | |
def train(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument( | |
"-c", "--config", type=str, required=True, | |
help="the path of config yaml file." | |
) | |
parser.add_argument( | |
"--tag", type=str, required=True, | |
help="the outputs will be saved to exp_root/tag/" | |
) | |
parser.add_argument( | |
"--exp_root", type=str, default="exp" | |
) | |
parser.add_argument( | |
"--resume", default="", type=str, nargs="?", | |
help='checkpoint file path to resume training. (default="")', | |
) | |
parser.add_argument("--seed", default=1337, type=int) | |
parser.add_argument("--disable_cudnn", choices=("True", "False"), default="False", help="Disable CUDNN") | |
args = parser.parse_args() | |
train_main = TrainMain(args) | |
train_main.initialize_data_loader() | |
train_main.define_model_optimizer_scheduler() | |
train_main.define_criterion() | |
train_main.define_trainer() | |
train_main.initialize_model() | |
train_main.run() | |
if __name__ == '__main__': | |
train() | |