Skip to content

Commit bf6ffa7

Browse files
authored
Merge pull request #1467 from Roasbeef/copy-tree
mssmt: add new Copy method and InsertMany to optimize slightly
2 parents b4feb3e + bd32d61 commit bf6ffa7

File tree

4 files changed

+590
-1
lines changed

4 files changed

+590
-1
lines changed

mssmt/compacted_tree.go

Lines changed: 188 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -392,3 +392,191 @@ func (t *CompactedTree) MerkleProof(ctx context.Context, key [hashSize]byte) (
392392

393393
return NewProof(proof), nil
394394
}
395+
396+
// collectLeavesRecursive is a recursive helper function that's used to traverse
397+
// down an MS-SMT tree and collect all leaf nodes. It returns a map of leaf
398+
// nodes indexed by their hash.
399+
func collectLeavesRecursive(ctx context.Context, tx TreeStoreViewTx, node Node,
400+
depth int) (map[[hashSize]byte]*LeafNode, error) {
401+
402+
// Base case: If it's a compacted leaf node.
403+
if compactedLeaf, ok := node.(*CompactedLeafNode); ok {
404+
if compactedLeaf.LeafNode.IsEmpty() {
405+
return make(map[[hashSize]byte]*LeafNode), nil
406+
}
407+
return map[[hashSize]byte]*LeafNode{
408+
compactedLeaf.Key(): compactedLeaf.LeafNode,
409+
}, nil
410+
}
411+
412+
// Recursive step: If it's a branch node.
413+
if branchNode, ok := node.(*BranchNode); ok {
414+
// Optimization: if the branch is empty, return early.
415+
if depth < MaxTreeLevels &&
416+
IsEqualNode(branchNode, EmptyTree[depth]) {
417+
418+
return make(map[[hashSize]byte]*LeafNode), nil
419+
}
420+
421+
// Handle case where depth might exceed EmptyTree bounds if
422+
// logic error exists
423+
if depth >= MaxTreeLevels {
424+
// This shouldn't happen if called correctly, implies a
425+
// leaf.
426+
return nil, fmt.Errorf("invalid depth %d for branch "+
427+
"node", depth)
428+
}
429+
430+
left, right, err := tx.GetChildren(depth, branchNode.NodeHash())
431+
if err != nil {
432+
// If children not found, it might be an empty branch
433+
// implicitly Check if the error indicates "not found"
434+
// or similar Depending on store impl, this might be how
435+
// empty is signaled For now, treat error as fatal.
436+
return nil, fmt.Errorf("error getting children for "+
437+
"branch %s at depth %d: %w",
438+
branchNode.NodeHash(), depth, err)
439+
}
440+
441+
leftLeaves, err := collectLeavesRecursive(
442+
ctx, tx, left, depth+1,
443+
)
444+
if err != nil {
445+
return nil, err
446+
}
447+
448+
rightLeaves, err := collectLeavesRecursive(
449+
ctx, tx, right, depth+1,
450+
)
451+
if err != nil {
452+
return nil, err
453+
}
454+
455+
// Merge the results.
456+
for k, v := range rightLeaves {
457+
// Check for duplicate keys, although this shouldn't
458+
// happen in a valid SMT.
459+
if _, exists := leftLeaves[k]; exists {
460+
return nil, fmt.Errorf("duplicate key %x "+
461+
"found during leaf collection", k)
462+
}
463+
leftLeaves[k] = v
464+
}
465+
466+
return leftLeaves, nil
467+
}
468+
469+
// Handle unexpected node types or implicit empty nodes. If node is nil
470+
// or explicitly an EmptyLeafNode representation
471+
if node == nil || IsEqualNode(node, EmptyLeafNode) {
472+
return make(map[[hashSize]byte]*LeafNode), nil
473+
}
474+
475+
// Check against EmptyTree branches if possible (requires depth)
476+
if depth < MaxTreeLevels && IsEqualNode(node, EmptyTree[depth]) {
477+
return make(map[[hashSize]byte]*LeafNode), nil
478+
}
479+
480+
return nil, fmt.Errorf("unexpected node type %T encountered "+
481+
"during leaf collection at depth %d", node, depth)
482+
}
483+
484+
// Copy copies all the key-value pairs from the source tree into the target
485+
// tree.
486+
func (t *CompactedTree) Copy(ctx context.Context, targetTree Tree) error {
487+
var leaves map[[hashSize]byte]*LeafNode
488+
err := t.store.View(ctx, func(tx TreeStoreViewTx) error {
489+
root, err := tx.RootNode()
490+
if err != nil {
491+
return fmt.Errorf("error getting root node: %w", err)
492+
}
493+
494+
// Optimization: If the source tree is empty, there's nothing to
495+
// copy.
496+
if IsEqualNode(root, EmptyTree[0]) {
497+
leaves = make(map[[hashSize]byte]*LeafNode)
498+
return nil
499+
}
500+
501+
// Start recursive collection from the root at depth 0.
502+
leaves, err = collectLeavesRecursive(ctx, tx, root, 0)
503+
if err != nil {
504+
return fmt.Errorf("error collecting leaves: %w", err)
505+
}
506+
507+
return nil
508+
})
509+
if err != nil {
510+
return err
511+
}
512+
513+
// Insert all found leaves into the target tree using InsertMany for
514+
// efficiency.
515+
_, err = targetTree.InsertMany(ctx, leaves)
516+
if err != nil {
517+
return fmt.Errorf("error inserting leaves into "+
518+
"target tree: %w", err)
519+
}
520+
521+
return nil
522+
}
523+
524+
// InsertMany inserts multiple leaf nodes provided in the leaves map within a
525+
// single database transaction.
526+
func (t *CompactedTree) InsertMany(ctx context.Context,
527+
leaves map[[hashSize]byte]*LeafNode) (Tree, error) {
528+
529+
if len(leaves) == 0 {
530+
return t, nil
531+
}
532+
533+
dbErr := t.store.Update(ctx, func(tx TreeStoreUpdateTx) error {
534+
currentRoot, err := tx.RootNode()
535+
if err != nil {
536+
return err
537+
}
538+
rootBranch := currentRoot.(*BranchNode)
539+
540+
for key, leaf := range leaves {
541+
// Check for potential sum overflow before each
542+
// insertion.
543+
sumRoot := rootBranch.NodeSum()
544+
sumLeaf := leaf.NodeSum()
545+
err = CheckSumOverflowUint64(sumRoot, sumLeaf)
546+
if err != nil {
547+
return fmt.Errorf("compact tree leaf insert "+
548+
"sum overflow, root: %d, leaf: %d; %w",
549+
sumRoot, sumLeaf, err)
550+
}
551+
552+
// Insert the leaf using the internal helper.
553+
newRoot, err := t.insert(
554+
tx, &key, 0, rootBranch, leaf,
555+
)
556+
if err != nil {
557+
return fmt.Errorf("error inserting leaf "+
558+
"with key %x: %w", key, err)
559+
}
560+
rootBranch = newRoot
561+
562+
// Update the root within the transaction for
563+
// consistency, even though the insert logic passes the
564+
// root explicitly.
565+
err = tx.UpdateRoot(rootBranch)
566+
if err != nil {
567+
return fmt.Errorf("error updating root "+
568+
"during InsertMany: %w", err)
569+
}
570+
}
571+
572+
// The root is already updated by the last iteration of the
573+
// loop. No final update needed here, but returning nil error
574+
// signals success.
575+
return nil
576+
})
577+
if dbErr != nil {
578+
return nil, dbErr
579+
}
580+
581+
return t, nil
582+
}

mssmt/interface.go

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,4 +30,13 @@ type Tree interface {
3030
// proof. This is noted by the returned `Proof` containing an empty
3131
// leaf.
3232
MerkleProof(ctx context.Context, key [hashSize]byte) (*Proof, error)
33+
34+
// InsertMany inserts multiple leaf nodes provided in the leaves map
35+
// within a single database transaction.
36+
InsertMany(ctx context.Context, leaves map[[hashSize]byte]*LeafNode) (
37+
Tree, error)
38+
39+
// Copy copies all the key-value pairs from the source tree into the
40+
// target tree.
41+
Copy(ctx context.Context, targetTree Tree) error
3342
}

mssmt/tree.go

Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,14 @@ func bitIndex(idx uint8, key *[hashSize]byte) byte {
9797
return (byteVal >> (idx % 8)) & 1
9898
}
9999

100+
// setBit returns a copy of the key with the bit at the given depth set to 1.
101+
func setBit(key [hashSize]byte, depth int) [hashSize]byte {
102+
byteIndex := depth / 8
103+
bitIndex := depth % 8
104+
key[byteIndex] |= (1 << bitIndex)
105+
return key
106+
}
107+
100108
// iterFunc is a type alias for closures to be invoked at every iteration of
101109
// walking through a tree.
102110
type iterFunc = func(height int, current, sibling, parent Node) error
@@ -333,6 +341,162 @@ func (t *FullTree) MerkleProof(ctx context.Context, key [hashSize]byte) (
333341
return NewProof(proof), nil
334342
}
335343

344+
// findLeaves recursively traverses the tree represented by the given node and
345+
// collects all non-empty leaf nodes along with their reconstructed keys.
346+
func findLeaves(ctx context.Context, tx TreeStoreViewTx, node Node,
347+
keyPrefix [hashSize]byte,
348+
depth int) (map[[hashSize]byte]*LeafNode, error) {
349+
350+
// Base case: If it's a leaf node.
351+
if leafNode, ok := node.(*LeafNode); ok {
352+
if leafNode.IsEmpty() {
353+
return make(map[[hashSize]byte]*LeafNode), nil
354+
}
355+
return map[[hashSize]byte]*LeafNode{keyPrefix: leafNode}, nil
356+
}
357+
358+
// Recursive step: If it's a branch node.
359+
if branchNode, ok := node.(*BranchNode); ok {
360+
// Optimization: if the branch is empty, return early.
361+
if IsEqualNode(branchNode, EmptyTree[depth]) {
362+
return make(map[[hashSize]byte]*LeafNode), nil
363+
}
364+
365+
left, right, err := tx.GetChildren(depth, branchNode.NodeHash())
366+
if err != nil {
367+
return nil, fmt.Errorf("error getting children for "+
368+
"branch %s at depth %d: %w",
369+
branchNode.NodeHash(), depth, err)
370+
}
371+
372+
// Recursively find leaves in the left subtree. The key prefix
373+
// remains the same as the 0 bit is implicitly handled by the
374+
// initial keyPrefix state.
375+
leftLeaves, err := findLeaves(
376+
ctx, tx, left, keyPrefix, depth+1,
377+
)
378+
if err != nil {
379+
return nil, err
380+
}
381+
382+
// Recursively find leaves in the right subtree. Set the bit
383+
// corresponding to the current depth in the key prefix.
384+
rightKeyPrefix := setBit(keyPrefix, depth)
385+
386+
rightLeaves, err := findLeaves(
387+
ctx, tx, right, rightKeyPrefix, depth+1,
388+
)
389+
if err != nil {
390+
return nil, err
391+
}
392+
393+
// Merge the results.
394+
for k, v := range rightLeaves {
395+
leftLeaves[k] = v
396+
}
397+
return leftLeaves, nil
398+
}
399+
400+
// Handle unexpected node types.
401+
return nil, fmt.Errorf("unexpected node type %T encountered "+
402+
"during leaf collection", node)
403+
}
404+
405+
// Copy copies all the key-value pairs from the source tree into the target
406+
// tree.
407+
func (t *FullTree) Copy(ctx context.Context, targetTree Tree) error {
408+
var leaves map[[hashSize]byte]*LeafNode
409+
err := t.store.View(ctx, func(tx TreeStoreViewTx) error {
410+
root, err := tx.RootNode()
411+
if err != nil {
412+
return fmt.Errorf("error getting root node: %w", err)
413+
}
414+
415+
// Optimization: If the source tree is empty, there's nothing
416+
// to copy.
417+
if IsEqualNode(root, EmptyTree[0]) {
418+
leaves = make(map[[hashSize]byte]*LeafNode)
419+
return nil
420+
}
421+
422+
leaves, err = findLeaves(ctx, tx, root, [hashSize]byte{}, 0)
423+
if err != nil {
424+
return fmt.Errorf("error finding leaves: %w", err)
425+
}
426+
return nil
427+
})
428+
if err != nil {
429+
return err
430+
}
431+
432+
// Insert all found leaves into the target tree using InsertMany for
433+
// efficiency.
434+
_, err = targetTree.InsertMany(ctx, leaves)
435+
if err != nil {
436+
return fmt.Errorf("error inserting leaves into target "+
437+
"tree: %w", err)
438+
}
439+
440+
return nil
441+
}
442+
443+
// InsertMany inserts multiple leaf nodes provided in the leaves map within a
444+
// single database transaction.
445+
func (t *FullTree) InsertMany(ctx context.Context,
446+
leaves map[[hashSize]byte]*LeafNode) (Tree, error) {
447+
448+
if len(leaves) == 0 {
449+
return t, nil
450+
}
451+
452+
err := t.store.Update(ctx, func(tx TreeStoreUpdateTx) error {
453+
currentRoot, err := tx.RootNode()
454+
if err != nil {
455+
return err
456+
}
457+
rootBranch := currentRoot.(*BranchNode)
458+
459+
for key, leaf := range leaves {
460+
// Check for potential sum overflow before each
461+
// insertion.
462+
sumRoot := rootBranch.NodeSum()
463+
sumLeaf := leaf.NodeSum()
464+
err = CheckSumOverflowUint64(sumRoot, sumLeaf)
465+
if err != nil {
466+
return fmt.Errorf("full tree leaf insert sum "+
467+
"overflow, root: %d, leaf: %d; %w",
468+
sumRoot, sumLeaf, err)
469+
}
470+
471+
// Insert the leaf using the internal helper.
472+
newRoot, err := t.insert(tx, &key, leaf)
473+
if err != nil {
474+
return fmt.Errorf("error inserting leaf "+
475+
"with key %x: %w", key, err)
476+
}
477+
rootBranch = newRoot
478+
479+
// Update the root within the transaction so subsequent
480+
// inserts in this batch read the correct state.
481+
err = tx.UpdateRoot(rootBranch)
482+
if err != nil {
483+
return fmt.Errorf("error updating root "+
484+
"during InsertMany: %w", err)
485+
}
486+
}
487+
488+
// The root is already updated by the last iteration of the
489+
// loop. No final update needed here, but returning nil error
490+
// signals success.
491+
return nil
492+
})
493+
if err != nil {
494+
return nil, err
495+
}
496+
497+
return t, nil
498+
}
499+
336500
// VerifyMerkleProof determines whether a merkle proof for the leaf found at the
337501
// given key is valid.
338502
func VerifyMerkleProof(key [hashSize]byte, leaf *LeafNode, proof *Proof,

0 commit comments

Comments
 (0)