
1use ark_serialize::{Read, Write};
2use ark_std::ops::{Index, IndexMut};
4use crate::BigInt;
6/// A helper macro for emulating `for` loops in a `const` context.
7/// # Usage
8/// ```rust
9/// # use ark_ff::const_for;
10/// const fn for_in_const() {
11///     let mut array = [0usize; 4];
12///     const_for!((i in 0..(array.len())) { // We need to wrap the `array.len()` in parenthesis.
13///         array[i] = i;
14///     });
15///     assert!(array[0] == 0);
16///     assert!(array[1] == 1);
17///     assert!(array[2] == 2);
18///     assert!(array[3] == 3);
19/// }
20/// ```
22macro_rules! const_for {
23    (($i:ident in $start:tt..$end:tt)  $code:expr ) => {{
24        let mut $i = $start;
25        while $i < $end {
26            $code
27            $i += 1;
28        }
29    }};
32/// A buffer to hold values of size 2 * N. This is mostly
33/// a hack that's necessary until `generic_const_exprs` is stable.
34#[derive(Copy, Clone)]
35#[repr(C, align(8))]
36pub(super) struct MulBuffer<const N: usize> {
37    pub(super) b0: [u64; N],
38    pub(super) b1: [u64; N],
41impl<const N: usize> MulBuffer<N> {
42    const fn new(b0: [u64; N], b1: [u64; N]) -> Self {
43        Self { b0, b1 }
44    }
46    pub(super) const fn zeroed() -> Self {
47        let b = [0u64; N];
48        Self::new(b, b)
49    }
51    #[inline(always)]
52    pub(super) const fn get(&self, index: usize) -> &u64 {
53        if index < N {
54            &self.b0[index]
55        } else {
56            &self.b1[index - N]
57        }
58    }
60    #[inline(always)]
61    pub(super) fn get_mut(&mut self, index: usize) -> &mut u64 {
62        if index < N {
63            &mut self.b0[index]
64        } else {
65            &mut self.b1[index - N]
66        }
67    }
70impl<const N: usize> Index<usize> for MulBuffer<N> {
71    type Output = u64;
72    #[inline(always)]
73    fn index(&self, index: usize) -> &Self::Output {
74        self.get(index)
75    }
78impl<const N: usize> IndexMut<usize> for MulBuffer<N> {
79    #[inline(always)]
80    fn index_mut(&mut self, index: usize) -> &mut Self::Output {
81        self.get_mut(index)
82    }
85/// A buffer to hold values of size 8 * N + 1 bytes. This is mostly
86/// a hack that's necessary until `generic_const_exprs` is stable.
87#[derive(Copy, Clone)]
88#[repr(C, align(1))]
89pub(super) struct SerBuffer<const N: usize> {
90    pub(super) buffers: [[u8; 8]; N],
91    pub(super) last: u8,
94impl<const N: usize> SerBuffer<N> {
95    pub(super) const fn zeroed() -> Self {
96        Self {
97            buffers: [[0u8; 8]; N],
98            last: 0u8,
99        }
100    }
102    #[inline(always)]
103    pub(super) const fn get(&self, index: usize) -> &u8 {
104        if index == 8 * N {
105            &self.last
106        } else {
107            let part = index / 8;
108            let in_buffer_index = index % 8;
109            &self.buffers[part][in_buffer_index]
110        }
111    }
113    #[inline(always)]
114    pub(super) fn get_mut(&mut self, index: usize) -> &mut u8 {
115        if index == 8 * N {
116            &mut self.last
117        } else {
118            let part = index / 8;
119            let in_buffer_index = index % 8;
120            &mut self.buffers[part][in_buffer_index]
121        }
122    }
124    #[allow(unsafe_code)]
125    pub(super) fn as_slice(&self) -> &[u8] {
126        unsafe { ark_std::slice::from_raw_parts((self as *const Self) as *const u8, 8 * N + 1) }
127    }
129    #[inline(always)]
130    pub(super) fn last_n_plus_1_bytes_mut(&mut self) -> impl Iterator<Item = &mut u8> {
131        self.buffers[N - 1]
132            .iter_mut()
133            .chain(ark_std::iter::once(&mut self.last))
134    }
136    #[inline(always)]
137    pub(super) fn copy_from_u8_slice(&mut self, other: &[u8]) {
138        other.chunks(8).enumerate().for_each(|(i, chunk)| {
139            if i < N {
140                self.buffers[i][..chunk.len()].copy_from_slice(chunk);
141            } else {
142                self.last = chunk[0]
143            }
144        });
145    }
147    #[inline(always)]
148    pub(super) fn copy_from_u64_slice(&mut self, other: &[u64; N]) {
149        other
150            .iter()
151            .zip(&mut self.buffers)
152            .for_each(|(other, this)| *this = other.to_le_bytes());
153    }
155    #[inline(always)]
156    pub(super) fn to_bigint(self) -> BigInt<N> {
157        let mut self_integer = BigInt::from(0u64);
158        self_integer
159            .0
160            .iter_mut()
161            .zip(self.buffers)
162            .for_each(|(other, this)| *other = u64::from_le_bytes(this));
163        self_integer
164    }
166    #[inline(always)]
167    /// Write up to `num_bytes` bytes from `self` to `other`.
168    /// `num_bytes` is allowed to range from `8 * (N - 1) + 1` to `8 * N + 1`.
169    pub(super) fn write_up_to(
170        &self,
171        mut other: impl Write,
172        num_bytes: usize,
173    ) -> ark_std::io::Result<()> {
174        debug_assert!(num_bytes <= 8 * N + 1, "index too large");
175        debug_assert!(num_bytes > 8 * (N - 1), "index too small");
176        // unconditionally write first `N - 1` limbs.
177        for i in 0..(N - 1) {
178            other.write_all(&self.buffers[i])?;
179        }
180        // for the `N`-th limb, depending on `index`, we can write anywhere from
181        // 1 to all bytes.
182        let remaining_bytes = num_bytes - (8 * (N - 1));
183        let write_last_byte = remaining_bytes > 8;
184        let num_last_limb_bytes = ark_std::cmp::min(8, remaining_bytes);
185        other.write_all(&self.buffers[N - 1][..num_last_limb_bytes])?;
186        if write_last_byte {
187            other.write_all(&[self.last])?;
188        }
189        Ok(())
190    }
192    #[inline(always)]
193    /// Read up to `num_bytes` bytes from `other` to `self`.
194    /// `num_bytes` is allowed to range from `8 * (N - 1)` to `8 * N + 1`.
195    pub(super) fn read_exact_up_to(
196        &mut self,
197        mut other: impl Read,
198        num_bytes: usize,
199    ) -> ark_std::io::Result<()> {
200        debug_assert!(num_bytes <= 8 * N + 1, "index too large");
201        debug_assert!(num_bytes > 8 * (N - 1), "index too small");
202        // unconditionally write first `N - 1` limbs.
203        for i in 0..(N - 1) {
204            other.read_exact(&mut self.buffers[i])?;
205        }
206        // for the `N`-th limb, depending on `index`, we can write anywhere from
207        // 1 to all bytes.
208        let remaining_bytes = num_bytes - (8 * (N - 1));
209        let write_last_byte = remaining_bytes > 8;
210        let num_last_limb_bytes = ark_std::cmp::min(8, remaining_bytes);
211        other.read_exact(&mut self.buffers[N - 1][..num_last_limb_bytes])?;
212        if write_last_byte {
213            let mut last = [0u8; 1];
214            other.read_exact(&mut last)?;
215            self.last = last[0];
216        }
217        Ok(())
218    }
221impl<const N: usize> Index<usize> for SerBuffer<N> {
222    type Output = u8;
223    #[inline(always)]
224    fn index(&self, index: usize) -> &Self::Output {
225        self.get(index)
226    }
229impl<const N: usize> IndexMut<usize> for SerBuffer<N> {
230    #[inline(always)]
231    fn index_mut(&mut self, index: usize) -> &mut Self::Output {
232        self.get_mut(index)
233    }
236pub(super) struct RBuffer<const N: usize>(pub [u64; N], pub u64);
238impl<const N: usize> RBuffer<N> {
239    /// Find the number of bits in the binary decomposition of `self`.
240    pub(super) const fn num_bits(&self) -> u32 {
241        (N * 64) as u32 + (64 - self.1.leading_zeros())
242    }
244    /// Returns the `i`-th bit where bit 0 is the least significant one.
245    /// In other words, the bit with weight `2^i`.
246    pub(super) const fn get_bit(&self, i: usize) -> bool {
247        let d = i / 64;
248        let b = i % 64;
249        if d == N {
250            (self.1 >> b) & 1 == 1
251        } else {
252            (self.0[d] >> b) & 1 == 1
253        }
254    }
257pub(super) struct R2Buffer<const N: usize>(pub [u64; N], pub [u64; N], pub u64);
259impl<const N: usize> R2Buffer<N> {
260    /// Find the number of bits in the binary decomposition of `self`.
261    pub(super) const fn num_bits(&self) -> u32 {
262        ((2 * N) * 64) as u32 + (64 - self.2.leading_zeros())
263    }
265    /// Returns the `i`-th bit where bit 0 is the least significant one.
266    /// In other words, the bit with weight `2^i`.
267    pub(super) const fn get_bit(&self, i: usize) -> bool {
268        let d = i / 64;
269        let b = i % 64;
270        if d == 2 * N {
271            (self.2 >> b) & 1 == 1
272        } else if d >= N {
273            (self.1[d - N] >> b) & 1 == 1
274        } else {
275            (self.0[d] >> b) & 1 == 1
276        }
277    }
280mod tests {
281    #[test]
282    fn test_mul_buffer_correctness() {
283        use super::*;
284        type Buf = MulBuffer<10>;
285        let temp = Buf::new([10u64; 10], [20u64; 10]);
287        for i in 0..20 {
288            if i < 10 {
289                assert_eq!(temp[i], 10);
290            } else {
291                assert_eq!(temp[i], 20);
292            }
293        }
294    }
296    #[test]
297    #[should_panic]
298    fn test_mul_buffer_soundness() {
299        use super::*;
300        type Buf = MulBuffer<10>;
301        let temp = Buf::new([10u64; 10], [10u64; 10]);
303        for i in 20..21 {
304            // indexing `temp[20]` should panic
305            assert_eq!(temp[i], 10);
306        }
307    }