File size: 5,111 Bytes
8b7c501 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
// Copyright 2022 Google LLC
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.
#include <assert.h>
#include <stddef.h>
#include <stdint.h>
#include <xnnpack/math.h>
#include <xnnpack/fft.h>
#include <arm_neon.h>
void xnn_cs16_bfly4_ukernel__neon_x4(
size_t batch,
size_t samples,
int16_t* data,
const int16_t* twiddle,
size_t stride)
{
assert(batch != 0);
assert(samples != 0);
assert(samples % (sizeof(int16_t) * 8) == 0);
assert(data != NULL);
assert(stride != 0);
assert(twiddle != NULL);
const int16x4_t vdiv4 = vdup_n_s16(8191);
int16_t* data3 = data;
do {
int16_t* data0 = data3;
int16_t* data1 = (int16_t*) ((uintptr_t) data0 + samples);
int16_t* data2 = (int16_t*) ((uintptr_t) data1 + samples);
data3 = (int16_t*) ((uintptr_t) data2 + samples);
const int16_t* tw1 = twiddle;
const int16_t* tw2 = twiddle;
const int16_t* tw3 = twiddle;
size_t s = samples;
for (; s >= sizeof(int16_t) * 8; s -= sizeof(int16_t) * 8) {
int16x4x2_t vout0 = vld2_s16(data0);
int16x4x2_t vout1 = vld2_s16(data1);
int16x4x2_t vout2 = vld2_s16(data2);
int16x4x2_t vout3 = vld2_s16(data3);
int16x4x2_t vtw1 = vld2_dup_s16(tw1);
int16x4x2_t vtw2 = vld2_dup_s16(tw2);
int16x4x2_t vtw3 = vld2_dup_s16(tw3);
tw1 = (const int16_t*) ((uintptr_t) tw1 + stride);
tw2 = (const int16_t*) ((uintptr_t) tw2 + stride * 2);
tw3 = (const int16_t*) ((uintptr_t) tw3 + stride * 3);
vtw1 = vld2_lane_s16(tw1, vtw1, 1);
vtw2 = vld2_lane_s16(tw2, vtw2, 1);
vtw3 = vld2_lane_s16(tw3, vtw3, 1);
tw1 = (const int16_t*) ((uintptr_t) tw1 + stride);
tw2 = (const int16_t*) ((uintptr_t) tw2 + stride * 2);
tw3 = (const int16_t*) ((uintptr_t) tw3 + stride * 3);
vtw1 = vld2_lane_s16(tw1, vtw1, 2);
vtw2 = vld2_lane_s16(tw2, vtw2, 2);
vtw3 = vld2_lane_s16(tw3, vtw3, 2);
tw1 = (const int16_t*) ((uintptr_t) tw1 + stride);
tw2 = (const int16_t*) ((uintptr_t) tw2 + stride * 2);
tw3 = (const int16_t*) ((uintptr_t) tw3 + stride * 3);
vtw1 = vld2_lane_s16(tw1, vtw1, 3);
vtw2 = vld2_lane_s16(tw2, vtw2, 3);
vtw3 = vld2_lane_s16(tw3, vtw3, 3);
tw1 = (const int16_t*) ((uintptr_t) tw1 + stride);
tw2 = (const int16_t*) ((uintptr_t) tw2 + stride * 2);
tw3 = (const int16_t*) ((uintptr_t) tw3 + stride * 3);
// Note 32767 / 4 = 8191. Should be 8192.
vout1.val[0] = vqrdmulh_s16(vout1.val[0], vdiv4);
vout1.val[1] = vqrdmulh_s16(vout1.val[1], vdiv4);
vout2.val[0] = vqrdmulh_s16(vout2.val[0], vdiv4);
vout2.val[1] = vqrdmulh_s16(vout2.val[1], vdiv4);
vout3.val[0] = vqrdmulh_s16(vout3.val[0], vdiv4);
vout3.val[1] = vqrdmulh_s16(vout3.val[1], vdiv4);
vout0.val[0] = vqrdmulh_s16(vout0.val[0], vdiv4);
vout0.val[1] = vqrdmulh_s16(vout0.val[1], vdiv4);
int32x4_t vacc0r = vmull_s16(vout1.val[0], vtw1.val[0]);
int32x4_t vacc1r = vmull_s16(vout2.val[0], vtw2.val[0]);
int32x4_t vacc2r = vmull_s16(vout3.val[0], vtw3.val[0]);
int32x4_t vacc0i = vmull_s16(vout1.val[0], vtw1.val[1]);
int32x4_t vacc1i = vmull_s16(vout2.val[0], vtw2.val[1]);
int32x4_t vacc2i = vmull_s16(vout3.val[0], vtw3.val[1]);
vacc0r = vmlsl_s16(vacc0r, vout1.val[1], vtw1.val[1]);
vacc1r = vmlsl_s16(vacc1r, vout2.val[1], vtw2.val[1]);
vacc2r = vmlsl_s16(vacc2r, vout3.val[1], vtw3.val[1]);
vacc0i = vmlal_s16(vacc0i, vout1.val[1], vtw1.val[0]);
vacc1i = vmlal_s16(vacc1i, vout2.val[1], vtw2.val[0]);
vacc2i = vmlal_s16(vacc2i, vout3.val[1], vtw3.val[0]);
int16x4_t vtmp0r = vrshrn_n_s32(vacc0r, 15);
int16x4_t vtmp1r = vrshrn_n_s32(vacc1r, 15);
int16x4_t vtmp2r = vrshrn_n_s32(vacc2r, 15);
int16x4_t vtmp0i = vrshrn_n_s32(vacc0i, 15);
int16x4_t vtmp1i = vrshrn_n_s32(vacc1i, 15);
int16x4_t vtmp2i = vrshrn_n_s32(vacc2i, 15);
const int16x4_t vtmp4r = vsub_s16(vtmp0r, vtmp2r);
const int16x4_t vtmp4i = vsub_s16(vtmp0i, vtmp2i);
const int16x4_t vtmp3r = vadd_s16(vtmp0r, vtmp2r);
const int16x4_t vtmp3i = vadd_s16(vtmp0i, vtmp2i);
const int16x4_t vtmp5r = vsub_s16(vout0.val[0], vtmp1r);
const int16x4_t vtmp5i = vsub_s16(vout0.val[1], vtmp1i);
vout0.val[0] = vadd_s16(vout0.val[0], vtmp1r);
vout0.val[1] = vadd_s16(vout0.val[1], vtmp1i);
vout2.val[0] = vsub_s16(vout0.val[0], vtmp3r);
vout2.val[1] = vsub_s16(vout0.val[1], vtmp3i);
vout0.val[0] = vadd_s16(vout0.val[0], vtmp3r);
vout0.val[1] = vadd_s16(vout0.val[1], vtmp3i);
vout1.val[0] = vadd_s16(vtmp5r, vtmp4i);
vout1.val[1] = vsub_s16(vtmp5i, vtmp4r);
vout3.val[0] = vsub_s16(vtmp5r, vtmp4i);
vout3.val[1] = vadd_s16(vtmp5i, vtmp4r);
vst2_s16(data0, vout0); data0 += 8;
vst2_s16(data1, vout1); data1 += 8;
vst2_s16(data2, vout2); data2 += 8;
vst2_s16(data3, vout3); data3 += 8;
}
} while (--batch != 0);
}
|