|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#include "sparse_matmul/compute/matmul_fixed_avx2.h" |
|
|
|
#include <cstdint> |
|
|
|
#if defined __AVX__ |
|
#include <immintrin.h> |
|
#endif |
|
|
|
#include "sparse_matmul/compute/matmul.h" |
|
|
|
namespace csrblocksparse { |
|
namespace detail { |
|
|
|
static const int32_t kint32min = static_cast<int32_t>(~0x7FFFFFFF); |
|
static const int32_t kint32max = static_cast<int32_t>(0x7FFFFFFF); |
|
|
|
#if defined __AVX2__ |
|
|
|
|
|
|
|
inline __m256i ComputeRowResults(const __m128i& bias128, const int16_t* rhs, |
|
const int16_t* rhs_indices, int nnz, |
|
int16_t const*& weights_ptr) { |
|
|
|
|
|
__m256i sum = _mm256_cvtepu32_epi64(bias128); |
|
|
|
for (int c = 0; c < nnz; ++c) { |
|
int rhs_index = rhs_indices[c]; |
|
|
|
__m256i weights = |
|
_mm256_load_si256(reinterpret_cast<__m256i const*>(weights_ptr)); |
|
|
|
__m128i rhs_64 = _mm_loadl_epi64( |
|
reinterpret_cast<__m128i const*>(rhs + rhs_index * kBlockSize)); |
|
|
|
|
|
__m256i rhs_value = _mm256_broadcastq_epi64(rhs_64); |
|
weights_ptr += 16; |
|
sum = _mm256_add_epi32(sum, _mm256_madd_epi16(weights, rhs_value)); |
|
} |
|
|
|
|
|
|
|
sum = _mm256_hadd_epi32(sum, sum); |
|
|
|
return _mm256_permute4x64_epi64(sum, 0xd8); |
|
} |
|
|
|
|
|
|
|
|
|
template <typename OutType, int kReplicas> |
|
void MatVec4x4FixedAVX2Template(const int16_t* weights_ptr, const int16_t* rhs, |
|
const int32_t* bias, const int32_t* nnz_per_row, |
|
const int16_t* rhs_indices, int start_row, |
|
int end_row, bool relu, int shift_out, |
|
int replicas, int stride, OutType* output) { |
|
int rounding_addon = shift_out > 0 ? (1 << (shift_out - 1)) : 0; |
|
__m256i rounding = _mm256_set1_epi32(rounding_addon); |
|
__m256i zero = relu ? _mm256_setzero_si256() : _mm256_set1_epi32(kint32min); |
|
for (int row_block = start_row; row_block < end_row; ++row_block) { |
|
|
|
__m128i bias128 = _mm_load_si128(reinterpret_cast<__m128i const*>(bias)); |
|
bias += kBlockSize; |
|
int nnz = nnz_per_row[row_block]; |
|
__m256i sum = |
|
ComputeRowResults(bias128, rhs, rhs_indices, nnz, weights_ptr); |
|
rhs_indices += nnz; |
|
|
|
sum = _mm256_add_epi32(sum, rounding); |
|
sum = _mm256_srai_epi32(sum, shift_out); |
|
|
|
sum = _mm256_max_epi32(sum, zero); |
|
if (sizeof(OutType) == 2) { |
|
|
|
|
|
|
|
sum = _mm256_packs_epi32(sum, sum); |
|
int64_t result = _mm256_extract_epi64(sum, 0); |
|
*reinterpret_cast<int64_t*>(output) = result; |
|
if (kReplicas > 1) { |
|
*reinterpret_cast<int64_t*>(output + stride) = result; |
|
if (kReplicas > 2) { |
|
for (int r = 2; r < replicas; ++r) { |
|
*reinterpret_cast<int64_t*>(output + r * stride) = result; |
|
} |
|
} |
|
} |
|
} else { |
|
|
|
__m128i result = _mm256_extractf128_si256(sum, 0); |
|
_mm_store_si128(reinterpret_cast<__m128i*>(output), result); |
|
if (kReplicas > 1) { |
|
_mm_store_si128(reinterpret_cast<__m128i*>(output + stride), result); |
|
if (kReplicas > 2) { |
|
for (int r = 2; r < replicas; ++r) { |
|
_mm_store_si128(reinterpret_cast<__m128i*>(output + r * stride), |
|
result); |
|
} |
|
} |
|
} |
|
} |
|
output += kBlockSize; |
|
} |
|
} |
|
|
|
|
|
|
|
void MatVec4x4FixedAVX2(const int16_t* weights_ptr, const int16_t* rhs, |
|
const int32_t* bias, const int32_t* nnz_per_row, |
|
const int16_t* rhs_indices, int start_row, int end_row, |
|
bool relu, int shift_out, int replicas, int stride, |
|
int16_t* output) { |
|
if (replicas <= 1) { |
|
MatVec4x4FixedAVX2Template<int16_t, 1>(weights_ptr, rhs, bias, nnz_per_row, |
|
rhs_indices, start_row, end_row, |
|
relu, shift_out, 1, stride, output); |
|
} else if (replicas == 2) { |
|
MatVec4x4FixedAVX2Template<int16_t, 2>(weights_ptr, rhs, bias, nnz_per_row, |
|
rhs_indices, start_row, end_row, |
|
relu, shift_out, 2, stride, output); |
|
} else { |
|
MatVec4x4FixedAVX2Template<int16_t, 3>( |
|
weights_ptr, rhs, bias, nnz_per_row, rhs_indices, start_row, end_row, |
|
relu, shift_out, replicas, stride, output); |
|
} |
|
} |
|
|
|
|
|
|
|
void MatVec4x4FixedAVX2(const int16_t* weights_ptr, const int16_t* rhs, |
|
const int32_t* bias, const int32_t* nnz_per_row, |
|
const int16_t* rhs_indices, int start_row, int end_row, |
|
bool relu, int shift_out, int replicas, int stride, |
|
int32_t* output) { |
|
if (replicas <= 1) { |
|
MatVec4x4FixedAVX2Template<int32_t, 1>(weights_ptr, rhs, bias, nnz_per_row, |
|
rhs_indices, start_row, end_row, |
|
relu, shift_out, 1, stride, output); |
|
} else if (replicas == 2) { |
|
MatVec4x4FixedAVX2Template<int32_t, 2>(weights_ptr, rhs, bias, nnz_per_row, |
|
rhs_indices, start_row, end_row, |
|
relu, shift_out, 2, stride, output); |
|
} else { |
|
MatVec4x4FixedAVX2Template<int32_t, 3>( |
|
weights_ptr, rhs, bias, nnz_per_row, rhs_indices, start_row, end_row, |
|
relu, shift_out, replicas, stride, output); |
|
} |
|
} |
|
|
|
|
|
|
|
|
|
inline __m256i Compute8RowResults(const __m256i& bias256, const int16_t* rhs, |
|
const int16_t* rhs_indices, int nnz, |
|
int16_t const*& weights_ptr) { |
|
|
|
|
|
__m256i sum1 = _mm256_cvtepu32_epi64(_mm256_castsi256_si128(bias256)); |
|
|
|
__m256i sum2 = _mm256_cvtepu32_epi64(_mm256_extractf128_si256(bias256, 1)); |
|
|
|
for (int c = 0; c < nnz; ++c) { |
|
int rhs_index = rhs_indices[c]; |
|
|
|
__m256i weights = |
|
_mm256_load_si256(reinterpret_cast<__m256i const*>(weights_ptr)); |
|
|
|
__m128i rhs_64 = _mm_loadl_epi64( |
|
reinterpret_cast<__m128i const*>(rhs + rhs_index * kBlockSize)); |
|
|
|
|
|
__m256i rhs_value = _mm256_broadcastq_epi64(rhs_64); |
|
weights_ptr += 16; |
|
sum1 = _mm256_add_epi32(sum1, _mm256_madd_epi16(weights, rhs_value)); |
|
|
|
weights = _mm256_load_si256(reinterpret_cast<__m256i const*>(weights_ptr)); |
|
weights_ptr += 16; |
|
sum2 = _mm256_add_epi32(sum2, _mm256_madd_epi16(weights, rhs_value)); |
|
} |
|
|
|
|
|
|
|
sum1 = _mm256_hadd_epi32(sum1, sum2); |
|
|
|
return _mm256_permute4x64_epi64(sum1, 0xd8); |
|
} |
|
|
|
|
|
|
|
void MatVec8x4FixedAVX2(const int16_t* weights_ptr, const int16_t* rhs, |
|
const int32_t* bias, const int32_t* nnz_per_row, |
|
const int16_t* rhs_indices, int start_row, int end_row, |
|
bool relu, int shift_out, int32_t* output) { |
|
int rounding_addon = shift_out > 0 ? (1 << (shift_out - 1)) : 0; |
|
__m256i rounding = _mm256_set1_epi32(rounding_addon); |
|
__m256i zero = relu ? _mm256_setzero_si256() : _mm256_set1_epi32(kint32min); |
|
for (int row_block = start_row; row_block < end_row; ++row_block) { |
|
|
|
__m256i bias256 = _mm256_load_si256(reinterpret_cast<__m256i const*>(bias)); |
|
bias += kBlockSize * 2; |
|
int nnz = nnz_per_row[row_block]; |
|
__m256i sum = |
|
Compute8RowResults(bias256, rhs, rhs_indices, nnz, weights_ptr); |
|
rhs_indices += nnz; |
|
|
|
sum = _mm256_add_epi32(sum, rounding); |
|
sum = _mm256_srai_epi32(sum, shift_out); |
|
|
|
sum = _mm256_max_epi32(sum, zero); |
|
|
|
_mm256_store_si256(reinterpret_cast<__m256i*>(output), sum); |
|
output += kBlockSize * 2; |
|
} |
|
} |
|
|
|
#endif |
|
|
|
} |
|
} |
|
|