blob: 6c449d7959068f5b747ee4e0d3f9fbd6416cceea [file] [log] [blame] [edit]
// Copyright (c) Facebook, Inc. and its affiliates.
// All rights reserved.
//
// Copyright 2019 Google LLC
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.
#include <algorithm>
#include <cassert>
#include <cstddef>
#include <cstdint>
#include <cstring>
#include "include/xnnpack.h"
#include "src/xnnpack/common.h"
#include "src/xnnpack/config-types.h"
#include "src/xnnpack/log.h"
#include "src/xnnpack/math.h"
#include "src/xnnpack/microfnptr.h"
#include "src/xnnpack/microparams-init.h"
#include "src/xnnpack/microparams.h"
#include "src/xnnpack/pack.h"
#include "src/xnnpack/unaligned.h"
#if XNN_ENABLE_KLEIDIAI
#include "kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme.h"
#include "kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme.h"
#include "kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme.h"
#include "kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme.h"
#include "kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0.h"
#include "kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi4cxp_qs4cxs1s0.h"
#include "kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme.h"
#include "kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi8cxp_qsi8cx_neon.h"
#include "kai/ukernels/matmul/pack/kai_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme.h"
#include "kai/ukernels/matmul/pack/kai_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme.h"
#include "kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0.h"
#include "kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0.h"
#include "kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxps1s0_qsu4cxs1s0_neon.h"
#include "kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi8cxp_qsi8cx_neon.h"
#include "kai/ukernels/matmul/pack/kai_rhs_pack_nxk_x16p2vlx2b_x16_x16_sme.h"
#include "src/xnnpack/allocator.h"
#endif // XNN_ENABLE_KLEIDIAI
class unaligned_int32_t {
public:
XNN_INLINE unaligned_int32_t( // NOLINT(google-explicit-constructor)
int32_t v) {
memcpy(value_, &v, sizeof(v));
}
XNN_INLINE operator int32_t() const { // NOLINT(google-explicit-constructor)
int32_t v;
memcpy(&v, value_, sizeof(v));
return v;
}
private:
char value_[sizeof(int32_t)];
};
template <typename Src, typename Dst>
void copy_bias(const Src* b, size_t b_offset, size_t n, Dst* packed_b) {
if (b) {
std::copy_n(b + b_offset, n, packed_b);
} else {
std::fill_n(packed_b, n, static_cast<Dst>(0));
}
}
template <typename Src, typename Dst>
void copy_bias(const Src* b, size_t b_offset, size_t n, Dst* packed_b,
Src zero_point) {
if (b) {
for (size_t i = 0; i < n; ++i) {
*packed_b++ = zero_point + b[b_offset + i];
}
} else {
std::fill_n(packed_b, n, zero_point);
}
}
template <typename Src, typename Dst>
int32_t copy_n_and_sum(const Src* src, size_t n, Dst* dst) {
int32_t sum = 0;
for (size_t i = 0; i < n; ++i) {
const auto v = *src++;
sum += (int32_t)v;
*dst++ = v;
}
return sum;
}
extern "C" {
void xnn_pack_f32_gemm_goi_w(size_t g, size_t nc, size_t kc, size_t nr,
size_t kr, size_t sr, const float* k,
const float* b, const void* scale,
float* packed_weights, size_t extra_bytes,
const void* params) {
assert(g != 0);
assert(nr >= sr);
assert(k != nullptr);
assert(packed_weights != nullptr);
const size_t skr = sr * kr;
do {
for (size_t nr_block_start = 0; nr_block_start < nc; nr_block_start += nr) {
const size_t nr_block_size = min(nc - nr_block_start, nr);
copy_bias(b, nr_block_start, nr_block_size, packed_weights);
packed_weights += nr;
for (size_t kr_block_start = 0; kr_block_start < round_up_po2(kc, skr);
kr_block_start += kr) {
for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size;
nr_block_offset++) {
const size_t kc_begin =
round_down_po2(kr_block_start, skr) +
((kr_block_start + nr_block_offset * kr) & (skr - 1));
const size_t kc_end = std::min(kc, kc_begin + kr);
float* end = packed_weights + kr;
if (kc_begin < kc_end) {
std::copy_n(&k[(nr_block_start + nr_block_offset) * kc + kc_begin],
kc_end - kc_begin, packed_weights);
packed_weights += kc_end - kc_begin;
}
std::fill(packed_weights, end, 0.0f);
packed_weights = end;
}
packed_weights += (nr - nr_block_size) * kr;
}
packed_weights = (float*)((uintptr_t)packed_weights + extra_bytes);
}
k += nc * kc;
if XNN_UNPREDICTABLE (b != nullptr) {
b += nc;
}
} while (--g != 0);
}
void xnn_pack_bf16_f32_gemm_goi_w(size_t g, size_t nc, size_t kc, size_t nr,
size_t kr, size_t sr, const xnn_bfloat16* k,
const float* bias, const void* scale,
void* packed_weights, size_t extra_bytes,
const void* params) {
assert(g != 0);
assert(nr >= sr);
assert(k != nullptr);
assert(packed_weights != nullptr);
const size_t skr = sr * kr;
do {
for (size_t nr_block_start = 0; nr_block_start < nc; nr_block_start += nr) {
const size_t nr_block_size = min(nc - nr_block_start, nr);
float* packed_weights_float = (float*)packed_weights;
copy_bias(bias, nr_block_start, nr_block_size, packed_weights_float);
packed_weights = (void*)((uintptr_t)packed_weights + nr * sizeof(float));
for (size_t kr_block_start = 0; kr_block_start < round_up_po2(kc, skr);
kr_block_start += kr) {
for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size;
nr_block_offset++) {
const size_t kc_begin =
round_down_po2(kr_block_start, skr) +
((kr_block_start + nr_block_offset * kr) & (skr - 1));
const size_t kc_end = std::min(kc, kc_begin + kr);
xnn_bfloat16* end = (xnn_bfloat16*)packed_weights + kr;
if (kc_begin < kc_end) {
std::copy_n(&k[(nr_block_start + nr_block_offset) * kc + kc_begin],
kc_end - kc_begin, (xnn_bfloat16*)packed_weights);
packed_weights = (xnn_bfloat16*)packed_weights + kc_end - kc_begin;
}
std::fill((xnn_bfloat16*)packed_weights, end, xnn_bfloat16(0.0f));
packed_weights = end;
}
packed_weights = (void*)((uintptr_t)packed_weights +
(nr - nr_block_size) * kr * sizeof(uint16_t));
}
packed_weights = (void*)((uintptr_t)packed_weights + extra_bytes);
}
k += nc * kc;
if XNN_UNPREDICTABLE (bias != nullptr) {
bias += nc;
}
} while (--g != 0);
}
void xnn_pack_f16_gemm_goi_w(size_t g, size_t nc, size_t kc, size_t nr,
size_t kr, size_t sr, const uint16_t* k,
const uint16_t* b, const void* scale,
uint16_t* packed_weights, size_t extra_bytes,
const void* params) {
assert(g != 0);
assert(nr >= sr);
assert(k != nullptr);
assert(packed_weights != nullptr);
const size_t skr = sr * kr;
do {
for (size_t nr_block_start = 0; nr_block_start < nc; nr_block_start += nr) {
const size_t nr_block_size = min(nc - nr_block_start, nr);
copy_bias(b, nr_block_start, nr_block_size, packed_weights);
packed_weights += nr;
for (size_t kr_block_start = 0; kr_block_start < round_up_po2(kc, skr);
kr_block_start += kr) {
for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size;
nr_block_offset++) {
const size_t kc_begin =
round_down_po2(kr_block_start, skr) +
((kr_block_start + nr_block_offset * kr) & (skr - 1));
const size_t kc_end = std::min(kc, kc_begin + kr);
uint16_t* end = packed_weights + kr;
if (kc_begin < kc_end) {
std::copy_n(&k[(nr_block_start + nr_block_offset) * kc + kc_begin],
kc_end - kc_begin, packed_weights);
packed_weights += kc_end - kc_begin;
}
std::fill(packed_weights, end, 0);
packed_weights = end;
}
packed_weights += (nr - nr_block_size) * kr;
}
packed_weights = (uint16_t*)((uintptr_t)packed_weights + extra_bytes);
}
k += nc * kc;
if XNN_UNPREDICTABLE (b != nullptr) {
b += nc;
}
} while (--g != 0);
}
void xnn_pack_f32_to_f16_gemm_goi_w(size_t g, size_t nc, size_t kc, size_t nr,
size_t kr, size_t sr, const float* k,
const float* b, const void* scale,
xnn_float16* packed_weights,
size_t extra_bytes, const void* params) {
assert(g != 0);
assert(nr >= sr);
assert(k != nullptr);
assert(packed_weights != nullptr);
const size_t skr = sr * kr;
do {
for (size_t nr_block_start = 0; nr_block_start < nc; nr_block_start += nr) {
const size_t nr_block_size = min(nc - nr_block_start, nr);
copy_bias(b, nr_block_start, nr_block_size, packed_weights);
packed_weights += nr;
for (size_t kr_block_start = 0; kr_block_start < round_up_po2(kc, skr);
kr_block_start += kr) {
for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size;
nr_block_offset++) {
const size_t kc_begin =
round_down_po2(kr_block_start, skr) +
((kr_block_start + nr_block_offset * kr) & (skr - 1));
const size_t kc_end = std::min(kc, kc_begin + kr);
xnn_float16* end = packed_weights + kr;
if (kc_begin < kc_end) {
std::copy_n(&k[(nr_block_start + nr_block_offset) * kc + kc_begin],
kc_end - kc_begin, packed_weights);
packed_weights += kc_end - kc_begin;
}
std::fill(packed_weights, end, xnn_float16(0.0f));
packed_weights = end;
}
packed_weights += (nr - nr_block_size) * kr;
}
packed_weights = (xnn_float16*)((uintptr_t)packed_weights + extra_bytes);
}
k += nc * kc;
if XNN_UNPREDICTABLE (b != nullptr) {
b += nc;
}
} while (--g != 0);
}
void xnn_pack_qu8_gemm_goi_w(size_t g, size_t nc, size_t kc, size_t nr,
size_t kr, size_t sr, const uint8_t* k,
const int32_t* b, const void* scale,
void* packed_weights, size_t extra_bytes,
const struct xnn_qu8_packing_params* params) {
assert(g != 0);
assert(nr >= sr);
assert(k != nullptr);
assert(packed_weights != nullptr);
const size_t skr = sr * kr;
const int32_t izp = (int32_t)params->input_zero_point;
const int32_t bzp = (int32_t)kc * izp * (int32_t)params->kernel_zero_point;
do {
for (size_t nr_block_start = 0; nr_block_start < nc; nr_block_start += nr) {
const size_t nr_block_size = min(nc - nr_block_start, nr);
unaligned_int32_t* packed_b = (unaligned_int32_t*)packed_weights;
copy_bias(b, nr_block_start, nr_block_size, packed_b, bzp);
packed_weights = (int32_t*)packed_weights + nr;
for (size_t kr_block_start = 0; kr_block_start < round_up_po2(kc, skr);
kr_block_start += kr) {
for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size;
nr_block_offset++) {
const size_t kc_begin =
round_down_po2(kr_block_start, skr) +
((kr_block_start + nr_block_offset * kr) & (skr - 1));
const size_t kc_end = std::min(kc, kc_begin + kr);
uint8_t* end = (uint8_t*)packed_weights + kr;
if (kc_begin < kc_end) {
int32_t ksum = copy_n_and_sum(
&k[(nr_block_start + nr_block_offset) * kc + kc_begin],
kc_end - kc_begin, (uint8_t*)packed_weights);
packed_weights = (int8_t*)packed_weights + kc_end - kc_begin;
packed_b[nr_block_offset] = packed_b[nr_block_offset] - ksum * izp;
}
std::fill((uint8_t*)packed_weights, end, params->kernel_zero_point);
packed_weights = end;
}
packed_weights = (uint8_t*)packed_weights + (nr - nr_block_size) * kr;
}
packed_weights =
reinterpret_cast<void*>((uintptr_t)packed_weights + extra_bytes);
}
k += nc * kc;
if XNN_UNPREDICTABLE (b != nullptr) {
b += nc;
}
} while (--g != 0);
}
void xnn_pack_qs8_gemm_goi_w(size_t g, size_t nc, size_t kc, size_t nr,
size_t kr, size_t sr, const int8_t* k,
const int32_t* b, const float* scale,
void* packed_weights, size_t extra_bytes,
const struct xnn_qs8_packing_params* params) {
assert(g != 0);
assert(nr >= sr);
assert(k != nullptr);
assert(packed_weights != nullptr);
const size_t skr = sr * kr;
const uint32_t izp = (uint32_t)params->input_zero_point;
do {
for (size_t nr_block_start = 0; nr_block_start < nc; nr_block_start += nr) {
const size_t nr_block_size = min(nc - nr_block_start, nr);
unaligned_int32_t* packed_b = (unaligned_int32_t*)packed_weights;
copy_bias(b, nr_block_start, nr_block_size, packed_b);
packed_weights = (int32_t*)packed_weights + nr;
for (size_t kr_block_start = 0; kr_block_start < round_up_po2(kc, skr);
kr_block_start += kr) {
for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size;
nr_block_offset++) {
const size_t kc_begin =
round_down_po2(kr_block_start, skr) +
((kr_block_start + nr_block_offset * kr) & (skr - 1));
const size_t kc_end = std::min(kc, kc_begin + kr);
int8_t* end = (int8_t*)packed_weights + kr;
if (kc_begin < kc_end) {
uint32_t ksum = copy_n_and_sum(
&k[(nr_block_start + nr_block_offset) * kc + kc_begin],
kc_end - kc_begin, (int8_t*)packed_weights);
packed_weights = (int8_t*)packed_weights + kc_end - kc_begin;
packed_b[nr_block_offset] = packed_b[nr_block_offset] - ksum * izp;
}
std::fill((int8_t*)packed_weights, end, INT8_C(0));
packed_weights = end;
}
packed_weights = (int8_t*)packed_weights + (nr - nr_block_size) * kr;
}
packed_weights =
reinterpret_cast<void*>((uintptr_t)packed_weights + extra_bytes);
}
k += nc * kc;
if XNN_UNPREDICTABLE (b != nullptr) {
b += nc;
}
} while (--g != 0);
}
void xnn_pack_qs8_to_qu8_gemm_goi_w(
size_t g, size_t nc, size_t kc, size_t nr, size_t kr, size_t sr,
const int8_t* k, const int32_t* b, const float* scale, void* packed_weights,
size_t extra_bytes, const struct xnn_qs8_packing_params* params) {
assert(g != 0);
assert(nr >= sr);
assert(k != nullptr);
assert(packed_weights != nullptr);
const size_t skr = sr * kr;
const uint32_t izp = (uint32_t)params->input_zero_point + 128;
do {
for (size_t nr_block_start = 0; nr_block_start < nc; nr_block_start += nr) {
const size_t nr_block_size = min(nc - nr_block_start, nr);
unaligned_int32_t* packed_b = (unaligned_int32_t*)packed_weights;
copy_bias(b, nr_block_start, nr_block_size, packed_b);
packed_weights = (int32_t*)packed_weights + nr;
for (size_t kr_block_start = 0; kr_block_start < round_up_po2(kc, skr);
kr_block_start += kr) {
for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size;
nr_block_offset++) {
const size_t kc_begin =
round_down_po2(kr_block_start, skr) +
((kr_block_start + nr_block_offset * kr) & (skr - 1));
const size_t kc_end = std::min(kc, kc_begin + kr);
int8_t* end = (int8_t*)packed_weights + kr;
if (kc_begin < kc_end) {
uint32_t ksum = copy_n_and_sum(
&k[(nr_block_start + nr_block_offset) * kc + kc_begin],
kc_end - kc_begin, (int8_t*)packed_weights);
packed_weights = (int8_t*)packed_weights + kc_end - kc_begin;
packed_b[nr_block_offset] = packed_b[nr_block_offset] - ksum * izp;
}
std::fill((int8_t*)packed_weights, end, INT8_C(0));
packed_weights = end;
}
packed_weights = (int8_t*)packed_weights + (nr - nr_block_size) * kr;
}
packed_weights =
reinterpret_cast<void*>((uintptr_t)packed_weights + extra_bytes);
}
k += nc * kc;
if XNN_UNPREDICTABLE (b != nullptr) {
b += nc;
}
} while (--g != 0);
}
namespace {
// Packs the weights so as to maximize performance in kernels.
static int8_t sign_extend_int4(int8_t value) { return (value ^ 0x8) - 8; }
static int8_t sign_extend_int2(int8_t value) { return (value ^ 0x2) - 2; }
void pack_qs8_qc4w_gemm_goi_w(
size_t g, size_t nc, size_t kc, size_t nr, size_t kr, size_t sr,
const uint8_t* k, const int32_t* b, const float* scale,
void* packed_weights, size_t extra_bytes,
uint32_t izp, uint32_t kernel_zero_point) {
assert(g != 0);
assert(nc != 0);
assert(kc != 0);
assert(nr >= sr);
assert(kr >= 1 && kr <= 16);
assert(sr >= 1 && sr <= 16);
assert(k != nullptr);
assert(packed_weights != nullptr);
assert(kernel_zero_point == 8 || kernel_zero_point == 0);
const size_t skr = sr * kr;
do {
size_t nr_block_start = 0;
do {
const size_t nr_block_size = min(nc - nr_block_start, nr);
unaligned_int32_t* packed_b = (unaligned_int32_t*)packed_weights;
if (b) {
for (size_t i = 0; i < nr_block_size; ++i) {
packed_b[i] = b[nr_block_start + i] * 16;
}
} else {
std::fill_n(packed_b, nr_block_size, 0);
}
packed_weights = (int32_t*)packed_weights + nr;
for (size_t kr_block_start = 0;
kr_block_start < round_up_po2(kc, skr * 2);
kr_block_start += kr * 2) {
for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size;
nr_block_offset++) {
int32_t ksum_lo = 0;
int32_t ksum_hi = 0;
const size_t kc_begin =
round_down_po2(kr_block_start, skr) +
((kr_block_start + nr_block_offset * kr) & (skr - 1));
for (size_t kr_block_offset = 0; kr_block_offset < kr;
kr_block_offset++) {
const size_t kc_idx = kc_begin + kr_block_offset;
const size_t k_offset =
(nr_block_start + nr_block_offset) * kc + kc_idx;
const size_t kh_offset = k_offset + kr;
if (kernel_zero_point == 0) {
int8_t kv_lo = 0;
if (kc_idx < kc) {
kv_lo = ((k_offset & 1) ? (k[k_offset >> 1] >> 4)
: (k[k_offset >> 1] & 0xF));
}
int8_t kv_hi = 0;
if ((kc_idx + kr) < kc) {
kv_hi = ((kh_offset & 1) ? (k[kh_offset >> 1] >> 4)
: (k[kh_offset >> 1] & 0xF));
}
const int8_t kv = (kv_lo | (kv_hi << 4));
kv_lo = sign_extend_int4(kv_lo);
kv_hi = sign_extend_int4(kv_hi);
ksum_lo += kv_lo;
ksum_hi += kv_hi;
((int8_t*)packed_weights)[kr_block_offset] = kv;
} else {
uint8_t kv_lo = kernel_zero_point;
if (kc_idx < kc) {
kv_lo = ((k_offset & 1) ? (k[k_offset >> 1] >> 4)
: (k[k_offset >> 1] & 0xF));
}
uint8_t kv_hi = kernel_zero_point;
if ((kc_idx + kr) < kc) {
kv_hi = ((kh_offset & 1) ? (k[kh_offset >> 1] >> 4)
: (k[kh_offset >> 1] & 0xF));
}
const uint8_t kv = (kv_lo | (kv_hi << 4)) ^ 0x88;
ksum_lo += kv_lo - kernel_zero_point;
ksum_hi += kv_hi - kernel_zero_point;
((uint8_t*)packed_weights)[kr_block_offset] = kv;
}
}
packed_b[nr_block_offset] =
packed_b[nr_block_offset] - (ksum_lo + ksum_hi) * izp * 16;
packed_weights = (uint8_t*)packed_weights + kr; // kr * 2 nibbles
}
packed_weights = (uint8_t*)packed_weights +
(nr - nr_block_size) * kr; // skip NR remainder
}
packed_weights =
reinterpret_cast<void*>((uintptr_t)packed_weights + extra_bytes);
nr_block_start += nr;
} while (nr_block_start < nc);
k += nc * kc; // kc * 2 nibbles
if XNN_UNPREDICTABLE (b != nullptr) {
b += nc;
}
} while (--g != 0);
}
} // namespace
void xnn_pack_qs8_qc4w_gemm_goi_w(
size_t g, size_t nc, size_t kc, size_t nr, size_t kr, size_t sr,
const uint8_t* k, const int32_t* b, const float* scale,
void* packed_weights, size_t extra_bytes,
const struct xnn_qs8_qc4w_packing_params* params) {
assert(params != nullptr);
pack_qs8_qc4w_gemm_goi_w(
g, nc, kc, nr, kr, sr,
k, b, scale, packed_weights, extra_bytes,
params->input_zero_point, params->kernel_zero_point);
}
void xnn_pack_qs8_to_qu8_qc4w_gemm_goi_w(
size_t g, size_t nc, size_t kc, size_t nr, size_t kr, size_t sr,
const uint8_t* k, const int32_t* b, const float* scale,
void* packed_weights, size_t extra_bytes,
const struct xnn_qs8_qc4w_packing_params* params) {
assert(params != nullptr);
uint32_t input_zero_point = (int32_t)params->input_zero_point + 0x80;
pack_qs8_qc4w_gemm_goi_w(
g, nc, kc, nr, kr, sr,
k, b, scale, packed_weights, extra_bytes,
input_zero_point, params->kernel_zero_point);
}
void pack_qs8_qc2w_gemm_gio_w(
size_t g, size_t nc, size_t kc, size_t nr, size_t kr, size_t sr,
size_t k_stride, const uint8_t* k, const int32_t* b, const float* scale,
void* packed_weights, size_t extra_bytes, uint32_t izp,
const struct xnn_qs8_qc2w_packing_params* params) {
assert(g != 0);
assert(nc != 0);
assert(kc != 0);
assert(nr >= sr);
assert(kr >= 1 && kr <= 16);
assert(sr >= 1 && sr <= 16);
assert(k != nullptr);
assert(packed_weights != nullptr);
assert(params != nullptr);
assert(params->kernel_zero_point == 0);
// row sums, weights, extra data
const size_t skr = sr * kr;
do {
size_t nr_block_start = 0;
do {
const size_t nr_block_size = min(nc - nr_block_start, nr);
unaligned_int32_t* packed_b =
static_cast<unaligned_int32_t*>(packed_weights);
copy_bias(b, nr_block_start, nr_block_size, packed_b);
packed_weights = static_cast<int32_t*>(packed_weights) + nr;
for (size_t kr_block_start = 0;
kr_block_start < round_up_po2(kc, skr * 4);
kr_block_start += kr * 4) {
for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size;
++nr_block_offset) {
int32_t ksum = 0;
for (size_t kr_block_offset = 0; kr_block_offset < kr;
++kr_block_offset) {
const size_t kc_idx =
round_down_po2(kr_block_start, skr) +
((kr_block_start + kr_block_offset + nr_block_offset * kr) &
(skr - 1));
const size_t oc = nr_block_start + nr_block_offset;
int8_t kv_0 = 0, kv_1 = 0, kv_2 = 0, kv_3 = 0;
if (kc_idx < kc) {
const size_t k_element_offset = kc_idx * k_stride + oc;
const int crumb_shift = (k_element_offset & 3) * 2;
kv_0 = (k[k_element_offset >> 2] >> crumb_shift) & 0x3;
}
if (kc_idx + kr < kc) {
const size_t k_element_offset = (kc_idx + kr) * k_stride + oc;
const int crumb_shift = (k_element_offset & 3) * 2;
kv_1 = (k[k_element_offset >> 2] >> crumb_shift) & 0x3;
}
if (kc_idx + 2 * kr < kc) {
const size_t k_element_offset = (kc_idx + 2 * kr) * k_stride + oc;
const int crumb_shift = (k_element_offset & 3) * 2;
kv_2 = (k[k_element_offset >> 2] >> crumb_shift) & 0x3;
}
if (kc_idx + 3 * kr < kc) {
const size_t k_element_offset = (kc_idx + 3 * kr) * k_stride + oc;
const int crumb_shift = (k_element_offset & 3) * 2;
kv_3 = (k[k_element_offset >> 2] >> crumb_shift) & 0x3;
}
const int8_t kv = (kv_0 | (kv_1 << 2) | (kv_2 << 4) | (kv_3 << 6));
kv_0 = sign_extend_int2(kv_0);
kv_1 = sign_extend_int2(kv_1);
kv_2 = sign_extend_int2(kv_2);
kv_3 = sign_extend_int2(kv_3);
ksum += kv_0 + kv_1 + kv_2 + kv_3;
static_cast<int8_t*>(packed_weights)[kr_block_offset] = kv;
}
packed_b[nr_block_offset] = packed_b[nr_block_offset] - ksum * izp;
// kr * 4 crumbs
packed_weights = static_cast<uint8_t*>(packed_weights) + kr;
}
packed_weights = static_cast<uint8_t*>(packed_weights) +
(nr - nr_block_size) * kr;
}
packed_weights = reinterpret_cast<void*>(
reinterpret_cast<uintptr_t>(packed_weights) + extra_bytes);
nr_block_start += nr;
} while (nr_block_start < nc);
k += nc * kc; // kc * 4 crumbs
if XNN_UNPREDICTABLE (b != nullptr) {
b += nc;
}
} while (--g != 0);
}
void pack_qs8_qc2w_gemm_goi_w(
size_t g, size_t nc, size_t kc, size_t nr, size_t kr, size_t sr,
const uint8_t* k, const int32_t* b, const float* scale,
void* packed_weights, size_t extra_bytes, uint32_t izp,
bool make_weights_unsigned,
const struct xnn_qs8_qc2w_packing_params* params) {
assert(g != 0);
assert(nc != 0);
assert(kc != 0);
assert(nr >= sr);
assert(kr >= 1 && kr <= 16);
assert(sr >= 1 && sr <= 16);
assert(k != nullptr);
assert(packed_weights != nullptr);
assert(params != nullptr);
assert(params->kernel_zero_point == 0);
const size_t skr = sr * kr;
do {
size_t nr_block_start = 0;
do {
const size_t nr_block_size = min(nc - nr_block_start, nr);
unaligned_int32_t* packed_b =
static_cast<unaligned_int32_t*>(packed_weights);
copy_bias(b, nr_block_start, nr_block_size, packed_b);
packed_weights = static_cast<int32_t*>(packed_weights) + nr;
for (size_t kr_block_start = 0;
kr_block_start < round_up_po2(kc, skr * 4);
kr_block_start += kr * 4) {
for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size;
++nr_block_offset) {
int32_t ksum = 0;
const size_t kc_begin =
round_down_po2(kr_block_start, skr) +
((kr_block_start + nr_block_offset * kr) & (skr - 1));
for (size_t kr_block_offset = 0; kr_block_offset < kr;
++kr_block_offset) {
const size_t kc_idx = kc_begin + kr_block_offset;
const size_t k_offset =
(nr_block_start + nr_block_offset) * kc + kc_idx;
int8_t kv_0 = 0, kv_1 = 0, kv_2 = 0, kv_3 = 0;
if (kc_idx < kc) {
kv_0 = (k[k_offset >> 2] >> ((k_offset & 3) * 2)) & 0x3;
}
if (kc_idx + kr < kc) {
const size_t offset = k_offset + kr;
kv_1 = (k[offset >> 2] >> ((offset & 3) * 2)) & 0x3;
}
if (kc_idx + 2 * kr < kc) {
const size_t offset = k_offset + 2 * kr;
kv_2 = (k[offset >> 2] >> ((offset & 3) * 2)) & 0x3;
}
if (kc_idx + 3 * kr < kc) {
const size_t offset = k_offset + 3 * kr;
kv_3 = (k[offset >> 2] >> ((offset & 3) * 2)) & 0x3;
}
const int8_t kv = (kv_0 | (kv_1 << 2) | (kv_2 << 4) | (kv_3 << 6));
kv_0 = sign_extend_int2(kv_0);
kv_1 = sign_extend_int2(kv_1);
kv_2 = sign_extend_int2(kv_2);
kv_3 = sign_extend_int2(kv_3);
ksum += kv_0 + kv_1 + kv_2 + kv_3;
if (make_weights_unsigned) {
static_cast<int8_t*>(packed_weights)[kr_block_offset] = kv ^ 0xAA;
} else {
static_cast<int8_t*>(packed_weights)[kr_block_offset] = kv;
}
}
packed_b[nr_block_offset] = packed_b[nr_block_offset] - ksum * izp;
// kr * 4 crumbs
packed_weights = static_cast<uint8_t*>(packed_weights) + kr;
}
packed_weights = static_cast<uint8_t*>(packed_weights) +
(nr - nr_block_size) * kr; // skip NR remainder
}
packed_weights = reinterpret_cast<void*>(
reinterpret_cast<uintptr_t>(packed_weights) + extra_bytes);
nr_block_start += nr;
} while (nr_block_start < nc);
k += nc * kc; // kc * 4 crumbs
if XNN_UNPREDICTABLE (b != nullptr) {
b += nc;
}
} while (--g != 0);
}
void xnn_pack_qs8_qc2w_gemm_gio_w(
size_t g, size_t nc, size_t kc, size_t nr, size_t kr, size_t sr,
size_t k_stride, const uint8_t* k, const int32_t* b, const float* scale,
void* packed_weights, size_t extra_bytes,
const struct xnn_qs8_qc2w_packing_params* params) {
assert(params != nullptr);
pack_qs8_qc2w_gemm_gio_w(
g, nc, kc, nr, kr, sr, k_stride, k, b, scale, packed_weights, extra_bytes,
params->input_zero_point, params);
}
void xnn_pack_qs8_qc2w_gemm_goi_w(
size_t g, size_t nc, size_t kc, size_t nr, size_t kr, size_t sr,
const uint8_t* k, const int32_t* b, const float* scale,
void* packed_weights, size_t extra_bytes,
const struct xnn_qs8_qc2w_packing_params* params) {
assert(params != nullptr);
pack_qs8_qc2w_gemm_goi_w(g, nc, kc, nr, kr, sr, k, b, scale, packed_weights,
extra_bytes, params->input_zero_point,
/*make_weights_unsigned=*/false, params);
}
void xnn_pack_qs8_to_qu8_qc2w_gemm_gio_w(
size_t g, size_t nc, size_t kc, size_t nr, size_t kr, size_t sr,
size_t k_stride, const uint8_t* k, const int32_t* b, const float* scale,
void* packed_weights, size_t extra_bytes,
const struct xnn_qs8_qc2w_packing_params* params) {
assert(params != nullptr);
uint32_t input_zero_point = (int32_t)params->input_zero_point + 0x80;
pack_qs8_qc2w_gemm_gio_w(
g, nc, kc, nr, kr, sr, k_stride, k, b, scale, packed_weights, extra_bytes,
input_zero_point, params);
}
void xnn_pack_qs8_to_qu8_qc2w_gemm_goi_w(
size_t g, size_t nc, size_t kc, size_t nr, size_t kr, size_t sr,
const uint8_t* k, const int32_t* b, const float* scale,
void* packed_weights, size_t extra_bytes,
const struct xnn_qs8_qc2w_packing_params* params) {
assert(params != nullptr);
uint32_t input_zero_point = (int32_t)params->input_zero_point + 0x80;
pack_qs8_qc2w_gemm_goi_w(
g, nc, kc, nr, kr, sr, k, b, scale, packed_weights, extra_bytes,
input_zero_point, /*make_weights_unsigned=*/true, params);
}
void xnn_pack_qd8_qc2w_gemm_goi_w(
size_t g, size_t nc, size_t kc, size_t nr, size_t kr, size_t sr,
const uint8_t* k, const int32_t* b, const float* scale,
void* packed_weights, size_t extra_bytes,
const struct xnn_qd8_qc2w_packing_params* params) {
assert(g != 0);
assert(nc != 0);
assert(kc != 0);
assert(nr >= sr);
assert(kr >= 1 && kr <= 16);
assert(sr >= 1 && sr <= 16);
assert(k != nullptr);
assert(packed_weights != nullptr);
assert(params != nullptr);
const size_t skr = sr * kr;
const int32_t izp = static_cast<int32_t>(params->input_zero_point);
do {
size_t nr_block_start = 0;
do {
const size_t nr_block_size = min(nc - nr_block_start, nr);
unaligned_int32_t* packed_b =
static_cast<unaligned_int32_t*>(packed_weights);
copy_bias(b, nr_block_start, nr_block_size, packed_b);
packed_weights = static_cast<int32_t*>(packed_weights) + nr;
for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size;
++nr_block_offset) {
unaligned_store_f32(
packed_weights,
params->kernel_zero_point == nullptr
? 0.0f
: params->kernel_zero_point[nr_block_start + nr_block_offset]);
packed_weights = static_cast<float*>(packed_weights) + 1;
}
packed_weights =
static_cast<float*>(packed_weights) + (nr - nr_block_size);
for (size_t kr_block_start = 0;
kr_block_start < round_up_po2(kc, skr * 4);
kr_block_start += kr * 4) {
for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size;
++nr_block_offset) {
int32_t ksum = 0;
const size_t kc_begin =
round_down_po2(kr_block_start, skr) +
((kr_block_start + nr_block_offset * kr) & (skr - 1));
for (size_t kr_block_offset = 0; kr_block_offset < kr;
++kr_block_offset) {
const size_t kc_idx = kc_begin + kr_block_offset;
const size_t k_offset =
(nr_block_start + nr_block_offset) * kc + kc_idx;
int8_t kv_0 = 0, kv_1 = 0, kv_2 = 0, kv_3 = 0;
if (kc_idx < kc) {
kv_0 = (k[k_offset >> 2] >> ((k_offset & 3) * 2)) & 0x3;
}
if (kc_idx + kr < kc) {
const size_t offset = k_offset + kr;
kv_1 = (k[offset >> 2] >> ((offset & 3) * 2)) & 0x3;
}
if (kc_idx + 2 * kr < kc) {
const size_t offset = k_offset + 2 * kr;
kv_2 = (k[offset >> 2] >> ((offset & 3) * 2)) & 0x3;
}
if (kc_idx + 3 * kr < kc) {
const size_t offset = k_offset + 3 * kr;
kv_3 = (k[offset >> 2] >> ((offset & 3) * 2)) & 0x3;
}
const int8_t kv = (kv_0 | (kv_1 << 2) | (kv_2 << 4) | (kv_3 << 6));
kv_0 = sign_extend_int2(kv_0);
kv_1 = sign_extend_int2(kv_1);
kv_2 = sign_extend_int2(kv_2);
kv_3 = sign_extend_int2(kv_3);
ksum += kv_0 + kv_1 + kv_2 + kv_3;
static_cast<int8_t*>(packed_weights)[kr_block_offset] = kv ^ 0xAA;
}
packed_b[nr_block_offset] = packed_b[nr_block_offset] - ksum * izp;
// kr * 4 crumbs
packed_weights = static_cast<uint8_t*>(packed_weights) + kr;
}
packed_weights = static_cast<uint8_t*>(packed_weights) +
(nr - nr_block_size) * kr;
}
packed_weights = reinterpret_cast<void*>(
reinterpret_cast<uintptr_t>(packed_weights) + extra_bytes);
nr_block_start += nr;
} while (nr_block_start < nc);
k += nc * kc; // kc * 4 crumbs
if XNN_UNPREDICTABLE (b != nullptr) {
b += nc;
}
} while (--g != 0);
}
void xnn_pack_qd8_qc2w_gemm_gio_w(
size_t g, size_t nc, size_t kc, size_t nr, size_t kr, size_t sr,
size_t k_stride, const uint8_t* k, const int32_t* b, const float* scale,
void* packed_weights, size_t extra_bytes,
const struct xnn_qd8_qc2w_packing_params* params) {
assert(g != 0);
assert(nc != 0);
assert(kc != 0);
assert(nr >= sr);
assert(kr >= 1 && kr <= 16);
assert(sr >= 1 && sr <= 16);
assert(k != nullptr);
assert(packed_weights != nullptr);
assert(params != nullptr);
// row sums, weights, zero points, extra data
const size_t skr = sr * kr;
const int32_t izp = static_cast<int32_t>(params->input_zero_point);
do {
size_t nr_block_start = 0;
do {
const size_t nr_block_size = min(nc - nr_block_start, nr);
unaligned_int32_t* packed_b =
static_cast<unaligned_int32_t*>(packed_weights);
copy_bias(b, nr_block_start, nr_block_size, packed_b);
packed_weights = static_cast<int32_t*>(packed_weights) + nr;
// Skip another nr for the float zero points
for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size;
++nr_block_offset) {
unaligned_store_f32(
packed_weights,
params->kernel_zero_point == nullptr
? 0.0f
: params->kernel_zero_point[nr_block_start + nr_block_offset]);
packed_weights = static_cast<float*>(packed_weights) + 1;
}
packed_weights =
static_cast<float*>(packed_weights) + (nr - nr_block_size);
for (size_t kr_block_start = 0;
kr_block_start < round_up_po2(kc, skr * 4);
kr_block_start += kr * 4) {
for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size;
++nr_block_offset) {
int32_t ksum = 0;
for (size_t kr_block_offset = 0; kr_block_offset < kr;
++kr_block_offset) {
const size_t kc_idx =
round_down_po2(kr_block_start, skr) +
((kr_block_start + kr_block_offset + nr_block_offset * kr) &
(skr - 1));
const size_t oc = nr_block_start + nr_block_offset;
int8_t kv_0 = 0, kv_1 = 0, kv_2 = 0, kv_3 = 0;
if (kc_idx < kc) {
const size_t k_element_offset = kc_idx * k_stride + oc;
const int crumb_shift = (k_element_offset & 3) * 2;
kv_0 = (k[k_element_offset >> 2] >> crumb_shift) & 0x3;
}
if (kc_idx + kr < kc) {
const size_t k_element_offset = (kc_idx + kr) * k_stride + oc;
const int crumb_shift = (k_element_offset & 3) * 2;
kv_1 = (k[k_element_offset >> 2] >> crumb_shift) & 0x3;
}
if (kc_idx + 2 * kr < kc) {
const size_t k_element_offset = (kc_idx + 2 * kr) * k_stride + oc;
const int crumb_shift = (k_element_offset & 3) * 2;
kv_2 = (k[k_element_offset >> 2] >> crumb_shift) & 0x3;
}
if (kc_idx + 3 * kr < kc) {
const size_t k_element_offset = (kc_idx + 3 * kr) * k_stride + oc;
const int crumb_shift = (k_element_offset & 3) * 2;
kv_3 = (k[k_element_offset >> 2] >> crumb_shift) & 0x3;
}
const int8_t kv = (kv_0 | (kv_1 << 2) | (kv_2 << 4) | (kv_3 << 6));
kv_0 = sign_extend_int2(kv_0);
kv_1 = sign_extend_int2(kv_1);
kv_2 = sign_extend_int2(kv_2);
kv_3 = sign_extend_int2(kv_3);
ksum += kv_0 + kv_1 + kv_2 + kv_3;
static_cast<int8_t*>(packed_weights)[kr_block_offset] = kv ^ 0xAA;
}
packed_b[nr_block_offset] = packed_b[nr_block_offset] - ksum * izp;
// kr * 4 crumbs
packed_weights = static_cast<uint8_t*>(packed_weights) + kr;
}
packed_weights = static_cast<uint8_t*>(packed_weights) +
(nr - nr_block_size) * kr; // skip NR remainder
}
packed_weights = reinterpret_cast<void*>(
reinterpret_cast<uintptr_t>(packed_weights) + extra_bytes);
nr_block_start += nr;
} while (nr_block_start < nc);
k += nc * kc; // kc * 4 crumbs
if XNN_UNPREDICTABLE (b != nullptr) {
b += nc;
}
} while (--g != 0);
}
namespace {
// Packs the weights so as to minimize register usage in kernels.
// For example:
// 0 1
// 2 3
// 4 5
// 6 7
// 8 9
// A B
// C D
// E F
//
// is packed for a Mx8c4 microkernel as:
// (row sums) 1 5 9 13 17 21 2 29 | (packed weights) 08 19 00 00 | 2A 3B 00 00 |
// 4C 5D 00 | 6E 7F 00 00 The row sums are packed first. In contrast to planar
// packing which packs the weights from the same channel side by side, so
// position + kr. The register bytes parameter is needed so that we know the
// offset between each weight's load.
void xnn_pack_qs8_qc4w_gemm_goi_w_non_planar(
size_t g, size_t nc, size_t kc, size_t nr, size_t kr, size_t sr,
size_t register_bytes, const uint8_t* k, const int32_t* b,
const float* scale, void* packed_weights, size_t extra_bytes,
uint32_t input_zero_point, uint32_t kernel_zero_point) {
assert(g != 0);
assert(nc != 0);
assert(kc != 0);
assert(nr >= sr);
assert(kr >= 1 && kr <= 16);
assert(sr >= 1 && sr <= 16);
assert(k != nullptr);
assert(packed_weights != nullptr);
assert(kernel_zero_point == 8 || kernel_zero_point == 0);
const size_t skr = sr * kr;
int row_offset = register_bytes / kr;
do {
size_t nr_block_start = 0;
do {
size_t nr_block_size = min(nc - nr_block_start, nr);
unaligned_int32_t* packed_b = (unaligned_int32_t*)packed_weights;
if (b) {
for (size_t i = 0; i < nr_block_size; ++i) {
packed_b[i] = b[nr_block_start + i] * 16;
}
} else {
std::fill_n(packed_b, nr_block_size, 0);
}
packed_weights = (int32_t*)packed_weights + nr;
size_t num_k_blocks = round_up_po2(kc, skr * 1);
for (size_t kr_block_start = 0; kr_block_start < num_k_blocks;
kr_block_start += kr * 1) {
void* pw = packed_weights;
for (size_t nr_block_offset_ = 0; nr_block_offset_ < nr_block_size;
nr_block_offset_ += row_offset * 2) {
for (size_t inner_nr_block_offset = 0;
inner_nr_block_offset < row_offset; inner_nr_block_offset += 1) {
size_t actual_nr_block_offset =
inner_nr_block_offset + nr_block_offset_;
int32_t ksum_lo = 0;
int32_t ksum_hi = 0;
const size_t kc_begin =
round_down_po2(kr_block_start, skr) +
((kr_block_start + actual_nr_block_offset * kr) & (skr - 1));
for (size_t kr_block_offset = 0; kr_block_offset < kr;
kr_block_offset++) {
const size_t kc_idx = kc_begin + kr_block_offset;
const size_t k_offset =
(nr_block_start + actual_nr_block_offset) * kc + kc_idx;
const size_t kh_offset = k_offset + kc * row_offset;
if (kernel_zero_point == 0) {
int8_t kv_lo = 0;
if ((nr_block_start + actual_nr_block_offset) < nc) {
if (kc_idx < kc) {
kv_lo = ((k_offset & 1) ? (k[k_offset >> 1] >> 4)
: (k[k_offset >> 1] & 0xF));
}
}
int8_t kv_hi = 0;
if ((nr_block_start + actual_nr_block_offset + row_offset) <
nc) {
if (kc_idx < kc) {
kv_hi = ((kh_offset & 1) ? (k[kh_offset >> 1] >> 4)
: (k[kh_offset >> 1] & 0xF));
}
}
// Pack and flip the sign bit.
const int8_t kv = (kv_lo | (kv_hi << 4));
kv_lo = sign_extend_int4(kv_lo);
kv_hi = sign_extend_int4(kv_hi);
ksum_lo += kv_lo;
ksum_hi += kv_hi;
((uint8_t*)pw)[kr_block_offset] = kv;
} else {
uint8_t kv_lo = kernel_zero_point;
if ((nr_block_start + actual_nr_block_offset) < nc) {
if (kc_idx < kc) {
kv_lo = ((k_offset & 1) ? (k[k_offset >> 1] >> 4)
: (k[k_offset >> 1] & 0xF));
}
}
uint8_t kv_hi = kernel_zero_point;
if ((nr_block_start + actual_nr_block_offset + row_offset) <
nc) {
if (kc_idx < kc) {
kv_hi = ((kh_offset & 1) ? (k[kh_offset >> 1] >> 4)
: (k[kh_offset >> 1] & 0xF));
}
}
// Pack and flip the sign bit.
const uint8_t kv = (kv_lo | (kv_hi << 4)) ^ 0x88;
ksum_lo += kv_lo - kernel_zero_point;
ksum_hi += kv_hi - kernel_zero_point;
((uint8_t*)pw)[kr_block_offset] = kv;
}
}
packed_b[actual_nr_block_offset] =
packed_b[actual_nr_block_offset] -
ksum_lo * input_zero_point * 16;
packed_b[actual_nr_block_offset + row_offset] =
packed_b[actual_nr_block_offset + row_offset] -
ksum_hi * input_zero_point * 16;
pw = (uint8_t*)pw + kr; // kr * 2 nibbles
}
}
packed_weights =
(uint8_t*)packed_weights + nr * kr / 2; // skip NR remainder
}
packed_weights =
reinterpret_cast<void*>((uintptr_t)packed_weights + extra_bytes);
nr_block_start += nr;
} while (nr_block_start < nc);
k += nc * kc; // kc * 2 nibbles
if XNN_UNPREDICTABLE (b != nullptr) {
b += nc;
}
} while (--g != 0);
}
} // namespace
void xnn_pack_qs8_qc4w_gemm_goi_w_non_planar_scalar(
size_t g, size_t nc, size_t kc, size_t nr, size_t kr, size_t sr,
const uint8_t* k, const int32_t* b, const float* scale,
void* packed_weights, size_t extra_bytes,
const struct xnn_qs8_qc4w_packing_params* params) {
assert(params != nullptr);
xnn_pack_qs8_qc4w_gemm_goi_w_non_planar(
g, nc, kc, nr, kr, sr,
/*register_bytes=*/1, k, b, scale, packed_weights, extra_bytes,
params->input_zero_point, params->kernel_zero_point);
}
void xnn_pack_qs8_qc4w_gemm_goi_w_non_planar_aarch64(
size_t g, size_t nc, size_t kc, size_t nr, size_t kr, size_t sr,
const uint8_t* k, const int32_t* b, const float* scale,
void* packed_weights, size_t extra_bytes,
const struct xnn_qs8_qc4w_packing_params* params) {
assert(params != nullptr);
xnn_pack_qs8_qc4w_gemm_goi_w_non_planar(
g, nc, kc, nr, kr, sr,
/*register_bytes=*/16, k, b, scale, packed_weights, extra_bytes,
params->input_zero_point, params->kernel_zero_point);
}
void xnn_pack_qs8_qc4w_gemm_goi_w_non_planar_avx512(
size_t g, size_t nc, size_t kc, size_t nr, size_t kr, size_t sr,
const uint8_t* k, const int32_t* b, const float* scale,
void* packed_weights, size_t extra_bytes,
const struct xnn_qs8_qc4w_packing_params* params) {
assert(params != nullptr);
xnn_pack_qs8_qc4w_gemm_goi_w_non_planar(
g, nc, kc, nr, kr, sr,
/*register_bytes=*/64, k, b, scale, packed_weights, extra_bytes,
params->input_zero_point, params->kernel_zero_point);
}
void xnn_pack_qs8_to_qu8_qc4w_gemm_goi_w_non_planar_avx512(
size_t g, size_t nc, size_t kc, size_t nr, size_t kr, size_t sr,
const uint8_t* k, const int32_t* b, const float* scale,
void* packed_weights, size_t extra_bytes,
const struct xnn_qs8_qc4w_packing_params* params) {
assert(params != nullptr);
uint32_t input_zero_point = (int32_t)params->input_zero_point + 0x80;
xnn_pack_qs8_qc4w_gemm_goi_w_non_planar(
g, nc, kc, nr, kr, sr,
/*register_bytes=*/64, k, b, scale, packed_weights, extra_bytes,
input_zero_point, params->kernel_zero_point);
}
// Same as qc4w but unsigned 4 bit output
// Applies kv ^ 0x88 to convert int4 to uint4
// Does not multiply bias by 16
static void pack_qs8_qc4uw_gemm_goi_w(
size_t g, size_t nc, size_t kc, size_t nr, size_t kr, size_t sr,
const uint8_t* k, const int32_t* b, const float* scale,
void* packed_weights, size_t extra_bytes,
uint32_t izp, uint32_t kernel_zero_point) {
assert(g != 0);
assert(nc != 0);
assert(kc != 0);
assert(nr >= sr);
assert(kr >= 1 && kr <= 16);
assert(sr >= 1 && sr <= 16);
assert(k != nullptr);
assert(packed_weights != nullptr);
assert(kernel_zero_point == 8 || kernel_zero_point == 0);
const size_t skr = sr * kr;
do {
size_t nr_block_start = 0;
do {
const size_t nr_block_size = min(nc - nr_block_start, nr);
unaligned_int32_t* packed_b = (unaligned_int32_t*)packed_weights;
copy_bias(b, nr_block_start, nr_block_size, packed_b);
packed_weights = (int32_t*)packed_weights + nr;
for (size_t kr_block_start = 0;
kr_block_start < round_up_po2(kc, skr * 2);
kr_block_start += kr * 2) {
for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size;
nr_block_offset++) {
int32_t ksum = 0;
const size_t kc_begin =
round_down_po2(kr_block_start, skr) +
((kr_block_start + nr_block_offset * kr) & (skr - 1));
for (size_t kr_block_offset = 0; kr_block_offset < kr;
kr_block_offset++) {
const size_t kc_idx = kc_begin + kr_block_offset;
const size_t k_offset =
(nr_block_start + nr_block_offset) * kc + kc_idx;
const size_t kh_offset = k_offset + kr;
if (kernel_zero_point == 0) {
int8_t kv_lo = 0;
if (kc_idx < kc) {
kv_lo = ((k_offset & 1) ? (k[k_offset >> 1] >> 4)
: (k[k_offset >> 1] & 0xF));
}
int8_t kv_hi = 0;
if ((kc_idx + kr) < kc) {
kv_hi = ((kh_offset & 1) ? (k[kh_offset >> 1] >> 4)
: (k[kh_offset >> 1] & 0xF));
}
const int8_t kv = (kv_lo | (kv_hi << 4));
kv_lo = sign_extend_int4(kv_lo);
kv_hi = sign_extend_int4(kv_hi);
ksum += kv_lo + kv_hi;
((int8_t*)packed_weights)[kr_block_offset] =
kv ^ 0x88; // Convert to uint4
} else {
uint8_t kv_lo = kernel_zero_point;
if (kc_idx < kc) {
kv_lo = ((k_offset & 1) ? (k[k_offset >> 1] >> 4)
: (k[k_offset >> 1] & 0xF));
}
uint8_t kv_hi = kernel_zero_point;
if ((kc_idx + kr) < kc) {
kv_hi = ((kh_offset & 1) ? (k[kh_offset >> 1] >> 4)
: (k[kh_offset >> 1] & 0xF));
}
const uint8_t kv = (kv_lo | (kv_hi << 4)) ^ 0x88;
ksum += kv_lo + kv_hi -
2 * kernel_zero_point; // subtract 2 zero points
((uint8_t*)packed_weights)[kr_block_offset] =
kv ^ 0x88; // Convert to uint4
}
}
packed_b[nr_block_offset] = packed_b[nr_block_offset] - ksum * izp;
packed_weights = (uint8_t*)packed_weights + kr; // kr * 2 nibbles
}
packed_weights = (uint8_t*)packed_weights +
(nr - nr_block_size) * kr; // skip NR remainder
}
packed_weights =
reinterpret_cast<void*>((uintptr_t)packed_weights + extra_bytes);
nr_block_start += nr;
} while (nr_block_start < nc);
k += nc * kc; // kc * 2 nibbles
if XNN_UNPREDICTABLE (b != nullptr) {
b += nc;
}
} while (--g != 0);
}
// For qd8_qc4w madd
void xnn_pack_qs8_qc4uw_gemm_goi_w(
size_t g, size_t nc, size_t kc, size_t nr, size_t kr, size_t sr,
const uint8_t* k, const int32_t* b, const float* scale,
void* packed_weights, size_t extra_bytes,
const struct xnn_qs8_qc4w_packing_params* params) {
assert(params != nullptr);
pack_qs8_qc4uw_gemm_goi_w(
g, nc, kc, nr, kr, sr,
k, b, scale, packed_weights, extra_bytes,
params->input_zero_point, params->kernel_zero_point);
}
// For qs8_qc4w madd
void xnn_pack_qs8_to_qu8_qc4uw_gemm_goi_w(
size_t g, size_t nc, size_t kc, size_t nr, size_t kr, size_t sr,
const uint8_t* k, const int32_t* b, const float* scale,
void* packed_weights, size_t extra_bytes,
const struct xnn_qs8_qc4w_packing_params* params) {
assert(params != nullptr);
pack_qs8_qc4uw_gemm_goi_w(
g, nc, kc, nr, kr, sr,
k, b, scale, packed_weights, extra_bytes,
params->input_zero_point + 0x80, params->kernel_zero_point);
}
void xnn_pack_qs8_qb4w_gemm_goi_w(
size_t g, size_t nc, size_t kc, size_t nr, size_t kr, size_t sr,
size_t bl, // blocksize
const uint8_t* k, // kernel
const float* bias, const xnn_bfloat16* scale, void* packed_weights,
size_t extra_bytes_bl, // extra bytes per block
size_t extra_bytes_n, // extra bytes per n
const struct xnn_qs8_qc4w_packing_params* params) {
assert(g != 0);
assert(nc != 0);
assert(kc != 0);
assert(nr >= sr);
assert(kr >= 1 && kr <= 16);
assert(sr >= 1 && sr <= 16);
assert(k != nullptr);
assert(packed_weights != nullptr);
assert(params != nullptr);
assert(params->kernel_zero_point == 8 || params->kernel_zero_point == 0);
assert(bias == nullptr); // Not used here. Must be updated outside.
const size_t skr = sr * kr;
// Constraints for blocksize
// These need to be reevaluated in the future.
assert(bl != 0);
assert(round_up_po2(kc, skr) % bl ==
0); // must be round number of blocks inside a column
assert(bl % skr == 0); // must be round number of kr * sr
assert(bl <= round_up_po2(kc, skr)); // must not be larger than K
assert(2 * skr <=
bl); // must be at least two skr to avoid back-to-back extra_bytes
const size_t num_blocks = round_up_po2(kc, skr) / bl;
const int32_t izp = (int32_t)params->input_zero_point;
const uint32_t kernel_zero_point = (uint32_t)params->kernel_zero_point;
do {
size_t nr_block_start = 0;
do {
const size_t nr_block_size = min(nc - nr_block_start, nr);
float* packed_b = (float*)packed_weights;
std::fill_n(packed_b, nr, 0.0f);
packed_weights = (float*)packed_weights + nr;
for (size_t kr_block_start = 0;
kr_block_start < round_up_po2(kc, skr * 2);
kr_block_start += kr * 2) {
for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size;
nr_block_offset++) {
int32_t ksum = 0;
const size_t kc_begin =
round_down_po2(kr_block_start, skr) +
((kr_block_start + nr_block_offset * kr) & (skr - 1));
for (size_t kr_block_offset = 0; kr_block_offset < kr;
kr_block_offset++) {
const size_t kc_idx = kc_begin + kr_block_offset;
const size_t k_offset =
(nr_block_start + nr_block_offset) * kc + kc_idx;
const size_t kh_offset = k_offset + kr;
if (kernel_zero_point == 0) {
int8_t kv_lo = 0;
if (kc_idx < kc) {
kv_lo = ((k_offset & 1) ? (k[k_offset >> 1] >> 4)
: (k[k_offset >> 1] & 0xF));
}
int8_t kv_hi = 0;
if ((kc_idx + kr) < kc) {
kv_hi = ((kh_offset & 1) ? (k[kh_offset >> 1] >> 4)
: (k[kh_offset >> 1] & 0xF));
}
const int8_t kv = (kv_lo | (kv_hi << 4));
kv_lo = sign_extend_int4(kv_lo);
kv_hi = sign_extend_int4(kv_hi);
ksum += kv_lo + kv_hi;
((int8_t*)packed_weights)[kr_block_offset] = kv;
} else {
uint8_t kv_lo = 8;
if (kc_idx < kc) {
kv_lo = ((k_offset & 1) ? (k[k_offset >> 1] >> 4)
: (k[k_offset >> 1] & 0xF));
}
uint8_t kv_hi = 8;
if ((kc_idx + kr) < kc) {
kv_hi = ((kh_offset & 1) ? (k[kh_offset >> 1] >> 4)
: (k[kh_offset >> 1] & 0xF));
}
ksum += kv_lo + kv_hi - 16; // subtract 2 zero points (8)
const uint8_t kv = (kv_lo | (kv_hi << 4)) ^ 0x88;
((uint8_t*)packed_weights)[kr_block_offset] = kv;
}
}
size_t block_index = kr_block_start / bl;
size_t scale_index =
(nr_block_start + nr_block_offset) * num_blocks + block_index;
unaligned_indexed_store_f32(
packed_b, nr_block_offset,
unaligned_indexed_load_f32(packed_b, nr_block_offset) -
(float)ksum * izp *
xnn_bfloat16_to_float(scale[scale_index]));
packed_weights = (uint8_t*)packed_weights + kr; // kr * 2 nibbles
}
if (((2 * kr) + kr_block_start) % bl == 0) {
packed_weights = (void*)((uintptr_t)packed_weights + extra_bytes_bl);
}
packed_weights = (uint8_t*)packed_weights +
(nr - nr_block_size) * kr; // skip NR remainder
}
packed_weights = (void*)((uintptr_t)packed_weights + extra_bytes_n);
nr_block_start += nr;
} while (nr_block_start < nc);
k += nc * kc; // kc * 2 nibbles
} while (--g != 0);
}
void xnn_pack_qs8_qb4w_gemm_gio_w(
size_t g, size_t nc, size_t kc, size_t nr, size_t kr, size_t sr,
size_t k_stride,
size_t bl, // block size
const uint8_t* k, // kernel
const float* bias,
const xnn_bfloat16* scale, // block scales (bf16 format)
void* packed_weights,
size_t extra_bytes_bl, // extra bytes per block
size_t extra_bytes_n, // extra bytes per n
const struct xnn_qs8_qc4w_packing_params* params) {
assert(g != 0);
assert(nc != 0);
assert(kc != 0);
assert(nr >= sr);
assert(kr >= 1 && kr <= 16);
assert(sr >= 1 && sr <= 16);
assert(k != nullptr);
assert(packed_weights != nullptr);
assert(params != nullptr);
assert(params->kernel_zero_point == 8);
assert(bias == nullptr); // Not used here. Must be updated outside.
const size_t skr = sr * kr;
// Constraints for blocksize
// These need to be reevaluated in the future.
assert(bl != 0);
assert(round_up_po2(kc, skr) % bl ==
0); // must be round number of blocks inside a column
assert(bl % skr == 0); // must be round number of kr * sr
assert(bl <= round_up_po2(kc, skr)); // must not be larger than K
assert(2 * skr <=
bl); // must be at least two skr to avoid back-to-back extra_bytes
const size_t num_blocks = round_up_po2(kc, skr) / bl;
const int32_t izp = (int32_t)params->input_zero_point;
do {
size_t nr_block_start = 0;
do {
const size_t nr_block_size = min(nc - nr_block_start, nr);
int32_t* packed_b = (int32_t*)packed_weights;
std::fill_n(packed_b, nr, 0);
packed_weights = (float*)packed_weights + nr;
for (size_t kr_block_start = 0;
kr_block_start < round_up_po2(kc, skr * 2);
kr_block_start += kr * 2) {
for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size;
nr_block_offset++) {
int32_t ksum = 0;
const size_t kc_begin =
round_down_po2(kr_block_start, skr) +
((kr_block_start + nr_block_offset * kr) & (skr - 1));
for (size_t kr_block_offset = 0; kr_block_offset < kr;
kr_block_offset++) {
const size_t kc_idx = kc_begin + kr_block_offset;
const size_t k_offset =
(nr_block_start + nr_block_offset + kc_idx * k_stride);
const size_t kh_offset = k_offset + (kr * k_stride);
uint8_t kv_lo = 8;
if (kc_idx < kc) {
kv_lo = ((k_offset & 1) ? (k[k_offset >> 1] >> 4)
: (k[k_offset >> 1] & 0xF));
}
uint8_t kv_hi = 8;
if ((kc_idx + kr) < kc) {
kv_hi = ((kh_offset & 1) ? (k[kh_offset >> 1] >> 4)
: (k[kh_offset >> 1] & 0xF));
}
ksum += kv_lo + kv_hi - 16; // subtract 2 zero points (8)
const uint8_t kv = (kv_lo | (kv_hi << 4)) ^ 0x88;
((uint8_t*)packed_weights)[kr_block_offset] = kv;
}
size_t block_index = kr_block_start / bl;
size_t scale_index =
(nr_block_start + nr_block_offset) * num_blocks + block_index;
unaligned_indexed_store_f32(
packed_b, nr_block_offset,
unaligned_indexed_load_f32(packed_b, nr_block_offset) -
(float)ksum * izp *
xnn_bfloat16_to_float(scale[scale_index]));
packed_weights = (uint8_t*)packed_weights + kr; // kr * 2 nibbles
}
if (((2 * kr) + kr_block_start) % bl == 0) {
packed_weights = (void*)((uintptr_t)packed_weights + extra_bytes_bl);
}
packed_weights = (uint8_t*)packed_weights +
(nr - nr_block_size) * kr; // skip NR remainder
}
packed_weights = (void*)((uintptr_t)packed_weights + extra_bytes_n);
nr_block_start += nr;
} while (nr_block_start < nc);
k += nc * kc; // kc * 2 nibbles
} while (--g != 0);
}
void xnn_pack_qs8_qc4w_gemm_gio_w(
size_t g, size_t nc, size_t kc, size_t nr, size_t kr, size_t sr,
size_t k_stride, const uint8_t* k, const int32_t* b, const float* scale,
void* packed_weights, size_t extra_bytes,
const struct xnn_qs8_qc4w_packing_params* params) {
assert(g != 0);
assert(nc != 0);
assert(kc != 0);
assert(nr >= sr);
assert(kr >= 1 && kr <= 16);
assert(sr >= 1 && sr <= 16);
assert(k != nullptr);
assert(packed_weights != nullptr);
assert(params != nullptr);
assert(params->kernel_zero_point == 8 || params->kernel_zero_point == 0);
const size_t skr = sr * kr;
const uint32_t izp = (uint32_t)params->input_zero_point;
const uint32_t kernel_zero_point = (uint32_t)params->kernel_zero_point;
do {
size_t nr_block_start = 0;
do {
const size_t nr_block_size = min(nc - nr_block_start, nr);
unaligned_int32_t* packed_b = (unaligned_int32_t*)packed_weights;
if (b) {
for (size_t i = 0; i < nr_block_size; ++i) {
packed_b[i] = b[nr_block_start + i] * 16;
}
} else {
std::fill_n(packed_b, nr_block_size, 0);
}
packed_weights = (int32_t*)packed_weights + nr;
for (size_t kr_block_start = 0;
kr_block_start < round_up_po2(kc, skr * 2);
kr_block_start += kr * 2) {
for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size;
nr_block_offset++) {
int32_t ksum = 0;
const size_t kc_begin =
round_down_po2(kr_block_start, skr) +
((kr_block_start + nr_block_offset * kr) & (skr - 1));
for (size_t kr_block_offset = 0; kr_block_offset < kr;
kr_block_offset++) {
const size_t kc_idx = kc_begin + kr_block_offset;
const size_t k_offset =
kc_idx * k_stride + (nr_block_start + nr_block_offset);
const size_t kh_offset =
(kc_idx + kr) * k_stride + (nr_block_start + nr_block_offset);
if (kernel_zero_point == 0) {
int8_t kv_lo = 0;
if (kc_idx < kc) {
kv_lo = ((k_offset & 1) ? (k[k_offset >> 1] >> 4)
: (k[k_offset >> 1] & 0xF));
}
int8_t kv_hi = 0;
if ((kc_idx + kr) < kc) {
kv_hi = ((kh_offset & 1) ? (k[kh_offset >> 1] >> 4)
: (k[kh_offset >> 1] & 0xF));
}
const int8_t kv = (kv_lo | (kv_hi << 4));
kv_lo = sign_extend_int4(kv_lo);
kv_hi = sign_extend_int4(kv_hi);
ksum += kv_lo + kv_hi;
((int8_t*)packed_weights)[kr_block_offset] = kv;
} else {
uint8_t kv_lo = kernel_zero_point;
if (kc_idx < kc) {
kv_lo = ((k_offset & 1) ? (k[k_offset >> 1] >> 4)
: (k[k_offset >> 1] & 0xF));
}
uint8_t kv_hi = kernel_zero_point;
if ((kc_idx + kr) < kc) {
kv_hi = ((kh_offset & 1) ? (k[kh_offset >> 1] >> 4)
: (k[kh_offset >> 1] & 0xF));
}
ksum += kv_lo + kv_hi -
2 * kernel_zero_point; // subtract 2 zero points
const uint8_t kv = (kv_lo | (kv_hi << 4)) ^ 0x88;
((uint8_t*)packed_weights)[kr_block_offset] = kv;
}
}
packed_b[nr_block_offset] =
packed_b[nr_block_offset] - ksum * izp * 16;
packed_weights = (uint8_t*)packed_weights + kr; // kr * 2 nibbles
}
packed_weights = (uint8_t*)packed_weights +
(nr - nr_block_size) * kr; // skip NR remainder
}
packed_weights =
reinterpret_cast<void*>((uintptr_t)packed_weights + extra_bytes);
nr_block_start += nr;
} while (nr_block_start < nc);
k += nc * kc; // kc * 2 nibbles
if XNN_UNPREDICTABLE (b != nullptr) {
b += nc;
}
} while (--g != 0);
}
// Same as qc4w but unsigned 4 bit output
// Applies kv ^ 0x88 to convert int4 to uint4
// Does not multiply bias by 16
void xnn_pack_qs8_qc4uw_gemm_gio_w(
size_t g, size_t nc, size_t kc, size_t nr, size_t kr, size_t sr,
size_t k_stride, const uint8_t* k, const int32_t* b, const float* scale,
void* packed_weights, size_t extra_bytes,
const struct xnn_qs8_qc4w_packing_params* params) {
assert(g != 0);
assert(nc != 0);
assert(kc != 0);
assert(nr >= sr);
assert(kr >= 1 && kr <= 16);
assert(sr >= 1 && sr <= 16);
assert(k != nullptr);
assert(packed_weights != nullptr);
assert(params != nullptr);
assert(params->kernel_zero_point == 8 || params->kernel_zero_point == 0);
const size_t skr = sr * kr;
const uint32_t izp = (uint32_t)params->input_zero_point;
const uint32_t kernel_zero_point = (uint32_t)params->kernel_zero_point;
do {
size_t nr_block_start = 0;
do {
const size_t nr_block_size = min(nc - nr_block_start, nr);
unaligned_int32_t* packed_b = (unaligned_int32_t*)packed_weights;
copy_bias(b, nr_block_start, nr_block_size, packed_b);
packed_weights = (int32_t*)packed_weights + nr;
for (size_t kr_block_start = 0;
kr_block_start < round_up_po2(kc, skr * 2);
kr_block_start += kr * 2) {
for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size;
nr_block_offset++) {
int32_t ksum = 0;
const size_t kc_begin =
round_down_po2(kr_block_start, skr) +
((kr_block_start + nr_block_offset * kr) & (skr - 1));
for (size_t kr_block_offset = 0; kr_block_offset < kr;
kr_block_offset++) {
const size_t kc_idx = kc_begin + kr_block_offset;
const size_t k_offset =
kc_idx * k_stride + (nr_block_start + nr_block_offset);
const size_t kh_offset =
(kc_idx + kr) * k_stride + (nr_block_start + nr_block_offset);
if (kernel_zero_point == 0) {
int8_t kv_lo = 0;
if (kc_idx < kc) {
kv_lo = ((k_offset & 1) ? (k[k_offset >> 1] >> 4)
: (k[k_offset >> 1] & 0xF));
}
int8_t kv_hi = 0;
if ((kc_idx + kr) < kc) {
kv_hi = ((kh_offset & 1) ? (k[kh_offset >> 1] >> 4)
: (k[kh_offset >> 1] & 0xF));
}
const int8_t kv = (kv_lo | (kv_hi << 4));
kv_lo = sign_extend_int4(kv_lo);
kv_hi = sign_extend_int4(kv_hi);
ksum += kv_lo + kv_hi;
((int8_t*)packed_weights)[kr_block_offset] =
kv ^ 0x88; // Convert to uint4
} else {
uint8_t kv_lo = kernel_zero_point;
if (kc_idx < kc) {
kv_lo = ((k_offset & 1) ? (k[k_offset >> 1] >> 4)
: (k[k_offset >> 1] & 0xF));
}
uint8_t kv_hi = kernel_zero_point;
if ((kc_idx + kr) < kc) {
kv_hi = ((kh_offset & 1) ? (k[kh_offset >> 1] >> 4)
: (k[kh_offset >> 1] & 0xF));
}
ksum += kv_lo + kv_hi -
2 * kernel_zero_point; // subtract 2 zero points
const uint8_t kv = (kv_lo | (kv_hi << 4)) ^ 0x88;
((uint8_t*)packed_weights)[kr_block_offset] =
kv ^ 0x88; // Convert to uint4
}
}
packed_b[nr_block_offset] = packed_b[nr_block_offset] - ksum * izp;
packed_weights = (uint8_t*)packed_weights + kr; // kr * 2 nibbles
}
packed_weights = (uint8_t*)packed_weights +
(nr - nr_block_size) * kr; // skip NR remainder
}
packed_weights =
reinterpret_cast<void*>((uintptr_t)packed_weights + extra_bytes);
nr_block_start += nr;
} while (nr_block_start < nc);
k += nc * kc; // kc * 2 nibbles
if XNN_UNPREDICTABLE (b != nullptr) {
b += nc;
}
} while (--g != 0);
}
void xnn_pack_f32_qs8w_gemm_goi_w(size_t g, size_t nc, size_t kc, size_t nr,
size_t kr, size_t sr, const int8_t* k,
const float* bias, const float* scale,
void* packed_weights, size_t extra_bytes,
const void* params) {
assert(g != 0);
assert(nr >= sr);
assert(k != nullptr);
assert(packed_weights != nullptr);
const int32_t* b = (const int32_t*)bias;
const size_t skr = sr * kr;
do {
for (size_t nr_block_start = 0; nr_block_start < nc; nr_block_start += nr) {
const size_t nr_block_size = min(nc - nr_block_start, nr);
unaligned_int32_t* packed_b = (unaligned_int32_t*)packed_weights;
copy_bias(b, nr_block_start, nr_block_size, packed_b);
packed_weights = (int32_t*)packed_weights + nr;
for (size_t kr_block_start = 0; kr_block_start < round_up_po2(kc, skr);
kr_block_start += kr) {
for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size;
nr_block_offset++) {
const size_t kc_begin =
round_down_po2(kr_block_start, skr) +
((kr_block_start + nr_block_offset * kr) & (skr - 1));
const size_t kc_end = std::min(kc, kc_begin + kr);
int8_t* end = (int8_t*)packed_weights + kr;
if (kc_begin < kc_end) {
std::copy_n(&k[(nr_block_start + nr_block_offset) * kc + kc_begin],
kc_end - kc_begin, (int8_t*)packed_weights);
packed_weights = (int8_t*)packed_weights + kc_end - kc_begin;
}
std::fill((int8_t*)packed_weights, end, INT8_C(0));
packed_weights = end;
}
packed_weights = (int8_t*)packed_weights + (nr - nr_block_size) * kr;
}
packed_weights =
reinterpret_cast<void*>((uintptr_t)packed_weights + extra_bytes);
}
k += nc * kc;
if XNN_UNPREDICTABLE (b != nullptr) {
b += nc;
}
} while (--g != 0);
}
// qs4 packs 2 columns into 2 rows.
// kc can be odd. assume k values in a row are padded to a byte boundary
void xnn_pack_f32_qc4w_gemm_goi_w(size_t g, size_t nc, size_t kc, size_t nr,
size_t kr, size_t sr,
const void* k, // 4 bit values
const float* bias, const float* scale,
void* packed_weights, size_t extra_bytes,
const void* params) {
assert(g != 0);
assert(nr >= sr);
assert(k != nullptr);
assert(packed_weights != nullptr);
kc = (kc + 1) >> 1;
const int32_t* b = (const int32_t*)bias;
const size_t skr = sr * kr;
do {
for (size_t nr_block_start = 0; nr_block_start < nc; nr_block_start += nr) {
const size_t nr_block_size = min(nc - nr_block_start, nr);
unaligned_int32_t* packed_b = (unaligned_int32_t*)packed_weights;
copy_bias(b, nr_block_start, nr_block_size, packed_b);
packed_weights = (int32_t*)packed_weights + nr;
for (size_t kr_block_start = 0; kr_block_start < round_up_po2(kc, skr);
kr_block_start += kr) {
for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size;
nr_block_offset++) {
const size_t kc_begin =
round_down_po2(kr_block_start, skr) +
((kr_block_start + nr_block_offset * kr) & (skr - 1));
const size_t kc_end = std::min(kc, kc_begin + kr);
uint8_t* end = (uint8_t*)packed_weights + kr;
if (kc_begin < kc_end) {
std::copy_n(
&((const uint8_t*)
k)[(nr_block_start + nr_block_offset) * kc + kc_begin],
kc_end - kc_begin, (uint8_t*)packed_weights);
packed_weights = (uint8_t*)packed_weights + kc_end - kc_begin;
}
std::fill((uint8_t*)packed_weights, end, UINT8_C(0));
packed_weights = end;
}
packed_weights = (uint8_t*)packed_weights + (nr - nr_block_size) * kr;
}
packed_weights =
reinterpret_cast<void*>((uintptr_t)packed_weights + extra_bytes);
}
k = (const uint8_t*)k + nc * kc;
if XNN_UNPREDICTABLE (b != nullptr) {
b += nc;
}
} while (--g != 0);
}
void xnn_pack_f32_gemm_gio_w(size_t g, size_t nc, size_t kc, size_t nr,
size_t kr, size_t sr, size_t k_stride,
const float* k, const float* b, const void* scale,
float* packed_weights, size_t extra_bytes,
const void* params) {
assert(g != 0);
assert(nr >= sr);
assert(k != nullptr);
assert(packed_weights != nullptr);
const size_t skr = sr * kr;
do {
for (size_t nr_block_start = 0; nr_block_start < nc; nr_block_start += nr) {
const size_t nr_block_size = min(nc - nr_block_start, nr);
copy_bias(b, nr_block_start, nr_block_size, packed_weights);
packed_weights += nr;
// Special case for trivial packings.
if (skr == 1) {
for (size_t kr_block_start = 0; kr_block_start < kc; kr_block_start++) {
const size_t kc_idx = round_down_po2(kr_block_start, skr);
if (kc_idx < kc) {
std::copy_n(&k[kc_idx * k_stride + nr_block_start], nr_block_size,
packed_weights);
}
packed_weights += nr;
}
} else {
for (size_t kr_block_start = 0; kr_block_start < round_up_po2(kc, skr);
kr_block_start += kr) {
for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size;
nr_block_offset++) {
const size_t kc_begin =
round_down_po2(kr_block_start, skr) +
((kr_block_start + nr_block_offset * kr) & (skr - 1));
for (size_t kr_block_offset = 0; kr_block_offset < kr;
kr_block_offset++) {
const size_t kc_idx = kc_begin + kr_block_offset;
packed_weights[kr_block_offset] =
kc_idx < kc
? k[kc_idx * k_stride + nr_block_start + nr_block_offset]
: 0.0f;
}
packed_weights += kr;
}
packed_weights += (nr - nr_block_size) * kr;
}
}
packed_weights = (float*)((uintptr_t)packed_weights + extra_bytes);
}
k += nc * kc;
if XNN_UNPREDICTABLE (b != nullptr) {
b += nc;
}
} while (--g != 0);
}
void xnn_pack_bf16_f32_gemm_gio_w(size_t g, size_t nc, size_t kc, size_t nr,
size_t kr, size_t sr, size_t k_stride,
const xnn_bfloat16* k, const float* b,
const void* scale, void* packed_weights,
size_t extra_bytes, const void* params) {
assert(g != 0);
assert(nr >= sr);
assert(k != nullptr);
assert(packed_weights != nullptr);
const size_t skr = sr * kr;
do {
for (size_t nr_block_start = 0; nr_block_start < nc; nr_block_start += nr) {
const size_t nr_block_size = min(nc - nr_block_start, nr);
copy_bias(b, nr_block_start, nr_block_size, (float*)packed_weights);
packed_weights = (float*)packed_weights + nr;
// Special case for trivial packings.
if (skr == 1) {
for (size_t kr_block_start = 0; kr_block_start < kc; kr_block_start++) {
const size_t kc_idx = round_down_po2(kr_block_start, skr);
if (kc_idx < kc) {
std::copy_n(&k[kc_idx * k_stride + nr_block_start], nr_block_size,
(xnn_bfloat16*)packed_weights);
}
packed_weights = (xnn_bfloat16*)packed_weights + nr;
}
} else {
for (size_t kr_block_start = 0; kr_block_start < round_up_po2(kc, skr);
kr_block_start += kr) {
for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size;
nr_block_offset++) {
const size_t kc_begin =
round_down_po2(kr_block_start, skr) +
((kr_block_start + nr_block_offset * kr) & (skr - 1));
for (size_t kr_block_offset = 0; kr_block_offset < kr;
kr_block_offset++) {
const size_t kc_idx = kc_begin + kr_block_offset;
((xnn_bfloat16*)packed_weights)[kr_block_offset] =
kc_idx < kc
? k[kc_idx * k_stride + nr_block_start + nr_block_offset]
: static_cast<xnn_bfloat16>(0.0f);
}
packed_weights = (xnn_bfloat16*)packed_weights + kr;
}
packed_weights =
(xnn_bfloat16*)packed_weights + (nr - nr_block_size) * kr;
}
}
packed_weights = (float*)((uintptr_t)packed_weights + extra_bytes);
}
k += nc * kc;
if XNN_UNPREDICTABLE (b != nullptr) {
b += nc;
}
} while (--g != 0);
}
void xnn_pack_f16_gemm_gio_w(size_t g, size_t nc, size_t kc, size_t nr,
size_t kr, size_t sr, size_t k_stride,
const uint16_t* k, const uint16_t* b,
const void* scale, uint16_t* packed_weights,
size_t extra_bytes, const void* params) {
assert(g != 0);
assert(nr >= sr);
assert(k != nullptr);
assert(packed_weights != nullptr);
const size_t skr = sr * kr;
do {
for (size_t nr_block_start = 0; nr_block_start < nc; nr_block_start += nr) {
const size_t nr_block_size = min(nc - nr_block_start, nr);
copy_bias(b, nr_block_start, nr_block_size, packed_weights);
packed_weights += nr;
// Special case for trivial packings.
if (skr == 1) {
for (size_t kr_block_start = 0; kr_block_start < kc; kr_block_start++) {
const size_t kc_idx = round_down_po2(kr_block_start, skr);
if (kc_idx < kc) {
std::copy_n(&k[kc_idx * k_stride + nr_block_start], nr_block_size,
packed_weights);
}
packed_weights += nr;
}
} else {
for (size_t kr_block_start = 0; kr_block_start < round_up_po2(kc, skr);
kr_block_start += kr) {
for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size;
nr_block_offset++) {
const size_t kc_begin =
round_down_po2(kr_block_start, skr) +
((kr_block_start + nr_block_offset * kr) & (skr - 1));
for (size_t kr_block_offset = 0; kr_block_offset < kr;
kr_block_offset++) {
const size_t kc_idx = kc_begin + kr_block_offset;
packed_weights[kr_block_offset] =
kc_idx < kc
? k[kc_idx * k_stride + nr_block_start + nr_block_offset]
: UINT16_C(0);
}
packed_weights += kr;
}
packed_weights += (nr - nr_block_size) * kr;
}
}
packed_weights = (uint16_t*)((uintptr_t)packed_weights + extra_bytes);
}
k += nc * kc;
if XNN_UNPREDICTABLE (b != nullptr) {
b += nc;
}
} while (--g != 0);
}
void xnn_pack_f32_to_f16_gemm_gio_w(size_t g, size_t nc, size_t kc, size_t nr,
size_t kr, size_t sr, size_t k_stride,
const float* k, const float* b,
const void* scale,
xnn_float16* packed_weights,
size_t extra_bytes, const void* params) {
assert(g != 0);
assert(nr >= sr);
assert(k != nullptr);
assert(packed_weights != nullptr);
const size_t skr = sr * kr;
do {
for (size_t nr_block_start = 0; nr_block_start < nc; nr_block_start += nr) {
const size_t nr_block_size = min(nc - nr_block_start, nr);
copy_bias(b, nr_block_start, nr_block_size, packed_weights);
packed_weights += nr;
for (size_t kr_block_start = 0; kr_block_start < round_up_po2(kc, skr);
kr_block_start += kr) {
for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size;
nr_block_offset++) {
const size_t kc_begin =
round_down_po2(kr_block_start, skr) +
((kr_block_start + nr_block_offset * kr) & (skr - 1));
for (size_t kr_block_offset = 0; kr_block_offset < kr;
kr_block_offset++) {
const size_t kc_idx = kc_begin + kr_block_offset;
packed_weights[kr_block_offset] =
kc_idx < kc
? k[kc_idx * k_stride + nr_block_start + nr_block_offset]
: 0.0f;
}
packed_weights += kr;
}
packed_weights += (nr - nr_block_size) * kr;
}
packed_weights = (xnn_float16*)((uintptr_t)packed_weights + extra_bytes);
}
k += nc * kc;
if XNN_UNPREDICTABLE (b != nullptr) {
b += nc;
}
} while (--g != 0);
}
void xnn_pack_qu8_gemm_gio_w(size_t g, size_t nc, size_t kc, size_t nr,
size_t kr, size_t sr, size_t k_stride,
const uint8_t* k, const int32_t* b,
const void* scale, void* packed_weights,
size_t extra_bytes,
const struct xnn_qu8_packing_params* params) {
assert(g != 0);
assert(nr >= sr);
assert(k != nullptr);
assert(packed_weights != nullptr);
const size_t skr = sr * kr;
const int32_t izp = (int32_t)params->input_zero_point;
const int32_t bzp = (int32_t)kc * izp * (int32_t)params->kernel_zero_point;
do {
for (size_t nr_block_start = 0; nr_block_start < nc; nr_block_start += nr) {
const size_t nr_block_size = min(nc - nr_block_start, nr);
unaligned_int32_t* packed_b = (unaligned_int32_t*)packed_weights;
copy_bias(b, nr_block_start, nr_block_size, packed_b, bzp);
packed_weights = (int32_t*)packed_weights + nr;
for (size_t kr_block_start = 0; kr_block_start < round_up_po2(kc, skr);
kr_block_start += kr) {
for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size;
nr_block_offset++) {
int32_t ksum = 0;
const size_t kc_begin =
round_down_po2(kr_block_start, skr) +
((kr_block_start + nr_block_offset * kr) & (skr - 1));
for (size_t kr_block_offset = 0; kr_block_offset < kr;
kr_block_offset++) {
const size_t kc_idx = kc_begin + kr_block_offset;
if (kc_idx < kc) {
const uint8_t kv =
k[kc_idx * k_stride + (nr_block_start + nr_block_offset)];
ksum += (int32_t)kv;
((uint8_t*)packed_weights)[kr_block_offset] = kv;
} else {
((uint8_t*)packed_weights)[kr_block_offset] =
params->kernel_zero_point;
}
}
packed_b[nr_block_offset] = packed_b[nr_block_offset] - ksum * izp;
packed_weights = (uint8_t*)packed_weights + kr;
}
packed_weights = (uint8_t*)packed_weights + (nr - nr_block_size) * kr;
}
packed_weights =
reinterpret_cast<void*>((uintptr_t)packed_weights + extra_bytes);
}
k += nc * kc;
if XNN_UNPREDICTABLE (b != nullptr) {
b += nc;
}
} while (--g != 0);
}
void xnn_pack_qs8_to_qu8_gemm_gio_w(
size_t g, size_t nc, size_t kc, size_t nr, size_t kr, size_t sr,
size_t k_stride, const int8_t* k, const int32_t* b, const float* scale,
void* packed_weights, size_t extra_bytes,
const struct xnn_qs8_packing_params* params) {
assert(g != 0);
assert(nr >= sr);
assert(k != nullptr);
assert(packed_weights != nullptr);
const size_t skr = sr * kr;
const uint32_t izp = (uint32_t)params->input_zero_point + 128;
do {
for (size_t nr_block_start = 0; nr_block_start < nc; nr_block_start += nr) {
const size_t nr_block_size = min(nc - nr_block_start, nr);
unaligned_int32_t* packed_b = (unaligned_int32_t*)packed_weights;
copy_bias(b, nr_block_start, nr_block_size, packed_b);
packed_weights = (uint32_t*)packed_weights + nr;
for (size_t kr_block_start = 0; kr_block_start < round_up_po2(kc, skr);
kr_block_start += kr) {
for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size;
nr_block_offset++) {
uint32_t ksum = 0;
const size_t kc_begin =
round_down_po2(kr_block_start, skr) +
((kr_block_start + nr_block_offset * kr) & (skr - 1));
for (size_t kr_block_offset = 0; kr_block_offset < kr;
kr_block_offset++) {
const size_t kc_idx = kc_begin + kr_block_offset;
const int8_t kv =
kc_idx < kc
? k[kc_idx * k_stride + (nr_block_start + nr_block_offset)]
: INT8_C(0);
ksum += (uint32_t)kv;
((int8_t*)packed_weights)[kr_block_offset] = kv;
}
packed_b[nr_block_offset] = packed_b[nr_block_offset] - ksum * izp;
packed_weights = (int8_t*)packed_weights + kr;
}
packed_weights = (int8_t*)packed_weights + (nr - nr_block_size) * kr;
}
packed_weights =
reinterpret_cast<void*>((uintptr_t)packed_weights + extra_bytes);
}
k += nc * kc;
if XNN_UNPREDICTABLE (b != nullptr) {
b += nc;
}
} while (--g != 0);
}
void xnn_pack_qs8_gemm_gio_w(size_t g, size_t nc, size_t kc, size_t nr,
size_t kr, size_t sr, size_t k_stride,
const int8_t* k, const int32_t* b,
const float* scale, void* packed_weights,
size_t extra_bytes,
const struct xnn_qs8_packing_params* params) {
assert(g != 0);
assert(nr >= sr);
assert(k != nullptr);
assert(packed_weights != nullptr);
const size_t skr = sr * kr;
const uint32_t izp = (uint32_t)params->input_zero_point;
do {
for (size_t nr_block_start = 0; nr_block_start < nc; nr_block_start += nr) {
const size_t nr_block_size = min(nc - nr_block_start, nr);
unaligned_int32_t* packed_b = (unaligned_int32_t*)packed_weights;
copy_bias(b, nr_block_start, nr_block_size, packed_b);
packed_weights = (uint32_t*)packed_weights + nr;
for (size_t kr_block_start = 0; kr_block_start < round_up_po2(kc, skr);
kr_block_start += kr) {
for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size;
nr_block_offset++) {
uint32_t ksum = 0;
const size_t kc_begin =
round_down_po2(kr_block_start, skr) +
((kr_block_start + nr_block_offset * kr) & (skr - 1));
for (size_t kr_block_offset = 0; kr_block_offset < kr;
kr_block_offset++) {
const size_t kc_idx = kc_begin + kr_block_offset;
const int8_t kv =
kc_idx < kc
? k[kc_idx * k_stride + (nr_block_start + nr_block_offset)]
: INT8_C(0);
ksum += (uint32_t)kv;
((int8_t*)packed_weights)[kr_block_offset] = kv;
}
packed_b[nr_block_offset] = packed_b[nr_block_offset] - ksum * izp;
packed_weights = (int8_t*)packed_weights + kr;
}
packed_weights = (int8_t*)packed_weights + (nr - nr_block_size) * kr;
}
packed_weights =
reinterpret_cast<void*>((uintptr_t)packed_weights + extra_bytes);
}
k += nc * kc;
if XNN_UNPREDICTABLE (b != nullptr) {
b += nc;
}
} while (--g != 0);
}
static void pack_weights_and_biases(
uint32_t flags, //
const struct xnn_gemm_config* gemm_config, //
size_t input_channels, //
size_t output_channels, //
size_t groups, //
size_t unused_block_size, //
size_t weights_stride, //
xnn_packw_gemm_gio_ukernel_fn pack_gemm_gio_w, //
xnn_packw_gemm_goi_ukernel_fn pack_gemm_goi_w, //
const void* accumulator_init, //
const void* weights, //
xnn_init_scale_params_fn init_extra_data0_fn, //
const void* extra_data0, //
size_t extra_data0_element_size, //
xnn_init_scale_params_fn init_extra_data1_fn, //
const void* extra_data1, //
size_t extra_data1_element_size, //
void* packed_weights_ptr, //
size_t extra_bytes, //
const void* params) {
const uint32_t nr = gemm_config->nr;
const uint32_t kr = UINT32_C(1) << gemm_config->log2_kr;
const uint32_t sr = UINT32_C(1) << gemm_config->log2_sr;
const size_t n_stride = round_up(output_channels, nr);
if (flags & XNN_FLAG_TRANSPOSE_WEIGHTS) {
pack_gemm_gio_w(groups, output_channels, input_channels, nr, kr, sr,
output_channels, weights, accumulator_init,
/*scale=*/nullptr, packed_weights_ptr, nr * extra_bytes,
params);
} else {
pack_gemm_goi_w(groups, output_channels, input_channels, nr, kr, sr,
weights, accumulator_init, /*scale=*/nullptr,
packed_weights_ptr, nr * extra_bytes, params);
}
if (extra_data1 != nullptr) {
assert(init_extra_data1_fn != nullptr);
for (size_t group = 0; group < groups; group++) {
void* packed_group_ptr = (void*)((char*)packed_weights_ptr +
group * n_stride * weights_stride);
void* weights = (void*)((uintptr_t)packed_group_ptr +
nr * (weights_stride - extra_bytes));
void* extra_data_ptr =
(void*)((uintptr_t)extra_data1 +
extra_data1_element_size * output_channels * group);
init_extra_data1_fn(output_channels, nr, nr * weights_stride,
extra_data_ptr, weights);
}
}
if (extra_data0 != nullptr) {
assert(init_extra_data0_fn != nullptr);
for (size_t group = 0; group < groups; group++) {
void* packed_group_ptr = (void*)((char*)packed_weights_ptr +
group * n_stride * weights_stride);
void* weights = (void*)((uintptr_t)packed_group_ptr +
nr * (weights_stride - extra_bytes));
if (extra_data1 != nullptr) {
weights = (void*)((uintptr_t)weights + nr * sizeof(float));
}
void* extra_data_ptr =
(void*)((uintptr_t)extra_data0 +
extra_data0_element_size * output_channels * group);
init_extra_data0_fn(output_channels, nr, nr * weights_stride,
extra_data_ptr, weights);
}
}
}
size_t xnn_packed_stride_qs8_weights_and_biases(
const struct xnn_gemm_config* gemm_config, size_t unused_k,
size_t unused_block_size, size_t k_stride, size_t extra_bytes) {
const size_t bias_element_size = sizeof(int32_t);
const size_t log2_filter_element_size = XNN_LOG2_SIZEOF_INT8_T;
return (k_stride << log2_filter_element_size) + bias_element_size +
extra_bytes;
}
void xnn_pack_qs8_weights_and_biases(
uint32_t flags, const struct xnn_gemm_config* gemm_config,
size_t input_channels, size_t output_channels, size_t groups,
size_t unused_block_size, size_t unused_k_stride,
const void* accumulator_init, const void* weights,
xnn_init_scale_params_fn init_extra_data0_fn, const void* extra_data0,
size_t extra_data0_element_size,
xnn_init_scale_params_fn init_extra_data1_fn, const void* extra_data1,
size_t extra_data1_element_size, void* packed_weights_ptr,
const void* params) {
const uint32_t kr = UINT32_C(1) << gemm_config->log2_kr;
const uint32_t sr = UINT32_C(1) << gemm_config->log2_sr;
const size_t packed_k_stride = round_up_po2(input_channels, kr * sr);
const size_t extra_bytes =
extra_data0_element_size + extra_data1_element_size;
const size_t weights_stride = xnn_packed_stride_qs8_weights_and_biases(
gemm_config, input_channels, unused_block_size, packed_k_stride,
extra_bytes);
return pack_weights_and_biases(
flags, gemm_config, input_channels, output_channels, groups,
unused_block_size, weights_stride,
(xnn_packw_gemm_gio_ukernel_fn)xnn_pack_qs8_gemm_gio_w,
(xnn_packw_gemm_goi_ukernel_fn)xnn_pack_qs8_gemm_goi_w, accumulator_init,
weights, init_extra_data0_fn, extra_data0, extra_data0_element_size,
init_extra_data1_fn, extra_data1, extra_data1_element_size,
packed_weights_ptr, extra_bytes, params);
}
size_t xnn_packed_stride_qs4_weights_and_biases(
const struct xnn_gemm_config* gemm_config, size_t unused_k,
size_t unused_block_size, size_t k_stride, size_t extra_bytes) {
const size_t bias_element_size = sizeof(int32_t);
const size_t log2_filter_element_size = XNN_LOG2_SIZEOF_INT8_T;
return (k_stride << log2_filter_element_size) + bias_element_size +
extra_bytes;
}
void xnn_pack_qs4_weights_and_biases(
uint32_t flags, const struct xnn_gemm_config* gemm_config,
size_t input_channels, size_t output_channels, size_t groups,
size_t unused_block_size, size_t unused_k_stride,
const void* accumulator_init, const void* weights,
xnn_init_scale_params_fn init_extra_data0_fn, const void* extra_data0,
size_t extra_data0_element_size,
xnn_init_scale_params_fn init_extra_data1_fn, const void* extra_data1,
size_t extra_data1_element_size, void* packed_weights_ptr,
const void* params) {
const uint32_t kr = UINT32_C(1) << gemm_config->log2_kr;
const uint32_t sr = UINT32_C(1) << gemm_config->log2_sr;
const uint32_t planes = gemm_config->planes;
size_t k_stride = round_up_po2(input_channels, kr * sr * planes);
k_stride = round_up_po2(k_stride, 2) >> 1;
const size_t extra_bytes =
extra_data0_element_size + extra_data1_element_size;
const size_t weights_stride = xnn_packed_stride_qs8_weights_and_biases(
gemm_config, input_channels, unused_block_size, k_stride, extra_bytes);
return pack_weights_and_biases(
flags, gemm_config, input_channels, output_channels, groups,
unused_block_size, weights_stride,
(xnn_packw_gemm_gio_ukernel_fn)xnn_pack_qs8_qc4w_gemm_gio_w,
(xnn_packw_gemm_goi_ukernel_fn)xnn_pack_qs8_qc4w_gemm_goi_w,
accumulator_init, weights, init_extra_data0_fn, extra_data0,
extra_data0_element_size, init_extra_data1_fn, extra_data1,
extra_data1_element_size, packed_weights_ptr, extra_bytes, params);
}
size_t xnn_packed_stride_qb4_weights_and_biases(
const struct xnn_gemm_config* gemm_config, size_t k, size_t block_size,
size_t k_stride, size_t extra_bytes) {
const size_t planes = gemm_config->planes;
size_t input_channels = round_up_po2(k, planes);
size_t block_scale_bytes = 0;
size_t num_blocks = 0;
const bool block_wise = (block_size != 0);
if (block_wise) {
num_blocks = input_channels / block_size;
block_scale_bytes += num_blocks * sizeof(uint16_t);
}
const size_t bias_element_size = sizeof(int32_t);
const size_t log2_filter_element_size = XNN_LOG2_SIZEOF_INT8_T;
return (k_stride << log2_filter_element_size) + bias_element_size +
extra_bytes + block_scale_bytes;
}
void xnn_pack_qb4_weights_and_biases(
uint32_t flags, const struct xnn_gemm_config* gemm_config,
size_t input_channels, size_t output_channels, size_t groups,
size_t block_size, size_t unused_k_stride, const void* accumulator_init,
const void* weights, xnn_init_scale_params_fn init_extra_data0_fn,
const void* extra_data0, size_t extra_data0_element_size,
xnn_init_scale_params_fn init_extra_data1_fn, const void* extra_data1,
size_t extra_data1_element_size, void* packed_weights_ptr,
const void* params) {
const uint32_t nr = gemm_config->nr;
const uint32_t kr = UINT32_C(1) << gemm_config->log2_kr;
const uint32_t sr = UINT32_C(1) << gemm_config->log2_sr;
const uint32_t planes = gemm_config->planes;
size_t k_stride = round_up_po2(input_channels, kr * sr * planes);
k_stride = round_up_po2(k_stride, 2) >> 1;
const size_t extra_bytes_bl = sizeof(uint16_t);
const size_t extra_bytes_n = sizeof(uint32_t);
if (flags & XNN_FLAG_TRANSPOSE_WEIGHTS) {
xnn_pack_qs8_qb4w_gemm_gio_w(
/*g=*/groups,
/*nc=*/output_channels,
/*kc=*/input_channels,
/*nr=*/nr,
/*kr=*/kr,
/*sr=*/sr,
/*k_stride=*/k_stride,
/*bl=*/block_size,
/*kernel=*/(const uint8_t*)weights,
/*bias=*/nullptr,
/*scale=*/(const xnn_bfloat16*)extra_data1,
/*packed_weights=*/packed_weights_ptr,
/*extra_bytes_bl=*/nr * extra_bytes_bl,
/*extra_bytes_n=*/nr * extra_bytes_n,
/*params*/ (const struct xnn_qs8_qc4w_packing_params*)params);
} else {
bool has_fast_packing_ukernel = gemm_config->pack_gemm_goi_bl != nullptr;
xnn_packw_gemm_goi_bl_ukernel_fn pack_gemm_goi =
has_fast_packing_ukernel
? (xnn_packw_gemm_goi_bl_ukernel_fn)gemm_config->pack_gemm_goi_bl
: (xnn_packw_gemm_goi_bl_ukernel_fn)xnn_pack_qs8_qb4w_gemm_goi_w;
// Fast Packing ukernel initializes scales and bias, so we pass in
// bias to packing fn if we use the fast packing ukernel, nullptr otherwise
pack_gemm_goi(
/*g=*/groups,
/*nc=*/output_channels,
/*kc=*/input_channels,
/*nr=*/nr,
/*kr=*/kr,
/*sr=*/sr,
/*bl=*/block_size,
/*kernel=*/(const uint8_t*)weights,
/*bias=*/
has_fast_packing_ukernel ? (const int32_t*)accumulator_init : nullptr,
/*scale=*/(const xnn_bfloat16*)extra_data1,
/*packed_weights=*/packed_weights_ptr,
/*extra_bytes_bl=*/nr * extra_bytes_bl,
/*extra_bytes_n=*/nr * extra_bytes_n,
/*params*/ (const struct xnn_qs8_qc4w_packing_params*)params);
if (has_fast_packing_ukernel) {
// Fast Packing UKernel initializes scales and bias, so can early exit
return;
}
}
// fill in kernel scales
const size_t num_blocks = input_channels / block_size;
const size_t weights_stride = xnn_packed_stride_qb4_weights_and_biases(
gemm_config, input_channels, block_size, k_stride, extra_bytes_n);
void* weights_start =
(void*)((uintptr_t)packed_weights_ptr +
nr * (sizeof(float) + (block_size * sizeof(int8_t) / 2)));
const size_t block_stride = /*weights*/ block_size / 2 + sizeof(uint16_t);
xnn_init_blockwise_scale_bf16_params(
output_channels, nr, nr * weights_stride,
/*num_blocks=*/num_blocks,
/*block_stride=*/gemm_config->nr * block_stride,
(const xnn_bfloat16*)extra_data1, weights_start);
// fill in bias if not null
if (accumulator_init != nullptr) {
weights_start = (void*)((uintptr_t)packed_weights_ptr +
gemm_config->nr * (weights_stride - sizeof(float)));
xnn_init_qs8_qc8w_scale_fp32_params(
output_channels, gemm_config->nr, gemm_config->nr * weights_stride,
(const float*)accumulator_init, weights_start);
}
}
size_t xnn_packed_stride_qu8_weights_and_biases(
const struct xnn_gemm_config* gemm_config, size_t unused_k,
size_t unused_block_size, size_t k_stride, size_t extra_bytes) {
const size_t bias_element_size = sizeof(int32_t);
const size_t log2_filter_element_size = XNN_LOG2_SIZEOF_INT8_T;
return (k_stride << log2_filter_element_size) + bias_element_size +
extra_bytes;
}
void xnn_pack_qu8_weights_and_biases(
uint32_t flags, const struct xnn_gemm_config* gemm_config,
size_t input_channels, size_t output_channels, size_t groups,
size_t unused_block_size, size_t unused_k_stride,
const void* accumulator_init, const void* weights,
xnn_init_scale_params_fn init_extra_data0_fn, const void* extra_data0,
size_t extra_data0_element_size,
xnn_init_scale_params_fn init_extra_data1_fn, const void* extra_data1,
size_t extra_data1_element_size, void* packed_weights_ptr,
const void* params) {
const uint32_t kr = UINT32_C(1) << gemm_config->log2_kr;
const uint32_t sr = UINT32_C(1) << gemm_config->log2_sr;
const size_t packed_k_stride = round_up_po2(input_channels, kr * sr);
const size_t extra_bytes =
extra_data0_element_size + extra_data1_element_size;
const size_t weights_stride = xnn_packed_stride_qs8_weights_and_biases(
gemm_config, input_channels, unused_block_size, packed_k_stride,
extra_bytes);
return pack_weights_and_biases(
flags, gemm_config, input_channels, output_channels, groups,
unused_block_size, weights_stride,
(xnn_packw_gemm_gio_ukernel_fn)xnn_pack_qu8_gemm_gio_w,
(xnn_packw_gemm_goi_ukernel_fn)xnn_pack_qu8_gemm_goi_w, accumulator_init,
weights, init_extra_data0_fn, extra_data0, extra_data0_element_size,
init_extra_data1_fn, extra_data1, extra_data1_element_size,
packed_weights_ptr, extra_bytes, params);
}
#if XNN_ENABLE_KLEIDIAI
size_t xnn_packed_stride_kai_qs4_weights_and_biases_sme(
const struct xnn_gemm_config* gemm_config, size_t k, size_t unused_k_stride,
size_t unused_block_size, size_t extra_bytes) {
const uint32_t kr = UINT32_C(1) << gemm_config->log2_kr;
const uint32_t sr = UINT32_C(1) << gemm_config->log2_sr;
return kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4cxps1s0_qsu4cxs1s0_neon(
k, /*nr=*/1, kr, sr);
}
void xnn_pack_kai_qs4_weights_and_biases_sme(
uint32_t flags, const struct xnn_gemm_config* gemm_config,
size_t input_channels, size_t output_channels, size_t groups,
size_t unused_block_size, size_t unused_k_stride,
const void* accumulator_init, const void* weights,
xnn_init_scale_params_fn init_extra_data0_fn, const void* extra_data0,
size_t extra_data0_element_size,
xnn_init_scale_params_fn init_extra_data1_fn, const void* extra_data1,
size_t extra_data1_element_size, void* packed_weights_ptr,
const void* params) {
const uint32_t nr = gemm_config->nr;
const uint32_t kr = UINT32_C(1) << gemm_config->log2_kr;
const uint32_t sr = UINT32_C(1) << gemm_config->log2_sr;
const struct xnn_qs8_qc4w_packing_params* xnn_params =
reinterpret_cast<const struct xnn_qs8_qc4w_packing_params*>(params);
struct kai_rhs_pack_nxk_qsi4cxps1s0_qsu4cxs1s0_neon_params kai_params;
kai_params.lhs_zero_point = xnn_params->input_zero_point;
kai_params.rhs_zero_point = xnn_params->kernel_zero_point;
bool free_accumulator_init = false;
if (extra_data0 == nullptr) {
extra_data0 = calloc(output_channels, sizeof(float));
free_accumulator_init = true;
}
kai_run_rhs_pack_nxk_qsi4cxps1s0_qsu4cxs1s0_neon(
groups, output_channels, input_channels, nr, kr, sr,
/*rhs=*/reinterpret_cast<const uint8_t*>(weights),
/*bias=*/reinterpret_cast<const float*>(extra_data0),
/*scale=*/reinterpret_cast<const float*>(extra_data1),
/*rhs_packed=*/packed_weights_ptr,
/*extra_bytes=*/0, &kai_params);
if (free_accumulator_init) {
free((void*)extra_data0);
}
}
size_t xnn_packed_stride_kai_qs4_weights_and_biases(
const struct xnn_gemm_config* gemm_config, size_t k,
size_t unused_block_size, size_t unused_k_stride, size_t extra_bytes) {
const uint32_t kr = UINT32_C(1) << gemm_config->log2_kr;
const uint32_t sr = UINT32_C(1) << gemm_config->log2_sr;
return kai_get_rhs_packed_stride_rhs_pack_kxn_qsi4cxp_qs4cxs1s0(k, /*nr=*/1,
kr, sr);
}
void xnn_pack_kai_qs4_weights_and_biases(
uint32_t flags, const struct xnn_gemm_config* gemm_config,
size_t input_channels, size_t output_channels, size_t groups,
size_t unused_block_size, size_t k_stride, const void* accumulator_init,
const void* weights, xnn_init_scale_params_fn init_extra_data0_fn,
const void* extra_data0, size_t extra_data0_element_size,
xnn_init_scale_params_fn init_extra_data1_fn, const void* extra_data1,
size_t extra_data1_element_size, void* packed_weights_ptr,
const void* params) {
const uint32_t nr = gemm_config->nr;
const uint32_t kr = UINT32_C(1) << gemm_config->log2_kr;
const uint32_t sr = UINT32_C(1) << gemm_config->log2_sr;
const struct xnn_qs8_qc4w_packing_params* xnn_params =
reinterpret_cast<const struct xnn_qs8_qc4w_packing_params*>(params);
if (flags & XNN_FLAG_TRANSPOSE_WEIGHTS) {
// Repack the packing params.
struct kai_rhs_pack_kxn_qsi4cxp_qs4cxs1s0_params kai_params;
kai_params.lhs_zero_point = xnn_params->input_zero_point;
kai_params.rhs_zero_point = xnn_params->kernel_zero_point;
kai_run_rhs_pack_kxn_qsi4cxp_qs4cxs1s0(
groups, output_channels, input_channels, nr, kr, sr,
/*rhs=*/reinterpret_cast<const uint8_t*>(weights),
/*bias=*/reinterpret_cast<const float*>(extra_data0),
/*scale=*/reinterpret_cast<const float*>(extra_data1),
/*rhs_packed=*/packed_weights_ptr,
/*extra_bytes=*/0, &kai_params);
} else {
// Repack the packing params.
struct kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0_params kai_params;
kai_params.lhs_zero_point = xnn_params->input_zero_point;
kai_params.rhs_zero_point = xnn_params->kernel_zero_point;
kai_run_rhs_pack_nxk_qsi4cxp_qs4cxs1s0(
groups, output_channels, input_channels, nr, kr, sr,
/*rhs=*/reinterpret_cast<const uint8_t*>(weights),
/*bias=*/reinterpret_cast<const float*>(extra_data0),
/*scale=*/reinterpret_cast<const float*>(extra_data1),
/*rhs_packed=*/packed_weights_ptr,
/*extra_bytes=*/0, &kai_params);
}
}
size_t xnn_packed_stride_kai_qs8_qc8w_weights_and_biases_sme(
const struct xnn_gemm_config* gemm_config, size_t k,
size_t unused_block_size, size_t unused_k_stride, size_t extra_bytes) {
size_t ret_val =
kai_get_rhs_packed_stride_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme(
k) /
kai_get_n_step_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme();
return ret_val;
}
void transpose_weights_x8(const int8_t* in, int8_t* out, size_t height,
size_t width) {
for (size_t i = 0; i < height; ++i) {
for (size_t j = 0; j < width; ++j) {
out[j * height + i] = in[i * width + j];
}
}
}
static void transpose_weights_x16(const uint16_t* in, uint16_t* out,
size_t height, size_t width) {
for (size_t j = 0; j < width; ++j) {
for (size_t i = 0; i < height; ++i) {
out[j * height + i] = in[i * width + j];
}
}
}
void xnn_pack_kai_qs8_qc8w_weights_and_biases_sme(
uint32_t flags, const struct xnn_gemm_config* gemm_config,
size_t input_channels, size_t output_channels, size_t groups,
size_t unused_block_size, size_t k_stride, const void* accumulator_init,
const void* weights, xnn_init_scale_params_fn init_extra_data0_fn,
const void* extra_data0, size_t extra_data0_element_size,
xnn_init_scale_params_fn init_extra_data1_fn, const void* extra_data1,
size_t extra_data1_element_size, void* packed_weights_ptr,
const void* params) {
const uint32_t nr = gemm_config->nr;
const uint32_t kr = UINT32_C(1) << gemm_config->log2_kr;
const uint32_t sr = UINT32_C(1) << gemm_config->log2_sr;
const size_t rhs_stride = output_channels * sizeof(int8_t);
// Some packing kernels assume that the bias is non-null. Allocate a zero
// initialized array as a workaround if bias is null.
bool free_accumulator_init = false;
if (accumulator_init == NULL) {
accumulator_init = calloc(output_channels, sizeof(int32_t));
free_accumulator_init = true;
}
const struct xnn_qs8_packing_params* xnn_params =
reinterpret_cast<const struct xnn_qs8_packing_params*>(params);
struct kai_rhs_pack_qsi8cx_params kai_params;
kai_params.lhs_zero_point = xnn_params->input_zero_point;
kai_params.scale_multiplier = 1.f;
if (flags & XNN_FLAG_TRANSPOSE_WEIGHTS) {
kai_run_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme(
groups, output_channels, input_channels, nr, kr, sr, rhs_stride,
/*rhs=*/weights,
/*bias=*/accumulator_init,
/*scale=*/extra_data0,
/*rhs_packed=*/packed_weights_ptr,
/*extra_bytes=*/0, &kai_params);
} else {
// Transpose the weights until the transpose packing function is ready.
int8_t* tmp_data =
(int8_t*)malloc(input_channels * output_channels * sizeof(int8_t));
transpose_weights_x8((const int8_t*)weights, tmp_data, output_channels,
input_channels);
kai_run_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme(
groups, output_channels, input_channels, nr, kr, sr, rhs_stride,
/*rhs=*/tmp_data,
/*bias=*/accumulator_init,
/*scale=*/extra_data0,
/*rhs_packed=*/packed_weights_ptr,
/*extra_bytes=*/0, &kai_params);
free(tmp_data);
}
if (free_accumulator_init) {
free((void*)accumulator_init);
}
}
size_t xnn_packed_stride_kai_qs8_weights_and_biases(
const struct xnn_gemm_config* gemm_config, size_t k,
size_t unused_block_size, size_t unused_k_stride, size_t extra_bytes) {
const uint32_t kr = UINT32_C(1) << gemm_config->log2_kr;
const uint32_t sr = UINT32_C(1) << gemm_config->log2_sr;
return kai_get_rhs_packed_stride_rhs_pack_kxn_qsi8cxp_qsi8cx_neon(k, /*nr=*/1,
kr, sr);
}
void xnn_pack_kai_qs8_weights_and_biases(
uint32_t flags, const struct xnn_gemm_config* gemm_config,
size_t input_channels, size_t output_channels, size_t groups,
size_t unused_block_size, size_t unused_k_stride,
const void* accumulator_init, const void* weights,
xnn_init_scale_params_fn init_extra_data0_fn, const void* extra_data0,
size_t extra_data0_element_size,
xnn_init_scale_params_fn init_extra_data1_fn, const void* extra_data1,
size_t extra_data1_element_size, void* packed_weights_ptr,
const void* params) {
const uint32_t nr = gemm_config->nr;
const uint32_t kr = UINT32_C(1) << gemm_config->log2_kr;
const uint32_t sr = UINT32_C(1) << gemm_config->log2_sr;
const struct xnn_qs8_qc8w_packing_params* xnn_params =
reinterpret_cast<const struct xnn_qs8_qc8w_packing_params*>(params);
// Repack the packing params.
struct kai_rhs_pack_qsi8cx_params kai_params;
kai_params.lhs_zero_point = xnn_params->input_zero_point;
kai_params.scale_multiplier = xnn_params->scale_multiplier;
const size_t weights_group_stride =
sizeof(int8_t) * input_channels * output_channels;
const size_t n_stride = round_up(output_channels, nr);
const size_t packed_weights_group_stride =
n_stride * xnn_packed_stride_kai_qs8_weights_and_biases(
gemm_config, input_channels, unused_block_size,
/*unused_k_stride=*/0,
extra_data0_element_size + extra_data1_element_size);
if (flags & XNN_FLAG_TRANSPOSE_WEIGHTS) {
for (size_t group = 0; group < groups; group++) {
kai_run_rhs_pack_kxn_qsi8cxp_qsi8cx_neon(
/*groups=*/1, output_channels, input_channels, nr, kr, sr,
/*rhs=*/
reinterpret_cast<const int8_t*>((uintptr_t)weights +
group * weights_group_stride),
/*bias=*/
extra_data0 ? reinterpret_cast<const float*>(extra_data0) +
group * output_channels
: NULL,
/*scale=*/
extra_data1 ? reinterpret_cast<const float*>(extra_data1) +
group * output_channels
: NULL,
/*rhs_packed=*/
(void*)((uintptr_t)packed_weights_ptr +
group * packed_weights_group_stride),
/*extra_bytes=*/0, &kai_params);
}
} else {
for (size_t group = 0; group < groups; group++) {
kai_run_rhs_pack_nxk_qsi8cxp_qsi8cx_neon(
/*groups=*/1, output_channels, input_channels, nr, kr, sr,
/*rhs=*/
reinterpret_cast<const int8_t*>((uintptr_t)weights +
group * weights_group_stride),
/*bias=*/
extra_data0 ? reinterpret_cast<const float*>(extra_data0) +
group * output_channels
: NULL,
/*scale=*/
extra_data1 ? reinterpret_cast<const float*>(extra_data1) +
group * output_channels
: NULL,
/*rhs_packed=*/
(void*)((uintptr_t)packed_weights_ptr +
group * packed_weights_group_stride),
/*extra_bytes=*/0, &kai_params);
}
}
}
size_t xnn_packed_stride_kai_f16_weights_and_biases(
const struct xnn_gemm_config* unused_gemm_config, size_t k,
size_t unused_block_size, size_t unused_k_stride,
size_t unused_extra_bytes) {
size_t ret_val =
kai_get_rhs_packed_stride_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme(k) /
kai_get_n_step_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme();
return ret_val;
}
void xnn_pack_kai_f16_weights_and_biases(
uint32_t flags, const struct xnn_gemm_config* gemm_config,
size_t input_channels, size_t output_channels, size_t groups,
size_t unused_block_size, size_t k_stride, const void* accumulator_init,
const void* weights, xnn_init_scale_params_fn init_extra_data0_fn,
const void* extra_data0, size_t extra_data0_element_size,
xnn_init_scale_params_fn init_extra_data1_fn, const void* extra_data1,
size_t extra_data1_element_size, void* packed_weights_ptr,
const void* params) {
assert(extra_data0 == nullptr);
assert(extra_data1 == nullptr);
const uint32_t nr = gemm_config->nr;
const uint32_t kr = UINT32_C(1) << gemm_config->log2_kr;
const uint32_t sr = UINT32_C(1) << gemm_config->log2_sr;
// Some packing kernels assume that the bias is non-null. Allocate a zero
// initialized array as a workaround if bias is null.
bool free_accumulator_init = false;
if (accumulator_init == NULL) {
accumulator_init = calloc(output_channels, sizeof(float));
free_accumulator_init = true;
}
const size_t rhs_stride = k_stride * sizeof(xnn_float16);
const size_t weights_group_stride =
sizeof(xnn_float16) * input_channels * output_channels;
const size_t n_stride = round_up(output_channels, nr);
const size_t packed_weights_group_stride =
n_stride * xnn_packed_stride_kai_f16_weights_and_biases(
gemm_config, input_channels, unused_block_size,
/*unused_k_stride=*/0,
/*unused_extra_bytes=*/0);
if (flags & XNN_FLAG_TRANSPOSE_WEIGHTS) {
for (size_t group = 0; group < groups; group++) {
kai_run_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme(
/*groups=*/1, output_channels, input_channels, nr, kr, sr, rhs_stride,
/*rhs=*/
(const void*)((uintptr_t)weights + group * weights_group_stride),
/*bias=*/
free_accumulator_init
? accumulator_init
: (const float*)(accumulator_init) + group * output_channels,
/*scale=*/NULL,
/*rhs_packed=*/
(void*)((uintptr_t)packed_weights_ptr +
group * packed_weights_group_stride),
/*extra_bytes=*/0, /*params=*/NULL);
}
} else {
for (size_t group = 0; group < groups; group++) {
kai_run_rhs_pack_nxk_x16p2vlx2b_x16_x16_sme(
/*groups=*/1, output_channels, input_channels, nr, kr, sr, rhs_stride,
/*rhs=*/
(const void*)((uintptr_t)weights + group * weights_group_stride),
/*bias=*/
free_accumulator_init
? accumulator_init
: (const float*)(accumulator_init) + group * output_channels,
/*scale=*/NULL,
/*rhs_packed=*/
(void*)((uintptr_t)packed_weights_ptr +
group * packed_weights_group_stride),
/*extra_bytes=*/0, /*params=*/NULL);
}
}
if (free_accumulator_init) {
free((void*)accumulator_init);
}
}
size_t xnn_packed_stride_kai_f32_weights_and_biases(
const struct xnn_gemm_config* gemm_config, size_t k,
size_t unused_block_size, size_t unused_k_stride, size_t extra_bytes) {
size_t ret_val =
kai_get_rhs_packed_stride_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme(k) /
kai_get_n_step_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme();
return ret_val;
}
void xnn_pack_kai_f32_weights_and_biases(
uint32_t flags, const struct xnn_gemm_config* gemm_config,
size_t input_channels, size_t output_channels, size_t groups,
size_t unused_block_size, size_t k_stride, const void* accumulator_init,
const void* weights, xnn_init_scale_params_fn init_extra_data0_fn,
const void* extra_data0, size_t extra_data0_element_size,
xnn_init_scale_params_fn init_extra_data1_fn, const void* extra_data1,
size_t extra_data1_element_size, void* packed_weights_ptr,
const void* params) {
assert(extra_data0 == nullptr);
assert(extra_data1 == nullptr);
const uint32_t nr = gemm_config->nr;
const uint32_t kr = UINT32_C(1) << gemm_config->log2_kr;
const uint32_t sr = UINT32_C(1) << gemm_config->log2_sr;
// Some packing kernels assume that the bias is non-null. Allocate a zero
// initialized array as a workaround if bias is null.
bool free_accumulator_init = false;
if (accumulator_init == NULL) {
accumulator_init = calloc(output_channels, sizeof(float));
free_accumulator_init = true;
}
const size_t rhs_stride = k_stride * sizeof(float);
const size_t weights_group_stride =
sizeof(float) * input_channels * output_channels;
const size_t n_stride = round_up(output_channels, nr);
const size_t packed_weights_group_stride =
n_stride * xnn_packed_stride_kai_f32_weights_and_biases(
gemm_config, input_channels, unused_block_size,
/*unused_k_stride=*/0,
/*unused_extra_bytes=*/0);
if (flags & XNN_FLAG_TRANSPOSE_WEIGHTS) {
for (size_t group = 0; group < groups; group++) {
kai_run_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme(
/*groups=*/1, output_channels, input_channels, nr, kr, sr, rhs_stride,
/*rhs=*/
(const void*)((uintptr_t)weights + group * weights_group_stride),
/*bias=*/
free_accumulator_init
? accumulator_init
: (const float*)(accumulator_init) + group * output_channels,
/*scale=*/NULL,
/*rhs_packed=*/
(void*)((uintptr_t)packed_weights_ptr +
group * packed_weights_group_stride),
/*extra_bytes=*/0, /*params=*/NULL);
}
} else {
for (size_t group = 0; group < groups; group++) {
kai_run_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme(
/*groups=*/1, output_channels, input_channels, nr, kr, sr, rhs_stride,
/*rhs=*/
(const void*)((uintptr_t)weights + group * weights_group_stride),
/*bias=*/
free_accumulator_init
? accumulator_init
: (const float*)(accumulator_init) + group * output_channels,
/*scale=*/NULL,
/*rhs_packed=*/
(void*)((uintptr_t)packed_weights_ptr +
group * packed_weights_group_stride),
/*extra_bytes=*/0, /*params=*/NULL);
}
}
if (free_accumulator_init) {
free((void*)accumulator_init);
}
}
size_t xnn_packed_stride_kai_qb4_weights_and_biases(
const struct xnn_gemm_config* gemm_config, size_t k, size_t block_size,
size_t unused_k_stride, size_t extra_bytes) {
const uint32_t kr = UINT32_C(1) << gemm_config->log2_kr;
const uint32_t sr = UINT32_C(1) << gemm_config->log2_sr;
const uint32_t nr = gemm_config->nr;
// We want the weight stride with nr = 1, but kleidi enforces a constraint
// where nr % 4 == 0. So instead we give nr to get the nr-scaled stride, and
// divide by nr to scaled down the stride.
const size_t nr_scaled_packed_stride =
kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0(
k, nr, kr, sr, block_size, kai_datatype::kai_dt_bf16);
return nr_scaled_packed_stride / nr;
}
void xnn_pack_kai_qb4_weights_and_biases(
uint32_t flags, const struct xnn_gemm_config* gemm_config,
size_t input_channels, size_t output_channels, size_t groups,
size_t block_size, size_t k_stride, const void* accumulator_init,
const void* weights, xnn_init_scale_params_fn init_extra_data0_fn,
const void* extra_data0, size_t extra_data0_element_size,
xnn_init_scale_params_fn init_extra_data1_fn, const void* extra_data1,
size_t extra_data1_element_size, void* packed_weights_ptr,
const void* params) {
const uint32_t nr = gemm_config->nr;
const uint32_t kr = UINT32_C(1) << gemm_config->log2_kr;
const uint32_t sr = UINT32_C(1) << gemm_config->log2_sr;
const uint32_t planes = gemm_config->planes;
const struct xnn_qs8_qc4w_packing_params* xnn_params =
reinterpret_cast<const struct xnn_qs8_qc4w_packing_params*>(params);
size_t rhs_stride = (k_stride + 1) / 2;
size_t blocks_per_row = (input_channels + block_size - 1) / block_size;
if (flags & XNN_FLAG_TRANSPOSE_WEIGHTS) {
struct kai_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0_params kai_params;
kai_params.lhs_zero_point = xnn_params->input_zero_point;
kai_params.rhs_zero_point = xnn_params->kernel_zero_point;
kai_params.scale_dt = kai_datatype::kai_dt_bf16;
kai_run_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0(
groups, output_channels, input_channels, nr, kr, sr,
/*bl=*/block_size,
/*rhs=*/reinterpret_cast<const uint8_t*>(weights), rhs_stride,
/*bias=*/reinterpret_cast<const float*>(extra_data0),
/*scale=*/reinterpret_cast<const uint16_t*>(extra_data1),
/*scale_stride=*/blocks_per_row * sizeof(uint16_t),
/*rhs_packed*/ packed_weights_ptr,
/*extra_bytes=*/0, &kai_params);
} else {
// Repack the packing params.
struct kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0_params kai_params;
kai_params.lhs_zero_point = xnn_params->input_zero_point;
kai_params.rhs_zero_point = xnn_params->kernel_zero_point;
kai_params.scale_dt = kai_datatype::kai_dt_bf16;
kai_run_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0(
groups, output_channels, input_channels, nr, kr, sr,
/*bl=*/block_size,
/*rhs=*/reinterpret_cast<const uint8_t*>(weights), rhs_stride,
/*bias=*/reinterpret_cast<const float*>(extra_data0),
/*scale=*/reinterpret_cast<const uint16_t*>(extra_data1),
/*scale_stride=*/blocks_per_row * sizeof(uint16_t),
/*rhs_packed*/ packed_weights_ptr,
/*extra_bytes=*/0, &kai_params);
}
// init bias
size_t packed_k_stride = round_up_po2(input_channels, kr * sr);
if (1 < planes) {
input_channels = round_up_po2(input_channels, planes);
packed_k_stride = round_up_po2(input_channels, kr * sr * planes);
packed_k_stride = round_up_po2(packed_k_stride, 2) >> 1;
}
const size_t weights_stride = xnn_packed_stride_kai_qb4_weights_and_biases(
gemm_config, input_channels, block_size, packed_k_stride, 0);
if (accumulator_init != NULL) {
void* weights_start =
(void*)((uintptr_t)packed_weights_ptr +
nr * (sizeof(float) + (block_size * sizeof(int8_t) / 2)));
weights_start = (void*)((uintptr_t)packed_weights_ptr +
nr * (weights_stride - sizeof(float)));
xnn_init_qs8_qc8w_scale_fp32_params(
output_channels, nr, nr * weights_stride,
(const float*)accumulator_init, weights_start);
}
}
void xnn_pack_kai_f16_conv_goki_w_sme(size_t g, size_t nc, size_t ks,
size_t kc, size_t nr, size_t kr,
size_t sr, const uint16_t* k,
const uint16_t* b, const void* scale,
void* packed_weights, size_t extra_bytes,
const void* params) {
assert(g != 0);
assert(nr >= sr);
assert(k != nullptr);
assert(packed_weights != nullptr);
uint16_t* tmp_bias = NULL;
if (b == NULL) {
tmp_bias = (uint16_t*)xnn_allocate_zero_memory(g * nc * sizeof(uint16_t));
b = tmp_bias;
}
uint16_t* tmp_data =
(uint16_t*)xnn_allocate_memory(nc * ks * kc * sizeof(uint16_t));
const size_t rhs_row_stride = nc * sizeof(uint16_t);
const size_t packed_rhs_size =
kai_get_rhs_packed_size_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme(
nc, ks, kc);
for (size_t g_idx = 0; g_idx < g; ++g_idx) {
// TODO: Remove transpose_weights_x16 if KleidiAI release imatmul_pack_nxk packing variant
transpose_weights_x16(k, tmp_data, nc, ks * kc);
// Pass FP16 bias directly to the rhs_imatmul packer which expects FP16 bias
// for this kernel.
kai_run_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme(
nc, ks, kc, rhs_row_stride, tmp_data, b, packed_weights);
k += nc * ks * kc;
b += nc;
packed_weights = (void*)((uintptr_t)packed_weights + packed_rhs_size);
}
xnn_release_memory(tmp_data);
if (tmp_bias != NULL) {
xnn_release_memory(tmp_bias);
}
}
size_t xnn_packed_size_kai_f16_conv_goki_w(size_t nc, size_t ks, size_t kc) {
return kai_get_rhs_packed_size_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme(
nc, ks, kc);
}
void xnn_pack_kai_qs8_conv_goki_w_sme(
size_t g, size_t nc, size_t ks, size_t kc, size_t nr, size_t kr, size_t sr,
const int8_t* k, const int32_t* b, const float* scale, void* packed_weights,
size_t extra_bytes, const struct xnn_qs8_packing_params* params) {
assert(g != 0);
assert(nr >= sr);
assert(k != nullptr);
assert(packed_weights != nullptr);
kai_rhs_pack_qsi8cx_params kai_params{};
kai_params.lhs_zero_point = params->input_zero_point;
kai_params.scale_multiplier = 1.0F;
int32_t* tmp_bias = NULL;
if (b == NULL) {
tmp_bias = (int32_t*)xnn_allocate_zero_memory(g * nc * sizeof(int32_t));
b = tmp_bias;
}
int8_t* tmp_data =
(int8_t*)xnn_allocate_memory(nc * ks * kc * sizeof(int8_t));
const size_t rhs_row_stride = nc * sizeof(int8_t);
const size_t packed_rhs_size =
kai_get_rhs_packed_size_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme(
nc, ks, kc);
for (size_t g_idx = 0; g_idx < g; ++g_idx) {
transpose_weights_x8(k, tmp_data, nc, ks * kc);
kai_run_rhs_imatmul_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme(
nc, ks, kc, rhs_row_stride, tmp_data, b, scale, packed_weights,
&kai_params);
k += nc * ks * kc;
b += nc;
if (scale != NULL) {
scale += nc;
}
packed_weights = (uint8_t*)packed_weights + packed_rhs_size;
}
xnn_release_memory(tmp_data);
if (tmp_bias != NULL) {
xnn_release_memory(tmp_bias);
}
}
void transpose_weights(const float* in, float* out, size_t height,
size_t width) {
for (size_t i = 0; i < height; ++i) {
for (size_t j = 0; j < width; ++j) {
out[j * height + i] = in[i * width + j];
}
}
}
void xnn_pack_kai_pf32_conv_goki_w_sme(
size_t g, size_t nc, size_t ks, size_t kc,
size_t nr, size_t kr, size_t sr, const float* k,
const float* b, const void* scale,
float* packed_weights, size_t extra_bytes,
const void* params) {
assert(g != 0);
assert(nr >= sr);
assert(k != nullptr);
assert(packed_weights != nullptr);
float* tmp_bias = NULL;
if (b == NULL) {
tmp_bias = (float*) calloc(g * nc, sizeof(float));
b = tmp_bias;
}
float* tmp_data = (float*) malloc(nc * ks * kc * sizeof(float));
const size_t rhs_row_stride = nc * sizeof(float);
const size_t packed_rhs_size = kai_get_rhs_packed_size_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme(nc, ks, kc);
for (size_t g_idx = 0; g_idx < g; ++g_idx) {
transpose_weights(k, tmp_data, nc, ks * kc);
kai_run_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme(
nc, ks, kc, rhs_row_stride, tmp_data, b, packed_weights);
k += nc * ks * kc;
b += nc;
packed_weights = (float*)((uintptr_t)packed_weights + packed_rhs_size);
}
free(tmp_data);
if (tmp_bias != NULL) {
free(tmp_bias);
}
}
#endif // XNN_ENABLE_KLEIDIAI
size_t xnn_packed_stride_qc2w_weights_and_biases(
const struct xnn_gemm_config* gemm_config, size_t k,
size_t unused_block_size, size_t unused_k_stride,
size_t unused_extra_bytes) {
// Extract some useful constants.
const uint32_t kr = UINT32_C(1) << gemm_config->log2_kr;
const uint32_t sr = UINT32_C(1) << gemm_config->log2_sr;
const uint32_t planes = gemm_config->planes;
k = round_up_po2(k, 4);
const size_t packed_k4 =
round_up_po2(k, kr * sr * planes); // 4 blocks for crumbs
const size_t packed_k_bytes = (packed_k4 + 3) / 4;
return packed_k_bytes + (sizeof(int32_t) + sizeof(float) * 3);
}
void xnn_pack_qs8_qc2w_weights_and_biases(
uint32_t flags, const struct xnn_gemm_config* gemm_config,
size_t input_channels, size_t output_channels, size_t groups,
size_t unused_block_size, size_t k_stride, const void* accumulator_init,
const void* weights, xnn_init_scale_params_fn init_extra_data0_fn,
const void* extra_data0, size_t extra_data0_element_size,
xnn_init_scale_params_fn init_extra_data1_fn, const void* extra_data1,
size_t extra_data1_element_size, void* packed_weights_ptr,
const void* params) {
// Extract some useful constants.
const uint32_t nr = gemm_config->nr;
const uint32_t kr = UINT32_C(1) << gemm_config->log2_kr;
const uint32_t sr = UINT32_C(1) << gemm_config->log2_sr;
const size_t packed_stride = xnn_packed_stride_qc2w_weights_and_biases(
gemm_config, input_channels, unused_block_size, k_stride,
/*unused_extra_bytes=*/0);
if (flags & XNN_FLAG_TRANSPOSE_WEIGHTS) {
xnn_pack_qs8_qc2w_gemm_gio_w(
/*g=*/1, output_channels, input_channels, nr, kr, sr, output_channels,
static_cast<const uint8_t*>(weights), /*b=*/nullptr,
/*scale=*/nullptr, packed_weights_ptr, 2 * sizeof(float) * nr,
static_cast<const struct xnn_qs8_qc2w_packing_params*>(params));
} else {
xnn_pack_qs8_qc2w_gemm_goi_w(
/*g=*/1, output_channels, input_channels, nr, kr, sr,
static_cast<const uint8_t*>(weights), /*b=*/nullptr,
/*scale=*/nullptr, packed_weights_ptr, 2 * sizeof(float) * nr,
static_cast<const struct xnn_qs8_qc2w_packing_params*>(params));
}
// Pack the kernel_scale.
if (extra_data1 != nullptr) {
xnn_init_qs8_qc8w_scale_fp32_params(
output_channels, nr, nr * packed_stride,
reinterpret_cast<const float*>(extra_data1),
reinterpret_cast<void*>(
reinterpret_cast<uintptr_t>(packed_weights_ptr) +
nr * (packed_stride - 2 * sizeof(float))));
}
// Pack the bias.
if (extra_data0 != nullptr) {
xnn_init_qs8_qc8w_scale_fp32_params(
output_channels, nr, nr * packed_stride,
reinterpret_cast<const float*>(extra_data0),
reinterpret_cast<void*>(
reinterpret_cast<uintptr_t>(packed_weights_ptr) +
nr * (packed_stride - sizeof(float))));
}
}
void xnn_pack_qd8_qc2w_weights_and_biases(
uint32_t flags, const struct xnn_gemm_config* gemm_config,
size_t input_channels, size_t output_channels, size_t groups,
size_t unused_block_size, size_t k_stride, const void* accumulator_init,
const void* weights, xnn_init_scale_params_fn init_extra_data0_fn,
const void* extra_data0, size_t extra_data0_element_size,
xnn_init_scale_params_fn init_extra_data1_fn, const void* extra_data1,
size_t extra_data1_element_size, void* packed_weights_ptr,
const void* params) {
// Extract some useful constants.
const uint32_t nr = gemm_config->nr;
const uint32_t kr = UINT32_C(1) << gemm_config->log2_kr;
const uint32_t sr = UINT32_C(1) << gemm_config->log2_sr;
const size_t packed_stride = xnn_packed_stride_qc2w_weights_and_biases(
gemm_config, input_channels, unused_block_size, k_stride,
/*unused_extra_bytes=*/0);
if (flags & XNN_FLAG_TRANSPOSE_WEIGHTS) {
xnn_pack_qd8_qc2w_gemm_gio_w(
/*g=*/1, output_channels, input_channels, nr, kr, sr, output_channels,
static_cast<const uint8_t*>(weights), /*b=*/nullptr,
/*scale=*/nullptr, packed_weights_ptr, 2 * sizeof(float) * nr,
static_cast<const struct xnn_qd8_qc2w_packing_params*>(params));
} else {
xnn_pack_qd8_qc2w_gemm_goi_w(
/*g=*/1, output_channels, input_channels, nr, kr, sr,
static_cast<const uint8_t*>(weights), /*b=*/nullptr,
/*scale=*/nullptr, packed_weights_ptr, 2 * sizeof(float) * nr,
static_cast<const struct xnn_qd8_qc2w_packing_params*>(params));
}
// Pack the kernel_scale.
if (extra_data1 != nullptr) {
xnn_init_qs8_qc8w_scale_fp32_params(
output_channels, nr, nr * packed_stride,
reinterpret_cast<const float*>(extra_data1),
reinterpret_cast<void*>(
reinterpret_cast<uintptr_t>(packed_weights_ptr) +
nr * (packed_stride - 2 * sizeof(float))));
}
// Pack the bias.
if (extra_data0 != nullptr) {
xnn_init_qs8_qc8w_scale_fp32_params(
output_channels, nr, nr * packed_stride,
reinterpret_cast<const float*>(extra_data0),
reinterpret_cast<void*>(
reinterpret_cast<uintptr_t>(packed_weights_ptr) +
nr * (packed_stride - sizeof(float))));
}
}
void xnn_pack_f32_qs8w_gemm_gio_w(size_t g, size_t nc, size_t kc, size_t nr,
size_t kr, size_t sr, size_t k_stride,
const int8_t* k, const float* bias,
const float* scale, void* packed_weights,
size_t extra_bytes, const void* params) {
assert(g != 0);
assert(nr >= sr);
assert(k != nullptr);
assert(packed_weights != nullptr);
const int32_t* b = (const int32_t*)bias;
const size_t skr = sr * kr;
do {
for (size_t nr_block_start = 0; nr_block_start < nc; nr_block_start += nr) {
const size_t nr_block_size = min(nc - nr_block_start, nr);
unaligned_int32_t* packed_b = (unaligned_int32_t*)packed_weights;
copy_bias(b, nr_block_start, nr_block_size, packed_b);
packed_weights = (int32_t*)packed_weights + nr;
for (size_t kr_block_start = 0; kr_block_start < round_up_po2(kc, skr);
kr_block_start += kr) {
for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size;
nr_block_offset++) {
const size_t kc_begin =
round_down_po2(kr_block_start, skr) +
((kr_block_start + nr_block_offset * kr) & (skr - 1));
for (size_t kr_block_offset = 0; kr_block_offset < kr;
kr_block_offset++) {
const size_t kc_idx = kc_begin + kr_block_offset;
const int8_t kv =
kc_idx < kc
? k[kc_idx * k_stride + (nr_block_start + nr_block_offset)]
: INT8_C(0);
((int8_t*)packed_weights)[kr_block_offset] = kv;
}
packed_weights = (int8_t*)packed_weights + kr;
}
packed_weights = (int8_t*)packed_weights + (nr - nr_block_size) * kr;
}
packed_weights =
reinterpret_cast<void*>((uintptr_t)packed_weights + extra_bytes);
}
k += nc * kc;
if XNN_UNPREDICTABLE (b != nullptr) {
b += nc;
}
} while (--g != 0);
}
void xnn_pack_f32_conv_goki_w(size_t g, size_t nc, size_t ks, size_t kc,
size_t nr, size_t kr, size_t sr, const float* k,
const float* b, const void* scale,
float* packed_weights, size_t extra_bytes,
const void* params) {
assert(g != 0);
assert(nr >= sr);
assert(k != nullptr);
assert(packed_weights != nullptr);
const size_t skr = sr * kr;
do {
for (size_t nr_block_start = 0; nr_block_start < nc; nr_block_start += nr) {
const size_t nr_block_size = min(nc - nr_block_start, nr);
copy_bias(b, nr_block_start, nr_block_size, packed_weights);
packed_weights += nr;
for (size_t ki = 0; ki < ks; ki++) {
for (size_t kr_block_start = 0; kr_block_start < round_up_po2(kc, skr);
kr_block_start += kr) {
for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size;
nr_block_offset++) {
const size_t kc_begin =
round_down_po2(kr_block_start, skr) +
((kr_block_start + nr_block_offset * kr) & (skr - 1));
const size_t kc_end = std::min(kc, kc_begin + kr);
float* end = packed_weights + kr;
if (kc_begin < kc_end) {
std::copy_n(
&k[((nr_block_start + nr_block_offset) * ks + ki) * kc +
kc_begin],
kc_end - kc_begin, packed_weights);
packed_weights += kc_end - kc_begin;
}
std::fill(packed_weights, end, 0.0f);
packed_weights = end;
}
packed_weights += (nr - nr_block_size) * kr;
}
}
packed_weights = (float*)((uintptr_t)packed_weights + extra_bytes);
}
k += ks * kc * nc;
if XNN_UNPREDICTABLE (b != nullptr) {
b += nc;
}
} while (--g != 0);
}
void xnn_pack_f16_conv_goki_w(size_t g, size_t nc, size_t ks, size_t kc,
size_t nr, size_t kr, size_t sr,
const uint16_t* k, const uint16_t* b,
const void* scale, uint16_t* packed_weights,
size_t extra_bytes, const void* params) {
assert(g != 0);
assert(nr >= sr);
assert(k != nullptr);
assert(packed_weights != nullptr);
const size_t skr = sr * kr;
do {
for (size_t nr_block_start = 0; nr_block_start < nc; nr_block_start += nr) {
const size_t nr_block_size = min(nc - nr_block_start, nr);
copy_bias(b, nr_block_start, nr_block_size, packed_weights);
packed_weights += nr;
for (size_t ki = 0; ki < ks; ki++) {
for (size_t kr_block_start = 0; kr_block_start < round_up_po2(kc, skr);
kr_block_start += kr) {
for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size;
nr_block_offset++) {
const size_t kc_begin =
round_down_po2(kr_block_start, skr) +
((kr_block_start + nr_block_offset * kr) & (skr - 1));
const size_t kc_end = std::min(kc, kc_begin + kr);
uint16_t* end = packed_weights + kr;
if (kc_begin < kc_end) {
std::copy_n(
&k[((nr_block_start + nr_block_offset) * ks + ki) * kc +
kc_begin],
kc_end - kc_begin, packed_weights);
packed_weights += kc_end - kc_begin;
}
std::fill(packed_weights, end, UINT16_C(0));
packed_weights = end;
}
packed_weights += (nr - nr_block_size) * kr;
}
}
packed_weights = (uint16_t*)((uintptr_t)packed_weights + extra_bytes);
}
k += ks * kc * nc;
if XNN_UNPREDICTABLE (b != nullptr) {
b += nc;
}
} while (--g != 0);
}
void xnn_pack_f32_to_f16_conv_goki_w(size_t g, size_t nc, size_t ks, size_t kc,
size_t nr, size_t kr, size_t sr,
const float* k, const float* b,
const void* scale,
xnn_float16* packed_weights,
size_t extra_bytes, const void* params) {
assert(g != 0);
assert(nr >= sr);
assert(k != nullptr);
assert(packed_weights != nullptr);
const size_t skr = sr * kr;
do {
for (size_t nr_block_start = 0; nr_block_start < nc; nr_block_start += nr) {
const size_t nr_block_size = min(nc - nr_block_start, nr);
copy_bias(b, nr_block_start, nr_block_size, packed_weights);
packed_weights += nr;
for (size_t ki = 0; ki < ks; ki++) {
for (size_t kr_block_start = 0; kr_block_start < round_up_po2(kc, skr);
kr_block_start += kr) {
for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size;
nr_block_offset++) {
const size_t kc_begin =
round_down_po2(kr_block_start, skr) +
((kr_block_start + nr_block_offset * kr) & (skr - 1));
const size_t kc_end = std::min(kc, kc_begin + kr);
xnn_float16* end = packed_weights + kr;
if (kc_begin < kc_end) {
std::copy_n(
&k[((nr_block_start + nr_block_offset) * ks + ki) * kc +
kc_begin],
kc_end - kc_begin, packed_weights);
packed_weights += kc_end - kc_begin;
}
std::fill(packed_weights, end, 0.0f);
packed_weights = end;
}
packed_weights += (nr - nr_block_size) * kr;
}
}
packed_weights = (xnn_float16*)((uintptr_t)packed_weights + extra_bytes);
}
k += ks * kc * nc;
if XNN_UNPREDICTABLE (b != nullptr) {
b += nc;
}
} while (--g != 0);
}
void xnn_pack_qu8_conv_goki_w(size_t g, size_t nc, size_t ks, size_t kc,
size_t nr, size_t kr, size_t sr, const uint8_t* k,
const int32_t* b, const void* scale,
void* packed_weights, size_t extra_bytes,
const struct xnn_qu8_packing_params* params) {
assert(g != 0);
assert(nr >= sr);
assert(k != nullptr);
assert(packed_weights != nullptr);
const size_t skr = sr * kr;
const int32_t izp = (int32_t)params->input_zero_point;
const int32_t bzp =
(int32_t)ks * (int32_t)kc * izp * (int32_t)params->kernel_zero_point;
do {
for (size_t nr_block_start = 0; nr_block_start < nc; nr_block_start += nr) {
const size_t nr_block_size = min(nc - nr_block_start, nr);
unaligned_int32_t* packed_b = (unaligned_int32_t*)packed_weights;
copy_bias(b, nr_block_start, nr_block_size, packed_b, bzp);
packed_weights =
(void*)((uintptr_t)packed_weights + nr * sizeof(int32_t));
for (size_t ki = 0; ki < ks; ki++) {
for (size_t kr_block_start = 0; kr_block_start < round_up_po2(kc, skr);
kr_block_start += kr) {
for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size;
nr_block_offset++) {
const size_t kc_begin =
round_down_po2(kr_block_start, skr) +
((kr_block_start + nr_block_offset * kr) & (skr - 1));
const size_t kc_end = std::min(kc, kc_begin + kr);
uint8_t* end = (uint8_t*)packed_weights + kr;
if (kc_begin < kc_end) {
int32_t ksum = copy_n_and_sum(
&k[((nr_block_start + nr_block_offset) * ks + ki) * kc +
kc_begin],
kc_end - kc_begin, (uint8_t*)packed_weights);
packed_weights = (uint8_t*)packed_weights + kc_end - kc_begin;
packed_b[nr_block_offset] =
packed_b[nr_block_offset] - ksum * izp;
}
std::fill((uint8_t*)packed_weights, end, params->kernel_zero_point);
packed_weights = end;
}
packed_weights = (uint8_t*)packed_weights + (nr - nr_block_size) * kr;
}
}
packed_weights =
reinterpret_cast<void*>((uintptr_t)packed_weights + extra_bytes);
}
k += ks * kc * nc;
if XNN_UNPREDICTABLE (b != nullptr) {
b += nc;
}
} while (--g != 0);
}
void xnn_pack_qs8_to_qu8_conv_goki_w(
size_t g, size_t nc, size_t ks, size_t kc, size_t nr, size_t kr, size_t sr,
const int8_t* k, const int32_t* b, const float* scale, void* packed_weights,
size_t extra_bytes, const struct xnn_qs8_packing_params* params) {
assert(g != 0);
assert(nr >= sr);
assert(k != nullptr);
assert(packed_weights != nullptr);
const size_t skr = sr * kr;
const uint32_t izp = (int32_t)params->input_zero_point + 128;
do {
for (size_t nr_block_start = 0; nr_block_start < nc; nr_block_start += nr) {
const size_t nr_block_size = min(nc - nr_block_start, nr);
unaligned_int32_t* packed_b = (unaligned_int32_t*)packed_weights;
copy_bias(b, nr_block_start, nr_block_size, packed_b);
packed_weights =
(void*)((uintptr_t)packed_weights + nr * sizeof(int32_t));
for (size_t ki = 0; ki < ks; ki++) {
for (size_t kr_block_start = 0; kr_block_start < round_up_po2(kc, skr);
kr_block_start += kr) {
for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size;
nr_block_offset++) {
const size_t kc_begin =
round_down_po2(kr_block_start, skr) +
((kr_block_start + nr_block_offset * kr) & (skr - 1));
const size_t kc_end = std::min(kc, kc_begin + kr);
int8_t* end = (int8_t*)packed_weights + kr;
if (kc_begin < kc_end) {
uint32_t ksum = copy_n_and_sum(
&k[((nr_block_start + nr_block_offset) * ks + ki) * kc +
kc_begin],
kc_end - kc_begin, (int8_t*)packed_weights);
packed_weights = (int8_t*)packed_weights + kc_end - kc_begin;
packed_b[nr_block_offset] =
packed_b[nr_block_offset] - ksum * izp;
}
std::fill((int8_t*)packed_weights, end, INT8_C(0));
packed_weights = end;
}
packed_weights = (int8_t*)packed_weights + (nr - nr_block_size) * kr;
}
}
packed_weights =
reinterpret_cast<void*>((uintptr_t)packed_weights + extra_bytes);
}
k += ks * kc * nc;
if XNN_UNPREDICTABLE (b != nullptr) {
b += nc;
}
} while (--g != 0);
}
void xnn_pack_qs8_conv_goki_w(size_t g, size_t nc, size_t ks, size_t kc,
size_t nr, size_t kr, size_t sr, const int8_t* k,
const int32_t* b, const float* scale,
void* packed_weights, size_t extra_bytes,
const struct xnn_qs8_packing_params* params) {
assert(g != 0);
assert(nr >= sr);
assert(k != nullptr);
assert(packed_weights != nullptr);
const size_t skr = sr * kr;
const uint32_t izp = (int32_t)params->input_zero_point;
do {
for (size_t nr_block_start = 0; nr_block_start < nc; nr_block_start += nr) {
const size_t nr_block_size = min(nc - nr_block_start, nr);
unaligned_int32_t* packed_b = (unaligned_int32_t*)packed_weights;
copy_bias(b, nr_block_start, nr_block_size, packed_b);
packed_weights =
(void*)((uintptr_t)packed_weights + nr * sizeof(int32_t));
for (size_t ki = 0; ki < ks; ki++) {
for (size_t kr_block_start = 0; kr_block_start < round_up_po2(kc, skr);
kr_block_start += kr) {
for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size;
nr_block_offset++) {
const size_t kc_begin =
round_down_po2(kr_block_start, skr) +
((kr_block_start + nr_block_offset * kr) & (skr - 1));
const size_t kc_end = std::min(kc, kc_begin + kr);
int8_t* end = (int8_t*)packed_weights + kr;
if (kc_begin < kc_end) {
uint32_t ksum = copy_n_and_sum(
&k[((nr_block_start + nr_block_offset) * ks + ki) * kc +
kc_begin],
kc_end - kc_begin, (int8_t*)packed_weights);
packed_weights = (int8_t*)packed_weights + kc_end - kc_begin;
packed_b[nr_block_offset] =
packed_b[nr_block_offset] - ksum * izp;
}
std::fill((int8_t*)packed_weights, end, INT8_C(0));
packed_weights = end;
}
packed_weights = (int8_t*)packed_weights + (nr - nr_block_size) * kr;
}
}
packed_weights =
reinterpret_cast<void*>((uintptr_t)packed_weights + extra_bytes);
}
k += ks * kc * nc;
if XNN_UNPREDICTABLE (b != nullptr) {
b += nc;
}
} while (--g != 0);
}
void xnn_pack_f32_conv_kgo_w(size_t g, size_t nc, size_t ks, size_t nr,
size_t kr, size_t sr, const float* k,
const float* b, const void* scale,
float* packed_weights, size_t extra_bytes,
const void* params) {
assert(g != 0);
assert(nr >= sr);
assert(k != nullptr);
assert(packed_weights != nullptr);
for (size_t i = 0; i < g; i++) {
for (size_t nr_block_start = 0; nr_block_start < nc; nr_block_start += nr) {
const size_t nr_block_size = min(nc - nr_block_start, nr);
copy_bias(b, nr_block_start, nr_block_size, packed_weights);
packed_weights += nr;
for (size_t ki = 0; ki < ks; ki++) {
for (size_t sr_block_offset = 0; sr_block_offset < sr;
sr_block_offset++) {
// TODO(unassigned): Is there a more precise zeroing we could do here?
std::fill_n(packed_weights, nr * kr, 0.0f);
for (size_t nr_block_offset = (-sr_block_offset) & (sr - 1);
nr_block_offset < nr_block_size; nr_block_offset += sr) {
packed_weights[nr_block_offset * kr] =
k[ki * g * nc + (nr_block_start + nr_block_offset)];
}
packed_weights += nr * kr;
}
}
packed_weights = (float*)((uintptr_t)packed_weights + extra_bytes);
}
k += nc;
if XNN_UNPREDICTABLE (b != nullptr) {
b += nc;
}
}
}
void xnn_pack_f16_conv_kgo_w(size_t g, size_t nc, size_t ks, size_t nr,
size_t kr, size_t sr, const uint16_t* k,
const uint16_t* b, const void* scale,
uint16_t* packed_weights, size_t extra_bytes,
const void* params) {
assert(g != 0);
assert(nr >= sr);
assert(k != nullptr);
assert(packed_weights != nullptr);
for (size_t i = 0; i < g; i++) {
for (size_t nr_block_start = 0; nr_block_start < nc; nr_block_start += nr) {
const size_t nr_block_size = min(nc - nr_block_start, nr);
copy_bias(b, nr_block_start, nr_block_size, packed_weights);
packed_weights += nr;
for (size_t ki = 0; ki < ks; ki++) {
for (size_t sr_block_offset = 0; sr_block_offset < sr;
sr_block_offset++) {
// TODO(unassigned): Is there a more precise zeroing we could do here?
std::fill_n(packed_weights, nr * kr, UINT16_C(0));
for (size_t nr_block_offset = (-sr_block_offset) & (sr - 1);
nr_block_offset < nr_block_size; nr_block_offset += sr) {
packed_weights[nr_block_offset * kr] =
k[ki * g * nc + (nr_block_start + nr_block_offset)];
}
packed_weights += nr * kr;
}
}
packed_weights = (uint16_t*)((uintptr_t)packed_weights + extra_bytes);
}
k += nc;
if XNN_UNPREDICTABLE (b != nullptr) {
b += nc;
}
}
}
void xnn_pack_f32_to_f16_conv_kgo_w(size_t g, size_t nc, size_t ks, size_t nr,
size_t kr, size_t sr, const float* k,
const float* b, const void* scale,
xnn_float16* packed_weights,
size_t extra_bytes, const void* params) {
assert(g != 0);
assert(nr >= sr);
assert(k != nullptr);
assert(packed_weights != nullptr);
for (size_t i = 0; i < g; i++) {
for (size_t nr_block_start = 0; nr_block_start < nc; nr_block_start += nr) {
const size_t nr_block_size = min(nc - nr_block_start, nr);
copy_bias(b, nr_block_start, nr_block_size, packed_weights);
packed_weights += nr;
for (size_t ki = 0; ki < ks; ki++) {
for (size_t sr_block_offset = 0; sr_block_offset < sr;
sr_block_offset++) {
// TODO(unassigned): Is there a more precise zeroing we could do here?
std::fill_n(packed_weights, nr * kr, static_cast<xnn_float16>(0.0f));
for (size_t nr_block_offset = (-sr_block_offset) & (sr - 1);
nr_block_offset < nr_block_size; nr_block_offset += sr) {
packed_weights[nr_block_offset * kr] = xnn_float16_from_float(
k[ki * g * nc + (nr_block_start + nr_block_offset)]);
}
packed_weights += nr * kr;
}
}
packed_weights = (xnn_float16*)((uintptr_t)packed_weights + extra_bytes);
}
k += nc;
if XNN_UNPREDICTABLE (b != nullptr) {
b += nc;
}
}
}
void xnn_pack_qu8_conv_kgo_w(size_t g, size_t nc, size_t ks, size_t nr,
size_t kr, size_t sr, const uint8_t* k,
const int32_t* b, const void* scale,
void* packed_weights, size_t extra_bytes,
const struct xnn_qu8_packing_params* params) {
assert(g != 0);
assert(nr >= sr);
assert(k != nullptr);
assert(packed_weights != nullptr);
const int32_t izp = (int32_t)params->input_zero_point;
const int32_t bzp = (int32_t)ks * izp * (int32_t)params->kernel_zero_point;
for (size_t i = 0; i < g; i++) {
for (size_t nr_block_start = 0; nr_block_start < nc; nr_block_start += nr) {
const size_t nr_block_size = min(nc - nr_block_start, nr);
unaligned_int32_t* packed_b = (unaligned_int32_t*)packed_weights;
copy_bias(b, nr_block_start, nr_block_size, packed_b, bzp);
packed_weights =
(void*)((uintptr_t)packed_weights + nr * sizeof(int32_t));
for (size_t ki = 0; ki < ks; ki++) {
for (size_t sr_block_offset = 0; sr_block_offset < sr;
sr_block_offset++) {
// TODO(unassigned): Is there a more precise zeroing we could do here?
std::fill_n((uint8_t*)packed_weights, nr * kr,
params->kernel_zero_point);
for (size_t nr_block_offset = (-sr_block_offset) & (sr - 1);
nr_block_offset < nr_block_size; nr_block_offset += sr) {
const uint8_t kv =
k[ki * g * nc + (nr_block_start + nr_block_offset)];
((uint8_t*)packed_weights)[nr_block_offset * kr] = kv;
packed_b[nr_block_offset] =
packed_b[nr_block_offset] - (int32_t)kv * izp;
}
packed_weights = (uint8_t*)packed_weights + nr * kr;
}
}
packed_weights =
reinterpret_cast<void*>((uintptr_t)packed_weights + extra_bytes);
}
k += nc;
if XNN_UNPREDICTABLE (b != nullptr) {
b += nc;
}
}
}
void pack_qs8_conv_kgo_w(size_t g, size_t nc, size_t ks, size_t nr, size_t kr,
size_t sr, const int8_t* k, const int32_t* b,
const float* scale, void* packed_weights,
size_t extra_bytes, int32_t zero_point_offset,
const struct xnn_qs8_packing_params* params) {
assert(g != 0);
assert(nr >= sr);
assert(k != nullptr);
assert(packed_weights != nullptr);
const uint32_t izp = (uint32_t)params->input_zero_point + zero_point_offset;
for (size_t i = 0; i < g; i++) {
for (size_t nr_block_start = 0; nr_block_start < nc; nr_block_start += nr) {
const size_t nr_block_size = min(nc - nr_block_start, nr);
unaligned_int32_t* packed_b = (unaligned_int32_t*)packed_weights;
copy_bias(b, nr_block_start, nr_block_size, packed_b);
packed_weights =
(void*)((uintptr_t)packed_weights + nr * sizeof(int32_t));
for (size_t ki = 0; ki < ks; ki++) {
for (size_t sr_block_offset = 0; sr_block_offset < sr;
sr_block_offset++) {
// TODO(unassigned): Is there a more precise zeroing we could do here?
std::fill_n((int8_t*)packed_weights, nr * kr, INT8_C(0));
for (size_t nr_block_offset = (-sr_block_offset) & (sr - 1);
nr_block_offset < nr_block_size; nr_block_offset += sr) {
const int8_t kv =
k[ki * g * nc + (nr_block_start + nr_block_offset)];
((int8_t*)packed_weights)[nr_block_offset * kr] = kv;
packed_b[nr_block_offset] =
packed_b[nr_block_offset] - (uint32_t)kv * izp;
}
packed_weights = (int8_t*)packed_weights + nr * kr;
}
}
packed_weights =
reinterpret_cast<void*>((uintptr_t)packed_weights + extra_bytes);
}
k += nc;
if XNN_UNPREDICTABLE (b != nullptr) {
b += nc;
}
}
}
void xnn_pack_qs8_conv_kgo_w(size_t g, size_t nc, size_t ks, size_t nr,
size_t kr, size_t sr, const int8_t* k,
const int32_t* b, const float* scale,
void* packed_weights, size_t extra_bytes,
const struct xnn_qs8_packing_params* params) {
pack_qs8_conv_kgo_w(g, nc, ks, nr, kr, sr, k, b, scale, packed_weights,
extra_bytes, /*zero_point_offset=*/0, params);
}
void xnn_pack_qs8_to_qu8_conv_kgo_w(
size_t g, size_t nc, size_t ks, size_t nr, size_t kr, size_t sr,
const int8_t* k, const int32_t* b, const float* scale, void* packed_weights,
size_t extra_bytes, const struct xnn_qs8_packing_params* params) {
pack_qs8_conv_kgo_w(g, nc, ks, nr, kr, sr, k, b, scale, packed_weights,
extra_bytes, /*zero_point_offset=*/128, params);
}
void xnn_pack_f32_deconv_goki_w(size_t g, size_t nc, size_t kh, size_t kw,
size_t kc, size_t sh, size_t sw, size_t nr,
size_t kr, size_t sr, const float* k,
const float* b, const void* scale,
float* packed_weights, size_t extra_bytes,
struct subconvolution_params* subconv_params,
const void* params) {
assert(g != 0);
assert(nr >= sr);
assert(k != nullptr);
assert(packed_weights != nullptr);
const size_t skr = sr * kr;
for (size_t i = 0; i < g; i++) {
for (size_t oy = 0; oy < sh; oy++) {
for (size_t ox = 0; ox < sw; ox++) {
if (i == 0) {
(*subconv_params++).weights = packed_weights;
}
for (size_t nr_block_start = 0; nr_block_start < nc;
nr_block_start += nr) {
const size_t nr_block_size = min(nc - nr_block_start, nr);
copy_bias(b, nr_block_start, nr_block_size, packed_weights);
packed_weights += nr;
for (size_t ky = oy; ky < kh; ky += sh) {
for (size_t kx = ox; kx < kw; kx += sw) {
for (size_t kr_block_start = 0;
kr_block_start < round_up_po2(kc, skr);
kr_block_start += kr) {
for (size_t nr_block_offset = 0;
nr_block_offset < nr_block_size; nr_block_offset++) {
const size_t kc_begin =
round_down_po2(kr_block_start, skr) +
((kr_block_start + nr_block_offset * kr) & (skr - 1));
const size_t kc_end = std::min(kc, kc_begin + kr);
float* end = packed_weights + kr;
if (kc_begin < kc_end) {
std::copy_n(
&k[(((nr_block_start + nr_block_offset) * kh + ky) *
kw +
kx) *
kc +
kc_begin],
kc_end - kc_begin, packed_weights);
packed_weights += kc_end - kc_begin;
}
std::fill(packed_weights, end, 0.0f);
packed_weights = end;
}
packed_weights += (nr - nr_block_size) * kr;
}
}
}
packed_weights =
reinterpret_cast<float*>((uintptr_t)packed_weights + extra_bytes);
}
}
}
k += kh * kw * kc * nc;
if XNN_UNPREDICTABLE (b != nullptr) {
b += nc;
}
}
}
void xnn_pack_f16_deconv_goki_w(size_t g, size_t nc, size_t kh, size_t kw,
size_t kc, size_t sh, size_t sw, size_t nr,
size_t kr, size_t sr, const uint16_t* k,
const uint16_t* b, const void* scale,
uint16_t* packed_weights, size_t extra_bytes,
struct subconvolution_params* subconv_params,
const void* params) {
assert(g != 0);
assert(nr >= sr);
assert(k != nullptr);
assert(packed_weights != nullptr);
const size_t skr = sr * kr;
for (size_t i = 0; i < g; i++) {
for (size_t oy = 0; oy < sh; oy++) {
for (size_t ox = 0; ox < sw; ox++) {
if (i == 0) {
(*subconv_params++).weights = packed_weights;
}
for (size_t nr_block_start = 0; nr_block_start < nc;
nr_block_start += nr) {
const size_t nr_block_size = min(nc - nr_block_start, nr);
copy_bias(b, nr_block_start, nr_block_size, packed_weights);
packed_weights += nr;
for (size_t ky = oy; ky < kh; ky += sh) {
for (size_t kx = ox; kx < kw; kx += sw) {
for (size_t kr_block_start = 0;
kr_block_start < round_up_po2(kc, skr);
kr_block_start += kr) {
for (size_t nr_block_offset = 0;
nr_block_offset < nr_block_size; nr_block_offset++) {
const size_t kc_begin =
round_down_po2(kr_block_start, skr) +
((kr_block_start + nr_block_offset * kr) & (skr - 1));
const size_t kc_end = std::min(kc, kc_begin + kr);
uint16_t* end = packed_weights + kr;
if (kc_begin < kc_end) {
std::copy_n(
&k[(((nr_block_start + nr_block_offset) * kh + ky) *
kw +
kx) *
kc +
kc_begin],
kc_end - kc_begin, packed_weights);
packed_weights += kc_end - kc_begin;
}
std::fill(packed_weights, end, UINT16_C(0));
packed_weights = end;
}
packed_weights += (nr - nr_block_size) * kr;
}
}
}
packed_weights = reinterpret_cast<uint16_t*>(
(uintptr_t)packed_weights + extra_bytes);
}
}
}
k += kh * kw * kc * nc;
if XNN_UNPREDICTABLE (b != nullptr) {
b += nc;
}
}
}
void xnn_pack_f32_to_f16_deconv_goki_w(
size_t g, size_t nc, size_t kh, size_t kw, size_t kc, size_t sh, size_t sw,
size_t nr, size_t kr, size_t sr, const float* k, const float* b,
const void* scale, xnn_float16* packed_weights, size_t extra_bytes,
struct subconvolution_params* subconv_params, const void* params) {
assert(g != 0);
assert(nr >= sr);
assert(k != nullptr);
assert(packed_weights != nullptr);
const size_t skr = sr * kr;
for (size_t i = 0; i < g; i++) {
for (size_t oy = 0; oy < sh; oy++) {
for (size_t ox = 0; ox < sw; ox++) {
if (i == 0) {
(*subconv_params++).weights = packed_weights;
}
for (size_t nr_block_start = 0; nr_block_start < nc;
nr_block_start += nr) {
const size_t nr_block_size = min(nc - nr_block_start, nr);
copy_bias(b, nr_block_start, nr_block_size, packed_weights);
packed_weights += nr;
for (size_t ky = oy; ky < kh; ky += sh) {
for (size_t kx = ox; kx < kw; kx += sw) {
for (size_t kr_block_start = 0;
kr_block_start < round_up_po2(kc, skr);
kr_block_start += kr) {
for (size_t nr_block_offset = 0;
nr_block_offset < nr_block_size; nr_block_offset++) {
const size_t kc_begin =
round_down_po2(kr_block_start, skr) +
((kr_block_start + nr_block_offset * kr) & (skr - 1));
const size_t kc_end = std::min(kc, kc_begin + kr);
xnn_float16* end = packed_weights + kr;
if (kc_begin < kc_end) {
std::copy_n(
&k[(((nr_block_start + nr_block_offset) * kh + ky) *
kw +
kx) *
kc +
kc_begin],
kc_end - kc_begin, packed_weights);
packed_weights += kc_end - kc_begin;
}
std::fill(packed_weights, end,
static_cast<xnn_float16>(0.0f));
packed_weights = end;
}
packed_weights += (nr - nr_block_size) * kr;
}
}
}
packed_weights = reinterpret_cast<xnn_float16*>(
(uintptr_t)packed_weights + extra_bytes);
}
}
}
k += kh * kw * kc * nc;
if XNN_UNPREDICTABLE (b != nullptr) {
b += nc;
}
}
}
void pack_qs8_deconv_goki_w(size_t groups, size_t nc, size_t kh, size_t kw,
size_t kc, size_t sh, size_t sw, size_t nr,
size_t kr, size_t sr, const int8_t* k,
const int32_t* b, const float* scale,
void* packed_weights, size_t extra_bytes,
int32_t zero_point_offset,
struct subconvolution_params* subconv_params,
const struct xnn_qs8_packing_params* params) {
assert(groups != 0);
assert(nr >= sr);
assert(k != nullptr);
assert(packed_weights != nullptr);
const size_t skr = sr * kr;
const uint32_t izp = (uint32_t)params->input_zero_point + zero_point_offset;
for (size_t i = 0; i < groups; i++) {
for (size_t oy = 0; oy < sh; oy++) {
for (size_t ox = 0; ox < sw; ox++) {
if (i == 0) {
(*subconv_params++).weights = packed_weights;
}
for (size_t nr_block_start = 0; nr_block_start < nc;
nr_block_start += nr) {
const size_t nr_block_size = min(nc - nr_block_start, nr);
unaligned_int32_t* packed_b = (unaligned_int32_t*)packed_weights;
copy_bias(b, nr_block_start, nr_block_size, packed_b);
packed_weights =
(void*)((uintptr_t)packed_weights + nr * sizeof(int32_t));
for (size_t ky = oy; ky < kh; ky += sh) {
for (size_t kx = ox; kx < kw; kx += sw) {
for (size_t kr_block_start = 0;
kr_block_start < round_up_po2(kc, skr);
kr_block_start += kr) {
for (size_t nr_block_offset = 0;
nr_block_offset < nr_block_size; nr_block_offset++) {
const size_t kc_begin =
round_down_po2(kr_block_start, skr) +
((kr_block_start + nr_block_offset * kr) & (skr - 1));
const size_t kc_end = std::min(kc, kc_begin + kr);
int8_t* end = (int8_t*)packed_weights + kr;
if (kc_begin < kc_end) {
uint32_t ksum = copy_n_and_sum(
&k[(((nr_block_start + nr_block_offset) * kh + ky) *
kw +
kx) *
kc +
kc_begin],
kc_end - kc_begin, (int8_t*)packed_weights);
packed_b[nr_block_offset] =
packed_b[nr_block_offset] - ksum * izp;
packed_weights =
(int8_t*)packed_weights + kc_end - kc_begin;
}
std::fill((int8_t*)packed_weights, end, INT8_C(0));
packed_weights = end;
}
packed_weights =
(int8_t*)packed_weights + (nr - nr_block_size) * kr;
}
}
}
packed_weights =
reinterpret_cast<void*>((uintptr_t)packed_weights + extra_bytes);
}
}
}
k += kh * kw * kc * nc;
if XNN_UNPREDICTABLE (b != nullptr) {
b += nc;
}
}
}
void xnn_pack_qs8_deconv_goki_w(size_t g, size_t nc, size_t kh, size_t kw,
size_t kc, size_t sh, size_t sw, size_t nr,
size_t kr, size_t sr, const int8_t* k,
const int32_t* b, const float* scale,
void* packed_weights, size_t extra_bytes,
struct subconvolution_params* subconv_params,
const struct xnn_qs8_packing_params* params) {
pack_qs8_deconv_goki_w(g, nc, kh, kw, kc, sh, sw, nr, kr, sr, k, b, scale,
packed_weights, extra_bytes, /*zero_point_offset=*/0,
subconv_params, params);
}
void xnn_pack_qs8_to_qu8_deconv_goki_w(
size_t g, size_t nc, size_t kh, size_t kw, size_t kc, size_t sh, size_t sw,
size_t nr, size_t kr, size_t sr, const int8_t* k, const int32_t* b,
const float* scale, void* packed_weights, size_t extra_bytes,
struct subconvolution_params* subconv_params,
const struct xnn_qs8_packing_params* params) {
pack_qs8_deconv_goki_w(g, nc, kh, kw, kc, sh, sw, nr, kr, sr, k, b, scale,
packed_weights, extra_bytes, /*zero_point_offset=*/128,
subconv_params, params);
}
void xnn_pack_qu8_deconv_goki_w(size_t g, size_t nc, size_t kh, size_t kw,
size_t kc, size_t sh, size_t sw, size_t nr,
size_t kr, size_t sr, const uint8_t* k,
const int32_t* b, const void* scale,
void* packed_weights, size_t extra_bytes,
struct subconvolution_params* subconv_params,
const struct xnn_qu8_packing_params* params) {
assert(g != 0);
assert(nr >= sr);
assert(k != nullptr);
assert(packed_weights != nullptr);
const size_t skr = sr * kr;
const int32_t izp = (int32_t)params->input_zero_point;
const int32_t kzp = (int32_t)params->kernel_zero_point;
for (size_t i = 0; i < g; i++) {
for (size_t oy = 0; oy < sh; oy++) {
for (size_t ox = 0; ox < sw; ox++) {
if (i == 0) {
(*subconv_params++).weights = packed_weights;
}
const int32_t bzp = (int32_t)divide_round_up(kh - oy, sh) *
(int32_t)divide_round_up(kw - ox, sw) *
(int32_t)kc * izp * kzp;
for (size_t nr_block_start = 0; nr_block_start < nc;
nr_block_start += nr) {
const size_t nr_block_size = min(nc - nr_block_start, nr);
unaligned_int32_t* packed_b = (unaligned_int32_t*)packed_weights;
copy_bias(b, nr_block_start, nr_block_size, packed_b, bzp);
packed_weights =
(void*)((uintptr_t)packed_weights + nr * sizeof(int32_t));
for (size_t ky = oy; ky < kh; ky += sh) {
for (size_t kx = ox; kx < kw; kx += sw) {
for (size_t kr_block_start = 0;
kr_block_start < round_up_po2(kc, skr);
kr_block_start += kr) {
for (size_t nr_block_offset = 0;
nr_block_offset < nr_block_size; nr_block_offset++) {
const size_t kc_begin =
round_down_po2(kr_block_start, skr) +
((kr_block_start + nr_block_offset * kr) & (skr - 1));
const size_t kc_end = std::min(kc, kc_begin + kr);
uint8_t* end = (uint8_t*)packed_weights + kr;
if (kc_begin < kc_end) {
int32_t ksum = copy_n_and_sum(
&k[(((nr_block_start + nr_block_offset) * kh + ky) *
kw +
kx) *
kc +
kc_begin],
kc_end - kc_begin, (uint8_t*)packed_weights);
packed_b[nr_block_offset] =
packed_b[nr_block_offset] - ksum * izp;
packed_weights =
(uint8_t*)packed_weights + kc_end - kc_begin;
}
std::fill((uint8_t*)packed_weights, end,
params->kernel_zero_point);
packed_weights = end;
}
packed_weights =
(uint8_t*)packed_weights + (nr - nr_block_size) * kr;
}
}
}
packed_weights =
reinterpret_cast<void*>((uintptr_t)packed_weights + extra_bytes);
}
}
}
k += kh * kw * kc * nc;
if XNN_UNPREDICTABLE (b != nullptr) {
b += nc;
}
}
}
// Helper function to advance x and y indices.
inline static void advance_x_y(size_t h, size_t* x, size_t* y) {
if (++*y == h) {
*y = 0;
++*x;
}
}
void xnn_pack_f32_dwconv_ghw_w(size_t primary_tile, size_t h, size_t w,
size_t c, size_t channel_tile, const float* k,
const float* b, const void* scale,
float* packed_weights,
size_t per_tile_extra_bytes,
const void* params) {
assert(k != nullptr);
assert(packed_weights != nullptr);
size_t kernel_size = h * w;
assert(kernel_size <= primary_tile);
for (size_t cr_block_start = 0; cr_block_start < c;
cr_block_start += channel_tile) {
const size_t cr_block_size = min(c - cr_block_start, channel_tile);
copy_bias(b, cr_block_start, cr_block_size, packed_weights);
packed_weights += channel_tile;
// Stores the x and y index that should be processed next.
size_t x = 0;
size_t y = 0;
for (size_t i = 0; i < kernel_size; i++) {
for (size_t cr_block_offset = 0; cr_block_offset < cr_block_size;
cr_block_offset++) {
const float kv =
k[((cr_block_start + cr_block_offset) * h + y) * w + x];
*packed_weights++ = kv;
}
packed_weights += channel_tile - cr_block_size;
advance_x_y(h, &x, &y);
}
std::fill_n(packed_weights, (primary_tile - kernel_size) * channel_tile,
0.0f);
packed_weights += (primary_tile - kernel_size) * cr_block_size;
}
}
void xnn_pack_f16_dwconv_ghw_w(size_t primary_tile, size_t h, size_t w,
size_t c, size_t channel_tile, const uint16_t* k,
const uint16_t* b, const void* scale,
uint16_t* packed_weights,
size_t per_tile_extra_bytes,
const void* params) {
assert(k != nullptr);
assert(packed_weights != nullptr);
size_t kernel_size = h * w;
for (size_t cr_block_start = 0; cr_block_start < c;
cr_block_start += channel_tile) {
const size_t cr_block_size = min(c - cr_block_start, channel_tile);
copy_bias(b, cr_block_start, cr_block_size, packed_weights);
packed_weights += channel_tile;
// Stores the x and y index that should be processed next.
size_t x = 0;
size_t y = 0;
for (size_t i = 0; i < kernel_size; i++) {
for (size_t cr_block_offset = 0; cr_block_offset < cr_block_size;
cr_block_offset++) {
const uint16_t kv =
k[((cr_block_start + cr_block_offset) * h + y) * w + x];
*packed_weights++ = kv;
}
packed_weights += channel_tile - cr_block_size;
advance_x_y(h, &x, &y);
}
std::fill_n(packed_weights, (primary_tile - kernel_size) * channel_tile,
UINT16_C(0));
packed_weights += (primary_tile - kernel_size) * cr_block_size;
}
}
void xnn_pack_f32_to_f16_dwconv_ghw_w(size_t primary_tile, size_t h, size_t w,
size_t c, size_t channel_tile,
const float* k, const float* b,
const void* scale,
xnn_float16* packed_weights,
size_t per_tile_extra_bytes,
const void* params) {
assert(k != nullptr);
assert(packed_weights != nullptr);
size_t kernel_size = h * w;
for (size_t cr_block_start = 0; cr_block_start < c;
cr_block_start += channel_tile) {
const size_t cr_block_size = min(c - cr_block_start, channel_tile);
copy_bias(b, cr_block_start, cr_block_size, packed_weights);
packed_weights += channel_tile;
// Stores the x and y index that should be processed next.
size_t x = 0;
size_t y = 0;
for (size_t i = 0; i < kernel_size; i++) {
for (size_t cr_block_offset = 0; cr_block_offset < cr_block_size;
cr_block_offset++) {
const xnn_float16 kv = xnn_float16_from_float(
k[((cr_block_start + cr_block_offset) * h + y) * w + x]);
*packed_weights++ = kv;
}
packed_weights += channel_tile - cr_block_size;
advance_x_y(h, &x, &y);
}
std::fill_n(packed_weights, (primary_tile - kernel_size) * channel_tile,
static_cast<xnn_float16>(0.0f));
packed_weights += (primary_tile - kernel_size) * cr_block_size;
}
}
void xnn_pack_qu8_dwconv_ghw_w(size_t primary_tile, size_t h, size_t w,
size_t c, size_t channel_tile, const uint8_t* k,
const int32_t* b, const void* scale,
void* packed_weights,
size_t per_tile_extra_bytes,
const struct xnn_qu8_packing_params* params) {
assert(k != nullptr);
assert(packed_weights != nullptr);
size_t kernel_size = h * w;
const int32_t izp = (int32_t)params->input_zero_point;
const int32_t boff =
(int32_t)h * (int32_t)w * izp * (int32_t)params->kernel_zero_point;
for (size_t cr_block_start = 0; cr_block_start < c;
cr_block_start += channel_tile) {
unaligned_int32_t* packed_b = (unaligned_int32_t*)packed_weights;
const size_t cr_block_size = min(c - cr_block_start, channel_tile);
copy_bias(b, cr_block_start, cr_block_size, packed_b, boff);
packed_weights =
(void*)((uintptr_t)packed_weights + channel_tile * sizeof(int32_t));
// Biases need to be offset by all kernel values.
for (size_t x = 0; x < w; x++) {
for (size_t y = 0; y < h; y++) {
for (size_t cr_block_offset = 0; cr_block_offset < cr_block_size;
cr_block_offset++) {
const uint8_t kv =
k[((cr_block_start + cr_block_offset) * h + y) * w + x];
unaligned_indexed_store_s32(
packed_b, cr_block_offset,
unaligned_indexed_load_s32(packed_b, cr_block_offset) -
(int32_t)kv * izp);
}
}
}
// Stores the x and y index that should be processed next.
size_t x = 0;
size_t y = 0;
for (size_t i = 0; i < kernel_size; i++) {
for (size_t cr_block_offset = 0; cr_block_offset < cr_block_size;
cr_block_offset++) {
const uint8_t kv =
k[((cr_block_start + cr_block_offset) * h + y) * w + x];
*((uint8_t*)packed_weights) = kv;
packed_weights = (void*)((uintptr_t)packed_weights + sizeof(uint8_t));
}
packed_weights =
(void*)((uintptr_t)packed_weights +
(channel_tile - cr_block_size) * sizeof(uint8_t));
advance_x_y(h, &x, &y);
}
std::fill_n((uint8_t*)packed_weights,
(primary_tile - kernel_size) * channel_tile,
params->kernel_zero_point);
packed_weights = (void*)((uintptr_t)packed_weights +
(primary_tile - kernel_size) * cr_block_size);
}
}
void xnn_pack_qs8_dwconv_ghw_w(size_t primary_tile, size_t h, size_t w,
size_t c, size_t channel_tile, const int8_t* k,
const int32_t* b, const float* scale,
void* packed_weights,
size_t per_tile_extra_bytes,
const struct xnn_qs8_packing_params* params) {
assert(k != nullptr);
assert(packed_weights != nullptr);
size_t kernel_size = h * w;
const uint32_t izp = (uint32_t)params->input_zero_point;
for (size_t cr_block_start = 0; cr_block_start < c;
cr_block_start += channel_tile) {
unaligned_int32_t* packed_b = (unaligned_int32_t*)packed_weights;
const size_t cr_block_size = min(c - cr_block_start, channel_tile);
copy_bias(b, cr_block_start, cr_block_size, packed_b);
packed_weights =
(void*)((uintptr_t)packed_weights + channel_tile * sizeof(int32_t));
// Biases need to be offset by all kernel values.
for (size_t x = 0; x < w; x++) {
for (size_t y = 0; y < h; y++) {
for (size_t cr_block_offset = 0; cr_block_offset < cr_block_size;
cr_block_offset++) {
const int8_t kv =
k[((cr_block_start + cr_block_offset) * h + y) * w + x];
packed_b[cr_block_offset] =
packed_b[cr_block_offset] - (uint32_t)kv * izp;
}
}
}
// Stores the x and y index that should be processed next.
size_t x = 0;
size_t y = 0;
for (size_t i = 0; i < kernel_size; i++) {
for (size_t cr_block_offset = 0; cr_block_offset < cr_block_size;
cr_block_offset++) {
const int8_t kv =
k[((cr_block_start + cr_block_offset) * h + y) * w + x];
*((int8_t*)packed_weights) = kv;
packed_weights = (void*)((uintptr_t)packed_weights + sizeof(int8_t));
}
std::fill_n((int8_t*)packed_weights, channel_tile - cr_block_size,
INT8_C(0));
packed_weights = (void*)((uintptr_t)packed_weights +
(channel_tile - cr_block_size) * sizeof(int8_t));
advance_x_y(h, &x, &y);
}
std::fill_n((int8_t*)packed_weights,
(primary_tile - kernel_size) * channel_tile, INT8_C(0));
packed_weights = (void*)((uintptr_t)packed_weights +
(primary_tile - kernel_size) * cr_block_size);
// We need to pack extra bytes for scale values here.
packed_weights = (void*)((uintptr_t)packed_weights + per_tile_extra_bytes);
}
}
void xnn_pack_f32_dwconv_hwg_w(size_t primary_tile, size_t h, size_t w,
size_t c, size_t channel_tile, const float* k,
const float* b, const void* scale,
float* packed_weights,
size_t per_tile_extra_bytes,
const void* params) {
assert(k != nullptr);
assert(packed_weights != nullptr);
size_t kernel_size = h * w;
for (size_t cr_block_start = 0; cr_block_start < c;
cr_block_start += channel_tile) {
const size_t cr_block_size = min(c - cr_block_start, channel_tile);
copy_bias(b, cr_block_start, cr_block_size, packed_weights);
packed_weights += channel_tile;
// Stores the x and y index that should be processed next.
size_t x = 0;
size_t y = 0;
for (size_t i = 0; i < kernel_size; i++) {
for (size_t cr_block_offset = 0; cr_block_offset < cr_block_size;
cr_block_offset++) {
const float kv =
k[(y * w + x) * c + (cr_block_start + cr_block_offset)];
*packed_weights++ = kv;
}
packed_weights += channel_tile - cr_block_size;
advance_x_y(h, &x, &y);
}
std::fill_n(packed_weights, (primary_tile - kernel_size) * channel_tile,
0.0f);
packed_weights += (primary_tile - kernel_size) * cr_block_size;
}
}
void xnn_pack_f16_dwconv_hwg_w(size_t primary_tile, size_t h, size_t w,
size_t c, size_t channel_tile, const uint16_t* k,
const uint16_t* b, const void* scale,
uint16_t* packed_weights,
size_t per_tile_extra_bytes,
const void* params) {
assert(k != nullptr);
assert(packed_weights != nullptr);
size_t kernel_size = h * w;
for (size_t cr_block_start = 0; cr_block_start < c;
cr_block_start += channel_tile) {
const size_t cr_block_size = min(c - cr_block_start, channel_tile);
copy_bias(b, cr_block_start, cr_block_size, packed_weights);
packed_weights += channel_tile;
// Stores the x and y index that should be processed next.
size_t x = 0;
size_t y = 0;
for (size_t i = 0; i < kernel_size; i++) {
for (size_t cr_block_offset = 0; cr_block_offset < cr_block_size;
cr_block_offset++) {
const uint16_t kv =
k[(y * w + x) * c + (cr_block_start + cr_block_offset)];
*packed_weights++ = kv;
}
packed_weights += channel_tile - cr_block_size;
advance_x_y(h, &x, &y);
}
std::fill_n(packed_weights, (primary_tile - kernel_size) * channel_tile,
UINT16_C(0));
packed_weights += (primary_tile - kernel_size) * cr_block_size;
}
}
void xnn_pack_f32_to_f16_dwconv_hwg_w(size_t primary_tile, size_t h, size_t w,
size_t c, size_t channel_tile,
const float* k, const float* b,
const void* scale,
xnn_float16* packed_weights,
size_t per_tile_extra_bytes,
const void* params) {
assert(k != nullptr);
assert(packed_weights != nullptr);
size_t kernel_size = h * w;
for (size_t cr_block_start = 0; cr_block_start < c;
cr_block_start += channel_tile) {
const size_t cr_block_size = min(c - cr_block_start, channel_tile);
copy_bias(b, cr_block_start, cr_block_size, packed_weights);
packed_weights += channel_tile;
// Stores the x and y index that should be processed next.
size_t x = 0;
size_t y = 0;
for (size_t i = 0; i < kernel_size; i++) {
for (size_t cr_block_offset = 0; cr_block_offset < cr_block_size;
cr_block_offset++) {
const xnn_float16 kv = xnn_float16_from_float(
k[(y * w + x) * c + (cr_block_start + cr_block_offset)]);
*packed_weights++ = kv;
}
packed_weights += channel_tile - cr_block_size;
advance_x_y(h, &x, &y);
}
std::fill_n(packed_weights, (primary_tile - kernel_size) * channel_tile,
xnn_float16_zero());
packed_weights += (primary_tile - kernel_size) * cr_block_size;
}
}
void xnn_pack_qu8_dwconv_hwg_w(size_t primary_tile, size_t h, size_t w,
size_t c, size_t channel_tile, const uint8_t* k,
const int32_t* b, const void* scale,
void* packed_weights,
size_t per_tile_extra_bytes,
const struct xnn_qu8_packing_params* params) {
assert(k != nullptr);
assert(packed_weights != nullptr);
size_t kernel_size = h * w;
const int32_t izp = (int32_t)params->input_zero_point;
const int32_t boff =
(int32_t)h * (int32_t)w * izp * (int32_t)params->kernel_zero_point;
for (size_t cr_block_start = 0; cr_block_start < c;
cr_block_start += channel_tile) {
unaligned_int32_t* packed_b = (unaligned_int32_t*)packed_weights;
const size_t cr_block_size = min(c - cr_block_start, channel_tile);
copy_bias(b, cr_block_start, cr_block_size, packed_b, boff);
packed_weights =
(void*)((uintptr_t)packed_weights + channel_tile * sizeof(int32_t));
// Biases need to be offset by all kernel values.
for (size_t x = 0; x < w; x++) {
for (size_t y = 0; y < h; y++) {
for (size_t cr_block_offset = 0; cr_block_offset < cr_block_size;
cr_block_offset++) {
const uint8_t kv =
k[(y * w + x) * c + (cr_block_start + cr_block_offset)];
unaligned_indexed_store_s32(
packed_b, cr_block_offset,
unaligned_indexed_load_s32(packed_b, cr_block_offset) -
(int32_t)kv * izp);
}
}
}
// Stores the x and y index that should be processed next.
size_t x = 0;
size_t y = 0;
for (size_t i = 0; i < kernel_size; i++) {
for (size_t cr_block_offset = 0; cr_block_offset < cr_block_size;
cr_block_offset++) {
const uint8_t kv =
k[(y * w + x) * c + (cr_block_start + cr_block_offset)];
*((uint8_t*)packed_weights) = kv;
packed_weights = (void*)((uintptr_t)packed_weights + sizeof(uint8_t));
}
packed_weights =
(void*)((uintptr_t)packed_weights +
(channel_tile - cr_block_size) * sizeof(uint8_t));
advance_x_y(h, &x, &y);
}
std::fill_n((uint8_t*)packed_weights,
(primary_tile - kernel_size) * channel_tile,
params->kernel_zero_point);
packed_weights = (void*)((uintptr_t)packed_weights +
(primary_tile - kernel_size) * cr_block_size);
}
}
void xnn_pack_qs8_dwconv_hwg_w(size_t primary_tile, size_t h, size_t w,
size_t c, size_t channel_tile, const int8_t* k,
const int32_t* b, const float* scale,
void* packed_weights,
size_t per_tile_extra_bytes,
const struct xnn_qs8_packing_params* params) {
assert(k != nullptr);
assert(packed_weights != nullptr);
size_t kernel_size = h * w;
const uint32_t izp = (uint32_t)params->input_zero_point;
for (size_t cr_block_start = 0; cr_block_start < c;
cr_block_start += channel_tile) {
unaligned_int32_t* packed_b = (unaligned_int32_t*)packed_weights;
const size_t cr_block_size = min(c - cr_block_start, channel_tile);
copy_bias(b, cr_block_start, cr_block_size, packed_b);
packed_weights =
(void*)((uintptr_t)packed_weights + channel_tile * sizeof(int32_t));
// Biases need to be offset by all kernel values.
for (size_t x = 0; x < w; x++) {
for (size_t y = 0; y < h; y++) {
for (size_t cr_block_offset = 0; cr_block_offset < cr_block_size;
cr_block_offset++) {
const int8_t kv =
k[(y * w + x) * c + (cr_block_start + cr_block_offset)];
packed_b[cr_block_offset] =
packed_b[cr_block_offset] - (uint32_t)kv * izp;
}
}
}
// Stores the x and y index that should be processed next.
size_t x = 0;
size_t y = 0;
for (size_t i = 0; i < kernel_size; i++) {
for (size_t cr_block_offset = 0; cr_block_offset < cr_block_size;
cr_block_offset++) {
const int8_t kv =
k[(y * w + x) * c + (cr_block_start + cr_block_offset)];
*((int8_t*)packed_weights) = kv;
packed_weights = (void*)((uintptr_t)packed_weights + sizeof(int8_t));
}
packed_weights = (void*)((uintptr_t)packed_weights +
(channel_tile - cr_block_size) * sizeof(int8_t));
advance_x_y(h, &x, &y);
}
std::fill_n((int8_t*)packed_weights,
(primary_tile - kernel_size) * channel_tile, INT8_C(0));
packed_weights = (void*)((uintptr_t)packed_weights +
(primary_tile - kernel_size) * cr_block_size);
// We need to pack extra bytes for scale values here.
packed_weights = (void*)((uintptr_t)packed_weights + per_tile_extra_bytes);
}
}
void xnn_pack_f32_dconv_oki_w(size_t nc, size_t kc, size_t nr, size_t kh,
size_t kw, const float* k, const float* b,
float* packed_weights, const void* params) {
assert(k != nullptr);
assert(packed_weights != nullptr);
for (size_t nr_block_start = 0; nr_block_start < nc; nr_block_start += nr) {
const size_t nr_block_size = min(nc - nr_block_start, nr);
if XNN_LIKELY (b != nullptr) {
for (size_t nr_block_offset = 0; nr_block_offset < nr;
nr_block_offset++) {
*packed_weights++ = b[min(nr_block_offset, nr_block_size - 1)];
}
} else {
size_t n = nr;
do {
*packed_weights++ = 0.0f;
} while (--n != 0);
}
for (size_t kx = 0; kx < kw; kx++) {
for (size_t c = 0; c < kc; c++) {
for (size_t ky = 0; ky < kh; ky++) {
for (size_t nr_block_offset = 0; nr_block_offset < nr;
nr_block_offset++) {
*packed_weights++ =
k[(((nr_block_start + min(nr_block_offset, nr_block_size - 1)) *
kh +
ky) *
kw +
kx) *
kc +
c];
}
}
}
}
if XNN_UNPREDICTABLE (b != nullptr) {
b += nr;
}
}
}
void xnn_pack_f32_to_f16_dconv_oki_w(size_t nc, size_t kc, size_t nr, size_t kh,
size_t kw, const float* k, const float* b,
xnn_float16* packed_weights,
const void* params) {
assert(k != nullptr);
assert(packed_weights != nullptr);
for (size_t nr_block_start = 0; nr_block_start < nc; nr_block_start += nr) {
const size_t nr_block_size = min(nc - nr_block_start, nr);
if XNN_LIKELY (b != nullptr) {
for (size_t nr_block_offset = 0; nr_block_offset < nr;
nr_block_offset++) {
*packed_weights++ =
xnn_float16_from_float(b[min(nr_block_offset, nr_block_size - 1)]);
}
} else {
size_t n = nr;
do {
*packed_weights++ = xnn_float16_zero();
} while (--n != 0);
}
for (size_t kx = 0; kx < kw; kx++) {
for (size_t c = 0; c < kc; c++) {
for (size_t ky = 0; ky < kh; ky++) {
for (size_t nr_block_offset = 0; nr_block_offset < nr;
nr_block_offset++) {
*packed_weights++ = xnn_float16_from_float(
k[(((nr_block_start + min(nr_block_offset, nr_block_size - 1)) *
kh +
ky) *
kw +
kx) *
kc +
c]);
}
}
}
}
if XNN_UNPREDICTABLE (b != nullptr) {
b += nr;
}
}
}
void xnn_pack_f16_dconv_oki_w(size_t nc, size_t kc, size_t nr, size_t kh,
size_t kw, const uint16_t* k, const uint16_t* b,
uint16_t* packed_weights, const void* params) {
assert(k != nullptr);
assert(packed_weights != nullptr);
for (size_t nr_block_start = 0; nr_block_start < nc; nr_block_start += nr) {
const size_t nr_block_size = min(nc - nr_block_start, nr);
if XNN_LIKELY (b != nullptr) {
for (size_t nr_block_offset = 0; nr_block_offset < nr;
nr_block_offset++) {
*packed_weights++ = b[min(nr_block_offset, nr_block_size - 1)];
}
} else {
size_t n = nr;
do {
*packed_weights++ = 0;
} while (--n != 0);
}
for (size_t kx = 0; kx < kw; kx++) {
for (size_t c = 0; c < kc; c++) {
for (size_t ky = 0; ky < kh; ky++) {
for (size_t nr_block_offset = 0; nr_block_offset < nr;
nr_block_offset++) {
*packed_weights++ =
k[(((nr_block_start + min(nr_block_offset, nr_block_size - 1)) *
kh +
ky) *
kw +
kx) *
kc +
c];
}
}
}
}
if XNN_UNPREDICTABLE (b != nullptr) {
b += nr;
}
}
}
void xnn_pack_f32_chw_dwconv_ghw_w(size_t kernel_size, size_t groups,
const float* k, const float* b,
float* packed_weights, const void* params) {
assert(k != nullptr);
assert(packed_weights != nullptr);
for (size_t g = 0; g < groups; g++) {
if XNN_LIKELY (b != nullptr) {
*packed_weights = *b++;
} else {
*packed_weights = 0.0f;
}
packed_weights += 1;
for (size_t i = 0; i < kernel_size; i++) {
*packed_weights++ = k[g * kernel_size + i];
}
}
}
void xnn_pack_f32_to_f16_chw_dwconv_ghw_w(size_t kernel_size, size_t groups,
const float* k, const float* b,
xnn_float16* packed_weights,
const void* params) {
assert(k != nullptr);
assert(packed_weights != nullptr);
for (size_t g = 0; g < groups; g++) {
if XNN_LIKELY (b != nullptr) {
*packed_weights = xnn_float16_from_float(*b++);
} else {
*packed_weights = xnn_float16_zero();
}
packed_weights += 1;
for (size_t i = 0; i < kernel_size; i++) {
*packed_weights++ = xnn_float16_from_float(k[g * kernel_size + i]);
}
}
}
void xnn_pack_f16_chw_dwconv_ghw_w(size_t kernel_size, size_t groups,
const uint16_t* k, const uint16_t* b,
uint16_t* packed_weights,
const void* params) {
assert(k != nullptr);
assert(packed_weights != nullptr);
for (size_t g = 0; g < groups; g++) {
if XNN_LIKELY (b != nullptr) {
*packed_weights = *b++;
} else {
*packed_weights = 0;
}
packed_weights += 1;
for (size_t i = 0; i < kernel_size; i++) {
*packed_weights++ = k[g * kernel_size + i];
}
}
}
void xnn_pack_f32_chw_dwconv_hwg_w(size_t kernel_size, size_t groups,
const float* k, const float* b,
float* packed_weights, const void* params) {
assert(k != nullptr);
assert(packed_weights != nullptr);
for (size_t g = 0; g < groups; g++) {
if XNN_LIKELY (b != nullptr) {
*packed_weights = *b++;
} else {
*packed_weights = 0.0f;
}
packed_weights += 1;
for (size_t i = 0; i < kernel_size; i++) {
*packed_weights++ = k[i * groups + g];
}
}
}
void xnn_pack_f16_chw_dwconv_hwg_w(size_t kernel_size, size_t groups,
const uint16_t* k, const uint16_t* b,
uint16_t* packed_weights,
const void* params) {
assert(k != nullptr);
assert(packed_weights != nullptr);
for (size_t g = 0; g < groups; g++) {
if XNN_LIKELY (b != nullptr) {
*packed_weights = *b++;
} else {
*packed_weights = 0;
}
packed_weights += 1;
for (size_t i = 0; i < kernel_size; i++) {
*packed_weights++ = k[i * groups + g];
}
}
}
void xnn_pack_f32_to_f16_chw_dwconv_hwg_w(size_t kernel_size, size_t groups,
const float* k, const float* b,
xnn_float16* packed_weights,
const void* params) {
assert(k != nullptr);
assert(packed_weights != nullptr);
for (size_t g = 0; g < groups; g++) {
if XNN_LIKELY (b != nullptr) {
*packed_weights = xnn_float16_from_float(*b++);
} else {
*packed_weights = xnn_float16_zero();
}
packed_weights += 1;
for (size_t i = 0; i < kernel_size; i++) {
*packed_weights++ = xnn_float16_from_float(k[i * groups + g]);
}
}
}
void xnn_pack_f32_vmulcaddc_w(size_t c, size_t cr, const float* s,
const float* b, float* packed_weights,
const void* params) {
assert(s != nullptr);
assert(packed_weights != nullptr);
for (size_t cr_block_start = 0; cr_block_start < c; cr_block_start += cr) {
const size_t cr_block_size = min(c - cr_block_start, cr);
for (size_t cr_block_offset = 0; cr_block_offset < cr_block_size;
cr_block_offset++) {
*packed_weights++ = s[cr_block_start + cr_block_offset];
}
packed_weights += cr - cr_block_size;
if XNN_LIKELY (b != nullptr) {
for (size_t cr_block_offset = 0; cr_block_offset < cr_block_size;
cr_block_offset++) {
*packed_weights++ = b[cr_block_start + cr_block_offset];
}
} else {
size_t n = cr_block_size;
do {
*packed_weights++ = 0.0f;
} while (--n != 0);
}
packed_weights += cr - cr_block_size;
}
}
void xnn_pack_f16_vmulcaddc_w(size_t c, size_t cr, const uint16_t* s,
const uint16_t* b, uint16_t* packed_weights,
const void* params) {
assert(s != nullptr);
assert(packed_weights != nullptr);
for (size_t cr_block_start = 0; cr_block_start < c; cr_block_start += cr) {
const size_t cr_block_size = min(c - cr_block_start, cr);
for (size_t cr_block_offset = 0; cr_block_offset < cr_block_size;
cr_block_offset++) {
*packed_weights++ = s[cr_block_start + cr_block_offset];
}
packed_weights += cr - cr_block_size;
if XNN_LIKELY (b != nullptr) {
for (size_t cr_block_offset = 0; cr_block_offset < cr_block_size;
cr_block_offset++) {
*packed_weights++ = b[cr_block_start + cr_block_offset];
}
} else {
size_t n = cr_block_size;
do {
*packed_weights++ = 0;
} while (--n != 0);
}
packed_weights += cr - cr_block_size;
}
}
void xnn_pack_f32_to_f16_vmulcaddc_w(size_t c, size_t cr, const float* s,
const float* b,
xnn_float16* packed_weights,
const void* params) {
assert(s != nullptr);
assert(packed_weights != nullptr);
for (size_t cr_block_start = 0; cr_block_start < c; cr_block_start += cr) {
const size_t cr_block_size = min(c - cr_block_start, cr);
for (size_t cr_block_offset = 0; cr_block_offset < cr_block_size;
cr_block_offset++) {
*packed_weights++ =
xnn_float16_from_float(s[cr_block_start + cr_block_offset]);
}
packed_weights += cr - cr_block_size;
if XNN_LIKELY (b != nullptr) {
for (size_t cr_block_offset = 0; cr_block_offset < cr_block_size;
cr_block_offset++) {
*packed_weights++ =
xnn_float16_from_float(b[cr_block_start + cr_block_offset]);
}
} else {
size_t n = cr_block_size;
do {
*packed_weights++ = xnn_float16_zero();
} while (--n != 0);
}
packed_weights += cr - cr_block_size;
}
}
void xnn_analyze_f32_spmm_w(size_t group_output_channels,
size_t group_input_channels, const float* kernel,
struct xnn_spmm_packing_params* params) {
assert(kernel != nullptr);
assert(params != nullptr);
// Count number of non-zero values.
size_t num_nonzeroes = 0;
size_t num_nonzero_blocks2 = 0;
size_t num_nonzero_blocks4 = 0;
for (size_t oc = 0; oc < round_down_po2(group_output_channels, 4); oc += 4) {
for (size_t ic = 0; ic < group_input_channels; ic++) {
const size_t row0_nonzero =
(size_t)(kernel[oc * group_input_channels + ic] != 0.0f);
const size_t row1_nonzero =
(size_t)(kernel[(oc + 1) * group_input_channels + ic] != 0.0f);
const size_t row2_nonzero =
(size_t)(kernel[(oc + 2) * group_input_channels + ic] != 0.0f);
const size_t row3_nonzero =
(size_t)(kernel[(oc + 3) * group_input_channels + ic] != 0.0f);
num_nonzeroes +=
row0_nonzero + row1_nonzero + row2_nonzero + row3_nonzero;
num_nonzero_blocks2 +=
(row0_nonzero | row1_nonzero) + (row2_nonzero | row3_nonzero);
num_nonzero_blocks4 +=
(row0_nonzero | row1_nonzero | row2_nonzero | row3_nonzero);
}
}
const size_t num_block4_nonzeroes = num_nonzeroes;
for (size_t oc = round_down_po2(group_output_channels, 4);
oc < round_down_po2(group_output_channels, 2); oc += 2) {
for (size_t ic = 0; ic < group_input_channels; ic++) {
const size_t row0_nonzero =
(size_t)(kernel[oc * group_input_channels + ic] != 0.0f);
const size_t row1_nonzero =
(size_t)(kernel[(oc + 1) * group_input_channels + ic] != 0.0f);
num_nonzeroes += row0_nonzero + row1_nonzero;
num_nonzero_blocks2 += (row0_nonzero | row1_nonzero);
}
}
const size_t num_block2_nonzeroes = num_nonzeroes;
for (size_t oc = round_down_po2(group_output_channels, 2);
oc < group_output_channels; oc++) {
for (size_t ic = 0; ic < group_input_channels; ic++) {
num_nonzeroes += (size_t)(kernel[oc * group_input_channels + ic] != 0.0f);
}
}
params->num_nonzeroes = num_nonzeroes;
params->num_nonzero_blocks2 = num_nonzero_blocks2;
params->num_nonzero_blocks4 = num_nonzero_blocks4;
params->num_block2_nonzeroes = num_block2_nonzeroes;
params->num_block4_nonzeroes = num_block4_nonzeroes;
}
void xnn_analyze_f16_spmm_w(size_t group_output_channels,
size_t group_input_channels,
const xnn_float16* kernel,
struct xnn_spmm_packing_params* params) {
assert(kernel != nullptr);
assert(params != nullptr);
// Count number of non-zero values.
size_t num_nonzeroes = 0;
size_t num_nonzero_blocks2 = 0;
size_t num_nonzero_blocks4 = 0;
for (size_t oc = 0; oc < round_down_po2(group_output_channels, 4); oc += 4) {
for (size_t ic = 0; ic < group_input_channels; ic++) {
const size_t row0_nonzero =
(size_t)!xnn_float16_is_zero(kernel[oc * group_input_channels + ic]);
const size_t row1_nonzero = (size_t)!xnn_float16_is_zero(
kernel[(oc + 1) * group_input_channels + ic]);
const size_t row2_nonzero = (size_t)!xnn_float16_is_zero(
kernel[(oc + 2) * group_input_channels + ic]);
const size_t row3_nonzero = (size_t)!xnn_float16_is_zero(
kernel[(oc + 3) * group_input_channels + ic]);
num_nonzeroes +=
row0_nonzero + row1_nonzero + row2_nonzero + row3_nonzero;
num_nonzero_blocks2 +=
(row0_nonzero | row1_nonzero) + (row2_nonzero | row3_nonzero);
num_nonzero_blocks4 +=
(row0_nonzero | row1_nonzero | row2_nonzero | row3_nonzero);
}
}
const size_t num_block4_nonzeroes = num_nonzeroes;
for (size_t oc = round_down_po2(group_output_channels, 4);
oc < round_down_po2(group_output_channels, 2); oc += 2) {
for (size_t ic = 0; ic < group_input_channels; ic++) {
const size_t row0_nonzero =
(size_t)!xnn_float16_is_zero(kernel[oc * group_input_channels + ic]);
const size_t row1_nonzero = (size_t)!xnn_float16_is_zero(
kernel[(oc + 1) * group_input_channels + ic]);
num_nonzeroes += row0_nonzero + row1_nonzero;
num_nonzero_blocks2 += (row0_nonzero | row1_nonzero);
}
}
const size_t num_block2_nonzeroes = num_nonzeroes;
for (size_t oc = round_down_po2(group_output_channels, 2);
oc < group_output_channels; oc++) {
for (size_t ic = 0; ic < group_input_channels; ic++) {
num_nonzeroes +=
(size_t)!xnn_float16_is_zero(kernel[oc * group_input_channels + ic]);
}
}
params->num_nonzeroes = num_nonzeroes;
params->num_nonzero_blocks2 = num_nonzero_blocks2;
params->num_nonzero_blocks4 = num_nonzero_blocks4;
params->num_block2_nonzeroes = num_block2_nonzeroes;
params->num_block4_nonzeroes = num_block4_nonzeroes;
}
enum xnn_status xnn_pack_f32_spmm_w(
size_t group_output_channels, size_t output_channels_block_size,
size_t group_input_channels, const float* kernel, const float* bias,
int32_t* input_channel_diffs, uint32_t* output_channel_nonzeros,
float* nonzero_values, size_t* first_input_channel) {
size_t first_ic = 0, last_ic = 0;
bool first_nonzero = true;
for (size_t ocb = 0;
ocb < round_down_po2(group_output_channels, output_channels_block_size);
ocb += output_channels_block_size) {
if XNN_LIKELY (bias != nullptr) {
for (size_t oco = 0; oco < output_channels_block_size; oco++) {
*nonzero_values++ = bias[ocb + oco];
}
} else {
for (size_t oco = 0; oco < output_channels_block_size; oco++) {
*nonzero_values++ = 0.0f;
}
}
for (size_t ic = 0; ic < group_input_channels; ic++) {
bool is_nonzero_block = false;
for (size_t oco = 0; oco < output_channels_block_size; oco++) {
is_nonzero_block |=
(kernel[(ocb + oco) * group_input_channels + ic] != 0.0f);
}
if (is_nonzero_block) {
for (size_t oco = 0; oco < output_channels_block_size; oco++) {
*nonzero_values++ = kernel[(ocb + oco) * group_input_channels + ic];
}
if (first_nonzero) {
first_ic = ic;
} else {
const int64_t diff = (int64_t)((uint64_t)ic - (uint64_t)last_ic) *
(int64_t)sizeof(float);
if (diff != (int64_t)(int32_t)diff) {
xnn_log_error(
"failed to convert kernel to sparse representation: scaled "
"difference in input channels exceeds int32_t range");
return xnn_status_unsupported_parameter;
}
*input_channel_diffs++ = (int32_t)diff;
}
first_nonzero = false;
last_ic = ic;
*output_channel_nonzeros += 1;
}
}
output_channel_nonzeros += 1;
}
for (size_t oc =
round_down_po2(group_output_channels, output_channels_block_size);
oc < group_output_channels; oc++) {
if XNN_LIKELY (bias != nullptr) {
*nonzero_values++ = bias[oc];
} else {
*nonzero_values++ = 0.0f;
}
for (size_t ic = 0; ic < group_input_channels; ic++) {
const float weight = kernel[oc * group_input_channels + ic];
if (weight != 0.0f) {
*nonzero_values++ = weight;
if (first_nonzero) {
first_ic = ic;
} else {
const int64_t diff = (int64_t)((uint64_t)ic - (uint64_t)last_ic) *
(int64_t)sizeof(float);
if (diff != (int64_t)(int32_t)diff) {
xnn_log_error(
"failed to convert kernel to sparse representation: scaled "
"difference in input channels exceeds int32_t range");
return xnn_status_unsupported_parameter;
}
*input_channel_diffs++ = (int32_t)diff;
}
first_nonzero = false;
last_ic = ic;
*output_channel_nonzeros += 1;
}
}
output_channel_nonzeros += 1;
}
// If there are any non-zero elements, we have to return to the initial input
// channel.
if (!first_nonzero) {
const int64_t diff = (int64_t)((uint64_t)first_ic - (uint64_t)last_ic) *
(int64_t)sizeof(float);
if (diff != (int64_t)(int32_t)diff) {
xnn_log_error(
"failed to convert kernel to sparse representation: scaled "
"difference in input channels exceeds int32_t range");
return xnn_status_unsupported_parameter;
}
*input_channel_diffs++ = (int32_t)diff;
}
*first_input_channel = first_ic;
return xnn_status_success;
}
enum xnn_status xnn_pack_f32_to_f16_spmm_w(
size_t group_output_channels, size_t output_channels_block_size,
size_t group_input_channels, const float* kernel, const float* bias,
int32_t* input_channel_diffs, uint32_t* output_channel_nonzeros,
xnn_float16* nonzero_values, // fp16 values
size_t* first_input_channel) {
size_t first_ic = 0, last_ic = 0;
bool first_nonzero = true;
for (size_t ocb = 0;
ocb < round_down_po2(group_output_channels, output_channels_block_size);
ocb += output_channels_block_size) {
if XNN_LIKELY (bias != nullptr) {
for (size_t oco = 0; oco < output_channels_block_size; oco++) {
*nonzero_values++ = xnn_float16_from_float(bias[ocb + oco]);
}
} else {
for (size_t oco = 0; oco < output_channels_block_size; oco++) {
*nonzero_values++ = xnn_float16_zero();
}
}
for (size_t ic = 0; ic < group_input_channels; ic++) {
bool is_nonzero_block = false;
for (size_t oco = 0; oco < output_channels_block_size; oco++) {
is_nonzero_block |=
(kernel[(ocb + oco) * group_input_channels + ic] != 0.0f);
}
if (is_nonzero_block) {
for (size_t oco = 0; oco < output_channels_block_size; oco++) {
*nonzero_values++ = xnn_float16_from_float(
kernel[(ocb + oco) * group_input_channels + ic]);
}
if (first_nonzero) {
first_ic = ic;
} else {
const int64_t diff = (int64_t)((uint64_t)ic - (uint64_t)last_ic) *
(int64_t)sizeof(uint16_t);
if (diff != (int64_t)(int32_t)diff) {
xnn_log_error(
"failed to convert kernel to sparse representation: scaled "
"difference in input channels exceeds int32_t range");
return xnn_status_unsupported_parameter;
}
*input_channel_diffs++ = (int32_t)diff;
}
first_nonzero = false;
last_ic = ic;
*output_channel_nonzeros += 1;
}
}
output_channel_nonzeros += 1;
}
for (size_t oc =
round_down_po2(group_output_channels, output_channels_block_size);
oc < group_output_channels; oc++) {
if XNN_LIKELY (bias != nullptr) {
*nonzero_values++ = xnn_float16_from_float(bias[oc]);
} else {
*nonzero_values++ = xnn_float16_zero();
}
for (size_t ic = 0; ic < group_input_channels; ic++) {
const float weight = kernel[oc * group_input_channels + ic];
if (weight != 0.0f) {
*nonzero_values++ = xnn_float16_from_float(weight);
if (first_nonzero) {
first_ic = ic;
} else {
const int64_t diff = (int64_t)((uint64_t)ic - (uint64_t)last_ic) *
(int64_t)sizeof(uint16_t);
if (diff != (int64_t)(int32_t)diff) {
xnn_log_error(
"failed to convert kernel to sparse representation: scaled "
"difference in input channels exceeds int32_t range");
return xnn_status_unsupported_parameter;
}
*input_channel_diffs++ = (int32_t)diff;
}
first_nonzero = false;
last_ic = ic;
*output_channel_nonzeros += 1;
}
}
output_channel_nonzeros += 1;
}
// If there are any non-zero elements, we have to return to the initial input
// channel.
if (!first_nonzero) {
const int64_t diff = (int64_t)((uint64_t)first_ic - (uint64_t)last_ic) *
(int64_t)sizeof(uint16_t);
if (diff != (int64_t)(int32_t)diff) {
xnn_log_error(
"failed to convert kernel to sparse representation: scaled "
"difference in input channels exceeds int32_t range");
return xnn_status_unsupported_parameter;
}
*input_channel_diffs++ = (int32_t)diff;
}
*first_input_channel = first_ic;
return xnn_status_success;
}
enum xnn_status xnn_pack_f16_spmm_w(size_t group_output_channels,
size_t output_channels_block_size,
size_t group_input_channels,
const xnn_float16* kernel, // fp16 values
const xnn_float16* bias, // fp16 values
int32_t* input_channel_diffs,
uint32_t* output_channel_nonzeros,
xnn_float16* nonzero_values, // fp16 values
size_t* first_input_channel) {
size_t first_ic = 0, last_ic = 0;
bool first_nonzero = true;
for (size_t ocb = 0;
ocb < round_down_po2(group_output_channels, output_channels_block_size);
ocb += output_channels_block_size) {
if XNN_LIKELY (bias != nullptr) {
for (size_t oco = 0; oco < output_channels_block_size; oco++) {
*nonzero_values++ = bias[ocb + oco];
}
} else {
for (size_t oco = 0; oco < output_channels_block_size; oco++) {
*nonzero_values++ = xnn_float16_zero();
}
}
for (size_t ic = 0; ic < group_input_channels; ic++) {
bool is_nonzero_block = false;
for (size_t oco = 0; oco < output_channels_block_size; oco++) {
is_nonzero_block |= !xnn_float16_is_zero(
kernel[(ocb + oco) * group_input_channels + ic]);
}
if (is_nonzero_block) {
for (size_t oco = 0; oco < output_channels_block_size; oco++) {
*nonzero_values++ = kernel[(ocb + oco) * group_input_channels + ic];
}
if (first_nonzero) {
first_ic = ic;
} else {
const int64_t diff = (int64_t)((uint64_t)ic - (uint64_t)last_ic) *
(int64_t)sizeof(uint16_t);
if (diff != (int64_t)(int32_t)diff) {
xnn_log_error(
"failed to convert kernel to sparse representation: scaled "
"difference in input channels exceeds int32_t range");
return xnn_status_unsupported_parameter;
}
*input_channel_diffs++ = (int32_t)diff;
}
first_nonzero = false;
last_ic = ic;
*output_channel_nonzeros += 1;
}
}
output_channel_nonzeros += 1;
}
for (size_t oc =
round_down_po2(group_output_channels, output_channels_block_size);
oc < group_output_channels; oc++) {
if XNN_LIKELY (bias != nullptr) {
*nonzero_values++ = bias[oc];
} else {
*nonzero_values++ = xnn_float16_zero();
}
for (size_t ic = 0; ic < group_input_channels; ic++) {
const xnn_float16 weight = kernel[oc * group_input_channels + ic];
if (!xnn_float16_is_zero(weight)) {
*nonzero_values++ = weight;
if (first_nonzero) {
first_ic = ic;
} else {
const int64_t diff = (int64_t)((uint64_t)ic - (uint64_t)last_ic) *
(int64_t)sizeof(uint16_t);
if (diff != (int64_t)(int32_t)diff) {
xnn_log_error(
"failed to convert kernel to sparse representation: scaled "
"difference in input channels exceeds int32_t range");
return xnn_status_unsupported_parameter;
}
*input_channel_diffs++ = (int32_t)diff;
}
first_nonzero = false;
last_ic = ic;
*output_channel_nonzeros += 1;
}
}
output_channel_nonzeros += 1;
}
// If there are any non-zero elements, we have to return to the initial input
// channel.
if (!first_nonzero) {
const int64_t diff = (int64_t)((uint64_t)first_ic - (uint64_t)last_ic) *
(int64_t)sizeof(uint16_t);
if (diff != (int64_t)(int32_t)diff) {
xnn_log_error(
"failed to convert kernel to sparse representation: scaled "
"difference in input channels exceeds int32_t range");
return xnn_status_unsupported_parameter;
}
*input_channel_diffs++ = (int32_t)diff;
}
*first_input_channel = first_ic;
return xnn_status_success;
}
} // extern "C"