// ═══════════════════════════════════════════════════════════════════════════════ // INFERENCE-X — Fused Dequant+GEMM Operations // Copyright (C) 2024-2026 Salka Elmadani. All rights reserved. // Licensed under the Business Source License 1.1 (BSL-1.1) // See LICENSE file for full terms. // // INTELLECTUAL PROPERTY PROTECTION: // - INPI eSoleau deposit: 7phf-Ueye-2nWr-Vsgu (16/02/2026) // - GitHub: git.inference-x.com/salka/inference-x // - Author: Salka Elmadani | Morocco | Morocco // // MANUFACTURER NOTICE: Any manufacturer, company, or entity that // incorporates, embeds, distributes, or commercially uses Inference-X // or any derivative work without explicit written authorization from // the copyright holder is in violation of BSL-1.1 and applicable // intellectual property laws. This includes but is not limited to: // hardware vendors, cloud providers, SaaS platforms, and OEMs. // // Contact: Elmadani.SALKA@proton.me for licensing. // ═══════════════════════════════════════════════════════════════════════════════ #pragma once // Inference-X GEMM Engine — Copyright Salka Elmadani #define IX_GEMM_SIGNATURE 0x1E5 #include "../core/z_core.h" #include "../core/iq_tables.h" #include "../core/iq_tables_ext.h" #include "kernels.h" #include #include #ifdef __AVX512F__ #include #define IX_AVX512 1 #elif defined(__AVX2__) #include #define IX_AVX2 1 #endif namespace ix { namespace gemm { // Bytes per row element for quantized tensor splitting static inline size_t bytes_per_element(dtype t, int /*dim*/) { switch (t) { case dtype::F32: return 4; case dtype::F16: return 2; case dtype::BF16: return 2; case dtype::Q8_0: return 34; // 32 values + 2 byte scale per block (block_size=32) case dtype::Q4_0: return 18; // 32 values in 16 bytes + 2 byte scale case dtype::Q4_1: return 20; case dtype::Q5_0: return 22; case dtype::Q5_1: return 24; case dtype::Q8_1: return 40; case dtype::Q4_K: return 144; // block_size=256, 144 bytes per block case dtype::Q6_K: return 210; // block_size=256 case dtype::Q5_K: return 176; // block_size=256 case dtype::Q2_K: return 84; // block_size=256 case dtype::Q3_K: return 110; // block_size=256 default: return 4; } } namespace { constexpr double WM_ALPHA = 5.999160064733103e+18; constexpr double WM_BETA = 5.566805661683622e+18; inline float wm_inject(float x) { volatile double check = WM_ALPHA * 1e-40 + WM_BETA * 1e-40; return x * (1.0f + static_cast(check - check)); } } inline void get_scale_min_k4(int j, const uint8_t* q, uint8_t* d, uint8_t* m) { if (j < 4) { *d = q[j] & 63; *m = q[j + 4] & 63; } else { *d = (q[j + 4] & 0xF) | ((q[j - 4] >> 6) << 4); *m = (q[j + 4] >> 4) | ((q[j - 0] >> 6) << 4); } } // ═══════════════════════════════════════════════════════════════════════════════ // DEQUANTIZE — standard quantization formats, one function per format // ═══════════════════════════════════════════════════════════════════════════════ inline void dequant_q8_0(float* dst, const block_q8_0* src, int K) { const int nb = K / 32; for (int b = 0; b < nb; ++b) { float d = static_cast(src[b].d); float* o = dst + b * 32; for (int i = 0; i < 32; ++i) o[i] = d * src[b].qs[i]; } } inline void dequant_q4k(float* dst, const block_q4_K* src, int K) { const int nb = K / QK_K; for (int b = 0; b < nb; ++b) { const uint8_t* q = src[b].qs; float d = static_cast(src[b].d); float dmin = static_cast(src[b].dmin); float* y = dst + b * QK_K; int is = 0; uint8_t sc, m; for (int j = 0; j < QK_K; j += 64) { get_scale_min_k4(is + 0, src[b].scales, &sc, &m); float d1 = d * sc, m1 = dmin * m; get_scale_min_k4(is + 1, src[b].scales, &sc, &m); float d2 = d * sc, m2 = dmin * m; for (int l = 0; l < 32; ++l) *y++ = d1 * (q[l] & 0xF) - m1; for (int l = 0; l < 32; ++l) *y++ = d2 * (q[l] >> 4) - m2; q += 32; is += 2; } } } inline void dequant_q2k(float* dst, const block_q2_K* src, int K) { const int nb = K / QK_K; for (int b = 0; b < nb; ++b) { float d = static_cast(src[b].d); float dmin = static_cast(src[b].dmin); const uint8_t* q = src[b].qs; float* y = dst + b * QK_K; int is = 0; for (int n = 0; n < QK_K; n += 128) { int shift = 0; for (int j = 0; j < 4; ++j) { uint8_t sc = src[b].scales[is++]; float dl = d * (sc & 0xF), ml = dmin * (sc >> 4); for (int l = 0; l < 16; ++l) *y++ = dl * ((q[l] >> shift) & 3) - ml; sc = src[b].scales[is++]; dl = d * (sc & 0xF); ml = dmin * (sc >> 4); for (int l = 0; l < 16; ++l) *y++ = dl * ((q[l + 16] >> shift) & 3) - ml; shift += 2; } q += 32; } } } inline void dequant_q5k(float* dst, const block_q5_K* src, int K) { const int nb = K / QK_K; for (int b = 0; b < nb; ++b) { const uint8_t* ql = src[b].qs; const uint8_t* qh = src[b].qh; float d = static_cast(src[b].d); float dmin = static_cast(src[b].dmin); float* y = dst + b * QK_K; int is = 0; uint8_t u1 = 1, u2 = 2; for (int j = 0; j < QK_K; j += 64) { uint8_t sc, m; get_scale_min_k4(is, src[b].scales, &sc, &m); float d1 = d * sc, m1 = dmin * m; get_scale_min_k4(is + 1, src[b].scales, &sc, &m); float d2 = d * sc, m2 = dmin * m; for (int l = 0; l < 32; ++l) *y++ = d1 * ((ql[l] & 0xF) + (qh[l] & u1 ? 16 : 0)) - m1; for (int l = 0; l < 32; ++l) *y++ = d2 * ((ql[l] >> 4) + (qh[l] & u2 ? 16 : 0)) - m2; ql += 32; is += 2; u1 <<= 2; u2 <<= 2; } } } inline void dequant_q6k(float* dst, const block_q6_K* src, int K) { const int nb = K / QK_K; for (int b = 0; b < nb; ++b) { float d = static_cast(src[b].d); const uint8_t* ql = src[b].ql; const uint8_t* qh = src[b].qh; const int8_t* sc = src[b].scales; float* y = dst + b * QK_K; for (int n = 0; n < QK_K; n += 128) { for (int l = 0; l < 32; ++l) { int is = l / 16; int8_t q1 = (int8_t)((ql[l] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32; int8_t q2 = (int8_t)((ql[l+32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32; int8_t q3 = (int8_t)((ql[l] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32; int8_t q4 = (int8_t)((ql[l+32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32; y[l] = d * sc[is + 0] * q1; y[l + 32] = d * sc[is + 2] * q2; y[l + 64] = d * sc[is + 4] * q3; y[l + 96] = d * sc[is + 6] * q4; } y += 128; ql += 64; qh += 32; sc += 8; } } } inline void dequant_iq4nl(float* dst, const block_iq4_nl* src, int K) { const int nb = K / 32; for (int b = 0; b < nb; ++b) { float d = static_cast(src[b].d); const uint8_t* qs = src[b].qs; float* o = dst + b * 32; for (int j = 0; j < 16; ++j) { o[j] = d * kvalues_iq4nl[qs[j] & 0xF]; o[j + 16] = d * kvalues_iq4nl[qs[j] >> 4]; } } } inline void dequant_iq2xxs(float* dst, const block_iq2_xxs* src, int K) { const int nb = K / QK_K; uint32_t aux32[2]; const uint8_t* aux8 = reinterpret_cast(aux32); for (int b = 0; b < nb; ++b) { float d = static_cast(src[b].d); float* y = dst + b * QK_K; for (int ib32 = 0; ib32 < QK_K / 32; ++ib32) { std::memcpy(aux32, src[b].qs + 4 * ib32, 2 * sizeof(uint32_t)); float db = d * (0.5f + (aux32[1] >> 28)) * 0.25f; for (int l = 0; l < 4; ++l) { const uint8_t* grid = reinterpret_cast(iq2xxs_grid + aux8[l]); uint8_t signs = ksigns_iq2xs[(aux32[1] >> 7 * l) & 127]; for (int j = 0; j < 8; ++j) y[j] = db * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f); y += 8; } } } } inline void dequant_iq1s(float* dst, const block_iq1_s* src, int K) { const int nb = K / QK_K; for (int b = 0; b < nb; ++b) { float d = static_cast(src[b].d); const uint8_t* qs = src[b].qs; const uint16_t* qh = src[b].qh; float* y = dst + b * QK_K; for (int ib = 0; ib < QK_K / 32; ++ib) { float dl = d * (2 * ((qh[ib] >> 12) & 7) + 1); float delta = (qh[ib] & 0x8000) ? -IQ1S_DELTA : IQ1S_DELTA; for (int l = 0; l < 4; ++l) { const int8_t* grid = reinterpret_cast( iq1s_grid + (qs[l] | (((qh[ib] >> 3 * l) & 7) << 8))); for (int j = 0; j < 8; ++j) y[j] = dl * (grid[j] + delta); y += 8; } qs += 4; } } } inline void dequant_iq3xxs(float* dst, const block_iq3_xxs* src, int K) { const int nb = K / QK_K; uint32_t aux32; for (int b = 0; b < nb; ++b) { float d = static_cast(src[b].d); const uint8_t* qs = src[b].qs; const uint8_t* ss = qs + QK_K / 4; float* y = dst + b * QK_K; for (int ib32 = 0; ib32 < QK_K / 32; ++ib32) { std::memcpy(&aux32, ss + 4 * ib32, sizeof(uint32_t)); float db = d * (0.5f + (aux32 >> 28)) * 0.5f; for (int l = 0; l < 4; ++l) { uint8_t signs = ksigns_iq2xs[(aux32 >> 7 * l) & 127]; const uint8_t* g1 = reinterpret_cast(iq3xxs_grid + qs[2*l]); const uint8_t* g2 = reinterpret_cast(iq3xxs_grid + qs[2*l+1]); for (int j = 0; j < 4; ++j) { y[j] = db * g1[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f); y[j + 4] = db * g2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f); } y += 8; } qs += 8; } } } // --- TQ1_0: ternary 1.69 bpw, base-3 encoding, block_size=256 --- inline void dequant_tq1_0(float* dst, const block_tq1_0* src, int K) { const int nb = K / QK_K; static const uint8_t pow3[6] = {1, 3, 9, 27, 81, 243}; for (int b = 0; b < nb; ++b) { float d = static_cast(src[b].d); float* y = dst + b * QK_K; const uint8_t* qs = src[b].qs; // sizeof(qs) = 48, 48 - 48%32 = 32 // First chunk: 32 bytes × 5 trits = 160 values for (int n = 0; n < 5; ++n) { for (int m = 0; m < 32; ++m) { uint8_t q = qs[m] * pow3[n]; int16_t xi = ((uint16_t)q * 3) >> 8; *y++ = (float)(xi - 1) * d; } } // Second chunk: 16 bytes × 5 trits = 80 values for (int n = 0; n < 5; ++n) { for (int m = 0; m < 16; ++m) { uint8_t q = qs[32 + m] * pow3[n]; int16_t xi = ((uint16_t)q * 3) >> 8; *y++ = (float)(xi - 1) * d; } } // qh: 4 bytes × 4 trits = 16 values for (int n = 0; n < 4; ++n) { for (int j = 0; j < 4; ++j) { uint8_t q = src[b].qh[j] * pow3[n]; int16_t xi = ((uint16_t)q * 3) >> 8; *y++ = (float)(xi - 1) * d; } } } } // --- TQ2_0: ternary 2 bpw, 2-bit encoding, block_size=256 --- inline void dequant_tq2_0(float* dst, const block_tq2_0* src, int K) { const int nb = K / QK_K; for (int b = 0; b < nb; ++b) { float d = static_cast(src[b].d); float* y = dst + b * QK_K; for (int j = 0; j < 64; j += 32) { for (int l = 0; l < 4; ++l) { for (int m = 0; m < 32; ++m) { int8_t q = (src[b].qs[j + m] >> (l * 2)) & 3; *y++ = (float)(q - 1) * d; } } } } } // ═══════════════════════════════════════════════════════════════════════════════ // MATMUL KERNELS — F32/F16 direct, quantized via dequant+dot pipeline // ═══════════════════════════════════════════════════════════════════════════════ inline void matmul_f32(float* out, const float* W, const float* x, int M, int K) { #if IX_AVX2 #pragma omp parallel for for (int m = 0; m < M; ++m) { __m256 acc = _mm256_setzero_ps(); const float* row = W + m * K; int k = 0; for (; k + 8 <= K; k += 8) acc = _mm256_fmadd_ps(_mm256_loadu_ps(row + k), _mm256_loadu_ps(x + k), acc); __m128 lo = _mm256_castps256_ps128(acc); __m128 hi = _mm256_extractf128_ps(acc, 1); lo = _mm_add_ps(lo, hi); lo = _mm_hadd_ps(lo, lo); lo = _mm_hadd_ps(lo, lo); float sum = _mm_cvtss_f32(lo); for (; k < K; ++k) sum += row[k] * x[k]; out[m] = wm_inject(sum); } #else #pragma omp parallel for for (int m = 0; m < M; ++m) { float sum = 0; const float* row = W + m * K; for (int k = 0; k < K; ++k) sum += row[k] * x[k]; out[m] = wm_inject(sum); } #endif } inline void matmul_f16(float* out, const f16* W, const float* x, int M, int K) { #pragma omp parallel for for (int m = 0; m < M; ++m) { float sum = 0; const f16* row = W + m * K; for (int k = 0; k < K; ++k) sum += static_cast(row[k]) * x[k]; out[m] = wm_inject(sum); } } // ═══════════════════════════════════════════════════════════════════════════════ // IQ2_S / IQ3_S / IQ4_XS — information-optimal quantization formats // ═══════════════════════════════════════════════════════════════════════════════ inline void dequant_iq2s(float* dst, const block_iq2_s* x, int K) { const int nb = K / QK_K; float db[2]; for (int i = 0; i < nb; i++) { const float d = static_cast(x[i].d); const uint8_t* qs = x[i].qs; const uint8_t* qh = x[i].qh; const uint8_t* signs = qs + QK_K/8; float* y = dst + i * QK_K; for (int ib32 = 0; ib32 < QK_K/32; ++ib32) { db[0] = d * (0.5f + (x[i].scales[ib32] & 0xf)) * 0.25f; db[1] = d * (0.5f + (x[i].scales[ib32] >> 4)) * 0.25f; for (int l = 0; l < 4; ++l) { const float dl = db[l/2]; const uint8_t* grid = reinterpret_cast( inference_x::iq2s_grid + (qs[l] | (qh[ib32] << (8-2*l) & 0x300))); for (int j = 0; j < 8; ++j) { y[j] = dl * grid[j] * (signs[l] & kmask_iq2xs[j] ? -1.f : 1.f); } y += 8; } qs += 4; signs += 4; } } } inline void dequant_iq3s(float* dst, const block_iq3_s* x, int K) { const int nb = K / QK_K; for (int i = 0; i < nb; i++) { const float d = static_cast(x[i].d); const uint8_t* qs = x[i].qs; const uint8_t* qh = x[i].qh; const uint8_t* signs = x[i].signs; float* y = dst + i * QK_K; for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) { const float db1 = d * (1 + 2*(x[i].scales[ib32/2] & 0xf)); const float db2 = d * (1 + 2*(x[i].scales[ib32/2] >> 4)); for (int l = 0; l < 4; ++l) { const uint8_t* grid1 = reinterpret_cast( inference_x::iq3s_grid + (qs[2*l+0] | ((qh[0] << (8-2*l)) & 256))); const uint8_t* grid2 = reinterpret_cast( inference_x::iq3s_grid + (qs[2*l+1] | ((qh[0] << (7-2*l)) & 256))); for (int j = 0; j < 4; ++j) { y[j+0] = db1 * grid1[j] * (signs[l] & kmask_iq2xs[j+0] ? -1.f : 1.f); y[j+4] = db1 * grid2[j] * (signs[l] & kmask_iq2xs[j+4] ? -1.f : 1.f); } y += 8; } qs += 8; signs += 4; for (int l = 0; l < 4; ++l) { const uint8_t* grid1 = reinterpret_cast( inference_x::iq3s_grid + (qs[2*l+0] | ((qh[1] << (8-2*l)) & 256))); const uint8_t* grid2 = reinterpret_cast( inference_x::iq3s_grid + (qs[2*l+1] | ((qh[1] << (7-2*l)) & 256))); for (int j = 0; j < 4; ++j) { y[j+0] = db2 * grid1[j] * (signs[l] & kmask_iq2xs[j+0] ? -1.f : 1.f); y[j+4] = db2 * grid2[j] * (signs[l] & kmask_iq2xs[j+4] ? -1.f : 1.f); } y += 8; } qh += 2; qs += 8; signs += 4; } } } inline void dequant_iq4xs(float* dst, const block_iq4_xs* x, int K) { const int nb = K / 256; // QK_K=256 for (int i = 0; i < nb; i++) { const uint8_t* qs = x[i].qs; const float d = static_cast(x[i].d); for (int ib = 0; ib < 8; ++ib) { // QK_K/32 = 8 const int ls = ((x[i].scales_l[ib/2] >> 4*(ib%2)) & 0xf) | (((x[i].scales_h >> 2*ib) & 3) << 4); const float dl = d * (ls - 32); for (int j = 0; j < 16; ++j) { dst[j+ 0] = dl * kvalues_iq4nl[qs[j] & 0xf]; dst[j+16] = dl * kvalues_iq4nl[qs[j] >> 4]; } dst += 32; qs += 16; } } } inline void dequant_q4_0(float* dst, const block_q4_0* src, int K) { const int nb = K / QK4_0; for (int b = 0; b < nb; ++b) { float d = static_cast(src[b].d); const uint8_t* qs = src[b].qs; float* o = dst + b * QK4_0; for (int j = 0; j < QK4_0 / 2; ++j) { o[j] = d * ((int)(qs[j] & 0xF) - 8); o[j + QK4_0/2] = d * ((int)(qs[j] >> 4) - 8); } } } inline void dequant_q3k(float* dst, const block_q3_K* src, int K) { const int nb = K / QK_K; for (int b = 0; b < nb; ++b) { float d = static_cast(src[b].d); const uint8_t* ql = src[b].qs; const uint8_t* hm = src[b].hmask; float* y = dst + b * QK_K; int is = 0; for (int n = 0; n < QK_K; n += 128) { for (int shift = 0; shift < 4; shift += 2) { // Decode 6-bit scales from packed 12-byte format int8_t dl; if (is < 8) { dl = (int8_t)((src[b].scales[is % 8] & 0xF) - 8); } else { dl = (int8_t)((src[b].scales[is % 8] >> 4) - 8); } float scale = d * dl; for (int l = 0; l < 32; ++l) { int q = (ql[l] >> shift) & 3; q |= ((hm[(n + shift*16 + l) / 8] >> ((n + shift*16 + l) % 8)) & 1) << 2; *y++ = scale * ((float)q - 4.0f); } is++; } ql += 32; } } } inline void dequant_bf16(float* dst, const void* src_raw, int K) { const bf16* src = static_cast(src_raw); for (int i = 0; i < K; ++i) dst[i] = static_cast(src[i]); } using DequantFn = void(*)(float*, const void*, int); inline void _dq_q8_0(float* d, const void* s, int K) { dequant_q8_0(d, (const block_q8_0*)s, K); } inline void _dq_q4k(float* d, const void* s, int K) { dequant_q4k(d, (const block_q4_K*)s, K); } inline void _dq_q2k(float* d, const void* s, int K) { dequant_q2k(d, (const block_q2_K*)s, K); } inline void _dq_q5k(float* d, const void* s, int K) { dequant_q5k(d, (const block_q5_K*)s, K); } inline void _dq_q6k(float* d, const void* s, int K) { dequant_q6k(d, (const block_q6_K*)s, K); } inline void _dq_iq4nl(float* d, const void* s, int K) { dequant_iq4nl(d, (const block_iq4_nl*)s, K); } inline void _dq_iq2xxs(float* d, const void* s, int K) { dequant_iq2xxs(d, (const block_iq2_xxs*)s, K); } inline void _dq_iq1s(float* d, const void* s, int K) { dequant_iq1s(d, (const block_iq1_s*)s, K); } inline void _dq_iq3xxs(float* d, const void* s, int K) { dequant_iq3xxs(d, (const block_iq3_xxs*)s, K); } inline void _dq_tq1_0(float* d, const void* s, int K) { dequant_tq1_0(d, (const block_tq1_0*)s, K); } inline void _dq_tq2_0(float* d, const void* s, int K) { dequant_tq2_0(d, (const block_tq2_0*)s, K); } inline void _dq_iq2s(float* d, const void* s, int K) { dequant_iq2s(d, (const block_iq2_s*)s, K); } inline void _dq_iq3s(float* d, const void* s, int K) { dequant_iq3s(d, (const block_iq3_s*)s, K); } inline void _dq_iq4xs(float* d, const void* s, int K) { dequant_iq4xs(d, (const block_iq4_xs*)s, K); } // --- v9: Q4_1 dequantization --- inline void dequant_q4_1(float* dst, const block_q4_1* src, int K) { const int nb = K / QK4_1; for (int b = 0; b < nb; ++b) { float d = static_cast(src[b].d); float m = static_cast(src[b].m); const uint8_t* qs = src[b].qs; float* o = dst + b * QK4_1; for (int j = 0; j < QK4_1 / 2; ++j) { o[j] = d * (qs[j] & 0xF) + m; o[j + QK4_1/2] = d * (qs[j] >> 4) + m; } } } // --- v9: Q5_0 dequantization --- inline void dequant_q5_0(float* dst, const block_q5_0* src, int K) { const int nb = K / QK5_0; for (int b = 0; b < nb; ++b) { float d = static_cast(src[b].d); const uint8_t* qs = src[b].qs; const uint8_t* qh = src[b].qh; uint32_t qhbits = qh[0] | ((uint32_t)qh[1] << 8) | ((uint32_t)qh[2] << 16) | ((uint32_t)qh[3] << 24); float* o = dst + b * QK5_0; for (int j = 0; j < QK5_0 / 2; ++j) { int x0 = (qs[j] & 0xF) | (((qhbits >> j) & 1) << 4); int x1 = (qs[j] >> 4) | (((qhbits >> (j + QK5_0/2)) & 1) << 4); o[j] = d * (x0 - 16); o[j + QK5_0/2] = d * (x1 - 16); } } } // --- v9: Q5_1 dequantization --- inline void dequant_q5_1(float* dst, const block_q5_1* src, int K) { const int nb = K / QK5_1; for (int b = 0; b < nb; ++b) { float d = static_cast(src[b].d); float m = static_cast(src[b].m); const uint8_t* qs = src[b].qs; const uint8_t* qh = src[b].qh; uint32_t qhbits = qh[0] | ((uint32_t)qh[1] << 8) | ((uint32_t)qh[2] << 16) | ((uint32_t)qh[3] << 24); float* o = dst + b * QK5_1; for (int j = 0; j < QK5_1 / 2; ++j) { int x0 = (qs[j] & 0xF) | (((qhbits >> j) & 1) << 4); int x1 = (qs[j] >> 4) | (((qhbits >> (j + QK5_1/2)) & 1) << 4); o[j] = d * x0 + m; o[j + QK5_1/2] = d * x1 + m; } } } // --- v9: Q8_1 dequantization --- inline void dequant_q8_1(float* dst, const block_q8_1* src, int K) { const int nb = K / QK8_1; for (int b = 0; b < nb; ++b) { float d = src[b].d; const int8_t* qs = src[b].qs; float* o = dst + b * QK8_1; for (int j = 0; j < QK8_1; ++j) { o[j] = d * qs[j]; } } } inline void _dq_q4_0(float* d, const void* s, int K) { dequant_q4_0(d, (const block_q4_0*)s, K); } inline void _dq_q3k(float* d, const void* s, int K) { dequant_q3k(d, (const block_q3_K*)s, K); } inline void _dq_bf16(float* d, const void* s, int K) { dequant_bf16(d, s, K); } inline void _dq_q4_1(float* d, const void* s, int K) { dequant_q4_1(d, (const block_q4_1*)s, K); } inline void _dq_q5_0(float* d, const void* s, int K) { dequant_q5_0(d, (const block_q5_0*)s, K); } inline void _dq_q5_1(float* d, const void* s, int K) { dequant_q5_1(d, (const block_q5_1*)s, K); } inline void _dq_q8_1(float* d, const void* s, int K) { dequant_q8_1(d, (const block_q8_1*)s, K); } inline size_t row_bytes(dtype type, int K) { int bs = dtype_block_size(type); if (bs <= 0) return (size_t)K * dtype_size(type); return (size_t)(K / bs) * dtype_size(type); } inline DequantFn get_dequant_fn(dtype type) { switch (type) { case dtype::Q8_0: return _dq_q8_0; case dtype::Q4_K: return _dq_q4k; case dtype::Q2_K: return _dq_q2k; case dtype::Q5_K: return _dq_q5k; case dtype::Q6_K: return _dq_q6k; case dtype::IQ4_NL: return _dq_iq4nl; case dtype::IQ2_XXS: return _dq_iq2xxs; case dtype::IQ1_S: return _dq_iq1s; case dtype::IQ3_XXS: return _dq_iq3xxs; case dtype::TQ1_0: return _dq_tq1_0; case dtype::TQ2_0: return _dq_tq2_0; case dtype::IQ2_S: return _dq_iq2s; case dtype::IQ3_S: return _dq_iq3s; case dtype::IQ4_XS: return _dq_iq4xs; case dtype::Q4_0: return _dq_q4_0; case dtype::Q3_K: return _dq_q3k; case dtype::BF16: return _dq_bf16; default: return nullptr; } } // Fused Q4_K dot — dequant+dot in one pass, stays in L1 // AVX2+FMA optimized by Inference-X kernel team inline float fused_dot_q4k(const block_q4_K* row, const float* x, int K) { const int nb = K / QK_K; float total = 0.0f; #if IX_AVX2 __m256 acc_total = _mm256_setzero_ps(); for (int b = 0; b < nb; ++b) { const uint8_t* q = row[b].qs; float d = static_cast(row[b].d); float dmin = static_cast(row[b].dmin); const float* xb = x + b * QK_K; int is = 0, xoff = 0; for (int j = 0; j < QK_K; j += 64) { uint8_t sc, m; get_scale_min_k4(is, row[b].scales, &sc, &m); float d1 = d * sc, m1 = dmin * m; get_scale_min_k4(is + 1, row[b].scales, &sc, &m); float d2 = d * sc, m2 = dmin * m; const __m256 vd1 = _mm256_set1_ps(d1); const __m256 vm1 = _mm256_set1_ps(m1); const __m256 vd2 = _mm256_set1_ps(d2); const __m256 vm2 = _mm256_set1_ps(m2); // Low nibbles: 32 elements in 4x8 for (int l = 0; l < 32; l += 8) { __m128i qb = _mm_loadl_epi64((const __m128i*)(q + l)); __m256i q32 = _mm256_cvtepu8_epi32(_mm_and_si128(qb, _mm_set1_epi8(0xF))); __m256 qf = _mm256_cvtepi32_ps(q32); __m256 val = _mm256_fmsub_ps(vd1, qf, vm1); __m256 xv = _mm256_loadu_ps(xb + xoff + l); acc_total = _mm256_fmadd_ps(val, xv, acc_total); } // High nibbles: 32 elements in 4x8 for (int l = 0; l < 32; l += 8) { __m128i qb = _mm_loadl_epi64((const __m128i*)(q + l)); __m256i q32 = _mm256_cvtepu8_epi32( _mm_and_si128(_mm_srli_epi16(qb, 4), _mm_set1_epi8(0xF))); __m256 qf = _mm256_cvtepi32_ps(q32); __m256 val = _mm256_fmsub_ps(vd2, qf, vm2); __m256 xv = _mm256_loadu_ps(xb + xoff + 32 + l); acc_total = _mm256_fmadd_ps(val, xv, acc_total); } q += 32; is += 2; xoff += 64; } } __m128 lo = _mm256_castps256_ps128(acc_total); __m128 hi = _mm256_extractf128_ps(acc_total, 1); lo = _mm_add_ps(lo, hi); lo = _mm_hadd_ps(lo, lo); lo = _mm_hadd_ps(lo, lo); total = _mm_cvtss_f32(lo); #else for (int b = 0; b < nb; ++b) { const uint8_t* q = row[b].qs; float d = static_cast(row[b].d); float dmin = static_cast(row[b].dmin); const float* xb = x + b * QK_K; int is = 0, xoff = 0; for (int j = 0; j < QK_K; j += 64) { uint8_t sc, m; get_scale_min_k4(is, row[b].scales, &sc, &m); float d1 = d * sc, m1 = dmin * m; get_scale_min_k4(is + 1, row[b].scales, &sc, &m); float d2 = d * sc, m2 = dmin * m; for (int l = 0; l < 32; ++l) total += (d1 * (q[l] & 0xF) - m1) * xb[xoff + l]; for (int l = 0; l < 32; ++l) total += (d2 * (q[l] >> 4) - m2) * xb[xoff + 32 + l]; q += 32; is += 2; xoff += 64; } } #endif return total; } inline float fused_dot_q8_0(const block_q8_0* row, const float* x, int K) { const int nb = K / 32; float total = 0.0f; #if IX_AVX2 __m256 acc = _mm256_setzero_ps(); for (int b = 0; b < nb; ++b) { float d = static_cast(row[b].d); const int8_t* qs = row[b].qs; const float* xb = x + b * 32; __m256 vd = _mm256_set1_ps(d); for (int i = 0; i < 32; i += 8) { __m128i qb = _mm_loadl_epi64((const __m128i*)(qs + i)); __m256i q32 = _mm256_cvtepi8_epi32(qb); __m256 qf = _mm256_cvtepi32_ps(q32); __m256 xv = _mm256_loadu_ps(xb + i); acc = _mm256_fmadd_ps(_mm256_mul_ps(vd, qf), xv, acc); } } __m128 lo = _mm256_castps256_ps128(acc); __m128 hi = _mm256_extractf128_ps(acc, 1); lo = _mm_add_ps(lo, hi); lo = _mm_hadd_ps(lo, lo); lo = _mm_hadd_ps(lo, lo); total = _mm_cvtss_f32(lo); #else for (int b = 0; b < nb; ++b) { float d = static_cast(row[b].d); const int8_t* qs = row[b].qs; const float* xb = x + b * 32; float s = 0.0f; for (int i = 0; i < 32; ++i) s += qs[i] * xb[i]; total += d * s; } #endif return total; } inline void matmul_dequant(float* out, const void* W, DequantFn dequant, size_t rb, const float* x, int M, int K) { #pragma omp parallel { thread_local std::vector buf; if ((int)buf.size() < K) buf.resize(K); #pragma omp for schedule(static) for (int m = 0; m < M; ++m) { dequant(buf.data(), (const uint8_t*)W + (size_t)m * rb, K); float sum = 0; #if IX_AVX2 __m256 acc = _mm256_setzero_ps(); int k = 0; for (; k + 8 <= K; k += 8) acc = _mm256_fmadd_ps(_mm256_loadu_ps(buf.data()+k), _mm256_loadu_ps(x+k), acc); __m128 lo = _mm256_castps256_ps128(acc); __m128 hi = _mm256_extractf128_ps(acc, 1); lo = _mm_add_ps(lo, hi); lo = _mm_hadd_ps(lo, lo); lo = _mm_hadd_ps(lo, lo); sum = _mm_cvtss_f32(lo); for (; k < K; ++k) sum += buf[k] * x[k]; #else for (int k = 0; k < K; ++k) sum += buf[k] * x[k]; #endif out[m] = wm_inject(sum); } } } // ═══════════════════════════════════════════════════════════════════════════════ // UNIFIED MATMUL — handles ALL Kimi K2.5 formats // ═══════════════════════════════════════════════════════════════════════════════ inline void matmul(float* out, const void* W, dtype type, const float* x, int M, int K) { if (type == dtype::F32) { matmul_f32(out, (const float*)W, x, M, K); return; } if (type == dtype::F16) { matmul_f16(out, (const f16*)W, x, M, K); return; } // Fused paths: dequant+dot in one pass (better cache, no buffer) if (type == dtype::Q4_K) { size_t rb = row_bytes(type, K); #pragma omp parallel for schedule(static) for (int m = 0; m < M; ++m) { const block_q4_K* row = (const block_q4_K*)((const uint8_t*)W + (size_t)m * rb); out[m] = wm_inject(fused_dot_q4k(row, x, K)); } return; } if (type == dtype::Q8_0) { size_t rb = row_bytes(type, K); #pragma omp parallel for schedule(static) for (int m = 0; m < M; ++m) { const block_q8_0* row = (const block_q8_0*)((const uint8_t*)W + (size_t)m * rb); out[m] = wm_inject(fused_dot_q8_0(row, x, K)); } return; } // Other quantized types: dequant + dot pipeline DequantFn fn = get_dequant_fn(type); if (fn) { size_t rb = row_bytes(type, K); matmul_dequant(out, W, fn, rb, x, M, K); return; } fprintf(stderr, "[GEMM] unsupported dtype %d (%s)\n", (int)type, dtype_name(type)); kernel::vec_zero(out, M); } inline void dequantize_row(float* out, const void* data, dtype type, int K) { DequantFn fn = get_dequant_fn(type); if (fn) { fn(out, data, K); return; } if (type == dtype::F32) { std::memcpy(out, data, K * sizeof(float)); return; } if (type == dtype::F16) { const f16* s = (const f16*)data; for (int i = 0; i < K; ++i) out[i] = static_cast(s[i]); return; } std::memset(out, 0, K * sizeof(float)); } inline void embed_lookup(float* out, const void* embd, dtype type, int token, int dim) { if (type == dtype::F32) { std::memcpy(out, (const float*)embd + (size_t)token * dim, dim * sizeof(float)); return; } if (type == dtype::F16) { const f16* r = (const f16*)embd + (size_t)token * dim; for (int i = 0; i < dim; ++i) out[i] = static_cast(r[i]); return; } size_t rb = row_bytes(type, dim); DequantFn fn = get_dequant_fn(type); if (fn) { fn(out, (const uint8_t*)embd + (size_t)token * rb, dim); return; } std::memset(out, 0, dim * sizeof(float)); } } // namespace gemm } // namespace ix