penumbra_sdk_asset/balance/
imbalance.rs1use std::{
2 cmp::Ordering,
3 fmt::Debug,
4 num::NonZeroU128,
5 ops::{Add, Neg, Sub},
6};
7
8use serde::{Deserialize, Serialize};
9
10#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
14pub enum Imbalance<T> {
15 Required(T),
17 Provided(T),
19}
20
21impl<T> Neg for Imbalance<T> {
22 type Output = Imbalance<T>;
23
24 fn neg(self) -> Self::Output {
25 match self {
26 Imbalance::Required(t) => Imbalance::Provided(t),
27 Imbalance::Provided(t) => Imbalance::Required(t),
28 }
29 }
30}
31
32impl Add for Imbalance<NonZeroU128> {
33 type Output = Option<Self>;
34
35 fn add(self, other: Self) -> Self::Output {
36 match (self, other) {
39 (Imbalance::Required(r), Imbalance::Required(s)) => {
40 if let Some(t) = r.get().checked_add(s.get()) {
41 Some(Imbalance::Required(
42 NonZeroU128::new(t).expect("checked addition of NonZeroU128 is never zero"),
43 ))
44 } else {
45 panic!("overflow when adding imbalances")
46 }
47 }
48 (Imbalance::Required(r), Imbalance::Provided(p)) => match p.cmp(&r) {
49 Ordering::Less => Some(Imbalance::Required(
50 NonZeroU128::new(r.get() - p.get())
51 .expect("subtraction of lesser from greater is never zero"),
52 )),
53 Ordering::Equal => None,
54 Ordering::Greater => Some(Imbalance::Provided(
55 NonZeroU128::new(p.get() - r.get())
56 .expect("subtraction of lesser from greater is never zero"),
57 )),
58 },
59 (x, y) => Some(-((-x + -y)?)),
60 }
61 }
62}
63
64impl Sub for Imbalance<NonZeroU128> {
65 type Output = <Self as Add>::Output;
66
67 fn sub(self, other: Self) -> Self::Output {
68 self + -other
69 }
70}
71
72impl<T> Imbalance<T> {
73 pub fn into_inner(self) -> (Sign, T) {
75 match self {
76 Imbalance::Required(t) => (Sign::Required, t),
77 Imbalance::Provided(t) => (Sign::Provided, t),
78 }
79 }
80
81 pub fn map<S>(self, f: impl FnOnce(T) -> S) -> Imbalance<S> {
83 match self {
84 Imbalance::Required(t) => Imbalance::Required(f(t)),
85 Imbalance::Provided(t) => Imbalance::Provided(f(t)),
86 }
87 }
88
89 pub fn required(self) -> Option<T> {
91 match self {
92 Imbalance::Required(t) => Some(t),
93 Imbalance::Provided(_) => None,
94 }
95 }
96
97 pub fn provided(self) -> Option<T> {
99 match self {
100 Imbalance::Required(_) => None,
101 Imbalance::Provided(t) => Some(t),
102 }
103 }
104
105 pub fn is_required(&self) -> bool {
107 matches!(self, Imbalance::Required(_))
108 }
109
110 pub fn is_provided(&self) -> bool {
112 !self.is_required()
113 }
114
115 pub fn sign(&self) -> Sign {
117 match self {
118 Imbalance::Required(_) => Sign::Required,
119 Imbalance::Provided(_) => Sign::Provided,
120 }
121 }
122}
123
124#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
126pub enum Sign {
127 Required,
128 Provided,
129}
130
131impl Sign {
132 pub fn is_required(&self) -> bool {
134 matches!(self, Sign::Required)
135 }
136
137 pub fn is_provided(&self) -> bool {
139 !self.is_required()
140 }
141
142 pub fn imbalance<T>(&self, t: T) -> Imbalance<T> {
144 match self {
145 Sign::Required => Imbalance::Required(t),
146 Sign::Provided => Imbalance::Provided(t),
147 }
148 }
149}
150
151#[cfg(test)]
152mod test {
153 use super::*;
154
155 #[test]
156 fn add_provided_provided() {
157 let a = Imbalance::Provided(NonZeroU128::new(1).unwrap());
158 let b = Imbalance::Provided(NonZeroU128::new(2).unwrap());
159 let c = a + b;
160 assert_eq!(c, Some(Imbalance::Provided(NonZeroU128::new(3).unwrap())));
161 }
162
163 #[test]
164 fn add_provided_required_greater() {
165 let a = Imbalance::Provided(NonZeroU128::new(2).unwrap());
166 let b = Imbalance::Required(NonZeroU128::new(1).unwrap());
167 let c = a + b;
168 assert_eq!(c, Some(Imbalance::Provided(NonZeroU128::new(1).unwrap())));
169 }
170
171 #[test]
172 fn add_provided_required_equal() {
173 let a = Imbalance::Provided(NonZeroU128::new(1).unwrap());
174 let b = Imbalance::Required(NonZeroU128::new(1).unwrap());
175 let c = a + b;
176 assert_eq!(c, None);
177 }
178
179 #[test]
180 fn add_provided_required_less() {
181 let a = Imbalance::Provided(NonZeroU128::new(1).unwrap());
182 let b = Imbalance::Required(NonZeroU128::new(2).unwrap());
183 let c = a + b;
184 assert_eq!(c, Some(Imbalance::Required(NonZeroU128::new(1).unwrap())));
185 }
186
187 #[test]
188 fn add_required_required() {
189 let a = Imbalance::Required(NonZeroU128::new(1).unwrap());
190 let b = Imbalance::Required(NonZeroU128::new(2).unwrap());
191 let c = a + b;
192 assert_eq!(c, Some(Imbalance::Required(NonZeroU128::new(3).unwrap())));
193 }
194
195 #[test]
196 fn add_required_provided_greater() {
197 let a = Imbalance::Required(NonZeroU128::new(2).unwrap());
198 let b = Imbalance::Provided(NonZeroU128::new(1).unwrap());
199 let c = a + b;
200 assert_eq!(c, Some(Imbalance::Required(NonZeroU128::new(1).unwrap())));
201 }
202
203 #[test]
204 fn add_required_provided_equal() {
205 let a = Imbalance::Required(NonZeroU128::new(1).unwrap());
206 let b = Imbalance::Provided(NonZeroU128::new(1).unwrap());
207 let c = a + b;
208 assert_eq!(c, None);
209 }
210
211 #[test]
212 fn add_required_provided_less() {
213 let a = Imbalance::Required(NonZeroU128::new(1).unwrap());
214 let b = Imbalance::Provided(NonZeroU128::new(2).unwrap());
215 let c = a + b;
216 assert_eq!(c, Some(Imbalance::Provided(NonZeroU128::new(1).unwrap())));
217 }
218
219 #[test]
220 fn sub_provided_provided_greater() {
221 let a = Imbalance::Provided(NonZeroU128::new(2).unwrap());
222 let b = Imbalance::Provided(NonZeroU128::new(1).unwrap());
223 let c = a - b;
224 assert_eq!(c, Some(Imbalance::Provided(NonZeroU128::new(1).unwrap())));
225 }
226
227 #[test]
228 fn sub_provided_provided_equal() {
229 let a = Imbalance::Provided(NonZeroU128::new(1).unwrap());
230 let b = Imbalance::Provided(NonZeroU128::new(1).unwrap());
231 let c = a - b;
232 assert_eq!(c, None);
233 }
234
235 #[test]
236 fn sub_provided_provided_less() {
237 let a = Imbalance::Provided(NonZeroU128::new(1).unwrap());
238 let b = Imbalance::Provided(NonZeroU128::new(2).unwrap());
239 let c = a - b;
240 assert_eq!(c, Some(Imbalance::Required(NonZeroU128::new(1).unwrap())));
241 }
242
243 #[test]
244 fn sub_provided_required_greater() {
245 let a = Imbalance::Provided(NonZeroU128::new(2).unwrap());
246 let b = Imbalance::Required(NonZeroU128::new(1).unwrap());
247 let c = a - b;
248 assert_eq!(c, Some(Imbalance::Provided(NonZeroU128::new(3).unwrap())));
249 }
250
251 #[test]
252 fn sub_provided_required_equal() {
253 let a = Imbalance::Provided(NonZeroU128::new(1).unwrap());
254 let b = Imbalance::Required(NonZeroU128::new(1).unwrap());
255 let c = a - b;
256 assert_eq!(c, Some(Imbalance::Provided(NonZeroU128::new(2).unwrap())));
257 }
258
259 #[test]
260 fn sub_provided_required_less() {
261 let a = Imbalance::Provided(NonZeroU128::new(1).unwrap());
262 let b = Imbalance::Required(NonZeroU128::new(2).unwrap());
263 let c = a - b;
264 assert_eq!(c, Some(Imbalance::Provided(NonZeroU128::new(3).unwrap())));
265 }
266
267 #[test]
268 fn sub_required_provided_greater() {
269 let a = Imbalance::Required(NonZeroU128::new(2).unwrap());
270 let b = Imbalance::Provided(NonZeroU128::new(1).unwrap());
271 let c = a - b;
272 assert_eq!(c, Some(Imbalance::Required(NonZeroU128::new(3).unwrap())));
273 }
274
275 #[test]
276 fn sub_required_provided_equal() {
277 let a = Imbalance::Required(NonZeroU128::new(1).unwrap());
278 let b = Imbalance::Provided(NonZeroU128::new(1).unwrap());
279 let c = a - b;
280 assert_eq!(c, Some(Imbalance::Required(NonZeroU128::new(2).unwrap())));
281 }
282
283 #[test]
284 fn sub_required_provided_less() {
285 let a = Imbalance::Required(NonZeroU128::new(1).unwrap());
286 let b = Imbalance::Provided(NonZeroU128::new(2).unwrap());
287 let c = a - b;
288 assert_eq!(c, Some(Imbalance::Required(NonZeroU128::new(3).unwrap())));
289 }
290
291 #[test]
292 fn sub_required_required_greater() {
293 let a = Imbalance::Required(NonZeroU128::new(2).unwrap());
294 let b = Imbalance::Required(NonZeroU128::new(1).unwrap());
295 let c = a - b;
296 assert_eq!(c, Some(Imbalance::Required(NonZeroU128::new(1).unwrap())));
297 }
298
299 #[test]
300 fn sub_required_required_equal() {
301 let a = Imbalance::Required(NonZeroU128::new(1).unwrap());
302 let b = Imbalance::Required(NonZeroU128::new(1).unwrap());
303 let c = a - b;
304 assert_eq!(c, None);
305 }
306
307 #[test]
308 fn sub_required_required_less() {
309 let a = Imbalance::Required(NonZeroU128::new(1).unwrap());
310 let b = Imbalance::Required(NonZeroU128::new(2).unwrap());
311 let c = a - b;
312 assert_eq!(c, Some(Imbalance::Provided(NonZeroU128::new(1).unwrap())));
313 }
314}