Source code for celerite2.pymc.distribution

# -*- coding: utf-8 -*-

__all__ = ["CeleriteNormal"]

import numpy as np
import pytensor.tensor as pt
from pymc.distributions.dist_math import check_parameters
from pymc.distributions.distribution import Continuous
from pymc.distributions.shape_utils import rv_size_is_none
from pytensor.tensor.random.op import RandomVariable
from pytensor.tensor.random.utils import (
    broadcast_params,
    supp_shape_from_ref_param_shape,
)

import celerite2.driver as driver
from celerite2.pymc import ops


def safe_celerite_normal(rng, mean, norm, t, c, U, W, d, size=None):
    if size is None:
        shape = (mean.shape[0], 1)
        out_shape = mean.shape[0]
    else:
        shape = (mean.shape[0], np.prod(size))
        out_shape = tuple(size) + (mean.shape[0],)
    n = rng.standard_normal(size=shape) * np.sqrt(d)[:, None]
    result = driver.matmul_lower(t, c, U, W, n, n)
    return np.reshape(np.transpose(result), out_shape)


class CeleriteNormalRV(RandomVariable):
    name = "celerite_normal"
    ndim_supp = 1
    ndims_params = [1, 0, 1, 1, 2, 2, 1]
    dtype = "floatX"
    _print_name = ("CeleriteNormal", "\\operatorname{CeleriteNormal}")

    def _supp_shape_from_params(self, dist_params, param_shapes=None):
        return supp_shape_from_ref_param_shape(
            ndim_supp=self.ndim_supp,
            dist_params=dist_params,
            param_shapes=param_shapes,
            ref_param_idx=2,
        )

    @classmethod
    def rng_fn(cls, rng, mean, norm, t, c, U, W, d, size):
        if any(
            x.ndim > n
            for n, x in zip(cls.ndims_params, [mean, norm, t, c, U, W, d])
        ):
            mean, norm, t, c, U, W, d = broadcast_params(
                [mean, norm, t, c, U, W, d], cls.ndims_params
            )
            size = tuple(size or ())

            if size:
                if (
                    0 < mean.ndim - 1 <= len(size)
                    and size[-mean.ndim + 1 :] != mean.shape[:-1]
                ):
                    raise ValueError(
                        "shape mismatch: objects cannot be broadcast to a single shape"
                    )
                mean = np.broadcast_to(mean, size + mean.shape[-1:])
                norm = np.broadcast_to(norm, size)
                t = np.broadcast_to(t, size + t.shape[-1:])
                c = np.broadcast_to(c, size + c.shape[-1:])
                U = np.broadcast_to(U, size + U.shape[-2:])
                W = np.broadcast_to(W, size + W.shape[-2:])
                d = np.broadcast_to(d, size + d.shape[-2:])

            res = np.empty(mean.shape)
            for idx in np.ndindex(mean.shape[:-1]):
                res[idx] = safe_celerite_normal(
                    rng,
                    mean[idx],
                    norm[idx],
                    t[idx],
                    c[idx],
                    U[idx],
                    W[idx],
                    d[idx],
                )
            return res

        else:
            return safe_celerite_normal(
                rng, mean, norm, t, c, U, W, d, size=size
            )


celerite_normal = CeleriteNormalRV()


[docs] class CeleriteNormal(Continuous): """A multivariate normal distribution with a celerite covariance matrix""" rv_op = celerite_normal @classmethod def dist(cls, mean, norm, t, c, U, W, d, **kwargs): mean = pt.as_tensor_variable(mean) norm = pt.as_tensor_variable(norm) t = pt.as_tensor_variable(t) c = pt.as_tensor_variable(c) U = pt.as_tensor_variable(U) W = pt.as_tensor_variable(W) d = pt.as_tensor_variable(d) mean = pt.broadcast_arrays(mean, t)[0] return super().dist([mean, norm, t, c, U, W, d], **kwargs) def moment(rv, size, mean, *args): moment = mean if not rv_size_is_none(size): moment_size = pt.concatenate([size, [mean.shape[-1]]]) moment = pt.full(moment_size, mean) return moment def logp(value, mean, norm, t, c, U, W, d): ok = pt.all(pt.gt(d, 0.0)) alpha = value - mean alpha = ops.solve_lower(t, c, U, W, alpha[:, None])[0][:, 0] logp = norm - 0.5 * pt.sum(alpha**2 / d) return check_parameters(logp, ok)