blob: fba3a0443f3e8c11f877720cc564d77a6599d201 [file] [edit]
// Copyright 2019 Google LLC
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.
$assert BATCH_TILE % 4 == 0
$assert BATCH_TILE >= 4
$SIMD_TILE = BATCH_TILE // 4
$assert OP in ["ADD", "DIV", "MAX", "MIN", "MUL", "SUB", "SQRDIFF", "PRELU"]
#include <arm_neon.h>
#include <assert.h>
#include <stddef.h>
#include <stdint.h>
#include "src/xnnpack/common.h"
#include "src/xnnpack/microparams.h"
#include "src/xnnpack/vbinary.h"
$VOPQ_F32 = {
$ "ADD": "vaddq_f32",
$ "DIV": "vdivq_f32",
$ "MAX": "vmaxq_f32",
$ "MIN": "vminq_f32",
$ "MUL": "vmulq_f32",
$ "SUB": "vsubq_f32",
$ "SQRDIFF": "vsubq_f32",
$ "PRELU": "vmulq_f32",
$}[OP]
$ISA = "aarch64_neon" if OP == "DIV" else "neon"
void xnn_f32_v${OP.lower()}_ukernel__${ISA}_u${BATCH_TILE}(
size_t batch,
const float* input_a,
const float* input_b,
float* output,
const struct xnn_f32_default_params* restrict params) XNN_OOB_READS
{
assert(batch != 0);
assert(batch % sizeof(float) == 0);
assert(input_a != NULL);
assert(input_b != NULL);
assert(output != NULL);
$if BATCH_TILE > 4:
for (; batch >= ${BATCH_TILE} * sizeof(float); batch -= ${BATCH_TILE} * sizeof(float)) {
$for N in range(SIMD_TILE):
const float32x4_t va${N} = vld1q_f32(input_a); input_a += 4;
const float32x4_t vb${N} = vld1q_f32(input_b); input_b += 4;
$for N in range(SIMD_TILE):
float32x4_t vacc${N} = ${VOPQ_F32}(va${N}, vb${N});
$if OP == "SQRDIFF":
$for N in range(SIMD_TILE):
vacc${N} = vmulq_f32(vacc${N}, vacc${N});
$elif OP == "PRELU":
$for N in range(SIMD_TILE):
const uint32x4_t vm${N} = vcltq_s32(vreinterpretq_s32_f32(va${N}), vmovq_n_s32(0));
$for N in range(SIMD_TILE):
vacc${N} = vbslq_f32(vm${N}, vacc${N}, va${N});
$for N in range(SIMD_TILE):
vst1q_f32(output, vacc${N}); output += 4;
}
for (; batch >= 4 * sizeof(float); batch -= 4 * sizeof(float)) {
const float32x4_t va = vld1q_f32(input_a); input_a += 4;
const float32x4_t vb = vld1q_f32(input_b); input_b += 4;
float32x4_t vacc = ${VOPQ_F32}(va, vb);
$if OP == "SQRDIFF":
vacc = vmulq_f32(vacc, vacc);
$elif OP == "PRELU":
const uint32x4_t vm = vcltq_s32(vreinterpretq_s32_f32(va), vmovq_n_s32(0));
vacc = vbslq_f32(vm, vacc, va);
vst1q_f32(output, vacc); output += 4;
}
if XNN_UNLIKELY(batch != 0) {
const float32x4_t va = vld1q_f32(input_a);
const float32x4_t vb = vld1q_f32(input_b);
float32x4_t vacc = ${VOPQ_F32}(va, vb);
$if OP == "SQRDIFF":
vacc = vmulq_f32(vacc, vacc);
$elif OP == "PRELU":
const uint32x4_t vm = vcltq_s32(vreinterpretq_s32_f32(va), vmovq_n_s32(0));
vacc = vbslq_f32(vm, vacc, va);
float32x2_t vacc_lo = vget_low_f32(vacc);
if (batch & (2 * sizeof(float))) {
vst1_f32(output, vacc_lo); output += 2;
vacc_lo = vget_high_f32(vacc);
}
if (batch & (1 * sizeof(float))) {
vst1_lane_f32(output, vacc_lo, 0);
}
}
}