Getting started

This tutorial is based on the quickstart example in the celerite documentation, but it has been updated to work with celerite2.

For this tutorial, we’re going to fit a Gaussian Process (GP) model to a simulated dataset with quasiperiodic oscillations. We’re also going to leave a gap in the simulated data and we’ll use the GP model to predict what we would have observed for those “missing” datapoints.

To start, here’s some code to simulate the dataset:

[3]:
import numpy as np
import matplotlib.pyplot as plt

np.random.seed(42)

t = np.sort(
    np.append(
        np.random.uniform(0, 3.8, 57),
        np.random.uniform(5.5, 10, 68),
    )
)  # The input coordinates must be sorted
yerr = np.random.uniform(0.08, 0.22, len(t))
y = (
    0.2 * (t - 5)
    + np.sin(3 * t + 0.1 * (t - 5) ** 2)
    + yerr * np.random.randn(len(t))
)

true_t = np.linspace(0, 10, 500)
true_y = 0.2 * (true_t - 5) + np.sin(3 * true_t + 0.1 * (true_t - 5) ** 2)

plt.plot(true_t, true_y, "k", lw=1.5, alpha=0.3)
plt.errorbar(t, y, yerr=yerr, fmt=".k", capsize=0)
plt.xlabel("x [day]")
plt.ylabel("y [ppm]")
plt.xlim(0, 10)
plt.ylim(-2.5, 2.5)
_ = plt.title("simulated data")
../../_images/tutorials_first_3_0.png

Now, let’s fit this dataset using a mixture of SHOTerm terms: one quasi-periodic component and one non-periodic component. First let’s set up an initial model to see how it looks:

[4]:
import celerite2
from celerite2 import terms

# Quasi-periodic term
term1 = terms.SHOTerm(sigma=1.0, rho=1.0, tau=10.0)

# Non-periodic component
term2 = terms.SHOTerm(sigma=1.0, rho=5.0, Q=0.25)
kernel = term1 + term2

# Setup the GP
gp = celerite2.GaussianProcess(kernel, mean=0.0)
gp.compute(t, yerr=yerr)

print("Initial log likelihood: {0}".format(gp.log_likelihood(y)))
Initial log likelihood: -16.75164079832632

Let’s look at the underlying power spectral density of this initial model:

[5]:
freq = np.linspace(1.0 / 8, 1.0 / 0.3, 500)
omega = 2 * np.pi * freq


def plot_psd(gp):
    for n, term in enumerate(gp.kernel.terms):
        plt.loglog(freq, term.get_psd(omega), label="term {0}".format(n + 1))
    plt.loglog(freq, gp.kernel.get_psd(omega), ":k", label="full model")
    plt.xlim(freq.min(), freq.max())
    plt.legend()
    plt.xlabel("frequency [1 / day]")
    plt.ylabel("power [day ppt$^2$]")


plt.title("initial psd")
plot_psd(gp)
../../_images/tutorials_first_7_0.png

And then we can also plot the prediction that this model makes for the missing data and compare it to the truth:

[6]:
def plot_prediction(gp):
    plt.plot(true_t, true_y, "k", lw=1.5, alpha=0.3, label="data")
    plt.errorbar(t, y, yerr=yerr, fmt=".k", capsize=0, label="truth")

    if gp:
        mu, variance = gp.predict(y, t=true_t, return_var=True)
        sigma = np.sqrt(variance)
        plt.plot(true_t, mu, label="prediction")
        plt.fill_between(true_t, mu - sigma, mu + sigma, color="C0", alpha=0.2)

    plt.xlabel("x [day]")
    plt.ylabel("y [ppm]")
    plt.xlim(0, 10)
    plt.ylim(-2.5, 2.5)
    plt.legend()


plt.title("initial prediction")
plot_prediction(gp)
../../_images/tutorials_first_9_0.png

Ok, that looks pretty terrible, but we can get a better fit by numerically maximizing the likelihood as described in the following section.

Maximum likelihood

In this section, we’ll improve our initial GP model by maximizing the likelihood function for the parameters of the kernel, the mean, and a “jitter” (a constant variance term added to the diagonal of our covariance matrix). To do this, we’ll use the numerical optimization routine from scipy:

[7]:
from scipy.optimize import minimize


def set_params(params, gp):
    gp.mean = params[0]
    theta = np.exp(params[1:])
    gp.kernel = terms.SHOTerm(
        sigma=theta[0], rho=theta[1], tau=theta[2]
    ) + terms.SHOTerm(sigma=theta[3], rho=theta[4], Q=0.25)
    gp.compute(t, diag=yerr ** 2 + theta[5], quiet=True)
    return gp


def neg_log_like(params, gp):
    gp = set_params(params, gp)
    return -gp.log_likelihood(y)


initial_params = [0.0, 0.0, 0.0, np.log(10.0), 0.0, np.log(5.0), np.log(0.01)]
soln = minimize(neg_log_like, initial_params, method="L-BFGS-B", args=(gp,))
opt_gp = set_params(soln.x, gp)
soln
[7]:
      fun: -15.942826380167517
 hess_inv: <7x7 LbfgsInvHessProduct with dtype=float64>
      jac: array([ 2.84217094e-06,  4.12114787e-05, -4.33431067e-04,  9.94759836e-06,
       -8.95283843e-05,  1.35003121e-04,  3.12638806e-05])
  message: b'CONVERGENCE: REL_REDUCTION_OF_F_<=_FACTR*EPSMCH'
     nfev: 432
      nit: 49
     njev: 54
   status: 0
  success: True
        x: array([ 4.73106807e-03, -3.41522802e-01,  7.00011980e-01,  1.94365570e+00,
        6.06033252e-01,  3.75573934e+00, -7.87331665e+00])

Now let’s make the same plots for the maximum likelihood model:

[8]:
plt.figure()
plt.title("maximum likelihood psd")
plot_psd(opt_gp)

plt.figure()
plt.title("maximum likelihood prediction")
plot_prediction(opt_gp)
../../_images/tutorials_first_13_0.png
../../_images/tutorials_first_13_1.png

These predictions are starting to look much better!

Posterior inference using emcee

Now, to get a sense for the uncertainties on our model, let’s use Markov chain Monte Carlo (MCMC) to numerically estimate the posterior expectations of the model. In this first example, we’ll use the emcee package to run our MCMC. Our likelihood function is the same as the one we used in the previous section, but we’ll also choose a wide normal prior on each of our parameters.

[9]:
import emcee

prior_sigma = 2.0


def log_prob(params, gp):
    gp = set_params(params, gp)
    return (
        gp.log_likelihood(y) - 0.5 * np.sum((params / prior_sigma) ** 2),
        gp.kernel.get_psd(omega),
    )


np.random.seed(5693854)
coords = soln.x + 1e-5 * np.random.randn(32, len(soln.x))
sampler = emcee.EnsembleSampler(
    coords.shape[0], coords.shape[1], log_prob, args=(gp,)
)
state = sampler.run_mcmc(coords, 2000, progress=True)
sampler.reset()
state = sampler.run_mcmc(state, 5000, progress=True)
100%|██████████| 2000/2000 [00:24<00:00, 83.10it/s]
100%|██████████| 5000/5000 [00:59<00:00, 84.06it/s]

After running our MCMC, we can plot the predictions that the model makes for a handful of samples from the chain. This gives a qualitative sense of the uncertainty in the predictions.

[10]:
chain = sampler.get_chain(discard=100, flat=True)

for sample in chain[np.random.randint(len(chain), size=50)]:
    gp = set_params(sample, gp)
    plt.plot(true_t, gp.sample_conditional(y, true_t), color="C0", alpha=0.1)

plt.title("posterior prediction")
plot_prediction(None)
../../_images/tutorials_first_17_0.png

Similarly, we can plot the posterior expectation for the power spectral density:

[11]:
psds = sampler.get_blobs(discard=100, flat=True)

q = np.percentile(psds, [16, 50, 84], axis=0)

plt.loglog(freq, q[1], color="C0")
plt.fill_between(freq, q[0], q[2], color="C0", alpha=0.1)

plt.xlim(freq.min(), freq.max())
plt.xlabel("frequency [1 / day]")
plt.ylabel("power [day ppt$^2$]")
_ = plt.title("posterior psd using emcee")
../../_images/tutorials_first_19_0.png

Posterior inference using PyMC3

celerite2 also includes support for probabilistic modeling using PyMC3, and we can implement the same model from above as follows:

[12]:
import pymc3 as pm

import celerite2.theano
from celerite2.theano import terms as theano_terms

with pm.Model() as model:

    mean = pm.Normal("mean", mu=0.0, sigma=prior_sigma)
    jitter = pm.Lognormal("jitter", mu=0.0, sigma=prior_sigma)

    sigma1 = pm.Lognormal("sigma1", mu=0.0, sigma=prior_sigma)
    rho1 = pm.Lognormal("rho1", mu=0.0, sigma=prior_sigma)
    tau = pm.Lognormal("tau", mu=0.0, sigma=prior_sigma)
    term1 = theano_terms.SHOTerm(sigma=sigma1, rho=rho1, tau=tau)

    sigma2 = pm.Lognormal("sigma2", mu=0.0, sigma=prior_sigma)
    rho2 = pm.Lognormal("rho2", mu=0.0, sigma=prior_sigma)
    term2 = theano_terms.SHOTerm(sigma=sigma2, rho=rho2, Q=0.25)

    kernel = term1 + term2
    gp = celerite2.theano.GaussianProcess(kernel, mean=mean)
    gp.compute(t, diag=yerr ** 2 + jitter, quiet=True)
    gp.marginal("obs", observed=y)

    pm.Deterministic("psd", kernel.get_psd(omega))

    trace = pm.sample(
        tune=1000,
        draws=1000,
        target_accept=0.8,
        init="adapt_full",
        cores=2,
        chains=2,
        random_seed=34923,
    )
Auto-assigning NUTS sampler...
Initializing NUTS using adapt_full...
Multiprocess sampling (2 chains in 2 jobs)
NUTS: [rho2, sigma2, tau, rho1, sigma1, jitter, mean]
100.00% [4000/4000 00:13<00:00 Sampling 2 chains, 4 divergences]
Sampling 2 chains for 1_000 tune and 1_000 draw iterations (2_000 + 2_000 draws total) took 25 seconds.
There was 1 divergence after tuning. Increase `target_accept` or reparameterize.
There were 3 divergences after tuning. Increase `target_accept` or reparameterize.

Like before, we can plot the posterior estimate of the power spectrum to show that the results are qualitatively similar:

[13]:
psds = trace["psd"]

q = np.percentile(psds, [16, 50, 84], axis=0)

plt.loglog(freq, q[1], color="C0")
plt.fill_between(freq, q[0], q[2], color="C0", alpha=0.1)

plt.xlim(freq.min(), freq.max())
plt.xlabel("frequency [1 / day]")
plt.ylabel("power [day ppt$^2$]")
_ = plt.title("posterior psd using PyMC3")
../../_images/tutorials_first_23_0.png

Posterior inference using numpyro

Since celerite2 includes support for JAX as well as Theano, you can also use tools like numpyro for inference. The following is similar to previous PyMC3 example, but the main difference is that (for technical reasons related to how JAX works) SHOTerms cannot be used in combination with jax.jit, so we need to explicitly specify the terms as “underdamped” (UnderdampedSHOTerm) or “overdamped” (OverdampedSHOTerm).

[14]:
from jax.config import config

config.update("jax_enable_x64", True)

from jax import random

import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS

import celerite2.jax
from celerite2.jax import terms as jax_terms


def numpyro_model(t, yerr, y=None):
    mean = numpyro.sample("mean", dist.Normal(0.0, prior_sigma))
    jitter = numpyro.sample("jitter", dist.LogNormal(0.0, prior_sigma))

    sigma1 = numpyro.sample("sigma1", dist.LogNormal(0.0, prior_sigma))
    rho1 = numpyro.sample("rho1", dist.LogNormal(0.0, prior_sigma))
    tau = numpyro.sample("tau", dist.LogNormal(0.0, prior_sigma))
    term1 = jax_terms.UnderdampedSHOTerm(sigma=sigma1, rho=rho1, tau=tau)

    sigma2 = numpyro.sample("sigma2", dist.LogNormal(0.0, prior_sigma))
    rho2 = numpyro.sample("rho2", dist.LogNormal(0.0, prior_sigma))
    term2 = jax_terms.OverdampedSHOTerm(sigma=sigma2, rho=rho2, Q=0.25)

    kernel = term1 + term2
    gp = celerite2.jax.GaussianProcess(kernel, mean=mean)
    gp.compute(t, diag=yerr ** 2 + jitter, check_sorted=False)

    numpyro.sample("obs", gp.numpyro_dist(), obs=y)
    numpyro.deterministic("psd", kernel.get_psd(omega))


nuts_kernel = NUTS(numpyro_model, dense_mass=True)
mcmc = MCMC(nuts_kernel, num_warmup=1000, num_samples=1000, num_chains=2)
rng_key = random.PRNGKey(34923)
%time mcmc.run(rng_key, t, yerr, y=y)
CPU times: user 18.7 s, sys: 109 ms, total: 18.8 s
Wall time: 18.8 s

This runtime was similar to the PyMC3 result from above, and (as we’ll see below) the convergence is also similar. Any difference in runtime will probably disappear for more computationally expensive models, but this interface is looking pretty great here!

As above, we can plot the posterior expectations for the power spectrum:

[15]:
psds = np.asarray(mcmc.get_samples()["psd"])

q = np.percentile(psds, [16, 50, 84], axis=0)

plt.loglog(freq, q[1], color="C0")
plt.fill_between(freq, q[0], q[2], color="C0", alpha=0.1)

plt.xlim(freq.min(), freq.max())
plt.xlabel("frequency [1 / day]")
plt.ylabel("power [day ppt$^2$]")
_ = plt.title("posterior psd using numpyro")
../../_images/tutorials_first_27_0.png

Comparison

Finally, let’s compare the results of these different inference methods a bit more quantitaively. First, let’s look at the posterior constraint on the period of the underdamped harmonic oscillator, the effective period of the oscillatory signal.

[16]:
import arviz as az

emcee_data = az.from_emcee(
    sampler,
    var_names=[
        "mean",
        "log_sigma1",
        "log_rho1",
        "log_tau",
        "log_sigma2",
        "log_rho2",
        "log_jitter",
    ],
)
for k in emcee_data.posterior.data_vars:
    if k.startswith("log_"):
        emcee_data.posterior[k[4:]] = np.exp(emcee_data.posterior[k])

with model:
    pm_data = az.from_pymc3(trace)

numpyro_data = az.from_numpyro(mcmc)

bins = np.linspace(1.5, 2.75, 25)
plt.hist(
    np.asarray((emcee_data.posterior["rho1"].T)).flatten(),
    bins,
    histtype="step",
    density=True,
    label="emcee",
)
plt.hist(
    np.asarray((pm_data.posterior["rho1"].T)).flatten(),
    bins,
    histtype="step",
    density=True,
    label="PyMC3",
)
plt.hist(
    np.asarray((numpyro_data.posterior["rho1"].T)).flatten(),
    bins,
    histtype="step",
    density=True,
    label="numpyro",
)
plt.legend()
plt.yticks([])
plt.xlabel(r"$\rho_1$")
_ = plt.ylabel(r"$p(\rho_1)$")
../../_images/tutorials_first_29_0.png

That looks pretty consistent.

Next we can look at the ArviZ summary for each method to see how the posterior expectations and convergence diagnostics look.

[17]:
az.summary(
    emcee_data,
    var_names=["mean", "sigma1", "rho1", "tau", "sigma2", "rho2", "jitter"],
)
[17]:
mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_mean ess_sd ess_bulk ess_tail r_hat
mean 0.018 1.243 -2.331 2.565 0.033 0.023 1461.0 1461.0 1447.0 3237.0 1.02
sigma1 0.804 0.338 0.376 1.387 0.008 0.005 1981.0 1981.0 1850.0 3712.0 1.02
rho1 2.020 0.125 1.809 2.244 0.003 0.002 1759.0 1759.0 1730.0 3646.0 1.02
tau 10.063 13.457 1.032 26.135 0.259 0.183 2707.0 2707.0 1644.0 3917.0 1.02
sigma2 2.165 2.297 0.414 5.026 0.046 0.033 2484.0 2484.0 1119.0 4211.0 1.03
rho2 42.039 54.934 3.299 120.103 1.095 0.774 2517.0 2517.0 1135.0 5057.0 1.03
jitter 0.004 0.002 0.000 0.008 0.000 0.000 2252.0 2252.0 1947.0 3790.0 1.02
[18]:
az.summary(
    pm_data,
    var_names=["mean", "sigma1", "rho1", "tau", "sigma2", "rho2", "jitter"],
)
[18]:
mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_mean ess_sd ess_bulk ess_tail r_hat
mean 0.085 1.143 -2.033 2.461 0.031 0.027 1328.0 874.0 1417.0 1090.0 1.0
sigma1 0.791 0.333 0.397 1.316 0.011 0.009 879.0 749.0 1454.0 1028.0 1.0
rho1 2.019 0.109 1.822 2.233 0.003 0.002 1539.0 1528.0 1546.0 1079.0 1.0
tau 9.853 15.296 1.456 23.933 0.536 0.379 813.0 813.0 1294.0 958.0 1.0
sigma2 2.127 3.180 0.385 4.872 0.095 0.067 1115.0 1115.0 905.0 968.0 1.0
rho2 41.443 59.027 3.233 112.047 1.849 1.308 1019.0 1019.0 787.0 868.0 1.0
jitter 0.004 0.002 0.000 0.008 0.000 0.000 1956.0 1950.0 1725.0 1186.0 1.0
[19]:
az.summary(
    numpyro_data,
    var_names=["mean", "sigma1", "rho1", "tau", "sigma2", "rho2", "jitter"],
)
[19]:
mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_mean ess_sd ess_bulk ess_tail r_hat
mean 0.066 1.150 -2.293 2.270 0.031 0.028 1376.0 861.0 1451.0 1178.0 1.01
sigma1 0.794 0.316 0.389 1.348 0.009 0.007 1117.0 1076.0 1542.0 1111.0 1.00
rho1 2.019 0.117 1.805 2.241 0.004 0.003 927.0 888.0 1112.0 654.0 1.00
tau 9.948 11.897 0.993 26.627 0.369 0.261 1037.0 1037.0 1497.0 1164.0 1.00
sigma2 2.114 2.025 0.360 5.267 0.067 0.048 901.0 901.0 1136.0 922.0 1.00
rho2 41.828 60.314 4.434 117.973 1.900 1.344 1008.0 1008.0 1098.0 1031.0 1.00
jitter 0.004 0.002 0.000 0.008 0.000 0.000 1950.0 1950.0 1617.0 1068.0 1.00

Overall these results are consistent, but the \(\hat{R}\) values are a bit high for the emcee run, so I’d probably run that for longer. Either way, for models like these, PyMC3 and numpyro are generally going to be much better inference tools (in terms of runtime per effective sample) than emcee, so those are the recommended interfaces if the rest of your model can be easily implemented in such a framework.