File size: 15,593 Bytes
f7009b3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 |
import torch
import einops
from torch.utils.data import Dataset
from torchvision.datasets import CIFAR10
from torchvision import transforms
import os
import math
import random
import json
from abc import ABC
import pickle
def pad_to_length(x, common_factor, **config):
if x.numel() % common_factor == 0:
return x.flatten()
# print(f"padding {x.shape} according to {common_factor}")
full_length = (x.numel() // common_factor + 1) * common_factor
padding_length = full_length - len(x.flatten())
padding = torch.full([padding_length, ], dtype=x.dtype, device=x.device, fill_value=config["fill_value"])
x = torch.cat((x.flatten(), padding), dim=0)
return x
def layer_to_token(x, common_factor, **config):
if config["granularity"] == 2: # split by output
if x.numel() <= common_factor:
return pad_to_length(x.flatten(), common_factor, **config)[None]
dim2 = x[0].numel()
dim1 = x.shape[0]
if dim2 <= common_factor:
i = int(dim1 / (common_factor / dim2))
while True:
if dim1 % i == 0 and dim2 * (dim1 // i) <= common_factor:
output = x.view(-1, dim2 * (dim1 // i))
output = [pad_to_length(item, common_factor, **config) for item in output]
return torch.stack(output, dim=0)
i += 1
else: # dim2 > common_factor
output = [layer_to_token(item, common_factor, **config) for item in x]
return torch.cat(output, dim=0)
elif config["granularity"] == 1: # split by layer
return pad_to_length(x.flatten(), common_factor, **config).view(-1, common_factor)
elif config["granularity"] == 0: # flatten directly
return x.flatten()
else: # NotImplementedError
raise NotImplementedError("granularity: 0: flatten directly, 1: split by layer, 2: split by output dim")
def token_to_layer(tokens, shape, **config):
common_factor = tokens.shape[-1]
if config["granularity"] == 2: # split by output
num_element = math.prod(shape)
if num_element <= common_factor:
param = tokens[0][:num_element].view(shape)
tokens = tokens[1:]
return param, tokens
dim2 = num_element // shape[0]
dim1 = shape[0]
if dim2 <= common_factor:
i = int(dim1 / (common_factor / dim2))
while True:
if dim1 % i == 0 and dim2 * (dim1 // i) <= common_factor:
item_per_token = dim2 * (dim1 // i)
length = num_element // item_per_token
output = [item[:item_per_token] for item in tokens[:length]]
param = torch.cat(output, dim=0).view(shape)
tokens = tokens[length:]
return param, tokens
i += 1
else: # dim2 > common_factor
output = []
for i in range(shape[0]):
param, tokens = token_to_layer(tokens, shape[1:], **config)
output.append(param.flatten())
param = torch.cat(output, dim=0).view(shape)
return param, tokens
elif config["granularity"] == 1: # split by layer
num_element = math.prod(shape)
token_num = num_element // common_factor if num_element % common_factor == 0 \
else num_element // common_factor + 1
param = tokens.flatten()[:num_element].view(shape)
tokens = tokens[token_num:]
return param, tokens
elif config["granularity"] == 0: # flatten directly
num_element = math.prod(shape)
param = tokens.flatten()[:num_element].view(shape)
tokens = pad_to_length(tokens.flatten()[num_element:],
common_factor, fill_value=torch.nan).view(-1, common_factor)
return param, tokens
else: # NotImplementedError
raise NotImplementedError("granularity: 0: flatten directly, 1: split by layer, 2: split by output dim")
def positional_embedding_2d(dim1, dim2, d_model):
assert d_model % 4 == 0, f"Cannot use sin/cos positional encoding with odd dimension {d_model}"
pe = torch.zeros(d_model, dim1, dim2)
d_model = int(d_model / 2) # Each dimension use half of d_model
div_term = torch.exp(torch.arange(0., d_model, 2, dtype=torch.float32) * -(math.log(10000.0) / d_model))
pos_w = torch.arange(0., dim2).unsqueeze(1)
pos_h = torch.arange(0., dim1).unsqueeze(1)
pe[0:d_model:2, :, :] = torch.sin(pos_w * div_term).transpose(0, 1).unsqueeze(1).repeat(1, dim1, 1)
pe[1:d_model:2, :, :] = torch.cos(pos_w * div_term).transpose(0, 1).unsqueeze(1).repeat(1, dim1, 1)
pe[d_model::2, :, :] = torch.sin(pos_h * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, dim2)
pe[d_model+1::2, :, :] = torch.cos(pos_h * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, dim2)
return pe.permute(1, 2, 0)
def positional_embedding_1d(dim1, d_model):
pe = torch.zeros(dim1, d_model)
position = torch.arange(0, dim1, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
return pe
class BaseDataset(Dataset, ABC):
data_path = None
generated_path = None
test_command = None
config = {
"fill_value": torch.nan,
"granularity": 1, # 0: flatten directly, 1: split by layer, 2: split by output
"pe_granularity": 2, # 0: no embedding, 1: 1d embedding, 2: 2d embedding
}
def __init__(self, checkpoint_path=None, dim_per_token=8192, **kwargs):
if not os.path.exists(self.data_path):
os.makedirs(self.data_path, exist_ok=False)
if self.generated_path is not None and not os.path.exists(os.path.dirname(self.generated_path)):
os.makedirs(os.path.dirname(self.generated_path))
self.config.update(kwargs)
checkpoint_path = self.data_path if checkpoint_path is None else checkpoint_path
assert os.path.exists(checkpoint_path)
self.dim_per_token = dim_per_token
self.structure = None # set in get_structure()
self.sequence_length = None # set in get_structure()
# load checkpoint_list
checkpoint_list = os.listdir(checkpoint_path)
self.checkpoint_list = list([os.path.join(checkpoint_path, item) for item in checkpoint_list])
self.length = self.real_length = len(self.checkpoint_list)
self.set_infinite_dataset()
# get structure
structure_cache_file = os.path.join(os.path.dirname(self.data_path), "structure.cache")
try: # try to load cache file
assert os.path.exists(structure_cache_file)
with open(structure_cache_file, "rb") as f:
print(f"Loading cache from {structure_cache_file}")
cache_file = pickle.load(f)
if len(self.checkpoint_list) != 0:
assert set(cache_file["checkpoint_list"]) == set(self.checkpoint_list)
self.structure = cache_file["structure"]
else: # empty checkpoint_list, only generate
print("Cannot find any trained checkpoint, loading cache file for generating!")
self.structure = cache_file["structure"]
fake_diction = {key: torch.zeros(item[0]) for key, item in self.structure.items()}
torch.save(fake_diction, os.path.join(checkpoint_path, "fake_checkpoint.pth"))
self.checkpoint_list.append(os.path.join(checkpoint_path, "fake_checkpoint.pth"))
self.length = self.real_length = len(self.checkpoint_list)
self.set_infinite_dataset()
os.system(f"rm {os.path.join(checkpoint_path, 'fake_checkpoint.pth')}")
except AssertionError: # recompute cache file
print("==> Organizing structure..")
self.structure = self.get_structure()
with open(structure_cache_file, "wb") as f:
pickle.dump({"structure": self.structure, "checkpoint_list": self.checkpoint_list}, f)
# get sequence_length
self.sequence_length = self.get_sequence_length()
def get_sequence_length(self):
fake_diction = {key: torch.zeros(item[0]) for key, item in self.structure.items()}
# get sequence_length
param = self.preprocess(fake_diction)
self.sequence_length = param.size(0)
return self.sequence_length
def get_structure(self):
# get structure
checkpoint_list = self.checkpoint_list
structures = [{} for _ in range(len(checkpoint_list))]
for i, checkpoint in enumerate(checkpoint_list):
diction = torch.load(checkpoint, map_location="cpu")
for key, value in diction.items():
if ("num_batches_tracked" in key) or (value.numel() == 1) or not torch.is_floating_point(value):
structures[i][key] = (value.shape, value, None)
elif "running_var" in key:
pre_mean = value.mean() * 0.95
value = torch.log(value / pre_mean + 0.05)
structures[i][key] = (value.shape, pre_mean, value.mean(), value.std())
else: # conv & linear
structures[i][key] = (value.shape, value.mean(), value.std())
final_structure = {}
structure_diction = torch.load(checkpoint_list[0], map_location="cpu")
for key, param in structure_diction.items():
if ("num_batches_tracked" in key) or (param.numel() == 1) or not torch.is_floating_point(param):
final_structure[key] = (param.shape, param, None)
elif "running_var" in key:
value = [param.shape, 0., 0., 0.]
for structure in structures:
for i in [1, 2, 3]:
value[i] += structure[key][i]
for i in [1, 2, 3]:
value[i] /= len(structures)
final_structure[key] = tuple(value)
else: # conv & linear
value = [param.shape, 0., 0.]
for structure in structures:
for i in [1, 2]:
value[i] += structure[key][i]
for i in [1, 2]:
value[i] /= len(structures)
final_structure[key] = tuple(value)
self.structure = final_structure
return self.structure
def set_infinite_dataset(self, max_num=None):
if max_num is None:
max_num = self.length * 1000000
self.length = max_num
return self
@property
def max_permutation_state(self):
return self.real_length
def get_position_embedding(self, positional_embedding_dim=None):
if positional_embedding_dim is None:
positional_embedding_dim = self.dim_per_token // 2
assert self.structure is not None, "run get_structure before get_position_embedding"
if self.config["pe_granularity"] == 2:
print("Use 2d positional embedding")
positional_embedding_index = []
for key, item in self.structure.items():
if ("num_batches_tracked" in key) or (item[-1] is None):
continue
else: # conv & linear
shape, *_ = item
fake_param = torch.ones(size=shape)
fake_param = layer_to_token(fake_param, self.dim_per_token, **self.config)
positional_embedding_index.append(list(range(fake_param.size(0))))
dim1 = len(positional_embedding_index)
dim2 = max([len(token_per_layer) for token_per_layer in positional_embedding_index])
full_pe = positional_embedding_2d(dim1, dim2, positional_embedding_dim)
positional_embedding = []
for layer_index, token_indexes in enumerate(positional_embedding_index):
for token_index in token_indexes:
this_pe = full_pe[layer_index, token_index]
positional_embedding.append(this_pe)
positional_embedding = torch.stack(positional_embedding)
return positional_embedding
elif self.config["pe_granularity"] == 1:
print("Use 1d positional embedding")
return positional_embedding_1d(self.sequence_length, positional_embedding_dim)
elif self.config["pe_granularity"] == 0:
print("Not use positional embedding")
return torch.zeros_like(self.__getitem__(0))
else: # NotImplementedError
raise NotImplementedError("pe_granularity: 0: no embedding, 1: 1d embedding, 2: 2d embedding")
def __len__(self):
return self.length
def __getitem__(self, index):
index = index % self.real_length
diction = torch.load(self.checkpoint_list[index], map_location="cpu")
param = self.preprocess(diction)
return param, index
def save_params(self, params, save_path):
diction = self.postprocess(params.cpu().to(torch.float32))
torch.save(diction, save_path)
def preprocess(self, diction: dict, **kwargs) -> torch.Tensor:
param_list = []
for key, value in diction.items():
if ("num_batches_tracked" in key) or (value.numel() == 1) or not torch.is_floating_point(value):
continue
elif "running_var" in key:
shape, pre_mean, mean, std = self.structure[key]
value = torch.log(value / pre_mean + 0.05)
else: # normal
shape, mean, std = self.structure[key]
value = (value - mean) / std
value = layer_to_token(value, self.dim_per_token, **self.config)
param_list.append(value)
param = torch.cat(param_list, dim=0)
if self.config["granularity"] == 0: # padding directly process tail
param = pad_to_length(param, self.dim_per_token, **self.config).view(-1, self.dim_per_token)
# print("Sequence length:", param.size(0))
return param.to(torch.float32)
def postprocess(self, params: torch.Tensor, **kwargs) -> dict:
diction = {}
params = params if len(params.shape) == 2 else params.squeeze(0)
for key, item in self.structure.items():
if ("num_batches_tracked" in key) or (item[-1] is None):
shape, mean, std = item
diction[key] = mean
continue
elif "running_var" in key:
shape, pre_mean, mean, std = item
else: # conv & linear
shape, mean, std = item
this_param, params = token_to_layer(params, shape, **self.config)
this_param = this_param * std + mean
if "running_var" in key:
this_param = torch.clip(torch.exp(this_param) - 0.05, min=0.001) * pre_mean
diction[key] = this_param
return diction
class ConditionalDataset(BaseDataset, ABC):
def _extract_condition(self, index: int):
name = self.checkpoint_list[index]
condition_list = os.path.basename(name).split("_")
return condition_list
def __getitem__(self, index):
index = index % self.real_length
diction = torch.load(self.checkpoint_list[index], map_location="cpu")
condition = self._extract_condition(index)
param = self.preprocess(diction)
return param, condition |