poseidon_permutation/
permutation.rs1#![allow(non_snake_case)]
2
3use decaf377::Fq;
4use poseidon_parameters::v1::{Alpha, MatrixOperations, PoseidonParameters};
5
6pub struct Instance<
10 'a,
11 const STATE_SIZE: usize,
12 const STATE_SIZE_MINUS_1: usize,
13 const NUM_MDS_ELEMENTS: usize,
14 const NUM_STATE_SIZE_MINUS_1_ELEMENTS: usize,
15 const NUM_ROUND_ROWS: usize,
16 const NUM_ROUND_COLS: usize,
17 const NUM_ROUND_ELEMENTS: usize,
18 const NUM_PARTIAL_ROUNDS: usize,
19> {
20 parameters: &'a PoseidonParameters<
22 STATE_SIZE,
23 STATE_SIZE_MINUS_1,
24 NUM_MDS_ELEMENTS,
25 NUM_STATE_SIZE_MINUS_1_ELEMENTS,
26 NUM_ROUND_ROWS,
27 NUM_ROUND_COLS,
28 NUM_ROUND_ELEMENTS,
29 NUM_PARTIAL_ROUNDS,
30 >,
31
32 state_words: [Fq; STATE_SIZE],
34}
35
36impl<
37 'a,
38 const STATE_SIZE: usize,
39 const STATE_SIZE_MINUS_1: usize,
40 const NUM_MDS_ELEMENTS: usize,
41 const NUM_STATE_SIZE_MINUS_1_ELEMENTS: usize,
42 const NUM_ROUND_ROWS: usize,
43 const NUM_ROUND_COLS: usize,
44 const NUM_ROUND_ELEMENTS: usize,
45 const NUM_PARTIAL_ROUNDS: usize,
46 >
47 Instance<
48 'a,
49 STATE_SIZE,
50 STATE_SIZE_MINUS_1,
51 NUM_MDS_ELEMENTS,
52 NUM_STATE_SIZE_MINUS_1_ELEMENTS,
53 NUM_ROUND_ROWS,
54 NUM_ROUND_COLS,
55 NUM_ROUND_ELEMENTS,
56 NUM_PARTIAL_ROUNDS,
57 >
58{
59 pub fn new(
61 parameters: &'a PoseidonParameters<
62 STATE_SIZE,
63 STATE_SIZE_MINUS_1,
64 NUM_MDS_ELEMENTS,
65 NUM_STATE_SIZE_MINUS_1_ELEMENTS,
66 NUM_ROUND_ROWS,
67 NUM_ROUND_COLS,
68 NUM_ROUND_ELEMENTS,
69 NUM_PARTIAL_ROUNDS,
70 >,
71 ) -> Self {
72 Self {
73 parameters,
74 state_words: [Fq::from(0u64); STATE_SIZE],
75 }
76 }
77
78 pub fn n_to_1_fixed_hash(&mut self, input_words: &[Fq; STATE_SIZE]) -> Fq {
80 for (i, input_word) in input_words.iter().enumerate() {
82 self.state_words[i] = *input_word
83 }
84
85 self.permute();
87
88 self.state_words[1]
90 }
91
92 pub fn output_words(&self) -> [Fq; STATE_SIZE] {
94 self.state_words
95 }
96
97 fn permute(&mut self) {
102 let R_f = self.parameters.rounds.full() / 2;
103
104 for r in 0..R_f {
106 for i in 0..STATE_SIZE {
108 self.state_words[i] += self.parameters.optimized_arc.0.get_element(r, i);
109 }
110 self.full_sub_words();
111 self.mix_layer_mds();
112 }
113 let mut round_constants_counter = R_f;
114
115 for i in 0..STATE_SIZE {
118 self.state_words[i] += self
119 .parameters
120 .optimized_arc
121 .0
122 .get_element(round_constants_counter, i);
123 }
124 self.mix_layer_mi();
126
127 for r in 0..self.parameters.rounds.partial() - 1 {
128 self.partial_sub_words();
129 round_constants_counter += 1;
131 self.state_words[0] += self
132 .parameters
133 .optimized_arc
134 .0
135 .get_element(round_constants_counter, 0);
136 self.sparse_mat_mul(self.parameters.rounds.partial() - r - 1);
137 }
138
139 self.partial_sub_words();
141 self.sparse_mat_mul(0);
142 round_constants_counter += 1;
143
144 for _ in 0..R_f {
146 for i in 0..STATE_SIZE {
148 self.state_words[i] += self
149 .parameters
150 .optimized_arc
151 .0
152 .get_element(round_constants_counter, i);
153 }
154 self.full_sub_words();
155 self.mix_layer_mds();
156 round_constants_counter += 1;
157 }
158 }
159
160 pub fn unoptimized_n_to_1_fixed_hash(&mut self, input_words: [Fq; STATE_SIZE]) -> Fq {
162 for (i, input_word) in input_words.iter().enumerate() {
164 self.state_words[i] = *input_word
165 }
166
167 self.unoptimized_permute();
169
170 self.state_words[1]
172 }
173
174 fn unoptimized_permute(&mut self) {
179 let R_f = self.parameters.rounds.full() / 2;
180 let R_P = self.parameters.rounds.partial();
181 let mut round_constants_counter = 0;
182 let round_constants = self.parameters.arc.elements();
183
184 for _ in 0..R_f {
186 for i in 0..STATE_SIZE {
188 self.state_words[i] += round_constants[round_constants_counter];
189 round_constants_counter += 1;
190 }
191 self.full_sub_words();
192 self.mix_layer_mds();
193 }
194
195 for _ in 0..R_P {
197 for i in 0..STATE_SIZE {
199 self.state_words[i] += round_constants[round_constants_counter];
200 round_constants_counter += 1;
201 }
202 self.partial_sub_words();
203 self.mix_layer_mds();
204 }
205
206 for _ in 0..R_f {
208 for i in 0..STATE_SIZE {
210 self.state_words[i] += round_constants[round_constants_counter];
211 round_constants_counter += 1;
212 }
213 self.full_sub_words();
214 self.mix_layer_mds();
215 }
216 }
217
218 fn partial_sub_words(&mut self) {
220 match self.parameters.alpha {
221 Alpha::Exponent(exp) => self.state_words[0] = (self.state_words[0]).power([exp as u64]),
222 Alpha::Inverse => self.state_words[0] = Fq::from(1u64) / self.state_words[0],
223 }
224 }
225
226 fn full_sub_words(&mut self) {
228 match self.parameters.alpha {
229 Alpha::Exponent(exp) => {
230 for i in 0..STATE_SIZE {
231 self.state_words[i] = self.state_words[i].power([exp as u64]);
232 }
233 }
234 Alpha::Inverse => {
235 for i in 0..STATE_SIZE {
236 self.state_words[i] = Fq::from(1u64) / self.state_words[i];
237 }
238 }
239 }
240 }
241
242 fn mix_layer_mi(&mut self) {
244 let mut new_state_words = [Fq::from(0u64); STATE_SIZE];
245 for (i, row) in self.parameters.optimized_mds.M_i.iter_rows().enumerate() {
246 let sum = row
247 .iter()
248 .zip(&self.state_words)
249 .map(|(x, y)| *x * *y)
250 .sum();
251 new_state_words[i] = sum;
252 }
253 self.state_words = new_state_words;
254 }
255
256 fn mix_layer_mds(&mut self) {
258 let mut new_state_words = [Fq::from(0u64); STATE_SIZE];
259
260 for (i, row) in self.parameters.mds.0 .0.iter_rows().enumerate() {
261 let sum = row
262 .iter()
263 .zip(&self.state_words)
264 .map(|(x, y)| *x * *y)
265 .sum();
266 new_state_words[i] = sum;
267 }
268 self.state_words = new_state_words;
269 }
270
271 fn sparse_mat_mul(&mut self, round_number: usize) {
273 let mut add_row = [Fq::from(0u64); STATE_SIZE_MINUS_1];
276 for (i, x) in self.parameters.optimized_mds.v_collection[round_number]
277 .elements
278 .iter()
279 .enumerate()
280 {
281 add_row[i] = *x * self.state_words[0] + self.state_words[i + 1];
282 }
283
284 self.state_words[0] = self.parameters.optimized_mds.M_00 * self.state_words[0]
288 + self.parameters.optimized_mds.w_hat_collection[round_number]
289 .elements
290 .iter()
291 .zip(self.state_words[1..STATE_SIZE].iter())
292 .map(|(x, y)| *x * *y)
293 .sum::<Fq>();
294
295 self.state_words[1..STATE_SIZE].copy_from_slice(&add_row[..(STATE_SIZE - 1)]);
296 }
297}