| // Copyright 2025 Google LLC |
| // |
| // This source code is licensed under the BSD-style license found in the |
| // LICENSE file in the root directory of this source tree. |
| |
| #include "src/xnnpack/pack-lh.h" |
| |
| #include <cassert> |
| #include <cstddef> |
| #include <cstdint> |
| #include <limits> |
| |
| #include "src/xnnpack/common.h" |
| #include "src/xnnpack/config-types.h" |
| #include "src/xnnpack/config.h" |
| #include "src/xnnpack/hardware-config.h" |
| #include "src/xnnpack/math.h" |
| #include "src/xnnpack/microfnptr.h" |
| #include "src/xnnpack/microparams.h" |
| #include "src/xnnpack/quantization.h" |
| |
| namespace { |
| |
| size_t xnn_pack_lh_fx_qd8_packed_size(size_t m, size_t k, size_t mr_packed, |
| size_t kr, size_t sr) { |
| // Each packed row starts with the `mr_packed` quantization params, followed |
| // by the `mr_packed` rows of quantized data. |
| m = round_up(m, mr_packed); |
| k = round_up(k, kr * sr); |
| const size_t alignment = alignof(struct xnn_qd8_quantization_params); |
| return m * round_up(sizeof(struct xnn_qd8_quantization_params) + |
| k * sizeof(int8_t), |
| alignment); |
| } |
| |
| size_t xnn_pack_lh_fx_qd8_row_sums_packed_size(size_t m, size_t k, |
| size_t mr_packed, |
| size_t kr, size_t sr) { |
| // Each packed row starts with the `mr_packed` quantization params, |
| // followed by the `mr_packed` row_sums followed by the `mr_packed` rows of |
| // quantized data. |
| m = round_up(m, mr_packed); |
| k = round_up(k, kr * sr); |
| const size_t alignment = alignof(struct xnn_qd8_quantization_params); |
| return m * round_up(sizeof(struct xnn_qd8_quantization_params) + |
| sizeof(float) + k * sizeof(int8_t), |
| alignment); |
| } |
| |
| size_t xnn_pack_lh_fx_qd8_packed_offset(size_t m, size_t k, size_t mr_packed, |
| size_t kr, size_t sr) { |
| // Each packed row starts with the `mr_packed` quantization params, followed |
| // by the `mr_packed` rows of quantized data. |
| assert(m % mr_packed == 0); |
| k = round_up(k, kr * sr); |
| const size_t alignment = alignof(struct xnn_qd8_quantization_params); |
| return m * round_up(sizeof(struct xnn_qd8_quantization_params) + |
| k * sizeof(int8_t), |
| alignment); |
| } |
| |
| size_t xnn_pack_lh_fx_qd8_qc2w_packed_offset(size_t m, size_t k, |
| size_t mr_packed, size_t kr, |
| size_t sr) { |
| // Each packed row starts with the `mr_packed` quantization params, followed |
| // by the `mr_packed` row_sums, followed by the `mr_packed` rows of quantized |
| // data. |
| assert(m % mr_packed == 0); |
| k = round_up(k, kr * sr); |
| const size_t alignment = alignof(struct xnn_qd8_quantization_params); |
| return m * round_up(sizeof(struct xnn_qd8_quantization_params) + |
| sizeof(float) + k * sizeof(int8_t), |
| alignment); |
| } |
| |
| // Wraps a templated function that generates `xnn_qd8_quantization_params` from |
| // parameters of the templated type. |
| template <typename T> |
| struct InitQuantizationParams { |
| typedef struct xnn_qd8_quantization_params (*fn)(T min, T max, T* scale); |
| }; |
| |
| template <typename InputT, typename OutputT, typename qs8_cvt_params_t, |
| typename InitQuantizationParams<InputT>::fn init_quantization_params> |
| static void pack_lh_fx_qd(size_t m, size_t k, size_t mr_packed, size_t kr, |
| size_t sr, size_t m_idx_start, const InputT* lhs, |
| size_t lhs_stride, void* lhs_packed, |
| xnn_vunary_ukernel_fn convert_ukernel, |
| xnn_reduce_ukernel_fn minmax_ukernel, |
| xnn_reduce_ukernel_fn rsum_ukernel) { |
| assert(m_idx_start == 0); |
| |
| struct xnn_f32_default_params minmax_params; |
| qs8_cvt_params_t convert_params; |
| xnn_qs8_rsum_params rsum_params; |
| |
| const size_t k_scaled = k * sizeof(InputT); |
| const uintptr_t packed_row_stride = round_up(k, kr * sr) * sizeof(OutputT); |
| const size_t packed_size = |
| rsum_ukernel |
| ? xnn_pack_lh_fx_qd8_row_sums_packed_size(/*m=*/mr_packed, k, |
| mr_packed, kr, sr) |
| : xnn_pack_lh_fx_qd8_packed_size(/*m=*/mr_packed, k, mr_packed, kr, |
| sr); |
| |
| while (m) { |
| // Pointers to the input and output data for this set of `mr` rows. |
| struct xnn_qd8_quantization_params* quantization_params = |
| static_cast<struct xnn_qd8_quantization_params*>(lhs_packed); |
| assert((uintptr_t)quantization_params % |
| alignof(struct xnn_qd8_quantization_params) == 0); |
| float* row_sum = nullptr; |
| OutputT* packed_weights = |
| reinterpret_cast<OutputT*>(reinterpret_cast<uintptr_t>(lhs_packed) + |
| mr_packed * sizeof(struct xnn_qd8_quantization_params)); |
| if (rsum_ukernel) { |
| row_sum = reinterpret_cast<float*>( |
| reinterpret_cast<uintptr_t>(lhs_packed) + |
| mr_packed * sizeof(struct xnn_qd8_quantization_params)); |
| packed_weights = reinterpret_cast<OutputT*>( |
| reinterpret_cast<uintptr_t>(packed_weights) + |
| mr_packed * sizeof(float)); |
| } |
| |
| // For each row in this block of `mr` rows... |
| for (size_t row_id = 0; row_id < min(mr_packed, m); ++row_id) { |
| // Compute the quantization params for this row. |
| InputT minmax[2] = {std::numeric_limits<float>::infinity(), |
| -std::numeric_limits<float>::infinity()}; |
| InputT scale; |
| minmax_ukernel(k_scaled, lhs, minmax, &minmax_params); |
| quantization_params[row_id] = |
| init_quantization_params(minmax[0], minmax[1], &scale); |
| |
| // Quantize the row. |
| convert_params.scalar.scale = scale; |
| convert_params.scalar.output_zero_point = |
| quantization_params[row_id].zero_point; |
| convert_ukernel(k_scaled, lhs, packed_weights, |
| (union xnn_unary_uparams*)&convert_params); |
| |
| if (row_sum) { |
| // Compute and store the row sum of the quantized output. |
| const size_t num_bytes = k * sizeof(OutputT); |
| int32_t row_sum_value = 0; |
| rsum_ukernel(num_bytes, packed_weights, &row_sum_value, &rsum_params); |
| row_sum[row_id] = static_cast<float>(row_sum_value); |
| } |
| |
| // Advance the input and output pointers. |
| lhs = (const InputT*)((uintptr_t)lhs + lhs_stride); |
| packed_weights = |
| (OutputT*)((uintptr_t)packed_weights + packed_row_stride); |
| } |
| |
| // Copy any extra quantization params if needed. |
| for (size_t row_id = m; row_id < mr_packed; ++row_id) { |
| quantization_params[row_id] = quantization_params[m - 1]; |
| |
| if (row_sum) { |
| row_sum[row_id] = row_sum[m - 1]; |
| } |
| } |
| |
| // Advance the pointers and counters. |
| lhs_packed = (void*)((uintptr_t)lhs_packed + packed_size); |
| m -= min(mr_packed, m); |
| } |
| } |
| |
| void xnn_pack_lh_f32_qdint8(size_t m, size_t k, size_t mr_packed, size_t kr, |
| size_t sr, size_t m_idx_start, const void* lhs, |
| size_t lhs_stride, void* lhs_packed) { |
| static const xnn_vunary_ukernel_fn convert_ukernel = |
| xnn_init_f32_to_qs8_cvt_config()->ukernel; |
| static const xnn_reduce_ukernel_fn minmax_ukernel = |
| xnn_init_f32_rminmax_config()->ukernel; |
| pack_lh_fx_qd</*InputT=*/float, /*OutputT=*/int8_t, |
| /*qs8_cvt_params_t=*/struct xnn_f32_qs8_cvt_params, |
| xnn_f32_qd8_asymmetric_quantization_params>( |
| m, k, mr_packed, kr, sr, m_idx_start, (const float*)lhs, lhs_stride, |
| lhs_packed, convert_ukernel, minmax_ukernel, /*rsum_ukernel=*/nullptr); |
| } |
| |
| void xnn_pack_lh_f32_qdint8_qc2w(size_t m, size_t k, size_t mr_packed, |
| size_t kr, size_t sr, size_t m_idx_start, |
| const void* lhs, size_t lhs_stride, |
| void* lhs_packed) { |
| static const xnn_vunary_ukernel_fn convert_ukernel = |
| xnn_init_f32_to_qs8_cvt_config()->ukernel; |
| static const xnn_reduce_ukernel_fn minmax_ukernel = |
| xnn_init_f32_rminmax_config()->ukernel; |
| static const xnn_reduce_ukernel_fn rsum_ukernel = |
| xnn_init_qs8_rsum_config()->ukernel; |
| pack_lh_fx_qd</*InputT=*/float, /*OutputT=*/int8_t, |
| /*qs8_cvt_params_t=*/struct xnn_f32_qs8_cvt_params, |
| xnn_f32_qd8_asymmetric_quantization_params>( |
| m, k, mr_packed, kr, sr, m_idx_start, (const float*)lhs, lhs_stride, |
| lhs_packed, convert_ukernel, minmax_ukernel, rsum_ukernel); |
| } |
| |
| void xnn_pack_lh_f32_qduint8_qc2w(size_t m, size_t k, size_t mr_packed, |
| size_t kr, size_t sr, size_t m_idx_start, |
| const void* lhs, size_t lhs_stride, |
| void* lhs_packed) { |
| static const xnn_vunary_ukernel_fn convert_ukernel = |
| xnn_init_f32_to_qu8_cvt_config()->ukernel; |
| static const xnn_reduce_ukernel_fn minmax_ukernel = |
| xnn_init_f32_rminmax_config()->ukernel; |
| static const xnn_reduce_ukernel_fn rsum_ukernel = |
| xnn_init_qu8_rsum_config()->ukernel; |
| pack_lh_fx_qd</*InputT=*/float, /*OutputT=*/uint8_t, |
| /*qs8_cvt_params_t=*/struct xnn_f32_qs8_cvt_params, |
| xnn_f32_qdu8_asymmetric_quantization_params>( |
| m, k, mr_packed, kr, sr, m_idx_start, (const float*)lhs, lhs_stride, |
| lhs_packed, convert_ukernel, minmax_ukernel, rsum_ukernel); |
| } |
| |
| void xnn_pack_lh_f32_qduint8(size_t m, size_t k, size_t mr_packed, size_t kr, |
| size_t sr, size_t m_idx_start, const void* lhs, |
| size_t lhs_stride, void* lhs_packed) { |
| static const xnn_vunary_ukernel_fn convert_ukernel = |
| xnn_init_f32_to_qu8_cvt_config()->ukernel; |
| static const xnn_reduce_ukernel_fn minmax_ukernel = |
| xnn_init_f32_rminmax_config()->ukernel; |
| pack_lh_fx_qd</*InputT=*/float, /*OutputT=*/uint8_t, |
| /*qs8_cvt_params_t=*/struct xnn_f32_qs8_cvt_params, |
| xnn_f32_qdu8_asymmetric_quantization_params>( |
| m, k, mr_packed, kr, sr, m_idx_start, (const float*)lhs, lhs_stride, |
| lhs_packed, convert_ukernel, minmax_ukernel, /*rsum_ukernel=*/nullptr); |
| } |
| |
| void xnn_pack_lh_f16_qdint8_qc2w(size_t m, size_t k, size_t mr_packed, |
| size_t kr, size_t sr, size_t m_idx_start, |
| const void* lhs, size_t lhs_stride, |
| void* lhs_packed) { |
| static const xnn_vunary_ukernel_fn convert_ukernel = |
| xnn_init_f16_to_qs8_cvt_config()->ukernel; |
| static const xnn_reduce_ukernel_fn minmax_ukernel = |
| xnn_init_f16_rminmax_config()->ukernel; |
| static const xnn_reduce_ukernel_fn rsum_ukernel = |
| xnn_init_qs8_rsum_config()->ukernel; |
| pack_lh_fx_qd</*InputT=*/xnn_float16, /*OutputT=*/int8_t, |
| /*qs8_cvt_params_t=*/struct xnn_f16_qs8_cvt_params, |
| xnn_f16_qd8_asymmetric_quantization_params>( |
| m, k, mr_packed, kr, sr, m_idx_start, (const xnn_float16*)lhs, lhs_stride, |
| lhs_packed, convert_ukernel, minmax_ukernel, rsum_ukernel); |
| } |
| |
| void xnn_pack_lh_f16_qduint8_qc2w(size_t m, size_t k, size_t mr_packed, |
| size_t kr, size_t sr, size_t m_idx_start, |
| const void* lhs, size_t lhs_stride, |
| void* lhs_packed) { |
| static const xnn_vunary_ukernel_fn convert_ukernel = |
| xnn_init_f16_to_qu8_cvt_config()->ukernel; |
| static const xnn_reduce_ukernel_fn minmax_ukernel = |
| xnn_init_f16_rminmax_config()->ukernel; |
| static const xnn_reduce_ukernel_fn rsum_ukernel = |
| xnn_init_qu8_rsum_config()->ukernel; |
| pack_lh_fx_qd</*InputT=*/xnn_float16, /*OutputT=*/uint8_t, |
| /*qs8_cvt_params_t=*/struct xnn_f16_qs8_cvt_params, |
| xnn_f16_qdu8_asymmetric_quantization_params>( |
| m, k, mr_packed, kr, sr, m_idx_start, (const xnn_float16*)lhs, lhs_stride, |
| lhs_packed, convert_ukernel, minmax_ukernel, rsum_ukernel); |
| } |
| |
| void xnn_pack_lh_f16_qdint8(size_t m, size_t k, size_t mr_packed, size_t kr, |
| size_t sr, size_t m_idx_start, const void* lhs, |
| size_t lhs_stride, void* lhs_packed) { |
| static const xnn_vunary_ukernel_fn convert_ukernel = |
| xnn_init_f16_to_qs8_cvt_config()->ukernel; |
| static const xnn_reduce_ukernel_fn minmax_ukernel = |
| xnn_init_f16_rminmax_config()->ukernel; |
| pack_lh_fx_qd</*InputT=*/xnn_float16, /*OutputT=*/int8_t, |
| /*qs8_cvt_params_t=*/struct xnn_f16_qs8_cvt_params, |
| xnn_f16_qd8_asymmetric_quantization_params>( |
| m, k, mr_packed, kr, sr, m_idx_start, (const xnn_float16*)lhs, lhs_stride, |
| lhs_packed, convert_ukernel, minmax_ukernel, /*rsum_ukernel=*/nullptr); |
| } |
| |
| void xnn_pack_lh_f16_qduint8(size_t m, size_t k, size_t mr_packed, size_t kr, |
| size_t sr, size_t m_idx_start, const void* lhs, |
| size_t lhs_stride, void* lhs_packed) { |
| static const xnn_vunary_ukernel_fn convert_ukernel = |
| xnn_init_f16_to_qu8_cvt_config()->ukernel; |
| static const xnn_reduce_ukernel_fn minmax_ukernel = |
| xnn_init_f16_rminmax_config()->ukernel; |
| pack_lh_fx_qd</*InputT=*/xnn_float16, /*OutputT=*/uint8_t, |
| /*qs8_cvt_params_t=*/struct xnn_f16_qs8_cvt_params, |
| xnn_f16_qdu8_asymmetric_quantization_params>( |
| m, k, mr_packed, kr, sr, m_idx_start, (const xnn_float16*)lhs, lhs_stride, |
| lhs_packed, convert_ukernel, minmax_ukernel, /*rsum_ukernel=*/nullptr); |
| } |
| |
| } // namespace |
| |
| extern "C" { |
| |
| const xnn_pack_lh_config* xnn_init_f16_qdint8_pack_lh_config() { |
| const xnn_hardware_config* hardware_config = |
| xnn_init_hardware_config(); |
| if (hardware_config == nullptr) { |
| return nullptr; |
| } |
| static const xnn_pack_lh_config config = []() { |
| xnn_pack_lh_config config = {}; |
| config.pack_lh_fn = (xnn_pack_lh_ukernel_fn)xnn_pack_lh_f16_qdint8; |
| config.size_fn = (xnn_pack_lh_size_fn)xnn_pack_lh_fx_qd8_packed_size; |
| config.offset_fn = (xnn_pack_lh_offset_fn)xnn_pack_lh_fx_qd8_packed_offset; |
| config.log2_input_element_size = XNN_LOG2_SIZEOF_FLOAT16; |
| config.log2_packed_element_size = 0; |
| return config; |
| }(); |
| return &config; |
| } |
| |
| const xnn_pack_lh_config* xnn_init_f16_qdint8_row_sums_pack_lh_config() { |
| const xnn_hardware_config* hardware_config = |
| xnn_init_hardware_config(); |
| if (hardware_config == nullptr) { |
| return nullptr; |
| } |
| static const xnn_pack_lh_config config = []() { |
| xnn_pack_lh_config config = {}; |
| config.pack_lh_fn = (xnn_pack_lh_ukernel_fn)xnn_pack_lh_f16_qdint8_qc2w; |
| config.size_fn = |
| (xnn_pack_lh_size_fn)xnn_pack_lh_fx_qd8_row_sums_packed_size; |
| config.offset_fn = |
| (xnn_pack_lh_offset_fn)xnn_pack_lh_fx_qd8_qc2w_packed_offset; |
| config.log2_input_element_size = XNN_LOG2_SIZEOF_FLOAT16; |
| config.log2_packed_element_size = 0; |
| return config; |
| }(); |
| return &config; |
| } |
| |
| const xnn_pack_lh_config* xnn_init_f16_qduint8_row_sums_pack_lh_config() { |
| const xnn_hardware_config* hardware_config = |
| xnn_init_hardware_config(); |
| if (hardware_config == nullptr) { |
| return nullptr; |
| } |
| static const xnn_pack_lh_config config = []() { |
| xnn_pack_lh_config config = {}; |
| config.pack_lh_fn = (xnn_pack_lh_ukernel_fn)xnn_pack_lh_f16_qduint8_qc2w; |
| config.size_fn = |
| (xnn_pack_lh_size_fn)xnn_pack_lh_fx_qd8_row_sums_packed_size; |
| config.offset_fn = |
| (xnn_pack_lh_offset_fn)xnn_pack_lh_fx_qd8_qc2w_packed_offset; |
| config.log2_input_element_size = XNN_LOG2_SIZEOF_FLOAT16; |
| config.log2_packed_element_size = 0; |
| return config; |
| }(); |
| return &config; |
| } |
| |
| const xnn_pack_lh_config* xnn_init_f16_qduint8_pack_lh_config() { |
| const xnn_hardware_config* hardware_config = |
| xnn_init_hardware_config(); |
| if (hardware_config == nullptr) { |
| return nullptr; |
| } |
| static const xnn_pack_lh_config config = []() { |
| xnn_pack_lh_config config = {}; |
| config.pack_lh_fn = (xnn_pack_lh_ukernel_fn)xnn_pack_lh_f16_qduint8; |
| config.size_fn = (xnn_pack_lh_size_fn)xnn_pack_lh_fx_qd8_packed_size; |
| config.offset_fn = (xnn_pack_lh_offset_fn)xnn_pack_lh_fx_qd8_packed_offset; |
| config.log2_input_element_size = XNN_LOG2_SIZEOF_FLOAT16; |
| config.log2_packed_element_size = 0; |
| return config; |
| }(); |
| return &config; |
| } |
| |
| const xnn_pack_lh_config* xnn_init_f32_qdint8_pack_lh_config() { |
| const xnn_hardware_config* hardware_config = |
| xnn_init_hardware_config(); |
| if (hardware_config == nullptr) { |
| return nullptr; |
| } |
| static const xnn_pack_lh_config config = []() { |
| xnn_pack_lh_config config = {}; |
| config.pack_lh_fn = (xnn_pack_lh_ukernel_fn)xnn_pack_lh_f32_qdint8; |
| config.size_fn = (xnn_pack_lh_size_fn)xnn_pack_lh_fx_qd8_packed_size; |
| config.offset_fn = (xnn_pack_lh_offset_fn)xnn_pack_lh_fx_qd8_packed_offset; |
| config.log2_input_element_size = XNN_LOG2_SIZEOF_FLOAT; |
| config.log2_packed_element_size = 0; |
| return config; |
| }(); |
| return &config; |
| } |
| |
| const xnn_pack_lh_config* xnn_init_f32_qdint8_row_sums_pack_lh_config() { |
| const xnn_hardware_config* hardware_config = |
| xnn_init_hardware_config(); |
| if (hardware_config == nullptr) { |
| return nullptr; |
| } |
| static const xnn_pack_lh_config config = []() { |
| xnn_pack_lh_config config = {}; |
| config.pack_lh_fn = (xnn_pack_lh_ukernel_fn)xnn_pack_lh_f32_qdint8_qc2w; |
| config.size_fn = |
| (xnn_pack_lh_size_fn)xnn_pack_lh_fx_qd8_row_sums_packed_size; |
| config.offset_fn = |
| (xnn_pack_lh_offset_fn)xnn_pack_lh_fx_qd8_qc2w_packed_offset; |
| config.log2_input_element_size = XNN_LOG2_SIZEOF_FLOAT; |
| config.log2_packed_element_size = 0; |
| return config; |
| }(); |
| return &config; |
| } |
| |
| const xnn_pack_lh_config* xnn_init_f32_qduint8_row_sums_pack_lh_config() { |
| const xnn_hardware_config* hardware_config = |
| xnn_init_hardware_config(); |
| if (hardware_config == nullptr) { |
| return nullptr; |
| } |
| static const xnn_pack_lh_config config = []() { |
| xnn_pack_lh_config config = {}; |
| config.pack_lh_fn = (xnn_pack_lh_ukernel_fn)xnn_pack_lh_f32_qduint8_qc2w; |
| config.size_fn = |
| (xnn_pack_lh_size_fn)xnn_pack_lh_fx_qd8_row_sums_packed_size; |
| config.offset_fn = |
| (xnn_pack_lh_offset_fn)xnn_pack_lh_fx_qd8_qc2w_packed_offset; |
| config.log2_input_element_size = XNN_LOG2_SIZEOF_FLOAT; |
| config.log2_packed_element_size = 0; |
| return config; |
| }(); |
| return &config; |
| } |
| |
| const xnn_pack_lh_config* xnn_init_f32_qduint8_pack_lh_config() { |
| const xnn_hardware_config* hardware_config = |
| xnn_init_hardware_config(); |
| if (hardware_config == nullptr) { |
| return nullptr; |
| } |
| static const xnn_pack_lh_config config = []() { |
| xnn_pack_lh_config config = {}; |
| config.pack_lh_fn = (xnn_pack_lh_ukernel_fn)xnn_pack_lh_f32_qduint8; |
| config.size_fn = (xnn_pack_lh_size_fn)xnn_pack_lh_fx_qd8_packed_size; |
| config.offset_fn = (xnn_pack_lh_offset_fn)xnn_pack_lh_fx_qd8_packed_offset; |
| config.log2_input_element_size = XNN_LOG2_SIZEOF_FLOAT; |
| config.log2_packed_element_size = 0; |
| return config; |
| }(); |
| return &config; |
| } |
| |
| } // extern "C" |