1use anyhow::{anyhow, Result};
2
3use ark_ff::ToConstraintField;
4use ark_r1cs_std::{
5 prelude::{AllocVar, EqGadget},
6 select::CondSelectGadget,
7};
8use ark_relations::r1cs::{ConstraintSystemRef, SynthesisError};
9use decaf377::{r1cs::FqVar, Fq};
10use penumbra_sdk_proto::{penumbra::core::component::dex::v1 as pb, DomainType};
11use penumbra_sdk_tct::Position;
12use serde::{Deserialize, Serialize};
13
14use penumbra_sdk_num::fixpoint::{bit_constrain, U128x128, U128x128Var};
15use penumbra_sdk_num::{Amount, AmountVar};
16
17use crate::TradingPairVar;
18
19use super::TradingPair;
20
21#[derive(Clone, Debug, Copy, PartialEq, Eq, Serialize, Deserialize)]
22#[serde(try_from = "pb::BatchSwapOutputData", into = "pb::BatchSwapOutputData")]
23pub struct BatchSwapOutputData {
24 pub delta_1: Amount,
26 pub delta_2: Amount,
28 pub lambda_1: Amount,
30 pub lambda_2: Amount,
32 pub unfilled_1: Amount,
34 pub unfilled_2: Amount,
36 pub height: u64,
38 pub trading_pair: TradingPair,
40 pub sct_position_prefix: Position,
42}
43
44impl BatchSwapOutputData {
45 pub fn pro_rata_outputs(&self, (delta_1_i, delta_2_i): (Amount, Amount)) -> (Amount, Amount) {
48 let delta_1_i = U128x128::from(delta_1_i);
53 let delta_2_i = U128x128::from(delta_2_i);
54 let delta_1 = U128x128::from(self.delta_1);
55 let delta_2 = U128x128::from(self.delta_2);
56 let lambda_1 = U128x128::from(self.lambda_1);
57 let lambda_2 = U128x128::from(self.lambda_2);
58 let unfilled_1 = U128x128::from(self.unfilled_1);
59 let unfilled_2 = U128x128::from(self.unfilled_2);
60
61 let pro_rata_input_1 = (delta_1_i / delta_1).unwrap_or_default();
64 let pro_rata_input_2 = (delta_2_i / delta_2).unwrap_or_default();
65
66 let lambda_2_i = (pro_rata_input_1 * lambda_2).unwrap_or_default()
67 + (pro_rata_input_2 * unfilled_2).unwrap_or_default();
68 let lambda_1_i = (pro_rata_input_1 * unfilled_1).unwrap_or_default()
69 + (pro_rata_input_2 * lambda_1).unwrap_or_default();
70
71 (
72 lambda_1_i
73 .unwrap_or_default()
74 .round_down()
75 .try_into()
76 .expect("rounded amount is integral"),
77 lambda_2_i
78 .unwrap_or_default()
79 .round_down()
80 .try_into()
81 .expect("rounded amount is integral"),
82 )
83 }
84}
85
86impl ToConstraintField<Fq> for BatchSwapOutputData {
87 fn to_field_elements(&self) -> Option<Vec<Fq>> {
88 let mut public_inputs = Vec::new();
89 let delta_1 = U128x128::from(self.delta_1);
90 public_inputs.extend(
91 delta_1
92 .to_field_elements()
93 .expect("delta_1 is a Bls12-377 field member"),
94 );
95 public_inputs.extend(
96 U128x128::from(self.delta_2)
97 .to_field_elements()
98 .expect("U128x128 types are Bls12-377 field members"),
99 );
100 public_inputs.extend(
101 U128x128::from(self.lambda_1)
102 .to_field_elements()
103 .expect("U128x128 types are Bls12-377 field members"),
104 );
105 public_inputs.extend(
106 U128x128::from(self.lambda_2)
107 .to_field_elements()
108 .expect("U128x128 types are Bls12-377 field members"),
109 );
110 public_inputs.extend(
111 U128x128::from(self.unfilled_1)
112 .to_field_elements()
113 .expect("U128x128 types are Bls12-377 field members"),
114 );
115 public_inputs.extend(
116 U128x128::from(self.unfilled_2)
117 .to_field_elements()
118 .expect("U128x128 types are Bls12-377 field members"),
119 );
120 public_inputs.extend(
121 self.trading_pair
122 .to_field_elements()
123 .expect("trading_pair is a Bls12-377 field member"),
124 );
125 public_inputs.extend(
126 Fq::from(self.sct_position_prefix.epoch())
127 .to_field_elements()
128 .expect("Position types are Bls12-377 field members"),
129 );
130 public_inputs.extend(
131 Fq::from(self.sct_position_prefix.block())
132 .to_field_elements()
133 .expect("Position types are Bls12-377 field members"),
134 );
135 Some(public_inputs)
136 }
137}
138
139pub struct BatchSwapOutputDataVar {
140 pub delta_1: U128x128Var,
141 pub delta_2: U128x128Var,
142 pub lambda_1: U128x128Var,
143 pub lambda_2: U128x128Var,
144 pub unfilled_1: U128x128Var,
145 pub unfilled_2: U128x128Var,
146 pub trading_pair: TradingPairVar,
147 pub epoch: FqVar,
148 pub block_within_epoch: FqVar,
149}
150
151impl AllocVar<BatchSwapOutputData, Fq> for BatchSwapOutputDataVar {
152 fn new_variable<T: std::borrow::Borrow<BatchSwapOutputData>>(
153 cs: impl Into<ark_relations::r1cs::Namespace<Fq>>,
154 f: impl FnOnce() -> Result<T, SynthesisError>,
155 mode: ark_r1cs_std::prelude::AllocationMode,
156 ) -> Result<Self, SynthesisError> {
157 let ns = cs.into();
158 let cs = ns.cs();
159 let output_data = *(f()?.borrow());
160 let delta_1_fixpoint: U128x128 = output_data.delta_1.into();
161 let delta_1 = U128x128Var::new_variable(cs.clone(), || Ok(delta_1_fixpoint), mode)?;
162 let delta_2_fixpoint: U128x128 = output_data.delta_2.into();
163 let delta_2 = U128x128Var::new_variable(cs.clone(), || Ok(delta_2_fixpoint), mode)?;
164 let lambda_1_fixpoint: U128x128 = output_data.lambda_1.into();
165 let lambda_1 = U128x128Var::new_variable(cs.clone(), || Ok(lambda_1_fixpoint), mode)?;
166 let lambda_2_fixpoint: U128x128 = output_data.lambda_2.into();
167 let lambda_2 = U128x128Var::new_variable(cs.clone(), || Ok(lambda_2_fixpoint), mode)?;
168 let unfilled_1_fixpoint: U128x128 = output_data.unfilled_1.into();
169 let unfilled_1 = U128x128Var::new_variable(cs.clone(), || Ok(unfilled_1_fixpoint), mode)?;
170 let unfilled_2_fixpoint: U128x128 = output_data.unfilled_2.into();
171 let unfilled_2 = U128x128Var::new_variable(cs.clone(), || Ok(unfilled_2_fixpoint), mode)?;
172 let trading_pair = TradingPairVar::new_variable_unchecked(
173 cs.clone(),
174 || Ok(output_data.trading_pair),
175 mode,
176 )?;
177 let epoch = FqVar::new_variable(
178 cs.clone(),
179 || Ok(Fq::from(output_data.sct_position_prefix.epoch())),
180 mode,
181 )?;
182 bit_constrain(epoch.clone(), 16)?;
183 let block_within_epoch = FqVar::new_variable(
184 cs.clone(),
185 || Ok(Fq::from(output_data.sct_position_prefix.block())),
186 mode,
187 )?;
188 bit_constrain(block_within_epoch.clone(), 16)?;
189
190 Ok(Self {
191 delta_1,
192 delta_2,
193 lambda_1,
194 lambda_2,
195 unfilled_1,
196 unfilled_2,
197 trading_pair,
198 epoch,
199 block_within_epoch,
200 })
201 }
202}
203
204impl DomainType for BatchSwapOutputData {
205 type Proto = pb::BatchSwapOutputData;
206}
207
208impl From<BatchSwapOutputData> for pb::BatchSwapOutputData {
209 fn from(s: BatchSwapOutputData) -> Self {
210 #[allow(deprecated)]
211 pb::BatchSwapOutputData {
212 delta_1: Some(s.delta_1.into()),
213 delta_2: Some(s.delta_2.into()),
214 lambda_1: Some(s.lambda_1.into()),
215 lambda_2: Some(s.lambda_2.into()),
216 unfilled_1: Some(s.unfilled_1.into()),
217 unfilled_2: Some(s.unfilled_2.into()),
218 height: s.height,
219 trading_pair: Some(s.trading_pair.into()),
220 sct_position_prefix: s.sct_position_prefix.into(),
221 epoch_starting_height: Default::default(),
225 }
226 }
227}
228
229impl BatchSwapOutputDataVar {
230 pub fn pro_rata_outputs(
231 &self,
232 delta_1_i: AmountVar,
233 delta_2_i: AmountVar,
234 cs: ConstraintSystemRef<Fq>,
235 ) -> Result<(AmountVar, AmountVar), SynthesisError> {
236 let delta_1_i = U128x128Var::from_amount_var(delta_1_i)?;
241 let delta_2_i = U128x128Var::from_amount_var(delta_2_i)?;
242
243 let zero = U128x128Var::zero();
244 let one = U128x128Var::new_constant(cs.clone(), U128x128::from(1u64))?;
245
246 let delta_1_is_zero = self.delta_1.is_eq(&zero)?;
249 let divisor_1 = U128x128Var::conditionally_select(&delta_1_is_zero, &one, &self.delta_1)?;
250 let division_result_1 = delta_1_i.checked_div(&divisor_1, cs.clone())?;
251 let pro_rata_input_1 =
252 U128x128Var::conditionally_select(&delta_1_is_zero, &zero, &division_result_1)?;
253
254 let delta_2_is_zero = self.delta_2.is_eq(&zero)?;
255 let divisor_2 = U128x128Var::conditionally_select(&delta_2_is_zero, &one, &self.delta_2)?;
256 let division_result_2 = delta_2_i.checked_div(&divisor_2, cs)?;
257 let pro_rata_input_2 =
258 U128x128Var::conditionally_select(&delta_2_is_zero, &zero, &division_result_2)?;
259
260 let addition_term2_1 = pro_rata_input_1.clone().checked_mul(&self.lambda_2)?;
263 let addition_term2_2 = pro_rata_input_2.clone().checked_mul(&self.unfilled_2)?;
264 let lambda_2_i = addition_term2_1.checked_add(&addition_term2_2)?;
265
266 let addition_term1_1 = pro_rata_input_1.checked_mul(&self.unfilled_1)?;
269 let addition_term1_2 = pro_rata_input_2.checked_mul(&self.lambda_1)?;
270 let lambda_1_i = addition_term1_1.checked_add(&addition_term1_2)?;
271
272 let lambda_1_i_rounded = lambda_1_i.round_down();
273 let lambda_2_i_rounded = lambda_2_i.round_down();
274
275 Ok((lambda_1_i_rounded.into(), lambda_2_i_rounded.into()))
276 }
277}
278
279impl From<BatchSwapOutputData> for pb::BatchSwapOutputDataResponse {
280 fn from(s: BatchSwapOutputData) -> Self {
281 pb::BatchSwapOutputDataResponse {
282 data: Some(s.into()),
283 }
284 }
285}
286
287impl TryFrom<pb::BatchSwapOutputData> for BatchSwapOutputData {
288 type Error = anyhow::Error;
289 fn try_from(s: pb::BatchSwapOutputData) -> Result<Self, Self::Error> {
290 let sct_position_prefix = {
291 let prefix = Position::from(s.sct_position_prefix);
292 anyhow::ensure!(
293 prefix.commitment() == 0,
294 "sct_position_prefix.commitment() != 0"
295 );
296 prefix
297 };
298 Ok(Self {
299 delta_1: s
300 .delta_1
301 .ok_or_else(|| anyhow!("Missing delta_1"))?
302 .try_into()?,
303 delta_2: s
304 .delta_2
305 .ok_or_else(|| anyhow!("Missing delta_2"))?
306 .try_into()?,
307 lambda_1: s
308 .lambda_1
309 .ok_or_else(|| anyhow!("Missing lambda_1"))?
310 .try_into()?,
311 lambda_2: s
312 .lambda_2
313 .ok_or_else(|| anyhow!("Missing lambda_2"))?
314 .try_into()?,
315 unfilled_1: s
316 .unfilled_1
317 .ok_or_else(|| anyhow!("Missing unfilled_1"))?
318 .try_into()?,
319 unfilled_2: s
320 .unfilled_2
321 .ok_or_else(|| anyhow!("Missing unfilled_2"))?
322 .try_into()?,
323 height: s.height,
324 trading_pair: s
325 .trading_pair
326 .ok_or_else(|| anyhow!("Missing trading_pair"))?
327 .try_into()?,
328 sct_position_prefix,
329 })
330 }
331}
332
333impl TryFrom<pb::BatchSwapOutputDataResponse> for BatchSwapOutputData {
334 type Error = anyhow::Error;
335 fn try_from(value: pb::BatchSwapOutputDataResponse) -> Result<Self, Self::Error> {
336 value
337 .data
338 .ok_or_else(|| anyhow::anyhow!("empty BatchSwapOutputDataResponse message"))?
339 .try_into()
340 }
341}
342
343#[cfg(test)]
344mod tests {
345 use ark_groth16::{r1cs_to_qap::LibsnarkReduction, Groth16};
346 use ark_relations::r1cs::ConstraintSynthesizer;
347 use ark_snark::SNARK;
348 use decaf377::Bls12_377;
349 use penumbra_sdk_asset::asset;
350 use penumbra_sdk_proof_params::{generate_test_parameters, DummyWitness};
351 use rand_core::OsRng;
352
353 use super::*;
354
355 #[test]
356 fn pasiphae_inflation_bug() {
357 let bsod: BatchSwapOutputData = serde_json::from_str(
358 r#"
359{
360 "delta1": {
361 "lo": "31730032"
362 },
363 "delta2": {},
364 "unfilled1": {},
365 "lambda2": {
366 "lo": "28766268"
367 },
368 "lambda1": {},
369 "unfilled2": {},
370 "height": "2185",
371 "tradingPair": {
372 "asset1": {
373 "inner": "HW2Eq3UZVSBttoUwUi/MUtE7rr2UU7/UH500byp7OAc="
374 },
375 "asset2": {
376 "inner": "KeqcLzNx9qSH5+lcJHBB9KNW+YPrBk5dKzvPMiypahA="
377 }
378 }
379}"#,
380 )
381 .unwrap();
382
383 let (delta_1_i, delta_2_i) = (Amount::from(31730032u64), Amount::from(0u64));
384
385 let (lambda_1_i, lambda_2_i) = bsod.pro_rata_outputs((delta_1_i, delta_2_i));
386
387 assert_eq!(lambda_1_i, Amount::from(0u64));
388 assert_eq!(lambda_2_i, Amount::from(28766268u64));
389 }
390
391 struct ProRataOutputCircuit {
392 delta_1_i: Amount,
393 delta_2_i: Amount,
394 lambda_1_i: Amount,
395 lambda_2_i: Amount,
396 pub bsod: BatchSwapOutputData,
397 }
398
399 impl ConstraintSynthesizer<Fq> for ProRataOutputCircuit {
400 fn generate_constraints(
401 self,
402 cs: ConstraintSystemRef<Fq>,
403 ) -> ark_relations::r1cs::Result<()> {
404 let delta_1_i_var = AmountVar::new_witness(cs.clone(), || Ok(self.delta_1_i))?;
405 let delta_2_i_var = AmountVar::new_witness(cs.clone(), || Ok(self.delta_2_i))?;
406 let lambda_1_i_var = AmountVar::new_witness(cs.clone(), || Ok(self.lambda_1_i))?;
407 let lambda_2_i_var = AmountVar::new_witness(cs.clone(), || Ok(self.lambda_2_i))?;
408 let bsod_var = BatchSwapOutputDataVar::new_input(cs.clone(), || Ok(self.bsod))?;
409
410 let (calculated_lambda_1_i_var, calculated_lambda_2_i_var) =
411 bsod_var.pro_rata_outputs(delta_1_i_var, delta_2_i_var, cs.clone())?;
412 calculated_lambda_1_i_var.enforce_equal(&lambda_1_i_var)?;
413 calculated_lambda_2_i_var.enforce_equal(&lambda_2_i_var)?;
414
415 Ok(())
416 }
417 }
418
419 impl DummyWitness for ProRataOutputCircuit {
420 fn with_dummy_witness() -> Self {
421 let trading_pair = TradingPair {
422 asset_1: asset::Cache::with_known_assets()
423 .get_unit("upenumbra")
424 .expect("upenumbra denom should always be known by the asset registry")
425 .id(),
426 asset_2: asset::Cache::with_known_assets()
427 .get_unit("nala")
428 .expect("nala denom should always be known by the asset registry")
429 .id(),
430 };
431 Self {
432 delta_1_i: Amount::from(1u32),
433 delta_2_i: Amount::from(1u32),
434 lambda_1_i: Amount::from(1u32),
435 lambda_2_i: Amount::from(1u32),
436 bsod: BatchSwapOutputData {
437 delta_1: Amount::from(1u32),
438 delta_2: Amount::from(1u32),
439 lambda_1: Amount::from(1u32),
440 lambda_2: Amount::from(1u32),
441 unfilled_1: Amount::from(1u32),
442 unfilled_2: Amount::from(1u32),
443 height: 0,
444 trading_pair,
445 sct_position_prefix: 0u64.into(),
446 },
447 }
448 }
449 }
450
451 #[test]
452 fn happy_path_bsod_pro_rata() {
453 let gm = asset::Cache::with_known_assets().get_unit("gm").unwrap();
455 let gn = asset::Cache::with_known_assets().get_unit("gn").unwrap();
456 let trading_pair = TradingPair::new(gm.id(), gn.id());
457 let bsod = BatchSwapOutputData {
458 delta_1: Amount::from(200u64),
459 delta_2: Amount::from(300u64),
460 lambda_1: Amount::from(150u64),
461 lambda_2: Amount::from(125u64),
462 unfilled_1: Amount::from(23u64),
463 unfilled_2: Amount::from(50u64),
464 height: 0u64,
465 trading_pair,
466 sct_position_prefix: 0u64.into(),
467 };
468
469 let delta_1_i = Amount::from(100u64);
471 let delta_2_i = Amount::from(200u64);
472
473 let (lambda_1_i, lambda_2_i) = bsod.pro_rata_outputs((delta_1_i, delta_2_i));
475
476 let circuit = ProRataOutputCircuit {
477 delta_1_i,
478 delta_2_i,
479 lambda_1_i,
480 lambda_2_i,
481 bsod,
482 };
483
484 let mut rng = OsRng;
485 let (pk, vk) = generate_test_parameters::<ProRataOutputCircuit>(&mut rng);
486
487 let proof = Groth16::<Bls12_377, LibsnarkReduction>::prove(&pk, circuit, &mut rng)
488 .expect("should be able to form proof");
489
490 let proof_result = Groth16::<Bls12_377, LibsnarkReduction>::verify(
491 &vk,
492 &bsod.to_field_elements().unwrap(),
493 &proof,
494 )
495 .expect("should be able to verify proof");
496
497 assert!(proof_result);
498 }
499}