# Part 2: Training of a regression network with JAX

Uncomment the command in the next cell to install all libraries needed to run this notebook.

In [1]:
#!pip install --upgrade flax optax jax jaxlib matplotlib numpy scikit-learn boost-histogram

In this exercise, we train a simple feed-forward network implemented with the libraries JAX and FLAX with data from the LHCb simulation, which simulates how the LHCb experiment detects particles produced in collisions of the LHC beams, here proton-proton collisions at 13 TeV center-of-mass energy. 

The quantity to predict here is the efficiency of the LHCb tracking system. The LHCb tracking system consists of a large magnet and several tracking detectors that trace through-going charged particles produced in the proton-proton collisions. Depending on the direction and total momentum of the produced particles, the detector can miss a particle. The particle may simply bump into the magnet or some support structure and is lost. It can also decay in flight (most particles have finite life-times).

Running the LHCb simulation software take a lot of time. We will train a neural network to learn the tracking efficiency, so that the network can act as a **fast surrogate model** for the real LHCb simulation. As input, we will use histograms produced from the simulation. These histograms contain the information **how many particles were generated and how many were detected** in the simulation. The latter divided by the former is an estimate for the **tracking efficiency**. This is what we want the network to learn.

The histograms are three-dimensional. They have three axes, which fully describe the momentum of the produced particles:

* Pseudorapidity $\eta$ (eta): measures the angle of the particle with respect to the proton beam. $\eta = 0$ means travel perpendicular to the proton beams, $\eta = \infty$ means parallel to the beam.
* Transverse momentum $p_T$: measures the momentum component of the particle perpendicular to the proton beams.
* Polar angle $\phi$ (phi): measures the polar angle of the particle with respect to a coordinate system where the z-axis is aligned with the proton beams.

Why don't we use the Cartesian components $p_x, p_y, p_z$? We could do that, but the coordinates $(\eta, p_T, \phi)$ are more useful for physicists, and they turn out to be easier to use for machine learning as well.

## Gathering and preprocessing data

Before we can start with the training, we need to prepare our inputs.

The neural network needs pairs (`X`, `y`) to train, where `X` is a vector of the inputs (here $\eta$, $p_T$, $\phi$) and the `y` is the efficiency. We construct the `X` vector from the bin centers of the histograms and compute `y` from the ratio of number of detected particles divided by number of generated particles in the simulation.

In [2]:
import pickle
import gzip
import numpy as np

with gzip.open("lhcb_tracking_rec.pkl.gz") as f:
    rec = pickle.load(f)

with gzip.open("lhcb_tracking_gen.pkl.gz") as f:
    gen = pickle.load(f)

n_gen = gen.values()
n_rec = rec.values()

X_train = []
k_train = []
n_train = []
X_test = []
k_test = []
n_test = []
for ieta, eta in enumerate(rec.axes[0].centers):
    for iphi, phi in enumerate(rec.axes[2].centers):
        for ipt, pt in enumerate(rec.axes[1].centers):
            if (ipt + ieta + iphi) % 2 == 0:
                X = X_train
                k = k_train
                n = n_train
            else:
                X = X_test
                k = k_test
                n = n_test
            X.append((eta, pt, phi))
            k.append(n_rec[ieta, ipt, iphi])
            n.append(n_gen[ieta, ipt, iphi])

# lists to numpy arrays
X_train = np.array(X_train)
k_train = np.array(k_train)
n_train = np.array(n_train)
y_train = k_train / n_train
X_test = np.array(X_test)
k_test = np.array(k_test)
n_test = np.array(n_test)
y_test = k_test / n_test

print(X_train.shape[0], "training points", X_test.shape[0], "test points")

Neural networks need the input to be standardized. Ideally, the inputs are Gaussian distributed with zero mean and unit variance.

Let's see whether that is true for our inputs.

In [3]:
import matplotlib.pyplot as plt

fig, ax = plt.subplots(1, 3, figsize=(10, 4))
for i, label in enumerate(("eta", "pt", "phi")):
    ax[i].hist(X_train[:, i], bins=20)
    ax[i].set_xlabel(label)

In `eta` and `phi`, the distribution is uniform. This is not Gaussian, but acceptable. The `pT` distribution is far from Gaussian. We make it more Gaussian by transforming `pT` to `log(pT)`.

In [4]:
X_train[:, 1] = np.log(X_train[:, 1])
X_test[:, 1] = np.log(X_test[:, 1])

In [5]:
fig, ax = plt.subplots(1, 3, figsize=(10, 4))
for i, label in enumerate(("eta", "log(pt)", "phi")):
    ax[i].hist(X_train[:, i], bins=20)
    ax[i].set_xlabel(label)

Now the log(pT) distribution is also approximately uniform (the spikes are aliasing artifacts and can be ignored).

We still need to make the distributions centered around zero with variance one. We use the `StandardScaler` from Scikit-Learn to transform the inputs.

In [6]:
from sklearn.preprocessing import StandardScaler

scaler = StandardScaler()
T_train = scaler.fit_transform(X_train)
T_test = scaler.transform(X_test)

fig, ax = plt.subplots(1, 3, figsize=(10, 4))
for i, label in enumerate(("eta", "log(pt)", "phi")):
    ax[i].hist(T_train[:, i], bins=20)
    ax[i].set_xlabel(label)

Let's have a first look at the efficiency as a function of the inputs with the following drawing function. 

In [7]:
def draw(predict, show_all_phi=False):
    for phi_i in sorted(np.unique(X_train[:, 2])):
        fig, axes = plt.subplots(3, 2, sharex=True, sharey=True, layout="compressed")
        for kind, (X, y) in enumerate(((X_train, y_train), (X_test, y_test))):
            plt.suptitle(rf"$\phi = {np.degrees(phi_i):.0f}$ deg")
            for eta_i, ax_i in zip(np.unique(X[:, 0]), axes.flat):
                plt.sca(ax_i)
                ma = X[:, 0] == eta_i
                ma &= X[:, 2] == phi_i
                plt.title(r"$\eta =$" + f"${eta_i:.2f}$")
                plt.plot(
                    np.exp(X[ma, 1]),
                    y[ma],
                    "os"[kind],
                    mfc="none",
                    ms=4,
                    label=["train", "test"][kind],
                )

                if predict and (kind == 1):
                    pt = np.geomspace(1e1, 1e4, 1000)
                    Xp = np.empty((len(pt), 3))
                    Xp[:, 0] = eta_i
                    Xp[:, 1] = np.log(pt)
                    Xp[:, 2] = phi_i
                    yp = predict(scaler.transform(Xp))
                    plt.plot(pt, yp, color="k", label="model")
            fig.supxlabel("$p_T$ / MeV$c^{-1}$")
            plt.sca(axes[0, 0])
            plt.legend(frameon=False, fontsize="x-small", handlelength=0.3)
            plt.semilogx()
        if not show_all_phi:
            break


draw(None)

## Model building and training with Flax / JAX / Optax

[JAX](https://jax.readthedocs.io/en/latest/) is a modern functional programming library to compute derivatives of pure functions and to JIT compile pure functions. [Flax](https://flax.readthedocs.io/en/latest/) is a library to build neural networks for JAX. [Optax] is a gradient processing and optimization library for JAX. We use these libraries to implement a neural network model and the training loop.

JAX is a modern functional programming library which offers the greatest amount of transparency and flexibility. While it takes some experience to write JAX code correctly, JAX code is easy to read since it mimics the well-known APIs of Numpy and Scipy. JAX can also make training code run faster than other libraries, thanks to its JIT-compiler. However, JAX comes with fewer high-level functionality than other libraries.

We first create the neural network. We make a simple feed-forward network consisting of a linear mapping ("Dense" layer) and a ReLU activation function.

In [8]:
from typing import Sequence
from flax import linen as nn


# Models are created by deriving from `flax.linen.Module`
# and implementing (usually) the `__call__` method
class Model(nn.Module):
    layer: Sequence[int]  # Number of hidden nodes per layer

    # x is a vector of N input features (N still unspecified)
    @nn.compact
    def __call__(self, x):
        for k in self.layer:
            x = nn.Dense(k)(x)
            x = nn.relu(x)
        # output of the model is 1 feature
        x = nn.Dense(1)(x).flatten()
        return x


# we create a model with 3 layers of 128 nodes each
model = Model((128, 128, 128))

In [9]:
import jax
import jax.numpy as jnp
import optax
from types import SimpleNamespace

print("JAX local devices:", *jax.local_devices())

If your computer has a suitable graphic card and if you install the right version of `jaxlib`, the GPU cores will show up here. Otherwise, your CPU will show.

Next, we write the training loop. Although our dataset is very small, we split the input into smaller batches for training, since this improves the convergence.

In [10]:
def train(config, model, T_train, y_train, T_test, y_test, seed=0):
    # load data onto computing device, necessary to use a GPU
    T_train = jax.device_put(T_train)
    y_train = jax.device_put(y_train)
    T_test = jax.device_put(T_test)
    y_test = jax.device_put(y_test)

    # create root key for the pseudo-random number generator
    root_key = jax.random.PRNGKey(seed)
    # create one PRNG key for initialization and one for batching
    init_key, batch_key = jax.random.split(root_key)

    # initialize model and infer input shape from one input sample
    params = model.init(init_key, T_train[:1])

    # count and print number of parameters
    n = 0
    for layer in params["params"].values():
        n += np.prod(layer["kernel"].shape) + layer["bias"].shape[0]
    print("number of parameters", n)

    # create optimizer and initialize with model parameters
    opt = optax.nadamw(learning_rate=config.learning_rate)
    opt_state = opt.init(params)

    # our loss function, JIT-compiled with JAX
    @jax.jit
    def loss_fn(params, T, y):
        y_pred = model.apply(params, T)
        # l2 loss is also called squared error loss
        lo = optax.l2_loss(y_pred, y)
        return jnp.sum(lo)

    # our step function, JIT-compiled with JAX
    @jax.jit
    def step(params, opt_state, T, y):
        grads = jax.grad(loss_fn)(params, T, y)
        updates, opt_state = opt.update(grads, opt_state, params)
        params = optax.apply_updates(params, updates)
        return params

    # training loop
    loss_train = []
    loss_test = []
    best_epoch = 0
    best_params = params
    assert len(T_train) % config.batch_size == 0
    for epoch in range(config.max_epoch):
        # compute random shuffling of training data to avoid cycles
        batch_train_key = jax.random.fold_in(batch_key, epoch)
        perm = jax.random.permutation(batch_train_key, np.arange(len(T_train)))
        T_perm = T_train[perm]
        y_perm = y_train[perm]

        # loop over batches and compute parameter updates from each batch
        for batch_idx in range(0, len(T_perm), config.batch_size):
            T = T_perm[batch_idx : batch_idx + config.batch_size]
            y = y_perm[batch_idx : batch_idx + config.batch_size]
            # here the actual magic happens: params is updated
            params = step(params, opt_state, T, y)

        # compute loss on training data (as a metric)
        loss = loss_fn(params, T_train, y_train)
        loss_train.append(loss.item())

        # compute loss on test data (as a metric)
        loss = loss_fn(params, T_test, y_test)
        loss_test.append(loss.item())

        if epoch % config.print_freq == 0:
            print(
                f"epoch {epoch} "
                f"loss[train]={loss_train[-1]:.2f} "
                f"loss[test]={loss_test[-1]:.2f}"
            )

        # store params with current best test loss
        if loss_test[-1] < loss_test[best_epoch]:
            best_epoch = epoch
            best_params = params

    print(f"best loss[test]={loss_test[best_epoch]:.2f}")

    return best_params, loss_train, loss_test


config = SimpleNamespace()
config.max_epoch = 500
config.learning_rate = 1e-3
config.print_freq = 100
config.batch_size = 10

params, loss_train, loss_test = train(config, model, T_train, y_train, T_test, y_test)

plt.plot(loss_train, label="train")
plt.plot(loss_test, label="test")
plt.legend()
plt.semilogy()

draw(lambda X: model.apply(params, X))

The train and loss curves are very jittery, what's going on?

## Exercise 1

Try reducing the learning rate by an order of magnitude and run the precious cell again. Repeat until the jitter largely disappears. Can you guess what happens here? Note: the loss function can have a steep and narrow minimum.

In [11]:
# develop your solution here

A large learning rate makes the network converge faster, but a too large learning rate leads to overshoot: the model jumps over the minimum. Fiddling around with the learning rate manually is annoying, moreover, it would be optimal to use a larger learning rate in the beginning and have it decline towards the end. For these reasons, it is a good idea to adjust the learning rate dynamically. The simplest approach is to reduce the learning rate exponentially as the epoch increases.

We can do better. Let's write an algorithm to reduce the learning rate when we reach a plateau with oscillations.

## Exercise 2

Copy the cell with the `train` function into the cell below. Replace the line
```py
opt = optax.nadamw(learning_rate=config.learning_rate)
```
with the line
```py
opt = optax.inject_hyperparams(optax.nadamw)(learning_rate=config.learning_rate)
```
This will allow us to modify the learning rate (a hyperparameter) in the training loop. To reduce the learning rate by 10 %, you can use this line
```py
opt_state.hyperparams["learning_rate"] *= 0.9
```
Note that you have to change `opt_state` and not `opt` which is a pure function.

Alter the training loop so that the learning rate is reduced whenever the smallest test loss in the last 10 epochs is not smaller than the smallest test loss in the previous 10 epochs. This indicates a plateau in the loss curve. Make sure that you only check this condition every 10 epochs and not every epoch. Stop the training if the learning rate becomes smaller than 1e-10.

You should be able to achieve a lower test loss with this dynamic schedule within the budget of 500 epochs.

In [12]:
# develop your solution here

## Exercise 3

If we investigate the model output closely, we will find that it may predict efficiencies less than 0 or larger than 1. A neural network is unbounded by default, the output can range from -infinity to +infinity. This is an issue if we want to interpret the output of the network as a probability.

To enforce that the output of the model is a valid probability, use the sigmoid function:
$$
p(x) = \frac{1}{1 + e^{-x}}
$$

Copy the previous cell with the `train` function into the cell below. Modify the `loss_fn`. Let the model predict `y_pred` as before, but compute `p_pred` inside the loss function, which is the output of the sigmoid function feed with `y_pred`. Use `flax.linen.sigmoid` to compute the sigmoid function. You also have to apply this transform in the `draw` function. You need to replace the line
```py
draw(lambda X: model.apply(params, X))
```
with
```py
draw(lambda X: nn.sigmoid(model.apply(params, X)))
```

In [13]:
# develop your solution here

The model is now guaranteed to compute a real probability, but the result may look slightly worse than before. The reason is that the sigmoid function is highly non-linear and more layers in the model may be needed to adapt to this non-linearity.

## Exercise 4

So far we have used the simple L2 loss, but this loss function is not correct for this problem. Our inputs have statistical uncertainties since they originate from a simulation sample of finite size. The L2 loss treats all points equally, but we can be less certain that the true efficiency in a bin is really 0.5 if we observe 1 reconstructed particle for 2 generated particles, than if we observe 500 reconstructed particles for 1000 generated particles.

To take the accuracy of individual bins into account, we use the negative log-likelihood as the loss function. The probability $P$ to observe $k$ reconstructed particles for a given number $n$ of generated particles and efficiency $p$ in a bin is given by the binomial distribution
$$
P(k; n, p) = \frac{n!}{k!(n-k)!} \, p^k \, (1 - p)^{n-k}
$$
Here, $p$ is the prediction by the neural network. We can compute this probability for each bin and the best model would maximize the product of all these probabilities. It turns out that this is a not a good computational strategy though, it is better to compute the log-probability per bin and the sum over those log-probabilities. To turn this sum into a loss, we negate it. This gives us the so-called *negative log-likelihood*.

Copy the previous cell with the `train` function into the cell below. Alter the `loss_fn` so that it computes the negative log-likelihood.

Hints:
* You need to pass `k_train`, `k_test`, `n_train`, `n_test` to `train`, instead of `y_train`, `y_test`. You also need to change all inner functions that accept `y` so that they accept `k` and `n` instead.
* You can drop the term $\dfrac{n!}{k!(n-k)!}$ from the computation. It contributes only a constant offset, which provides nothing for us.
* You can compute `log(p_pred)` directly from `y_pred` with `flax.linen.activation.log_sigmoid(y_pred)`. Library implementations typically include clever numerical optimizations that work better than a naive translation of mathematical formulas into code.
* `log(1 - p_pred)` is best computed by `flax.linen.activation.log_sigmoid(-y_pred)` (feel free to check this on paper).

If you implement the loss correctly, you will get huge loss values that are not comparable to before, but the network should nevertheless train to a good state.

In [14]:
# develop your solution here

If there is time left, explore different ways to make the model fit better. Explore other network architectures, increase the number of epochs, change the learning rate...

## Where to go from here?

In this exercise, we trained a simple model on a tiny dataset. The model is probably not fitting the data perfectly by the end of this lesson, but a satisfactory fit can be achieved with a larger dataset than the one we let you use here. Further important concepts that we did not explore here that you can study are:

#### Add BatchNorm layers to the model

Batch normalization accelerates the training of deep networks. The idea is to partially decouple the training of individual layers. BatchNorm layers should be placed after the activation.

#### Add Dropout layers to the model

Dropout is a regularization technique. When a dropout layer is present, the network is forced to not rely on single network nodes very much and instead distribute its computation on several nodes. This can make the network generalize better or help with training on a small sample. Dropout layers should be placed near the end of the network, for example, after the last hidden layer.