penumbra_sdk_num/
fixpoint.rs

1use std::{
2    fmt::{Debug, Display},
3    iter::zip,
4};
5
6mod div;
7mod from;
8mod ops;
9
10#[cfg(test)]
11mod tests;
12
13use ark_ff::{BigInteger, PrimeField, ToConstraintField, Zero};
14use ark_r1cs_std::bits::uint64::UInt64;
15use ark_r1cs_std::fields::fp::FpVar;
16use ark_r1cs_std::prelude::*;
17use ark_relations::r1cs::{ConstraintSystemRef, SynthesisError};
18
19use decaf377::{r1cs::FqVar, Fq};
20use ethnum::U256;
21
22use crate::{Amount, AmountVar};
23
24use self::div::stub_div_rem_u384_by_u256;
25
26#[derive(thiserror::Error, Debug)]
27pub enum Error {
28    #[error("overflow")]
29    Overflow,
30    #[error("underflow")]
31    Underflow,
32    #[error("division by zero")]
33    DivisionByZero,
34    #[error("attempted to convert invalid f64: {value:?} to a U128x128")]
35    InvalidFloat64 { value: f64 },
36    #[error("attempted to convert non-integral value {value:?} to an integer")]
37    NonIntegral { value: U128x128 },
38    #[error("attempted to decode a slice of the wrong length {0}, expected 32")]
39    SliceLength(usize),
40}
41
42#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
43pub struct U128x128(U256);
44
45impl Default for U128x128 {
46    fn default() -> Self {
47        Self::from(0u64)
48    }
49}
50
51impl Debug for U128x128 {
52    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
53        let (integral, fractional) = self.0.into_words();
54        f.debug_struct("U128x128")
55            .field("integral", &integral)
56            .field("fractional", &fractional)
57            .finish()
58    }
59}
60
61impl Display for U128x128 {
62    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
63        write!(f, "{}", f64::from(*self))
64    }
65}
66
67impl U128x128 {
68    /// Encode this number as a 32-byte array.
69    ///
70    /// The encoding has the property that it preserves ordering, i.e., if `x <=
71    /// y` (with numeric ordering) then `x.to_bytes() <= y.to_bytes()` (with the
72    /// lex ordering on byte strings).
73    pub fn to_bytes(self) -> [u8; 32] {
74        // The U256 type has really weird endianness handling -- e.g., it reverses
75        // the endianness of the inner u128s (??) -- so just do it manually.
76        let mut bytes = [0u8; 32];
77        let (hi, lo) = self.0.into_words();
78        bytes[0..16].copy_from_slice(&hi.to_be_bytes());
79        bytes[16..32].copy_from_slice(&lo.to_be_bytes());
80        bytes
81    }
82
83    /// Decode this number from a 32-byte array.
84    pub fn from_bytes(bytes: [u8; 32]) -> Self {
85        // See above.
86        let hi = u128::from_be_bytes(bytes[0..16].try_into().expect("slice is 16 bytes"));
87        let lo = u128::from_be_bytes(bytes[16..32].try_into().expect("slice is 16 bytes"));
88        Self(U256::from_words(hi, lo))
89    }
90
91    pub fn ratio<T: Into<Self>>(numerator: T, denominator: T) -> Result<Self, Error> {
92        numerator.into() / denominator.into()
93    }
94
95    /// Checks whether this number is integral, i.e., whether it has no fractional part.
96    pub fn is_integral(&self) -> bool {
97        let fractional_word = self.0.into_words().1;
98        fractional_word == 0
99    }
100
101    /// Rounds the number down to the nearest integer.
102    pub fn round_down(self) -> Self {
103        let integral_word = self.0.into_words().0;
104        Self(U256::from_words(integral_word, 0u128))
105    }
106
107    /// Rounds the number up to the nearest integer.
108    pub fn round_up(&self) -> Result<Self, Error> {
109        let (integral, fractional) = self.0.into_words();
110        if fractional == 0 {
111            Ok(*self)
112        } else {
113            let integral = integral.checked_add(1).ok_or(Error::Overflow)?;
114            Ok(Self(U256::from_words(integral, 0u128)))
115        }
116    }
117
118    /// Performs checked multiplication, returning `Ok` if no overflow occurred.
119    pub fn checked_mul(self, rhs: &Self) -> Result<Self, Error> {
120        // It's important to use `into_words` because the `U256` type has an
121        // unsafe API that makes the limb ordering dependent on the host
122        // endianness.
123        let (x1, x0) = self.0.into_words();
124        let (y1, y0) = rhs.0.into_words();
125        let x0 = U256::from(x0);
126        let x1 = U256::from(x1);
127        let y0 = U256::from(y0);
128        let y1 = U256::from(y1);
129
130        // x = (x0*2^-128 + x1)*2^128
131        // y = (y0*2^-128 + y1)*2^128
132        // x*y        = (x0*y0*2^-256 + (x0*y1 + x1*y0)*2^-128 + x1*y1)*2^256
133        // x*y*2^-128 = (x0*y0*2^-256 + (x0*y1 + x1*y0)*2^-128 + x1*y1)*2^128
134        //               ^^^^^
135        //               we drop the low 128 bits of this term as rounding error
136
137        let x0y0 = x0 * y0; // cannot overflow, widening mul
138        let x0y1 = x0 * y1; // cannot overflow, widening mul
139        let x1y0 = x1 * y0; // cannot overflow, widening mul
140        let x1y1 = x1 * y1; // cannot overflow, widening mul
141
142        let (x1y1_hi, _x1y1_lo) = x1y1.into_words();
143        if x1y1_hi != 0 {
144            return Err(Error::Overflow);
145        }
146
147        x1y1.checked_shl(128)
148            .and_then(|acc| acc.checked_add(x0y1))
149            .and_then(|acc| acc.checked_add(x1y0))
150            .and_then(|acc| acc.checked_add(x0y0 >> 128))
151            .map(U128x128)
152            .ok_or(Error::Overflow)
153    }
154
155    /// Performs checked division, returning `Ok` if no overflow occurred.
156    pub fn checked_div(self, rhs: &Self) -> Result<Self, Error> {
157        stub_div_rem_u384_by_u256(self.0, rhs.0).map(|(quo, _rem)| U128x128(quo))
158    }
159
160    /// Performs checked addition, returning `Ok` if no overflow occurred.
161    pub fn checked_add(self, rhs: &Self) -> Result<Self, Error> {
162        self.0
163            .checked_add(rhs.0)
164            .map(U128x128)
165            .ok_or(Error::Overflow)
166    }
167
168    /// Performs checked subtraction, returning `Ok` if no underflow occurred.
169    pub fn checked_sub(self, rhs: &Self) -> Result<Self, Error> {
170        self.0
171            .checked_sub(rhs.0)
172            .map(U128x128)
173            .ok_or(Error::Underflow)
174    }
175
176    /// Saturating integer subtraction. Computes self - rhs, saturating at the numeric bounds instead of overflowing.
177    pub fn saturating_sub(self, rhs: &Self) -> Self {
178        U128x128(self.0.saturating_sub(rhs.0))
179    }
180
181    /// Multiply an amount by this fraction, then round down.
182    pub fn apply_to_amount(self, rhs: &Amount) -> Result<Amount, Error> {
183        let mul = (Self::from(rhs) * self)?;
184        let out = mul
185            .round_down()
186            .try_into()
187            .expect("converting integral U128xU128 into Amount will succeed");
188        Ok(out)
189    }
190}
191
192#[derive(Clone)]
193pub struct U128x128Var {
194    pub limbs: [UInt64<Fq>; 4],
195}
196
197impl AllocVar<U128x128, Fq> for U128x128Var {
198    fn new_variable<T: std::borrow::Borrow<U128x128>>(
199        cs: impl Into<ark_relations::r1cs::Namespace<Fq>>,
200        f: impl FnOnce() -> Result<T, SynthesisError>,
201        mode: ark_r1cs_std::prelude::AllocationMode,
202    ) -> Result<Self, SynthesisError> {
203        let ns = cs.into();
204        let cs = ns.cs();
205        let inner: U128x128 = *f()?.borrow();
206
207        // TODO: in the case of a constant U128x128Var, this will allocate
208        // witness vars instead of constants, but we don't have much use for
209        // constant U128x128Vars anyways, so this efficiency loss shouldn't be a
210        // problem.
211
212        let (hi_128, lo_128) = inner.0.into_words();
213        let hi_128_var = FqVar::new_variable(cs.clone(), || Ok(Fq::from(hi_128)), mode)?;
214        let lo_128_var = FqVar::new_variable(cs.clone(), || Ok(Fq::from(lo_128)), mode)?;
215
216        // Now construct the bit constraints out of thin air ...
217        let bytes = inner.to_bytes();
218        // The U128x128 type uses a big-endian encoding
219        let limb_3 = u64::from_be_bytes(bytes[0..8].try_into().expect("slice is 8 bytes"));
220        let limb_2 = u64::from_be_bytes(bytes[8..16].try_into().expect("slice is 8 bytes"));
221        let limb_1 = u64::from_be_bytes(bytes[16..24].try_into().expect("slice is 8 bytes"));
222        let limb_0 = u64::from_be_bytes(bytes[24..32].try_into().expect("slice is 8 bytes"));
223
224        let limb_0_var = UInt64::new_variable(cs.clone(), || Ok(limb_0), AllocationMode::Witness)?;
225        let limb_1_var = UInt64::new_variable(cs.clone(), || Ok(limb_1), AllocationMode::Witness)?;
226        let limb_2_var = UInt64::new_variable(cs.clone(), || Ok(limb_2), AllocationMode::Witness)?;
227        let limb_3_var = UInt64::new_variable(cs, || Ok(limb_3), AllocationMode::Witness)?;
228
229        // ... and then bind them to the input variables we created above.
230        let lo_128_bits = limb_0_var
231            .to_bits_le()
232            .into_iter()
233            .chain(limb_1_var.to_bits_le())
234            .collect::<Vec<_>>();
235        let hi_128_bits = limb_2_var
236            .to_bits_le()
237            .into_iter()
238            .chain(limb_3_var.to_bits_le())
239            .collect::<Vec<_>>();
240
241        hi_128_var.enforce_equal(&Boolean::<Fq>::le_bits_to_fp_var(
242            &(hi_128_bits[..]).to_bits_le()?,
243        )?)?;
244        lo_128_var.enforce_equal(&Boolean::<Fq>::le_bits_to_fp_var(
245            &(lo_128_bits[..]).to_bits_le()?,
246        )?)?;
247
248        Ok(Self {
249            limbs: [limb_0_var, limb_1_var, limb_2_var, limb_3_var],
250        })
251    }
252}
253
254impl R1CSVar<Fq> for U128x128Var {
255    type Value = U128x128;
256
257    fn cs(&self) -> ark_relations::r1cs::ConstraintSystemRef<Fq> {
258        self.limbs[0].cs()
259    }
260
261    fn value(&self) -> Result<Self::Value, ark_relations::r1cs::SynthesisError> {
262        let x0 = self.limbs[0].value()?;
263        let x1 = self.limbs[1].value()?;
264        let x2 = self.limbs[2].value()?;
265        let x3 = self.limbs[3].value()?;
266
267        let mut bytes = [0u8; 32];
268        bytes[0..8].copy_from_slice(x3.to_be_bytes().as_ref());
269        bytes[8..16].copy_from_slice(x2.to_be_bytes().as_ref());
270        bytes[16..24].copy_from_slice(x1.to_be_bytes().as_ref());
271        bytes[24..32].copy_from_slice(x0.to_be_bytes().as_ref());
272
273        Ok(Self::Value::from_bytes(bytes))
274    }
275}
276
277impl U128x128Var {
278    pub fn checked_add(self, rhs: &Self) -> Result<U128x128Var, SynthesisError> {
279        // x = [x0, x1, x2, x3]
280        // x = x0 + x1 * 2^64 + x2 * 2^128 + x3 * 2^192
281        // y = [y0, y1, y2, y3]
282        // y = y0 + y1 * 2^64 + y2 * 2^128 + y3 * 2^192
283        let x0 = Boolean::<Fq>::le_bits_to_fp_var(&self.limbs[0].to_bits_le())?;
284        let x1 = Boolean::<Fq>::le_bits_to_fp_var(&self.limbs[1].to_bits_le())?;
285        let x2 = Boolean::<Fq>::le_bits_to_fp_var(&self.limbs[2].to_bits_le())?;
286        let x3 = Boolean::<Fq>::le_bits_to_fp_var(&self.limbs[3].to_bits_le())?;
287
288        let y0 = Boolean::<Fq>::le_bits_to_fp_var(&rhs.limbs[0].to_bits_le())?;
289        let y1 = Boolean::<Fq>::le_bits_to_fp_var(&rhs.limbs[1].to_bits_le())?;
290        let y2 = Boolean::<Fq>::le_bits_to_fp_var(&rhs.limbs[2].to_bits_le())?;
291        let y3 = Boolean::<Fq>::le_bits_to_fp_var(&rhs.limbs[3].to_bits_le())?;
292
293        // z = x + y
294        // z = [z0, z1, z2, z3]
295        let z0_raw = &x0 + &y0;
296        let z1_raw = &x1 + &y1;
297        let z2_raw = &x2 + &y2;
298        let z3_raw = &x3 + &y3;
299
300        // z0 <= (2^64 - 1) + (2^64 - 1) < 2^(65) => 65 bits
301        let z0_bits = bit_constrain(z0_raw, 65)?; // no carry-in
302        let z0 = UInt64::from_bits_le(&z0_bits[0..64]);
303        let c1 = Boolean::<Fq>::le_bits_to_fp_var(&z0_bits[64..].to_bits_le()?)?;
304
305        // z1 <= (2^64 - 1) + (2^64 - 1) + 1 < 2^(65) => 65 bits
306        let z1_bits = bit_constrain(z1_raw + c1, 65)?; // carry-in c1
307        let z1 = UInt64::from_bits_le(&z1_bits[0..64]);
308        let c2 = Boolean::<Fq>::le_bits_to_fp_var(&z1_bits[64..].to_bits_le()?)?;
309
310        // z2 <= (2^64 - 1) + (2^64 - 1) + 1 < 2^(65) => 65 bits
311        let z2_bits = bit_constrain(z2_raw + c2, 65)?; // carry-in c2
312        let z2 = UInt64::from_bits_le(&z2_bits[0..64]);
313        let c3 = Boolean::<Fq>::le_bits_to_fp_var(&z2_bits[64..].to_bits_le()?)?;
314
315        // z3 <= (2^64 - 1) + (2^64 - 1) + 1 < 2^(65) => 65 bits
316        // However, the last bit (65th) which would be used as a final carry flag, should be 0 if there is no overflow.
317        // As such, we can constrain the length for this call to 64 bits.
318        let z3_bits = bit_constrain(z3_raw + c3, 64)?; // carry-in c3
319        let z3 = UInt64::from_bits_le(&z3_bits[0..64]);
320
321        Ok(Self {
322            limbs: [z0, z1, z2, z3],
323        })
324    }
325
326    pub fn checked_sub(
327        self,
328        _rhs: &Self,
329        _cs: ConstraintSystemRef<Fq>,
330    ) -> Result<U128x128Var, SynthesisError> {
331        todo!();
332    }
333
334    pub fn checked_mul(self, rhs: &Self) -> Result<U128x128Var, SynthesisError> {
335        // x = [x0, x1, x2, x3]
336        // x = x0 + x1 * 2^64 + x2 * 2^128 + x3 * 2^192
337        // y = [y0, y1, y2, y3]
338        // y = y0 + y1 * 2^64 + y2 * 2^128 + y3 * 2^192
339        let x0 = Boolean::<Fq>::le_bits_to_fp_var(&self.limbs[0].to_bits_le())?;
340        let x1 = Boolean::<Fq>::le_bits_to_fp_var(&self.limbs[1].to_bits_le())?;
341        let x2 = Boolean::<Fq>::le_bits_to_fp_var(&self.limbs[2].to_bits_le())?;
342        let x3 = Boolean::<Fq>::le_bits_to_fp_var(&self.limbs[3].to_bits_le())?;
343
344        let y0 = Boolean::<Fq>::le_bits_to_fp_var(&rhs.limbs[0].to_bits_le())?;
345        let y1 = Boolean::<Fq>::le_bits_to_fp_var(&rhs.limbs[1].to_bits_le())?;
346        let y2 = Boolean::<Fq>::le_bits_to_fp_var(&rhs.limbs[2].to_bits_le())?;
347        let y3 = Boolean::<Fq>::le_bits_to_fp_var(&rhs.limbs[3].to_bits_le())?;
348
349        // z = x * y
350        // z = [z0, z1, z2, z3, z4, z5, z6, z7]
351        // zi is 128 bits
352        //let z0 = x0.clone() * y0.clone();
353        let z0 = &x0 * &y0;
354        let z1 = &x0 * &y1 + &x1 * &y0;
355        let z2 = &x0 * &y2 + &x1 * &y1 + &x2 * &y0;
356        let z3 = &x0 * &y3 + &x1 * &y2 + &x2 * &y1 + &x3 * &y0;
357        let z4 = &x1 * &y3 + &x2 * &y2 + &x3 * &y1;
358        let z5 = &x2 * &y3 + &x3 * &y2;
359        let z6 = &x3 * &y3;
360        // z7 = 0
361        // z = z0 + z1 * 2^64 + z2 * 2^128 + z3 * 2^192 + z4 * 2^256 + z5 * 2^320 + z6 * 2^384
362        // z*2^-128 = z0*2^-128 + z1*2^-64 + z2 + z3*2^64 + z4*2^128 + z5*2^192 + z6*2^256
363        //
364        // w represents the limbs of the reduced result (z)
365        // w = [w0, w1, w2, w3]
366        // w0
367        // wi are 64 bits like xi and yi
368        //
369        // ti represents some temporary value (indices not necessarily meaningful)
370        let t0 = z0 + z1 * Fq::from(1u128 << 64);
371        let t0_bits = bit_constrain(t0, 193)?;
372        // Constrain: t0 fits in 193 bits
373
374        // t1 = (t0 >> 128) + z2
375        let t1 = z2 + Boolean::<Fq>::le_bits_to_fp_var(&t0_bits[128..193].to_bits_le()?)?;
376        // Constrain: t1 fits in 130 bits
377        let t1_bits = bit_constrain(t1, 130)?;
378
379        // w0 = t1 & 2^64 - 1
380        let w0 = UInt64::from_bits_le(&t1_bits[0..64]);
381
382        // t2 = (t1 >> 64) + z3
383        let t2 = z3 + Boolean::<Fq>::le_bits_to_fp_var(&t1_bits[64..129].to_bits_le()?)?;
384        // Constrain: t2 fits in 129 bits
385        let t2_bits = bit_constrain(t2, 129)?;
386
387        // w1 = t2 & 2^64 - 1
388        let w1 = UInt64::from_bits_le(&t2_bits[0..64]);
389
390        // t3 = (t2 >> 64) + z4
391        let t3 = z4 + Boolean::<Fq>::le_bits_to_fp_var(&t2_bits[64..129].to_bits_le()?)?;
392        // Constrain: t3 fits in 128 bits
393        let t3_bits = bit_constrain(t3, 128)?;
394
395        // w2 = t3 & 2^64 - 1
396        let w2 = UInt64::from_bits_le(&t3_bits[0..64]);
397
398        // t4 = (t3 >> 64) + z5
399        let t4 = z5 + Boolean::<Fq>::le_bits_to_fp_var(&t3_bits[64..128].to_bits_le()?)?;
400        // Constrain: t4 fits in 64 bits
401        let t4_bits = bit_constrain(t4, 64)?;
402        // If we didn't overflow, it will fit in 64 bits.
403
404        // w3 = t4 & 2^64 - 1
405        let w3 = UInt64::from_bits_le(&t4_bits[0..64]);
406
407        // Overflow condition. Constrain: z6 = 0.
408        z6.enforce_equal(&FqVar::zero())?;
409
410        Ok(U128x128Var {
411            limbs: [w0, w1, w2, w3],
412        })
413    }
414
415    pub fn to_bits_le(&self) -> Vec<Boolean<Fq>> {
416        let lo_128_bits = self.limbs[0]
417            .to_bits_le()
418            .into_iter()
419            .chain(self.limbs[1].to_bits_le())
420            .collect::<Vec<_>>();
421        let hi_128_bits = self.limbs[2]
422            .to_bits_le()
423            .into_iter()
424            .chain(self.limbs[3].to_bits_le())
425            .collect::<Vec<_>>();
426        lo_128_bits.into_iter().chain(hi_128_bits).collect()
427    }
428
429    /// This function enforces the ordering between `self` and `other`.
430    pub fn enforce_cmp(
431        &self,
432        other: &U128x128Var,
433        ordering: std::cmp::Ordering,
434    ) -> Result<(), SynthesisError> {
435        // Collect bits from each limb to be compared.
436        let self_bits: Vec<Boolean<Fq>> = self.to_bits_le().into_iter().rev().collect();
437        let other_bits: Vec<Boolean<Fq>> = other.to_bits_le().into_iter().rev().collect();
438
439        // Now starting at the most significant side, compare bits.
440        // `gt` is true if we have conclusively determined that self > other.
441        // `lt` is true if we have conclusively determined that self < other.
442        let mut gt: Boolean<Fq> = Boolean::constant(false);
443        let mut lt: Boolean<Fq> = Boolean::constant(false);
444        for (p, q) in zip(self_bits, other_bits) {
445            // If we've determined that self > other, that will remain
446            // true as we continue to look at other bits. Otherwise,
447            // we need to make sure that we don't have self < other.
448            // At this point, if we see a 1 bit for self and a 0 bit for other,
449            // we know that self > other.
450            gt = gt.or(&lt.not().and(&p)?.and(&q.not())?)?;
451            // The exact same logic, but swapping gt <-> lt, p <-> q
452            lt = lt.or(&gt.not().and(&q)?.and(&p.not())?)?;
453        }
454
455        match ordering {
456            std::cmp::Ordering::Greater => {
457                gt.enforce_equal(&Boolean::constant(true))?;
458                lt.enforce_equal(&Boolean::constant(false))?;
459            }
460            std::cmp::Ordering::Less => {
461                gt.enforce_equal(&Boolean::constant(false))?;
462                lt.enforce_equal(&Boolean::constant(true))?;
463            }
464            std::cmp::Ordering::Equal => {
465                unimplemented!("use EqGadget for efficiency");
466            }
467        }
468
469        Ok(())
470    }
471
472    pub fn checked_div(
473        self,
474        rhs: &Self,
475        cs: ConstraintSystemRef<Fq>,
476    ) -> Result<U128x128Var, SynthesisError> {
477        // Similar to AmountVar::quo_rem
478        // x = q * y + r
479        // Constrain 0 <= r
480        // Constrain r < q
481        // y will be 256 bits wide
482        // x will be 384 bits wide
483
484        // x = self (logical 128-bit)
485        // y = rhs (logical 128-bit)
486        // xbar = x * 2^128 (representative 256-bit)
487        // ybar = y * 2^128 (representative 256-bit)
488
489        // q = x / y
490        // qbar = q * 2^128 (256 bit value)
491
492        // xbar / ybar = x / y * 1
493        // qbar = xbar / ybar * 2^128
494        // xbar * 2^128 = qbar * ybar + r
495
496        // use a division oracle to compute (qbar, r) out-of-circuit (OOC)
497        // Constrain: divisor is non-zero
498        rhs.enforce_not_equal(&U128x128Var::zero())?;
499
500        // OOC division
501        let xbar_ooc = self.value().unwrap_or_default();
502        let ybar_ooc = rhs.value().unwrap_or(U128x128::from(1u64));
503        let Ok((quo_ooc, rem_ooc)) = stub_div_rem_u384_by_u256(xbar_ooc.0, ybar_ooc.0) else {
504            return Err(SynthesisError::Unsatisfiable);
505        };
506        // Constrain: xbar * 2^128 = qbar * ybar + r
507        // We already have xbar as bits, so we have xbar * 2^128 "for free" by rearranging limbs
508        // Need the bits of qbar * ybar + r => need bits of qbar, ybar, r + mul constraint
509
510        let x = self;
511        let y = rhs;
512        let q = U128x128Var::new_witness(cs.clone(), || Ok(U128x128(quo_ooc)))?;
513        // r isn't a U128x128, but we can reuse the codepath for constraining its bits as limb values
514        let r_var = U128x128Var::new_witness(cs, || Ok(U128x128(rem_ooc)))?;
515        // Constrain r < ybar: this also constrains that r is non-negative, i.e. that 0 <= r
516        // i.e. the remainder cannot be greater than the divisor (`y` also called `rhs`)
517        r_var.enforce_cmp(rhs, core::cmp::Ordering::Less)?;
518
519        let r = r_var.limbs;
520        let qbar = &q.limbs;
521        let ybar = &y.limbs;
522        let xbar = &x.limbs;
523
524        // qbar = qbar0 + qbar1 * 2^64 + qbar2 * 2^128 + qbar3 * 2^192
525        // ybar = ybar0 + ybar1 * 2^64 + ybar2 * 2^128 + ybar3 * 2^192
526        //    r =    r0 +    r1 * 2^64 +    r2 * 2^128 +    r3 * 2^192
527
528        let xbar0 = Boolean::<Fq>::le_bits_to_fp_var(&xbar[0].to_bits_le())?;
529        let xbar1 = Boolean::<Fq>::le_bits_to_fp_var(&xbar[1].to_bits_le())?;
530        let xbar2 = Boolean::<Fq>::le_bits_to_fp_var(&xbar[2].to_bits_le())?;
531        let xbar3 = Boolean::<Fq>::le_bits_to_fp_var(&xbar[3].to_bits_le())?;
532
533        let ybar0 = Boolean::<Fq>::le_bits_to_fp_var(&ybar[0].to_bits_le())?;
534        let ybar1 = Boolean::<Fq>::le_bits_to_fp_var(&ybar[1].to_bits_le())?;
535        let ybar2 = Boolean::<Fq>::le_bits_to_fp_var(&ybar[2].to_bits_le())?;
536        let ybar3 = Boolean::<Fq>::le_bits_to_fp_var(&ybar[3].to_bits_le())?;
537
538        let qbar0 = Boolean::<Fq>::le_bits_to_fp_var(&qbar[0].to_bits_le())?;
539        let qbar1 = Boolean::<Fq>::le_bits_to_fp_var(&qbar[1].to_bits_le())?;
540        let qbar2 = Boolean::<Fq>::le_bits_to_fp_var(&qbar[2].to_bits_le())?;
541        let qbar3 = Boolean::<Fq>::le_bits_to_fp_var(&qbar[3].to_bits_le())?;
542
543        let r0 = Boolean::<Fq>::le_bits_to_fp_var(&r[0].to_bits_le())?;
544        let r1 = Boolean::<Fq>::le_bits_to_fp_var(&r[1].to_bits_le())?;
545        let r2 = Boolean::<Fq>::le_bits_to_fp_var(&r[2].to_bits_le())?;
546        let r3 = Boolean::<Fq>::le_bits_to_fp_var(&r[3].to_bits_le())?;
547
548        // Let z = qbar * ybar + r.  Then z will be 513 bits in general; we want
549        // to constrain it to be equal to xbar * 2^128 so we need the low 384
550        // bits -- we'll constrain the low 128 as 0 and the upper 256 as xbar --
551        // and constrain everything above as 0 (not necessarily as bit
552        // constraints)
553
554        // Write z as:
555        //    z =    z0 +    z1 * 2^64 +    z2 * 2^128 +    z3 * 2^192 +    z4 * 2^256 +    z5 * 2^320 +    z6 * 2^384
556        // Without carrying, the limbs of z are:
557        // z0_raw = r0 + qbar0 * ybar0
558        // z1_raw = r1 + qbar1 * ybar0 + qbar0 * ybar1
559        // z2_raw = r2 + qbar2 * ybar0 + qbar1 * ybar1 + qbar0 * ybar2
560        // z3_raw = r3 + qbar3 * ybar0 + qbar2 * ybar1 + qbar1 * ybar2 + qbar0 * ybar3
561        // z4_raw =                      qbar3 * ybar1 + qbar2 * ybar2 + qbar1 * ybar3
562        // z5_raw =                                      qbar3 * ybar2 + qbar2 * ybar3
563        // z6_raw =                                                      qbar3 * ybar3
564
565        let z0_raw = r0 + &qbar0 * &ybar0;
566        let z1_raw = r1 + &qbar1 * &ybar0 + &qbar0 * &ybar1;
567        let z2_raw = r2 + &qbar2 * &ybar0 + &qbar1 * &ybar1 + &qbar0 * &ybar2;
568        let z3_raw = r3 + &qbar3 * &ybar0 + &qbar2 * &ybar1 + &qbar1 * &ybar2 + &qbar0 * &ybar3;
569        let z4_raw = /*__________________*/ &qbar3 * &ybar1 + &qbar2 * &ybar2 + &qbar1 * &ybar3;
570        let z5_raw = /*____________________________________*/ &qbar3 * &ybar2 + &qbar2 * &ybar3;
571        let z6_raw = /*______________________________________________________*/ &qbar3 * &ybar3;
572        /* ------------------------------------------------------------------------------------^ 384 + 128 = 512 */
573
574        // These terms are overlapping, and we need to carry to compute the
575        // canonical limbs.
576        //
577        // We want to constrain
578        //    z =    z0 +    z1 * 2^64 +    z2 * 2^128 +    z3 * 2^192 +    z4 * 2^256 +    z5 * 2^320 +    z6 * 2^384
579        // ==         0       0          xbar0           xbar1           xbar2           xbar3              0
580        // We need to bit-constrain z0 and z1 to be able to compute the carry to
581        // get the canonical z2, z3, z4, z5 values, but don't need bit constraints
582        // for the upper terms, we just need to enforce that they're 0, without the
583        // possibility of wrapping in the finite field.
584
585        // z0 <= (2^64 - 1)(2^64 - 1) + (2^64 - 1) => 128 bits
586        let z0_bits = bit_constrain(z0_raw, 128)?; // no carry-in
587        let z0 = Boolean::<Fq>::le_bits_to_fp_var(&z0_bits[0..64].to_bits_le()?)?;
588        let c1 = Boolean::<Fq>::le_bits_to_fp_var(&z0_bits[64..].to_bits_le()?)?; // 64 bits
589
590        // z1 <= 2*(2^64 - 1)(2^64 - 1) + (2^64 - 1) + carry (2^64 - 1) => 129 bits
591        let z1_bits = bit_constrain(z1_raw + c1, 129)?; // carry-in c1
592        let z1 = Boolean::<Fq>::le_bits_to_fp_var(&z1_bits[0..64].to_bits_le()?)?;
593        let c2 = Boolean::<Fq>::le_bits_to_fp_var(&z1_bits[64..].to_bits_le()?)?; // 65 bits
594
595        // z2 <= 3*(2^64 - 1)(2^64 - 1) + (2^64 - 1) + carry (2^65 - 2) => 130 bits
596        let z2_bits = bit_constrain(z2_raw + c2, 130)?; // carry-in c2
597        let z2 = Boolean::<Fq>::le_bits_to_fp_var(&z2_bits[0..64].to_bits_le()?)?;
598        let c3 = Boolean::<Fq>::le_bits_to_fp_var(&z2_bits[64..].to_bits_le()?)?; // 66 bits
599
600        // z3 <= 4*(2^64 - 1)(2^64 - 1) + (2^64 - 1) + carry (2^66 - 1) => 130 bits
601        let z3_bits = bit_constrain(z3_raw + c3, 130)?; // carry-in c3
602        let z3 = Boolean::<Fq>::le_bits_to_fp_var(&z3_bits[0..64].to_bits_le()?)?;
603        let c4 = Boolean::<Fq>::le_bits_to_fp_var(&z3_bits[64..].to_bits_le()?)?; // 66 bits
604
605        // z4 <= 3*(2^64 - 1)(2^64 - 1) + carry (2^66 - 1) => 130 bits
606        // But extra bits beyond 128 spill into z6, which should be zero, so we can constrain to 128 bits.
607        let z4_bits = bit_constrain(z4_raw + c4, 128)?; // carry-in c4
608        let z4 = Boolean::<Fq>::le_bits_to_fp_var(&z4_bits[0..64].to_bits_le()?)?;
609        let c5 = Boolean::<Fq>::le_bits_to_fp_var(&z4_bits[64..].to_bits_le()?)?; // 64 bits
610
611        // z5 <= 2*(2^64 - 1)(2^64 - 1) + (2^64 - 1)
612        // But if there is no overflow, the final carry (which would be c6 constructed from z5_bits[64..])
613        // should be zero. So instead of constructing that final carry, we can instead bit constrain z5 to
614        // the first 64 bits to save constraints.
615        let z5_bits = bit_constrain(z5_raw + c5, 64)?; // carry-in c5
616        let z5 = Boolean::<Fq>::le_bits_to_fp_var(&z5_bits[0..64].to_bits_le()?)?;
617
618        // Repeat:
619        // We want to constrain
620        //    z =    z0 +    z1 * 2^64 +    z2 * 2^128 +    z3 * 2^192 +    z4 * 2^256 +    z5 * 2^320 +    z6 * 2^384
621        // ==         0       0          xbar0           xbar1           xbar2           xbar3              0
622        z0.enforce_equal(&FqVar::zero())?;
623        z1.enforce_equal(&FqVar::zero())?;
624        z2.enforce_equal(&xbar0)?;
625        z3.enforce_equal(&xbar1)?;
626        z4.enforce_equal(&xbar2)?;
627        z5.enforce_equal(&xbar3)?;
628        // z6_raw should be zero if there was no overflow.
629        z6_raw.enforce_equal(&FqVar::zero())?;
630
631        Ok(q)
632    }
633
634    pub fn round_down(self) -> U128x128Var {
635        Self {
636            limbs: [
637                UInt64::constant(0u64),
638                UInt64::constant(0u64),
639                self.limbs[2].clone(),
640                self.limbs[3].clone(),
641            ],
642        }
643    }
644    /// Multiply an amount by this fraction, then round down.
645    pub fn apply_to_amount(self, rhs: AmountVar) -> Result<AmountVar, SynthesisError> {
646        U128x128Var::from_amount_var(rhs)?
647            .checked_mul(&self)?
648            .round_down_to_amount()
649    }
650
651    pub fn round_down_to_amount(self) -> Result<AmountVar, SynthesisError> {
652        let bits = self.limbs[2]
653            .to_bits_le()
654            .into_iter()
655            .chain(self.limbs[3].to_bits_le().into_iter())
656            .collect::<Vec<Boolean<Fq>>>();
657        Ok(AmountVar {
658            amount: Boolean::<Fq>::le_bits_to_fp_var(&bits)?,
659        })
660    }
661
662    pub fn zero() -> U128x128Var {
663        Self {
664            limbs: [
665                UInt64::constant(0u64),
666                UInt64::constant(0u64),
667                UInt64::constant(0u64),
668                UInt64::constant(0u64),
669            ],
670        }
671    }
672}
673
674impl EqGadget<Fq> for U128x128Var {
675    fn is_eq(&self, other: &Self) -> Result<Boolean<Fq>, SynthesisError> {
676        let limb_1_eq = self.limbs[0].is_eq(&other.limbs[0])?;
677        let limb_2_eq = self.limbs[1].is_eq(&other.limbs[1])?;
678        let limb_3_eq = self.limbs[2].is_eq(&other.limbs[2])?;
679        let limb_4_eq = self.limbs[3].is_eq(&other.limbs[3])?;
680
681        let limb_12_eq = limb_1_eq.and(&limb_2_eq)?;
682        let limb_34_eq = limb_3_eq.and(&limb_4_eq)?;
683
684        limb_12_eq.and(&limb_34_eq)
685    }
686}
687
688impl ToConstraintField<Fq> for U128x128 {
689    fn to_field_elements(&self) -> Option<Vec<Fq>> {
690        let (hi_128, lo_128) = self.0.into_words();
691        Some(vec![Fq::from(hi_128), Fq::from(lo_128)])
692    }
693}
694
695impl CondSelectGadget<Fq> for U128x128Var {
696    fn conditionally_select(
697        cond: &Boolean<Fq>,
698        true_value: &Self,
699        false_value: &Self,
700    ) -> Result<Self, SynthesisError> {
701        let limb0 = cond.select(&true_value.limbs[0], &false_value.limbs[0])?;
702        let limb1 = cond.select(&true_value.limbs[1], &false_value.limbs[1])?;
703        let limb2 = cond.select(&true_value.limbs[2], &false_value.limbs[2])?;
704        let limb3 = cond.select(&true_value.limbs[3], &false_value.limbs[3])?;
705        Ok(Self {
706            limbs: [limb0, limb1, limb2, limb3],
707        })
708    }
709}
710
711/// Convert Uint64 into an FqVar
712pub fn convert_uint64_to_fqvar<F: PrimeField>(value: &UInt64<F>) -> FpVar<F> {
713    Boolean::<F>::le_bits_to_fp_var(&value.to_bits_le()).expect("can convert to bits")
714}
715
716/// Bit constrain for FqVar and return number of bits
717pub fn bit_constrain(value: FqVar, n: usize) -> Result<Vec<Boolean<Fq>>, SynthesisError> {
718    let inner = value.value().unwrap_or(Fq::zero());
719
720    // Get only first n bits based on that value (OOC)
721    let inner_bigint = inner.into_bigint();
722    let bits = &inner_bigint.to_bits_le()[0..n];
723
724    // Allocate Boolean vars for first n bits
725    let mut boolean_constraints = Vec::new();
726    for bit in bits {
727        let boolean = Boolean::new_witness(value.cs().clone(), || Ok(bit))?;
728        boolean_constraints.push(boolean);
729    }
730
731    // Construct an FqVar from those n Boolean constraints, and constrain it to be equal to the original value
732    let constructed_fqvar = Boolean::<Fq>::le_bits_to_fp_var(&boolean_constraints.to_bits_le()?)
733        .expect("can convert to bits");
734    constructed_fqvar.enforce_equal(&value)?;
735
736    Ok(boolean_constraints)
737}
738
739#[cfg(test)]
740mod test {
741    use ark_groth16::{r1cs_to_qap::LibsnarkReduction, Groth16, ProvingKey, VerifyingKey};
742    use ark_relations::r1cs::ConstraintSynthesizer;
743    use ark_snark::SNARK;
744    use decaf377::Bls12_377;
745    use proptest::prelude::*;
746    use rand_core::OsRng;
747
748    use crate::Amount;
749
750    use super::*;
751
752    proptest! {
753        #![proptest_config(ProptestConfig::with_cases(1))]
754        #[test]
755        fn multiply_and_round(
756            a_int in any::<u64>(),
757            a_frac in any::<u64>(),
758            b_int in any::<u64>(),
759            b_frac in any::<u64>(),
760        ) {
761            let a = U128x128(
762                U256([a_frac.into(), a_int.into()])
763            );
764            let b = U128x128(
765                U256([b_frac.into(), b_int.into()])
766            );
767
768            let result = a.checked_mul(&b);
769
770            let expected_c = result.expect("result should not overflow");
771            let rounded_down_c = expected_c.round_down();
772
773            let circuit = TestMultiplicationCircuit {
774                a,
775                b,
776                c: expected_c,
777                rounded_down_c,
778            };
779
780            let (pk, vk) = TestMultiplicationCircuit::generate_test_parameters();
781            let mut rng = OsRng;
782
783            let proof = Groth16::<Bls12_377, LibsnarkReduction>::prove(&pk, circuit, &mut rng)
784            .expect("should be able to form proof");
785
786            let mut pi = Vec::new();
787            pi.extend_from_slice(&expected_c.to_field_elements().unwrap());
788            pi.extend_from_slice(&rounded_down_c.to_field_elements().unwrap());
789            let proof_result = Groth16::<Bls12_377, LibsnarkReduction>::verify(
790                &vk,
791                &pi,
792                &proof,
793            );
794            assert!(proof_result.is_ok());
795        }
796    }
797
798    struct TestMultiplicationCircuit {
799        a: U128x128,
800        b: U128x128,
801
802        // c = a * b
803        pub c: U128x128,
804        pub rounded_down_c: U128x128,
805    }
806
807    impl ConstraintSynthesizer<Fq> for TestMultiplicationCircuit {
808        fn generate_constraints(
809            self,
810            cs: ConstraintSystemRef<Fq>,
811        ) -> ark_relations::r1cs::Result<()> {
812            let a_var = U128x128Var::new_witness(cs.clone(), || Ok(self.a))?;
813            let b_var = U128x128Var::new_witness(cs.clone(), || Ok(self.b))?;
814            let c_public_var = U128x128Var::new_input(cs.clone(), || Ok(self.c))?;
815            let c_public_rounded_down_var = U128x128Var::new_input(cs, || Ok(self.rounded_down_c))?;
816            let c_var = a_var.clone().checked_mul(&b_var)?;
817            c_var.enforce_equal(&c_public_var)?;
818            let c_rounded_down = c_var.clone().round_down();
819            c_rounded_down.enforce_equal(&c_public_rounded_down_var)?;
820
821            // Also check that a < c
822            a_var.enforce_cmp(&c_var, std::cmp::Ordering::Less)?;
823
824            // Also check that c > a
825            c_var.enforce_cmp(&a_var, std::cmp::Ordering::Greater)?;
826            Ok(())
827        }
828    }
829
830    impl TestMultiplicationCircuit {
831        fn generate_test_parameters() -> (ProvingKey<Bls12_377>, VerifyingKey<Bls12_377>) {
832            let num: [u8; 32] = [0u8; 32];
833            let a = U128x128::from_bytes(num);
834            let b = U128x128::from_bytes(num);
835            let c = a.checked_mul(&b).unwrap();
836            let rounded_down_c = c.round_down();
837            let circuit = TestMultiplicationCircuit {
838                a,
839                b,
840                c,
841                rounded_down_c,
842            };
843            let (pk, vk) = Groth16::<Bls12_377, LibsnarkReduction>::circuit_specific_setup(
844                circuit, &mut OsRng,
845            )
846            .expect("can perform circuit specific setup");
847            (pk, vk)
848        }
849    }
850
851    proptest! {
852        #![proptest_config(ProptestConfig::with_cases(5))]
853        #[test]
854        fn add(
855            a_int in any::<u64>(),
856            a_frac in any::<u128>(),
857            b_int in any::<u64>(),
858            b_frac in any::<u128>(),
859        ) {
860            let a = U128x128(
861                U256([a_frac, a_int.into()])
862            );
863            let b = U128x128(
864                U256([b_frac, b_int.into()])
865            );
866            let result = a.checked_add(&b);
867
868            if result.is_err() {
869                // If the result overflowed, then we can't construct a valid proof.
870                return Ok(())
871            }
872            let expected_c = result.expect("result should not overflow");
873
874            let circuit = TestAdditionCircuit {
875                a,
876                b,
877                c: expected_c,
878            };
879
880            let (pk, vk) = TestAdditionCircuit::generate_test_parameters();
881            let mut rng = OsRng;
882
883            let proof = Groth16::<Bls12_377, LibsnarkReduction>::prove(&pk, circuit, &mut rng)
884            .expect("should be able to form proof");
885
886            let proof_result = Groth16::<Bls12_377, LibsnarkReduction>::verify(
887                &vk,
888                &expected_c.to_field_elements().unwrap(),
889                &proof,
890            );
891            assert!(proof_result.is_ok());
892        }
893    }
894
895    #[test]
896    fn max_u64_addition() {
897        let a = U128x128(U256([u64::MAX as u128, 0]));
898        let b = U128x128(U256([u64::MAX as u128, 0]));
899
900        let result = a.checked_add(&b);
901
902        let expected_c = result.expect("result should not overflow");
903
904        let circuit = TestAdditionCircuit {
905            a,
906            b,
907            c: expected_c,
908        };
909
910        let (pk, vk) = TestAdditionCircuit::generate_test_parameters();
911        let mut rng = OsRng;
912
913        let proof = Groth16::<Bls12_377, LibsnarkReduction>::prove(&pk, circuit, &mut rng)
914            .expect("should be able to form proof");
915
916        let proof_result = Groth16::<Bls12_377, LibsnarkReduction>::verify(
917            &vk,
918            &expected_c.to_field_elements().unwrap(),
919            &proof,
920        );
921        assert!(proof_result.is_ok());
922    }
923
924    struct TestAdditionCircuit {
925        a: U128x128,
926        b: U128x128,
927
928        // c = a + b
929        pub c: U128x128,
930    }
931
932    impl ConstraintSynthesizer<Fq> for TestAdditionCircuit {
933        fn generate_constraints(
934            self,
935            cs: ConstraintSystemRef<Fq>,
936        ) -> ark_relations::r1cs::Result<()> {
937            let a_var = U128x128Var::new_witness(cs.clone(), || Ok(self.a))?;
938            let b_var = U128x128Var::new_witness(cs.clone(), || Ok(self.b))?;
939            let c_public_var = U128x128Var::new_input(cs, || Ok(self.c))?;
940            let c_var = a_var.checked_add(&b_var)?;
941            c_var.enforce_equal(&c_public_var)?;
942            Ok(())
943        }
944    }
945
946    impl TestAdditionCircuit {
947        fn generate_test_parameters() -> (ProvingKey<Bls12_377>, VerifyingKey<Bls12_377>) {
948            let num: [u8; 32] = [0u8; 32];
949            let a = U128x128::from_bytes(num);
950            let b = U128x128::from_bytes(num);
951            let circuit = TestAdditionCircuit {
952                a,
953                b,
954                c: a.checked_add(&b).unwrap(),
955            };
956            let (pk, vk) = Groth16::<Bls12_377, LibsnarkReduction>::circuit_specific_setup(
957                circuit, &mut OsRng,
958            )
959            .expect("can perform circuit specific setup");
960            (pk, vk)
961        }
962    }
963
964    #[test]
965    fn max_division() {
966        let b = U128x128(U256([0, 1]));
967        let a = U128x128(U256([u128::MAX, u128::MAX]));
968
969        let result = a.checked_div(&b);
970
971        let expected_c = result.expect("result should not overflow");
972        dbg!(expected_c);
973
974        let circuit = TestDivisionCircuit {
975            a,
976            b,
977            c: expected_c,
978        };
979
980        let (pk, vk) = TestDivisionCircuit::generate_test_parameters();
981        let mut rng = OsRng;
982
983        let proof = Groth16::<Bls12_377, LibsnarkReduction>::prove(&pk, circuit, &mut rng)
984            .expect("should be able to form proof");
985
986        let proof_result = Groth16::<Bls12_377, LibsnarkReduction>::verify(
987            &vk,
988            &expected_c.to_field_elements().unwrap(),
989            &proof,
990        );
991        assert!(proof_result.is_ok());
992    }
993
994    proptest! {
995        #![proptest_config(ProptestConfig::with_cases(10))]
996        #[test]
997        fn division(
998            a_int in any::<u64>(),
999            a_frac in any::<u64>(),
1000            b_int in any::<u128>(),
1001            b_frac in any::<u128>(),
1002        ) {
1003            let a = U128x128(
1004                U256([a_frac.into(), a_int.into()])
1005            );
1006            let b = U128x128(
1007                U256([b_frac, b_int])
1008            );
1009
1010            // We can't divide by zero
1011            if b_int == 0 {
1012                return Ok(())
1013            }
1014
1015            let result = a.checked_div(&b);
1016
1017            let expected_c = result.expect("result should not overflow");
1018
1019            let circuit = TestDivisionCircuit {
1020                a,
1021                b,
1022                c: expected_c,
1023            };
1024
1025            let (pk, vk) = TestDivisionCircuit::generate_test_parameters();
1026            let mut rng = OsRng;
1027
1028            let proof = Groth16::<Bls12_377, LibsnarkReduction>::prove(&pk, circuit, &mut rng)
1029            .expect("should be able to form proof");
1030
1031            let proof_result = Groth16::<Bls12_377, LibsnarkReduction>::verify(
1032                &vk,
1033                &expected_c.to_field_elements().unwrap(),
1034                &proof,
1035            );
1036            assert!(proof_result.is_ok());
1037        }
1038    }
1039
1040    struct TestDivisionCircuit {
1041        a: U128x128,
1042        b: U128x128,
1043
1044        // c = a / b
1045        pub c: U128x128,
1046    }
1047
1048    impl ConstraintSynthesizer<Fq> for TestDivisionCircuit {
1049        fn generate_constraints(
1050            self,
1051            cs: ConstraintSystemRef<Fq>,
1052        ) -> ark_relations::r1cs::Result<()> {
1053            let a_var = U128x128Var::new_witness(cs.clone(), || Ok(self.a))?;
1054            let b_var = U128x128Var::new_witness(cs.clone(), || Ok(self.b))?;
1055            let c_public_var = U128x128Var::new_input(cs.clone(), || Ok(self.c))?;
1056            let c_var = a_var.checked_div(&b_var, cs)?;
1057            c_var.enforce_equal(&c_public_var)?;
1058            Ok(())
1059        }
1060    }
1061
1062    impl TestDivisionCircuit {
1063        fn generate_test_parameters() -> (ProvingKey<Bls12_377>, VerifyingKey<Bls12_377>) {
1064            let num: [u8; 32] = [1u8; 32];
1065            let a = U128x128::from_bytes(num);
1066            let b = U128x128::from_bytes(num);
1067            let circuit = TestDivisionCircuit {
1068                a,
1069                b,
1070                c: a.checked_div(&b).unwrap(),
1071            };
1072            let (pk, vk) = Groth16::<Bls12_377, LibsnarkReduction>::circuit_specific_setup(
1073                circuit, &mut OsRng,
1074            )
1075            .expect("can perform circuit specific setup");
1076            (pk, vk)
1077        }
1078    }
1079
1080    proptest! {
1081        #![proptest_config(ProptestConfig::with_cases(5))]
1082        #[test]
1083        fn compare(
1084            a_int in any::<u64>(),
1085            c_int in any::<u64>(),
1086        ) {
1087            // a < b
1088            let a =
1089                if a_int == u64::MAX {
1090                    U128x128::from(a_int - 1)
1091                } else {
1092                    U128x128::from(a_int)
1093                };
1094            let b = (a + U128x128::from(1u64)).expect("should not overflow");
1095            // c > d
1096            let c =
1097                if c_int == 0 {
1098                    U128x128::from(c_int + 1)
1099                } else {
1100                    U128x128::from(c_int)
1101                };
1102            let d = (c - U128x128::from(1u64)).expect("should not underflow");
1103
1104            let circuit = TestComparisonCircuit {
1105                a,
1106                b,
1107                c,
1108                d,
1109            };
1110
1111            let (pk, vk) = TestComparisonCircuit::generate_test_parameters();
1112            let mut rng = OsRng;
1113
1114            let proof = Groth16::<Bls12_377, LibsnarkReduction>::prove(&pk, circuit, &mut rng)
1115            .expect("should be able to form proof");
1116
1117            let proof_result = Groth16::<Bls12_377, LibsnarkReduction>::verify(
1118                &vk,
1119                &[],
1120                &proof,
1121            );
1122            assert!(proof_result.is_ok());
1123        }
1124    }
1125
1126    struct TestComparisonCircuit {
1127        a: U128x128,
1128        b: U128x128,
1129        c: U128x128,
1130        d: U128x128,
1131    }
1132
1133    impl ConstraintSynthesizer<Fq> for TestComparisonCircuit {
1134        fn generate_constraints(
1135            self,
1136            cs: ConstraintSystemRef<Fq>,
1137        ) -> ark_relations::r1cs::Result<()> {
1138            // a < b
1139            let a_var = U128x128Var::new_witness(cs.clone(), || Ok(self.a))?;
1140            let b_var = U128x128Var::new_witness(cs.clone(), || Ok(self.b))?;
1141            a_var.enforce_cmp(&b_var, std::cmp::Ordering::Less)?;
1142            // c > d
1143            let c_var = U128x128Var::new_witness(cs.clone(), || Ok(self.c))?;
1144            let d_var = U128x128Var::new_witness(cs, || Ok(self.d))?;
1145            c_var.enforce_cmp(&d_var, std::cmp::Ordering::Greater)?;
1146
1147            Ok(())
1148        }
1149    }
1150
1151    impl TestComparisonCircuit {
1152        fn generate_test_parameters() -> (ProvingKey<Bls12_377>, VerifyingKey<Bls12_377>) {
1153            let num: [u8; 32] = [0u8; 32];
1154            let a = U128x128::from_bytes(num);
1155            let b = U128x128::from_bytes(num);
1156            let c = U128x128::from_bytes(num);
1157            let d = U128x128::from_bytes(num);
1158            let circuit = TestComparisonCircuit { a, b, c, d };
1159            let (pk, vk) = Groth16::<Bls12_377, LibsnarkReduction>::circuit_specific_setup(
1160                circuit, &mut OsRng,
1161            )
1162            .expect("can perform circuit specific setup");
1163            (pk, vk)
1164        }
1165    }
1166
1167    struct TestGreaterInvalidComparisonCircuit {
1168        a: U128x128,
1169        b: U128x128,
1170    }
1171
1172    impl ConstraintSynthesizer<Fq> for TestGreaterInvalidComparisonCircuit {
1173        fn generate_constraints(
1174            self,
1175            cs: ConstraintSystemRef<Fq>,
1176        ) -> ark_relations::r1cs::Result<()> {
1177            // In reality a < b, but we're asserting that a > b here (should panic)
1178            let a_var = U128x128Var::new_witness(cs.clone(), || Ok(self.a))?;
1179            let b_var = U128x128Var::new_witness(cs, || Ok(self.b))?;
1180            a_var.enforce_cmp(&b_var, std::cmp::Ordering::Greater)?;
1181
1182            Ok(())
1183        }
1184    }
1185
1186    impl TestGreaterInvalidComparisonCircuit {
1187        fn generate_test_parameters(
1188        ) -> Result<(ProvingKey<Bls12_377>, VerifyingKey<Bls12_377>), SynthesisError> {
1189            let num: [u8; 32] = [0u8; 32];
1190            let a = U128x128::from_bytes(num);
1191            let b = U128x128::from_bytes(num);
1192            let circuit = TestGreaterInvalidComparisonCircuit { a, b };
1193            Groth16::<Bls12_377, LibsnarkReduction>::circuit_specific_setup(circuit, &mut OsRng)
1194        }
1195    }
1196
1197    proptest! {
1198        #![proptest_config(ProptestConfig::with_cases(5))]
1199        #[should_panic]
1200        #[test]
1201        fn invalid_greater_compare(
1202            a_int in any::<u128>(),
1203        ) {
1204            // a < b
1205            let a =
1206                if a_int == u128::MAX {
1207                    U128x128::from(a_int - 1)
1208                } else {
1209                    U128x128::from(a_int)
1210                };
1211            let b = (a + U128x128::from(1u64)).expect("should not overflow");
1212
1213            let circuit = TestGreaterInvalidComparisonCircuit {
1214                a,
1215                b,
1216            };
1217
1218            let (pk, vk) = TestGreaterInvalidComparisonCircuit::generate_test_parameters().expect("can perform setup");
1219            let mut rng = OsRng;
1220
1221            let proof = Groth16::<Bls12_377, LibsnarkReduction>::prove(&pk, circuit, &mut rng)
1222            .expect("in debug mode only, we assert that the circuit is satisfied, so we will panic here");
1223
1224            let proof_result = Groth16::<Bls12_377, LibsnarkReduction>::verify(
1225                &vk,
1226                &[],
1227                &proof,
1228            ).expect("in release mode, we will be able to construct the proof, so we can unwrap the result");
1229
1230            // We want the same behavior in release or debug mode, so we will panic if the proof does not verify.
1231            if !proof_result {
1232                panic!("should not be able to verify proof");
1233            }
1234        }
1235    }
1236
1237    struct TestLessInvalidComparisonCircuit {
1238        c: U128x128,
1239        d: U128x128,
1240    }
1241
1242    impl ConstraintSynthesizer<Fq> for TestLessInvalidComparisonCircuit {
1243        fn generate_constraints(
1244            self,
1245            cs: ConstraintSystemRef<Fq>,
1246        ) -> ark_relations::r1cs::Result<()> {
1247            // In reality c > d, but we're asserting that c < d here (should panic)
1248            let c_var = U128x128Var::new_witness(cs.clone(), || Ok(self.c))?;
1249            let d_var = U128x128Var::new_witness(cs, || Ok(self.d))?;
1250            c_var.enforce_cmp(&d_var, std::cmp::Ordering::Less)?;
1251
1252            Ok(())
1253        }
1254    }
1255
1256    impl TestLessInvalidComparisonCircuit {
1257        fn generate_test_parameters(
1258        ) -> Result<(ProvingKey<Bls12_377>, VerifyingKey<Bls12_377>), SynthesisError> {
1259            let num: [u8; 32] = [0u8; 32];
1260            let c = U128x128::from_bytes(num);
1261            let d = U128x128::from_bytes(num);
1262            let circuit = TestLessInvalidComparisonCircuit { c, d };
1263            Groth16::<Bls12_377, LibsnarkReduction>::circuit_specific_setup(circuit, &mut OsRng)
1264        }
1265    }
1266
1267    proptest! {
1268        #![proptest_config(ProptestConfig::with_cases(5))]
1269        #[should_panic]
1270        #[test]
1271        fn invalid_less_compare(
1272            c_int in any::<u128>(),
1273        ) {
1274            // c > d
1275            let c =
1276                if c_int == 0 {
1277                    U128x128::from(c_int + 1)
1278                } else {
1279                    U128x128::from(c_int)
1280                };
1281            let d = (c - U128x128::from(1u64)).expect("should not underflow");
1282
1283            let circuit = TestLessInvalidComparisonCircuit {
1284                c,
1285                d,
1286            };
1287
1288            let (pk, vk) = TestLessInvalidComparisonCircuit::generate_test_parameters().expect("can perform setup");
1289            let mut rng = OsRng;
1290
1291            let proof = Groth16::<Bls12_377, LibsnarkReduction>::prove(&pk, circuit, &mut rng)
1292            .expect("in debug mode only, we assert that the circuit is satisfied, so we will panic here");
1293
1294            let proof_result = Groth16::<Bls12_377, LibsnarkReduction>::verify(
1295                &vk,
1296                &[],
1297                &proof,
1298            ).expect("in release mode, we will be able to construct the proof, so we can unwrap the result");
1299
1300            // We want the same behavior in release or debug mode, so we will panic if the proof does not verify.
1301            if !proof_result {
1302                panic!("should not be able to verify proof");
1303            }
1304        }
1305    }
1306
1307    #[should_panic]
1308    #[test]
1309    fn regression_invalid_less_compare() {
1310        // c > d in reality, the circuit will attempt to prove c < d (should panic)
1311        let c = U128x128::from(354389783742u64);
1312        let d = U128x128::from(17u64);
1313
1314        let circuit = TestLessInvalidComparisonCircuit { c, d };
1315
1316        let (pk, vk) = TestLessInvalidComparisonCircuit::generate_test_parameters()
1317            .expect("can perform setup");
1318        let mut rng = OsRng;
1319
1320        let proof = Groth16::<Bls12_377, LibsnarkReduction>::prove(&pk, circuit, &mut rng).expect(
1321            "in debug mode only, we assert that the circuit is satisfied, so we will panic here",
1322        );
1323
1324        let proof_result = Groth16::<Bls12_377, LibsnarkReduction>::verify(&vk, &[], &proof)
1325            .expect(
1326            "in release mode, we will be able to construct the proof, so we can unwrap the result",
1327        );
1328
1329        // We want the same behavior in release or debug mode, so we will panic if the proof does not verify.
1330        if !proof_result {
1331            panic!("should not be able to verify proof");
1332        }
1333    }
1334
1335    proptest! {
1336            #![proptest_config(ProptestConfig::with_cases(5))]
1337        #[test]
1338        fn round_down_to_amount(
1339            a_int in any::<u128>(),
1340            a_frac in any::<u128>(),
1341            ) {
1342            let a = U128x128(
1343                U256([a_frac, a_int])
1344            );
1345
1346            let expected_c = a.round_down().try_into().expect("should be able to round down OOC");
1347
1348            let circuit = TestRoundDownCircuit {
1349                a,
1350                c: expected_c,
1351            };
1352
1353            let (pk, vk) = TestRoundDownCircuit::generate_test_parameters();
1354            let mut rng = OsRng;
1355
1356            let proof = Groth16::<Bls12_377, LibsnarkReduction>::prove(&pk, circuit, &mut rng)
1357                .expect("should be able to form proof");
1358
1359            let proof_result = Groth16::<Bls12_377, LibsnarkReduction>::verify(
1360                &vk,
1361                &expected_c.to_field_elements().unwrap(),
1362                &proof,
1363            );
1364            assert!(proof_result.is_ok());
1365        }
1366    }
1367
1368    struct TestRoundDownCircuit {
1369        a: U128x128,
1370
1371        // `c` is expected to be `a` rounded down to an `Amount`
1372        pub c: Amount,
1373    }
1374
1375    impl ConstraintSynthesizer<Fq> for TestRoundDownCircuit {
1376        fn generate_constraints(
1377            self,
1378            cs: ConstraintSystemRef<Fq>,
1379        ) -> ark_relations::r1cs::Result<()> {
1380            let a_var = U128x128Var::new_witness(cs.clone(), || Ok(self.a))?;
1381            let c_public_var = AmountVar::new_input(cs, || Ok(self.c))?;
1382            let c_var = a_var.round_down_to_amount()?;
1383            c_var.enforce_equal(&c_public_var)?;
1384            Ok(())
1385        }
1386    }
1387
1388    impl TestRoundDownCircuit {
1389        fn generate_test_parameters() -> (ProvingKey<Bls12_377>, VerifyingKey<Bls12_377>) {
1390            let num: [u8; 32] = [0u8; 32];
1391            let a = U128x128::from_bytes(num);
1392            let c: Amount = a
1393                .round_down()
1394                .try_into()
1395                .expect("should be able to round down OOC");
1396            let circuit = TestRoundDownCircuit { a, c };
1397            let (pk, vk) = Groth16::<Bls12_377, LibsnarkReduction>::circuit_specific_setup(
1398                circuit, &mut OsRng,
1399            )
1400            .expect("can perform circuit specific setup");
1401            (pk, vk)
1402        }
1403    }
1404}