// Copyright 2021 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #if defined __aarch64__ #include #endif #if defined __AVX__ || defined __AVX2__ #include #endif #include #include #include #include #include #include "gtest/gtest.h" #include "sparse_matmul/numerics/fast_transcendentals.h" #include "sparse_matmul/numerics/test_utils.h" namespace csrblocksparse { const float kExpFixedRelTolerance = .084f; #ifdef SIGMOID_AS_TANH #if defined FAST_TRANSCENDENTALS && defined ACCURATE_TRANSCENDENTAL_APPROX const float kSigmoidRelTolerance = .093f; // 9.3% relative const float kSigmoidAbsTolerance = .0005f; const float kSigmoidFixedRelTolerance = .093f; const float kSigmoidFixedAbsTolerance = .0005f; #elif defined FAST_TRANSCENDENTALS const float kSigmoidRelTolerance = .09f; // 9.0% relative const float kSigmoidAbsTolerance = .003f; const float kSigmoidFixedRelTolerance = .09f; const float kSigmoidFixedAbsTolerance = .003f; #endif #elif defined FAST_TRANSCENDENTALS and defined ACCURATE_TRANSCENDENTAL_APPROX const float kSigmoidRelTolerance = .102f; // 10.2% relative const float kSigmoidAbsTolerance = .0003f; const float kSigmoidFixedRelTolerance = .102f; const float kSigmoidFixedAbsTolerance = .0003f; #elif defined FAST_TRANSCENDENTALS const float kSigmoidRelTolerance = .09f; // 9.0% relative const float kSigmoidAbsTolerance = .006f; const float kSigmoidFixedRelTolerance = .09f; const float kSigmoidFixedAbsTolerance = .006f; #else const float kSigmoidRelTolerance = .0001f; const float kSigmoidAbsTolerance = 1e-5f; const float kSigmoidFixedRelTolerance = .001f; const float kSigmoidFixedAbsTolerance = .001f; #endif #if (defined FAST_TRANSCENDENTALS && defined ACCURATE_TRANSCENDENTAL_APPROX || \ defined FASTER_TRANSCENDENTALS) const float kExpRelTolerance = .03f; // 3% relative const float kTanhRelTolerance = .006f; // .6% relative const float kTanhAbsTolerance = .0003f; #elif defined FAST_TRANSCENDENTALS const float kExpRelTolerance = .03f; // 3% relative const float kTanhRelTolerance = .091f; // .91% relative const float kTanhAbsTolerance = .00525f; #else const float kExpRelTolerance = .0001f; const float kTanhRelTolerance = .0001f; const float kTanhAbsTolerance = 1e-5f; #endif constexpr float kQuarticFloatExpRelTolerance = 8e-6f; constexpr float kQuarticFloatExpTolerance = 9e-6f; constexpr float kQuarticExpRelTolerance = 3e-5f; constexpr float kQuarticExpTolerance = 6e-5f; constexpr float kCubicExpRelTolerance = 6e-4f; constexpr float kCubicExpTolerance = 2e-3f; constexpr float kQuarticFloatTanhRelTolerance = 3e-5f; constexpr float kQuarticFloatTanhTolerance = 3e-6f; constexpr float kCubicTanhRelTolerance = 3e-3f; constexpr float kCubicTanhTolerance = 3e-4f; constexpr float kQuarticSigmoidRelTolerance = 3e-5f; constexpr float kQuarticSigmoidTolerance = 7e-6f; constexpr float kCubicSigmoidRelTolerance = 6e-4f; constexpr float kCubicSigmoidTolerance = 2e-4f; #ifdef __AVX2__ constexpr float kQuarticTanhRelTolerance = 1e-4f; constexpr float kQuarticTanhTolerance = 2e-5f; constexpr float kQuarticFloatSigmoidRelTolerance = 4e-6f; constexpr float kQuarticFloatSigmoidTolerance = 1e-6f; #endif // __AVX2__ TEST(Transcendentals, Exp) { // 132 - 127 = 5, we check between -63.99... and 63.99... const int maxExponent = 132; const int minExponent = 0; float max_error = 0.f; constexpr int kExponentBits = 7; for (int s = 0; s < 2; ++s) { for (int e = minExponent; e < maxExponent; ++e) { // Don't check every mantissa for speed reasons. for (int m = 0; m < (1 << 23); m += (1 << 10)) { uint32_t int_val = s << 31 | e << 23 | m; float x; memcpy(&x, &int_val, sizeof(float)); float exact_exp = expf(x); float approx_exp = csrblocksparse::fast_exp(x); float approx_exp_fixed = csrblocksparse::fast_exp( csrblocksparse::fixed32(x)); float rel_diff = RelDiff(exact_exp, approx_exp); float rel_diff_fixed = RelDiff(exact_exp, approx_exp_fixed); max_error = std::max(max_error, rel_diff); EXPECT_LT(rel_diff, kExpRelTolerance) << exact_exp << " " << approx_exp << " " << x; EXPECT_LT(rel_diff_fixed, kExpRelTolerance) << exact_exp << " " << approx_exp << " " << x; } } } } TEST(Transcendentals, FixedExp) { const int maxExponent = 132; const int minExponent = 120; float max_error = 0.f; float max_abs_error = 0.f; for (int s = 0; s < 2; ++s) { for (int e = minExponent; e < maxExponent; ++e) { // Don't check every mantissa for speed reasons. for (int m = 0; m < (1 << 23); m += (1 << 10)) { uint32_t int_val = s << 31 | e << 23 | m; float x; memcpy(&x, &int_val, sizeof(float)); float exact_exp = expf(x); float approx_exp = csrblocksparse::fast_exp_fixed(csrblocksparse::fixed32<16>(x)); float rel_diff = RelDiff(exact_exp, approx_exp); float abs_diff = std::abs(exact_exp - approx_exp); max_error = std::max(max_error, rel_diff); max_abs_error = std::max(max_abs_error, abs_diff); EXPECT_LT(rel_diff, kExpFixedRelTolerance) << exact_exp << " " << approx_exp << " " << x; } } } LOG(INFO) << "Max relative exp error = " << max_error << ", abs=" << max_abs_error; } template void TestExp(float abs_tolerance, float rel_tolerance) { constexpr int kMaxInput = 80 << 16; constexpr int kMinInput = -(80 << 16); constexpr int kExponentBits = 15; float max_error = 0.f; float max_abs_error = 0.f; for (int i = kMinInput; i <= kMaxInput; ++i) { csrblocksparse::fixed32 fixed_int(i); float x = static_cast(fixed_int); float exact_exp = expf(x); float approx_exp = fixed32_exp(fixed_int); float diff = exact_exp - approx_exp; float abs_diff = std::abs(diff); float rel_diff = RelDiff(exact_exp, approx_exp); max_error = std::max(max_error, rel_diff); if (x <= 1.0f) { ASSERT_LT(abs_diff, abs_tolerance) << "x=" << x << ", target=" << exact_exp << ", aprx=" << approx_exp; max_abs_error = std::max(max_abs_error, abs_diff); } ASSERT_LT(rel_diff, rel_tolerance) << "x=" << x << ", target=" << exact_exp << ", aprx=" << approx_exp; } LOG(INFO) << "Max relative error = " << max_error << ", abs=" << max_abs_error; } TEST(Transcendentals, QuarticExp) { TestExp(kQuarticFloatExpTolerance, kQuarticFloatExpRelTolerance); } TEST(Transcendentals, CubicExp) { TestExp(kCubicExpTolerance, kCubicExpRelTolerance); } template void TestTanh(float abs_tolerance, float rel_tolerance) { constexpr int kMaxInput = (40 << 16); constexpr int kMinInput = -(40 << 16); constexpr int kExponentBits = 15; float max_error = 0.f; float max_abs_error = 0.f; for (int i = kMinInput; i <= kMaxInput; ++i) { csrblocksparse::fixed32 fixed_int(i); float x = static_cast(fixed_int); float exact_tanh = tanh(x); float approx_tanh = fixed32_tanh(fixed_int); float diff = exact_tanh - approx_tanh; float abs_diff = std::abs(diff); float rel_diff = RelDiff(exact_tanh, approx_tanh); ASSERT_LT(abs_diff, abs_tolerance) << "x=" << x << ", target=" << exact_tanh << ", aprx=" << approx_tanh; max_abs_error = std::max(max_abs_error, abs_diff); max_error = std::max(max_error, rel_diff); ASSERT_LT(rel_diff, rel_tolerance) << "x=" << x << ", target=" << exact_tanh << ", aprx=" << approx_tanh; } LOG(INFO) << "Max relative error = " << max_error << ", abs=" << max_abs_error; } TEST(Transcendentals, QuarticTanh) { TestTanh(kQuarticFloatTanhTolerance, kQuarticFloatTanhRelTolerance); } TEST(Transcendentals, CubicTanh) { TestTanh(kCubicTanhTolerance, kCubicTanhRelTolerance); } template void TestSigmoid(float abs_tolerance, float rel_tolerance) { constexpr int kMaxInput = 80 << 16; constexpr int kMinInput = -(80 << 16); constexpr int kExponentBits = 15; float max_error = 0.f; float max_abs_error = 0.f; for (int i = kMinInput; i <= kMaxInput; ++i) { csrblocksparse::fixed32 fixed_int(i); float x = static_cast(fixed_int); float exact_sigmoid = 1.0f / (1.0f + exp(-x)); float approx_sigmoid = fixed32_sigmoid(fixed_int); float diff = exact_sigmoid - approx_sigmoid; float abs_diff = std::abs(diff); float rel_diff = RelDiff(exact_sigmoid, approx_sigmoid); max_error = std::max(max_error, rel_diff); ASSERT_LT(abs_diff, abs_tolerance) << "x=" << x << ", target=" << exact_sigmoid << ", aprx=" << approx_sigmoid; max_abs_error = std::max(max_abs_error, abs_diff); ASSERT_LT(rel_diff, rel_tolerance) << "x=" << x << ", target=" << exact_sigmoid << ", aprx=" << approx_sigmoid; } LOG(INFO) << "Max relative sigmoid error = " << max_error << ", abs=" << max_abs_error; } TEST(Transcendentals, QuarticSigmoidExp) { TestSigmoid(kQuarticSigmoidTolerance, kQuarticSigmoidRelTolerance); } TEST(Transcendentals, CubicSigmoidExp) { TestSigmoid(kCubicSigmoidTolerance, kCubicSigmoidRelTolerance); } TEST(Transcendentals, Sigmoid) { // 132 - 127 = 5, we check between -63.99... and 63.99... const int maxExponent = 132; const int minExponent = 0; // The mantissa bits must not exceed 23, so min exponent bits here is: // 31 - 23 = 8. constexpr int kExponentBits = 9; float max_error = 0.f; float max_abs_error = 0.f; #if defined __aarch64__ float max_vector_error = 0.f; float max_vector_abs_error = 0.f; #endif for (int s = 0; s < 2; ++s) { for (int e = minExponent; e < maxExponent; ++e) { // Don't check every mantissa for speed reasons. for (int m = 0; m < (1 << 23); m += (1 << 10)) { uint32_t int_val = s << 31 | e << 23 | m; float x; memcpy(&x, &int_val, sizeof(float)); float exact_sigmoid = 1. / (1. + expf(-x)); float approx_sigmoid = csrblocksparse::fast_sigmoid(x); float approx_sigmoid_fixed = csrblocksparse::fast_sigmoid( csrblocksparse::fixed32(x)); float rel_diff = RelDiff(exact_sigmoid, approx_sigmoid); float abs_diff = std::abs(exact_sigmoid - approx_sigmoid); float rel_diff_fixed = RelDiff(exact_sigmoid, approx_sigmoid_fixed); max_error = std::max(max_error, rel_diff); max_abs_error = std::max(max_abs_error, abs_diff); EXPECT_LT(rel_diff, kSigmoidRelTolerance) << exact_sigmoid << " " << approx_sigmoid << " " << x; EXPECT_NEAR(approx_sigmoid, exact_sigmoid, kSigmoidAbsTolerance) << x; EXPECT_LT(rel_diff_fixed, kSigmoidFixedRelTolerance) << exact_sigmoid << " " << approx_sigmoid_fixed << " " << x; EXPECT_NEAR(approx_sigmoid_fixed, exact_sigmoid, kSigmoidFixedAbsTolerance) << x; #if defined __aarch64__ constexpr int kSIMD_WIDTH = 4; float approx_results[kSIMD_WIDTH]; int32x4_t input = vdupq_n_s32(csrblocksparse::fixed32(x).raw_val()); float32x4_t result = csrblocksparse::fast_sigmoid(input); vst1q_f32(approx_results, result); for (int i = 0; i < kSIMD_WIDTH; ++i) { float rel_diff = RelDiff(exact_sigmoid, approx_results[i]); float abs_diff = std::abs(exact_sigmoid - approx_results[i]); max_vector_error = std::max(max_vector_error, rel_diff); max_vector_abs_error = std::max(max_vector_abs_error, abs_diff); EXPECT_LT(rel_diff, kSigmoidRelTolerance) << exact_sigmoid << " " << approx_sigmoid << " " << x; EXPECT_NEAR(approx_sigmoid, exact_sigmoid, kSigmoidAbsTolerance) << x; } #endif } } } LOG(INFO) << "Max relative error in float sigmoid=" << max_error; LOG(INFO) << "Max abs error in float sigmoid=" << max_abs_error; #if defined __aarch64__ LOG(INFO) << "Max relative vector error fixed sigmoid=" << max_vector_error; LOG(INFO) << "Max abs vector error fixed sigmoid=" << max_vector_abs_error; #endif } TEST(Transcendentals, Tanh) { // 132 - 127 = 5, we check between -63.99... and 63.99... const int maxExponent = 132; const int minExponent = 0; float max_error = 0.f; float max_abs_error = 0.f; for (int s = 0; s < 2; ++s) { for (int e = minExponent; e < maxExponent; ++e) { // Don't check every mantissa for speed reasons. for (int m = 0; m < (1 << 23); m += (1 << 10)) { uint32_t int_val = s << 31 | e << 23 | m; float x; memcpy(&x, &int_val, sizeof(float)); float exact_tanh = tanhf(x); float approx_tanh = csrblocksparse::fast_tanh(x); float rel_diff = RelDiff(exact_tanh, approx_tanh); float abs_diff = std::abs(exact_tanh - approx_tanh); max_error = std::max(rel_diff, max_error); max_abs_error = std::max(abs_diff, max_abs_error); EXPECT_LT(rel_diff, kTanhRelTolerance) << exact_tanh << " " << approx_tanh << " " << x; EXPECT_NEAR(approx_tanh, exact_tanh, kTanhAbsTolerance) << x; } } } LOG(INFO) << "Max relative error in float tanh=" << max_error; LOG(INFO) << "Max abs error in float tanh=" << max_abs_error; // tanh behavior is not identical across all lanes, so need to test // with some values in the linear region and some not. #if defined __aarch64__ float vals[4] = {-1.f, -.1f, .1f, 1.f}; float exact_results[4]; float approx_results[4]; max_error = 0.f; max_abs_error = 0.f; float32x4_t input = vld1q_f32(vals); float32x4_t result = csrblocksparse::fast_tanh(input); vst1q_f32(approx_results, result); for (int i = 0; i < 4; ++i) { exact_results[i] = tanh(vals[i]); float rel_diff = RelDiff(exact_results[i], approx_results[i]); float abs_diff = std::abs(exact_results[i] - approx_results[i]); max_error = std::max(rel_diff, max_error); max_abs_error = std::max(abs_diff, max_abs_error); EXPECT_LT(rel_diff, kTanhRelTolerance) << exact_results[i] << " " << approx_results[i] << " " << vals[i]; EXPECT_NEAR(approx_results[i], exact_results[i], kTanhAbsTolerance) << vals[i]; } LOG(INFO) << "Max relative vector error in float tanh=" << max_error; LOG(INFO) << "Max abs vector error in float tanh=" << max_abs_error; #endif } #if defined __AVX2__ constexpr int kSIMDSize = 8; constexpr int kNumExpBitsIn = 10; constexpr int kNumExpBitsOut = 5; TEST(Transcendentals, TanhLut) { // Test every value in (-1, 1) for round-trip exactness. constexpr int kNumMantissaBitsIn = fixed32::kMantissaBits; constexpr int kNumMantissaBitsOut = fixed16::kMantissaBits; const int32_t* tanh_table = TanhTable(kNumMantissaBitsOut); float in_factor = static_cast(1 << kNumMantissaBitsIn); float out_factor = static_cast(1 << kNumMantissaBitsOut); for (int i = 1 - (1 << kNumMantissaBitsOut); i + kSIMDSize < (1 << kNumMantissaBitsOut); i += kSIMDSize) { int32_t inputs[kSIMDSize]; int32_t outputs[kSIMDSize]; int32_t target_outputs[kSIMDSize]; for (int j = 0; j < kSIMDSize; ++j) { float target_tanh = (i + j) / out_factor; float x = atanhf(static_cast(target_tanh)); inputs[j] = static_cast(x * in_factor); target_outputs[j] = i + j; } __m256i x_in = _mm256_loadu_si256(reinterpret_cast<__m256i*>(inputs)); __m256i output = fixed32_tanh_fixed16( tanh_table, x_in); _mm256_storeu_si256(reinterpret_cast<__m256i*>(outputs), output); for (int j = 0; j < kSIMDSize; ++j) { EXPECT_EQ(target_outputs[j], outputs[j]); } } } TEST(Transcendentals, SigmoidLut) { // Test every value in (-1, 1) for round-trip exactness. constexpr int kNumMantissaBitsIn = fixed32::kMantissaBits; constexpr int kNumMantissaBitsOut = fixed16::kMantissaBits; const int32_t* sigmoid_table = SigmoidTable(kNumMantissaBitsOut); float in_factor = static_cast(1 << kNumMantissaBitsIn); float out_factor = static_cast(1 << kNumMantissaBitsOut); for (int i = 1; i + kSIMDSize < (1 << kNumMantissaBitsOut); i += kSIMDSize) { int32_t inputs[kSIMDSize]; int32_t outputs[kSIMDSize]; int32_t target_outputs[kSIMDSize]; for (int j = 0; j < kSIMDSize; ++j) { float target_sigmoid = (i + j) / out_factor; float x = 2.0f * atanhf(2.0f * static_cast(target_sigmoid) - 1.0f); inputs[j] = static_cast(x * in_factor); target_outputs[j] = i + j; } __m256i x_in = _mm256_loadu_si256(reinterpret_cast<__m256i*>(inputs)); __m256i output = fixed32_sigmoid_fixed16( sigmoid_table, x_in); _mm256_storeu_si256(reinterpret_cast<__m256i*>(outputs), output); for (int j = 0; j < kSIMDSize; ++j) { EXPECT_EQ(target_outputs[j], outputs[j]); } } } template static void TestExpAVX2(float abs_tolerance, float rel_tolerance) { constexpr int kMantissaBits = 20; // Test every value in [-80, 80] and report the max error. constexpr int kMinInput = -(80 << kMantissaBits); constexpr int kMaxInput = 80 << kMantissaBits; constexpr int kNumInputs = kMaxInput - kMinInput; std::vector inputs(kNumInputs); std::vector outputs(kNumInputs); std::vector target_outputs(kNumInputs); for (int i = 0; i < inputs.size(); ++i) { csrblocksparse::fixed32<31 - kMantissaBits> fixed_int(i + kMinInput); float x = static_cast(fixed_int); inputs[i] = fixed_int.raw_val(); target_outputs[i] = expf(x); } absl::Time t_start = absl::Now(); for (int i = 0; i + kSIMDSize * 2 <= kNumInputs; i += kSIMDSize * 2) { __m256i x0 = _mm256_loadu_si256(reinterpret_cast(inputs.data() + i)); __m256i x1 = _mm256_loadu_si256( reinterpret_cast(inputs.data() + i + kSIMDSize)); __m256 y0, y1; fixed32_exp_float(x0, x1, y0, y1); _mm256_storeu_ps(outputs.data() + i, y0); _mm256_storeu_ps(outputs.data() + i + kSIMDSize, y1); } LOG(INFO) << "Time=" << absl::ToDoubleMilliseconds(absl::Now() - t_start); float max_error = 0.f; float max_abs_error = 0.f; for (int i = 0; i < kNumInputs; ++i) { float diff = target_outputs[i] - outputs[i]; float abs_diff = std::abs(diff); csrblocksparse::fixed32<31 - kMantissaBits> fixed_int(i + kMinInput); float x = static_cast(fixed_int); float rel_diff = RelDiff(target_outputs[i], outputs[i]); max_error = std::max(max_error, rel_diff); if (x <= 1.0f) { ASSERT_LT(abs_diff, abs_tolerance) << "x=" << x << ", target=" << target_outputs[i] << ", result= " << outputs[i] << ", i=" << i; max_abs_error = std::max(max_abs_error, abs_diff); } ASSERT_LT(rel_diff, rel_tolerance) << "x=" << x << ", target=" << target_outputs[i] << ", result= " << outputs[i] << ", i=" << i; } LOG(INFO) << "Max relative error = " << max_error << ", abs=" << max_abs_error; } TEST(Transcendentals, QuarticFloatExpAVX2) { TestExpAVX2(kQuarticFloatExpTolerance, kQuarticFloatExpRelTolerance); } TEST(Transcendentals, QuarticExpAVX2) { TestExpAVX2(kQuarticExpTolerance, kQuarticExpRelTolerance); } TEST(Transcendentals, CubicExpAVX2) { TestExpAVX2(kCubicExpTolerance, kCubicExpRelTolerance); } template void TestTanhAVX2Float(float abs_tolerance, float rel_tolerance) { constexpr int kMantissaBits = 16; // Test every value in [-10, 10] and report the max error. constexpr int kMinInput = -(10 << kMantissaBits); constexpr int kMaxInput = 10 << kMantissaBits; constexpr int kNumInputs = kMaxInput - kMinInput; float max_error = 0.f; float max_abs_error = 0.f; std::vector inputs(kNumInputs); std::vector outputs(kNumInputs); std::vector target_outputs(kNumInputs); for (int i = 0; i < inputs.size(); ++i) { csrblocksparse::fixed32<31 - kMantissaBits> fixed_int(i + kMinInput); float x = static_cast(fixed_int); float exact = tanh(x); inputs[i] = static_cast(fixed_int.raw_val()); target_outputs[i] = exact; } absl::Time t_start = absl::Now(); for (int i = 0; i + kSIMDSize * 2 <= inputs.size(); i += kSIMDSize * 2) { __m256 x0 = _mm256_loadu_ps(inputs.data() + i); __m256 x1 = _mm256_loadu_ps(inputs.data() + kSIMDSize + i); __m256 y0, y1; float_tanh_float(x0, x1, y0, y1); _mm256_storeu_ps(outputs.data() + i, y0); _mm256_storeu_ps(outputs.data() + i + kSIMDSize, y1); } LOG(INFO) << "Time=" << absl::ToDoubleMilliseconds(absl::Now() - t_start); float worst_abs_x = 0.0f, worst_rel_x = 0.0f; for (int i = 0; i < inputs.size(); ++i) { float diff = target_outputs[i] - outputs[i]; float abs_diff = std::abs(diff); csrblocksparse::fixed32<31 - kMantissaBits> fixed_int(i + kMinInput); float x = static_cast(fixed_int); ASSERT_LT(abs_diff, abs_tolerance) << "x=" << x << ", target=" << target_outputs[i] << ", aprx=" << outputs[i]; if (abs_diff > max_abs_error) worst_abs_x = x; max_abs_error = std::max(max_abs_error, abs_diff); float rel_diff = 0.0f; rel_diff = RelDiff(target_outputs[i], outputs[i]); if (rel_diff > max_error) worst_rel_x = x; max_error = std::max(max_error, rel_diff); ASSERT_LT(rel_diff, rel_tolerance) << "x=" << x << ", target=" << target_outputs[i] << ", aprx=" << outputs[i]; } LOG(INFO) << "Max relative error = " << max_error << ", abs=" << max_abs_error; LOG(INFO) << "Worst rel x = " << worst_rel_x << ", abs=" << worst_abs_x; } TEST(Transcendentals, QuarticTanhFloatAVX2Float) { TestTanhAVX2Float(kQuarticFloatTanhTolerance, kQuarticFloatTanhRelTolerance); } TEST(Transcendentals, QuarticTanhAVX2Float) { TestTanhAVX2Float(kQuarticTanhTolerance, kQuarticTanhRelTolerance); } TEST(Transcendentals, CubicTanhAVX2Float) { TestTanhAVX2Float(kCubicTanhTolerance, kCubicTanhRelTolerance); } template void TestSigmoidAVX2Float(float abs_tolerance, float rel_tolerance) { constexpr int kMantissaBits = 20; // Test every value in [-20, 20] and report the max error. constexpr int kMaxInput = 20 << kMantissaBits; constexpr int kMinInput = -(20 << kMantissaBits); float max_error = 0.f; float max_abs_error = 0.f; std::vector inputs(kMaxInput - kMinInput); std::vector outputs(kMaxInput - kMinInput); std::vector target_outputs(kMaxInput - kMinInput); for (int i = 0; i < inputs.size(); ++i) { csrblocksparse::fixed32<31 - kMantissaBits> fixed_int(i + kMinInput); float x = static_cast(fixed_int); float exact = 1.0f / (1.0f + expf(-x)); inputs[i] = fixed_int.raw_val(); target_outputs[i] = exact; } absl::Time t_start = absl::Now(); for (int i = 0; i + kSIMDSize * 2 <= inputs.size(); i += kSIMDSize * 2) { __m256i x0 = _mm256_loadu_si256(reinterpret_cast(inputs.data() + i)); __m256i x1 = _mm256_loadu_si256( reinterpret_cast(inputs.data() + i + kSIMDSize)); __m256 y0 = _mm256_cvtepi32_ps(x0); __m256 y1 = _mm256_cvtepi32_ps(x1); float_sigmoid_float(y0, y1); _mm256_storeu_ps(outputs.data() + i, y0); _mm256_storeu_ps(outputs.data() + i + kSIMDSize, y1); } LOG(INFO) << "Time=" << absl::ToDoubleMilliseconds(absl::Now() - t_start); for (int i = 0; i < inputs.size(); ++i) { float diff = target_outputs[i] - outputs[i]; float abs_diff = std::abs(diff); csrblocksparse::fixed32<31 - kMantissaBits> fixed_int(i + kMinInput); float x = static_cast(fixed_int); float rel_diff = RelDiff(target_outputs[i], outputs[i]); max_error = std::max(max_error, rel_diff); ASSERT_LT(abs_diff, abs_tolerance) << "x=" << x << ", target=" << target_outputs[i] << ", aprx=" << outputs[i]; max_abs_error = std::max(max_abs_error, abs_diff); ASSERT_LT(rel_diff, rel_tolerance) << "x=" << x << ", target=" << target_outputs[i] << ", aprx=" << outputs[i]; } LOG(INFO) << "Max relative error = " << max_error << ", abs=" << max_abs_error; } TEST(Transcendentals, QuarticSigmoidFloatAVX2Float) { TestSigmoidAVX2Float(kQuarticFloatSigmoidTolerance, kQuarticFloatSigmoidRelTolerance); } TEST(Transcendentals, QuarticSigmoidAVX2Float) { TestSigmoidAVX2Float(kQuarticSigmoidTolerance, kQuarticSigmoidRelTolerance); } TEST(Transcendentals, CubicSigmoidAVX2Float) { TestSigmoidAVX2Float(kCubicSigmoidTolerance, kCubicSigmoidRelTolerance); } #endif // __AVX2__ } // namespace csrblocksparse