inference-x/runtime/gguf.h

626 lines
30 KiB
C++

// ═══════════════════════════════════════════════════════════════════════════════
// INFERENCE-X — GGUF Multi-Shard Model Parser
// Copyright (C) 2024-2026 Salka Elmadani. All rights reserved.
// Licensed under the Business Source License 1.1 (BSL-1.1)
// See LICENSE file for full terms.
//
// INTELLECTUAL PROPERTY PROTECTION:
// - INPI eSoleau deposit: 7phf-Ueye-2nWr-Vsgu (16/02/2026)
// - GitHub: git.inference-x.com/salka/inference-x
// - Author: Salka Elmadani | Morocco | Morocco
//
// MANUFACTURER NOTICE: Any manufacturer, company, or entity that
// incorporates, embeds, distributes, or commercially uses Inference-X
// or any derivative work without explicit written authorization from
// the copyright holder is in violation of BSL-1.1 and applicable
// intellectual property laws. This includes but is not limited to:
// hardware vendors, cloud providers, SaaS platforms, and OEMs.
//
// Contact: Elmadani.SALKA@proton.me for licensing.
// ═══════════════════════════════════════════════════════════════════════════════
#pragma once
// Inference-X GGUF Parser — Salka Elmadani — Morocco
#define IX_GGUF_WATERMARK "Inference-X-GGUF-935-Elmadani"
#include "../core/z_core.h"
#include <sys/mman.h>
#include <sys/stat.h>
#include <fcntl.h>
#include <unistd.h>
#include <string>
#include <vector>
#include <unordered_map>
#include <algorithm>
#include <cstdio>
#include <filesystem>
#include <functional>
namespace ix {
// ═══════════════════════════════════════════════════════════════════════════════
// GGUF CONSTANTS
// ═══════════════════════════════════════════════════════════════════════════════
static constexpr uint32_t GGUF_MAGIC = 0x46554747; // "GGUF"
enum class GGUFType : uint32_t {
UINT8 = 0, INT8 = 1, UINT16 = 2, INT16 = 3,
UINT32 = 4, INT32 = 5, FLOAT32 = 6, BOOL = 7,
STRING = 8, ARRAY = 9, UINT64 = 10, INT64 = 11, FLOAT64 = 12
};
// ═══════════════════════════════════════════════════════════════════════════════
// TENSOR INFO — Points into mmap'd shard
// ═══════════════════════════════════════════════════════════════════════════════
struct TensorInfo {
std::string name;
uint32_t n_dims;
uint64_t dims[4];
dtype type;
uint64_t offset; // offset within shard's data segment
uint64_t n_elements;
uint64_t n_bytes;
void* data; // pointer into mmap'd memory
int shard_idx; // which shard this tensor lives in
void compute_size() {
n_elements = 1;
for (uint32_t i = 0; i < n_dims; ++i) n_elements *= dims[i];
int bs = dtype_block_size(type);
if (bs > 1) {
n_bytes = (n_elements / bs) * dtype_size(type);
} else {
n_bytes = n_elements * dtype_size(type);
}
}
};
// ═══════════════════════════════════════════════════════════════════════════════
// MMAP'D SHARD — One per GGUF file
// ═══════════════════════════════════════════════════════════════════════════════
struct Shard {
int fd = -1;
uint8_t* data = nullptr;
size_t size = 0;
std::string path;
bool open(const std::string& p) {
path = p;
fd = ::open(p.c_str(), O_RDONLY);
if (fd < 0) return false;
struct stat st;
if (fstat(fd, &st) < 0) { close(); return false; }
size = st.st_size;
data = static_cast<uint8_t*>(mmap(nullptr, size, PROT_READ, MAP_PRIVATE, fd, 0));
if (data == MAP_FAILED) { data = nullptr; close(); return false; }
// MADV_RANDOM for MoE: experts accessed non-sequentially
madvise(data, size, MADV_SEQUENTIAL);
madvise(data, size, MADV_WILLNEED);
printf("[GGUF] Preloading %.1f MB into RAM...\n", size / 1e6);
return true;
}
void close() {
if (data) { munmap(data, size); data = nullptr; }
if (fd >= 0) { ::close(fd); fd = -1; }
}
~Shard() { close(); }
Shard() = default;
Shard(Shard&& o) noexcept : fd(o.fd), data(o.data), size(o.size), path(std::move(o.path)) {
o.fd = -1; o.data = nullptr; o.size = 0;
}
Shard& operator=(Shard&& o) noexcept {
close();
fd = o.fd; data = o.data; size = o.size; path = std::move(o.path);
o.fd = -1; o.data = nullptr; o.size = 0;
return *this;
}
};
// ═══════════════════════════════════════════════════════════════════════════════
// GGUF MULTI-SHARD — Mmap-based zero-copy across N shards
// ═══════════════════════════════════════════════════════════════════════════════
class GGUF {
public:
GGUF() = default;
~GGUF() { close(); }
GGUF(const GGUF&) = delete;
GGUF& operator=(const GGUF&) = delete;
// ═══════════════════════════════════════════════════════════════════════════
// OPEN — Detects single vs multi-shard automatically
// Pass any shard or the first shard; will discover the rest.
// ═══════════════════════════════════════════════════════════════════════════
bool open(const std::string& path) {
if (!signature::verify()) return false;
// If path is a directory, find GGUF files inside
std::string resolved = path;
struct stat st;
if (stat(path.c_str(), &st) == 0 && S_ISDIR(st.st_mode)) {
// Glob for *.gguf in directory
std::vector<std::string> gguf_files;
for (const auto& entry : std::filesystem::directory_iterator(path)) {
if (entry.path().extension() == ".gguf") {
gguf_files.push_back(entry.path().string());
}
}
std::sort(gguf_files.begin(), gguf_files.end());
if (gguf_files.empty()) {
printf("[GGUF] ERROR: No .gguf files in directory: %s\n", path.c_str());
return false;
}
resolved = gguf_files[0]; // Use first shard for discovery
printf("[GGUF] Directory mode: found %zu .gguf file(s)\n", gguf_files.size());
}
// Discover all shards
auto shard_paths = discover_shards(resolved);
if (shard_paths.empty()) {
shard_paths.push_back(resolved); // single file
}
printf("[GGUF] Loading %zu shard(s)...\n", shard_paths.size());
shards_.resize(shard_paths.size());
for (size_t i = 0; i < shard_paths.size(); ++i) {
if (!shards_[i].open(shard_paths[i])) {
printf("[GGUF] ERROR: Failed to mmap shard %zu: %s\n", i, shard_paths[i].c_str());
close();
return false;
}
printf("[GGUF] Shard %zu: %s (%.1f GB)\n", i, shard_paths[i].c_str(),
shards_[i].size / 1e9);
}
// Parse metadata from shard 0 (contains all KV pairs)
if (!parse_shard(0, true)) {
printf("[GGUF] ERROR: Failed to parse metadata from shard 0\n");
close();
return false;
}
// Parse tensor index from all shards
for (size_t i = 1; i < shards_.size(); ++i) {
if (!parse_shard(i, false)) {
printf("[GGUF] ERROR: Failed to parse shard %zu\n", i);
close();
return false;
}
}
printf("[GGUF] Total tensors: %zu across %zu shards\n",
tensors_.size(), shards_.size());
// Validate tensor pointers against shard boundaries
int bad = 0;
for (size_t i = 0; i < tensors_.size(); ++i) {
const auto& ti = tensors_[i];
if (!ti.data) { printf("[VALIDATE] WARN: %s has null data\n", ti.name.c_str()); bad++; continue; }
uint8_t* start = static_cast<uint8_t*>(ti.data);
uint8_t* end = start + ti.n_bytes;
uint8_t* shard_start = shards_[ti.shard_idx].data;
uint8_t* shard_end = shard_start + shards_[ti.shard_idx].size;
if (start < shard_start || end > shard_end) {
printf("[VALIDATE] BAD: %s type=%d(%s) nelem=%lu nbytes=%lu shard=%d off=%lu shard_sz=%zu overflow=%ld\n",
ti.name.c_str(), (int)ti.type, dtype_name(ti.type), ti.n_elements, ti.n_bytes, ti.shard_idx, ti.offset,
shards_[ti.shard_idx].size, (long)(end - shard_end));
bad++;
}
}
if (bad > 0) printf("[VALIDATE] %d tensors have invalid pointers!\n", bad);
else printf("[VALIDATE] All %zu tensors OK\n", tensors_.size());
return true;
}
void close() {
shards_.clear();
tensors_.clear();
tensor_map_.clear();
meta_u32_.clear();
meta_u64_.clear();
meta_f32_.clear();
meta_str_.clear();
meta_str_arr_.clear();
meta_i32_arr_.clear();
}
// ═══════════════════════════════════════════════════════════════════════════
// METADATA ACCESS
// ═══════════════════════════════════════════════════════════════════════════
uint32_t get_u32(const std::string& k, uint32_t d = 0) const {
auto it = meta_u32_.find(k); return it != meta_u32_.end() ? it->second : d;
}
uint64_t get_u64(const std::string& k, uint64_t d = 0) const {
auto it = meta_u64_.find(k); return it != meta_u64_.end() ? it->second : d;
}
float get_f32(const std::string& k, float d = 0) const {
auto it = meta_f32_.find(k); return it != meta_f32_.end() ? it->second : d;
}
std::string get_str(const std::string& k, const std::string& d = "") const {
auto it = meta_str_.find(k); return it != meta_str_.end() ? it->second : d;
}
const std::vector<std::string>* get_str_arr(const std::string& k) const {
auto it = meta_str_arr_.find(k); return it != meta_str_arr_.end() ? &it->second : nullptr;
}
const std::vector<int32_t>* get_i32_arr(const std::string& k) const {
auto it = meta_i32_arr_.find(k); return it != meta_i32_arr_.end() ? &it->second : nullptr;
}
// ═══════════════════════════════════════════════════════════════════════════
// TENSOR ACCESS
// ═══════════════════════════════════════════════════════════════════════════
const TensorInfo* tensor(const std::string& name) const {
auto it = tensor_map_.find(name);
return it != tensor_map_.end() ? &tensors_[it->second] : nullptr;
}
const std::vector<TensorInfo>& tensors() const { return tensors_; }
size_t num_shards() const { return shards_.size(); }
// ═══════════════════════════════════════════════════════════════════════════
// PREFETCH — Hint kernel to preload expert pages
// ═══════════════════════════════════════════════════════════════════════════
void prefetch_tensor(const std::string& name) const {
const TensorInfo* ti = tensor(name);
if (ti && ti->data) {
madvise(const_cast<void*>(ti->data), ti->n_bytes, MADV_WILLNEED);
}
}
void prefetch_experts(int layer, const std::vector<int>& expert_ids) const {
for (int eid __attribute__((unused)) : expert_ids) {
// Expert tensors are 3D: [expert_ffn_dim, dim, n_experts]
// We need to prefetch the slice for this expert
std::string prefix = "blk." + std::to_string(layer) + ".";
prefetch_tensor(prefix + "ffn_gate_exps.weight");
prefetch_tensor(prefix + "ffn_up_exps.weight");
prefetch_tensor(prefix + "ffn_down_exps.weight");
}
}
// ═══════════════════════════════════════════════════════════════════════════
// CONFIG EXTRACTION — Architecture-aware
// ═══════════════════════════════════════════════════════════════════════════
Config extract_config() const {
Config c;
// Detect architecture — all model families
std::string arch_str = get_str("general.architecture", "llama");
if (arch_str == "deepseek2") {
c.arch = Architecture::DEEPSEEK2;
c.activation = Activation::SILU;
} else if (arch_str == "qwen2") {
c.arch = Architecture::QWEN2;
c.activation = Activation::SILU;
} else if (arch_str == "phi3" || arch_str == "phi") {
c.arch = Architecture::PHI3;
c.activation = Activation::GELU;
} else if (arch_str == "gemma" || arch_str == "gemma2") {
c.arch = Architecture::GEMMA2;
c.activation = Activation::GELU;
} else if (arch_str == "starcoder2") {
c.arch = Architecture::STARCODER2;
c.activation = Activation::GELU;
} else if (arch_str == "command-r" || arch_str == "cohere") {
c.arch = Architecture::COMMAND_R;
c.activation = Activation::SILU;
} else {
c.arch = Architecture::LLAMA;
c.activation = Activation::SILU;
}
// Architecture prefix for KV lookups
std::string pfx = arch_str + ".";
// === Common ===
c.dim = get_u32(pfx + "embedding_length", 4096);
c.n_layers = get_u32(pfx + "block_count", 32);
c.n_heads = get_u32(pfx + "attention.head_count", 32);
c.n_kv_heads = get_u32(pfx + "attention.head_count_kv", c.n_heads);
c.vocab_size = get_u32(pfx + "vocab_size", 0);
// Fallback: count tokenizer tokens if metadata key missing
if (c.vocab_size == 0) {
auto* toks = get_str_arr("tokenizer.ggml.tokens");
if (toks && !toks->empty()) {
c.vocab_size = (uint32_t)toks->size();
fprintf(stderr, "[GGUF] vocab_size from token array: %u\n", c.vocab_size);
}
} // 0 = fixup from tokenizer
c.max_seq_len = get_u32(pfx + "context_length", 4096);
c.intermediate = get_u32(pfx + "feed_forward_length", 11008);
// Sliding window attention
c.sliding_window = get_u32(pfx + "attention.sliding_window", 0);
c.rope_theta = get_f32(pfx + "rope.freq_base", 10000.0f);
c.rms_norm_eps = get_f32(pfx + "attention.layer_norm_rms_epsilon", 1e-5f);
// === MLA (DeepSeek V3) ===
if (c.arch == Architecture::DEEPSEEK2) {
c.q_lora_rank = get_u32(pfx + "attention.q_lora_rank", 0);
c.kv_lora_rank = get_u32(pfx + "attention.kv_lora_rank", 0);
c.key_length = get_u32(pfx + "attention.key_length", 576);
c.value_length = get_u32(pfx + "attention.value_length", 512);
c.key_length_mla = get_u32(pfx + "attention.key_length_mla", 192);
c.value_length_mla = get_u32(pfx + "attention.value_length_mla", 128);
c.rope_dim = get_u32(pfx + "rope.dimension_count", 64);
// === MoE ===
c.n_experts = get_u32(pfx + "expert_count", 0);
c.n_experts_used = get_u32(pfx + "expert_used_count", 0);
c.n_expert_shared = get_u32(pfx + "expert_shared_count", 0);
c.expert_ffn_dim = get_u32(pfx + "expert_feed_forward_length", 2048);
c.n_dense_layers = get_u32(pfx + "leading_dense_block_count", 1);
c.n_expert_groups = get_u32(pfx + "expert_group_count", 1);
c.n_expert_groups_used = get_u32(pfx + "expert_group_used_count", 1);
c.expert_gating_func = get_u32(pfx + "expert_gating_func", 0);
c.expert_weights_scale = get_f32(pfx + "expert_weights_scale", 1.0f);
c.expert_weights_norm = get_u32(pfx + "expert_weights_norm", 0) != 0;
// === YaRN RoPE scaling ===
c.rope_scaling_factor = get_f32(pfx + "rope.scaling.factor", 1.0f);
c.rope_scaling_orig_ctx = get_u32(pfx + "rope.scaling.original_context_length", 4096);
c.rope_yarn_beta_fast = get_f32(pfx + "rope.scaling.yarn_beta_fast", 32.0f);
c.rope_yarn_beta_slow = get_f32(pfx + "rope.scaling.yarn_beta_slow", 1.0f);
c.rope_yarn_log_mul = get_f32(pfx + "rope.scaling.yarn_log_multiplier", 0.1f);
}
// Universal adapter: logit caps, embedding scale
c.attn_logit_softcap = get_f32(pfx + "attn_logit_softcapping", 0.0f);
c.final_logit_softcap = get_f32(pfx + "final_logit_softcapping", 0.0f);
// Gemma: embeddings scaled by sqrt(dim)
if (arch_str == "gemma" || arch_str == "gemma2") {
c.embed_scale_sqrt_dim = true;
}
c.compute_derived();
return c;
}
// ═══════════════════════════════════════════════════════════════════════════
// MEMORY STATS
// ═══════════════════════════════════════════════════════════════════════════
void print_memory_map() const {
size_t total_mmap = 0;
for (const auto& s : shards_) total_mmap += s.size;
printf("\n=== Memory Map ===\n");
printf(" Total mmap'd: %.1f GB\n", total_mmap / 1e9);
printf(" Total tensors: %zu\n", tensors_.size());
// Categorize
size_t expert_bytes = 0, attn_bytes = 0, router_bytes = 0;
size_t norm_bytes = 0, embed_bytes = 0, other_bytes = 0;
int expert_count = 0, attn_count = 0;
for (const auto& t : tensors_) {
if (t.name.find("exps") != std::string::npos) {
expert_bytes += t.n_bytes; expert_count++;
} else if (t.name.find("attn") != std::string::npos) {
attn_bytes += t.n_bytes; attn_count++;
} else if (t.name.find("gate_inp") != std::string::npos) {
router_bytes += t.n_bytes;
} else if (t.name.find("norm") != std::string::npos) {
norm_bytes += t.n_bytes;
} else if (t.name.find("embd") != std::string::npos ||
t.name.find("output") != std::string::npos) {
embed_bytes += t.n_bytes;
} else {
other_bytes += t.n_bytes;
}
}
printf(" Expert tensors: %d (%.1f GB) — MoE sleeping experts\n",
expert_count, expert_bytes / 1e9);
printf(" Attention: %d (%.1f GB) — MLA projections\n",
attn_count, attn_bytes / 1e9);
printf(" Router: %.1f MB — F32 gating (sacred)\n", router_bytes / 1e6);
printf(" Embed+Output: %.1f GB\n", embed_bytes / 1e9);
printf(" Norms: %.1f MB\n", norm_bytes / 1e6);
printf(" Other: %.1f GB\n", other_bytes / 1e9);
// Active memory estimate for MoE
Config cfg = extract_config();
if (cfg.is_moe()) {
float active_expert_ratio = (float)cfg.n_experts_used / cfg.n_experts;
float active_expert_gb = expert_bytes * active_expert_ratio / 1e9;
float active_total = active_expert_gb + attn_bytes / 1e9 +
router_bytes / 1e9 + norm_bytes / 1e9;
printf("\n=== Active per Token (MoE) ===\n");
printf(" Experts active: %d/%d (%.1f%%)\n",
cfg.n_experts_used, cfg.n_experts, active_expert_ratio * 100);
printf(" Active experts: %.1f GB\n", active_expert_gb);
printf(" + Attention: %.1f GB\n", attn_bytes / 1e9);
printf(" + Router (F32): %.1f MB\n", router_bytes / 1e6);
printf(" ≈ Active total: %.1f GB ← fits in 17GB RAM\n", active_total);
}
}
private:
std::vector<Shard> shards_;
std::vector<TensorInfo> tensors_;
std::unordered_map<std::string, size_t> tensor_map_;
std::unordered_map<std::string, uint32_t> meta_u32_;
std::unordered_map<std::string, uint64_t> meta_u64_;
std::unordered_map<std::string, float> meta_f32_;
std::unordered_map<std::string, std::string> meta_str_;
std::unordered_map<std::string, std::vector<std::string>> meta_str_arr_;
std::unordered_map<std::string, std::vector<int32_t>> meta_i32_arr_;
// ═══════════════════════════════════════════════════════════════════════════
// SHARD DISCOVERY — Find all -NNNNN-of-NNNNN.gguf siblings
// ═══════════════════════════════════════════════════════════════════════════
std::vector<std::string> discover_shards(const std::string& path) {
std::vector<std::string> result;
// Check if this looks like a split GGUF
// Pattern: *-00001-of-00005.gguf
auto pos = path.rfind("-of-");
if (pos == std::string::npos) return result; // single file
auto dash_before = path.rfind('-', pos - 1);
if (dash_before == std::string::npos) return result;
std::string prefix = path.substr(0, dash_before + 1); // includes trailing dash
std::string suffix_after_of = path.substr(pos + 4); // "00005.gguf"
// Extract total count
auto dot = suffix_after_of.find('.');
if (dot == std::string::npos) return result;
int total = std::atoi(suffix_after_of.substr(0, dot).c_str());
std::string ext = suffix_after_of.substr(dot); // ".gguf"
// Build shard paths
for (int i = 1; i <= total; ++i) {
char buf[16];
snprintf(buf, sizeof(buf), "%05d", i);
char buf2[16];
snprintf(buf2, sizeof(buf2), "%05d", total);
std::string shard_path = prefix + buf + "-of-" + buf2 + ext;
// Verify file exists
struct stat st;
if (stat(shard_path.c_str(), &st) == 0) {
result.push_back(shard_path);
} else {
printf("[GGUF] WARNING: Expected shard not found: %s\n", shard_path.c_str());
}
}
return result;
}
// ═══════════════════════════════════════════════════════════════════════════
// PARSE ONE SHARD
// ═══════════════════════════════════════════════════════════════════════════
// Binary reader helper — avoid template lambdas for C++17 compat
struct Reader {
const uint8_t*& ptr;
Reader(const uint8_t*& p) : ptr(p) {}
template<typename T> T read() {
T v; std::memcpy(&v, ptr, sizeof(T)); ptr += sizeof(T); return v;
}
std::string read_str() {
uint64_t len = read<uint64_t>();
std::string s(reinterpret_cast<const char*>(ptr), len);
ptr += len;
return s;
}
void skip_val(GGUFType t) {
switch (t) {
case GGUFType::UINT8: case GGUFType::INT8: case GGUFType::BOOL: ptr += 1; break;
case GGUFType::UINT16: case GGUFType::INT16: ptr += 2; break;
case GGUFType::UINT32: case GGUFType::INT32: case GGUFType::FLOAT32: ptr += 4; break;
case GGUFType::UINT64: case GGUFType::INT64: case GGUFType::FLOAT64: ptr += 8; break;
case GGUFType::STRING: read_str(); break;
case GGUFType::ARRAY: {
GGUFType at = static_cast<GGUFType>(read<uint32_t>());
uint64_t n = read<uint64_t>();
for (uint64_t i = 0; i < n; ++i) skip_val(at);
break;
}
}
}
};
bool parse_shard(int shard_idx, bool read_metadata) {
const uint8_t* ptr = shards_[shard_idx].data;
Reader r(ptr);
// Header
if (r.read<uint32_t>() != GGUF_MAGIC) return false;
uint32_t ver = r.read<uint32_t>();
if (ver < 2 || ver > 3) return false;
uint64_t n_tensors = r.read<uint64_t>();
uint64_t n_kv = r.read<uint64_t>();
// KV pairs
for (uint64_t i = 0; i < n_kv; ++i) {
std::string key = r.read_str();
GGUFType type = static_cast<GGUFType>(r.read<uint32_t>());
if (read_metadata) {
switch (type) {
case GGUFType::UINT8: meta_u32_[key] = r.read<uint8_t>(); break;
case GGUFType::BOOL: meta_u32_[key] = r.read<uint8_t>(); break;
case GGUFType::UINT16: meta_u32_[key] = r.read<uint16_t>(); break;
case GGUFType::INT16: meta_u32_[key] = static_cast<uint32_t>(r.read<int16_t>()); break;
case GGUFType::UINT32: meta_u32_[key] = r.read<uint32_t>(); break;
case GGUFType::INT32: meta_u32_[key] = static_cast<uint32_t>(r.read<int32_t>()); break;
case GGUFType::FLOAT32: meta_f32_[key] = r.read<float>(); break;
case GGUFType::UINT64: meta_u64_[key] = r.read<uint64_t>(); break;
case GGUFType::INT64: meta_u64_[key] = static_cast<uint64_t>(r.read<int64_t>()); break;
case GGUFType::FLOAT64: meta_f32_[key] = static_cast<float>(r.read<double>()); break;
case GGUFType::STRING: meta_str_[key] = r.read_str(); break;
case GGUFType::ARRAY: {
GGUFType at = static_cast<GGUFType>(r.read<uint32_t>());
uint64_t n = r.read<uint64_t>();
if (at == GGUFType::STRING) {
auto& arr = meta_str_arr_[key];
arr.resize(n);
for (uint64_t j = 0; j < n; ++j) arr[j] = r.read_str();
} else if (at == GGUFType::INT32 || at == GGUFType::UINT32) {
auto& arr = meta_i32_arr_[key];
arr.resize(n);
for (uint64_t j = 0; j < n; ++j) arr[j] = r.read<int32_t>();
} else {
for (uint64_t j = 0; j < n; ++j) r.skip_val(at);
}
break;
}
default: r.skip_val(type); break;
}
} else {
r.skip_val(type);
}
}
// Tensor index
size_t base_idx = tensors_.size();
tensors_.reserve(base_idx + n_tensors);
for (uint64_t i = 0; i < n_tensors; ++i) {
TensorInfo ti;
ti.name = r.read_str();
ti.n_dims = r.read<uint32_t>();
for (uint32_t d = 0; d < 4; ++d) {
ti.dims[d] = (d < ti.n_dims) ? r.read<uint64_t>() : 1;
}
ti.type = static_cast<dtype>(r.read<uint32_t>());
ti.offset = r.read<uint64_t>();
ti.shard_idx = shard_idx;
ti.compute_size();
ti.data = nullptr;
tensor_map_[ti.name] = tensors_.size();
tensors_.push_back(std::move(ti));
}
// Resolve data pointers
uint32_t align = read_metadata ?
get_u32("general.alignment", 32) : 32;
uint64_t data_start = ((ptr - shards_[shard_idx].data) + align - 1) & ~(uint64_t)(align - 1);
for (size_t i = base_idx; i < tensors_.size(); ++i) {
auto& ti = tensors_[i];
if (ti.shard_idx == shard_idx) {
ti.data = shards_[shard_idx].data + data_start + ti.offset;
}
}
return true;
}
};
} // namespace ix