1use std::{
4 fmt::{Debug, Display},
5 ops::Range,
6};
7
8use crate::prelude::*;
9
10#[doc(inline)]
11pub use crate::internal::hash::{Forgotten, Hash};
12
13pub(crate) trait Any<'tree>: GetHash + sealed::Sealed {
16 fn children(&'tree self) -> Vec<HashOrNode<'tree>>;
18
19 fn kind(&self) -> Kind;
22
23 fn forgotten(&self) -> Forgotten;
25}
26
27impl GetHash for &dyn Any<'_> {
28 fn hash(&self) -> Hash {
29 (**self).hash()
30 }
31
32 fn cached_hash(&self) -> Option<Hash> {
33 (**self).cached_hash()
34 }
35
36 fn clear_cached_hash(&self) {
37 (**self).clear_cached_hash()
38 }
39}
40
41impl<'tree, T: Any<'tree>> Any<'tree> for &T {
42 fn kind(&self) -> Kind {
43 (**self).kind()
44 }
45
46 fn forgotten(&self) -> Forgotten {
47 (**self).forgotten()
48 }
49
50 fn children(&'tree self) -> Vec<HashOrNode<'tree>> {
51 (**self).children()
52 }
53}
54
55#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
57pub enum Kind {
58 Leaf {
60 commitment: Option<StateCommitment>,
62 },
63 Internal {
65 height: u8,
67 },
68}
69
70impl Display for Kind {
71 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
72 match self {
73 Kind::Leaf { .. } => write!(f, "Leaf",),
74 Kind::Internal { .. } => write!(f, "Node"),
75 }
76 }
77}
78
79#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
84pub enum Place {
85 Complete,
87 Frontier,
89}
90
91impl Display for Place {
92 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
93 match self {
94 Place::Frontier => write!(f, "frontier"),
95 Place::Complete => write!(f, "complete"),
96 }
97 }
98}
99
100#[derive(Clone, Copy)]
102pub struct Node<'tree> {
103 offset: u64,
104 global_position: Option<Position>,
105 this: HashOrNode<'tree>,
106}
107
108impl GetHash for Node<'_> {
109 fn hash(&self) -> Hash {
110 self.this.hash()
111 }
112
113 fn cached_hash(&self) -> Option<Hash> {
114 self.this.cached_hash()
115 }
116
117 fn clear_cached_hash(&self) {
118 self.this.clear_cached_hash()
119 }
120}
121
122#[derive(Clone, Copy)]
123pub(crate) enum HashOrNode<'tree> {
124 Hash(HashedNode),
125 Node(&'tree dyn Any<'tree>),
126}
127
128impl GetHash for HashOrNode<'_> {
129 fn hash(&self) -> Hash {
130 match self {
131 HashOrNode::Hash(hashed) => hashed.hash(),
132 HashOrNode::Node(node) => node.hash(),
133 }
134 }
135
136 fn cached_hash(&self) -> Option<Hash> {
137 match self {
138 HashOrNode::Hash(hashed) => Some(hashed.hash()),
139 HashOrNode::Node(node) => node.cached_hash(),
140 }
141 }
142
143 fn clear_cached_hash(&self) {
144 if let HashOrNode::Node(node) = self {
145 node.clear_cached_hash()
146 }
147 }
148}
149
150#[derive(Clone, Copy)]
151pub(crate) struct HashedNode {
152 pub hash: Hash,
153 pub height: u8,
154 pub forgotten: Forgotten,
155}
156
157impl GetHash for HashedNode {
158 fn hash(&self) -> Hash {
159 self.hash
160 }
161
162 fn cached_hash(&self) -> Option<Hash> {
163 Some(self.hash)
164 }
165
166 fn clear_cached_hash(&self) {}
167}
168
169impl Debug for Node<'_> {
170 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
171 let name = format!("{}::{}", self.place(), self.kind());
172 let mut s = f.debug_struct(&name);
173 if self.height() != 0 {
174 s.field("height", &(*self).height());
175 }
176 s.field("position", &u64::from(self.position()));
177 if self.forgotten() != Forgotten::default() {
178 s.field("forgotten", &self.forgotten());
179 }
180 if let Some(hash) = self.cached_hash() {
181 s.field("hash", &hash);
182 }
183 if let Kind::Leaf {
184 commitment: Some(commitment),
185 } = self.kind()
186 {
187 s.field("commitment", &commitment);
188 }
189 let children = self.children();
190 if !children.is_empty() {
191 s.field("children", &children);
192 }
193 s.finish()
194 }
195}
196
197impl Display for Node<'_> {
198 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
199 f.debug_struct(&format!("{}::{}", self.place(), self.kind()))
200 .field("height", &self.height())
201 .field("position", &self.position())
202 .finish_non_exhaustive()
203 }
204}
205
206impl<'tree> Node<'tree> {
207 pub(crate) fn root<R: Any<'tree> + GetPosition>(node: &'tree R) -> Self {
209 Self {
210 offset: 0,
211 global_position: node.position().map(Into::into),
212 this: HashOrNode::Node(node),
213 }
214 }
215
216 pub fn hash(&self) -> Hash {
218 self.this.hash()
219 }
220
221 pub fn cached_hash(&self) -> Option<Hash> {
223 self.this.cached_hash()
224 }
225
226 pub fn kind(&self) -> Kind {
229 match self.this {
230 HashOrNode::Hash(HashedNode { height, .. }) => Kind::Internal { height },
231 HashOrNode::Node(node) => node.kind(),
232 }
233 }
234
235 pub fn forgotten(&self) -> Forgotten {
237 match self.this {
238 HashOrNode::Hash(HashedNode { forgotten, .. }) => forgotten,
239 HashOrNode::Node(node) => node.forgotten(),
240 }
241 }
242
243 pub fn children(&self) -> Vec<Node<'tree>> {
245 match self.this {
246 HashOrNode::Hash(_) => Vec::new(),
247 HashOrNode::Node(node) => node
248 .children()
249 .into_iter()
250 .enumerate()
251 .map(|(i, hash_or_node)| Node {
252 global_position: self.global_position,
253 offset: self.offset * 4 + (i as u64),
254 this: hash_or_node,
255 })
256 .collect(),
257 }
258 }
259
260 pub fn index(&self) -> u64 {
264 self.offset
265 }
266
267 pub fn height(&self) -> u8 {
269 match self.kind() {
270 Kind::Internal { height } => height,
271 Kind::Leaf { .. } => 0,
272 }
273 }
274
275 pub fn position(&self) -> Position {
277 (4u64.pow(self.height() as u32) * self.index()).into()
278 }
279
280 pub fn stride(&self) -> u64 {
282 4u64.pow(self.height() as u32)
283 }
284
285 pub fn range(&self) -> Range<Position> {
288 let position: u64 = self.position().into();
289 position.into()..(position + self.stride()).min(4u64.pow(24) - 1).into()
290 }
291
292 pub fn global_position(&self) -> Option<Position> {
294 self.global_position
295 }
296
297 pub fn place(&self) -> Place {
299 if let Some(global_position) = self.global_position() {
300 if let Some(frontier_tip) = u64::from(global_position).checked_sub(1) {
301 let height = self.height();
302 let position = u64::from(self.position());
303 if position >> (height * 2) == frontier_tip >> (height * 2) {
304 Place::Frontier
307 } else {
308 Place::Complete
311 }
312 } else {
313 Place::Frontier
316 }
317 } else {
318 Place::Complete
320 }
321 }
322}
323
324mod sealed {
325 use super::*;
326
327 pub trait Sealed: Send + Sync {}
328
329 impl<T: Sealed> Sealed for &T {}
330 impl Sealed for Node<'_> {}
331
332 impl Sealed for complete::Item {}
333 impl<T: Sealed> Sealed for complete::Leaf<T> {}
334 impl<T: Sealed + Clone> Sealed for complete::Node<T> {}
335 impl<T: Sealed + Height + GetHash + Clone> Sealed for complete::Tier<T> {}
336 impl<T: Sealed + Height + GetHash + Clone> Sealed for complete::Top<T> {}
337
338 impl Sealed for frontier::Item {}
339 impl<T: Sealed> Sealed for frontier::Leaf<T> {}
340 impl<T: Sealed + Focus> Sealed for frontier::Node<T> where T::Complete: Send + Sync {}
341 impl<T: Sealed + Height + GetHash + Focus + Clone> Sealed for frontier::Tier<T> where
342 T::Complete: Send + Sync + Clone
343 {
344 }
345 impl<T: Sealed + Height + GetHash + Focus + Clone> Sealed for frontier::Top<T> where
346 T::Complete: Send + Sync + Clone
347 {
348 }
349}
350
351#[cfg(test)]
352mod test {
353 use super::*;
354
355 #[test]
356 fn indexing_correct() {
357 const MAX_SIZE_TO_TEST: u16 = 100;
358
359 let mut top: frontier::Top<Item> = frontier::Top::new(frontier::TrackForgotten::No);
360 for i in 0..MAX_SIZE_TO_TEST {
361 top.insert(StateCommitment(i.into()).into()).unwrap();
362 }
363
364 fn check_leaves(index: &mut [u64; 9], node: Node) {
365 assert_eq!(node.index(), index[usize::from(node.height())], "{node}");
366
367 index[usize::from(node.height())] += 1;
368
369 for child in node.children() {
370 check_leaves(index, child);
371 }
372 }
373
374 check_leaves(&mut [0; 9], Node::root(&top));
375 }
376
377 #[test]
378 fn place_correct() {
379 const MAX_SIZE_TO_TEST: u16 = 100;
380
381 let mut top: frontier::Top<Item> = frontier::Top::new(frontier::TrackForgotten::No);
382 for i in 0..MAX_SIZE_TO_TEST {
383 top.insert(StateCommitment(i.into()).into()).unwrap();
384 let root = Node::root(&top);
385 check(root, Place::Frontier);
386 }
387
388 fn check(node: Node, expected: Place) {
389 assert_eq!(node.place(), expected);
390 match node.children().as_slice() {
391 [] => {}
392 [a] => {
393 check(*a, expected);
394 }
395 [a, b] => {
396 check(*a, Place::Complete);
397 check(*b, expected);
398 }
399 [a, b, c] => {
400 check(*a, Place::Complete);
401 check(*b, Place::Complete);
402 check(*c, expected);
403 }
404 [a, b, c, d] => {
405 check(*a, Place::Complete);
406 check(*b, Place::Complete);
407 check(*c, Place::Complete);
408 check(*d, expected);
409 }
410 _ => unreachable!("nodes can't have > 4 children"),
411 }
412 }
413 }
414
415 #[test]
416 fn height_correct() {
417 const MAX_SIZE_TO_TEST: u16 = 100;
418
419 let mut tree = crate::Tree::new();
420
421 for i in 0..MAX_SIZE_TO_TEST {
422 tree.insert(crate::Witness::Keep, StateCommitment(i.into()))
423 .unwrap();
424 let root = tree.structure();
425 check(root, 24);
426 }
427
428 fn check(node: Node, expected: u8) {
429 assert_eq!(node.height(), expected, "{node}");
430 for child in node.children() {
431 check(child, expected - 1);
432 }
433 }
434 }
435}