1#![cfg_attr(not(feature = "std"), no_std)]
2#![warn(
3 unused,
4 future_incompatible,
5 nonstandard_style,
6 rust_2018_idioms,
7 rust_2021_compatibility
8)]
9#![forbid(unsafe_code)]
10#![doc = include_str!("../README.md")]
11mod error;
12mod flags;
13mod impls;
14
15use ark_std::borrow::ToOwned;
16pub use ark_std::io::{Read, Write};
17
18pub use error::*;
19pub use flags::*;
20
21#[cfg(feature = "derive")]
22#[doc(hidden)]
23pub use ark_serialize_derive::*;
24
25use digest::{generic_array::GenericArray, Digest, OutputSizeUser};
26
27#[derive(Copy, Clone, PartialEq, Eq)]
30pub enum Compress {
31 Yes,
32 No,
33}
34
35#[derive(Copy, Clone, PartialEq, Eq)]
38pub enum Validate {
39 Yes,
40 No,
41}
42
43pub trait Valid: Sized + Sync {
44 fn check(&self) -> Result<(), SerializationError>;
45
46 fn batch_check<'a>(
47 batch: impl Iterator<Item = &'a Self> + Send,
48 ) -> Result<(), SerializationError>
49 where
50 Self: 'a,
51 {
52 #[cfg(feature = "parallel")]
53 {
54 use rayon::{iter::ParallelBridge, prelude::ParallelIterator};
55 batch.par_bridge().try_for_each(|e| e.check())?;
56 }
57 #[cfg(not(feature = "parallel"))]
58 {
59 for item in batch {
60 item.check()?;
61 }
62 }
63 Ok(())
64 }
65}
66
67pub trait CanonicalSerialize {
84 fn serialize_with_mode<W: Write>(
86 &self,
87 writer: W,
88 compress: Compress,
89 ) -> Result<(), SerializationError>;
90
91 fn serialized_size(&self, compress: Compress) -> usize;
92
93 fn serialize_compressed<W: Write>(&self, writer: W) -> Result<(), SerializationError> {
94 self.serialize_with_mode(writer, Compress::Yes)
95 }
96
97 fn compressed_size(&self) -> usize {
98 self.serialized_size(Compress::Yes)
99 }
100
101 fn serialize_uncompressed<W: Write>(&self, writer: W) -> Result<(), SerializationError> {
102 self.serialize_with_mode(writer, Compress::No)
103 }
104
105 fn uncompressed_size(&self) -> usize {
106 self.serialized_size(Compress::No)
107 }
108}
109
110pub trait CanonicalDeserialize: Valid {
127 fn deserialize_with_mode<R: Read>(
129 reader: R,
130 compress: Compress,
131 validate: Validate,
132 ) -> Result<Self, SerializationError>;
133
134 fn deserialize_compressed<R: Read>(reader: R) -> Result<Self, SerializationError> {
135 Self::deserialize_with_mode(reader, Compress::Yes, Validate::Yes)
136 }
137
138 fn deserialize_compressed_unchecked<R: Read>(reader: R) -> Result<Self, SerializationError> {
139 Self::deserialize_with_mode(reader, Compress::Yes, Validate::No)
140 }
141
142 fn deserialize_uncompressed<R: Read>(reader: R) -> Result<Self, SerializationError> {
143 Self::deserialize_with_mode(reader, Compress::No, Validate::Yes)
144 }
145
146 fn deserialize_uncompressed_unchecked<R: Read>(reader: R) -> Result<Self, SerializationError> {
147 Self::deserialize_with_mode(reader, Compress::No, Validate::No)
148 }
149}
150
151pub trait CanonicalSerializeWithFlags: CanonicalSerialize {
153 fn serialize_with_flags<W: Write, F: Flags>(
155 &self,
156 writer: W,
157 flags: F,
158 ) -> Result<(), SerializationError>;
159
160 fn serialized_size_with_flags<F: Flags>(&self) -> usize;
162}
163
164pub trait CanonicalDeserializeWithFlags: Sized {
166 fn deserialize_with_flags<R: Read, F: Flags>(
169 reader: R,
170 ) -> Result<(Self, F), SerializationError>;
171}
172
173struct HashMarshaller<'a, H: Digest>(&'a mut H);
176
177impl<'a, H: Digest> ark_std::io::Write for HashMarshaller<'a, H> {
178 #[inline]
179 fn write(&mut self, buf: &[u8]) -> ark_std::io::Result<usize> {
180 Digest::update(self.0, buf);
181 Ok(buf.len())
182 }
183
184 #[inline]
185 fn flush(&mut self) -> ark_std::io::Result<()> {
186 Ok(())
187 }
188}
189
190pub trait CanonicalSerializeHashExt: CanonicalSerialize {
193 fn hash<H: Digest>(&self) -> GenericArray<u8, <H as OutputSizeUser>::OutputSize> {
194 let mut hasher = H::new();
195 self.serialize_compressed(HashMarshaller(&mut hasher))
196 .expect("HashMarshaller::flush should be infaillible!");
197 hasher.finalize()
198 }
199
200 fn hash_uncompressed<H: Digest>(&self) -> GenericArray<u8, <H as OutputSizeUser>::OutputSize> {
201 let mut hasher = H::new();
202 self.serialize_uncompressed(HashMarshaller(&mut hasher))
203 .expect("HashMarshaller::flush should be infaillible!");
204 hasher.finalize()
205 }
206}
207
208impl<T: CanonicalSerialize> CanonicalSerializeHashExt for T {}
211
212#[inline]
213pub fn buffer_bit_byte_size(modulus_bits: usize) -> (usize, usize) {
214 let byte_size = buffer_byte_size(modulus_bits);
215 ((byte_size * 8), byte_size)
216}
217
218#[inline]
221pub const fn buffer_byte_size(modulus_bits: usize) -> usize {
222 (modulus_bits + 7) / 8
223}
224
225#[cfg(test)]
226mod test {
227 use super::*;
228 use ark_std::{
229 collections::{BTreeMap, BTreeSet},
230 rand::RngCore,
231 string::String,
232 vec,
233 vec::Vec,
234 };
235 use num_bigint::BigUint;
236
237 #[derive(Copy, Clone, Ord, PartialOrd, Eq, PartialEq, Debug)]
238 struct Dummy;
239
240 impl CanonicalSerialize for Dummy {
241 #[inline]
242 fn serialize_with_mode<W: Write>(
243 &self,
244 mut writer: W,
245 compress: Compress,
246 ) -> Result<(), SerializationError> {
247 match compress {
248 Compress::Yes => 100u8.serialize_compressed(&mut writer),
249 Compress::No => [100u8, 200u8].serialize_compressed(&mut writer),
250 }
251 }
252
253 fn serialized_size(&self, compress: Compress) -> usize {
254 match compress {
255 Compress::Yes => 1,
256 Compress::No => 2,
257 }
258 }
259 }
260
261 impl Valid for Dummy {
262 fn check(&self) -> Result<(), SerializationError> {
263 Ok(())
264 }
265 }
266 impl CanonicalDeserialize for Dummy {
267 #[inline]
268 fn deserialize_with_mode<R: Read>(
269 reader: R,
270 compress: Compress,
271 _validate: Validate,
272 ) -> Result<Self, SerializationError> {
273 match compress {
274 Compress::Yes => assert_eq!(u8::deserialize_compressed(reader)?, 100u8),
275 Compress::No => {
276 assert_eq!(<[u8; 2]>::deserialize_compressed(reader)?, [100u8, 200u8])
277 },
278 }
279 Ok(Dummy)
280 }
281 }
282
283 fn test_serialize<
284 T: PartialEq + core::fmt::Debug + CanonicalSerialize + CanonicalDeserialize,
285 >(
286 data: T,
287 ) {
288 for compress in [Compress::Yes, Compress::No] {
289 for validate in [Validate::Yes, Validate::No] {
290 let mut serialized = vec![0; data.serialized_size(compress)];
291 data.serialize_with_mode(&mut serialized[..], compress)
292 .unwrap();
293 let de = T::deserialize_with_mode(&serialized[..], compress, validate).unwrap();
294 assert_eq!(data, de);
295 }
296 }
297 }
298
299 fn test_hash<T: CanonicalSerialize, H: Digest + core::fmt::Debug>(data: T) {
300 let h1 = data.hash::<H>();
301
302 let mut hash = H::new();
303 let mut serialized = vec![0; data.serialized_size(Compress::Yes)];
304 data.serialize_compressed(&mut serialized[..]).unwrap();
305 hash.update(&serialized);
306 let h2 = hash.finalize();
307
308 assert_eq!(h1, h2);
309
310 let h3 = data.hash_uncompressed::<H>();
311
312 let mut hash = H::new();
313 serialized = vec![0; data.uncompressed_size()];
314 data.serialize_uncompressed(&mut serialized[..]).unwrap();
315 hash.update(&serialized);
316 let h4 = hash.finalize();
317
318 assert_eq!(h3, h4);
319 }
320
321 fn ensure_non_malleable_encoding<
328 T: PartialEq + core::fmt::Debug + CanonicalSerialize + CanonicalDeserialize,
329 >(
330 data: T,
331 valid_mutation: fn(&[u8]) -> bool,
332 ) {
333 let mut r = ark_std::test_rng();
334 let mut serialized = vec![0; data.compressed_size()];
335 r.fill_bytes(&mut serialized);
336 while !valid_mutation(&serialized) {
337 r.fill_bytes(&mut serialized);
338 }
339 let de = T::deserialize_compressed(&serialized[..]);
340 assert!(de.is_err());
341
342 let mut serialized = vec![0; data.uncompressed_size()];
343 r.fill_bytes(&mut serialized);
344 while !valid_mutation(&serialized) {
345 r.fill_bytes(&mut serialized);
346 }
347 let de = T::deserialize_uncompressed(&serialized[..]);
348 assert!(de.is_err());
349 }
350
351 #[test]
352 fn test_array() {
353 test_serialize([1u64, 2, 3, 4, 5]);
354 test_serialize([1u8; 33]);
355 }
356
357 #[test]
358 fn test_vec() {
359 test_serialize(vec![1u64, 2, 3, 4, 5]);
360 test_serialize(Vec::<u64>::new());
361 }
362
363 #[test]
364 fn test_uint() {
365 test_serialize(192830918usize);
366 test_serialize(192830918u64);
367 test_serialize(192830918u32);
368 test_serialize(22313u16);
369 test_serialize(123u8);
370 }
371
372 #[test]
373 fn test_string() {
374 test_serialize(String::from("arkworks"));
375 }
376
377 #[test]
378 fn test_tuple() {
379 test_serialize(());
380 test_serialize((123u64, Dummy));
381 test_serialize((123u64, 234u32, Dummy));
382 }
383
384 #[test]
385 fn test_tuple_vec() {
386 test_serialize(vec![
387 (Dummy, Dummy, Dummy),
388 (Dummy, Dummy, Dummy),
389 (Dummy, Dummy, Dummy),
390 ]);
391 test_serialize(vec![
392 (86u8, 98u64, Dummy),
393 (86u8, 98u64, Dummy),
394 (86u8, 98u64, Dummy),
395 ]);
396 }
397
398 #[test]
399 fn test_option() {
400 test_serialize(Some(Dummy));
401 test_serialize(None::<Dummy>);
402
403 test_serialize(Some(10u64));
404 test_serialize(None::<u64>);
405 }
406
407 #[test]
408 fn test_bool() {
409 test_serialize(true);
410 test_serialize(false);
411
412 let valid_mutation = |data: &[u8]| -> bool { data.len() == 1 && data[0] > 1 };
413 for _ in 0..10 {
414 ensure_non_malleable_encoding(true, valid_mutation);
415 ensure_non_malleable_encoding(false, valid_mutation);
416 }
417 }
418
419 #[test]
420 fn test_btreemap() {
421 let mut map = BTreeMap::new();
422 map.insert(0u64, Dummy);
423 map.insert(5u64, Dummy);
424 test_serialize(map);
425 let mut map = BTreeMap::new();
426 map.insert(10u64, vec![1u8, 2u8, 3u8]);
427 map.insert(50u64, vec![4u8, 5u8, 6u8]);
428 test_serialize(map);
429 }
430
431 #[test]
432 fn test_btreeset() {
433 let mut set = BTreeSet::new();
434 set.insert(Dummy);
435 set.insert(Dummy);
436 test_serialize(set);
437 let mut set = BTreeSet::new();
438 set.insert(vec![1u8, 2u8, 3u8]);
439 set.insert(vec![4u8, 5u8, 6u8]);
440 test_serialize(set);
441 }
442
443 #[test]
444 fn test_phantomdata() {
445 test_serialize(core::marker::PhantomData::<Dummy>);
446 }
447
448 #[test]
449 fn test_sha2() {
450 test_hash::<_, sha2::Sha256>(Dummy);
451 test_hash::<_, sha2::Sha512>(Dummy);
452 }
453
454 #[test]
455 fn test_blake2() {
456 test_hash::<_, blake2::Blake2b512>(Dummy);
457 test_hash::<_, blake2::Blake2s256>(Dummy);
458 }
459
460 #[test]
461 fn test_sha3() {
462 test_hash::<_, sha3::Sha3_256>(Dummy);
463 test_hash::<_, sha3::Sha3_512>(Dummy);
464 }
465
466 #[test]
467 fn test_biguint() {
468 let biguint = BigUint::from(123456u64);
469 test_serialize(biguint.clone());
470
471 let mut expected = (biguint.to_bytes_le().len() as u64).to_le_bytes().to_vec();
472 expected.extend_from_slice(&biguint.to_bytes_le());
473
474 let mut bytes = Vec::new();
475 biguint
476 .serialize_with_mode(&mut bytes, Compress::Yes)
477 .unwrap();
478 assert_eq!(bytes, expected);
479
480 let mut bytes = Vec::new();
481 biguint
482 .serialize_with_mode(&mut bytes, Compress::No)
483 .unwrap();
484 assert_eq!(bytes, expected);
485 }
486}