|
from __future__ import division |
|
from __future__ import print_function |
|
|
|
import argparse |
|
import time |
|
|
|
import torch |
|
from spatial_correlation_sampler import SpatialCorrelationSampler |
|
from tqdm import trange |
|
|
|
TIME_SCALES = {'s': 1, 'ms': 1000, 'us': 1000000} |
|
|
|
parser = argparse.ArgumentParser() |
|
parser.add_argument('backend', choices=['cpu', 'cuda'], default='cuda') |
|
parser.add_argument('-b', '--batch-size', type=int, default=16) |
|
parser.add_argument('-k', '--kernel-size', type=int, default=3) |
|
parser.add_argument('--patch', type=int, default=3) |
|
parser.add_argument('--patch_dilation', type=int, default=2) |
|
parser.add_argument('-c', '--channel', type=int, default=64) |
|
parser.add_argument('--height', type=int, default=100) |
|
parser.add_argument('-w', '--width', type=int, default=100) |
|
parser.add_argument('-s', '--stride', type=int, default=2) |
|
parser.add_argument('-p', '--pad', type=int, default=1) |
|
parser.add_argument('--scale', choices=['s', 'ms', 'us'], default='us') |
|
parser.add_argument('-r', '--runs', type=int, default=100) |
|
parser.add_argument('--dilation', type=int, default=2) |
|
parser.add_argument('-d', '--dtype', choices=['half', 'float', 'double']) |
|
|
|
args = parser.parse_args() |
|
|
|
device = torch.device(args.backend) |
|
|
|
if args.dtype == 'half': |
|
dtype = torch.float16 |
|
elif args.dtype == 'float': |
|
dtype = torch.float32 |
|
else: |
|
dtype = torch.float64 |
|
|
|
|
|
input1 = torch.randn(args.batch_size, |
|
args.channel, |
|
args.height, |
|
args.width, |
|
dtype=dtype, |
|
device=device, |
|
requires_grad=True) |
|
input2 = torch.randn_like(input1) |
|
|
|
correlation_sampler = SpatialCorrelationSampler( |
|
args.kernel_size, |
|
args.patch, |
|
args.stride, |
|
args.pad, |
|
args.dilation, |
|
args.patch_dilation) |
|
|
|
|
|
output = correlation_sampler(input1, input2) |
|
print(output.size()) |
|
output.mean().backward() |
|
forward_min = float('inf') |
|
forward_time = 0 |
|
backward_min = float('inf') |
|
backward_time = 0 |
|
for _ in trange(args.runs): |
|
correlation_sampler.zero_grad() |
|
|
|
start = time.time() |
|
output = correlation_sampler(input1, input2) |
|
elapsed = time.time() - start |
|
forward_min = min(forward_min, elapsed) |
|
forward_time += elapsed |
|
output = output.mean() |
|
|
|
start = time.time() |
|
(output.mean()).backward() |
|
elapsed = time.time() - start |
|
backward_min = min(backward_min, elapsed) |
|
backward_time += elapsed |
|
|
|
scale = TIME_SCALES[args.scale] |
|
forward_min *= scale |
|
backward_min *= scale |
|
forward_average = forward_time / args.runs * scale |
|
backward_average = backward_time / args.runs * scale |
|
|
|
print('Forward: {0:.3f}/{1:.3f} {4} | Backward {2:.3f}/{3:.3f} {4}'.format( |
|
forward_min, forward_average, backward_min, backward_average, |
|
args.scale)) |
|
|