import torch | |
def generate_permute_matrix(dim, num, keep_first=True, gpu_id=0): | |
all_matrix = [] | |
for idx in range(num): | |
random_matrix = torch.eye(dim, device=torch.device('cuda', gpu_id)) | |
if keep_first: | |
fg = random_matrix[1:][torch.randperm(dim - 1)] | |
random_matrix = torch.cat([random_matrix[0:1], fg], dim=0) | |
else: | |
random_matrix = random_matrix[torch.randperm(dim)] | |
all_matrix.append(random_matrix) | |
return torch.stack(all_matrix, dim=0) | |
def truncated_normal_(tensor, mean=0, std=.02): | |
size = tensor.shape | |
tmp = tensor.new_empty(size + (4, )).normal_() | |
valid = (tmp < 2) & (tmp > -2) | |
ind = valid.max(-1, keepdim=True)[1] | |
tensor.data.copy_(tmp.gather(-1, ind).squeeze(-1)) | |
tensor.data.mul_(std).add_(mean) | |
return tensor | |