jmt/
restore.rs

1// Copyright (c) The Diem Core Contributors
2// SPDX-License-Identifier: Apache-2.0
3
4//! This module implements the functionality to restore a
5//! [`JellyfishMerkleTree`](crate::JellyfishMerkleTree) from small chunks of
6//! key/value pairs.
7
8use core::marker::PhantomData;
9
10use alloc::boxed::Box;
11use alloc::vec;
12use alloc::{sync::Arc, vec::Vec};
13
14use anyhow::{bail, ensure, Result};
15use mirai_annotations::*;
16
17use crate::{
18    node_type::{
19        get_child_and_sibling_half_start, Child, Children, InternalNode, LeafNode, Node, NodeKey,
20        NodeType,
21    },
22    storage::{NodeBatch, TreeReader, TreeWriter},
23    types::{
24        nibble::{
25            nibble_path::{NibbleIterator, NibblePath},
26            Nibble,
27        },
28        proof::{SparseMerkleInternalNode, SparseMerkleLeafNode, SparseMerkleRangeProof},
29        Version,
30    },
31    Bytes32Ext, KeyHash, OwnedValue, RootHash, SimpleHasher, ValueHash, ROOT_NIBBLE_HEIGHT,
32    SPARSE_MERKLE_PLACEHOLDER_HASH,
33};
34
35#[derive(Clone, Debug, Eq, PartialEq)]
36enum ChildInfo {
37    /// This child is an internal node. The hash of the internal node is stored here if it is
38    /// known, otherwise it is `None`. In the process of restoring a tree, we will only know the
39    /// hash of an internal node after we see all the keys that share the same prefix.
40    Internal {
41        hash: Option<[u8; 32]>,
42        leaf_count: usize,
43    },
44
45    /// This child is a leaf node.
46    Leaf { node: LeafNode },
47}
48
49impl ChildInfo {
50    /// Converts `self` to a child, assuming the hash is known if it's an internal node.
51    fn into_child<H: SimpleHasher>(self, version: Version) -> Child {
52        match self {
53            Self::Internal { hash, leaf_count } => Child::new(
54                hash.expect("Must have been initialized."),
55                version,
56                NodeType::Internal { leaf_count },
57            ),
58            Self::Leaf { node } => Child::new(node.hash::<H>(), version, NodeType::Leaf),
59        }
60    }
61}
62
63#[derive(Clone, Debug)]
64struct InternalInfo {
65    /// The node key of this internal node.
66    node_key: NodeKey,
67
68    /// The existing children. Every time a child appears, the corresponding position will be set
69    /// to `Some`.
70    children: [Option<ChildInfo>; 16],
71}
72
73impl InternalInfo {
74    /// Creates an empty internal node with no children.
75    fn new_empty(node_key: NodeKey) -> Self {
76        Self {
77            node_key,
78            children: Default::default(),
79        }
80    }
81
82    fn set_child(&mut self, index: usize, child_info: ChildInfo) {
83        precondition!(index < 16);
84        self.children[index] = Some(child_info);
85    }
86
87    /// Converts `self` to an internal node, assuming all of its children are already known and
88    /// fully initialized.
89    fn into_internal_node<H: SimpleHasher>(mut self, version: Version) -> (NodeKey, InternalNode) {
90        let mut children = Children::new();
91
92        // Calling `into_iter` on an array is equivalent to calling `iter`:
93        // https://github.com/rust-lang/rust/issues/25725. So we use `iter_mut` and `take`.
94        for (index, child_info_option) in self.children.iter_mut().enumerate() {
95            if let Some(child_info) = child_info_option.take() {
96                children.insert((index as u8).into(), child_info.into_child::<H>(version));
97            }
98        }
99
100        (self.node_key, InternalNode::new(children))
101    }
102}
103
104/// Implements the functionality to restore a
105/// [`JellyfishMerkleTree`](crate::JellyfishMerkleTree) from small chunks of
106/// key-value pairs.
107pub struct JellyfishMerkleRestore<H: SimpleHasher> {
108    /// The underlying storage.
109    store: Arc<dyn TreeWriter>,
110
111    /// The version of the tree we are restoring.
112    version: Version,
113
114    /// The nodes we have partially restored. Each `partial_nodes[i-1]` is the parent of
115    /// `partial_nodes[i]`. If a node `partial_nodes[i-1]` has multiple children, only the
116    /// rightmost known child will appear here as `partial_nodes[i]`, because any other children on
117    /// the left would have been frozen.
118    ///
119    /// At any point in time, the structure looks like the following:
120    ///
121    /// ```text
122    /// +----+----+----+----+----+----+----+----+
123    /// |    |    |    |    |    |    |    | C  |  partial_nodes[0]
124    /// +----+----+----+----+----+----+----+----+
125    ///   |         |              |
126    ///   |         |              |
127    ///   |         |              |
128    ///   v         v              v
129    /// Frozen    Frozen     +----+----+----+----+----+----+----+----+
130    ///                      |    |    |    | B  |    |    | A  |    |  partial_nodes[1]
131    ///                      +----+----+----+----+----+----+----+----+
132    ///                             |         |
133    ///                             |         |
134    ///                             |         |
135    ///                             v         v
136    ///                            Frozen    Previously inserted account
137    /// ```
138    ///
139    /// We insert the accounts from left to right. So if the next account appears at position `A`,
140    /// it will cause the leaf at position `B` to be frozen. If it appears at position `B`, it
141    /// might cause a few internal nodes to be created additionally. If it appears at position `C`,
142    /// it will also cause `partial_nodes[1]` to be added to `frozen_nodes` as an internal node and
143    /// be removed from `partial_nodes`.
144    partial_nodes: Vec<InternalInfo>,
145
146    /// The nodes that have been fully restored and are ready to be written to storage.
147    frozen_nodes: NodeBatch,
148
149    /// The most recently added leaf. This is used to ensure the keys come in increasing order and
150    /// do proof verification.
151    previous_leaf: Option<LeafNode>,
152
153    /// The number of keys we have received since the most recent restart.
154    num_keys_received: u64,
155
156    /// When the restoration process finishes, we expect the tree to have this root hash.
157    expected_root_hash: RootHash,
158
159    _phantom_hasher: PhantomData<H>,
160}
161
162impl<H: SimpleHasher> JellyfishMerkleRestore<H> {
163    pub fn new<D: 'static + TreeReader + TreeWriter>(
164        store: Arc<D>,
165        version: Version,
166        expected_root_hash: RootHash,
167    ) -> Result<Self> {
168        let tree_reader = Arc::clone(&store);
169        let (partial_nodes, previous_leaf) =
170            if let Some((node_key, leaf_node)) = tree_reader.get_rightmost_leaf()? {
171                // TODO: confirm rightmost leaf is at the desired version
172                // If the system crashed in the middle of the previous restoration attempt, we need
173                // to recover the partial nodes to the state right before the crash.
174                (
175                    Self::recover_partial_nodes(tree_reader.as_ref(), version, node_key)?,
176                    Some(leaf_node),
177                )
178            } else {
179                (
180                    vec![InternalInfo::new_empty(NodeKey::new_empty_path(version))],
181                    None,
182                )
183            };
184
185        Ok(Self {
186            store,
187            version,
188            partial_nodes,
189            frozen_nodes: Default::default(),
190            previous_leaf,
191            num_keys_received: 0,
192            expected_root_hash,
193            _phantom_hasher: Default::default(),
194        })
195    }
196
197    pub fn new_overwrite<D: 'static + TreeWriter>(
198        store: Arc<D>,
199        version: Version,
200        expected_root_hash: RootHash,
201    ) -> Result<Self> {
202        Ok(Self {
203            store,
204            version,
205            partial_nodes: vec![InternalInfo::new_empty(NodeKey::new_empty_path(version))],
206            frozen_nodes: Default::default(),
207            previous_leaf: None,
208            num_keys_received: 0,
209            expected_root_hash,
210            _phantom_hasher: Default::default(),
211        })
212    }
213
214    /// Recovers partial nodes from storage. We do this by looking at all the ancestors of the
215    /// rightmost leaf. The ones do not exist in storage are the partial nodes.
216    fn recover_partial_nodes(
217        store: &dyn TreeReader,
218        version: Version,
219        rightmost_leaf_node_key: NodeKey,
220    ) -> Result<Vec<InternalInfo>> {
221        ensure!(
222            !rightmost_leaf_node_key.nibble_path().is_empty(),
223            "Root node would not be written until entire restoration process has completed \
224             successfully.",
225        );
226
227        // Start from the parent of the rightmost leaf. If this internal node exists in storage, it
228        // is not a partial node. Go to the parent node and repeat until we see a node that does
229        // not exist. This node and all its ancestors will be the partial nodes.
230        let mut node_key = rightmost_leaf_node_key.gen_parent_node_key();
231        while store.get_node_option(&node_key)?.is_some() {
232            node_key = node_key.gen_parent_node_key();
233        }
234
235        // Next we reconstruct all the partial nodes up to the root node, starting from the bottom.
236        // For all of them, we scan all its possible child positions and see if there is one at
237        // each position. If the node is not the bottom one, there is additionally a partial node
238        // child at the position `previous_child_index`.
239        let mut partial_nodes = vec![];
240        // Initialize `previous_child_index` to `None` for the first iteration of the loop so the
241        // code below treats it differently.
242        let mut previous_child_index = None;
243
244        loop {
245            let mut internal_info = InternalInfo::new_empty(node_key.clone());
246
247            for i in 0..previous_child_index.unwrap_or(16) {
248                let child_node_key = node_key.gen_child_node_key(version, (i as u8).into());
249                if let Some(node) = store.get_node_option(&child_node_key)? {
250                    let child_info = match node {
251                        Node::Internal(internal_node) => ChildInfo::Internal {
252                            hash: Some(internal_node.hash::<H>()),
253                            leaf_count: internal_node.leaf_count(),
254                        },
255                        Node::Leaf(leaf_node) => ChildInfo::Leaf { node: leaf_node },
256                        Node::Null => bail!("Null node should not appear in storage."),
257                    };
258                    internal_info.set_child(i, child_info);
259                }
260            }
261
262            // If this is not the lowest partial node, it will have a partial node child at
263            // `previous_child_index`. Set the hash of this child to `None` because it is a
264            // partial node and we do not know its hash yet. For the lowest partial node, we just
265            // find all its known children from storage in the loop above.
266            if let Some(index) = previous_child_index {
267                internal_info.set_child(
268                    index,
269                    ChildInfo::Internal {
270                        hash: None,
271                        leaf_count: 0,
272                    },
273                );
274            }
275
276            partial_nodes.push(internal_info);
277            if node_key.nibble_path().is_empty() {
278                break;
279            }
280            previous_child_index = node_key.nibble_path().last().map(|x| u8::from(x) as usize);
281            node_key = node_key.gen_parent_node_key();
282        }
283
284        partial_nodes.reverse();
285        Ok(partial_nodes)
286    }
287
288    /// Restores a chunk of accounts. This function will verify that the given chunk is correct
289    /// using the proof and root hash, then write things to storage. If the chunk is invalid, an
290    /// error will be returned and nothing will be written to storage.
291    fn add_chunk_impl(
292        &mut self,
293        chunk: Vec<(KeyHash, OwnedValue)>,
294        proof: SparseMerkleRangeProof<H>,
295    ) -> Result<()> {
296        ensure!(!chunk.is_empty(), "Should not add empty chunks.");
297
298        for (key, value) in chunk {
299            if let Some(ref prev_leaf) = self.previous_leaf {
300                ensure!(
301                    key > prev_leaf.key_hash(),
302                    "Account keys must come in increasing order.",
303                );
304            }
305            let value_hash = ValueHash::with::<H>(value.as_slice());
306            self.frozen_nodes.insert_value(self.version, key, value);
307
308            self.add_one(key, value_hash);
309            self.previous_leaf.replace(LeafNode::new(key, value_hash));
310            self.num_keys_received += 1;
311        }
312
313        // Verify what we have added so far is all correct.
314        self.verify(proof)?;
315
316        // Write the frozen nodes to storage.
317        self.store.write_node_batch(&self.frozen_nodes)?;
318        self.frozen_nodes.clear();
319
320        Ok(())
321    }
322
323    /// Restores one account.
324    fn add_one(&mut self, new_key: KeyHash, value_hash: ValueHash) {
325        let nibble_path = NibblePath::new(new_key.0.to_vec());
326        let mut nibbles = nibble_path.nibbles();
327
328        for i in 0..ROOT_NIBBLE_HEIGHT {
329            let child_index = u8::from(nibbles.next().expect("This nibble must exist.")) as usize;
330
331            assert!(i < self.partial_nodes.len());
332            match self.partial_nodes[i].children[child_index] {
333                Some(ref child_info) => {
334                    // If there exists an internal node at this position, we just continue the loop
335                    // with the next nibble. Here we deal with the leaf case.
336                    if let ChildInfo::Leaf { node } = child_info {
337                        assert_eq!(
338                            i,
339                            self.partial_nodes.len() - 1,
340                            "If we see a leaf, there will be no more partial internal nodes on \
341                             lower level, since they would have been frozen.",
342                        );
343
344                        let existing_leaf = node.clone();
345                        self.insert_at_leaf(
346                            child_index,
347                            existing_leaf,
348                            new_key,
349                            value_hash,
350                            nibbles,
351                        );
352                        break;
353                    }
354                }
355                None => {
356                    // This means that we are going to put a leaf in this position. For all the
357                    // descendants on the left, they are now frozen.
358                    self.freeze(i + 1);
359
360                    // Mark this position as a leaf child.
361                    self.partial_nodes[i].set_child(
362                        child_index,
363                        ChildInfo::Leaf {
364                            node: LeafNode::new(new_key, value_hash),
365                        },
366                    );
367
368                    // We do not add this leaf node to self.frozen_nodes because we don't know its
369                    // node key yet. We will know its node key when the next account comes.
370                    break;
371                }
372            }
373        }
374    }
375
376    /// Inserts a new account at the position of the existing leaf node. We may need to create
377    /// multiple internal nodes depending on the length of the common prefix of the existing key
378    /// and the new key.
379    fn insert_at_leaf(
380        &mut self,
381        child_index: usize,
382        existing_leaf: LeafNode,
383        new_key: KeyHash,
384        value_hash: ValueHash,
385        mut remaining_nibbles: NibbleIterator,
386    ) {
387        let num_existing_partial_nodes = self.partial_nodes.len();
388
389        // The node at this position becomes an internal node. Since we may insert more nodes at
390        // this position in the future, we do not know its hash yet.
391        self.partial_nodes[num_existing_partial_nodes - 1].set_child(
392            child_index,
393            ChildInfo::Internal {
394                hash: None,
395                leaf_count: 0,
396            },
397        );
398
399        // Next we build the new internal nodes from top to bottom. All these internal node except
400        // the bottom one will now have a single internal node child.
401        let common_prefix_len = existing_leaf
402            .key_hash()
403            .0
404            .common_prefix_nibbles_len(&new_key.0);
405        for _ in num_existing_partial_nodes..common_prefix_len {
406            let visited_nibbles = remaining_nibbles.visited_nibbles().collect();
407            let next_nibble = remaining_nibbles.next().expect("This nibble must exist.");
408            let new_node_key = NodeKey::new(self.version, visited_nibbles);
409
410            let mut internal_info = InternalInfo::new_empty(new_node_key);
411            internal_info.set_child(
412                u8::from(next_nibble) as usize,
413                ChildInfo::Internal {
414                    hash: None,
415                    leaf_count: 0,
416                },
417            );
418            self.partial_nodes.push(internal_info);
419        }
420
421        // The last internal node will have two leaf node children.
422        let visited_nibbles = remaining_nibbles.visited_nibbles().collect();
423        let new_node_key = NodeKey::new(self.version, visited_nibbles);
424        let mut internal_info = InternalInfo::new_empty(new_node_key);
425
426        // Next we put the existing leaf as a child of this internal node.
427        let existing_child_index = existing_leaf.key_hash().0.get_nibble(common_prefix_len);
428        internal_info.set_child(
429            u8::from(existing_child_index) as usize,
430            ChildInfo::Leaf {
431                node: existing_leaf,
432            },
433        );
434
435        // Do not set the new child for now. We always call `freeze` first, then set the new child
436        // later, because this way it's easier in `freeze` to find the correct leaf to freeze --
437        // it's always the rightmost leaf on the lowest level.
438        self.partial_nodes.push(internal_info);
439        self.freeze(self.partial_nodes.len());
440
441        // Now we set the new child.
442        let new_child_index = new_key.0.get_nibble(common_prefix_len);
443        assert!(
444            new_child_index > existing_child_index,
445            "New leaf must be on the right.",
446        );
447        self.partial_nodes
448            .last_mut()
449            .expect("This node must exist.")
450            .set_child(
451                u8::from(new_child_index) as usize,
452                ChildInfo::Leaf {
453                    node: LeafNode::new(new_key, value_hash),
454                },
455            );
456    }
457
458    /// Puts the nodes that will not be changed later in `self.frozen_nodes`.
459    fn freeze(&mut self, num_remaining_partial_nodes: usize) {
460        self.freeze_previous_leaf();
461        self.freeze_internal_nodes(num_remaining_partial_nodes);
462    }
463
464    /// Freezes the previously added leaf node. It should always be the rightmost leaf node on the
465    /// lowest level, inserted in the previous `add_one` call.
466    fn freeze_previous_leaf(&mut self) {
467        // If this is the very first key, there is no previous leaf to freeze.
468        if self.num_keys_received == 0 {
469            return;
470        }
471
472        let last_node = self
473            .partial_nodes
474            .last()
475            .expect("Must have at least one partial node.");
476        let rightmost_child_index = last_node
477            .children
478            .iter()
479            .rposition(|x| x.is_some())
480            .expect("Must have at least one child.");
481
482        match last_node.children[rightmost_child_index] {
483            Some(ChildInfo::Leaf { ref node }) => {
484                let child_node_key = last_node
485                    .node_key
486                    .gen_child_node_key(self.version, (rightmost_child_index as u8).into());
487                self.frozen_nodes
488                    .insert_node(child_node_key, node.clone().into());
489            }
490            _ => panic!("Must have at least one child and must not have further internal nodes."),
491        }
492    }
493
494    /// Freeze extra internal nodes. Only `num_remaining_nodes` partial internal nodes will be kept
495    /// and the ones on the lower level will be frozen.
496    fn freeze_internal_nodes(&mut self, num_remaining_nodes: usize) {
497        while self.partial_nodes.len() > num_remaining_nodes {
498            let last_node = self.partial_nodes.pop().expect("This node must exist.");
499            let (node_key, internal_node) = last_node.into_internal_node::<H>(self.version);
500            // Keep the hash of this node before moving it into `frozen_nodes`, so we can update
501            // its parent later.
502            let node_hash = internal_node.hash::<H>();
503            let node_leaf_count = internal_node.leaf_count();
504            self.frozen_nodes
505                .insert_node(node_key, internal_node.into());
506
507            // Now that we have computed the hash of the internal node above, we will also update
508            // its parent unless it is root node.
509            if let Some(parent_node) = self.partial_nodes.last_mut() {
510                // This internal node must be the rightmost child of its parent at the moment.
511                let rightmost_child_index = parent_node
512                    .children
513                    .iter()
514                    .rposition(|x| x.is_some())
515                    .expect("Must have at least one child.");
516
517                match parent_node.children[rightmost_child_index] {
518                    Some(ChildInfo::Internal {
519                        ref mut hash,
520                        ref mut leaf_count,
521                    }) => {
522                        assert_eq!(hash.replace(node_hash), None);
523                        assert_eq!(*leaf_count, 0);
524                        *leaf_count = node_leaf_count;
525                    }
526                    _ => panic!(
527                        "Must have at least one child and the rightmost child must not be a leaf."
528                    ),
529                }
530            }
531        }
532    }
533
534    /// Verifies that all accounts that have been added so far (from the leftmost one to
535    /// `self.previous_leaf`) are correct, i.e., we are able to construct `self.expected_root_hash`
536    /// by combining all existing accounts and `proof`.
537    #[allow(clippy::collapsible_if)]
538    fn verify(&self, proof: SparseMerkleRangeProof<H>) -> Result<()> {
539        let previous_leaf = self
540            .previous_leaf
541            .as_ref()
542            .expect("The previous leaf must exist.");
543        let previous_key = previous_leaf.key_hash();
544
545        // If we have all siblings on the path from root to `previous_key`, we should be able to
546        // compute the root hash. The siblings on the right are already in the proof. Now we
547        // compute the siblings on the left side, which represent all the accounts that have ever
548        // been added.
549        let mut left_siblings = vec![];
550
551        // The following process might add some extra placeholder siblings on the left, but it is
552        // nontrivial to determine when the loop should stop. So instead we just add these
553        // siblings for now and get rid of them in the next step.
554        let mut num_visited_right_siblings = 0;
555        for (i, bit) in previous_key.0.iter_bits().enumerate() {
556            if bit {
557                // This node is a right child and there should be a sibling on the left.
558                let sibling = if i >= self.partial_nodes.len() * 4 {
559                    SPARSE_MERKLE_PLACEHOLDER_HASH
560                } else {
561                    Self::compute_left_sibling(
562                        &self.partial_nodes[i / 4],
563                        previous_key.0.get_nibble(i / 4),
564                        (3 - i % 4) as u8,
565                    )
566                };
567                left_siblings.push(sibling);
568            } else {
569                // This node is a left child and there should be a sibling on the right.
570                num_visited_right_siblings += 1;
571            }
572        }
573        ensure!(
574            num_visited_right_siblings >= proof.right_siblings().len(),
575            "Too many right siblings in the proof.",
576        );
577
578        // Now we remove any extra placeholder siblings at the bottom. We keep removing the last
579        // sibling if 1) it's a placeholder 2) it's a sibling on the left.
580        for bit in previous_key.0.iter_bits().rev() {
581            if bit {
582                if *left_siblings.last().expect("This sibling must exist.")
583                    == SPARSE_MERKLE_PLACEHOLDER_HASH
584                {
585                    left_siblings.pop();
586                } else {
587                    break;
588                }
589            } else if num_visited_right_siblings > proof.right_siblings().len() {
590                num_visited_right_siblings -= 1;
591            } else {
592                break;
593            }
594        }
595
596        // Left siblings must use the same ordering as the right siblings in the proof
597        left_siblings.reverse();
598
599        // Verify the proof now that we have all the siblings
600        proof.verify(
601            self.expected_root_hash,
602            SparseMerkleLeafNode::new(previous_key, previous_leaf.value_hash()),
603            left_siblings,
604        )
605    }
606
607    /// Computes the sibling on the left for the `n`-th child.
608    fn compute_left_sibling(partial_node: &InternalInfo, n: Nibble, height: u8) -> [u8; 32] {
609        assert!(height < 4);
610        let width = 1usize << height;
611        let start = get_child_and_sibling_half_start(n, height).1 as usize;
612        Self::compute_left_sibling_impl(&partial_node.children[start..start + width]).0
613    }
614
615    /// Returns the hash for given portion of the subtree and whether this part is a leaf node.
616    fn compute_left_sibling_impl(children: &[Option<ChildInfo>]) -> ([u8; 32], bool) {
617        assert!(!children.is_empty());
618
619        let num_children = children.len();
620        assert!(num_children.is_power_of_two());
621
622        if num_children == 1 {
623            match &children[0] {
624                Some(ChildInfo::Internal { hash, .. }) => {
625                    (*hash.as_ref().expect("The hash must be known."), false)
626                }
627                Some(ChildInfo::Leaf { node }) => (node.hash::<H>(), true),
628                None => (SPARSE_MERKLE_PLACEHOLDER_HASH, true),
629            }
630        } else {
631            let (left_hash, left_is_leaf) =
632                Self::compute_left_sibling_impl(&children[..num_children / 2]);
633            let (right_hash, right_is_leaf) =
634                Self::compute_left_sibling_impl(&children[num_children / 2..]);
635
636            if left_hash == SPARSE_MERKLE_PLACEHOLDER_HASH && right_is_leaf {
637                (right_hash, true)
638            } else if left_is_leaf && right_hash == SPARSE_MERKLE_PLACEHOLDER_HASH {
639                (left_hash, true)
640            } else {
641                (
642                    SparseMerkleInternalNode::new(left_hash, right_hash).hash::<H>(),
643                    false,
644                )
645            }
646        }
647    }
648
649    /// Finishes the restoration process. This tells the code that there is no more account,
650    /// otherwise we can not freeze the rightmost leaf and its ancestors.
651    fn finish_impl(mut self) -> Result<()> {
652        // Deal with the special case when the entire tree has a single leaf.
653        if self.partial_nodes.len() == 1 {
654            let mut num_children = 0;
655            let mut leaf = None;
656            for i in 0..16 {
657                if let Some(ref child_info) = self.partial_nodes[0].children[i] {
658                    num_children += 1;
659                    if let ChildInfo::Leaf { node } = child_info {
660                        leaf = Some(node.clone());
661                    }
662                }
663            }
664
665            if num_children == 1 {
666                if let Some(node) = leaf {
667                    let node_key = NodeKey::new_empty_path(self.version);
668                    assert!(self.frozen_nodes.is_empty());
669                    self.frozen_nodes.insert_node(node_key, node.into());
670                    self.store.write_node_batch(&self.frozen_nodes)?;
671                    return Ok(());
672                }
673            }
674        }
675
676        self.freeze(0);
677        self.store.write_node_batch(&self.frozen_nodes)
678    }
679}
680
681/// The interface used with [`JellyfishMerkleRestore`], taken from the Diem `storage-interface` crate.
682pub trait StateSnapshotReceiver<H: SimpleHasher> {
683    fn add_chunk(
684        &mut self,
685        chunk: Vec<(KeyHash, OwnedValue)>,
686        proof: SparseMerkleRangeProof<H>,
687    ) -> Result<()>;
688
689    fn finish(self) -> Result<()>;
690
691    fn finish_box(self: Box<Self>) -> Result<()>;
692}
693
694impl<H: SimpleHasher> StateSnapshotReceiver<H> for JellyfishMerkleRestore<H> {
695    fn add_chunk(
696        &mut self,
697        chunk: Vec<(KeyHash, OwnedValue)>,
698        proof: SparseMerkleRangeProof<H>,
699    ) -> Result<()> {
700        self.add_chunk_impl(chunk, proof)
701    }
702
703    fn finish(self) -> Result<()> {
704        self.finish_impl()
705    }
706
707    fn finish_box(self: Box<Self>) -> Result<()> {
708        self.finish_impl()
709    }
710}