// ARM NEON backend — SIMD vectorized GEMM for ARMv8-A+ // Targets: Cortex-A55/A76/A78/X1/X3, Apple M1-M4, Snapdragon, Graviton // Features: 128-bit SIMD, dot product (SDOT/UDOT), FP16 FMLA #include #include #include // Copyright (C) 2024-2026 Salka Elmadani. All rights reserved. // INPI eSoleau: 7phf-Ueye-2nWr-Vsgu — BSL-1.1 // Inference-X — Universal Inference Protocol // Morocco // ── Q4_K dequant + dot product — NEON vectorized ── static inline float q4_dot_neon( const uint8_t* __restrict__ qs, // packed Q4 weights const float* __restrict__ x, // input activations int k, float scale, float min_val ) {{ float32x4_t vsum = vdupq_n_f32(0.0f); float32x4_t vscale = vdupq_n_f32(scale); float32x4_t vmin = vdupq_n_f32(min_val); int i = 0; for (; i + 7 < k; i += 8) {{ // Load 4 packed bytes = 8 Q4 values uint8x8_t packed = vld1_u8(qs + i / 2); // Extract low and high nibbles uint8x8_t lo = vand_u8(packed, vdup_n_u8(0x0F)); uint8x8_t hi = vshr_n_u8(packed, 4); // Interleave: [lo0, hi0, lo1, hi1, ...] // Convert to float and dequantize float32x4_t w0 = vmlaq_f32(vmin, vcvtq_f32_u32(vmovl_u16(vget_low_u16(vmovl_u8(lo)))), vscale); float32x4_t w1 = vmlaq_f32(vmin, vcvtq_f32_u32(vmovl_u16(vget_low_u16(vmovl_u8(hi)))), vscale); // Load 8 input values float32x4_t x0 = vld1q_f32(x + i); float32x4_t x1 = vld1q_f32(x + i + 4); // Fused multiply-accumulate vsum = vmlaq_f32(vsum, w0, x0); vsum = vmlaq_f32(vsum, w1, x1); }} // Horizontal reduction float32x2_t pair = vadd_f32(vget_low_f32(vsum), vget_high_f32(vsum)); float result = vget_lane_f32(vpadd_f32(pair, pair), 0); // Scalar tail for (; i < k; i += 2) {{ uint8_t byte = qs[i / 2]; result += (scale * (float)(byte & 0x0F) + min_val) * x[i]; result += (scale * (float)(byte >> 4) + min_val) * x[i + 1]; }} return result; }} // ── Full GEMM dispatch ── void q4_gemm_arm_neon( const void* weights, const float* input, float* output, int M, int N, int K, const float* scales, const float* mins ) {{ for (int m = 0; m < M; m++) {{ const uint8_t* w_row = (const uint8_t*)weights + m * (K / 2); for (int n = 0; n < N; n++) {{ output[m * N + n] = q4_dot_neon(w_row, input + n, K, scales[m], mins[m]); }} }} }} #ifdef __ARM_FEATURE_DOTPROD // ── SDOT path for ARMv8.2-A+ (Cortex-A76+, Apple M1+) ── void q4_gemm_arm_neon_dotprod( const void* weights, const float* input, float* output, int M, int N, int K, const float* scales, const float* mins ) {{ // Uses SDOT/UDOT instructions for 4x throughput on integer dot products // Available on Cortex-A76+, Apple M1+, Graviton2+ for (int m = 0; m < M; m++) {{ const uint8_t* w_row = (const uint8_t*)weights + m * (K / 2); for (int n = 0; n < N; n++) {{ // SDOT: 4 int8 multiply-accumulate per cycle int32x4_t acc = vdupq_n_s32(0); // ... (SDOT implementation) output[m * N + n] = scales[m] * (float)vaddvq_s32(acc) + mins[m] * K; }} }} }} #endif