blob: 75f6cb0147fc6de7465ad48c3f6e1fcd8a8d8e69 [file] [log] [blame] [edit]
// Copyright 2019 Google LLC
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.
#include <assert.h>
#include <math.h>
#include <stddef.h>
#include <stdint.h>
#include <stdlib.h>
#include <string.h>
#include "include/xnnpack.h"
#include "src/xnnpack/allocator.h"
#include "src/xnnpack/compute.h"
#include "src/xnnpack/config-types.h"
#include "src/xnnpack/config.h"
#include "src/xnnpack/datatype.h"
#include "src/xnnpack/log.h"
#include "src/xnnpack/math.h"
#include "src/xnnpack/microparams.h"
#include "src/xnnpack/operator-type.h"
#include "src/xnnpack/operator-utils.h"
#include "src/xnnpack/operator.h"
#include "src/xnnpack/params.h"
#include "src/xnnpack/reference-config.h"
#include <pthreadpool.h>
static const struct xnn_binary_elementwise_config* init_config(
enum xnn_binary_operator type, enum xnn_datatype datatype, int* sign_b) {
switch (type) {
case xnn_binary_add:
switch (datatype) {
case xnn_datatype_fp32:
return xnn_init_f32_vadd_config();
case xnn_datatype_fp16:
return xnn_init_f16_vadd_config();
case xnn_datatype_qint8:
return xnn_init_qs8_vadd_config();
case xnn_datatype_quint8:
return xnn_init_qu8_vadd_config();
default:
return NULL;
}
case xnn_binary_subtract:
switch (datatype) {
case xnn_datatype_fp32:
return xnn_init_f32_vsub_config();
case xnn_datatype_fp16:
return xnn_init_f16_vsub_config();
case xnn_datatype_qint8:
*sign_b = -1;
return xnn_init_qs8_vadd_config();
case xnn_datatype_quint8:
*sign_b = -1;
return xnn_init_qu8_vadd_config();
default:
return NULL;
}
case xnn_binary_multiply:
switch (datatype) {
case xnn_datatype_fp32:
return xnn_init_f32_vmul_config();
case xnn_datatype_fp16:
return xnn_init_f16_vmul_config();
case xnn_datatype_qint8:
return xnn_init_qs8_vmul_config();
case xnn_datatype_quint8:
return xnn_init_qu8_vmul_config();
default:
return NULL;
}
case xnn_binary_divide:
switch (datatype) {
case xnn_datatype_fp32:
return xnn_init_f32_vdiv_config();
case xnn_datatype_fp16:
return xnn_init_f16_vdiv_config();
default:
return NULL;
}
case xnn_binary_maximum:
switch (datatype) {
case xnn_datatype_fp32:
return xnn_init_f32_vmax_config();
case xnn_datatype_fp16:
return xnn_init_f16_vmax_config();
default:
return NULL;
}
case xnn_binary_minimum:
switch (datatype) {
case xnn_datatype_fp32:
return xnn_init_f32_vmin_config();
case xnn_datatype_fp16:
return xnn_init_f16_vmin_config();
default:
return NULL;
}
case xnn_binary_copysign:
switch (datatype) {
case xnn_datatype_fp32:
return xnn_init_f32_vcopysign_config();
default:
return NULL;
}
case xnn_binary_squared_difference:
switch (datatype) {
case xnn_datatype_fp32:
return xnn_init_f32_vsqrdiff_config();
case xnn_datatype_fp16:
return xnn_init_f16_vsqrdiff_config();
default:
return NULL;
}
case xnn_binary_prelu:
switch (datatype) {
case xnn_datatype_fp32:
return xnn_init_f32_vprelu_config();
case xnn_datatype_fp16:
return xnn_init_f16_vprelu_config();
case xnn_datatype_qint8:
return xnn_init_qs8_vprelu_config();
case xnn_datatype_quint8:
return xnn_init_qu8_vprelu_config();
default:
return NULL;
}
default:
return NULL;
}
}
static enum xnn_status init_binary_elementwise_nd(
xnn_operator_t op, enum xnn_binary_operator type,
enum xnn_datatype datatype,
const struct xnn_quantization_params* a_quantization,
const struct xnn_quantization_params* b_quantization,
const struct xnn_quantization_params* output_quantization, uint32_t flags) {
int sign_b = 1;
const struct xnn_binary_elementwise_config* config =
init_config(type, datatype, &sign_b);
if (config == NULL ||
config->op_ukernel == NULL ||
config->opc_ukernel == NULL ||
config->ropc_ukernel == NULL) {
xnn_log_debug(
"unsupported operator %s for datatype %s, falling back to reference kernel",
xnn_binary_operator_to_string(type), xnn_datatype_to_string(datatype));
config = xnn_init_binary_reference_config(type, datatype);
}
if (config == NULL) {
xnn_log_error(
"failed to create %s operator: unsupported datatype %s",
xnn_binary_operator_to_string(type),
xnn_datatype_to_string(datatype));
return xnn_status_unsupported_parameter;
}
union xnn_binary_uparams uparams;
union xnn_binary_uparams uparams2;
if (config->init != NULL) {
if (datatype == xnn_datatype_qint8 || datatype == xnn_datatype_quint8) {
if (!a_quantization || !b_quantization || !output_quantization) {
xnn_log_error(
"failed to create %s operator with NULL quantization params",
xnn_binary_operator_to_string(type));
return xnn_status_invalid_parameter;
}
const float a_scale = a_quantization ? a_quantization->scale : 1.0f;
const float b_scale = b_quantization ? b_quantization->scale : 1.0f;
const float output_scale =
output_quantization ? output_quantization->scale : 1.0f;
if (a_scale <= 0.0f || !isnormal(a_scale)) {
xnn_log_error(
"failed to create %s operator with %.7g input 1 scale: scale must "
"be finite and positive",
xnn_binary_operator_to_string(type), a_scale);
return xnn_status_invalid_parameter;
}
if (b_scale <= 0.0f || !isnormal(b_scale)) {
xnn_log_error(
"failed to create %s operator with %.7g input 2 scale: scale must "
"be finite and positive",
xnn_binary_operator_to_string(type), b_scale);
return xnn_status_invalid_parameter;
}
if (output_scale <= 0.0f || !isnormal(output_scale)) {
xnn_log_error(
"failed to create %s operator with %.7g output scale: scale must "
"be finite and positive",
xnn_binary_operator_to_string(type), output_scale);
return xnn_status_invalid_parameter;
}
struct xnn_quantization_params b_quantization_with_sign = *b_quantization;
b_quantization_with_sign.scale *= sign_b;
config->init(&uparams, a_quantization, &b_quantization_with_sign,
output_quantization);
config->init(&uparams2, &b_quantization_with_sign, a_quantization,
output_quantization);
} else {
config->init(&uparams, NULL, NULL, NULL);
config->init(&uparams2, NULL, NULL, NULL);
}
}
memcpy(&op->params, &uparams, sizeof(uparams));
memcpy(op->extra_params, &uparams2, sizeof(uparams2));
op->binary_elementwise_config = config;
op->binary_elementwise.log2_element_size =
xnn_datatype_log2_size_bytes(datatype);
op->type = xnn_operator_type_binary_elementwise;
op->binary_elementwise.op_type = type;
op->flags = flags;
op->state = xnn_run_state_invalid;
return xnn_status_success;
}
enum xnn_status xnn_create_binary_elementwise_nd(
enum xnn_binary_operator type, enum xnn_datatype datatype,
const struct xnn_quantization_params* a_quantization,
const struct xnn_quantization_params* b_quantization,
const struct xnn_quantization_params* output_quantization, uint32_t flags,
xnn_operator_t* binary_op_out) {
if ((xnn_params.init_flags & XNN_INIT_FLAG_XNNPACK) == 0) {
xnn_log_error("failed to create %s operator: XNNPACK is not initialized",
xnn_binary_operator_to_string(type));
return xnn_status_uninitialized;
}
xnn_operator_t op =
xnn_allocate_zero_simd_memory(sizeof(struct xnn_operator));
if (op == NULL) {
xnn_log_error("failed to allocate %zu bytes for %s operator descriptor",
sizeof(struct xnn_operator),
xnn_binary_operator_to_string(type));
return xnn_status_out_of_memory;
}
op->compute = xnn_allocate_zero_memory(sizeof(struct compute_parameters));
if (op->compute == NULL) {
xnn_log_error("failed to allocate %zu bytes for %s operator descriptor",
sizeof(struct compute_parameters),
xnn_binary_operator_to_string(type));
return xnn_status_out_of_memory;
}
op->num_compute_invocations = 1;
xnn_allocate_extra_params(op, /*num_extra_params=*/1);
enum xnn_status status =
init_binary_elementwise_nd(op, type, datatype, a_quantization,
b_quantization, output_quantization, flags);
if (status != xnn_status_success) {
xnn_delete_operator(op);
return status;
}
*binary_op_out = op;
return xnn_status_success;
}
static size_t get_tile_size(xnn_operator_t op) {
// Assume a default width (unrolling factor) of 32.
const size_t element_tile = op->binary_elementwise_config->element_tile
? op->binary_elementwise_config->element_tile
: 32;
return round_up(16 * 1024,
element_tile << op->binary_elementwise.log2_element_size);
}
enum xnn_status xnn_reshape_binary_elementwise_nd(xnn_operator_t op,
size_t num_input1_dims,
const size_t* input1_shape,
size_t num_input2_dims,
const size_t* input2_shape,
pthreadpool_t threadpool) {
op->state = xnn_run_state_invalid;
if (max(num_input1_dims, num_input2_dims) > XNN_MAX_TENSOR_DIMS) {
xnn_log_error(
"failed to reshape %s operator with %zu and %zu dimensions in input "
"shapes: the number of input dimensions must not exceed %d",
xnn_operator_type_to_string_v2(op), num_input1_dims, num_input2_dims,
XNN_MAX_TENSOR_DIMS);
return xnn_status_unsupported_parameter;
}
size_t num_compressed_dims = 0;
size_t compressed_input1_shape[XNN_MAX_TENSOR_DIMS];
size_t compressed_input2_shape[XNN_MAX_TENSOR_DIMS];
size_t compressed_output_shape[XNN_MAX_TENSOR_DIMS];
for (size_t i = 0; i < XNN_MAX_TENSOR_DIMS; i++) {
compressed_input1_shape[i] = 1;
compressed_input2_shape[i] = 1;
compressed_output_shape[i] = 1;
}
bool broadcast_input1 = false;
bool broadcast_input2 = false;
bool first_nonunit = true;
bool degenerate_shape = false;
const size_t num_common_dims = min(num_input1_dims, num_input2_dims);
for (size_t i = 1; i <= num_common_dims; i++) {
const size_t input1_dim = input1_shape[num_input1_dims - i];
const size_t input2_dim = input2_shape[num_input2_dims - i];
degenerate_shape |= input1_dim == 0;
degenerate_shape |= input2_dim == 0;
if (input1_dim == 1 && input2_dim == 1) {
continue;
}
assert(!broadcast_input1 || !broadcast_input2);
if (input1_dim == 1) {
if (!broadcast_input1) {
broadcast_input1 = true;
broadcast_input2 = false;
num_compressed_dims++;
}
compressed_input2_shape[num_compressed_dims - 1] *= input2_dim;
compressed_output_shape[num_compressed_dims - 1] *= input2_dim;
} else if (input2_dim == 1) {
if (!broadcast_input2) {
broadcast_input1 = false;
broadcast_input2 = true;
num_compressed_dims++;
}
compressed_input1_shape[num_compressed_dims - 1] *= input1_dim;
compressed_output_shape[num_compressed_dims - 1] *= input1_dim;
} else if (input1_dim == input2_dim) {
if (broadcast_input1 || broadcast_input2 || first_nonunit) {
broadcast_input1 = false;
broadcast_input2 = false;
num_compressed_dims++;
}
compressed_input1_shape[num_compressed_dims - 1] *= input1_dim;
compressed_input2_shape[num_compressed_dims - 1] *= input1_dim;
compressed_output_shape[num_compressed_dims - 1] *= input1_dim;
} else {
xnn_log_error(
"failed to reshape %s operator: shape dimension #%zu of input1 (%zu) "
"does not match shape dimension #%zu of input2 (%zu)",
xnn_operator_type_to_string_v2(op), num_input1_dims - i, input1_dim,
num_input2_dims - i, input2_dim);
return xnn_status_invalid_parameter;
}
first_nonunit = false;
}
if (num_input1_dims > num_input2_dims) {
if (!broadcast_input2) {
num_compressed_dims++;
}
for (size_t i = 0; i < num_input1_dims - num_input2_dims; i++) {
const size_t input1_dim = input1_shape[i];
degenerate_shape |= input1_dim == 0;
compressed_input1_shape[num_compressed_dims - 1] *= input1_dim;
compressed_output_shape[num_compressed_dims - 1] *= input1_dim;
}
} else if (num_input2_dims > num_input1_dims) {
if (!broadcast_input1) {
num_compressed_dims++;
}
for (size_t i = 0; i < num_input2_dims - num_input1_dims; i++) {
const size_t input2_dim = input2_shape[i];
degenerate_shape |= input2_dim == 0;
compressed_input2_shape[num_compressed_dims - 1] *= input2_dim;
compressed_output_shape[num_compressed_dims - 1] *= input2_dim;
}
}
num_compressed_dims = max(num_compressed_dims, 1);
// Early exit without setting up context if any shape dimension is zero.
if (degenerate_shape) {
op->state = xnn_run_state_skip;
return xnn_status_success;
}
const uint32_t log2_element_size = op->binary_elementwise.log2_element_size;
op->context.elementwise_binary = (struct elementwise_binary_context){
.elements = compressed_output_shape[0] << log2_element_size,
};
memcpy(&op->context.elementwise_binary.params, &op->params.binary,
sizeof(op->params.binary));
const size_t* compressed_a_shape = compressed_input1_shape;
const size_t* compressed_b_shape = compressed_input2_shape;
if (compressed_input1_shape[0] == 1) {
op->context.elementwise_binary.flip_a_b = true;
op->context.elementwise_binary.ukernel =
op->binary_elementwise_config->ropc_ukernel;
compressed_a_shape = compressed_input2_shape;
compressed_b_shape = compressed_input1_shape;
memcpy(&op->context.elementwise_binary.params, op->extra_params,
sizeof(op->extra_params->binary));
} else if (compressed_input2_shape[0] == 1) {
op->context.elementwise_binary.ukernel =
op->binary_elementwise_config->opc_ukernel;
} else if (compressed_input1_shape[0] == compressed_input2_shape[0]) {
op->context.elementwise_binary.ukernel =
op->binary_elementwise_config->op_ukernel;
}
size_t a_stride = compressed_a_shape[0];
size_t b_stride = compressed_b_shape[0];
size_t y_stride = compressed_output_shape[0];
for (size_t i = 1; i < num_compressed_dims; i++) {
if (compressed_a_shape[i] != 1) {
op->context.elementwise_binary.a_stride[XNN_MAX_TENSOR_DIMS - 1 - i] =
a_stride << log2_element_size;
}
if (compressed_b_shape[i] != 1) {
op->context.elementwise_binary.b_stride[XNN_MAX_TENSOR_DIMS - 1 - i] =
b_stride << log2_element_size;
}
op->context.elementwise_binary.y_stride[XNN_MAX_TENSOR_DIMS - 1 - i] =
y_stride << log2_element_size;
a_stride *= compressed_a_shape[i];
b_stride *= compressed_b_shape[i];
y_stride *= compressed_output_shape[i];
}
if (compressed_output_shape[5] == 1) {
if (compressed_output_shape[4] == 1) {
if (compressed_output_shape[3] == 1) {
if (compressed_output_shape[2] == 1) {
if (compressed_output_shape[1] == 1) {
op->context.elementwise_binary.a_stride[4] =
compressed_a_shape[0] == 1 ? 0 : (1 << log2_element_size);
op->context.elementwise_binary.b_stride[4] =
compressed_b_shape[0] == 1 ? 0 : (1 << log2_element_size);
op->context.elementwise_binary.y_stride[4] =
(1 << log2_element_size);
op->context.elementwise_binary.elements = (1 << log2_element_size);
op->compute[0].type = xnn_parallelization_type_1d_tile_1d_dynamic;
op->compute[0].task_1d_tile_1d_dynamic =
(pthreadpool_task_1d_tile_1d_dynamic_t)
xnn_compute_elementwise_binary_1d_tile;
op->compute[0].range[0] = compressed_output_shape[0]
<< log2_element_size;
op->compute[0].tile[0] = get_tile_size(op);
} else {
op->compute[0].type = xnn_parallelization_type_1d_tile_1d_dynamic;
op->compute[0].task_1d_tile_1d_dynamic =
(pthreadpool_task_1d_tile_1d_dynamic_t)
xnn_compute_elementwise_binary_1d;
op->compute[0].range[0] = compressed_output_shape[1];
op->compute[0].tile[0] = divide_round_up(
get_tile_size(op), op->context.elementwise_binary.elements);
}
} else {
op->compute[0].type = xnn_parallelization_type_2d_tile_1d_dynamic;
op->compute[0].task_2d_tile_1d_dynamic =
(pthreadpool_task_2d_tile_1d_dynamic_t)
xnn_compute_elementwise_binary_2d;
op->compute[0].range[0] = compressed_output_shape[2];
op->compute[0].range[1] = compressed_output_shape[1];
op->compute[0].tile[0] = divide_round_up(
get_tile_size(op), op->context.elementwise_binary.elements);
}
} else {
op->compute[0].type = xnn_parallelization_type_3d_tile_2d_dynamic;
op->compute[0].task_3d_tile_2d_dynamic =
(pthreadpool_task_3d_tile_2d_dynamic_t)
xnn_compute_elementwise_binary_3d;
op->compute[0].range[0] = compressed_output_shape[3];
op->compute[0].range[1] = compressed_output_shape[2];
op->compute[0].range[2] = compressed_output_shape[1];
op->compute[0].tile[0] = 1;
op->compute[0].tile[1] = 1;
}
} else {
op->compute[0].type = xnn_parallelization_type_4d_tile_2d_dynamic;
op->compute[0].task_4d_tile_2d_dynamic =
(pthreadpool_task_4d_tile_2d_dynamic_t)
xnn_compute_elementwise_binary_4d;
op->compute[0].range[0] = compressed_output_shape[4];
op->compute[0].range[1] = compressed_output_shape[3];
op->compute[0].range[2] = compressed_output_shape[2];
op->compute[0].range[3] = compressed_output_shape[1];
op->compute[0].tile[0] = 1;
op->compute[0].tile[1] = 1;
}
} else {
op->compute[0].type = xnn_parallelization_type_5d;
op->compute[0].task_5d =
(pthreadpool_task_5d_t)xnn_compute_elementwise_binary_5d;
op->compute[0].range[0] = compressed_output_shape[5];
op->compute[0].range[1] = compressed_output_shape[4];
op->compute[0].range[2] = compressed_output_shape[3];
op->compute[0].range[3] = compressed_output_shape[2];
op->compute[0].range[4] = compressed_output_shape[1];
}
op->state = xnn_run_state_needs_setup;
return xnn_status_success;
}
enum xnn_status xnn_setup_binary_elementwise_nd(xnn_operator_t op,
const void* input1,
const void* input2,
void* output) {
switch (op->state) {
case xnn_run_state_skip:
return xnn_status_success;
case xnn_run_state_invalid:
xnn_log_error(
"failed to setup %s operator: operator has not been reshaped yet",
xnn_operator_type_to_string_v2(op));
return xnn_status_invalid_state;
case xnn_run_state_needs_setup:
// Operator has been reshaped, but not setup, continue with setup.
case xnn_run_state_ready:
// Operator has been reshaped, and we are setting up with different pointers.
break;
}
op->context.elementwise_binary.a = input1;
op->context.elementwise_binary.b = input2;
op->context.elementwise_binary.y = output;
if (op->context.elementwise_binary.flip_a_b) {
op->context.elementwise_binary.a = input2;
op->context.elementwise_binary.b = input1;
}
op->state = xnn_run_state_ready;
return xnn_status_success;
}
enum xnn_status xnn_run_binary_elementwise_nd(
enum xnn_binary_operator type, enum xnn_datatype datatype,
const struct xnn_quantization_params* input1_quantization,
const struct xnn_quantization_params* input2_quantization,
const struct xnn_quantization_params* output_quantization, uint32_t flags,
size_t num_input1_dims, const size_t* input1_shape, size_t num_input2_dims,
const size_t* input2_shape, const void* input1, const void* input2,
void* output, pthreadpool_t threadpool) {
struct xnn_operator op;
memset(&op, 0, sizeof(op));
op.compute = xnn_allocate_zero_memory(sizeof(struct compute_parameters));
if (op.compute == NULL) {
xnn_log_error("failed to allocate %zu bytes for %s operator descriptor",
sizeof(struct compute_parameters),
xnn_binary_operator_to_string(type));
return xnn_status_out_of_memory;
}
op.num_compute_invocations = 1;
xnn_allocate_extra_params(&op, /*num_extra_params=*/1);
memset(op.extra_params, 0, sizeof(union xnn_params));
enum xnn_status status = init_binary_elementwise_nd(
&op, type, datatype, input1_quantization, input2_quantization,
output_quantization, flags);
if (status != xnn_status_success) {
xnn_destroy_operator(&op);
return status;
}
status = xnn_reshape_binary_elementwise_nd(&op, num_input1_dims, input1_shape,
num_input2_dims, input2_shape,
threadpool);
if (status != xnn_status_success) {
xnn_destroy_operator(&op);
return status;
}
status = xnn_setup_binary_elementwise_nd(&op, input1, input2, output);
if (status != xnn_status_success) {
xnn_destroy_operator(&op);
return status;
}
status = xnn_run_operator(&op, threadpool);
xnn_destroy_operator(&op);
return status;
}