Posts
Reproducibility in Computational Research
JIT compilers for scientific computing in Python: Numba vs. JAX
Scientific Computing with JAX
Presented at Durham HPC Days 2025.
JAX is a Python library that combines just-in-time (JIT) compilation with automatic differentiation, powered by XLA (Accelerated Linear Algebra compiler), to target multiple hardware architectures including CPU, GPU (NVIDIA, AMD, Intel, Apple), and Google TPU. While primarily designed for machine learning research, JAX presents compelling advantages for scientific computing in the HPC landscape.
As HPC increasingly depends on heterogeneous accelerators, JAX’s “write once, deploy anywhere” approach enables scientific libraries to efficiently utilize diverse computing resources. Python programmers familiar with array programming and other JIT solutions like Numba can transition to JAX with minimal code changes, allowing straightforward adoption. Additionally, automatic differentiation provides gradient calculations for optimization problems with minimal developer effort — automatically generating derivatives that would be prohibitively complex to implement manually, while enabling optimization algorithms to converge more rapidly with gradient information.
This presentation examines the adaptation of a likelihood function calculation from Numba to JAX in the context of gravitational lensing analysis of James Webb Space Telescope observations. The case study specifically addresses the convergence of AI/ML techniques with traditional scientific computing, highlighting performance gains, implementation challenges, and practical considerations when migrating existing scientific code to JAX’s programming model. We’ll discuss how JAX enables scalable processing across different HPC hardware resources while maintaining scientific accuracy and computational efficiency.