blob: 82986c53ed3919b9210ec4f036b754bd5416b01e [file] [log] [blame] [edit]
// Copyright 2019 Google LLC
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.
#include <cassert>
#include <cmath>
#include <cstddef>
#include <cstdint>
#include <cstdlib>
#include <cstring>
#include <type_traits>
#include "include/xnnpack.h"
#include "src/xnnpack/config-types.h"
#include "src/xnnpack/datatype.h"
#include "src/xnnpack/math.h"
#include "src/xnnpack/microfnptr.h"
#include "src/xnnpack/microparams.h"
#include "src/xnnpack/reference-utils.h"
using xnnpack::dequantize;
using xnnpack::euclidean_div;
using xnnpack::euclidean_mod;
using xnnpack::integer_pow;
using xnnpack::quantize;
using xnnpack::widen;
namespace {
// These "reference" microkernels are not designed to be fast, only to support
// all possible operators and datatypes with reasonable performance. We just
// intend to give the compiler a reasonable chance at optimizing them.
template <typename T, typename Operator>
void binary_ukernel_unquantized(size_t batch_size_bytes, const T* a, const T* b,
T* output, const xnn_binary_uparams*) {
const size_t batch_size = batch_size_bytes / sizeof(T);
Operator op;
for (size_t i = 0; i < batch_size; ++i) {
output[i] = static_cast<T>(op(a[i], b[i]));
}
}
template <typename T, typename Operator>
void binaryc_ukernel_unquantized(size_t batch_size_bytes, const T* a,
const T* b, T* output,
const xnn_binary_uparams*) {
const size_t batch_size = batch_size_bytes / sizeof(T);
const T b_0 = *b;
Operator op;
for (size_t i = 0; i < batch_size; ++i) {
output[i] = static_cast<T>(op(a[i], b_0));
}
}
template <typename T, typename Operator>
void rbinaryc_ukernel_unquantized(size_t batch_size_bytes, const T* a,
const T* b, T* output,
const xnn_binary_uparams*) {
const size_t batch_size = batch_size_bytes / sizeof(T);
const T b_0 = *b;
Operator op;
for (size_t i = 0; i < batch_size; ++i) {
output[i] = static_cast<T>(op(b_0, a[i]));
}
}
template <typename Operator, typename T>
const xnn_binary_elementwise_config* get_config(T) {
static_assert(!xnnpack::is_quantized<T>::value, "");
static xnn_binary_elementwise_config config = {
(xnn_vbinary_ukernel_fn)binary_ukernel_unquantized<T, Operator>,
(xnn_vbinary_ukernel_fn)binaryc_ukernel_unquantized<T, Operator>,
(xnn_vbinary_ukernel_fn)rbinaryc_ukernel_unquantized<T, Operator>,
nullptr,
/*element_tile=*/1,
};
return &config;
}
template <typename T, typename Operator>
void binary_ukernel_quantized(size_t batch_size_bytes, const T* a, const T* b,
T* output, const xnn_binary_uparams* params) {
const size_t batch_size = batch_size_bytes / sizeof(T);
Operator op;
for (size_t i = 0; i < batch_size; ++i) {
const float a_i = dequantize(a[i], params->reference.a_scale,
params->reference.a_zero_point);
const float b_i = dequantize(b[i], params->reference.b_scale,
params->reference.b_zero_point);
const float result = op(a_i, b_i);
output[i] = quantize<T>(result, params->reference.inv_output_scale,
params->reference.output_zero_point);
}
}
template <typename T, typename Operator>
void binaryc_ukernel_quantized(size_t batch_size_bytes, const T* a, const T* b,
T* output, const xnn_binary_uparams* params) {
const size_t batch_size = batch_size_bytes / sizeof(T);
Operator op;
const float b_0 =
dequantize(*b, params->reference.b_scale, params->reference.b_zero_point);
for (size_t i = 0; i < batch_size; ++i) {
const float a_i = dequantize(a[i], params->reference.a_scale,
params->reference.a_zero_point);
const float result = op(a_i, b_0);
output[i] = quantize<T>(result, params->reference.inv_output_scale,
params->reference.output_zero_point);
}
}
template <typename T, typename Operator>
void rbinaryc_ukernel_quantized(size_t batch_size_bytes, const T* a, const T* b,
T* output, const xnn_binary_uparams* params) {
const size_t batch_size = batch_size_bytes / sizeof(T);
Operator op;
const float b_0 =
dequantize(*b, params->reference.b_scale, params->reference.b_zero_point);
for (size_t i = 0; i < batch_size; ++i) {
const float a_i = dequantize(a[i], params->reference.a_scale,
params->reference.a_zero_point);
const float result = op(b_0, a_i);
output[i] = quantize<T>(result, params->reference.inv_output_scale,
params->reference.output_zero_point);
}
}
size_t init_quantized_binary_op(
union xnn_binary_uparams* params,
const struct xnn_quantization_params* a_params,
const struct xnn_quantization_params* b_params,
const struct xnn_quantization_params* output_params) {
params->reference.a_scale = a_params->scale;
params->reference.a_zero_point = a_params->zero_point;
params->reference.b_scale = b_params->scale;
params->reference.b_zero_point = b_params->zero_point;
params->reference.inv_output_scale = 1.0f / output_params->scale;
params->reference.output_zero_point = output_params->zero_point;
return sizeof(params->reference);
}
template <typename Operator, typename T>
const xnn_binary_elementwise_config* get_config(xnnpack::quantized<T>) {
static xnn_binary_elementwise_config config = {
(xnn_vbinary_ukernel_fn)
binary_ukernel_quantized<xnnpack::quantized<T>, Operator>,
(xnn_vbinary_ukernel_fn)
binaryc_ukernel_quantized<xnnpack::quantized<T>, Operator>,
(xnn_vbinary_ukernel_fn)
rbinaryc_ukernel_quantized<xnnpack::quantized<T>, Operator>,
init_quantized_binary_op,
/*element_tile=*/1,
};
return &config;
}
// We can't use STL's functional ops for integers because they will crash if
// they overflow.
template <typename T>
struct AddOp {
T operator()(T a, T b) const { return a + b; }
};
template <>
struct AddOp<int32_t> {
int32_t operator()(int32_t a, int32_t b) const { return widen(a) + widen(b); }
};
template <typename T>
struct SubOp {
T operator()(T a, T b) const { return a - b; }
};
template <>
struct SubOp<int32_t> {
int32_t operator()(int32_t a, int32_t b) const { return widen(a) - widen(b); }
};
template <typename T>
struct MultiplyOp {
T operator()(T a, T b) const { return a * b; }
};
template <>
struct MultiplyOp<int32_t> {
int32_t operator()(int32_t a, int32_t b) const { return widen(a) * widen(b); }
};
template <typename T>
struct DivideOp {
T operator()(T a, T b) const { return a / b; }
};
template <>
struct DivideOp<int32_t> {
int32_t operator()(int32_t a, int32_t b) const { return euclidean_div(a, b); }
};
template <typename T>
struct ModulusOp {
float operator()(float a, float b) const {
// Define division by zero to be 0?
if (b == 0) {
return 0;
} else {
return std::fmod(a, b);
}
}
};
template <>
struct ModulusOp<int32_t> {
int32_t operator()(int32_t a, int32_t b) const { return euclidean_mod(a, b); }
};
template <typename T>
struct MaxOp {
T operator()(T a, T b) const { return a > b ? a : b; }
};
template <typename T>
struct MinOp {
T operator()(T a, T b) const { return a < b ? a : b; }
};
template <typename T>
struct SquaredDifferenceOp {
T operator()(T a, T b) const { return (a - b) * (a - b); }
};
template <typename T>
struct PreluOp {
T operator()(T a, T b) const { return (a < 0) ? static_cast<T>(a * b) : a; }
};
template <typename T>
struct Atan2Op {
float operator()(float a, float b) const { return std::atan2(a, b); }
};
template <typename T>
struct PowOp {
float operator()(float a, float b) const { return std::pow(a, b); }
};
template <>
struct PowOp<int32_t> {
int32_t operator()(int32_t a, int32_t b) const { return integer_pow(a, b); }
};
template <typename T>
struct BitwiseAndOp {
T operator()(T a, T b) const { return a & b; }
};
template <typename T>
struct BitwiseOrOp {
T operator()(T a, T b) const { return a | b; }
};
template <typename T>
struct BitwiseXorOp {
T operator()(T a, T b) const { return a ^ b; }
};
template <typename T>
struct ShiftLeftOp {
static constexpr T type_mask = sizeof(T) * 8 - 1;
T operator()(T a, T b) const { return a << (b & type_mask); }
};
template <typename T>
struct ShiftRightLogicalOp {
T operator()(T a, T b) const {
static constexpr T type_mask = sizeof(T) * 8 - 1;
return static_cast<typename std::make_unsigned<T>::type>(a) >>
(b & type_mask);
}
};
template <typename T>
struct ShiftRightArithmeticOp {
T operator()(T a, T b) const {
static constexpr T type_mask = sizeof(T) * 8 - 1;
return static_cast<typename std::make_signed<T>::type>(a) >>
(b & type_mask);
}
};
using std::copysign;
xnn_float16 copysign(xnn_float16 a, xnn_float16 b) {
uint16_t a_bits = xnn_float16_to_bits(a);
uint16_t b_bits = xnn_float16_to_bits(b);
return xnn_float16_from_bits((a_bits & 0x7FFF) | (b_bits & 0x8000));
}
xnn_bfloat16 copysign(xnn_bfloat16 a, xnn_bfloat16 b) {
uint16_t a_bits = xnn_bfloat16_to_bits(a);
uint16_t b_bits = xnn_bfloat16_to_bits(b);
return xnn_bfloat16_from_bits((a_bits & 0x7FFF) | (b_bits & 0x8000));
}
template <typename T>
struct CopysignOp {
T operator()(T a, T b) const { return copysign(a, b); }
};
#define DISPATCH_OPERATOR_FOR_DATATYPE(datatype, op) \
switch (datatype) { \
case xnn_datatype_fp32: \
return get_config<op<float>>(float()); \
case xnn_datatype_fp16: \
return get_config<op<xnn_float16>>(xnn_float16()); \
case xnn_datatype_bf16: \
return get_config<op<xnn_bfloat16>>(xnn_bfloat16()); \
case xnn_datatype_qint8: \
return get_config<op<float>>(xnnpack::quantized<int8_t>()); \
case xnn_datatype_quint8: \
return get_config<op<float>>(xnnpack::quantized<uint8_t>()); \
case xnn_datatype_int32: \
return get_config<op<int32_t>>(int32_t()); \
default: \
return nullptr; \
}
#define DISPATCH_OPERATOR_FOR_REAL_DATATYPE(datatype, op) \
switch (datatype) { \
case xnn_datatype_fp32: \
return get_config<op<float>>(float()); \
case xnn_datatype_fp16: \
return get_config<op<xnn_float16>>(xnn_float16()); \
case xnn_datatype_bf16: \
return get_config<op<xnn_bfloat16>>(xnn_bfloat16()); \
case xnn_datatype_qint8: \
return get_config<op<float>>(xnnpack::quantized<int8_t>()); \
case xnn_datatype_quint8: \
return get_config<op<float>>(xnnpack::quantized<uint8_t>()); \
default: \
return nullptr; \
}
#define DISPATCH_OPERATOR_FOR_INTEGRAL_DATATYPE(datatype, op) \
switch (datatype) { \
case xnn_datatype_int32: \
return get_config<op<int32_t>>(int32_t()); \
default: \
return nullptr; \
}
} // namespace
extern "C" {
const struct xnn_binary_elementwise_config* xnn_init_binary_reference_config(
enum xnn_binary_operator type, enum xnn_datatype datatype) {
switch (type) {
case xnn_binary_add:
DISPATCH_OPERATOR_FOR_DATATYPE(datatype, AddOp);
case xnn_binary_subtract:
DISPATCH_OPERATOR_FOR_DATATYPE(datatype, SubOp);
case xnn_binary_multiply:
DISPATCH_OPERATOR_FOR_DATATYPE(datatype, MultiplyOp);
case xnn_binary_divide:
DISPATCH_OPERATOR_FOR_DATATYPE(datatype, DivideOp);
case xnn_binary_maximum:
DISPATCH_OPERATOR_FOR_DATATYPE(datatype, MaxOp);
case xnn_binary_minimum:
DISPATCH_OPERATOR_FOR_DATATYPE(datatype, MinOp);
case xnn_binary_copysign:
DISPATCH_OPERATOR_FOR_DATATYPE(datatype, CopysignOp);
case xnn_binary_prelu:
DISPATCH_OPERATOR_FOR_REAL_DATATYPE(datatype, PreluOp);
case xnn_binary_squared_difference:
DISPATCH_OPERATOR_FOR_REAL_DATATYPE(datatype, SquaredDifferenceOp);
case xnn_binary_modulus:
DISPATCH_OPERATOR_FOR_DATATYPE(datatype, ModulusOp);
case xnn_binary_atan2:
DISPATCH_OPERATOR_FOR_REAL_DATATYPE(datatype, Atan2Op);
case xnn_binary_pow:
DISPATCH_OPERATOR_FOR_DATATYPE(datatype, PowOp);
case xnn_binary_bitwise_and:
DISPATCH_OPERATOR_FOR_INTEGRAL_DATATYPE(datatype, BitwiseAndOp);
case xnn_binary_bitwise_or:
DISPATCH_OPERATOR_FOR_INTEGRAL_DATATYPE(datatype, BitwiseOrOp);
case xnn_binary_bitwise_xor:
DISPATCH_OPERATOR_FOR_INTEGRAL_DATATYPE(datatype, BitwiseXorOp);
case xnn_binary_shift_left:
DISPATCH_OPERATOR_FOR_INTEGRAL_DATATYPE(datatype, ShiftLeftOp);
case xnn_binary_shift_right_logical:
DISPATCH_OPERATOR_FOR_INTEGRAL_DATATYPE(datatype, ShiftRightLogicalOp);
case xnn_binary_shift_right_arithmetic:
DISPATCH_OPERATOR_FOR_INTEGRAL_DATATYPE(datatype, ShiftRightArithmeticOp);
case xnn_binary_invalid:
return nullptr;
}
return nullptr;
}
} // extern "C"