inference-x/backends/q4_kernels/cpu/q4_gemm_cpu.c
Salka Elmadani ec36668cf5 Inference-X v1.0 — Universal AI Inference Engine
Better output from the same model. Fused computation, adaptive precision,
surgical expert loading. 305 KB, 19 backends, zero dependencies.

https://inference-x.com
2026-02-23 07:10:47 +00:00

69 lines
2.4 KiB
C

// ═══════════════════════════════════════════════════════════════════════════════
// INFERENCE-X — CPU AVX-512 Q4_K Backend
// Copyright (C) 2025-2026 Salka Elmadani. All rights reserved.
// BSL-1.1 | Morocco
// ═══════════════════════════════════════════════════════════════════════════════
#include "q4_types.h"
#include <string.h>
#ifdef __AVX512F__
#include <immintrin.h>
#endif
#define IX_BACKEND_ID "Inference-X-CPU-AVX512"
// Dequantize Q4_K block to FP32 — correct f16 decoding
static void dequantize_q4_K_block(const block_q4_K *x, float *y) {
const uint8_t *q = x->qs;
float d = f16_to_float(x->d);
float dmin = f16_to_float(x->dmin);
int is = 0;
for (int j = 0; j < QK_K; j += 64) {
uint8_t sc, m;
get_scale_min_k4(is + 0, x->scales, &sc, &m);
float d1 = d * sc, m1 = dmin * m;
get_scale_min_k4(is + 1, x->scales, &sc, &m);
float d2 = d * sc, m2 = dmin * m;
for (int l = 0; l < 32; ++l) *y++ = d1 * (q[l] & 0xF) - m1;
for (int l = 0; l < 32; ++l) *y++ = d2 * (q[l] >> 4) - m2;
q += 32; is += 2;
}
}
// Q4_K matmul: out[M] = W[M,K] @ x[K] where W is Q4_K quantized
void gemm_q4_K_fp32_cpu(
const void *W_raw,
const float *x,
float *out,
int M, int K)
{
const int nb = K / QK_K;
const block_q4_K *W = (const block_q4_K *)W_raw;
#pragma omp parallel for schedule(static)
for (int m = 0; m < M; m++) {
float y_buf[QK_K];
float sum = 0.0f;
for (int kb = 0; kb < nb; kb++) {
dequantize_q4_K_block(&W[m * nb + kb], y_buf);
const float *xb = x + kb * QK_K;
#ifdef __AVX512F__
__m512 acc = _mm512_setzero_ps();
for (int i = 0; i < QK_K; i += 16) {
__m512 vq = _mm512_loadu_ps(y_buf + i);
__m512 vx = _mm512_loadu_ps(xb + i);
acc = _mm512_fmadd_ps(vq, vx, acc);
}
sum += _mm512_reduce_add_ps(acc);
#else
for (int i = 0; i < QK_K; ++i)
sum += y_buf[i] * xb[i];
#endif
}
out[m] = sum;
}
}