ark_ff/fields/models/fp/
montgomery_backend.rs

1use ark_std::{marker::PhantomData, Zero};
2
3use super::{Fp, FpConfig};
4use crate::{biginteger::arithmetic as fa, BigInt, BigInteger, PrimeField, SqrtPrecomputation};
5use ark_ff_macros::unroll_for_loops;
6
7/// A trait that specifies the constants and arithmetic procedures
8/// for Montgomery arithmetic over the prime field defined by `MODULUS`.
9///
10/// # Note
11/// Manual implementation of this trait is not recommended unless one wishes
12/// to specialize arithmetic methods. Instead, the
13/// [`MontConfig`][`ark_ff_macros::MontConfig`] derive macro should be used.
14pub trait MontConfig<const N: usize>: 'static + Sync + Send + Sized {
15    /// The modulus of the field.
16    const MODULUS: BigInt<N>;
17
18    /// Let `M` be the power of 2^64 nearest to `Self::MODULUS_BITS`. Then
19    /// `R = M % Self::MODULUS`.
20    const R: BigInt<N> = Self::MODULUS.montgomery_r();
21
22    /// R2 = R^2 % Self::MODULUS
23    const R2: BigInt<N> = Self::MODULUS.montgomery_r2();
24
25    /// INV = -MODULUS^{-1} mod 2^64
26    const INV: u64 = inv::<Self, N>();
27
28    /// A multiplicative generator of the field.
29    /// `Self::GENERATOR` is an element having multiplicative order
30    /// `Self::MODULUS - 1`.
31    const GENERATOR: Fp<MontBackend<Self, N>, N>;
32
33    /// Can we use the no-carry optimization for multiplication
34    /// outlined [here](https://hackmd.io/@gnark/modular_multiplication)?
35    ///
36    /// This optimization applies if
37    /// (a) `Self::MODULUS[N-1] < u64::MAX >> 1`, and
38    /// (b) the bits of the modulus are not all 1.
39    #[doc(hidden)]
40    const CAN_USE_NO_CARRY_MUL_OPT: bool = can_use_no_carry_mul_optimization::<Self, N>();
41
42    /// Can we use the no-carry optimization for squaring
43    /// outlined [here](https://hackmd.io/@gnark/modular_multiplication)?
44    ///
45    /// This optimization applies if
46    /// (a) `Self::MODULUS[N-1] < u64::MAX >> 2`, and
47    /// (b) the bits of the modulus are not all 1.
48    #[doc(hidden)]
49    const CAN_USE_NO_CARRY_SQUARE_OPT: bool = can_use_no_carry_mul_optimization::<Self, N>();
50
51    /// Does the modulus have a spare unused bit
52    ///
53    /// This condition applies if
54    /// (a) `Self::MODULUS[N-1] >> 63 == 0`
55    #[doc(hidden)]
56    const MODULUS_HAS_SPARE_BIT: bool = modulus_has_spare_bit::<Self, N>();
57
58    /// 2^s root of unity computed by GENERATOR^t
59    const TWO_ADIC_ROOT_OF_UNITY: Fp<MontBackend<Self, N>, N>;
60
61    /// An integer `b` such that there exists a multiplicative subgroup
62    /// of size `b^k` for some integer `k`.
63    const SMALL_SUBGROUP_BASE: Option<u32> = None;
64
65    /// The integer `k` such that there exists a multiplicative subgroup
66    /// of size `Self::SMALL_SUBGROUP_BASE^k`.
67    const SMALL_SUBGROUP_BASE_ADICITY: Option<u32> = None;
68
69    /// GENERATOR^((MODULUS-1) / (2^s *
70    /// SMALL_SUBGROUP_BASE^SMALL_SUBGROUP_BASE_ADICITY)).
71    /// Used for mixed-radix FFT.
72    const LARGE_SUBGROUP_ROOT_OF_UNITY: Option<Fp<MontBackend<Self, N>, N>> = None;
73
74    /// Precomputed material for use when computing square roots.
75    /// The default is to use the standard Tonelli-Shanks algorithm.
76    const SQRT_PRECOMP: Option<SqrtPrecomputation<Fp<MontBackend<Self, N>, N>>> =
77        sqrt_precomputation::<N, Self>();
78
79    /// (MODULUS + 1) / 4 when MODULUS % 4 == 3. Used for square root precomputations.
80    #[doc(hidden)]
81    const MODULUS_PLUS_ONE_DIV_FOUR: Option<BigInt<N>> = {
82        match Self::MODULUS.mod_4() == 3 {
83            true => {
84                let (modulus_plus_one, carry) =
85                    Self::MODULUS.const_add_with_carry(&BigInt::<N>::one());
86                let mut result = modulus_plus_one.divide_by_2_round_down();
87                // Since modulus_plus_one is even, dividing by 2 results in a MSB of 0.
88                // Thus we can set MSB to `carry` to get the correct result of (MODULUS + 1) // 2:
89                result.0[N - 1] |= (carry as u64) << 63;
90                Some(result.divide_by_2_round_down())
91            },
92            false => None,
93        }
94    };
95
96    /// Sets `a = a + b`.
97    #[inline(always)]
98    fn add_assign(a: &mut Fp<MontBackend<Self, N>, N>, b: &Fp<MontBackend<Self, N>, N>) {
99        // This cannot exceed the backing capacity.
100        let c = a.0.add_with_carry(&b.0);
101        // However, it may need to be reduced
102        if Self::MODULUS_HAS_SPARE_BIT {
103            a.subtract_modulus()
104        } else {
105            a.subtract_modulus_with_carry(c)
106        }
107    }
108
109    /// Sets `a = a - b`.
110    #[inline(always)]
111    fn sub_assign(a: &mut Fp<MontBackend<Self, N>, N>, b: &Fp<MontBackend<Self, N>, N>) {
112        // If `other` is larger than `self`, add the modulus to self first.
113        if b.0 > a.0 {
114            a.0.add_with_carry(&Self::MODULUS);
115        }
116        a.0.sub_with_borrow(&b.0);
117    }
118
119    /// Sets `a = 2 * a`.
120    #[inline(always)]
121    fn double_in_place(a: &mut Fp<MontBackend<Self, N>, N>) {
122        // This cannot exceed the backing capacity.
123        let c = a.0.mul2();
124        // However, it may need to be reduced.
125        if Self::MODULUS_HAS_SPARE_BIT {
126            a.subtract_modulus()
127        } else {
128            a.subtract_modulus_with_carry(c)
129        }
130    }
131
132    /// Sets `a = -a`.
133    #[inline(always)]
134    fn neg_in_place(a: &mut Fp<MontBackend<Self, N>, N>) {
135        if !a.is_zero() {
136            let mut tmp = Self::MODULUS;
137            tmp.sub_with_borrow(&a.0);
138            a.0 = tmp;
139        }
140    }
141
142    /// This modular multiplication algorithm uses Montgomery
143    /// reduction for efficient implementation. It also additionally
144    /// uses the "no-carry optimization" outlined
145    /// [here](https://hackmd.io/@gnark/modular_multiplication) if
146    /// `Self::MODULUS` has (a) a non-zero MSB, and (b) at least one
147    /// zero bit in the rest of the modulus.
148    #[unroll_for_loops(12)]
149    #[inline(always)]
150    fn mul_assign(a: &mut Fp<MontBackend<Self, N>, N>, b: &Fp<MontBackend<Self, N>, N>) {
151        // No-carry optimisation applied to CIOS
152        if Self::CAN_USE_NO_CARRY_MUL_OPT {
153            if N <= 6
154                && N > 1
155                && cfg!(all(
156                    feature = "asm",
157                    target_feature = "bmi2",
158                    target_feature = "adx",
159                    target_arch = "x86_64"
160                ))
161            {
162                #[cfg(
163                    all(
164                        feature = "asm", 
165                        target_feature = "bmi2", 
166                        target_feature = "adx", 
167                        target_arch = "x86_64"
168                    )
169                )]
170                #[allow(unsafe_code, unused_mut)]
171                #[rustfmt::skip]
172
173                // Tentatively avoid using assembly for `N == 1`.
174                match N {
175                    2 => { ark_ff_asm::x86_64_asm_mul!(2, (a.0).0, (b.0).0); },
176                    3 => { ark_ff_asm::x86_64_asm_mul!(3, (a.0).0, (b.0).0); },
177                    4 => { ark_ff_asm::x86_64_asm_mul!(4, (a.0).0, (b.0).0); },
178                    5 => { ark_ff_asm::x86_64_asm_mul!(5, (a.0).0, (b.0).0); },
179                    6 => { ark_ff_asm::x86_64_asm_mul!(6, (a.0).0, (b.0).0); },
180                    _ => unsafe { ark_std::hint::unreachable_unchecked() },
181                };
182            } else {
183                let mut r = [0u64; N];
184
185                for i in 0..N {
186                    let mut carry1 = 0u64;
187                    r[0] = fa::mac(r[0], (a.0).0[0], (b.0).0[i], &mut carry1);
188
189                    let k = r[0].wrapping_mul(Self::INV);
190
191                    let mut carry2 = 0u64;
192                    fa::mac_discard(r[0], k, Self::MODULUS.0[0], &mut carry2);
193
194                    for j in 1..N {
195                        r[j] = fa::mac_with_carry(r[j], (a.0).0[j], (b.0).0[i], &mut carry1);
196                        r[j - 1] = fa::mac_with_carry(r[j], k, Self::MODULUS.0[j], &mut carry2);
197                    }
198                    r[N - 1] = carry1 + carry2;
199                }
200                (a.0).0 = r;
201            }
202            a.subtract_modulus();
203        } else {
204            // Alternative implementation
205            // Implements CIOS.
206            let (carry, res) = a.mul_without_cond_subtract(b);
207            *a = res;
208
209            if Self::MODULUS_HAS_SPARE_BIT {
210                a.subtract_modulus_with_carry(carry);
211            } else {
212                a.subtract_modulus();
213            }
214        }
215    }
216
217    #[inline(always)]
218    #[unroll_for_loops(12)]
219    fn square_in_place(a: &mut Fp<MontBackend<Self, N>, N>) {
220        if N == 1 {
221            // We default to multiplying with `a` using the `Mul` impl
222            // for the N == 1 case
223            *a *= *a;
224            return;
225        }
226        if Self::CAN_USE_NO_CARRY_SQUARE_OPT
227            && (2..=6).contains(&N)
228            && cfg!(all(
229                feature = "asm",
230                target_feature = "bmi2",
231                target_feature = "adx",
232                target_arch = "x86_64"
233            ))
234        {
235            #[cfg(all(
236                feature = "asm",
237                target_feature = "bmi2",
238                target_feature = "adx",
239                target_arch = "x86_64"
240            ))]
241            #[allow(unsafe_code, unused_mut)]
242            #[rustfmt::skip]
243            match N {
244                2 => { ark_ff_asm::x86_64_asm_square!(2, (a.0).0); },
245                3 => { ark_ff_asm::x86_64_asm_square!(3, (a.0).0); },
246                4 => { ark_ff_asm::x86_64_asm_square!(4, (a.0).0); },
247                5 => { ark_ff_asm::x86_64_asm_square!(5, (a.0).0); },
248                6 => { ark_ff_asm::x86_64_asm_square!(6, (a.0).0); },
249                _ => unsafe { ark_std::hint::unreachable_unchecked() },
250            };
251            a.subtract_modulus();
252            return;
253        }
254
255        let mut r = crate::const_helpers::MulBuffer::<N>::zeroed();
256
257        let mut carry = 0;
258        for i in 0..(N - 1) {
259            for j in (i + 1)..N {
260                r[i + j] = fa::mac_with_carry(r[i + j], (a.0).0[i], (a.0).0[j], &mut carry);
261            }
262            r.b1[i] = carry;
263            carry = 0;
264        }
265
266        r.b1[N - 1] = r.b1[N - 2] >> 63;
267        for i in 2..(2 * N - 1) {
268            r[2 * N - i] = (r[2 * N - i] << 1) | (r[2 * N - (i + 1)] >> 63);
269        }
270        r.b0[1] <<= 1;
271
272        for i in 0..N {
273            r[2 * i] = fa::mac_with_carry(r[2 * i], (a.0).0[i], (a.0).0[i], &mut carry);
274            carry = fa::adc(&mut r[2 * i + 1], 0, carry);
275        }
276        // Montgomery reduction
277        let mut carry2 = 0;
278        for i in 0..N {
279            let k = r[i].wrapping_mul(Self::INV);
280            let mut carry = 0;
281            fa::mac_discard(r[i], k, Self::MODULUS.0[0], &mut carry);
282            for j in 1..N {
283                r[j + i] = fa::mac_with_carry(r[j + i], k, Self::MODULUS.0[j], &mut carry);
284            }
285            carry2 = fa::adc(&mut r.b1[i], carry, carry2);
286        }
287        (a.0).0.copy_from_slice(&r.b1);
288        if Self::MODULUS_HAS_SPARE_BIT {
289            a.subtract_modulus();
290        } else {
291            a.subtract_modulus_with_carry(carry2 != 0);
292        }
293    }
294
295    fn inverse(a: &Fp<MontBackend<Self, N>, N>) -> Option<Fp<MontBackend<Self, N>, N>> {
296        if a.is_zero() {
297            None
298        } else {
299            // Guajardo Kumar Paar Pelzl
300            // Efficient Software-Implementation of Finite Fields with Applications to
301            // Cryptography
302            // Algorithm 16 (BEA for Inversion in Fp)
303
304            let one = BigInt::from(1u64);
305
306            let mut u = a.0;
307            let mut v = Self::MODULUS;
308            let mut b = Fp::new_unchecked(Self::R2); // Avoids unnecessary reduction step.
309            let mut c = Fp::zero();
310
311            while u != one && v != one {
312                while u.is_even() {
313                    u.div2();
314
315                    if b.0.is_even() {
316                        b.0.div2();
317                    } else {
318                        let carry = b.0.add_with_carry(&Self::MODULUS);
319                        b.0.div2();
320                        if !Self::MODULUS_HAS_SPARE_BIT && carry {
321                            (b.0).0[N - 1] |= 1 << 63;
322                        }
323                    }
324                }
325
326                while v.is_even() {
327                    v.div2();
328
329                    if c.0.is_even() {
330                        c.0.div2();
331                    } else {
332                        let carry = c.0.add_with_carry(&Self::MODULUS);
333                        c.0.div2();
334                        if !Self::MODULUS_HAS_SPARE_BIT && carry {
335                            (c.0).0[N - 1] |= 1 << 63;
336                        }
337                    }
338                }
339
340                if v < u {
341                    u.sub_with_borrow(&v);
342                    b -= &c;
343                } else {
344                    v.sub_with_borrow(&u);
345                    c -= &b;
346                }
347            }
348
349            if u == one {
350                Some(b)
351            } else {
352                Some(c)
353            }
354        }
355    }
356
357    fn from_bigint(r: BigInt<N>) -> Option<Fp<MontBackend<Self, N>, N>> {
358        let mut r = Fp::new_unchecked(r);
359        if r.is_zero() {
360            Some(r)
361        } else if r.is_geq_modulus() {
362            None
363        } else {
364            r *= &Fp::new_unchecked(Self::R2);
365            Some(r)
366        }
367    }
368
369    #[inline]
370    #[unroll_for_loops(12)]
371    #[allow(clippy::modulo_one)]
372    fn into_bigint(a: Fp<MontBackend<Self, N>, N>) -> BigInt<N> {
373        let mut tmp = a.0;
374        let mut r = tmp.0;
375        // Montgomery Reduction
376        for i in 0..N {
377            let k = r[i].wrapping_mul(Self::INV);
378            let mut carry = 0;
379
380            fa::mac_with_carry(r[i], k, Self::MODULUS.0[0], &mut carry);
381            for j in 1..N {
382                r[(j + i) % N] =
383                    fa::mac_with_carry(r[(j + i) % N], k, Self::MODULUS.0[j], &mut carry);
384            }
385            r[i % N] = carry;
386        }
387        tmp.0 = r;
388        tmp
389    }
390
391    #[unroll_for_loops(12)]
392    fn sum_of_products<const M: usize>(
393        a: &[Fp<MontBackend<Self, N>, N>; M],
394        b: &[Fp<MontBackend<Self, N>, N>; M],
395    ) -> Fp<MontBackend<Self, N>, N> {
396        // Adapted from https://github.com/zkcrypto/bls12_381/pull/84 by @str4d.
397
398        // For a single `a x b` multiplication, operand scanning (schoolbook) takes each
399        // limb of `a` in turn, and multiplies it by all of the limbs of `b` to compute
400        // the result as a double-width intermediate representation, which is then fully
401        // reduced at the carry. Here however we have pairs of multiplications (a_i, b_i),
402        // the results of which are summed.
403        //
404        // The intuition for this algorithm is two-fold:
405        // - We can interleave the operand scanning for each pair, by processing the jth
406        //   limb of each `a_i` together. As these have the same offset within the overall
407        //   operand scanning flow, their results can be summed directly.
408        // - We can interleave the multiplication and reduction steps, resulting in a
409        //   single bitshift by the limb size after each iteration. This means we only
410        //   need to store a single extra limb overall, instead of keeping around all the
411        //   intermediate results and eventually having twice as many limbs.
412
413        let modulus_size = Self::MODULUS.const_num_bits() as usize;
414        if modulus_size >= 64 * N - 1 {
415            a.iter().zip(b).map(|(a, b)| *a * b).sum()
416        } else if M == 2 {
417            // Algorithm 2, line 2
418            let result = (0..N).fold(BigInt::zero(), |mut result, j| {
419                // Algorithm 2, line 3
420                let mut carry_a = 0;
421                let mut carry_b = 0;
422                for (a, b) in a.iter().zip(b) {
423                    let a = &a.0;
424                    let b = &b.0;
425                    let mut carry2 = 0;
426                    result.0[0] = fa::mac(result.0[0], a.0[j], b.0[0], &mut carry2);
427                    for k in 1..N {
428                        result.0[k] = fa::mac_with_carry(result.0[k], a.0[j], b.0[k], &mut carry2);
429                    }
430                    carry_b = fa::adc(&mut carry_a, carry_b, carry2);
431                }
432
433                let k = result.0[0].wrapping_mul(Self::INV);
434                let mut carry2 = 0;
435                fa::mac_discard(result.0[0], k, Self::MODULUS.0[0], &mut carry2);
436                for i in 1..N {
437                    result.0[i - 1] =
438                        fa::mac_with_carry(result.0[i], k, Self::MODULUS.0[i], &mut carry2);
439                }
440                result.0[N - 1] = fa::adc_no_carry(carry_a, carry_b, &mut carry2);
441                result
442            });
443            let mut result = Fp::new_unchecked(result);
444            result.subtract_modulus();
445            debug_assert_eq!(
446                a.iter().zip(b).map(|(a, b)| *a * b).sum::<Fp<_, N>>(),
447                result
448            );
449            result
450        } else {
451            let chunk_size = 2 * (N * 64 - modulus_size) - 1;
452            // chunk_size is at least 1, since MODULUS_BIT_SIZE is at most N * 64 - 1.
453            a.chunks(chunk_size)
454                .zip(b.chunks(chunk_size))
455                .map(|(a, b)| {
456                    // Algorithm 2, line 2
457                    let result = (0..N).fold(BigInt::zero(), |mut result, j| {
458                        // Algorithm 2, line 3
459                        let (temp, carry) = a.iter().zip(b).fold(
460                            (result, 0),
461                            |(mut temp, mut carry), (Fp(a, _), Fp(b, _))| {
462                                let mut carry2 = 0;
463                                temp.0[0] = fa::mac(temp.0[0], a.0[j], b.0[0], &mut carry2);
464                                for k in 1..N {
465                                    temp.0[k] =
466                                        fa::mac_with_carry(temp.0[k], a.0[j], b.0[k], &mut carry2);
467                                }
468                                carry = fa::adc_no_carry(carry, 0, &mut carry2);
469                                (temp, carry)
470                            },
471                        );
472
473                        let k = temp.0[0].wrapping_mul(Self::INV);
474                        let mut carry2 = 0;
475                        fa::mac_discard(temp.0[0], k, Self::MODULUS.0[0], &mut carry2);
476                        for i in 1..N {
477                            result.0[i - 1] =
478                                fa::mac_with_carry(temp.0[i], k, Self::MODULUS.0[i], &mut carry2);
479                        }
480                        result.0[N - 1] = fa::adc_no_carry(carry, 0, &mut carry2);
481                        result
482                    });
483                    let mut result = Fp::new_unchecked(result);
484                    result.subtract_modulus();
485                    debug_assert_eq!(
486                        a.iter().zip(b).map(|(a, b)| *a * b).sum::<Fp<_, N>>(),
487                        result
488                    );
489                    result
490                })
491                .sum()
492        }
493    }
494}
495
496/// Compute -M^{-1} mod 2^64.
497pub const fn inv<T: MontConfig<N>, const N: usize>() -> u64 {
498    // We compute this as follows.
499    // First, MODULUS mod 2^64 is just the lower 64 bits of MODULUS.
500    // Hence MODULUS mod 2^64 = MODULUS.0[0] mod 2^64.
501    //
502    // Next, computing the inverse mod 2^64 involves exponentiating by
503    // the multiplicative group order, which is euler_totient(2^64) - 1.
504    // Now, euler_totient(2^64) = 1 << 63, and so
505    // euler_totient(2^64) - 1 = (1 << 63) - 1 = 1111111... (63 digits).
506    // We compute this powering via standard square and multiply.
507    let mut inv = 1u64;
508    crate::const_for!((_i in 0..63) {
509        // Square
510        inv = inv.wrapping_mul(inv);
511        // Multiply
512        inv = inv.wrapping_mul(T::MODULUS.0[0]);
513    });
514    inv.wrapping_neg()
515}
516
517#[inline]
518pub const fn can_use_no_carry_mul_optimization<T: MontConfig<N>, const N: usize>() -> bool {
519    // Checking the modulus at compile time
520    let top_bit_is_zero = T::MODULUS.0[N - 1] >> 63 == 0;
521    let mut all_remaining_bits_are_one = T::MODULUS.0[N - 1] == u64::MAX >> 1;
522    crate::const_for!((i in 1..N) {
523        all_remaining_bits_are_one  &= T::MODULUS.0[N - i - 1] == u64::MAX;
524    });
525    top_bit_is_zero && !all_remaining_bits_are_one
526}
527
528#[inline]
529pub const fn modulus_has_spare_bit<T: MontConfig<N>, const N: usize>() -> bool {
530    T::MODULUS.0[N - 1] >> 63 == 0
531}
532
533#[inline]
534pub const fn can_use_no_carry_square_optimization<T: MontConfig<N>, const N: usize>() -> bool {
535    // Checking the modulus at compile time
536    let top_two_bits_are_zero = T::MODULUS.0[N - 1] >> 62 == 0;
537    let mut all_remaining_bits_are_one = T::MODULUS.0[N - 1] == u64::MAX >> 2;
538    crate::const_for!((i in 1..N) {
539        all_remaining_bits_are_one  &= T::MODULUS.0[N - i - 1] == u64::MAX;
540    });
541    top_two_bits_are_zero && !all_remaining_bits_are_one
542}
543
544pub const fn sqrt_precomputation<const N: usize, T: MontConfig<N>>(
545) -> Option<SqrtPrecomputation<Fp<MontBackend<T, N>, N>>> {
546    match T::MODULUS.mod_4() {
547        3 => match T::MODULUS_PLUS_ONE_DIV_FOUR.as_ref() {
548            Some(BigInt(modulus_plus_one_div_four)) => Some(SqrtPrecomputation::Case3Mod4 {
549                modulus_plus_one_div_four,
550            }),
551            None => None,
552        },
553        _ => Some(SqrtPrecomputation::TonelliShanks {
554            two_adicity: <MontBackend<T, N>>::TWO_ADICITY,
555            quadratic_nonresidue_to_trace: T::TWO_ADIC_ROOT_OF_UNITY,
556            trace_of_modulus_minus_one_div_two:
557                &<Fp<MontBackend<T, N>, N>>::TRACE_MINUS_ONE_DIV_TWO.0,
558        }),
559    }
560}
561
562/// Construct a [`Fp<MontBackend<T, N>, N>`] element from a literal string. This
563/// should be used primarily for constructing constant field elements; in a
564/// non-const context, [`Fp::from_str`](`ark_std::str::FromStr::from_str`) is
565/// preferable.
566///
567/// # Panics
568///
569/// If the integer represented by the string cannot fit in the number
570/// of limbs of the `Fp`, this macro results in a
571/// * compile-time error if used in a const context
572/// * run-time error otherwise.
573///
574/// # Usage
575///
576/// ```rust
577/// # use ark_test_curves::{MontFp, One};
578/// # use ark_test_curves::bls12_381 as ark_bls12_381;
579/// # use ark_std::str::FromStr;
580/// use ark_bls12_381::Fq;
581/// const ONE: Fq = MontFp!("1");
582/// const NEG_ONE: Fq = MontFp!("-1");
583///
584/// fn check_correctness() {
585///     assert_eq!(ONE, Fq::one());
586///     assert_eq!(Fq::from_str("1").unwrap(), ONE);
587///     assert_eq!(NEG_ONE, -Fq::one());
588/// }
589/// ```
590#[macro_export]
591macro_rules! MontFp {
592    ($c0:expr) => {{
593        let (is_positive, limbs) = $crate::ark_ff_macros::to_sign_and_limbs!($c0);
594        $crate::Fp::from_sign_and_limbs(is_positive, &limbs)
595    }};
596}
597
598pub use ark_ff_macros::MontConfig;
599
600pub use MontFp;
601
602pub struct MontBackend<T: MontConfig<N>, const N: usize>(PhantomData<T>);
603
604impl<T: MontConfig<N>, const N: usize> FpConfig<N> for MontBackend<T, N> {
605    /// The modulus of the field.
606    const MODULUS: crate::BigInt<N> = T::MODULUS;
607
608    /// A multiplicative generator of the field.
609    /// `Self::GENERATOR` is an element having multiplicative order
610    /// `Self::MODULUS - 1`.
611    const GENERATOR: Fp<Self, N> = T::GENERATOR;
612
613    /// Additive identity of the field, i.e. the element `e`
614    /// such that, for all elements `f` of the field, `e + f = f`.
615    const ZERO: Fp<Self, N> = Fp::new_unchecked(BigInt([0u64; N]));
616
617    /// Multiplicative identity of the field, i.e. the element `e`
618    /// such that, for all elements `f` of the field, `e * f = f`.
619    const ONE: Fp<Self, N> = Fp::new_unchecked(T::R);
620
621    const TWO_ADICITY: u32 = Self::MODULUS.two_adic_valuation();
622    const TWO_ADIC_ROOT_OF_UNITY: Fp<Self, N> = T::TWO_ADIC_ROOT_OF_UNITY;
623    const SMALL_SUBGROUP_BASE: Option<u32> = T::SMALL_SUBGROUP_BASE;
624    const SMALL_SUBGROUP_BASE_ADICITY: Option<u32> = T::SMALL_SUBGROUP_BASE_ADICITY;
625    const LARGE_SUBGROUP_ROOT_OF_UNITY: Option<Fp<Self, N>> = T::LARGE_SUBGROUP_ROOT_OF_UNITY;
626    const SQRT_PRECOMP: Option<crate::SqrtPrecomputation<Fp<Self, N>>> = T::SQRT_PRECOMP;
627
628    fn add_assign(a: &mut Fp<Self, N>, b: &Fp<Self, N>) {
629        T::add_assign(a, b)
630    }
631
632    fn sub_assign(a: &mut Fp<Self, N>, b: &Fp<Self, N>) {
633        T::sub_assign(a, b)
634    }
635
636    fn double_in_place(a: &mut Fp<Self, N>) {
637        T::double_in_place(a)
638    }
639
640    fn neg_in_place(a: &mut Fp<Self, N>) {
641        T::neg_in_place(a)
642    }
643
644    /// This modular multiplication algorithm uses Montgomery
645    /// reduction for efficient implementation. It also additionally
646    /// uses the "no-carry optimization" outlined
647    /// [here](https://hackmd.io/@zkteam/modular_multiplication) if
648    /// `P::MODULUS` has (a) a non-zero MSB, and (b) at least one
649    /// zero bit in the rest of the modulus.
650    #[inline]
651    fn mul_assign(a: &mut Fp<Self, N>, b: &Fp<Self, N>) {
652        T::mul_assign(a, b)
653    }
654
655    fn sum_of_products<const M: usize>(a: &[Fp<Self, N>; M], b: &[Fp<Self, N>; M]) -> Fp<Self, N> {
656        T::sum_of_products(a, b)
657    }
658
659    #[inline]
660    #[allow(unused_braces, clippy::absurd_extreme_comparisons)]
661    fn square_in_place(a: &mut Fp<Self, N>) {
662        T::square_in_place(a)
663    }
664
665    fn inverse(a: &Fp<Self, N>) -> Option<Fp<Self, N>> {
666        T::inverse(a)
667    }
668
669    fn from_bigint(r: BigInt<N>) -> Option<Fp<Self, N>> {
670        T::from_bigint(r)
671    }
672
673    #[inline]
674    #[allow(clippy::modulo_one)]
675    fn into_bigint(a: Fp<Self, N>) -> BigInt<N> {
676        T::into_bigint(a)
677    }
678}
679
680impl<T: MontConfig<N>, const N: usize> Fp<MontBackend<T, N>, N> {
681    #[doc(hidden)]
682    pub const R: BigInt<N> = T::R;
683    #[doc(hidden)]
684    pub const R2: BigInt<N> = T::R2;
685    #[doc(hidden)]
686    pub const INV: u64 = T::INV;
687
688    /// Construct a new field element from its underlying
689    /// [`struct@BigInt`] data type.
690    #[inline]
691    pub const fn new(element: BigInt<N>) -> Self {
692        let mut r = Self(element, PhantomData);
693        if r.const_is_zero() {
694            r
695        } else {
696            r = r.mul(&Fp(T::R2, PhantomData));
697            r
698        }
699    }
700
701    /// Construct a new field element from its underlying
702    /// [`struct@BigInt`] data type.
703    ///
704    /// Unlike [`Self::new`], this method does not perform Montgomery reduction.
705    /// Thus, this method should be used only when constructing
706    /// an element from an integer that has already been put in
707    /// Montgomery form.
708    #[inline]
709    pub const fn new_unchecked(element: BigInt<N>) -> Self {
710        Self(element, PhantomData)
711    }
712
713    const fn const_is_zero(&self) -> bool {
714        self.0.const_is_zero()
715    }
716
717    #[doc(hidden)]
718    const fn const_neg(self) -> Self {
719        if !self.const_is_zero() {
720            Self::new_unchecked(Self::sub_with_borrow(&T::MODULUS, &self.0))
721        } else {
722            self
723        }
724    }
725
726    /// Interpret a set of limbs (along with a sign) as a field element.
727    /// For *internal* use only; please use the `ark_ff::MontFp` macro instead
728    /// of this method
729    #[doc(hidden)]
730    pub const fn from_sign_and_limbs(is_positive: bool, limbs: &[u64]) -> Self {
731        let mut repr = BigInt::<N>([0; N]);
732        assert!(limbs.len() <= N);
733        crate::const_for!((i in 0..(limbs.len())) {
734            repr.0[i] = limbs[i];
735        });
736        let res = Self::new(repr);
737        if is_positive {
738            res
739        } else {
740            res.const_neg()
741        }
742    }
743
744    const fn mul_without_cond_subtract(mut self, other: &Self) -> (bool, Self) {
745        let (mut lo, mut hi) = ([0u64; N], [0u64; N]);
746        crate::const_for!((i in 0..N) {
747            let mut carry = 0;
748            crate::const_for!((j in 0..N) {
749                let k = i + j;
750                if k >= N {
751                    hi[k - N] = mac_with_carry!(hi[k - N], (self.0).0[i], (other.0).0[j], &mut carry);
752                } else {
753                    lo[k] = mac_with_carry!(lo[k], (self.0).0[i], (other.0).0[j], &mut carry);
754                }
755            });
756            hi[i] = carry;
757        });
758        // Montgomery reduction
759        let mut carry2 = 0;
760        crate::const_for!((i in 0..N) {
761            let tmp = lo[i].wrapping_mul(T::INV);
762            let mut carry;
763            mac!(lo[i], tmp, T::MODULUS.0[0], &mut carry);
764            crate::const_for!((j in 1..N) {
765                let k = i + j;
766                if k >= N {
767                    hi[k - N] = mac_with_carry!(hi[k - N], tmp, T::MODULUS.0[j], &mut carry);
768                }  else {
769                    lo[k] = mac_with_carry!(lo[k], tmp, T::MODULUS.0[j], &mut carry);
770                }
771            });
772            hi[i] = adc!(hi[i], carry, &mut carry2);
773        });
774
775        crate::const_for!((i in 0..N) {
776            (self.0).0[i] = hi[i];
777        });
778        (carry2 != 0, self)
779    }
780
781    const fn mul(self, other: &Self) -> Self {
782        let (carry, res) = self.mul_without_cond_subtract(other);
783        if T::MODULUS_HAS_SPARE_BIT {
784            res.const_subtract_modulus()
785        } else {
786            res.const_subtract_modulus_with_carry(carry)
787        }
788    }
789
790    const fn const_is_valid(&self) -> bool {
791        crate::const_for!((i in 0..N) {
792            if (self.0).0[N - i - 1] < T::MODULUS.0[N - i - 1] {
793                return true
794            } else if (self.0).0[N - i - 1] > T::MODULUS.0[N - i - 1] {
795                return false
796            }
797        });
798        false
799    }
800
801    #[inline]
802    const fn const_subtract_modulus(mut self) -> Self {
803        if !self.const_is_valid() {
804            self.0 = Self::sub_with_borrow(&self.0, &T::MODULUS);
805        }
806        self
807    }
808
809    #[inline]
810    const fn const_subtract_modulus_with_carry(mut self, carry: bool) -> Self {
811        if carry || !self.const_is_valid() {
812            self.0 = Self::sub_with_borrow(&self.0, &T::MODULUS);
813        }
814        self
815    }
816
817    const fn sub_with_borrow(a: &BigInt<N>, b: &BigInt<N>) -> BigInt<N> {
818        a.const_sub_with_borrow(b).0
819    }
820}
821
822#[cfg(test)]
823mod test {
824    use ark_std::{str::FromStr, vec::Vec};
825    use ark_test_curves::secp256k1::Fr;
826    use num_bigint::{BigInt, BigUint, Sign};
827
828    #[test]
829    fn test_mont_macro_correctness() {
830        let (is_positive, limbs) = str_to_limbs_u64(
831            "111192936301596926984056301862066282284536849596023571352007112326586892541694",
832        );
833        let t = Fr::from_sign_and_limbs(is_positive, &limbs);
834
835        let result: BigUint = t.into();
836        let expected = BigUint::from_str(
837            "111192936301596926984056301862066282284536849596023571352007112326586892541694",
838        )
839        .unwrap();
840
841        assert_eq!(result, expected);
842    }
843
844    fn str_to_limbs_u64(num: &str) -> (bool, Vec<u64>) {
845        let (sign, digits) = BigInt::from_str(num)
846            .expect("could not parse to bigint")
847            .to_radix_le(16);
848        let limbs = digits
849            .chunks(16)
850            .map(|chunk| {
851                let mut this = 0u64;
852                for (i, hexit) in chunk.iter().enumerate() {
853                    this += (*hexit as u64) << (4 * i);
854                }
855                this
856            })
857            .collect::<Vec<_>>();
858
859        let sign_is_positive = sign != Sign::Minus;
860        (sign_is_positive, limbs)
861    }
862}