Better output from the same model. Fused computation, adaptive precision, surgical expert loading. 305 KB, 19 backends, zero dependencies. https://inference-x.com
63 lines
2.2 KiB
C++
63 lines
2.2 KiB
C++
// WebGPU backend — Browser-native GPU inference
|
|
// Targets: Chrome 113+, Firefox 121+, Safari 18+
|
|
// Features: WGSL compute shaders, zero-install web inference
|
|
|
|
#include <cstdint>
|
|
|
|
// Copyright (C) 2024-2026 Salka Elmadani. All rights reserved.
|
|
// INPI eSoleau: 7phf-Ueye-2nWr-Vsgu — BSL-1.1
|
|
// Inference-X — Universal Inference Protocol
|
|
// Morocco
|
|
|
|
// ── WGSL compute shader for Q4 GEMM ──
|
|
// This shader is compiled and dispatched via the WebGPU API in JavaScript
|
|
// The C++ wrapper provides the bridge for Emscripten/WASM builds
|
|
|
|
static const char* q4_gemm_wgsl = R"WGSL(
|
|
@group(0) @binding(0) var<storage, read> weights: array<u32>;
|
|
@group(0) @binding(1) var<storage, read> input: array<f32>;
|
|
@group(0) @binding(2) var<storage, read_write> output: array<f32>;
|
|
@group(0) @binding(3) var<storage, read> scales: array<f32>;
|
|
@group(0) @binding(4) var<storage, read> mins: array<f32>;
|
|
|
|
struct Params {
|
|
M: u32, N: u32, K: u32
|
|
};
|
|
@group(0) @binding(5) var<uniform> params: Params;
|
|
|
|
@compute @workgroup_size(16, 16)
|
|
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
|
let row = gid.y;
|
|
let col = gid.x;
|
|
if (row >= params.M || col >= params.N) { return; }
|
|
|
|
var sum: f32 = 0.0;
|
|
let w_offset = row * (params.K / 2u);
|
|
|
|
for (var k: u32 = 0u; k < params.K; k += 8u) {
|
|
// Load 4 bytes = 8 Q4 values packed in a u32
|
|
let packed = weights[w_offset + k / 8u]; // NOTE: simplified
|
|
let scale = scales[row];
|
|
let min_val = mins[row];
|
|
|
|
// Extract and dequantize 8 values
|
|
for (var j: u32 = 0u; j < 8u && (k + j) < params.K; j += 2u) {
|
|
let byte_val = (packed >> ((j / 2u) * 8u)) & 0xFFu;
|
|
let w0 = scale * f32(byte_val & 0x0Fu) + min_val;
|
|
let w1 = scale * f32(byte_val >> 4u) + min_val;
|
|
sum += w0 * input[(k + j) * params.N + col];
|
|
sum += w1 * input[(k + j + 1u) * params.N + col];
|
|
}
|
|
}
|
|
|
|
output[row * params.N + col] = sum;
|
|
}
|
|
)WGSL";
|
|
|
|
// ── C interface for Emscripten bridge ──
|
|
extern "C" {{
|
|
const char* ix_webgpu_get_shader() {{ return q4_gemm_wgsl; }}
|
|
int ix_webgpu_workgroup_x() {{ return 16; }}
|
|
int ix_webgpu_workgroup_y() {{ return 16; }}
|
|
}}
|