ychenhq commited on
Commit
fb89f86
·
verified ·
1 Parent(s): 94c4d21

Upload folder using huggingface_hub

Browse files
utils/__pycache__/utils.cpython-310.pyc ADDED
Binary file (3.03 kB). View file
 
utils/utils.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+ import numpy as np
3
+ import cv2
4
+ import torch
5
+ import torch.distributed as dist
6
+
7
+
8
+ def count_params(model, verbose=False):
9
+ total_params = sum(p.numel() for p in model.parameters())
10
+ if verbose:
11
+ print(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.")
12
+ return total_params
13
+
14
+
15
+ def check_istarget(name, para_list):
16
+ """
17
+ name: full name of source para
18
+ para_list: partial name of target para
19
+ """
20
+ istarget=False
21
+ for para in para_list:
22
+ if para in name:
23
+ return True
24
+ return istarget
25
+
26
+
27
+ def instantiate_from_config(config):
28
+ if not "target" in config:
29
+ if config == '__is_first_stage__':
30
+ return None
31
+ elif config == "__is_unconditional__":
32
+ return None
33
+ raise KeyError("Expected key `target` to instantiate.")
34
+ return get_obj_from_str(config["target"])(**config.get("params", dict()))
35
+
36
+
37
+ def get_obj_from_str(string, reload=False):
38
+ module, cls = string.rsplit(".", 1)
39
+ if reload:
40
+ module_imp = importlib.import_module(module)
41
+ importlib.reload(module_imp)
42
+ return getattr(importlib.import_module(module, package=None), cls)
43
+
44
+
45
+ def load_npz_from_dir(data_dir):
46
+ data = [np.load(os.path.join(data_dir, data_name))['arr_0'] for data_name in os.listdir(data_dir)]
47
+ data = np.concatenate(data, axis=0)
48
+ return data
49
+
50
+
51
+ def load_npz_from_paths(data_paths):
52
+ data = [np.load(data_path)['arr_0'] for data_path in data_paths]
53
+ data = np.concatenate(data, axis=0)
54
+ return data
55
+
56
+
57
+ def resize_numpy_image(image, max_resolution=512 * 512, resize_short_edge=None):
58
+ h, w = image.shape[:2]
59
+ if resize_short_edge is not None:
60
+ k = resize_short_edge / min(h, w)
61
+ else:
62
+ k = max_resolution / (h * w)
63
+ k = k**0.5
64
+ h = int(np.round(h * k / 64)) * 64
65
+ w = int(np.round(w * k / 64)) * 64
66
+ image = cv2.resize(image, (w, h), interpolation=cv2.INTER_LANCZOS4)
67
+ return image
68
+
69
+
70
+ def setup_dist(args):
71
+ if dist.is_initialized():
72
+ return
73
+ torch.cuda.set_device(args.local_rank)
74
+ torch.distributed.init_process_group(
75
+ 'nccl',
76
+ init_method='env://'
77
+ )