File size: 1,801 Bytes
3b96cb1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Tuple

import torch
from mmengine.dist import all_gather, broadcast, get_rank


@torch.no_grad()
def batch_shuffle_ddp(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
    """Batch shuffle, for making use of BatchNorm.

    Args:
        x (torch.Tensor): Data in each GPU.

    Returns:
        Tuple[torch.Tensor, torch.Tensor]: Output of shuffle operation.
            - x_gather[idx_this]: Shuffled data.
            - idx_unshuffle: Index for restoring.
    """
    # gather from all gpus
    batch_size_this = x.shape[0]
    x_gather = torch.cat(all_gather(x), dim=0)
    batch_size_all = x_gather.shape[0]

    num_gpus = batch_size_all // batch_size_this

    # random shuffle index
    idx_shuffle = torch.randperm(batch_size_all)

    # broadcast to all gpus
    broadcast(idx_shuffle, src=0)

    # index for restoring
    idx_unshuffle = torch.argsort(idx_shuffle)

    # shuffled index for this gpu
    gpu_idx = get_rank()
    idx_this = idx_shuffle.view(num_gpus, -1)[gpu_idx]

    return x_gather[idx_this], idx_unshuffle


@torch.no_grad()
def batch_unshuffle_ddp(x: torch.Tensor,
                        idx_unshuffle: torch.Tensor) -> torch.Tensor:
    """Undo batch shuffle.

    Args:
        x (torch.Tensor): Data in each GPU.
        idx_unshuffle (torch.Tensor): Index for restoring.

    Returns:
        torch.Tensor: Output of unshuffle operation.
    """
    # gather from all gpus
    batch_size_this = x.shape[0]
    x_gather = torch.cat(all_gather(x), dim=0)
    batch_size_all = x_gather.shape[0]

    num_gpus = batch_size_all // batch_size_this

    # restored index for this gpu
    gpu_idx = get_rank()
    idx_this = idx_unshuffle.view(num_gpus, -1)[gpu_idx]

    return x_gather[idx_this]