blob: d493ecf126e7b06068e9b9d30020200d34ce314f [file] [edit]
// Copyright 2026 Google LLC
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.
$assert LMUL in [1, 2, 4, 8]
$assert OP in ["ADD", "DIV", "RDIV", "MAX", "MIN", "MUL", "SUB", "RSUB", "SQRDIFF", "PRELU", "RPRELU"]
#include <assert.h>
#include <riscv_vector.h>
#include "src/xnnpack/common.h"
#include "src/xnnpack/vbinary.h"
$OP_FUNC = {
$ "ADD": "__riscv_vfadd_vf_f16",
$ "DIV": "__riscv_vfdiv_vf_f16",
$ "RDIV": "__riscv_vfrdiv_vf_f16",
$ "MAX": "__riscv_vfmax_vf_f16",
$ "MIN": "__riscv_vfmin_vf_f16",
$ "MUL": "__riscv_vfmul_vf_f16",
$ "SUB": "__riscv_vfsub_vf_f16",
$ "RSUB": "__riscv_vfrsub_vf_f16",
$ "SQRDIFF": "__riscv_vfsub_vf_f16",
$ "PRELU": "__riscv_vfmul_vf_f16",
$ "RPRELU": "__riscv_vfmul_vf_f16",
$}[OP]
void xnn_f16_v${OP.lower()}c_ukernel__rvvfp16arith_u${LMUL}v(
size_t batch,
const xnn_float16* input_a,
const xnn_float16* input_b,
xnn_float16* output,
const struct xnn_f16_default_params* restrict params)
{
assert(batch != 0);
assert(batch % sizeof(uint16_t) == 0);
assert(input_a != NULL);
assert(input_b != NULL);
assert(output != NULL);
const uint16_t* a = (const uint16_t*) input_a;
const _Float16 b = *(const _Float16*) input_b;
uint16_t* o = (uint16_t*) output;
size_t n = batch >> XNN_LOG2_SIZEOF_FLOAT16;
$if OP == "RPRELU":
if XNN_UNLIKELY(b >= 0.0f) {
size_t vl = __riscv_vsetvl_e16m${LMUL}(n);
vfloat16m${LMUL}_t vacc = __riscv_vfmv_v_f_f16m${LMUL}(b, vl);
do {
size_t vl = __riscv_vsetvl_e16m${LMUL}(n);
n -= vl;
__riscv_vse16_v_f16m${LMUL}((void *) o, vacc, vl);
o += vl;
} while (n > 0);
return;
}
do {
size_t vl = __riscv_vsetvl_e16m${LMUL}(n);
n -= vl;
vfloat16m${LMUL}_t va = __riscv_vle16_v_f16m${LMUL}((const void *) a, vl);
a += vl;
vfloat16m${LMUL}_t vacc = ${OP_FUNC}m${LMUL}(va, b, vl);
$if OP == "SQRDIFF":
vacc = __riscv_vfmul_vv_f16m${LMUL}(vacc, vacc, vl);
$elif OP == "PRELU":
vbool${16//LMUL}_t vmask = __riscv_vmflt_vf_f16m${LMUL}_b${16//LMUL}(va, 0.0f, vl);
vacc = __riscv_vmerge_vvm_f16m${LMUL}(va, vacc, vmask, vl);
__riscv_vse16_v_f16m${LMUL}((void *) o, vacc, vl);
o += vl;
} while (n > 0);
}