MMM Multidimensional Example Notebook#

In this notebook, we present an new experimental media mix model class to create multidimensional and customized marketing mix models. To showcase its capabilities, we extend the MMM Example Notebook simulation to create a multidimensional hierarchical model.

Warning

Even though the new MMM class is an experimental class, it is fully functional and can be used to create multidimensional marketing mix models. This model is under active development and will be further improved in the future (feedback welcome!).

Prepare Notebook#

import warnings

import arviz as az
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pymc as pm
import seaborn as sns
import xarray as xr
from pymc_extras.prior import Prior

from pymc_marketing.mmm import GeometricAdstock, LogisticSaturation
from pymc_marketing.mmm.multidimensional import (
    MMM,
    MultiDimensionalBudgetOptimizerWrapper,
)
from pymc_marketing.paths import data_dir

warnings.filterwarnings("ignore", category=UserWarning)

az.style.use("arviz-darkgrid")
plt.rcParams["figure.figsize"] = [12, 7]
plt.rcParams["figure.dpi"] = 100
plt.rcParams["xtick.labelsize"] = 10
plt.rcParams["ytick.labelsize"] = 8

%load_ext autoreload
%autoreload 2
%config InlineBackend.figure_format = "retina"
C:\Users\dsaun\pymc-marketing\pymc_marketing\mmm\multidimensional.py:75: FutureWarning: This functionality is experimental and subject to change. If you encounter any issues or have suggestions, please raise them at: https://github.com/pymc-labs/pymc-marketing/issues/new
  warnings.warn(warning_msg, FutureWarning, stacklevel=1)
seed: int = sum(map(ord, "mmm_multidimensional"))
rng: np.random.Generator = np.random.default_rng(seed=seed)

Read Data#

We read the simulated data from the MMM Multidimensional Example Notebook.

data_path = data_dir / "mmm_multidimensional_example.csv"

data_df = pd.read_csv(data_path, parse_dates=["date"])
data_df.info()
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 318 entries, 0 to 317
Data columns (total 7 columns):
 #   Column   Non-Null Count  Dtype         
---  ------   --------------  -----         
 0   date     318 non-null    datetime64[ns]
 1   geo      318 non-null    object        
 2   x1       318 non-null    float64       
 3   x2       318 non-null    float64       
 4   event_1  318 non-null    int64         
 5   event_2  318 non-null    int64         
 6   y        318 non-null    float64       
dtypes: datetime64[ns](1), float64(3), int64(2), object(1)
memory usage: 17.5+ KB

For our setup, imagine we are selling one product in two different countries (geo_a and geo_b). Our marketing team maintains two channels - one is a usually-on channel while the other channel is more tactical and is turned on during marketing campaigns. Visual inspection of the data suggests that there is at least some effect of marketing on sales, but the relationship is noisy. Our mission is to see if the MMM can parse the signal in the noise.

One strategy for dealing with noisy, low-signal data is to borrow information from similar contexts. If channel 2 seems to be pretty effective in geo_b, that gives us reason to suspect it will be effective in geo_a. This can be implemented either with full pooling or partial pooling (partial pooling models are often called ‘hierarchical’ or ‘multi-level’). So this notebook will demonstrate how to fit an MMM to multiple markets at the same time and make decisions about how to pool information across the two contexts.

fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 8))
fig.suptitle("Channel Spends Over Time", fontsize=16, fontweight="bold")

blue_colors = ["#1f77b4", "#7aa6c2"]  # Darker and lighter shades of blue

# Plot for geo_a
geo_a_data = data_df[data_df["geo"] == "geo_a"]
ax1.bar(geo_a_data["date"], geo_a_data["x1"], label="x1", width=7, color=blue_colors[0])
ax1.bar(
    geo_a_data["date"],
    geo_a_data["x2"],
    bottom=geo_a_data["x1"],
    label="x2",
    width=7,
    color=blue_colors[1],
)
ax1.plot(geo_a_data["date"], geo_a_data["y"], "--", label="y", color="black")
ax1.set_title("geo_a")
ax1.legend()

# Plot for geo_b
geo_b_data = data_df[data_df["geo"] == "geo_b"]
ax2.bar(geo_b_data["date"], geo_b_data["x1"], label="x1", width=7, color=blue_colors[0])
ax2.bar(
    geo_b_data["date"],
    geo_b_data["x2"],
    bottom=geo_b_data["x1"],
    label="x2",
    width=7,
    color=blue_colors[1],
)
ax2.plot(geo_b_data["date"], geo_b_data["y"], "--", label="y", color="black")
ax2.set_title("geo_b")
ax2.legend()

plt.tight_layout()
../../_images/4b30acffda9598f70a600a53db10926dbdcf9daaed969596b08cb73f01790b0c.png

Prior Specification#

The beta parameter represents the maximum number of weekly sales you could drive through a channel. Beta will the only hierarchical parameter in this model. There is good reason to think that different markets saturate at different levels. A product might be popular in one place and the potential audience is very large. In another market, it’s a niche product. To capture this, we’ll allow the saturation point to vary across geography.

A hierarchical model does something clever - instead of assuming that every geography is different, we assume that they must be at least partially related to each other. If you had 10 geos and 9 of them had a high saturation point, you would reasonably expect the 10th one to have a high saturation point. A hierarchical model will adaptively pool information between contexts. When each geo is very different, it will transfer less information between them. When each geo is very similar, it will transfer more information. This is also called partial pooling. If you need an introduction on Bayesian hierarchical models, check out the comprehensive example “A Primer on Bayesian Methods for Multilevel Modeling” in the PyMC documentation.

We can build hierarchical parameters with the prior API. Notice that we have one Prior object with dimensions (channel, geos) and then we have further Priors for each parameter. The prior on mu captures what we expect the channels to do, without considering their variation on geography. The prior on sigma represents how much the effect varies across geographies. Bayesian inference propagates information all the way up and down the network of parameters - as well learn the values of the interior parameters, mu and sigma, they will act as constraints on the behaviour of the individual beta parameters.

There are a lot of option in how we code up each type of effect. It can be difficult to keep track of which parameters share information, which are completely separate, and which are shared. But there is a trick to help you remember - notice that the interior Priors have fewer dimensions than the main Prior. The extra dimension represents the dimension across which information is transferred. The shared dimensions represent the ones that are independent. Channel 1 in geo_a influences channel 2 in geo_b, but channel 1 never influences channel 2.

beta_prior = Prior(
    "Normal",
    mu=Prior("Normal", mu=-1.5, sigma=0.5, dims=("channel")),
    sigma=Prior("Exponential", scale=0.25, dims=("channel")),
    dims=("channel", "geo"),
    transform="exp",
    centered=False,
)

This notebook is only illustrative so we’ll show you how to code each type of assumption you might make (we aren’t recommending a universal solution!). Lambda represents the efficiency of a channel. The higher the lambda, the more responsive sales are to spending on that channel. We’ll have the lambda parameter be fully pooled across all geographies. We are assuming that channel 1 has the same efficiency in both geos, so we do not specify “geo” dims. The package will automatically broadcast the channel effect for each geography.

saturation = LogisticSaturation(
    priors={
        "beta": beta_prior,
        "lam": Prior(
            "Gamma",
            mu=0.5,
            sigma=0.25,
            dims=("channel"),
        ),
    }
)

saturation.model_config
{'saturation_lam': Prior("Gamma", mu=0.5, sigma=0.25, dims="channel"),
 'saturation_beta': Prior("Normal", mu=Prior("Normal", mu=-1.5, sigma=0.5, dims="channel"), sigma=Prior("Exponential", scale=0.25, dims="channel"), dims=("channel", "geo"), centered=False, transform="exp")}

Adstock alpha represents how long customers remember marketing. For adstock, we’ll illustrate the unpooled strategy. Here, each channel in each geography has its own effect and those effects do not influence each other. Notice that we put a dim for both geos and channels to indicate that we want 4 unique effects.

adstock = GeometricAdstock(
    priors={"alpha": Prior("Beta", alpha=2, beta=5, dims=("geo", "channel"))}, l_max=8
)

adstock.model_config
{'adstock_alpha': Prior("Beta", alpha=2, beta=5, dims=("geo", "channel"))}

You can mix and match the unpooled, fully pooled, and partially pooled strategy for any of your effects. You can extend this strategy to controls or noise parameters as well. Given the variety of options, it can be hard to know which pooling strategy to choose for a given effect. In our opinion, the choice is primarily driven by computational considerations. Partial pooling is generally a more reasonable assumption but it can make the model slower to estimate, more complicated to debug, and more difficult to reason about.

For example, you might notice that we set our prior with centered=False. This is known as a reparameterization, a strategy to solve computational difficulties that MCMC algorithms can run into when fitting hierarchical models. We recommend that you start with a model that uses only fully pooled or unpooled effects. Once you have a good working model you can add complexity slowly, verifying your model performance and accuracy at each stage.

We complete the model specification with similar priors as in the MMM Example Notebook.

model_config = {
    "intercept": Prior("Gamma", mu=0.5, sigma=0.25, dims="geo"),
    "gamma_control": Prior("Normal", mu=0, sigma=0.5, dims="control"),
    "gamma_fourier": Prior(
        "Normal",
        mu=0,
        sigma=Prior("HalfNormal", sigma=0.2),
        dims=("geo", "fourier_mode"),
        centered=False,
    ),
    "likelihood": Prior(
        "TruncatedNormal",
        lower=0,
        sigma=Prior("HalfNormal", sigma=1.5),
        dims=("date", "geo"),
    ),
}

Model Definition#

We are now ready to define the model class. The API is very similar to the one in the MMM Example Notebook.

# Base MMM model specification
mmm = MMM(
    date_column="date",
    target_column="y",
    channel_columns=["x1", "x2"],
    control_columns=["event_1", "event_2"],
    dims=("geo",),
    scaling={
        "channel": {"method": "max", "dims": ()},
        "target": {"method": "max", "dims": ()},
    },
    adstock=adstock,
    saturation=saturation,
    yearly_seasonality=2,
    model_config=model_config,
)

Tip

Observe we have the following two new arguments:

  • dims: a tuple of strings that specify the dimensions of the model.

  • scaling: a dictionary that specifies the scaling method and dimensions for the target and media variables. In this case we leave the dimensions empty as we want to scale the target variable for each geo (see details below).

We can now prepare the training data.

x_train = data_df.drop(columns=["y"])
y_train = data_df["y"]

To build the model, we need to specify the training data and the target variables.

Tip

We do not need to build the model, we can simply fit the model. This is just to inspect the model structure.

mmm.build_model(X=x_train, y=y_train)

Let’s look into the model graph:

pm.model_to_graphviz(mmm.model)
../../_images/cd9b2ab5972c845b7ab2586811df3acc089e5093482d5247d5f125f9095899cf.svg

It is great to see that the model automatically vectorizes and creates the expected hierarchies and dimensions 🚀!

As we are scaling our data internally, we can add deterministic terms to recover the component contributions in the original scale.

mmm.add_original_scale_contribution_variable(
    var=[
        "channel_contribution",
        "control_contribution",
        "intercept_contribution",
        "yearly_seasonality_contribution",
        "y",
    ]
)

pm.model_to_graphviz(mmm.model)
../../_images/1c155aaaa09a74f13f19aff8ff166a164ee6393d5afa93d08898d6d989bfddc2.svg

Coming back to the scalers, we can get them as an xarray dataset.

scalers = mmm.get_scales_as_xarray()

scalers
{'channel_scale': <xarray.DataArray '_channel' (geo: 2, channel: 2)> Size: 32B
 array([[ 9318.97848455,  9755.9729876 ],
        [10555.0774866 , 11760.98180037]])
 Coordinates:
   * geo      (geo) object 16B 'geo_a' 'geo_b'
   * channel  (channel) object 16B 'x1' 'x2',
 'target_scale': <xarray.DataArray '_target' (geo: 2)> Size: 16B
 array([13812.08025674, 11002.97913936])
 Coordinates:
   * geo      (geo) object 16B 'geo_a' 'geo_b'}

As expected, from the model definition, we have scalers for the target and media variables across geos.

Prior Predictive Checks#

Before fitting the model, we can inspect the prior predictive distribution.

with mmm.model:
    prior = pm.sample_prior_predictive()
prior
Sampling: [adstock_alpha, gamma_control, gamma_fourier_offset, gamma_fourier_sigma, intercept_contribution, saturation_beta_raw_mu, saturation_beta_raw_offset, saturation_beta_raw_sigma, saturation_lam, y, y_sigma]
arviz.InferenceData
    • <xarray.Dataset> Size: 19MB
      Dimensions:                                         (chain: 1, draw: 500,
                                                           control: 2, date: 159,
                                                           geo: 2, channel: 2,
                                                           fourier_mode: 4)
      Coordinates:
        * chain                                           (chain) int64 8B 0
        * draw                                            (draw) int64 4kB 0 1 ... 499
        * control                                         (control) <U7 56B 'event_...
        * date                                            (date) datetime64[ns] 1kB ...
        * geo                                             (geo) <U5 40B 'geo_a' 'ge...
        * channel                                         (channel) <U2 16B 'x1' 'x2'
        * fourier_mode                                    (fourier_mode) <U5 80B 's...
      Data variables: (12/23)
          gamma_control                                   (chain, draw, control) float64 8kB ...
          y_original_scale                                (chain, draw, date, geo) float64 1MB ...
          saturation_beta                                 (chain, draw, channel, geo) float64 16kB ...
          gamma_fourier                                   (chain, draw, geo, fourier_mode) float64 32kB ...
          control_contribution                            (chain, draw, date, geo, control) float64 3MB ...
          y_sigma                                         (chain, draw) float64 4kB ...
          ...                                              ...
          yearly_seasonality_contribution_original_scale  (chain, draw, date, geo) float64 1MB ...
          channel_contribution                            (chain, draw, date, geo, channel) float64 3MB ...
          intercept_contribution                          (chain, draw, geo) float64 8kB ...
          saturation_beta_raw                             (chain, draw, channel, geo) float64 16kB ...
          saturation_lam                                  (chain, draw, channel) float64 8kB ...
          gamma_fourier_sigma                             (chain, draw) float64 4kB ...
      Attributes:
          created_at:                 2025-08-15T21:50:06.142004+00:00
          arviz_version:              0.22.0
          inference_library:          pymc
          inference_library_version:  5.25.1

    • <xarray.Dataset> Size: 1MB
      Dimensions:  (chain: 1, draw: 500, date: 159, geo: 2)
      Coordinates:
        * chain    (chain) int64 8B 0
        * draw     (draw) int64 4kB 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499
        * date     (date) datetime64[ns] 1kB 2022-06-06 2022-06-13 ... 2025-06-16
        * geo      (geo) <U5 40B 'geo_a' 'geo_b'
      Data variables:
          y        (chain, draw, date, geo) float64 1MB 0.8647 1.005 ... 0.405 2.664
      Attributes:
          created_at:                 2025-08-15T21:50:06.164382+00:00
          arviz_version:              0.22.0
          inference_library:          pymc
          inference_library_version:  5.25.1

    • <xarray.Dataset> Size: 4kB
      Dimensions:  (date: 159, geo: 2)
      Coordinates:
        * date     (date) datetime64[ns] 1kB 2022-06-06 2022-06-13 ... 2025-06-16
        * geo      (geo) <U5 40B 'geo_a' 'geo_b'
      Data variables:
          y        (date, geo) float64 3kB 0.1917 0.06202 0.3635 ... 0.4068 0.5073
      Attributes:
          created_at:                 2025-08-15T21:50:06.167874+00:00
          arviz_version:              0.22.0
          inference_library:          pymc
          inference_library_version:  5.25.1

    • <xarray.Dataset> Size: 12kB
      Dimensions:        (geo: 2, channel: 2, date: 159, control: 2)
      Coordinates:
        * geo            (geo) <U5 40B 'geo_a' 'geo_b'
        * channel        (channel) <U2 16B 'x1' 'x2'
        * date           (date) datetime64[ns] 1kB 2022-06-06 ... 2025-06-16
        * control        (control) <U7 56B 'event_1' 'event_2'
      Data variables:
          channel_scale  (geo, channel) float64 32B 9.319e+03 9.756e+03 ... 1.176e+04
          target_scale   (geo) float64 16B 1.381e+04 1.1e+04
          channel_data   (date, geo, channel) float64 5kB 5.528e+03 0.0 ... 8.091e+03
          target_data    (date, geo) float64 3kB 2.648e+03 682.4 ... 5.581e+03
          control_data   (date, geo, control) int32 3kB 0 0 0 0 0 0 0 ... 0 0 0 0 0 0
          dayofyear      (date) int32 636B 157 164 171 178 185 ... 139 146 153 160 167
      Attributes:
          created_at:                 2025-08-15T21:50:06.181618+00:00
          arviz_version:              0.22.0
          inference_library:          pymc
          inference_library_version:  5.25.1

g = sns.relplot(
    data=data_df,
    x="date",
    y="y",
    color="black",
    col="geo",
    col_wrap=1,
    kind="line",
    height=4,
    aspect=3,
)

axes = g.axes.flatten()

for ax, geo in zip(axes, mmm.model.coords["geo"], strict=True):
    az.plot_hdi(
        x=mmm.model.coords["date"],
        y=(
            prior.prior.sel(geo=geo)["y_original_scale"]
            .unstack()
            .transpose(..., "date")
        ),
        smooth=False,
        color="C0",
        hdi_prob=0.94,
        fill_kwargs={"alpha": 0.3, "label": "94% HDI"},
        ax=ax,
    )
    az.plot_hdi(
        x=mmm.model.coords["date"],
        y=(
            prior.prior.sel(geo=geo)["y_original_scale"]
            .unstack()
            .transpose(..., "date")
        ),
        smooth=False,
        color="C0",
        hdi_prob=0.5,
        fill_kwargs={"alpha": 0.5, "label": "50% HDI"},
        ax=ax,
    )
    ax.legend(loc="upper left")

g.figure.suptitle("Prior Predictive", fontsize=16, fontweight="bold", y=1.03);
../../_images/5406f852b2a69e1a2a13a6fb3e1bacb38b4f44640ef496ea5c742b0c93537a17.png

The prior predictive distribution looks good and not too restrictive.

Model Fitting#

We can now fit the model and generate the posterior predictive distribution.

mmm.fit(
    X=x_train,
    y=y_train,
    chains=4,
    target_accept=0.975,
    random_seed=rng,
)

mmm.sample_posterior_predictive(
    X=x_train,
    extend_idata=True,
    combined=True,
    random_seed=rng,
)
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [intercept_contribution, adstock_alpha, saturation_lam, saturation_beta_raw_offset, saturation_beta_raw_mu, saturation_beta_raw_sigma, gamma_control, gamma_fourier_offset, gamma_fourier_sigma, y_sigma]

Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 99 seconds.

Sampling: [y]

<xarray.Dataset> Size: 20MB
Dimensions:           (date: 159, geo: 2, sample: 4000)
Coordinates:
  * date              (date) datetime64[ns] 1kB 2022-06-06 ... 2025-06-16
  * geo               (geo) <U5 40B 'geo_a' 'geo_b'
  * sample            (sample) object 32kB MultiIndex
  * chain             (sample) int64 32kB 0 0 0 0 0 0 0 0 0 ... 3 3 3 3 3 3 3 3
  * draw              (sample) int64 32kB 0 1 2 3 4 5 ... 995 996 997 998 999
Data variables:
    y                 (date, geo, sample) float64 10MB 0.5745 0.5809 ... 0.0395
    y_original_scale  (date, geo, sample) float64 10MB 7.935e+03 ... 434.6
Attributes:
    created_at:                 2025-08-15T21:52:03.186757+00:00
    arviz_version:              0.22.0
    inference_library:          pymc
    inference_library_version:  5.25.1

The sampling looks good. No divergences and the r-hat values are close to \(1\).

mmm.idata.sample_stats.diverging.sum("draw")
<xarray.DataArray 'diverging' (chain: 4)> Size: 32B
array([0, 0, 0, 0])
Coordinates:
  * chain    (chain) int64 32B 0 1 2 3
az.summary(
    mmm.idata,
    var_names=[
        "adstock_alpha",
        "gamma_control",
        "gamma_fourier",
        "intercept_contribution",
        "saturation_beta",
        "saturation_beta_raw_mu",
        "saturation_beta_raw_sigma",
        "saturation_lam",
        "y_sigma",
    ],
)
mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk ess_tail r_hat
adstock_alpha[geo_a, x1] 0.291 0.163 0.036 0.598 0.002 0.002 4730.0 2551.0 1.0
adstock_alpha[geo_a, x2] 0.311 0.165 0.034 0.609 0.002 0.002 4214.0 2413.0 1.0
adstock_alpha[geo_b, x1] 0.259 0.149 0.014 0.526 0.002 0.002 4529.0 2386.0 1.0
adstock_alpha[geo_b, x2] 0.273 0.154 0.009 0.544 0.002 0.002 4128.0 1896.0 1.0
gamma_control[event_1] 0.303 0.087 0.140 0.463 0.001 0.001 4832.0 2865.0 1.0
gamma_control[event_2] -0.098 0.092 -0.271 0.078 0.001 0.002 4603.0 2753.0 1.0
gamma_fourier[geo_a, sin_1] -0.350 0.035 -0.415 -0.285 0.001 0.000 3440.0 3156.0 1.0
gamma_fourier[geo_a, sin_2] -0.028 0.028 -0.081 0.025 0.000 0.001 3975.0 2903.0 1.0
gamma_fourier[geo_a, cos_1] -0.285 0.033 -0.348 -0.224 0.001 0.000 4231.0 3608.0 1.0
gamma_fourier[geo_a, cos_2] 0.004 0.027 -0.046 0.055 0.000 0.000 4641.0 2992.0 1.0
gamma_fourier[geo_b, sin_1] -0.046 0.025 -0.092 -0.000 0.000 0.000 6112.0 3272.0 1.0
gamma_fourier[geo_b, sin_2] 0.190 0.027 0.141 0.239 0.000 0.000 5180.0 3432.0 1.0
gamma_fourier[geo_b, cos_1] -0.200 0.030 -0.254 -0.142 0.000 0.000 5350.0 3719.0 1.0
gamma_fourier[geo_b, cos_2] -0.030 0.025 -0.077 0.018 0.000 0.000 5615.0 3183.0 1.0
intercept_contribution[geo_a] 0.196 0.029 0.141 0.248 0.001 0.000 3064.0 2074.0 1.0
intercept_contribution[geo_b] 0.213 0.028 0.159 0.261 0.001 0.000 3089.0 2077.0 1.0
saturation_beta[x1, geo_a] 0.215 0.114 0.038 0.422 0.002 0.002 4150.0 3142.0 1.0
saturation_beta[x1, geo_b] 0.278 0.176 0.051 0.590 0.003 0.005 3391.0 3092.0 1.0
saturation_beta[x2, geo_a] 0.242 0.127 0.051 0.480 0.002 0.003 3943.0 3061.0 1.0
saturation_beta[x2, geo_b] 0.241 0.131 0.050 0.480 0.002 0.003 4546.0 3303.0 1.0
saturation_beta_raw_mu[x1] -1.557 0.449 -2.490 -0.784 0.007 0.007 4541.0 3110.0 1.0
saturation_beta_raw_mu[x2] -1.537 0.448 -2.410 -0.730 0.007 0.007 4577.0 2795.0 1.0
saturation_beta_raw_sigma[x1] 0.270 0.261 0.000 0.765 0.004 0.006 3002.0 2377.0 1.0
saturation_beta_raw_sigma[x2] 0.238 0.233 0.000 0.662 0.003 0.006 3673.0 2250.0 1.0
saturation_lam[x1] 0.471 0.220 0.108 0.880 0.003 0.004 4053.0 2717.0 1.0
saturation_lam[x2] 0.483 0.219 0.108 0.876 0.003 0.004 4273.0 2618.0 1.0
y_sigma 0.183 0.010 0.165 0.203 0.000 0.000 2925.0 3089.0 1.0
_ = az.plot_trace(
    data=mmm.idata,
    var_names=[
        "adstock_alpha",
        "gamma_control",
        "gamma_fourier",
        "intercept_contribution",
        "saturation_beta",
        "saturation_beta_raw_mu",
        "saturation_beta_raw_sigma",
        "saturation_lam",
        "y_sigma",
    ],
    compact=True,
    backend_kwargs={"figsize": (15, 15), "layout": "constrained"},
)
plt.gcf().suptitle("Model Trace", fontsize=16, fontweight="bold", y=1.03);
../../_images/8f8eb9ff458a1b90cc3c38e777e97bcef991f4d664ff6799892da33eb6de8111.png

Posterior Predictive Checks#

We can now inspect the posterior predictive distribution. As before, we need to scale the posterior predictive to the original scale to make it comparable to the data.

fig, axes = plt.subplots(
    nrows=len(mmm.model.coords["geo"]),
    figsize=(12, 9),
    sharex=True,
    sharey=True,
    layout="constrained",
)

for i, geo in enumerate(mmm.model.coords["geo"]):
    ax = axes[i]
    az.plot_hdi(
        x=mmm.model.coords["date"],
        y=(mmm.idata["posterior_predictive"].y_original_scale.sel(geo=geo)),
        color="C0",
        smooth=False,
        hdi_prob=0.94,
        fill_kwargs={"alpha": 0.2, "label": "94% HDI"},
        ax=ax,
    )

    az.plot_hdi(
        x=mmm.model.coords["date"],
        y=(mmm.idata["posterior_predictive"].y_original_scale.sel(geo=geo)),
        color="C0",
        smooth=False,
        hdi_prob=0.5,
        fill_kwargs={"alpha": 0.4, "label": "50% HDI"},
        ax=ax,
    )

    sns.lineplot(
        data=data_df.query("geo == @geo"),
        x="date",
        y="y",
        color="black",
        ax=ax,
    )

    ax.legend(loc="upper left")
    ax.set(title=f"{geo}")

fig.suptitle("Posterior Predictive", fontsize=16, fontweight="bold", y=1.03);
../../_images/91ecdc6cc5658e907cb5624a01ac49b1e6972f211f1df8a88f738ba936b14d67.png

The fit looks okay! There is a lot of white-noise in the sales process we cannot predict. However, the main movements in the sales are either captured by our seasonality model or the MMM components.

Model Components#

We can extract the contributions of each component of the model in the original scale thanks to the deterministic variables added to the model.

fig, axes = plt.subplots(
    nrows=len(mmm.model.coords["geo"]),
    figsize=(15, 10),
    sharex=True,
    sharey=True,
    layout="constrained",
)

for i, geo in enumerate(mmm.model.coords["geo"]):
    ax = axes[i]

    for j, channel in enumerate(mmm.model.coords["channel"]):
        az.plot_hdi(
            x=mmm.model.coords["date"],
            y=mmm.idata["posterior"]["channel_contribution_original_scale"].sel(
                geo=geo, channel=channel
            ),
            color=f"C{j}",
            smooth=False,
            hdi_prob=0.94,
            fill_kwargs={"alpha": 0.5, "label": f"94% HDI ({channel})"},
            ax=ax,
        )

    az.plot_hdi(
        x=mmm.model.coords["date"],
        y=mmm.idata["posterior"]["intercept_contribution_original_scale"]
        .sel(geo=geo)
        .expand_dims({"date": mmm.model.coords["date"]})
        .transpose(..., "date"),
        color="C2",
        smooth=False,
        hdi_prob=0.94,
        fill_kwargs={"alpha": 0.5, "label": "94% HDI intercept"},
        ax=ax,
    )

    az.plot_hdi(
        x=mmm.model.coords["date"],
        y=mmm.idata["posterior"]["yearly_seasonality_contribution_original_scale"].sel(
            geo=geo,
        ),
        color="C3",
        smooth=False,
        hdi_prob=0.94,
        fill_kwargs={"alpha": 0.5, "label": "94% HDI Fourier"},
        ax=ax,
    )

    for k, control in enumerate(mmm.model.coords["control"]):
        az.plot_hdi(
            x=mmm.model.coords["date"],
            y=mmm.idata["posterior"]["control_contribution_original_scale"].sel(
                geo=geo, control=control
            ),
            color=f"C{5 + k}",
            smooth=False,
            hdi_prob=0.94,
            fill_kwargs={"alpha": 0.5, "label": f"94% HDI control ({control})"},
            ax=ax,
        )

    sns.lineplot(
        data=data_df.query("geo == @geo"),
        x="date",
        y="y",
        color="black",
        label="y",
        ax=ax,
    )
    ax.legend(
        loc="upper center",
        bbox_to_anchor=(0.5, -0.1),
        ncol=4,
    )
    ax.set(title=f"{geo}")

fig.suptitle(
    "Posterior Predictive - Channel Contributions",
    fontsize=16,
    fontweight="bold",
    y=1.03,
);
../../_images/febf4410e7b8bf0e160e3d6a36ea62fdbb0c19e4044b1f8a38b8d90b85d4a3a8.png

Media Deep Dive#

Next, we can look into the individual channel contributions across geos. This new class has a new plot name space that contains many plotting methods.

fig, axes = mmm.plot.contributions_over_time(
    var=["channel_contribution_original_scale"],
)

# Adjust figure size and layout to 2x2
fig.set_size_inches(14, 10)
fig.set_constrained_layout(True)

# Reshape axes to 2x2 grid
num_axes = len(axes.flatten())
if num_axes > 0:
    # Create a new 2x2 grid
    gs = fig.add_gridspec(2, 2)

    # Move existing axes to the new grid
    for i, ax in enumerate(axes.flatten()):
        if i < 4:  # Only handle up to 4 axes for 2x2 grid
            ax.set_position(gs[i // 2, i % 2].get_position(fig))

axes = axes.flatten()

# Share x and y axes across all subplots
for ax in axes:
    ax.legend().remove()
    ax.tick_params(axis="both", which="major", labelsize=6)
    ax.tick_params(axis="both", which="minor", labelsize=6)

# Share y axis limits
y_min = min(ax.get_ylim()[0] for ax in axes)
y_max = max(ax.get_ylim()[1] for ax in axes)
for ax in axes:
    ax.set_ylim(y_min, y_max)

# Share x axis limits
x_min = min(ax.get_xlim()[0] for ax in axes)
x_max = max(ax.get_xlim()[1] for ax in axes)
for ax in axes:
    ax.set_xlim(x_min, x_max)
../../_images/c2f2b77f8110c5deb32ad4032c886ae1f484f9c4b1cdfb3768893f6c58f79878.png

We can plot the saturation curves for each channel and geo, using a few different functions:

  1. Using saturation_scatterplot, we can get only the scatterplot between investment and estimated returns.

  2. Using saturation_curves, we can get the posterior of the curves and their posterior fit regarding the given mean contribution.

mmm.plot.saturation_scatterplot(width_per_col=8, height_per_row=4, original_scale=True);
../../_images/8d352af381662a255a53df9ed158af5fcbc00925a11f1446b52c2ab62638a7da.png
curve = mmm.saturation.sample_curve(
    mmm.idata.posterior[["saturation_beta", "saturation_lam"]], max_value=2
)
fig, axes = mmm.plot.saturation_curves(
    curve,
    original_scale=True,
    n_samples=10,
    hdi_probs=0.85,
    random_seed=rng,
    subplot_kwargs={"figsize": (12, 8), "ncols": 2},
    rc_params={
        "xtick.labelsize": 10,
        "ytick.labelsize": 10,
        "axes.labelsize": 10,
        "axes.titlesize": 10,
    },
)

for ax in axes.ravel():
    ax.title.set_fontsize(10)

if fig._suptitle is not None:
    fig._suptitle.set_fontsize(12)

plt.tight_layout()
plt.show()
Sampling: [saturation_beta_raw_mu, saturation_beta_raw_offset, saturation_beta_raw_sigma]

../../_images/59b2f32db0f8ed669a4bd66daed7c1bec1a346c41a5c79f7a91d9c1fd4303273.png

Parameter recovery#

One nice sign that the model is working as intend is that it can recover the true parameter values underlying the marketing mechanism. In our case, we know the true parameter values because we simulated the date. Informally, if the bulk of the posterior distribution covers the parameter value, that’s a good sign. We do not expect the mean of the posterior to always line up with the true value - for small or noisy data, we should expect the posterior to cover a wide interval regardless of whether we built a good model or not. There are also formal frameworks for thinking about parameter recovery in simulations that might be helpful if you need even more rigorous evidence the model is working correctly.

Below we compare the posterior distribution to the true values for the main MMM parameters (saturation lambda, saturation beta and adstock alpha).

# Load the true parameters used to generate the data

data_path = data_dir / "mmm_multidimensional_example_true_parameters.nc"
true_parameters = xr.open_dataset(data_path)
az.plot_posterior(
    mmm.fit_result,
    var_names=[
        "saturation_lam",
    ],
    figsize=(12, 4),
    ref_val={
        "saturation_lam": [
            {
                "channel": "x1",
                "ref_val": true_parameters["saturation_lam"].sel(channel="x1").values,
            },
            {
                "channel": "x2",
                "ref_val": true_parameters["saturation_lam"].sel(channel="x2").values,
            },
        ]
    },
);
../../_images/62de3e2d4bfd5a39c80dfae43546a3a4a75421d8a5ae6e6d40a7642eddb78580.png
az.plot_posterior(
    mmm.fit_result,
    var_names=[
        "saturation_beta",
    ],
    grid=(2, 2),
    figsize=(12, 8),
    ref_val={
        "saturation_beta": [
            {
                "channel": "x1",
                "geo": "geo_a",
                "ref_val": true_parameters["saturation_beta"]
                .sel(channel="x1", geo="geo_a")
                .values,
            },
            {
                "channel": "x2",
                "geo": "geo_a",
                "ref_val": true_parameters["saturation_beta"]
                .sel(channel="x2", geo="geo_a")
                .values,
            },
            {
                "channel": "x1",
                "geo": "geo_b",
                "ref_val": true_parameters["saturation_beta"]
                .sel(channel="x1", geo="geo_b")
                .values,
            },
            {
                "channel": "x2",
                "geo": "geo_b",
                "ref_val": true_parameters["saturation_beta"]
                .sel(channel="x2", geo="geo_b")
                .values,
            },
        ]
    },
);
../../_images/2aa762bd46391047cb1911cff2f142ea6acbb6ae4c6588f6339a882762dee82f.png
az.plot_posterior(
    mmm.fit_result,
    var_names=[
        "adstock_alpha",
    ],
    grid=(2, 2),
    figsize=(12, 8),
    ref_val={
        "adstock_alpha": [
            {
                "channel": "x1",
                "geo": "geo_a",
                "ref_val": true_parameters["adstock_alpha"]
                .sel(channel="x1", geo="geo_a")
                .values,
            },
            {
                "channel": "x2",
                "geo": "geo_a",
                "ref_val": true_parameters["adstock_alpha"]
                .sel(channel="x2", geo="geo_a")
                .values,
            },
            {
                "channel": "x1",
                "geo": "geo_b",
                "ref_val": true_parameters["adstock_alpha"]
                .sel(channel="x1", geo="geo_b")
                .values,
            },
            {
                "channel": "x2",
                "geo": "geo_b",
                "ref_val": true_parameters["adstock_alpha"]
                .sel(channel="x2", geo="geo_b")
                .values,
            },
        ]
    },
);
../../_images/1ee23b791574d055f3447ddbabbfa1b68c9fb2df121114798ed11e475f61bb13.png

Out of Sample Predictions#

It is very important to be able to make predictions out of the sample. This is key for model validation, forward looking scenario planning and business decision making. Similarly as in the MMM Example Notebook, we assume the future spends are the same as the last day in the training sample. This way we can create a new dataset with the future dates and channel spends and use the model to make predictions.

last_date = x_train["date"].max()

# New dates starting from last in dataset
n_new = 7
new_dates = pd.date_range(start=last_date, periods=1 + n_new, freq="W-MON")[1:]

x_out_of_sample_geo_a = pd.DataFrame({"date": new_dates, "geo": "geo_a"})
x_out_of_sample_geo_b = pd.DataFrame({"date": new_dates, "geo": "geo_b"})

# Same channel spends as last day
x_out_of_sample_geo_a["x1"] = x_train.query("geo == 'geo_a'")["x1"].iloc[-1]
x_out_of_sample_geo_a["x2"] = x_train.query("geo == 'geo_a'")["x2"].iloc[-1]

x_out_of_sample_geo_b["x1"] = x_train.query("geo == 'geo_b'")["x1"].iloc[-1]
x_out_of_sample_geo_b["x2"] = x_train.query("geo == 'geo_b'")["x2"].iloc[-1]

# Other features
## Event 1
x_out_of_sample_geo_a["event_1"] = 0.0
x_out_of_sample_geo_a["event_2"] = 0.0
## Event 2
x_out_of_sample_geo_b["event_1"] = 0.0
x_out_of_sample_geo_b["event_2"] = 0.0

x_out_of_sample = pd.concat([x_out_of_sample_geo_a, x_out_of_sample_geo_b])

# Final dataset to generate out of sample predictions.
x_out_of_sample
date geo x1 x2 event_1 event_2
0 2025-06-23 geo_a 0.0 6384.065021 0.0 0.0
1 2025-06-30 geo_a 0.0 6384.065021 0.0 0.0
2 2025-07-07 geo_a 0.0 6384.065021 0.0 0.0
3 2025-07-14 geo_a 0.0 6384.065021 0.0 0.0
4 2025-07-21 geo_a 0.0 6384.065021 0.0 0.0
5 2025-07-28 geo_a 0.0 6384.065021 0.0 0.0
6 2025-08-04 geo_a 0.0 6384.065021 0.0 0.0
0 2025-06-23 geo_b 0.0 8090.900533 0.0 0.0
1 2025-06-30 geo_b 0.0 8090.900533 0.0 0.0
2 2025-07-07 geo_b 0.0 8090.900533 0.0 0.0
3 2025-07-14 geo_b 0.0 8090.900533 0.0 0.0
4 2025-07-21 geo_b 0.0 8090.900533 0.0 0.0
5 2025-07-28 geo_b 0.0 8090.900533 0.0 0.0
6 2025-08-04 geo_b 0.0 8090.900533 0.0 0.0

Using the same sample_posterior_predictive method, we can now generate the forecast.

y_out_of_sample = mmm.sample_posterior_predictive(
    x_out_of_sample,
    extend_idata=False,
    include_last_observations=True,
    random_seed=rng,
    var_names=["y_original_scale"],
)

y_out_of_sample
Sampling: [y]

<xarray.Dataset> Size: 544kB
Dimensions:           (date: 7, geo: 2, sample: 4000)
Coordinates:
  * date              (date) datetime64[ns] 56B 2025-06-23 ... 2025-08-04
  * geo               (geo) <U5 40B 'geo_a' 'geo_b'
  * sample            (sample) object 32kB MultiIndex
  * chain             (sample) int64 32kB 0 0 0 0 0 0 0 0 0 ... 3 3 3 3 3 3 3 3
  * draw              (sample) int64 32kB 0 1 2 3 4 5 ... 995 996 997 998 999
Data variables:
    y_original_scale  (date, geo, sample) float64 448kB 7.35e+03 ... 6.859e+03
Attributes:
    created_at:                 2025-08-15T21:52:21.623974+00:00
    arviz_version:              0.22.0
    inference_library:          pymc
    inference_library_version:  5.25.1
fig, axes = plt.subplots(
    nrows=2,
    ncols=1,
    figsize=(12, 10),
    sharex=True,
    sharey=True,
    layout="constrained",
)

n_train_to_plot = 30

for ax, geo in zip(axes, mmm.model.coords["geo"], strict=True):
    for hdi_prob in [0.94, 0.5]:
        az.plot_hdi(
            x=mmm.model.coords["date"][-n_train_to_plot:],
            y=(
                mmm.idata["posterior_predictive"].y_original_scale.sel(geo=geo)[
                    :, :, -n_train_to_plot:
                ]
            ),
            color="C0",
            smooth=False,
            hdi_prob=hdi_prob,
            fill_kwargs={"alpha": 0.4, "label": f"{hdi_prob: 0.0%} HDI"},
            ax=ax,
        )

        az.plot_hdi(
            x_out_of_sample.query("geo == @geo")["date"],
            (
                y_out_of_sample["y_original_scale"]
                .sel(geo=geo)
                .unstack()
                .transpose(..., "date")
            ),
            color="C1",
            smooth=False,
            hdi_prob=hdi_prob,
            fill_kwargs={"alpha": 0.4, "label": f"{hdi_prob: 0.0%} HDI"},
            ax=ax,
        )

        ax.plot(
            x_out_of_sample.query("geo == @geo")["date"],
            y_out_of_sample["y_original_scale"].sel(geo=geo).mean(dim="sample"),
            marker="o",
            color="C1",
            label="posterior predictive mean",
        )

    sns.lineplot(
        data=data_df.query("(geo == @geo)").tail(n_train_to_plot),
        x="date",
        y="y",
        marker="o",
        color="black",
        label="observed",
        ax=ax,
    )

    ax.axvline(x=last_date, color="gray", linestyle="--", label="last observation")
    ax.legend(
        loc="upper center",
        bbox_to_anchor=(0.5, -0.15),
        ncol=3,
    )
    ax.set(title=f"{geo}")

fig.suptitle(
    "Posterior Predictive - Out of Sample", fontsize=16, fontweight="bold", y=1.03
);
../../_images/9bd4bbe6ce7605c0e3e5a35363707ed6546c34cbe7cab77b31bbb98329a838ba.png

Optimization#

If you want to run optimizations, then you need to use the MultiDimensionalBudgetOptimizerWrapper.

optimizable_model = MultiDimensionalBudgetOptimizerWrapper(
    model=mmm, start_date="2021-10-01", end_date="2021-12-31"
)

allocation_xarray, scipy_opt_result = optimizable_model.optimize_budget(
    budget=100_000,
)

sample_allocation = optimizable_model.sample_response_distribution(
    allocation_strategy=allocation_xarray,
)
Sampling: [y]

This objects is an xarray dataset with the allocation and posterior predictive responses!

sample_allocation
<xarray.Dataset> Size: 4MB
Dimensions:                                  (date: 21, geo: 2, sample: 4000,
                                              channel: 2)
Coordinates:
  * date                                     (date) datetime64[ns] 168B 2021-...
  * geo                                      (geo) <U5 40B 'geo_a' 'geo_b'
  * channel                                  (channel) <U2 16B 'x1' 'x2'
  * sample                                   (sample) object 32kB MultiIndex
  * chain                                    (sample) int64 32kB 0 0 0 ... 3 3 3
  * draw                                     (sample) int64 32kB 0 1 ... 998 999
Data variables:
    y                                        (date, geo, sample) float64 1MB ...
    channel_contribution                     (date, geo, channel, sample) float64 3MB ...
    total_media_contribution_original_scale  (sample) float64 32kB 4.729e+04 ...
    allocation                               (geo, channel) float64 32B 2.622...
    x1                                       (date, geo) float64 336B 2.623e+...
    x2                                       (date, geo) float64 336B 3.105e+...
Attributes:
    created_at:                 2025-08-15T21:52:35.494520+00:00
    arviz_version:              0.22.0
    inference_library:          pymc
    inference_library_version:  5.25.1

Once you get the allocation, you can plot a the results 🚀

optimizable_model.plot.budget_allocation(
    samples=sample_allocation,
);
../../_images/da1317df271b5cd06999f46801388d824c52b191d69657ba026b585b9cef116b.png

The graph shows the optimal budget for each channel on each geo, next to their respective mean contribution given the optimal budget. The method identify automatically the number of dimensions and tries to create a plot based on them.

If you want to see the full uncertainty over time, you can use the plot suite and the method allocated_contribution_by_channel_over_time.

optimizable_model.plot.allocated_contribution_by_channel_over_time(
    samples=sample_allocation,
);
../../_images/d15a46af8c2d4da788e6262154328d298768e60e4e879f27483902a999204301.png

If you have a custom model, you can wrapped it into the model protocol, and use the optimizer after. If your model handle scales internally, you don’t need to modify anything. Otherwise, for the plots, you may want to use scale_factor=N. E.g:

optimizable_model.plot.budget_allocation(
    samples=sample_allocation,
    scale_factor=120
);

Save Model#

You can optionally save the result of your hard work. The model result objects (idata) can get very large once we start working in multiple dimensions. So it can sometimes be helpful to compress the idata before saving. Below are a couple of tricks.

# Reduce your posterior (optional)
# clone_idata = mmm.idata.copy()
# clone_idata.posterior = clone_idata.posterior.astype(np.float32)
# clone_idata.posterior = clone_idata.posterior.sel(draw=slice(None, None, 10))

# clone_idata.to_netcdf("multidimensional_model_compressed.nc", groups=["posterior", "fit_data"], engine="h5netcdf")

Note

We are very excited about this new feature and the possibilities it opens up. We are looking forward to hearing your feedback!

%load_ext watermark
%watermark -n -u -v -iv -w -p pymc_marketing,pytensor,nutpie
Last updated: Fri Aug 15 2025

Python implementation: CPython
Python version       : 3.13.5
IPython version      : 9.4.0

pymc_marketing: 0.15.1
pytensor      : 2.31.7
nutpie        : 0.15.2

xarray        : 2025.7.1
matplotlib    : 3.10.3
pymc          : 5.25.1
pandas        : 2.3.1
seaborn       : 0.13.2
numpy         : 2.2.6
arviz         : 0.22.0
pymc_extras   : 0.4.0
pymc_marketing: 0.15.1

Watermark: 2.5.0