penumbra_sdk_asset/balance/
imbalance.rs

1use std::{
2    cmp::Ordering,
3    fmt::Debug,
4    num::NonZeroU128,
5    ops::{Add, Neg, Sub},
6};
7
8use serde::{Deserialize, Serialize};
9
10/// An imbalance is either a required amount or a provided amount.
11///
12/// This is used exclusively when the type contained is non-zero.
13#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
14pub enum Imbalance<T> {
15    /// Something is required, i.e. it must be cancelled out by a provided thing.
16    Required(T),
17    /// Something is provided, i.e. it must be cancelled out by a required thing.
18    Provided(T),
19}
20
21impl<T> Neg for Imbalance<T> {
22    type Output = Imbalance<T>;
23
24    fn neg(self) -> Self::Output {
25        match self {
26            Imbalance::Required(t) => Imbalance::Provided(t),
27            Imbalance::Provided(t) => Imbalance::Required(t),
28        }
29    }
30}
31
32impl Add for Imbalance<NonZeroU128> {
33    type Output = Option<Self>;
34
35    fn add(self, other: Self) -> Self::Output {
36        // We define the case where the two are the same, and where the two are different, and in
37        // the symmetric cases, use double-negation to avoid repeating the logic
38        match (self, other) {
39            (Imbalance::Required(r), Imbalance::Required(s)) => {
40                if let Some(t) = r.get().checked_add(s.get()) {
41                    Some(Imbalance::Required(
42                        NonZeroU128::new(t).expect("checked addition of NonZeroU128 is never zero"),
43                    ))
44                } else {
45                    panic!("overflow when adding imbalances")
46                }
47            }
48            (Imbalance::Required(r), Imbalance::Provided(p)) => match p.cmp(&r) {
49                Ordering::Less => Some(Imbalance::Required(
50                    NonZeroU128::new(r.get() - p.get())
51                        .expect("subtraction of lesser from greater is never zero"),
52                )),
53                Ordering::Equal => None,
54                Ordering::Greater => Some(Imbalance::Provided(
55                    NonZeroU128::new(p.get() - r.get())
56                        .expect("subtraction of lesser from greater is never zero"),
57                )),
58            },
59            (x, y) => Some(-((-x + -y)?)),
60        }
61    }
62}
63
64impl Sub for Imbalance<NonZeroU128> {
65    type Output = <Self as Add>::Output;
66
67    fn sub(self, other: Self) -> Self::Output {
68        self + -other
69    }
70}
71
72impl<T> Imbalance<T> {
73    /// Split an imbalance into its sign and the thing contained in it.
74    pub fn into_inner(self) -> (Sign, T) {
75        match self {
76            Imbalance::Required(t) => (Sign::Required, t),
77            Imbalance::Provided(t) => (Sign::Provided, t),
78        }
79    }
80
81    /// Map a function over both required or provided possibilities.
82    pub fn map<S>(self, f: impl FnOnce(T) -> S) -> Imbalance<S> {
83        match self {
84            Imbalance::Required(t) => Imbalance::Required(f(t)),
85            Imbalance::Provided(t) => Imbalance::Provided(f(t)),
86        }
87    }
88
89    /// Filter an imbalance to get only the `T` out if it is a required thing.
90    pub fn required(self) -> Option<T> {
91        match self {
92            Imbalance::Required(t) => Some(t),
93            Imbalance::Provided(_) => None,
94        }
95    }
96
97    /// Filter an imbalance to get only the `T` out if it is a provided thing.
98    pub fn provided(self) -> Option<T> {
99        match self {
100            Imbalance::Required(_) => None,
101            Imbalance::Provided(t) => Some(t),
102        }
103    }
104
105    /// Check if an imbalance is required.
106    pub fn is_required(&self) -> bool {
107        matches!(self, Imbalance::Required(_))
108    }
109
110    /// Check if an imbalance is provided.
111    pub fn is_provided(&self) -> bool {
112        !self.is_required()
113    }
114
115    /// Get the sign of an imbalance.
116    pub fn sign(&self) -> Sign {
117        match self {
118            Imbalance::Required(_) => Sign::Required,
119            Imbalance::Provided(_) => Sign::Provided,
120        }
121    }
122}
123
124/// The sign of an imbalance is whether it is required or provided.
125#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
126pub enum Sign {
127    Required,
128    Provided,
129}
130
131impl Sign {
132    /// Check if the sign is required.
133    pub fn is_required(&self) -> bool {
134        matches!(self, Sign::Required)
135    }
136
137    /// Check if the sign if provided.
138    pub fn is_provided(&self) -> bool {
139        !self.is_required()
140    }
141
142    /// Form a new [`Imbalance`] by using the sign as a constructor for some value.
143    pub fn imbalance<T>(&self, t: T) -> Imbalance<T> {
144        match self {
145            Sign::Required => Imbalance::Required(t),
146            Sign::Provided => Imbalance::Provided(t),
147        }
148    }
149}
150
151#[cfg(test)]
152mod test {
153    use super::*;
154
155    #[test]
156    fn add_provided_provided() {
157        let a = Imbalance::Provided(NonZeroU128::new(1).unwrap());
158        let b = Imbalance::Provided(NonZeroU128::new(2).unwrap());
159        let c = a + b;
160        assert_eq!(c, Some(Imbalance::Provided(NonZeroU128::new(3).unwrap())));
161    }
162
163    #[test]
164    fn add_provided_required_greater() {
165        let a = Imbalance::Provided(NonZeroU128::new(2).unwrap());
166        let b = Imbalance::Required(NonZeroU128::new(1).unwrap());
167        let c = a + b;
168        assert_eq!(c, Some(Imbalance::Provided(NonZeroU128::new(1).unwrap())));
169    }
170
171    #[test]
172    fn add_provided_required_equal() {
173        let a = Imbalance::Provided(NonZeroU128::new(1).unwrap());
174        let b = Imbalance::Required(NonZeroU128::new(1).unwrap());
175        let c = a + b;
176        assert_eq!(c, None);
177    }
178
179    #[test]
180    fn add_provided_required_less() {
181        let a = Imbalance::Provided(NonZeroU128::new(1).unwrap());
182        let b = Imbalance::Required(NonZeroU128::new(2).unwrap());
183        let c = a + b;
184        assert_eq!(c, Some(Imbalance::Required(NonZeroU128::new(1).unwrap())));
185    }
186
187    #[test]
188    fn add_required_required() {
189        let a = Imbalance::Required(NonZeroU128::new(1).unwrap());
190        let b = Imbalance::Required(NonZeroU128::new(2).unwrap());
191        let c = a + b;
192        assert_eq!(c, Some(Imbalance::Required(NonZeroU128::new(3).unwrap())));
193    }
194
195    #[test]
196    fn add_required_provided_greater() {
197        let a = Imbalance::Required(NonZeroU128::new(2).unwrap());
198        let b = Imbalance::Provided(NonZeroU128::new(1).unwrap());
199        let c = a + b;
200        assert_eq!(c, Some(Imbalance::Required(NonZeroU128::new(1).unwrap())));
201    }
202
203    #[test]
204    fn add_required_provided_equal() {
205        let a = Imbalance::Required(NonZeroU128::new(1).unwrap());
206        let b = Imbalance::Provided(NonZeroU128::new(1).unwrap());
207        let c = a + b;
208        assert_eq!(c, None);
209    }
210
211    #[test]
212    fn add_required_provided_less() {
213        let a = Imbalance::Required(NonZeroU128::new(1).unwrap());
214        let b = Imbalance::Provided(NonZeroU128::new(2).unwrap());
215        let c = a + b;
216        assert_eq!(c, Some(Imbalance::Provided(NonZeroU128::new(1).unwrap())));
217    }
218
219    #[test]
220    fn sub_provided_provided_greater() {
221        let a = Imbalance::Provided(NonZeroU128::new(2).unwrap());
222        let b = Imbalance::Provided(NonZeroU128::new(1).unwrap());
223        let c = a - b;
224        assert_eq!(c, Some(Imbalance::Provided(NonZeroU128::new(1).unwrap())));
225    }
226
227    #[test]
228    fn sub_provided_provided_equal() {
229        let a = Imbalance::Provided(NonZeroU128::new(1).unwrap());
230        let b = Imbalance::Provided(NonZeroU128::new(1).unwrap());
231        let c = a - b;
232        assert_eq!(c, None);
233    }
234
235    #[test]
236    fn sub_provided_provided_less() {
237        let a = Imbalance::Provided(NonZeroU128::new(1).unwrap());
238        let b = Imbalance::Provided(NonZeroU128::new(2).unwrap());
239        let c = a - b;
240        assert_eq!(c, Some(Imbalance::Required(NonZeroU128::new(1).unwrap())));
241    }
242
243    #[test]
244    fn sub_provided_required_greater() {
245        let a = Imbalance::Provided(NonZeroU128::new(2).unwrap());
246        let b = Imbalance::Required(NonZeroU128::new(1).unwrap());
247        let c = a - b;
248        assert_eq!(c, Some(Imbalance::Provided(NonZeroU128::new(3).unwrap())));
249    }
250
251    #[test]
252    fn sub_provided_required_equal() {
253        let a = Imbalance::Provided(NonZeroU128::new(1).unwrap());
254        let b = Imbalance::Required(NonZeroU128::new(1).unwrap());
255        let c = a - b;
256        assert_eq!(c, Some(Imbalance::Provided(NonZeroU128::new(2).unwrap())));
257    }
258
259    #[test]
260    fn sub_provided_required_less() {
261        let a = Imbalance::Provided(NonZeroU128::new(1).unwrap());
262        let b = Imbalance::Required(NonZeroU128::new(2).unwrap());
263        let c = a - b;
264        assert_eq!(c, Some(Imbalance::Provided(NonZeroU128::new(3).unwrap())));
265    }
266
267    #[test]
268    fn sub_required_provided_greater() {
269        let a = Imbalance::Required(NonZeroU128::new(2).unwrap());
270        let b = Imbalance::Provided(NonZeroU128::new(1).unwrap());
271        let c = a - b;
272        assert_eq!(c, Some(Imbalance::Required(NonZeroU128::new(3).unwrap())));
273    }
274
275    #[test]
276    fn sub_required_provided_equal() {
277        let a = Imbalance::Required(NonZeroU128::new(1).unwrap());
278        let b = Imbalance::Provided(NonZeroU128::new(1).unwrap());
279        let c = a - b;
280        assert_eq!(c, Some(Imbalance::Required(NonZeroU128::new(2).unwrap())));
281    }
282
283    #[test]
284    fn sub_required_provided_less() {
285        let a = Imbalance::Required(NonZeroU128::new(1).unwrap());
286        let b = Imbalance::Provided(NonZeroU128::new(2).unwrap());
287        let c = a - b;
288        assert_eq!(c, Some(Imbalance::Required(NonZeroU128::new(3).unwrap())));
289    }
290
291    #[test]
292    fn sub_required_required_greater() {
293        let a = Imbalance::Required(NonZeroU128::new(2).unwrap());
294        let b = Imbalance::Required(NonZeroU128::new(1).unwrap());
295        let c = a - b;
296        assert_eq!(c, Some(Imbalance::Required(NonZeroU128::new(1).unwrap())));
297    }
298
299    #[test]
300    fn sub_required_required_equal() {
301        let a = Imbalance::Required(NonZeroU128::new(1).unwrap());
302        let b = Imbalance::Required(NonZeroU128::new(1).unwrap());
303        let c = a - b;
304        assert_eq!(c, None);
305    }
306
307    #[test]
308    fn sub_required_required_less() {
309        let a = Imbalance::Required(NonZeroU128::new(1).unwrap());
310        let b = Imbalance::Required(NonZeroU128::new(2).unwrap());
311        let c = a - b;
312        assert_eq!(c, Some(Imbalance::Provided(NonZeroU128::new(1).unwrap())));
313    }
314}