Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from mmcv.cnn.bricks.wrappers import NewEmptyTensorOp, obsolete_torch_version | |
if torch.__version__ == 'parrots': | |
TORCH_VERSION = torch.__version__ | |
else: | |
# torch.__version__ could be 1.3.1+cu92, we only need the first two | |
# for comparison | |
TORCH_VERSION = tuple(int(x) for x in torch.__version__.split('.')[:2]) | |
def adaptive_avg_pool2d(input, output_size): | |
"""Handle empty batch dimension to adaptive_avg_pool2d. | |
Args: | |
input (tensor): 4D tensor. | |
output_size (int, tuple[int,int]): the target output size. | |
""" | |
if input.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 9)): | |
if isinstance(output_size, int): | |
output_size = [output_size, output_size] | |
output_size = [*input.shape[:2], *output_size] | |
empty = NewEmptyTensorOp.apply(input, output_size) | |
return empty | |
else: | |
return F.adaptive_avg_pool2d(input, output_size) | |
class AdaptiveAvgPool2d(nn.AdaptiveAvgPool2d): | |
"""Handle empty batch dimension to AdaptiveAvgPool2d.""" | |
def forward(self, x): | |
# PyTorch 1.9 does not support empty tensor inference yet | |
if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 9)): | |
output_size = self.output_size | |
if isinstance(output_size, int): | |
output_size = [output_size, output_size] | |
else: | |
output_size = [ | |
v if v is not None else d | |
for v, d in zip(output_size, | |
x.size()[-2:]) | |
] | |
output_size = [*x.shape[:2], *output_size] | |
empty = NewEmptyTensorOp.apply(x, output_size) | |
return empty | |
return super().forward(x) | |