Better output from the same model. Fused computation, adaptive precision, surgical expert loading. 305 KB, 19 backends, zero dependencies. https://inference-x.com
721 lines
30 KiB
C++
721 lines
30 KiB
C++
// ═══════════════════════════════════════════════════════════════════════════════
|
|
// INFERENCE-X — Transformer v6 Forward Pass
|
|
// Copyright (C) 2024-2026 Salka Elmadani. All rights reserved.
|
|
// Licensed under the Business Source License 1.1 (BSL-1.1)
|
|
// See LICENSE file for full terms.
|
|
//
|
|
// INTELLECTUAL PROPERTY PROTECTION:
|
|
// - INPI eSoleau deposit: 7phf-Ueye-2nWr-Vsgu (16/02/2026)
|
|
// - GitHub: github.com/ElmadaniS/inference-x
|
|
// - Author: Salka Elmadani | Morocco | Morocco
|
|
//
|
|
// MANUFACTURER NOTICE: Any manufacturer, company, or entity that
|
|
// incorporates, embeds, distributes, or commercially uses Inference-X
|
|
// or any derivative work without explicit written authorization from
|
|
// the copyright holder is in violation of BSL-1.1 and applicable
|
|
// intellectual property laws. This includes but is not limited to:
|
|
// hardware vendors, cloud providers, SaaS platforms, and OEMs.
|
|
//
|
|
// Contact: Elmadani.SALKA@proton.me for licensing.
|
|
// ═══════════════════════════════════════════════════════════════════════════════
|
|
|
|
#pragma once
|
|
|
|
// Inference-X Transformer — Salka Elmadani — Morocco
|
|
#define IX_TRANSFORMER_SIGNATURE 0x935
|
|
#define IX_TRANSFORMER_MARK "Inference-X-Transformer-935-Elmadani"
|
|
|
|
// Inference-X Signature — integral to compilation
|
|
namespace ix {
|
|
constexpr uint32_t SIGNATURE = 935;
|
|
constexpr uint32_t FINGERPRINT = 0x935E1DAD;
|
|
constexpr const char* AUTHOR = "Salka Elmadani";
|
|
}
|
|
|
|
#include "../core/z_core.h"
|
|
#include "gguf.h"
|
|
#include "kernels.h"
|
|
#include "gemm.h"
|
|
#include "kernel_dispatch.h"
|
|
#include "attention.h"
|
|
#include "moe_mla.h"
|
|
#include <vector>
|
|
#include <random>
|
|
#include <cstdio>
|
|
#include <chrono>
|
|
|
|
namespace ix {
|
|
|
|
// ═══════════════════════════════════════════════════════════════════════════════
|
|
// LAYER WEIGHTS — Per-layer pointers into mmap'd GGUF
|
|
// Handles both dense and MoE layers, both GQA and MLA attention
|
|
// ═══════════════════════════════════════════════════════════════════════════════
|
|
struct LayerWeightsV6 {
|
|
// === Attention norms ===
|
|
const float* attn_norm = nullptr;
|
|
const float* ffn_norm = nullptr;
|
|
const float* post_attn_norm = nullptr; // Gemma-2: post-attention norm
|
|
const float* post_ffn_norm = nullptr; // Gemma-2: post-FFN norm
|
|
|
|
// Fused tensor adapter (Phi-3: QKV fused, gate+up fused)
|
|
const void* w_qkv_fused = nullptr; dtype t_qkv_fused = dtype::F32;
|
|
const void* w_ffn_gate_up_fused = nullptr; dtype t_ffn_gate_up_fused = dtype::F32;
|
|
bool has_fused_qkv = false;
|
|
bool has_fused_gate_up = false;
|
|
|
|
// === MLA Attention weights ===
|
|
const void* w_q_a = nullptr; dtype t_q_a = dtype::F32; // Q compress
|
|
const float* q_a_norm = nullptr; // Q norm
|
|
const void* w_q_b = nullptr; dtype t_q_b = dtype::F32; // Q decompress
|
|
const void* w_kv_a = nullptr; dtype t_kv_a = dtype::F32; // KV compress (MQA)
|
|
const float* kv_a_norm = nullptr; // KV norm
|
|
const void* w_k_b = nullptr; dtype t_k_b = dtype::F32; // K decompress
|
|
const void* w_v_b = nullptr; dtype t_v_b = dtype::F32; // V decompress
|
|
const void* w_o = nullptr; dtype t_o = dtype::F32; // output proj
|
|
|
|
// === GQA Attention weights (for dense layers if using standard attn) ===
|
|
const void* wq = nullptr; dtype tq = dtype::F32;
|
|
const void* wk = nullptr; dtype tk = dtype::F32;
|
|
const void* wv = nullptr; dtype tv = dtype::F32;
|
|
|
|
// === QKV Bias (Qwen2) ===
|
|
const float* bq = nullptr;
|
|
const float* bk = nullptr;
|
|
const float* bv = nullptr;
|
|
|
|
// === Dense FFN (for layer 0 / leading dense layers) ===
|
|
const void* w_ffn_gate = nullptr; dtype t_ffn_gate = dtype::F32;
|
|
const void* w_ffn_up = nullptr; dtype t_ffn_up = dtype::F32;
|
|
const void* w_ffn_down = nullptr; dtype t_ffn_down = dtype::F32;
|
|
|
|
// === MoE weights ===
|
|
const float* gate_inp = nullptr; // Router [dim, n_experts] F32
|
|
const void* gate_exps = nullptr; dtype t_gate_exps = dtype::F32; // [dim, ffn, n_exp]
|
|
const void* up_exps = nullptr; dtype t_up_exps = dtype::F32;
|
|
const void* down_exps = nullptr; dtype t_down_exps = dtype::F32;
|
|
|
|
// === Shared expert ===
|
|
const void* gate_shexp = nullptr; dtype t_gate_shexp = dtype::F32;
|
|
const void* up_shexp = nullptr; dtype t_up_shexp = dtype::F32;
|
|
const void* down_shexp = nullptr; dtype t_down_shexp = dtype::F32;
|
|
|
|
bool is_moe = false;
|
|
bool is_mla = false;
|
|
};
|
|
|
|
struct WeightsV6 {
|
|
const void* token_embd = nullptr;
|
|
dtype t_embd = dtype::F32;
|
|
std::vector<LayerWeightsV6> layers;
|
|
const float* output_norm = nullptr;
|
|
const void* output = nullptr;
|
|
dtype t_output = dtype::F32;
|
|
};
|
|
|
|
// ═══════════════════════════════════════════════════════════════════════════════
|
|
// LOAD WEIGHTS FROM GGUF — Architecture-aware
|
|
// ═══════════════════════════════════════════════════════════════════════════════
|
|
inline bool load_weights_v6(WeightsV6& w, const GGUF& gguf, const Config& cfg) {
|
|
auto get = [&](const std::string& name, const void*& ptr, dtype& type) -> bool {
|
|
const TensorInfo* ti = gguf.tensor(name);
|
|
if (!ti) return false;
|
|
ptr = ti->data;
|
|
type = ti->type;
|
|
return true;
|
|
};
|
|
|
|
auto get_f32 = [&](const std::string& name) -> const float* {
|
|
const TensorInfo* ti = gguf.tensor(name);
|
|
return ti ? static_cast<const float*>(ti->data) : nullptr;
|
|
};
|
|
|
|
// Token embedding
|
|
if (!get("token_embd.weight", w.token_embd, w.t_embd)) {
|
|
printf("[WEIGHTS] ERROR: token_embd.weight not found\n");
|
|
return false;
|
|
}
|
|
|
|
// Output
|
|
w.output_norm = get_f32("output_norm.weight");
|
|
if (!get("output.weight", w.output, w.t_output)) {
|
|
// Try lm_head
|
|
if (!get("lm_head.weight", w.output, w.t_output)) {
|
|
// Weight tying
|
|
w.output = w.token_embd;
|
|
w.t_output = w.t_embd;
|
|
}
|
|
}
|
|
|
|
// Layers
|
|
w.layers.resize(cfg.n_layers);
|
|
int loaded_mla = 0, loaded_moe = 0, loaded_dense = 0;
|
|
|
|
for (int l = 0; l < cfg.n_layers; ++l) {
|
|
std::string p = "blk." + std::to_string(l) + ".";
|
|
LayerWeightsV6& lw = w.layers[l];
|
|
|
|
// Norms (always present)
|
|
lw.attn_norm = get_f32(p + "attn_norm.weight");
|
|
lw.ffn_norm = get_f32(p + "ffn_norm.weight");
|
|
|
|
// === Try MLA attention ===
|
|
bool has_mla = false;
|
|
if (cfg.is_mla()) {
|
|
has_mla |= get(p + "attn_q_a.weight", lw.w_q_a, lw.t_q_a);
|
|
lw.q_a_norm = get_f32(p + "attn_q_a_norm.weight");
|
|
get(p + "attn_q_b.weight", lw.w_q_b, lw.t_q_b);
|
|
get(p + "attn_kv_a_mqa.weight", lw.w_kv_a, lw.t_kv_a);
|
|
lw.kv_a_norm = get_f32(p + "attn_kv_a_norm.weight");
|
|
get(p + "attn_k_b.weight", lw.w_k_b, lw.t_k_b);
|
|
get(p + "attn_v_b.weight", lw.w_v_b, lw.t_v_b);
|
|
get(p + "attn_output.weight", lw.w_o, lw.t_o);
|
|
lw.is_mla = has_mla;
|
|
if (has_mla) loaded_mla++;
|
|
}
|
|
|
|
// === Fallback: standard GQA attention ===
|
|
if (!has_mla) {
|
|
if (!get(p + "attn_q.weight", lw.wq, lw.tq)) {
|
|
// ═══ TENSOR ADAPTER: Fused QKV (Phi-3, Falcon, StarCoder2) ═══
|
|
if (get(p + "attn_qkv.weight", lw.w_qkv_fused, lw.t_qkv_fused)) {
|
|
lw.has_fused_qkv = true;
|
|
if (l == 0) printf("[ADAPTER] Fused QKV detected → will split at compute\n");
|
|
}
|
|
} else {
|
|
get(p + "attn_k.weight", lw.wk, lw.tk);
|
|
get(p + "attn_v.weight", lw.wv, lw.tv);
|
|
}
|
|
get(p + "attn_output.weight", lw.w_o, lw.t_o);
|
|
// QKV bias (Qwen2 architecture)
|
|
lw.bq = get_f32(p + "attn_q.bias");
|
|
lw.bk = get_f32(p + "attn_k.bias");
|
|
lw.bv = get_f32(p + "attn_v.bias");
|
|
|
|
// Post-norms (Gemma-2)
|
|
lw.post_attn_norm = get_f32(p + "post_attention_norm.weight");
|
|
lw.post_ffn_norm = get_f32(p + "post_ffw_norm.weight");
|
|
}
|
|
|
|
// === Try MoE FFN ===
|
|
bool has_moe = false;
|
|
if (cfg.is_moe() && l >= cfg.n_dense_layers) {
|
|
lw.gate_inp = get_f32(p + "ffn_gate_inp.weight");
|
|
has_moe |= (lw.gate_inp != nullptr);
|
|
get(p + "ffn_gate_exps.weight", lw.gate_exps, lw.t_gate_exps);
|
|
get(p + "ffn_up_exps.weight", lw.up_exps, lw.t_up_exps);
|
|
get(p + "ffn_down_exps.weight", lw.down_exps, lw.t_down_exps);
|
|
|
|
// Shared expert
|
|
get(p + "ffn_gate_shexp.weight", lw.gate_shexp, lw.t_gate_shexp);
|
|
get(p + "ffn_up_shexp.weight", lw.up_shexp, lw.t_up_shexp);
|
|
get(p + "ffn_down_shexp.weight", lw.down_shexp, lw.t_down_shexp);
|
|
|
|
lw.is_moe = has_moe;
|
|
if (has_moe) loaded_moe++;
|
|
}
|
|
|
|
// === Dense FFN (for leading dense layers) ===
|
|
if (!has_moe) {
|
|
if (!get(p + "ffn_gate.weight", lw.w_ffn_gate, lw.t_ffn_gate)) {
|
|
// No separate gate — try fused gate+up (Phi-3)
|
|
if (get(p + "ffn_up.weight", lw.w_ffn_gate_up_fused, lw.t_ffn_gate_up_fused)) {
|
|
lw.has_fused_gate_up = true;
|
|
if (l == 0) printf("[ADAPTER] Fused gate+up detected\n");
|
|
}
|
|
} else {
|
|
get(p + "ffn_up.weight", lw.w_ffn_up, lw.t_ffn_up);
|
|
}
|
|
get(p + "ffn_down.weight", lw.w_ffn_down, lw.t_ffn_down);
|
|
loaded_dense++;
|
|
}
|
|
}
|
|
|
|
printf("[WEIGHTS] Loaded: %d MLA layers, %d MoE layers, %d dense layers\n",
|
|
loaded_mla, loaded_moe, loaded_dense);
|
|
printf("[WEIGHTS] Embedding: %s, Output: %s\n",
|
|
dtype_name(w.t_embd), dtype_name(w.t_output));
|
|
|
|
return true;
|
|
}
|
|
|
|
// ═══════════════════════════════════════════════════════════════════════════════
|
|
// TRANSFORMER V6 — DeepSeek V3 Complete Forward Pass
|
|
// ═══════════════════════════════════════════════════════════════════════════════
|
|
class TransformerV6 {
|
|
public:
|
|
bool init(const GGUF& gguf, int max_ctx = 4096) {
|
|
if (!signature::verify()) return false;
|
|
|
|
cfg_ = gguf.extract_config();
|
|
// FIX: clamp max_seq_len to actual max_ctx
|
|
cfg_.max_seq_len = std::min(cfg_.max_seq_len, max_ctx);
|
|
|
|
// Universal: auto-cap context to available RAM
|
|
{
|
|
// KV cache size = 2 * n_kv_heads * max_seq * head_dim * sizeof(float) * n_layers
|
|
size_t kv_per_layer = 2ULL * cfg_.n_kv_heads * cfg_.max_seq_len * cfg_.head_dim * sizeof(float);
|
|
size_t kv_total = kv_per_layer * cfg_.n_layers;
|
|
// Get available RAM (rough: read /proc/meminfo)
|
|
size_t avail_ram = 0;
|
|
FILE* mi = fopen("/proc/meminfo", "r");
|
|
if (mi) {
|
|
char line[256];
|
|
while (fgets(line, sizeof(line), mi)) {
|
|
if (strncmp(line, "MemAvailable:", 13) == 0) {
|
|
avail_ram = strtoull(line + 13, nullptr, 10) * 1024;
|
|
break;
|
|
}
|
|
}
|
|
fclose(mi);
|
|
}
|
|
if (avail_ram > 0 && kv_total > avail_ram / 2) {
|
|
// KV cache would use >50% of available RAM → reduce context
|
|
int safe_ctx = (int)(avail_ram / 2 / (kv_per_layer / cfg_.max_seq_len));
|
|
safe_ctx = std::max(512, std::min(safe_ctx, cfg_.max_seq_len));
|
|
if (safe_ctx < cfg_.max_seq_len) {
|
|
printf("[ADAPTER] Context capped: %d → %d (KV cache %.1f GB > %.1f GB avail/2)\n",
|
|
cfg_.max_seq_len, safe_ctx,
|
|
kv_total / 1e9, avail_ram / 2e9);
|
|
cfg_.max_seq_len = safe_ctx;
|
|
}
|
|
}
|
|
}
|
|
cfg_.print();
|
|
|
|
printf("\n[TRANSFORMER] Loading weights...\n");
|
|
if (!load_weights_v6(weights_, gguf, cfg_)) {
|
|
printf("[TRANSFORMER] ERROR: Failed to load weights\n");
|
|
return false;
|
|
}
|
|
|
|
// FIX: derive vocab_size from output weight tensor dims
|
|
{
|
|
const TensorInfo* oti = gguf.tensor("output.weight");
|
|
if (!oti) oti = gguf.tensor("lm_head.weight");
|
|
if (oti && oti->dims[1] > 0 && (int)oti->dims[1] != cfg_.vocab_size) {
|
|
printf("[FIX] vocab_size: %d -> %lu (from output weight)\n", cfg_.vocab_size, (unsigned long)oti->dims[1]);
|
|
cfg_.vocab_size = (int)oti->dims[1];
|
|
}
|
|
}
|
|
|
|
// Initialize components based on architecture
|
|
if (cfg_.is_mla()) {
|
|
printf("[TRANSFORMER] Initializing MLA attention (kv_lora=%d, rope=%d)\n",
|
|
cfg_.kv_lora_rank, cfg_.rope_dim);
|
|
mla_attn_.init(cfg_, max_ctx);
|
|
mla_cache_.init(cfg_, max_ctx);
|
|
mla_rope_.init(cfg_.rope_dim, max_ctx, cfg_.rope_theta,
|
|
cfg_.rope_scaling_factor);
|
|
}
|
|
if ((cfg_.n_dense_layers > 0 || !cfg_.is_moe()) && !cfg_.is_mla()) {
|
|
printf("[TRANSFORMER] Initializing standard GQA attention\n");
|
|
gqa_cache_.init(cfg_);
|
|
gqa_rope_.init(cfg_.head_dim, cfg_.max_seq_len, cfg_.rope_theta);
|
|
gqa_attn_.init(cfg_);
|
|
}
|
|
|
|
if (cfg_.is_moe()) {
|
|
printf("[TRANSFORMER] Initializing MoE (%d experts, %d active, %d shared)\n",
|
|
cfg_.n_experts, cfg_.n_experts_used, cfg_.n_expert_shared);
|
|
router_.init(cfg_);
|
|
expert_ffn_.init(cfg_);
|
|
shared_ffn_.init(cfg_);
|
|
expert_cache_.init(cfg_.n_layers, cfg_.n_experts);
|
|
}
|
|
if (cfg_.n_dense_layers > 0) {
|
|
dense_ffn_.init(cfg_);
|
|
}
|
|
|
|
// Allocate buffers
|
|
x_.resize(cfg_.dim);
|
|
xb_.resize(cfg_.dim);
|
|
expert_out_.resize(cfg_.dim);
|
|
ffn_out_.resize(cfg_.dim);
|
|
logits_.resize(cfg_.vocab_size);
|
|
|
|
printf("[TRANSFORMER] Ready. dim=%d, layers=%d, vocab=%d, ctx=%d\n",
|
|
cfg_.dim, cfg_.n_layers, cfg_.vocab_size, max_ctx);
|
|
printf("[TRANSFORMER] MLA KV cache: %.1f MB\n",
|
|
cfg_.is_mla() ? mla_cache_.memory_bytes() / 1e6 : 0.0);
|
|
|
|
initialized_ = true;
|
|
return true;
|
|
}
|
|
|
|
void reset() {
|
|
if (cfg_.is_mla()) mla_cache_.clear();
|
|
else gqa_cache_.clear();
|
|
}
|
|
|
|
// ═══════════════════════════════════════════════════════════════════════════
|
|
// FORWARD — Single token
|
|
// ═══════════════════════════════════════════════════════════════════════════
|
|
const float* forward(int32_t token) {
|
|
if (!initialized_) return nullptr;
|
|
|
|
// Embedding
|
|
embed(token);
|
|
|
|
// Transformer layers
|
|
for (int l = 0; l < cfg_.n_layers; ++l) {
|
|
layer_forward(l);
|
|
}
|
|
|
|
// Final norm + output projection
|
|
kernel::rms_norm(x_.data(), x_.data(), weights_.output_norm,
|
|
cfg_.dim, cfg_.rms_norm_eps);
|
|
gemm::matmul(logits_.data(), weights_.output, weights_.t_output,
|
|
x_.data(), cfg_.vocab_size, cfg_.dim);
|
|
|
|
// Final logit soft-capping (Gemma-2)
|
|
if (cfg_.final_logit_softcap > 0.0f) {
|
|
float cap = cfg_.final_logit_softcap;
|
|
for (int i = 0; i < cfg_.vocab_size; ++i)
|
|
logits_[i] = cap * tanhf(logits_[i] / cap);
|
|
}
|
|
|
|
if (cfg_.is_mla()) mla_cache_.advance();
|
|
else gqa_cache_.advance();
|
|
|
|
return logits_.data();
|
|
}
|
|
|
|
// ═══════════════════════════════════════════════════════════════════════════
|
|
// GENERATE
|
|
// ═══════════════════════════════════════════════════════════════════════════
|
|
std::vector<int32_t> generate(
|
|
const std::vector<int32_t>& prompt,
|
|
int max_new_tokens = 256,
|
|
float temperature = 0.7f,
|
|
float top_p = 0.9f,
|
|
int top_k = 40
|
|
) {
|
|
reset();
|
|
std::vector<int32_t> output;
|
|
|
|
auto t0 = std::chrono::high_resolution_clock::now();
|
|
|
|
// Prefill
|
|
printf("[GEN] Prefilling %zu tokens...\n", prompt.size());
|
|
for (size_t i = 0; i < prompt.size(); ++i) {
|
|
forward(prompt[i]);
|
|
if ((i + 1) % 100 == 0) printf(" [%zu/%zu]\n", i + 1, prompt.size());
|
|
}
|
|
|
|
auto t1 = std::chrono::high_resolution_clock::now();
|
|
double prefill_ms = std::chrono::duration<double, std::milli>(t1 - t0).count();
|
|
printf("[GEN] Prefill: %.0f ms (%.1f tok/s)\n",
|
|
prefill_ms, prompt.size() * 1000.0 / prefill_ms);
|
|
|
|
// Generate
|
|
printf("[GEN] Generating up to %d tokens...\n", max_new_tokens);
|
|
for (int i = 0; i < max_new_tokens; ++i) {
|
|
int32_t next = sample(temperature, top_p, top_k);
|
|
if (is_eos(next)) break;
|
|
output.push_back(next);
|
|
forward(next);
|
|
}
|
|
|
|
auto t2 = std::chrono::high_resolution_clock::now();
|
|
double gen_ms = std::chrono::duration<double, std::milli>(t2 - t1).count();
|
|
printf("[GEN] Generated %zu tokens in %.0f ms (%.2f tok/s)\n",
|
|
output.size(), gen_ms,
|
|
output.size() > 0 ? output.size() * 1000.0 / gen_ms : 0);
|
|
|
|
// Print expert cache stats
|
|
if (cfg_.is_moe()) {
|
|
expert_cache_.print_stats();
|
|
}
|
|
|
|
return output;
|
|
}
|
|
|
|
// Generate with streaming callback
|
|
template<typename Callback>
|
|
void generate_stream(
|
|
const std::vector<int32_t>& prompt,
|
|
int max_new_tokens,
|
|
float temperature, float top_p, int top_k,
|
|
Callback&& on_token
|
|
) {
|
|
reset();
|
|
for (size_t i = 0; i < prompt.size(); ++i) forward(prompt[i]);
|
|
for (int i = 0; i < max_new_tokens; ++i) {
|
|
int32_t next = sample(temperature, top_p, top_k);
|
|
if (is_eos(next)) break;
|
|
if (!on_token(next)) break;
|
|
forward(next);
|
|
}
|
|
}
|
|
|
|
const Config& config() const { return cfg_; }
|
|
Config& config_mut() { return cfg_; }
|
|
void set_eos_token(int32_t eos) { eos_tokens_ = {eos}; }
|
|
void add_eos_token(int32_t eos) {
|
|
for (auto e : eos_tokens_) if (e == eos) return;
|
|
eos_tokens_.push_back(eos);
|
|
}
|
|
bool is_eos(int32_t tok) const {
|
|
for (auto e : eos_tokens_) if (tok == e) return true;
|
|
return false;
|
|
}
|
|
ExpertCache& expert_cache_ref() { return expert_cache_; }
|
|
|
|
private:
|
|
Config cfg_;
|
|
WeightsV6 weights_;
|
|
bool initialized_ = false;
|
|
std::vector<int32_t> eos_tokens_ = {2};
|
|
|
|
// MLA components
|
|
MLAAttention mla_attn_;
|
|
MLAKVCache mla_cache_;
|
|
MLARoPE mla_rope_;
|
|
|
|
// GQA fallback
|
|
Attention gqa_attn_;
|
|
KVCache gqa_cache_;
|
|
kernel::RoPE gqa_rope_;
|
|
|
|
// MoE components
|
|
MoERouter router_;
|
|
ExpertFFN expert_ffn_;
|
|
SharedExpertFFN shared_ffn_;
|
|
ExpertCache expert_cache_;
|
|
|
|
// Dense FFN fallback
|
|
FFN dense_ffn_;
|
|
|
|
// Buffers
|
|
std::vector<float> x_, xb_, expert_out_, ffn_out_, logits_;
|
|
std::mt19937 rng_{42};
|
|
|
|
// ─── EMBEDDING ─────────────────────────────────────────────────────────
|
|
void embed(int32_t token) {
|
|
gemm::embed_lookup(x_.data(), weights_.token_embd, weights_.t_embd,
|
|
token, cfg_.dim);
|
|
// Gemma: scale embeddings by sqrt(dim)
|
|
if (cfg_.embed_scale_sqrt_dim) {
|
|
float scale = sqrtf((float)cfg_.dim);
|
|
for (int i = 0; i < cfg_.dim; ++i) x_[i] *= scale;
|
|
}
|
|
}
|
|
|
|
// ─── LAYER FORWARD ─────────────────────────────────────────────────────
|
|
void layer_forward(int layer) {
|
|
const LayerWeightsV6& lw = weights_.layers[layer];
|
|
|
|
// === ATTENTION ===
|
|
kernel::vec_copy(xb_.data(), x_.data(), cfg_.dim);
|
|
|
|
if (lw.attn_norm) {
|
|
kernel::rms_norm(x_.data(), x_.data(), lw.attn_norm,
|
|
cfg_.dim, cfg_.rms_norm_eps);
|
|
}
|
|
|
|
if (lw.is_mla) {
|
|
// MLA attention path
|
|
mla_attn_.forward(
|
|
x_.data(), x_.data(),
|
|
lw.w_q_a, lw.t_q_a,
|
|
lw.q_a_norm,
|
|
lw.w_q_b, lw.t_q_b,
|
|
lw.w_kv_a, lw.t_kv_a,
|
|
lw.kv_a_norm,
|
|
lw.w_k_b, lw.t_k_b,
|
|
lw.w_v_b, lw.t_v_b,
|
|
lw.w_o, lw.t_o,
|
|
mla_cache_, mla_rope_, layer
|
|
);
|
|
} else if (!lw.is_mla) {
|
|
if (lw.has_fused_qkv) {
|
|
// ═══ TENSOR ADAPTER: Split fused QKV on the fly ═══
|
|
// QKV is [dim, q_dim + k_dim + v_dim] contiguous in memory
|
|
int q_dim = cfg_.dim;
|
|
int kv_dim = cfg_.n_kv_heads * cfg_.head_dim;
|
|
// For quantized: compute byte offsets based on type
|
|
const char* base = (const char*)lw.w_qkv_fused;
|
|
size_t bpe = gemm::bytes_per_element(lw.t_qkv_fused, cfg_.dim);
|
|
size_t q_off = 0;
|
|
size_t k_off = q_dim * bpe;
|
|
size_t v_off = (q_dim + kv_dim) * bpe;
|
|
gqa_attn_.forward(
|
|
x_.data(), x_.data(),
|
|
(const void*)(base + q_off), lw.t_qkv_fused,
|
|
(const void*)(base + k_off), lw.t_qkv_fused,
|
|
(const void*)(base + v_off), lw.t_qkv_fused,
|
|
lw.w_o, lw.t_o,
|
|
gqa_cache_, gqa_rope_, layer,
|
|
lw.bq, lw.bk, lw.bv
|
|
);
|
|
} else {
|
|
// Standard GQA attention path (for dense layers)
|
|
gqa_attn_.forward(
|
|
x_.data(), x_.data(),
|
|
lw.wq, lw.tq,
|
|
lw.wk, lw.tk,
|
|
lw.wv, lw.tv,
|
|
lw.w_o, lw.t_o,
|
|
gqa_cache_, gqa_rope_, layer,
|
|
lw.bq, lw.bk, lw.bv
|
|
);
|
|
}
|
|
}
|
|
|
|
// Residual
|
|
kernel::vec_add(x_.data(), xb_.data(), x_.data(), cfg_.dim);
|
|
|
|
// Post-attention norm (Gemma-2)
|
|
if (lw.post_attn_norm) {
|
|
kernel::rms_norm(x_.data(), x_.data(), lw.post_attn_norm,
|
|
cfg_.dim, cfg_.rms_norm_eps);
|
|
}
|
|
|
|
// === FFN ===
|
|
kernel::vec_copy(xb_.data(), x_.data(), cfg_.dim);
|
|
|
|
if (lw.ffn_norm) {
|
|
kernel::rms_norm(x_.data(), x_.data(), lw.ffn_norm,
|
|
cfg_.dim, cfg_.rms_norm_eps);
|
|
}
|
|
|
|
if (lw.is_moe) {
|
|
// ─── MoE FORWARD ───────────────────────────────────────────────
|
|
kernel::vec_zero(ffn_out_.data(), cfg_.dim);
|
|
|
|
// Route: select top-K experts
|
|
auto selected = router_.route(lw.gate_inp, x_.data());
|
|
expert_cache_.record_batch(layer, selected);
|
|
|
|
// ─── SURGICAL MMAP PREFETCH (÷48 I/O) ────────────────────────
|
|
{
|
|
std::vector<int> active_ids;
|
|
for (const auto& s : selected) active_ids.push_back(s.expert_id);
|
|
KernelDispatch::instance().prefetch_experts(
|
|
layer, active_ids.data(), (int)active_ids.size());
|
|
}
|
|
|
|
// Dispatch to selected experts
|
|
for (const auto& sel : selected) {
|
|
expert_ffn_.forward(
|
|
expert_out_.data(), x_.data(),
|
|
sel.expert_id,
|
|
lw.gate_exps, lw.t_gate_exps,
|
|
lw.up_exps, lw.t_up_exps,
|
|
lw.down_exps, lw.t_down_exps
|
|
);
|
|
|
|
// Weighted accumulation
|
|
for (int d = 0; d < cfg_.dim; ++d) {
|
|
ffn_out_[d] += sel.weight * expert_out_[d];
|
|
}
|
|
}
|
|
|
|
// Shared expert (always active)
|
|
if (lw.gate_shexp) {
|
|
shared_ffn_.forward(
|
|
expert_out_.data(), x_.data(),
|
|
lw.gate_shexp, lw.t_gate_shexp,
|
|
lw.up_shexp, lw.t_up_shexp,
|
|
lw.down_shexp, lw.t_down_shexp
|
|
);
|
|
|
|
// Add shared expert output
|
|
kernel::vec_add(ffn_out_.data(), ffn_out_.data(),
|
|
expert_out_.data(), cfg_.dim);
|
|
}
|
|
|
|
kernel::vec_copy(x_.data(), ffn_out_.data(), cfg_.dim);
|
|
|
|
// ─── EVICT INACTIVE EXPERT PAGES ──────────────────────────────
|
|
KernelDispatch::instance().evict_layer(layer);
|
|
}
|
|
if (!lw.is_moe && lw.w_ffn_gate) {
|
|
// ─── Dense FFN ─────────────────────────────────────────────────
|
|
dense_ffn_.forward(
|
|
x_.data(), x_.data(),
|
|
lw.w_ffn_gate, lw.t_ffn_gate,
|
|
lw.w_ffn_up, lw.t_ffn_up,
|
|
lw.w_ffn_down, lw.t_ffn_down
|
|
);
|
|
} else if (!lw.is_moe && lw.has_fused_gate_up) {
|
|
// ═══ TENSOR ADAPTER: Fused gate+up (Phi-3) ═══
|
|
// Split [dim, 2*intermediate] → gate [dim, intermediate] + up [dim, intermediate]
|
|
int inter = cfg_.intermediate;
|
|
size_t bpe = gemm::bytes_per_element(lw.t_ffn_gate_up_fused, cfg_.dim);
|
|
const char* base = (const char*)lw.w_ffn_gate_up_fused;
|
|
dense_ffn_.forward(
|
|
x_.data(), x_.data(),
|
|
(const void*)(base), lw.t_ffn_gate_up_fused, // gate half
|
|
(const void*)(base + inter * bpe), lw.t_ffn_gate_up_fused, // up half
|
|
lw.w_ffn_down, lw.t_ffn_down
|
|
);
|
|
}
|
|
|
|
// FFN Residual
|
|
kernel::vec_add(x_.data(), xb_.data(), x_.data(), cfg_.dim);
|
|
|
|
// Post-FFN norm (Gemma-2)
|
|
if (lw.post_ffn_norm) {
|
|
kernel::rms_norm(x_.data(), x_.data(), lw.post_ffn_norm,
|
|
cfg_.dim, cfg_.rms_norm_eps);
|
|
}
|
|
}
|
|
|
|
// ─── SAMPLING ──────────────────────────────────────────────────────────
|
|
int32_t sample(float temperature, float top_p, int top_k) {
|
|
int n = cfg_.vocab_size;
|
|
|
|
if (temperature < 1e-6f) {
|
|
return static_cast<int32_t>(
|
|
std::max_element(logits_.begin(), logits_.end()) - logits_.begin()
|
|
);
|
|
}
|
|
|
|
for (int i = 0; i < n; ++i) logits_[i] /= temperature;
|
|
kernel::softmax(logits_.data(), n);
|
|
|
|
// Top-K
|
|
if (top_k > 0 && top_k < n) {
|
|
std::vector<std::pair<float, int>> indexed(n);
|
|
for (int i = 0; i < n; ++i) indexed[i] = {logits_[i], i};
|
|
std::partial_sort(indexed.begin(), indexed.begin() + top_k, indexed.end(),
|
|
[](const auto& a, const auto& b) { return a.first > b.first; });
|
|
for (int i = top_k; i < n; ++i) logits_[indexed[i].second] = 0.0f;
|
|
}
|
|
|
|
// Top-P
|
|
if (top_p < 1.0f) {
|
|
std::vector<std::pair<float, int>> sorted(n);
|
|
for (int i = 0; i < n; ++i) sorted[i] = {logits_[i], i};
|
|
std::sort(sorted.begin(), sorted.end(),
|
|
[](const auto& a, const auto& b) { return a.first > b.first; });
|
|
float cumsum = 0.0f;
|
|
for (int i = 0; i < n; ++i) {
|
|
cumsum += sorted[i].first;
|
|
if (cumsum > top_p) {
|
|
for (int j = i + 1; j < n; ++j) logits_[sorted[j].second] = 0.0f;
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
|
|
// Renormalize
|
|
float sum = 0.0f;
|
|
for (int i = 0; i < n; ++i) sum += logits_[i];
|
|
if (sum > 0.0f) for (int i = 0; i < n; ++i) logits_[i] /= sum;
|
|
|
|
// Sample
|
|
std::uniform_real_distribution<float> dist(0.0f, 1.0f);
|
|
float r = dist(rng_);
|
|
float cumsum = 0.0f;
|
|
for (int i = 0; i < n; ++i) {
|
|
cumsum += logits_[i];
|
|
if (r < cumsum) return static_cast<int32_t>(i);
|
|
}
|
|
return static_cast<int32_t>(n - 1);
|
|
}
|
|
};
|
|
|
|
} // namespace ix
|