diff --git a/src/lib.zig b/src/lib.zig index 7fd12fb..44f5a22 100644 --- a/src/lib.zig +++ b/src/lib.zig @@ -880,6 +880,26 @@ fn packBits(bits: []const bool, l: *ArrayList(u8), allocator: Allocator) ![]chun pub fn hashTreeRoot(Hasher: type, T: type, value: T, out: *[Hasher.digest_length]u8, allocator: Allocator) !void { // Check if type has its own hashTreeRoot method at compile time if (comptime std.meta.hasFn(T, "hashTreeRoot")) { + // value is a by-value copy; if hashTreeRoot lazily allocates a new cache + // on this copy, we must free it to prevent leaks. But if the original + // already had a cache (shallow-copied into value), we must not free it. + const has_optional_cache = comptime blk: { + if (!std.meta.hasFn(T, "deinitCache")) break :blk false; + if (!@hasField(T, "cache")) break :blk false; + break :blk @typeInfo(@FieldType(T, "cache")) == .optional; + }; + if (has_optional_cache) { + const cache_before = value.cache; + var mutable = value; + errdefer if (cache_before == null and mutable.cache != null) { + mutable.deinitCache(); + }; + try mutable.hashTreeRoot(Hasher, out, allocator); + if (cache_before == null and mutable.cache != null) { + mutable.deinitCache(); + } + return; + } return value.hashTreeRoot(Hasher, out, allocator); } diff --git a/src/merkle_cache.zig b/src/merkle_cache.zig new file mode 100644 index 0000000..483c8cc --- /dev/null +++ b/src/merkle_cache.zig @@ -0,0 +1,270 @@ +const std = @import("std"); +const zeros = @import("./zeros.zig"); + +const BYTES_PER_CHUNK = 32; +const chunk = [BYTES_PER_CHUNK]u8; +const zero_chunk: chunk = [_]u8{0} ** BYTES_PER_CHUNK; + +/// A cached Merkle tree using a flat 1-indexed array representation. +/// Node 1 is the root. Node i has children 2i and 2i+1. +/// Leaves occupy indices [capacity .. 2*capacity). +pub fn MerkleCache(comptime Hasher: type) type { + return struct { + const Self = @This(); + const hashes_of_zero = zeros.buildHashesOfZero(Hasher, 32, 256); + + /// Flat array of tree nodes, 1-indexed. Length = 2 * capacity. + /// Index 0 is unused. nodes[1] = root. nodes[capacity..2*capacity] = leaves. + nodes: []chunk, + /// Number of leaf slots (next power of 2 of the limit). + capacity: usize, + /// Dirty leaf range (0-based, relative to leaf start). + /// dirty_low > dirty_high means no dirty leaves. + dirty_low: usize, + dirty_high: usize, + /// Whether the full tree has been computed at least once. + initialized: bool, + /// Cached length for mixInLength detection. + cached_length: usize, + /// Final root after mixInLength. + cached_root: chunk, + /// Whether cached_root is valid. + root_valid: bool, + + pub fn init(allocator: std.mem.Allocator, limit: usize) !Self { + const capacity = if (limit > 0) try std.math.ceilPowerOfTwo(usize, limit) else 1; + const nodes = try allocator.alloc([BYTES_PER_CHUNK]u8, 2 * capacity); + @memset(nodes, zero_chunk); + return .{ + .nodes = nodes, + .capacity = capacity, + .dirty_low = 0, + .dirty_high = 0, + .initialized = false, + .cached_length = 0, + .cached_root = zero_chunk, + .root_valid = false, + }; + } + + pub fn deinit(self: *Self, allocator: std.mem.Allocator) void { + allocator.free(self.nodes); + } + + pub fn markDirty(self: *Self, leaf_index: usize) void { + if (self.dirty_low > self.dirty_high) { + // Currently clean — set range to single leaf + self.dirty_low = leaf_index; + self.dirty_high = leaf_index; + } else { + if (leaf_index < self.dirty_low) self.dirty_low = leaf_index; + if (leaf_index > self.dirty_high) self.dirty_high = leaf_index; + } + self.root_valid = false; + } + + pub fn markAllDirty(self: *Self) void { + self.dirty_low = 0; + self.dirty_high = self.capacity - 1; + self.root_valid = false; + } + + /// Mark clean (dirty_low > dirty_high). + fn markClean(self: *Self) void { + self.dirty_low = 1; + self.dirty_high = 0; + } + + /// Set a leaf chunk value and mark it dirty. + pub fn setLeaf(self: *Self, index: usize, value: chunk) void { + self.nodes[self.capacity + index] = value; + self.markDirty(index); + } + + /// Recompute the Merkle root, only rehashing dirty paths. + /// `num_chunks` is the number of actual data chunks (rest are zero-padded). + /// Returns the data root (before mixInLength). + pub fn recompute(self: *Self, num_chunks: usize) chunk { + if (!self.initialized) { + // First time: set all leaves beyond data to zero, hash everything bottom-up + for (num_chunks..self.capacity) |i| { + self.nodes[self.capacity + i] = zero_chunk; + } + // Hash all internal nodes bottom-up + var level_size = self.capacity; + while (level_size > 1) : (level_size /= 2) { + const level_start = level_size; // start index of this level + var i = level_start; + while (i < level_start + level_size) : (i += 2) { + self.hashPair(i / 2, i, i + 1); + } + } + self.initialized = true; + self.markClean(); + return self.nodes[1]; + } + + // If nothing is dirty, return cached root + if (self.dirty_low > self.dirty_high) { + return self.nodes[1]; + } + + // Incremental update: rehash only dirty paths + // Start at leaf level, process dirty range, then walk up + var lo = self.dirty_low + self.capacity; + var hi = self.dirty_high + self.capacity; + + // Ensure parents of the dirty range are rehashed at each level + while (lo > 1) { + // Align to pairs: we need to hash the parent of each node in [lo, hi] + const pair_lo = lo - (lo % 2); // round down to even (left sibling) + const pair_hi = hi + 1 - (hi % 2); // round up to odd (right sibling) + + var i = pair_lo; + while (i < pair_hi) : (i += 2) { + self.hashPair(i / 2, i, i + 1); + } + + // Move to parent level + lo = pair_lo / 2; + hi = pair_hi / 2; + } + + self.markClean(); + return self.nodes[1]; + } + + fn hashPair(self: *Self, parent: usize, left: usize, right: usize) void { + var hasher = Hasher.init(Hasher.Options{}); + hasher.update(&self.nodes[left]); + hasher.update(&self.nodes[right]); + hasher.final(&self.nodes[parent]); + } + + /// Convenience: compute root with mixInLength applied. + pub fn recomputeWithLength(self: *Self, num_chunks: usize, length: usize) chunk { + const data_root = self.recompute(num_chunks); + + if (self.root_valid and self.cached_length == length) { + return self.cached_root; + } + + // Apply mixInLength + var length_buf: chunk = zero_chunk; + std.mem.writeInt(u64, length_buf[0..8], @intCast(length), .little); + + var hasher = Hasher.init(Hasher.Options{}); + hasher.update(&data_root); + hasher.update(&length_buf); + hasher.final(&self.cached_root); + self.cached_length = length; + self.root_valid = true; + return self.cached_root; + } + }; +} + +// Tests +const Sha256 = std.crypto.hash.sha2.Sha256; +const lib = @import("./lib.zig"); + +test "MerkleCache produces same root as merkleize for single chunk" { + const cache_type = MerkleCache(Sha256); + var cache = try cache_type.init(std.testing.allocator, 1); + defer cache.deinit(std.testing.allocator); + + const data: chunk = [_]u8{0xAB} ** 32; + cache.setLeaf(0, data); + + const cached_root = cache.recompute(1); + + var expected: chunk = undefined; + var chunks = [_]chunk{data}; + try lib.merkleize(Sha256, &chunks, 1, &expected); + + try std.testing.expectEqualSlices(u8, &expected, &cached_root); +} + +test "MerkleCache produces same root as merkleize for multiple chunks" { + const cache_type = MerkleCache(Sha256); + var cache = try cache_type.init(std.testing.allocator, 4); + defer cache.deinit(std.testing.allocator); + + var chunks: [3]chunk = undefined; + for (0..3) |i| { + const byte: u8 = @intCast(i + 1); + chunks[i] = [_]u8{byte} ** 32; + cache.setLeaf(i, chunks[i]); + } + + const cached_root = cache.recompute(3); + + var expected: chunk = undefined; + try lib.merkleize(Sha256, &chunks, 4, &expected); + + try std.testing.expectEqualSlices(u8, &expected, &cached_root); +} + +test "MerkleCache incremental update matches full rebuild" { + const cache_type = MerkleCache(Sha256); + var cache = try cache_type.init(std.testing.allocator, 4); + defer cache.deinit(std.testing.allocator); + + // Initial build with 4 chunks + var chunks: [4]chunk = undefined; + for (0..4) |i| { + const byte: u8 = @intCast(i + 1); + chunks[i] = [_]u8{byte} ** 32; + cache.setLeaf(i, chunks[i]); + } + _ = cache.recompute(4); + + // Modify one chunk and recompute incrementally + chunks[2] = [_]u8{0xFF} ** 32; + cache.setLeaf(2, chunks[2]); + const incremental_root = cache.recompute(4); + + // Full rebuild for comparison + var expected: chunk = undefined; + try lib.merkleize(Sha256, &chunks, 4, &expected); + + try std.testing.expectEqualSlices(u8, &expected, &incremental_root); +} + +test "MerkleCache recomputeWithLength matches merkleize + mixInLength" { + const cache_type = MerkleCache(Sha256); + var cache = try cache_type.init(std.testing.allocator, 8); + defer cache.deinit(std.testing.allocator); + + var chunks: [3]chunk = undefined; + for (0..3) |i| { + const byte: u8 = @intCast(i + 10); + chunks[i] = [_]u8{byte} ** 32; + cache.setLeaf(i, chunks[i]); + } + + const cached = cache.recomputeWithLength(3, 3); + + // Compare: merkleize then mixInLength2 + var data_root: chunk = undefined; + try lib.merkleize(Sha256, &chunks, 8, &data_root); + var expected: chunk = undefined; + lib.mixInLength2(Sha256, data_root, 3, &expected); + + try std.testing.expectEqualSlices(u8, &expected, &cached); +} + +test "MerkleCache empty chunks" { + const cache_type = MerkleCache(Sha256); + var cache = try cache_type.init(std.testing.allocator, 4); + defer cache.deinit(std.testing.allocator); + + // No leaves set — all zeros + const cached_root = cache.recompute(0); + + var expected: chunk = undefined; + const empty: []chunk = &.{}; + try lib.merkleize(Sha256, empty, 4, &expected); + + try std.testing.expectEqualSlices(u8, &expected, &cached_root); +} diff --git a/src/tests.zig b/src/tests.zig index 365413b..1ab283a 100644 --- a/src/tests.zig +++ b/src/tests.zig @@ -2300,6 +2300,134 @@ test "roundtrip: [2][4]u32 (nested fixed array) preserves inner elements" { try expect(std.mem.eql(u32, &out[1], &.{ 5, 6, 7, 8 })); } +test "Cached hashTreeRoot for List(u64) matches uncached" { + const ListU64 = utils.List(u64, 1024); + var list = try ListU64.init(std.testing.allocator); + defer list.deinit(); + + try list.append(1); + try list.append(2); + try list.append(3); + + // Compute hash via lib.hashTreeRoot (value copy, cache cleaned up after) + var uncached: [32]u8 = undefined; + try hashTreeRoot(Sha256, ListU64, list, &uncached, std.testing.allocator); + + // Compute hash directly (cache persists on instance) + var cached: [32]u8 = undefined; + try list.hashTreeRoot(Sha256, &cached, std.testing.allocator); + + try expect(std.mem.eql(u8, &cached, &uncached)); + + // Modify one element and verify cached recompute matches fresh uncached + try list.set(1, 42); + + var cached2: [32]u8 = undefined; + try list.hashTreeRoot(Sha256, &cached2, std.testing.allocator); + + // Build a fresh list with same data for comparison + var fresh = try ListU64.init(std.testing.allocator); + defer fresh.deinit(); + try fresh.append(1); + try fresh.append(42); + try fresh.append(3); + + var expected: [32]u8 = undefined; + try hashTreeRoot(Sha256, ListU64, fresh, &expected, std.testing.allocator); + + try expect(std.mem.eql(u8, &cached2, &expected)); + + // Verify it differs from original + try expect(!std.mem.eql(u8, &cached2, &uncached)); +} + +test "Cached hashTreeRoot for List with append" { + const ListU64 = utils.List(u64, 1024); + var list = try ListU64.init(std.testing.allocator); + defer list.deinit(); + + try list.append(10); + + var hash1: [32]u8 = undefined; + try list.hashTreeRoot(Sha256, &hash1, std.testing.allocator); + + // Append and recompute + try list.append(20); + var hash2: [32]u8 = undefined; + try list.hashTreeRoot(Sha256, &hash2, std.testing.allocator); + + // Compare with fresh uncached + var fresh = try ListU64.init(std.testing.allocator); + defer fresh.deinit(); + try fresh.append(10); + try fresh.append(20); + + var expected: [32]u8 = undefined; + try hashTreeRoot(Sha256, ListU64, fresh, &expected, std.testing.allocator); + + try expect(std.mem.eql(u8, &hash2, &expected)); + try expect(!std.mem.eql(u8, &hash1, &hash2)); +} + +test "Cached hashTreeRoot for composite List" { + const Point = struct { x: u32, y: u32 }; + const ListOfPoint = utils.List(Point, 100); + + var list = try ListOfPoint.init(std.testing.allocator); + defer list.deinit(); + + try list.append(.{ .x = 1, .y = 2 }); + try list.append(.{ .x = 3, .y = 4 }); + + var cached: [32]u8 = undefined; + try list.hashTreeRoot(Sha256, &cached, std.testing.allocator); + + var uncached: [32]u8 = undefined; + try hashTreeRoot(Sha256, ListOfPoint, list, &uncached, std.testing.allocator); + + try expect(std.mem.eql(u8, &cached, &uncached)); +} + +test "Cached hashTreeRoot for Bitlist matches uncached" { + const TestBitlist = utils.Bitlist(256); + var bl = try TestBitlist.init(std.testing.allocator); + defer bl.deinit(); + + try bl.append(true); + try bl.append(false); + try bl.append(true); + try bl.append(true); + + // Uncached + var uncached: [32]u8 = undefined; + try hashTreeRoot(Sha256, TestBitlist, bl, &uncached, std.testing.allocator); + + var cached: [32]u8 = undefined; + try bl.hashTreeRoot(Sha256, &cached, std.testing.allocator); + + try expect(std.mem.eql(u8, &cached, &uncached)); + + // Modify a bit and verify incremental update + try bl.set(1, true); + var cached2: [32]u8 = undefined; + try bl.hashTreeRoot(Sha256, &cached2, std.testing.allocator); + + // Fresh comparison + var fresh = try TestBitlist.init(std.testing.allocator); + defer fresh.deinit(); + try fresh.append(true); + try fresh.append(true); + try fresh.append(true); + try fresh.append(true); + + var expected: [32]u8 = undefined; + try hashTreeRoot(Sha256, TestBitlist, fresh, &expected, std.testing.allocator); + + try expect(std.mem.eql(u8, &cached2, &expected)); + try expect(!std.mem.eql(u8, &cached2, &uncached)); +} + test { _ = @import("beacon_tests.zig"); + _ = @import("merkle_cache.zig"); } diff --git a/src/utils.zig b/src/utils.zig index c699451..3caf8a2 100644 --- a/src/utils.zig +++ b/src/utils.zig @@ -8,6 +8,7 @@ const isFixedSizeObject = lib.isFixedSizeObject; const ArrayList = std.ArrayList; const Allocator = std.mem.Allocator; const hashes_of_zero = @import("./zeros.zig").hashes_of_zero; +const merkle_cache = @import("./merkle_cache.zig"); // SSZ specification constants const BYTES_PER_CHUNK = 32; @@ -30,6 +31,12 @@ pub fn List(T: type, comptime N: usize) type { inner: Inner, allocator: Allocator, + cache: ?*CacheType = null, + + /// Hasher-agnostic cache type. We use SHA256 as the default since that's the + /// standard SSZ hasher, but the cache is only used when hashTreeRoot is called + /// with a matching hasher. + const CacheType = merkle_cache.MerkleCache(std.crypto.hash.sha2.Sha256); pub fn sszEncode(self: *const Self, l: *ArrayList(u8), allocator: Allocator) !void { try serialize([]const Item, self.constSlice(), l, allocator); @@ -120,16 +127,18 @@ pub fn List(T: type, comptime N: usize) type { } pub fn deinit(self: *Self) void { + if (self.cache) |c| { + c.deinit(self.allocator); + self.allocator.destroy(c); + self.cache = null; + } self.inner.deinit(self.allocator); } pub fn append(self: *Self, item: Self.Item) error{ Overflow, OutOfMemory }!void { if (self.inner.items.len >= N) return error.Overflow; - return self.inner.append(self.allocator, item); - } - - pub fn slice(self: *Self) []T { - return self.inner.items; + try self.inner.append(self.allocator, item); + self.invalidateCacheForIndex(self.inner.items.len - 1); } pub fn constSlice(self: *const Self) []const T { @@ -151,6 +160,7 @@ pub fn List(T: type, comptime N: usize) type { pub fn set(self: *Self, i: usize, item: T) error{IndexOutOfBounds}!void { if (i >= self.inner.items.len) return error.IndexOutOfBounds; self.inner.items[i] = item; + self.invalidateCacheForIndex(i); } pub fn len(self: *const Self) usize { @@ -163,6 +173,29 @@ pub fn List(T: type, comptime N: usize) type { } pub fn hashTreeRoot(self: *const Self, Hasher: type, out: *[32]u8, allocator: Allocator) !void { + if (Hasher == std.crypto.hash.sha2.Sha256) { + // Lazily initialize cache using self.allocator for consistent ownership + if (self.cache == null) { + const limit = switch (@typeInfo(Item)) { + .int => blk: { + const bytes_per_item = @sizeOf(Item); + const items_per_chunk = BYTES_PER_CHUNK / bytes_per_item; + break :blk (N + items_per_chunk - 1) / items_per_chunk; + }, + else => N, + }; + const c = try self.allocator.create(CacheType); + errdefer self.allocator.destroy(c); + c.* = try CacheType.init(self.allocator, limit); + @constCast(&self.cache).* = c; + } + return self.hashTreeRootCached(out, allocator); + } + + return self.hashTreeRootUncached(Hasher, out, allocator); + } + + fn hashTreeRootUncached(self: *const Self, Hasher: type, out: *[32]u8, allocator: Allocator) !void { const items = self.constSlice(); switch (@typeInfo(Item)) { @@ -186,14 +219,118 @@ pub fn List(T: type, comptime N: usize) type { try lib.hashTreeRoot(Hasher, Item, item, &tmp, allocator); try chunks.append(allocator, tmp); } - // Always use N (max capacity) for merkleization, even when empty - // This ensures proper tree depth according to SSZ specification try lib.merkleize(Hasher, chunks.items, N, &tmp); lib.mixInLength2(Hasher, tmp, items.len, out); }, } } + /// Free the Merkle cache without freeing the list itself. + /// Called by lib.hashTreeRoot to clean up caches on value copies. + pub fn deinitCache(self: *Self) void { + if (self.cache) |c| { + c.deinit(self.allocator); + self.allocator.destroy(c); + self.cache = null; + } + } + + fn hashTreeRootCached(self: *const Self, out: *[32]u8, allocator: Allocator) !void { + const Sha256 = std.crypto.hash.sha2.Sha256; + const items = self.constSlice(); + const cache = self.cache.?; + + switch (@typeInfo(Item)) { + .int => { + // Pack items into chunks and update dirty leaves + const bytes_per_item = @sizeOf(Item); + const items_per_chunk = BYTES_PER_CHUNK / bytes_per_item; + const chunks_for_max_capacity = (N + items_per_chunk - 1) / items_per_chunk; + + if (!cache.initialized) { + // First time: set all leaf chunks + for (0..items.len) |i| { + const chunk_idx = i / items_per_chunk; + const pos_in_chunk = (i % items_per_chunk) * bytes_per_item; + var leaf = cache.nodes[cache.capacity + chunk_idx]; + // SSZ requires little-endian encoding + std.mem.writeInt(Item, leaf[pos_in_chunk..][0..bytes_per_item], items[i], .little); + cache.nodes[cache.capacity + chunk_idx] = leaf; + } + cache.markAllDirty(); + } + // For dirty chunks, rebuild from current items + if (cache.dirty_low <= cache.dirty_high) { + const lo = cache.dirty_low; + const hi = @min(cache.dirty_high, if (items.len > 0) (items.len - 1) / items_per_chunk else 0); + for (lo..hi + 1) |chunk_idx| { + var leaf: chunk = zero_chunk; + const start_item = chunk_idx * items_per_chunk; + const end_item = @min(start_item + items_per_chunk, items.len); + for (start_item..end_item) |item_i| { + const pos = (item_i % items_per_chunk) * bytes_per_item; + // SSZ requires little-endian encoding + std.mem.writeInt(Item, leaf[pos..][0..bytes_per_item], items[item_i], .little); + } + cache.nodes[cache.capacity + chunk_idx] = leaf; + } + // Zero out chunks beyond data + for ((if (items.len > 0) (items.len - 1) / items_per_chunk + 1 else 0)..chunks_for_max_capacity) |chunk_idx| { + if (chunk_idx >= cache.dirty_low and chunk_idx <= cache.dirty_high) { + cache.nodes[cache.capacity + chunk_idx] = zero_chunk; + } + } + } + + const num_data_chunks = if (items.len > 0) (items.len - 1) / items_per_chunk + 1 else 0; + out.* = cache.recomputeWithLength(num_data_chunks, items.len); + }, + else => { + // Composite types: each item is its own chunk (hash tree root of item) + // Composite items may contain pointers to mutable data that + // can change without going through set(), so we must always + // recompute all item hashes. We compare against cached leaves + // to only mark actually-changed nodes dirty for tree rehashing. + var tmp: chunk = undefined; + for (items, 0..) |item, i| { + try lib.hashTreeRoot(Sha256, Item, item, &tmp, allocator); + if (!std.mem.eql(u8, &tmp, &cache.nodes[cache.capacity + i])) { + cache.nodes[cache.capacity + i] = tmp; + cache.markDirty(i); + } + } + // Zero out any previously-occupied slots beyond current length + if (cache.initialized) { + for (items.len..cache.capacity) |i| { + if (!std.mem.eql(u8, &zero_chunk, &cache.nodes[cache.capacity + i])) { + cache.nodes[cache.capacity + i] = zero_chunk; + cache.markDirty(i); + } + } + } else { + cache.markAllDirty(); + } + + out.* = cache.recomputeWithLength(items.len, items.len); + }, + } + } + + /// Compute the chunk index for a given element index and mark it dirty in the cache. + fn invalidateCacheForIndex(self: *Self, element_index: usize) void { + if (self.cache) |cache| { + const chunk_idx = switch (@typeInfo(Item)) { + .int => blk: { + const bytes_per_item = @sizeOf(Item); + const items_per_chunk = BYTES_PER_CHUNK / bytes_per_item; + break :blk element_index / items_per_chunk; + }, + else => element_index, + }; + cache.markDirty(chunk_idx); + } + } + /// Decodes and validates the length from dynamic input pub fn decodeDynamicLength(buf: []const u8) !u32 { if (buf.len == 0) { @@ -228,6 +365,9 @@ pub fn Bitlist(comptime N: usize) type { inner: Inner, allocator: Allocator, length: usize, + cache: ?*BitCacheType = null, + + const BitCacheType = merkle_cache.MerkleCache(std.crypto.hash.sha2.Sha256); pub fn sszEncode(self: *const Self, l: *ArrayList(u8), allocator: Allocator) !void { if (self.length == 0) { @@ -315,6 +455,7 @@ pub fn Bitlist(comptime N: usize) type { const mask = ~@shlExact(@as(u8, 1), @truncate(i % 8)); const b = if (bit) @shlExact(@as(u8, 1), @truncate(i % 8)) else 0; self.inner.items[i / 8] = @truncate((self.inner.items[i / 8] & mask) | b); + if (self.cache) |c| c.markDirty(i / 256); } pub fn append(self: *Self, item: bool) error{ Overflow, OutOfMemory, IndexOutOfBounds }!void { @@ -331,6 +472,11 @@ pub fn Bitlist(comptime N: usize) type { } pub fn deinit(self: *Self) void { + if (self.cache) |c| { + c.deinit(self.allocator); + self.allocator.destroy(c); + self.cache = null; + } self.inner.deinit(self.allocator); } @@ -344,36 +490,91 @@ pub fn Bitlist(comptime N: usize) type { } pub fn hashTreeRoot(self: *const Self, Hasher: type, out: *[32]u8, allocator: Allocator) !void { + if (Hasher == std.crypto.hash.sha2.Sha256) { + // Lazily initialize cache using self.allocator for consistent ownership + if (self.cache == null) { + const chunk_count_limit = (N + 255) / 256; + const c = try self.allocator.create(BitCacheType); + errdefer self.allocator.destroy(c); + c.* = try BitCacheType.init(self.allocator, chunk_count_limit); + @constCast(&self.cache).* = c; + } + return self.hashTreeRootCached(out); + } + return self.hashTreeRootUncached(Hasher, out, allocator); + } + + fn hashTreeRootUncached(self: *const Self, Hasher: type, out: *[32]u8, allocator: Allocator) !void { const bit_length = self.length; var bitfield_bytes: ArrayList(u8) = .empty; defer bitfield_bytes.deinit(allocator); if (bit_length > 0) { - // Get the internal bit data since we don't store delimiter const sl = self.inner.items; try bitfield_bytes.appendSlice(allocator, sl[0..sl.len]); - // Remove trailing zeros but keep at least one byte - // This avoids the wasteful pattern of removing all zeros then adding back a chunk while (bitfield_bytes.items.len > 1 and bitfield_bytes.items[bitfield_bytes.items.len - 1] == 0) { _ = bitfield_bytes.pop(); } } - // Pack bits into chunks (pad to chunk boundary) const padding_size = (BYTES_PER_CHUNK - bitfield_bytes.items.len % BYTES_PER_CHUNK) % BYTES_PER_CHUNK; _ = try bitfield_bytes.appendSlice(allocator, zero_chunk[0..padding_size]); const chunks = std.mem.bytesAsSlice(chunk, bitfield_bytes.items); var tmp: chunk = undefined; - // Use chunk_count limit as per SSZ specification const chunk_count_limit = (N + 255) / 256; try lib.merkleize(Hasher, chunks, chunk_count_limit, &tmp); lib.mixInLength2(Hasher, tmp, bit_length, out); } + fn hashTreeRootCached(self: *const Self, out: *[32]u8) void { + const cache = self.cache.?; + const bit_length = self.length; + + // Update dirty leaf chunks from current bit data + if (!cache.initialized) { + // Set all leaf chunks from bit data + const sl = self.inner.items; + const num_data_chunks = if (sl.len > 0) (sl.len - 1) / BYTES_PER_CHUNK + 1 else 0; + for (0..num_data_chunks) |ci| { + var leaf: chunk = zero_chunk; + const start = ci * BYTES_PER_CHUNK; + const end = @min(start + BYTES_PER_CHUNK, sl.len); + @memcpy(leaf[0 .. end - start], sl[start..end]); + cache.nodes[cache.capacity + ci] = leaf; + } + cache.markAllDirty(); + } else if (cache.dirty_low <= cache.dirty_high) { + const sl = self.inner.items; + const hi = @min(cache.dirty_high, if (sl.len > 0) (sl.len - 1) / BYTES_PER_CHUNK else 0); + for (cache.dirty_low..hi + 1) |ci| { + var leaf: chunk = zero_chunk; + const start = ci * BYTES_PER_CHUNK; + const end = @min(start + BYTES_PER_CHUNK, sl.len); + if (start < sl.len) { + @memcpy(leaf[0 .. end - start], sl[start..end]); + } + cache.nodes[cache.capacity + ci] = leaf; + } + } + + const num_data_chunks = if (self.inner.items.len > 0) (self.inner.items.len - 1) / BYTES_PER_CHUNK + 1 else 0; + out.* = cache.recomputeWithLength(num_data_chunks, bit_length); + } + + /// Free the Merkle cache without freeing the bitlist itself. + /// Called by lib.hashTreeRoot to clean up caches on value copies. + pub fn deinitCache(self: *Self) void { + if (self.cache) |c| { + c.deinit(self.allocator); + self.allocator.destroy(c); + self.cache = null; + } + } + /// Validates that the bitlist is correctly formed pub fn validateBitlist(buf: []const u8) !void { const byte_len = buf.len;