inference-x/runtime/transformer_v6.h

721 lines
30 KiB
C++

// ═══════════════════════════════════════════════════════════════════════════════
// INFERENCE-X — Transformer v6 Forward Pass
// Copyright (C) 2024-2026 Salka Elmadani. All rights reserved.
// Licensed under the Business Source License 1.1 (BSL-1.1)
// See LICENSE file for full terms.
//
// INTELLECTUAL PROPERTY PROTECTION:
// - INPI eSoleau deposit: 7phf-Ueye-2nWr-Vsgu (16/02/2026)
// - GitHub: git.inference-x.com/salka/inference-x
// - Author: Salka Elmadani | Morocco | Morocco
//
// MANUFACTURER NOTICE: Any manufacturer, company, or entity that
// incorporates, embeds, distributes, or commercially uses Inference-X
// or any derivative work without explicit written authorization from
// the copyright holder is in violation of BSL-1.1 and applicable
// intellectual property laws. This includes but is not limited to:
// hardware vendors, cloud providers, SaaS platforms, and OEMs.
//
// Contact: Elmadani.SALKA@proton.me for licensing.
// ═══════════════════════════════════════════════════════════════════════════════
#pragma once
// Inference-X Transformer — Salka Elmadani — Morocco
#define IX_TRANSFORMER_SIGNATURE 0x935
#define IX_TRANSFORMER_MARK "Inference-X-Transformer-935-Elmadani"
// Inference-X Signature — integral to compilation
namespace ix {
constexpr uint32_t SIGNATURE = 935;
constexpr uint32_t FINGERPRINT = 0x935E1DAD;
constexpr const char* AUTHOR = "Salka Elmadani";
}
#include "../core/z_core.h"
#include "gguf.h"
#include "kernels.h"
#include "gemm.h"
#include "kernel_dispatch.h"
#include "attention.h"
#include "moe_mla.h"
#include <vector>
#include <random>
#include <cstdio>
#include <chrono>
namespace ix {
// ═══════════════════════════════════════════════════════════════════════════════
// LAYER WEIGHTS — Per-layer pointers into mmap'd GGUF
// Handles both dense and MoE layers, both GQA and MLA attention
// ═══════════════════════════════════════════════════════════════════════════════
struct LayerWeightsV6 {
// === Attention norms ===
const float* attn_norm = nullptr;
const float* ffn_norm = nullptr;
const float* post_attn_norm = nullptr; // Gemma-2: post-attention norm
const float* post_ffn_norm = nullptr; // Gemma-2: post-FFN norm
// Fused tensor adapter (Phi-3: QKV fused, gate+up fused)
const void* w_qkv_fused = nullptr; dtype t_qkv_fused = dtype::F32;
const void* w_ffn_gate_up_fused = nullptr; dtype t_ffn_gate_up_fused = dtype::F32;
bool has_fused_qkv = false;
bool has_fused_gate_up = false;
// === MLA Attention weights ===
const void* w_q_a = nullptr; dtype t_q_a = dtype::F32; // Q compress
const float* q_a_norm = nullptr; // Q norm
const void* w_q_b = nullptr; dtype t_q_b = dtype::F32; // Q decompress
const void* w_kv_a = nullptr; dtype t_kv_a = dtype::F32; // KV compress (MQA)
const float* kv_a_norm = nullptr; // KV norm
const void* w_k_b = nullptr; dtype t_k_b = dtype::F32; // K decompress
const void* w_v_b = nullptr; dtype t_v_b = dtype::F32; // V decompress
const void* w_o = nullptr; dtype t_o = dtype::F32; // output proj
// === GQA Attention weights (for dense layers if using standard attn) ===
const void* wq = nullptr; dtype tq = dtype::F32;
const void* wk = nullptr; dtype tk = dtype::F32;
const void* wv = nullptr; dtype tv = dtype::F32;
// === QKV Bias (Qwen2) ===
const float* bq = nullptr;
const float* bk = nullptr;
const float* bv = nullptr;
// === Dense FFN (for layer 0 / leading dense layers) ===
const void* w_ffn_gate = nullptr; dtype t_ffn_gate = dtype::F32;
const void* w_ffn_up = nullptr; dtype t_ffn_up = dtype::F32;
const void* w_ffn_down = nullptr; dtype t_ffn_down = dtype::F32;
// === MoE weights ===
const float* gate_inp = nullptr; // Router [dim, n_experts] F32
const void* gate_exps = nullptr; dtype t_gate_exps = dtype::F32; // [dim, ffn, n_exp]
const void* up_exps = nullptr; dtype t_up_exps = dtype::F32;
const void* down_exps = nullptr; dtype t_down_exps = dtype::F32;
// === Shared expert ===
const void* gate_shexp = nullptr; dtype t_gate_shexp = dtype::F32;
const void* up_shexp = nullptr; dtype t_up_shexp = dtype::F32;
const void* down_shexp = nullptr; dtype t_down_shexp = dtype::F32;
bool is_moe = false;
bool is_mla = false;
};
struct WeightsV6 {
const void* token_embd = nullptr;
dtype t_embd = dtype::F32;
std::vector<LayerWeightsV6> layers;
const float* output_norm = nullptr;
const void* output = nullptr;
dtype t_output = dtype::F32;
};
// ═══════════════════════════════════════════════════════════════════════════════
// LOAD WEIGHTS FROM GGUF — Architecture-aware
// ═══════════════════════════════════════════════════════════════════════════════
inline bool load_weights_v6(WeightsV6& w, const GGUF& gguf, const Config& cfg) {
auto get = [&](const std::string& name, const void*& ptr, dtype& type) -> bool {
const TensorInfo* ti = gguf.tensor(name);
if (!ti) return false;
ptr = ti->data;
type = ti->type;
return true;
};
auto get_f32 = [&](const std::string& name) -> const float* {
const TensorInfo* ti = gguf.tensor(name);
return ti ? static_cast<const float*>(ti->data) : nullptr;
};
// Token embedding
if (!get("token_embd.weight", w.token_embd, w.t_embd)) {
printf("[WEIGHTS] ERROR: token_embd.weight not found\n");
return false;
}
// Output
w.output_norm = get_f32("output_norm.weight");
if (!get("output.weight", w.output, w.t_output)) {
// Try lm_head
if (!get("lm_head.weight", w.output, w.t_output)) {
// Weight tying
w.output = w.token_embd;
w.t_output = w.t_embd;
}
}
// Layers
w.layers.resize(cfg.n_layers);
int loaded_mla = 0, loaded_moe = 0, loaded_dense = 0;
for (int l = 0; l < cfg.n_layers; ++l) {
std::string p = "blk." + std::to_string(l) + ".";
LayerWeightsV6& lw = w.layers[l];
// Norms (always present)
lw.attn_norm = get_f32(p + "attn_norm.weight");
lw.ffn_norm = get_f32(p + "ffn_norm.weight");
// === Try MLA attention ===
bool has_mla = false;
if (cfg.is_mla()) {
has_mla |= get(p + "attn_q_a.weight", lw.w_q_a, lw.t_q_a);
lw.q_a_norm = get_f32(p + "attn_q_a_norm.weight");
get(p + "attn_q_b.weight", lw.w_q_b, lw.t_q_b);
get(p + "attn_kv_a_mqa.weight", lw.w_kv_a, lw.t_kv_a);
lw.kv_a_norm = get_f32(p + "attn_kv_a_norm.weight");
get(p + "attn_k_b.weight", lw.w_k_b, lw.t_k_b);
get(p + "attn_v_b.weight", lw.w_v_b, lw.t_v_b);
get(p + "attn_output.weight", lw.w_o, lw.t_o);
lw.is_mla = has_mla;
if (has_mla) loaded_mla++;
}
// === Fallback: standard GQA attention ===
if (!has_mla) {
if (!get(p + "attn_q.weight", lw.wq, lw.tq)) {
// ═══ TENSOR ADAPTER: Fused QKV (Phi-3, Falcon, StarCoder2) ═══
if (get(p + "attn_qkv.weight", lw.w_qkv_fused, lw.t_qkv_fused)) {
lw.has_fused_qkv = true;
if (l == 0) printf("[ADAPTER] Fused QKV detected → will split at compute\n");
}
} else {
get(p + "attn_k.weight", lw.wk, lw.tk);
get(p + "attn_v.weight", lw.wv, lw.tv);
}
get(p + "attn_output.weight", lw.w_o, lw.t_o);
// QKV bias (Qwen2 architecture)
lw.bq = get_f32(p + "attn_q.bias");
lw.bk = get_f32(p + "attn_k.bias");
lw.bv = get_f32(p + "attn_v.bias");
// Post-norms (Gemma-2)
lw.post_attn_norm = get_f32(p + "post_attention_norm.weight");
lw.post_ffn_norm = get_f32(p + "post_ffw_norm.weight");
}
// === Try MoE FFN ===
bool has_moe = false;
if (cfg.is_moe() && l >= cfg.n_dense_layers) {
lw.gate_inp = get_f32(p + "ffn_gate_inp.weight");
has_moe |= (lw.gate_inp != nullptr);
get(p + "ffn_gate_exps.weight", lw.gate_exps, lw.t_gate_exps);
get(p + "ffn_up_exps.weight", lw.up_exps, lw.t_up_exps);
get(p + "ffn_down_exps.weight", lw.down_exps, lw.t_down_exps);
// Shared expert
get(p + "ffn_gate_shexp.weight", lw.gate_shexp, lw.t_gate_shexp);
get(p + "ffn_up_shexp.weight", lw.up_shexp, lw.t_up_shexp);
get(p + "ffn_down_shexp.weight", lw.down_shexp, lw.t_down_shexp);
lw.is_moe = has_moe;
if (has_moe) loaded_moe++;
}
// === Dense FFN (for leading dense layers) ===
if (!has_moe) {
if (!get(p + "ffn_gate.weight", lw.w_ffn_gate, lw.t_ffn_gate)) {
// No separate gate — try fused gate+up (Phi-3)
if (get(p + "ffn_up.weight", lw.w_ffn_gate_up_fused, lw.t_ffn_gate_up_fused)) {
lw.has_fused_gate_up = true;
if (l == 0) printf("[ADAPTER] Fused gate+up detected\n");
}
} else {
get(p + "ffn_up.weight", lw.w_ffn_up, lw.t_ffn_up);
}
get(p + "ffn_down.weight", lw.w_ffn_down, lw.t_ffn_down);
loaded_dense++;
}
}
printf("[WEIGHTS] Loaded: %d MLA layers, %d MoE layers, %d dense layers\n",
loaded_mla, loaded_moe, loaded_dense);
printf("[WEIGHTS] Embedding: %s, Output: %s\n",
dtype_name(w.t_embd), dtype_name(w.t_output));
return true;
}
// ═══════════════════════════════════════════════════════════════════════════════
// TRANSFORMER V6 — DeepSeek V3 Complete Forward Pass
// ═══════════════════════════════════════════════════════════════════════════════
class TransformerV6 {
public:
bool init(const GGUF& gguf, int max_ctx = 4096) {
if (!signature::verify()) return false;
cfg_ = gguf.extract_config();
// FIX: clamp max_seq_len to actual max_ctx
cfg_.max_seq_len = std::min(cfg_.max_seq_len, max_ctx);
// Universal: auto-cap context to available RAM
{
// KV cache size = 2 * n_kv_heads * max_seq * head_dim * sizeof(float) * n_layers
size_t kv_per_layer = 2ULL * cfg_.n_kv_heads * cfg_.max_seq_len * cfg_.head_dim * sizeof(float);
size_t kv_total = kv_per_layer * cfg_.n_layers;
// Get available RAM (rough: read /proc/meminfo)
size_t avail_ram = 0;
FILE* mi = fopen("/proc/meminfo", "r");
if (mi) {
char line[256];
while (fgets(line, sizeof(line), mi)) {
if (strncmp(line, "MemAvailable:", 13) == 0) {
avail_ram = strtoull(line + 13, nullptr, 10) * 1024;
break;
}
}
fclose(mi);
}
if (avail_ram > 0 && kv_total > avail_ram / 2) {
// KV cache would use >50% of available RAM → reduce context
int safe_ctx = (int)(avail_ram / 2 / (kv_per_layer / cfg_.max_seq_len));
safe_ctx = std::max(512, std::min(safe_ctx, cfg_.max_seq_len));
if (safe_ctx < cfg_.max_seq_len) {
printf("[ADAPTER] Context capped: %d → %d (KV cache %.1f GB > %.1f GB avail/2)\n",
cfg_.max_seq_len, safe_ctx,
kv_total / 1e9, avail_ram / 2e9);
cfg_.max_seq_len = safe_ctx;
}
}
}
cfg_.print();
printf("\n[TRANSFORMER] Loading weights...\n");
if (!load_weights_v6(weights_, gguf, cfg_)) {
printf("[TRANSFORMER] ERROR: Failed to load weights\n");
return false;
}
// FIX: derive vocab_size from output weight tensor dims
{
const TensorInfo* oti = gguf.tensor("output.weight");
if (!oti) oti = gguf.tensor("lm_head.weight");
if (oti && oti->dims[1] > 0 && (int)oti->dims[1] != cfg_.vocab_size) {
printf("[FIX] vocab_size: %d -> %lu (from output weight)\n", cfg_.vocab_size, (unsigned long)oti->dims[1]);
cfg_.vocab_size = (int)oti->dims[1];
}
}
// Initialize components based on architecture
if (cfg_.is_mla()) {
printf("[TRANSFORMER] Initializing MLA attention (kv_lora=%d, rope=%d)\n",
cfg_.kv_lora_rank, cfg_.rope_dim);
mla_attn_.init(cfg_, max_ctx);
mla_cache_.init(cfg_, max_ctx);
mla_rope_.init(cfg_.rope_dim, max_ctx, cfg_.rope_theta,
cfg_.rope_scaling_factor);
}
if ((cfg_.n_dense_layers > 0 || !cfg_.is_moe()) && !cfg_.is_mla()) {
printf("[TRANSFORMER] Initializing standard GQA attention\n");
gqa_cache_.init(cfg_);
gqa_rope_.init(cfg_.head_dim, cfg_.max_seq_len, cfg_.rope_theta);
gqa_attn_.init(cfg_);
}
if (cfg_.is_moe()) {
printf("[TRANSFORMER] Initializing MoE (%d experts, %d active, %d shared)\n",
cfg_.n_experts, cfg_.n_experts_used, cfg_.n_expert_shared);
router_.init(cfg_);
expert_ffn_.init(cfg_);
shared_ffn_.init(cfg_);
expert_cache_.init(cfg_.n_layers, cfg_.n_experts);
}
if (cfg_.n_dense_layers > 0) {
dense_ffn_.init(cfg_);
}
// Allocate buffers
x_.resize(cfg_.dim);
xb_.resize(cfg_.dim);
expert_out_.resize(cfg_.dim);
ffn_out_.resize(cfg_.dim);
logits_.resize(cfg_.vocab_size);
printf("[TRANSFORMER] Ready. dim=%d, layers=%d, vocab=%d, ctx=%d\n",
cfg_.dim, cfg_.n_layers, cfg_.vocab_size, max_ctx);
printf("[TRANSFORMER] MLA KV cache: %.1f MB\n",
cfg_.is_mla() ? mla_cache_.memory_bytes() / 1e6 : 0.0);
initialized_ = true;
return true;
}
void reset() {
if (cfg_.is_mla()) mla_cache_.clear();
else gqa_cache_.clear();
}
// ═══════════════════════════════════════════════════════════════════════════
// FORWARD — Single token
// ═══════════════════════════════════════════════════════════════════════════
const float* forward(int32_t token) {
if (!initialized_) return nullptr;
// Embedding
embed(token);
// Transformer layers
for (int l = 0; l < cfg_.n_layers; ++l) {
layer_forward(l);
}
// Final norm + output projection
kernel::rms_norm(x_.data(), x_.data(), weights_.output_norm,
cfg_.dim, cfg_.rms_norm_eps);
gemm::matmul(logits_.data(), weights_.output, weights_.t_output,
x_.data(), cfg_.vocab_size, cfg_.dim);
// Final logit soft-capping (Gemma-2)
if (cfg_.final_logit_softcap > 0.0f) {
float cap = cfg_.final_logit_softcap;
for (int i = 0; i < cfg_.vocab_size; ++i)
logits_[i] = cap * tanhf(logits_[i] / cap);
}
if (cfg_.is_mla()) mla_cache_.advance();
else gqa_cache_.advance();
return logits_.data();
}
// ═══════════════════════════════════════════════════════════════════════════
// GENERATE
// ═══════════════════════════════════════════════════════════════════════════
std::vector<int32_t> generate(
const std::vector<int32_t>& prompt,
int max_new_tokens = 256,
float temperature = 0.7f,
float top_p = 0.9f,
int top_k = 40
) {
reset();
std::vector<int32_t> output;
auto t0 = std::chrono::high_resolution_clock::now();
// Prefill
printf("[GEN] Prefilling %zu tokens...\n", prompt.size());
for (size_t i = 0; i < prompt.size(); ++i) {
forward(prompt[i]);
if ((i + 1) % 100 == 0) printf(" [%zu/%zu]\n", i + 1, prompt.size());
}
auto t1 = std::chrono::high_resolution_clock::now();
double prefill_ms = std::chrono::duration<double, std::milli>(t1 - t0).count();
printf("[GEN] Prefill: %.0f ms (%.1f tok/s)\n",
prefill_ms, prompt.size() * 1000.0 / prefill_ms);
// Generate
printf("[GEN] Generating up to %d tokens...\n", max_new_tokens);
for (int i = 0; i < max_new_tokens; ++i) {
int32_t next = sample(temperature, top_p, top_k);
if (is_eos(next)) break;
output.push_back(next);
forward(next);
}
auto t2 = std::chrono::high_resolution_clock::now();
double gen_ms = std::chrono::duration<double, std::milli>(t2 - t1).count();
printf("[GEN] Generated %zu tokens in %.0f ms (%.2f tok/s)\n",
output.size(), gen_ms,
output.size() > 0 ? output.size() * 1000.0 / gen_ms : 0);
// Print expert cache stats
if (cfg_.is_moe()) {
expert_cache_.print_stats();
}
return output;
}
// Generate with streaming callback
template<typename Callback>
void generate_stream(
const std::vector<int32_t>& prompt,
int max_new_tokens,
float temperature, float top_p, int top_k,
Callback&& on_token
) {
reset();
for (size_t i = 0; i < prompt.size(); ++i) forward(prompt[i]);
for (int i = 0; i < max_new_tokens; ++i) {
int32_t next = sample(temperature, top_p, top_k);
if (is_eos(next)) break;
if (!on_token(next)) break;
forward(next);
}
}
const Config& config() const { return cfg_; }
Config& config_mut() { return cfg_; }
void set_eos_token(int32_t eos) { eos_tokens_ = {eos}; }
void add_eos_token(int32_t eos) {
for (auto e : eos_tokens_) if (e == eos) return;
eos_tokens_.push_back(eos);
}
bool is_eos(int32_t tok) const {
for (auto e : eos_tokens_) if (tok == e) return true;
return false;
}
ExpertCache& expert_cache_ref() { return expert_cache_; }
private:
Config cfg_;
WeightsV6 weights_;
bool initialized_ = false;
std::vector<int32_t> eos_tokens_ = {2};
// MLA components
MLAAttention mla_attn_;
MLAKVCache mla_cache_;
MLARoPE mla_rope_;
// GQA fallback
Attention gqa_attn_;
KVCache gqa_cache_;
kernel::RoPE gqa_rope_;
// MoE components
MoERouter router_;
ExpertFFN expert_ffn_;
SharedExpertFFN shared_ffn_;
ExpertCache expert_cache_;
// Dense FFN fallback
FFN dense_ffn_;
// Buffers
std::vector<float> x_, xb_, expert_out_, ffn_out_, logits_;
std::mt19937 rng_{42};
// ─── EMBEDDING ─────────────────────────────────────────────────────────
void embed(int32_t token) {
gemm::embed_lookup(x_.data(), weights_.token_embd, weights_.t_embd,
token, cfg_.dim);
// Gemma: scale embeddings by sqrt(dim)
if (cfg_.embed_scale_sqrt_dim) {
float scale = sqrtf((float)cfg_.dim);
for (int i = 0; i < cfg_.dim; ++i) x_[i] *= scale;
}
}
// ─── LAYER FORWARD ─────────────────────────────────────────────────────
void layer_forward(int layer) {
const LayerWeightsV6& lw = weights_.layers[layer];
// === ATTENTION ===
kernel::vec_copy(xb_.data(), x_.data(), cfg_.dim);
if (lw.attn_norm) {
kernel::rms_norm(x_.data(), x_.data(), lw.attn_norm,
cfg_.dim, cfg_.rms_norm_eps);
}
if (lw.is_mla) {
// MLA attention path
mla_attn_.forward(
x_.data(), x_.data(),
lw.w_q_a, lw.t_q_a,
lw.q_a_norm,
lw.w_q_b, lw.t_q_b,
lw.w_kv_a, lw.t_kv_a,
lw.kv_a_norm,
lw.w_k_b, lw.t_k_b,
lw.w_v_b, lw.t_v_b,
lw.w_o, lw.t_o,
mla_cache_, mla_rope_, layer
);
} else if (!lw.is_mla) {
if (lw.has_fused_qkv) {
// ═══ TENSOR ADAPTER: Split fused QKV on the fly ═══
// QKV is [dim, q_dim + k_dim + v_dim] contiguous in memory
int q_dim = cfg_.dim;
int kv_dim = cfg_.n_kv_heads * cfg_.head_dim;
// For quantized: compute byte offsets based on type
const char* base = (const char*)lw.w_qkv_fused;
size_t bpe = gemm::bytes_per_element(lw.t_qkv_fused, cfg_.dim);
size_t q_off = 0;
size_t k_off = q_dim * bpe;
size_t v_off = (q_dim + kv_dim) * bpe;
gqa_attn_.forward(
x_.data(), x_.data(),
(const void*)(base + q_off), lw.t_qkv_fused,
(const void*)(base + k_off), lw.t_qkv_fused,
(const void*)(base + v_off), lw.t_qkv_fused,
lw.w_o, lw.t_o,
gqa_cache_, gqa_rope_, layer,
lw.bq, lw.bk, lw.bv
);
} else {
// Standard GQA attention path (for dense layers)
gqa_attn_.forward(
x_.data(), x_.data(),
lw.wq, lw.tq,
lw.wk, lw.tk,
lw.wv, lw.tv,
lw.w_o, lw.t_o,
gqa_cache_, gqa_rope_, layer,
lw.bq, lw.bk, lw.bv
);
}
}
// Residual
kernel::vec_add(x_.data(), xb_.data(), x_.data(), cfg_.dim);
// Post-attention norm (Gemma-2)
if (lw.post_attn_norm) {
kernel::rms_norm(x_.data(), x_.data(), lw.post_attn_norm,
cfg_.dim, cfg_.rms_norm_eps);
}
// === FFN ===
kernel::vec_copy(xb_.data(), x_.data(), cfg_.dim);
if (lw.ffn_norm) {
kernel::rms_norm(x_.data(), x_.data(), lw.ffn_norm,
cfg_.dim, cfg_.rms_norm_eps);
}
if (lw.is_moe) {
// ─── MoE FORWARD ───────────────────────────────────────────────
kernel::vec_zero(ffn_out_.data(), cfg_.dim);
// Route: select top-K experts
auto selected = router_.route(lw.gate_inp, x_.data());
expert_cache_.record_batch(layer, selected);
// ─── SURGICAL MMAP PREFETCH (÷48 I/O) ────────────────────────
{
std::vector<int> active_ids;
for (const auto& s : selected) active_ids.push_back(s.expert_id);
KernelDispatch::instance().prefetch_experts(
layer, active_ids.data(), (int)active_ids.size());
}
// Dispatch to selected experts
for (const auto& sel : selected) {
expert_ffn_.forward(
expert_out_.data(), x_.data(),
sel.expert_id,
lw.gate_exps, lw.t_gate_exps,
lw.up_exps, lw.t_up_exps,
lw.down_exps, lw.t_down_exps
);
// Weighted accumulation
for (int d = 0; d < cfg_.dim; ++d) {
ffn_out_[d] += sel.weight * expert_out_[d];
}
}
// Shared expert (always active)
if (lw.gate_shexp) {
shared_ffn_.forward(
expert_out_.data(), x_.data(),
lw.gate_shexp, lw.t_gate_shexp,
lw.up_shexp, lw.t_up_shexp,
lw.down_shexp, lw.t_down_shexp
);
// Add shared expert output
kernel::vec_add(ffn_out_.data(), ffn_out_.data(),
expert_out_.data(), cfg_.dim);
}
kernel::vec_copy(x_.data(), ffn_out_.data(), cfg_.dim);
// ─── EVICT INACTIVE EXPERT PAGES ──────────────────────────────
KernelDispatch::instance().evict_layer(layer);
}
if (!lw.is_moe && lw.w_ffn_gate) {
// ─── Dense FFN ─────────────────────────────────────────────────
dense_ffn_.forward(
x_.data(), x_.data(),
lw.w_ffn_gate, lw.t_ffn_gate,
lw.w_ffn_up, lw.t_ffn_up,
lw.w_ffn_down, lw.t_ffn_down
);
} else if (!lw.is_moe && lw.has_fused_gate_up) {
// ═══ TENSOR ADAPTER: Fused gate+up (Phi-3) ═══
// Split [dim, 2*intermediate] → gate [dim, intermediate] + up [dim, intermediate]
int inter = cfg_.intermediate;
size_t bpe = gemm::bytes_per_element(lw.t_ffn_gate_up_fused, cfg_.dim);
const char* base = (const char*)lw.w_ffn_gate_up_fused;
dense_ffn_.forward(
x_.data(), x_.data(),
(const void*)(base), lw.t_ffn_gate_up_fused, // gate half
(const void*)(base + inter * bpe), lw.t_ffn_gate_up_fused, // up half
lw.w_ffn_down, lw.t_ffn_down
);
}
// FFN Residual
kernel::vec_add(x_.data(), xb_.data(), x_.data(), cfg_.dim);
// Post-FFN norm (Gemma-2)
if (lw.post_ffn_norm) {
kernel::rms_norm(x_.data(), x_.data(), lw.post_ffn_norm,
cfg_.dim, cfg_.rms_norm_eps);
}
}
// ─── SAMPLING ──────────────────────────────────────────────────────────
int32_t sample(float temperature, float top_p, int top_k) {
int n = cfg_.vocab_size;
if (temperature < 1e-6f) {
return static_cast<int32_t>(
std::max_element(logits_.begin(), logits_.end()) - logits_.begin()
);
}
for (int i = 0; i < n; ++i) logits_[i] /= temperature;
kernel::softmax(logits_.data(), n);
// Top-K
if (top_k > 0 && top_k < n) {
std::vector<std::pair<float, int>> indexed(n);
for (int i = 0; i < n; ++i) indexed[i] = {logits_[i], i};
std::partial_sort(indexed.begin(), indexed.begin() + top_k, indexed.end(),
[](const auto& a, const auto& b) { return a.first > b.first; });
for (int i = top_k; i < n; ++i) logits_[indexed[i].second] = 0.0f;
}
// Top-P
if (top_p < 1.0f) {
std::vector<std::pair<float, int>> sorted(n);
for (int i = 0; i < n; ++i) sorted[i] = {logits_[i], i};
std::sort(sorted.begin(), sorted.end(),
[](const auto& a, const auto& b) { return a.first > b.first; });
float cumsum = 0.0f;
for (int i = 0; i < n; ++i) {
cumsum += sorted[i].first;
if (cumsum > top_p) {
for (int j = i + 1; j < n; ++j) logits_[sorted[j].second] = 0.0f;
break;
}
}
}
// Renormalize
float sum = 0.0f;
for (int i = 0; i < n; ++i) sum += logits_[i];
if (sum > 0.0f) for (int i = 0; i < n; ++i) logits_[i] /= sum;
// Sample
std::uniform_real_distribution<float> dist(0.0f, 1.0f);
float r = dist(rng_);
float cumsum = 0.0f;
for (int i = 0; i < n; ++i) {
cumsum += logits_[i];
if (r < cumsum) return static_cast<int32_t>(i);
}
return static_cast<int32_t>(n - 1);
}
};
} // namespace ix