jmt/types/nibble/
nibble_path.rs

1// Copyright (c) The Diem Core Contributors
2// SPDX-License-Identifier: Apache-2.0
3
4//! NibblePath library simplify operations with nibbles in a compact format for modified sparse
5//! Merkle tree by providing powerful iterators advancing by either bit or nibble.
6
7use alloc::vec;
8use core::{fmt, iter::FromIterator};
9
10use alloc::vec::Vec;
11use mirai_annotations::*;
12#[cfg(any(test))]
13use proptest::{collection::vec, prelude::*};
14use serde::{Deserialize, Serialize};
15
16use crate::types::nibble::{Nibble, ROOT_NIBBLE_HEIGHT};
17
18/// NibblePath defines a path in Merkle tree in the unit of nibble (4 bits).
19#[derive(
20    Clone,
21    Hash,
22    Eq,
23    PartialEq,
24    Ord,
25    PartialOrd,
26    Serialize,
27    Deserialize,
28    borsh::BorshSerialize,
29    borsh::BorshDeserialize,
30)]
31pub struct NibblePath {
32    /// Indicates the total number of nibbles in bytes. Either `bytes.len() * 2 - 1` or
33    /// `bytes.len() * 2`.
34    // Guarantees intended ordering based on the top-to-bottom declaration order of the struct's
35    // members.
36    num_nibbles: usize,
37    /// The underlying bytes that stores the path, 2 nibbles per byte. If the number of nibbles is
38    /// odd, the second half of the last byte must be 0.
39    bytes: Vec<u8>,
40    // invariant num_nibbles <= ROOT_NIBBLE_HEIGHT
41}
42
43/// Supports debug format by concatenating nibbles literally. For example, [0x12, 0xa0] with 3
44/// nibbles will be printed as "12a".
45impl fmt::Debug for NibblePath {
46    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
47        self.nibbles().try_for_each(|x| write!(f, "{:x}", x))
48    }
49}
50
51/// Convert a vector of bytes into `NibblePath` using the lower 4 bits of each byte as nibble.
52impl FromIterator<Nibble> for NibblePath {
53    fn from_iter<I: IntoIterator<Item = Nibble>>(iter: I) -> Self {
54        let mut nibble_path = NibblePath::new(vec![]);
55        for nibble in iter {
56            nibble_path.push(nibble);
57        }
58        nibble_path
59    }
60}
61
62#[cfg(any(test))]
63impl Arbitrary for NibblePath {
64    type Parameters = ();
65    fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
66        arb_nibble_path().boxed()
67    }
68    type Strategy = BoxedStrategy<Self>;
69}
70
71#[cfg(any(test))]
72prop_compose! {
73    fn arb_nibble_path()(
74        mut bytes in vec(any::<u8>(), 0..=ROOT_NIBBLE_HEIGHT/2),
75        is_odd in any::<bool>()
76    ) -> NibblePath {
77        if let Some(last_byte) = bytes.last_mut() {
78            if is_odd {
79                *last_byte &= 0xf0;
80                return NibblePath::new_odd(bytes);
81            }
82        }
83        NibblePath::new(bytes)
84    }
85}
86
87#[cfg(any(test))]
88prop_compose! {
89    pub(crate) fn arb_internal_nibble_path()(
90        nibble_path in arb_nibble_path().prop_filter(
91            "Filter out leaf paths.",
92            |p| p.num_nibbles() < ROOT_NIBBLE_HEIGHT,
93        )
94    ) -> NibblePath {
95        nibble_path
96    }
97}
98
99impl NibblePath {
100    /// Creates a new `NibblePath` from a vector of bytes assuming each byte has 2 nibbles.
101    pub(crate) fn new(bytes: Vec<u8>) -> Self {
102        checked_precondition!(bytes.len() <= ROOT_NIBBLE_HEIGHT / 2);
103        let num_nibbles = bytes.len() * 2;
104        NibblePath { num_nibbles, bytes }
105    }
106
107    /// Similar to `new()` but assumes that the bytes have one less nibble.
108    // Unlike `new`, this function is not used under all feature combinations - so
109    // we #[allow(unused)] to silence the warnings
110    #[allow(unused)]
111    pub(crate) fn new_odd(bytes: Vec<u8>) -> Self {
112        checked_precondition!(bytes.len() <= ROOT_NIBBLE_HEIGHT / 2);
113        assert_eq!(
114            bytes.last().expect("Should have odd number of nibbles.") & 0x0f,
115            0,
116            "Last nibble must be 0."
117        );
118        let num_nibbles = bytes.len() * 2 - 1;
119        NibblePath { num_nibbles, bytes }
120    }
121
122    /// Adds a nibble to the end of the nibble path.
123    pub(crate) fn push(&mut self, nibble: Nibble) {
124        assert!(ROOT_NIBBLE_HEIGHT > self.num_nibbles);
125        if self.num_nibbles % 2 == 0 {
126            self.bytes.push(u8::from(nibble) << 4);
127        } else {
128            self.bytes[self.num_nibbles / 2] |= u8::from(nibble);
129        }
130        self.num_nibbles += 1;
131    }
132
133    /// Pops a nibble from the end of the nibble path.
134    pub(crate) fn pop(&mut self) -> Option<Nibble> {
135        let poped_nibble = if self.num_nibbles % 2 == 0 {
136            self.bytes.last_mut().map(|last_byte| {
137                let nibble = *last_byte & 0x0f;
138                *last_byte &= 0xf0;
139                Nibble::from(nibble)
140            })
141        } else {
142            self.bytes.pop().map(|byte| Nibble::from(byte >> 4))
143        };
144        if poped_nibble.is_some() {
145            self.num_nibbles -= 1;
146        }
147        poped_nibble
148    }
149
150    /// Returns the last nibble.
151    pub fn last(&self) -> Option<Nibble> {
152        let last_byte_option = self.bytes.last();
153        if self.num_nibbles % 2 == 0 {
154            last_byte_option.map(|last_byte| Nibble::from(*last_byte & 0x0f))
155        } else {
156            let last_byte = last_byte_option.expect("Last byte must exist if num_nibbles is odd.");
157            Some(Nibble::from(*last_byte >> 4))
158        }
159    }
160
161    /// Get the i-th bit.
162    pub(crate) fn get_bit(&self, i: usize) -> bool {
163        assert!(i < self.num_nibbles * 4);
164        let pos = i / 8;
165        let bit = 7 - i % 8;
166        ((self.bytes[pos] >> bit) & 1) != 0
167    }
168
169    /// Get the i-th nibble.
170    pub fn get_nibble(&self, i: usize) -> Nibble {
171        assert!(i < self.num_nibbles);
172        Nibble::from((self.bytes[i / 2] >> (if i % 2 == 1 { 0 } else { 4 })) & 0xf)
173    }
174
175    /// Get a bit iterator iterates over the whole nibble path.
176    pub fn bits(&self) -> BitIterator {
177        assume!(self.num_nibbles <= ROOT_NIBBLE_HEIGHT); // invariant
178        BitIterator {
179            nibble_path: self,
180            pos: (0..self.num_nibbles * 4),
181        }
182    }
183
184    /// Get a nibble iterator iterates over the whole nibble path.
185    pub fn nibbles(&self) -> NibbleIterator {
186        assume!(self.num_nibbles <= ROOT_NIBBLE_HEIGHT); // invariant
187        NibbleIterator::new(self, 0, self.num_nibbles)
188    }
189
190    /// Get the total number of nibbles stored.
191    pub fn num_nibbles(&self) -> usize {
192        self.num_nibbles
193    }
194
195    ///  Returns `true` if the nibbles contains no elements.
196    pub fn is_empty(&self) -> bool {
197        self.num_nibbles() == 0
198    }
199
200    /// Get the underlying bytes storing nibbles.
201    pub(crate) fn bytes(&self) -> &[u8] {
202        &self.bytes
203    }
204}
205
206pub trait Peekable: Iterator {
207    /// Returns the `next()` value without advancing the iterator.
208    fn peek(&self) -> Option<Self::Item>;
209}
210
211/// BitIterator iterates a nibble path by bit.
212pub struct BitIterator<'a> {
213    nibble_path: &'a NibblePath,
214    pos: core::ops::Range<usize>,
215}
216
217impl<'a> Peekable for BitIterator<'a> {
218    /// Returns the `next()` value without advancing the iterator.
219    fn peek(&self) -> Option<Self::Item> {
220        if self.pos.start < self.pos.end {
221            Some(self.nibble_path.get_bit(self.pos.start))
222        } else {
223            None
224        }
225    }
226}
227
228/// BitIterator spits out a boolean each time. True/false denotes 1/0.
229impl<'a> Iterator for BitIterator<'a> {
230    type Item = bool;
231    fn next(&mut self) -> Option<Self::Item> {
232        self.pos.next().map(|i| self.nibble_path.get_bit(i))
233    }
234}
235
236/// Support iterating bits in reversed order.
237impl<'a> DoubleEndedIterator for BitIterator<'a> {
238    fn next_back(&mut self) -> Option<Self::Item> {
239        self.pos.next_back().map(|i| self.nibble_path.get_bit(i))
240    }
241}
242
243/// NibbleIterator iterates a nibble path by nibble.
244#[derive(Debug, Clone)]
245pub struct NibbleIterator<'a> {
246    /// The underlying nibble path that stores the nibbles
247    nibble_path: &'a NibblePath,
248
249    /// The current index, `pos.start`, will bump by 1 after calling `next()` until `pos.start ==
250    /// pos.end`.
251    pos: core::ops::Range<usize>,
252
253    /// The start index of the iterator. At the beginning, `pos.start == start`. [start, pos.end)
254    /// defines the range of `nibble_path` this iterator iterates over. `nibble_path` refers to
255    /// the entire underlying buffer but the range may only be partial.
256    start: usize,
257    // invariant self.start <= self.pos.start;
258    // invariant self.pos.start <= self.pos.end;
259    // invariant self.pos.end <= ROOT_NIBBLE_HEIGHT;
260}
261
262/// NibbleIterator spits out a byte each time. Each byte must be in range [0, 16).
263impl<'a> Iterator for NibbleIterator<'a> {
264    type Item = Nibble;
265    fn next(&mut self) -> Option<Self::Item> {
266        self.pos.next().map(|i| self.nibble_path.get_nibble(i))
267    }
268}
269
270impl<'a> Peekable for NibbleIterator<'a> {
271    /// Returns the `next()` value without advancing the iterator.
272    fn peek(&self) -> Option<Self::Item> {
273        if self.pos.start < self.pos.end {
274            Some(self.nibble_path.get_nibble(self.pos.start))
275        } else {
276            None
277        }
278    }
279}
280
281impl<'a> NibbleIterator<'a> {
282    fn new(nibble_path: &'a NibblePath, start: usize, end: usize) -> Self {
283        precondition!(start <= end);
284        precondition!(start <= ROOT_NIBBLE_HEIGHT);
285        precondition!(end <= ROOT_NIBBLE_HEIGHT);
286        Self {
287            nibble_path,
288            pos: (start..end),
289            start,
290        }
291    }
292
293    /// Returns a nibble iterator that iterates all visited nibbles.
294    pub fn visited_nibbles(&self) -> NibbleIterator<'a> {
295        assume!(self.start <= self.pos.start); // invariant
296        assume!(self.pos.start <= ROOT_NIBBLE_HEIGHT); // invariant
297        Self::new(self.nibble_path, self.start, self.pos.start)
298    }
299
300    /// Returns a nibble iterator that iterates all remaining nibbles.
301    pub fn remaining_nibbles(&self) -> NibbleIterator<'a> {
302        assume!(self.pos.start <= self.pos.end); // invariant
303        assume!(self.pos.end <= ROOT_NIBBLE_HEIGHT); // invariant
304        Self::new(self.nibble_path, self.pos.start, self.pos.end)
305    }
306
307    /// Turn it into a `BitIterator`.
308    pub fn bits(&self) -> BitIterator<'a> {
309        assume!(self.pos.start <= self.pos.end); // invariant
310        assume!(self.pos.end <= ROOT_NIBBLE_HEIGHT); // invariant
311        BitIterator {
312            nibble_path: self.nibble_path,
313            pos: (self.pos.start * 4..self.pos.end * 4),
314        }
315    }
316
317    /// Cut and return the range of the underlying `nibble_path` that this iterator is iterating
318    /// over as a new `NibblePath`
319    pub fn get_nibble_path(&self) -> NibblePath {
320        self.visited_nibbles()
321            .chain(self.remaining_nibbles())
322            .collect()
323    }
324
325    /// Get the number of nibbles that this iterator covers.
326    pub fn num_nibbles(&self) -> usize {
327        assume!(self.start <= self.pos.end); // invariant
328        self.pos.end - self.start
329    }
330
331    /// Return `true` if the iteration is over.
332    pub fn is_finished(&self) -> bool {
333        self.peek().is_none()
334    }
335}
336
337/// Advance both iterators if their next nibbles are the same until either reaches the end or
338/// the find a mismatch. Return the number of matched nibbles.
339pub fn skip_common_prefix<I1, I2>(x: &mut I1, y: &mut I2) -> usize
340where
341    I1: Iterator + Peekable,
342    I2: Iterator + Peekable,
343    <I1 as Iterator>::Item: core::cmp::PartialEq<<I2 as Iterator>::Item>,
344{
345    let mut count = 0;
346    loop {
347        let x_peek = x.peek();
348        let y_peek = y.peek();
349        if x_peek.is_none()
350            || y_peek.is_none()
351            || x_peek.expect("cannot be none") != y_peek.expect("cannot be none")
352        {
353            break;
354        }
355        count += 1;
356        x.next();
357        y.next();
358    }
359    count
360}