1use core::marker::PhantomData;
9
10use alloc::boxed::Box;
11use alloc::vec;
12use alloc::{sync::Arc, vec::Vec};
13
14use anyhow::{bail, ensure, Result};
15use mirai_annotations::*;
16
17use crate::{
18 node_type::{
19 get_child_and_sibling_half_start, Child, Children, InternalNode, LeafNode, Node, NodeKey,
20 NodeType,
21 },
22 storage::{NodeBatch, TreeReader, TreeWriter},
23 types::{
24 nibble::{
25 nibble_path::{NibbleIterator, NibblePath},
26 Nibble,
27 },
28 proof::{SparseMerkleInternalNode, SparseMerkleLeafNode, SparseMerkleRangeProof},
29 Version,
30 },
31 Bytes32Ext, KeyHash, OwnedValue, RootHash, SimpleHasher, ValueHash, ROOT_NIBBLE_HEIGHT,
32 SPARSE_MERKLE_PLACEHOLDER_HASH,
33};
34
35#[derive(Clone, Debug, Eq, PartialEq)]
36enum ChildInfo {
37 Internal {
41 hash: Option<[u8; 32]>,
42 leaf_count: usize,
43 },
44
45 Leaf { node: LeafNode },
47}
48
49impl ChildInfo {
50 fn into_child<H: SimpleHasher>(self, version: Version) -> Child {
52 match self {
53 Self::Internal { hash, leaf_count } => Child::new(
54 hash.expect("Must have been initialized."),
55 version,
56 NodeType::Internal { leaf_count },
57 ),
58 Self::Leaf { node } => Child::new(node.hash::<H>(), version, NodeType::Leaf),
59 }
60 }
61}
62
63#[derive(Clone, Debug)]
64struct InternalInfo {
65 node_key: NodeKey,
67
68 children: [Option<ChildInfo>; 16],
71}
72
73impl InternalInfo {
74 fn new_empty(node_key: NodeKey) -> Self {
76 Self {
77 node_key,
78 children: Default::default(),
79 }
80 }
81
82 fn set_child(&mut self, index: usize, child_info: ChildInfo) {
83 precondition!(index < 16);
84 self.children[index] = Some(child_info);
85 }
86
87 fn into_internal_node<H: SimpleHasher>(mut self, version: Version) -> (NodeKey, InternalNode) {
90 let mut children = Children::new();
91
92 for (index, child_info_option) in self.children.iter_mut().enumerate() {
95 if let Some(child_info) = child_info_option.take() {
96 children.insert((index as u8).into(), child_info.into_child::<H>(version));
97 }
98 }
99
100 (self.node_key, InternalNode::new(children))
101 }
102}
103
104pub struct JellyfishMerkleRestore<H: SimpleHasher> {
108 store: Arc<dyn TreeWriter>,
110
111 version: Version,
113
114 partial_nodes: Vec<InternalInfo>,
145
146 frozen_nodes: NodeBatch,
148
149 previous_leaf: Option<LeafNode>,
152
153 num_keys_received: u64,
155
156 expected_root_hash: RootHash,
158
159 _phantom_hasher: PhantomData<H>,
160}
161
162impl<H: SimpleHasher> JellyfishMerkleRestore<H> {
163 pub fn new<D: 'static + TreeReader + TreeWriter>(
164 store: Arc<D>,
165 version: Version,
166 expected_root_hash: RootHash,
167 ) -> Result<Self> {
168 let tree_reader = Arc::clone(&store);
169 let (partial_nodes, previous_leaf) =
170 if let Some((node_key, leaf_node)) = tree_reader.get_rightmost_leaf()? {
171 (
175 Self::recover_partial_nodes(tree_reader.as_ref(), version, node_key)?,
176 Some(leaf_node),
177 )
178 } else {
179 (
180 vec![InternalInfo::new_empty(NodeKey::new_empty_path(version))],
181 None,
182 )
183 };
184
185 Ok(Self {
186 store,
187 version,
188 partial_nodes,
189 frozen_nodes: Default::default(),
190 previous_leaf,
191 num_keys_received: 0,
192 expected_root_hash,
193 _phantom_hasher: Default::default(),
194 })
195 }
196
197 pub fn new_overwrite<D: 'static + TreeWriter>(
198 store: Arc<D>,
199 version: Version,
200 expected_root_hash: RootHash,
201 ) -> Result<Self> {
202 Ok(Self {
203 store,
204 version,
205 partial_nodes: vec![InternalInfo::new_empty(NodeKey::new_empty_path(version))],
206 frozen_nodes: Default::default(),
207 previous_leaf: None,
208 num_keys_received: 0,
209 expected_root_hash,
210 _phantom_hasher: Default::default(),
211 })
212 }
213
214 fn recover_partial_nodes(
217 store: &dyn TreeReader,
218 version: Version,
219 rightmost_leaf_node_key: NodeKey,
220 ) -> Result<Vec<InternalInfo>> {
221 ensure!(
222 !rightmost_leaf_node_key.nibble_path().is_empty(),
223 "Root node would not be written until entire restoration process has completed \
224 successfully.",
225 );
226
227 let mut node_key = rightmost_leaf_node_key.gen_parent_node_key();
231 while store.get_node_option(&node_key)?.is_some() {
232 node_key = node_key.gen_parent_node_key();
233 }
234
235 let mut partial_nodes = vec![];
240 let mut previous_child_index = None;
243
244 loop {
245 let mut internal_info = InternalInfo::new_empty(node_key.clone());
246
247 for i in 0..previous_child_index.unwrap_or(16) {
248 let child_node_key = node_key.gen_child_node_key(version, (i as u8).into());
249 if let Some(node) = store.get_node_option(&child_node_key)? {
250 let child_info = match node {
251 Node::Internal(internal_node) => ChildInfo::Internal {
252 hash: Some(internal_node.hash::<H>()),
253 leaf_count: internal_node.leaf_count(),
254 },
255 Node::Leaf(leaf_node) => ChildInfo::Leaf { node: leaf_node },
256 Node::Null => bail!("Null node should not appear in storage."),
257 };
258 internal_info.set_child(i, child_info);
259 }
260 }
261
262 if let Some(index) = previous_child_index {
267 internal_info.set_child(
268 index,
269 ChildInfo::Internal {
270 hash: None,
271 leaf_count: 0,
272 },
273 );
274 }
275
276 partial_nodes.push(internal_info);
277 if node_key.nibble_path().is_empty() {
278 break;
279 }
280 previous_child_index = node_key.nibble_path().last().map(|x| u8::from(x) as usize);
281 node_key = node_key.gen_parent_node_key();
282 }
283
284 partial_nodes.reverse();
285 Ok(partial_nodes)
286 }
287
288 fn add_chunk_impl(
292 &mut self,
293 chunk: Vec<(KeyHash, OwnedValue)>,
294 proof: SparseMerkleRangeProof<H>,
295 ) -> Result<()> {
296 ensure!(!chunk.is_empty(), "Should not add empty chunks.");
297
298 for (key, value) in chunk {
299 if let Some(ref prev_leaf) = self.previous_leaf {
300 ensure!(
301 key > prev_leaf.key_hash(),
302 "Account keys must come in increasing order.",
303 );
304 }
305 let value_hash = ValueHash::with::<H>(value.as_slice());
306 self.frozen_nodes.insert_value(self.version, key, value);
307
308 self.add_one(key, value_hash);
309 self.previous_leaf.replace(LeafNode::new(key, value_hash));
310 self.num_keys_received += 1;
311 }
312
313 self.verify(proof)?;
315
316 self.store.write_node_batch(&self.frozen_nodes)?;
318 self.frozen_nodes.clear();
319
320 Ok(())
321 }
322
323 fn add_one(&mut self, new_key: KeyHash, value_hash: ValueHash) {
325 let nibble_path = NibblePath::new(new_key.0.to_vec());
326 let mut nibbles = nibble_path.nibbles();
327
328 for i in 0..ROOT_NIBBLE_HEIGHT {
329 let child_index = u8::from(nibbles.next().expect("This nibble must exist.")) as usize;
330
331 assert!(i < self.partial_nodes.len());
332 match self.partial_nodes[i].children[child_index] {
333 Some(ref child_info) => {
334 if let ChildInfo::Leaf { node } = child_info {
337 assert_eq!(
338 i,
339 self.partial_nodes.len() - 1,
340 "If we see a leaf, there will be no more partial internal nodes on \
341 lower level, since they would have been frozen.",
342 );
343
344 let existing_leaf = node.clone();
345 self.insert_at_leaf(
346 child_index,
347 existing_leaf,
348 new_key,
349 value_hash,
350 nibbles,
351 );
352 break;
353 }
354 }
355 None => {
356 self.freeze(i + 1);
359
360 self.partial_nodes[i].set_child(
362 child_index,
363 ChildInfo::Leaf {
364 node: LeafNode::new(new_key, value_hash),
365 },
366 );
367
368 break;
371 }
372 }
373 }
374 }
375
376 fn insert_at_leaf(
380 &mut self,
381 child_index: usize,
382 existing_leaf: LeafNode,
383 new_key: KeyHash,
384 value_hash: ValueHash,
385 mut remaining_nibbles: NibbleIterator,
386 ) {
387 let num_existing_partial_nodes = self.partial_nodes.len();
388
389 self.partial_nodes[num_existing_partial_nodes - 1].set_child(
392 child_index,
393 ChildInfo::Internal {
394 hash: None,
395 leaf_count: 0,
396 },
397 );
398
399 let common_prefix_len = existing_leaf
402 .key_hash()
403 .0
404 .common_prefix_nibbles_len(&new_key.0);
405 for _ in num_existing_partial_nodes..common_prefix_len {
406 let visited_nibbles = remaining_nibbles.visited_nibbles().collect();
407 let next_nibble = remaining_nibbles.next().expect("This nibble must exist.");
408 let new_node_key = NodeKey::new(self.version, visited_nibbles);
409
410 let mut internal_info = InternalInfo::new_empty(new_node_key);
411 internal_info.set_child(
412 u8::from(next_nibble) as usize,
413 ChildInfo::Internal {
414 hash: None,
415 leaf_count: 0,
416 },
417 );
418 self.partial_nodes.push(internal_info);
419 }
420
421 let visited_nibbles = remaining_nibbles.visited_nibbles().collect();
423 let new_node_key = NodeKey::new(self.version, visited_nibbles);
424 let mut internal_info = InternalInfo::new_empty(new_node_key);
425
426 let existing_child_index = existing_leaf.key_hash().0.get_nibble(common_prefix_len);
428 internal_info.set_child(
429 u8::from(existing_child_index) as usize,
430 ChildInfo::Leaf {
431 node: existing_leaf,
432 },
433 );
434
435 self.partial_nodes.push(internal_info);
439 self.freeze(self.partial_nodes.len());
440
441 let new_child_index = new_key.0.get_nibble(common_prefix_len);
443 assert!(
444 new_child_index > existing_child_index,
445 "New leaf must be on the right.",
446 );
447 self.partial_nodes
448 .last_mut()
449 .expect("This node must exist.")
450 .set_child(
451 u8::from(new_child_index) as usize,
452 ChildInfo::Leaf {
453 node: LeafNode::new(new_key, value_hash),
454 },
455 );
456 }
457
458 fn freeze(&mut self, num_remaining_partial_nodes: usize) {
460 self.freeze_previous_leaf();
461 self.freeze_internal_nodes(num_remaining_partial_nodes);
462 }
463
464 fn freeze_previous_leaf(&mut self) {
467 if self.num_keys_received == 0 {
469 return;
470 }
471
472 let last_node = self
473 .partial_nodes
474 .last()
475 .expect("Must have at least one partial node.");
476 let rightmost_child_index = last_node
477 .children
478 .iter()
479 .rposition(|x| x.is_some())
480 .expect("Must have at least one child.");
481
482 match last_node.children[rightmost_child_index] {
483 Some(ChildInfo::Leaf { ref node }) => {
484 let child_node_key = last_node
485 .node_key
486 .gen_child_node_key(self.version, (rightmost_child_index as u8).into());
487 self.frozen_nodes
488 .insert_node(child_node_key, node.clone().into());
489 }
490 _ => panic!("Must have at least one child and must not have further internal nodes."),
491 }
492 }
493
494 fn freeze_internal_nodes(&mut self, num_remaining_nodes: usize) {
497 while self.partial_nodes.len() > num_remaining_nodes {
498 let last_node = self.partial_nodes.pop().expect("This node must exist.");
499 let (node_key, internal_node) = last_node.into_internal_node::<H>(self.version);
500 let node_hash = internal_node.hash::<H>();
503 let node_leaf_count = internal_node.leaf_count();
504 self.frozen_nodes
505 .insert_node(node_key, internal_node.into());
506
507 if let Some(parent_node) = self.partial_nodes.last_mut() {
510 let rightmost_child_index = parent_node
512 .children
513 .iter()
514 .rposition(|x| x.is_some())
515 .expect("Must have at least one child.");
516
517 match parent_node.children[rightmost_child_index] {
518 Some(ChildInfo::Internal {
519 ref mut hash,
520 ref mut leaf_count,
521 }) => {
522 assert_eq!(hash.replace(node_hash), None);
523 assert_eq!(*leaf_count, 0);
524 *leaf_count = node_leaf_count;
525 }
526 _ => panic!(
527 "Must have at least one child and the rightmost child must not be a leaf."
528 ),
529 }
530 }
531 }
532 }
533
534 #[allow(clippy::collapsible_if)]
538 fn verify(&self, proof: SparseMerkleRangeProof<H>) -> Result<()> {
539 let previous_leaf = self
540 .previous_leaf
541 .as_ref()
542 .expect("The previous leaf must exist.");
543 let previous_key = previous_leaf.key_hash();
544
545 let mut left_siblings = vec![];
550
551 let mut num_visited_right_siblings = 0;
555 for (i, bit) in previous_key.0.iter_bits().enumerate() {
556 if bit {
557 let sibling = if i >= self.partial_nodes.len() * 4 {
559 SPARSE_MERKLE_PLACEHOLDER_HASH
560 } else {
561 Self::compute_left_sibling(
562 &self.partial_nodes[i / 4],
563 previous_key.0.get_nibble(i / 4),
564 (3 - i % 4) as u8,
565 )
566 };
567 left_siblings.push(sibling);
568 } else {
569 num_visited_right_siblings += 1;
571 }
572 }
573 ensure!(
574 num_visited_right_siblings >= proof.right_siblings().len(),
575 "Too many right siblings in the proof.",
576 );
577
578 for bit in previous_key.0.iter_bits().rev() {
581 if bit {
582 if *left_siblings.last().expect("This sibling must exist.")
583 == SPARSE_MERKLE_PLACEHOLDER_HASH
584 {
585 left_siblings.pop();
586 } else {
587 break;
588 }
589 } else if num_visited_right_siblings > proof.right_siblings().len() {
590 num_visited_right_siblings -= 1;
591 } else {
592 break;
593 }
594 }
595
596 left_siblings.reverse();
598
599 proof.verify(
601 self.expected_root_hash,
602 SparseMerkleLeafNode::new(previous_key, previous_leaf.value_hash()),
603 left_siblings,
604 )
605 }
606
607 fn compute_left_sibling(partial_node: &InternalInfo, n: Nibble, height: u8) -> [u8; 32] {
609 assert!(height < 4);
610 let width = 1usize << height;
611 let start = get_child_and_sibling_half_start(n, height).1 as usize;
612 Self::compute_left_sibling_impl(&partial_node.children[start..start + width]).0
613 }
614
615 fn compute_left_sibling_impl(children: &[Option<ChildInfo>]) -> ([u8; 32], bool) {
617 assert!(!children.is_empty());
618
619 let num_children = children.len();
620 assert!(num_children.is_power_of_two());
621
622 if num_children == 1 {
623 match &children[0] {
624 Some(ChildInfo::Internal { hash, .. }) => {
625 (*hash.as_ref().expect("The hash must be known."), false)
626 }
627 Some(ChildInfo::Leaf { node }) => (node.hash::<H>(), true),
628 None => (SPARSE_MERKLE_PLACEHOLDER_HASH, true),
629 }
630 } else {
631 let (left_hash, left_is_leaf) =
632 Self::compute_left_sibling_impl(&children[..num_children / 2]);
633 let (right_hash, right_is_leaf) =
634 Self::compute_left_sibling_impl(&children[num_children / 2..]);
635
636 if left_hash == SPARSE_MERKLE_PLACEHOLDER_HASH && right_is_leaf {
637 (right_hash, true)
638 } else if left_is_leaf && right_hash == SPARSE_MERKLE_PLACEHOLDER_HASH {
639 (left_hash, true)
640 } else {
641 (
642 SparseMerkleInternalNode::new(left_hash, right_hash).hash::<H>(),
643 false,
644 )
645 }
646 }
647 }
648
649 fn finish_impl(mut self) -> Result<()> {
652 if self.partial_nodes.len() == 1 {
654 let mut num_children = 0;
655 let mut leaf = None;
656 for i in 0..16 {
657 if let Some(ref child_info) = self.partial_nodes[0].children[i] {
658 num_children += 1;
659 if let ChildInfo::Leaf { node } = child_info {
660 leaf = Some(node.clone());
661 }
662 }
663 }
664
665 if num_children == 1 {
666 if let Some(node) = leaf {
667 let node_key = NodeKey::new_empty_path(self.version);
668 assert!(self.frozen_nodes.is_empty());
669 self.frozen_nodes.insert_node(node_key, node.into());
670 self.store.write_node_batch(&self.frozen_nodes)?;
671 return Ok(());
672 }
673 }
674 }
675
676 self.freeze(0);
677 self.store.write_node_batch(&self.frozen_nodes)
678 }
679}
680
681pub trait StateSnapshotReceiver<H: SimpleHasher> {
683 fn add_chunk(
684 &mut self,
685 chunk: Vec<(KeyHash, OwnedValue)>,
686 proof: SparseMerkleRangeProof<H>,
687 ) -> Result<()>;
688
689 fn finish(self) -> Result<()>;
690
691 fn finish_box(self: Box<Self>) -> Result<()>;
692}
693
694impl<H: SimpleHasher> StateSnapshotReceiver<H> for JellyfishMerkleRestore<H> {
695 fn add_chunk(
696 &mut self,
697 chunk: Vec<(KeyHash, OwnedValue)>,
698 proof: SparseMerkleRangeProof<H>,
699 ) -> Result<()> {
700 self.add_chunk_impl(chunk, proof)
701 }
702
703 fn finish(self) -> Result<()> {
704 self.finish_impl()
705 }
706
707 fn finish_box(self: Box<Self>) -> Result<()> {
708 self.finish_impl()
709 }
710}