Spaces:
Runtime error
Runtime error
from __future__ import division | |
from __future__ import print_function | |
import argparse | |
import numpy as np | |
import torch | |
from spatial_correlation_sampler import SpatialCorrelationSampler | |
def check_equal(first, second, verbose): | |
if verbose: | |
print() | |
for i, (x, y) in enumerate(zip(first, second)): | |
x = x.cpu().detach().numpy() | |
y = y.cpu().detach().numpy() | |
if verbose: | |
print("x = {}".format(x.flatten())) | |
print("y = {}".format(y.flatten())) | |
print('-' * 80) | |
np.testing.assert_allclose(x, y, err_msg="Index: {}".format(i)) | |
def zero_grad(variables): | |
for variable in variables: | |
if variable.grad is not None: variable.grad.zero_() | |
def get_grads(variables): | |
return [var.grad.clone() for var in variables] | |
def check_forward(input1, input2, correlation_sampler, verbose, gpu_index=0): | |
device = torch.device(f"cuda:{gpu_index}") | |
cpu_values = correlation_sampler(input1, input2) | |
cuda_values = correlation_sampler(input1.to(device), input2.to(device)) | |
print(f"Forward: CPU vs. CUDA device:{gpu_index} ... ", end='') | |
check_equal(cpu_values, cuda_values, verbose) | |
print('Ok') | |
def check_backward(input1, input2, correlation_sampler, verbose, gpu_index=0): | |
device = torch.device(f"cuda:{gpu_index}") | |
zero_grad([input1, input2]) | |
cpu_values = correlation_sampler(input1, input2) | |
cpu_values.sum().backward() | |
grad_cpu = get_grads([input1, input2]) | |
zero_grad([input1, input2]) | |
cuda_values = correlation_sampler(input1.to(device), input2.to(device)) | |
cuda_values.sum().backward() | |
grad_cuda = get_grads([input1, input2]) | |
print(f"Backward: CPU vs. CUDA device:{gpu_index} ... ", end='') | |
check_equal(grad_cpu, grad_cuda, verbose) | |
print('Ok') | |
def check_multi_gpu_forward(correlation_sampler, verbose): | |
print("Multi-GPU forward") | |
total_gpus = torch.cuda.device_count() | |
for gpu in range(total_gpus): | |
check_forward(input1, input2, correlation_sampler, verbose, gpu_index=gpu) | |
def check_multi_gpu_backward(correlation_sampler, verbose): | |
print("Multi-GPU backward") | |
total_gpus = torch.cuda.device_count() | |
for gpu in range(total_gpus): | |
check_backward(input1, input2, correlation_sampler, verbose, gpu_index=gpu) | |
parser = argparse.ArgumentParser() | |
parser.add_argument('direction', choices=['forward', 'backward'], nargs='+') | |
parser.add_argument('-b', '--batch-size', type=int, default=1) | |
parser.add_argument('-k', '--kernel-size', type=int, default=3) | |
parser.add_argument('--patch', type=int, default=3) | |
parser.add_argument('--patch_dilation', type=int, default=2) | |
parser.add_argument('-c', '--channel', type=int, default=10) | |
parser.add_argument('--height', type=int, default=10) | |
parser.add_argument('-w', '--width', type=int, default=10) | |
parser.add_argument('-s', '--stride', type=int, default=2) | |
parser.add_argument('-p', '--pad', type=int, default=5) | |
parser.add_argument('-v', '--verbose', action='store_true', default=False) | |
parser.add_argument('-d', '--dilation', type=int, default=2) | |
args = parser.parse_args() | |
print(args) | |
assert(torch.cuda.is_available()), "no comparison to make" | |
input1 = torch.randn(args.batch_size, | |
args.channel, | |
args.height, | |
args.width).double() | |
input2 = torch.randn(args.batch_size, | |
args.channel, | |
args.height, | |
args.width).double() | |
input1.requires_grad = True | |
input2.requires_grad = True | |
correlation_sampler = SpatialCorrelationSampler( | |
args.kernel_size, | |
args.patch, | |
args.stride, | |
args.pad, | |
args.dilation, | |
args.patch_dilation) | |
if 'forward' in args.direction: | |
check_forward(input1, input2, correlation_sampler, args.verbose) | |
if torch.cuda.device_count() > 1: check_multi_gpu_forward(correlation_sampler, args.verbose) | |
if 'backward' in args.direction: | |
check_backward(input1, input2, correlation_sampler, args.verbose) | |
if torch.cuda.device_count() > 1: check_multi_gpu_backward(correlation_sampler, args.verbose) | |