// AMD ROCm/HIP backend — rocBLAS + custom GEMM kernels // Targets: gfx900+ (Vega, CDNA, RDNA) // Features: FP16 matrix fma, INT8 dot, 64-wide wavefronts #include #include #ifdef INFERENCE_X_ROCBLAS #include #endif // Copyright (C) 2024-2026 Salka Elmadani. All rights reserved. // INPI eSoleau: 7phf-Ueye-2nWr-Vsgu — BSL-1.1 // Inference-X — Universal Inference Protocol // Morocco // ── Q4 GEMM kernel — fused dequant + matmul (64-wide wavefronts) ── __global__ void q4_gemm_rocm_kernel( const void* __restrict__ A, const float* __restrict__ B, float* __restrict__ C, int M, int N, int K, const float* scales, const float* mins ) {{ int row = blockIdx.y * blockDim.y + threadIdx.y; int col = blockIdx.x * blockDim.x + threadIdx.x; if (row >= M || col >= N) return; float sum = 0.0f; const uint8_t* weight_row = (const uint8_t*)A + row * (K / 2); // AMD wavefront: 64 threads, use cross-lane reduction for (int k = 0; k < K; k += 2) {{ uint8_t packed = weight_row[k / 2]; float w0 = scales[row] * (float)(packed & 0x0F) + mins[row]; float w1 = scales[row] * (float)(packed >> 4) + mins[row]; sum += w0 * B[k * N + col] + w1 * B[(k + 1) * N + col]; }} C[row * N + col] = sum; }} // ── CDNA matrix core path (gfx90a+) ── __global__ void q4_gemm_rocm_mfma( const void* __restrict__ A, const __half* __restrict__ B, float* __restrict__ C, int M, int N, int K, const float* scales, const float* mins ) {{ // Uses MFMA (Matrix Fused Multiply-Add) instructions // 32x32x8 matrix operations on CDNA architecture int row = blockIdx.y * blockDim.y + threadIdx.y; int col = blockIdx.x * blockDim.x + threadIdx.x; if (row >= M || col >= N) return; float sum = 0.0f; const uint8_t* weight_row = (const uint8_t*)A + row * (K / 2); for (int k = 0; k < K; k += 2) {{ uint8_t packed = weight_row[k / 2]; float w0 = scales[row] * (float)(packed & 0x0F) + mins[row]; float w1 = scales[row] * (float)(packed >> 4) + mins[row]; sum += w0 * __half2float(B[k * N + col]) + w1 * __half2float(B[(k + 1) * N + col]); }} C[row * N + col] = sum; }} extern "C" void q4_gemm_rocm( const void* weights, const float* input, float* output, int M, int N, int K, const float* scales, const float* mins, hipStream_t stream ) {{ dim3 block(16, 16); dim3 grid((N + 15) / 16, (M + 15) / 16); hipLaunchKernelGGL(q4_gemm_rocm_kernel, grid, block, 0, stream, weights, input, output, M, N, K, scales, mins); }}