| // Copyright (c) Facebook, Inc. and its affiliates. |
| // All rights reserved. |
| // |
| // Copyright 2019 Google LLC |
| // |
| // This source code is licensed under the BSD-style license found in the |
| // LICENSE file in the root directory of this source tree. |
| |
| #include <assert.h> |
| #include <limits.h> |
| #include <math.h> |
| #include <stddef.h> |
| #include <stdint.h> |
| #include <string.h> |
| |
| #include "include/xnnpack.h" |
| #include "src/xnnpack/common.h" |
| #include "src/xnnpack/compute.h" |
| #include "src/xnnpack/indirection.h" |
| #include "src/xnnpack/log.h" |
| #include "src/xnnpack/math.h" |
| #include "src/xnnpack/microfnptr.h" |
| #include "src/xnnpack/microkernel-type.h" |
| #include "src/xnnpack/microparams.h" |
| #include "src/xnnpack/operator-utils.h" |
| #include "src/xnnpack/operator.h" |
| #include "src/xnnpack/packq.h" |
| #include "src/xnnpack/quantization.h" |
| #include <pthreadpool.h> |
| |
| #if XNN_MAX_UARCH_TYPES > 1 |
| #include "src/xnnpack/config-types.h" |
| #include "src/xnnpack/microparams-init.h" |
| #endif // XNN_MAX_UARCH_TYPES > 1 |
| |
| void xnn_compute_transposec_2d(struct transpose_context* restrict context, |
| size_t i, size_t j, size_t tile_i, |
| size_t tile_j) { |
| const size_t ld_input = context->input_stride[1]; |
| const size_t ld_output = context->output_stride[0]; |
| context->const_size_ukernel( |
| (const void*)((uintptr_t)context->x + i * context->input_stride[0] + |
| j * context->input_stride[1]), |
| (void*)((uintptr_t)context->y + j * context->output_stride[1] + |
| i * context->output_stride[0]), |
| ld_input, ld_output, tile_i, tile_j); |
| } |
| |
| void xnn_compute_transposec_3d(struct transpose_context* restrict context, |
| size_t i, size_t j, size_t k, size_t tile_j, |
| size_t tile_k) { |
| const size_t ld_input = context->input_stride[2]; |
| const size_t ld_output = context->output_stride[1]; |
| const void* x = |
| (const void*)((uintptr_t)context->x + i * context->input_stride[0] + |
| j * context->input_stride[1] + |
| k * context->input_stride[2]); |
| void* y = |
| (void*)((uintptr_t)context->y + i * context->output_stride[0] + |
| j * context->output_stride[1] + k * context->output_stride[2]); |
| |
| context->const_size_ukernel(x, y, ld_input, ld_output, tile_j, tile_k); |
| } |
| |
| void xnn_compute_transposec_4d(struct transpose_context* restrict context, |
| size_t i, size_t j, size_t k, size_t l, |
| size_t tile_k, size_t tile_l) { |
| const size_t ld_input = context->input_stride[3]; |
| const size_t ld_output = context->output_stride[2]; |
| const void* x = |
| (const void*)((uintptr_t)context->x + i * context->input_stride[0] + |
| j * context->input_stride[1] + |
| k * context->input_stride[2] + |
| l * context->input_stride[3]); |
| void* y = |
| (void*)((uintptr_t)context->y + i * context->output_stride[0] + |
| j * context->output_stride[1] + k * context->output_stride[2] + |
| l * context->output_stride[3]); |
| |
| context->const_size_ukernel(x, y, ld_input, ld_output, tile_k, tile_l); |
| } |
| |
| void xnn_compute_transposec_5d(struct transpose_context* restrict context, |
| size_t i, size_t j, size_t k, size_t l, size_t m, |
| size_t tile_l, size_t tile_m) { |
| const size_t ld_input = context->input_stride[4]; |
| const size_t ld_output = context->output_stride[3]; |
| const void* x = |
| (const void*)((uintptr_t)context->x + i * context->input_stride[0] + |
| j * context->input_stride[1] + |
| k * context->input_stride[2] + |
| l * context->input_stride[3] + |
| m * context->input_stride[4]); |
| void* y = |
| (void*)((uintptr_t)context->y + i * context->output_stride[0] + |
| j * context->output_stride[1] + k * context->output_stride[2] + |
| l * context->output_stride[3] + m * context->output_stride[4]); |
| |
| context->const_size_ukernel(x, y, ld_input, ld_output, tile_l, tile_m); |
| } |
| |
| void xnn_compute_transposec_6d(struct transpose_context* restrict context, |
| size_t i, size_t j, size_t k, size_t l, size_t m, |
| size_t n, size_t tile_m, size_t tile_n) { |
| const size_t ld_input = context->input_stride[5]; |
| const size_t ld_output = context->output_stride[4]; |
| const void* x = |
| (const void*)((uintptr_t)context->x + i * context->input_stride[0] + |
| j * context->input_stride[1] + |
| k * context->input_stride[2] + |
| l * context->input_stride[3] + |
| m * context->input_stride[4] + |
| n * context->input_stride[5]); |
| void* y = |
| (void*)((uintptr_t)context->y + i * context->output_stride[0] + |
| j * context->output_stride[1] + k * context->output_stride[2] + |
| l * context->output_stride[3] + m * context->output_stride[4] + |
| n * context->output_stride[5]); |
| |
| context->const_size_ukernel(x, y, ld_input, ld_output, tile_m, tile_n); |
| } |
| |
| void xnn_compute_transposev_2d(struct transpose_context* restrict context, |
| size_t i, size_t j, size_t tile_i, |
| size_t tile_j) { |
| const size_t element_size = context->output_stride[1]; |
| const size_t ld_input = context->input_stride[1]; |
| const size_t ld_output = context->output_stride[0]; |
| const void* x = |
| (const void*)((uintptr_t)context->x + i * context->input_stride[0] + |
| j * context->input_stride[1]); |
| void* y = (void*)((uintptr_t)context->y + context->output_stride[1] * j + |
| i * context->output_stride[0]); |
| |
| context->variable_size_ukernel( |
| x, y, ld_input, ld_output, context->input_stride[0], |
| context->output_stride[1], element_size, tile_i, tile_j); |
| } |
| |
| void xnn_compute_transposev_3d(struct transpose_context* restrict context, |
| size_t i, size_t j, size_t k, size_t tile_j, |
| size_t tile_k) { |
| const size_t element_size = context->output_stride[2]; |
| const size_t ld_input = context->input_stride[2]; |
| const size_t ld_output = context->output_stride[1]; |
| const void* x = |
| (const void*)((uintptr_t)context->x + i * context->input_stride[0] + |
| j * context->input_stride[1] + |
| k * context->input_stride[2]); |
| void* y = |
| (void*)((uintptr_t)context->y + i * context->output_stride[0] + |
| j * context->output_stride[1] + k * context->output_stride[2]); |
| |
| context->variable_size_ukernel( |
| x, y, ld_input, ld_output, context->input_stride[1], |
| context->output_stride[2], element_size, tile_j, tile_k); |
| } |
| |
| void xnn_compute_transposev_4d(struct transpose_context* restrict context, |
| size_t i, size_t j, size_t k, size_t l, |
| size_t tile_k, size_t tile_l) { |
| const size_t element_size = context->output_stride[3]; |
| const size_t ld_input = context->input_stride[3]; |
| const size_t ld_output = context->output_stride[2]; |
| const void* x = |
| (const void*)((uintptr_t)context->x + i * context->input_stride[0] + |
| j * context->input_stride[1] + |
| k * context->input_stride[2] + |
| l * context->input_stride[3]); |
| void* y = |
| (void*)((uintptr_t)context->y + context->output_stride[3] * l + |
| i * context->output_stride[0] + j * context->output_stride[1] + |
| k * context->output_stride[2]); |
| |
| context->variable_size_ukernel( |
| x, y, ld_input, ld_output, context->input_stride[2], |
| context->output_stride[3], element_size, tile_k, tile_l); |
| } |
| |
| void xnn_compute_transposev_5d(struct transpose_context* restrict context, |
| size_t i, size_t j, size_t k, size_t l, size_t m, |
| size_t tile_l, size_t tile_m) { |
| const size_t element_size = context->output_stride[4]; |
| const size_t ld_input = context->input_stride[4]; |
| const size_t ld_output = context->output_stride[3]; |
| const void* x = |
| (const void*)((uintptr_t)context->x + i * context->input_stride[0] + |
| j * context->input_stride[1] + |
| k * context->input_stride[2] + |
| l * context->input_stride[3] + |
| m * context->input_stride[4]); |
| void* y = |
| (void*)((uintptr_t)context->y + context->output_stride[4] * m + |
| i * context->output_stride[0] + j * context->output_stride[1] + |
| k * context->output_stride[2] + l * context->output_stride[3]); |
| |
| context->variable_size_ukernel( |
| x, y, ld_input, ld_output, context->input_stride[3], |
| context->output_stride[4], element_size, tile_l, tile_m); |
| } |
| |
| void xnn_compute_transposev_6d(struct transpose_context* restrict context, |
| size_t i, size_t j, size_t k, size_t l, size_t m, |
| size_t n, size_t tile_m, size_t tile_n) { |
| const size_t element_size = context->output_stride[5]; |
| const size_t ld_input = context->input_stride[5]; |
| const size_t ld_output = context->output_stride[4]; |
| const void* x = |
| (const void*)((uintptr_t)context->x + i * context->input_stride[0] + |
| j * context->input_stride[1] + |
| k * context->input_stride[2] + |
| l * context->input_stride[3] + |
| m * context->input_stride[4] + |
| n * context->input_stride[5]); |
| void* y = |
| (void*)((uintptr_t)context->y + context->output_stride[5] * n + |
| i * context->output_stride[0] + j * context->output_stride[1] + |
| k * context->output_stride[2] + l * context->output_stride[3] + |
| m * context->output_stride[4]); |
| |
| context->variable_size_ukernel( |
| x, y, ld_input, ld_output, context->input_stride[4], |
| context->output_stride[5], element_size, tile_m, tile_n); |
| } |
| |
| void xnn_compute_batched_packw_gemm_gio( |
| struct packw_gemm_gio_context* restrict context, size_t batch_index, |
| size_t n_block_start, size_t n_block_size) { |
| const void* kernel = (const void*)((uintptr_t)context->kernel + |
| n_block_start * context->n_stride + |
| batch_index * context->gk_stride); |
| const void* bias = context->bias; |
| if (bias != NULL) { |
| bias = (const void*)((uintptr_t)bias + n_block_start * context->b_stride + |
| batch_index * context->gb_stride); |
| } |
| void* packed_weights = (void*)((uintptr_t)context->packed_weights + |
| n_block_start * context->w_stride + |
| batch_index * context->gc_stride); |
| |
| if (context->pack_weights_and_biases) { |
| context->pack_weights_and_biases( |
| /*flags=*/XNN_FLAG_TRANSPOSE_WEIGHTS, context->gemm_config, context->kc, |
| n_block_size, /*groups=*/1, /*block_size=*/0, |
| /*k_stride=*/context->k_stride_elements, /*accumulator_init=*/bias, |
| kernel, /*init_extra_data0_fn=*/context->init_scale_b, |
| /*extra_data0=*/context->scale_b, |
| /*extra_data0_element_size=*/context->scale_b_size, |
| /*init_extra_data1_fn=*/NULL, /*extra_data1=*/NULL, |
| /*extra_data1_element_size=*/0, packed_weights, |
| /*params=*/context->params); |
| } else { |
| context->packw_gemm_gio( |
| /*groups=*/1, n_block_size, context->kc, context->nr, context->kr, |
| context->sr, context->k_stride_elements, kernel, bias, /*scale=*/NULL, |
| packed_weights, /*extra_bytes=*/context->nr * context->scale_b_size, |
| /*params=*/context->params); |
| |
| if (context->scale_b != NULL) { |
| assert(context->init_scale_b != NULL); |
| void* weights = |
| (void*)((uintptr_t)packed_weights + context->nr * ( |
| context->w_stride - context->scale_b_size)); |
| context->init_scale_b(n_block_size, context->nr, |
| context->nr * context->w_stride, |
| context->scale_b, weights); |
| } |
| } |
| } |
| |
| void xnn_compute_packw_gemm_gio( |
| struct packw_gemm_gio_context* restrict context, size_t n_block_start, |
| size_t n_block_size) { |
| xnn_compute_batched_packw_gemm_gio(context, /*batch_index=*/0, n_block_start, |
| n_block_size); |
| } |
| |
| void xnn_compute_batched_packw_gemm_goi( |
| struct packw_gemm_goi_context* restrict context, size_t batch_index, |
| size_t n_block_start, size_t n_block_size) { |
| const void* kernel = (const void*)((uintptr_t)context->kernel + |
| context->k_stride * n_block_start + |
| batch_index * context->gk_stride); |
| const void* bias = context->bias; |
| if (bias != NULL) { |
| bias = (const void*)((uintptr_t)bias + n_block_start * context->b_stride + |
| batch_index * context->gb_stride); |
| } |
| void* packed_weights = (void*)((uintptr_t)context->packed_weights + |
| context->w_stride * n_block_start + |
| batch_index * context->gc_stride); |
| |
| if (context->pack_weights_and_biases) { |
| context->pack_weights_and_biases( |
| /*flags=*/0, context->gemm_config, context->kc, n_block_size, |
| /*groups=*/1, /*block_size=*/0, /*k_stride=*/context->kc, |
| /*accumulator_init=*/bias, kernel, |
| /*init_extra_data0_fn=*/context->init_scale_b, |
| /*extra_data0=*/context->scale_b, |
| /*extra_data0_element_size=*/context->scale_b_size, |
| /*init_extra_data1_fn=*/NULL, /*extra_data1=*/NULL, |
| /*extra_data1_element_size=*/0, packed_weights, |
| /*params=*/context->params); |
| } else { |
| context->packw_gemm_goi( |
| /*groups=*/1, n_block_size, context->kc, context->nr, context->kr, |
| context->sr, kernel, bias, /*scale=*/NULL, packed_weights, |
| /*extra_bytes=*/context->nr * context->scale_b_size, |
| /*params=*/context->params); |
| |
| if (context->scale_b != NULL) { |
| assert(context->init_scale_b != NULL); |
| void* weights = |
| (void*)((uintptr_t)packed_weights + context->nr * |
| (context->w_stride - context->scale_b_size)); |
| context->init_scale_b(n_block_size, context->nr, |
| context->nr * context->w_stride, context->scale_b, |
| weights); |
| } |
| } |
| } |
| |
| void xnn_compute_packw_gemm_goi(struct packw_gemm_goi_context* restrict context, |
| size_t n_block_start, size_t n_block_size) { |
| xnn_compute_batched_packw_gemm_goi(context, /*batch_index=*/0, n_block_start, |
| n_block_size); |
| } |
| |
| static void compute_group_indices(struct gemm_context* context, |
| size_t group_index, size_t* group_index_a, |
| size_t* group_index_b) { |
| const size_t num_batch_dims = context->num_batch_dims; |
| *group_index_a = 0; |
| *group_index_b = 0; |
| for (int k = 0; k < num_batch_dims; k++) { |
| // Extract the kth batch index from the group_index. |
| const size_t index = group_index / context->batch_strides_c[k]; |
| group_index %= context->batch_strides_c[k]; |
| |
| // Compute the corresponding kth group index offsets into A and B. |
| *group_index_a = (index % context->batch_dims_a[k]) + |
| context->batch_dims_a[k] * *group_index_a; |
| *group_index_b = (index % context->batch_dims_b[k]) + |
| context->batch_dims_b[k] * *group_index_b; |
| } |
| } |
| |
| void xnn_compute_hmp_grouped_gemm(struct gemm_context* restrict context, |
| uint32_t uarch_index, size_t group_index, |
| size_t nr_block_start, size_t mr_block_start, |
| size_t nr_block_size, size_t mr_block_size) { |
| const size_t k_scaled = context->k_scaled; |
| const size_t a_stride = context->a_stride; |
| const size_t cm_stride = context->cm_stride; |
| const size_t group_index_c = group_index; |
| |
| // Compute the group index offsets into A and B. |
| size_t group_index_a = 0; |
| size_t group_index_b = 0; |
| compute_group_indices(context, group_index, &group_index_a, &group_index_b); |
| |
| while (mr_block_size > 0) { |
| const size_t mr_step = min(mr_block_size, context->mr); |
| if (context->quantization_params != NULL) { |
| // If the effective `mr_block_size` is smaller than the kernel's `mr`, |
| // create a padded copy of the dynamic quantization params. |
| const struct xnn_qd8_quantization_params* quantization_params = |
| &context->quantization_params[group_index_a * context->gq_stride + |
| mr_block_start]; |
| struct xnn_qd8_quantization_params padded_quantization_params[XNN_MAX_MR]; |
| if (mr_step < context->mr) { |
| for (size_t i = 0; i < mr_step; i++) { |
| padded_quantization_params[i] = quantization_params[i]; |
| } |
| for (size_t i = mr_step; i < context->mr; i++) { |
| padded_quantization_params[i] = |
| padded_quantization_params[mr_step - 1]; |
| } |
| quantization_params = padded_quantization_params; |
| }; |
| |
| context->dq_ukernel.function[uarch_index]( |
| mr_step, nr_block_size, k_scaled, |
| (const void*)((uintptr_t)context->a + mr_block_start * a_stride + |
| group_index_a * context->ga_stride), |
| a_stride, |
| (const void*)((uintptr_t)context->packed_w + |
| nr_block_start * context->w_stride + |
| group_index_b * context->gw_stride), |
| (void*)((uintptr_t)context->c + mr_block_start * cm_stride + |
| (nr_block_start << context->log2_csize) + |
| group_index_c * context->gc_stride), |
| cm_stride, context->cn_stride, &context->params, quantization_params); |
| } else { |
| context->ukernel.function[uarch_index]( |
| mr_step, nr_block_size, k_scaled, |
| (const void*)((uintptr_t)context->a + mr_block_start * a_stride + |
| group_index_a * context->ga_stride), |
| a_stride, |
| (const void*)((uintptr_t)context->packed_w + |
| nr_block_start * context->w_stride + |
| group_index_b * context->gw_stride), |
| (void*)((uintptr_t)context->c + mr_block_start * cm_stride + |
| (nr_block_start << context->log2_csize) + |
| group_index_c * context->gc_stride), |
| cm_stride, context->cn_stride, &context->params); |
| } |
| mr_block_size -= mr_step; |
| mr_block_start += mr_step; |
| } |
| } |
| |
| void xnn_compute_grouped_gemm(struct gemm_context* restrict context, |
| size_t group_index, size_t nr_block_start, |
| size_t mr_block_start, size_t nr_block_size, |
| size_t mr_block_size) { |
| xnn_compute_hmp_grouped_gemm(context, XNN_UARCH_DEFAULT, group_index, |
| nr_block_start, mr_block_start, nr_block_size, |
| mr_block_size); |
| } |
| |
| void xnn_compute_gemm(struct gemm_context* restrict context, |
| size_t nr_block_start, size_t mr_block_start, |
| size_t nr_block_size, size_t mr_block_size) { |
| xnn_compute_hmp_gemm(context, XNN_UARCH_DEFAULT, nr_block_start, |
| mr_block_start, nr_block_size, mr_block_size); |
| } |
| |
| void xnn_compute_dqgemm(struct gemm_context* restrict context, |
| size_t nr_block_start, size_t mr_block_start, |
| size_t nr_block_size, size_t mr_block_size) { |
| xnn_compute_hmp_dqgemm(context, XNN_UARCH_DEFAULT, nr_block_start, |
| mr_block_start, nr_block_size, mr_block_size); |
| } |
| |
| void xnn_compute_hmp_grouped_qp8gemm(struct gemm_context* restrict context, |
| uint32_t uarch_index, size_t group_index, |
| size_t nr_block_start, |
| size_t mr_block_start, |
| size_t nr_block_size, |
| size_t mr_block_size) { |
| const size_t cm_stride = context->cm_stride; |
| const size_t cn_stride = context->cn_stride; |
| |
| // Compute the group index offsets into A and B. |
| const size_t group_index_c = group_index; |
| size_t group_index_a = 0; |
| size_t group_index_b = 0; |
| compute_group_indices(context, group_index, &group_index_a, &group_index_b); |
| |
| const size_t mr = context->mr; |
| const size_t mr_packed = context->mr_packed; |
| const size_t kr = context->kr; |
| const size_t sr = context->sr; |
| const size_t kc = context->kc; |
| const size_t k_scaled = |
| kc << context->packed_lh_config->log2_packed_element_size; |
| const uintptr_t a = |
| (uintptr_t)context->a + group_index_a * context->ga_stride; |
| const uintptr_t c = (uintptr_t)context->c + |
| group_index_c * context->gc_stride + |
| (nr_block_start << context->log2_csize); |
| const void* packed_w = (const void*)((uintptr_t)context->packed_w + |
| group_index_b * context->gw_stride + |
| nr_block_start * context->w_stride); |
| const uintptr_t packed_input_stride = round_up(kc, kr * sr) * sizeof(int8_t); |
| |
| while (mr_block_size > 0) { |
| const size_t mr_step = min(mr_block_size, mr); |
| const size_t a_offset = context->packed_lh_config->offset_fn( |
| mr_block_start, kc, mr_packed, kr, sr); |
| |
| if (context->dynamic_quantization) { |
| const void* workspace = (const void*)((uintptr_t)a + a_offset); |
| const struct xnn_qd8_quantization_params* quantization_params = workspace; |
| const void* packed_inputs = |
| (const void*)((uintptr_t)workspace + |
| mr * sizeof(struct xnn_qd8_quantization_params)); |
| |
| context->dq_ukernel.function[uarch_index]( |
| mr_step, nr_block_size, k_scaled, packed_inputs, packed_input_stride, |
| packed_w, (void*)(c + mr_block_start * cm_stride), cm_stride, |
| cn_stride, context->fused_params, quantization_params); |
| } else { |
| context->qp8_ukernel.function[uarch_index]( |
| mr_step, nr_block_size, k_scaled, (const void*)(a + a_offset), |
| packed_w, (void*)(c + mr_block_start * cm_stride), cm_stride, |
| /*dst_stride_col=*/sizeof(float), context->fused_params); |
| } |
| mr_block_size -= mr_step; |
| mr_block_start += mr_step; |
| } |
| } |
| |
| void xnn_compute_grouped_qp8gemm(struct gemm_context* restrict context, |
| size_t group_index, size_t nr_block_start, |
| size_t mr_block_start, size_t nr_block_size, |
| size_t mr_block_size) { |
| xnn_compute_hmp_grouped_qp8gemm(context, XNN_UARCH_DEFAULT, group_index, |
| nr_block_start, mr_block_start, nr_block_size, |
| mr_block_size); |
| } |
| |
| XNN_INLINE static void compute_hmp_qp8gemm( |
| struct gemm_context* restrict context, uint32_t uarch_index, |
| size_t nr_block_start, size_t mr_block_start, size_t nr_block_size, |
| size_t mr_block_size) { |
| const size_t cm_stride = context->cm_stride; |
| const size_t cn_stride = context->cn_stride; |
| const size_t mr_packed = context->mr_packed; |
| const size_t mr = context->mr; |
| const size_t kr = context->kr; |
| const size_t sr = context->sr; |
| const size_t kc = context->kc; |
| const size_t k_scaled = |
| kc << context->packed_lh_config->log2_packed_element_size; |
| const uintptr_t a = (uintptr_t)context->a; |
| const uintptr_t c = |
| (uintptr_t)context->c + (nr_block_start << context->log2_csize); |
| const uintptr_t packed_input_stride = round_up(kc, kr * sr) * sizeof(int8_t); |
| const void* packed_w = (const void*)((uintptr_t)context->packed_w + |
| nr_block_start * context->w_stride); |
| |
| while (mr_block_size > 0) { |
| const size_t mr_step = min(mr_block_size, mr); |
| const size_t a_offset = context->packed_lh_config->offset_fn( |
| mr_block_start, context->kc, mr_packed, context->kr, context->sr); |
| |
| if (context->dynamic_quantization) { |
| const void* workspace = (const void*)((uintptr_t)a + a_offset); |
| const struct xnn_qd8_quantization_params* quantization_params = workspace; |
| |
| if (context->with_row_sum) { |
| const float* row_sum = (const float*)((uintptr_t)workspace + |
| mr * sizeof(struct xnn_qd8_quantization_params)); |
| const void* packed_inputs = |
| (const void*)((uintptr_t)row_sum + mr * sizeof(float)); |
| context->dq_qc2w_ukernel.function[uarch_index]( |
| mr_step, nr_block_size, k_scaled, packed_inputs, |
| packed_input_stride, packed_w, |
| (void*)(c + mr_block_start * cm_stride), cm_stride, cn_stride, |
| context->fused_params, row_sum, quantization_params); |
| } else { |
| const void* packed_inputs = |
| (const void*)((uintptr_t)workspace + |
| mr * sizeof(struct xnn_qd8_quantization_params)); |
| context->dq_ukernel.function[uarch_index]( |
| mr_step, nr_block_size, k_scaled, packed_inputs, |
| packed_input_stride, packed_w, |
| (void*)(c + mr_block_start * cm_stride), cm_stride, cn_stride, |
| context->fused_params, quantization_params); |
| } |
| } else { |
| context->qp8_ukernel.function[uarch_index]( |
| mr_step, nr_block_size, k_scaled, |
| (const void*)((uintptr_t)a + a_offset), packed_w, |
| (void*)((uintptr_t)c + mr_block_start * cm_stride), cm_stride, |
| /*dst_stride_col=*/sizeof(float), context->fused_params); |
| } |
| |
| mr_block_size -= mr_step; |
| mr_block_start += mr_step; |
| } |
| } |
| |
| void xnn_compute_hmp_qp8gemm(struct gemm_context* restrict context, |
| uint32_t uarch_index, size_t nr_block_start, |
| size_t mr_block_start, size_t nr_block_size, |
| size_t mr_block_size) { |
| compute_hmp_qp8gemm(context, uarch_index, nr_block_start, mr_block_start, |
| nr_block_size, mr_block_size); |
| } |
| |
| void xnn_compute_qp8gemm(struct gemm_context* restrict context, |
| size_t nr_block_start, size_t mr_block_start, |
| size_t nr_block_size, size_t mr_block_size) { |
| compute_hmp_qp8gemm(context, XNN_UARCH_DEFAULT, nr_block_start, |
| mr_block_start, nr_block_size, mr_block_size); |
| } |
| |
| void xnn_compute_spmm(struct spmm_context* restrict context, size_t batch_index, |
| size_t mr_block_start, size_t mr_block_size) { |
| context->ukernel( |
| mr_block_size, context->n, |
| (const void*)((uintptr_t)context->input + |
| batch_index * context->batched_input_stride + |
| mr_block_start), |
| context->nonzero_weights, context->input_increments, |
| context->output_channel_nonzeros, |
| (void*)((uintptr_t)context->output + |
| batch_index * context->batched_output_stride + mr_block_start), |
| context->scaled_m, &context->params); |
| } |
| |
| XNN_INLINE static void compute_inline_packed_qp8gemm( |
| struct gemm_context* context, uint32_t uarch_index, size_t thread_id, |
| size_t group_index_a, size_t group_index_b, size_t group_index_c, |
| size_t mr_block_start, size_t mr_block_size) { |
| assert(context->packed_lh_config != NULL); |
| assert(context->packed_lh_config->offset_fn != NULL); |
| assert(context->packed_lh_config->pack_lh_fn != NULL); |
| |
| const size_t cm_stride = context->cm_stride; |
| const size_t cn_stride = context->cn_stride; |
| const size_t mr = context->mr; |
| const size_t mr_packed = context->mr_packed; |
| const size_t kr = context->kr; |
| const size_t sr = context->sr; |
| const size_t kc = context->kc; |
| const size_t nc = context->nc; |
| const uintptr_t a = |
| (uintptr_t)context->a + group_index_a * context->ga_stride; |
| const size_t a_stride = context->a_stride; |
| const void* packed_w = (const void*)((uintptr_t)context->packed_w + |
| group_index_b * context->gw_stride); |
| const uintptr_t c = |
| (uintptr_t)context->c + group_index_c * context->gc_stride; |
| const size_t k_scaled = |
| context->kc << context->packed_lh_config->log2_packed_element_size; |
| const uintptr_t packed_input_stride = round_up(kc, kr * sr) * sizeof(int8_t); |
| |
| const bool skip_lhs_packing = context->packed_lh_config->gemv_noop && mr == 1; |
| void* workspace = |
| skip_lhs_packing |
| ? NULL |
| : (void*)((uintptr_t)context->workspace + context->workspace_offset + |
| context->packed_lh_config->offset_fn(thread_id * mr, kc, |
| mr_packed, kr, sr)); |
| const void* packed_lhs = workspace; |
| |
| while (mr_block_size > 0) { |
| const size_t mr_step = min(mr_block_size, mr); |
| |
| // Pack the `mr_step` rows of the left-hand operand into the workspace. |
| if (skip_lhs_packing) { |
| packed_lhs = (const void*)(a + mr_block_start * a_stride); |
| } else { |
| context->packed_lh_config->pack_lh_fn( |
| /*m=*/mr_step, kc, mr_packed, kr, sr, |
| /*m_idx_start=*/0, (const void*)(a + mr_block_start * a_stride), |
| a_stride, workspace); |
| } |
| |
| // Call the appropriate GEMM kernel. |
| if (context->dynamic_quantization) { |
| const struct xnn_qd8_quantization_params* quantization_params = |
| packed_lhs; |
| |
| if (context->with_row_sum) { |
| const float* row_sum = (const float*)((uintptr_t)packed_lhs + |
| mr * sizeof(struct xnn_qd8_quantization_params)); |
| const void* packed_inputs = |
| (const void*)((uintptr_t)row_sum + mr * sizeof(float)); |
| context->dq_qc2w_ukernel.function[uarch_index]( |
| mr_step, nc, k_scaled, packed_inputs, packed_input_stride, packed_w, |
| (void*)(c + mr_block_start * cm_stride), cm_stride, cn_stride, |
| context->fused_params, row_sum, quantization_params); |
| } else { |
| const void* packed_inputs = |
| (const void*)((uintptr_t)packed_lhs + |
| mr * sizeof(struct xnn_qd8_quantization_params)); |
| context->dq_ukernel.function[uarch_index]( |
| mr_step, nc, k_scaled, packed_inputs, packed_input_stride, packed_w, |
| (void*)(c + mr_block_start * cm_stride), cm_stride, cn_stride, |
| context->fused_params, quantization_params); |
| } |
| } else { |
| context->qp8_ukernel.function[uarch_index]( |
| mr_step, nc, k_scaled, packed_lhs, packed_w, |
| (void*)(c + mr_block_start * cm_stride), cm_stride, |
| /*dst_stride_col=*/1 << context->log2_csize, context->fused_params); |
| } |
| |
| mr_block_size -= mr_step; |
| mr_block_start += mr_step; |
| } |
| } |
| |
| void xnn_compute_hmp_inline_packed_qp8gemm(struct gemm_context* context, |
| uint32_t uarch_index, |
| size_t thread_id, |
| size_t mr_block_start, |
| size_t mr_block_size) { |
| compute_inline_packed_qp8gemm(context, uarch_index, thread_id, |
| /*group_index_a=*/0, /*group_index_b=*/0, |
| /*group_index_c=*/0, mr_block_start, |
| mr_block_size); |
| } |
| |
| void xnn_compute_inline_packed_qp8gemm(struct gemm_context* context, |
| uint32_t thread_id, |
| size_t mr_block_start, |
| size_t mr_block_size) { |
| compute_inline_packed_qp8gemm(context, XNN_UARCH_DEFAULT, thread_id, |
| /*group_index_a=*/0, /*group_index_b=*/0, |
| /*group_index_c=*/0, mr_block_start, |
| mr_block_size); |
| } |
| |
| void xnn_compute_hmp_grouped_inline_packed_qp8gemm( |
| struct gemm_context* context, uint32_t uarch_index, uint32_t thread_id, |
| size_t group_index, size_t mr_block_start, size_t mr_block_size) { |
| // Compute the group index offsets into A and B. |
| size_t group_index_a = 0; |
| size_t group_index_b = 0; |
| compute_group_indices(context, group_index, &group_index_a, &group_index_b); |
| |
| compute_inline_packed_qp8gemm(context, uarch_index, thread_id, group_index_a, |
| group_index_b, group_index, mr_block_start, |
| mr_block_size); |
| } |
| |
| void xnn_compute_grouped_inline_packed_qp8gemm(struct gemm_context* context, |
| uint32_t thread_id, |
| size_t group_index, |
| size_t mr_block_start, |
| size_t mr_block_size) { |
| // Compute the group index offsets into A and B. |
| size_t group_index_a = 0; |
| size_t group_index_b = 0; |
| compute_group_indices(context, group_index, &group_index_a, &group_index_b); |
| |
| compute_inline_packed_qp8gemm(context, XNN_UARCH_DEFAULT, thread_id, |
| group_index_a, group_index_b, group_index, |
| mr_block_start, mr_block_size); |
| } |
| |
| void xnn_compute_igemm(struct igemm_context* restrict context, |
| size_t batch_index, size_t group_index, |
| size_t nr_block_start, size_t mr_block_start, |
| size_t nr_block_size, size_t mr_block_size) { |
| xnn_compute_hmp_igemm(context, XNN_UARCH_DEFAULT, batch_index, group_index, |
| nr_block_start, mr_block_start, nr_block_size, |
| mr_block_size); |
| } |
| |
| void xnn_compute_dq_zero_buffer_igemm(struct igemm_context* restrict context, |
| size_t batch_index) { |
| memset(context->zero_buffers[batch_index], |
| context->quantization_params[batch_index].zero_point, |
| context->zero_size); |
| } |
| |
| void xnn_compute_dq_zero_buffer_subconv( |
| struct subconv_context* restrict context, size_t batch_index, |
| size_t batch_size) { |
| for (size_t k = 0; k < batch_size; k++) { |
| memset(context->zero_buffers[batch_index + k], |
| context->quantization_params[batch_index + k].zero_point, |
| context->zero_size); |
| } |
| } |
| |
| void xnn_compute_dqigemm(struct igemm_context* restrict context, |
| size_t batch_index, size_t group_index, |
| size_t nr_block_start, size_t mr_block_start, |
| size_t nr_block_size, size_t mr_block_size) { |
| xnn_compute_hmp_dqigemm(context, XNN_UARCH_DEFAULT, batch_index, group_index, |
| nr_block_start, mr_block_start, nr_block_size, |
| mr_block_size); |
| } |
| |
| void xnn_compute_inline_packed_igemm( |
| struct igemm_context* restrict context, uint32_t thread_id, |
| size_t batch_index, size_t group_index, size_t mr_block_start, |
| size_t mr_block_size) { |
| xnn_compute_hmp_inline_packed_igemm(context, XNN_UARCH_DEFAULT, thread_id, |
| batch_index, group_index, mr_block_start, |
| mr_block_size); |
| } |
| |
| void xnn_compute_hmp_inline_packed_igemm(struct igemm_context* restrict context, |
| uint32_t uarch_index, size_t thread_id, |
| size_t batch_index, size_t group_index, |
| size_t mr_block_start, |
| size_t mr_block_size) { |
| const size_t mr = context->mr; |
| const size_t mr_packed = context->mr_packed; |
| const size_t kc = context->kc; |
| const size_t ks = context->ks; |
| const size_t cm_stride = context->cm_stride; |
| const size_t kc_elems = |
| kc >> context->packed_lh_config->log2_input_element_size; |
| const size_t cm_stride_elems = cm_stride >> context->log2_csize; |
| const size_t cm_stride_kernel = |
| (context->log2_csize == 0) ? cm_stride : cm_stride_elems; |
| const size_t a_offset = context->a_offset + batch_index * context->ba_stride + |
| group_index * context->ga_stride; |
| const void* packed_w = (const void*)((uintptr_t)context->packed_w + |
| group_index * context->gw_stride); |
| const uintptr_t c = (uintptr_t)context->c + batch_index * context->bc_stride + |
| group_index * context->gc_stride; |
| void* workspace = |
| (void*)((uintptr_t)context->workspace + context->workspace_offset + |
| thread_id * context->per_thread_workspace_size); |
| |
| while (mr_block_size > 0) { |
| const size_t mr_step = min(mr_block_size, mr); |
| |
| // Pack the LHS data into the workspace. |
| context->packed_lh_config->pack_lh_for_igemm_fn( |
| mr_step, kc_elems, ks, mr_packed, context->kr, context->sr, |
| /*a=*/ |
| (const void**)((uintptr_t)context->indirect_a + |
| mr_block_start * ks * sizeof(void*)), |
| a_offset, context->zero, workspace); |
| |
| // Compute the iGEMM on the packed LHS data. |
| context->ukernel.packed_lhs_function[uarch_index]( |
| mr_step, context->nc, kc_elems, ks, /*packed_lhs=*/workspace, packed_w, |
| (void*)(c + mr_block_start * cm_stride), cm_stride_kernel, &context->params); |
| |
| mr_block_size -= mr_step; |
| mr_block_start += mr_step; |
| } |
| } |
| |
| // `output_tile_start` should be a multiple of igemm.mr (tile size). |
| void xnn_compute_conv2d_igemm_indirection( |
| struct conv2d_igemm_indirection_init_context* restrict context, |
| size_t output_tile_start, size_t output_tile_size) { |
| while (output_tile_size > 0) { |
| const size_t mr_step = min(output_tile_size, context->mr); |
| xnn_indirection_init_conv2d( |
| mr_step, output_tile_start, output_tile_start + mr_step, |
| context->indirection_buffer, context->input, context->zero_buffer, |
| context->input_pixel_stride, context->input_height, |
| context->input_width, context->output_height, context->output_width, |
| context->kernel_height, context->kernel_width, context->stride_height, |
| context->stride_width, context->dilation_height, |
| context->dilation_width, context->input_padding_top, |
| context->input_padding_left); |
| output_tile_size -= mr_step; |
| output_tile_start += mr_step; |
| } |
| } |
| |
| void xnn_compute_subconv2d(struct subconv_context* restrict context, |
| size_t batch_index, size_t group_index, |
| size_t subkernel_index, size_t slice_y, |
| size_t slice_x_start, size_t nc_block_start, |
| size_t slice_x_max, size_t nc_block_size) { |
| const struct subconvolution_params* subconvolution_params = |
| &context->subconvolution_params[subkernel_index]; |
| |
| if XNN_UNLIKELY (slice_y >= subconvolution_params->slice_height) { |
| return; |
| } |
| |
| const size_t slice_width = subconvolution_params->slice_width; |
| if XNN_UNLIKELY (slice_x_start >= slice_width) { |
| return; |
| } |
| const size_t slice_x_size = min(slice_x_max, slice_width - slice_x_start); |
| |
| const size_t cx_stride = context->cx_stride; |
| context->ukernel.function[XNN_UARCH_DEFAULT]( |
| slice_x_size, nc_block_size, context->kc, |
| subconvolution_params->scaled_kernel_size, |
| (const void**)((uintptr_t)subconvolution_params->indirection_buffer + |
| slice_y * subconvolution_params->indirection_y_stride + |
| slice_x_start * |
| subconvolution_params->indirection_x_stride), |
| (const void*)((uintptr_t)subconvolution_params->weights + |
| nc_block_start * subconvolution_params->w_stride + |
| group_index * context->gw_stride), |
| (void*)((uintptr_t)subconvolution_params->output + |
| group_index * context->gc_stride + slice_y * context->cy_stride + |
| slice_x_start * cx_stride + batch_index * context->bc_stride + |
| (nc_block_start << context->log2_csize)), |
| cx_stride, context->cn_stride, |
| context->a_offset + group_index * context->ga_stride + |
| batch_index * context->ba_stride, |
| context->zero, &context->params); |
| } |
| |
| void xnn_compute_dqsubconv2d(struct subconv_context* restrict context, |
| size_t batch_index, size_t group_index, |
| size_t subkernel_index, size_t slice_y, |
| size_t slice_x_start, size_t nc_block_start, |
| size_t slice_x_max, size_t nc_block_size) { |
| const struct subconvolution_params* subconvolution_params = |
| &context->subconvolution_params[subkernel_index]; |
| |
| if XNN_UNLIKELY (slice_y >= subconvolution_params->slice_height) { |
| return; |
| } |
| |
| const size_t slice_width = subconvolution_params->slice_width; |
| if XNN_UNLIKELY (slice_x_start >= slice_width) { |
| return; |
| } |
| const size_t slice_x_size = min(slice_x_max, slice_width - slice_x_start); |
| |
| const size_t cx_stride = context->cx_stride; |
| context->dq_ukernel.function[XNN_UARCH_DEFAULT]( |
| slice_x_size, nc_block_size, context->kc, |
| subconvolution_params->scaled_kernel_size, |
| (const void**)((uintptr_t)subconvolution_params->indirection_buffer + |
| slice_y * subconvolution_params->indirection_y_stride + |
| slice_x_start * |
| subconvolution_params->indirection_x_stride), |
| (const void*)((uintptr_t)subconvolution_params->weights + |
| nc_block_start * subconvolution_params->w_stride + |
| group_index * context->gw_stride), |
| (void*)((uintptr_t)subconvolution_params->output + |
| group_index * context->gc_stride + slice_y * context->cy_stride + |
| slice_x_start * cx_stride + batch_index * context->bc_stride + |
| (nc_block_start << context->log2_csize)), |
| cx_stride, context->cn_stride, |
| context->a_offset + group_index * context->ga_stride + |
| batch_index * context->ba_stride, |
| context->zero, context->zero_buffers[batch_index], &context->params, |
| &context->quantization_params[batch_index]); |
| } |
| |
| void xnn_compute_conv2d_hwc2chw(struct conv2d_context* restrict context, |
| size_t batch_index, size_t output_y_start, |
| size_t output_y_slice) { |
| context->hwc2chw_ukernel( |
| context->input_height, context->input_width, output_y_start, |
| output_y_start + output_y_slice, |
| (const void*)((uintptr_t)context->input + |
| batch_index * context->input_batch_stride), |
| context->zero, context->packed_weights, |
| (void*)((uintptr_t)context->output + |
| batch_index * context->output_batch_stride), |
| context->input_padding_top, context->output_channels, |
| context->output_height_stride, context->output_channel_stride, |
| &context->params); |
| } |
| |
| void xnn_compute_dwconv_indirection( |
| struct dwconv_indirection_init_context* restrict context, |
| size_t output_y_start, size_t output_y_tile) { |
| xnn_indirection_init_dwconv2d( |
| output_y_start, output_y_start + output_y_tile, |
| context->indirection_buffer, context->input, context->input_pixel_stride, |
| context->zero_buffer, context->input_height, context->input_width, |
| context->output_height, context->output_width, context->kernel_height, |
| context->kernel_width, context->stride_height, context->stride_width, |
| context->dilation_height, context->dilation_width, |
| context->input_padding_top, context->input_padding_left, |
| context->step_height, context->step_width, context->tile_size); |
| } |
| |
| void xnn_compute_dwconv_unipass(struct dwconv_context* restrict context, |
| size_t batch_index, size_t output_y, |
| size_t output_c_start, size_t output_c_tile) { |
| const void** indirect_input = |
| (const void**)((uintptr_t)context->indirect_input + |
| output_y * context->indirect_input_height_stride); |
| const size_t input_offset = context->input_offset + |
| batch_index * context->input_batch_stride + |
| output_c_start * context->input_channel_stride; |
| void* output = (void*)((uintptr_t)context->output + |
| batch_index * context->output_batch_stride + |
| output_y * context->output_height_stride + |
| output_c_start * context->output_channel_stride); |
| void* weights = (void*)((uintptr_t)context->packed_weights + |
| output_c_start * context->weights_channel_stride); |
| const size_t output_increment = |
| context->output_pixel_stride - |
| output_c_tile * context->output_channel_stride; |
| |
| context->ukernel(output_c_tile, context->output_width, indirect_input, |
| weights, output, context->indirect_input_width_stride, |
| output_increment, input_offset, /*input_pixel_stride=*/0, |
| context->zero, &context->params); |
| } |
| |
| void xnn_compute_dwconv2d_chw(struct dwconv2d_context* restrict context, |
| size_t batch_index, size_t channel) { |
| context->chw_ukernel(context->input_height, context->input_width, |
| (const void*)((uintptr_t)context->input + |
| channel * context->input_channel_stride + |
| batch_index * context->input_batch_stride), |
| (const void*)((uintptr_t)context->packed_weights + |
| channel * context->weights_channel_stride), |
| context->zero, |
| (void*)((uintptr_t)context->output + |
| channel * context->output_channel_stride + |
| batch_index * context->output_batch_stride), |
| context->input_padding_top, &context->params); |
| } |
| |
| void xnn_compute_argmax_pooling(struct argmax_pooling_context* restrict context, |
| size_t batch_index, size_t output_y) { |
| const void** indirect_input = |
| (const void**)((uintptr_t)context->indirect_input + |
| output_y * context->indirect_input_height_stride); |
| const size_t input_offset = |
| context->input_offset + batch_index * context->input_batch_stride; |
| void* output = (void*)((uintptr_t)context->output + |
| batch_index * context->output_batch_stride + |
| output_y * context->output_height_stride); |
| uint32_t* index = (uint32_t*)((uintptr_t)context->index + |
| batch_index * context->index_batch_stride + |
| output_y * context->index_height_stride); |
| |
| context->ukernel(context->output_width, context->pooling_size, |
| context->channels, indirect_input, input_offset, |
| /*input_pixel_stride=*/0, output, index, |
| context->input_increment, context->output_increment, |
| context->index_increment); |
| } |
| |
| void xnn_compute_max_pooling(struct max_pooling_context* restrict context, |
| size_t batch_index, size_t output_y) { |
| const void** indirect_input = |
| (const void**)((uintptr_t)context->indirect_input + |
| output_y * context->indirect_input_height_stride); |
| const size_t input_offset = |
| context->input_offset + batch_index * context->input_batch_stride; |
| void* output = (void*)((uintptr_t)context->output + |
| batch_index * context->output_batch_stride + |
| output_y * context->output_height_stride); |
| |
| context->ukernel(context->output_width, context->pooling_size, |
| context->channels, indirect_input, input_offset, |
| /*input_pixel_stride=*/0, output, context->input_increment, |
| context->output_increment, &context->params); |
| } |
| |
| void xnn_compute_unpooling(struct unpooling_context* restrict context, |
| size_t input_y, size_t input_x) { |
| const void* input = (const void*)((uintptr_t)context->input + |
| input_y * context->input_height_stride + |
| input_x * context->input_width_stride); |
| const uint32_t* index = |
| (const uint32_t*)((uintptr_t)context->index + |
| input_y * context->index_height_stride + |
| input_x * context->index_width_stride); |
| void** indirect_output = |
| (void**)((uintptr_t)context->indirect_output + |
| input_y * context->indirect_output_height_stride + |
| input_x * context->indirect_output_width_stride); |
| |
| context->ukernel(context->pooling_size, context->channels, |
| context->fill_value, input, index, indirect_output); |
| } |
| |
| void xnn_compute_average_pooling( |
| struct average_pooling_context* restrict context, size_t batch_index, |
| size_t output_y) { |
| // Refer to xnn_compute_average_pooling for documentation on these terms. |
| const size_t indirect_y = min(output_y, context->indirect_top_height) + |
| doz(output_y + 1, context->indirect_bot_start); |
| const void** indirect_input = |
| (void*)((uintptr_t)context->indirect_input + |
| indirect_y * context->indirect_input_height_stride); |
| const size_t input_offset_for_compressed_section = |
| (output_y - indirect_y) * (output_y < context->indirect_bot_start) * |
| context->input_y_stride; |
| const size_t input_offset = context->input_offset + |
| batch_index * context->input_batch_stride + |
| input_offset_for_compressed_section; |
| |
| const void* pixelwise_buffer = |
| context->pixelwise_buffer |
| ? (const void*)((uintptr_t)context->pixelwise_buffer + |
| output_y * context->pixelwise_buffer_height_stride) |
| : NULL; |
| void* output = (void*)((uintptr_t)context->output + |
| batch_index * context->output_batch_stride + |
| output_y * context->output_height_stride); |
| |
| context->ukernel(context->output_width, context->pooling_size, |
| context->channels, indirect_input, input_offset, |
| /*input_pixel_stride=*/0, context->zero, pixelwise_buffer, |
| output, context->input_increment, context->output_increment, |
| &context->params); |
| } |
| |
| void xnn_compute_resize_bilinear_indirection( |
| struct resize_bilinear_nhwc_indirection_init_context* restrict context, |
| size_t output_y_start, size_t output_y_tile) { |
| void* buffer = context->buffer; |
| |
| context->indirection_init( |
| output_y_start, output_y_start + output_y_tile, |
| context->input_pixel_stride, context->input_height, context->input_width, |
| context->output_height, context->output_width, context->input, |
| /*indirection_buffer==*/ |
| (const void**)((uintptr_t)buffer + context->indirect_input_offset), |
| /*packed_weights=*/(void*)buffer, context->align_corners, |
| context->tensorflow_legacy_mode); |
| } |
| |
| void xnn_compute_resize_bilinear( |
| struct resize_bilinear_context* restrict context, size_t batch_index, |
| size_t pixel_start, size_t pixel_range) { |
| void* output = (void*)((uintptr_t)context->output + |
| pixel_start * context->output_pixel_stride + |
| batch_index * context->output_batch_stride); |
| |
| context->ukernel( |
| pixel_range, context->scaled_channels, |
| context->indirect_input + pixel_start * 4, |
| context->input_offset + batch_index * context->input_batch_stride, |
| (const void*)((uintptr_t)context->packed_weights + |
| (pixel_start << context->log2_wsize)), |
| output, context->output_pixel_stride - context->scaled_channels); |
| } |
| |
| void xnn_compute_resize_bilinear_chw( |
| struct resize_bilinear_chw_context* restrict context, size_t batch_index, |
| size_t channel_start, size_t channel_range) { |
| void* output = (void*)((uintptr_t)context->output + |
| channel_start * context->output_channel_stride + |
| batch_index * context->output_batch_stride); |
| const size_t input_offset = context->input_offset + |
| batch_index * context->input_batch_stride + |
| channel_start * context->input_channel_stride; |
| |
| context->ukernel(context->output_pixels, channel_range, |
| context->indirect_input, input_offset, |
| context->packed_weights, output, |
| context->input_channel_stride); |
| } |
| |
| void xnn_compute_pad_5d(struct pad_context* restrict context, size_t i, |
| size_t j, size_t k, size_t l, size_t m) { |
| const void* input = |
| (const void*)((uintptr_t)context->input + i * context->input_stride[4] + |
| j * context->input_stride[3] + |
| k * context->input_stride[2] + |
| l * context->input_stride[1] + |
| m * context->input_stride[0]); |
| void* output = |
| (void*)((uintptr_t)context->output + i * context->output_stride[4] + |
| j * context->output_stride[3] + k * context->output_stride[2] + |
| l * context->output_stride[1] + m * context->output_stride[0]); |
| |
| const size_t i_padding = context->pre_paddings[5]; |
| const size_t j_padding = context->pre_paddings[4]; |
| const size_t k_padding = context->pre_paddings[3]; |
| const size_t l_padding = context->pre_paddings[2]; |
| const size_t m_padding = context->pre_paddings[1]; |
| |
| const size_t i_size = context->input_size[5]; |
| const size_t j_size = context->input_size[4]; |
| const size_t k_size = context->input_size[3]; |
| const size_t l_size = context->input_size[2]; |
| const size_t m_size = context->input_size[1]; |
| |
| if XNN_LIKELY (i - i_padding < i_size && j - j_padding < j_size && |
| k - k_padding < k_size && l - l_padding < l_size && |
| m - m_padding < m_size) { |
| context->pad_ukernel(1 /* rows */, context->input_size[0], |
| context->pre_paddings[0], context->post_paddings[0], |
| input, 0 /* input stride */, output, |
| 0 /* output stride */, context->padding_value); |
| } else { |
| context->fill_ukernel(1 /* rows */, context->output_size[0], output, |
| 0 /* output stride */, context->padding_value); |
| } |
| } |
| |
| void xnn_compute_slice_1d(struct slice_context* restrict context, size_t i) { |
| const void* input = |
| (const void*)((uintptr_t)context->input + i * context->input_stride[0]); |
| void* output = |
| (void*)((uintptr_t)context->output + i * context->output_stride[0]); |
| |
| context->ukernel(context->contiguous_size, input, output, NULL); |
| } |
| |
| void xnn_compute_slice_2d(struct slice_context* restrict context, size_t i, |
| size_t j) { |
| const void* input = |
| (const void*)((uintptr_t)context->input + i * context->input_stride[1] + |
| j * context->input_stride[0]); |
| void* output = |
| (void*)((uintptr_t)context->output + i * context->output_stride[1] + |
| j * context->output_stride[0]); |
| |
| context->ukernel(context->contiguous_size, input, output, NULL); |
| } |
| |
| void xnn_compute_slice_3d(struct slice_context* restrict context, size_t i, |
| size_t j, size_t k) { |
| const void* input = |
| (const void*)((uintptr_t)context->input + i * context->input_stride[2] + |
| j * context->input_stride[1] + |
| k * context->input_stride[0]); |
| void* output = |
| (void*)((uintptr_t)context->output + i * context->output_stride[2] + |
| j * context->output_stride[1] + k * context->output_stride[0]); |
| |
| context->ukernel(context->contiguous_size, input, output, NULL); |
| } |
| |
| void xnn_compute_slice_4d(struct slice_context* restrict context, size_t i, |
| size_t j, size_t k, size_t l) { |
| const void* input = |
| (const void*)((uintptr_t)context->input + i * context->input_stride[3] + |
| j * context->input_stride[2] + |
| k * context->input_stride[1] + |
| l * context->input_stride[0]); |
| void* output = |
| (void*)((uintptr_t)context->output + i * context->output_stride[3] + |
| j * context->output_stride[2] + k * context->output_stride[1] + |
| l * context->output_stride[0]); |
| |
| context->ukernel(context->contiguous_size, input, output, NULL); |
| } |
| |
| void xnn_compute_slice_5d(struct slice_context* restrict context, size_t i, |
| size_t j, size_t k, size_t l, size_t m) { |
| const void* input = |
| (const void*)((uintptr_t)context->input + i * context->input_stride[4] + |
| j * context->input_stride[3] + |
| k * context->input_stride[2] + |
| l * context->input_stride[1] + |
| m * context->input_stride[0]); |
| void* output = |
| (void*)((uintptr_t)context->output + i * context->output_stride[4] + |
| j * context->output_stride[3] + k * context->output_stride[2] + |
| l * context->output_stride[1] + m * context->output_stride[0]); |
| |
| context->ukernel(context->contiguous_size, input, output, NULL); |
| } |
| |
| void xnn_compute_elementwise_binary_1d_tile( |
| struct elementwise_binary_context* restrict context, size_t offset, |
| size_t count) { |
| size_t a_offset = ((context->a_stride[4] == 0 ? 0 : offset)); |
| size_t b_offset = ((context->b_stride[4] == 0 ? 0 : offset)); |
| const void* a = (const void*)((uintptr_t)context->a + a_offset); |
| const void* b = (const void*)((uintptr_t)context->b + b_offset); |
| void* y = (void*)((uintptr_t)context->y + offset); |
| context->ukernel(count, a, b, y, &context->params); |
| } |
| |
| void xnn_compute_elementwise_binary_1d( |
| struct elementwise_binary_context* context, size_t offset, size_t count) { |
| for (size_t i = offset; i < offset + count; i++) { |
| const void* a = |
| (const void*)((uintptr_t)context->a + i * context->a_stride[4]); |
| const void* b = |
| (const void*)((uintptr_t)context->b + i * context->b_stride[4]); |
| void* y = (void*)((uintptr_t)context->y + i * context->y_stride[4]); |
| context->ukernel(context->elements, a, b, y, &context->params); |
| } |
| } |
| |
| void xnn_compute_elementwise_binary_2d( |
| struct elementwise_binary_context* context, size_t i, size_t offset, |
| size_t count) { |
| uintptr_t a = (uintptr_t)context->a + i * context->a_stride[3]; |
| uintptr_t b = (uintptr_t)context->b + i * context->b_stride[3]; |
| uintptr_t y = (uintptr_t)context->y + i * context->y_stride[3]; |
| for (size_t j = offset; j < offset + count; j++) { |
| context->ukernel(context->elements, |
| (const void*)(a + j * context->a_stride[4]), |
| (const void*)(b + j * context->b_stride[4]), |
| (void*)(y + j * context->y_stride[4]), &context->params); |
| } |
| } |
| |
| void xnn_compute_elementwise_binary_3d( |
| struct elementwise_binary_context* context, size_t i, size_t offset_j, |
| size_t offset_k, size_t count_j, size_t count_k) { |
| uintptr_t a = (uintptr_t)context->a + i * context->a_stride[2]; |
| uintptr_t b = (uintptr_t)context->b + i * context->b_stride[2]; |
| uintptr_t y = (uintptr_t)context->y + i * context->y_stride[2]; |
| for (size_t j = offset_j; j < offset_j + count_j; j++) { |
| for (size_t k = offset_k; k < offset_k + count_k; k++) { |
| context->ukernel( |
| context->elements, |
| (const void*)(a + j * context->a_stride[3] + |
| k * context->a_stride[4]), |
| (const void*)(b + j * context->b_stride[3] + |
| k * context->b_stride[4]), |
| (void*)(y + j * context->y_stride[3] + k * context->y_stride[4]), |
| &context->params); |
| } |
| } |
| } |
| |
| void xnn_compute_elementwise_binary_4d( |
| struct elementwise_binary_context* context, size_t i, size_t j, |
| size_t offset_k, size_t offset_l, size_t count_k, size_t count_l) { |
| uintptr_t a = (uintptr_t)context->a + +i * context->a_stride[1] + |
| j * context->a_stride[2]; |
| uintptr_t b = (uintptr_t)context->b + i * context->b_stride[1] + |
| j * context->b_stride[2]; |
| uintptr_t y = (uintptr_t)context->y + i * context->y_stride[1] + |
| j * context->y_stride[2]; |
| for (size_t k = offset_k; k < offset_k + count_k; k++) { |
| for (size_t l = offset_l; l < offset_l + count_l; l++) { |
| context->ukernel( |
| context->elements, |
| (const void*)(a + k * context->a_stride[3] + |
| l * context->a_stride[4]), |
| (const void*)(b + k * context->b_stride[3] + |
| l * context->b_stride[4]), |
| (void*)(y + k * context->y_stride[3] + l * context->y_stride[4]), |
| &context->params); |
| } |
| } |
| } |
| |
| void xnn_compute_elementwise_binary_5d( |
| struct elementwise_binary_context* restrict context, size_t i, size_t j, |
| size_t k, size_t l, size_t m) { |
| const void* a = |
| (const void*)((uintptr_t)context->a + i * context->a_stride[0] + |
| j * context->a_stride[1] + k * context->a_stride[2] + |
| l * context->a_stride[3] + m * context->a_stride[4]); |
| const void* b = |
| (const void*)((uintptr_t)context->b + i * context->b_stride[0] + |
| j * context->b_stride[1] + k * context->b_stride[2] + |
| l * context->b_stride[3] + m * context->b_stride[4]); |
| void* y = (void*)((uintptr_t)context->y + i * context->y_stride[0] + |
| j * context->y_stride[1] + k * context->y_stride[2] + |
| l * context->y_stride[3] + m * context->y_stride[4]); |
| context->ukernel(context->elements, a, b, y, &context->params); |
| } |
| |
| void xnn_compute_lut_strided(struct lut_strided_context* restrict context, |
| size_t batch_offset, size_t batch_range) { |
| for (size_t batch_index = batch_offset; |
| batch_index < batch_offset + batch_range; batch_index++) { |
| const void* x = |
| (const void*)((uintptr_t)context->x + context->x_stride * batch_index); |
| void* y = (void*)((uintptr_t)context->y + context->y_stride * batch_index); |
| |
| context->ukernel(context->n, x, y, context->t); |
| } |
| } |
| |
| void xnn_compute_lut_contiguous(struct lut_contiguous_context* restrict context, |
| size_t offset, size_t size) { |
| const void* x = (const void*)((uintptr_t)context->x + offset); |
| void* y = (void*)((uintptr_t)context->y + offset); |
| |
| context->ukernel(size, x, y, context->t); |
| } |
| |
| void xnn_compute_univector_strided( |
| struct univector_strided_context* restrict context, size_t batch_index, |
| size_t batch_range) { |
| const size_t x_stride = context->x_stride; |
| const size_t y_stride = context->y_stride; |
| |
| const void* x = (const void*)((uintptr_t)context->x + x_stride * batch_index); |
| void* y = (void*)((uintptr_t)context->y + y_stride * batch_index); |
| do { |
| context->ukernel(context->n, x, y, &context->params); |
| x = (const void*)((uintptr_t)x + x_stride); |
| y = (void*)((uintptr_t)y + y_stride); |
| } while (--batch_range != 0); |
| } |
| |
| void xnn_compute_univector_contiguous( |
| struct univector_contiguous_context* restrict context, size_t offset, |
| size_t size) { |
| const uint32_t log2_xsize = context->log2_xsize; |
| const uint32_t log2_ysize = context->log2_ysize; |
| const void* x = (const void*)((uintptr_t)context->x + offset); |
| void* y = |
| (void*)((uintptr_t)context->y + ((offset >> log2_xsize) << log2_ysize)); |
| context->ukernel(size, x, y, &context->params); |
| } |
| |
| void xnn_compute_contiguous_reduce(struct reduce_context* restrict context, |
| size_t output_idx0, size_t output_idx1, |
| size_t output_idx2, |
| size_t output2_block_size) { |
| const size_t* input_stride = context->input_stride; |
| const size_t* output_stride = context->output_stride; |
| |
| // input dimensions 1, 3 & 5 are reduced so the entirety of these dimensions |
| // are processed so their indices are always 0. |
| size_t input_offset = input_stride[0] * output_idx0 + |
| input_stride[2] * output_idx1 + |
| input_stride[4] * output_idx2; |
| size_t output_offset = |
| (output_stride[0] * output_idx0 + output_stride[1] * output_idx1 + |
| output_stride[2] * output_idx2) * |
| context->output_element_size; |
| size_t workspace_offset = |
| (output_stride[0] * output_idx0 + output_stride[1] * output_idx1 + |
| output_stride[2] * output_idx2) * |
| context->accumulation_element_size; |
| int input_shape1 = context->input_shape[1]; |
| int input_shape3 = context->input_shape[3]; |
| |
| void* output_ptr = NULL; |
| if (context->workspace) { |
| output_ptr = context->workspace; |
| } else { |
| output_ptr = context->output; |
| } |
| void* output = (void*)((uintptr_t)output_ptr + workspace_offset); |
| // Reduce microkernels accumulate into the output buffer. |
| if (context->identity_value == 0) { |
| memset(output, 0, context->accumulation_element_size * output2_block_size); |
| } else { |
| context->fill_ukernel( |
| 1, context->accumulation_element_size * output2_block_size, output, |
| context->accumulation_element_size, context->identity_value); |
| } |
| |
| // Input dimension 1 is reduced. |
| if (context->channels != 0) { |
| for (size_t i = 0; i < input_shape1; ++i) { |
| const void* input = (const void*)((uintptr_t)context->input + input_offset); |
| // Input dimension 3 is reduced. |
| for (size_t j = 0; j < input_shape3; ++j) { |
| const void* input_row = input; |
| // output2_block_size output elements are written. |
| for (size_t k = 0; k < output2_block_size; ++k) { |
| // The microkernel reduces input dimension 5. |
| context->ukernel.contiguous_reduce(context->channels, input_row, output, |
| &context->params); |
| // input_stride[4] is the number of bytes of input which have been |
| // processed by the microkernel call. |
| input_row = (const void*)((uintptr_t)input_row + input_stride[4]); |
| // Increment output pointer by the number of output bytes which have |
| // been written. |
| output = |
| (void*)((uintptr_t)output + context->accumulation_element_size); |
| } |
| // Reset the output pointer. |
| output = (void*)((uintptr_t)output_ptr + workspace_offset); |
| // Iterating over input_shape[3]. |
| input = (const void*)((uintptr_t)input + input_stride[3]); |
| } |
| // Iterating over input_shape[1]. |
| input_offset += input_stride[1]; |
| } |
| } |
| // Convert to output datatype if accumulation type != output type. |
| if (context->workspace) { |
| void* workspace_ptr = |
| (void*)((uintptr_t)context->workspace + workspace_offset); |
| output_ptr = (void*)((uintptr_t)context->output + output_offset); |
| |
| context->cvt_ukernel( |
| context->accumulation_element_size * output2_block_size, workspace_ptr, |
| output_ptr, &context->cvt_params); |
| } |
| } |
| |
| void xnn_compute_discontiguous_reduce(struct reduce_context* restrict context, |
| size_t output_idx0, size_t output_idx1, |
| size_t output_idx2, |
| size_t output2_block_size) { |
| const size_t* input_stride = context->input_stride; |
| const size_t* output_stride = context->output_stride; |
| |
| // input dimensions 0, 2 & 4 are reduced so the entirety of these dimensions |
| // are processed so their indices are always 0. |
| size_t input_offset = input_stride[1] * output_idx0 + |
| input_stride[3] * output_idx1 + |
| input_stride[5] * output_idx2; |
| size_t output_offset = |
| (output_stride[0] * output_idx0 + output_stride[1] * output_idx1 + |
| output_stride[2] * output_idx2) * |
| context->output_element_size; |
| size_t workspace_offset = |
| (output_stride[0] * output_idx0 + output_stride[1] * output_idx1 + |
| output_stride[2] * output_idx2) * |
| context->accumulation_element_size; |
| int input_shape0 = context->input_shape[0]; |
| int input_shape2 = context->input_shape[2]; |
| |
| void* output_ptr = NULL; |
| if (context->workspace) { |
| output_ptr = context->workspace; |
| } else { |
| output_ptr = context->output; |
| } |
| void* output = (void*)((uintptr_t)output_ptr + workspace_offset); |
| // Discontiguous reduce microkernels accumulate into the output buffer. |
| if (context->identity_value == 0) { |
| memset(output, 0, context->accumulation_element_size * output2_block_size); |
| } else { |
| context->fill_ukernel( |
| 1, context->accumulation_element_size * output2_block_size, output, |
| context->accumulation_element_size, context->identity_value); |
| } |
| |
| if (context->channels != 0) { |
| if (context->is_old_reduce) { |
| // Input dimension 0 is reduced. |
| for (size_t i = 0; i < input_shape0; ++i) { |
| // Input dimension 2 is reduced. |
| for (size_t j = 0; j < input_shape2; ++j) { |
| // The microkernel reduces input dimension 4 and iterates over |
| // output_block_size elements of dimension 5. |
| context->ukernel.discontiguous_reduce( |
| context->channels, output2_block_size, |
| (const void*)((uintptr_t)context->input + input_offset + |
| i * input_stride[0] + j * input_stride[2]), |
| input_stride[4], context->zero, |
| (void*)((uintptr_t)output_ptr + workspace_offset), |
| &context->params); |
| } |
| } |
| } else { |
| // The microkernel reduces input dimension 0, 2 & 4 and iterates over |
| // output_block_size elements of dimension 5. |
| context->ukernel.discontiguous_reduce2( |
| output2_block_size, context->channels, input_shape2, input_shape0, |
| (const void*)((uintptr_t)context->input + input_offset), |
| input_stride[4], input_stride[2], input_stride[0], context->zero, |
| output, &context->params); |
| } |
| } |
| // Convert to output datatype if accumulation type != output type. |
| if (context->workspace) { |
| void* workspace_ptr = |
| (void*)((uintptr_t)context->workspace + workspace_offset); |
| output_ptr = (void*)((uintptr_t)context->output + output_offset); |
| |
| context->cvt_ukernel( |
| context->accumulation_element_size * output2_block_size, workspace_ptr, |
| output_ptr, &context->cvt_params); |
| } |
| } |
| |
| void xnn_compute_pad_qd8_params( |
| struct f32_qd8_convert_context* restrict context, size_t batch_index) { |
| const size_t batch_size = context->batch_size; |
| for (size_t i = 0; i < XNN_EXTRA_QUANTIZATION_PARAMS; ++i) { |
| context->quantization_params[batch_size + i].zero_point = |
| context->quantization_params[batch_size - 1].zero_point; |
| context->quantization_params[batch_size + i].inv_scale = |
| context->quantization_params[batch_size - 1].inv_scale; |
| |
| if (context->rsum_ukernel) { |
| context->row_sum[batch_size + i] = context->row_sum[batch_size - 1]; |
| } |
| } |
| } |
| |
| typedef struct xnn_qd8_quantization_params(f16_quantization_params_fn)( |
| xnn_float16 min, xnn_float16 max, xnn_float16* f32_scale); |
| typedef struct xnn_qd8_quantization_params(f32_quantization_params_fn)( |
| float min, float max, float* f32_scale); |
| |
| void xnn_compute_f16_qx8_convert( |
| struct f16_qd8_convert_context* restrict context, |
| f16_quantization_params_fn quantization_params_function, |
| size_t batch_index) { |
| const size_t x_stride = context->x_stride; |
| const size_t y_stride = context->y_stride; |
| const size_t n = context->n; |
| const void* input = |
| (const void*)((uintptr_t)context->x + x_stride * batch_index); |
| void* output = (void*)((uintptr_t)context->y + y_stride * batch_index); |
| |
| xnn_float16 minmax[2] = {xnn_float16_from_bits(UINT16_C(0x7c00)), |
| xnn_float16_from_bits(UINT16_C(0xfc00))}; |
| context->rminmax_ukernel(n, input, minmax, &context->params); |
| xnn_float16 f16_scale; |
| context->quantization_params[batch_index] = |
| quantization_params_function(minmax[0], minmax[1], &f16_scale); |
| |
| struct xnn_f16_qs8_cvt_params params; |
| params.scalar.scale = f16_scale; |
| params.scalar.output_zero_point = |
| context->quantization_params[batch_index].zero_point; |
| context->convert_ukernel(n, input, output, (union xnn_unary_uparams*)¶ms); |
| |
| if (context->rsum_ukernel) { |
| // Compute and store the row sum of the quantized output. |
| const size_t num_bytes = n / sizeof(xnn_float16) * sizeof(int8_t); |
| int32_t row_sum = 0; |
| struct xnn_qs8_rsum_params rsum_params = {0,}; |
| context->rsum_ukernel(num_bytes, output, &row_sum, &rsum_params); |
| context->row_sum[batch_index] = (float)row_sum; |
| } |
| } |
| |
| void xnn_compute_f16_qd8_convert( |
| struct f16_qd8_convert_context* restrict context, size_t batch_offset, |
| size_t batch_range) { |
| for (size_t batch_index = batch_offset; |
| batch_index < batch_offset + batch_range; batch_index++) { |
| xnn_compute_f16_qx8_convert( |
| context, xnn_f16_qd8_asymmetric_quantization_params, batch_index); |
| } |
| } |
| |
| void xnn_compute_f16_qdu8_convert( |
| struct f16_qd8_convert_context* restrict context, size_t batch_offset, |
| size_t batch_range) { |
| for (size_t batch_index = batch_offset; |
| batch_index < batch_offset + batch_range; batch_index++) { |
| xnn_compute_f16_qx8_convert( |
| context, xnn_f16_qdu8_asymmetric_quantization_params, batch_index); |
| } |
| } |
| |
| void xnn_compute_f32_qx8_convert( |
| struct f32_qd8_convert_context* restrict context, |
| f32_quantization_params_fn quantization_params_function, |
| size_t batch_index) { |
| const size_t x_stride = context->x_stride; |
| const size_t y_stride = context->y_stride; |
| const size_t n = context->n; |
| const void* input = |
| (const void*)((uintptr_t)context->x + x_stride * batch_index); |
| void* output = (void*)((uintptr_t)context->y + y_stride * batch_index); |
| |
| float minmax[2] = {INFINITY, -INFINITY}; |
| context->rminmax_ukernel(n, input, minmax, &context->params); |
| float scale; |
| context->quantization_params[batch_index] = |
| quantization_params_function(minmax[0], minmax[1], &scale); |
| |
| struct xnn_f32_qs8_cvt_params params; |
| params.scalar.scale = scale; |
| params.scalar.output_zero_point = |
| context->quantization_params[batch_index].zero_point; |
| context->convert_ukernel(n, input, output, (union xnn_unary_uparams*)¶ms); |
| |
| if (context->rsum_ukernel) { |
| // Compute and store the row sum of the quantized output. |
| const size_t num_bytes = n / sizeof(float) * sizeof(int8_t); |
| int32_t row_sum = 0; |
| struct xnn_qs8_rsum_params rsum_params = {0,}; |
| context->rsum_ukernel(num_bytes, output, &row_sum, &rsum_params); |
| context->row_sum[batch_index] = (float)row_sum; |
| } |
| } |
| |
| void xnn_compute_f32_qd8_convert( |
| struct f32_qd8_convert_context* restrict context, size_t batch_offset, |
| size_t batch_range) { |
| for (size_t batch_index = batch_offset; |
| batch_index < batch_offset + batch_range; batch_index++) { |
| xnn_compute_f32_qx8_convert( |
| context, xnn_f32_qd8_asymmetric_quantization_params, batch_index); |
| } |
| } |
| |
| void xnn_compute_f32_qdu8_convert( |
| struct f32_qd8_convert_context* restrict context, size_t batch_offset, |
| size_t batch_range) { |
| for (size_t batch_index = batch_offset; |
| batch_index < batch_offset + batch_range; batch_index++) { |
| xnn_compute_f32_qx8_convert( |
| context, xnn_f32_qdu8_asymmetric_quantization_params, batch_index); |
| } |
| } |
| |
| void xnn_compute_pack_lh(struct pack_lh_context* restrict context, |
| size_t group_idx, size_t m_idx_start, size_t tile) { |
| const void* lhs = |
| (const void*)((uintptr_t)context->lhs + group_idx * context->gi_stride + |
| m_idx_start * context->lhs_stride); |
| const size_t offset = context->packed_offset_fn( |
| m_idx_start, context->k, context->mr, context->kr, context->sr); |
| void* lhs_packed = (void*)((uintptr_t)context->lhs_packed + |
| group_idx * context->gp_stride + offset); |
| |
| context->pack_lh_ukernel(/*m=*/tile, context->k, context->mr, context->kr, |
| context->sr, /*m_idx_start=*/0, lhs, |
| context->lhs_stride, lhs_packed); |
| } |
| |
| void xnn_compute_f32_qp8_convert( |
| struct f32_qp8_convert_context* restrict context, size_t group_idx, |
| size_t m_idx_start, size_t m_tile) { |
| const size_t m_end = m_idx_start + m_tile; |
| while (m_idx_start < m_end) { |
| const size_t m_step = min(context->mr, m_end - m_idx_start); |
| const float* lhs = (const float*)((const char*)context->lhs + |
| (group_idx * context->m + m_idx_start) * |
| context->lhs_stride); |
| int8_t* lhs_packed = (int8_t*)((uintptr_t)context->lhs_packed + |
| group_idx * context->group_stride + |
| xnn_x8_packq_f32qp8_packed_offset( |
| m_idx_start, context->k, context->mr, |
| context->kr, context->sr)); |
| |
| context->packq_ukernel(/*m=*/m_step, context->k, context->mr, context->kr, |
| context->sr, m_idx_start, lhs, context->lhs_stride, |
| lhs_packed); |
| m_idx_start += m_step; |
| } |
| } |
| |
| void xnn_compute_u8_softmax(struct u8_softmax_context* restrict context, |
| size_t batch_index) { |
| const uint8_t* x = |
| (const uint8_t*)((uintptr_t)context->x + context->x_stride * batch_index); |
| uint8_t* y = |
| (uint8_t*)((uintptr_t)context->y + context->y_stride * batch_index); |
| const size_t n = context->n; |
| |
| uint8_t x_max = 0; |
| context->rmax_ukernel(n, x, &x_max, /*params=*/NULL); |
| const size_t adjustment = x_max ^ 255; |
| const uint32_t* t = (const uint32_t*)context->t + adjustment; |
| context->lut_norm_ukernel(n, x, t, y); |
| } |
| |
| void xnn_compute_floating_point_softmax( |
| struct floating_point_softmax_context* restrict context, |
| size_t batch_index) { |
| const void* x = |
| (const void*)((uintptr_t)context->x + context->x_stride * batch_index); |
| void* y = (void*)((uintptr_t)context->y + context->y_stride * batch_index); |
| const size_t n = context->n; |
| |
| // First pass: reduce-max |
| union { |
| float as_float; |
| xnn_float16 as_half; |
| } x_max; |
| memcpy(&x_max, &context->rmax_init, sizeof(x_max)); |
| context->rmax_ukernel(n, x, &x_max, &context->rmax_params); |
| |
| // Second pass: reduce-add & store exp(x-x_max) |
| float y_sum; |
| context->raddstoreexpminusmax_ukernel(n, x, &x_max, y, &y_sum, |
| &context->expminus_params); |
| |
| // Third pass: scale y |
| union { |
| float as_float; |
| xnn_float16 as_half; |
| } y_scale; |
| context->compute_reciprocal(&y_sum, &y_scale); |
| context->vmulc_ukernel(n, y, &y_scale, y, &context->minmax_params); |
| } |
| |
| void xnn_compute_vmulcaddc(struct vmulcaddc_context* restrict context, |
| size_t batch_start, size_t batch_size) { |
| const size_t x_stride = context->x_stride; |
| const size_t y_stride = context->y_stride; |
| |
| const void* x = (const void*)((uintptr_t)context->x + x_stride * batch_start); |
| void* y = (void*)((uintptr_t)context->y + y_stride * batch_start); |
| |
| context->ukernel(batch_size, context->n, x, x_stride, context->w, y, y_stride, |
| &context->params); |
| } |
| |
| void xnn_compute_rope(struct rope_context* restrict context, size_t batch_index, |
| size_t head_index, size_t sequence_index) { |
| const size_t scaled_channels = context->scaled_channels; |
| const size_t offset = batch_index * context->batch_stride + |
| head_index * context->head_stride + |
| sequence_index * context->sequence_stride; |
| const void* input = (const void*)((uintptr_t)context->input + offset); |
| const void* weights = |
| (const void*)((uintptr_t)context->weights + |
| sequence_index * (scaled_channels + scaled_channels)); |
| void* output = (void*)((uintptr_t)context->output + offset); |
| |
| context->vcmul(scaled_channels, input, weights, output, NULL); |
| } |
| |
| void xnn_compute_hmp_gemm(struct gemm_context* restrict context, |
| uint32_t uarch_index, size_t nr_block_start, |
| size_t mr_block_start, size_t nr_block_size, |
| size_t mr_block_size) { |
| const size_t a_stride = context->a_stride; |
| const size_t cm_stride = context->cm_stride; |
| |
| while (mr_block_size > 0) { |
| const size_t mr_step = min(mr_block_size, context->mr); |
| context->ukernel.function[uarch_index]( |
| mr_step, nr_block_size, context->k_scaled, |
| (const void*)((uintptr_t)context->a + mr_block_start * a_stride), |
| a_stride, |
| (const void*)((uintptr_t)context->packed_w + |
| nr_block_start * context->w_stride), |
| (void*)((uintptr_t)context->c + mr_block_start * cm_stride + |
| (nr_block_start << context->log2_csize)), |
| cm_stride, context->cn_stride, context->fused_params); |
| mr_block_size -= mr_step; |
| mr_block_start += mr_step; |
| } |
| } |
| |
| void xnn_compute_hmp_dqgemm(struct gemm_context* restrict context, |
| uint32_t uarch_index, size_t nr_block_start, |
| size_t mr_block_start, size_t nr_block_size, |
| size_t mr_block_size) { |
| const size_t a_stride = context->a_stride; |
| const size_t cm_stride = context->cm_stride; |
| |
| while (mr_block_size > 0) { |
| const size_t mr_step = min(mr_block_size, context->mr); |
| |
| if (context->with_row_sum) { |
| context->dq_qc2w_ukernel.function[uarch_index]( |
| mr_step, nr_block_size, context->k_scaled, |
| (const void*)((uintptr_t)context->a + mr_block_start * a_stride), |
| a_stride, |
| (const void*)((uintptr_t)context->packed_w + |
| nr_block_start * context->w_stride), |
| (void*)((uintptr_t)context->c + mr_block_start * cm_stride + |
| (nr_block_start << context->log2_csize)), |
| cm_stride, context->cn_stride, context->fused_params, |
| &context->row_sum[mr_block_start], |
| &context->quantization_params[mr_block_start]); |
| } else { |
| context->dq_ukernel.function[uarch_index]( |
| mr_step, nr_block_size, context->k_scaled, |
| (const void*)((uintptr_t)context->a + mr_block_start * a_stride), |
| a_stride, |
| (const void*)((uintptr_t)context->packed_w + |
| nr_block_start * context->w_stride), |
| (void*)((uintptr_t)context->c + mr_block_start * cm_stride + |
| (nr_block_start << context->log2_csize)), |
| cm_stride, context->cn_stride, context->fused_params, |
| &context->quantization_params[mr_block_start]); |
| } |
| mr_block_size -= mr_step; |
| mr_block_start += mr_step; |
| } |
| } |
| |
| void xnn_compute_hmp_igemm(struct igemm_context* restrict context, |
| uint32_t uarch_index, size_t batch_index, |
| size_t group_index, size_t nr_block_start, |
| size_t mr_block_start, size_t nr_block_size, |
| size_t mr_block_size) { |
| const size_t ks = context->ks; |
| const size_t cm_stride = context->cm_stride; |
| |
| while (mr_block_size > 0) { |
| const size_t mr_step = min(mr_block_size, context->mr); |
| context->ukernel.function[uarch_index]( |
| mr_step, nr_block_size, context->kc, context->ks_scaled, |
| (const void**)((uintptr_t)context->indirect_a + |
| mr_block_start * ks * sizeof(void*)), |
| (const void*)((uintptr_t)context->packed_w + |
| nr_block_start * context->w_stride + |
| group_index * context->gw_stride), |
| (void*)((uintptr_t)context->c + group_index * context->gc_stride + |
| batch_index * context->bc_stride + mr_block_start * cm_stride + |
| (nr_block_start << context->log2_csize)), |
| cm_stride, context->cn_stride, |
| context->a_offset + group_index * context->ga_stride + |
| batch_index * context->ba_stride, |
| context->zero, &context->params); |
| mr_block_size -= mr_step; |
| mr_block_start += mr_step; |
| } |
| } |
| |
| void xnn_compute_hmp_dqigemm(struct igemm_context* restrict context, |
| uint32_t uarch_index, size_t batch_index, |
| size_t group_index, size_t nr_block_start, |
| size_t mr_block_start, size_t nr_block_size, |
| size_t mr_block_size) { |
| const size_t ks = context->ks; |
| const size_t cm_stride = context->cm_stride; |
| |
| while (mr_block_size > 0) { |
| const size_t mr_step = min(mr_block_size, context->mr); |
| context->dq_ukernel.function[uarch_index]( |
| mr_step, nr_block_size, context->kc, context->ks_scaled, |
| (const void**)((uintptr_t)context->indirect_a + |
| mr_block_start * ks * sizeof(void*)), |
| (const void*)((uintptr_t)context->packed_w + |
| nr_block_start * context->w_stride + |
| group_index * context->gw_stride), |
| (void*)((uintptr_t)context->c + group_index * context->gc_stride + |
| batch_index * context->bc_stride + mr_block_start * cm_stride + |
| (nr_block_start << context->log2_csize)), |
| cm_stride, context->cn_stride, |
| context->a_offset + group_index * context->ga_stride + |
| batch_index * context->ba_stride, |
| context->zero, context->zero_buffers[batch_index], &context->params, |
| &context->quantization_params[batch_index]); |
| mr_block_size -= mr_step; |
| mr_block_start += mr_step; |
| } |
| } |
| |
| enum xnn_status xnn_run_operator(xnn_operator_t op, pthreadpool_t threadpool) { |
| return xnn_run_operator_with_index(op, 0, 0, threadpool); |
| } |
| |
| enum xnn_status xnn_run_operator_with_index(xnn_operator_t op, |
| size_t opdata_index, |
| size_t operator_object_index, |
| pthreadpool_t threadpool) { |
| switch (op->state) { |
| case xnn_run_state_invalid: |
| xnn_log_error( |
| "failed to run operator: operator was not successfully setup"); |
| return xnn_status_invalid_state; |
| case xnn_run_state_ready: |
| xnn_log_debug("running operator %zu:%zu (%s %s)", opdata_index, |
| operator_object_index, xnn_operator_type_to_string_v2(op), |
| xnn_microkernel_type_to_string(op->ukernel.type)); |
| break; |
| case xnn_run_state_skip: |
| xnn_log_debug("skip running operator %zu:%zu (%s %s)", opdata_index, |
| operator_object_index, xnn_operator_type_to_string_v2(op), |
| xnn_microkernel_type_to_string(op->ukernel.type)); |
| return xnn_status_success; |
| case xnn_run_state_needs_setup: |
| xnn_log_error( |
| "failed to run operator %zu:%zu (%s %s): operator has been reshaped " |
| "but not yet setup", |
| opdata_index, operator_object_index, |
| xnn_operator_type_to_string_v2(op), |
| xnn_microkernel_type_to_string(op->ukernel.type)); |
| return xnn_status_invalid_state; |
| } |
| |
| uint32_t flags = PTHREADPOOL_FLAG_DISABLE_DENORMALS; |
| if (op->flags & XNN_FLAG_DONT_SPIN_WORKERS) { |
| flags |= PTHREADPOOL_FLAG_YIELD_WORKERS; |
| } |
| for (size_t i = 0; i < op->num_compute_invocations; i++) { |
| struct compute_parameters* compute = &op->compute[i]; |
| if (compute->type == xnn_parallelization_type_invalid) { |
| break; |
| } |
| void* context = |
| (void*)((uintptr_t)(op->dynamic_context.gemm ? op->dynamic_context.gemm |
| : (void*)&op->context) + |
| compute->context_offset); |
| switch (compute->type) { |
| case xnn_parallelization_type_1d: |
| assert(compute->range[0] != 0); |
| pthreadpool_parallelize_1d(threadpool, compute->task_1d, context, |
| compute->range[0], flags); |
| break; |
| case xnn_parallelization_type_1d_with_thread: |
| assert(compute->range[0] != 0); |
| pthreadpool_parallelize_1d_with_thread( |
| threadpool, compute->task_1d_with_thread, context, |
| compute->range[0], flags); |
| break; |
| case xnn_parallelization_type_1d_tile_1d: |
| assert(compute->range[0] != 0); |
| assert(compute->tile[0] != 0); |
| pthreadpool_parallelize_1d_tile_1d(threadpool, compute->task_1d_tile_1d, |
| context, compute->range[0], |
| compute->tile[0], flags); |
| break; |
| case xnn_parallelization_type_1d_tile_1d_dynamic: |
| assert(compute->range[0] != 0); |
| assert(compute->tile[0] != 0); |
| pthreadpool_parallelize_1d_tile_1d_dynamic( |
| threadpool, compute->task_1d_tile_1d_dynamic, context, |
| compute->range[0], compute->tile[0], flags); |
| break; |
| case xnn_parallelization_type_1d_tile_1d_dynamic_with_thread: |
| assert(compute->range[0] != 0); |
| assert(compute->tile[0] != 0); |
| pthreadpool_parallelize_1d_tile_1d_dynamic_with_thread( |
| threadpool, compute->task_1d_tile_1d_dynamic_with_id, context, |
| compute->range[0], compute->tile[0], flags); |
| break; |
| case xnn_parallelization_type_2d: |
| assert(compute->range[0] != 0); |
| assert(compute->range[1] != 0); |
| pthreadpool_parallelize_2d(threadpool, compute->task_2d, context, |
| compute->range[0], compute->range[1], flags); |
| break; |
| case xnn_parallelization_type_2d_with_thread: |
| assert(compute->range[0] != 0); |
| assert(compute->range[1] != 0); |
| pthreadpool_parallelize_2d_with_thread( |
| threadpool, compute->task_2d_with_thread, context, |
| compute->range[0], compute->range[1], flags); |
| break; |
| case xnn_parallelization_type_2d_tile_1d: |
| assert(compute->range[0] != 0); |
| assert(compute->range[1] != 0); |
| assert(compute->tile[0] != 0); |
| pthreadpool_parallelize_2d_tile_1d( |
| threadpool, compute->task_2d_tile_1d, context, compute->range[0], |
| compute->range[1], compute->tile[0], flags); |
| break; |
| case xnn_parallelization_type_2d_tile_2d: |
| assert(compute->range[0] != 0); |
| assert(compute->range[1] != 0); |
| assert(compute->tile[0] != 0); |
| assert(compute->tile[1] != 0); |
| pthreadpool_parallelize_2d_tile_2d( |
| threadpool, compute->task_2d_tile_2d, context, compute->range[0], |
| compute->range[1], compute->tile[0], compute->tile[1], flags); |
| break; |
| case xnn_parallelization_type_2d_tile_1d_dynamic: |
| assert(compute->range[0] != 0); |
| assert(compute->range[1] != 0); |
| assert(compute->tile[0] != 0); |
| pthreadpool_parallelize_2d_tile_1d_dynamic( |
| threadpool, compute->task_2d_tile_1d_dynamic, context, |
| compute->range[0], compute->range[1], compute->tile[0], flags); |
| break; |
| case xnn_parallelization_type_2d_tile_1d_dynamic_with_thread: |
| assert(compute->range[0] != 0); |
| assert(compute->range[1] != 0); |
| assert(compute->tile[0] != 0); |
| pthreadpool_parallelize_2d_tile_1d_dynamic_with_thread( |
| threadpool, compute->task_2d_tile_1d_dynamic_with_id, context, |
| compute->range[0], compute->range[1], compute->tile[0], flags); |
| break; |
| case xnn_parallelization_type_2d_tile_2d_dynamic: |
| assert(compute->range[0] != 0); |
| assert(compute->range[1] != 0); |
| assert(compute->tile[0] != 0); |
| assert(compute->tile[1] != 0); |
| pthreadpool_parallelize_2d_tile_2d_dynamic( |
| threadpool, compute->task_2d_tile_2d_dynamic, context, |
| compute->range[0], compute->range[1], compute->tile[0], |
| compute->tile[1], flags); |
| break; |
| case xnn_parallelization_type_2d_tile_2d_dynamic_with_thread: |
| assert(compute->range[0] != 0); |
| assert(compute->range[1] != 0); |
| assert(compute->tile[0] != 0); |
| assert(compute->tile[1] != 0); |
| pthreadpool_parallelize_2d_tile_2d_dynamic_with_thread( |
| threadpool, compute->task_2d_tile_2d_dynamic_with_id, context, |
| compute->range[0], compute->range[1], compute->tile[0], |
| compute->tile[1], flags); |
| break; |
| case xnn_parallelization_type_3d: |
| assert(compute->range[0] != 0); |
| assert(compute->range[1] != 0); |
| assert(compute->range[2] != 0); |
| pthreadpool_parallelize_3d(threadpool, compute->task_3d, context, |
| compute->range[0], compute->range[1], |
| compute->range[2], flags); |
| break; |
| case xnn_parallelization_type_3d_tile_1d: |
| assert(compute->range[0] != 0); |
| assert(compute->range[1] != 0); |
| assert(compute->range[2] != 0); |
| assert(compute->tile[0] != 0); |
| pthreadpool_parallelize_3d_tile_1d( |
| threadpool, compute->task_3d_tile_1d, context, compute->range[0], |
| compute->range[1], compute->range[2], compute->tile[0], flags); |
| break; |
| case xnn_parallelization_type_3d_tile_1d_dynamic: |
| assert(compute->range[0] != 0); |
| assert(compute->range[1] != 0); |
| assert(compute->range[2] != 0); |
| assert(compute->tile[0] != 0); |
| pthreadpool_parallelize_3d_tile_1d_dynamic( |
| threadpool, compute->task_3d_tile_1d_dynamic, context, |
| compute->range[0], compute->range[1], compute->range[2], |
| compute->tile[0], flags); |
| break; |
| case xnn_parallelization_type_3d_tile_1d_with_thread: |
| assert(compute->range[0] != 0); |
| assert(compute->range[1] != 0); |
| assert(compute->range[2] != 0); |
| assert(compute->tile[0] != 0); |
| pthreadpool_parallelize_3d_tile_1d_with_thread( |
| threadpool, compute->task_3d_tile_1d_with_thread, context, |
| compute->range[0], compute->range[1], compute->range[2], |
| compute->tile[0], flags); |
| break; |
| case xnn_parallelization_type_3d_tile_1d_dynamic_with_thread: |
| assert(compute->range[0] != 0); |
| assert(compute->range[1] != 0); |
| assert(compute->range[2] != 0); |
| assert(compute->tile[0] != 0); |
| pthreadpool_parallelize_3d_tile_1d_dynamic_with_thread( |
| threadpool, compute->task_3d_tile_1d_dynamic_with_id, context, |
| compute->range[0], compute->range[1], compute->range[2], |
| compute->tile[0], flags); |
| break; |
| case xnn_parallelization_type_3d_tile_2d: |
| assert(compute->range[0] != 0); |
| assert(compute->range[1] != 0); |
| assert(compute->range[2] != 0); |
| assert(compute->tile[0] != 0); |
| assert(compute->tile[1] != 0); |
| pthreadpool_parallelize_3d_tile_2d( |
| threadpool, compute->task_3d_tile_2d, context, compute->range[0], |
| compute->range[1], compute->range[2], compute->tile[0], |
| compute->tile[1], flags); |
| break; |
| case xnn_parallelization_type_3d_tile_2d_dynamic: |
| assert(compute->range[0] != 0); |
| assert(compute->range[1] != 0); |
| assert(compute->range[2] != 0); |
| assert(compute->tile[0] != 0); |
| assert(compute->tile[1] != 0); |
| pthreadpool_parallelize_3d_tile_2d_dynamic( |
| threadpool, compute->task_3d_tile_2d_dynamic, context, |
| compute->range[0], compute->range[1], compute->range[2], |
| compute->tile[0], compute->tile[1], flags); |
| break; |
| case xnn_parallelization_type_3d_tile_2d_dynamic_with_thread: |
| assert(compute->range[0] != 0); |
| assert(compute->range[1] != 0); |
| assert(compute->range[2] != 0); |
| assert(compute->tile[0] != 0); |
| assert(compute->tile[1] != 0); |
| pthreadpool_parallelize_3d_tile_2d_dynamic_with_thread( |
| threadpool, compute->task_3d_tile_2d_dynamic_with_id, context, |
| compute->range[0], compute->range[1], compute->range[2], |
| compute->tile[0], compute->tile[1], flags); |
| break; |
| case xnn_parallelization_type_4d: |
| assert(compute->range[0] != 0); |
| assert(compute->range[1] != 0); |
| assert(compute->range[2] != 0); |
| assert(compute->range[3] != 0); |
| pthreadpool_parallelize_4d(threadpool, compute->task_4d, context, |
| compute->range[0], compute->range[1], |
| compute->range[2], compute->range[3], flags); |
| break; |
| case xnn_parallelization_type_4d_tile_2d: |
| assert(compute->range[0] != 0); |
| assert(compute->range[1] != 0); |
| assert(compute->range[2] != 0); |
| assert(compute->range[3] != 0); |
| assert(compute->tile[0] != 0); |
| assert(compute->tile[1] != 0); |
| pthreadpool_parallelize_4d_tile_2d( |
| threadpool, compute->task_4d_tile_2d, context, compute->range[0], |
| compute->range[1], compute->range[2], compute->range[3], |
| compute->tile[0], compute->tile[1], flags); |
| break; |
| case xnn_parallelization_type_4d_tile_2d_dynamic: |
| assert(compute->range[0] != 0); |
| assert(compute->range[1] != 0); |
| assert(compute->range[2] != 0); |
| assert(compute->range[3] != 0); |
| assert(compute->tile[0] != 0); |
| assert(compute->tile[1] != 0); |
| pthreadpool_parallelize_4d_tile_2d_dynamic( |
| threadpool, compute->task_4d_tile_2d_dynamic, context, |
| compute->range[0], compute->range[1], compute->range[2], |
| compute->range[3], compute->tile[0], compute->tile[1], flags); |
| break; |
| case xnn_parallelization_type_5d: |
| assert(compute->range[0] != 0); |
| assert(compute->range[1] != 0); |
| assert(compute->range[2] != 0); |
| assert(compute->range[3] != 0); |
| assert(compute->range[4] != 0); |
| pthreadpool_parallelize_5d(threadpool, compute->task_5d, context, |
| compute->range[0], compute->range[1], |
| compute->range[2], compute->range[3], |
| compute->range[4], flags); |
| break; |
| case xnn_parallelization_type_5d_tile_2d: |
| assert(compute->range[0] != 0); |
| assert(compute->range[1] != 0); |
| assert(compute->range[2] != 0); |
| assert(compute->range[3] != 0); |
| assert(compute->range[4] != 0); |
| assert(compute->tile[0] != 0); |
| assert(compute->tile[1] != 0); |
| pthreadpool_parallelize_5d_tile_2d( |
| threadpool, compute->task_5d_tile_2d, context, compute->range[0], |
| compute->range[1], compute->range[2], compute->range[3], |
| compute->range[4], compute->tile[0], compute->tile[1], flags); |
| break; |
| case xnn_parallelization_type_6d_tile_2d: |
| assert(compute->range[0] != 0); |
| assert(compute->range[1] != 0); |
| assert(compute->range[2] != 0); |
| assert(compute->range[3] != 0); |
| assert(compute->range[4] != 0); |
| assert(compute->range[5] != 0); |
| assert(compute->tile[0] != 0); |
| assert(compute->tile[1] != 0); |
| pthreadpool_parallelize_6d_tile_2d( |
| threadpool, compute->task_6d_tile_2d, context, compute->range[0], |
| compute->range[1], compute->range[2], compute->range[3], |
| compute->range[4], compute->range[5], compute->tile[0], |
| compute->tile[1], flags); |
| break; |
| #if XNN_MAX_UARCH_TYPES > 1 |
| case xnn_parallelization_type_2d_tile_1d_with_uarch: |
| assert(compute->range[0] != 0); |
| assert(compute->range[1] != 0); |
| assert(compute->tile[0] != 0); |
| pthreadpool_parallelize_2d_tile_1d_with_uarch( |
| threadpool, compute->task_2d_tile_1d_with_id, context, |
| 0 /* default uarch index */, XNN_MAX_UARCH_TYPES - 1, |
| compute->range[0], compute->range[1], compute->tile[0], flags); |
| break; |
| case xnn_parallelization_type_1d_tile_1d_dynamic_with_uarch_with_thread: |
| assert(compute->range[0] != 0); |
| assert(compute->tile[0] != 0); |
| pthreadpool_parallelize_1d_tile_1d_dynamic_with_uarch_with_thread( |
| threadpool, compute->task_1d_tile_1d_dynamic_with_id_with_thread, |
| context, |
| /*default_uarch_index=*/0, XNN_MAX_UARCH_TYPES - 1, |
| compute->range[0], compute->tile[0], flags); |
| break; |
| case xnn_parallelization_type_2d_tile_1d_dynamic_with_uarch_with_thread: |
| assert(compute->range[0] != 0); |
| assert(compute->range[1] != 0); |
| assert(compute->tile[0] != 0); |
| pthreadpool_parallelize_2d_tile_1d_dynamic_with_uarch_with_thread( |
| threadpool, compute->task_2d_tile_1d_dynamic_with_id_with_thread, |
| context, |
| /*default_uarch_index=*/0, XNN_MAX_UARCH_TYPES - 1, |
| compute->range[0], compute->range[1], compute->tile[0], flags); |
| break; |
| case xnn_parallelization_type_2d_tile_2d_with_uarch: |
| assert(compute->range[0] != 0); |
| assert(compute->range[1] != 0); |
| assert(compute->tile[0] != 0); |
| assert(compute->tile[1] != 0); |
| pthreadpool_parallelize_2d_tile_2d_with_uarch( |
| threadpool, compute->task_2d_tile_2d_with_id, context, |
| 0 /* default uarch index */, XNN_MAX_UARCH_TYPES - 1, |
| compute->range[0], compute->range[1], compute->tile[0], |
| compute->tile[1], flags); |
| break; |
| case xnn_parallelization_type_2d_tile_2d_dynamic_with_uarch: |
| assert(compute->range[0] != 0); |
| assert(compute->range[1] != 0); |
| assert(compute->tile[0] != 0); |
| assert(compute->tile[1] != 0); |
| pthreadpool_parallelize_2d_tile_2d_dynamic_with_uarch( |
| threadpool, compute->task_2d_tile_2d_dynamic_with_id, context, |
| /*default_uarch_index=*/0, XNN_MAX_UARCH_TYPES - 1, |
| compute->range[0], compute->range[1], compute->tile[0], |
| compute->tile[1], flags); |
| break; |
| case xnn_parallelization_type_3d_tile_1d_with_uarch: |
| assert(compute->range[0] != 0); |
| assert(compute->range[1] != 0); |
| assert(compute->range[2] != 0); |
| assert(compute->tile[0] != 0); |
| pthreadpool_parallelize_3d_tile_1d_with_uarch( |
| threadpool, compute->task_3d_tile_1d_with_id, context, |
| 0 /* default uarch index */, XNN_MAX_UARCH_TYPES - 1, |
| compute->range[0], compute->range[1], compute->range[2], |
| compute->tile[0], flags); |
| break; |
| case xnn_parallelization_type_3d_tile_1d_with_uarch_with_thread: |
| assert(compute->range[0] != 0); |
| assert(compute->range[1] != 0); |
| assert(compute->range[2] != 0); |
| assert(compute->tile[0] != 0); |
| pthreadpool_parallelize_3d_tile_1d_with_uarch_with_thread( |
| threadpool, compute->task_3d_tile_1d_with_id_with_thread, context, |
| 0 /* default uarch index */, XNN_MAX_UARCH_TYPES - 1, |
| compute->range[0], compute->range[1], compute->range[2], |
| compute->tile[0], flags); |
| break; |
| case xnn_parallelization_type_3d_tile_2d_with_uarch: |
| assert(compute->range[0] != 0); |
| assert(compute->range[1] != 0); |
| assert(compute->range[2] != 0); |
| assert(compute->tile[0] != 0); |
| assert(compute->tile[1] != 0); |
| pthreadpool_parallelize_3d_tile_2d_with_uarch( |
| threadpool, compute->task_3d_tile_2d_with_id, context, |
| 0 /* default uarch index */, XNN_MAX_UARCH_TYPES - 1, |
| compute->range[0], compute->range[1], compute->range[2], |
| compute->tile[0], compute->tile[1], flags); |
| break; |
| case xnn_parallelization_type_3d_tile_2d_dynamic_with_uarch: |
| assert(compute->range[0] != 0); |
| assert(compute->range[1] != 0); |
| assert(compute->range[2] != 0); |
| assert(compute->tile[0] != 0); |
| assert(compute->tile[1] != 0); |
| pthreadpool_parallelize_3d_tile_2d_dynamic_with_uarch( |
| threadpool, compute->task_3d_tile_2d_dynamic_with_id, context, |
| /*default_uarch_index=*/0, XNN_MAX_UARCH_TYPES - 1, |
| compute->range[0], compute->range[1], compute->range[2], |
| compute->tile[0], compute->tile[1], flags); |
| break; |
| case xnn_parallelization_type_4d_tile_2d_with_uarch: |
| assert(compute->range[0] != 0); |
| assert(compute->range[1] != 0); |
| assert(compute->range[2] != 0); |
| assert(compute->range[3] != 0); |
| assert(compute->tile[0] != 0); |
| assert(compute->tile[1] != 0); |
| pthreadpool_parallelize_4d_tile_2d_with_uarch( |
| threadpool, compute->task_4d_tile_2d_with_id, context, |
| 0 /* default uarch index */, XNN_MAX_UARCH_TYPES - 1, |
| compute->range[0], compute->range[1], compute->range[2], |
| compute->range[3], compute->tile[0], compute->tile[1], flags); |
| break; |
| case xnn_parallelization_type_4d_tile_2d_dynamic_with_uarch: |
| assert(compute->range[0] != 0); |
| assert(compute->range[1] != 0); |
| assert(compute->range[2] != 0); |
| assert(compute->range[3] != 0); |
| assert(compute->tile[0] != 0); |
| assert(compute->tile[1] != 0); |
| pthreadpool_parallelize_4d_tile_2d_dynamic_with_uarch( |
| threadpool, compute->task_4d_tile_2d_dynamic_with_id, context, |
| 0 /* default uarch index */, XNN_MAX_UARCH_TYPES - 1, |
| compute->range[0], compute->range[1], compute->range[2], |
| compute->range[3], compute->tile[0], compute->tile[1], flags); |
| break; |
| #endif // XNN_MAX_UARCH_TYPES > 1 |
| default: |
| XNN_UNREACHABLE; |
| } |
| } |
| return xnn_status_success; |
| } |