jmt/
bytes32ext.rs

1use core::ops::Index;
2
3use mirai_annotations::*;
4
5pub trait Bytes32Ext: Index<usize> + Sized {
6    /// Returns the `index`-th nibble.
7    fn get_nibble(&self, index: usize) -> crate::types::nibble::Nibble;
8    /// Returns the length of common prefix of `self` and `other` in bits.
9    fn common_prefix_bits_len(&self, other: &[u8; 32]) -> usize;
10    /// Returns a `HashValueBitIterator` over all the bits that represent this hash value.
11    fn iter_bits(&self) -> HashValueBitIterator<'_>;
12    /// Returns the `index`-th nibble in the bytes.
13    fn nibble(&self, index: usize) -> u8;
14    /// Returns the length of common prefix of `self` and `other` in nibbles.
15    fn common_prefix_nibbles_len(&self, other: &[u8; 32]) -> usize {
16        self.common_prefix_bits_len(other) / 4
17    }
18    /// Constructs a `HashValue` from an iterator of bits.
19    fn from_bit_iter(iter: impl ExactSizeIterator<Item = bool>) -> Option<Self>;
20}
21
22impl Bytes32Ext for [u8; 32] {
23    fn get_nibble(&self, index: usize) -> crate::types::nibble::Nibble {
24        crate::types::nibble::Nibble::from(if index % 2 == 0 {
25            self[index / 2] >> 4
26        } else {
27            self[index / 2] & 0x0F
28        })
29    }
30
31    fn common_prefix_bits_len(&self, other: &[u8; 32]) -> usize {
32        self.iter_bits()
33            .zip(other.iter_bits())
34            .take_while(|(x, y)| x == y)
35            .count()
36    }
37
38    fn iter_bits(&self) -> HashValueBitIterator<'_> {
39        HashValueBitIterator::new(self)
40    }
41
42    fn nibble(&self, index: usize) -> u8 {
43        assume!(index < 32 * 2); // assumed precondition
44        let pos = index / 2;
45        let shift = if index % 2 == 0 { 4 } else { 0 };
46        (self[pos] >> shift) & 0x0f
47    }
48
49    /// Constructs a `HashValue` from an iterator of bits.
50    fn from_bit_iter(iter: impl ExactSizeIterator<Item = bool>) -> Option<Self> {
51        if iter.len() != 256 {
52            return None;
53        }
54
55        let mut buf = [0; 32];
56        for (i, bit) in iter.enumerate() {
57            if bit {
58                buf[i / 8] |= 1 << (7 - i % 8);
59            }
60        }
61        Some(buf)
62    }
63}
64
65/// An iterator over a hash value that generates one bit for each iteration.
66pub struct HashValueBitIterator<'a> {
67    /// The reference to the bytes that represent the `HashValue`.
68    hash_bytes: &'a [u8],
69    pos: core::ops::Range<usize>,
70    // invariant hash_bytes.len() == HashValue::LENGTH;
71    // invariant pos.end == hash_bytes.len() * 8;
72}
73
74impl<'a> HashValueBitIterator<'a> {
75    /// Constructs a new `HashValueBitIterator` using given `HashValue`.
76    fn new(hash_value: &'a [u8; 32]) -> Self {
77        HashValueBitIterator {
78            hash_bytes: hash_value.as_ref(),
79            pos: (0..32 * 8),
80        }
81    }
82
83    /// Returns the `index`-th bit in the bytes.
84    fn get_bit(&self, index: usize) -> bool {
85        assume!(index < self.pos.end); // assumed precondition
86        assume!(self.hash_bytes.len() == 32); // invariant
87        assume!(self.pos.end == self.hash_bytes.len() * 8); // invariant
88        let pos = index / 8;
89        let bit = 7 - index % 8;
90        (self.hash_bytes[pos] >> bit) & 1 != 0
91    }
92}
93
94impl<'a> core::iter::Iterator for HashValueBitIterator<'a> {
95    type Item = bool;
96
97    fn next(&mut self) -> Option<Self::Item> {
98        self.pos.next().map(|x| self.get_bit(x))
99    }
100
101    fn size_hint(&self) -> (usize, Option<usize>) {
102        self.pos.size_hint()
103    }
104}
105
106impl<'a> core::iter::DoubleEndedIterator for HashValueBitIterator<'a> {
107    fn next_back(&mut self) -> Option<Self::Item> {
108        self.pos.next_back().map(|x| self.get_bit(x))
109    }
110}
111
112impl<'a> core::iter::ExactSizeIterator for HashValueBitIterator<'a> {}