Kinetic Plasma Physics Using JAX
We wrote a 1D-1V Pythonic Vlasov-Poisson-Fokker-Planck solver last year and published it in JOSS. This used NumPy and was tested against well-established plasma physics calculations.
There were a couple of principle design choices that were made.
- It was functional and made with a composable interface in mind. i.e. the solvers were composable
- The other was that it came with an experiment management software dependency in the form of MLFlow
An example of the functional approach is in the following code snippet.
def get_edfdv_exponential(kv):
"""
This function creates the exponential v df/dx stepper
It uses kv as metadata that should stay constant throughout the simulation
:param kv: (float array (nv, )) the velocity-space wavenumber
:return: a function with the above values initialized as static variables
"""
def step_edfdv_exponential(f, e, dt):
"""
evolution of df/dt = e df/dv using the exponential integrator described in
[1].
[1] - https://juliavlasov.github.io/ -- Dr. Pierre Navarro
:param f: (float array (nx, nv)) distribution function
:param e: (float array (nx, )) the electric field in real space
:param dt: (float) timestep
:return: (float array (nx, nv)) updated distribution function
"""
return np.real(
fft.ifft(np.exp(-1j * kv * dt * e[:, None]) * fft.fft(f, axis=1), axis=1)
)
return step_edfdv_exponential
In this snippet, we initialize the $e \partial f/\partial v$ step as a function and return it to a different function which creates a function composed of smaller sub-physics as each part of a timestep. For us, the main reason for following the composition approach was that it enables explicit transformations of the functions.
These transformations are exploited by the modern-day ML frameworks. Here we count on these transformations to enable the following two functionalities:
- Be able to perform differentiable Vlasov-Poisson-Fokker-Planck simulations
- Run the simulations on a GPU
JAX
readily enables differentiable simulations although it has its weaknesses and it can compile code for a GPU to consume. I realize Julia enables this as well. I hope to write about a Julia implementation one day.
In today’s post, I’ll discuss some gotchas in porting from NumPy
to JAX
and the performance gain from running on GPUs.
In a future post, I will discuss some applications of the differentiable simulations
Porting VlaPy to JAX
Swap JAX NumPy for NumPy
In the new version of VlaPy
, the above code snippet was transformed into
def get_edfdv_exponential(custom_methods: dict, kv):
"""
This function creates the exponential edfdv stepper
It uses kv as metadata that should stay constant throughout the simulation
:param custom_methods:
:param kv:
:return:
"""
this_np = custom_methods["numpy"]
checkpoint = custom_methods["checkpoint"]
jit = custom_methods["jit"]
kv_real = kv[: int(kv.size // 2) + 1]
@checkpoint
def step_edfdv_exponential(f, e, dt):
"""
evolution of df/dt = e df/dv
:param f: distribution function. (numpy array of shape (nx, nv))
:param e: electric field (numpy array of shape (nx,))
:param dt: timestep (single float value)
:return:
"""
return this_np.fft.irfft(
this_np.exp(-1j * kv_real[None, :] * dt * e[:, None])
* this_np.fft.rfft(f, axis=1),
axis=1,
)
return jit(step_edfdv_exponential)
There are a few additions here. We now use this_np
to denote which NumPy
we’re using. It’s passed through the custom_methods
dictionary
We also make use of jax.checkpoint
and jax.jit
for checkpointing backwards gradient calls and JIT-compiling our functions.
If the NumPy
backed is being used, checkpoint
and jit
are dummy functions.
FFT package
At the time of porting, JAX had yet to implement an FFT package. As you can see in the code snippet above, VlaPy’s solvers use spectral derivatives. In fact, we find that spectral differentiation is far more stable than finite difference. In order to port to JAX, we needed a set of FFT routines in JAX. Luckily, I was able to contribute here with a few very straightforward lines of code because the underlying routines were already written and integrated.
Tridiagonal Solver
The other was the broadcasted tridiagonal solver that’s implemented in the NumPy version. Here is the code for the NumPy version
def _batched_tridiagonal_solver_(a, b, c, f):
"""
Arrayed/Sliced algorithm for tridiagonal solve.
About 50x faster than looping through a numpy linalg call
for a 256x2048 solve
:param a:
:param b:
:param c:
:param f:
:param nv:
:return:
"""
ac = a.copy()
bc = b.copy()
cc = c.copy()
dc = f.copy()
for it in range(1, nv):
mc = ac[:, it - 1] / bc[:, it - 1]
bc[:, it] = bc[:, it] - mc * cc[:, it - 1]
dc[:, it] = dc[:, it] - mc * dc[:, it - 1]
xc = bc
xc[:, -1] = dc[:, -1] / bc[:, -1]
for il in range(nv - 2, -1, -1):
xc[:, il] = (dc[:, il] - cc[:, il] * xc[:, il + 1]) / bc[:, il]
return xc
That routine uses index-based assignment and a more JAX-friendly version is needed. We found this excellent routine (written for solving tridiagonal systems in CFD), also see ref. [1] and we ported it over for our purpose as follows
@checkpoint
@jit
def compute_primes(last_primes, x):
"""
This function is a single iteration of the forward pass in the non-in-place Thomas
tridiagonal algorithm
:param last_primes:
:param x:
:return:
"""
last_cp, last_dp = last_primes
a, b, c, d = x
cp = c / (b - a * last_cp)
dp = (d - a * last_dp) / (b - a * last_cp)
new_primes = this_np.stack((cp, dp))
return new_primes, new_primes
@checkpoint
@jit
def backsubstitution(last_x, x):
"""
This function is a single iteration of the backward pass in the non-in-place Thomas
tridiagonal algorithm
:param last_x:
:param x:
:return:
"""
cp, dp = x
new_x = dp - cp * last_x
return new_x, new_x
@checkpoint
def _batched_tridiagonal_solver_(a, b, c, d):
"""
Solves a tridiagonal matrix system with diagonals a, b, c and RHS vector d.
This uses the non-in-place Thomas tridiagonal algorithm.
The NumPy version, on the other hand, uses the in-place algorithm.
:param a: (2D float array (nx, nv)) represents the subdiagonal of the linear operator
:param b: (2D float array (nx, nv)) represents the main diagonal of the linear operator
:param c: (2D float array (nx, nv)) represents the super diagonal of the linear operator
:param d: (2D float array (nx, nv)) represents the right hand side of the linear operator
:return:
"""
diags_stacked = this_np.stack(
[arr.transpose((1, 0)) for arr in (a, b, c, d)], axis=1
)
_, primes = scan(
compute_primes,
this_np.zeros((2, *a.shape[:-1])),
diags_stacked,
unroll=num_unroll,
)
_, sol = scan(
backsubstitution,
this_np.zeros(a.shape[:-1]),
primes[::-1],
unroll=num_unroll,
)
return sol[::-1].transpose((1, 0))
Implementing this was a great reminder of why tests are useful.
Double Precision
It was also important to enable double precision in JAX. Otherwise, the Landau damping tests do not pass! This makes sense because Landau damping is a result of the interaction of the electrostatic field with the distribution function in phase space where the value of the distribution function is often very small i.e. $\sim 10^{-8}$
Sidenote - this is one blocker towards the adoption of ML-friendly types like BF16
Performance Results
So was all that work worth it?
One thing to check is the performance since JAX runs well on GPUs.
Running VFP on a GPU using JAX is about 30x faster
Here are some performance metrics with respect to grid size. $N_x$ was varied from 32 to 1024. The upper limit was primarily defined by how patient I was willing to be with the NumPy/SciPy CPU implementation. We generally see about a 30x speed-up when running on JAX and GPU.
Solving the Fokker-Planck equation is nearly free with NumPy and the broadcasted Thomas tridiagonal solver. On JAX, however, there is a sizable performance hit when solving the Fokker-Planck equation. This can be remedied going forward by providing JAX an API for CUDA’s batched tridiagonal solve
Here’s a similar plot but for $N_v$.
What does this mean for real simulations?
To simulate an electrostatic wave for $\sim 500 \omega_p^{-1}$, we need 1000 timesteps using the sixth order time-stepper. A possible grid resolution for these might be $N_x = 256, N_v = 2048$. We observe a single time-step with this grid-size takes 0.7 ms, and this simulation takes roughly 1s on a T4.
Okay, how about a very high-resolution electrostatic wave, like a KEEN wave simulation? For $N_x = 1024, N_v = 16384$, one timestep takes 0.3s. 2000 timesteps takes about 10 minutes. That compares well with the benchmarks in ref. [2] where a similar simulation takes 2847s i.e. 45 minutes on 128 cores. This is a nice 5x - 640x speedup (depending on how you factor in the parallelization) in comparison to their highly optimized FORTRAN code. However, part of this is speedup is because of the solvers used as well. Modern FFTs are quite fast, even compared to the straightforward tridiagonal solves that are needed for a semi-Lagrangian method.
Okay, but what about an electrostatic wavepacket like in ref. [3]? For those, the timescale is similar, but the grid is larger in $N_x$. We performed some wavepacket simulations with $N_x = 4096, N_v = 8192$ and observed the per-time-step calculation time to be 0.17 s. A full simulation takes about 5.5 minutes.
All in all, the speed-up is excellent!
Note that all of these times are on a NVIDIA T4 GPU, running it on a V100 is about 3 times faster (but 6x more expensive on AWS!)
Conclusion
Porting to JAX was straightforward. I wanted to maintain the ability to switch back to NumPy for testing and be able to
swap modules in and out. I also needed to add @jit
and @checkpoint
decorators to many of the core functions in order
to fully realize the potential in JAX. To do that and maintain backward compatibility, I did have to create dummy routines
that passed those two decorators. In hindsight, much of this can be solved by using a proper configuration management system like
hydra
or gin
.
It was definitely worth it for the speed-up alone. Being able to run the wavepacket simulation in a few minutes is a happy outcome. This enables the generation of large datasets for machine learning, and parameteric scans in general.
In a future post, I will show some results of Canonical 1D-1V VPFP problems and their performance using JAX on GPUs. In the meantime,
here is a plot of the distribution function of a non-linear electrostatic wavepacket.
References
[1]: Häfner, Dion, René Løwe Jacobsen, Carsten Eden, Mads R. B. Kristensen, Markus Jochum, Roman Nuterman, and Brian Vinter. “Veros v0.1 – a Fast and Versatile Ocean Simulator in Pure Python.” Geoscientific Model Development 11, no. 8 (August 16, 2018): 3299–3312. https://doi.org/10.5194/gmd-11-3299-2018.
[2]: Afeyan, Bedros, Fernando Casas, Nicolas Crouseilles, Adila Dodhy, Erwan Faou, Michel Mehrenberger, and Eric Sonnendrücker. “Simulations of Kinetic Electrostatic Electron Nonlinear (KEEN) Waves with Variable Velocity Resolution Grids and High-Order Time-Splitting.” The European Physical Journal D 68, no. 10 (October 2014): 295. https://doi.org/10.1140/epjd/e2014-50212-6.
[3]: Fahlen, J. E., B. J. Winjum, T. Grismayer, and W. B. Mori. “Propagation and Damping of Nonlinear Plasma Wave Packets.” Physical Review Letters 102, no. 24 (June 17, 2009): 245002. https://doi.org/10.1103/PhysRevLett.102.245002.