Spaces:
Runtime error
Runtime error
File size: 1,895 Bytes
92bea3e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 |
import argparse
import torch
# torch.set_printoptions(precision=1, threshold=10000)
from torch.autograd import gradcheck
from spatial_correlation_sampler import SpatialCorrelationSampler
parser = argparse.ArgumentParser()
parser.add_argument('backend', choices=['cpu', 'cuda'], default='cuda')
parser.add_argument('-b', '--batch-size', type=int, default=2)
parser.add_argument('-k', '--kernel-size', type=int, default=3)
parser.add_argument('--patch', type=int, default=3)
parser.add_argument('--patch_dilation', type=int, default=2)
parser.add_argument('-c', '--channel', type=int, default=2)
parser.add_argument('--height', type=int, default=10)
parser.add_argument('-w', '--width', type=int, default=10)
parser.add_argument('-s', '--stride', type=int, default=2)
parser.add_argument('-p', '--pad', type=int, default=1)
parser.add_argument('-d', '--dilation', type=int, default=2)
args = parser.parse_args()
input1 = torch.randn(args.batch_size,
args.channel,
args.height,
args.width,
dtype=torch.float64,
device=torch.device(args.backend))
input2 = torch.randn(args.batch_size,
args.channel,
args.height,
args.width,
dtype=torch.float64,
device=torch.device(args.backend))
input1.requires_grad = True
input2.requires_grad = True
correlation_sampler = SpatialCorrelationSampler(args.kernel_size,
args.patch,
args.stride,
args.pad,
args.dilation,
args.patch_dilation)
if gradcheck(correlation_sampler, [input1, input2]):
print('Ok')
|