Spaces:
Runtime error
Runtime error
# Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
# | |
# This work is licensed under a Creative Commons | |
# Attribution-NonCommercial-ShareAlike 4.0 International License. | |
# You should have received a copy of the license along with this | |
# work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/ | |
import os | |
import re | |
import socket | |
import torch | |
import torch.distributed | |
from . import training_stats | |
_sync_device = None | |
#---------------------------------------------------------------------------- | |
def init(): | |
global _sync_device | |
if not torch.distributed.is_initialized(): | |
# Setup some reasonable defaults for env-based distributed init if | |
# not set by the running environment. | |
if 'MASTER_ADDR' not in os.environ: | |
os.environ['MASTER_ADDR'] = 'localhost' | |
if 'MASTER_PORT' not in os.environ: | |
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) | |
s.bind(('', 0)) | |
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) | |
os.environ['MASTER_PORT'] = str(s.getsockname()[1]) | |
s.close() | |
if 'RANK' not in os.environ: | |
os.environ['RANK'] = '0' | |
if 'LOCAL_RANK' not in os.environ: | |
os.environ['LOCAL_RANK'] = '0' | |
if 'WORLD_SIZE' not in os.environ: | |
os.environ['WORLD_SIZE'] = '1' | |
backend = 'gloo' if os.name == 'nt' else 'nccl' | |
torch.distributed.init_process_group(backend=backend, init_method='env://') | |
torch.cuda.set_device(int(os.environ.get('LOCAL_RANK', '0'))) | |
_sync_device = torch.device('cuda') if get_world_size() > 1 else None | |
training_stats.init_multiprocessing(rank=get_rank(), sync_device=_sync_device) | |
#---------------------------------------------------------------------------- | |
def get_rank(): | |
return torch.distributed.get_rank() if torch.distributed.is_initialized() else 0 | |
#---------------------------------------------------------------------------- | |
def get_world_size(): | |
return torch.distributed.get_world_size() if torch.distributed.is_initialized() else 1 | |
#---------------------------------------------------------------------------- | |
def should_stop(): | |
return False | |
#---------------------------------------------------------------------------- | |
def should_suspend(): | |
return False | |
#---------------------------------------------------------------------------- | |
def request_suspend(): | |
pass | |
#---------------------------------------------------------------------------- | |
def update_progress(cur, total): | |
pass | |
#---------------------------------------------------------------------------- | |
def print0(*args, **kwargs): | |
if get_rank() == 0: | |
print(*args, **kwargs) | |
#---------------------------------------------------------------------------- | |
class CheckpointIO: | |
def __init__(self, **kwargs): | |
self._state_objs = kwargs | |
def save(self, pt_path, verbose=True): | |
if verbose: | |
print0(f'Saving {pt_path} ... ', end='', flush=True) | |
data = dict() | |
for name, obj in self._state_objs.items(): | |
if obj is None: | |
data[name] = None | |
elif isinstance(obj, dict): | |
data[name] = obj | |
elif hasattr(obj, 'state_dict'): | |
data[name] = obj.state_dict() | |
elif hasattr(obj, '__getstate__'): | |
data[name] = obj.__getstate__() | |
elif hasattr(obj, '__dict__'): | |
data[name] = obj.__dict__ | |
else: | |
raise ValueError(f'Invalid state object of type {type(obj).__name__}') | |
if get_rank() == 0: | |
torch.save(data, pt_path) | |
if verbose: | |
print0('done') | |
def load(self, pt_path, verbose=True): | |
if verbose: | |
print0(f'Loading {pt_path} ... ', end='', flush=True) | |
data = torch.load(pt_path, map_location=torch.device('cpu')) | |
for name, obj in self._state_objs.items(): | |
if obj is None: | |
pass | |
elif isinstance(obj, dict): | |
obj.clear() | |
obj.update(data[name]) | |
elif hasattr(obj, 'load_state_dict'): | |
obj.load_state_dict(data[name]) | |
elif hasattr(obj, '__setstate__'): | |
obj.__setstate__(data[name]) | |
elif hasattr(obj, '__dict__'): | |
obj.__dict__.clear() | |
obj.__dict__.update(data[name]) | |
else: | |
raise ValueError(f'Invalid state object of type {type(obj).__name__}') | |
if verbose: | |
print0('done') | |
def load_latest(self, run_dir, pattern=r'training-state-(\d+).pt', verbose=True): | |
fnames = [entry.name for entry in os.scandir(run_dir) if entry.is_file() and re.fullmatch(pattern, entry.name)] | |
if len(fnames) == 0: | |
return None | |
pt_path = os.path.join(run_dir, max(fnames, key=lambda x: float(re.fullmatch(pattern, x).group(1)))) | |
self.load(pt_path, verbose=verbose) | |
return pt_path | |
#---------------------------------------------------------------------------- | |