1use ark_ff::{BigInteger, PrimeField, ToConstraintField};
2use ark_r1cs_std::{prelude::*, uint64::UInt64};
3use ark_relations::r1cs::{ConstraintSystemRef, SynthesisError};
4use decaf377::{Fq, Fr};
5use penumbra_sdk_proto::{penumbra::core::num::v1 as pb, DomainType};
6use serde::{Deserialize, Serialize};
7use std::{fmt::Display, iter::Sum, num::NonZeroU128, ops};
8
9use crate::fixpoint::{bit_constrain, U128x128, U128x128Var};
10use decaf377::r1cs::FqVar;
11
12#[derive(Serialize, Default, Deserialize, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy)]
13#[serde(try_from = "pb::Amount", into = "pb::Amount")]
14pub struct Amount {
15 inner: u128,
16}
17
18impl std::fmt::Debug for Amount {
19 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
20 write!(f, "{}", self.inner)
21 }
22}
23
24impl Amount {
25 pub fn value(&self) -> u128 {
26 self.inner
27 }
28
29 pub fn zero() -> Self {
30 Self { inner: 0 }
31 }
32
33 pub fn to_le_bytes(&self) -> [u8; 16] {
35 self.inner.to_le_bytes()
36 }
37
38 pub fn to_be_bytes(&self) -> [u8; 16] {
39 self.inner.to_be_bytes()
40 }
41
42 pub fn from_le_bytes(bytes: [u8; 16]) -> Amount {
43 Amount {
44 inner: u128::from_le_bytes(bytes),
45 }
46 }
47
48 pub fn from_be_bytes(bytes: [u8; 16]) -> Amount {
49 Amount {
50 inner: u128::from_be_bytes(bytes),
51 }
52 }
53
54 pub fn checked_sub(&self, rhs: &Self) -> Option<Self> {
55 self.inner
56 .checked_sub(rhs.inner)
57 .map(|inner| Self { inner })
58 }
59
60 pub fn checked_add(&self, rhs: &Self) -> Option<Self> {
61 self.inner
62 .checked_add(rhs.inner)
63 .map(|inner| Self { inner })
64 }
65
66 pub fn checked_mul(&self, rhs: &Self) -> Option<Self> {
67 self.inner
68 .checked_mul(rhs.inner)
69 .map(|inner| Self { inner })
70 }
71
72 pub fn saturating_add(&self, rhs: &Self) -> Self {
73 Self {
74 inner: self.inner.saturating_add(rhs.inner),
75 }
76 }
77
78 pub fn saturating_sub(&self, rhs: &Self) -> Self {
79 Self {
80 inner: self.inner.saturating_sub(rhs.inner),
81 }
82 }
83}
84
85impl ops::Not for Amount {
86 type Output = Self;
87
88 fn not(self) -> Self::Output {
89 Self { inner: !self.inner }
90 }
91}
92
93#[derive(Clone)]
94pub struct AmountVar {
95 pub amount: FqVar,
96}
97
98impl ToConstraintField<Fq> for Amount {
99 fn to_field_elements(&self) -> Option<Vec<Fq>> {
100 let mut elements = Vec::new();
101 elements.extend_from_slice(&[Fq::from(self.inner)]);
102 Some(elements)
103 }
104}
105
106pub fn is_bit_constrained(
108 cs: ConstraintSystemRef<Fq>,
109 value: FqVar,
110 n: usize,
111) -> Result<Boolean<Fq>, SynthesisError> {
112 let inner = value.value().unwrap_or(Fq::from(1u64));
113
114 let inner_bigint = inner.into_bigint();
116 let bits = &inner_bigint.to_bits_le()[0..n];
117
118 let mut boolean_constraints = Vec::new();
120 for bit in bits {
121 let boolean = Boolean::new_witness(cs.clone(), || Ok(bit))?;
122 boolean_constraints.push(boolean);
123 }
124
125 let constructed_fqvar = Boolean::<Fq>::le_bits_to_fp_var(&boolean_constraints.to_bits_le()?)
127 .expect("can convert to bits");
128 constructed_fqvar.is_eq(&value)
129}
130
131impl AmountVar {
132 pub fn negate(&self) -> Result<Self, SynthesisError> {
133 Ok(Self {
134 amount: self.amount.negate()?,
135 })
136 }
137
138 pub fn quo_rem(
139 &self,
140 divisor_var: &AmountVar,
141 ) -> Result<(AmountVar, AmountVar), SynthesisError> {
142 let current_amount_bytes: [u8; 16] = self.amount.value().unwrap_or_default().to_bytes()
143 [0..16]
144 .try_into()
145 .expect("amounts should fit in 16 bytes");
146 let current_amount = u128::from_le_bytes(current_amount_bytes);
147 let divisor_bytes: [u8; 16] = divisor_var.amount.value().unwrap_or_default().to_bytes()
148 [0..16]
149 .try_into()
150 .expect("amounts should fit in 16 bytes");
151 let divisor = u128::from_le_bytes(divisor_bytes);
152
153 let quo = current_amount.checked_div(divisor).unwrap_or(0);
155 let rem = current_amount.checked_rem(divisor).unwrap_or(0);
156
157 let quo_var = AmountVar::new_witness(self.cs(), || Ok(Fq::from(quo)))?;
159 let rem_var = AmountVar::new_witness(self.cs(), || Ok(Fq::from(rem)))?;
160
161 let q_is_64_bits = is_bit_constrained(self.cs(), quo_var.amount.clone(), 64)?;
163 let d_is_64_bits = is_bit_constrained(self.cs(), divisor_var.amount.clone(), 64)?;
164 let q_or_d_is_64_bits = q_is_64_bits.or(&d_is_64_bits)?;
165 q_or_d_is_64_bits.enforce_equal(&Boolean::constant(true))?;
166
167 let numerator_var = quo_var.clone() * divisor_var.clone() + rem_var.clone();
169 self.enforce_equal(&numerator_var)?;
170
171 rem_var
181 .amount
182 .enforce_cmp(&divisor_var.amount, core::cmp::Ordering::Less, false)?;
183 Ok((quo_var, rem_var))
190 }
191}
192
193impl AllocVar<Amount, Fq> for AmountVar {
194 fn new_variable<T: std::borrow::Borrow<Amount>>(
195 cs: impl Into<ark_relations::r1cs::Namespace<Fq>>,
196 f: impl FnOnce() -> Result<T, SynthesisError>,
197 mode: ark_r1cs_std::prelude::AllocationMode,
198 ) -> Result<Self, SynthesisError> {
199 let ns = cs.into();
200 let cs = ns.cs();
201 let amount: Amount = *f()?.borrow();
202 let inner_amount_var = FqVar::new_variable(cs, || Ok(Fq::from(amount)), mode)?;
203 let _ = bit_constrain(inner_amount_var.clone(), 128);
205 Ok(Self {
206 amount: inner_amount_var,
207 })
208 }
209}
210
211impl AllocVar<Fq, Fq> for AmountVar {
212 fn new_variable<T: std::borrow::Borrow<Fq>>(
213 cs: impl Into<ark_relations::r1cs::Namespace<Fq>>,
214 f: impl FnOnce() -> Result<T, SynthesisError>,
215 mode: ark_r1cs_std::prelude::AllocationMode,
216 ) -> Result<Self, SynthesisError> {
217 let ns = cs.into();
218 let cs = ns.cs();
219 let amount: Fq = *f()?.borrow();
220 let inner_amount_var = FqVar::new_variable(cs, || Ok(amount), mode)?;
221 let _ = bit_constrain(inner_amount_var.clone(), 128);
223 Ok(Self {
224 amount: inner_amount_var,
225 })
226 }
227}
228
229impl R1CSVar<Fq> for AmountVar {
230 type Value = Amount;
231
232 fn cs(&self) -> ark_relations::r1cs::ConstraintSystemRef<Fq> {
233 self.amount.cs()
234 }
235
236 fn value(&self) -> Result<Self::Value, SynthesisError> {
237 let amount_fq = self.amount.value()?;
238 let amount_bytes = &amount_fq.to_bytes()[0..16];
239 Ok(Amount::from_le_bytes(
240 amount_bytes
241 .try_into()
242 .expect("should be able to fit in 16 bytes"),
243 ))
244 }
245}
246
247impl EqGadget<Fq> for AmountVar {
248 fn is_eq(&self, other: &Self) -> Result<Boolean<Fq>, SynthesisError> {
249 self.amount.is_eq(&other.amount)
250 }
251}
252
253impl CondSelectGadget<Fq> for AmountVar {
254 fn conditionally_select(
255 cond: &Boolean<Fq>,
256 true_value: &Self,
257 false_value: &Self,
258 ) -> Result<Self, SynthesisError> {
259 Ok(Self {
260 amount: FqVar::conditionally_select(cond, &true_value.amount, &false_value.amount)?,
261 })
262 }
263}
264
265impl std::ops::Add for AmountVar {
266 type Output = Self;
267
268 fn add(self, rhs: Self) -> Self::Output {
269 Self {
270 amount: self.amount + rhs.amount,
271 }
272 }
273}
274
275impl std::ops::Sub for AmountVar {
276 type Output = Self;
277
278 fn sub(self, rhs: Self) -> Self::Output {
279 Self {
280 amount: self.amount - rhs.amount,
281 }
282 }
283}
284
285impl std::ops::Mul for AmountVar {
286 type Output = Self;
287
288 fn mul(self, rhs: Self) -> Self::Output {
289 Self {
290 amount: self.amount * rhs.amount,
291 }
292 }
293}
294
295impl From<Amount> for pb::Amount {
296 fn from(a: Amount) -> Self {
297 let lo = a.inner as u64;
298 let hi = (a.inner >> 64) as u64;
299 pb::Amount { lo, hi }
300 }
301}
302
303impl TryFrom<pb::Amount> for Amount {
304 type Error = anyhow::Error;
305
306 fn try_from(amount: pb::Amount) -> Result<Self, Self::Error> {
307 let lo = amount.lo as u128;
308 let hi = amount.hi as u128;
309 let shifted = hi << 64;
321 let inner = shifted + lo;
323
324 Ok(Amount { inner })
325 }
326}
327
328impl TryFrom<std::string::String> for Amount {
329 type Error = anyhow::Error;
330
331 fn try_from(s: std::string::String) -> Result<Self, Self::Error> {
332 let inner = s.parse::<u128>()?;
333 Ok(Amount { inner })
334 }
335}
336
337impl DomainType for Amount {
338 type Proto = pb::Amount;
339}
340
341impl From<u64> for Amount {
342 fn from(amount: u64) -> Amount {
343 Amount {
344 inner: amount as u128,
345 }
346 }
347}
348
349impl From<u32> for Amount {
350 fn from(amount: u32) -> Amount {
351 Amount {
352 inner: amount as u128,
353 }
354 }
355}
356
357impl From<u16> for Amount {
358 fn from(amount: u16) -> Amount {
359 Amount {
360 inner: amount as u128,
361 }
362 }
363}
364
365impl From<u8> for Amount {
366 fn from(amount: u8) -> Amount {
367 Amount {
368 inner: amount as u128,
369 }
370 }
371}
372
373impl From<Amount> for f64 {
374 fn from(amount: Amount) -> f64 {
375 amount.inner as f64
376 }
377}
378
379impl Display for Amount {
380 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
381 write!(f, "{}", self.inner)
382 }
383}
384
385impl ops::Add<Amount> for Amount {
386 type Output = Amount;
387
388 fn add(self, rhs: Amount) -> Amount {
389 Amount {
390 inner: self.inner + rhs.inner,
391 }
392 }
393}
394
395impl ops::AddAssign<Amount> for Amount {
396 fn add_assign(&mut self, rhs: Amount) {
397 self.inner += rhs.inner;
398 }
399}
400
401impl ops::Sub<Amount> for Amount {
402 type Output = Amount;
403
404 fn sub(self, rhs: Amount) -> Amount {
405 Amount {
406 inner: self.inner - rhs.inner,
407 }
408 }
409}
410
411impl ops::SubAssign<Amount> for Amount {
412 fn sub_assign(&mut self, rhs: Amount) {
413 self.inner -= rhs.inner;
414 }
415}
416
417impl ops::Rem<Amount> for Amount {
418 type Output = Amount;
419
420 fn rem(self, rhs: Amount) -> Amount {
421 Amount {
422 inner: self.inner % rhs.inner,
423 }
424 }
425}
426
427impl ops::Mul<Amount> for Amount {
428 type Output = Amount;
429
430 fn mul(self, rhs: Amount) -> Amount {
431 Amount {
432 inner: self.inner * rhs.inner,
433 }
434 }
435}
436
437impl ops::Div<Amount> for Amount {
438 type Output = Amount;
439
440 fn div(self, rhs: Amount) -> Amount {
441 Amount {
442 inner: self.inner / rhs.inner,
443 }
444 }
445}
446
447impl From<NonZeroU128> for Amount {
448 fn from(n: NonZeroU128) -> Self {
449 Self { inner: n.get() }
450 }
451}
452
453impl From<Amount> for Fq {
454 fn from(amount: Amount) -> Fq {
455 Fq::from(amount.inner)
456 }
457}
458
459impl From<Amount> for Fr {
460 fn from(amount: Amount) -> Fr {
461 Fr::from(amount.inner)
462 }
463}
464
465impl From<u128> for Amount {
466 fn from(amount: u128) -> Amount {
467 Amount { inner: amount }
468 }
469}
470
471impl From<Amount> for u128 {
472 fn from(amount: Amount) -> u128 {
473 amount.inner
474 }
475}
476
477impl From<i128> for Amount {
478 fn from(amount: i128) -> Amount {
479 Amount {
480 inner: amount as u128,
481 }
482 }
483}
484
485impl From<Amount> for i128 {
486 fn from(amount: Amount) -> i128 {
487 amount.inner as i128
488 }
489}
490
491impl From<Amount> for U128x128 {
492 fn from(amount: Amount) -> U128x128 {
493 U128x128::from(amount.inner)
494 }
495}
496
497impl From<&Amount> for U128x128 {
498 fn from(value: &Amount) -> Self {
499 (*value).into()
500 }
501}
502
503impl TryFrom<U128x128> for Amount {
504 type Error = <u128 as TryFrom<U128x128>>::Error;
505 fn try_from(value: U128x128) -> Result<Self, Self::Error> {
506 Ok(Amount {
507 inner: value.try_into()?,
508 })
509 }
510}
511
512impl U128x128Var {
513 pub fn from_amount_var(amount: AmountVar) -> Result<U128x128Var, SynthesisError> {
514 let bits = amount.amount.to_bits_le()?;
515 let limb_2 = UInt64::from_bits_le(&bits[0..64]);
516 let limb_3 = UInt64::from_bits_le(&bits[64..128]);
517 Ok(Self {
518 limbs: [
519 UInt64::constant(0u64),
520 UInt64::constant(0u64),
521 limb_2,
522 limb_3,
523 ],
524 })
525 }
526}
527
528impl From<U128x128Var> for AmountVar {
529 fn from(value: U128x128Var) -> Self {
530 let mut le_bits = Vec::new();
531 le_bits.extend_from_slice(&value.limbs[2].to_bits_le()[..]);
532 le_bits.extend_from_slice(&value.limbs[3].to_bits_le()[..]);
533 Self {
534 amount: Boolean::<Fq>::le_bits_to_fp_var(&le_bits[..]).expect("can convert to bits"),
535 }
536 }
537}
538
539impl Sum for Amount {
540 fn sum<I: Iterator<Item = Amount>>(iter: I) -> Amount {
541 iter.fold(Amount::zero(), |acc, x| acc + x)
542 }
543}
544
545#[cfg(test)]
546mod test {
547 use crate::Amount;
548 use penumbra_sdk_proto::penumbra::core::num::v1 as pb;
549 use rand::RngCore;
550 use rand_core::OsRng;
551
552 fn encode_decode(value: u128) -> u128 {
553 let amount = Amount { inner: value };
554 let proto: pb::Amount = amount.into();
555 Amount::try_from(proto).unwrap().inner
556 }
557
558 #[test]
559 fn encode_decode_max() {
560 let value = u128::MAX;
561 assert_eq!(value, encode_decode(value))
562 }
563
564 #[test]
565 fn encode_decode_zero() {
566 let value = u128::MIN;
567 assert_eq!(value, encode_decode(value))
568 }
569
570 #[test]
571 fn encode_decode_right_border_bit() {
572 let value: u128 = 1 << 64;
573 assert_eq!(value, encode_decode(value))
574 }
575
576 #[test]
577 fn encode_decode_left_border_bit() {
578 let value: u128 = 1 << 63;
579 assert_eq!(value, encode_decode(value))
580 }
581
582 #[test]
583 fn encode_decode_random() {
584 let mut rng = OsRng;
585 let mut dest: [u8; 16] = [0; 16];
586 rng.fill_bytes(&mut dest);
587 let value: u128 = u128::from_le_bytes(dest);
588 assert_eq!(value, encode_decode(value))
589 }
590
591 #[test]
592 fn encode_decode_u64_max() {
593 let value = u64::MAX as u128;
594 assert_eq!(value, encode_decode(value))
595 }
596
597 #[test]
598 fn encode_decode_random_lower_order_bytes() {
599 let mut rng = OsRng;
600 let lo = rng.next_u64() as u128;
601 assert_eq!(lo, encode_decode(lo))
602 }
603
604 #[test]
605 fn encode_decode_random_higher_order_bytes() {
606 let mut rng = OsRng;
607 let value = rng.next_u64();
608 let hi = (value as u128) << 64;
609 assert_eq!(hi, encode_decode(hi))
610 }
611}