penumbra_sdk_num/
amount.rs

1use ark_ff::{BigInteger, PrimeField, ToConstraintField};
2use ark_r1cs_std::{prelude::*, uint64::UInt64};
3use ark_relations::r1cs::{ConstraintSystemRef, SynthesisError};
4use decaf377::{Fq, Fr};
5use penumbra_sdk_proto::{penumbra::core::num::v1 as pb, DomainType};
6use serde::{Deserialize, Serialize};
7use std::{fmt::Display, iter::Sum, num::NonZeroU128, ops};
8
9use crate::fixpoint::{bit_constrain, U128x128, U128x128Var};
10use decaf377::r1cs::FqVar;
11
12#[derive(Serialize, Default, Deserialize, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy)]
13#[serde(try_from = "pb::Amount", into = "pb::Amount")]
14pub struct Amount {
15    inner: u128,
16}
17
18impl std::fmt::Debug for Amount {
19    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
20        write!(f, "{}", self.inner)
21    }
22}
23
24impl Amount {
25    pub fn value(&self) -> u128 {
26        self.inner
27    }
28
29    pub fn zero() -> Self {
30        Self { inner: 0 }
31    }
32
33    // We need fixed length encoding to produce encrypted `Note`s.
34    pub fn to_le_bytes(&self) -> [u8; 16] {
35        self.inner.to_le_bytes()
36    }
37
38    pub fn to_be_bytes(&self) -> [u8; 16] {
39        self.inner.to_be_bytes()
40    }
41
42    pub fn from_le_bytes(bytes: [u8; 16]) -> Amount {
43        Amount {
44            inner: u128::from_le_bytes(bytes),
45        }
46    }
47
48    pub fn from_be_bytes(bytes: [u8; 16]) -> Amount {
49        Amount {
50            inner: u128::from_be_bytes(bytes),
51        }
52    }
53
54    pub fn checked_sub(&self, rhs: &Self) -> Option<Self> {
55        self.inner
56            .checked_sub(rhs.inner)
57            .map(|inner| Self { inner })
58    }
59
60    pub fn checked_add(&self, rhs: &Self) -> Option<Self> {
61        self.inner
62            .checked_add(rhs.inner)
63            .map(|inner| Self { inner })
64    }
65
66    pub fn checked_mul(&self, rhs: &Self) -> Option<Self> {
67        self.inner
68            .checked_mul(rhs.inner)
69            .map(|inner| Self { inner })
70    }
71
72    pub fn saturating_add(&self, rhs: &Self) -> Self {
73        Self {
74            inner: self.inner.saturating_add(rhs.inner),
75        }
76    }
77
78    pub fn saturating_sub(&self, rhs: &Self) -> Self {
79        Self {
80            inner: self.inner.saturating_sub(rhs.inner),
81        }
82    }
83}
84
85impl ops::Not for Amount {
86    type Output = Self;
87
88    fn not(self) -> Self::Output {
89        Self { inner: !self.inner }
90    }
91}
92
93#[derive(Clone)]
94pub struct AmountVar {
95    pub amount: FqVar,
96}
97
98impl ToConstraintField<Fq> for Amount {
99    fn to_field_elements(&self) -> Option<Vec<Fq>> {
100        let mut elements = Vec::new();
101        elements.extend_from_slice(&[Fq::from(self.inner)]);
102        Some(elements)
103    }
104}
105
106/// Return a boolean constraint indicating if the FqVar can be represented using n bits
107pub fn is_bit_constrained(
108    cs: ConstraintSystemRef<Fq>,
109    value: FqVar,
110    n: usize,
111) -> Result<Boolean<Fq>, SynthesisError> {
112    let inner = value.value().unwrap_or(Fq::from(1u64));
113
114    // Get only first n bits based on that value (OOC)
115    let inner_bigint = inner.into_bigint();
116    let bits = &inner_bigint.to_bits_le()[0..n];
117
118    // Allocate Boolean vars for first n bits
119    let mut boolean_constraints = Vec::new();
120    for bit in bits {
121        let boolean = Boolean::new_witness(cs.clone(), || Ok(bit))?;
122        boolean_constraints.push(boolean);
123    }
124
125    // Construct an FqVar from those n Boolean constraints
126    let constructed_fqvar = Boolean::<Fq>::le_bits_to_fp_var(&boolean_constraints.to_bits_le()?)
127        .expect("can convert to bits");
128    constructed_fqvar.is_eq(&value)
129}
130
131impl AmountVar {
132    pub fn negate(&self) -> Result<Self, SynthesisError> {
133        Ok(Self {
134            amount: self.amount.negate()?,
135        })
136    }
137
138    pub fn quo_rem(
139        &self,
140        divisor_var: &AmountVar,
141    ) -> Result<(AmountVar, AmountVar), SynthesisError> {
142        let current_amount_bytes: [u8; 16] = self.amount.value().unwrap_or_default().to_bytes()
143            [0..16]
144            .try_into()
145            .expect("amounts should fit in 16 bytes");
146        let current_amount = u128::from_le_bytes(current_amount_bytes);
147        let divisor_bytes: [u8; 16] = divisor_var.amount.value().unwrap_or_default().to_bytes()
148            [0..16]
149            .try_into()
150            .expect("amounts should fit in 16 bytes");
151        let divisor = u128::from_le_bytes(divisor_bytes);
152
153        // Out of circuit
154        let quo = current_amount.checked_div(divisor).unwrap_or(0);
155        let rem = current_amount.checked_rem(divisor).unwrap_or(0);
156
157        // Add corresponding in-circuit variables
158        let quo_var = AmountVar::new_witness(self.cs(), || Ok(Fq::from(quo)))?;
159        let rem_var = AmountVar::new_witness(self.cs(), || Ok(Fq::from(rem)))?;
160
161        // Constrain either quo_var or divisor_var to be 64 bits to guard against overflow
162        let q_is_64_bits = is_bit_constrained(self.cs(), quo_var.amount.clone(), 64)?;
163        let d_is_64_bits = is_bit_constrained(self.cs(), divisor_var.amount.clone(), 64)?;
164        let q_or_d_is_64_bits = q_is_64_bits.or(&d_is_64_bits)?;
165        q_or_d_is_64_bits.enforce_equal(&Boolean::constant(true))?;
166
167        // Constrain: numerator = quo * divisor + rem
168        let numerator_var = quo_var.clone() * divisor_var.clone() + rem_var.clone();
169        self.enforce_equal(&numerator_var)?;
170
171        // In this stanza we constrain: 0 <= rem < divisor.
172        //
173        // We do not need to explicitly constrain 0 <= rem, as that is done
174        // inside the `FqVar::enforce_cmp` function, which verifies the inputs are
175        // of size <(p-1)/2.
176        //
177        // See: https://docs.rs/ark-r1cs-std/latest/ark_r1cs_std/fields/fp/enum.FpVar.html#method.enforce_cmp
178        //
179        // Constrain: 0 <= rem < divisor
180        rem_var
181            .amount
182            .enforce_cmp(&divisor_var.amount, core::cmp::Ordering::Less, false)?;
183        // As above, `FpVar::enforce_cmp` requires that the amounts have size <(p-1)/2 which is
184        // true for amounts as they are 128 bits at most.
185
186        // We do not need to check the divisor is non-zero, as that is already
187        // enforced by 0 <= r < d above.
188
189        Ok((quo_var, rem_var))
190    }
191}
192
193impl AllocVar<Amount, Fq> for AmountVar {
194    fn new_variable<T: std::borrow::Borrow<Amount>>(
195        cs: impl Into<ark_relations::r1cs::Namespace<Fq>>,
196        f: impl FnOnce() -> Result<T, SynthesisError>,
197        mode: ark_r1cs_std::prelude::AllocationMode,
198    ) -> Result<Self, SynthesisError> {
199        let ns = cs.into();
200        let cs = ns.cs();
201        let amount: Amount = *f()?.borrow();
202        let inner_amount_var = FqVar::new_variable(cs, || Ok(Fq::from(amount)), mode)?;
203        // Check the amounts are 128 bits maximum
204        let _ = bit_constrain(inner_amount_var.clone(), 128);
205        Ok(Self {
206            amount: inner_amount_var,
207        })
208    }
209}
210
211impl AllocVar<Fq, Fq> for AmountVar {
212    fn new_variable<T: std::borrow::Borrow<Fq>>(
213        cs: impl Into<ark_relations::r1cs::Namespace<Fq>>,
214        f: impl FnOnce() -> Result<T, SynthesisError>,
215        mode: ark_r1cs_std::prelude::AllocationMode,
216    ) -> Result<Self, SynthesisError> {
217        let ns = cs.into();
218        let cs = ns.cs();
219        let amount: Fq = *f()?.borrow();
220        let inner_amount_var = FqVar::new_variable(cs, || Ok(amount), mode)?;
221        // Check the amounts are 128 bits maximum
222        let _ = bit_constrain(inner_amount_var.clone(), 128);
223        Ok(Self {
224            amount: inner_amount_var,
225        })
226    }
227}
228
229impl R1CSVar<Fq> for AmountVar {
230    type Value = Amount;
231
232    fn cs(&self) -> ark_relations::r1cs::ConstraintSystemRef<Fq> {
233        self.amount.cs()
234    }
235
236    fn value(&self) -> Result<Self::Value, SynthesisError> {
237        let amount_fq = self.amount.value()?;
238        let amount_bytes = &amount_fq.to_bytes()[0..16];
239        Ok(Amount::from_le_bytes(
240            amount_bytes
241                .try_into()
242                .expect("should be able to fit in 16 bytes"),
243        ))
244    }
245}
246
247impl EqGadget<Fq> for AmountVar {
248    fn is_eq(&self, other: &Self) -> Result<Boolean<Fq>, SynthesisError> {
249        self.amount.is_eq(&other.amount)
250    }
251}
252
253impl CondSelectGadget<Fq> for AmountVar {
254    fn conditionally_select(
255        cond: &Boolean<Fq>,
256        true_value: &Self,
257        false_value: &Self,
258    ) -> Result<Self, SynthesisError> {
259        Ok(Self {
260            amount: FqVar::conditionally_select(cond, &true_value.amount, &false_value.amount)?,
261        })
262    }
263}
264
265impl std::ops::Add for AmountVar {
266    type Output = Self;
267
268    fn add(self, rhs: Self) -> Self::Output {
269        Self {
270            amount: self.amount + rhs.amount,
271        }
272    }
273}
274
275impl std::ops::Sub for AmountVar {
276    type Output = Self;
277
278    fn sub(self, rhs: Self) -> Self::Output {
279        Self {
280            amount: self.amount - rhs.amount,
281        }
282    }
283}
284
285impl std::ops::Mul for AmountVar {
286    type Output = Self;
287
288    fn mul(self, rhs: Self) -> Self::Output {
289        Self {
290            amount: self.amount * rhs.amount,
291        }
292    }
293}
294
295impl From<Amount> for pb::Amount {
296    fn from(a: Amount) -> Self {
297        let lo = a.inner as u64;
298        let hi = (a.inner >> 64) as u64;
299        pb::Amount { lo, hi }
300    }
301}
302
303impl TryFrom<pb::Amount> for Amount {
304    type Error = anyhow::Error;
305
306    fn try_from(amount: pb::Amount) -> Result<Self, Self::Error> {
307        let lo = amount.lo as u128;
308        let hi = amount.hi as u128;
309        // `hi` and `lo` represent the high/low order bytes respectively.
310        //
311        // We want to decode `hi` and `lo` into a single `u128` of the form:
312        //
313        //            hi: u64                          lo: u64
314        // ┌───┬───┬───┬───┬───┬───┬───┬───┐ ┌───┬───┬───┬───┬───┬───┬───┬───┐
315        // │   │   │   │   │   │   │   │   │ │   │   │   │   │   │   │   │   │
316        // └───┴───┴───┴───┴───┴───┴───┴───┘ └───┴───┴───┴───┴───┴───┴───┴───┘
317        //   15  14  13  12  11  10  9   8     7   6   5   4   3   2   1   0
318        //
319        // To achieve this, we shift `hi` 8 bytes to the left:
320        let shifted = hi << 64;
321        // and then add the lower order bytes:
322        let inner = shifted + lo;
323
324        Ok(Amount { inner })
325    }
326}
327
328impl TryFrom<std::string::String> for Amount {
329    type Error = anyhow::Error;
330
331    fn try_from(s: std::string::String) -> Result<Self, Self::Error> {
332        let inner = s.parse::<u128>()?;
333        Ok(Amount { inner })
334    }
335}
336
337impl DomainType for Amount {
338    type Proto = pb::Amount;
339}
340
341impl From<u64> for Amount {
342    fn from(amount: u64) -> Amount {
343        Amount {
344            inner: amount as u128,
345        }
346    }
347}
348
349impl From<u32> for Amount {
350    fn from(amount: u32) -> Amount {
351        Amount {
352            inner: amount as u128,
353        }
354    }
355}
356
357impl From<u16> for Amount {
358    fn from(amount: u16) -> Amount {
359        Amount {
360            inner: amount as u128,
361        }
362    }
363}
364
365impl From<u8> for Amount {
366    fn from(amount: u8) -> Amount {
367        Amount {
368            inner: amount as u128,
369        }
370    }
371}
372
373impl From<Amount> for f64 {
374    fn from(amount: Amount) -> f64 {
375        amount.inner as f64
376    }
377}
378
379impl Display for Amount {
380    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
381        write!(f, "{}", self.inner)
382    }
383}
384
385impl ops::Add<Amount> for Amount {
386    type Output = Amount;
387
388    fn add(self, rhs: Amount) -> Amount {
389        Amount {
390            inner: self.inner + rhs.inner,
391        }
392    }
393}
394
395impl ops::AddAssign<Amount> for Amount {
396    fn add_assign(&mut self, rhs: Amount) {
397        self.inner += rhs.inner;
398    }
399}
400
401impl ops::Sub<Amount> for Amount {
402    type Output = Amount;
403
404    fn sub(self, rhs: Amount) -> Amount {
405        Amount {
406            inner: self.inner - rhs.inner,
407        }
408    }
409}
410
411impl ops::SubAssign<Amount> for Amount {
412    fn sub_assign(&mut self, rhs: Amount) {
413        self.inner -= rhs.inner;
414    }
415}
416
417impl ops::Rem<Amount> for Amount {
418    type Output = Amount;
419
420    fn rem(self, rhs: Amount) -> Amount {
421        Amount {
422            inner: self.inner % rhs.inner,
423        }
424    }
425}
426
427impl ops::Mul<Amount> for Amount {
428    type Output = Amount;
429
430    fn mul(self, rhs: Amount) -> Amount {
431        Amount {
432            inner: self.inner * rhs.inner,
433        }
434    }
435}
436
437impl ops::Div<Amount> for Amount {
438    type Output = Amount;
439
440    fn div(self, rhs: Amount) -> Amount {
441        Amount {
442            inner: self.inner / rhs.inner,
443        }
444    }
445}
446
447impl From<NonZeroU128> for Amount {
448    fn from(n: NonZeroU128) -> Self {
449        Self { inner: n.get() }
450    }
451}
452
453impl From<Amount> for Fq {
454    fn from(amount: Amount) -> Fq {
455        Fq::from(amount.inner)
456    }
457}
458
459impl From<Amount> for Fr {
460    fn from(amount: Amount) -> Fr {
461        Fr::from(amount.inner)
462    }
463}
464
465impl From<u128> for Amount {
466    fn from(amount: u128) -> Amount {
467        Amount { inner: amount }
468    }
469}
470
471impl From<Amount> for u128 {
472    fn from(amount: Amount) -> u128 {
473        amount.inner
474    }
475}
476
477impl From<i128> for Amount {
478    fn from(amount: i128) -> Amount {
479        Amount {
480            inner: amount as u128,
481        }
482    }
483}
484
485impl From<Amount> for i128 {
486    fn from(amount: Amount) -> i128 {
487        amount.inner as i128
488    }
489}
490
491impl From<Amount> for U128x128 {
492    fn from(amount: Amount) -> U128x128 {
493        U128x128::from(amount.inner)
494    }
495}
496
497impl From<&Amount> for U128x128 {
498    fn from(value: &Amount) -> Self {
499        (*value).into()
500    }
501}
502
503impl TryFrom<U128x128> for Amount {
504    type Error = <u128 as TryFrom<U128x128>>::Error;
505    fn try_from(value: U128x128) -> Result<Self, Self::Error> {
506        Ok(Amount {
507            inner: value.try_into()?,
508        })
509    }
510}
511
512impl U128x128Var {
513    pub fn from_amount_var(amount: AmountVar) -> Result<U128x128Var, SynthesisError> {
514        let bits = amount.amount.to_bits_le()?;
515        let limb_2 = UInt64::from_bits_le(&bits[0..64]);
516        let limb_3 = UInt64::from_bits_le(&bits[64..128]);
517        Ok(Self {
518            limbs: [
519                UInt64::constant(0u64),
520                UInt64::constant(0u64),
521                limb_2,
522                limb_3,
523            ],
524        })
525    }
526}
527
528impl From<U128x128Var> for AmountVar {
529    fn from(value: U128x128Var) -> Self {
530        let mut le_bits = Vec::new();
531        le_bits.extend_from_slice(&value.limbs[2].to_bits_le()[..]);
532        le_bits.extend_from_slice(&value.limbs[3].to_bits_le()[..]);
533        Self {
534            amount: Boolean::<Fq>::le_bits_to_fp_var(&le_bits[..]).expect("can convert to bits"),
535        }
536    }
537}
538
539impl Sum for Amount {
540    fn sum<I: Iterator<Item = Amount>>(iter: I) -> Amount {
541        iter.fold(Amount::zero(), |acc, x| acc + x)
542    }
543}
544
545#[cfg(test)]
546mod test {
547    use crate::Amount;
548    use penumbra_sdk_proto::penumbra::core::num::v1 as pb;
549    use rand::RngCore;
550    use rand_core::OsRng;
551
552    fn encode_decode(value: u128) -> u128 {
553        let amount = Amount { inner: value };
554        let proto: pb::Amount = amount.into();
555        Amount::try_from(proto).unwrap().inner
556    }
557
558    #[test]
559    fn encode_decode_max() {
560        let value = u128::MAX;
561        assert_eq!(value, encode_decode(value))
562    }
563
564    #[test]
565    fn encode_decode_zero() {
566        let value = u128::MIN;
567        assert_eq!(value, encode_decode(value))
568    }
569
570    #[test]
571    fn encode_decode_right_border_bit() {
572        let value: u128 = 1 << 64;
573        assert_eq!(value, encode_decode(value))
574    }
575
576    #[test]
577    fn encode_decode_left_border_bit() {
578        let value: u128 = 1 << 63;
579        assert_eq!(value, encode_decode(value))
580    }
581
582    #[test]
583    fn encode_decode_random() {
584        let mut rng = OsRng;
585        let mut dest: [u8; 16] = [0; 16];
586        rng.fill_bytes(&mut dest);
587        let value: u128 = u128::from_le_bytes(dest);
588        assert_eq!(value, encode_decode(value))
589    }
590
591    #[test]
592    fn encode_decode_u64_max() {
593        let value = u64::MAX as u128;
594        assert_eq!(value, encode_decode(value))
595    }
596
597    #[test]
598    fn encode_decode_random_lower_order_bytes() {
599        let mut rng = OsRng;
600        let lo = rng.next_u64() as u128;
601        assert_eq!(lo, encode_decode(lo))
602    }
603
604    #[test]
605    fn encode_decode_random_higher_order_bytes() {
606        let mut rng = OsRng;
607        let value = rng.next_u64();
608        let hi = (value as u128) << 64;
609        assert_eq!(hi, encode_decode(hi))
610    }
611}