blob: f3549fe25fe6c458afc52a806c571975ffbf7e22 [file] [edit]
// Copyright 2025 Google LLC
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.
#include "src/xnnpack/pack-lh.h"
#include <cassert>
#include <cstddef>
#include <cstdint>
#include <limits>
#include "src/xnnpack/common.h"
#include "src/xnnpack/config-types.h"
#include "src/xnnpack/config.h"
#include "src/xnnpack/hardware-config.h"
#include "src/xnnpack/math.h"
#include "src/xnnpack/microfnptr.h"
#include "src/xnnpack/microparams.h"
#include "src/xnnpack/quantization.h"
namespace {
size_t xnn_pack_lh_fx_qd8_packed_size(size_t m, size_t k, size_t mr_packed,
size_t kr, size_t sr) {
// Each packed row starts with the `mr_packed` quantization params, followed
// by the `mr_packed` rows of quantized data.
m = round_up(m, mr_packed);
k = round_up(k, kr * sr);
const size_t alignment = alignof(struct xnn_qd8_quantization_params);
return m * round_up(sizeof(struct xnn_qd8_quantization_params) +
k * sizeof(int8_t),
alignment);
}
size_t xnn_pack_lh_fx_qd8_row_sums_packed_size(size_t m, size_t k,
size_t mr_packed,
size_t kr, size_t sr) {
// Each packed row starts with the `mr_packed` quantization params,
// followed by the `mr_packed` row_sums followed by the `mr_packed` rows of
// quantized data.
m = round_up(m, mr_packed);
k = round_up(k, kr * sr);
const size_t alignment = alignof(struct xnn_qd8_quantization_params);
return m * round_up(sizeof(struct xnn_qd8_quantization_params) +
sizeof(float) + k * sizeof(int8_t),
alignment);
}
size_t xnn_pack_lh_fx_qd8_packed_offset(size_t m, size_t k, size_t mr_packed,
size_t kr, size_t sr) {
// Each packed row starts with the `mr_packed` quantization params, followed
// by the `mr_packed` rows of quantized data.
assert(m % mr_packed == 0);
k = round_up(k, kr * sr);
const size_t alignment = alignof(struct xnn_qd8_quantization_params);
return m * round_up(sizeof(struct xnn_qd8_quantization_params) +
k * sizeof(int8_t),
alignment);
}
size_t xnn_pack_lh_fx_qd8_qc2w_packed_offset(size_t m, size_t k,
size_t mr_packed, size_t kr,
size_t sr) {
// Each packed row starts with the `mr_packed` quantization params, followed
// by the `mr_packed` row_sums, followed by the `mr_packed` rows of quantized
// data.
assert(m % mr_packed == 0);
k = round_up(k, kr * sr);
const size_t alignment = alignof(struct xnn_qd8_quantization_params);
return m * round_up(sizeof(struct xnn_qd8_quantization_params) +
sizeof(float) + k * sizeof(int8_t),
alignment);
}
// Wraps a templated function that generates `xnn_qd8_quantization_params` from
// parameters of the templated type.
template <typename T>
struct InitQuantizationParams {
typedef struct xnn_qd8_quantization_params (*fn)(T min, T max, T* scale);
};
template <typename InputT, typename OutputT, typename qs8_cvt_params_t,
typename InitQuantizationParams<InputT>::fn init_quantization_params>
static void pack_lh_fx_qd(size_t m, size_t k, size_t mr_packed, size_t kr,
size_t sr, size_t m_idx_start, const InputT* lhs,
size_t lhs_stride, void* lhs_packed,
xnn_vunary_ukernel_fn convert_ukernel,
xnn_reduce_ukernel_fn minmax_ukernel,
xnn_reduce_ukernel_fn rsum_ukernel) {
assert(m_idx_start == 0);
struct xnn_f32_default_params minmax_params;
qs8_cvt_params_t convert_params;
xnn_qs8_rsum_params rsum_params;
const size_t k_scaled = k * sizeof(InputT);
const uintptr_t packed_row_stride = round_up(k, kr * sr) * sizeof(OutputT);
const size_t packed_size =
rsum_ukernel
? xnn_pack_lh_fx_qd8_row_sums_packed_size(/*m=*/mr_packed, k,
mr_packed, kr, sr)
: xnn_pack_lh_fx_qd8_packed_size(/*m=*/mr_packed, k, mr_packed, kr,
sr);
while (m) {
// Pointers to the input and output data for this set of `mr` rows.
struct xnn_qd8_quantization_params* quantization_params =
static_cast<struct xnn_qd8_quantization_params*>(lhs_packed);
assert((uintptr_t)quantization_params %
alignof(struct xnn_qd8_quantization_params) == 0);
float* row_sum = nullptr;
OutputT* packed_weights =
reinterpret_cast<OutputT*>(reinterpret_cast<uintptr_t>(lhs_packed) +
mr_packed * sizeof(struct xnn_qd8_quantization_params));
if (rsum_ukernel) {
row_sum = reinterpret_cast<float*>(
reinterpret_cast<uintptr_t>(lhs_packed) +
mr_packed * sizeof(struct xnn_qd8_quantization_params));
packed_weights = reinterpret_cast<OutputT*>(
reinterpret_cast<uintptr_t>(packed_weights) +
mr_packed * sizeof(float));
}
// For each row in this block of `mr` rows...
for (size_t row_id = 0; row_id < min(mr_packed, m); ++row_id) {
// Compute the quantization params for this row.
InputT minmax[2] = {std::numeric_limits<float>::infinity(),
-std::numeric_limits<float>::infinity()};
InputT scale;
minmax_ukernel(k_scaled, lhs, minmax, &minmax_params);
quantization_params[row_id] =
init_quantization_params(minmax[0], minmax[1], &scale);
// Quantize the row.
convert_params.scalar.scale = scale;
convert_params.scalar.output_zero_point =
quantization_params[row_id].zero_point;
convert_ukernel(k_scaled, lhs, packed_weights,
(union xnn_unary_uparams*)&convert_params);
if (row_sum) {
// Compute and store the row sum of the quantized output.
const size_t num_bytes = k * sizeof(OutputT);
int32_t row_sum_value = 0;
rsum_ukernel(num_bytes, packed_weights, &row_sum_value, &rsum_params);
row_sum[row_id] = static_cast<float>(row_sum_value);
}
// Advance the input and output pointers.
lhs = (const InputT*)((uintptr_t)lhs + lhs_stride);
packed_weights =
(OutputT*)((uintptr_t)packed_weights + packed_row_stride);
}
// Copy any extra quantization params if needed.
for (size_t row_id = m; row_id < mr_packed; ++row_id) {
quantization_params[row_id] = quantization_params[m - 1];
if (row_sum) {
row_sum[row_id] = row_sum[m - 1];
}
}
// Advance the pointers and counters.
lhs_packed = (void*)((uintptr_t)lhs_packed + packed_size);
m -= min(mr_packed, m);
}
}
void xnn_pack_lh_f32_qdint8(size_t m, size_t k, size_t mr_packed, size_t kr,
size_t sr, size_t m_idx_start, const void* lhs,
size_t lhs_stride, void* lhs_packed) {
static const xnn_vunary_ukernel_fn convert_ukernel =
xnn_init_f32_to_qs8_cvt_config()->ukernel;
static const xnn_reduce_ukernel_fn minmax_ukernel =
xnn_init_f32_rminmax_config()->ukernel;
pack_lh_fx_qd</*InputT=*/float, /*OutputT=*/int8_t,
/*qs8_cvt_params_t=*/struct xnn_f32_qs8_cvt_params,
xnn_f32_qd8_asymmetric_quantization_params>(
m, k, mr_packed, kr, sr, m_idx_start, (const float*)lhs, lhs_stride,
lhs_packed, convert_ukernel, minmax_ukernel, /*rsum_ukernel=*/nullptr);
}
void xnn_pack_lh_f32_qdint8_qc2w(size_t m, size_t k, size_t mr_packed,
size_t kr, size_t sr, size_t m_idx_start,
const void* lhs, size_t lhs_stride,
void* lhs_packed) {
static const xnn_vunary_ukernel_fn convert_ukernel =
xnn_init_f32_to_qs8_cvt_config()->ukernel;
static const xnn_reduce_ukernel_fn minmax_ukernel =
xnn_init_f32_rminmax_config()->ukernel;
static const xnn_reduce_ukernel_fn rsum_ukernel =
xnn_init_qs8_rsum_config()->ukernel;
pack_lh_fx_qd</*InputT=*/float, /*OutputT=*/int8_t,
/*qs8_cvt_params_t=*/struct xnn_f32_qs8_cvt_params,
xnn_f32_qd8_asymmetric_quantization_params>(
m, k, mr_packed, kr, sr, m_idx_start, (const float*)lhs, lhs_stride,
lhs_packed, convert_ukernel, minmax_ukernel, rsum_ukernel);
}
void xnn_pack_lh_f32_qduint8_qc2w(size_t m, size_t k, size_t mr_packed,
size_t kr, size_t sr, size_t m_idx_start,
const void* lhs, size_t lhs_stride,
void* lhs_packed) {
static const xnn_vunary_ukernel_fn convert_ukernel =
xnn_init_f32_to_qu8_cvt_config()->ukernel;
static const xnn_reduce_ukernel_fn minmax_ukernel =
xnn_init_f32_rminmax_config()->ukernel;
static const xnn_reduce_ukernel_fn rsum_ukernel =
xnn_init_qu8_rsum_config()->ukernel;
pack_lh_fx_qd</*InputT=*/float, /*OutputT=*/uint8_t,
/*qs8_cvt_params_t=*/struct xnn_f32_qs8_cvt_params,
xnn_f32_qdu8_asymmetric_quantization_params>(
m, k, mr_packed, kr, sr, m_idx_start, (const float*)lhs, lhs_stride,
lhs_packed, convert_ukernel, minmax_ukernel, rsum_ukernel);
}
void xnn_pack_lh_f32_qduint8(size_t m, size_t k, size_t mr_packed, size_t kr,
size_t sr, size_t m_idx_start, const void* lhs,
size_t lhs_stride, void* lhs_packed) {
static const xnn_vunary_ukernel_fn convert_ukernel =
xnn_init_f32_to_qu8_cvt_config()->ukernel;
static const xnn_reduce_ukernel_fn minmax_ukernel =
xnn_init_f32_rminmax_config()->ukernel;
pack_lh_fx_qd</*InputT=*/float, /*OutputT=*/uint8_t,
/*qs8_cvt_params_t=*/struct xnn_f32_qs8_cvt_params,
xnn_f32_qdu8_asymmetric_quantization_params>(
m, k, mr_packed, kr, sr, m_idx_start, (const float*)lhs, lhs_stride,
lhs_packed, convert_ukernel, minmax_ukernel, /*rsum_ukernel=*/nullptr);
}
void xnn_pack_lh_f16_qdint8_qc2w(size_t m, size_t k, size_t mr_packed,
size_t kr, size_t sr, size_t m_idx_start,
const void* lhs, size_t lhs_stride,
void* lhs_packed) {
static const xnn_vunary_ukernel_fn convert_ukernel =
xnn_init_f16_to_qs8_cvt_config()->ukernel;
static const xnn_reduce_ukernel_fn minmax_ukernel =
xnn_init_f16_rminmax_config()->ukernel;
static const xnn_reduce_ukernel_fn rsum_ukernel =
xnn_init_qs8_rsum_config()->ukernel;
pack_lh_fx_qd</*InputT=*/xnn_float16, /*OutputT=*/int8_t,
/*qs8_cvt_params_t=*/struct xnn_f16_qs8_cvt_params,
xnn_f16_qd8_asymmetric_quantization_params>(
m, k, mr_packed, kr, sr, m_idx_start, (const xnn_float16*)lhs, lhs_stride,
lhs_packed, convert_ukernel, minmax_ukernel, rsum_ukernel);
}
void xnn_pack_lh_f16_qduint8_qc2w(size_t m, size_t k, size_t mr_packed,
size_t kr, size_t sr, size_t m_idx_start,
const void* lhs, size_t lhs_stride,
void* lhs_packed) {
static const xnn_vunary_ukernel_fn convert_ukernel =
xnn_init_f16_to_qu8_cvt_config()->ukernel;
static const xnn_reduce_ukernel_fn minmax_ukernel =
xnn_init_f16_rminmax_config()->ukernel;
static const xnn_reduce_ukernel_fn rsum_ukernel =
xnn_init_qu8_rsum_config()->ukernel;
pack_lh_fx_qd</*InputT=*/xnn_float16, /*OutputT=*/uint8_t,
/*qs8_cvt_params_t=*/struct xnn_f16_qs8_cvt_params,
xnn_f16_qdu8_asymmetric_quantization_params>(
m, k, mr_packed, kr, sr, m_idx_start, (const xnn_float16*)lhs, lhs_stride,
lhs_packed, convert_ukernel, minmax_ukernel, rsum_ukernel);
}
void xnn_pack_lh_f16_qdint8(size_t m, size_t k, size_t mr_packed, size_t kr,
size_t sr, size_t m_idx_start, const void* lhs,
size_t lhs_stride, void* lhs_packed) {
static const xnn_vunary_ukernel_fn convert_ukernel =
xnn_init_f16_to_qs8_cvt_config()->ukernel;
static const xnn_reduce_ukernel_fn minmax_ukernel =
xnn_init_f16_rminmax_config()->ukernel;
pack_lh_fx_qd</*InputT=*/xnn_float16, /*OutputT=*/int8_t,
/*qs8_cvt_params_t=*/struct xnn_f16_qs8_cvt_params,
xnn_f16_qd8_asymmetric_quantization_params>(
m, k, mr_packed, kr, sr, m_idx_start, (const xnn_float16*)lhs, lhs_stride,
lhs_packed, convert_ukernel, minmax_ukernel, /*rsum_ukernel=*/nullptr);
}
void xnn_pack_lh_f16_qduint8(size_t m, size_t k, size_t mr_packed, size_t kr,
size_t sr, size_t m_idx_start, const void* lhs,
size_t lhs_stride, void* lhs_packed) {
static const xnn_vunary_ukernel_fn convert_ukernel =
xnn_init_f16_to_qu8_cvt_config()->ukernel;
static const xnn_reduce_ukernel_fn minmax_ukernel =
xnn_init_f16_rminmax_config()->ukernel;
pack_lh_fx_qd</*InputT=*/xnn_float16, /*OutputT=*/uint8_t,
/*qs8_cvt_params_t=*/struct xnn_f16_qs8_cvt_params,
xnn_f16_qdu8_asymmetric_quantization_params>(
m, k, mr_packed, kr, sr, m_idx_start, (const xnn_float16*)lhs, lhs_stride,
lhs_packed, convert_ukernel, minmax_ukernel, /*rsum_ukernel=*/nullptr);
}
} // namespace
extern "C" {
const xnn_pack_lh_config* xnn_init_f16_qdint8_pack_lh_config() {
const xnn_hardware_config* hardware_config =
xnn_init_hardware_config();
if (hardware_config == nullptr) {
return nullptr;
}
static const xnn_pack_lh_config config = []() {
xnn_pack_lh_config config = {};
config.pack_lh_fn = (xnn_pack_lh_ukernel_fn)xnn_pack_lh_f16_qdint8;
config.size_fn = (xnn_pack_lh_size_fn)xnn_pack_lh_fx_qd8_packed_size;
config.offset_fn = (xnn_pack_lh_offset_fn)xnn_pack_lh_fx_qd8_packed_offset;
config.log2_input_element_size = XNN_LOG2_SIZEOF_FLOAT16;
config.log2_packed_element_size = 0;
return config;
}();
return &config;
}
const xnn_pack_lh_config* xnn_init_f16_qdint8_row_sums_pack_lh_config() {
const xnn_hardware_config* hardware_config =
xnn_init_hardware_config();
if (hardware_config == nullptr) {
return nullptr;
}
static const xnn_pack_lh_config config = []() {
xnn_pack_lh_config config = {};
config.pack_lh_fn = (xnn_pack_lh_ukernel_fn)xnn_pack_lh_f16_qdint8_qc2w;
config.size_fn =
(xnn_pack_lh_size_fn)xnn_pack_lh_fx_qd8_row_sums_packed_size;
config.offset_fn =
(xnn_pack_lh_offset_fn)xnn_pack_lh_fx_qd8_qc2w_packed_offset;
config.log2_input_element_size = XNN_LOG2_SIZEOF_FLOAT16;
config.log2_packed_element_size = 0;
return config;
}();
return &config;
}
const xnn_pack_lh_config* xnn_init_f16_qduint8_row_sums_pack_lh_config() {
const xnn_hardware_config* hardware_config =
xnn_init_hardware_config();
if (hardware_config == nullptr) {
return nullptr;
}
static const xnn_pack_lh_config config = []() {
xnn_pack_lh_config config = {};
config.pack_lh_fn = (xnn_pack_lh_ukernel_fn)xnn_pack_lh_f16_qduint8_qc2w;
config.size_fn =
(xnn_pack_lh_size_fn)xnn_pack_lh_fx_qd8_row_sums_packed_size;
config.offset_fn =
(xnn_pack_lh_offset_fn)xnn_pack_lh_fx_qd8_qc2w_packed_offset;
config.log2_input_element_size = XNN_LOG2_SIZEOF_FLOAT16;
config.log2_packed_element_size = 0;
return config;
}();
return &config;
}
const xnn_pack_lh_config* xnn_init_f16_qduint8_pack_lh_config() {
const xnn_hardware_config* hardware_config =
xnn_init_hardware_config();
if (hardware_config == nullptr) {
return nullptr;
}
static const xnn_pack_lh_config config = []() {
xnn_pack_lh_config config = {};
config.pack_lh_fn = (xnn_pack_lh_ukernel_fn)xnn_pack_lh_f16_qduint8;
config.size_fn = (xnn_pack_lh_size_fn)xnn_pack_lh_fx_qd8_packed_size;
config.offset_fn = (xnn_pack_lh_offset_fn)xnn_pack_lh_fx_qd8_packed_offset;
config.log2_input_element_size = XNN_LOG2_SIZEOF_FLOAT16;
config.log2_packed_element_size = 0;
return config;
}();
return &config;
}
const xnn_pack_lh_config* xnn_init_f32_qdint8_pack_lh_config() {
const xnn_hardware_config* hardware_config =
xnn_init_hardware_config();
if (hardware_config == nullptr) {
return nullptr;
}
static const xnn_pack_lh_config config = []() {
xnn_pack_lh_config config = {};
config.pack_lh_fn = (xnn_pack_lh_ukernel_fn)xnn_pack_lh_f32_qdint8;
config.size_fn = (xnn_pack_lh_size_fn)xnn_pack_lh_fx_qd8_packed_size;
config.offset_fn = (xnn_pack_lh_offset_fn)xnn_pack_lh_fx_qd8_packed_offset;
config.log2_input_element_size = XNN_LOG2_SIZEOF_FLOAT;
config.log2_packed_element_size = 0;
return config;
}();
return &config;
}
const xnn_pack_lh_config* xnn_init_f32_qdint8_row_sums_pack_lh_config() {
const xnn_hardware_config* hardware_config =
xnn_init_hardware_config();
if (hardware_config == nullptr) {
return nullptr;
}
static const xnn_pack_lh_config config = []() {
xnn_pack_lh_config config = {};
config.pack_lh_fn = (xnn_pack_lh_ukernel_fn)xnn_pack_lh_f32_qdint8_qc2w;
config.size_fn =
(xnn_pack_lh_size_fn)xnn_pack_lh_fx_qd8_row_sums_packed_size;
config.offset_fn =
(xnn_pack_lh_offset_fn)xnn_pack_lh_fx_qd8_qc2w_packed_offset;
config.log2_input_element_size = XNN_LOG2_SIZEOF_FLOAT;
config.log2_packed_element_size = 0;
return config;
}();
return &config;
}
const xnn_pack_lh_config* xnn_init_f32_qduint8_row_sums_pack_lh_config() {
const xnn_hardware_config* hardware_config =
xnn_init_hardware_config();
if (hardware_config == nullptr) {
return nullptr;
}
static const xnn_pack_lh_config config = []() {
xnn_pack_lh_config config = {};
config.pack_lh_fn = (xnn_pack_lh_ukernel_fn)xnn_pack_lh_f32_qduint8_qc2w;
config.size_fn =
(xnn_pack_lh_size_fn)xnn_pack_lh_fx_qd8_row_sums_packed_size;
config.offset_fn =
(xnn_pack_lh_offset_fn)xnn_pack_lh_fx_qd8_qc2w_packed_offset;
config.log2_input_element_size = XNN_LOG2_SIZEOF_FLOAT;
config.log2_packed_element_size = 0;
return config;
}();
return &config;
}
const xnn_pack_lh_config* xnn_init_f32_qduint8_pack_lh_config() {
const xnn_hardware_config* hardware_config =
xnn_init_hardware_config();
if (hardware_config == nullptr) {
return nullptr;
}
static const xnn_pack_lh_config config = []() {
xnn_pack_lh_config config = {};
config.pack_lh_fn = (xnn_pack_lh_ukernel_fn)xnn_pack_lh_f32_qduint8;
config.size_fn = (xnn_pack_lh_size_fn)xnn_pack_lh_fx_qd8_packed_size;
config.offset_fn = (xnn_pack_lh_offset_fn)xnn_pack_lh_fx_qd8_packed_offset;
config.log2_input_element_size = XNN_LOG2_SIZEOF_FLOAT;
config.log2_packed_element_size = 0;
return config;
}();
return &config;
}
} // extern "C"