Better output from the same model. Fused computation, adaptive precision, surgical expert loading. 305 KB, 19 backends, zero dependencies. https://inference-x.com
109 lines
3.9 KiB
Plaintext
109 lines
3.9 KiB
Plaintext
// Apple Metal backend — Metal Performance Shaders + custom compute
|
|
// Targets: M1/M2/M3/M4 (Apple Silicon), A14+ (iPhone/iPad)
|
|
// Features: Unified memory, tile shaders, SIMD-group matrix ops
|
|
|
|
#import <Metal/Metal.h>
|
|
#import <MetalPerformanceShaders/MetalPerformanceShaders.h>
|
|
|
|
// Copyright (C) 2024-2026 Salka Elmadani. All rights reserved.
|
|
// INPI eSoleau: 7phf-Ueye-2nWr-Vsgu — BSL-1.1
|
|
// Inference-X — Universal Inference Protocol
|
|
// Morocco
|
|
|
|
// ── Metal compute shader (embedded as string, compiled at runtime) ──
|
|
static const char* q4_gemm_metal_shader = R"(
|
|
#include <metal_stdlib>
|
|
using namespace metal;
|
|
|
|
// Q4_K dequantize + GEMM kernel
|
|
kernel void q4_gemm_kernel(
|
|
device const uint8_t* weights [[buffer(0)]],
|
|
device const float* input [[buffer(1)]],
|
|
device float* output [[buffer(2)]],
|
|
device const float* scales [[buffer(3)]],
|
|
device const float* mins [[buffer(4)]],
|
|
constant int& M [[buffer(5)]],
|
|
constant int& N [[buffer(6)]],
|
|
constant int& K [[buffer(7)]],
|
|
uint2 gid [[thread_position_in_grid]]
|
|
) {
|
|
int row = gid.y;
|
|
int col = gid.x;
|
|
if (row >= M || col >= N) return;
|
|
|
|
float sum = 0.0f;
|
|
device const uint8_t* weight_row = weights + row * (K / 2);
|
|
|
|
// Fused dequant + dot — Metal SIMD-group vectorized
|
|
for (int k = 0; k < K; k += 2) {
|
|
uint8_t packed = weight_row[k / 2];
|
|
float w0 = scales[row] * float(packed & 0x0F) + mins[row];
|
|
float w1 = scales[row] * float(packed >> 4) + mins[row];
|
|
sum += w0 * input[k * N + col] + w1 * input[(k + 1) * N + col];
|
|
}
|
|
|
|
output[row * N + col] = sum;
|
|
}
|
|
|
|
// FP16 variant for Apple Neural Engine acceleration
|
|
kernel void q4_gemm_kernel_fp16(
|
|
device const uint8_t* weights [[buffer(0)]],
|
|
device const half* input [[buffer(1)]],
|
|
device half* output [[buffer(2)]],
|
|
device const half* scales [[buffer(3)]],
|
|
device const half* mins [[buffer(4)]],
|
|
constant int& M [[buffer(5)]],
|
|
constant int& N [[buffer(6)]],
|
|
constant int& K [[buffer(7)]],
|
|
uint2 gid [[thread_position_in_grid]]
|
|
) {
|
|
int row = gid.y;
|
|
int col = gid.x;
|
|
if (row >= M || col >= N) return;
|
|
|
|
half sum = 0.0h;
|
|
device const uint8_t* weight_row = weights + row * (K / 2);
|
|
|
|
for (int k = 0; k < K; k += 2) {
|
|
uint8_t packed = weight_row[k / 2];
|
|
half w0 = scales[row] * half(packed & 0x0F) + mins[row];
|
|
half w1 = scales[row] * half(packed >> 4) + mins[row];
|
|
sum += w0 * input[k * N + col] + w1 * input[(k + 1) * N + col];
|
|
}
|
|
|
|
output[row * N + col] = sum;
|
|
}
|
|
)";
|
|
|
|
// ── Metal pipeline setup ──
|
|
@interface IXMetalBackend : NSObject
|
|
@property (nonatomic, strong) id<MTLDevice> device;
|
|
@property (nonatomic, strong) id<MTLCommandQueue> queue;
|
|
@property (nonatomic, strong) id<MTLComputePipelineState> pipeline;
|
|
@property (nonatomic, strong) id<MTLComputePipelineState> pipeline_fp16;
|
|
- (instancetype)init;
|
|
- (void)q4_gemm:(const void*)weights input:(const float*)input output:(float*)output
|
|
M:(int)M N:(int)N K:(int)K scales:(const float*)scales mins:(const float*)mins;
|
|
@end
|
|
|
|
@implementation IXMetalBackend
|
|
- (instancetype)init {{
|
|
self = [super init];
|
|
if (self) {{
|
|
_device = MTLCreateSystemDefaultDevice();
|
|
_queue = [_device newCommandQueue];
|
|
NSError* error = nil;
|
|
id<MTLLibrary> lib = [_device newLibraryWithSource:
|
|
[NSString stringWithUTF8String:q4_gemm_metal_shader]
|
|
options:nil error:&error];
|
|
if (lib) {{
|
|
id<MTLFunction> fn = [lib newFunctionWithName:@"q4_gemm_kernel"];
|
|
_pipeline = [_device newComputePipelineStateWithFunction:fn error:&error];
|
|
id<MTLFunction> fn16 = [lib newFunctionWithName:@"q4_gemm_kernel_fp16"];
|
|
_pipeline_fp16 = [_device newComputePipelineStateWithFunction:fn16 error:&error];
|
|
}}
|
|
}}
|
|
return self;
|
|
}}
|
|
@end
|