Infinity / utils /wandb_utils.py
MohamedRashad's picture
Add initial project structure with requirements and utility functions
32287b3
import wandb
import torch
from torchvision.utils import make_grid
import torch.distributed as dist
from PIL import Image
import os
import argparse
import hashlib
import math
def is_main_process():
return dist.get_rank() == 0
def namespace_to_dict(namespace):
return {
k: namespace_to_dict(v) if isinstance(v, argparse.Namespace) else v
for k, v in vars(namespace).items()
}
def generate_run_id(exp_name):
# https://stackoverflow.com/questions/16008670/how-to-hash-a-string-into-8-digits
return str(int(hashlib.sha256(exp_name.encode('utf-8')).hexdigest(), 16) % 10 ** 8)
def initialize(args, entity, exp_name, project_name):
config_dict = namespace_to_dict(args)
wandb.login(key=os.environ["WANDB_KEY"])
wandb.init(
entity=entity,
project=project_name,
name=exp_name,
config=config_dict,
id=generate_run_id(exp_name),
resume="allow",
)
def log(stats, step=None):
if is_main_process():
wandb.log({k: v for k, v in stats.items()}, step=step)
def log_image(name, sample, step=None):
if is_main_process():
sample = array2grid(sample)
wandb.log({f"{name}": wandb.Image(sample), "train_step": step})
def array2grid(x):
nrow = round(math.sqrt(x.size(0)))
x = make_grid(x, nrow=nrow, normalize=True, value_range=(-1,1))
x = x.mul(255).add_(0.5).clamp_(0,255).permute(1,2,0).to('cpu', torch.uint8).numpy()
return x