Getting started

This tutorial is based on the quickstart example in the celerite documentation, but it has been updated to work with celerite2.

For this tutorial, we’re going to fit a Gaussian Process (GP) model to a simulated dataset with quasiperiodic oscillations. We’re also going to leave a gap in the simulated data and we’ll use the GP model to predict what we would have observed for those “missing” datapoints.

To start, here’s some code to simulate the dataset:

import numpy as np
import matplotlib.pyplot as plt

np.random.seed(42)

t = np.sort(
    np.append(
        np.random.uniform(0, 3.8, 57),
        np.random.uniform(5.5, 10, 68),
    )
)  # The input coordinates must be sorted
yerr = np.random.uniform(0.08, 0.22, len(t))
y = (
    0.2 * (t - 5)
    + np.sin(3 * t + 0.1 * (t - 5) ** 2)
    + yerr * np.random.randn(len(t))
)

true_t = np.linspace(0, 10, 500)
true_y = 0.2 * (true_t - 5) + np.sin(3 * true_t + 0.1 * (true_t - 5) ** 2)

plt.plot(true_t, true_y, "k", lw=1.5, alpha=0.3)
plt.errorbar(t, y, yerr=yerr, fmt=".k", capsize=0)
plt.xlabel("x [day]")
plt.ylabel("y [ppm]")
plt.xlim(0, 10)
plt.ylim(-2.5, 2.5)
_ = plt.title("simulated data")
../../_images/99a372f7016caace6999e3e76704b4eaa3341d5bc2a88f174695dfdc6fc21ad3.png

Now, let’s fit this dataset using a mixture of SHOTerm terms: one quasi-periodic component and one non-periodic component. First let’s set up an initial model to see how it looks:

import celerite2
from celerite2 import terms

# Quasi-periodic term
term1 = terms.SHOTerm(sigma=1.0, rho=1.0, tau=10.0)

# Non-periodic component
term2 = terms.SHOTerm(sigma=1.0, rho=5.0, Q=0.25)
kernel = term1 + term2

# Setup the GP
gp = celerite2.GaussianProcess(kernel, mean=0.0)
gp.compute(t, yerr=yerr)

print("Initial log likelihood: {0}".format(gp.log_likelihood(y)))
Initial log likelihood: -16.751640798326278

Let’s look at the underlying power spectral density of this initial model:

freq = np.linspace(1.0 / 8, 1.0 / 0.3, 500)
omega = 2 * np.pi * freq


def plot_psd(gp):
    for n, term in enumerate(gp.kernel.terms):
        plt.loglog(freq, term.get_psd(omega), label="term {0}".format(n + 1))
    plt.loglog(freq, gp.kernel.get_psd(omega), ":k", label="full model")
    plt.xlim(freq.min(), freq.max())
    plt.legend()
    plt.xlabel("frequency [1 / day]")
    plt.ylabel("power [day ppt$^2$]")


plt.title("initial psd")
plot_psd(gp)
../../_images/c3623ff40baf598b8c1cf0716d5b6f9d5c6b02dbc366b3fde2631da8f0af1e9b.png

And then we can also plot the prediction that this model makes for the missing data and compare it to the truth:

def plot_prediction(gp):
    plt.plot(true_t, true_y, "k", lw=1.5, alpha=0.3, label="data")
    plt.errorbar(t, y, yerr=yerr, fmt=".k", capsize=0, label="truth")

    if gp:
        mu, variance = gp.predict(y, t=true_t, return_var=True)
        sigma = np.sqrt(variance)
        plt.plot(true_t, mu, label="prediction")
        plt.fill_between(true_t, mu - sigma, mu + sigma, color="C0", alpha=0.2)

    plt.xlabel("x [day]")
    plt.ylabel("y [ppm]")
    plt.xlim(0, 10)
    plt.ylim(-2.5, 2.5)
    plt.legend()


plt.title("initial prediction")
plot_prediction(gp)
../../_images/a7ba8c074ec5eac1010c467f8edf7f8739109b1db93387a4572a3e33aa36fd4d.png

Ok, that looks pretty terrible, but we can get a better fit by numerically maximizing the likelihood as described in the following section.

Maximum likelihood

In this section, we’ll improve our initial GP model by maximizing the likelihood function for the parameters of the kernel, the mean, and a “jitter” (a constant variance term added to the diagonal of our covariance matrix). To do this, we’ll use the numerical optimization routine from scipy:

from scipy.optimize import minimize


def set_params(params, gp):
    gp.mean = params[0]
    theta = np.exp(params[1:])
    gp.kernel = terms.SHOTerm(
        sigma=theta[0], rho=theta[1], tau=theta[2]
    ) + terms.SHOTerm(sigma=theta[3], rho=theta[4], Q=0.25)
    gp.compute(t, diag=yerr**2 + theta[5], quiet=True)
    return gp


def neg_log_like(params, gp):
    gp = set_params(params, gp)
    return -gp.log_likelihood(y)


initial_params = [0.0, 0.0, 0.0, np.log(10.0), 0.0, np.log(5.0), np.log(0.01)]
soln = minimize(neg_log_like, initial_params, method="L-BFGS-B", args=(gp,))
opt_gp = set_params(soln.x, gp)
soln
  message: CONVERGENCE: REL_REDUCTION_OF_F_<=_FACTR*EPSMCH
  success: True
   status: 0
      fun: -15.94282638406871
        x: [ 4.596e-03 -3.416e-01  7.000e-01  1.944e+00  6.059e-01
             3.755e+00 -7.873e+00]
      nit: 52
      jac: [-4.263e-05 -2.714e-04 -1.990e-05  6.253e-05  3.837e-05
             2.132e-05  3.695e-05]
     nfev: 464
     njev: 58
 hess_inv: <7x7 LbfgsInvHessProduct with dtype=float64>

Now let’s make the same plots for the maximum likelihood model:

plt.figure()
plt.title("maximum likelihood psd")
plot_psd(opt_gp)

plt.figure()
plt.title("maximum likelihood prediction")
plot_prediction(opt_gp)
../../_images/fe4672d085671df8977f94e46d707cd2391154570ecb2078719a82c1e75ec7c5.png ../../_images/75e8a10df0493e8a9f13f03f543def8cb73ff61e432058676822f540cfb53e57.png

These predictions are starting to look much better!

Posterior inference using emcee

Now, to get a sense for the uncertainties on our model, let’s use Markov chain Monte Carlo (MCMC) to numerically estimate the posterior expectations of the model. In this first example, we’ll use the emcee package to run our MCMC. Our likelihood function is the same as the one we used in the previous section, but we’ll also choose a wide normal prior on each of our parameters.

import emcee

prior_sigma = 2.0


def log_prob(params, gp):
    gp = set_params(params, gp)
    return (
        gp.log_likelihood(y) - 0.5 * np.sum((params / prior_sigma) ** 2),
        gp.kernel.get_psd(omega),
    )


np.random.seed(5693854)
coords = soln.x + 1e-5 * np.random.randn(32, len(soln.x))
sampler = emcee.EnsembleSampler(
    coords.shape[0], coords.shape[1], log_prob, args=(gp,)
)
state = sampler.run_mcmc(coords, 2000, progress=False)
sampler.reset()
state = sampler.run_mcmc(state, 5000, progress=False)

After running our MCMC, we can plot the predictions that the model makes for a handful of samples from the chain. This gives a qualitative sense of the uncertainty in the predictions.

chain = sampler.get_chain(discard=100, flat=True)

for sample in chain[np.random.randint(len(chain), size=50)]:
    gp = set_params(sample, gp)
    conditional = gp.condition(y, true_t)
    plt.plot(true_t, conditional.sample(), color="C0", alpha=0.1)

plt.title("posterior prediction")
plot_prediction(None)
../../_images/c882cc87e98e33831d7037e5917eeb5507f99ad9d7ba72df4f66d3cf3c4885db.png

Similarly, we can plot the posterior expectation for the power spectral density:

psds = sampler.get_blobs(discard=100, flat=True)

q = np.percentile(psds, [16, 50, 84], axis=0)

plt.loglog(freq, q[1], color="C0")
plt.fill_between(freq, q[0], q[2], color="C0", alpha=0.1)

plt.xlim(freq.min(), freq.max())
plt.xlabel("frequency [1 / day]")
plt.ylabel("power [day ppt$^2$]")
_ = plt.title("posterior psd using emcee")
../../_images/44f2d30cf3bc2a359146dee8af4b9b641e1f6b5e912433cfa000b3d7ffa973d7.png

Posterior inference using PyMC

celerite2 also includes support for probabilistic modeling using PyMC (v5 or v3, using the celerite2.pymc or celerite2.pymc3 submodule respectively), and we can implement the same model from above as follows:

import pymc as pm
from celerite2.pymc import GaussianProcess, terms as pm_terms

with pm.Model() as model:
    mean = pm.Normal("mean", mu=0.0, sigma=prior_sigma)
    log_jitter = pm.Normal("log_jitter", mu=0.0, sigma=prior_sigma)

    log_sigma1 = pm.Normal("log_sigma1", mu=0.0, sigma=prior_sigma)
    log_rho1 = pm.Normal("log_rho1", mu=0.0, sigma=prior_sigma)
    log_tau = pm.Normal("log_tau", mu=0.0, sigma=prior_sigma)
    term1 = pm_terms.SHOTerm(
        sigma=pm.math.exp(log_sigma1),
        rho=pm.math.exp(log_rho1),
        tau=pm.math.exp(log_tau),
    )

    log_sigma2 = pm.Normal("log_sigma2", mu=0.0, sigma=prior_sigma)
    log_rho2 = pm.Normal("log_rho2", mu=0.0, sigma=prior_sigma)
    term2 = pm_terms.SHOTerm(
        sigma=pm.math.exp(log_sigma2), rho=pm.math.exp(log_rho2), Q=0.25
    )

    kernel = term1 + term2
    gp = GaussianProcess(kernel, mean=mean)
    gp.compute(t, diag=yerr**2 + pm.math.exp(log_jitter), quiet=True)
    gp.marginal("obs", observed=y)

    pm.Deterministic("psd", kernel.get_psd(omega))

    trace = pm.sample(
        tune=1000,
        draws=1000,
        target_accept=0.9,
        init="adapt_full",
        cores=2,
        chains=2,
        random_seed=34923,
    )
WARNING (pytensor.tensor.blas): Using NumPy C-API based implementation for BLAS functions.
/home/docs/checkouts/readthedocs.org/user_builds/celerite2/envs/latest/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
  from .autonotebook import tqdm as notebook_tqdm
Auto-assigning NUTS sampler...
Initializing NUTS using adapt_full...
/home/docs/checkouts/readthedocs.org/user_builds/celerite2/envs/latest/lib/python3.10/site-packages/pymc/step_methods/hmc/quadpotential.py:627: UserWarning: QuadPotentialFullAdapt is an experimental feature
  warnings.warn("QuadPotentialFullAdapt is an experimental feature")
Multiprocess sampling (2 chains in 2 jobs)
NUTS: [mean, log_jitter, log_sigma1, log_rho1, log_tau, log_sigma2, log_rho2]
/home/docs/.asdf/installs/python/3.10.13/lib/python3.10/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.
  self.pid = os.fork()
/home/docs/checkouts/readthedocs.org/user_builds/celerite2/envs/latest/lib/python3.10/site-packages/rich/live.py:23
1: UserWarning: install "ipywidgets" for Jupyter support
  warnings.warn('install "ipywidgets" for Jupyter support')
/home/docs/.asdf/installs/python/3.10.13/lib/python3.10/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.
  self.pid = os.fork()


Sampling 2 chains for 1_000 tune and 1_000 draw iterations (2_000 + 2_000 draws total) took 26 seconds.
There were 1 divergences after tuning. Increase `target_accept` or reparameterize.
We recommend running at least 4 chains for robust computation of convergence diagnostics

Like before, we can plot the posterior estimate of the power spectrum to show that the results are qualitatively similar:

psds = trace.posterior["psd"].values

q = np.percentile(psds, [16, 50, 84], axis=(0, 1))

plt.loglog(freq, q[1], color="C0")
plt.fill_between(freq, q[0], q[2], color="C0", alpha=0.1)

plt.xlim(freq.min(), freq.max())
plt.xlabel("frequency [1 / day]")
plt.ylabel("power [day ppt$^2$]")
_ = plt.title("posterior psd using PyMC")
../../_images/e6e5d4d2af0eb3ee234ba73bdaa2f0a1ff95539517fdc82d905b285788c83cda.png

Posterior inference using numpyro

Since celerite2 also includes support for JAX, you can also use tools like numpyro for inference.

from jax import config

config.update("jax_enable_x64", True)

from jax import random
import jax.numpy as jnp

import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS

import celerite2.jax
from celerite2.jax import terms as jax_terms


def numpyro_model(t, yerr, y=None):
    mean = numpyro.sample("mean", dist.Normal(0.0, prior_sigma))
    log_jitter = numpyro.sample("log_jitter", dist.Normal(0.0, prior_sigma))

    log_sigma1 = numpyro.sample("log_sigma1", dist.Normal(0.0, prior_sigma))
    log_rho1 = numpyro.sample("log_rho1", dist.Normal(0.0, prior_sigma))
    log_tau = numpyro.sample("log_tau", dist.Normal(0.0, prior_sigma))
    term1 = jax_terms.SHOTerm(
        sigma=jnp.exp(log_sigma1), rho=jnp.exp(log_rho1), tau=jnp.exp(log_tau)
    )

    log_sigma2 = numpyro.sample("log_sigma2", dist.Normal(0.0, prior_sigma))
    log_rho2 = numpyro.sample("log_rho2", dist.Normal(0.0, prior_sigma))
    term2 = jax_terms.SHOTerm(
        sigma=jnp.exp(log_sigma2), rho=jnp.exp(log_rho2), Q=0.25
    )

    kernel = term1 + term2
    gp = celerite2.jax.GaussianProcess(kernel, mean=mean)
    gp.compute(t, diag=yerr**2 + jnp.exp(log_jitter), check_sorted=False)

    numpyro.sample("obs", gp.numpyro_dist(), obs=y)
    numpyro.deterministic("psd", kernel.get_psd(omega))


nuts_kernel = NUTS(numpyro_model, dense_mass=True)
mcmc = MCMC(
    nuts_kernel,
    num_warmup=1000,
    num_samples=1000,
    num_chains=2,
    progress_bar=False,
)
rng_key = random.PRNGKey(34923)
%time mcmc.run(rng_key, t, yerr, y=y)
/tmp/ipykernel_1038/987984401.py:42: UserWarning: There are not enough devices to run parallel chains: expected 2 but got 1. Chains will be drawn sequentially. If you are running MCMC in CPU, consider using `numpyro.set_host_device_count(2)` at the beginning of your program. You can double-check how many devices are available in your system using `jax.local_device_count()`.
  mcmc = MCMC(
CPU times: user 11.8 s, sys: 156 ms, total: 12 s
Wall time: 11.9 s

This runtime was similar to the PyMC result from above, and (as we’ll see below) the convergence is also similar. Any difference in runtime will probably disappear for more computationally expensive models, but this interface is looking pretty great here!

As above, we can plot the posterior expectations for the power spectrum:

psds = np.asarray(mcmc.get_samples()["psd"])

q = np.percentile(psds, [16, 50, 84], axis=0)

plt.loglog(freq, q[1], color="C0")
plt.fill_between(freq, q[0], q[2], color="C0", alpha=0.1)

plt.xlim(freq.min(), freq.max())
plt.xlabel("frequency [1 / day]")
plt.ylabel("power [day ppt$^2$]")
_ = plt.title("posterior psd using numpyro")
../../_images/d809bbc4c12626f84272972a54a8e2fb6f3fb1fd6d73b572e2ed8f6a25263c43.png

Comparison

Finally, let’s compare the results of these different inference methods a bit more quantitaively. First, let’s look at the posterior constraint on the period of the underdamped harmonic oscillator, the effective period of the oscillatory signal.

import arviz as az

emcee_data = az.from_emcee(
    sampler,
    var_names=[
        "mean",
        "log_sigma1",
        "log_rho1",
        "log_tau",
        "log_sigma2",
        "log_rho2",
        "log_jitter",
    ],
)

pm_data = trace
numpyro_data = az.from_numpyro(mcmc)

bins = np.linspace(1.5, 2.75, 25)
plt.hist(
    np.exp(np.asarray((emcee_data.posterior["log_rho1"].T)).flatten()),
    bins,
    histtype="step",
    density=True,
    label="emcee",
)
plt.hist(
    np.exp(np.asarray((pm_data.posterior["log_rho1"].T)).flatten()),
    bins,
    histtype="step",
    density=True,
    label="PyMC",
)
plt.hist(
    np.exp(np.asarray((numpyro_data.posterior["log_rho1"].T)).flatten()),
    bins,
    histtype="step",
    density=True,
    label="numpyro",
)
plt.legend()
plt.yticks([])
plt.xlabel(r"$\rho_1$")
_ = plt.ylabel(r"$p(\rho_1)$")
../../_images/a7f7a363c68e1cd657c521eabd5eef61aa512b8595b037fa8bba54114e7f6b08.png

That looks pretty consistent.

Next we can look at the ArviZ summary for each method to see how the posterior expectations and convergence diagnostics look.

az.summary(
    emcee_data,
    var_names=[
        "mean",
        "log_sigma1",
        "log_rho1",
        "log_tau",
        "log_sigma2",
        "log_rho2",
        "log_jitter",
    ],
)
mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk ess_tail r_hat
mean -0.006 1.257 -2.491 2.451 0.034 0.024 1627.0 3418.0 1.02
log_sigma1 -0.297 0.331 -0.869 0.338 0.008 0.005 1865.0 3584.0 1.02
log_rho1 0.701 0.059 0.591 0.810 0.001 0.001 1782.0 4277.0 1.02
log_tau 1.889 0.753 0.569 3.337 0.019 0.013 1680.0 3647.0 1.02
log_sigma2 0.524 0.655 -0.608 1.769 0.016 0.012 1608.0 3082.0 1.02
log_rho2 3.322 0.884 1.759 4.981 0.022 0.015 1650.0 4081.0 1.02
log_jitter -5.829 0.723 -7.180 -4.559 0.020 0.014 1427.0 2897.0 1.02
az.summary(
    pm_data,
    var_names=[
        "mean",
        "log_sigma1",
        "log_rho1",
        "log_tau",
        "log_sigma2",
        "log_rho2",
        "log_jitter",
    ],
)
mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk ess_tail r_hat
mean -0.009 1.224 -2.288 2.511 0.037 0.028 1193.0 1124.0 1.0
log_sigma1 -0.288 0.331 -0.878 0.313 0.009 0.006 1571.0 1053.0 1.0
log_rho1 0.702 0.059 0.604 0.816 0.002 0.002 1097.0 600.0 1.0
log_tau 1.924 0.753 0.510 3.283 0.020 0.015 1539.0 1176.0 1.0
log_sigma2 0.510 0.660 -0.637 1.743 0.019 0.016 1418.0 997.0 1.0
log_rho2 3.302 0.878 1.697 4.848 0.025 0.018 1315.0 1212.0 1.0
log_jitter -5.840 0.753 -7.275 -4.501 0.022 0.016 1460.0 951.0 1.0
az.summary(
    numpyro_data,
    var_names=[
        "mean",
        "log_sigma1",
        "log_rho1",
        "log_tau",
        "log_sigma2",
        "log_rho2",
        "log_jitter",
    ],
)
mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk ess_tail r_hat
mean 0.031 1.225 -2.541 2.301 0.035 0.032 1349.0 1013.0 1.0
log_sigma1 -0.286 0.351 -0.942 0.373 0.009 0.006 1785.0 1087.0 1.0
log_rho1 0.700 0.055 0.598 0.810 0.001 0.001 1772.0 1390.0 1.0
log_tau 1.954 0.771 0.666 3.504 0.021 0.016 1490.0 1250.0 1.0
log_sigma2 0.530 0.652 -0.577 1.772 0.021 0.017 1037.0 1015.0 1.0
log_rho2 3.314 0.863 1.853 4.960 0.028 0.020 1055.0 1094.0 1.0
log_jitter -5.817 0.744 -7.214 -4.522 0.020 0.015 1651.0 1203.0 1.0

Overall these results are consistent, but the $\hat{R}$ values are a bit high for the emcee run, so I’d probably run that for longer. Either way, for models like these, PyMC and numpyro are generally going to be much better inference tools (in terms of runtime per effective sample) than emcee, so those are the recommended interfaces if the rest of your model can be easily implemented in such a framework.