1use ark_serialize::{Read, Write};
2use ark_std::ops::{Index, IndexMut};
3
4use crate::BigInt;
5
6#[macro_export]
22macro_rules! const_for {
23 (($i:ident in $start:tt..$end:tt) $code:expr ) => {{
24 let mut $i = $start;
25 while $i < $end {
26 $code
27 $i += 1;
28 }
29 }};
30}
31
32#[derive(Copy, Clone)]
35#[repr(C, align(8))]
36pub(super) struct MulBuffer<const N: usize> {
37 pub(super) b0: [u64; N],
38 pub(super) b1: [u64; N],
39}
40
41impl<const N: usize> MulBuffer<N> {
42 const fn new(b0: [u64; N], b1: [u64; N]) -> Self {
43 Self { b0, b1 }
44 }
45
46 pub(super) const fn zeroed() -> Self {
47 let b = [0u64; N];
48 Self::new(b, b)
49 }
50
51 #[inline(always)]
52 pub(super) const fn get(&self, index: usize) -> &u64 {
53 if index < N {
54 &self.b0[index]
55 } else {
56 &self.b1[index - N]
57 }
58 }
59
60 #[inline(always)]
61 pub(super) fn get_mut(&mut self, index: usize) -> &mut u64 {
62 if index < N {
63 &mut self.b0[index]
64 } else {
65 &mut self.b1[index - N]
66 }
67 }
68}
69
70impl<const N: usize> Index<usize> for MulBuffer<N> {
71 type Output = u64;
72 #[inline(always)]
73 fn index(&self, index: usize) -> &Self::Output {
74 self.get(index)
75 }
76}
77
78impl<const N: usize> IndexMut<usize> for MulBuffer<N> {
79 #[inline(always)]
80 fn index_mut(&mut self, index: usize) -> &mut Self::Output {
81 self.get_mut(index)
82 }
83}
84
85#[derive(Copy, Clone)]
88#[repr(C, align(1))]
89pub(super) struct SerBuffer<const N: usize> {
90 pub(super) buffers: [[u8; 8]; N],
91 pub(super) last: u8,
92}
93
94impl<const N: usize> SerBuffer<N> {
95 pub(super) const fn zeroed() -> Self {
96 Self {
97 buffers: [[0u8; 8]; N],
98 last: 0u8,
99 }
100 }
101
102 #[inline(always)]
103 pub(super) const fn get(&self, index: usize) -> &u8 {
104 if index == 8 * N {
105 &self.last
106 } else {
107 let part = index / 8;
108 let in_buffer_index = index % 8;
109 &self.buffers[part][in_buffer_index]
110 }
111 }
112
113 #[inline(always)]
114 pub(super) fn get_mut(&mut self, index: usize) -> &mut u8 {
115 if index == 8 * N {
116 &mut self.last
117 } else {
118 let part = index / 8;
119 let in_buffer_index = index % 8;
120 &mut self.buffers[part][in_buffer_index]
121 }
122 }
123
124 #[allow(unsafe_code)]
125 pub(super) fn as_slice(&self) -> &[u8] {
126 unsafe { ark_std::slice::from_raw_parts((self as *const Self) as *const u8, 8 * N + 1) }
127 }
128
129 #[inline(always)]
130 pub(super) fn last_n_plus_1_bytes_mut(&mut self) -> impl Iterator<Item = &mut u8> {
131 self.buffers[N - 1]
132 .iter_mut()
133 .chain(ark_std::iter::once(&mut self.last))
134 }
135
136 #[inline(always)]
137 pub(super) fn copy_from_u8_slice(&mut self, other: &[u8]) {
138 other.chunks(8).enumerate().for_each(|(i, chunk)| {
139 if i < N {
140 self.buffers[i][..chunk.len()].copy_from_slice(chunk);
141 } else {
142 self.last = chunk[0]
143 }
144 });
145 }
146
147 #[inline(always)]
148 pub(super) fn copy_from_u64_slice(&mut self, other: &[u64; N]) {
149 other
150 .iter()
151 .zip(&mut self.buffers)
152 .for_each(|(other, this)| *this = other.to_le_bytes());
153 }
154
155 #[inline(always)]
156 pub(super) fn to_bigint(self) -> BigInt<N> {
157 let mut self_integer = BigInt::from(0u64);
158 self_integer
159 .0
160 .iter_mut()
161 .zip(self.buffers)
162 .for_each(|(other, this)| *other = u64::from_le_bytes(this));
163 self_integer
164 }
165
166 #[inline(always)]
167 pub(super) fn write_up_to(
170 &self,
171 mut other: impl Write,
172 num_bytes: usize,
173 ) -> ark_std::io::Result<()> {
174 debug_assert!(num_bytes <= 8 * N + 1, "index too large");
175 debug_assert!(num_bytes > 8 * (N - 1), "index too small");
176 for i in 0..(N - 1) {
178 other.write_all(&self.buffers[i])?;
179 }
180 let remaining_bytes = num_bytes - (8 * (N - 1));
183 let write_last_byte = remaining_bytes > 8;
184 let num_last_limb_bytes = ark_std::cmp::min(8, remaining_bytes);
185 other.write_all(&self.buffers[N - 1][..num_last_limb_bytes])?;
186 if write_last_byte {
187 other.write_all(&[self.last])?;
188 }
189 Ok(())
190 }
191
192 #[inline(always)]
193 pub(super) fn read_exact_up_to(
196 &mut self,
197 mut other: impl Read,
198 num_bytes: usize,
199 ) -> ark_std::io::Result<()> {
200 debug_assert!(num_bytes <= 8 * N + 1, "index too large");
201 debug_assert!(num_bytes > 8 * (N - 1), "index too small");
202 for i in 0..(N - 1) {
204 other.read_exact(&mut self.buffers[i])?;
205 }
206 let remaining_bytes = num_bytes - (8 * (N - 1));
209 let write_last_byte = remaining_bytes > 8;
210 let num_last_limb_bytes = ark_std::cmp::min(8, remaining_bytes);
211 other.read_exact(&mut self.buffers[N - 1][..num_last_limb_bytes])?;
212 if write_last_byte {
213 let mut last = [0u8; 1];
214 other.read_exact(&mut last)?;
215 self.last = last[0];
216 }
217 Ok(())
218 }
219}
220
221impl<const N: usize> Index<usize> for SerBuffer<N> {
222 type Output = u8;
223 #[inline(always)]
224 fn index(&self, index: usize) -> &Self::Output {
225 self.get(index)
226 }
227}
228
229impl<const N: usize> IndexMut<usize> for SerBuffer<N> {
230 #[inline(always)]
231 fn index_mut(&mut self, index: usize) -> &mut Self::Output {
232 self.get_mut(index)
233 }
234}
235
236pub(super) struct RBuffer<const N: usize>(pub [u64; N], pub u64);
237
238impl<const N: usize> RBuffer<N> {
239 pub(super) const fn num_bits(&self) -> u32 {
241 (N * 64) as u32 + (64 - self.1.leading_zeros())
242 }
243
244 pub(super) const fn get_bit(&self, i: usize) -> bool {
247 let d = i / 64;
248 let b = i % 64;
249 if d == N {
250 (self.1 >> b) & 1 == 1
251 } else {
252 (self.0[d] >> b) & 1 == 1
253 }
254 }
255}
256
257pub(super) struct R2Buffer<const N: usize>(pub [u64; N], pub [u64; N], pub u64);
258
259impl<const N: usize> R2Buffer<N> {
260 pub(super) const fn num_bits(&self) -> u32 {
262 ((2 * N) * 64) as u32 + (64 - self.2.leading_zeros())
263 }
264
265 pub(super) const fn get_bit(&self, i: usize) -> bool {
268 let d = i / 64;
269 let b = i % 64;
270 if d == 2 * N {
271 (self.2 >> b) & 1 == 1
272 } else if d >= N {
273 (self.1[d - N] >> b) & 1 == 1
274 } else {
275 (self.0[d] >> b) & 1 == 1
276 }
277 }
278}
279
280mod tests {
281 #[test]
282 fn test_mul_buffer_correctness() {
283 use super::*;
284 type Buf = MulBuffer<10>;
285 let temp = Buf::new([10u64; 10], [20u64; 10]);
286
287 for i in 0..20 {
288 if i < 10 {
289 assert_eq!(temp[i], 10);
290 } else {
291 assert_eq!(temp[i], 20);
292 }
293 }
294 }
295
296 #[test]
297 #[should_panic]
298 fn test_mul_buffer_soundness() {
299 use super::*;
300 type Buf = MulBuffer<10>;
301 let temp = Buf::new([10u64; 10], [10u64; 10]);
302
303 for i in 20..21 {
304 assert_eq!(temp[i], 10);
306 }
307 }
308}