penumbra_sdk_dex/component/position_manager/
counter.rs

1use anyhow::bail;
2use async_trait::async_trait;
3use cnidarium::{StateRead, StateWrite};
4
5use crate::lp::position::{self, Position};
6use crate::state_key::engine;
7use crate::TradingPair;
8use anyhow::Result;
9
10#[async_trait]
11pub(crate) trait PositionCounterRead: StateRead {
12    /// Returns the number of position for a [`TradingPair`].
13    /// If there were no counter initialized for a given pair, this default to zero.
14    async fn get_position_count(&self, trading_pair: &TradingPair) -> u32 {
15        let path = engine::counter::num_positions::by_trading_pair(trading_pair);
16        self.get_position_count_from_key(path).await
17    }
18
19    async fn get_position_count_from_key(&self, path: [u8; 99]) -> u32 {
20        let Some(raw_count) = self
21            .nonverifiable_get_raw(&path)
22            .await
23            .expect("no deserialization failure")
24        else {
25            return 0;
26        };
27
28        // This is safe because we only increment the counter via [`Self::increase_position_counter`].
29        let raw_count: [u8; 4] = raw_count
30            .try_into()
31            .expect("position counter is at most two bytes");
32        u32::from_be_bytes(raw_count)
33    }
34}
35
36impl<T: StateRead + ?Sized> PositionCounterRead for T {}
37
38#[async_trait]
39pub(crate) trait PositionCounter: StateWrite {
40    async fn update_trading_pair_position_counter(
41        &mut self,
42        prev_state: &Option<Position>,
43        new_state: &Position,
44    ) -> Result<()> {
45        use position::State::*;
46        let trading_pair = new_state.phi.pair;
47        match (prev_state.as_ref().map(|p| p.state), new_state.state) {
48            // Increment the counter whenever a new position is opened
49            (None, Opened) => {
50                let _ = self.increment_position_counter(&trading_pair).await?;
51            }
52            // Decrement the counter whenever an opened position is closed
53            (Some(Opened), Closed) => {
54                let _ = self.decrement_position_counter(&trading_pair).await?;
55            }
56            // Other state transitions don't affect the opened position counter
57            _ => {}
58        }
59        Ok(())
60    }
61}
62impl<T: StateWrite + ?Sized> PositionCounter for T {}
63
64trait Inner: StateWrite {
65    /// Increment the number of position for a [`TradingPair`].
66    /// Returns the updated total, or an error if overflow occurred.
67    async fn increment_position_counter(&mut self, trading_pair: &TradingPair) -> Result<u32> {
68        let path = engine::counter::num_positions::by_trading_pair(trading_pair);
69        let prev = self.get_position_count_from_key(path).await;
70
71        let Some(new_total) = prev.checked_add(1) else {
72            bail!("incrementing position counter would overflow")
73        };
74        self.nonverifiable_put_raw(path.to_vec(), new_total.to_be_bytes().to_vec());
75        Ok(new_total)
76    }
77
78    /// Decrement the number of positions for a [`TradingPair`], unless it would underflow.
79    /// Returns the updated total, or an error if underflow occurred.
80    async fn decrement_position_counter(&mut self, trading_pair: &TradingPair) -> Result<u32> {
81        let path = engine::counter::num_positions::by_trading_pair(trading_pair);
82        let prev = self.get_position_count_from_key(path).await;
83
84        let Some(new_total) = prev.checked_sub(1) else {
85            bail!("decrementing position counter would underflow")
86        };
87        self.nonverifiable_put_raw(path.to_vec(), new_total.to_be_bytes().to_vec());
88        Ok(new_total)
89    }
90}
91
92impl<T: StateWrite + ?Sized> Inner for T {}
93
94// For some reason, `rust-analyzer` is complaining about used imports.
95// Silence the warnings until I find a fix.
96#[allow(unused_imports)]
97mod tests {
98    use cnidarium::{StateDelta, StateWrite, TempStorage};
99    use penumbra_sdk_asset::{asset::REGISTRY, Value};
100
101    use crate::component::position_manager::counter::{
102        Inner, PositionCounter, PositionCounterRead,
103    };
104    use crate::state_key::engine;
105    use crate::TradingPair;
106
107    #[tokio::test]
108    /// Test that we can detect overflows and that they are handled properly: increment is ignored / no crash.
109    async fn test_no_overflow() -> anyhow::Result<()> {
110        let asset_a = REGISTRY.parse_denom("upenumbra").unwrap().id();
111        let asset_b = REGISTRY.parse_denom("pizza").unwrap().id();
112        let trading_pair = TradingPair::new(asset_a, asset_b);
113
114        let storage = TempStorage::new().await?;
115        let mut delta = StateDelta::new(storage.latest_snapshot());
116        let path = engine::counter::num_positions::by_trading_pair(&trading_pair);
117        // Manually set the counter to the maximum value
118        delta.nonverifiable_put_raw(path.to_vec(), u32::MAX.to_be_bytes().to_vec());
119
120        // Check that the counter is at the maximum value
121        let total = delta.get_position_count(&trading_pair).await;
122        assert_eq!(total, u32::MAX);
123
124        // Check that we can handle an overflow
125        assert!(delta
126            .increment_position_counter(&trading_pair)
127            .await
128            .is_err());
129        assert_eq!(delta.get_position_count(&trading_pair).await, u32::MAX);
130
131        Ok(())
132    }
133
134    #[tokio::test]
135    /// Test that we can detect underflow and that they are handled properly: decrement is ignored / no crash.
136    async fn test_no_underflow() -> anyhow::Result<()> {
137        let asset_a = REGISTRY.parse_denom("upenumbra").unwrap().id();
138        let asset_b = REGISTRY.parse_denom("pizza").unwrap().id();
139        let trading_pair = TradingPair::new(asset_a, asset_b);
140
141        let storage = TempStorage::new().await?;
142        let mut delta = StateDelta::new(storage.latest_snapshot());
143
144        let maybe_total = delta.decrement_position_counter(&trading_pair).await;
145        assert!(maybe_total.is_err());
146
147        let counter = delta.get_position_count(&trading_pair).await;
148        assert_eq!(counter, 0u32);
149        Ok(())
150    }
151}