import os import sys import json import argparse import numpy as np import math from einops import rearrange import time import random import h5py from tqdm import tqdm import webdataset as wds import gc import matplotlib.pyplot as plt import torch import torch.nn as nn from torchvision import transforms from torchvision.transforms import ToPILImage #CHANGED (added) from accelerate import Accelerator, DeepSpeedPlugin # tf32 data type is faster than standard float32 torch.backends.cuda.matmul.allow_tf32 = True # custom functions # import utils global_batch_size = 128 #128 ### Multi-GPU config ### local_rank = os.getenv('RANK') if local_rank is None: local_rank = 0 else: local_rank = int(local_rank) print("LOCAL RANK ", local_rank) num_devices = torch.cuda.device_count() if num_devices==0: num_devices = 1 accelerator = Accelerator(split_batches=False) ### UNCOMMENT BELOW STUFF TO USE DEEPSPEED (also comment out the immediately above "accelerator = " line) ### # if num_devices <= 1 and utils.is_interactive(): # # can emulate a distributed environment for deepspeed to work in jupyter notebook # os.environ["MASTER_ADDR"] = "localhost" # os.environ["MASTER_PORT"] = str(np.random.randint(10000)+9000) # os.environ["RANK"] = "0" # os.environ["LOCAL_RANK"] = "0" # os.environ["WORLD_SIZE"] = "1" # os.environ["GLOBAL_BATCH_SIZE"] = str(global_batch_size) # set this to your batch size! # global_batch_size = os.environ["GLOBAL_BATCH_SIZE"] # # alter the deepspeed config according to your global and local batch size # if local_rank == 0: # with open('deepspeed_config_stage2.json', 'r') as file: # config = json.load(file) # config['train_batch_size'] = int(os.environ["GLOBAL_BATCH_SIZE"]) # config['train_micro_batch_size_per_gpu'] = int(os.environ["GLOBAL_BATCH_SIZE"]) // num_devices # with open('deepspeed_config_stage2.json', 'w') as file: # json.dump(config, file) # else: # # give some time for the local_rank=0 gpu to prep new deepspeed config file # time.sleep(10) # deepspeed_plugin = DeepSpeedPlugin("deepspeed_config_stage2.json") # accelerator = Accelerator(split_batches=False, deepspeed_plugin=deepspeed_plugin)