inference-x/backends/q4_kernels/metal/q4_gemm_metal.mm
Salka Elmadani ec36668cf5 Inference-X v1.0 — Universal AI Inference Engine
Better output from the same model. Fused computation, adaptive precision,
surgical expert loading. 305 KB, 19 backends, zero dependencies.

https://inference-x.com
2026-02-23 07:10:47 +00:00

109 lines
3.9 KiB
Plaintext

// Apple Metal backend — Metal Performance Shaders + custom compute
// Targets: M1/M2/M3/M4 (Apple Silicon), A14+ (iPhone/iPad)
// Features: Unified memory, tile shaders, SIMD-group matrix ops
#import <Metal/Metal.h>
#import <MetalPerformanceShaders/MetalPerformanceShaders.h>
// Copyright (C) 2024-2026 Salka Elmadani. All rights reserved.
// INPI eSoleau: 7phf-Ueye-2nWr-Vsgu — BSL-1.1
// Inference-X — Universal Inference Protocol
// Morocco
// ── Metal compute shader (embedded as string, compiled at runtime) ──
static const char* q4_gemm_metal_shader = R"(
#include <metal_stdlib>
using namespace metal;
// Q4_K dequantize + GEMM kernel
kernel void q4_gemm_kernel(
device const uint8_t* weights [[buffer(0)]],
device const float* input [[buffer(1)]],
device float* output [[buffer(2)]],
device const float* scales [[buffer(3)]],
device const float* mins [[buffer(4)]],
constant int& M [[buffer(5)]],
constant int& N [[buffer(6)]],
constant int& K [[buffer(7)]],
uint2 gid [[thread_position_in_grid]]
) {
int row = gid.y;
int col = gid.x;
if (row >= M || col >= N) return;
float sum = 0.0f;
device const uint8_t* weight_row = weights + row * (K / 2);
// Fused dequant + dot — Metal SIMD-group vectorized
for (int k = 0; k < K; k += 2) {
uint8_t packed = weight_row[k / 2];
float w0 = scales[row] * float(packed & 0x0F) + mins[row];
float w1 = scales[row] * float(packed >> 4) + mins[row];
sum += w0 * input[k * N + col] + w1 * input[(k + 1) * N + col];
}
output[row * N + col] = sum;
}
// FP16 variant for Apple Neural Engine acceleration
kernel void q4_gemm_kernel_fp16(
device const uint8_t* weights [[buffer(0)]],
device const half* input [[buffer(1)]],
device half* output [[buffer(2)]],
device const half* scales [[buffer(3)]],
device const half* mins [[buffer(4)]],
constant int& M [[buffer(5)]],
constant int& N [[buffer(6)]],
constant int& K [[buffer(7)]],
uint2 gid [[thread_position_in_grid]]
) {
int row = gid.y;
int col = gid.x;
if (row >= M || col >= N) return;
half sum = 0.0h;
device const uint8_t* weight_row = weights + row * (K / 2);
for (int k = 0; k < K; k += 2) {
uint8_t packed = weight_row[k / 2];
half w0 = scales[row] * half(packed & 0x0F) + mins[row];
half w1 = scales[row] * half(packed >> 4) + mins[row];
sum += w0 * input[k * N + col] + w1 * input[(k + 1) * N + col];
}
output[row * N + col] = sum;
}
)";
// ── Metal pipeline setup ──
@interface IXMetalBackend : NSObject
@property (nonatomic, strong) id<MTLDevice> device;
@property (nonatomic, strong) id<MTLCommandQueue> queue;
@property (nonatomic, strong) id<MTLComputePipelineState> pipeline;
@property (nonatomic, strong) id<MTLComputePipelineState> pipeline_fp16;
- (instancetype)init;
- (void)q4_gemm:(const void*)weights input:(const float*)input output:(float*)output
M:(int)M N:(int)N K:(int)K scales:(const float*)scales mins:(const float*)mins;
@end
@implementation IXMetalBackend
- (instancetype)init {{
self = [super init];
if (self) {{
_device = MTLCreateSystemDefaultDevice();
_queue = [_device newCommandQueue];
NSError* error = nil;
id<MTLLibrary> lib = [_device newLibraryWithSource:
[NSString stringWithUTF8String:q4_gemm_metal_shader]
options:nil error:&error];
if (lib) {{
id<MTLFunction> fn = [lib newFunctionWithName:@"q4_gemm_kernel"];
_pipeline = [_device newComputePipelineStateWithFunction:fn error:&error];
id<MTLFunction> fn16 = [lib newFunctionWithName:@"q4_gemm_kernel_fp16"];
_pipeline_fp16 = [_device newComputePipelineStateWithFunction:fn16 error:&error];
}}
}}
return self;
}}
@end