Spaces:
Sleeping
Sleeping
// Adapted from https://github.com/NVIDIA/apex/blob/master/apex/contrib/csrc/xentropy/xentropy_kernel.cu | |
// TD [2022-09-17]: We make it work for bfloat16, and add an option to do the backward inplace (to save memory). | |
/** | |
* From PyTorch: | |
* | |
* Copyright (c) 2016- Facebook, Inc (Adam Paszke) | |
* Copyright (c) 2014- Facebook, Inc (Soumith Chintala) | |
* Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert) | |
* Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu) | |
* Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu) | |
* Copyright (c) 2011-2013 NYU (Clement Farabet) | |
* Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston) | |
* Copyright (c) 2006 Idiap Research Institute (Samy Bengio) | |
* Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz) | |
* | |
* From Caffe2: | |
* | |
* Copyright (c) 2016-present, Facebook Inc. All rights reserved. | |
* | |
* All contributions by Facebook: | |
* Copyright (c) 2016 Facebook Inc. | |
* | |
* All contributions by Google: | |
* Copyright (c) 2015 Google Inc. | |
* All rights reserved. | |
* | |
* All contributions by Yangqing Jia: | |
* Copyright (c) 2015 Yangqing Jia | |
* All rights reserved. | |
* | |
* All contributions from Caffe: | |
* Copyright(c) 2013, 2014, 2015, the respective contributors | |
* All rights reserved. | |
* | |
* All other contributions: | |
* Copyright(c) 2015, 2016 the respective contributors | |
* All rights reserved. | |
* | |
* Caffe2 uses a copyright model similar to Caffe: each contributor holds | |
* copyright over their contributions to Caffe2. The project versioning records | |
* all such contribution and copyright details. If a contributor wants to further | |
* mark their specific copyright on a particular contribution, they should | |
* indicate their copyright solely in the commit message of the change when it is | |
* committed. | |
* | |
* All rights reserved. | |
* | |
* Redistribution and use in source and binary forms, with or without | |
* modification, are permitted provided that the following conditions are met: | |
* | |
* 1. Redistributions of source code must retain the above copyright | |
* notice, this list of conditions and the following disclaimer. | |
* | |
* 2. Redistributions in binary form must reproduce the above copyright | |
* notice, this list of conditions and the following disclaimer in the | |
* documentation and/or other materials provided with the distribution. | |
* | |
* 3. Neither the names of Facebook, Deepmind Technologies, NYU, NEC Laboratories America | |
* and IDIAP Research Institute nor the names of its contributors may be | |
* used to endorse or promote products derived from this software without | |
* specific prior written permission. | |
* | |
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | |
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | |
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE | |
* ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE | |
* LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR | |
* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF | |
* SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS | |
* INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN | |
* CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) | |
* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE | |
* POSSIBILITY OF SUCH DAMAGE. | |
*/ | |
#include <ATen/ATen.h> | |
#include <ATen/cuda/CUDAContext.h> | |
#include <c10/cuda/CUDAGuard.h> | |
#include <ATen/AccumulateType.h> | |
#include <ATen/cuda/NumericLimits.cuh> | |
// https://github.com/NVIDIA/apex/blob/master/csrc/type_shim.h | |
// #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 | |
#define DISPATCH_FLOAT_AND_HALF_AND_BF16(TYPE, LEVEL, NAME, ...) \ | |
switch(TYPE) \ | |
{ \ | |
case at::ScalarType::Float: \ | |
{ \ | |
using scalar_t_##LEVEL = float; \ | |
__VA_ARGS__; \ | |
break; \ | |
} \ | |
case at::ScalarType::Half: \ | |
{ \ | |
using scalar_t_##LEVEL = at::Half; \ | |
__VA_ARGS__; \ | |
break; \ | |
} \ | |
case at::ScalarType::BFloat16: \ | |
{ \ | |
using scalar_t_##LEVEL = at::BFloat16; \ | |
__VA_ARGS__; \ | |
break; \ | |
} \ | |
default: \ | |
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ | |
} | |
// #else | |
// #define DISPATCH_FLOAT_AND_HALF_AND_BF16(TYPE, LEVEL, NAME, ...) \ | |
// switch(TYPE) \ | |
// { \ | |
// case at::ScalarType::Float: \ | |
// { \ | |
// using scalar_t_##LEVEL = float; \ | |
// __VA_ARGS__; \ | |
// break; \ | |
// } \ | |
// case at::ScalarType::Half: \ | |
// { \ | |
// using scalar_t_##LEVEL = at::Half; \ | |
// __VA_ARGS__; \ | |
// break; \ | |
// } \ | |
// default: \ | |
// AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ | |
// } | |
// #endif | |
#define ALIGN_BYTES 16 | |
using Tensor = at::Tensor; | |
using TensorList = at::TensorList; | |
using ScalarType = at::ScalarType; | |
using at::acc_type; | |
template<typename T, typename AccumT, typename OutT> | |
struct LogSoftMaxForwardEpilogue { | |
__device__ __forceinline__ LogSoftMaxForwardEpilogue(AccumT max_input, AccumT sum) | |
: logsum(max_input + std::log(sum)) {} | |
__device__ __forceinline__ LogSoftMaxForwardEpilogue(AccumT max_log_sum_exp) | |
: logsum(max_log_sum_exp) {} | |
__device__ __forceinline__ OutT operator()(T input) const { | |
return static_cast<OutT>(input - logsum); | |
} | |
const AccumT logsum; | |
}; | |
template<typename T, typename AccumT, typename OutT> | |
struct LogSoftMaxBackwardEpilogue { | |
__device__ __forceinline__ LogSoftMaxBackwardEpilogue(AccumT sum) | |
: sum(sum) {} | |
__device__ __forceinline__ T operator()(OutT gradOutput, OutT output) const { | |
return static_cast<T>(gradOutput - std::exp(static_cast<AccumT>(output)) * sum); | |
} | |
const AccumT sum; | |
}; | |
const int max_threads = 1024; | |
inline dim3 SoftMax_getBlockSize(int ILP, uint64_t dim_size) { | |
uint64_t block_size = 1; | |
uint64_t max_block_size = std::min(dim_size / ILP, static_cast<uint64_t>(max_threads)); | |
while (block_size < (max_block_size/2)) block_size *= 2; | |
// Launch at least a single warp - the kernel assumes that. | |
block_size = std::max(block_size, static_cast<uint64_t>(32)); | |
return dim3(block_size); | |
} | |
template<typename T> | |
struct Add { | |
__device__ __forceinline__ T operator()(T a, T b) const { | |
return a + b; | |
} | |
}; | |
template<typename T> | |
struct Max { | |
__device__ __forceinline__ T operator()(T a, T b) const { | |
return a < b ? b : a; | |
} | |
}; | |
//////////////////////////////////////////////////////////////////////////////// | |
// Regular kernel (fast when dim_size is large; requires inner_size == 1) | |
//////////////////////////////////////////////////////////////////////////////// | |
template <typename T, typename AccumT> | |
struct MaxFloat | |
{ | |
__device__ __forceinline__ AccumT operator()(AccumT max, T v) const { | |
return ::max(max, (AccumT)v); | |
} | |
}; | |
template<typename T, typename AccumT> | |
struct AddFloat | |
{ | |
__device__ __forceinline__ AccumT operator()(AccumT sum, T v) const { | |
return sum + v; | |
} | |
}; | |
template<typename T, typename AccumT> | |
struct SumExpFloat | |
{ | |
__device__ __forceinline__ SumExpFloat(AccumT v) | |
: max_k(v) {} | |
__device__ __forceinline__ AccumT operator()(AccumT sum, T v) const { | |
return sum + std::exp(v - max_k); | |
} | |
const AccumT max_k; | |
}; | |
template <template<typename> class Reduction, typename AccumT> | |
__device__ __forceinline__ AccumT | |
blockReduce(AccumT* smem, AccumT val, | |
const Reduction<AccumT>& r, | |
AccumT defaultVal) | |
{ | |
// To avoid RaW races from chaining blockReduce calls together, we need a sync here | |
__syncthreads(); | |
smem[threadIdx.x] = val; | |
__syncthreads(); | |
AccumT warpVal = defaultVal; | |
// First warp will perform per-warp reductions for the remaining warps | |
uint32_t mask = (((uint64_t)1) << (blockDim.x / 32)) - 1; | |
if (threadIdx.x < 32) { | |
int lane = threadIdx.x % 32; | |
if (lane < blockDim.x / 32) { | |
#pragma unroll | |
for (int i = 0; i < 32; ++i) { | |
warpVal = r(warpVal, smem[lane * 32 + i]); | |
} | |
__syncwarp(mask); | |
smem[lane] = warpVal; | |
} | |
} | |
__syncthreads(); | |
// First thread will perform a reduction of the above per-warp reductions | |
AccumT blockVal = defaultVal; | |
if (threadIdx.x == 0) { | |
for (int i = 0; i < blockDim.x / 32; ++i) { | |
blockVal = r(blockVal, smem[i]); | |
} | |
smem[0] = blockVal; | |
} | |
// Sync and broadcast | |
__syncthreads(); | |
return smem[0]; | |
} | |
template <template<typename> class Reduction1, template<typename> class Reduction2, typename AccumT> | |
__device__ __forceinline__ void | |
blockReduce(AccumT* smem, | |
AccumT* reducVal1, | |
AccumT val1, | |
const Reduction1<AccumT>& r1, | |
AccumT defaultVal1, | |
AccumT* reducVal2, | |
AccumT val2, | |
const Reduction2<AccumT>& r2, | |
AccumT defaultVal2) | |
{ | |
// To avoid RaW races from chaining blockReduce calls together, we need a sync here | |
__syncthreads(); | |
smem[threadIdx.x] = val1; | |
smem[blockDim.x + threadIdx.x] = val2; | |
__syncthreads(); | |
AccumT warpVal1 = defaultVal1; | |
AccumT warpVal2 = defaultVal2; | |
// First warp will perform per-warp reductions for the remaining warps | |
uint32_t mask = (((uint64_t)1) << (blockDim.x / 32)) - 1; | |
if (threadIdx.x < 32) { | |
int lane = threadIdx.x % 32; | |
if (lane < blockDim.x / 32) { | |
#pragma unroll | |
for (int i = 0; i < 32; ++i) { | |
warpVal1 = r1(warpVal1, smem[lane * 32 + i]); | |
warpVal2 = r2(warpVal2, smem[lane * 32 + i + blockDim.x]); | |
} | |
__syncwarp(mask); | |
smem[lane] = warpVal1; | |
smem[lane + blockDim.x] = warpVal2; | |
} | |
} | |
__syncthreads(); | |
// First thread will perform a reduction of the above per-warp reductions | |
AccumT blockVal1 = defaultVal1; | |
AccumT blockVal2 = defaultVal2; | |
if (threadIdx.x == 0) { | |
for (int i = 0; i < blockDim.x / 32; ++i) { | |
blockVal1 = r1(blockVal1, smem[i]); | |
blockVal2 = r2(blockVal2, smem[i + blockDim.x]); | |
} | |
smem[0] = blockVal1; | |
smem[blockDim.x] = blockVal2; | |
} | |
// Sync and broadcast | |
__syncthreads(); | |
*reducVal1 = smem[0]; | |
*reducVal2 = smem[blockDim.x]; | |
__syncthreads(); | |
} | |
template <template<typename, typename> class Reduction, int ILP, typename T, typename AccumT> | |
__device__ __forceinline__ AccumT | |
ilpReduce(int shift, | |
T* data, | |
int size, | |
const Reduction<T, AccumT>& r, | |
AccumT defaultVal) | |
{ | |
typedef typename std::aligned_storage<ILP*sizeof(T), ILP*alignof(T)>::type LoadT; | |
AccumT threadVal = defaultVal; | |
int offset = threadIdx.x; | |
// shift and do 1 | |
if(shift > 0){ | |
data -= shift; | |
size += shift; | |
if(threadIdx.x >= shift){ | |
threadVal = r(threadVal, data[offset]); | |
} | |
size -= blockDim.x; | |
data += blockDim.x; | |
} | |
int last = size % (ILP * blockDim.x); | |
T v[ILP]; | |
LoadT* value = reinterpret_cast<LoadT*>(&v); | |
for (; offset * ILP < (size - last); offset += blockDim.x) { | |
*value = reinterpret_cast<LoadT*>(data)[offset]; | |
for (int j = 0; j < ILP; ++j) { | |
threadVal = r(threadVal, v[j]); | |
} | |
} | |
offset = size - last + threadIdx.x; | |
// Epilogue | |
for (; offset < size; offset += blockDim.x) | |
threadVal = r(threadVal, data[offset]); | |
return threadVal; | |
} | |
template <template<typename, typename> class Reduction1, template<typename, typename> class Reduction2, int ILP, typename T, typename AccumT> | |
__device__ __forceinline__ void | |
ilpReduce(int shift, | |
T* data, | |
int size, | |
AccumT* reducVal1, | |
const Reduction1<T, AccumT>& r1, | |
AccumT defaultVal1, | |
AccumT* reducVal2, | |
const Reduction2<T, AccumT>& r2, | |
AccumT defaultVal2) | |
{ | |
typedef typename std::aligned_storage<ILP*sizeof(T), ILP*alignof(T)>::type LoadT; | |
AccumT threadVal1 = defaultVal1; | |
AccumT threadVal2 = defaultVal2; | |
int offset = threadIdx.x; | |
// shift and do 1 | |
if(shift > 0){ | |
data -= shift; | |
size += shift; | |
if(threadIdx.x >= shift){ | |
threadVal1 = r1(threadVal1, data[offset]); | |
threadVal2 = r2(threadVal2, data[offset]); | |
} | |
size -= blockDim.x; | |
data += blockDim.x; | |
} | |
int last = size % (ILP * blockDim.x); | |
T v[ILP]; | |
LoadT* value = reinterpret_cast<LoadT*>(&v); | |
for (; offset * ILP < (size - last); offset += blockDim.x) { | |
*value = reinterpret_cast<LoadT*>(data)[offset]; | |
for (int j = 0; j < ILP; ++j) { | |
threadVal1 = r1(threadVal1, v[j]); | |
threadVal2 = r2(threadVal2, v[j]); | |
} | |
} | |
offset = size - last + threadIdx.x; | |
// Epilogue | |
for (; offset < size; offset += blockDim.x) { | |
threadVal1 = r1(threadVal1, data[offset]); | |
threadVal2 = r2(threadVal2, data[offset]); | |
} | |
*reducVal1 = threadVal1; | |
*reducVal2 = threadVal2; | |
} | |
template <int ILP, typename scalar_t, typename accscalar_t, typename outscalar_t, template <typename, typename, typename> class Epilogue> | |
__global__ void | |
cunn_SoftMaxXEntropyForward( | |
accscalar_t *losses, | |
outscalar_t *max_log_sum_exp, | |
scalar_t *input, | |
int64_t *labels, | |
int64_t classes, | |
const float smoothing, | |
const int total_classes) | |
{ | |
extern __shared__ unsigned char smem[]; | |
auto sdata = reinterpret_cast<accscalar_t*>(smem); | |
// forward pointers to batch[blockIdx.x] | |
// each block handles a sample in the mini-batch | |
input += blockIdx.x * classes; | |
//output += blockIdx.x * classes; | |
const int shift = ((uint64_t)input) % ALIGN_BYTES / sizeof(scalar_t); | |
int64_t label = labels[blockIdx.x]; | |
// find the max and sum | |
accscalar_t threadMax, threadSum, max_k, sum_k; | |
ilpReduce<MaxFloat, AddFloat, ILP, scalar_t, accscalar_t>( | |
shift, input, classes, | |
&threadMax, MaxFloat<scalar_t, accscalar_t>(), | |
-at::numeric_limits<accscalar_t>::max(), | |
&threadSum, AddFloat<scalar_t, accscalar_t>(), | |
static_cast<accscalar_t>(0)); | |
blockReduce<Max, Add, accscalar_t>( | |
sdata, | |
&max_k, threadMax, Max<accscalar_t>(), | |
-at::numeric_limits<accscalar_t>::max(), | |
&sum_k, threadSum, Add<accscalar_t>(), | |
static_cast<accscalar_t>(0)); | |
accscalar_t threadExp = ilpReduce<SumExpFloat, ILP, scalar_t, accscalar_t>(shift, input, classes, SumExpFloat<scalar_t, accscalar_t>(max_k), static_cast<accscalar_t>(0)); | |
accscalar_t sumAll = blockReduce<Add, accscalar_t>( | |
sdata, threadExp, Add<accscalar_t>(), static_cast<accscalar_t>(0)); | |
Epilogue<scalar_t, accscalar_t, outscalar_t> epilogue(max_k, sumAll); | |
// calculate per element loss with label smoothing | |
// reserve max + log_sum_exp for bprop | |
if (threadIdx.x == 0) { | |
accscalar_t lse = max_k + std::log(sumAll); | |
accscalar_t log_prob = (label >= 0 && label < classes) ? epilogue(static_cast<accscalar_t>(input[label])) : 0.f; | |
losses[blockIdx.x] = (lse - sum_k / total_classes) * smoothing - log_prob * (1 - smoothing); | |
max_log_sum_exp[blockIdx.x] = lse; | |
} | |
} | |
template <int ILP, typename scalar_t, typename accscalar_t, typename outscalar_t> | |
__device__ __forceinline__ void | |
apply(scalar_t *gradInput, | |
scalar_t *logits, | |
outscalar_t *max_log_sum_exp, | |
outscalar_t *gradOutput, | |
int64_t *labels, | |
const float smoothing, | |
int classes, | |
const int total_classes) | |
{ | |
accscalar_t smooth_positives = 1.0 - smoothing; | |
accscalar_t smooth_negatives = smoothing / total_classes; | |
accscalar_t tmpGradOutput = gradOutput[blockIdx.x]; | |
int64_t label = labels[blockIdx.x]; | |
accscalar_t coeff = max_log_sum_exp[blockIdx.x]; | |
int offset = threadIdx.x; | |
int last = classes % (ILP * blockDim.x); | |
for (; offset < classes - last; offset += blockDim.x * ILP) { | |
accscalar_t tmpLogits[ILP]; | |
#pragma unroll | |
for (int j = 0; j < ILP; ++j) { | |
tmpLogits[j] = static_cast<accscalar_t>(logits[offset + j * blockDim.x]); | |
} | |
#pragma unroll | |
for (int j = 0; j < ILP; ++j) | |
gradInput[offset + j * blockDim.x] = tmpGradOutput * ( | |
std::exp(tmpLogits[j] - coeff) - static_cast<accscalar_t>( | |
(offset + j * blockDim.x == label) ? 1 : 0) * | |
smooth_positives - smooth_negatives); | |
} | |
for (; offset < classes; offset += blockDim.x) | |
gradInput[offset] = tmpGradOutput * (std::exp( | |
static_cast<accscalar_t>(logits[offset]) - coeff) - | |
static_cast<accscalar_t>((offset == label) ? 1 : 0) * | |
smooth_positives - smooth_negatives); | |
} | |
template <int ILP, typename scalar_t, typename accscalar_t, typename outscalar_t> | |
__device__ __forceinline__ void | |
aligned_apply(int shift, | |
scalar_t *gradInput, | |
scalar_t *logits, | |
outscalar_t *max_log_sum_exp, | |
outscalar_t *gradOutput, | |
int64_t *labels, | |
const float smoothing, | |
int classes, | |
const int total_classes) | |
{ | |
accscalar_t smooth_positives = 1.0 - smoothing; | |
accscalar_t smooth_negatives = smoothing / total_classes; | |
accscalar_t tmpGradOutput = gradOutput[blockIdx.x]; | |
int64_t label = labels[blockIdx.x]; | |
accscalar_t coeff = max_log_sum_exp[blockIdx.x]; | |
int offset = threadIdx.x; | |
// shift and do 1 | |
if(shift > 0){ | |
logits -= shift; | |
gradInput -= shift; | |
classes += shift; | |
if(threadIdx.x >= shift){ | |
gradInput[offset] = tmpGradOutput * (std::exp( | |
static_cast<accscalar_t>(logits[offset]) - coeff) - | |
static_cast<accscalar_t>(((offset - shift) == label) ? 1 : 0) * | |
smooth_positives - smooth_negatives); | |
} | |
classes -= blockDim.x; | |
gradInput += blockDim.x; | |
logits += blockDim.x; | |
shift -= blockDim.x; | |
} | |
int last = classes % (ILP * blockDim.x); | |
typedef typename std::aligned_storage<ILP*sizeof(scalar_t), ILP*alignof(scalar_t)>::type LoadT; | |
// input | |
scalar_t v[ILP]; | |
LoadT* value = reinterpret_cast<LoadT*>(&v); | |
// output | |
scalar_t r[ILP]; | |
LoadT* result = reinterpret_cast<LoadT*>(&r); | |
for (; offset * ILP < (classes - last); offset += blockDim.x) { | |
*value = reinterpret_cast<LoadT*>(logits)[offset]; | |
#pragma unroll | |
for (int j = 0; j < ILP; ++j) { | |
r[j] = tmpGradOutput * (std::exp( | |
static_cast<accscalar_t>(v[j]) - coeff) - | |
static_cast<accscalar_t>(((ILP * offset + j - shift) == label) ? 1 : 0) * | |
smooth_positives - smooth_negatives); | |
} | |
reinterpret_cast<LoadT*>(gradInput)[offset] = *result; | |
} | |
offset = classes - last + threadIdx.x; | |
for (; offset < classes; offset += blockDim.x) | |
gradInput[offset] = tmpGradOutput * (std::exp( | |
static_cast<accscalar_t>(logits[offset]) - coeff) - | |
static_cast<accscalar_t>(((offset - shift) == label) ? 1 : 0) * | |
smooth_positives - smooth_negatives); | |
} | |
template <int ILP, typename scalar_t, typename accscalar_t, typename outscalar_t, template<typename, typename, typename> class Epilogue> | |
__global__ void | |
cunn_SoftMaxXEntropyBackward( | |
scalar_t *gradInput, | |
scalar_t *logits, | |
outscalar_t *max_log_sum_exp, | |
outscalar_t *gradOutput, | |
int64_t *labels, | |
const float smoothing, | |
int classes, | |
const int total_classes) | |
{ | |
gradInput += blockIdx.x * classes; | |
logits += blockIdx.x * classes; | |
// Do vectorized load/store when input/output have same alignment | |
const int shift = ((uint64_t)logits) % ALIGN_BYTES / sizeof(scalar_t); | |
const int shift_ = ((uint64_t)gradInput) % ALIGN_BYTES / sizeof(scalar_t); | |
if (shift == shift_){ | |
aligned_apply<ILP, scalar_t, accscalar_t, outscalar_t>(shift, gradInput, logits, max_log_sum_exp, gradOutput, labels, smoothing, classes, total_classes <= 0 ? classes : total_classes); | |
} | |
else { | |
apply<ILP, scalar_t, accscalar_t, outscalar_t>(gradInput, logits, max_log_sum_exp, gradOutput, labels, smoothing, classes, total_classes <= 0 ? classes : total_classes); | |
} | |
} | |
template<template<typename, typename, typename> class Epilogue> | |
std::vector<Tensor> host_softmax_xentropy( | |
const Tensor & input_, | |
const Tensor & labels_, | |
const float smoothing, | |
const int total_classes) { | |
// For tensor parallel cross entropy with smoothing, we want to pass in the total number | |
// of classes so that smoothing can be applied correctly. If total_classes=-1, use the | |
// last dimension of the input tensor. | |
AT_ASSERTM(labels_.scalar_type() == ScalarType::Long,"Label type should be CUDA Long"); | |
// Otherwise the kernel will be launched from cuda:0 device | |
// Cast to char to avoid compiler warning about narrowing | |
at::cuda::CUDAGuard device_guard{(char)input_.get_device()}; | |
auto input = input_.contiguous(); | |
Tensor max_log_sum_exp = at::empty_like(labels_, input.options().dtype(ScalarType::Float)); | |
Tensor losses = at::empty_like(labels_, input_.options().dtype(ScalarType::Float)); | |
static_assert(std::is_same<acc_type<at::Half, true>, float>::value || | |
std::is_same<acc_type<at::Half, true>, double>::value, | |
"accscalar_t for half should be float or double"); | |
AT_ASSERTM(input.dim() == 2, "Currently only 2 dim input supported"); | |
AT_ASSERTM(labels_.dim() == 1, "Labels should be 1 dimensional"); | |
AT_ASSERTM(input.size(0) == labels_.size(0), "Input and label should have same number of examples"); | |
AT_ASSERTM(input.numel() > 0, "Number of classes in input should not be 0"); | |
const int64_t dim = 1; | |
int64_t outer_size = 1; | |
int64_t dim_size = input.size(dim); | |
int64_t inner_size = 1; | |
cudaStream_t stream = at::cuda::getCurrentCUDAStream(); | |
for (int64_t i = 0; i < dim; ++i) | |
outer_size *= input.size(i); | |
for (int64_t i = dim + 1; i < input.dim(); ++i) | |
inner_size *= input.size(i); | |
// This kernel spawns a block per each element in the batch. | |
// XXX: it assumes that inner_size == 1 | |
TORCH_CHECK(inner_size == 1, "Currently only inner size 1 supported"); | |
dim3 grid(outer_size); | |
using namespace at; | |
DISPATCH_FLOAT_AND_HALF_AND_BF16(input.scalar_type(), 0, "host_softmax_xentropy", | |
using accscalar_t = at::acc_type<scalar_t_0, true>; | |
const int ILP = sizeof(float4)/sizeof(scalar_t_0); | |
dim3 block = SoftMax_getBlockSize(ILP, dim_size); | |
cunn_SoftMaxXEntropyForward<ILP, scalar_t_0, accscalar_t, accscalar_t, Epilogue> | |
<<<grid, block, 2 * block.x * sizeof(accscalar_t), stream>>>( | |
losses.data_ptr<accscalar_t>(), max_log_sum_exp.data_ptr<accscalar_t>(), | |
input.data_ptr<scalar_t_0>(), labels_.data_ptr<int64_t>(), | |
dim_size, smoothing, total_classes <= 0 ? dim_size : total_classes | |
); | |
); | |
C10_CUDA_CHECK(cudaGetLastError()); | |
std::vector<at::Tensor> ret = {losses, max_log_sum_exp}; | |
return ret; | |
} | |
template<template<typename, typename, typename> class Epilogue> | |
Tensor host_softmax_xentropy_backward( | |
const at::Tensor &grad_loss, | |
at::Tensor &logits_, | |
const at::Tensor &max_log_sum_exp, | |
const at::Tensor &labels, | |
const float smoothing, | |
bool inplace, | |
const int total_classes) { | |
// Otherwise the kernel will be launched from cuda:0 device | |
// Cast to char to avoid compiler warning about narrowing | |
at::cuda::CUDAGuard device_guard{(char)grad_loss.get_device()}; | |
const int64_t dim = 1; | |
Tensor gI = inplace ? logits_ : at::empty_like(logits_); | |
if (grad_loss.numel() == 0) { | |
return gI; | |
} | |
auto grad = grad_loss.contiguous(); | |
auto logits = logits_.contiguous(); | |
static_assert(std::is_same<acc_type<at::Half, true>, float>::value || | |
std::is_same<acc_type<at::Half, true>, double>::value, | |
"accscalar_t for half should be float or double"); | |
if (grad.dim() == 0) grad = grad.view(1); | |
AT_ASSERTM(logits_.dim() == 2, "Currently only 2 dim input supported"); | |
AT_ASSERTM(labels.dim() == 1, "Labels should be 1 dimensional"); | |
AT_ASSERTM(logits_.numel() > 0, "Number of classes in input should not be 0"); | |
AT_ASSERTM(logits_.size(0) == labels.size(0), "Input and label should have same number of examples"); | |
AT_ASSERTM(labels.size(0) == grad.size(0), "Label and loss should have same number of examples"); | |
int64_t outer_size = 1; | |
int64_t dim_size = logits.size(dim); | |
int64_t inner_size = 1; | |
for (int64_t i = 0; i < dim; ++i) | |
outer_size *= logits.size(i); | |
for (int64_t i = dim + 1; i < logits.dim(); ++i) | |
inner_size *= logits.size(i); | |
// See descriptions of kernels above. | |
cudaStream_t stream = at::cuda::getCurrentCUDAStream(); | |
TORCH_CHECK(inner_size == 1, "Currently only inner size 1 supported"); | |
dim3 grid(outer_size); | |
DISPATCH_FLOAT_AND_HALF_AND_BF16(gI.scalar_type(), 0, "host_softmax_xentropy_backward", | |
using accscalar_t = acc_type<scalar_t_0, true>; | |
const int ILP = sizeof(float4)/sizeof(scalar_t_0); | |
dim3 block = SoftMax_getBlockSize(ILP, dim_size); | |
cunn_SoftMaxXEntropyBackward<ILP, scalar_t_0, accscalar_t, accscalar_t, Epilogue> | |
<<<grid, block, block.x * sizeof(accscalar_t), stream>>>( | |
gI.data_ptr<scalar_t_0>(), logits.data_ptr<scalar_t_0>(), | |
max_log_sum_exp.data_ptr<accscalar_t>(), | |
grad.data_ptr<accscalar_t>(), labels.data_ptr<int64_t>(), | |
smoothing, dim_size, total_classes | |
); | |
); | |
C10_CUDA_CHECK(cudaGetLastError()); | |
return gI; | |
} | |
std::vector<Tensor> softmax_xentropy_cuda(const Tensor &input, const Tensor &labels, const float smoothing, const int total_classes){ | |
return host_softmax_xentropy<LogSoftMaxForwardEpilogue>(input, labels, smoothing, total_classes); | |
} | |
at::Tensor softmax_xentropy_backward_cuda( | |
const at::Tensor &grad_loss, | |
at::Tensor &logits, | |
const at::Tensor &max_log_sum_exp, | |
const at::Tensor &labels, | |
const float smoothing, | |
const bool inplace, | |
const int total_classes) { | |
AT_ASSERTM((grad_loss.scalar_type() == ScalarType::Float), "expected grad types to be at::Float"); | |
return host_softmax_xentropy_backward<LogSoftMaxBackwardEpilogue>(grad_loss, logits, max_log_sum_exp, labels, smoothing, inplace, total_classes); | |
} | |