|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
import torch.nn as nn |
|
from torch.autograd import Function |
|
|
|
from ..utils import ext_loader |
|
|
|
ext_module = ext_loader.load_ext('_ext', |
|
['tin_shift_forward', 'tin_shift_backward']) |
|
|
|
|
|
class TINShiftFunction(Function): |
|
|
|
@staticmethod |
|
def forward(ctx, input, shift): |
|
C = input.size(2) |
|
num_segments = shift.size(1) |
|
if C // num_segments <= 0 or C % num_segments != 0: |
|
raise ValueError('C should be a multiple of num_segments, ' |
|
f'but got C={C} and num_segments={num_segments}.') |
|
|
|
ctx.save_for_backward(shift) |
|
|
|
out = torch.zeros_like(input) |
|
ext_module.tin_shift_forward(input, shift, out) |
|
|
|
return out |
|
|
|
@staticmethod |
|
def backward(ctx, grad_output): |
|
|
|
shift = ctx.saved_tensors[0] |
|
data_grad_input = grad_output.new(*grad_output.size()).zero_() |
|
shift_grad_input = shift.new(*shift.size()).zero_() |
|
ext_module.tin_shift_backward(grad_output, shift, data_grad_input) |
|
|
|
return data_grad_input, shift_grad_input |
|
|
|
|
|
tin_shift = TINShiftFunction.apply |
|
|
|
|
|
class TINShift(nn.Module): |
|
"""Temporal Interlace Shift. |
|
|
|
Temporal Interlace shift is a differentiable temporal-wise frame shifting |
|
which is proposed in "Temporal Interlacing Network" |
|
|
|
Please refer to https://arxiv.org/abs/2001.06499 for more details. |
|
Code is modified from https://github.com/mit-han-lab/temporal-shift-module |
|
""" |
|
|
|
def forward(self, input, shift): |
|
"""Perform temporal interlace shift. |
|
|
|
Args: |
|
input (Tensor): Feature map with shape [N, num_segments, C, H * W]. |
|
shift (Tensor): Shift tensor with shape [N, num_segments]. |
|
|
|
Returns: |
|
Feature map after temporal interlace shift. |
|
""" |
|
return tin_shift(input, shift) |
|
|