Kinetic Plasma Physics Using JAX

We wrote a 1D-1V Pythonic Vlasov-Poisson-Fokker-Planck solver last year and published it in JOSS. This used NumPy and was tested against well-established plasma physics calculations.

There were a couple of principle design choices that were made.

  • It was functional and made with a composable interface in mind. i.e. the solvers were composable
  • The other was that it came with an experiment management software dependency in the form of MLFlow

An example of the functional approach is in the following code snippet.

def get_edfdv_exponential(kv):
    """
    This function creates the exponential v df/dx stepper
    It uses kv as metadata that should stay constant throughout the simulation
    :param kv: (float array (nv, )) the velocity-space wavenumber
    :return: a function with the above values initialized as static variables
    """

    def step_edfdv_exponential(f, e, dt):
        """
        evolution of df/dt = e df/dv using the exponential integrator described in
        [1].
        [1] - https://juliavlasov.github.io/ -- Dr. Pierre Navarro
        :param f: (float array (nx, nv)) distribution function
        :param e: (float array (nx, )) the electric field in real space
        :param dt: (float) timestep
        :return: (float array (nx, nv)) updated distribution function
        """

        return np.real(
            fft.ifft(np.exp(-1j * kv * dt * e[:, None]) * fft.fft(f, axis=1), axis=1)
        )

    return step_edfdv_exponential

In this snippet, we initialize the $e \partial f/\partial v$ step as a function and return it to a different function which creates a function composed of smaller sub-physics as each part of a timestep. For us, the main reason for following the composition approach was that it enables explicit transformations of the functions.

These transformations are exploited by the modern-day ML frameworks. Here we count on these transformations to enable the following two functionalities:

  • Be able to perform differentiable Vlasov-Poisson-Fokker-Planck simulations
  • Run the simulations on a GPU

JAX readily enables differentiable simulations although it has its weaknesses and it can compile code for a GPU to consume. I realize Julia enables this as well. I hope to write about a Julia implementation one day.

In today’s post, I’ll discuss some gotchas in porting from NumPy to JAX and the performance gain from running on GPUs. In a future post, I will discuss some applications of the differentiable simulations

Porting VlaPy to JAX

Swap JAX NumPy for NumPy

In the new version of VlaPy, the above code snippet was transformed into

def get_edfdv_exponential(custom_methods: dict, kv):
    """
    This function creates the exponential edfdv stepper

    It uses kv as metadata that should stay constant throughout the simulation

    :param custom_methods:
    :param kv:
    :return:
    """
    this_np = custom_methods["numpy"]
    checkpoint = custom_methods["checkpoint"]
    jit = custom_methods["jit"]

    kv_real = kv[: int(kv.size // 2) + 1]

    @checkpoint
    def step_edfdv_exponential(f, e, dt):
        """
        evolution of df/dt = e df/dv

        :param f: distribution function. (numpy array of shape (nx, nv))
        :param e: electric field (numpy array of shape (nx,))
        :param dt: timestep (single float value)
        :return:
        """

        return this_np.fft.irfft(
            this_np.exp(-1j * kv_real[None, :] * dt * e[:, None])
            * this_np.fft.rfft(f, axis=1),
            axis=1,
        )

    return jit(step_edfdv_exponential)

There are a few additions here. We now use this_np to denote which NumPy we’re using. It’s passed through the custom_methods dictionary We also make use of jax.checkpoint and jax.jit for checkpointing backwards gradient calls and JIT-compiling our functions. If the NumPy backed is being used, checkpoint and jit are dummy functions.

FFT package

At the time of porting, JAX had yet to implement an FFT package. As you can see in the code snippet above, VlaPy’s solvers use spectral derivatives. In fact, we find that spectral differentiation is far more stable than finite difference. In order to port to JAX, we needed a set of FFT routines in JAX. Luckily, I was able to contribute here with a few very straightforward lines of code because the underlying routines were already written and integrated.

Tridiagonal Solver

The other was the broadcasted tridiagonal solver that’s implemented in the NumPy version. Here is the code for the NumPy version

def _batched_tridiagonal_solver_(a, b, c, f):
    """
    Arrayed/Sliced algorithm for tridiagonal solve.

    About 50x faster than looping through a numpy linalg call
    for a 256x2048 solve

    :param a:
    :param b:
    :param c:
    :param f:
    :param nv:
    :return:
    """

    ac = a.copy()
    bc = b.copy()
    cc = c.copy()
    dc = f.copy()

    for it in range(1, nv):
        mc = ac[:, it - 1] / bc[:, it - 1]
        bc[:, it] = bc[:, it] - mc * cc[:, it - 1]
        dc[:, it] = dc[:, it] - mc * dc[:, it - 1]

    xc = bc
    xc[:, -1] = dc[:, -1] / bc[:, -1]

    for il in range(nv - 2, -1, -1):
        xc[:, il] = (dc[:, il] - cc[:, il] * xc[:, il + 1]) / bc[:, il]

    return xc

That routine uses index-based assignment and a more JAX-friendly version is needed. We found this excellent routine (written for solving tridiagonal systems in CFD), also see ref. [1] and we ported it over for our purpose as follows

@checkpoint
@jit
def compute_primes(last_primes, x):
    """
    This function is a single iteration of the forward pass in the non-in-place Thomas
    tridiagonal algorithm

    :param last_primes:
    :param x:
    :return:
    """

    last_cp, last_dp = last_primes
    a, b, c, d = x
    cp = c / (b - a * last_cp)
    dp = (d - a * last_dp) / (b - a * last_cp)
    new_primes = this_np.stack((cp, dp))
    return new_primes, new_primes

@checkpoint
@jit
def backsubstitution(last_x, x):
    """
    This function is a single iteration of the backward pass in the non-in-place Thomas
    tridiagonal algorithm

    :param last_x:
    :param x:
    :return:
    """
    cp, dp = x
    new_x = dp - cp * last_x
    return new_x, new_x

@checkpoint
def _batched_tridiagonal_solver_(a, b, c, d):
    """
    Solves a tridiagonal matrix system with diagonals a, b, c and RHS vector d.

    This uses the non-in-place Thomas tridiagonal algorithm.

    The NumPy version, on the other hand, uses the in-place algorithm.

    :param a: (2D float array (nx, nv)) represents the subdiagonal of the linear operator
    :param b: (2D float array (nx, nv)) represents the main diagonal of the linear operator
    :param c: (2D float array (nx, nv)) represents the super diagonal of the linear operator
    :param d: (2D float array (nx, nv)) represents the right hand side of the linear operator
    :return:
    """

    diags_stacked = this_np.stack(
        [arr.transpose((1, 0)) for arr in (a, b, c, d)], axis=1
    )
    _, primes = scan(
        compute_primes,
        this_np.zeros((2, *a.shape[:-1])),
        diags_stacked,
        unroll=num_unroll,
    )
    _, sol = scan(
        backsubstitution,
        this_np.zeros(a.shape[:-1]),
        primes[::-1],
        unroll=num_unroll,
    )
    return sol[::-1].transpose((1, 0))

Implementing this was a great reminder of why tests are useful.

Double Precision

It was also important to enable double precision in JAX. Otherwise, the Landau damping tests do not pass! This makes sense because Landau damping is a result of the interaction of the electrostatic field with the distribution function in phase space where the value of the distribution function is often very small i.e. $\sim 10^{-8}$

Sidenote - this is one blocker towards the adoption of ML-friendly types like BF16

Performance Results

So was all that work worth it?

One thing to check is the performance since JAX runs well on GPUs.

Running VFP on a GPU using JAX is about 30x faster

Here are some performance metrics with respect to grid size. $N_x$ was varied from 32 to 1024. The upper limit was primarily defined by how patient I was willing to be with the NumPy/SciPy CPU implementation. We generally see about a 30x speed-up when running on JAX and GPU.

Scaling with Nx

Solving the Fokker-Planck equation is nearly free with NumPy and the broadcasted Thomas tridiagonal solver. On JAX, however, there is a sizable performance hit when solving the Fokker-Planck equation. This can be remedied going forward by providing JAX an API for CUDA’s batched tridiagonal solve

Scaling with Nv

Here’s a similar plot but for $N_v$.

What does this mean for real simulations?

To simulate an electrostatic wave for $\sim 500 \omega_p^{-1}$, we need 1000 timesteps using the sixth order time-stepper. A possible grid resolution for these might be $N_x = 256, N_v = 2048$. We observe a single time-step with this grid-size takes 0.7 ms, and this simulation takes roughly 1s on a T4.

Okay, how about a very high-resolution electrostatic wave, like a KEEN wave simulation? For $N_x = 1024, N_v = 16384$, one timestep takes 0.3s. 2000 timesteps takes about 10 minutes. That compares well with the benchmarks in ref. [2] where a similar simulation takes 2847s i.e. 45 minutes on 128 cores. This is a nice 5x - 640x speedup (depending on how you factor in the parallelization) in comparison to their highly optimized FORTRAN code. However, part of this is speedup is because of the solvers used as well. Modern FFTs are quite fast, even compared to the straightforward tridiagonal solves that are needed for a semi-Lagrangian method.

Okay, but what about an electrostatic wavepacket like in ref. [3]? For those, the timescale is similar, but the grid is larger in $N_x$. We performed some wavepacket simulations with $N_x = 4096, N_v = 8192$ and observed the per-time-step calculation time to be 0.17 s. A full simulation takes about 5.5 minutes.

All in all, the speed-up is excellent!

Note that all of these times are on a NVIDIA T4 GPU, running it on a V100 is about 3 times faster (but 6x more expensive on AWS!)

Conclusion

Porting to JAX was straightforward. I wanted to maintain the ability to switch back to NumPy for testing and be able to swap modules in and out. I also needed to add @jit and @checkpoint decorators to many of the core functions in order to fully realize the potential in JAX. To do that and maintain backward compatibility, I did have to create dummy routines that passed those two decorators. In hindsight, much of this can be solved by using a proper configuration management system like hydra or gin.

It was definitely worth it for the speed-up alone. Being able to run the wavepacket simulation in a few minutes is a happy outcome. This enables the generation of large datasets for machine learning, and parameteric scans in general.

In a future post, I will show some results of Canonical 1D-1V VPFP problems and their performance using JAX on GPUs. In the meantime, here is a plot of the distribution function of a non-linear electrostatic wavepacket.

Non-linear electrostatic wavepacket

References

[1]: Häfner, Dion, René Løwe Jacobsen, Carsten Eden, Mads R. B. Kristensen, Markus Jochum, Roman Nuterman, and Brian Vinter. “Veros v0.1 – a Fast and Versatile Ocean Simulator in Pure Python.” Geoscientific Model Development 11, no. 8 (August 16, 2018): 3299–3312. https://doi.org/10.5194/gmd-11-3299-2018.

[2]: Afeyan, Bedros, Fernando Casas, Nicolas Crouseilles, Adila Dodhy, Erwan Faou, Michel Mehrenberger, and Eric Sonnendrücker. “Simulations of Kinetic Electrostatic Electron Nonlinear (KEEN) Waves with Variable Velocity Resolution Grids and High-Order Time-Splitting.” The European Physical Journal D 68, no. 10 (October 2014): 295. https://doi.org/10.1140/epjd/e2014-50212-6.

[3]: Fahlen, J. E., B. J. Winjum, T. Grismayer, and W. B. Mori. “Propagation and Damping of Nonlinear Plasma Wave Packets.” Physical Review Letters 102, no. 24 (June 17, 2009): 245002. https://doi.org/10.1103/PhysRevLett.102.245002.

Dr. Archis Joglekar
Dr. Archis Joglekar
ML Researcher | Research Engineer | Theoretical Physicist

I like doing math with computers. I got a PhD in fusion plasma physics. It happens to be a perfect blend of applied mathematics, physics, and computing. I used to use supercomputers to do the math, now I use the cloud. I also like to do written math. I am currently working on something new at the intersection of deep learning and fusion. I am an Affiliate Researcher with the Laboratory for Laser Energetics and I am also an Adjunct Professor at the University of Michigan.