Spaces:
Running
on
Zero
Running
on
Zero
import wandb | |
import torch | |
from torchvision.utils import make_grid | |
import torch.distributed as dist | |
from PIL import Image | |
import os | |
import argparse | |
import hashlib | |
import math | |
def is_main_process(): | |
return dist.get_rank() == 0 | |
def namespace_to_dict(namespace): | |
return { | |
k: namespace_to_dict(v) if isinstance(v, argparse.Namespace) else v | |
for k, v in vars(namespace).items() | |
} | |
def generate_run_id(exp_name): | |
# https://stackoverflow.com/questions/16008670/how-to-hash-a-string-into-8-digits | |
return str(int(hashlib.sha256(exp_name.encode('utf-8')).hexdigest(), 16) % 10 ** 8) | |
def initialize(args, entity, exp_name, project_name): | |
config_dict = namespace_to_dict(args) | |
wandb.login(key=os.environ["WANDB_KEY"]) | |
wandb.init( | |
entity=entity, | |
project=project_name, | |
name=exp_name, | |
config=config_dict, | |
id=generate_run_id(exp_name), | |
resume="allow", | |
) | |
def log(stats, step=None): | |
if is_main_process(): | |
wandb.log({k: v for k, v in stats.items()}, step=step) | |
def log_image(name, sample, step=None): | |
if is_main_process(): | |
sample = array2grid(sample) | |
wandb.log({f"{name}": wandb.Image(sample), "train_step": step}) | |
def array2grid(x): | |
nrow = round(math.sqrt(x.size(0))) | |
x = make_grid(x, nrow=nrow, normalize=True, value_range=(-1,1)) | |
x = x.mul(255).add_(0.5).clamp_(0,255).permute(1,2,0).to('cpu', torch.uint8).numpy() | |
return x |