penumbra_sdk_dex/
batch_swap_output_data.rs

1use anyhow::{anyhow, Result};
2
3use ark_ff::ToConstraintField;
4use ark_r1cs_std::{
5    prelude::{AllocVar, EqGadget},
6    select::CondSelectGadget,
7};
8use ark_relations::r1cs::{ConstraintSystemRef, SynthesisError};
9use decaf377::{r1cs::FqVar, Fq};
10use penumbra_sdk_proto::{penumbra::core::component::dex::v1 as pb, DomainType};
11use penumbra_sdk_tct::Position;
12use serde::{Deserialize, Serialize};
13
14use penumbra_sdk_num::fixpoint::{bit_constrain, U128x128, U128x128Var};
15use penumbra_sdk_num::{Amount, AmountVar};
16
17use crate::TradingPairVar;
18
19use super::TradingPair;
20
21#[derive(Clone, Debug, Copy, PartialEq, Eq, Serialize, Deserialize)]
22#[serde(try_from = "pb::BatchSwapOutputData", into = "pb::BatchSwapOutputData")]
23pub struct BatchSwapOutputData {
24    /// The total amount of asset 1 that was input to the batch swap.
25    pub delta_1: Amount,
26    /// The total amount of asset 2 that was input to the batch swap.
27    pub delta_2: Amount,
28    /// The total amount of asset 1 that was output from the batch swap for 2=>1 trades.
29    pub lambda_1: Amount,
30    /// The total amount of asset 2 that was output from the batch swap for 1=>2 trades.
31    pub lambda_2: Amount,
32    /// The amount of asset 1 that was returned unfilled from the batch swap for 1=>2 trades.
33    pub unfilled_1: Amount,
34    /// The amount of asset 2 that was returned unfilled from the batch swap for 2=>1 trades.
35    pub unfilled_2: Amount,
36    /// The height for which the batch swap data is valid.
37    pub height: u64,
38    /// The trading pair associated with the batch swap.
39    pub trading_pair: TradingPair,
40    /// The position prefix where this batch swap occurred. The commitment index must be 0.
41    pub sct_position_prefix: Position,
42}
43
44impl BatchSwapOutputData {
45    /// Given a user's inputs `(delta_1_i, delta_2_i)`, compute their pro rata share
46    /// of the batch output `(lambda_1_i, lambda_2_i)`.
47    pub fn pro_rata_outputs(&self, (delta_1_i, delta_2_i): (Amount, Amount)) -> (Amount, Amount) {
48        // The pro rata fraction is delta_j_i / delta_j, which we can multiply through:
49        //   lambda_2_i = (delta_1_i / delta_1) * lambda_2   + (delta_2_i / delta_2) * unfilled_2
50        //   lambda_1_i = (delta_1_i / delta_1) * unfilled_1 + (delta_2_i / delta_2) * lambda_1
51
52        let delta_1_i = U128x128::from(delta_1_i);
53        let delta_2_i = U128x128::from(delta_2_i);
54        let delta_1 = U128x128::from(self.delta_1);
55        let delta_2 = U128x128::from(self.delta_2);
56        let lambda_1 = U128x128::from(self.lambda_1);
57        let lambda_2 = U128x128::from(self.lambda_2);
58        let unfilled_1 = U128x128::from(self.unfilled_1);
59        let unfilled_2 = U128x128::from(self.unfilled_2);
60
61        // Compute the user i's share of the batch inputs of assets 1 and 2.
62        // The .unwrap_or_default ensures that when the batch input delta_1 is zero, all pro-rata shares of it are also zero.
63        let pro_rata_input_1 = (delta_1_i / delta_1).unwrap_or_default();
64        let pro_rata_input_2 = (delta_2_i / delta_2).unwrap_or_default();
65
66        let lambda_2_i = (pro_rata_input_1 * lambda_2).unwrap_or_default()
67            + (pro_rata_input_2 * unfilled_2).unwrap_or_default();
68        let lambda_1_i = (pro_rata_input_1 * unfilled_1).unwrap_or_default()
69            + (pro_rata_input_2 * lambda_1).unwrap_or_default();
70
71        (
72            lambda_1_i
73                .unwrap_or_default()
74                .round_down()
75                .try_into()
76                .expect("rounded amount is integral"),
77            lambda_2_i
78                .unwrap_or_default()
79                .round_down()
80                .try_into()
81                .expect("rounded amount is integral"),
82        )
83    }
84}
85
86impl ToConstraintField<Fq> for BatchSwapOutputData {
87    fn to_field_elements(&self) -> Option<Vec<Fq>> {
88        let mut public_inputs = Vec::new();
89        let delta_1 = U128x128::from(self.delta_1);
90        public_inputs.extend(
91            delta_1
92                .to_field_elements()
93                .expect("delta_1 is a Bls12-377 field member"),
94        );
95        public_inputs.extend(
96            U128x128::from(self.delta_2)
97                .to_field_elements()
98                .expect("U128x128 types are Bls12-377 field members"),
99        );
100        public_inputs.extend(
101            U128x128::from(self.lambda_1)
102                .to_field_elements()
103                .expect("U128x128 types are Bls12-377 field members"),
104        );
105        public_inputs.extend(
106            U128x128::from(self.lambda_2)
107                .to_field_elements()
108                .expect("U128x128 types are Bls12-377 field members"),
109        );
110        public_inputs.extend(
111            U128x128::from(self.unfilled_1)
112                .to_field_elements()
113                .expect("U128x128 types are Bls12-377 field members"),
114        );
115        public_inputs.extend(
116            U128x128::from(self.unfilled_2)
117                .to_field_elements()
118                .expect("U128x128 types are Bls12-377 field members"),
119        );
120        public_inputs.extend(
121            self.trading_pair
122                .to_field_elements()
123                .expect("trading_pair is a Bls12-377 field member"),
124        );
125        public_inputs.extend(
126            Fq::from(self.sct_position_prefix.epoch())
127                .to_field_elements()
128                .expect("Position types are Bls12-377 field members"),
129        );
130        public_inputs.extend(
131            Fq::from(self.sct_position_prefix.block())
132                .to_field_elements()
133                .expect("Position types are Bls12-377 field members"),
134        );
135        Some(public_inputs)
136    }
137}
138
139pub struct BatchSwapOutputDataVar {
140    pub delta_1: U128x128Var,
141    pub delta_2: U128x128Var,
142    pub lambda_1: U128x128Var,
143    pub lambda_2: U128x128Var,
144    pub unfilled_1: U128x128Var,
145    pub unfilled_2: U128x128Var,
146    pub trading_pair: TradingPairVar,
147    pub epoch: FqVar,
148    pub block_within_epoch: FqVar,
149}
150
151impl AllocVar<BatchSwapOutputData, Fq> for BatchSwapOutputDataVar {
152    fn new_variable<T: std::borrow::Borrow<BatchSwapOutputData>>(
153        cs: impl Into<ark_relations::r1cs::Namespace<Fq>>,
154        f: impl FnOnce() -> Result<T, SynthesisError>,
155        mode: ark_r1cs_std::prelude::AllocationMode,
156    ) -> Result<Self, SynthesisError> {
157        let ns = cs.into();
158        let cs = ns.cs();
159        let output_data = *(f()?.borrow());
160        let delta_1_fixpoint: U128x128 = output_data.delta_1.into();
161        let delta_1 = U128x128Var::new_variable(cs.clone(), || Ok(delta_1_fixpoint), mode)?;
162        let delta_2_fixpoint: U128x128 = output_data.delta_2.into();
163        let delta_2 = U128x128Var::new_variable(cs.clone(), || Ok(delta_2_fixpoint), mode)?;
164        let lambda_1_fixpoint: U128x128 = output_data.lambda_1.into();
165        let lambda_1 = U128x128Var::new_variable(cs.clone(), || Ok(lambda_1_fixpoint), mode)?;
166        let lambda_2_fixpoint: U128x128 = output_data.lambda_2.into();
167        let lambda_2 = U128x128Var::new_variable(cs.clone(), || Ok(lambda_2_fixpoint), mode)?;
168        let unfilled_1_fixpoint: U128x128 = output_data.unfilled_1.into();
169        let unfilled_1 = U128x128Var::new_variable(cs.clone(), || Ok(unfilled_1_fixpoint), mode)?;
170        let unfilled_2_fixpoint: U128x128 = output_data.unfilled_2.into();
171        let unfilled_2 = U128x128Var::new_variable(cs.clone(), || Ok(unfilled_2_fixpoint), mode)?;
172        let trading_pair = TradingPairVar::new_variable_unchecked(
173            cs.clone(),
174            || Ok(output_data.trading_pair),
175            mode,
176        )?;
177        let epoch = FqVar::new_variable(
178            cs.clone(),
179            || Ok(Fq::from(output_data.sct_position_prefix.epoch())),
180            mode,
181        )?;
182        bit_constrain(epoch.clone(), 16)?;
183        let block_within_epoch = FqVar::new_variable(
184            cs.clone(),
185            || Ok(Fq::from(output_data.sct_position_prefix.block())),
186            mode,
187        )?;
188        bit_constrain(block_within_epoch.clone(), 16)?;
189
190        Ok(Self {
191            delta_1,
192            delta_2,
193            lambda_1,
194            lambda_2,
195            unfilled_1,
196            unfilled_2,
197            trading_pair,
198            epoch,
199            block_within_epoch,
200        })
201    }
202}
203
204impl DomainType for BatchSwapOutputData {
205    type Proto = pb::BatchSwapOutputData;
206}
207
208impl From<BatchSwapOutputData> for pb::BatchSwapOutputData {
209    fn from(s: BatchSwapOutputData) -> Self {
210        #[allow(deprecated)]
211        pb::BatchSwapOutputData {
212            delta_1: Some(s.delta_1.into()),
213            delta_2: Some(s.delta_2.into()),
214            lambda_1: Some(s.lambda_1.into()),
215            lambda_2: Some(s.lambda_2.into()),
216            unfilled_1: Some(s.unfilled_1.into()),
217            unfilled_2: Some(s.unfilled_2.into()),
218            height: s.height,
219            trading_pair: Some(s.trading_pair.into()),
220            sct_position_prefix: s.sct_position_prefix.into(),
221            // Deprecated fields we explicitly fill with defaults.
222            // We could instead use a `..Default::default()` here, but that would silently
223            // work if we were to add fields to the domain type.
224            epoch_starting_height: Default::default(),
225        }
226    }
227}
228
229impl BatchSwapOutputDataVar {
230    pub fn pro_rata_outputs(
231        &self,
232        delta_1_i: AmountVar,
233        delta_2_i: AmountVar,
234        cs: ConstraintSystemRef<Fq>,
235    ) -> Result<(AmountVar, AmountVar), SynthesisError> {
236        // The pro rata fraction is delta_j_i / delta_j, which we can multiply through:
237        //   lambda_2_i = (delta_1_i / delta_1) * lambda_2   + (delta_2_i / delta_2) * unfilled_2
238        //   lambda_1_i = (delta_1_i / delta_1) * unfilled_1 + (delta_2_i / delta_2) * lambda_1
239
240        let delta_1_i = U128x128Var::from_amount_var(delta_1_i)?;
241        let delta_2_i = U128x128Var::from_amount_var(delta_2_i)?;
242
243        let zero = U128x128Var::zero();
244        let one = U128x128Var::new_constant(cs.clone(), U128x128::from(1u64))?;
245
246        // Compute the user i's share of the batch inputs of assets 1 and 2.
247        // When the batch input delta_1 is zero, all pro-rata shares of it are also zero.
248        let delta_1_is_zero = self.delta_1.is_eq(&zero)?;
249        let divisor_1 = U128x128Var::conditionally_select(&delta_1_is_zero, &one, &self.delta_1)?;
250        let division_result_1 = delta_1_i.checked_div(&divisor_1, cs.clone())?;
251        let pro_rata_input_1 =
252            U128x128Var::conditionally_select(&delta_1_is_zero, &zero, &division_result_1)?;
253
254        let delta_2_is_zero = self.delta_2.is_eq(&zero)?;
255        let divisor_2 = U128x128Var::conditionally_select(&delta_2_is_zero, &one, &self.delta_2)?;
256        let division_result_2 = delta_2_i.checked_div(&divisor_2, cs)?;
257        let pro_rata_input_2 =
258            U128x128Var::conditionally_select(&delta_2_is_zero, &zero, &division_result_2)?;
259
260        // let lambda_2_i = (pro_rata_input_1 * lambda_2).unwrap_or_default()
261        //     + (pro_rata_input_2 * unfilled_2).unwrap_or_default();
262        let addition_term2_1 = pro_rata_input_1.clone().checked_mul(&self.lambda_2)?;
263        let addition_term2_2 = pro_rata_input_2.clone().checked_mul(&self.unfilled_2)?;
264        let lambda_2_i = addition_term2_1.checked_add(&addition_term2_2)?;
265
266        // let lambda_1_i = (pro_rata_input_1 * unfilled_1).unwrap_or_default()
267        //     + (pro_rata_input_2 * lambda_1).unwrap_or_default();
268        let addition_term1_1 = pro_rata_input_1.checked_mul(&self.unfilled_1)?;
269        let addition_term1_2 = pro_rata_input_2.checked_mul(&self.lambda_1)?;
270        let lambda_1_i = addition_term1_1.checked_add(&addition_term1_2)?;
271
272        let lambda_1_i_rounded = lambda_1_i.round_down();
273        let lambda_2_i_rounded = lambda_2_i.round_down();
274
275        Ok((lambda_1_i_rounded.into(), lambda_2_i_rounded.into()))
276    }
277}
278
279impl From<BatchSwapOutputData> for pb::BatchSwapOutputDataResponse {
280    fn from(s: BatchSwapOutputData) -> Self {
281        pb::BatchSwapOutputDataResponse {
282            data: Some(s.into()),
283        }
284    }
285}
286
287impl TryFrom<pb::BatchSwapOutputData> for BatchSwapOutputData {
288    type Error = anyhow::Error;
289    fn try_from(s: pb::BatchSwapOutputData) -> Result<Self, Self::Error> {
290        let sct_position_prefix = {
291            let prefix = Position::from(s.sct_position_prefix);
292            anyhow::ensure!(
293                prefix.commitment() == 0,
294                "sct_position_prefix.commitment() != 0"
295            );
296            prefix
297        };
298        Ok(Self {
299            delta_1: s
300                .delta_1
301                .ok_or_else(|| anyhow!("Missing delta_1"))?
302                .try_into()?,
303            delta_2: s
304                .delta_2
305                .ok_or_else(|| anyhow!("Missing delta_2"))?
306                .try_into()?,
307            lambda_1: s
308                .lambda_1
309                .ok_or_else(|| anyhow!("Missing lambda_1"))?
310                .try_into()?,
311            lambda_2: s
312                .lambda_2
313                .ok_or_else(|| anyhow!("Missing lambda_2"))?
314                .try_into()?,
315            unfilled_1: s
316                .unfilled_1
317                .ok_or_else(|| anyhow!("Missing unfilled_1"))?
318                .try_into()?,
319            unfilled_2: s
320                .unfilled_2
321                .ok_or_else(|| anyhow!("Missing unfilled_2"))?
322                .try_into()?,
323            height: s.height,
324            trading_pair: s
325                .trading_pair
326                .ok_or_else(|| anyhow!("Missing trading_pair"))?
327                .try_into()?,
328            sct_position_prefix,
329        })
330    }
331}
332
333impl TryFrom<pb::BatchSwapOutputDataResponse> for BatchSwapOutputData {
334    type Error = anyhow::Error;
335    fn try_from(value: pb::BatchSwapOutputDataResponse) -> Result<Self, Self::Error> {
336        value
337            .data
338            .ok_or_else(|| anyhow::anyhow!("empty BatchSwapOutputDataResponse message"))?
339            .try_into()
340    }
341}
342
343#[cfg(test)]
344mod tests {
345    use ark_groth16::{r1cs_to_qap::LibsnarkReduction, Groth16};
346    use ark_relations::r1cs::ConstraintSynthesizer;
347    use ark_snark::SNARK;
348    use decaf377::Bls12_377;
349    use penumbra_sdk_asset::asset;
350    use penumbra_sdk_proof_params::{generate_test_parameters, DummyWitness};
351    use rand_core::OsRng;
352
353    use super::*;
354
355    #[test]
356    fn pasiphae_inflation_bug() {
357        let bsod: BatchSwapOutputData = serde_json::from_str(
358            r#"
359{
360    "delta1": {
361        "lo": "31730032"
362    },
363    "delta2": {},
364    "unfilled1": {},
365    "lambda2": {
366        "lo": "28766268"
367    },
368    "lambda1": {},
369    "unfilled2": {},
370    "height": "2185",
371    "tradingPair": {
372        "asset1": {
373        "inner": "HW2Eq3UZVSBttoUwUi/MUtE7rr2UU7/UH500byp7OAc="
374        },
375        "asset2": {
376        "inner": "KeqcLzNx9qSH5+lcJHBB9KNW+YPrBk5dKzvPMiypahA="
377        }
378    }
379}"#,
380        )
381        .unwrap();
382
383        let (delta_1_i, delta_2_i) = (Amount::from(31730032u64), Amount::from(0u64));
384
385        let (lambda_1_i, lambda_2_i) = bsod.pro_rata_outputs((delta_1_i, delta_2_i));
386
387        assert_eq!(lambda_1_i, Amount::from(0u64));
388        assert_eq!(lambda_2_i, Amount::from(28766268u64));
389    }
390
391    struct ProRataOutputCircuit {
392        delta_1_i: Amount,
393        delta_2_i: Amount,
394        lambda_1_i: Amount,
395        lambda_2_i: Amount,
396        pub bsod: BatchSwapOutputData,
397    }
398
399    impl ConstraintSynthesizer<Fq> for ProRataOutputCircuit {
400        fn generate_constraints(
401            self,
402            cs: ConstraintSystemRef<Fq>,
403        ) -> ark_relations::r1cs::Result<()> {
404            let delta_1_i_var = AmountVar::new_witness(cs.clone(), || Ok(self.delta_1_i))?;
405            let delta_2_i_var = AmountVar::new_witness(cs.clone(), || Ok(self.delta_2_i))?;
406            let lambda_1_i_var = AmountVar::new_witness(cs.clone(), || Ok(self.lambda_1_i))?;
407            let lambda_2_i_var = AmountVar::new_witness(cs.clone(), || Ok(self.lambda_2_i))?;
408            let bsod_var = BatchSwapOutputDataVar::new_input(cs.clone(), || Ok(self.bsod))?;
409
410            let (calculated_lambda_1_i_var, calculated_lambda_2_i_var) =
411                bsod_var.pro_rata_outputs(delta_1_i_var, delta_2_i_var, cs.clone())?;
412            calculated_lambda_1_i_var.enforce_equal(&lambda_1_i_var)?;
413            calculated_lambda_2_i_var.enforce_equal(&lambda_2_i_var)?;
414
415            Ok(())
416        }
417    }
418
419    impl DummyWitness for ProRataOutputCircuit {
420        fn with_dummy_witness() -> Self {
421            let trading_pair = TradingPair {
422                asset_1: asset::Cache::with_known_assets()
423                    .get_unit("upenumbra")
424                    .expect("upenumbra denom should always be known by the asset registry")
425                    .id(),
426                asset_2: asset::Cache::with_known_assets()
427                    .get_unit("nala")
428                    .expect("nala denom should always be known by the asset registry")
429                    .id(),
430            };
431            Self {
432                delta_1_i: Amount::from(1u32),
433                delta_2_i: Amount::from(1u32),
434                lambda_1_i: Amount::from(1u32),
435                lambda_2_i: Amount::from(1u32),
436                bsod: BatchSwapOutputData {
437                    delta_1: Amount::from(1u32),
438                    delta_2: Amount::from(1u32),
439                    lambda_1: Amount::from(1u32),
440                    lambda_2: Amount::from(1u32),
441                    unfilled_1: Amount::from(1u32),
442                    unfilled_2: Amount::from(1u32),
443                    height: 0,
444                    trading_pair,
445                    sct_position_prefix: 0u64.into(),
446                },
447            }
448        }
449    }
450
451    #[test]
452    fn happy_path_bsod_pro_rata() {
453        // Example Chain-wide swap output data
454        let gm = asset::Cache::with_known_assets().get_unit("gm").unwrap();
455        let gn = asset::Cache::with_known_assets().get_unit("gn").unwrap();
456        let trading_pair = TradingPair::new(gm.id(), gn.id());
457        let bsod = BatchSwapOutputData {
458            delta_1: Amount::from(200u64),
459            delta_2: Amount::from(300u64),
460            lambda_1: Amount::from(150u64),
461            lambda_2: Amount::from(125u64),
462            unfilled_1: Amount::from(23u64),
463            unfilled_2: Amount::from(50u64),
464            height: 0u64,
465            trading_pair,
466            sct_position_prefix: 0u64.into(),
467        };
468
469        // Now suppose our user's contribution is:
470        let delta_1_i = Amount::from(100u64);
471        let delta_2_i = Amount::from(200u64);
472
473        // Then their pro-rata outputs (out-of-circuit) are:
474        let (lambda_1_i, lambda_2_i) = bsod.pro_rata_outputs((delta_1_i, delta_2_i));
475
476        let circuit = ProRataOutputCircuit {
477            delta_1_i,
478            delta_2_i,
479            lambda_1_i,
480            lambda_2_i,
481            bsod,
482        };
483
484        let mut rng = OsRng;
485        let (pk, vk) = generate_test_parameters::<ProRataOutputCircuit>(&mut rng);
486
487        let proof = Groth16::<Bls12_377, LibsnarkReduction>::prove(&pk, circuit, &mut rng)
488            .expect("should be able to form proof");
489
490        let proof_result = Groth16::<Bls12_377, LibsnarkReduction>::verify(
491            &vk,
492            &bsod.to_field_elements().unwrap(),
493            &proof,
494        )
495        .expect("should be able to verify proof");
496
497        assert!(proof_result);
498    }
499}