ark_serialize/
lib.rs

1#![cfg_attr(not(feature = "std"), no_std)]
2#![warn(
3    unused,
4    future_incompatible,
5    nonstandard_style,
6    rust_2018_idioms,
7    rust_2021_compatibility
8)]
9#![forbid(unsafe_code)]
10#![doc = include_str!("../README.md")]
11mod error;
12mod flags;
13mod impls;
14
15use ark_std::borrow::ToOwned;
16pub use ark_std::io::{Read, Write};
17
18pub use error::*;
19pub use flags::*;
20
21#[cfg(feature = "derive")]
22#[doc(hidden)]
23pub use ark_serialize_derive::*;
24
25use digest::{generic_array::GenericArray, Digest, OutputSizeUser};
26
27/// Whether to use a compressed version of the serialization algorithm. Specific behavior depends
28/// on implementation. If no compressed version exists (e.g. on `Fp`), mode is ignored.
29#[derive(Copy, Clone, PartialEq, Eq)]
30pub enum Compress {
31    Yes,
32    No,
33}
34
35/// Whether to validate the element after deserializing it. Specific behavior depends on
36/// implementation. If no validation algorithm exists (e.g. on `Fp`), mode is ignored.
37#[derive(Copy, Clone, PartialEq, Eq)]
38pub enum Validate {
39    Yes,
40    No,
41}
42
43pub trait Valid: Sized + Sync {
44    fn check(&self) -> Result<(), SerializationError>;
45
46    fn batch_check<'a>(
47        batch: impl Iterator<Item = &'a Self> + Send,
48    ) -> Result<(), SerializationError>
49    where
50        Self: 'a,
51    {
52        #[cfg(feature = "parallel")]
53        {
54            use rayon::{iter::ParallelBridge, prelude::ParallelIterator};
55            batch.par_bridge().try_for_each(|e| e.check())?;
56        }
57        #[cfg(not(feature = "parallel"))]
58        {
59            for item in batch {
60                item.check()?;
61            }
62        }
63        Ok(())
64    }
65}
66
67/// Serializer in little endian format.
68/// This trait can be derived if all fields of a struct implement
69/// `CanonicalSerialize` and the `derive` feature is enabled.
70///
71/// # Example
72/// ```
73/// // The `derive` feature must be set for the derivation to work.
74/// use ark_serialize::*;
75///
76/// # #[cfg(feature = "derive")]
77/// #[derive(CanonicalSerialize)]
78/// struct TestStruct {
79///     a: u64,
80///     b: (u64, (u64, u64)),
81/// }
82/// ```
83pub trait CanonicalSerialize {
84    /// The general serialize method that takes in customization flags.
85    fn serialize_with_mode<W: Write>(
86        &self,
87        writer: W,
88        compress: Compress,
89    ) -> Result<(), SerializationError>;
90
91    fn serialized_size(&self, compress: Compress) -> usize;
92
93    fn serialize_compressed<W: Write>(&self, writer: W) -> Result<(), SerializationError> {
94        self.serialize_with_mode(writer, Compress::Yes)
95    }
96
97    fn compressed_size(&self) -> usize {
98        self.serialized_size(Compress::Yes)
99    }
100
101    fn serialize_uncompressed<W: Write>(&self, writer: W) -> Result<(), SerializationError> {
102        self.serialize_with_mode(writer, Compress::No)
103    }
104
105    fn uncompressed_size(&self) -> usize {
106        self.serialized_size(Compress::No)
107    }
108}
109
110/// Deserializer in little endian format.
111/// This trait can be derived if all fields of a struct implement
112/// `CanonicalDeserialize` and the `derive` feature is enabled.
113///
114/// # Example
115/// ```
116/// // The `derive` feature must be set for the derivation to work.
117/// use ark_serialize::*;
118///
119/// # #[cfg(feature = "derive")]
120/// #[derive(CanonicalDeserialize)]
121/// struct TestStruct {
122///     a: u64,
123///     b: (u64, (u64, u64)),
124/// }
125/// ```
126pub trait CanonicalDeserialize: Valid {
127    /// The general deserialize method that takes in customization flags.
128    fn deserialize_with_mode<R: Read>(
129        reader: R,
130        compress: Compress,
131        validate: Validate,
132    ) -> Result<Self, SerializationError>;
133
134    fn deserialize_compressed<R: Read>(reader: R) -> Result<Self, SerializationError> {
135        Self::deserialize_with_mode(reader, Compress::Yes, Validate::Yes)
136    }
137
138    fn deserialize_compressed_unchecked<R: Read>(reader: R) -> Result<Self, SerializationError> {
139        Self::deserialize_with_mode(reader, Compress::Yes, Validate::No)
140    }
141
142    fn deserialize_uncompressed<R: Read>(reader: R) -> Result<Self, SerializationError> {
143        Self::deserialize_with_mode(reader, Compress::No, Validate::Yes)
144    }
145
146    fn deserialize_uncompressed_unchecked<R: Read>(reader: R) -> Result<Self, SerializationError> {
147        Self::deserialize_with_mode(reader, Compress::No, Validate::No)
148    }
149}
150
151/// Serializer in little endian format allowing to encode flags.
152pub trait CanonicalSerializeWithFlags: CanonicalSerialize {
153    /// Serializes `self` and `flags` into `writer`.
154    fn serialize_with_flags<W: Write, F: Flags>(
155        &self,
156        writer: W,
157        flags: F,
158    ) -> Result<(), SerializationError>;
159
160    /// Serializes `self` and `flags` into `writer`.
161    fn serialized_size_with_flags<F: Flags>(&self) -> usize;
162}
163
164/// Deserializer in little endian format allowing flags to be encoded.
165pub trait CanonicalDeserializeWithFlags: Sized {
166    /// Reads `Self` and `Flags` from `reader`.
167    /// Returns empty flags by default.
168    fn deserialize_with_flags<R: Read, F: Flags>(
169        reader: R,
170    ) -> Result<(Self, F), SerializationError>;
171}
172
173// This private struct works around Serialize taking the pre-existing
174// std::io::Write instance of most digest::Digest implementations by value
175struct HashMarshaller<'a, H: Digest>(&'a mut H);
176
177impl<'a, H: Digest> ark_std::io::Write for HashMarshaller<'a, H> {
178    #[inline]
179    fn write(&mut self, buf: &[u8]) -> ark_std::io::Result<usize> {
180        Digest::update(self.0, buf);
181        Ok(buf.len())
182    }
183
184    #[inline]
185    fn flush(&mut self) -> ark_std::io::Result<()> {
186        Ok(())
187    }
188}
189
190/// The CanonicalSerialize induces a natural way to hash the
191/// corresponding value, of which this is the convenience trait.
192pub trait CanonicalSerializeHashExt: CanonicalSerialize {
193    fn hash<H: Digest>(&self) -> GenericArray<u8, <H as OutputSizeUser>::OutputSize> {
194        let mut hasher = H::new();
195        self.serialize_compressed(HashMarshaller(&mut hasher))
196            .expect("HashMarshaller::flush should be infaillible!");
197        hasher.finalize()
198    }
199
200    fn hash_uncompressed<H: Digest>(&self) -> GenericArray<u8, <H as OutputSizeUser>::OutputSize> {
201        let mut hasher = H::new();
202        self.serialize_uncompressed(HashMarshaller(&mut hasher))
203            .expect("HashMarshaller::flush should be infaillible!");
204        hasher.finalize()
205    }
206}
207
208/// CanonicalSerializeHashExt is a (blanket) extension trait of
209/// CanonicalSerialize
210impl<T: CanonicalSerialize> CanonicalSerializeHashExt for T {}
211
212#[inline]
213pub fn buffer_bit_byte_size(modulus_bits: usize) -> (usize, usize) {
214    let byte_size = buffer_byte_size(modulus_bits);
215    ((byte_size * 8), byte_size)
216}
217
218/// Converts the number of bits required to represent a number
219/// into the number of bytes required to represent it.
220#[inline]
221pub const fn buffer_byte_size(modulus_bits: usize) -> usize {
222    (modulus_bits + 7) / 8
223}
224
225#[cfg(test)]
226mod test {
227    use super::*;
228    use ark_std::{
229        collections::{BTreeMap, BTreeSet},
230        rand::RngCore,
231        string::String,
232        vec,
233        vec::Vec,
234    };
235    use num_bigint::BigUint;
236
237    #[derive(Copy, Clone, Ord, PartialOrd, Eq, PartialEq, Debug)]
238    struct Dummy;
239
240    impl CanonicalSerialize for Dummy {
241        #[inline]
242        fn serialize_with_mode<W: Write>(
243            &self,
244            mut writer: W,
245            compress: Compress,
246        ) -> Result<(), SerializationError> {
247            match compress {
248                Compress::Yes => 100u8.serialize_compressed(&mut writer),
249                Compress::No => [100u8, 200u8].serialize_compressed(&mut writer),
250            }
251        }
252
253        fn serialized_size(&self, compress: Compress) -> usize {
254            match compress {
255                Compress::Yes => 1,
256                Compress::No => 2,
257            }
258        }
259    }
260
261    impl Valid for Dummy {
262        fn check(&self) -> Result<(), SerializationError> {
263            Ok(())
264        }
265    }
266    impl CanonicalDeserialize for Dummy {
267        #[inline]
268        fn deserialize_with_mode<R: Read>(
269            reader: R,
270            compress: Compress,
271            _validate: Validate,
272        ) -> Result<Self, SerializationError> {
273            match compress {
274                Compress::Yes => assert_eq!(u8::deserialize_compressed(reader)?, 100u8),
275                Compress::No => {
276                    assert_eq!(<[u8; 2]>::deserialize_compressed(reader)?, [100u8, 200u8])
277                },
278            }
279            Ok(Dummy)
280        }
281    }
282
283    fn test_serialize<
284        T: PartialEq + core::fmt::Debug + CanonicalSerialize + CanonicalDeserialize,
285    >(
286        data: T,
287    ) {
288        for compress in [Compress::Yes, Compress::No] {
289            for validate in [Validate::Yes, Validate::No] {
290                let mut serialized = vec![0; data.serialized_size(compress)];
291                data.serialize_with_mode(&mut serialized[..], compress)
292                    .unwrap();
293                let de = T::deserialize_with_mode(&serialized[..], compress, validate).unwrap();
294                assert_eq!(data, de);
295            }
296        }
297    }
298
299    fn test_hash<T: CanonicalSerialize, H: Digest + core::fmt::Debug>(data: T) {
300        let h1 = data.hash::<H>();
301
302        let mut hash = H::new();
303        let mut serialized = vec![0; data.serialized_size(Compress::Yes)];
304        data.serialize_compressed(&mut serialized[..]).unwrap();
305        hash.update(&serialized);
306        let h2 = hash.finalize();
307
308        assert_eq!(h1, h2);
309
310        let h3 = data.hash_uncompressed::<H>();
311
312        let mut hash = H::new();
313        serialized = vec![0; data.uncompressed_size()];
314        data.serialize_uncompressed(&mut serialized[..]).unwrap();
315        hash.update(&serialized);
316        let h4 = hash.finalize();
317
318        assert_eq!(h3, h4);
319    }
320
321    // Serialize T, randomly mutate the data, and deserialize it.
322    // Ensure it fails.
323    // Up to the caller to provide a valid mutation criterion
324    // to ensure that this test always fails.
325    // This method requires a concrete instance of the data to be provided,
326    // to get the serialized size.
327    fn ensure_non_malleable_encoding<
328        T: PartialEq + core::fmt::Debug + CanonicalSerialize + CanonicalDeserialize,
329    >(
330        data: T,
331        valid_mutation: fn(&[u8]) -> bool,
332    ) {
333        let mut r = ark_std::test_rng();
334        let mut serialized = vec![0; data.compressed_size()];
335        r.fill_bytes(&mut serialized);
336        while !valid_mutation(&serialized) {
337            r.fill_bytes(&mut serialized);
338        }
339        let de = T::deserialize_compressed(&serialized[..]);
340        assert!(de.is_err());
341
342        let mut serialized = vec![0; data.uncompressed_size()];
343        r.fill_bytes(&mut serialized);
344        while !valid_mutation(&serialized) {
345            r.fill_bytes(&mut serialized);
346        }
347        let de = T::deserialize_uncompressed(&serialized[..]);
348        assert!(de.is_err());
349    }
350
351    #[test]
352    fn test_array() {
353        test_serialize([1u64, 2, 3, 4, 5]);
354        test_serialize([1u8; 33]);
355    }
356
357    #[test]
358    fn test_vec() {
359        test_serialize(vec![1u64, 2, 3, 4, 5]);
360        test_serialize(Vec::<u64>::new());
361    }
362
363    #[test]
364    fn test_uint() {
365        test_serialize(192830918usize);
366        test_serialize(192830918u64);
367        test_serialize(192830918u32);
368        test_serialize(22313u16);
369        test_serialize(123u8);
370    }
371
372    #[test]
373    fn test_string() {
374        test_serialize(String::from("arkworks"));
375    }
376
377    #[test]
378    fn test_tuple() {
379        test_serialize(());
380        test_serialize((123u64, Dummy));
381        test_serialize((123u64, 234u32, Dummy));
382    }
383
384    #[test]
385    fn test_tuple_vec() {
386        test_serialize(vec![
387            (Dummy, Dummy, Dummy),
388            (Dummy, Dummy, Dummy),
389            (Dummy, Dummy, Dummy),
390        ]);
391        test_serialize(vec![
392            (86u8, 98u64, Dummy),
393            (86u8, 98u64, Dummy),
394            (86u8, 98u64, Dummy),
395        ]);
396    }
397
398    #[test]
399    fn test_option() {
400        test_serialize(Some(Dummy));
401        test_serialize(None::<Dummy>);
402
403        test_serialize(Some(10u64));
404        test_serialize(None::<u64>);
405    }
406
407    #[test]
408    fn test_bool() {
409        test_serialize(true);
410        test_serialize(false);
411
412        let valid_mutation = |data: &[u8]| -> bool { data.len() == 1 && data[0] > 1 };
413        for _ in 0..10 {
414            ensure_non_malleable_encoding(true, valid_mutation);
415            ensure_non_malleable_encoding(false, valid_mutation);
416        }
417    }
418
419    #[test]
420    fn test_btreemap() {
421        let mut map = BTreeMap::new();
422        map.insert(0u64, Dummy);
423        map.insert(5u64, Dummy);
424        test_serialize(map);
425        let mut map = BTreeMap::new();
426        map.insert(10u64, vec![1u8, 2u8, 3u8]);
427        map.insert(50u64, vec![4u8, 5u8, 6u8]);
428        test_serialize(map);
429    }
430
431    #[test]
432    fn test_btreeset() {
433        let mut set = BTreeSet::new();
434        set.insert(Dummy);
435        set.insert(Dummy);
436        test_serialize(set);
437        let mut set = BTreeSet::new();
438        set.insert(vec![1u8, 2u8, 3u8]);
439        set.insert(vec![4u8, 5u8, 6u8]);
440        test_serialize(set);
441    }
442
443    #[test]
444    fn test_phantomdata() {
445        test_serialize(core::marker::PhantomData::<Dummy>);
446    }
447
448    #[test]
449    fn test_sha2() {
450        test_hash::<_, sha2::Sha256>(Dummy);
451        test_hash::<_, sha2::Sha512>(Dummy);
452    }
453
454    #[test]
455    fn test_blake2() {
456        test_hash::<_, blake2::Blake2b512>(Dummy);
457        test_hash::<_, blake2::Blake2s256>(Dummy);
458    }
459
460    #[test]
461    fn test_sha3() {
462        test_hash::<_, sha3::Sha3_256>(Dummy);
463        test_hash::<_, sha3::Sha3_512>(Dummy);
464    }
465
466    #[test]
467    fn test_biguint() {
468        let biguint = BigUint::from(123456u64);
469        test_serialize(biguint.clone());
470
471        let mut expected = (biguint.to_bytes_le().len() as u64).to_le_bytes().to_vec();
472        expected.extend_from_slice(&biguint.to_bytes_le());
473
474        let mut bytes = Vec::new();
475        biguint
476            .serialize_with_mode(&mut bytes, Compress::Yes)
477            .unwrap();
478        assert_eq!(bytes, expected);
479
480        let mut bytes = Vec::new();
481        biguint
482            .serialize_with_mode(&mut bytes, Compress::No)
483            .unwrap();
484        assert_eq!(bytes, expected);
485    }
486}