poseidon_permutation/
r1cs.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
#![allow(non_snake_case)]
use ark_std::vec::Vec;

use ark_r1cs_std::{fields::fp::FpVar, prelude::*};
use ark_relations::r1cs::ConstraintSystemRef;
use decaf377::Fq;
use poseidon_parameters::v1::{Alpha, MatrixOperations, PoseidonParameters};

/// Represents a Poseidon permutation instance.
pub struct InstanceVar<
    const STATE_SIZE: usize,
    const STATE_SIZE_MINUS_1: usize,
    const NUM_MDS_ELEMENTS: usize,
    const NUM_STATE_SIZE_MINUS_1_ELEMENTS: usize,
    const NUM_ROUND_ROWS: usize,
    const NUM_ROUND_COLS: usize,
    const NUM_ROUND_ELEMENTS: usize,
    const NUM_PARTIAL_ROUNDS: usize,
> {
    /// Parameters for this instance of Poseidon.
    pub parameters: PoseidonParameters<
        STATE_SIZE,
        STATE_SIZE_MINUS_1,
        NUM_MDS_ELEMENTS,
        NUM_STATE_SIZE_MINUS_1_ELEMENTS,
        NUM_ROUND_ROWS,
        NUM_ROUND_COLS,
        NUM_ROUND_ELEMENTS,
        NUM_PARTIAL_ROUNDS,
    >,

    /// Constraint system
    pub cs: ConstraintSystemRef<Fq>,

    /// Current state
    pub state_words: Vec<FpVar<Fq>>,
}

impl<
        const STATE_SIZE: usize,
        const STATE_SIZE_MINUS_1: usize,
        const NUM_MDS_ELEMENTS: usize,
        const NUM_STATE_SIZE_MINUS_1_ELEMENTS: usize,
        const NUM_ROUND_ROWS: usize,
        const NUM_ROUND_COLS: usize,
        const NUM_ROUND_ELEMENTS: usize,
        const NUM_PARTIAL_ROUNDS: usize,
    >
    InstanceVar<
        STATE_SIZE,
        STATE_SIZE_MINUS_1,
        NUM_MDS_ELEMENTS,
        NUM_STATE_SIZE_MINUS_1_ELEMENTS,
        NUM_ROUND_ROWS,
        NUM_ROUND_COLS,
        NUM_ROUND_ELEMENTS,
        NUM_PARTIAL_ROUNDS,
    >
{
    /// Fixed width hash from n:1. Outputs a Fq given `t` input words.
    pub fn n_to_1_fixed_hash(
        parameters: PoseidonParameters<
            STATE_SIZE,
            STATE_SIZE_MINUS_1,
            NUM_MDS_ELEMENTS,
            NUM_STATE_SIZE_MINUS_1_ELEMENTS,
            NUM_ROUND_ROWS,
            NUM_ROUND_COLS,
            NUM_ROUND_ELEMENTS,
            NUM_PARTIAL_ROUNDS,
        >,
        cs: ConstraintSystemRef<Fq>,
        input_words: [FpVar<Fq>; STATE_SIZE],
    ) -> FpVar<Fq> {
        // t = rate + capacity

        let mut instance = InstanceVar {
            parameters,
            cs,
            state_words: input_words.to_vec(),
        };

        // Apply Poseidon permutation.
        instance.permute();

        // Emit a single element since this is a n:1 hash.
        instance.state_words[1].clone()
    }

    /// Poseidon permutation.
    pub fn permute(&mut self) {
        let R_f = self.parameters.rounds.full() / 2;
        let R_P = self.parameters.rounds.partial();
        let mut round_constants_counter = 0;
        let round_constants: [Fq; NUM_ROUND_ELEMENTS] = self.parameters.arc.inner_elements();

        // First full rounds
        for _ in 0..R_f {
            // Apply `AddRoundConstants` layer
            for i in 0..STATE_SIZE {
                self.state_words[i] += round_constants[round_constants_counter];
                round_constants_counter += 1;
            }
            self.full_sub_words();
            self.mix_layer_mds();
        }

        // Partial rounds
        for _ in 0..R_P {
            // Apply `AddRoundConstants` layer
            for i in 0..STATE_SIZE {
                self.state_words[i] += round_constants[round_constants_counter];
                round_constants_counter += 1;
            }
            self.partial_sub_words();
            self.mix_layer_mds();
        }

        // Final full rounds
        for _ in 0..R_f {
            // Apply `AddRoundConstants` layer
            for i in 0..STATE_SIZE {
                self.state_words[i] += round_constants[round_constants_counter];
                round_constants_counter += 1;
            }
            self.full_sub_words();
            self.mix_layer_mds();
        }
    }

    /// Applies the partial `SubWords` layer.
    fn partial_sub_words(&mut self) {
        match self.parameters.alpha {
            Alpha::Exponent(exp) => {
                self.state_words[0] = (self.state_words[0])
                    .pow_by_constant([exp as u64])
                    .expect("can compute pow")
            }
            Alpha::Inverse => unimplemented!("err: inverse alpha not implemented"),
        }
    }

    /// Applies the full `SubWords` layer.
    fn full_sub_words(&mut self) {
        match self.parameters.alpha {
            Alpha::Exponent(exp) => {
                for i in 0..STATE_SIZE {
                    self.state_words[i] = (self.state_words[i])
                        .pow_by_constant([exp as u64])
                        .expect("can compute pow");
                }
            }
            Alpha::Inverse => {
                unimplemented!("err: inverse alpha not implemented")
            }
        }
    }

    /// Applies the `MixLayer` using the MDS matrix.
    fn mix_layer_mds(&mut self) {
        self.state_words = self
            .parameters
            .mds
            .0
             .0
            .iter_rows()
            .map(|row| {
                let temp_vec: Vec<FpVar<Fq>> = row
                    .iter()
                    .zip(&self.state_words)
                    .map(|(x, y)| {
                        FpVar::<Fq>::new_constant(self.cs.clone(), x).expect("can create constant")
                            * y
                    })
                    .collect();
                let result = temp_vec.iter().sum();
                result
            })
            .collect();
    }
}