poseidon_permutation/
permutation.rs

1#![allow(non_snake_case)]
2
3use decaf377::Fq;
4use poseidon_parameters::v1::{Alpha, MatrixOperations, PoseidonParameters};
5
6/// Represents a generic instance of `Poseidon`.
7///
8/// Intended for generic fixed-width hashing.
9pub struct Instance<
10    'a,
11    const STATE_SIZE: usize,
12    const STATE_SIZE_MINUS_1: usize,
13    const NUM_MDS_ELEMENTS: usize,
14    const NUM_STATE_SIZE_MINUS_1_ELEMENTS: usize,
15    const NUM_ROUND_ROWS: usize,
16    const NUM_ROUND_COLS: usize,
17    const NUM_ROUND_ELEMENTS: usize,
18    const NUM_PARTIAL_ROUNDS: usize,
19> {
20    /// Parameters for this instance of Poseidon.
21    parameters: &'a PoseidonParameters<
22        STATE_SIZE,
23        STATE_SIZE_MINUS_1,
24        NUM_MDS_ELEMENTS,
25        NUM_STATE_SIZE_MINUS_1_ELEMENTS,
26        NUM_ROUND_ROWS,
27        NUM_ROUND_COLS,
28        NUM_ROUND_ELEMENTS,
29        NUM_PARTIAL_ROUNDS,
30    >,
31
32    /// Inner state.
33    state_words: [Fq; STATE_SIZE],
34}
35
36impl<
37        'a,
38        const STATE_SIZE: usize,
39        const STATE_SIZE_MINUS_1: usize,
40        const NUM_MDS_ELEMENTS: usize,
41        const NUM_STATE_SIZE_MINUS_1_ELEMENTS: usize,
42        const NUM_ROUND_ROWS: usize,
43        const NUM_ROUND_COLS: usize,
44        const NUM_ROUND_ELEMENTS: usize,
45        const NUM_PARTIAL_ROUNDS: usize,
46    >
47    Instance<
48        'a,
49        STATE_SIZE,
50        STATE_SIZE_MINUS_1,
51        NUM_MDS_ELEMENTS,
52        NUM_STATE_SIZE_MINUS_1_ELEMENTS,
53        NUM_ROUND_ROWS,
54        NUM_ROUND_COLS,
55        NUM_ROUND_ELEMENTS,
56        NUM_PARTIAL_ROUNDS,
57    >
58{
59    /// Instantiate a new hash function over Fq given `Parameters`.
60    pub fn new(
61        parameters: &'a PoseidonParameters<
62            STATE_SIZE,
63            STATE_SIZE_MINUS_1,
64            NUM_MDS_ELEMENTS,
65            NUM_STATE_SIZE_MINUS_1_ELEMENTS,
66            NUM_ROUND_ROWS,
67            NUM_ROUND_COLS,
68            NUM_ROUND_ELEMENTS,
69            NUM_PARTIAL_ROUNDS,
70        >,
71    ) -> Self {
72        Self {
73            parameters,
74            state_words: [Fq::from(0u64); STATE_SIZE],
75        }
76    }
77
78    /// Fixed width hash from n:1. Outputs a F given `t` input words.
79    pub fn n_to_1_fixed_hash(&mut self, input_words: &[Fq; STATE_SIZE]) -> Fq {
80        // Set internal state words.
81        for (i, input_word) in input_words.iter().enumerate() {
82            self.state_words[i] = *input_word
83        }
84
85        // Apply Poseidon permutation.
86        self.permute();
87
88        // Emit a single element since this is a n:1 hash.
89        self.state_words[1]
90    }
91
92    /// Print out internal state.
93    pub fn output_words(&self) -> [Fq; STATE_SIZE] {
94        self.state_words
95    }
96
97    /// Permutes the internal state.
98    ///
99    /// This implementation is based on the optimized Sage implementation
100    /// `poseidonperm_x3_64_optimized.sage` provided in Appendix B of the Poseidon paper.
101    fn permute(&mut self) {
102        let R_f = self.parameters.rounds.full() / 2;
103
104        // First chunk of full rounds
105        for r in 0..R_f {
106            // Apply `AddRoundConstants` layer
107            for i in 0..STATE_SIZE {
108                self.state_words[i] += self.parameters.optimized_arc.0.get_element(r, i);
109            }
110            self.full_sub_words();
111            self.mix_layer_mds();
112        }
113        let mut round_constants_counter = R_f;
114
115        // Partial rounds
116        // First part of `AddRoundConstants` layer
117        for i in 0..STATE_SIZE {
118            self.state_words[i] += self
119                .parameters
120                .optimized_arc
121                .0
122                .get_element(round_constants_counter, i);
123        }
124        // First full matrix multiplication.
125        self.mix_layer_mi();
126
127        for r in 0..self.parameters.rounds.partial() - 1 {
128            self.partial_sub_words();
129            // Rest of `AddRoundConstants` layer, moved to after the S-box layer
130            round_constants_counter += 1;
131            self.state_words[0] += self
132                .parameters
133                .optimized_arc
134                .0
135                .get_element(round_constants_counter, 0);
136            self.sparse_mat_mul(self.parameters.rounds.partial() - r - 1);
137        }
138
139        // Last partial round
140        self.partial_sub_words();
141        self.sparse_mat_mul(0);
142        round_constants_counter += 1;
143
144        // Final full rounds
145        for _ in 0..R_f {
146            // Apply `AddRoundConstants` layer
147            for i in 0..STATE_SIZE {
148                self.state_words[i] += self
149                    .parameters
150                    .optimized_arc
151                    .0
152                    .get_element(round_constants_counter, i);
153            }
154            self.full_sub_words();
155            self.mix_layer_mds();
156            round_constants_counter += 1;
157        }
158    }
159
160    /// Fixed width hash from n:1. Outputs a F given `t` input words. Unoptimized.
161    pub fn unoptimized_n_to_1_fixed_hash(&mut self, input_words: [Fq; STATE_SIZE]) -> Fq {
162        // Set internal state words.
163        for (i, input_word) in input_words.iter().enumerate() {
164            self.state_words[i] = *input_word
165        }
166
167        // Apply Poseidon permutation.
168        self.unoptimized_permute();
169
170        // Emit a single element since this is a n:1 hash.
171        self.state_words[1]
172    }
173
174    /// Permutes the internal state.
175    ///
176    /// This implementation is based on the unoptimized Sage implementation
177    /// `poseidonperm_x5_254_3.sage` provided in Appendix B of the Poseidon paper.
178    fn unoptimized_permute(&mut self) {
179        let R_f = self.parameters.rounds.full() / 2;
180        let R_P = self.parameters.rounds.partial();
181        let mut round_constants_counter = 0;
182        let round_constants = self.parameters.arc.elements();
183
184        // First full rounds
185        for _ in 0..R_f {
186            // Apply `AddRoundConstants` layer
187            for i in 0..STATE_SIZE {
188                self.state_words[i] += round_constants[round_constants_counter];
189                round_constants_counter += 1;
190            }
191            self.full_sub_words();
192            self.mix_layer_mds();
193        }
194
195        // Partial rounds
196        for _ in 0..R_P {
197            // Apply `AddRoundConstants` layer
198            for i in 0..STATE_SIZE {
199                self.state_words[i] += round_constants[round_constants_counter];
200                round_constants_counter += 1;
201            }
202            self.partial_sub_words();
203            self.mix_layer_mds();
204        }
205
206        // Final full rounds
207        for _ in 0..R_f {
208            // Apply `AddRoundConstants` layer
209            for i in 0..STATE_SIZE {
210                self.state_words[i] += round_constants[round_constants_counter];
211                round_constants_counter += 1;
212            }
213            self.full_sub_words();
214            self.mix_layer_mds();
215        }
216    }
217
218    /// Applies the partial `SubWords` layer.
219    fn partial_sub_words(&mut self) {
220        match self.parameters.alpha {
221            Alpha::Exponent(exp) => self.state_words[0] = (self.state_words[0]).power([exp as u64]),
222            Alpha::Inverse => self.state_words[0] = Fq::from(1u64) / self.state_words[0],
223        }
224    }
225
226    /// Applies the full `SubWords` layer.
227    fn full_sub_words(&mut self) {
228        match self.parameters.alpha {
229            Alpha::Exponent(exp) => {
230                for i in 0..STATE_SIZE {
231                    self.state_words[i] = self.state_words[i].power([exp as u64]);
232                }
233            }
234            Alpha::Inverse => {
235                for i in 0..STATE_SIZE {
236                    self.state_words[i] = Fq::from(1u64) / self.state_words[i];
237                }
238            }
239        }
240    }
241
242    /// Applies the `MixLayer` using the M_i matrix.
243    fn mix_layer_mi(&mut self) {
244        let mut new_state_words = [Fq::from(0u64); STATE_SIZE];
245        for (i, row) in self.parameters.optimized_mds.M_i.iter_rows().enumerate() {
246            let sum = row
247                .iter()
248                .zip(&self.state_words)
249                .map(|(x, y)| *x * *y)
250                .sum();
251            new_state_words[i] = sum;
252        }
253        self.state_words = new_state_words;
254    }
255
256    /// Applies the `MixLayer` using the MDS matrix.
257    fn mix_layer_mds(&mut self) {
258        let mut new_state_words = [Fq::from(0u64); STATE_SIZE];
259
260        for (i, row) in self.parameters.mds.0 .0.iter_rows().enumerate() {
261            let sum = row
262                .iter()
263                .zip(&self.state_words)
264                .map(|(x, y)| *x * *y)
265                .sum();
266            new_state_words[i] = sum;
267        }
268        self.state_words = new_state_words;
269    }
270
271    /// This is `cheap_matrix_mul` in the Sage spec
272    fn sparse_mat_mul(&mut self, round_number: usize) {
273        // mul_row = [(state_words[0] * v[i]) for i in range(0, t-1)]
274        // add_row = [(mul_row[i] + state_words[i+1]) for i in range(0, t-1)]
275        let mut add_row = [Fq::from(0u64); STATE_SIZE_MINUS_1];
276        for (i, x) in self.parameters.optimized_mds.v_collection[round_number]
277            .elements
278            .iter()
279            .enumerate()
280        {
281            add_row[i] = *x * self.state_words[0] + self.state_words[i + 1];
282        }
283
284        // column_1 = [M_0_0] + w_hat
285        // state_words_new[0] = sum([column_1[i] * state_words[i] for i in range(0, t)])
286        // state_words_new = [state_words_new[0]] + add_row
287        self.state_words[0] = self.parameters.optimized_mds.M_00 * self.state_words[0]
288            + self.parameters.optimized_mds.w_hat_collection[round_number]
289                .elements
290                .iter()
291                .zip(self.state_words[1..STATE_SIZE].iter())
292                .map(|(x, y)| *x * *y)
293                .sum::<Fq>();
294
295        self.state_words[1..STATE_SIZE].copy_from_slice(&add_row[..(STATE_SIZE - 1)]);
296    }
297}