Source code for celerite2.jax.distribution
# -*- coding: utf-8 -*-
__all__ = ["CeleriteNormal"]
from jax import numpy as jnp
from jax import random as random
from numpyro import distributions as dist
[docs]
class CeleriteNormal(dist.Distribution):
support = dist.constraints.real_vector
def __init__(self, gp, validate_args=None):
self.gp = gp
super().__init__(
batch_shape=(),
event_shape=jnp.shape(self.gp._t),
validate_args=validate_args,
)
@dist.util.validate_sample
def log_prob(self, value):
return self.gp.log_likelihood(value)
def sample(self, key, sample_shape=()):
eps = random.normal(key, shape=self.event_shape + sample_shape)
return jnp.moveaxis(self.gp.dot_tril(eps), 0, -1)