Spaces:
Sleeping
Sleeping
# Copyright (c) 2023, Tri Dao. | |
""" Useful functions for writing test code. """ | |
import torch | |
import torch.utils.benchmark as benchmark | |
def benchmark_forward( | |
fn, *inputs, repeats=10, desc="", verbose=True, amp=False, amp_dtype=torch.float16, **kwinputs | |
): | |
"""Use Pytorch Benchmark on the forward pass of an arbitrary function.""" | |
if verbose: | |
print(desc, "- Forward pass") | |
def amp_wrapper(*inputs, **kwinputs): | |
with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp): | |
fn(*inputs, **kwinputs) | |
t = benchmark.Timer( | |
stmt="fn_amp(*inputs, **kwinputs)", | |
globals={"fn_amp": amp_wrapper, "inputs": inputs, "kwinputs": kwinputs}, | |
num_threads=torch.get_num_threads(), | |
) | |
m = t.timeit(repeats) | |
if verbose: | |
print(m) | |
return t, m | |
def benchmark_backward( | |
fn, | |
*inputs, | |
grad=None, | |
repeats=10, | |
desc="", | |
verbose=True, | |
amp=False, | |
amp_dtype=torch.float16, | |
**kwinputs, | |
): | |
"""Use Pytorch Benchmark on the backward pass of an arbitrary function.""" | |
if verbose: | |
print(desc, "- Backward pass") | |
with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp): | |
y = fn(*inputs, **kwinputs) | |
if type(y) is tuple: | |
y = y[0] | |
if grad is None: | |
grad = torch.randn_like(y) | |
else: | |
if grad.shape != y.shape: | |
raise RuntimeError("Grad shape does not match output shape") | |
def f(*inputs, y, grad): | |
# Set .grad to None to avoid extra operation of gradient accumulation | |
for x in inputs: | |
if isinstance(x, torch.Tensor): | |
x.grad = None | |
y.backward(grad, retain_graph=True) | |
t = benchmark.Timer( | |
stmt="f(*inputs, y=y, grad=grad)", | |
globals={"f": f, "inputs": inputs, "y": y, "grad": grad}, | |
num_threads=torch.get_num_threads(), | |
) | |
m = t.timeit(repeats) | |
if verbose: | |
print(m) | |
return t, m | |
def benchmark_combined( | |
fn, | |
*inputs, | |
grad=None, | |
repeats=10, | |
desc="", | |
verbose=True, | |
amp=False, | |
amp_dtype=torch.float16, | |
**kwinputs, | |
): | |
"""Use Pytorch Benchmark on the forward+backward pass of an arbitrary function.""" | |
if verbose: | |
print(desc, "- Forward + Backward pass") | |
with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp): | |
y = fn(*inputs, **kwinputs) | |
if type(y) is tuple: | |
y = y[0] | |
if grad is None: | |
grad = torch.randn_like(y) | |
else: | |
if grad.shape != y.shape: | |
raise RuntimeError("Grad shape does not match output shape") | |
def f(grad, *inputs, **kwinputs): | |
for x in inputs: | |
if isinstance(x, torch.Tensor): | |
x.grad = None | |
with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp): | |
y = fn(*inputs, **kwinputs) | |
if type(y) is tuple: | |
y = y[0] | |
y.backward(grad, retain_graph=True) | |
t = benchmark.Timer( | |
stmt="f(grad, *inputs, **kwinputs)", | |
globals={"f": f, "fn": fn, "inputs": inputs, "grad": grad, "kwinputs": kwinputs}, | |
num_threads=torch.get_num_threads(), | |
) | |
m = t.timeit(repeats) | |
if verbose: | |
print(m) | |
return t, m | |
def benchmark_fwd_bwd( | |
fn, | |
*inputs, | |
grad=None, | |
repeats=10, | |
desc="", | |
verbose=True, | |
amp=False, | |
amp_dtype=torch.float16, | |
**kwinputs, | |
): | |
"""Use Pytorch Benchmark on the forward+backward pass of an arbitrary function.""" | |
return ( | |
benchmark_forward( | |
fn, | |
*inputs, | |
repeats=repeats, | |
desc=desc, | |
verbose=verbose, | |
amp=amp, | |
amp_dtype=amp_dtype, | |
**kwinputs, | |
), | |
benchmark_backward( | |
fn, | |
*inputs, | |
grad=grad, | |
repeats=repeats, | |
desc=desc, | |
verbose=verbose, | |
amp=amp, | |
amp_dtype=amp_dtype, | |
**kwinputs, | |
), | |
) | |
def benchmark_all( | |
fn, | |
*inputs, | |
grad=None, | |
repeats=10, | |
desc="", | |
verbose=True, | |
amp=False, | |
amp_dtype=torch.float16, | |
**kwinputs, | |
): | |
"""Use Pytorch Benchmark on the forward+backward pass of an arbitrary function.""" | |
return ( | |
benchmark_forward( | |
fn, | |
*inputs, | |
repeats=repeats, | |
desc=desc, | |
verbose=verbose, | |
amp=amp, | |
amp_dtype=amp_dtype, | |
**kwinputs, | |
), | |
benchmark_backward( | |
fn, | |
*inputs, | |
grad=grad, | |
repeats=repeats, | |
desc=desc, | |
verbose=verbose, | |
amp=amp, | |
amp_dtype=amp_dtype, | |
**kwinputs, | |
), | |
benchmark_combined( | |
fn, | |
*inputs, | |
grad=grad, | |
repeats=repeats, | |
desc=desc, | |
verbose=verbose, | |
amp=amp, | |
amp_dtype=amp_dtype, | |
**kwinputs, | |
), | |
) | |
def pytorch_profiler( | |
fn, | |
*inputs, | |
trace_filename=None, | |
backward=False, | |
amp=False, | |
amp_dtype=torch.float16, | |
cpu=False, | |
verbose=True, | |
**kwinputs, | |
): | |
"""Wrap benchmark functions in Pytorch profiler to see CUDA information.""" | |
if backward: | |
with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp): | |
out = fn(*inputs, **kwinputs) | |
if type(out) is tuple: | |
out = out[0] | |
g = torch.randn_like(out) | |
for _ in range(30): # Warm up | |
if backward: | |
for x in inputs: | |
if isinstance(x, torch.Tensor): | |
x.grad = None | |
with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp): | |
out = fn(*inputs, **kwinputs) | |
if type(out) is tuple: | |
out = out[0] | |
# Backward should be done outside autocast | |
if backward: | |
out.backward(g, retain_graph=True) | |
activities = ([torch.profiler.ProfilerActivity.CPU] if cpu else []) + [ | |
torch.profiler.ProfilerActivity.CUDA | |
] | |
with torch.profiler.profile( | |
activities=activities, | |
record_shapes=True, | |
# profile_memory=True, | |
with_stack=True, | |
) as prof: | |
if backward: | |
for x in inputs: | |
if isinstance(x, torch.Tensor): | |
x.grad = None | |
with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp): | |
out = fn(*inputs, **kwinputs) | |
if type(out) is tuple: | |
out = out[0] | |
if backward: | |
out.backward(g, retain_graph=True) | |
if verbose: | |
# print(prof.key_averages().table(sort_by="self_cuda_time_total", row_limit=50)) | |
print(prof.key_averages().table(row_limit=50)) | |
if trace_filename is not None: | |
prof.export_chrome_trace(trace_filename) | |
def benchmark_memory(fn, *inputs, desc="", verbose=True, **kwinputs): | |
torch.cuda.empty_cache() | |
torch.cuda.reset_peak_memory_stats() | |
torch.cuda.synchronize() | |
fn(*inputs, **kwinputs) | |
torch.cuda.synchronize() | |
mem = torch.cuda.max_memory_allocated() / ((2**20) * 1000) | |
if verbose: | |
print(f"{desc} max memory: {mem}GB") | |
torch.cuda.empty_cache() | |
return mem | |