JAX interface¶
This celerite2.jax
submodule provides an interface to celerite2 models
that can be used from JAX.
The Getting started tutorial demonstrates the use of this interface, while this
page provides the details for the celerite2.jax.GaussianProcess
class
which provides all this functionality. This page does not include documentation
for the term models defined in JAX, but you can refer to the
Model building section of the Python interface documentation. All of those
models are implemented in JAX and you can access them using something like
the following:
import jax
from celerite2.jax import GaussianProcess, terms
@jax.jit
def log_likelihood(params, x, diag, y):
term = terms.SHOTerm(S0=params["S0"], w0=params["w0"], Q=params["Q"])
gp = GaussianProcess(term)
gp.compute(x, diag=diag)
return gp.log_likelihood(y)
The celerite2.jax.GaussianProcess
class is detailed below:
- class celerite2.jax.GaussianProcess(kernel, t=None, *, mean=0.0, **kwargs)[source]¶
- apply_inverse(y, *, inplace=False)¶
Apply the inverse of the covariance matrix to a vector or matrix
Solve
K.x = y
forx
whereK
is the covariance matrix of the GP.Note
The mean function is not applied in this method.
- Parameters:
y (shape[N] or shape[N, M]) – The vector or matrix
y
described above.inplace (bool, optional) – If
True
,y
will be overwritten with the resultx
.
- Raises:
RuntimeError – If
GaussianProcess.compute()
is not called first.ValueError – When the inputs are not valid (shape, number, etc.).
- compute(t, *, yerr=None, diag=None, check_sorted=True, quiet=False)¶
Compute the Cholesky factorization of the GP covariance matrix
- Parameters:
t (shape[N]) – The independent coordinates of the observations. This must be sorted in increasing order.
yerr (shape[N], optional) – If provided, the diagonal standard deviation of the observation model.
diag (shape[N], optional) – If provided, the diagonal variance of the observation model.
check_sorted (bool, optional) – If
True
, a check is performed to make sure thatt
is correctly sorted. AValueError
will be thrown when this check fails.quiet (bool, optional) – If
True
, when the matrix cannot be factorized (because of numerics or otherwise) the solver’sLinAlgError
will be silenced and the determiniant will be set to zero. Otherwise, the exception will be propagated.
- Raises:
ValueError – When the inputs are not valid (shape, number, etc.).
LinAlgError – When the matrix is not numerically positive definite.
- dot_tril(y, *, inplace=False)¶
Dot the Cholesky factor of the GP system into a vector or matrix
Compute
x = L.y
whereK = L.L^T
andK
is the covariance matrix of the GP.Note
The mean function is not applied in this method.
- Parameters:
y (shape[N] or shape[N, M]) – The vector or matrix
y
described above.inplace (bool, optional) – If
True
,y
will be overwritten with the resultx
.
- Raises:
RuntimeError – If
GaussianProcess.compute()
is not called first.ValueError – When the inputs are not valid (shape, number, etc.).
- log_likelihood(y, *, inplace=False)¶
Compute the marginalized likelihood of the GP model
The factorized matrix from the previous call to
GaussianProcess.compute()
is used so that method must be called first.- Parameters:
y (shape[N]) – The observations at coordinates
t
as defined byGaussianProcess.compute()
.inplace (bool, optional) – If
True
,y
will be overwritten in the process of the calculation. This will reduce the memory footprint, but should be used with care since this will overwrite the data.
- Raises:
RuntimeError – If
GaussianProcess.compute()
is not called first.ValueError – When the inputs are not valid (shape, number, etc.).
- predict(y, t=None, *, return_cov=False, return_var=False, include_mean=True, kernel=None)¶
Compute the conditional distribution
The factorized matrix from the previous call to
GaussianProcess.compute()
is used so that method must be called first.- Parameters:
y (shape[N]) – The observations at coordinates
t
as defined byGaussianProcess.compute()
.t (shape[M], optional) – The independent coordinates where the prediction should be evaluated. If not provided, this will be evaluated at the observations
t
fromGaussianProcess.compute()
.return_var (bool, optional) – Return the variance of the conditional distribution.
return_cov (bool, optional) – Return the full covariance matrix of the conditional distribution.
include_mean (bool, optional) – Include the mean function in the prediction.
kernel (optional) – If provided, compute the conditional distribution using a different kernel. This is generally used to separate the contributions from different model components. Note that the computational cost and scaling will be worse when using this parameter.
- Raises:
RuntimeError – If
GaussianProcess.compute()
is not called first.ValueError – When the inputs are not valid (shape, number, etc.).
numpyro support¶
This implementation comes with a custom numpyro
Distribution
that represents a multivariate normal with a celerite
covariance matrix. This is used by the
celerite2.jax.GaussianProcess.numpyro_dist()
method documented above which
adds a marginal likelihood node to a numpyro model.