inference-x/backends/q4_kernels/webgpu/q4_gemm_webgpu.cpp
Salka Elmadani ec36668cf5 Inference-X v1.0 — Universal AI Inference Engine
Better output from the same model. Fused computation, adaptive precision,
surgical expert loading. 305 KB, 19 backends, zero dependencies.

https://inference-x.com
2026-02-23 07:10:47 +00:00

63 lines
2.2 KiB
C++

// WebGPU backend — Browser-native GPU inference
// Targets: Chrome 113+, Firefox 121+, Safari 18+
// Features: WGSL compute shaders, zero-install web inference
#include <cstdint>
// Copyright (C) 2024-2026 Salka Elmadani. All rights reserved.
// INPI eSoleau: 7phf-Ueye-2nWr-Vsgu — BSL-1.1
// Inference-X — Universal Inference Protocol
// Morocco
// ── WGSL compute shader for Q4 GEMM ──
// This shader is compiled and dispatched via the WebGPU API in JavaScript
// The C++ wrapper provides the bridge for Emscripten/WASM builds
static const char* q4_gemm_wgsl = R"WGSL(
@group(0) @binding(0) var<storage, read> weights: array<u32>;
@group(0) @binding(1) var<storage, read> input: array<f32>;
@group(0) @binding(2) var<storage, read_write> output: array<f32>;
@group(0) @binding(3) var<storage, read> scales: array<f32>;
@group(0) @binding(4) var<storage, read> mins: array<f32>;
struct Params {
M: u32, N: u32, K: u32
};
@group(0) @binding(5) var<uniform> params: Params;
@compute @workgroup_size(16, 16)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
let row = gid.y;
let col = gid.x;
if (row >= params.M || col >= params.N) { return; }
var sum: f32 = 0.0;
let w_offset = row * (params.K / 2u);
for (var k: u32 = 0u; k < params.K; k += 8u) {
// Load 4 bytes = 8 Q4 values packed in a u32
let packed = weights[w_offset + k / 8u]; // NOTE: simplified
let scale = scales[row];
let min_val = mins[row];
// Extract and dequantize 8 values
for (var j: u32 = 0u; j < 8u && (k + j) < params.K; j += 2u) {
let byte_val = (packed >> ((j / 2u) * 8u)) & 0xFFu;
let w0 = scale * f32(byte_val & 0x0Fu) + min_val;
let w1 = scale * f32(byte_val >> 4u) + min_val;
sum += w0 * input[(k + j) * params.N + col];
sum += w1 * input[(k + j + 1u) * params.N + col];
}
}
output[row * params.N + col] = sum;
}
)WGSL";
// ── C interface for Emscripten bridge ──
extern "C" {{
const char* ix_webgpu_get_shader() {{ return q4_gemm_wgsl; }}
int ix_webgpu_workgroup_x() {{ return 16; }}
int ix_webgpu_workgroup_y() {{ return 16; }}
}}