1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
#![allow(non_snake_case)]
use ark_ff::PrimeField;
use ark_std::vec::Vec;

use ark_r1cs_std::{fields::fp::FpVar, prelude::*};
use ark_relations::r1cs::ConstraintSystemRef;
use poseidon_parameters::v1::{Alpha, MatrixOperations, PoseidonParameters};

/// Represents a Poseidon permutation instance.
pub struct InstanceVar<F: PrimeField> {
    /// Parameters for this instance of Poseidon.
    pub parameters: PoseidonParameters<F>,

    /// Constraint system
    pub cs: ConstraintSystemRef<F>,

    /// Current state
    pub state_words: Vec<FpVar<F>>,
}

impl<F> InstanceVar<F>
where
    F: PrimeField,
{
    /// Fixed width hash from n:1. Outputs a F given `t` input words.
    pub fn n_to_1_fixed_hash(
        parameters: PoseidonParameters<F>,
        cs: ConstraintSystemRef<F>,
        input_words: Vec<FpVar<F>>,
    ) -> FpVar<F> {
        // Check input words are `t` elements long
        if input_words.len() != parameters.t {
            panic!("err: input words must be t elements long")
        }

        // t = rate + capacity

        let mut instance = InstanceVar {
            parameters,
            cs,
            state_words: input_words,
        };

        // Apply Poseidon permutation.
        instance.permute();

        // Emit a single element since this is a n:1 hash.
        instance.state_words[1].clone()
    }

    /// Poseidon permutation.
    pub fn permute(&mut self) {
        let R_f = self.parameters.rounds.full() / 2;
        let R_P = self.parameters.rounds.partial();
        let mut round_constants_counter = 0;
        let t = self.parameters.t;
        let round_constants = self.parameters.arc.elements().clone();

        // First full rounds
        for _ in 0..R_f {
            // Apply `AddRoundConstants` layer
            for i in 0..t {
                self.state_words[i] += round_constants[round_constants_counter];
                round_constants_counter += 1;
            }
            self.full_sub_words();
            self.mix_layer_mds();
        }

        // Partial rounds
        for _ in 0..R_P {
            // Apply `AddRoundConstants` layer
            for i in 0..t {
                self.state_words[i] += round_constants[round_constants_counter];
                round_constants_counter += 1;
            }
            self.partial_sub_words();
            self.mix_layer_mds();
        }

        // Final full rounds
        for _ in 0..R_f {
            // Apply `AddRoundConstants` layer
            for i in 0..t {
                self.state_words[i] += round_constants[round_constants_counter];
                round_constants_counter += 1;
            }
            self.full_sub_words();
            self.mix_layer_mds();
        }
    }

    /// Applies the partial `SubWords` layer.
    fn partial_sub_words(&mut self) {
        match self.parameters.alpha {
            Alpha::Exponent(exp) => {
                self.state_words[0] = (self.state_words[0])
                    .pow_by_constant([exp as u64])
                    .expect("can compute pow")
            }
            Alpha::Inverse => unimplemented!("err: inverse alpha not implemented"),
        }
    }

    /// Applies the full `SubWords` layer.
    fn full_sub_words(&mut self) {
        match self.parameters.alpha {
            Alpha::Exponent(exp) => {
                for i in 0..self.parameters.t {
                    self.state_words[i] = (self.state_words[i])
                        .pow_by_constant([exp as u64])
                        .expect("can compute pow");
                }
            }
            Alpha::Inverse => {
                unimplemented!("err: inverse alpha not implemented")
            }
        }
    }

    /// Applies the `MixLayer` using the MDS matrix.
    fn mix_layer_mds(&mut self) {
        self.state_words = self
            .parameters
            .mds
            .0
             .0
            .iter_rows()
            .map(|row| {
                let temp_vec: Vec<FpVar<F>> = row
                    .iter()
                    .zip(&self.state_words)
                    .map(|(x, y)| {
                        FpVar::<F>::new_constant(self.cs.clone(), x).expect("can create constant")
                            * y
                    })
                    .collect();
                let result = temp_vec.iter().sum();
                result
            })
            .collect();
    }
}