Spaces:
Runtime error
Runtime error
// Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
// | |
// NVIDIA CORPORATION and its licensors retain all intellectual property | |
// and proprietary rights in and to this software, related documentation | |
// and any modifications thereto. Any use, reproduction, disclosure or | |
// distribution of this software and related documentation without an express | |
// license agreement from NVIDIA CORPORATION is strictly prohibited. | |
//------------------------------------------------------------------------ | |
static std::tuple<torch::Tensor, torch::Tensor, int> filtered_lrelu( | |
torch::Tensor x, torch::Tensor fu, torch::Tensor fd, torch::Tensor b, torch::Tensor si, | |
int up, int down, int px0, int px1, int py0, int py1, int sx, int sy, float gain, float slope, float clamp, bool flip_filters, bool writeSigns) | |
{ | |
// Set CUDA device. | |
TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device"); | |
const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); | |
// Validate arguments. | |
TORCH_CHECK(fu.device() == x.device() && fd.device() == x.device() && b.device() == x.device(), "all input tensors must reside on the same device"); | |
TORCH_CHECK(fu.dtype() == torch::kFloat && fd.dtype() == torch::kFloat, "fu and fd must be float32"); | |
TORCH_CHECK(b.dtype() == x.dtype(), "x and b must have the same dtype"); | |
TORCH_CHECK(x.dtype() == torch::kHalf || x.dtype() == torch::kFloat, "x and b must be float16 or float32"); | |
TORCH_CHECK(x.dim() == 4, "x must be rank 4"); | |
TORCH_CHECK(x.size(0) * x.size(1) <= INT_MAX && x.size(2) <= INT_MAX && x.size(3) <= INT_MAX, "x is too large"); | |
TORCH_CHECK(x.numel() > 0, "x is empty"); | |
TORCH_CHECK((fu.dim() == 1 || fu.dim() == 2) && (fd.dim() == 1 || fd.dim() == 2), "fu and fd must be rank 1 or 2"); | |
TORCH_CHECK(fu.size(0) <= INT_MAX && fu.size(-1) <= INT_MAX, "fu is too large"); | |
TORCH_CHECK(fd.size(0) <= INT_MAX && fd.size(-1) <= INT_MAX, "fd is too large"); | |
TORCH_CHECK(fu.numel() > 0, "fu is empty"); | |
TORCH_CHECK(fd.numel() > 0, "fd is empty"); | |
TORCH_CHECK(b.dim() == 1 && b.size(0) == x.size(1), "b must be a vector with the same number of channels as x"); | |
TORCH_CHECK(up >= 1 && down >= 1, "up and down must be at least 1"); | |
// Figure out how much shared memory is available on the device. | |
int maxSharedBytes = 0; | |
AT_CUDA_CHECK(cudaDeviceGetAttribute(&maxSharedBytes, cudaDevAttrMaxSharedMemoryPerBlockOptin, x.device().index())); | |
int sharedKB = maxSharedBytes >> 10; | |
// Populate enough launch parameters to check if a CUDA kernel exists. | |
filtered_lrelu_kernel_params p; | |
p.up = up; | |
p.down = down; | |
p.fuShape = make_int2((int)fu.size(-1), fu.dim() == 2 ? (int)fu.size(0) : 0); // shape [n, 0] indicates separable filter. | |
p.fdShape = make_int2((int)fd.size(-1), fd.dim() == 2 ? (int)fd.size(0) : 0); | |
filtered_lrelu_kernel_spec test_spec = choose_filtered_lrelu_kernel<float, int32_t, false, false>(p, sharedKB); | |
if (!test_spec.exec) | |
{ | |
// No kernel found - return empty tensors and indicate missing kernel with return code of -1. | |
return std::make_tuple(torch::Tensor(), torch::Tensor(), -1); | |
} | |
// Input/output element size. | |
int64_t sz = (x.dtype() == torch::kHalf) ? 2 : 4; | |
// Input sizes. | |
int64_t xw = (int)x.size(3); | |
int64_t xh = (int)x.size(2); | |
int64_t fut_w = (int)fu.size(-1) - 1; | |
int64_t fut_h = (int)fu.size(0) - 1; | |
int64_t fdt_w = (int)fd.size(-1) - 1; | |
int64_t fdt_h = (int)fd.size(0) - 1; | |
// Logical size of upsampled buffer. | |
int64_t cw = xw * up + (px0 + px1) - fut_w; | |
int64_t ch = xh * up + (py0 + py1) - fut_h; | |
TORCH_CHECK(cw > fdt_w && ch > fdt_h, "upsampled buffer must be at least the size of downsampling filter"); | |
TORCH_CHECK(cw <= INT_MAX && ch <= INT_MAX, "upsampled buffer is too large"); | |
// Compute output size and allocate. | |
int64_t yw = (cw - fdt_w + (down - 1)) / down; | |
int64_t yh = (ch - fdt_h + (down - 1)) / down; | |
TORCH_CHECK(yw > 0 && yh > 0, "output must be at least 1x1"); | |
TORCH_CHECK(yw <= INT_MAX && yh <= INT_MAX, "output is too large"); | |
torch::Tensor y = torch::empty({x.size(0), x.size(1), yh, yw}, x.options(), x.suggest_memory_format()); | |
// Allocate sign tensor. | |
torch::Tensor so; | |
torch::Tensor s = si; | |
bool readSigns = !!s.numel(); | |
int64_t sw_active = 0; // Active width of sign tensor. | |
if (writeSigns) | |
{ | |
sw_active = yw * down - (down - 1) + fdt_w; // Active width in elements. | |
int64_t sh = yh * down - (down - 1) + fdt_h; // Height = active height. | |
int64_t sw = (sw_active + 15) & ~15; // Width = active width in elements, rounded up to multiple of 16. | |
TORCH_CHECK(sh <= INT_MAX && (sw >> 2) <= INT_MAX, "signs is too large"); | |
s = so = torch::empty({x.size(0), x.size(1), sh, sw >> 2}, x.options().dtype(torch::kUInt8), at::MemoryFormat::Contiguous); | |
} | |
else if (readSigns) | |
sw_active = s.size(3) << 2; | |
// Validate sign tensor if in use. | |
if (readSigns || writeSigns) | |
{ | |
TORCH_CHECK(s.is_contiguous(), "signs must be contiguous"); | |
TORCH_CHECK(s.dtype() == torch::kUInt8, "signs must be uint8"); | |
TORCH_CHECK(s.device() == x.device(), "signs must reside on the same device as x"); | |
TORCH_CHECK(s.dim() == 4, "signs must be rank 4"); | |
TORCH_CHECK(s.size(0) == x.size(0) && s.size(1) == x.size(1), "signs must have same batch & channels as x"); | |
TORCH_CHECK(s.size(2) <= INT_MAX && s.size(3) <= INT_MAX, "signs is too large"); | |
} | |
// Populate rest of CUDA kernel parameters. | |
p.x = x.data_ptr(); | |
p.y = y.data_ptr(); | |
p.b = b.data_ptr(); | |
p.s = (readSigns || writeSigns) ? s.data_ptr<unsigned char>() : 0; | |
p.fu = fu.data_ptr<float>(); | |
p.fd = fd.data_ptr<float>(); | |
p.pad0 = make_int2(px0, py0); | |
p.gain = gain; | |
p.slope = slope; | |
p.clamp = clamp; | |
p.flip = (flip_filters) ? 1 : 0; | |
p.xShape = make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0)); | |
p.yShape = make_int4((int)y.size(3), (int)y.size(2), (int)y.size(1), (int)y.size(0)); | |
p.sShape = (readSigns || writeSigns) ? make_int2((int)s.size(3), (int)s.size(2)) : make_int2(0, 0); // Width is in bytes. Contiguous. | |
p.sOfs = make_int2(sx, sy); | |
p.swLimit = (sw_active + 3) >> 2; // Rounded up to bytes. | |
// x, y, b strides are in bytes. | |
p.xStride = make_longlong4(sz * x.stride(3), sz * x.stride(2), sz * x.stride(1), sz * x.stride(0)); | |
p.yStride = make_longlong4(sz * y.stride(3), sz * y.stride(2), sz * y.stride(1), sz * y.stride(0)); | |
p.bStride = sz * b.stride(0); | |
// fu, fd strides are in elements. | |
p.fuStride = make_longlong3(fu.stride(-1), fu.dim() == 2 ? fu.stride(0) : 0, 0); | |
p.fdStride = make_longlong3(fd.stride(-1), fd.dim() == 2 ? fd.stride(0) : 0, 0); | |
// Determine if indices don't fit in int32. Support negative strides although Torch currently never produces those. | |
bool index64b = false; | |
if (std::abs(p.bStride * x.size(1)) > INT_MAX) index64b = true; | |
if (std::min(x.size(0) * p.xStride.w, 0ll) + std::min(x.size(1) * p.xStride.z, 0ll) + std::min(x.size(2) * p.xStride.y, 0ll) + std::min(x.size(3) * p.xStride.x, 0ll) < -INT_MAX) index64b = true; | |
if (std::max(x.size(0) * p.xStride.w, 0ll) + std::max(x.size(1) * p.xStride.z, 0ll) + std::max(x.size(2) * p.xStride.y, 0ll) + std::max(x.size(3) * p.xStride.x, 0ll) > INT_MAX) index64b = true; | |
if (std::min(y.size(0) * p.yStride.w, 0ll) + std::min(y.size(1) * p.yStride.z, 0ll) + std::min(y.size(2) * p.yStride.y, 0ll) + std::min(y.size(3) * p.yStride.x, 0ll) < -INT_MAX) index64b = true; | |
if (std::max(y.size(0) * p.yStride.w, 0ll) + std::max(y.size(1) * p.yStride.z, 0ll) + std::max(y.size(2) * p.yStride.y, 0ll) + std::max(y.size(3) * p.yStride.x, 0ll) > INT_MAX) index64b = true; | |
if (s.numel() > INT_MAX) index64b = true; | |
// Choose CUDA kernel. | |
filtered_lrelu_kernel_spec spec = { 0 }; | |
AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "filtered_lrelu_cuda", [&] | |
{ | |
if constexpr (sizeof(scalar_t) <= 4) // Exclude doubles. constexpr prevents template instantiation. | |
{ | |
// Choose kernel based on index type, datatype and sign read/write modes. | |
if (!index64b && writeSigns && !readSigns) spec = choose_filtered_lrelu_kernel<scalar_t, int32_t, true, false>(p, sharedKB); | |
else if (!index64b && !writeSigns && readSigns) spec = choose_filtered_lrelu_kernel<scalar_t, int32_t, false, true >(p, sharedKB); | |
else if (!index64b && !writeSigns && !readSigns) spec = choose_filtered_lrelu_kernel<scalar_t, int32_t, false, false>(p, sharedKB); | |
else if ( index64b && writeSigns && !readSigns) spec = choose_filtered_lrelu_kernel<scalar_t, int64_t, true, false>(p, sharedKB); | |
else if ( index64b && !writeSigns && readSigns) spec = choose_filtered_lrelu_kernel<scalar_t, int64_t, false, true >(p, sharedKB); | |
else if ( index64b && !writeSigns && !readSigns) spec = choose_filtered_lrelu_kernel<scalar_t, int64_t, false, false>(p, sharedKB); | |
} | |
}); | |
TORCH_CHECK(spec.exec, "internal error - CUDA kernel not found") // This should not happen because we tested earlier that kernel exists. | |
// Launch CUDA kernel. | |
void* args[] = {&p}; | |
int bx = spec.numWarps * 32; | |
int gx = (p.yShape.x - 1) / spec.tileOut.x + 1; | |
int gy = (p.yShape.y - 1) / spec.tileOut.y + 1; | |
int gz = p.yShape.z * p.yShape.w; | |
// Repeat multiple horizontal tiles in a CTA? | |
if (spec.xrep) | |
{ | |
p.tilesXrep = spec.xrep; | |
p.tilesXdim = gx; | |
gx = (gx + p.tilesXrep - 1) / p.tilesXrep; | |
std::swap(gx, gy); | |
} | |
else | |
{ | |
p.tilesXrep = 0; | |
p.tilesXdim = 0; | |
} | |
// Launch filter setup kernel. | |
AT_CUDA_CHECK(cudaLaunchKernel(spec.setup, 1, 1024, args, 0, at::cuda::getCurrentCUDAStream())); | |
// Copy kernels to constant memory. | |
if ( writeSigns && !readSigns) AT_CUDA_CHECK((copy_filters<true, false>(at::cuda::getCurrentCUDAStream()))); | |
else if (!writeSigns && readSigns) AT_CUDA_CHECK((copy_filters<false, true >(at::cuda::getCurrentCUDAStream()))); | |
else if (!writeSigns && !readSigns) AT_CUDA_CHECK((copy_filters<false, false>(at::cuda::getCurrentCUDAStream()))); | |
// Set cache and shared memory configurations for main kernel. | |
AT_CUDA_CHECK(cudaFuncSetCacheConfig(spec.exec, cudaFuncCachePreferShared)); | |
if (spec.dynamicSharedKB) // Need dynamically allocated shared memory? | |
AT_CUDA_CHECK(cudaFuncSetAttribute(spec.exec, cudaFuncAttributeMaxDynamicSharedMemorySize, spec.dynamicSharedKB << 10)); | |
AT_CUDA_CHECK(cudaFuncSetSharedMemConfig(spec.exec, cudaSharedMemBankSizeFourByte)); | |
// Launch main kernel. | |
const int maxSubGz = 65535; // CUDA maximum for block z dimension. | |
for (int zofs=0; zofs < gz; zofs += maxSubGz) // Do multiple launches if gz is too big. | |
{ | |
p.blockZofs = zofs; | |
int subGz = std::min(maxSubGz, gz - zofs); | |
AT_CUDA_CHECK(cudaLaunchKernel(spec.exec, dim3(gx, gy, subGz), bx, args, spec.dynamicSharedKB << 10, at::cuda::getCurrentCUDAStream())); | |
} | |
// Done. | |
return std::make_tuple(y, so, 0); | |
} | |
//------------------------------------------------------------------------ | |
static torch::Tensor filtered_lrelu_act(torch::Tensor x, torch::Tensor si, int sx, int sy, float gain, float slope, float clamp, bool writeSigns) | |
{ | |
// Set CUDA device. | |
TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device"); | |
const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); | |
// Validate arguments. | |
TORCH_CHECK(x.dim() == 4, "x must be rank 4"); | |
TORCH_CHECK(x.size(0) * x.size(1) <= INT_MAX && x.size(2) <= INT_MAX && x.size(3) <= INT_MAX, "x is too large"); | |
TORCH_CHECK(x.numel() > 0, "x is empty"); | |
TORCH_CHECK(x.dtype() == torch::kHalf || x.dtype() == torch::kFloat || x.dtype() == torch::kDouble, "x must be float16, float32 or float64"); | |
// Output signs if we don't have sign input. | |
torch::Tensor so; | |
torch::Tensor s = si; | |
bool readSigns = !!s.numel(); | |
if (writeSigns) | |
{ | |
int64_t sw = x.size(3); | |
sw = (sw + 15) & ~15; // Round to a multiple of 16 for coalescing. | |
s = so = torch::empty({x.size(0), x.size(1), x.size(2), sw >> 2}, x.options().dtype(torch::kUInt8), at::MemoryFormat::Contiguous); | |
} | |
// Validate sign tensor if in use. | |
if (readSigns || writeSigns) | |
{ | |
TORCH_CHECK(s.is_contiguous(), "signs must be contiguous"); | |
TORCH_CHECK(s.dtype() == torch::kUInt8, "signs must be uint8"); | |
TORCH_CHECK(s.device() == x.device(), "signs must reside on the same device as x"); | |
TORCH_CHECK(s.dim() == 4, "signs must be rank 4"); | |
TORCH_CHECK(s.size(0) == x.size(0) && s.size(1) == x.size(1), "signs must have same batch & channels as x"); | |
TORCH_CHECK(s.size(2) <= INT_MAX && (s.size(3) << 2) <= INT_MAX, "signs tensor is too large"); | |
} | |
// Initialize CUDA kernel parameters. | |
filtered_lrelu_act_kernel_params p; | |
p.x = x.data_ptr(); | |
p.s = (readSigns || writeSigns) ? s.data_ptr<unsigned char>() : 0; | |
p.gain = gain; | |
p.slope = slope; | |
p.clamp = clamp; | |
p.xShape = make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0)); | |
p.xStride = make_longlong4(x.stride(3), x.stride(2), x.stride(1), x.stride(0)); | |
p.sShape = (readSigns || writeSigns) ? make_int2((int)s.size(3) << 2, (int)s.size(2)) : make_int2(0, 0); // Width is in elements. Contiguous. | |
p.sOfs = make_int2(sx, sy); | |
// Choose CUDA kernel. | |
void* func = 0; | |
AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "filtered_lrelu_act_cuda", [&] | |
{ | |
if (writeSigns) | |
func = choose_filtered_lrelu_act_kernel<scalar_t, true, false>(); | |
else if (readSigns) | |
func = choose_filtered_lrelu_act_kernel<scalar_t, false, true>(); | |
else | |
func = choose_filtered_lrelu_act_kernel<scalar_t, false, false>(); | |
}); | |
TORCH_CHECK(func, "internal error - CUDA kernel not found"); | |
// Launch CUDA kernel. | |
void* args[] = {&p}; | |
int bx = 128; // 4 warps per block. | |
// Logical size of launch = writeSigns ? p.s : p.x | |
uint32_t gx = writeSigns ? p.sShape.x : p.xShape.x; | |
uint32_t gy = writeSigns ? p.sShape.y : p.xShape.y; | |
uint32_t gz = p.xShape.z * p.xShape.w; // Same as in p.sShape if signs are in use. | |
gx = (gx - 1) / bx + 1; | |
// Make sure grid y and z dimensions are within CUDA launch limits. Kernel loops internally to do the rest. | |
const uint32_t gmax = 65535; | |
gy = std::min(gy, gmax); | |
gz = std::min(gz, gmax); | |
// Launch. | |
AT_CUDA_CHECK(cudaLaunchKernel(func, dim3(gx, gy, gz), bx, args, 0, at::cuda::getCurrentCUDAStream())); | |
return so; | |
} | |
//------------------------------------------------------------------------ | |
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) | |
{ | |
m.def("filtered_lrelu", &filtered_lrelu); // The whole thing. | |
m.def("filtered_lrelu_act_", &filtered_lrelu_act); // Activation and sign tensor handling only. Modifies data tensor in-place. | |
} | |
//------------------------------------------------------------------------ | |