Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
import torch | |
def channel_shuffle(x, groups): | |
"""Channel Shuffle operation. | |
This function enables cross-group information flow for multiple groups | |
convolution layers. | |
Args: | |
x (Tensor): The input tensor. | |
groups (int): The number of groups to divide the input tensor | |
in the channel dimension. | |
Returns: | |
Tensor: The output tensor after channel shuffle operation. | |
""" | |
batch_size, num_channels, height, width = x.size() | |
assert (num_channels % groups == 0), ('num_channels should be ' | |
'divisible by groups') | |
channels_per_group = num_channels // groups | |
x = x.view(batch_size, groups, channels_per_group, height, width) | |
x = torch.transpose(x, 1, 2).contiguous() | |
x = x.view(batch_size, -1, height, width) | |
return x | |