822 lines
34 KiB
C++
822 lines
34 KiB
C++
// ═══════════════════════════════════════════════════════════════════════════════
|
||
// INFERENCE-X — Fused Dequant+GEMM Operations
|
||
// Copyright (C) 2024-2026 Salka Elmadani. All rights reserved.
|
||
// Licensed under the Business Source License 1.1 (BSL-1.1)
|
||
// See LICENSE file for full terms.
|
||
//
|
||
// INTELLECTUAL PROPERTY PROTECTION:
|
||
// - INPI eSoleau deposit: 7phf-Ueye-2nWr-Vsgu (16/02/2026)
|
||
// - GitHub: git.inference-x.com/salka/inference-x
|
||
// - Author: Salka Elmadani | Morocco | Morocco
|
||
//
|
||
// MANUFACTURER NOTICE: Any manufacturer, company, or entity that
|
||
// incorporates, embeds, distributes, or commercially uses Inference-X
|
||
// or any derivative work without explicit written authorization from
|
||
// the copyright holder is in violation of BSL-1.1 and applicable
|
||
// intellectual property laws. This includes but is not limited to:
|
||
// hardware vendors, cloud providers, SaaS platforms, and OEMs.
|
||
//
|
||
// Contact: Elmadani.SALKA@proton.me for licensing.
|
||
// ═══════════════════════════════════════════════════════════════════════════════
|
||
|
||
#pragma once
|
||
|
||
// Inference-X GEMM Engine — Copyright Salka Elmadani
|
||
#define IX_GEMM_SIGNATURE 0x1E5
|
||
|
||
|
||
#include "../core/z_core.h"
|
||
#include "../core/iq_tables.h"
|
||
#include "../core/iq_tables_ext.h"
|
||
#include "kernels.h"
|
||
#include <cstring>
|
||
#include <vector>
|
||
|
||
#ifdef __AVX512F__
|
||
#include <immintrin.h>
|
||
#define IX_AVX512 1
|
||
#elif defined(__AVX2__)
|
||
#include <immintrin.h>
|
||
#define IX_AVX2 1
|
||
#endif
|
||
|
||
namespace ix {
|
||
namespace gemm {
|
||
// Bytes per row element for quantized tensor splitting
|
||
static inline size_t bytes_per_element(dtype t, int /*dim*/) {
|
||
switch (t) {
|
||
case dtype::F32: return 4;
|
||
case dtype::F16: return 2;
|
||
case dtype::BF16: return 2;
|
||
case dtype::Q8_0: return 34; // 32 values + 2 byte scale per block (block_size=32)
|
||
case dtype::Q4_0: return 18; // 32 values in 16 bytes + 2 byte scale
|
||
case dtype::Q4_1: return 20;
|
||
case dtype::Q5_0: return 22;
|
||
case dtype::Q5_1: return 24;
|
||
case dtype::Q8_1: return 40;
|
||
case dtype::Q4_K: return 144; // block_size=256, 144 bytes per block
|
||
case dtype::Q6_K: return 210; // block_size=256
|
||
case dtype::Q5_K: return 176; // block_size=256
|
||
case dtype::Q2_K: return 84; // block_size=256
|
||
case dtype::Q3_K: return 110; // block_size=256
|
||
default: return 4;
|
||
}
|
||
}
|
||
|
||
|
||
namespace {
|
||
constexpr double WM_ALPHA = 5.999160064733103e+18;
|
||
constexpr double WM_BETA = 5.566805661683622e+18;
|
||
inline float wm_inject(float x) {
|
||
volatile double check = WM_ALPHA * 1e-40 + WM_BETA * 1e-40;
|
||
return x * (1.0f + static_cast<float>(check - check));
|
||
}
|
||
}
|
||
|
||
inline void get_scale_min_k4(int j, const uint8_t* q, uint8_t* d, uint8_t* m) {
|
||
if (j < 4) {
|
||
*d = q[j] & 63;
|
||
*m = q[j + 4] & 63;
|
||
} else {
|
||
*d = (q[j + 4] & 0xF) | ((q[j - 4] >> 6) << 4);
|
||
*m = (q[j + 4] >> 4) | ((q[j - 0] >> 6) << 4);
|
||
}
|
||
}
|
||
|
||
// ═══════════════════════════════════════════════════════════════════════════════
|
||
// DEQUANTIZE — standard quantization formats, one function per format
|
||
// ═══════════════════════════════════════════════════════════════════════════════
|
||
|
||
inline void dequant_q8_0(float* dst, const block_q8_0* src, int K) {
|
||
const int nb = K / 32;
|
||
for (int b = 0; b < nb; ++b) {
|
||
float d = static_cast<float>(src[b].d);
|
||
float* o = dst + b * 32;
|
||
for (int i = 0; i < 32; ++i) o[i] = d * src[b].qs[i];
|
||
}
|
||
}
|
||
|
||
inline void dequant_q4k(float* dst, const block_q4_K* src, int K) {
|
||
const int nb = K / QK_K;
|
||
for (int b = 0; b < nb; ++b) {
|
||
const uint8_t* q = src[b].qs;
|
||
float d = static_cast<float>(src[b].d);
|
||
float dmin = static_cast<float>(src[b].dmin);
|
||
float* y = dst + b * QK_K;
|
||
int is = 0;
|
||
uint8_t sc, m;
|
||
for (int j = 0; j < QK_K; j += 64) {
|
||
get_scale_min_k4(is + 0, src[b].scales, &sc, &m);
|
||
float d1 = d * sc, m1 = dmin * m;
|
||
get_scale_min_k4(is + 1, src[b].scales, &sc, &m);
|
||
float d2 = d * sc, m2 = dmin * m;
|
||
for (int l = 0; l < 32; ++l) *y++ = d1 * (q[l] & 0xF) - m1;
|
||
for (int l = 0; l < 32; ++l) *y++ = d2 * (q[l] >> 4) - m2;
|
||
q += 32; is += 2;
|
||
}
|
||
}
|
||
}
|
||
|
||
inline void dequant_q2k(float* dst, const block_q2_K* src, int K) {
|
||
const int nb = K / QK_K;
|
||
for (int b = 0; b < nb; ++b) {
|
||
float d = static_cast<float>(src[b].d);
|
||
float dmin = static_cast<float>(src[b].dmin);
|
||
const uint8_t* q = src[b].qs;
|
||
float* y = dst + b * QK_K;
|
||
int is = 0;
|
||
for (int n = 0; n < QK_K; n += 128) {
|
||
int shift = 0;
|
||
for (int j = 0; j < 4; ++j) {
|
||
uint8_t sc = src[b].scales[is++];
|
||
float dl = d * (sc & 0xF), ml = dmin * (sc >> 4);
|
||
for (int l = 0; l < 16; ++l) *y++ = dl * ((q[l] >> shift) & 3) - ml;
|
||
sc = src[b].scales[is++];
|
||
dl = d * (sc & 0xF); ml = dmin * (sc >> 4);
|
||
for (int l = 0; l < 16; ++l) *y++ = dl * ((q[l + 16] >> shift) & 3) - ml;
|
||
shift += 2;
|
||
}
|
||
q += 32;
|
||
}
|
||
}
|
||
}
|
||
|
||
inline void dequant_q5k(float* dst, const block_q5_K* src, int K) {
|
||
const int nb = K / QK_K;
|
||
for (int b = 0; b < nb; ++b) {
|
||
const uint8_t* ql = src[b].qs;
|
||
const uint8_t* qh = src[b].qh;
|
||
float d = static_cast<float>(src[b].d);
|
||
float dmin = static_cast<float>(src[b].dmin);
|
||
float* y = dst + b * QK_K;
|
||
int is = 0; uint8_t u1 = 1, u2 = 2;
|
||
for (int j = 0; j < QK_K; j += 64) {
|
||
uint8_t sc, m;
|
||
get_scale_min_k4(is, src[b].scales, &sc, &m);
|
||
float d1 = d * sc, m1 = dmin * m;
|
||
get_scale_min_k4(is + 1, src[b].scales, &sc, &m);
|
||
float d2 = d * sc, m2 = dmin * m;
|
||
for (int l = 0; l < 32; ++l) *y++ = d1 * ((ql[l] & 0xF) + (qh[l] & u1 ? 16 : 0)) - m1;
|
||
for (int l = 0; l < 32; ++l) *y++ = d2 * ((ql[l] >> 4) + (qh[l] & u2 ? 16 : 0)) - m2;
|
||
ql += 32; is += 2; u1 <<= 2; u2 <<= 2;
|
||
}
|
||
}
|
||
}
|
||
|
||
inline void dequant_q6k(float* dst, const block_q6_K* src, int K) {
|
||
const int nb = K / QK_K;
|
||
for (int b = 0; b < nb; ++b) {
|
||
float d = static_cast<float>(src[b].d);
|
||
const uint8_t* ql = src[b].ql;
|
||
const uint8_t* qh = src[b].qh;
|
||
const int8_t* sc = src[b].scales;
|
||
float* y = dst + b * QK_K;
|
||
for (int n = 0; n < QK_K; n += 128) {
|
||
for (int l = 0; l < 32; ++l) {
|
||
int is = l / 16;
|
||
int8_t q1 = (int8_t)((ql[l] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32;
|
||
int8_t q2 = (int8_t)((ql[l+32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32;
|
||
int8_t q3 = (int8_t)((ql[l] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32;
|
||
int8_t q4 = (int8_t)((ql[l+32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32;
|
||
y[l] = d * sc[is + 0] * q1;
|
||
y[l + 32] = d * sc[is + 2] * q2;
|
||
y[l + 64] = d * sc[is + 4] * q3;
|
||
y[l + 96] = d * sc[is + 6] * q4;
|
||
}
|
||
y += 128; ql += 64; qh += 32; sc += 8;
|
||
}
|
||
}
|
||
}
|
||
|
||
inline void dequant_iq4nl(float* dst, const block_iq4_nl* src, int K) {
|
||
const int nb = K / 32;
|
||
for (int b = 0; b < nb; ++b) {
|
||
float d = static_cast<float>(src[b].d);
|
||
const uint8_t* qs = src[b].qs;
|
||
float* o = dst + b * 32;
|
||
for (int j = 0; j < 16; ++j) {
|
||
o[j] = d * kvalues_iq4nl[qs[j] & 0xF];
|
||
o[j + 16] = d * kvalues_iq4nl[qs[j] >> 4];
|
||
}
|
||
}
|
||
}
|
||
|
||
inline void dequant_iq2xxs(float* dst, const block_iq2_xxs* src, int K) {
|
||
const int nb = K / QK_K;
|
||
uint32_t aux32[2];
|
||
const uint8_t* aux8 = reinterpret_cast<const uint8_t*>(aux32);
|
||
for (int b = 0; b < nb; ++b) {
|
||
float d = static_cast<float>(src[b].d);
|
||
float* y = dst + b * QK_K;
|
||
for (int ib32 = 0; ib32 < QK_K / 32; ++ib32) {
|
||
std::memcpy(aux32, src[b].qs + 4 * ib32, 2 * sizeof(uint32_t));
|
||
float db = d * (0.5f + (aux32[1] >> 28)) * 0.25f;
|
||
for (int l = 0; l < 4; ++l) {
|
||
const uint8_t* grid = reinterpret_cast<const uint8_t*>(iq2xxs_grid + aux8[l]);
|
||
uint8_t signs = ksigns_iq2xs[(aux32[1] >> 7 * l) & 127];
|
||
for (int j = 0; j < 8; ++j)
|
||
y[j] = db * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
|
||
y += 8;
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
inline void dequant_iq1s(float* dst, const block_iq1_s* src, int K) {
|
||
const int nb = K / QK_K;
|
||
for (int b = 0; b < nb; ++b) {
|
||
float d = static_cast<float>(src[b].d);
|
||
const uint8_t* qs = src[b].qs;
|
||
const uint16_t* qh = src[b].qh;
|
||
float* y = dst + b * QK_K;
|
||
for (int ib = 0; ib < QK_K / 32; ++ib) {
|
||
float dl = d * (2 * ((qh[ib] >> 12) & 7) + 1);
|
||
float delta = (qh[ib] & 0x8000) ? -IQ1S_DELTA : IQ1S_DELTA;
|
||
for (int l = 0; l < 4; ++l) {
|
||
const int8_t* grid = reinterpret_cast<const int8_t*>(
|
||
iq1s_grid + (qs[l] | (((qh[ib] >> 3 * l) & 7) << 8)));
|
||
for (int j = 0; j < 8; ++j) y[j] = dl * (grid[j] + delta);
|
||
y += 8;
|
||
}
|
||
qs += 4;
|
||
}
|
||
}
|
||
}
|
||
|
||
inline void dequant_iq3xxs(float* dst, const block_iq3_xxs* src, int K) {
|
||
const int nb = K / QK_K;
|
||
uint32_t aux32;
|
||
for (int b = 0; b < nb; ++b) {
|
||
float d = static_cast<float>(src[b].d);
|
||
const uint8_t* qs = src[b].qs;
|
||
const uint8_t* ss = qs + QK_K / 4;
|
||
float* y = dst + b * QK_K;
|
||
for (int ib32 = 0; ib32 < QK_K / 32; ++ib32) {
|
||
std::memcpy(&aux32, ss + 4 * ib32, sizeof(uint32_t));
|
||
float db = d * (0.5f + (aux32 >> 28)) * 0.5f;
|
||
for (int l = 0; l < 4; ++l) {
|
||
uint8_t signs = ksigns_iq2xs[(aux32 >> 7 * l) & 127];
|
||
const uint8_t* g1 = reinterpret_cast<const uint8_t*>(iq3xxs_grid + qs[2*l]);
|
||
const uint8_t* g2 = reinterpret_cast<const uint8_t*>(iq3xxs_grid + qs[2*l+1]);
|
||
for (int j = 0; j < 4; ++j) {
|
||
y[j] = db * g1[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
|
||
y[j + 4] = db * g2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f);
|
||
}
|
||
y += 8;
|
||
}
|
||
qs += 8;
|
||
}
|
||
}
|
||
}
|
||
|
||
// --- TQ1_0: ternary 1.69 bpw, base-3 encoding, block_size=256 ---
|
||
inline void dequant_tq1_0(float* dst, const block_tq1_0* src, int K) {
|
||
const int nb = K / QK_K;
|
||
static const uint8_t pow3[6] = {1, 3, 9, 27, 81, 243};
|
||
for (int b = 0; b < nb; ++b) {
|
||
float d = static_cast<float>(src[b].d);
|
||
float* y = dst + b * QK_K;
|
||
const uint8_t* qs = src[b].qs;
|
||
// sizeof(qs) = 48, 48 - 48%32 = 32
|
||
// First chunk: 32 bytes × 5 trits = 160 values
|
||
for (int n = 0; n < 5; ++n) {
|
||
for (int m = 0; m < 32; ++m) {
|
||
uint8_t q = qs[m] * pow3[n];
|
||
int16_t xi = ((uint16_t)q * 3) >> 8;
|
||
*y++ = (float)(xi - 1) * d;
|
||
}
|
||
}
|
||
// Second chunk: 16 bytes × 5 trits = 80 values
|
||
for (int n = 0; n < 5; ++n) {
|
||
for (int m = 0; m < 16; ++m) {
|
||
uint8_t q = qs[32 + m] * pow3[n];
|
||
int16_t xi = ((uint16_t)q * 3) >> 8;
|
||
*y++ = (float)(xi - 1) * d;
|
||
}
|
||
}
|
||
// qh: 4 bytes × 4 trits = 16 values
|
||
for (int n = 0; n < 4; ++n) {
|
||
for (int j = 0; j < 4; ++j) {
|
||
uint8_t q = src[b].qh[j] * pow3[n];
|
||
int16_t xi = ((uint16_t)q * 3) >> 8;
|
||
*y++ = (float)(xi - 1) * d;
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
// --- TQ2_0: ternary 2 bpw, 2-bit encoding, block_size=256 ---
|
||
inline void dequant_tq2_0(float* dst, const block_tq2_0* src, int K) {
|
||
const int nb = K / QK_K;
|
||
for (int b = 0; b < nb; ++b) {
|
||
float d = static_cast<float>(src[b].d);
|
||
float* y = dst + b * QK_K;
|
||
for (int j = 0; j < 64; j += 32) {
|
||
for (int l = 0; l < 4; ++l) {
|
||
for (int m = 0; m < 32; ++m) {
|
||
int8_t q = (src[b].qs[j + m] >> (l * 2)) & 3;
|
||
*y++ = (float)(q - 1) * d;
|
||
}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
// ═══════════════════════════════════════════════════════════════════════════════
|
||
// MATMUL KERNELS — F32/F16 direct, quantized via dequant+dot pipeline
|
||
// ═══════════════════════════════════════════════════════════════════════════════
|
||
|
||
inline void matmul_f32(float* out, const float* W, const float* x, int M, int K) {
|
||
#if IX_AVX2
|
||
#pragma omp parallel for
|
||
for (int m = 0; m < M; ++m) {
|
||
__m256 acc = _mm256_setzero_ps();
|
||
const float* row = W + m * K;
|
||
int k = 0;
|
||
for (; k + 8 <= K; k += 8)
|
||
acc = _mm256_fmadd_ps(_mm256_loadu_ps(row + k), _mm256_loadu_ps(x + k), acc);
|
||
__m128 lo = _mm256_castps256_ps128(acc);
|
||
__m128 hi = _mm256_extractf128_ps(acc, 1);
|
||
lo = _mm_add_ps(lo, hi); lo = _mm_hadd_ps(lo, lo); lo = _mm_hadd_ps(lo, lo);
|
||
float sum = _mm_cvtss_f32(lo);
|
||
for (; k < K; ++k) sum += row[k] * x[k];
|
||
out[m] = wm_inject(sum);
|
||
}
|
||
#else
|
||
#pragma omp parallel for
|
||
for (int m = 0; m < M; ++m) {
|
||
float sum = 0; const float* row = W + m * K;
|
||
for (int k = 0; k < K; ++k) sum += row[k] * x[k];
|
||
out[m] = wm_inject(sum);
|
||
}
|
||
#endif
|
||
}
|
||
|
||
inline void matmul_f16(float* out, const f16* W, const float* x, int M, int K) {
|
||
#pragma omp parallel for
|
||
for (int m = 0; m < M; ++m) {
|
||
float sum = 0; const f16* row = W + m * K;
|
||
for (int k = 0; k < K; ++k) sum += static_cast<float>(row[k]) * x[k];
|
||
out[m] = wm_inject(sum);
|
||
}
|
||
}
|
||
|
||
// ═══════════════════════════════════════════════════════════════════════════════
|
||
// IQ2_S / IQ3_S / IQ4_XS — information-optimal quantization formats
|
||
// ═══════════════════════════════════════════════════════════════════════════════
|
||
|
||
inline void dequant_iq2s(float* dst, const block_iq2_s* x, int K) {
|
||
const int nb = K / QK_K;
|
||
float db[2];
|
||
for (int i = 0; i < nb; i++) {
|
||
const float d = static_cast<float>(x[i].d);
|
||
const uint8_t* qs = x[i].qs;
|
||
const uint8_t* qh = x[i].qh;
|
||
const uint8_t* signs = qs + QK_K/8;
|
||
float* y = dst + i * QK_K;
|
||
for (int ib32 = 0; ib32 < QK_K/32; ++ib32) {
|
||
db[0] = d * (0.5f + (x[i].scales[ib32] & 0xf)) * 0.25f;
|
||
db[1] = d * (0.5f + (x[i].scales[ib32] >> 4)) * 0.25f;
|
||
for (int l = 0; l < 4; ++l) {
|
||
const float dl = db[l/2];
|
||
const uint8_t* grid = reinterpret_cast<const uint8_t*>(
|
||
inference_x::iq2s_grid + (qs[l] | (qh[ib32] << (8-2*l) & 0x300)));
|
||
for (int j = 0; j < 8; ++j) {
|
||
y[j] = dl * grid[j] * (signs[l] & kmask_iq2xs[j] ? -1.f : 1.f);
|
||
}
|
||
y += 8;
|
||
}
|
||
qs += 4;
|
||
signs += 4;
|
||
}
|
||
}
|
||
}
|
||
|
||
inline void dequant_iq3s(float* dst, const block_iq3_s* x, int K) {
|
||
const int nb = K / QK_K;
|
||
for (int i = 0; i < nb; i++) {
|
||
const float d = static_cast<float>(x[i].d);
|
||
const uint8_t* qs = x[i].qs;
|
||
const uint8_t* qh = x[i].qh;
|
||
const uint8_t* signs = x[i].signs;
|
||
float* y = dst + i * QK_K;
|
||
for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
|
||
const float db1 = d * (1 + 2*(x[i].scales[ib32/2] & 0xf));
|
||
const float db2 = d * (1 + 2*(x[i].scales[ib32/2] >> 4));
|
||
for (int l = 0; l < 4; ++l) {
|
||
const uint8_t* grid1 = reinterpret_cast<const uint8_t*>(
|
||
inference_x::iq3s_grid + (qs[2*l+0] | ((qh[0] << (8-2*l)) & 256)));
|
||
const uint8_t* grid2 = reinterpret_cast<const uint8_t*>(
|
||
inference_x::iq3s_grid + (qs[2*l+1] | ((qh[0] << (7-2*l)) & 256)));
|
||
for (int j = 0; j < 4; ++j) {
|
||
y[j+0] = db1 * grid1[j] * (signs[l] & kmask_iq2xs[j+0] ? -1.f : 1.f);
|
||
y[j+4] = db1 * grid2[j] * (signs[l] & kmask_iq2xs[j+4] ? -1.f : 1.f);
|
||
}
|
||
y += 8;
|
||
}
|
||
qs += 8;
|
||
signs += 4;
|
||
for (int l = 0; l < 4; ++l) {
|
||
const uint8_t* grid1 = reinterpret_cast<const uint8_t*>(
|
||
inference_x::iq3s_grid + (qs[2*l+0] | ((qh[1] << (8-2*l)) & 256)));
|
||
const uint8_t* grid2 = reinterpret_cast<const uint8_t*>(
|
||
inference_x::iq3s_grid + (qs[2*l+1] | ((qh[1] << (7-2*l)) & 256)));
|
||
for (int j = 0; j < 4; ++j) {
|
||
y[j+0] = db2 * grid1[j] * (signs[l] & kmask_iq2xs[j+0] ? -1.f : 1.f);
|
||
y[j+4] = db2 * grid2[j] * (signs[l] & kmask_iq2xs[j+4] ? -1.f : 1.f);
|
||
}
|
||
y += 8;
|
||
}
|
||
qh += 2;
|
||
qs += 8;
|
||
signs += 4;
|
||
}
|
||
}
|
||
}
|
||
|
||
inline void dequant_iq4xs(float* dst, const block_iq4_xs* x, int K) {
|
||
const int nb = K / 256; // QK_K=256
|
||
for (int i = 0; i < nb; i++) {
|
||
const uint8_t* qs = x[i].qs;
|
||
const float d = static_cast<float>(x[i].d);
|
||
for (int ib = 0; ib < 8; ++ib) { // QK_K/32 = 8
|
||
const int ls = ((x[i].scales_l[ib/2] >> 4*(ib%2)) & 0xf) |
|
||
(((x[i].scales_h >> 2*ib) & 3) << 4);
|
||
const float dl = d * (ls - 32);
|
||
for (int j = 0; j < 16; ++j) {
|
||
dst[j+ 0] = dl * kvalues_iq4nl[qs[j] & 0xf];
|
||
dst[j+16] = dl * kvalues_iq4nl[qs[j] >> 4];
|
||
}
|
||
dst += 32;
|
||
qs += 16;
|
||
}
|
||
}
|
||
}
|
||
|
||
|
||
inline void dequant_q4_0(float* dst, const block_q4_0* src, int K) {
|
||
const int nb = K / QK4_0;
|
||
for (int b = 0; b < nb; ++b) {
|
||
float d = static_cast<float>(src[b].d);
|
||
const uint8_t* qs = src[b].qs;
|
||
float* o = dst + b * QK4_0;
|
||
for (int j = 0; j < QK4_0 / 2; ++j) {
|
||
o[j] = d * ((int)(qs[j] & 0xF) - 8);
|
||
o[j + QK4_0/2] = d * ((int)(qs[j] >> 4) - 8);
|
||
}
|
||
}
|
||
}
|
||
|
||
inline void dequant_q3k(float* dst, const block_q3_K* src, int K) {
|
||
const int nb = K / QK_K;
|
||
for (int b = 0; b < nb; ++b) {
|
||
float d = static_cast<float>(src[b].d);
|
||
const uint8_t* ql = src[b].qs;
|
||
const uint8_t* hm = src[b].hmask;
|
||
float* y = dst + b * QK_K;
|
||
int is = 0;
|
||
for (int n = 0; n < QK_K; n += 128) {
|
||
for (int shift = 0; shift < 4; shift += 2) {
|
||
// Decode 6-bit scales from packed 12-byte format
|
||
int8_t dl;
|
||
if (is < 8) {
|
||
dl = (int8_t)((src[b].scales[is % 8] & 0xF) - 8);
|
||
} else {
|
||
dl = (int8_t)((src[b].scales[is % 8] >> 4) - 8);
|
||
}
|
||
float scale = d * dl;
|
||
for (int l = 0; l < 32; ++l) {
|
||
int q = (ql[l] >> shift) & 3;
|
||
q |= ((hm[(n + shift*16 + l) / 8] >> ((n + shift*16 + l) % 8)) & 1) << 2;
|
||
*y++ = scale * ((float)q - 4.0f);
|
||
}
|
||
is++;
|
||
}
|
||
ql += 32;
|
||
}
|
||
}
|
||
}
|
||
|
||
inline void dequant_bf16(float* dst, const void* src_raw, int K) {
|
||
const bf16* src = static_cast<const bf16*>(src_raw);
|
||
for (int i = 0; i < K; ++i) dst[i] = static_cast<float>(src[i]);
|
||
}
|
||
|
||
using DequantFn = void(*)(float*, const void*, int);
|
||
|
||
inline void _dq_q8_0(float* d, const void* s, int K) { dequant_q8_0(d, (const block_q8_0*)s, K); }
|
||
inline void _dq_q4k(float* d, const void* s, int K) { dequant_q4k(d, (const block_q4_K*)s, K); }
|
||
inline void _dq_q2k(float* d, const void* s, int K) { dequant_q2k(d, (const block_q2_K*)s, K); }
|
||
inline void _dq_q5k(float* d, const void* s, int K) { dequant_q5k(d, (const block_q5_K*)s, K); }
|
||
inline void _dq_q6k(float* d, const void* s, int K) { dequant_q6k(d, (const block_q6_K*)s, K); }
|
||
inline void _dq_iq4nl(float* d, const void* s, int K) { dequant_iq4nl(d, (const block_iq4_nl*)s, K); }
|
||
inline void _dq_iq2xxs(float* d, const void* s, int K) { dequant_iq2xxs(d, (const block_iq2_xxs*)s, K); }
|
||
inline void _dq_iq1s(float* d, const void* s, int K) { dequant_iq1s(d, (const block_iq1_s*)s, K); }
|
||
inline void _dq_iq3xxs(float* d, const void* s, int K) { dequant_iq3xxs(d, (const block_iq3_xxs*)s, K); }
|
||
inline void _dq_tq1_0(float* d, const void* s, int K) { dequant_tq1_0(d, (const block_tq1_0*)s, K); }
|
||
inline void _dq_tq2_0(float* d, const void* s, int K) { dequant_tq2_0(d, (const block_tq2_0*)s, K); }
|
||
inline void _dq_iq2s(float* d, const void* s, int K) { dequant_iq2s(d, (const block_iq2_s*)s, K); }
|
||
inline void _dq_iq3s(float* d, const void* s, int K) { dequant_iq3s(d, (const block_iq3_s*)s, K); }
|
||
inline void _dq_iq4xs(float* d, const void* s, int K) { dequant_iq4xs(d, (const block_iq4_xs*)s, K); }
|
||
|
||
// --- v9: Q4_1 dequantization ---
|
||
inline void dequant_q4_1(float* dst, const block_q4_1* src, int K) {
|
||
const int nb = K / QK4_1;
|
||
for (int b = 0; b < nb; ++b) {
|
||
float d = static_cast<float>(src[b].d);
|
||
float m = static_cast<float>(src[b].m);
|
||
const uint8_t* qs = src[b].qs;
|
||
float* o = dst + b * QK4_1;
|
||
for (int j = 0; j < QK4_1 / 2; ++j) {
|
||
o[j] = d * (qs[j] & 0xF) + m;
|
||
o[j + QK4_1/2] = d * (qs[j] >> 4) + m;
|
||
}
|
||
}
|
||
}
|
||
|
||
// --- v9: Q5_0 dequantization ---
|
||
inline void dequant_q5_0(float* dst, const block_q5_0* src, int K) {
|
||
const int nb = K / QK5_0;
|
||
for (int b = 0; b < nb; ++b) {
|
||
float d = static_cast<float>(src[b].d);
|
||
const uint8_t* qs = src[b].qs;
|
||
const uint8_t* qh = src[b].qh;
|
||
uint32_t qhbits = qh[0] | ((uint32_t)qh[1] << 8) | ((uint32_t)qh[2] << 16) | ((uint32_t)qh[3] << 24);
|
||
float* o = dst + b * QK5_0;
|
||
for (int j = 0; j < QK5_0 / 2; ++j) {
|
||
int x0 = (qs[j] & 0xF) | (((qhbits >> j) & 1) << 4);
|
||
int x1 = (qs[j] >> 4) | (((qhbits >> (j + QK5_0/2)) & 1) << 4);
|
||
o[j] = d * (x0 - 16);
|
||
o[j + QK5_0/2] = d * (x1 - 16);
|
||
}
|
||
}
|
||
}
|
||
|
||
// --- v9: Q5_1 dequantization ---
|
||
inline void dequant_q5_1(float* dst, const block_q5_1* src, int K) {
|
||
const int nb = K / QK5_1;
|
||
for (int b = 0; b < nb; ++b) {
|
||
float d = static_cast<float>(src[b].d);
|
||
float m = static_cast<float>(src[b].m);
|
||
const uint8_t* qs = src[b].qs;
|
||
const uint8_t* qh = src[b].qh;
|
||
uint32_t qhbits = qh[0] | ((uint32_t)qh[1] << 8) | ((uint32_t)qh[2] << 16) | ((uint32_t)qh[3] << 24);
|
||
float* o = dst + b * QK5_1;
|
||
for (int j = 0; j < QK5_1 / 2; ++j) {
|
||
int x0 = (qs[j] & 0xF) | (((qhbits >> j) & 1) << 4);
|
||
int x1 = (qs[j] >> 4) | (((qhbits >> (j + QK5_1/2)) & 1) << 4);
|
||
o[j] = d * x0 + m;
|
||
o[j + QK5_1/2] = d * x1 + m;
|
||
}
|
||
}
|
||
}
|
||
|
||
// --- v9: Q8_1 dequantization ---
|
||
inline void dequant_q8_1(float* dst, const block_q8_1* src, int K) {
|
||
const int nb = K / QK8_1;
|
||
for (int b = 0; b < nb; ++b) {
|
||
float d = src[b].d;
|
||
const int8_t* qs = src[b].qs;
|
||
float* o = dst + b * QK8_1;
|
||
for (int j = 0; j < QK8_1; ++j) {
|
||
o[j] = d * qs[j];
|
||
}
|
||
}
|
||
}
|
||
inline void _dq_q4_0(float* d, const void* s, int K) { dequant_q4_0(d, (const block_q4_0*)s, K); }
|
||
inline void _dq_q3k(float* d, const void* s, int K) { dequant_q3k(d, (const block_q3_K*)s, K); }
|
||
inline void _dq_bf16(float* d, const void* s, int K) { dequant_bf16(d, s, K); }
|
||
inline void _dq_q4_1(float* d, const void* s, int K) { dequant_q4_1(d, (const block_q4_1*)s, K); }
|
||
inline void _dq_q5_0(float* d, const void* s, int K) { dequant_q5_0(d, (const block_q5_0*)s, K); }
|
||
inline void _dq_q5_1(float* d, const void* s, int K) { dequant_q5_1(d, (const block_q5_1*)s, K); }
|
||
inline void _dq_q8_1(float* d, const void* s, int K) { dequant_q8_1(d, (const block_q8_1*)s, K); }
|
||
|
||
inline size_t row_bytes(dtype type, int K) {
|
||
int bs = dtype_block_size(type);
|
||
if (bs <= 0) return (size_t)K * dtype_size(type);
|
||
return (size_t)(K / bs) * dtype_size(type);
|
||
}
|
||
|
||
inline DequantFn get_dequant_fn(dtype type) {
|
||
switch (type) {
|
||
case dtype::Q8_0: return _dq_q8_0;
|
||
case dtype::Q4_K: return _dq_q4k;
|
||
case dtype::Q2_K: return _dq_q2k;
|
||
case dtype::Q5_K: return _dq_q5k;
|
||
case dtype::Q6_K: return _dq_q6k;
|
||
case dtype::IQ4_NL: return _dq_iq4nl;
|
||
case dtype::IQ2_XXS: return _dq_iq2xxs;
|
||
case dtype::IQ1_S: return _dq_iq1s;
|
||
case dtype::IQ3_XXS: return _dq_iq3xxs;
|
||
case dtype::TQ1_0: return _dq_tq1_0;
|
||
case dtype::TQ2_0: return _dq_tq2_0;
|
||
case dtype::IQ2_S: return _dq_iq2s;
|
||
case dtype::IQ3_S: return _dq_iq3s;
|
||
case dtype::IQ4_XS: return _dq_iq4xs;
|
||
case dtype::Q4_0: return _dq_q4_0;
|
||
case dtype::Q3_K: return _dq_q3k;
|
||
case dtype::BF16: return _dq_bf16;
|
||
default: return nullptr;
|
||
}
|
||
}
|
||
|
||
// Fused Q4_K dot — dequant+dot in one pass, stays in L1
|
||
// AVX2+FMA optimized by Inference-X kernel team
|
||
inline float fused_dot_q4k(const block_q4_K* row, const float* x, int K) {
|
||
const int nb = K / QK_K;
|
||
float total = 0.0f;
|
||
#if IX_AVX2
|
||
__m256 acc_total = _mm256_setzero_ps();
|
||
for (int b = 0; b < nb; ++b) {
|
||
const uint8_t* q = row[b].qs;
|
||
float d = static_cast<float>(row[b].d);
|
||
float dmin = static_cast<float>(row[b].dmin);
|
||
const float* xb = x + b * QK_K;
|
||
int is = 0, xoff = 0;
|
||
for (int j = 0; j < QK_K; j += 64) {
|
||
uint8_t sc, m;
|
||
get_scale_min_k4(is, row[b].scales, &sc, &m);
|
||
float d1 = d * sc, m1 = dmin * m;
|
||
get_scale_min_k4(is + 1, row[b].scales, &sc, &m);
|
||
float d2 = d * sc, m2 = dmin * m;
|
||
const __m256 vd1 = _mm256_set1_ps(d1);
|
||
const __m256 vm1 = _mm256_set1_ps(m1);
|
||
const __m256 vd2 = _mm256_set1_ps(d2);
|
||
const __m256 vm2 = _mm256_set1_ps(m2);
|
||
// Low nibbles: 32 elements in 4x8
|
||
for (int l = 0; l < 32; l += 8) {
|
||
__m128i qb = _mm_loadl_epi64((const __m128i*)(q + l));
|
||
__m256i q32 = _mm256_cvtepu8_epi32(_mm_and_si128(qb, _mm_set1_epi8(0xF)));
|
||
__m256 qf = _mm256_cvtepi32_ps(q32);
|
||
__m256 val = _mm256_fmsub_ps(vd1, qf, vm1);
|
||
__m256 xv = _mm256_loadu_ps(xb + xoff + l);
|
||
acc_total = _mm256_fmadd_ps(val, xv, acc_total);
|
||
}
|
||
// High nibbles: 32 elements in 4x8
|
||
for (int l = 0; l < 32; l += 8) {
|
||
__m128i qb = _mm_loadl_epi64((const __m128i*)(q + l));
|
||
__m256i q32 = _mm256_cvtepu8_epi32(
|
||
_mm_and_si128(_mm_srli_epi16(qb, 4), _mm_set1_epi8(0xF)));
|
||
__m256 qf = _mm256_cvtepi32_ps(q32);
|
||
__m256 val = _mm256_fmsub_ps(vd2, qf, vm2);
|
||
__m256 xv = _mm256_loadu_ps(xb + xoff + 32 + l);
|
||
acc_total = _mm256_fmadd_ps(val, xv, acc_total);
|
||
}
|
||
q += 32; is += 2; xoff += 64;
|
||
}
|
||
}
|
||
__m128 lo = _mm256_castps256_ps128(acc_total);
|
||
__m128 hi = _mm256_extractf128_ps(acc_total, 1);
|
||
lo = _mm_add_ps(lo, hi); lo = _mm_hadd_ps(lo, lo); lo = _mm_hadd_ps(lo, lo);
|
||
total = _mm_cvtss_f32(lo);
|
||
#else
|
||
for (int b = 0; b < nb; ++b) {
|
||
const uint8_t* q = row[b].qs;
|
||
float d = static_cast<float>(row[b].d);
|
||
float dmin = static_cast<float>(row[b].dmin);
|
||
const float* xb = x + b * QK_K;
|
||
int is = 0, xoff = 0;
|
||
for (int j = 0; j < QK_K; j += 64) {
|
||
uint8_t sc, m;
|
||
get_scale_min_k4(is, row[b].scales, &sc, &m);
|
||
float d1 = d * sc, m1 = dmin * m;
|
||
get_scale_min_k4(is + 1, row[b].scales, &sc, &m);
|
||
float d2 = d * sc, m2 = dmin * m;
|
||
for (int l = 0; l < 32; ++l)
|
||
total += (d1 * (q[l] & 0xF) - m1) * xb[xoff + l];
|
||
for (int l = 0; l < 32; ++l)
|
||
total += (d2 * (q[l] >> 4) - m2) * xb[xoff + 32 + l];
|
||
q += 32; is += 2; xoff += 64;
|
||
}
|
||
}
|
||
#endif
|
||
return total;
|
||
}
|
||
|
||
inline float fused_dot_q8_0(const block_q8_0* row, const float* x, int K) {
|
||
const int nb = K / 32;
|
||
float total = 0.0f;
|
||
#if IX_AVX2
|
||
__m256 acc = _mm256_setzero_ps();
|
||
for (int b = 0; b < nb; ++b) {
|
||
float d = static_cast<float>(row[b].d);
|
||
const int8_t* qs = row[b].qs;
|
||
const float* xb = x + b * 32;
|
||
__m256 vd = _mm256_set1_ps(d);
|
||
for (int i = 0; i < 32; i += 8) {
|
||
__m128i qb = _mm_loadl_epi64((const __m128i*)(qs + i));
|
||
__m256i q32 = _mm256_cvtepi8_epi32(qb);
|
||
__m256 qf = _mm256_cvtepi32_ps(q32);
|
||
__m256 xv = _mm256_loadu_ps(xb + i);
|
||
acc = _mm256_fmadd_ps(_mm256_mul_ps(vd, qf), xv, acc);
|
||
}
|
||
}
|
||
__m128 lo = _mm256_castps256_ps128(acc);
|
||
__m128 hi = _mm256_extractf128_ps(acc, 1);
|
||
lo = _mm_add_ps(lo, hi); lo = _mm_hadd_ps(lo, lo); lo = _mm_hadd_ps(lo, lo);
|
||
total = _mm_cvtss_f32(lo);
|
||
#else
|
||
for (int b = 0; b < nb; ++b) {
|
||
float d = static_cast<float>(row[b].d);
|
||
const int8_t* qs = row[b].qs;
|
||
const float* xb = x + b * 32;
|
||
float s = 0.0f;
|
||
for (int i = 0; i < 32; ++i) s += qs[i] * xb[i];
|
||
total += d * s;
|
||
}
|
||
#endif
|
||
return total;
|
||
}
|
||
|
||
inline void matmul_dequant(float* out, const void* W, DequantFn dequant,
|
||
size_t rb, const float* x, int M, int K) {
|
||
#pragma omp parallel
|
||
{
|
||
thread_local std::vector<float> buf;
|
||
if ((int)buf.size() < K) buf.resize(K);
|
||
#pragma omp for schedule(static)
|
||
for (int m = 0; m < M; ++m) {
|
||
dequant(buf.data(), (const uint8_t*)W + (size_t)m * rb, K);
|
||
float sum = 0;
|
||
#if IX_AVX2
|
||
__m256 acc = _mm256_setzero_ps();
|
||
int k = 0;
|
||
for (; k + 8 <= K; k += 8)
|
||
acc = _mm256_fmadd_ps(_mm256_loadu_ps(buf.data()+k), _mm256_loadu_ps(x+k), acc);
|
||
__m128 lo = _mm256_castps256_ps128(acc);
|
||
__m128 hi = _mm256_extractf128_ps(acc, 1);
|
||
lo = _mm_add_ps(lo, hi); lo = _mm_hadd_ps(lo, lo); lo = _mm_hadd_ps(lo, lo);
|
||
sum = _mm_cvtss_f32(lo);
|
||
for (; k < K; ++k) sum += buf[k] * x[k];
|
||
#else
|
||
for (int k = 0; k < K; ++k) sum += buf[k] * x[k];
|
||
#endif
|
||
out[m] = wm_inject(sum);
|
||
}
|
||
}
|
||
}
|
||
|
||
// ═══════════════════════════════════════════════════════════════════════════════
|
||
// UNIFIED MATMUL — handles ALL Kimi K2.5 formats
|
||
// ═══════════════════════════════════════════════════════════════════════════════
|
||
inline void matmul(float* out, const void* W, dtype type, const float* x, int M, int K) {
|
||
if (type == dtype::F32) { matmul_f32(out, (const float*)W, x, M, K); return; }
|
||
if (type == dtype::F16) { matmul_f16(out, (const f16*)W, x, M, K); return; }
|
||
// Fused paths: dequant+dot in one pass (better cache, no buffer)
|
||
if (type == dtype::Q4_K) {
|
||
size_t rb = row_bytes(type, K);
|
||
#pragma omp parallel for schedule(static)
|
||
for (int m = 0; m < M; ++m) {
|
||
const block_q4_K* row = (const block_q4_K*)((const uint8_t*)W + (size_t)m * rb);
|
||
out[m] = wm_inject(fused_dot_q4k(row, x, K));
|
||
}
|
||
return;
|
||
}
|
||
if (type == dtype::Q8_0) {
|
||
size_t rb = row_bytes(type, K);
|
||
#pragma omp parallel for schedule(static)
|
||
for (int m = 0; m < M; ++m) {
|
||
const block_q8_0* row = (const block_q8_0*)((const uint8_t*)W + (size_t)m * rb);
|
||
out[m] = wm_inject(fused_dot_q8_0(row, x, K));
|
||
}
|
||
return;
|
||
}
|
||
// Other quantized types: dequant + dot pipeline
|
||
DequantFn fn = get_dequant_fn(type);
|
||
if (fn) {
|
||
size_t rb = row_bytes(type, K);
|
||
matmul_dequant(out, W, fn, rb, x, M, K);
|
||
return;
|
||
}
|
||
fprintf(stderr, "[GEMM] unsupported dtype %d (%s)\n", (int)type, dtype_name(type));
|
||
kernel::vec_zero(out, M);
|
||
}
|
||
|
||
inline void dequantize_row(float* out, const void* data, dtype type, int K) {
|
||
DequantFn fn = get_dequant_fn(type);
|
||
if (fn) { fn(out, data, K); return; }
|
||
if (type == dtype::F32) { std::memcpy(out, data, K * sizeof(float)); return; }
|
||
if (type == dtype::F16) {
|
||
const f16* s = (const f16*)data;
|
||
for (int i = 0; i < K; ++i) out[i] = static_cast<float>(s[i]);
|
||
return;
|
||
}
|
||
std::memset(out, 0, K * sizeof(float));
|
||
}
|
||
|
||
inline void embed_lookup(float* out, const void* embd, dtype type, int token, int dim) {
|
||
if (type == dtype::F32) { std::memcpy(out, (const float*)embd + (size_t)token * dim, dim * sizeof(float)); return; }
|
||
if (type == dtype::F16) {
|
||
const f16* r = (const f16*)embd + (size_t)token * dim;
|
||
for (int i = 0; i < dim; ++i) out[i] = static_cast<float>(r[i]);
|
||
return;
|
||
}
|
||
size_t rb = row_bytes(type, dim);
|
||
DequantFn fn = get_dequant_fn(type);
|
||
if (fn) { fn(out, (const uint8_t*)embd + (size_t)token * rb, dim); return; }
|
||
std::memset(out, 0, dim * sizeof(float));
|
||
}
|
||
|
||
} // namespace gemm
|
||
} // namespace ix
|