ahsanMah's picture
adding files for building model
3f1e960
raw
history blame
5.13 kB
# Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# This work is licensed under a Creative Commons
# Attribution-NonCommercial-ShareAlike 4.0 International License.
# You should have received a copy of the license along with this
# work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
import os
import re
import socket
import torch
import torch.distributed
from . import training_stats
_sync_device = None
#----------------------------------------------------------------------------
def init():
global _sync_device
if not torch.distributed.is_initialized():
# Setup some reasonable defaults for env-based distributed init if
# not set by the running environment.
if 'MASTER_ADDR' not in os.environ:
os.environ['MASTER_ADDR'] = 'localhost'
if 'MASTER_PORT' not in os.environ:
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
s.bind(('', 0))
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
os.environ['MASTER_PORT'] = str(s.getsockname()[1])
s.close()
if 'RANK' not in os.environ:
os.environ['RANK'] = '0'
if 'LOCAL_RANK' not in os.environ:
os.environ['LOCAL_RANK'] = '0'
if 'WORLD_SIZE' not in os.environ:
os.environ['WORLD_SIZE'] = '1'
backend = 'gloo' if os.name == 'nt' else 'nccl'
torch.distributed.init_process_group(backend=backend, init_method='env://')
torch.cuda.set_device(int(os.environ.get('LOCAL_RANK', '0')))
_sync_device = torch.device('cuda') if get_world_size() > 1 else None
training_stats.init_multiprocessing(rank=get_rank(), sync_device=_sync_device)
#----------------------------------------------------------------------------
def get_rank():
return torch.distributed.get_rank() if torch.distributed.is_initialized() else 0
#----------------------------------------------------------------------------
def get_world_size():
return torch.distributed.get_world_size() if torch.distributed.is_initialized() else 1
#----------------------------------------------------------------------------
def should_stop():
return False
#----------------------------------------------------------------------------
def should_suspend():
return False
#----------------------------------------------------------------------------
def request_suspend():
pass
#----------------------------------------------------------------------------
def update_progress(cur, total):
pass
#----------------------------------------------------------------------------
def print0(*args, **kwargs):
if get_rank() == 0:
print(*args, **kwargs)
#----------------------------------------------------------------------------
class CheckpointIO:
def __init__(self, **kwargs):
self._state_objs = kwargs
def save(self, pt_path, verbose=True):
if verbose:
print0(f'Saving {pt_path} ... ', end='', flush=True)
data = dict()
for name, obj in self._state_objs.items():
if obj is None:
data[name] = None
elif isinstance(obj, dict):
data[name] = obj
elif hasattr(obj, 'state_dict'):
data[name] = obj.state_dict()
elif hasattr(obj, '__getstate__'):
data[name] = obj.__getstate__()
elif hasattr(obj, '__dict__'):
data[name] = obj.__dict__
else:
raise ValueError(f'Invalid state object of type {type(obj).__name__}')
if get_rank() == 0:
torch.save(data, pt_path)
if verbose:
print0('done')
def load(self, pt_path, verbose=True):
if verbose:
print0(f'Loading {pt_path} ... ', end='', flush=True)
data = torch.load(pt_path, map_location=torch.device('cpu'))
for name, obj in self._state_objs.items():
if obj is None:
pass
elif isinstance(obj, dict):
obj.clear()
obj.update(data[name])
elif hasattr(obj, 'load_state_dict'):
obj.load_state_dict(data[name])
elif hasattr(obj, '__setstate__'):
obj.__setstate__(data[name])
elif hasattr(obj, '__dict__'):
obj.__dict__.clear()
obj.__dict__.update(data[name])
else:
raise ValueError(f'Invalid state object of type {type(obj).__name__}')
if verbose:
print0('done')
def load_latest(self, run_dir, pattern=r'training-state-(\d+).pt', verbose=True):
fnames = [entry.name for entry in os.scandir(run_dir) if entry.is_file() and re.fullmatch(pattern, entry.name)]
if len(fnames) == 0:
return None
pt_path = os.path.join(run_dir, max(fnames, key=lambda x: float(re.fullmatch(pattern, x).group(1))))
self.load(pt_path, verbose=verbose)
return pt_path
#----------------------------------------------------------------------------