Automatic differentiation with JAX

Here we look into automatic differentiation, which can speed up fits with very many parameters.

iminuit’s minimization algorithm MIGRAD uses a mix of gradient descent and Newton’s method to find the minimum. Both require a first derivative, which MIGRAD usually computes numerically from finite differences. This requires many function evaluations and the gradient may not be accurate. As an alternative, iminuit also allows the user to compute the gradient and pass it to MIGRAD.

Although computing derivatives is often straight-forward, it is usually too much hassle to do manually. Automatic differentiation (AD) is an interesting alternative, it allows one to compute exact derivatives efficiently for pure Python/numpy functions. We demonstrate automatic differentiation with the JAX module, which can not only compute derivatives, but also accelerates the computation of Python code (including the gradient code) with a just-in-time compiler.

Recommended read: Gentle introduction to AD

Fit of a gaussian model to a histogram

We fit a gaussian to a histogram using a maximum-likelihood approach based on Poisson statistics. This example is used to investigate how automatic differentiation can accelerate a typical fit in a counting experiment.

To compare fits with and without passing an analytic gradient fairly, we use Minuit.strategy = 0, which prevents Minuit from automatically computing the Hesse matrix after the fit.

[1]:
# !pip install jax jaxlib matplotlib numpy iminuit numba-stats

import jax
from jax import numpy as jnp  # replacement for normal numpy
from jax.scipy.special import erf  # replacement for scipy.special.erf
from iminuit import Minuit
import numba as nb
import numpy as np  # original numpy still needed, since jax does not cover full API

jax.config.update("jax_enable_x64", True)  # enable float64 precision, default is float32

print(f"JAX version {jax.__version__}")
print(f"numba version {nb.__version__}")
---------------------------------------------------------------------------
ModuleNotFoundError                       Traceback (most recent call last)
Cell In[1], line 3
      1 # !pip install jax jaxlib matplotlib numpy iminuit numba-stats
----> 3 import jax
      4 from jax import numpy as jnp  # replacement for normal numpy
      5 from jax.scipy.special import erf  # replacement for scipy.special.erf

ModuleNotFoundError: No module named 'jax'

We generate some toy data and write the negative log-likelihood (nll) for a fit to binned data, assuming Poisson-distributed counts.

Note: We write all statistical functions in pure Python code, to demonstrate Jax’s ability to automatically differentiate and JIT compile this code. In practice, one should import JIT-able statistical distributions from jax.scipy.stats. The library versions can be expected to have fewer bugs and to be faster and more accurate than hand-written code.

[2]:
# generate some toy data
rng = np.random.default_rng(seed=1)
n, xe = np.histogram(rng.normal(size=10000), bins=1000)


def cdf(x, mu, sigma):
    # cdf of a normal distribution, needed to compute the expected counts per bin
    # better alternative for real code: from jax.scipy.stats.norm import cdf
    z = (x - mu) / sigma
    return 0.5 * (1 + erf(z / np.sqrt(2)))


def nll(par):  # negative log-likelihood with constants stripped
    amp = par[0]
    mu, sigma = par[1:]
    p = cdf(xe, mu, sigma)
    mu = amp * jnp.diff(p)
    result = jnp.sum(mu - n + n * jnp.log(n / (mu + 1e-100) + 1e-100))
    return result
---------------------------------------------------------------------------
NameError                                 Traceback (most recent call last)
Cell In[2], line 2
      1 # generate some toy data
----> 2 rng = np.random.default_rng(seed=1)
      3 n, xe = np.histogram(rng.normal(size=10000), bins=1000)
      6 def cdf(x, mu, sigma):
      7     # cdf of a normal distribution, needed to compute the expected counts per bin
      8     # better alternative for real code: from jax.scipy.stats.norm import cdf

NameError: name 'np' is not defined

Let’s check results from all combinations of using JIT and gradient and then compare the execution times.

[3]:
start_values = (1.5 * np.sum(n), 1.0, 2.0)
limits = ((0, None), (None, None), (0, None))


def make_and_run_minuit(fcn, grad=None):
    m = Minuit(fcn, start_values, grad=grad, name=("amp", "mu", "sigma"))
    m.errordef = Minuit.LIKELIHOOD
    m.limits = limits
    m.strategy = 0 # do not explicitly compute hessian after minimisation
    m.migrad()
    return m
---------------------------------------------------------------------------
NameError                                 Traceback (most recent call last)
Cell In[3], line 1
----> 1 start_values = (1.5 * np.sum(n), 1.0, 2.0)
      2 limits = ((0, None), (None, None), (0, None))
      5 def make_and_run_minuit(fcn, grad=None):

NameError: name 'np' is not defined
[4]:
m1 = make_and_run_minuit(nll)
m1.fmin
---------------------------------------------------------------------------
NameError                                 Traceback (most recent call last)
Cell In[4], line 1
----> 1 m1 = make_and_run_minuit(nll)
      2 m1.fmin

NameError: name 'make_and_run_minuit' is not defined
[5]:
m2 = make_and_run_minuit(nll, grad=jax.grad(nll))
m2.fmin
---------------------------------------------------------------------------
NameError                                 Traceback (most recent call last)
Cell In[5], line 1
----> 1 m2 = make_and_run_minuit(nll, grad=jax.grad(nll))
      2 m2.fmin

NameError: name 'make_and_run_minuit' is not defined
[6]:
m3 = make_and_run_minuit(jax.jit(nll), grad=jax.grad(nll))
m3.fmin
---------------------------------------------------------------------------
NameError                                 Traceback (most recent call last)
Cell In[6], line 1
----> 1 m3 = make_and_run_minuit(jax.jit(nll), grad=jax.grad(nll))
      2 m3.fmin

NameError: name 'make_and_run_minuit' is not defined
[7]:
m4 = make_and_run_minuit(jax.jit(nll), grad=jax.jit(jax.grad(nll)))
m4.fmin
---------------------------------------------------------------------------
NameError                                 Traceback (most recent call last)
Cell In[7], line 1
----> 1 m4 = make_and_run_minuit(jax.jit(nll), grad=jax.jit(jax.grad(nll)))
      2 m4.fmin

NameError: name 'make_and_run_minuit' is not defined
[8]:
from numba_stats import norm # numba jit-able version of norm

@nb.njit
def nb_nll(par):
    amp = par[0]
    mu, sigma = par[1:]
    p = norm.cdf(xe, mu, sigma)
    mu = amp * np.diff(p)
    result = np.sum(mu - n + n * np.log(n / (mu + 1e-323) + 1e-323))
    return result

m5 = make_and_run_minuit(nb_nll)
m5.fmin
---------------------------------------------------------------------------
NameError                                 Traceback (most recent call last)
Cell In[8], line 3
      1 from numba_stats import norm # numba jit-able version of norm
----> 3 @nb.njit
      4 def nb_nll(par):
      5     amp = par[0]
      6     mu, sigma = par[1:]

NameError: name 'nb' is not defined
[9]:
from timeit import timeit

times = {
    "no JIT, no grad": "m1",
    "no JIT, grad": "m2",
    "jax JIT, no grad": "m3",
    "jax JIT, grad": "m4",
    "numba JIT, no grad": "m5",
}
for k, v in times.items():
    t = timeit(
        f"{v}.values = start_values; {v}.migrad()",
        f"from __main__ import {v}, start_values",
        number=1,
    )
    times[k] = t
---------------------------------------------------------------------------
ImportError                               Traceback (most recent call last)
Cell In[9], line 11
      3 times = {
      4     "no JIT, no grad": "m1",
      5     "no JIT, grad": "m2",
   (...)
      8     "numba JIT, no grad": "m5",
      9 }
     10 for k, v in times.items():
---> 11     t = timeit(
     12         f"{v}.values = start_values; {v}.migrad()",
     13         f"from __main__ import {v}, start_values",
     14         number=1,
     15     )
     16     times[k] = t

File /usr/lib/python3.10/timeit.py:234, in timeit(stmt, setup, timer, number, globals)
    231 def timeit(stmt="pass", setup="pass", timer=default_timer,
    232            number=default_number, globals=None):
    233     """Convenience function to create Timer object and call timeit method."""
--> 234     return Timer(stmt, setup, timer, globals).timeit(number)

File /usr/lib/python3.10/timeit.py:178, in Timer.timeit(self, number)
    176 gc.disable()
    177 try:
--> 178     timing = self.inner(it, self.timer)
    179 finally:
    180     if gcold:

File <timeit-src>:3, in inner(_it, _timer)

ImportError: cannot import name 'm1' from '__main__' (unknown location)
[10]:
from matplotlib import pyplot as plt

x = np.fromiter(times.values(), dtype=float)
xmin = np.min(x)

y = -np.arange(len(times))
plt.barh(y, x)
for yi, k, v in zip(y, times, x):
    plt.text(v, yi, f"{v/xmin:.1f}x")
plt.yticks(y, times.keys())
for loc in ("top", "right"):
    plt.gca().spines[loc].set_visible(False)
plt.xlabel("execution time / s");
---------------------------------------------------------------------------
NameError                                 Traceback (most recent call last)
Cell In[10], line 3
      1 from matplotlib import pyplot as plt
----> 3 x = np.fromiter(times.values(), dtype=float)
      4 xmin = np.min(x)
      6 y = -np.arange(len(times))

NameError: name 'np' is not defined

Conclusions:

  1. As expected, the best results are obtained by JIT compiling the function and the gradient.

  2. JIT compiling the cost function with Jax but not using the gradient gives a negligible performance improvement. Numba is able to do much better.

  3. JIT compiling the gradient is very important. Using the Python-computed gradient even drastically reduces performance in this example.

In general, the gain from using a gradient is larger for functions with hundreds of parameters, as is common in machine learning. Human-made models often have less than 10 parameters, and then the gain is not so dramatic.

Computing covariance matrices with JAX

Automatic differentiation gives us another way to compute uncertainties of fitted parameters. MINUIT compute the uncertainties with the HESSE algorithm by default, which computes the matrix of second derivates approximately using finite differences and inverts this.

Let’s compare the output of HESSE with the exact (within floating point precision) computation using automatic differentiation.

[11]:
m4.hesse()
cov_hesse = m4.covariance


def jax_covariance(par):
    return jnp.linalg.inv(jax.hessian(nll)(par))


par = np.array(m4.values)
cov_jax = jax_covariance(par)

print(
    f"sigma[amp]  : HESSE = {cov_hesse[0, 0] ** 0.5:6.1f}, JAX = {cov_jax[0, 0] ** 0.5:6.1f}"
)
print(
    f"sigma[mu]   : HESSE = {cov_hesse[1, 1] ** 0.5:6.4f}, JAX = {cov_jax[1, 1] ** 0.5:6.4f}"
)
print(
    f"sigma[sigma]: HESSE = {cov_hesse[2, 2] ** 0.5:6.4f}, JAX = {cov_jax[2, 2] ** 0.5:6.4f}"
)
---------------------------------------------------------------------------
NameError                                 Traceback (most recent call last)
Cell In[11], line 1
----> 1 m4.hesse()
      2 cov_hesse = m4.covariance
      5 def jax_covariance(par):

NameError: name 'm4' is not defined

Success, HESSE and JAX give the same answer within the relevant precision.

Note: If you compute the covariance matrix in this way from a least-squares cost function instead of a negative log-likelihood, you must multiply it by 2.

Let us compare the performance of HESSE with Jax.

[12]:
%%timeit -n 1 -r 3
m = Minuit(nll, par)
m.errordef = Minuit.LIKELIHOOD
m.hesse()
---------------------------------------------------------------------------
NameError                                 Traceback (most recent call last)
Cell In[12], line 1
----> 1 get_ipython().run_cell_magic('timeit', '-n 1 -r 3', 'm = Minuit(nll, par)\nm.errordef = Minuit.LIKELIHOOD\nm.hesse()\n')

File /usr/lib/python3.10/site-packages/IPython/core/interactiveshell.py:2417, in InteractiveShell.run_cell_magic(self, magic_name, line, cell)
   2415 with self.builtin_trap:
   2416     args = (magic_arg_s, cell)
-> 2417     result = fn(*args, **kwargs)
   2418 return result

File /usr/lib/python3.10/site-packages/IPython/core/magics/execution.py:1166, in ExecutionMagics.timeit(self, line, cell, local_ns)
   1163         if time_number >= 0.2:
   1164             break
-> 1166 all_runs = timer.repeat(repeat, number)
   1167 best = min(all_runs) / number
   1168 worst = max(all_runs) / number

File /usr/lib/python3.10/timeit.py:206, in Timer.repeat(self, repeat, number)
    204 r = []
    205 for i in range(repeat):
--> 206     t = self.timeit(number)
    207     r.append(t)
    208 return r

File /usr/lib/python3.10/site-packages/IPython/core/magics/execution.py:156, in Timer.timeit(self, number)
    154 gc.disable()
    155 try:
--> 156     timing = self.inner(it, self.timer)
    157 finally:
    158     if gcold:

File <magic-timeit>:1, in inner(_it, _timer)

NameError: name 'Minuit' is not defined
[13]:
%%timeit -n 1 -r 3
jax_covariance(par)
---------------------------------------------------------------------------
NameError                                 Traceback (most recent call last)
Cell In[13], line 1
----> 1 get_ipython().run_cell_magic('timeit', '-n 1 -r 3', 'jax_covariance(par)\n')

File /usr/lib/python3.10/site-packages/IPython/core/interactiveshell.py:2417, in InteractiveShell.run_cell_magic(self, magic_name, line, cell)
   2415 with self.builtin_trap:
   2416     args = (magic_arg_s, cell)
-> 2417     result = fn(*args, **kwargs)
   2418 return result

File /usr/lib/python3.10/site-packages/IPython/core/magics/execution.py:1166, in ExecutionMagics.timeit(self, line, cell, local_ns)
   1163         if time_number >= 0.2:
   1164             break
-> 1166 all_runs = timer.repeat(repeat, number)
   1167 best = min(all_runs) / number
   1168 worst = max(all_runs) / number

File /usr/lib/python3.10/timeit.py:206, in Timer.repeat(self, repeat, number)
    204 r = []
    205 for i in range(repeat):
--> 206     t = self.timeit(number)
    207     r.append(t)
    208 return r

File /usr/lib/python3.10/site-packages/IPython/core/magics/execution.py:156, in Timer.timeit(self, number)
    154 gc.disable()
    155 try:
--> 156     timing = self.inner(it, self.timer)
    157 finally:
    158     if gcold:

File <magic-timeit>:1, in inner(_it, _timer)

NameError: name 'jax_covariance' is not defined

The computation with Jax is faster, but not by orders of magnitude. It is also more accurate (although the added precision is not relevant).

Minuit’s HESSE algorithm still makes sense today. It has the advantage that it can process any function, while Jax cannot. Jax cannot differentiate a function that calls into C/C++ code or Cython code, for example.

Final note: If we JIT compile jax_covariance, it greatly outperforms Minuit’s HESSE algorithm, but that only makes sense if you need to compute the hessian at different parameter values, so that the extra time spend to compile is balanced by the time saved over many invokations. This is not what happens here, the Hessian in only needed at the best fit point.

[14]:
%%timeit -n 1 -r 3 jit_jax_covariance = jax.jit(jax_covariance); jit_jax_covariance(par)
jit_jax_covariance(par)
---------------------------------------------------------------------------
NameError                                 Traceback (most recent call last)
Cell In[14], line 1
----> 1 get_ipython().run_cell_magic('timeit', '-n 1 -r 3 jit_jax_covariance = jax.jit(jax_covariance); jit_jax_covariance(par)', 'jit_jax_covariance(par)\n')

File /usr/lib/python3.10/site-packages/IPython/core/interactiveshell.py:2417, in InteractiveShell.run_cell_magic(self, magic_name, line, cell)
   2415 with self.builtin_trap:
   2416     args = (magic_arg_s, cell)
-> 2417     result = fn(*args, **kwargs)
   2418 return result

File /usr/lib/python3.10/site-packages/IPython/core/magics/execution.py:1166, in ExecutionMagics.timeit(self, line, cell, local_ns)
   1163         if time_number >= 0.2:
   1164             break
-> 1166 all_runs = timer.repeat(repeat, number)
   1167 best = min(all_runs) / number
   1168 worst = max(all_runs) / number

File /usr/lib/python3.10/timeit.py:206, in Timer.repeat(self, repeat, number)
    204 r = []
    205 for i in range(repeat):
--> 206     t = self.timeit(number)
    207     r.append(t)
    208 return r

File /usr/lib/python3.10/site-packages/IPython/core/magics/execution.py:156, in Timer.timeit(self, number)
    154 gc.disable()
    155 try:
--> 156     timing = self.inner(it, self.timer)
    157 finally:
    158     if gcold:

File <magic-timeit>:1, in inner(_it, _timer)

NameError: name 'jax' is not defined

It is much faster… but only because the compilation cost is excluded here.

[15]:
%%timeit -n 1 -r 1
# if we include the JIT compilation cost, the performance drops dramatically
@jax.jit
def jax_covariance(par):
    return jnp.linalg.inv(jax.hessian(nll)(par))


jax_covariance(par)
---------------------------------------------------------------------------
NameError                                 Traceback (most recent call last)
Cell In[15], line 1
----> 1 get_ipython().run_cell_magic('timeit', '-n 1 -r 1', '# if we include the JIT compilation cost, the performance drops dramatically\n@jax.jit\ndef jax_covariance(par):\n    return jnp.linalg.inv(jax.hessian(nll)(par))\n\n\njax_covariance(par)\n')

File /usr/lib/python3.10/site-packages/IPython/core/interactiveshell.py:2417, in InteractiveShell.run_cell_magic(self, magic_name, line, cell)
   2415 with self.builtin_trap:
   2416     args = (magic_arg_s, cell)
-> 2417     result = fn(*args, **kwargs)
   2418 return result

File /usr/lib/python3.10/site-packages/IPython/core/magics/execution.py:1166, in ExecutionMagics.timeit(self, line, cell, local_ns)
   1163         if time_number >= 0.2:
   1164             break
-> 1166 all_runs = timer.repeat(repeat, number)
   1167 best = min(all_runs) / number
   1168 worst = max(all_runs) / number

File /usr/lib/python3.10/timeit.py:206, in Timer.repeat(self, repeat, number)
    204 r = []
    205 for i in range(repeat):
--> 206     t = self.timeit(number)
    207     r.append(t)
    208 return r

File /usr/lib/python3.10/site-packages/IPython/core/magics/execution.py:156, in Timer.timeit(self, number)
    154 gc.disable()
    155 try:
--> 156     timing = self.inner(it, self.timer)
    157 finally:
    158     if gcold:

File <magic-timeit>:2, in inner(_it, _timer)

NameError: name 'jax' is not defined

With compilation cost included, it is much slower.

Conclusion: Using the JIT compiler makes a lot of sense if the covariance matrix has to be computed repeatedly for the same cost function but different parameters, but this is not the case when we use it to compute parameter errors.

Fit data points with uncertainties in x and y

Let’s say we have some data points \((x_i \pm \sigma_{x,i}, y_i \pm \sigma_{y,i})\) and we have a model \(y=f(x)\) that we want to adapt to this data. If \(\sigma_{x,i}\) was zero, we could use the usual least-squares method, minimizing the sum of squared residuals \(r^2_i = (y_i - f(x_i))^2 / \sigma^2_{y,i}\). Here, we don’t know where to evaluate \(f(x)\), since the exact \(x\)-location is only known up to \(\sigma_{x,i}\).

We can approximately extend the standard least-squares method to handle this case. We use that the uncertainty along the \(x\)-axis can be converted into an additional uncertainty along the \(y\)-axis with error propagation,

\[f(x_i \pm \sigma_{x,i}) \simeq f(x_i) \pm f'(x_i)\,\sigma_{x,i}.\]

Using this, we obtain modified squared residuals

\[r^2_i = \frac{(y_i - f(x_i))^2}{\sigma^2_{y,i} + (f'(x_i) \,\sigma_{x,i})^2}.\]

We demonstrate this with a fit of a polynomial.

[16]:
# polynomial model
def f(x, par):
    return jnp.polyval(par, x)


# true polynomial f(x) = x^2 + 2 x + 3
par_true = np.array((1, 2, 3))


# grad computes derivative with respect to the first argument
f_prime = jax.jit(jax.grad(f))


# checking first derivative f'(x) = 2 x + 2
assert f_prime(0.0, par_true) == 2
assert f_prime(1.0, par_true) == 4
assert f_prime(2.0, par_true) == 6
# ok!

# generate toy data
n = 30
data_x = np.linspace(-4, 7, n)
data_y = f(data_x, par_true)

rng = np.random.default_rng(seed=1)
sigma_x = 0.5
sigma_y = 5
data_x += rng.normal(0, sigma_x, n)
data_y += rng.normal(0, sigma_y, n)
---------------------------------------------------------------------------
NameError                                 Traceback (most recent call last)
Cell In[16], line 7
      3     return jnp.polyval(par, x)
      6 # true polynomial f(x) = x^2 + 2 x + 3
----> 7 par_true = np.array((1, 2, 3))
     10 # grad computes derivative with respect to the first argument
     11 f_prime = jax.jit(jax.grad(f))

NameError: name 'np' is not defined
[17]:
plt.errorbar(data_x, data_y, sigma_y, sigma_x, fmt="o");
---------------------------------------------------------------------------
NameError                                 Traceback (most recent call last)
Cell In[17], line 1
----> 1 plt.errorbar(data_x, data_y, sigma_y, sigma_x, fmt="o");

NameError: name 'data_x' is not defined
[18]:
# define the cost function
@jax.jit
def cost(par):
    result = 0.0
    for xi, yi in zip(data_x, data_y):
        y_var = sigma_y ** 2 + (f_prime(xi, par) * sigma_x) ** 2
        result += (yi - f(xi, par)) ** 2 / y_var
    return result

cost.errordef = Minuit.LEAST_SQUARES

# test the jit-ed function
cost(np.zeros(3))
---------------------------------------------------------------------------
NameError                                 Traceback (most recent call last)
Cell In[18], line 2
      1 # define the cost function
----> 2 @jax.jit
      3 def cost(par):
      4     result = 0.0
      5     for xi, yi in zip(data_x, data_y):

NameError: name 'jax' is not defined
[19]:
m = Minuit(cost, np.zeros(3))
m.migrad()
---------------------------------------------------------------------------
NameError                                 Traceback (most recent call last)
Cell In[19], line 1
----> 1 m = Minuit(cost, np.zeros(3))
      2 m.migrad()

NameError: name 'Minuit' is not defined
[20]:
plt.errorbar(data_x, data_y, sigma_y, sigma_x, fmt="o", label="data")
x = np.linspace(data_x[0], data_x[-1], 200)
par = np.array(m.values)
plt.plot(x, f(x, par), label="fit")
plt.legend()

# check fit quality
chi2 = m.fval
ndof = len(data_y) - 3
plt.title(f"$\\chi^2 / n_\\mathrm{{dof}} = {chi2:.2f} / {ndof} = {chi2/ndof:.2f}$");
---------------------------------------------------------------------------
NameError                                 Traceback (most recent call last)
Cell In[20], line 1
----> 1 plt.errorbar(data_x, data_y, sigma_y, sigma_x, fmt="o", label="data")
      2 x = np.linspace(data_x[0], data_x[-1], 200)
      3 par = np.array(m.values)

NameError: name 'data_x' is not defined

We obtained a good fit.