DiRAC: revealing the nature of dark matter with the James Webb space telescope and JAX

Scientific computing with JAX, a case study

Kolen Cheung

March 26th, 2025

Introduction

Newer versions

An updated version of this talk is presented in the Durham HPC Days 2025 and is available here.

Introduction of the physics

Non-linear Search (Dynesty)

PyAutoFit supports Nested sampling (Dynesty), MCMC (emcee), particle swarm optimization (PySwarms)

Introduction of the project

Goal of the project

Excerpts from project proposal:

… Both codes quote speed ups exceeding 20 times, sometimes going up to factors of 50 times faster. A times ten speed up would be sufficient for the multi-wavelength JWST analysis this proposal aims to do, thus JAX appears to be capable of giving the speed up necessary.

Efficiency gains will come from both making the lens model-fitting auto differentiable and the lensing calculations support accelerated linear algebra (XLA).

Organization of the project

Organization of the project (cont.)

Case study

\(\tilde{w}\)—Code: original version

\[\tilde{W}_{ij} = \sum_{k=1}^N \frac{1}{n_k^2} \cos(2\pi[(g_{i1} - g_{j1})u_{k0} + (g_{i0} - g_{j0})u_{k1}])\]

@numba.jit(nopython=True, nogil=True, parallel=True)
def w_tilde_curvature_interferometer_from(
    noise_map_real: np.ndarray,
    uv_wavelengths: np.ndarray,
    grid_radians_slim: np.ndarray,
) -> np.ndarray:
    w_tilde = np.zeros((grid_radians_slim.shape[0], grid_radians_slim.shape[0]))

    for i in range(w_tilde.shape[0]):
        for j in range(i, w_tilde.shape[1]):
            y_offset = grid_radians_slim[i, 1] - grid_radians_slim[j, 1]
            x_offset = grid_radians_slim[i, 0] - grid_radians_slim[j, 0]

            for vis_1d_index in range(uv_wavelengths.shape[0]):
                w_tilde[i, j] += noise_map_real[vis_1d_index] ** -2.0 * np.cos(
                    2.0
                    * np.pi
                    * (y_offset * uv_wavelengths[vis_1d_index, 0] + x_offset * uv_wavelengths[vis_1d_index, 1])
                )

    for i in range(w_tilde.shape[0]):
        for j in range(i, w_tilde.shape[1]):
            w_tilde[j, i] = w_tilde[i, j]

    return w_tilde

\(\tilde{w}\)—Code: 1st try

\[\tilde{W}_{ij} = \sum_{k=1}^N \frac{1}{n_k^2} \cos(2\pi[(g_{i1} - g_{j1})u_{k0} + (g_{i0} - g_{j0})u_{k1}])\]

@jax.jit
def w_tilde_curvature_interferometer_from(
    noise_map_real: np.ndarray[tuple[int], np.float64],
    uv_wavelengths: np.ndarray[tuple[int, int], np.float64],
    grid_radians_slim: np.ndarray[tuple[int, int], np.float64],
) -> np.ndarray[tuple[int, int], np.float64]:
    # (M, M, 1, 2)
    g_ij =  grid_radians_slim.reshape(-1, 1, 1, 2) - grid_radians_slim.reshape(1, -1, 1, 2)
    # (1, 1, K, 2)
    u_k = uv_wavelengths.reshape(1, 1, -1, 2)
    return (
        jnp.cos(
            (2.0 * jnp.pi) *
            # (M, M, K)
            (
                g_ij[:, :, :, 0] * u_k[:, :, :, 1] +
                g_ij[:, :, :, 1] * u_k[:, :, :, 0]
            )
        ) /
        # (1, 1, K)
        jnp.square(noise_map_real).reshape(1, 1, -1)
    ).sum(2)  # sum over k

\(\tilde{w}\)—Code: 2nd try

\[\tilde{W}_{ij} = \sum_{k=1}^N \frac{1}{n_k^2} \cos(2\pi[(g_{i1} - g_{j1})u_{k0} + (g_{i0} - g_{j0})u_{k1}])\]

@jax.jit
def w_tilde_curvature_interferometer_from(
    noise_map_real: np.ndarray[tuple[int], np.float64],
    uv_wavelengths: np.ndarray[tuple[int, int], np.float64],
    grid_radians_slim: np.ndarray[tuple[int, int], np.float64],
) -> np.ndarray[tuple[int, int], np.float64]:
    # A_mk, m<M, k<K
    # assume M > K to put TWO_PI multiplication there
    A = grid_radians_slim @ (TWO_PI * uv_wavelengths)[:, ::-1].T

    noise_map_real_inv = jnp.reciprocal(noise_map_real)
    C = jnp.cos(A) * noise_map_real_inv
    S = jnp.sin(A) * noise_map_real_inv

    return C @ C.T + S @ S.T

\(\tilde{w}\)—Code: final try

\[\tilde{W}_{ij} = \sum_{k=1}^N \frac{1}{n_k^2} \cos(2\pi[(g_{i1} - g_{j1})u_{k0} + (g_{i0} - g_{j0})u_{k1}])\]

@jax.jit
def w_tilde_curvature_interferometer_from(
    noise_map_real: np.ndarray[tuple[int], np.float64],
    uv_wavelengths: np.ndarray[tuple[int, int], np.float64],
    grid_radians_slim: np.ndarray[tuple[int, int], np.float64],
) -> np.ndarray[tuple[int, int], np.float64]:
    M = grid_radians_slim.shape[0]
    g_2pi = TWO_PI * grid_radians_slim
    δg_2pi = g_2pi.reshape(M, 1, 2) - g_2pi.reshape(1, M, 2)
    δg_2pi_y = δg_2pi[:, :, 0]
    δg_2pi_x = δg_2pi[:, :, 1]

    def f_k(
        noise_map_real: float,
        uv_wavelengths: np.ndarray[tuple[int], np.float64],
    ) -> np.ndarray[tuple[int, int], np.float64]:
        return jnp.cos(δg_2pi_x * uv_wavelengths[0] + δg_2pi_y * uv_wavelengths[1]) * jnp.reciprocal(
            jnp.square(noise_map_real)
        )

    def f_scan(
        sum_: np.ndarray[tuple[int, int], np.float64],
        args: tuple[float, np.ndarray[tuple[int], np.float64]],
    ) -> tuple[np.ndarray[tuple[int, int], np.float64], None]:
        noise_map_real, uv_wavelengths = args
        return sum_ + f_k(noise_map_real, uv_wavelengths), None

    res, _ = jax.lax.scan(
        f_scan,
        jnp.zeros((M, M)),
        (
            noise_map_real,
            uv_wavelengths,
        ),
    )
    return res

March madness: Numba vs JAX with 1 CPU core

N=64_B=3_K=32768_P=32_S=256

pixi run pytest-benchmark compare 1_N=64_B=3_K=32768_P=32_S=256_NUM_THREADS=1 256_N=64_B=3_K=32768_P=32_S=256_NUM_THREADS=256 --columns=mean,stddev,ops,rounds,iterations --sort=mean
--------------------------------------------- benchmark 'w_tilde_curvature_interferometer_from_DataGenerated': 9 tests ---------------------------------------------
Name (time in s)                                                                              Mean            StdDev               OPS            Rounds  Iterations
--------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_w_tilde_curvature_interferometer_from_jax_compact[DataGenerated]                       2.5336 (1.0)      0.0014 (3.59)     0.3947 (1.0)           5           1
test_w_tilde_curvature_interferometer_from_jax_compact_expanded[DataGenerated]              2.5398 (1.00)     0.0021 (5.48)     0.3937 (1.00)          5           1
test_w_tilde_curvature_interferometer_from_numba_compact[DataGenerated]                     2.7585 (1.09)     0.0004 (1.0)      0.3625 (0.92)          5           1
test_w_tilde_curvature_interferometer_from_numba_compact_expanded[DataGenerated]            2.7880 (1.10)     0.0011 (3.01)     0.3587 (0.91)          5           1
test_w_tilde_curvature_interferometer_from_original_preload[DataGenerated]                 13.6873 (5.40)     0.0009 (2.38)     0.0731 (0.19)          5           1
test_w_tilde_curvature_interferometer_from_original_preload_expanded[DataGenerated]        13.7170 (5.41)     0.0014 (3.57)     0.0729 (0.18)          5           1
test_w_tilde_curvature_interferometer_from_jax[DataGenerated]                           3,284.1203 (>1000.0)  0.3030 (799.20)   0.0003 (0.00)          5           1
test_w_tilde_curvature_interferometer_from_numba[DataGenerated]                         3,613.6091 (>1000.0)  6.4241 (>1000.0)  0.0003 (0.00)          5           1
test_w_tilde_curvature_interferometer_from_original[DataGenerated]                      4,390.4546 (>1000.0)  0.7409 (>1000.0)  0.0002 (0.00)          5           1
--------------------------------------------------------------------------------------------------------------------------------------------------------------------

match 2: Numba with 128 CPU cores and JAX with CUDA on GPU (A100)

N=32_B=300_K=8192_P=32_S=256

pixi run pytest-benchmark compare --columns=mean,stddev,ops,rounds,iterations --sort=mean 0009_N=32_B=300_K=8192_P=32_S=256_NUM_THREADS=128_cuda
----------------------------------------------- benchmark 'w_tilde_curvature_interferometer_from_DataGenerated': 9 tests -----------------------------------------------
Name (time in ms)                                                                              Mean             StdDev                 OPS            Rounds  Iterations
------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_w_tilde_curvature_interferometer_from_numba_compact[DataGenerated]                      2.5029 (1.0)       0.0520 (1.25)     399.5298 (1.0)         260           1
test_w_tilde_curvature_interferometer_from_numba_compact_expanded[DataGenerated]             3.7808 (1.51)      0.0415 (1.0)      264.4928 (0.66)        196           1
test_w_tilde_curvature_interferometer_from_jax_compact_expanded[DataGenerated]              58.6799 (23.44)     9.1624 (220.76)    17.0416 (0.04)         12           1
test_w_tilde_curvature_interferometer_from_jax_compact[DataGenerated]                       61.5560 (24.59)     6.8555 (165.18)    16.2454 (0.04)         12           1
test_w_tilde_curvature_interferometer_from_jax[DataGenerated]                              143.2451 (57.23)     0.0749 (1.80)       6.9810 (0.02)          5           1
test_w_tilde_curvature_interferometer_from_original_preload[DataGenerated]                 840.0727 (335.63)    0.1761 (4.24)       1.1904 (0.00)          5           1
test_w_tilde_curvature_interferometer_from_original_preload_expanded[DataGenerated]        842.6588 (336.67)    0.4933 (11.89)      1.1867 (0.00)          5           1
test_w_tilde_curvature_interferometer_from_numba[DataGenerated]                          1,794.1949 (716.83)   14.9648 (360.56)     0.5574 (0.00)          5           1
test_w_tilde_curvature_interferometer_from_original[DataGenerated]                      69,304.2738 (>1000.0)  13.5543 (326.58)     0.0144 (0.00)          5           1
------------------------------------------------------------------------------------------------------------------------------------------------------------------------

Bonus round: Numba vs JAX with 1 CPU core (\(F\))

\[F = T^T \tilde{w} T\]

N=64_B=3_K=32768_P=32_S=256

pixi run pytest-benchmark compare 1_N=64_B=3_K=32768_P=32_S=256_NUM_THREADS=1 256_N=64_B=3_K=32768_P=32_S=256_NUM_THREADS=256 --columns=mean,stddev,ops,rounds,iterations --sort=mean
--------------------------------------------- benchmark 'curvature_matrix_DataGenerated': 11 tests ---------------------------------------------
Name (time in ms)                                                        Mean            StdDev                OPS            Rounds  Iterations
------------------------------------------------------------------------------------------------------------------------------------------------
test_curvature_matrix_numba_sparse[DataGenerated]                     10.2812 (1.0)      0.1751 (1.50)     97.2652 (1.0)          94           1
test_curvature_matrix_jax[DataGenerated]                              20.0919 (1.95)     1.5245 (13.03)    49.7714 (0.51)         30           1
test_curvature_matrix_jax_sparse[DataGenerated]                       36.8161 (3.58)     1.2018 (10.27)    27.1620 (0.28)          9           1
test_curvature_matrix_numba_compact_sparse[DataGenerated]             52.9714 (5.15)     0.1170 (1.0)      18.8781 (0.19)         19           1
test_curvature_matrix_jax_BCOO[DataGenerated]                         63.1237 (6.14)     3.4937 (29.87)    15.8419 (0.16)          7           1
test_curvature_matrix_original_preload_direct[DataGenerated]          99.3323 (9.66)     0.1512 (1.29)     10.0672 (0.10)         11           1
test_curvature_matrix_numba[DataGenerated]                           126.0225 (12.26)    0.1512 (1.29)      7.9351 (0.08)          8           1
test_curvature_matrix_original[DataGenerated]                        126.5191 (12.31)    0.1542 (1.32)      7.9039 (0.08)          8           1
test_curvature_matrix_numba_compact_sparse_direct[DataGenerated]     140.8911 (13.70)    0.6626 (5.66)      7.0977 (0.07)          8           1
test_curvature_matrix_jax_compact_sparse[DataGenerated]              533.6890 (51.91)    5.1037 (43.63)     1.8738 (0.02)          5           1
test_curvature_matrix_jax_compact_sparse_BCOO[DataGenerated]         537.2100 (52.25)    2.2710 (19.42)     1.8615 (0.02)          5           1
------------------------------------------------------------------------------------------------------------------------------------------------

Bonus match 2: Numba with 128 CPU cores and JAX with CUDA on GPU (A100) (\(F\))

N=32_B=300_K=8192_P=32_S=256

pixi run pytest-benchmark compare --columns=mean,stddev,ops,rounds,iterations --sort=mean 0009_N=32_B=300_K=8192_P=32_S=256_NUM_THREADS=128_cuda
---------------------------------------------------- benchmark 'curvature_matrix_DataGenerated': 11 tests ----------------------------------------------------
Name (time in us)                                                               Mean                StdDev                   OPS            Rounds  Iterations
--------------------------------------------------------------------------------------------------------------------------------------------------------------
test_curvature_matrix_jax[DataGenerated]                                    260.5957 (1.0)         29.3714 (1.0)      3,837.3618 (1.0)        1295           1
test_curvature_matrix_jax_BCOO[DataGenerated]                             3,078.2068 (11.81)       35.9463 (1.22)       324.8645 (0.08)          5           1
test_curvature_matrix_jax_compact_sparse_BCOO[DataGenerated]              3,207.3388 (12.31)      107.0798 (3.65)       311.7850 (0.08)        293           1
test_curvature_matrix_numba_sparse[DataGenerated]                         5,548.5175 (21.29)       64.3711 (2.19)       180.2283 (0.05)        146           1
test_curvature_matrix_jax_compact_sparse[DataGenerated]                   7,190.9015 (27.59)       35.7355 (1.22)       139.0646 (0.04)        134           1
test_curvature_matrix_numba[DataGenerated]                               18,187.5003 (69.79)    5,603.6081 (190.78)      54.9828 (0.01)        207           1
test_curvature_matrix_original[DataGenerated]                            18,279.9851 (70.15)    6,052.1386 (206.06)      54.7046 (0.01)          9           1
test_curvature_matrix_jax_sparse[DataGenerated]                          19,786.7200 (75.93)       42.9344 (1.46)        50.5389 (0.01)          6           1
test_curvature_matrix_numba_compact_sparse[DataGenerated]                32,605.2243 (125.12)     248.8764 (8.47)        30.6699 (0.01)         30           1
test_curvature_matrix_numba_compact_sparse_direct[DataGenerated]      1,362,329.9249 (>1000.0)  1,366.9112 (46.54)        0.7340 (0.00)          5           1
test_curvature_matrix_original_preload_direct[DataGenerated]         25,218,633.7856 (>1000.0)  8,722.4870 (296.97)       0.0397 (0.00)          5           1
--------------------------------------------------------------------------------------------------------------------------------------------------------------

Beyond porting

Pipeline of \(\tilde{w}\)

Internal sparse data structure

Lessons learnt from Numba vs. JAX

export MKL_NUM_THREADS={num_threads}
export MKL_DOMAIN_NUM_THREADS="MKL_BLAS={num_threads}"
export MKL_DYNAMIC=FALSE

export OMP_NUM_THREADS={num_threads}
export OMP_PLACES=threads
export OMP_PROC_BIND=spread
export OMP_DYNAMIC=FALSE

export NUMEXPR_NUM_THREADS={num_threads}

export OPENBLAS_NUM_THREADS={num_threads}

export NUMBA_NUM_THREADS={num_threads}

export XLA_FLAGS="--xla_cpu_multi_thread_eigen=true intra_op_parallelism_threads={num_threads} --xla_force_host_platform_device_count={num_threads}"
export TF_NUM_INTEROP_THREADS=1
export TF_NUM_INTRAOP_THREADS={num_threads}

Misc.

Pull \(K\) outside the log-likelihood loop.