File size: 2,347 Bytes
8b7c501
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
// Copyright 2019 Google LLC
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.

#include <assert.h>

#include <arm_neon.h>

#include <xnnpack/gavgpool.h>
#include <xnnpack/math.h>


void xnn_f32_gavgpool_cw_ukernel__neon_x4(
    size_t elements,
    size_t channels,
    const float* input,
    float* output,
    const union xnn_f32_gavgpool_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS
{
  assert(elements != 0);
  assert(elements % sizeof(float) == 0);
  assert(channels != 0);

  const uint32x4_t vmask = vld1q_u32(params->neon.mask);
  const float32x2_t vmultiplier = vld1_dup_f32(&params->neon.multiplier);
  const float32x2_t voutput_min = vld1_dup_f32(&params->neon.output_min);
  const float32x2_t voutput_max = vld1_dup_f32(&params->neon.output_max);

  do {
    float32x4_t vsum0 = vmovq_n_f32(0.0f);
    size_t n = elements;
    if (n >= 16 * sizeof(float)) {
      float32x4_t vsum1 = vmovq_n_f32(0.0f);
      do {
        const float32x4_t vi0 = vld1q_f32(input);
        const float32x4_t vi1 = vld1q_f32(input + 4);
        const float32x4_t vi2 = vld1q_f32(input + 8);
        const float32x4_t vi3 = vld1q_f32(input + 12);
        input += 16;
        const float32x4_t acc0 = vaddq_f32(vi0, vi1);
        const float32x4_t acc1 = vaddq_f32(vi2, vi3);
        vsum0 = vaddq_f32(vsum0, acc0);
        vsum1 = vaddq_f32(vsum1, acc1);
        n -= 16 * sizeof(float);
      } while (n >= 32 * sizeof(float));
      vsum0 = vaddq_f32(vsum0, vsum1);
    }

    while (n >= 4 * sizeof(float)) {
      const float32x4_t vi0 = vld1q_f32(input);
      input += 4;
      vsum0 = vaddq_f32(vsum0, vi0);
      n -= 4 * sizeof(float);
    }

    if XNN_UNLIKELY(n != 0) {
      float32x4_t vi0 = vld1q_f32(input); input = (const float*) ((uintptr_t) input + n);

      vi0 = vreinterpretq_f32_u32(vandq_u32(vmask, vreinterpretq_u32_f32(vi0)));

      vsum0 = vaddq_f32(vsum0, vi0);
    }

    const float32x2_t vout2 = vpadd_f32(vget_low_f32(vsum0), vget_high_f32(vsum0));
    const float32x2_t vout1 = vpadd_f32(vout2, vout2);

    float32x2_t vout = vmul_f32(vout1, vmultiplier);
    vout = vmax_f32(vout, voutput_min);
    vout = vmin_f32(vout, voutput_max);

    vst1_lane_f32(output, vout, 0); output += 1;
  } while (--channels != 0);
}