# Copyright 2022 - 2025 The PyMC Labs Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Class that represents a prior distribution.
The `Prior` class is a wrapper around PyMC distributions that allows the user
to create outside of the PyMC model.
.. note::
This module has been deprecated and is moved to `pymc_extras.prior`.
This is the alternative to using the dictionaries in PyMC-Marketing models.
Examples
--------
Create a normal prior.
.. code-block:: python
from pymc_extras.prior import Prior
normal = Prior("Normal")
Create a hierarchical normal prior by using distributions for the parameters
and specifying the dims.
.. code-block:: python
hierarchical_normal = Prior(
"Normal",
mu=Prior("Normal"),
sigma=Prior("HalfNormal"),
dims="channel",
)
Create a non-centered hierarchical normal prior with the `centered` parameter.
.. code-block:: python
non_centered_hierarchical_normal = Prior(
"Normal",
mu=Prior("Normal"),
sigma=Prior("HalfNormal"),
dims="channel",
# Only change needed to make it non-centered
centered=False,
)
Create a hierarchical beta prior by using Beta distribution, distributions for
the parameters, and specifying the dims.
.. code-block:: python
hierarchical_beta = Prior(
"Beta",
alpha=Prior("HalfNormal"),
beta=Prior("HalfNormal"),
dims="channel",
)
Create a transformed hierarchical normal prior by using the `transform`
parameter. Here the "sigmoid" transformation comes from `pm.math`.
.. code-block:: python
transformed_hierarchical_normal = Prior(
"Normal",
mu=Prior("Normal"),
sigma=Prior("HalfNormal"),
transform="sigmoid",
dims="channel",
)
Create a prior with a custom transform function by registering it with
`register_tensor_transform`.
.. code-block:: python
from pymc_extras.prior import register_tensor_transform
def custom_transform(x):
return x**2
register_tensor_transform("square", custom_transform)
custom_distribution = Prior("Normal", transform="square")
"""
from __future__ import annotations
import copy
import functools
import warnings
from typing import Any
from pymc_extras import prior
from pymc_extras.deserialize import deserialize, register_deserialization
[docs]
def is_alternative_prior(data: Any) -> bool:
"""Check if the data is a dictionary representing a Prior (alternative check)."""
return isinstance(data, dict) and "distribution" in data
[docs]
def deserialize_alternative_prior(data: dict[str, Any]) -> prior.Prior:
"""Alternative deserializer that recursively handles all nested parameters.
This implementation is more general and handles cases where any parameter
might be a nested prior, and also extracts centered and transform parameters.
Examples
--------
This handles cases like:
.. code-block:: yaml
distribution: Gamma
alpha: 1
beta:
distribution: HalfNormal
sigma: 1
dims: channel
dims: [brand, channel]
"""
data = copy.deepcopy(data)
distribution = data.pop("distribution")
dims = data.pop("dims", None)
centered = data.pop("centered", True)
transform = data.pop("transform", None)
parameters = data
# Recursively deserialize any nested parameters
parameters = {
key: value if not isinstance(value, dict) else deserialize(value)
for key, value in parameters.items()
}
return prior.Prior(
distribution,
transform=transform,
centered=centered,
dims=dims,
**parameters,
)
# Register the alternative prior deserializer for more complex nested cases
register_deserialization(is_alternative_prior, deserialize_alternative_prior)
[docs]
def warn_class_deprecation(func):
"""Warn about the deprecation of this module."""
@functools.wraps(func)
def wrapper(self, *args, **kwargs):
name = self.__class__.__name__
warnings.warn(
f"The {name} class has moved to pymc_extras.prior module and will be removed in a future release. "
f"Import it from `from pymc_extras.prior import {name}`. ",
DeprecationWarning,
stacklevel=2,
)
return func(self, *args, **kwargs)
return wrapper
[docs]
def warn_function_deprecation(func):
"""Warn about the deprecation of this function."""
@functools.wraps(func)
def wrapper(*args, **kwargs):
name = func.__name__
warnings.warn(
f"The {name} function has moved to pymc_extras.prior module and will be removed in a future release. "
f"Import it from `from pymc_extras.prior import {name}`.",
DeprecationWarning,
stacklevel=2,
)
return func(*args, **kwargs)
return wrapper
[docs]
class Prior(prior.Prior):
"""Backwards-compatible wrapper for the Prior class."""
[docs]
@warn_class_deprecation
def __init__(self, *args, **kwargs):
"""Initialize the Prior class with the given arguments."""
super().__init__(*args, **kwargs)
[docs]
class Censored(prior.Censored):
"""Backwards-compatible wrapper for the CensoredPrior class."""
[docs]
@warn_class_deprecation
def __init__(self, *args, **kwargs):
"""Initialize the CensoredPrior class with the given arguments."""
super().__init__(*args, **kwargs)
[docs]
class Scaled(prior.Scaled):
"""Backwards-compatible wrapper for the ScaledPrior class."""
[docs]
@warn_class_deprecation
def __init__(self, *args, **kwargs):
"""Initialize the ScaledPrior class with the given arguments."""
super().__init__(*args, **kwargs)
sample_prior = warn_function_deprecation(prior.sample_prior)
create_dim_handler = warn_function_deprecation(prior.create_dim_handler)
handle_dims = warn_function_deprecation(prior.handle_dims)
register_tensor_transform = warn_function_deprecation(prior.register_tensor_transform)