Spaces:
Running
on
Zero
Running
on
Zero
import sys | |
# sys.path.append("src") | |
import shutil | |
import os | |
os.environ["TOKENIZERS_PARALLELISM"] = "true" | |
import argparse | |
import yaml | |
import torch | |
from tqdm import tqdm | |
from pytorch_lightning.strategies.ddp import DDPStrategy | |
from qa_mdt.audioldm_train.modules.latent_diffusion.ddpm import LatentDiffusion | |
from torch.utils.data import WeightedRandomSampler | |
from torch.utils.data import DataLoader | |
from pytorch_lightning import Trainer, seed_everything | |
from pytorch_lightning.callbacks import ModelCheckpoint | |
from pytorch_lightning.loggers import WandbLogger | |
from qa_mdt.audioldm_train.utilities.tools import ( | |
listdir_nohidden, | |
get_restore_step, | |
copy_test_subset_data, | |
) | |
import wandb | |
from qa_mdt.audioldm_train.utilities.model_util import instantiate_from_config | |
import logging | |
logging.basicConfig(level=logging.WARNING) | |
def convert_path(path): | |
parts = path.decode().split("/")[-4:] | |
base = "" | |
result = "/".join(parts) | |
def print_on_rank0(msg): | |
if torch.distributed.get_rank() == 0: | |
print(msg) | |
def main(configs, config_yaml_path, exp_group_name, exp_name, perform_validation): | |
print("MAIN START") | |
# cpth = "/train20/intern/permanent/changli7/dataset_ptm/test_dataset/dataset/audioset/zip_audios/unbalanced_train_segments/unbalanced_train_segments_part9/Y7fmOlUlwoNg.wav" | |
# convert_path(cpth) | |
if "seed" in configs.keys(): | |
seed_everything(configs["seed"]) | |
else: | |
print("SEED EVERYTHING TO 0") | |
seed_everything(1234) | |
if "precision" in configs.keys(): | |
torch.set_float32_matmul_precision( | |
configs["precision"] | |
) # highest, high, medium | |
log_path = configs["log_directory"] | |
batch_size = configs["model"]["params"]["batchsize"] | |
train_lmdb_path = configs["train_path"]["train_lmdb_path"] | |
train_key_path = [_ + '/data_key.key' for _ in train_lmdb_path] | |
val_lmdb_path = configs["val_path"]["val_lmdb_path"] | |
val_key_path = configs["val_path"]["val_key_path"] | |
#try: | |
mos_path = configs["mos_path"] | |
from qa_mdt.audioldm_train.utilities.data.hhhh import AudioDataset | |
dataset = AudioDataset(config=configs, lmdb_path=train_lmdb_path, key_path=train_key_path, mos_path=mos_path) | |
loader = DataLoader( | |
dataset, | |
batch_size=batch_size, | |
num_workers=8, | |
pin_memory=True, | |
shuffle=True, | |
) | |
print( | |
"The length of the dataset is %s, the length of the dataloader is %s, the batchsize is %s" | |
% (len(dataset), len(loader), batch_size) | |
) | |
try: | |
val_dataset = AudioDataset(config=configs, lmdb_path=val_lmdb_path, key_path=val_key_path, mos_path=mos_path) | |
except: | |
val_dataset = AudioDataset(config=configs, lmdb_path=val_lmdb_path, key_path=val_key_path) | |
val_loader = DataLoader( | |
val_dataset, | |
batch_size=8, | |
) | |
# Copy test data | |
import os | |
test_data_subset_folder = os.path.join( | |
os.path.dirname(configs["log_directory"]), | |
"testset_data", | |
"tmp", | |
) | |
os.makedirs(test_data_subset_folder, exist_ok=True) | |
# copy to test: | |
# import pdb | |
# pdb.set_trace() | |
# for i in range(len(val_dataset.keys)): | |
# key_tmp = val_dataset.keys[i].decode() | |
# cmd = "cp {} {}".format(key_tmp, os.path.join(test_data_subset_folder)) | |
# os.system(cmd) | |
try: | |
config_reload_from_ckpt = configs["reload_from_ckpt"] | |
except: | |
config_reload_from_ckpt = None | |
try: | |
limit_val_batches = configs["step"]["limit_val_batches"] | |
except: | |
limit_val_batches = None | |
validation_every_n_epochs = configs["step"]["validation_every_n_epochs"] | |
save_checkpoint_every_n_steps = configs["step"]["save_checkpoint_every_n_steps"] | |
max_steps = configs["step"]["max_steps"] | |
save_top_k = configs["step"]["save_top_k"] | |
checkpoint_path = os.path.join(log_path, exp_group_name, exp_name, "checkpoints") | |
wandb_path = os.path.join(log_path, exp_group_name, exp_name) | |
checkpoint_callback = ModelCheckpoint( | |
dirpath=checkpoint_path, | |
monitor="global_step", | |
mode="max", | |
filename="checkpoint-fad-{val/frechet_inception_distance:.2f}-global_step={global_step:.0f}", | |
every_n_train_steps=save_checkpoint_every_n_steps, | |
save_top_k=save_top_k, | |
auto_insert_metric_name=False, | |
save_last=False, | |
) | |
os.makedirs(checkpoint_path, exist_ok=True) | |
# shutil.copy(config_yaml_path, wandb_path) | |
if len(os.listdir(checkpoint_path)) > 0: | |
print("Load checkpoint from path: %s" % checkpoint_path) | |
restore_step, n_step = get_restore_step(checkpoint_path) | |
resume_from_checkpoint = os.path.join(checkpoint_path, restore_step) | |
print("Resume from checkpoint", resume_from_checkpoint) | |
elif config_reload_from_ckpt is not None: | |
resume_from_checkpoint = config_reload_from_ckpt | |
print("Reload ckpt specified in the config file %s" % resume_from_checkpoint) | |
else: | |
print("Train from scratch") | |
resume_from_checkpoint = None | |
devices = torch.cuda.device_count() | |
latent_diffusion = instantiate_from_config(configs["model"]) | |
latent_diffusion.set_log_dir(log_path, exp_group_name, exp_name) | |
wandb_logger = WandbLogger( | |
save_dir=wandb_path, | |
project=configs["project"], | |
config=configs, | |
name="%s/%s" % (exp_group_name, exp_name), | |
) | |
latent_diffusion.test_data_subset_path = test_data_subset_folder | |
print("==> Save checkpoint every %s steps" % save_checkpoint_every_n_steps) | |
print("==> Perform validation every %s epochs" % validation_every_n_epochs) | |
trainer = Trainer( | |
accelerator="auto", | |
devices="auto", | |
logger=wandb_logger, | |
max_steps=max_steps, | |
num_sanity_val_steps=1, | |
limit_val_batches=limit_val_batches, | |
check_val_every_n_epoch=validation_every_n_epochs, | |
strategy=DDPStrategy(find_unused_parameters=True), | |
gradient_clip_val=2.0,callbacks=[checkpoint_callback],num_nodes=1, | |
) | |
trainer.fit(latent_diffusion, loader, val_loader, ckpt_path=resume_from_checkpoint) | |
################################################################################################################ | |
# if(resume_from_checkpoint is not None): | |
# ckpt = torch.load(resume_from_checkpoint)["state_dict"] | |
# key_not_in_model_state_dict = [] | |
# size_mismatch_keys = [] | |
# state_dict = latent_diffusion.state_dict() | |
# print("Filtering key for reloading:", resume_from_checkpoint) | |
# print("State dict key size:", len(list(state_dict.keys())), len(list(ckpt.keys()))) | |
# for key in tqdm(list(ckpt.keys())): | |
# if(key not in state_dict.keys()): | |
# key_not_in_model_state_dict.append(key) | |
# del ckpt[key] | |
# continue | |
# if(state_dict[key].size() != ckpt[key].size()): | |
# del ckpt[key] | |
# size_mismatch_keys.append(key) | |
# if(len(key_not_in_model_state_dict) != 0 or len(size_mismatch_keys) != 0): | |
# print("⛳", end=" ") | |
# print("==> Warning: The following key in the checkpoint is not presented in the model:", key_not_in_model_state_dict) | |
# print("==> Warning: These keys have different size between checkpoint and current model: ", size_mismatch_keys) | |
# latent_diffusion.load_state_dict(ckpt, strict=False) | |
# if(perform_validation): | |
# trainer.validate(latent_diffusion, val_loader) | |
# trainer.fit(latent_diffusion, loader, val_loader) | |
################################################################################################################ | |
if __name__ == "__main__": | |
print("ok") | |
parser = argparse.ArgumentParser() | |
parser.add_argument( | |
"-c", | |
"--config_yaml", | |
type=str, | |
required=False, | |
help="path to config .yaml file", | |
) | |
parser.add_argument("--val", action="store_true") | |
args = parser.parse_args() | |
perform_validation = args.val | |
assert torch.cuda.is_available(), "CUDA is not available" | |
config_yaml = args.config_yaml | |
exp_name = os.path.basename(config_yaml.split(".")[0]) | |
exp_group_name = os.path.basename(os.path.dirname(config_yaml)) | |
config_yaml_path = os.path.join(config_yaml) | |
config_yaml = yaml.load(open(config_yaml_path, "r"), Loader=yaml.FullLoader) | |
if perform_validation: | |
config_yaml["model"]["params"]["cond_stage_config"][ | |
"crossattn_audiomae_generated" | |
]["params"]["use_gt_mae_output"] = False | |
config_yaml["step"]["limit_val_batches"] = None | |
main(config_yaml, config_yaml_path, exp_group_name, exp_name, perform_validation) | |