File size: 829 Bytes
c985ba4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
import torch


def generate_permute_matrix(dim, num, keep_first=True, gpu_id=0):
    all_matrix = []
    for idx in range(num):
        random_matrix = torch.eye(dim, device=torch.device('cuda', gpu_id))
        if keep_first:
            fg = random_matrix[1:][torch.randperm(dim - 1)]
            random_matrix = torch.cat([random_matrix[0:1], fg], dim=0)
        else:
            random_matrix = random_matrix[torch.randperm(dim)]
        all_matrix.append(random_matrix)
    return torch.stack(all_matrix, dim=0)


def truncated_normal_(tensor, mean=0, std=.02):
    size = tensor.shape
    tmp = tensor.new_empty(size + (4, )).normal_()
    valid = (tmp < 2) & (tmp > -2)
    ind = valid.max(-1, keepdim=True)[1]
    tensor.data.copy_(tmp.gather(-1, ind).squeeze(-1))
    tensor.data.mul_(std).add_(mean)
    return tensor