|
from __future__ import division |
|
from __future__ import print_function |
|
|
|
import argparse |
|
import numpy as np |
|
import torch |
|
|
|
from spatial_correlation_sampler import SpatialCorrelationSampler |
|
|
|
|
|
def check_equal(first, second, verbose): |
|
if verbose: |
|
print() |
|
for i, (x, y) in enumerate(zip(first, second)): |
|
x = x.cpu().detach().numpy() |
|
y = y.cpu().detach().numpy() |
|
if verbose: |
|
print("x = {}".format(x.flatten())) |
|
print("y = {}".format(y.flatten())) |
|
print('-' * 80) |
|
np.testing.assert_allclose(x, y, err_msg="Index: {}".format(i)) |
|
|
|
|
|
def zero_grad(variables): |
|
for variable in variables: |
|
if variable.grad is not None: variable.grad.zero_() |
|
|
|
|
|
def get_grads(variables): |
|
return [var.grad.clone() for var in variables] |
|
|
|
|
|
def check_forward(input1, input2, correlation_sampler, verbose, gpu_index=0): |
|
device = torch.device(f"cuda:{gpu_index}") |
|
|
|
cpu_values = correlation_sampler(input1, input2) |
|
cuda_values = correlation_sampler(input1.to(device), input2.to(device)) |
|
|
|
print(f"Forward: CPU vs. CUDA device:{gpu_index} ... ", end='') |
|
check_equal(cpu_values, cuda_values, verbose) |
|
print('Ok') |
|
|
|
|
|
def check_backward(input1, input2, correlation_sampler, verbose, gpu_index=0): |
|
device = torch.device(f"cuda:{gpu_index}") |
|
|
|
zero_grad([input1, input2]) |
|
|
|
cpu_values = correlation_sampler(input1, input2) |
|
cpu_values.sum().backward() |
|
grad_cpu = get_grads([input1, input2]) |
|
|
|
zero_grad([input1, input2]) |
|
|
|
cuda_values = correlation_sampler(input1.to(device), input2.to(device)) |
|
cuda_values.sum().backward() |
|
grad_cuda = get_grads([input1, input2]) |
|
|
|
print(f"Backward: CPU vs. CUDA device:{gpu_index} ... ", end='') |
|
check_equal(grad_cpu, grad_cuda, verbose) |
|
print('Ok') |
|
|
|
|
|
def check_multi_gpu_forward(correlation_sampler, verbose): |
|
print("Multi-GPU forward") |
|
total_gpus = torch.cuda.device_count() |
|
for gpu in range(total_gpus): |
|
check_forward(input1, input2, correlation_sampler, verbose, gpu_index=gpu) |
|
|
|
def check_multi_gpu_backward(correlation_sampler, verbose): |
|
print("Multi-GPU backward") |
|
total_gpus = torch.cuda.device_count() |
|
for gpu in range(total_gpus): |
|
check_backward(input1, input2, correlation_sampler, verbose, gpu_index=gpu) |
|
|
|
|
|
parser = argparse.ArgumentParser() |
|
parser.add_argument('direction', choices=['forward', 'backward'], nargs='+') |
|
parser.add_argument('-b', '--batch-size', type=int, default=1) |
|
parser.add_argument('-k', '--kernel-size', type=int, default=3) |
|
parser.add_argument('--patch', type=int, default=3) |
|
parser.add_argument('--patch_dilation', type=int, default=2) |
|
parser.add_argument('-c', '--channel', type=int, default=10) |
|
parser.add_argument('--height', type=int, default=10) |
|
parser.add_argument('-w', '--width', type=int, default=10) |
|
parser.add_argument('-s', '--stride', type=int, default=2) |
|
parser.add_argument('-p', '--pad', type=int, default=5) |
|
parser.add_argument('-v', '--verbose', action='store_true', default=False) |
|
parser.add_argument('-d', '--dilation', type=int, default=2) |
|
args = parser.parse_args() |
|
print(args) |
|
|
|
assert(torch.cuda.is_available()), "no comparison to make" |
|
input1 = torch.randn(args.batch_size, |
|
args.channel, |
|
args.height, |
|
args.width).double() |
|
input2 = torch.randn(args.batch_size, |
|
args.channel, |
|
args.height, |
|
args.width).double() |
|
input1.requires_grad = True |
|
input2.requires_grad = True |
|
|
|
correlation_sampler = SpatialCorrelationSampler( |
|
args.kernel_size, |
|
args.patch, |
|
args.stride, |
|
args.pad, |
|
args.dilation, |
|
args.patch_dilation) |
|
|
|
if 'forward' in args.direction: |
|
check_forward(input1, input2, correlation_sampler, args.verbose) |
|
if torch.cuda.device_count() > 1: check_multi_gpu_forward(correlation_sampler, args.verbose) |
|
|
|
if 'backward' in args.direction: |
|
check_backward(input1, input2, correlation_sampler, args.verbose) |
|
if torch.cuda.device_count() > 1: check_multi_gpu_backward(correlation_sampler, args.verbose) |
|
|