penumbra_sdk_num/fixpoint/
div.rs

1use ethnum::U256;
2use ibig::UBig;
3
4use super::Error;
5
6/// Computes (2^128 * x) / y and its remainder.
7/// TEMP HACK: need to implement this properly
8pub(super) fn stub_div_rem_u384_by_u256(x: U256, y: U256) -> Result<(U256, U256), Error> {
9    if y == U256::ZERO {
10        return Err(Error::DivisionByZero);
11    }
12
13    let x_big = ibig::UBig::from_le_bytes(&x.to_le_bytes());
14    let y_big = ibig::UBig::from_le_bytes(&y.to_le_bytes());
15    // this is what we actually want to compute: 384-bit / 256-bit division.
16    let x_big_128 = x_big << 128;
17    let q_big = &x_big_128 / &y_big;
18    let rem_big = x_big_128 - (&y_big * &q_big);
19
20    let Some(q) = ubig_to_u256(&q_big) else {
21        return Err(Error::Overflow);
22    };
23    let rem = ubig_to_u256(&rem_big).expect("rem < q, so we already returned on overflow");
24
25    Ok((q, rem))
26}
27
28#[allow(dead_code)]
29fn u256_to_ubig(x: U256) -> UBig {
30    let mut bytes = [0; 32];
31    bytes.copy_from_slice(&x.to_le_bytes());
32    UBig::from_le_bytes(&bytes)
33}
34
35fn ubig_to_u256(x: &UBig) -> Option<U256> {
36    let bytes = x.to_le_bytes();
37    if bytes.len() <= 32 {
38        let mut u256_bytes = [0; 32];
39        u256_bytes[..bytes.len()].copy_from_slice(&bytes);
40        Some(U256::from_le_bytes(u256_bytes))
41    } else {
42        None
43    }
44}
45
46#[cfg(test)]
47mod tests {
48    use super::*;
49    use proptest::prelude::*;
50
51    fn u256_strategy() -> BoxedStrategy<U256> {
52        any::<[u8; 32]>().prop_map(U256::from_le_bytes).boxed()
53    }
54
55    proptest! {
56        #[test]
57        fn stub_div_rem_works(
58            x in u256_strategy(),
59            y in u256_strategy()
60        ) {
61            let Ok((q, rem)) = stub_div_rem_u384_by_u256(x, y) else {
62                return Ok(());
63            };
64
65            let q_big = u256_to_ubig(q);
66            let rem_big = u256_to_ubig(rem);
67            let x_big = u256_to_ubig(x);
68            let y_big = u256_to_ubig(y);
69
70            let rhs = x_big << 128;
71            let lhs = &q_big * &y_big + &rem_big;
72            assert_eq!(rhs, lhs);
73        }
74    }
75}
76
77#[allow(dead_code)]
78fn div_rem_u384_by_u256(u: [u64; 6], mut v: [u64; 4]) -> ([u64; 6], [u64; 4]) {
79    // Uses Algorithm D from Knuth, vol 2, 4.3.1, p 272.
80
81    // Make a new buffer for u that will have an extra word.
82    let mut u = [u[0], u[1], u[2], u[3], u[4], u[5], 0];
83
84    // Find the most significant non-zero word of v.
85    let n = v
86        .iter()
87        .rposition(|&x| x != 0)
88        .expect("v has at least one nonzero word")
89        + 1;
90    assert!(
91        n >= 2,
92        "single-word division should use a different algorithm"
93    );
94    // 6 = m + n => m = 6 - n
95    let m = 6 - n;
96
97    // D1. [Normalize.] Multiply by d, a power of 2, so that the most significant bit of v[n-1] is set.
98    let lg_d = v[n - 1].leading_zeros();
99
100    // Normalize v in place by shifting, carrying bits across words.
101    // Working from the top down lets us avoid an explicit carry.
102    for i in (1..n).rev() {
103        v[i] = (v[i] << lg_d) | (v[i - 1] >> (64 - lg_d));
104    }
105    v[0] <<= lg_d;
106
107    // Normalize u in place by shifting, carrying bits across words.
108    // We may need an extra word to hold extra bits, since d was chosen from v, not u.
109    for i in (1..7).rev() {
110        u[i] = (u[i] << lg_d) | (u[i - 1] >> (64 - lg_d));
111    }
112    u[0] <<= lg_d;
113
114    // D2. [Initialize j.] Set j to m.
115    let mut j = m;
116
117    // This is really while j >= 0, but that's awkward without signed indexes.
118    loop {
119        // D3. [Calculate q_hat.]
120
121        // Set q_hat = (u[j+n]*2^64 + u[j+n-1]) / v[n-1].
122        let dividend = u128::from(u[j + n]) << 64 | u128::from(u[j + n - 1]);
123        let divisor = u128::from(v[n - 1]);
124        let mut q_hat = dividend / divisor;
125        let mut r_hat = dividend % divisor;
126
127        // Check whether we need to correct the estimated q_hat.
128        'correction: while q_hat >= 1 << 64
129            || q_hat * u128::from(v[n - 2]) > ((r_hat << 64) | u128::from(u[j + n - 2]))
130        {
131            q_hat -= 1;
132            r_hat += divisor;
133            if r_hat >= 1 << 64 {
134                break 'correction;
135            }
136        }
137
138        // D4. [Multiply and subtract.] Multiply v by q_hat, subtracting the result from u.
139
140        for _i in 0..=n {
141            todo!()
142        }
143
144        if j == 0 {
145            break;
146        } else {
147            j -= 1;
148        }
149    }
150
151    todo!()
152}