pcli/dex_utils/replicate/
math_utils.rs

1use ndarray::s;
2use ndarray::Array;
3use ndarray::Array2;
4
5/// Applies the Gaus-Seidel method to a square matrix A and returns
6/// a vector of solutions.
7pub(crate) fn gauss_seidel(
8    A: Array<f64, ndarray::Dim<[usize; 2]>>,
9    b: Array<f64, ndarray::Dim<[usize; 1]>>,
10    max_iterations: usize,
11    tolerance: f64,
12) -> anyhow::Result<Array<f64, ndarray::Dim<[usize; 1]>>> {
13    let n = A.shape()[0];
14
15    // First, we decompose the matrix into a lower triangular (L),
16    // and an off-diagonal upper triangular matrix (D) st. A = L + D
17    let L = lower_triangular(&A);
18    let D = &A - &L;
19
20    let mut k = Array::zeros(n);
21    for _ in 0..max_iterations {
22        let k_old = k.clone();
23
24        for i in 0..n {
25            let partial_off_diagonal_solution = D.slice(s![i, ..]).dot(&k);
26            let partial_lower_triangular_solution = L.slice(s![i, ..i]).dot(&k.slice(s![..i]));
27            let sum_ld = partial_off_diagonal_solution + partial_lower_triangular_solution;
28            k[i] = (b[i] - sum_ld) / L[[i, i]];
29        }
30
31        let delta_approximation = &k - &k_old;
32        let l2_norm_delta = delta_approximation
33            .iter()
34            .map(|&x| x * x)
35            .sum::<f64>()
36            .sqrt();
37
38        if l2_norm_delta < tolerance {
39            break;
40        }
41    }
42
43    Ok(k)
44}
45
46/// Converts a square matrix into a lower triangular matrix.
47pub(crate) fn lower_triangular(matrix: &Array2<f64>) -> Array2<f64> {
48    let (rows, cols) = matrix.dim();
49    let mut result = Array2::zeros((rows, cols));
50
51    for i in 0..rows {
52        for j in 0..=i {
53            result[[i, j]] = matrix[[i, j]];
54        }
55    }
56
57    result
58}
59
60/// Sample `num_points` up to a supplied `upper` limit.
61pub(crate) fn sample_to_upper(upper: f64, num_points: usize) -> Vec<f64> {
62    let step = upper / (num_points as f64);
63
64    (1..=num_points).map(|i| (i as f64) * step).collect()
65}