blob: 1cde98cba48435dbcf05c4060a873e3579820865 [file] [edit]
// Copyright 2025 Google LLC
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.
$assert(DATATYPE in ["f32", "f16", "u8", "s8"])
$assert(SIMD_SIZE > 0)
#include <assert.h>
#include <stddef.h>
#include <stdint.h>
$DATATYPE_IMPL = {"f32": "f32", "f16": "f16" if ARCH == "neonfp16arith" else "s16", "u8": "u8", "s8": "u8" if ARCH == "sse2" else "s8"}[DATATYPE]
// On some architectures, we have max(u8, u8) but not max(s8, s8). We can emulate max(s8, s8) on these architectures by
// xoring with the sign bit mask.
$if ARCH == "sse2" and DATATYPE == "s8":
#define xnn_load_impl(x) xnn_xor_u8(xnn_loadu_u8(x), vsign_bit_mask)
#define xnn_load_tail_impl(x, c) xnn_xor_u8(xnn_load_tail_u8(x, c), vsign_bit_mask)
#define xnn_load_tail_safe_impl(x, c) xnn_xor_u8(xnn_load_tail_safe_u8(x, c), vsign_bit_mask)
#define xnn_pre_store_impl(x) xnn_xor_u8(x, vsign_bit_mask)
$elif DATATYPE == "f16" and ARCH != "neonfp16arith":
#define xnn_load_impl(x) xnn_signcomplement_s16(xnn_loadu_s16(x))
#define xnn_load_tail_impl(x, c) xnn_signcomplement_s16(xnn_load_tail_s16(x, c))
#define xnn_load_tail_safe_impl(x, c) xnn_signcomplement_s16(xnn_load_tail_safe_s16(x, c))
#define xnn_pre_store_impl(x) xnn_signcomplement_s16(x)
$else:
#define xnn_load_impl(x) xnn_loadu_${DATATYPE}(x)
#define xnn_load_tail_impl(x, c) xnn_load_tail_${DATATYPE}(x, c)
#define xnn_load_tail_safe_impl(x, c) xnn_load_tail_safe_${DATATYPE}(x, c)
#define xnn_pre_store_impl(x) x
#include "src/xnnpack/common.h"
#include "src/xnnpack/math.h"
#include "src/xnnpack/microparams.h"
#include "src/xnnpack/simd/${DATATYPE_IMPL}-${ARCH}.h"
$CTYPE = {"f32": "float", "f16": "xnn_float16", "u8": "uint8_t", "s8": "int8_t"}[DATATYPE]
$CTYPE_IMPL = {"f32": "float", "f16": "xnn_float16", "s16": "int16_t", "u8": "uint8_t", "s8": "int8_t"}[DATATYPE_IMPL]
void xnn_${DATATYPE}_maxpool_minmax_ukernel_9p__${ARCH}_u${SIMD_SIZE}(
size_t output_pixels,
size_t kernel_elements,
size_t channels,
const ${CTYPE}** input,
size_t input_offset,
size_t input_pixel_stride,
${CTYPE}* output,
size_t input_increment,
size_t output_increment,
const struct xnn_${DATATYPE}_minmax_params* restrict params)
{
assert(output_pixels != 0);
assert(channels != 0);
$if ARCH == "sse2" and DATATYPE == "s8":
const xnn_simd_u8_t vsign_bit_mask = xnn_set1_u8(0x80);
XNN_FORCE_REALIZATION(vsign_bit_mask);
const xnn_simd_u8_t vmin = xnn_xor_u8(xnn_set1_u8(params->scalar.min), vsign_bit_mask);
const xnn_simd_u8_t vmax = xnn_xor_u8(xnn_set1_u8(params->scalar.max), vsign_bit_mask);
XNN_FORCE_REALIZATION(vmin);
XNN_FORCE_REALIZATION(vmax);
$elif DATATYPE == "f16" and ARCH != "neonfp16arith":
const xnn_simd_s16_t vmin = xnn_signcomplement_s16(xnn_set1_s16(*(const int16_t*) &params->scalar.min));
const xnn_simd_s16_t vmax = xnn_signcomplement_s16(xnn_set1_s16(*(const int16_t*) &params->scalar.max));
XNN_FORCE_REALIZATION(vmin);
XNN_FORCE_REALIZATION(vmax);
$else:
const xnn_simd_${DATATYPE_IMPL}_t vmin = xnn_set1_${DATATYPE_IMPL}(params->scalar.min);
const xnn_simd_${DATATYPE_IMPL}_t vmax = xnn_set1_${DATATYPE_IMPL}(params->scalar.max);
XNN_FORCE_REALIZATION(vmin);
XNN_FORCE_REALIZATION(vmax);
do {
const ${CTYPE_IMPL}** i = (const ${CTYPE_IMPL}**) input;
// First pass: load the inputs, store the max pool in the output.
const ${CTYPE_IMPL}* i0 = *i++;
$for K in range(1, 9):
const ${CTYPE_IMPL}* i${K} = ${K} < kernel_elements ? *i++ : i0;
$for K in range(9):
i${K} = (const ${CTYPE_IMPL}*) ((uintptr_t) i${K} + input_offset);
${CTYPE_IMPL}* o = (${CTYPE_IMPL}*) output;
size_t c = channels;
for (; c >= ${SIMD_SIZE}; c -= ${SIMD_SIZE}) {
$for K in range(9):
const xnn_simd_${DATATYPE_IMPL}_t vi${K} = xnn_load_impl(i${K}); i${K} += ${SIMD_SIZE};
const xnn_simd_${DATATYPE_IMPL}_t vmax018 = xnn_max_${DATATYPE_IMPL}(xnn_max_${DATATYPE_IMPL}(vi0, vi1), vi8);
const xnn_simd_${DATATYPE_IMPL}_t vmax23 = xnn_max_${DATATYPE_IMPL}(vi2, vi3);
const xnn_simd_${DATATYPE_IMPL}_t vmax45 = xnn_max_${DATATYPE_IMPL}(vi4, vi5);
const xnn_simd_${DATATYPE_IMPL}_t vmax67 = xnn_max_${DATATYPE_IMPL}(vi6, vi7);
const xnn_simd_${DATATYPE_IMPL}_t vmax2345 = xnn_max_${DATATYPE_IMPL}(vmax23, vmax45);
const xnn_simd_${DATATYPE_IMPL}_t vmax01678 = xnn_max_${DATATYPE_IMPL}(vmax018, vmax67);
xnn_simd_${DATATYPE_IMPL}_t vacc = xnn_max_${DATATYPE_IMPL}(vmax2345, vmax01678);
vacc = xnn_max_${DATATYPE_IMPL}(vacc, vmin);
vacc = xnn_min_${DATATYPE_IMPL}(vacc, vmax);
vacc = xnn_pre_store_impl(vacc);
xnn_storeu_${DATATYPE_IMPL}(o, vacc); o += ${SIMD_SIZE};
}
$if SIMD_SIZE > 1:
if (c > 0) {
$for K in range(9):
const xnn_simd_${DATATYPE_IMPL}_t vi${K} = xnn_load_tail_impl(i${K}, c);
const xnn_simd_${DATATYPE_IMPL}_t vmax018 = xnn_max_${DATATYPE_IMPL}(xnn_max_${DATATYPE_IMPL}(vi0, vi1), vi8);
const xnn_simd_${DATATYPE_IMPL}_t vmax23 = xnn_max_${DATATYPE_IMPL}(vi2, vi3);
const xnn_simd_${DATATYPE_IMPL}_t vmax45 = xnn_max_${DATATYPE_IMPL}(vi4, vi5);
const xnn_simd_${DATATYPE_IMPL}_t vmax67 = xnn_max_${DATATYPE_IMPL}(vi6, vi7);
const xnn_simd_${DATATYPE_IMPL}_t vmax2345 = xnn_max_${DATATYPE_IMPL}(vmax23, vmax45);
const xnn_simd_${DATATYPE_IMPL}_t vmax01678 = xnn_max_${DATATYPE_IMPL}(vmax018, vmax67);
xnn_simd_${DATATYPE_IMPL}_t vacc = xnn_max_${DATATYPE_IMPL}(vmax2345, vmax01678);
vacc = xnn_max_${DATATYPE_IMPL}(vacc, vmin);
vacc = xnn_min_${DATATYPE_IMPL}(vacc, vmax);
vacc = xnn_pre_store_impl(vacc);
xnn_store_tail_${DATATYPE_IMPL}(o, vacc, c); o += c;
}
// Passes 1 - n: Max more inputs to the output.
o = (${CTYPE_IMPL}*) output;
for (ptrdiff_t k = (ptrdiff_t) kernel_elements - 9; k > 0; k -= 9) {
const ${CTYPE_IMPL}* i0 = *i++;
$for K in range(1, 9):
const ${CTYPE_IMPL}* i${K} = ${K} < k ? *i++ : i0;
$for K in range(9):
i${K} = (const ${CTYPE_IMPL}*) ((uintptr_t) i${K} + input_offset);
${CTYPE_IMPL}* o = (${CTYPE_IMPL}*) output;
size_t c = channels;
for (; c >= ${SIMD_SIZE}; c -= ${SIMD_SIZE}) {
$for K in range(9):
const xnn_simd_${DATATYPE_IMPL}_t vi${K} = xnn_load_impl(i${K}); i${K} += ${SIMD_SIZE};
const xnn_simd_${DATATYPE_IMPL}_t vprev = xnn_load_impl(o);
const xnn_simd_${DATATYPE_IMPL}_t vmax018 = xnn_max_${DATATYPE_IMPL}(xnn_max_${DATATYPE_IMPL}(vi0, vi1), vi8);
const xnn_simd_${DATATYPE_IMPL}_t vmax23 = xnn_max_${DATATYPE_IMPL}(vi2, vi3);
const xnn_simd_${DATATYPE_IMPL}_t vmax45 = xnn_max_${DATATYPE_IMPL}(vi4, vi5);
const xnn_simd_${DATATYPE_IMPL}_t vmax67 = xnn_max_${DATATYPE_IMPL}(vi6, vi7);
const xnn_simd_${DATATYPE_IMPL}_t vmax2345 = xnn_max_${DATATYPE_IMPL}(vmax23, vmax45);
const xnn_simd_${DATATYPE_IMPL}_t vmax01678 = xnn_max_${DATATYPE_IMPL}(vmax018, vmax67);
const xnn_simd_${DATATYPE_IMPL}_t vmax012345678 = xnn_max_${DATATYPE_IMPL}(vmax2345, vmax01678);
xnn_simd_${DATATYPE_IMPL}_t vacc = xnn_max_${DATATYPE_IMPL}(vprev, vmax012345678);
vacc = xnn_min_${DATATYPE_IMPL}(vacc, vmax);
vacc = xnn_pre_store_impl(vacc);
xnn_storeu_${DATATYPE_IMPL}(o, vacc); o += ${SIMD_SIZE};
}
$if SIMD_SIZE > 1:
if (c > 0) {
$for K in range(9):
const xnn_simd_${DATATYPE_IMPL}_t vi${K} = xnn_load_tail_impl(i${K}, c);
const xnn_simd_${DATATYPE_IMPL}_t vprev = xnn_load_tail_safe_impl(o, c);
const xnn_simd_${DATATYPE_IMPL}_t vmax018 = xnn_max_${DATATYPE_IMPL}(xnn_max_${DATATYPE_IMPL}(vi0, vi1), vi8);
const xnn_simd_${DATATYPE_IMPL}_t vmax23 = xnn_max_${DATATYPE_IMPL}(vi2, vi3);
const xnn_simd_${DATATYPE_IMPL}_t vmax45 = xnn_max_${DATATYPE_IMPL}(vi4, vi5);
const xnn_simd_${DATATYPE_IMPL}_t vmax67 = xnn_max_${DATATYPE_IMPL}(vi6, vi7);
const xnn_simd_${DATATYPE_IMPL}_t vmax2345 = xnn_max_${DATATYPE_IMPL}(vmax23, vmax45);
const xnn_simd_${DATATYPE_IMPL}_t vmax01678 = xnn_max_${DATATYPE_IMPL}(vmax018, vmax67);
const xnn_simd_${DATATYPE_IMPL}_t vmax012345678 = xnn_max_${DATATYPE_IMPL}(vmax2345, vmax01678);
xnn_simd_${DATATYPE_IMPL}_t vacc = xnn_max_${DATATYPE_IMPL}(vprev, vmax012345678);
vacc = xnn_min_${DATATYPE_IMPL}(vacc, vmax);
vacc = xnn_pre_store_impl(vacc);
xnn_store_tail_${DATATYPE_IMPL}(o, vacc, c);
}
}
input = (const ${CTYPE}**) ((uintptr_t) input + input_increment);
input_offset += input_pixel_stride;
output = (${CTYPE}*) ((uintptr_t) output + output_increment);
} while (--output_pixels != 0);
}