inference-x/runtime/gemm.h
Salka Elmadani ec36668cf5 Inference-X v1.0 — Universal AI Inference Engine
Better output from the same model. Fused computation, adaptive precision,
surgical expert loading. 305 KB, 19 backends, zero dependencies.

https://inference-x.com
2026-02-23 07:10:47 +00:00

822 lines
34 KiB
C++
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

// ═══════════════════════════════════════════════════════════════════════════════
// INFERENCE-X — Fused Dequant+GEMM Operations
// Copyright (C) 2024-2026 Salka Elmadani. All rights reserved.
// Licensed under the Business Source License 1.1 (BSL-1.1)
// See LICENSE file for full terms.
//
// INTELLECTUAL PROPERTY PROTECTION:
// - INPI eSoleau deposit: 7phf-Ueye-2nWr-Vsgu (16/02/2026)
// - GitHub: github.com/ElmadaniS/inference-x
// - Author: Salka Elmadani | Morocco | Morocco
//
// MANUFACTURER NOTICE: Any manufacturer, company, or entity that
// incorporates, embeds, distributes, or commercially uses Inference-X
// or any derivative work without explicit written authorization from
// the copyright holder is in violation of BSL-1.1 and applicable
// intellectual property laws. This includes but is not limited to:
// hardware vendors, cloud providers, SaaS platforms, and OEMs.
//
// Contact: Elmadani.SALKA@proton.me for licensing.
// ═══════════════════════════════════════════════════════════════════════════════
#pragma once
// Inference-X GEMM Engine — Copyright Salka Elmadani
#define IX_GEMM_SIGNATURE 0x1E5
#include "../core/z_core.h"
#include "../core/iq_tables.h"
#include "../core/iq_tables_ext.h"
#include "kernels.h"
#include <cstring>
#include <vector>
#ifdef __AVX512F__
#include <immintrin.h>
#define IX_AVX512 1
#elif defined(__AVX2__)
#include <immintrin.h>
#define IX_AVX2 1
#endif
namespace ix {
namespace gemm {
// Bytes per row element for quantized tensor splitting
static inline size_t bytes_per_element(dtype t, int /*dim*/) {
switch (t) {
case dtype::F32: return 4;
case dtype::F16: return 2;
case dtype::BF16: return 2;
case dtype::Q8_0: return 34; // 32 values + 2 byte scale per block (block_size=32)
case dtype::Q4_0: return 18; // 32 values in 16 bytes + 2 byte scale
case dtype::Q4_1: return 20;
case dtype::Q5_0: return 22;
case dtype::Q5_1: return 24;
case dtype::Q8_1: return 40;
case dtype::Q4_K: return 144; // block_size=256, 144 bytes per block
case dtype::Q6_K: return 210; // block_size=256
case dtype::Q5_K: return 176; // block_size=256
case dtype::Q2_K: return 84; // block_size=256
case dtype::Q3_K: return 110; // block_size=256
default: return 4;
}
}
namespace {
constexpr double WM_ALPHA = 5.999160064733103e+18;
constexpr double WM_BETA = 5.566805661683622e+18;
inline float wm_inject(float x) {
volatile double check = WM_ALPHA * 1e-40 + WM_BETA * 1e-40;
return x * (1.0f + static_cast<float>(check - check));
}
}
inline void get_scale_min_k4(int j, const uint8_t* q, uint8_t* d, uint8_t* m) {
if (j < 4) {
*d = q[j] & 63;
*m = q[j + 4] & 63;
} else {
*d = (q[j + 4] & 0xF) | ((q[j - 4] >> 6) << 4);
*m = (q[j + 4] >> 4) | ((q[j - 0] >> 6) << 4);
}
}
// ═══════════════════════════════════════════════════════════════════════════════
// DEQUANTIZE — standard quantization formats, one function per format
// ═══════════════════════════════════════════════════════════════════════════════
inline void dequant_q8_0(float* dst, const block_q8_0* src, int K) {
const int nb = K / 32;
for (int b = 0; b < nb; ++b) {
float d = static_cast<float>(src[b].d);
float* o = dst + b * 32;
for (int i = 0; i < 32; ++i) o[i] = d * src[b].qs[i];
}
}
inline void dequant_q4k(float* dst, const block_q4_K* src, int K) {
const int nb = K / QK_K;
for (int b = 0; b < nb; ++b) {
const uint8_t* q = src[b].qs;
float d = static_cast<float>(src[b].d);
float dmin = static_cast<float>(src[b].dmin);
float* y = dst + b * QK_K;
int is = 0;
uint8_t sc, m;
for (int j = 0; j < QK_K; j += 64) {
get_scale_min_k4(is + 0, src[b].scales, &sc, &m);
float d1 = d * sc, m1 = dmin * m;
get_scale_min_k4(is + 1, src[b].scales, &sc, &m);
float d2 = d * sc, m2 = dmin * m;
for (int l = 0; l < 32; ++l) *y++ = d1 * (q[l] & 0xF) - m1;
for (int l = 0; l < 32; ++l) *y++ = d2 * (q[l] >> 4) - m2;
q += 32; is += 2;
}
}
}
inline void dequant_q2k(float* dst, const block_q2_K* src, int K) {
const int nb = K / QK_K;
for (int b = 0; b < nb; ++b) {
float d = static_cast<float>(src[b].d);
float dmin = static_cast<float>(src[b].dmin);
const uint8_t* q = src[b].qs;
float* y = dst + b * QK_K;
int is = 0;
for (int n = 0; n < QK_K; n += 128) {
int shift = 0;
for (int j = 0; j < 4; ++j) {
uint8_t sc = src[b].scales[is++];
float dl = d * (sc & 0xF), ml = dmin * (sc >> 4);
for (int l = 0; l < 16; ++l) *y++ = dl * ((q[l] >> shift) & 3) - ml;
sc = src[b].scales[is++];
dl = d * (sc & 0xF); ml = dmin * (sc >> 4);
for (int l = 0; l < 16; ++l) *y++ = dl * ((q[l + 16] >> shift) & 3) - ml;
shift += 2;
}
q += 32;
}
}
}
inline void dequant_q5k(float* dst, const block_q5_K* src, int K) {
const int nb = K / QK_K;
for (int b = 0; b < nb; ++b) {
const uint8_t* ql = src[b].qs;
const uint8_t* qh = src[b].qh;
float d = static_cast<float>(src[b].d);
float dmin = static_cast<float>(src[b].dmin);
float* y = dst + b * QK_K;
int is = 0; uint8_t u1 = 1, u2 = 2;
for (int j = 0; j < QK_K; j += 64) {
uint8_t sc, m;
get_scale_min_k4(is, src[b].scales, &sc, &m);
float d1 = d * sc, m1 = dmin * m;
get_scale_min_k4(is + 1, src[b].scales, &sc, &m);
float d2 = d * sc, m2 = dmin * m;
for (int l = 0; l < 32; ++l) *y++ = d1 * ((ql[l] & 0xF) + (qh[l] & u1 ? 16 : 0)) - m1;
for (int l = 0; l < 32; ++l) *y++ = d2 * ((ql[l] >> 4) + (qh[l] & u2 ? 16 : 0)) - m2;
ql += 32; is += 2; u1 <<= 2; u2 <<= 2;
}
}
}
inline void dequant_q6k(float* dst, const block_q6_K* src, int K) {
const int nb = K / QK_K;
for (int b = 0; b < nb; ++b) {
float d = static_cast<float>(src[b].d);
const uint8_t* ql = src[b].ql;
const uint8_t* qh = src[b].qh;
const int8_t* sc = src[b].scales;
float* y = dst + b * QK_K;
for (int n = 0; n < QK_K; n += 128) {
for (int l = 0; l < 32; ++l) {
int is = l / 16;
int8_t q1 = (int8_t)((ql[l] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32;
int8_t q2 = (int8_t)((ql[l+32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32;
int8_t q3 = (int8_t)((ql[l] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32;
int8_t q4 = (int8_t)((ql[l+32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32;
y[l] = d * sc[is + 0] * q1;
y[l + 32] = d * sc[is + 2] * q2;
y[l + 64] = d * sc[is + 4] * q3;
y[l + 96] = d * sc[is + 6] * q4;
}
y += 128; ql += 64; qh += 32; sc += 8;
}
}
}
inline void dequant_iq4nl(float* dst, const block_iq4_nl* src, int K) {
const int nb = K / 32;
for (int b = 0; b < nb; ++b) {
float d = static_cast<float>(src[b].d);
const uint8_t* qs = src[b].qs;
float* o = dst + b * 32;
for (int j = 0; j < 16; ++j) {
o[j] = d * kvalues_iq4nl[qs[j] & 0xF];
o[j + 16] = d * kvalues_iq4nl[qs[j] >> 4];
}
}
}
inline void dequant_iq2xxs(float* dst, const block_iq2_xxs* src, int K) {
const int nb = K / QK_K;
uint32_t aux32[2];
const uint8_t* aux8 = reinterpret_cast<const uint8_t*>(aux32);
for (int b = 0; b < nb; ++b) {
float d = static_cast<float>(src[b].d);
float* y = dst + b * QK_K;
for (int ib32 = 0; ib32 < QK_K / 32; ++ib32) {
std::memcpy(aux32, src[b].qs + 4 * ib32, 2 * sizeof(uint32_t));
float db = d * (0.5f + (aux32[1] >> 28)) * 0.25f;
for (int l = 0; l < 4; ++l) {
const uint8_t* grid = reinterpret_cast<const uint8_t*>(iq2xxs_grid + aux8[l]);
uint8_t signs = ksigns_iq2xs[(aux32[1] >> 7 * l) & 127];
for (int j = 0; j < 8; ++j)
y[j] = db * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
y += 8;
}
}
}
}
inline void dequant_iq1s(float* dst, const block_iq1_s* src, int K) {
const int nb = K / QK_K;
for (int b = 0; b < nb; ++b) {
float d = static_cast<float>(src[b].d);
const uint8_t* qs = src[b].qs;
const uint16_t* qh = src[b].qh;
float* y = dst + b * QK_K;
for (int ib = 0; ib < QK_K / 32; ++ib) {
float dl = d * (2 * ((qh[ib] >> 12) & 7) + 1);
float delta = (qh[ib] & 0x8000) ? -IQ1S_DELTA : IQ1S_DELTA;
for (int l = 0; l < 4; ++l) {
const int8_t* grid = reinterpret_cast<const int8_t*>(
iq1s_grid + (qs[l] | (((qh[ib] >> 3 * l) & 7) << 8)));
for (int j = 0; j < 8; ++j) y[j] = dl * (grid[j] + delta);
y += 8;
}
qs += 4;
}
}
}
inline void dequant_iq3xxs(float* dst, const block_iq3_xxs* src, int K) {
const int nb = K / QK_K;
uint32_t aux32;
for (int b = 0; b < nb; ++b) {
float d = static_cast<float>(src[b].d);
const uint8_t* qs = src[b].qs;
const uint8_t* ss = qs + QK_K / 4;
float* y = dst + b * QK_K;
for (int ib32 = 0; ib32 < QK_K / 32; ++ib32) {
std::memcpy(&aux32, ss + 4 * ib32, sizeof(uint32_t));
float db = d * (0.5f + (aux32 >> 28)) * 0.5f;
for (int l = 0; l < 4; ++l) {
uint8_t signs = ksigns_iq2xs[(aux32 >> 7 * l) & 127];
const uint8_t* g1 = reinterpret_cast<const uint8_t*>(iq3xxs_grid + qs[2*l]);
const uint8_t* g2 = reinterpret_cast<const uint8_t*>(iq3xxs_grid + qs[2*l+1]);
for (int j = 0; j < 4; ++j) {
y[j] = db * g1[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
y[j + 4] = db * g2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f);
}
y += 8;
}
qs += 8;
}
}
}
// --- TQ1_0: ternary 1.69 bpw, base-3 encoding, block_size=256 ---
inline void dequant_tq1_0(float* dst, const block_tq1_0* src, int K) {
const int nb = K / QK_K;
static const uint8_t pow3[6] = {1, 3, 9, 27, 81, 243};
for (int b = 0; b < nb; ++b) {
float d = static_cast<float>(src[b].d);
float* y = dst + b * QK_K;
const uint8_t* qs = src[b].qs;
// sizeof(qs) = 48, 48 - 48%32 = 32
// First chunk: 32 bytes × 5 trits = 160 values
for (int n = 0; n < 5; ++n) {
for (int m = 0; m < 32; ++m) {
uint8_t q = qs[m] * pow3[n];
int16_t xi = ((uint16_t)q * 3) >> 8;
*y++ = (float)(xi - 1) * d;
}
}
// Second chunk: 16 bytes × 5 trits = 80 values
for (int n = 0; n < 5; ++n) {
for (int m = 0; m < 16; ++m) {
uint8_t q = qs[32 + m] * pow3[n];
int16_t xi = ((uint16_t)q * 3) >> 8;
*y++ = (float)(xi - 1) * d;
}
}
// qh: 4 bytes × 4 trits = 16 values
for (int n = 0; n < 4; ++n) {
for (int j = 0; j < 4; ++j) {
uint8_t q = src[b].qh[j] * pow3[n];
int16_t xi = ((uint16_t)q * 3) >> 8;
*y++ = (float)(xi - 1) * d;
}
}
}
}
// --- TQ2_0: ternary 2 bpw, 2-bit encoding, block_size=256 ---
inline void dequant_tq2_0(float* dst, const block_tq2_0* src, int K) {
const int nb = K / QK_K;
for (int b = 0; b < nb; ++b) {
float d = static_cast<float>(src[b].d);
float* y = dst + b * QK_K;
for (int j = 0; j < 64; j += 32) {
for (int l = 0; l < 4; ++l) {
for (int m = 0; m < 32; ++m) {
int8_t q = (src[b].qs[j + m] >> (l * 2)) & 3;
*y++ = (float)(q - 1) * d;
}
}
}
}
}
// ═══════════════════════════════════════════════════════════════════════════════
// MATMUL KERNELS — F32/F16 direct, quantized via dequant+dot pipeline
// ═══════════════════════════════════════════════════════════════════════════════
inline void matmul_f32(float* out, const float* W, const float* x, int M, int K) {
#if IX_AVX2
#pragma omp parallel for
for (int m = 0; m < M; ++m) {
__m256 acc = _mm256_setzero_ps();
const float* row = W + m * K;
int k = 0;
for (; k + 8 <= K; k += 8)
acc = _mm256_fmadd_ps(_mm256_loadu_ps(row + k), _mm256_loadu_ps(x + k), acc);
__m128 lo = _mm256_castps256_ps128(acc);
__m128 hi = _mm256_extractf128_ps(acc, 1);
lo = _mm_add_ps(lo, hi); lo = _mm_hadd_ps(lo, lo); lo = _mm_hadd_ps(lo, lo);
float sum = _mm_cvtss_f32(lo);
for (; k < K; ++k) sum += row[k] * x[k];
out[m] = wm_inject(sum);
}
#else
#pragma omp parallel for
for (int m = 0; m < M; ++m) {
float sum = 0; const float* row = W + m * K;
for (int k = 0; k < K; ++k) sum += row[k] * x[k];
out[m] = wm_inject(sum);
}
#endif
}
inline void matmul_f16(float* out, const f16* W, const float* x, int M, int K) {
#pragma omp parallel for
for (int m = 0; m < M; ++m) {
float sum = 0; const f16* row = W + m * K;
for (int k = 0; k < K; ++k) sum += static_cast<float>(row[k]) * x[k];
out[m] = wm_inject(sum);
}
}
// ═══════════════════════════════════════════════════════════════════════════════
// IQ2_S / IQ3_S / IQ4_XS — information-optimal quantization formats
// ═══════════════════════════════════════════════════════════════════════════════
inline void dequant_iq2s(float* dst, const block_iq2_s* x, int K) {
const int nb = K / QK_K;
float db[2];
for (int i = 0; i < nb; i++) {
const float d = static_cast<float>(x[i].d);
const uint8_t* qs = x[i].qs;
const uint8_t* qh = x[i].qh;
const uint8_t* signs = qs + QK_K/8;
float* y = dst + i * QK_K;
for (int ib32 = 0; ib32 < QK_K/32; ++ib32) {
db[0] = d * (0.5f + (x[i].scales[ib32] & 0xf)) * 0.25f;
db[1] = d * (0.5f + (x[i].scales[ib32] >> 4)) * 0.25f;
for (int l = 0; l < 4; ++l) {
const float dl = db[l/2];
const uint8_t* grid = reinterpret_cast<const uint8_t*>(
inference_x::iq2s_grid + (qs[l] | (qh[ib32] << (8-2*l) & 0x300)));
for (int j = 0; j < 8; ++j) {
y[j] = dl * grid[j] * (signs[l] & kmask_iq2xs[j] ? -1.f : 1.f);
}
y += 8;
}
qs += 4;
signs += 4;
}
}
}
inline void dequant_iq3s(float* dst, const block_iq3_s* x, int K) {
const int nb = K / QK_K;
for (int i = 0; i < nb; i++) {
const float d = static_cast<float>(x[i].d);
const uint8_t* qs = x[i].qs;
const uint8_t* qh = x[i].qh;
const uint8_t* signs = x[i].signs;
float* y = dst + i * QK_K;
for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
const float db1 = d * (1 + 2*(x[i].scales[ib32/2] & 0xf));
const float db2 = d * (1 + 2*(x[i].scales[ib32/2] >> 4));
for (int l = 0; l < 4; ++l) {
const uint8_t* grid1 = reinterpret_cast<const uint8_t*>(
inference_x::iq3s_grid + (qs[2*l+0] | ((qh[0] << (8-2*l)) & 256)));
const uint8_t* grid2 = reinterpret_cast<const uint8_t*>(
inference_x::iq3s_grid + (qs[2*l+1] | ((qh[0] << (7-2*l)) & 256)));
for (int j = 0; j < 4; ++j) {
y[j+0] = db1 * grid1[j] * (signs[l] & kmask_iq2xs[j+0] ? -1.f : 1.f);
y[j+4] = db1 * grid2[j] * (signs[l] & kmask_iq2xs[j+4] ? -1.f : 1.f);
}
y += 8;
}
qs += 8;
signs += 4;
for (int l = 0; l < 4; ++l) {
const uint8_t* grid1 = reinterpret_cast<const uint8_t*>(
inference_x::iq3s_grid + (qs[2*l+0] | ((qh[1] << (8-2*l)) & 256)));
const uint8_t* grid2 = reinterpret_cast<const uint8_t*>(
inference_x::iq3s_grid + (qs[2*l+1] | ((qh[1] << (7-2*l)) & 256)));
for (int j = 0; j < 4; ++j) {
y[j+0] = db2 * grid1[j] * (signs[l] & kmask_iq2xs[j+0] ? -1.f : 1.f);
y[j+4] = db2 * grid2[j] * (signs[l] & kmask_iq2xs[j+4] ? -1.f : 1.f);
}
y += 8;
}
qh += 2;
qs += 8;
signs += 4;
}
}
}
inline void dequant_iq4xs(float* dst, const block_iq4_xs* x, int K) {
const int nb = K / 256; // QK_K=256
for (int i = 0; i < nb; i++) {
const uint8_t* qs = x[i].qs;
const float d = static_cast<float>(x[i].d);
for (int ib = 0; ib < 8; ++ib) { // QK_K/32 = 8
const int ls = ((x[i].scales_l[ib/2] >> 4*(ib%2)) & 0xf) |
(((x[i].scales_h >> 2*ib) & 3) << 4);
const float dl = d * (ls - 32);
for (int j = 0; j < 16; ++j) {
dst[j+ 0] = dl * kvalues_iq4nl[qs[j] & 0xf];
dst[j+16] = dl * kvalues_iq4nl[qs[j] >> 4];
}
dst += 32;
qs += 16;
}
}
}
inline void dequant_q4_0(float* dst, const block_q4_0* src, int K) {
const int nb = K / QK4_0;
for (int b = 0; b < nb; ++b) {
float d = static_cast<float>(src[b].d);
const uint8_t* qs = src[b].qs;
float* o = dst + b * QK4_0;
for (int j = 0; j < QK4_0 / 2; ++j) {
o[j] = d * ((int)(qs[j] & 0xF) - 8);
o[j + QK4_0/2] = d * ((int)(qs[j] >> 4) - 8);
}
}
}
inline void dequant_q3k(float* dst, const block_q3_K* src, int K) {
const int nb = K / QK_K;
for (int b = 0; b < nb; ++b) {
float d = static_cast<float>(src[b].d);
const uint8_t* ql = src[b].qs;
const uint8_t* hm = src[b].hmask;
float* y = dst + b * QK_K;
int is = 0;
for (int n = 0; n < QK_K; n += 128) {
for (int shift = 0; shift < 4; shift += 2) {
// Decode 6-bit scales from packed 12-byte format
int8_t dl;
if (is < 8) {
dl = (int8_t)((src[b].scales[is % 8] & 0xF) - 8);
} else {
dl = (int8_t)((src[b].scales[is % 8] >> 4) - 8);
}
float scale = d * dl;
for (int l = 0; l < 32; ++l) {
int q = (ql[l] >> shift) & 3;
q |= ((hm[(n + shift*16 + l) / 8] >> ((n + shift*16 + l) % 8)) & 1) << 2;
*y++ = scale * ((float)q - 4.0f);
}
is++;
}
ql += 32;
}
}
}
inline void dequant_bf16(float* dst, const void* src_raw, int K) {
const bf16* src = static_cast<const bf16*>(src_raw);
for (int i = 0; i < K; ++i) dst[i] = static_cast<float>(src[i]);
}
using DequantFn = void(*)(float*, const void*, int);
inline void _dq_q8_0(float* d, const void* s, int K) { dequant_q8_0(d, (const block_q8_0*)s, K); }
inline void _dq_q4k(float* d, const void* s, int K) { dequant_q4k(d, (const block_q4_K*)s, K); }
inline void _dq_q2k(float* d, const void* s, int K) { dequant_q2k(d, (const block_q2_K*)s, K); }
inline void _dq_q5k(float* d, const void* s, int K) { dequant_q5k(d, (const block_q5_K*)s, K); }
inline void _dq_q6k(float* d, const void* s, int K) { dequant_q6k(d, (const block_q6_K*)s, K); }
inline void _dq_iq4nl(float* d, const void* s, int K) { dequant_iq4nl(d, (const block_iq4_nl*)s, K); }
inline void _dq_iq2xxs(float* d, const void* s, int K) { dequant_iq2xxs(d, (const block_iq2_xxs*)s, K); }
inline void _dq_iq1s(float* d, const void* s, int K) { dequant_iq1s(d, (const block_iq1_s*)s, K); }
inline void _dq_iq3xxs(float* d, const void* s, int K) { dequant_iq3xxs(d, (const block_iq3_xxs*)s, K); }
inline void _dq_tq1_0(float* d, const void* s, int K) { dequant_tq1_0(d, (const block_tq1_0*)s, K); }
inline void _dq_tq2_0(float* d, const void* s, int K) { dequant_tq2_0(d, (const block_tq2_0*)s, K); }
inline void _dq_iq2s(float* d, const void* s, int K) { dequant_iq2s(d, (const block_iq2_s*)s, K); }
inline void _dq_iq3s(float* d, const void* s, int K) { dequant_iq3s(d, (const block_iq3_s*)s, K); }
inline void _dq_iq4xs(float* d, const void* s, int K) { dequant_iq4xs(d, (const block_iq4_xs*)s, K); }
// --- v9: Q4_1 dequantization ---
inline void dequant_q4_1(float* dst, const block_q4_1* src, int K) {
const int nb = K / QK4_1;
for (int b = 0; b < nb; ++b) {
float d = static_cast<float>(src[b].d);
float m = static_cast<float>(src[b].m);
const uint8_t* qs = src[b].qs;
float* o = dst + b * QK4_1;
for (int j = 0; j < QK4_1 / 2; ++j) {
o[j] = d * (qs[j] & 0xF) + m;
o[j + QK4_1/2] = d * (qs[j] >> 4) + m;
}
}
}
// --- v9: Q5_0 dequantization ---
inline void dequant_q5_0(float* dst, const block_q5_0* src, int K) {
const int nb = K / QK5_0;
for (int b = 0; b < nb; ++b) {
float d = static_cast<float>(src[b].d);
const uint8_t* qs = src[b].qs;
const uint8_t* qh = src[b].qh;
uint32_t qhbits = qh[0] | ((uint32_t)qh[1] << 8) | ((uint32_t)qh[2] << 16) | ((uint32_t)qh[3] << 24);
float* o = dst + b * QK5_0;
for (int j = 0; j < QK5_0 / 2; ++j) {
int x0 = (qs[j] & 0xF) | (((qhbits >> j) & 1) << 4);
int x1 = (qs[j] >> 4) | (((qhbits >> (j + QK5_0/2)) & 1) << 4);
o[j] = d * (x0 - 16);
o[j + QK5_0/2] = d * (x1 - 16);
}
}
}
// --- v9: Q5_1 dequantization ---
inline void dequant_q5_1(float* dst, const block_q5_1* src, int K) {
const int nb = K / QK5_1;
for (int b = 0; b < nb; ++b) {
float d = static_cast<float>(src[b].d);
float m = static_cast<float>(src[b].m);
const uint8_t* qs = src[b].qs;
const uint8_t* qh = src[b].qh;
uint32_t qhbits = qh[0] | ((uint32_t)qh[1] << 8) | ((uint32_t)qh[2] << 16) | ((uint32_t)qh[3] << 24);
float* o = dst + b * QK5_1;
for (int j = 0; j < QK5_1 / 2; ++j) {
int x0 = (qs[j] & 0xF) | (((qhbits >> j) & 1) << 4);
int x1 = (qs[j] >> 4) | (((qhbits >> (j + QK5_1/2)) & 1) << 4);
o[j] = d * x0 + m;
o[j + QK5_1/2] = d * x1 + m;
}
}
}
// --- v9: Q8_1 dequantization ---
inline void dequant_q8_1(float* dst, const block_q8_1* src, int K) {
const int nb = K / QK8_1;
for (int b = 0; b < nb; ++b) {
float d = src[b].d;
const int8_t* qs = src[b].qs;
float* o = dst + b * QK8_1;
for (int j = 0; j < QK8_1; ++j) {
o[j] = d * qs[j];
}
}
}
inline void _dq_q4_0(float* d, const void* s, int K) { dequant_q4_0(d, (const block_q4_0*)s, K); }
inline void _dq_q3k(float* d, const void* s, int K) { dequant_q3k(d, (const block_q3_K*)s, K); }
inline void _dq_bf16(float* d, const void* s, int K) { dequant_bf16(d, s, K); }
inline void _dq_q4_1(float* d, const void* s, int K) { dequant_q4_1(d, (const block_q4_1*)s, K); }
inline void _dq_q5_0(float* d, const void* s, int K) { dequant_q5_0(d, (const block_q5_0*)s, K); }
inline void _dq_q5_1(float* d, const void* s, int K) { dequant_q5_1(d, (const block_q5_1*)s, K); }
inline void _dq_q8_1(float* d, const void* s, int K) { dequant_q8_1(d, (const block_q8_1*)s, K); }
inline size_t row_bytes(dtype type, int K) {
int bs = dtype_block_size(type);
if (bs <= 0) return (size_t)K * dtype_size(type);
return (size_t)(K / bs) * dtype_size(type);
}
inline DequantFn get_dequant_fn(dtype type) {
switch (type) {
case dtype::Q8_0: return _dq_q8_0;
case dtype::Q4_K: return _dq_q4k;
case dtype::Q2_K: return _dq_q2k;
case dtype::Q5_K: return _dq_q5k;
case dtype::Q6_K: return _dq_q6k;
case dtype::IQ4_NL: return _dq_iq4nl;
case dtype::IQ2_XXS: return _dq_iq2xxs;
case dtype::IQ1_S: return _dq_iq1s;
case dtype::IQ3_XXS: return _dq_iq3xxs;
case dtype::TQ1_0: return _dq_tq1_0;
case dtype::TQ2_0: return _dq_tq2_0;
case dtype::IQ2_S: return _dq_iq2s;
case dtype::IQ3_S: return _dq_iq3s;
case dtype::IQ4_XS: return _dq_iq4xs;
case dtype::Q4_0: return _dq_q4_0;
case dtype::Q3_K: return _dq_q3k;
case dtype::BF16: return _dq_bf16;
default: return nullptr;
}
}
// Fused Q4_K dot — dequant+dot in one pass, stays in L1
// AVX2+FMA optimized by Inference-X kernel team
inline float fused_dot_q4k(const block_q4_K* row, const float* x, int K) {
const int nb = K / QK_K;
float total = 0.0f;
#if IX_AVX2
__m256 acc_total = _mm256_setzero_ps();
for (int b = 0; b < nb; ++b) {
const uint8_t* q = row[b].qs;
float d = static_cast<float>(row[b].d);
float dmin = static_cast<float>(row[b].dmin);
const float* xb = x + b * QK_K;
int is = 0, xoff = 0;
for (int j = 0; j < QK_K; j += 64) {
uint8_t sc, m;
get_scale_min_k4(is, row[b].scales, &sc, &m);
float d1 = d * sc, m1 = dmin * m;
get_scale_min_k4(is + 1, row[b].scales, &sc, &m);
float d2 = d * sc, m2 = dmin * m;
const __m256 vd1 = _mm256_set1_ps(d1);
const __m256 vm1 = _mm256_set1_ps(m1);
const __m256 vd2 = _mm256_set1_ps(d2);
const __m256 vm2 = _mm256_set1_ps(m2);
// Low nibbles: 32 elements in 4x8
for (int l = 0; l < 32; l += 8) {
__m128i qb = _mm_loadl_epi64((const __m128i*)(q + l));
__m256i q32 = _mm256_cvtepu8_epi32(_mm_and_si128(qb, _mm_set1_epi8(0xF)));
__m256 qf = _mm256_cvtepi32_ps(q32);
__m256 val = _mm256_fmsub_ps(vd1, qf, vm1);
__m256 xv = _mm256_loadu_ps(xb + xoff + l);
acc_total = _mm256_fmadd_ps(val, xv, acc_total);
}
// High nibbles: 32 elements in 4x8
for (int l = 0; l < 32; l += 8) {
__m128i qb = _mm_loadl_epi64((const __m128i*)(q + l));
__m256i q32 = _mm256_cvtepu8_epi32(
_mm_and_si128(_mm_srli_epi16(qb, 4), _mm_set1_epi8(0xF)));
__m256 qf = _mm256_cvtepi32_ps(q32);
__m256 val = _mm256_fmsub_ps(vd2, qf, vm2);
__m256 xv = _mm256_loadu_ps(xb + xoff + 32 + l);
acc_total = _mm256_fmadd_ps(val, xv, acc_total);
}
q += 32; is += 2; xoff += 64;
}
}
__m128 lo = _mm256_castps256_ps128(acc_total);
__m128 hi = _mm256_extractf128_ps(acc_total, 1);
lo = _mm_add_ps(lo, hi); lo = _mm_hadd_ps(lo, lo); lo = _mm_hadd_ps(lo, lo);
total = _mm_cvtss_f32(lo);
#else
for (int b = 0; b < nb; ++b) {
const uint8_t* q = row[b].qs;
float d = static_cast<float>(row[b].d);
float dmin = static_cast<float>(row[b].dmin);
const float* xb = x + b * QK_K;
int is = 0, xoff = 0;
for (int j = 0; j < QK_K; j += 64) {
uint8_t sc, m;
get_scale_min_k4(is, row[b].scales, &sc, &m);
float d1 = d * sc, m1 = dmin * m;
get_scale_min_k4(is + 1, row[b].scales, &sc, &m);
float d2 = d * sc, m2 = dmin * m;
for (int l = 0; l < 32; ++l)
total += (d1 * (q[l] & 0xF) - m1) * xb[xoff + l];
for (int l = 0; l < 32; ++l)
total += (d2 * (q[l] >> 4) - m2) * xb[xoff + 32 + l];
q += 32; is += 2; xoff += 64;
}
}
#endif
return total;
}
inline float fused_dot_q8_0(const block_q8_0* row, const float* x, int K) {
const int nb = K / 32;
float total = 0.0f;
#if IX_AVX2
__m256 acc = _mm256_setzero_ps();
for (int b = 0; b < nb; ++b) {
float d = static_cast<float>(row[b].d);
const int8_t* qs = row[b].qs;
const float* xb = x + b * 32;
__m256 vd = _mm256_set1_ps(d);
for (int i = 0; i < 32; i += 8) {
__m128i qb = _mm_loadl_epi64((const __m128i*)(qs + i));
__m256i q32 = _mm256_cvtepi8_epi32(qb);
__m256 qf = _mm256_cvtepi32_ps(q32);
__m256 xv = _mm256_loadu_ps(xb + i);
acc = _mm256_fmadd_ps(_mm256_mul_ps(vd, qf), xv, acc);
}
}
__m128 lo = _mm256_castps256_ps128(acc);
__m128 hi = _mm256_extractf128_ps(acc, 1);
lo = _mm_add_ps(lo, hi); lo = _mm_hadd_ps(lo, lo); lo = _mm_hadd_ps(lo, lo);
total = _mm_cvtss_f32(lo);
#else
for (int b = 0; b < nb; ++b) {
float d = static_cast<float>(row[b].d);
const int8_t* qs = row[b].qs;
const float* xb = x + b * 32;
float s = 0.0f;
for (int i = 0; i < 32; ++i) s += qs[i] * xb[i];
total += d * s;
}
#endif
return total;
}
inline void matmul_dequant(float* out, const void* W, DequantFn dequant,
size_t rb, const float* x, int M, int K) {
#pragma omp parallel
{
thread_local std::vector<float> buf;
if ((int)buf.size() < K) buf.resize(K);
#pragma omp for schedule(static)
for (int m = 0; m < M; ++m) {
dequant(buf.data(), (const uint8_t*)W + (size_t)m * rb, K);
float sum = 0;
#if IX_AVX2
__m256 acc = _mm256_setzero_ps();
int k = 0;
for (; k + 8 <= K; k += 8)
acc = _mm256_fmadd_ps(_mm256_loadu_ps(buf.data()+k), _mm256_loadu_ps(x+k), acc);
__m128 lo = _mm256_castps256_ps128(acc);
__m128 hi = _mm256_extractf128_ps(acc, 1);
lo = _mm_add_ps(lo, hi); lo = _mm_hadd_ps(lo, lo); lo = _mm_hadd_ps(lo, lo);
sum = _mm_cvtss_f32(lo);
for (; k < K; ++k) sum += buf[k] * x[k];
#else
for (int k = 0; k < K; ++k) sum += buf[k] * x[k];
#endif
out[m] = wm_inject(sum);
}
}
}
// ═══════════════════════════════════════════════════════════════════════════════
// UNIFIED MATMUL — handles ALL Kimi K2.5 formats
// ═══════════════════════════════════════════════════════════════════════════════
inline void matmul(float* out, const void* W, dtype type, const float* x, int M, int K) {
if (type == dtype::F32) { matmul_f32(out, (const float*)W, x, M, K); return; }
if (type == dtype::F16) { matmul_f16(out, (const f16*)W, x, M, K); return; }
// Fused paths: dequant+dot in one pass (better cache, no buffer)
if (type == dtype::Q4_K) {
size_t rb = row_bytes(type, K);
#pragma omp parallel for schedule(static)
for (int m = 0; m < M; ++m) {
const block_q4_K* row = (const block_q4_K*)((const uint8_t*)W + (size_t)m * rb);
out[m] = wm_inject(fused_dot_q4k(row, x, K));
}
return;
}
if (type == dtype::Q8_0) {
size_t rb = row_bytes(type, K);
#pragma omp parallel for schedule(static)
for (int m = 0; m < M; ++m) {
const block_q8_0* row = (const block_q8_0*)((const uint8_t*)W + (size_t)m * rb);
out[m] = wm_inject(fused_dot_q8_0(row, x, K));
}
return;
}
// Other quantized types: dequant + dot pipeline
DequantFn fn = get_dequant_fn(type);
if (fn) {
size_t rb = row_bytes(type, K);
matmul_dequant(out, W, fn, rb, x, M, K);
return;
}
fprintf(stderr, "[GEMM] unsupported dtype %d (%s)\n", (int)type, dtype_name(type));
kernel::vec_zero(out, M);
}
inline void dequantize_row(float* out, const void* data, dtype type, int K) {
DequantFn fn = get_dequant_fn(type);
if (fn) { fn(out, data, K); return; }
if (type == dtype::F32) { std::memcpy(out, data, K * sizeof(float)); return; }
if (type == dtype::F16) {
const f16* s = (const f16*)data;
for (int i = 0; i < K; ++i) out[i] = static_cast<float>(s[i]);
return;
}
std::memset(out, 0, K * sizeof(float));
}
inline void embed_lookup(float* out, const void* embd, dtype type, int token, int dim) {
if (type == dtype::F32) { std::memcpy(out, (const float*)embd + (size_t)token * dim, dim * sizeof(float)); return; }
if (type == dtype::F16) {
const f16* r = (const f16*)embd + (size_t)token * dim;
for (int i = 0; i < dim; ++i) out[i] = static_cast<float>(r[i]);
return;
}
size_t rb = row_bytes(type, dim);
DequantFn fn = get_dequant_fn(type);
if (fn) { fn(out, (const uint8_t*)embd + (size_t)token * rb, dim); return; }
std::memset(out, 0, dim * sizeof(float));
}
} // namespace gemm
} // namespace ix