Source code for pymc_extras.prior

"""Class that represents a prior distribution.

The `Prior` class is a wrapper around PyMC distributions that allows the user
to create outside of the PyMC model.

Examples
--------
Create a normal prior.

.. code-block:: python

    from pymc_extras.prior import Prior

    normal = Prior("Normal")

Create a hierarchical normal prior by using distributions for the parameters
and specifying the dims.

.. code-block:: python

    hierarchical_normal = Prior(
        "Normal",
        mu=Prior("Normal"),
        sigma=Prior("HalfNormal"),
        dims="channel",
    )

Create a non-centered hierarchical normal prior with the `centered` parameter.

.. code-block:: python

    non_centered_hierarchical_normal = Prior(
        "Normal",
        mu=Prior("Normal"),
        sigma=Prior("HalfNormal"),
        dims="channel",
        # Only change needed to make it non-centered
        centered=False,
    )

Create a hierarchical beta prior by using Beta distribution, distributions for
the parameters, and specifying the dims.

.. code-block:: python

    hierarchical_beta = Prior(
        "Beta",
        alpha=Prior("HalfNormal"),
        beta=Prior("HalfNormal"),
        dims="channel",
    )

Create a transformed hierarchical normal prior by using the `transform`
parameter. Here the "sigmoid" transformation comes from `pm.math`.

.. code-block:: python

    transformed_hierarchical_normal = Prior(
        "Normal",
        mu=Prior("Normal"),
        sigma=Prior("HalfNormal"),
        transform="sigmoid",
        dims="channel",
    )

Create a prior with a custom transform function by registering it with
`register_tensor_transform`.

.. code-block:: python

    from pymc_extras.prior import register_tensor_transform

    def custom_transform(x):
        return x ** 2

    register_tensor_transform("square", custom_transform)

    custom_distribution = Prior("Normal", transform="square")

"""

from __future__ import annotations

import copy

from collections.abc import Callable
from functools import partial
from inspect import signature
from typing import Any, Protocol, runtime_checkable

import numpy as np
import pymc as pm
import pytensor.tensor as pt
import xarray as xr

from pydantic import InstanceOf, validate_call
from pydantic.dataclasses import dataclass
from pymc.distributions.shape_utils import Dims

from pymc_extras.deserialize import deserialize, register_deserialization


class UnsupportedShapeError(Exception):
    """Error for when the shapes from variables are not compatible."""


class UnsupportedDistributionError(Exception):
    """Error for when an unsupported distribution is used."""


class UnsupportedParameterizationError(Exception):
    """The follow parameterization is not supported."""


class MuAlreadyExistsError(Exception):
    """Error for when 'mu' is present in Prior."""

    def __init__(self, distribution: Prior) -> None:
        self.distribution = distribution
        self.message = f"The mu parameter is already defined in {distribution}"
        super().__init__(self.message)


class UnknownTransformError(Exception):
    """Error for when an unknown transform is used."""


def _remove_leading_xs(args: list[str | int]) -> list[str | int]:
    """Remove leading 'x' from the args."""
    while args and args[0] == "x":
        args.pop(0)

    return args


def handle_dims(x: pt.TensorLike, dims: Dims, desired_dims: Dims) -> pt.TensorVariable:
    """Take a tensor of dims `dims` and align it to `desired_dims`.

    Doesn't check for validity of the dims

    Examples
    --------
    1D to 2D with new dim

    .. code-block:: python

        x = np.array([1, 2, 3])
        dims = "channel"

        desired_dims = ("channel", "group")

        handle_dims(x, dims, desired_dims)

    """
    x = pt.as_tensor_variable(x)

    if np.ndim(x) == 0:
        return x

    dims = dims if isinstance(dims, tuple) else (dims,)
    desired_dims = desired_dims if isinstance(desired_dims, tuple) else (desired_dims,)

    if difference := set(dims).difference(desired_dims):
        raise UnsupportedShapeError(
            f"Dims {dims} of data are not a subset of the desired dims {desired_dims}. "
            f"{difference} is missing from the desired dims."
        )

    aligned_dims = np.array(dims)[:, None] == np.array(desired_dims)

    missing_dims = aligned_dims.sum(axis=0) == 0
    new_idx = aligned_dims.argmax(axis=0)

    args = ["x" if missing else idx for (idx, missing) in zip(new_idx, missing_dims, strict=False)]
    args = _remove_leading_xs(args)
    return x.dimshuffle(*args)


DimHandler = Callable[[pt.TensorLike, Dims], pt.TensorLike]


def create_dim_handler(desired_dims: Dims) -> DimHandler:
    """Wrap the `handle_dims` function to act like the previous `create_dim_handler` function."""

    def func(x: pt.TensorLike, dims: Dims) -> pt.TensorVariable:
        return handle_dims(x, dims, desired_dims)

    return func


def _dims_to_str(obj: tuple[str, ...]) -> str:
    if len(obj) == 1:
        return f'"{obj[0]}"'

    return "(" + ", ".join(f'"{i}"' if isinstance(i, str) else str(i) for i in obj) + ")"


def _get_pymc_distribution(name: str) -> type[pm.Distribution]:
    if not hasattr(pm, name):
        raise UnsupportedDistributionError(f"PyMC doesn't have a distribution of name {name!r}")

    return getattr(pm, name)


Transform = Callable[[pt.TensorLike], pt.TensorLike]

CUSTOM_TRANSFORMS: dict[str, Transform] = {}


def register_tensor_transform(name: str, transform: Transform) -> None:
    """Register a tensor transform function to be used in the `Prior` class.

    Parameters
    ----------
    name : str
        The name of the transform.
    func : Callable[[pt.TensorLike], pt.TensorLike]
        The function to apply to the tensor.

    Examples
    --------
    Register a custom transform function.

    .. code-block:: python

        from pymc_extras.prior import (
            Prior,
            register_tensor_transform,
        )

        def custom_transform(x):
            return x ** 2

        register_tensor_transform("square", custom_transform)

        custom_distribution = Prior("Normal", transform="square")

    """
    CUSTOM_TRANSFORMS[name] = transform


def _get_transform(name: str):
    if name in CUSTOM_TRANSFORMS:
        return CUSTOM_TRANSFORMS[name]

    for module in (pt, pm.math):
        if hasattr(module, name):
            break
    else:
        module = None

    if not module:
        msg = (
            f"Neither pytensor.tensor nor pymc.math have the function {name!r}. "
            "If this is a custom function, register it with the "
            "`pymc_extras.prior.register_tensor_transform` function before "
            "previous function call."
        )

        raise UnknownTransformError(msg)

    return getattr(module, name)


def _get_pymc_parameters(distribution: pm.Distribution) -> set[str]:
    return set(signature(distribution.dist).parameters.keys()) - {"kwargs", "args"}


@runtime_checkable
class VariableFactory(Protocol):
    """Protocol for something that works like a Prior class."""

    dims: tuple[str, ...]

    def create_variable(self, name: str) -> pt.TensorVariable:
        """Create a TensorVariable."""


def sample_prior(
    factory: VariableFactory,
    coords=None,
    name: str = "variable",
    wrap: bool = False,
    **sample_prior_predictive_kwargs,
) -> xr.Dataset:
    """Sample the prior for an arbitrary VariableFactory.

    Parameters
    ----------
    factory : VariableFactory
        The factory to sample from.
    coords : dict[str, list[str]], optional
        The coordinates for the variable, by default None.
        Only required if the dims are specified.
    name : str, optional
        The name of the variable, by default "variable".
    wrap : bool, optional
        Whether to wrap the variable in a `pm.Deterministic` node, by default False.
    sample_prior_predictive_kwargs : dict
        Additional arguments to pass to `pm.sample_prior_predictive`.

    Returns
    -------
    xr.Dataset
        The dataset of the prior samples.

    Example
    -------
    Sample from an arbitrary variable factory.

    .. code-block:: python

        import pymc as pm

        import pytensor.tensor as pt

        from pymc_extras.prior import sample_prior

        class CustomVariableDefinition:
            def __init__(self, dims, n: int):
                self.dims = dims
                self.n = n

            def create_variable(self, name: str) -> "TensorVariable":
                x = pm.Normal(f"{name}_x", mu=0, sigma=1, dims=self.dims)
                return pt.sum([x ** n for n in range(1, self.n + 1)], axis=0)

        cubic = CustomVariableDefinition(dims=("channel",), n=3)
        coords = {"channel": ["C1", "C2", "C3"]}
        # Doesn't include the return value
        prior = sample_prior(cubic, coords=coords)

        prior_with = sample_prior(cubic, coords=coords, wrap=True)

    """
    coords = coords or {}

    if isinstance(factory.dims, str):
        dims = (factory.dims,)
    else:
        dims = factory.dims

    if missing_keys := set(dims) - set(coords.keys()):
        raise KeyError(f"Coords are missing the following dims: {missing_keys}")

    with pm.Model(coords=coords) as model:
        if wrap:
            pm.Deterministic(name, factory.create_variable(name), dims=factory.dims)
        else:
            factory.create_variable(name)

    return pm.sample_prior_predictive(
        model=model,
        **sample_prior_predictive_kwargs,
    ).prior


class Prior:
    """A class to represent a prior distribution.

    Make use of the various helper methods to understand the distributions
    better.

    - `preliz` attribute to get the equivalent distribution in `preliz`
    - `sample_prior` method to sample from the prior
    - `to_graph` get a dummy model graph with the distribution
    - `constrain` to shift the distribution to a different range

    Parameters
    ----------
    distribution : str
        The name of PyMC distribution.
    dims : Dims, optional
        The dimensions of the variable, by default None
    centered : bool, optional
        Whether the variable is centered or not, by default True.
        Only allowed for Normal distribution.
    transform : str, optional
        The name of the transform to apply to the variable after it is
        created, by default None or no transform. The transformation must
        be registered with `register_tensor_transform` function or
        be available in either `pytensor.tensor` or `pymc.math`.

    """

    # Taken from https://en.wikipedia.org/wiki/Location%E2%80%93scale_family
    non_centered_distributions: dict[str, dict[str, float]] = {
        "Normal": {"mu": 0, "sigma": 1},
        "StudentT": {"mu": 0, "sigma": 1},
        "ZeroSumNormal": {"sigma": 1},
    }

    pymc_distribution: type[pm.Distribution]
    pytensor_transform: Callable[[pt.TensorLike], pt.TensorLike] | None

    @validate_call
    def __init__(
        self,
        distribution: str,
        *,
        dims: Dims | None = None,
        centered: bool = True,
        transform: str | None = None,
        **parameters,
    ) -> None:
        self.distribution = distribution
        self.parameters = parameters
        self.dims = dims
        self.centered = centered
        self.transform = transform

        self._checks()

    @property
    def distribution(self) -> str:
        """The name of the PyMC distribution."""
        return self._distribution

    @distribution.setter
    def distribution(self, distribution: str) -> None:
        if hasattr(self, "_distribution"):
            raise AttributeError("Can't change the distribution")

        self._distribution = distribution
        self.pymc_distribution = _get_pymc_distribution(distribution)

    @property
    def transform(self) -> str | None:
        """The name of the transform to apply to the variable after it is created."""
        return self._transform

    @transform.setter
    def transform(self, transform: str | None) -> None:
        self._transform = transform
        self.pytensor_transform = not transform or _get_transform(transform)  # type: ignore

    @property
    def dims(self) -> Dims:
        """The dimensions of the variable."""
        return self._dims

    @dims.setter
    def dims(self, dims) -> None:
        if isinstance(dims, str):
            dims = (dims,)

        if isinstance(dims, list):
            dims = tuple(dims)

        self._dims = dims or ()

        self._param_dims_work()
        self._unique_dims()

    def __getitem__(self, key: str) -> Prior | Any:
        """Return the parameter of the prior."""
        return self.parameters[key]

    def _checks(self) -> None:
        if not self.centered:
            self._correct_non_centered_distribution()

        self._parameters_are_at_least_subset_of_pymc()
        self._convert_lists_to_numpy()
        self._parameters_are_correct_type()

    def _parameters_are_at_least_subset_of_pymc(self) -> None:
        pymc_params = _get_pymc_parameters(self.pymc_distribution)
        if not set(self.parameters.keys()).issubset(pymc_params):
            msg = (
                f"Parameters {set(self.parameters.keys())} "
                "are not a subset of the pymc distribution "
                f"parameters {set(pymc_params)}"
            )
            raise ValueError(msg)

    def _convert_lists_to_numpy(self) -> None:
        def convert(x):
            if not isinstance(x, list):
                return x

            return np.array(x)

        self.parameters = {key: convert(value) for key, value in self.parameters.items()}

    def _parameters_are_correct_type(self) -> None:
        supported_types = (
            int,
            float,
            np.ndarray,
            Prior,
            pt.TensorVariable,
            VariableFactory,
        )

        incorrect_types = {
            param: type(value)
            for param, value in self.parameters.items()
            if not isinstance(value, supported_types)
        }
        if incorrect_types:
            msg = (
                "Parameters must be one of the following types: "
                f"(int, float, np.array, Prior, pt.TensorVariable). Incorrect parameters: {incorrect_types}"
            )
            raise ValueError(msg)

    def _correct_non_centered_distribution(self) -> None:
        if not self.centered and self.distribution not in self.non_centered_distributions:
            raise UnsupportedParameterizationError(
                f"{self.distribution!r} is not supported for non-centered parameterization. "
                f"Choose from {list(self.non_centered_distributions.keys())}"
            )

        required_parameters = set(self.non_centered_distributions[self.distribution].keys())

        if set(self.parameters.keys()) < required_parameters:
            msg = " and ".join([f"{param!r}" for param in required_parameters])
            raise ValueError(
                f"Must have at least {msg} parameter for non-centered for {self.distribution!r}"
            )

    def _unique_dims(self) -> None:
        if not self.dims:
            return

        if len(self.dims) != len(set(self.dims)):
            raise ValueError("Dims must be unique")

    def _param_dims_work(self) -> None:
        other_dims = set()
        for value in self.parameters.values():
            if hasattr(value, "dims"):
                other_dims.update(value.dims)

        if not other_dims.issubset(self.dims):
            raise UnsupportedShapeError(
                f"Parameter dims {other_dims} are not a subset of the prior dims {self.dims}"
            )

    def __str__(self) -> str:
        """Return a string representation of the prior."""
        param_str = ", ".join([f"{param}={value}" for param, value in self.parameters.items()])
        param_str = "" if not param_str else f", {param_str}"

        dim_str = f", dims={_dims_to_str(self.dims)}" if self.dims else ""
        centered_str = f", centered={self.centered}" if not self.centered else ""
        transform_str = f', transform="{self.transform}"' if self.transform else ""
        return f'Prior("{self.distribution}"{param_str}{dim_str}{centered_str}{transform_str})'

    def __repr__(self) -> str:
        """Return a string representation of the prior."""
        return f"{self}"

    def _create_parameter(self, param, value, name):
        if not hasattr(value, "create_variable"):
            return value

        child_name = f"{name}_{param}"
        return self.dim_handler(value.create_variable(child_name), value.dims)

    def _create_centered_variable(self, name: str):
        parameters = {
            param: self._create_parameter(param, value, name)
            for param, value in self.parameters.items()
        }
        return self.pymc_distribution(name, **parameters, dims=self.dims)

    def _create_non_centered_variable(self, name: str) -> pt.TensorVariable:
        def handle_variable(var_name: str):
            parameter = self.parameters[var_name]
            if not hasattr(parameter, "create_variable"):
                return parameter

            return self.dim_handler(
                parameter.create_variable(f"{name}_{var_name}"),
                parameter.dims,
            )

        defaults = self.non_centered_distributions[self.distribution]
        other_parameters = {
            param: handle_variable(param)
            for param in self.parameters.keys()
            if param not in defaults
        }
        offset = self.pymc_distribution(
            f"{name}_offset",
            **defaults,
            **other_parameters,
            dims=self.dims,
        )
        if "mu" in self.parameters:
            mu = (
                handle_variable("mu")
                if isinstance(self.parameters["mu"], Prior)
                else self.parameters["mu"]
            )
        else:
            mu = 0

        sigma = (
            handle_variable("sigma")
            if isinstance(self.parameters["sigma"], Prior)
            else self.parameters["sigma"]
        )

        return pm.Deterministic(
            name,
            mu + sigma * offset,
            dims=self.dims,
        )

[docs] def create_variable(self, name: str) -> pt.TensorVariable: """Create a PyMC variable from the prior. Must be used in a PyMC model context. Parameters ---------- name : str The name of the variable. Returns ------- pt.TensorVariable The PyMC variable. Examples -------- Create a hierarchical normal variable in larger PyMC model. .. code-block:: python dist = Prior( "Normal", mu=Prior("Normal"), sigma=Prior("HalfNormal"), dims="channel", ) coords = {"channel": ["C1", "C2", "C3"]} with pm.Model(coords=coords): var = dist.create_variable("var") """ self.dim_handler = create_dim_handler(self.dims) if self.transform: var_name = f"{name}_raw" def transform(var): return pm.Deterministic(name, self.pytensor_transform(var), dims=self.dims) else: var_name = name def transform(var): return var create_variable = ( self._create_centered_variable if self.centered else self._create_non_centered_variable ) var = create_variable(name=var_name) return transform(var)
@property def preliz(self): """Create an equivalent preliz distribution. Helpful to visualize a distribution when it is univariate. Returns ------- preliz.distributions.Distribution Examples -------- Create a preliz distribution from a prior. .. code-block:: python from pymc_extras.prior import Prior dist = Prior("Gamma", alpha=5, beta=1) dist.preliz.plot_pdf() """ import preliz as pz return getattr(pz, self.distribution)(**self.parameters)
[docs] def to_dict(self) -> dict[str, Any]: """Convert the prior to dictionary format. Returns ------- dict[str, Any] The dictionary format of the prior. Examples -------- Convert a prior to the dictionary format. .. code-block:: python from pymc_extras.prior import Prior dist = Prior("Normal", mu=0, sigma=1) dist.to_dict() Convert a hierarchical prior to the dictionary format. .. code-block:: python dist = Prior( "Normal", mu=Prior("Normal"), sigma=Prior("HalfNormal"), dims="channel", ) dist.to_dict() """ data: dict[str, Any] = { "dist": self.distribution, } if self.parameters: def handle_value(value): if isinstance(value, Prior): return value.to_dict() if isinstance(value, pt.TensorVariable): value = value.eval() if isinstance(value, np.ndarray): return value.tolist() if hasattr(value, "to_dict"): return value.to_dict() return value data["kwargs"] = { param: handle_value(value) for param, value in self.parameters.items() } if not self.centered: data["centered"] = False if self.dims: data["dims"] = self.dims if self.transform: data["transform"] = self.transform return data
[docs] @classmethod def from_dict(cls, data) -> Prior: """Create a Prior from the dictionary format. Parameters ---------- data : dict[str, Any] The dictionary format of the prior. Returns ------- Prior The prior distribution. Examples -------- Convert prior in the dictionary format to a Prior instance. .. code-block:: python from pymc_extras.prior import Prior data = { "dist": "Normal", "kwargs": {"mu": 0, "sigma": 1}, } dist = Prior.from_dict(data) dist # Prior("Normal", mu=0, sigma=1) """ if not isinstance(data, dict): msg = ( "Must be a dictionary representation of a prior distribution. " f"Not of type: {type(data)}" ) raise ValueError(msg) dist = data["dist"] kwargs = data.get("kwargs", {}) def handle_value(value): if isinstance(value, dict): return deserialize(value) if isinstance(value, list): return np.array(value) return value kwargs = {param: handle_value(value) for param, value in kwargs.items()} centered = data.get("centered", True) dims = data.get("dims") if isinstance(dims, list): dims = tuple(dims) transform = data.get("transform") return cls(dist, dims=dims, centered=centered, transform=transform, **kwargs)
[docs] def constrain(self, lower: float, upper: float, mass: float = 0.95, kwargs=None) -> Prior: """Create a new prior with a given mass constrained within the given bounds. Wrapper around `preliz.maxent`. Parameters ---------- lower : float The lower bound. upper : float The upper bound. mass: float = 0.95 The mass of the distribution to keep within the bounds. kwargs : dict Additional arguments to pass to `pz.maxent`. Returns ------- Prior The maximum entropy prior with a mass constrained to the given bounds. Examples -------- Create a Beta distribution that is constrained to have 95% of the mass between 0.5 and 0.8. .. code-block:: python dist = Prior( "Beta", ).constrain(lower=0.5, upper=0.8) Create a Beta distribution with mean 0.6, that is constrained to have 95% of the mass between 0.5 and 0.8. .. code-block:: python dist = Prior( "Beta", mu=0.6, ).constrain(lower=0.5, upper=0.8) """ from preliz import maxent if self.transform: raise ValueError("Can't constrain a transformed variable") if kwargs is None: kwargs = {} kwargs.setdefault("plot", False) if kwargs["plot"]: new_parameters = maxent(self.preliz, lower, upper, mass, **kwargs)[0].params_dict else: new_parameters = maxent(self.preliz, lower, upper, mass, **kwargs).params_dict return Prior( self.distribution, dims=self.dims, transform=self.transform, centered=self.centered, **new_parameters, )
def __eq__(self, other) -> bool: """Check if two priors are equal.""" if not isinstance(other, Prior): return False try: np.testing.assert_equal(self.parameters, other.parameters) except AssertionError: return False return ( self.distribution == other.distribution and self.dims == other.dims and self.centered == other.centered and self.transform == other.transform )
[docs] def sample_prior( self, coords=None, name: str = "variable", **sample_prior_predictive_kwargs, ) -> xr.Dataset: """Sample the prior distribution for the variable. Parameters ---------- coords : dict[str, list[str]], optional The coordinates for the variable, by default None. Only required if the dims are specified. name : str, optional The name of the variable, by default "variable". sample_prior_predictive_kwargs : dict Additional arguments to pass to `pm.sample_prior_predictive`. Returns ------- xr.Dataset The dataset of the prior samples. Example ------- Sample from a hierarchical normal distribution. .. code-block:: python dist = Prior( "Normal", mu=Prior("Normal"), sigma=Prior("HalfNormal"), dims="channel", ) coords = {"channel": ["C1", "C2", "C3"]} prior = dist.sample_prior(coords=coords) """ return sample_prior( factory=self, coords=coords, name=name, **sample_prior_predictive_kwargs, )
def __deepcopy__(self, memo) -> Prior: """Return a deep copy of the prior.""" if id(self) in memo: return memo[id(self)] copy_obj = Prior( self.distribution, dims=copy.copy(self.dims), centered=self.centered, transform=self.transform, **copy.deepcopy(self.parameters), ) memo[id(self)] = copy_obj return copy_obj
[docs] def deepcopy(self) -> Prior: """Return a deep copy of the prior.""" return copy.deepcopy(self)
[docs] def to_graph(self): """Generate a graph of the variables. Examples -------- Create the graph for a 2D transformed hierarchical distribution. .. code-block:: python from pymc_extras.prior import Prior mu = Prior( "Normal", mu=Prior("Normal"), sigma=Prior("HalfNormal"), dims="channel", ) sigma = Prior("HalfNormal", dims="channel") dist = Prior( "Normal", mu=mu, sigma=sigma, dims=("channel", "geo"), centered=False, transform="sigmoid", ) dist.to_graph() .. image:: /_static/example-graph.png :alt: Example graph """ coords = {name: ["DUMMY"] for name in self.dims} with pm.Model(coords=coords) as model: self.create_variable("var") return pm.model_to_graphviz(model)
[docs] def create_likelihood_variable( self, name: str, mu: pt.TensorLike, observed: pt.TensorLike, ) -> pt.TensorVariable: """Create a likelihood variable from the prior. Will require that the distribution has a `mu` parameter and that it has not been set in the parameters. Parameters ---------- name : str The name of the variable. mu : pt.TensorLike The mu parameter for the likelihood. observed : pt.TensorLike The observed data. Returns ------- pt.TensorVariable The PyMC variable. Examples -------- Create a likelihood variable in a larger PyMC model. .. code-block:: python import pymc as pm dist = Prior("Normal", sigma=Prior("HalfNormal")) with pm.Model(): # Create the likelihood variable mu = pm.Normal("mu", mu=0, sigma=1) dist.create_likelihood_variable("y", mu=mu, observed=observed) """ if "mu" not in _get_pymc_parameters(self.pymc_distribution): raise UnsupportedDistributionError( f"Likelihood distribution {self.distribution!r} is not supported." ) if "mu" in self.parameters: raise MuAlreadyExistsError(self) distribution = self.deepcopy() distribution.parameters["mu"] = mu distribution.parameters["observed"] = observed return distribution.create_variable(name)
class VariableNotFound(Exception): """Variable is not found.""" def _remove_random_variable(var: pt.TensorVariable) -> None: if var.name is None: raise ValueError("This isn't removable") name: str = var.name model = pm.modelcontext(None) for idx, free_rv in enumerate(model.free_RVs): if var == free_rv: index_to_remove = idx break else: raise VariableNotFound(f"Variable {var.name!r} not found") var.name = None model.free_RVs.pop(index_to_remove) model.named_vars.pop(name) @dataclass class Censored: """Create censored random variable. Examples -------- Create a censored Normal distribution: .. code-block:: python from pymc_extras.prior import Prior, Censored normal = Prior("Normal") censored_normal = Censored(normal, lower=0) Create hierarchical censored Normal distribution: .. code-block:: python from pymc_extras.prior import Prior, Censored normal = Prior( "Normal", mu=Prior("Normal"), sigma=Prior("HalfNormal"), dims="channel", ) censored_normal = Censored(normal, lower=0) coords = {"channel": range(3)} samples = censored_normal.sample_prior(coords=coords) """ distribution: InstanceOf[Prior] lower: float | InstanceOf[pt.TensorVariable] = -np.inf upper: float | InstanceOf[pt.TensorVariable] = np.inf def __post_init__(self) -> None: """Check validity at initialization.""" if not self.distribution.centered: raise ValueError( "Censored distribution must be centered so that .dist() API can be used on distribution." ) if self.distribution.transform is not None: raise ValueError( "Censored distribution can't have a transform so that .dist() API can be used on distribution." ) @property def dims(self) -> tuple[str, ...]: """The dims from the distribution to censor.""" return self.distribution.dims @dims.setter def dims(self, dims) -> None: self.distribution.dims = dims
[docs] def create_variable(self, name: str) -> pt.TensorVariable: """Create censored random variable.""" dist = self.distribution.create_variable(name) _remove_random_variable(var=dist) return pm.Censored( name, dist, lower=self.lower, upper=self.upper, dims=self.dims, )
[docs] def to_dict(self) -> dict[str, Any]: """Convert the censored distribution to a dictionary.""" def handle_value(value): if isinstance(value, pt.TensorVariable): return value.eval().tolist() return value return { "class": "Censored", "data": { "dist": self.distribution.to_dict(), "lower": handle_value(self.lower), "upper": handle_value(self.upper), }, }
[docs] @classmethod def from_dict(cls, data: dict[str, Any]) -> Censored: """Create a censored distribution from a dictionary.""" data = data["data"] return cls( # type: ignore distribution=deserialize(data["dist"]), lower=data["lower"], upper=data["upper"], )
[docs] def sample_prior( self, coords=None, name: str = "variable", **sample_prior_predictive_kwargs, ) -> xr.Dataset: """Sample the prior distribution for the variable. Parameters ---------- coords : dict[str, list[str]], optional The coordinates for the variable, by default None. Only required if the dims are specified. name : str, optional The name of the variable, by default "var". sample_prior_predictive_kwargs : dict Additional arguments to pass to `pm.sample_prior_predictive`. Returns ------- xr.Dataset The dataset of the prior samples. Example ------- Sample from a censored Gamma distribution. .. code-block:: python gamma = Prior("Gamma", mu=1, sigma=1, dims="channel") dist = Censored(gamma, lower=0.5) coords = {"channel": ["C1", "C2", "C3"]} prior = dist.sample_prior(coords=coords) """ return sample_prior( factory=self, coords=coords, name=name, **sample_prior_predictive_kwargs, )
[docs] def to_graph(self): """Generate a graph of the variables. Examples -------- Create graph for a censored Normal distribution .. code-block:: python from pymc_extras.prior import Prior, Censored normal = Prior("Normal") censored_normal = Censored(normal, lower=0) censored_normal.to_graph() """ coords = {name: ["DUMMY"] for name in self.dims} with pm.Model(coords=coords) as model: self.create_variable("var") return pm.model_to_graphviz(model)
[docs] def create_likelihood_variable( self, name: str, mu: pt.TensorLike, observed: pt.TensorLike, ) -> pt.TensorVariable: """Create observed censored variable. Will require that the distribution has a `mu` parameter and that it has not been set in the parameters. Parameters ---------- name : str The name of the variable. mu : pt.TensorLike The mu parameter for the likelihood. observed : pt.TensorLike The observed data. Returns ------- pt.TensorVariable The PyMC variable. Examples -------- Create a censored likelihood variable in a larger PyMC model. .. code-block:: python import pymc as pm from pymc_extras.prior import Prior, Censored normal = Prior("Normal", sigma=Prior("HalfNormal")) dist = Censored(normal, lower=0) observed = 1 with pm.Model(): # Create the likelihood variable mu = pm.HalfNormal("mu", sigma=1) dist.create_likelihood_variable("y", mu=mu, observed=observed) """ if "mu" not in _get_pymc_parameters(self.distribution.pymc_distribution): raise UnsupportedDistributionError( f"Likelihood distribution {self.distribution.distribution!r} is not supported." ) if "mu" in self.distribution.parameters: raise MuAlreadyExistsError(self.distribution) distribution = self.distribution.deepcopy() distribution.parameters["mu"] = mu dist = distribution.create_variable(name) _remove_random_variable(var=dist) return pm.Censored( name, dist, observed=observed, lower=self.lower, upper=self.upper, dims=self.dims, )
class Scaled: """Scaled distribution for numerical stability.""" def __init__(self, dist: Prior, factor: float | pt.TensorVariable) -> None: self.dist = dist self.factor = factor @property def dims(self) -> Dims: """The dimensions of the scaled distribution.""" return self.dist.dims
[docs] def create_variable(self, name: str) -> pt.TensorVariable: """Create a scaled variable. Parameters ---------- name : str The name of the variable. Returns ------- pt.TensorVariable The scaled variable. """ var = self.dist.create_variable(f"{name}_unscaled") return pm.Deterministic(name, var * self.factor, dims=self.dims)
def _is_prior_type(data: dict) -> bool: return "dist" in data def _is_censored_type(data: dict) -> bool: return data.keys() == {"class", "data"} and data["class"] == "Censored" register_deserialization(is_type=_is_prior_type, deserialize=Prior.from_dict) register_deserialization(is_type=_is_censored_type, deserialize=Censored.from_dict) def __getattr__(name: str): """Get Prior class through the module. Examples -------- Create a normal distribution. .. code-block:: python from pymc_extras.prior import Normal dist = Normal(mu=1, sigma=2) Create a hierarchical normal distribution. .. code-block:: python import pymc_extras.prior as pr dist = pr.Normal(mu=pr.Normal(), sigma=pr.HalfNormal(), dims="channel") samples = dist.sample_prior(coords={"channel": ["C1", "C2", "C3"]}) """ # Protect against doctest if name == "__wrapped__": return _get_pymc_distribution(name) return partial(Prior, distribution=name)