// ═══════════════════════════════════════════════════════════════════════════════ // INFERENCE-X — Transformer v6 Forward Pass // Copyright (C) 2024-2026 Salka Elmadani. All rights reserved. // Licensed under the Business Source License 1.1 (BSL-1.1) // See LICENSE file for full terms. // // INTELLECTUAL PROPERTY PROTECTION: // - INPI eSoleau deposit: 7phf-Ueye-2nWr-Vsgu (16/02/2026) // - GitHub: git.inference-x.com/salka/inference-x // - Author: Salka Elmadani | Morocco | Morocco // // MANUFACTURER NOTICE: Any manufacturer, company, or entity that // incorporates, embeds, distributes, or commercially uses Inference-X // or any derivative work without explicit written authorization from // the copyright holder is in violation of BSL-1.1 and applicable // intellectual property laws. This includes but is not limited to: // hardware vendors, cloud providers, SaaS platforms, and OEMs. // // Contact: Elmadani.SALKA@proton.me for licensing. // ═══════════════════════════════════════════════════════════════════════════════ #pragma once // Inference-X Transformer — Salka Elmadani — Morocco #define IX_TRANSFORMER_SIGNATURE 0x935 #define IX_TRANSFORMER_MARK "Inference-X-Transformer-935-Elmadani" // Inference-X Signature — integral to compilation namespace ix { constexpr uint32_t SIGNATURE = 935; constexpr uint32_t FINGERPRINT = 0x935E1DAD; constexpr const char* AUTHOR = "Salka Elmadani"; } #include "../core/z_core.h" #include "gguf.h" #include "kernels.h" #include "gemm.h" #include "kernel_dispatch.h" #include "attention.h" #include "moe_mla.h" #include #include #include #include namespace ix { // ═══════════════════════════════════════════════════════════════════════════════ // LAYER WEIGHTS — Per-layer pointers into mmap'd GGUF // Handles both dense and MoE layers, both GQA and MLA attention // ═══════════════════════════════════════════════════════════════════════════════ struct LayerWeightsV6 { // === Attention norms === const float* attn_norm = nullptr; const float* ffn_norm = nullptr; const float* post_attn_norm = nullptr; // Gemma-2: post-attention norm const float* post_ffn_norm = nullptr; // Gemma-2: post-FFN norm // Fused tensor adapter (Phi-3: QKV fused, gate+up fused) const void* w_qkv_fused = nullptr; dtype t_qkv_fused = dtype::F32; const void* w_ffn_gate_up_fused = nullptr; dtype t_ffn_gate_up_fused = dtype::F32; bool has_fused_qkv = false; bool has_fused_gate_up = false; // === MLA Attention weights === const void* w_q_a = nullptr; dtype t_q_a = dtype::F32; // Q compress const float* q_a_norm = nullptr; // Q norm const void* w_q_b = nullptr; dtype t_q_b = dtype::F32; // Q decompress const void* w_kv_a = nullptr; dtype t_kv_a = dtype::F32; // KV compress (MQA) const float* kv_a_norm = nullptr; // KV norm const void* w_k_b = nullptr; dtype t_k_b = dtype::F32; // K decompress const void* w_v_b = nullptr; dtype t_v_b = dtype::F32; // V decompress const void* w_o = nullptr; dtype t_o = dtype::F32; // output proj // === GQA Attention weights (for dense layers if using standard attn) === const void* wq = nullptr; dtype tq = dtype::F32; const void* wk = nullptr; dtype tk = dtype::F32; const void* wv = nullptr; dtype tv = dtype::F32; // === QKV Bias (Qwen2) === const float* bq = nullptr; const float* bk = nullptr; const float* bv = nullptr; // === Dense FFN (for layer 0 / leading dense layers) === const void* w_ffn_gate = nullptr; dtype t_ffn_gate = dtype::F32; const void* w_ffn_up = nullptr; dtype t_ffn_up = dtype::F32; const void* w_ffn_down = nullptr; dtype t_ffn_down = dtype::F32; // === MoE weights === const float* gate_inp = nullptr; // Router [dim, n_experts] F32 const void* gate_exps = nullptr; dtype t_gate_exps = dtype::F32; // [dim, ffn, n_exp] const void* up_exps = nullptr; dtype t_up_exps = dtype::F32; const void* down_exps = nullptr; dtype t_down_exps = dtype::F32; // === Shared expert === const void* gate_shexp = nullptr; dtype t_gate_shexp = dtype::F32; const void* up_shexp = nullptr; dtype t_up_shexp = dtype::F32; const void* down_shexp = nullptr; dtype t_down_shexp = dtype::F32; bool is_moe = false; bool is_mla = false; }; struct WeightsV6 { const void* token_embd = nullptr; dtype t_embd = dtype::F32; std::vector layers; const float* output_norm = nullptr; const void* output = nullptr; dtype t_output = dtype::F32; }; // ═══════════════════════════════════════════════════════════════════════════════ // LOAD WEIGHTS FROM GGUF — Architecture-aware // ═══════════════════════════════════════════════════════════════════════════════ inline bool load_weights_v6(WeightsV6& w, const GGUF& gguf, const Config& cfg) { auto get = [&](const std::string& name, const void*& ptr, dtype& type) -> bool { const TensorInfo* ti = gguf.tensor(name); if (!ti) return false; ptr = ti->data; type = ti->type; return true; }; auto get_f32 = [&](const std::string& name) -> const float* { const TensorInfo* ti = gguf.tensor(name); return ti ? static_cast(ti->data) : nullptr; }; // Token embedding if (!get("token_embd.weight", w.token_embd, w.t_embd)) { printf("[WEIGHTS] ERROR: token_embd.weight not found\n"); return false; } // Output w.output_norm = get_f32("output_norm.weight"); if (!get("output.weight", w.output, w.t_output)) { // Try lm_head if (!get("lm_head.weight", w.output, w.t_output)) { // Weight tying w.output = w.token_embd; w.t_output = w.t_embd; } } // Layers w.layers.resize(cfg.n_layers); int loaded_mla = 0, loaded_moe = 0, loaded_dense = 0; for (int l = 0; l < cfg.n_layers; ++l) { std::string p = "blk." + std::to_string(l) + "."; LayerWeightsV6& lw = w.layers[l]; // Norms (always present) lw.attn_norm = get_f32(p + "attn_norm.weight"); lw.ffn_norm = get_f32(p + "ffn_norm.weight"); // === Try MLA attention === bool has_mla = false; if (cfg.is_mla()) { has_mla |= get(p + "attn_q_a.weight", lw.w_q_a, lw.t_q_a); lw.q_a_norm = get_f32(p + "attn_q_a_norm.weight"); get(p + "attn_q_b.weight", lw.w_q_b, lw.t_q_b); get(p + "attn_kv_a_mqa.weight", lw.w_kv_a, lw.t_kv_a); lw.kv_a_norm = get_f32(p + "attn_kv_a_norm.weight"); get(p + "attn_k_b.weight", lw.w_k_b, lw.t_k_b); get(p + "attn_v_b.weight", lw.w_v_b, lw.t_v_b); get(p + "attn_output.weight", lw.w_o, lw.t_o); lw.is_mla = has_mla; if (has_mla) loaded_mla++; } // === Fallback: standard GQA attention === if (!has_mla) { if (!get(p + "attn_q.weight", lw.wq, lw.tq)) { // ═══ TENSOR ADAPTER: Fused QKV (Phi-3, Falcon, StarCoder2) ═══ if (get(p + "attn_qkv.weight", lw.w_qkv_fused, lw.t_qkv_fused)) { lw.has_fused_qkv = true; if (l == 0) printf("[ADAPTER] Fused QKV detected → will split at compute\n"); } } else { get(p + "attn_k.weight", lw.wk, lw.tk); get(p + "attn_v.weight", lw.wv, lw.tv); } get(p + "attn_output.weight", lw.w_o, lw.t_o); // QKV bias (Qwen2 architecture) lw.bq = get_f32(p + "attn_q.bias"); lw.bk = get_f32(p + "attn_k.bias"); lw.bv = get_f32(p + "attn_v.bias"); // Post-norms (Gemma-2) lw.post_attn_norm = get_f32(p + "post_attention_norm.weight"); lw.post_ffn_norm = get_f32(p + "post_ffw_norm.weight"); } // === Try MoE FFN === bool has_moe = false; if (cfg.is_moe() && l >= cfg.n_dense_layers) { lw.gate_inp = get_f32(p + "ffn_gate_inp.weight"); has_moe |= (lw.gate_inp != nullptr); get(p + "ffn_gate_exps.weight", lw.gate_exps, lw.t_gate_exps); get(p + "ffn_up_exps.weight", lw.up_exps, lw.t_up_exps); get(p + "ffn_down_exps.weight", lw.down_exps, lw.t_down_exps); // Shared expert get(p + "ffn_gate_shexp.weight", lw.gate_shexp, lw.t_gate_shexp); get(p + "ffn_up_shexp.weight", lw.up_shexp, lw.t_up_shexp); get(p + "ffn_down_shexp.weight", lw.down_shexp, lw.t_down_shexp); lw.is_moe = has_moe; if (has_moe) loaded_moe++; } // === Dense FFN (for leading dense layers) === if (!has_moe) { if (!get(p + "ffn_gate.weight", lw.w_ffn_gate, lw.t_ffn_gate)) { // No separate gate — try fused gate+up (Phi-3) if (get(p + "ffn_up.weight", lw.w_ffn_gate_up_fused, lw.t_ffn_gate_up_fused)) { lw.has_fused_gate_up = true; if (l == 0) printf("[ADAPTER] Fused gate+up detected\n"); } } else { get(p + "ffn_up.weight", lw.w_ffn_up, lw.t_ffn_up); } get(p + "ffn_down.weight", lw.w_ffn_down, lw.t_ffn_down); loaded_dense++; } } printf("[WEIGHTS] Loaded: %d MLA layers, %d MoE layers, %d dense layers\n", loaded_mla, loaded_moe, loaded_dense); printf("[WEIGHTS] Embedding: %s, Output: %s\n", dtype_name(w.t_embd), dtype_name(w.t_output)); return true; } // ═══════════════════════════════════════════════════════════════════════════════ // TRANSFORMER V6 — DeepSeek V3 Complete Forward Pass // ═══════════════════════════════════════════════════════════════════════════════ class TransformerV6 { public: bool init(const GGUF& gguf, int max_ctx = 4096) { if (!signature::verify()) return false; cfg_ = gguf.extract_config(); // FIX: clamp max_seq_len to actual max_ctx cfg_.max_seq_len = std::min(cfg_.max_seq_len, max_ctx); // Universal: auto-cap context to available RAM { // KV cache size = 2 * n_kv_heads * max_seq * head_dim * sizeof(float) * n_layers size_t kv_per_layer = 2ULL * cfg_.n_kv_heads * cfg_.max_seq_len * cfg_.head_dim * sizeof(float); size_t kv_total = kv_per_layer * cfg_.n_layers; // Get available RAM (rough: read /proc/meminfo) size_t avail_ram = 0; FILE* mi = fopen("/proc/meminfo", "r"); if (mi) { char line[256]; while (fgets(line, sizeof(line), mi)) { if (strncmp(line, "MemAvailable:", 13) == 0) { avail_ram = strtoull(line + 13, nullptr, 10) * 1024; break; } } fclose(mi); } if (avail_ram > 0 && kv_total > avail_ram / 2) { // KV cache would use >50% of available RAM → reduce context int safe_ctx = (int)(avail_ram / 2 / (kv_per_layer / cfg_.max_seq_len)); safe_ctx = std::max(512, std::min(safe_ctx, cfg_.max_seq_len)); if (safe_ctx < cfg_.max_seq_len) { printf("[ADAPTER] Context capped: %d → %d (KV cache %.1f GB > %.1f GB avail/2)\n", cfg_.max_seq_len, safe_ctx, kv_total / 1e9, avail_ram / 2e9); cfg_.max_seq_len = safe_ctx; } } } cfg_.print(); printf("\n[TRANSFORMER] Loading weights...\n"); if (!load_weights_v6(weights_, gguf, cfg_)) { printf("[TRANSFORMER] ERROR: Failed to load weights\n"); return false; } // FIX: derive vocab_size from output weight tensor dims { const TensorInfo* oti = gguf.tensor("output.weight"); if (!oti) oti = gguf.tensor("lm_head.weight"); if (oti && oti->dims[1] > 0 && (int)oti->dims[1] != cfg_.vocab_size) { printf("[FIX] vocab_size: %d -> %lu (from output weight)\n", cfg_.vocab_size, (unsigned long)oti->dims[1]); cfg_.vocab_size = (int)oti->dims[1]; } } // Initialize components based on architecture if (cfg_.is_mla()) { printf("[TRANSFORMER] Initializing MLA attention (kv_lora=%d, rope=%d)\n", cfg_.kv_lora_rank, cfg_.rope_dim); mla_attn_.init(cfg_, max_ctx); mla_cache_.init(cfg_, max_ctx); mla_rope_.init(cfg_.rope_dim, max_ctx, cfg_.rope_theta, cfg_.rope_scaling_factor); } if ((cfg_.n_dense_layers > 0 || !cfg_.is_moe()) && !cfg_.is_mla()) { printf("[TRANSFORMER] Initializing standard GQA attention\n"); gqa_cache_.init(cfg_); gqa_rope_.init(cfg_.head_dim, cfg_.max_seq_len, cfg_.rope_theta); gqa_attn_.init(cfg_); } if (cfg_.is_moe()) { printf("[TRANSFORMER] Initializing MoE (%d experts, %d active, %d shared)\n", cfg_.n_experts, cfg_.n_experts_used, cfg_.n_expert_shared); router_.init(cfg_); expert_ffn_.init(cfg_); shared_ffn_.init(cfg_); expert_cache_.init(cfg_.n_layers, cfg_.n_experts); } if (cfg_.n_dense_layers > 0) { dense_ffn_.init(cfg_); } // Allocate buffers x_.resize(cfg_.dim); xb_.resize(cfg_.dim); expert_out_.resize(cfg_.dim); ffn_out_.resize(cfg_.dim); logits_.resize(cfg_.vocab_size); printf("[TRANSFORMER] Ready. dim=%d, layers=%d, vocab=%d, ctx=%d\n", cfg_.dim, cfg_.n_layers, cfg_.vocab_size, max_ctx); printf("[TRANSFORMER] MLA KV cache: %.1f MB\n", cfg_.is_mla() ? mla_cache_.memory_bytes() / 1e6 : 0.0); initialized_ = true; return true; } void reset() { if (cfg_.is_mla()) mla_cache_.clear(); else gqa_cache_.clear(); } // ═══════════════════════════════════════════════════════════════════════════ // FORWARD — Single token // ═══════════════════════════════════════════════════════════════════════════ const float* forward(int32_t token) { if (!initialized_) return nullptr; // Embedding embed(token); // Transformer layers for (int l = 0; l < cfg_.n_layers; ++l) { layer_forward(l); } // Final norm + output projection kernel::rms_norm(x_.data(), x_.data(), weights_.output_norm, cfg_.dim, cfg_.rms_norm_eps); gemm::matmul(logits_.data(), weights_.output, weights_.t_output, x_.data(), cfg_.vocab_size, cfg_.dim); // Final logit soft-capping (Gemma-2) if (cfg_.final_logit_softcap > 0.0f) { float cap = cfg_.final_logit_softcap; for (int i = 0; i < cfg_.vocab_size; ++i) logits_[i] = cap * tanhf(logits_[i] / cap); } if (cfg_.is_mla()) mla_cache_.advance(); else gqa_cache_.advance(); return logits_.data(); } // ═══════════════════════════════════════════════════════════════════════════ // GENERATE // ═══════════════════════════════════════════════════════════════════════════ std::vector generate( const std::vector& prompt, int max_new_tokens = 256, float temperature = 0.7f, float top_p = 0.9f, int top_k = 40 ) { reset(); std::vector output; auto t0 = std::chrono::high_resolution_clock::now(); // Prefill printf("[GEN] Prefilling %zu tokens...\n", prompt.size()); for (size_t i = 0; i < prompt.size(); ++i) { forward(prompt[i]); if ((i + 1) % 100 == 0) printf(" [%zu/%zu]\n", i + 1, prompt.size()); } auto t1 = std::chrono::high_resolution_clock::now(); double prefill_ms = std::chrono::duration(t1 - t0).count(); printf("[GEN] Prefill: %.0f ms (%.1f tok/s)\n", prefill_ms, prompt.size() * 1000.0 / prefill_ms); // Generate printf("[GEN] Generating up to %d tokens...\n", max_new_tokens); for (int i = 0; i < max_new_tokens; ++i) { int32_t next = sample(temperature, top_p, top_k); if (is_eos(next)) break; output.push_back(next); forward(next); } auto t2 = std::chrono::high_resolution_clock::now(); double gen_ms = std::chrono::duration(t2 - t1).count(); printf("[GEN] Generated %zu tokens in %.0f ms (%.2f tok/s)\n", output.size(), gen_ms, output.size() > 0 ? output.size() * 1000.0 / gen_ms : 0); // Print expert cache stats if (cfg_.is_moe()) { expert_cache_.print_stats(); } return output; } // Generate with streaming callback template void generate_stream( const std::vector& prompt, int max_new_tokens, float temperature, float top_p, int top_k, Callback&& on_token ) { reset(); for (size_t i = 0; i < prompt.size(); ++i) forward(prompt[i]); for (int i = 0; i < max_new_tokens; ++i) { int32_t next = sample(temperature, top_p, top_k); if (is_eos(next)) break; if (!on_token(next)) break; forward(next); } } const Config& config() const { return cfg_; } Config& config_mut() { return cfg_; } void set_eos_token(int32_t eos) { eos_tokens_ = {eos}; } void add_eos_token(int32_t eos) { for (auto e : eos_tokens_) if (e == eos) return; eos_tokens_.push_back(eos); } bool is_eos(int32_t tok) const { for (auto e : eos_tokens_) if (tok == e) return true; return false; } ExpertCache& expert_cache_ref() { return expert_cache_; } private: Config cfg_; WeightsV6 weights_; bool initialized_ = false; std::vector eos_tokens_ = {2}; // MLA components MLAAttention mla_attn_; MLAKVCache mla_cache_; MLARoPE mla_rope_; // GQA fallback Attention gqa_attn_; KVCache gqa_cache_; kernel::RoPE gqa_rope_; // MoE components MoERouter router_; ExpertFFN expert_ffn_; SharedExpertFFN shared_ffn_; ExpertCache expert_cache_; // Dense FFN fallback FFN dense_ffn_; // Buffers std::vector x_, xb_, expert_out_, ffn_out_, logits_; std::mt19937 rng_{42}; // ─── EMBEDDING ───────────────────────────────────────────────────────── void embed(int32_t token) { gemm::embed_lookup(x_.data(), weights_.token_embd, weights_.t_embd, token, cfg_.dim); // Gemma: scale embeddings by sqrt(dim) if (cfg_.embed_scale_sqrt_dim) { float scale = sqrtf((float)cfg_.dim); for (int i = 0; i < cfg_.dim; ++i) x_[i] *= scale; } } // ─── LAYER FORWARD ───────────────────────────────────────────────────── void layer_forward(int layer) { const LayerWeightsV6& lw = weights_.layers[layer]; // === ATTENTION === kernel::vec_copy(xb_.data(), x_.data(), cfg_.dim); if (lw.attn_norm) { kernel::rms_norm(x_.data(), x_.data(), lw.attn_norm, cfg_.dim, cfg_.rms_norm_eps); } if (lw.is_mla) { // MLA attention path mla_attn_.forward( x_.data(), x_.data(), lw.w_q_a, lw.t_q_a, lw.q_a_norm, lw.w_q_b, lw.t_q_b, lw.w_kv_a, lw.t_kv_a, lw.kv_a_norm, lw.w_k_b, lw.t_k_b, lw.w_v_b, lw.t_v_b, lw.w_o, lw.t_o, mla_cache_, mla_rope_, layer ); } else if (!lw.is_mla) { if (lw.has_fused_qkv) { // ═══ TENSOR ADAPTER: Split fused QKV on the fly ═══ // QKV is [dim, q_dim + k_dim + v_dim] contiguous in memory int q_dim = cfg_.dim; int kv_dim = cfg_.n_kv_heads * cfg_.head_dim; // For quantized: compute byte offsets based on type const char* base = (const char*)lw.w_qkv_fused; size_t bpe = gemm::bytes_per_element(lw.t_qkv_fused, cfg_.dim); size_t q_off = 0; size_t k_off = q_dim * bpe; size_t v_off = (q_dim + kv_dim) * bpe; gqa_attn_.forward( x_.data(), x_.data(), (const void*)(base + q_off), lw.t_qkv_fused, (const void*)(base + k_off), lw.t_qkv_fused, (const void*)(base + v_off), lw.t_qkv_fused, lw.w_o, lw.t_o, gqa_cache_, gqa_rope_, layer, lw.bq, lw.bk, lw.bv ); } else { // Standard GQA attention path (for dense layers) gqa_attn_.forward( x_.data(), x_.data(), lw.wq, lw.tq, lw.wk, lw.tk, lw.wv, lw.tv, lw.w_o, lw.t_o, gqa_cache_, gqa_rope_, layer, lw.bq, lw.bk, lw.bv ); } } // Residual kernel::vec_add(x_.data(), xb_.data(), x_.data(), cfg_.dim); // Post-attention norm (Gemma-2) if (lw.post_attn_norm) { kernel::rms_norm(x_.data(), x_.data(), lw.post_attn_norm, cfg_.dim, cfg_.rms_norm_eps); } // === FFN === kernel::vec_copy(xb_.data(), x_.data(), cfg_.dim); if (lw.ffn_norm) { kernel::rms_norm(x_.data(), x_.data(), lw.ffn_norm, cfg_.dim, cfg_.rms_norm_eps); } if (lw.is_moe) { // ─── MoE FORWARD ─────────────────────────────────────────────── kernel::vec_zero(ffn_out_.data(), cfg_.dim); // Route: select top-K experts auto selected = router_.route(lw.gate_inp, x_.data()); expert_cache_.record_batch(layer, selected); // ─── SURGICAL MMAP PREFETCH (÷48 I/O) ──────────────────────── { std::vector active_ids; for (const auto& s : selected) active_ids.push_back(s.expert_id); KernelDispatch::instance().prefetch_experts( layer, active_ids.data(), (int)active_ids.size()); } // Dispatch to selected experts for (const auto& sel : selected) { expert_ffn_.forward( expert_out_.data(), x_.data(), sel.expert_id, lw.gate_exps, lw.t_gate_exps, lw.up_exps, lw.t_up_exps, lw.down_exps, lw.t_down_exps ); // Weighted accumulation for (int d = 0; d < cfg_.dim; ++d) { ffn_out_[d] += sel.weight * expert_out_[d]; } } // Shared expert (always active) if (lw.gate_shexp) { shared_ffn_.forward( expert_out_.data(), x_.data(), lw.gate_shexp, lw.t_gate_shexp, lw.up_shexp, lw.t_up_shexp, lw.down_shexp, lw.t_down_shexp ); // Add shared expert output kernel::vec_add(ffn_out_.data(), ffn_out_.data(), expert_out_.data(), cfg_.dim); } kernel::vec_copy(x_.data(), ffn_out_.data(), cfg_.dim); // ─── EVICT INACTIVE EXPERT PAGES ────────────────────────────── KernelDispatch::instance().evict_layer(layer); } if (!lw.is_moe && lw.w_ffn_gate) { // ─── Dense FFN ───────────────────────────────────────────────── dense_ffn_.forward( x_.data(), x_.data(), lw.w_ffn_gate, lw.t_ffn_gate, lw.w_ffn_up, lw.t_ffn_up, lw.w_ffn_down, lw.t_ffn_down ); } else if (!lw.is_moe && lw.has_fused_gate_up) { // ═══ TENSOR ADAPTER: Fused gate+up (Phi-3) ═══ // Split [dim, 2*intermediate] → gate [dim, intermediate] + up [dim, intermediate] int inter = cfg_.intermediate; size_t bpe = gemm::bytes_per_element(lw.t_ffn_gate_up_fused, cfg_.dim); const char* base = (const char*)lw.w_ffn_gate_up_fused; dense_ffn_.forward( x_.data(), x_.data(), (const void*)(base), lw.t_ffn_gate_up_fused, // gate half (const void*)(base + inter * bpe), lw.t_ffn_gate_up_fused, // up half lw.w_ffn_down, lw.t_ffn_down ); } // FFN Residual kernel::vec_add(x_.data(), xb_.data(), x_.data(), cfg_.dim); // Post-FFN norm (Gemma-2) if (lw.post_ffn_norm) { kernel::rms_norm(x_.data(), x_.data(), lw.post_ffn_norm, cfg_.dim, cfg_.rms_norm_eps); } } // ─── SAMPLING ────────────────────────────────────────────────────────── int32_t sample(float temperature, float top_p, int top_k) { int n = cfg_.vocab_size; if (temperature < 1e-6f) { return static_cast( std::max_element(logits_.begin(), logits_.end()) - logits_.begin() ); } for (int i = 0; i < n; ++i) logits_[i] /= temperature; kernel::softmax(logits_.data(), n); // Top-K if (top_k > 0 && top_k < n) { std::vector> indexed(n); for (int i = 0; i < n; ++i) indexed[i] = {logits_[i], i}; std::partial_sort(indexed.begin(), indexed.begin() + top_k, indexed.end(), [](const auto& a, const auto& b) { return a.first > b.first; }); for (int i = top_k; i < n; ++i) logits_[indexed[i].second] = 0.0f; } // Top-P if (top_p < 1.0f) { std::vector> sorted(n); for (int i = 0; i < n; ++i) sorted[i] = {logits_[i], i}; std::sort(sorted.begin(), sorted.end(), [](const auto& a, const auto& b) { return a.first > b.first; }); float cumsum = 0.0f; for (int i = 0; i < n; ++i) { cumsum += sorted[i].first; if (cumsum > top_p) { for (int j = i + 1; j < n; ++j) logits_[sorted[j].second] = 0.0f; break; } } } // Renormalize float sum = 0.0f; for (int i = 0; i < n; ++i) sum += logits_[i]; if (sum > 0.0f) for (int i = 0; i < n; ++i) logits_[i] /= sum; // Sample std::uniform_real_distribution dist(0.0f, 1.0f); float r = dist(rng_); float cumsum = 0.0f; for (int i = 0; i < n; ++i) { cumsum += logits_[i]; if (r < cumsum) return static_cast(i); } return static_cast(n - 1); } }; } // namespace ix