Spaces:
Sleeping
Sleeping
from utils.dist import * | |
from parse import * | |
from utils.util import find_free_port | |
import torch.multiprocessing as mp | |
import torch.distributed | |
from importlib import import_module | |
import os | |
import glob | |
from inputs import args_parser | |
def main_worker(rank, opt): | |
if 'local_rank' not in opt: | |
opt['local_rank'] = opt['global_rank'] = rank | |
if opt['distributed']: | |
torch.cuda.set_device(int(opt['local_rank'])) | |
torch.distributed.init_process_group(backend='nccl', | |
init_method=opt['init_method'], | |
world_size=opt['world_size'], | |
rank=opt['global_rank'], | |
group_name='mtorch') | |
print('using GPU {}-{} for training'.format( | |
int(opt['global_rank']), int(opt['local_rank']))) | |
if torch.cuda.is_available(): | |
opt['device'] = torch.device("cuda:{}".format(opt['local_rank'])) | |
else: | |
opt['device'] = 'cpu' | |
pkg = import_module('networks.{}'.format(opt['network'])) | |
trainer = pkg.Network(opt, rank) | |
trainer.train() | |
def main(args_obj): | |
opt = parse(args_obj) | |
opt['world_size'] = get_world_size() | |
free_port = find_free_port() | |
master_ip = get_master_ip() | |
opt['init_method'] = "tcp://{}:{}".format(master_ip, free_port) | |
opt['distributed'] = True if opt['world_size'] > 1 else False | |
print(f'World size is: {opt["world_size"]}, and init_method is: {opt["init_method"]}') | |
print('Import network module: ', opt['network']) | |
checkpoint, config = glob.glob(os.path.join(opt['flow_checkPoint'], '*.tar'))[0], \ | |
glob.glob(os.path.join(opt['flow_checkPoint'], '*.yaml'))[0] | |
with open(config, 'r') as f: | |
configs = yaml.full_load(f) | |
opt['flow_config'] = configs | |
opt['flow_checkPoint'] = checkpoint | |
if args.finetune == 1: | |
opt['finetune'] = True | |
else: | |
opt['finetune'] = False | |
if opt['gen_state'] != '': | |
opt['path']['gen_state'] = opt['gen_state'] | |
if opt['dis_state'] != '': | |
opt['path']['dis_state'] = opt['dis_state'] | |
if opt['opt_state'] != '': | |
opt['path']['opt_state'] = opt['opt_state'] | |
opt['input_resolution'] = (opt['res_h'], opt['res_w']) | |
opt['kernel_size'] = (opt['kernel_size_h'], opt['kernel_size_w']) | |
opt['stride'] = (opt['stride_h'], opt['stride_w']) | |
opt['padding'] = (opt['pad_h'], opt['pad_w']) | |
print('model is: {}'.format(opt['model'])) | |
if get_master_ip() == "127.0.0.1": | |
# localhost | |
mp.spawn(main_worker, nprocs=opt['world_size'], args=(opt,)) | |
else: | |
# multiple processes should be launched by openmpi | |
opt['local_rank'] = get_local_rank() | |
opt['global_rank'] = get_global_rank() | |
main_worker(-1, opt) | |
if __name__ == '__main__': | |
args = args_parser() | |
args_obj = vars(args) | |
main(args_obj) | |