penumbra_sdk_dex/component/position_manager/
counter.rs1use anyhow::bail;
2use async_trait::async_trait;
3use cnidarium::{StateRead, StateWrite};
4
5use crate::lp::position::{self, Position};
6use crate::state_key::engine;
7use crate::TradingPair;
8use anyhow::Result;
9
10#[async_trait]
11pub(crate) trait PositionCounterRead: StateRead {
12 async fn get_position_count(&self, trading_pair: &TradingPair) -> u32 {
15 let path = engine::counter::num_positions::by_trading_pair(trading_pair);
16 self.get_position_count_from_key(path).await
17 }
18
19 async fn get_position_count_from_key(&self, path: [u8; 99]) -> u32 {
20 let Some(raw_count) = self
21 .nonverifiable_get_raw(&path)
22 .await
23 .expect("no deserialization failure")
24 else {
25 return 0;
26 };
27
28 let raw_count: [u8; 4] = raw_count
30 .try_into()
31 .expect("position counter is at most two bytes");
32 u32::from_be_bytes(raw_count)
33 }
34}
35
36impl<T: StateRead + ?Sized> PositionCounterRead for T {}
37
38#[async_trait]
39pub(crate) trait PositionCounter: StateWrite {
40 async fn update_trading_pair_position_counter(
41 &mut self,
42 prev_state: &Option<Position>,
43 new_state: &Position,
44 ) -> Result<()> {
45 use position::State::*;
46 let trading_pair = new_state.phi.pair;
47 match (prev_state.as_ref().map(|p| p.state), new_state.state) {
48 (None, Opened) => {
50 let _ = self.increment_position_counter(&trading_pair).await?;
51 }
52 (Some(Opened), Closed) => {
54 let _ = self.decrement_position_counter(&trading_pair).await?;
55 }
56 _ => {}
58 }
59 Ok(())
60 }
61}
62impl<T: StateWrite + ?Sized> PositionCounter for T {}
63
64trait Inner: StateWrite {
65 async fn increment_position_counter(&mut self, trading_pair: &TradingPair) -> Result<u32> {
68 let path = engine::counter::num_positions::by_trading_pair(trading_pair);
69 let prev = self.get_position_count_from_key(path).await;
70
71 let Some(new_total) = prev.checked_add(1) else {
72 bail!("incrementing position counter would overflow")
73 };
74 self.nonverifiable_put_raw(path.to_vec(), new_total.to_be_bytes().to_vec());
75 Ok(new_total)
76 }
77
78 async fn decrement_position_counter(&mut self, trading_pair: &TradingPair) -> Result<u32> {
81 let path = engine::counter::num_positions::by_trading_pair(trading_pair);
82 let prev = self.get_position_count_from_key(path).await;
83
84 let Some(new_total) = prev.checked_sub(1) else {
85 bail!("decrementing position counter would underflow")
86 };
87 self.nonverifiable_put_raw(path.to_vec(), new_total.to_be_bytes().to_vec());
88 Ok(new_total)
89 }
90}
91
92impl<T: StateWrite + ?Sized> Inner for T {}
93
94#[allow(unused_imports)]
97mod tests {
98 use cnidarium::{StateDelta, StateWrite, TempStorage};
99 use penumbra_sdk_asset::{asset::REGISTRY, Value};
100
101 use crate::component::position_manager::counter::{
102 Inner, PositionCounter, PositionCounterRead,
103 };
104 use crate::state_key::engine;
105 use crate::TradingPair;
106
107 #[tokio::test]
108 async fn test_no_overflow() -> anyhow::Result<()> {
110 let asset_a = REGISTRY.parse_denom("upenumbra").unwrap().id();
111 let asset_b = REGISTRY.parse_denom("pizza").unwrap().id();
112 let trading_pair = TradingPair::new(asset_a, asset_b);
113
114 let storage = TempStorage::new().await?;
115 let mut delta = StateDelta::new(storage.latest_snapshot());
116 let path = engine::counter::num_positions::by_trading_pair(&trading_pair);
117 delta.nonverifiable_put_raw(path.to_vec(), u32::MAX.to_be_bytes().to_vec());
119
120 let total = delta.get_position_count(&trading_pair).await;
122 assert_eq!(total, u32::MAX);
123
124 assert!(delta
126 .increment_position_counter(&trading_pair)
127 .await
128 .is_err());
129 assert_eq!(delta.get_position_count(&trading_pair).await, u32::MAX);
130
131 Ok(())
132 }
133
134 #[tokio::test]
135 async fn test_no_underflow() -> anyhow::Result<()> {
137 let asset_a = REGISTRY.parse_denom("upenumbra").unwrap().id();
138 let asset_b = REGISTRY.parse_denom("pizza").unwrap().id();
139 let trading_pair = TradingPair::new(asset_a, asset_b);
140
141 let storage = TempStorage::new().await?;
142 let mut delta = StateDelta::new(storage.latest_snapshot());
143
144 let maybe_total = delta.decrement_position_counter(&trading_pair).await;
145 assert!(maybe_total.is_err());
146
147 let counter = delta.get_position_count(&trading_pair).await;
148 assert_eq!(counter, 0u32);
149 Ok(())
150 }
151}