MMM Multidimensional Example Notebook#
In this notebook, we present an new experimental media mix model class to create multidimensional and customized marketing mix models. To showcase its capabilities, we extend the MMM Example Notebook simulation to create a multidimensional hierarchical model.
Warning
Even though the new MMM
class is an experimental class, it is fully functional and can be used to create multidimensional marketing mix models. This model is under active development and will be further improved in the future (feedback welcome!).
Prepare Notebook#
import warnings
import arviz as az
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pymc as pm
import seaborn as sns
import xarray as xr
from pymc_extras.prior import Prior
from pymc_marketing.mmm import GeometricAdstock, LogisticSaturation
from pymc_marketing.mmm.multidimensional import (
MMM,
MultiDimensionalBudgetOptimizerWrapper,
)
from pymc_marketing.paths import data_dir
warnings.filterwarnings("ignore", category=UserWarning)
az.style.use("arviz-darkgrid")
plt.rcParams["figure.figsize"] = [12, 7]
plt.rcParams["figure.dpi"] = 100
plt.rcParams["xtick.labelsize"] = 10
plt.rcParams["ytick.labelsize"] = 8
%load_ext autoreload
%autoreload 2
%config InlineBackend.figure_format = "retina"
C:\Users\dsaun\pymc-marketing\pymc_marketing\mmm\multidimensional.py:75: FutureWarning: This functionality is experimental and subject to change. If you encounter any issues or have suggestions, please raise them at: https://github.com/pymc-labs/pymc-marketing/issues/new
warnings.warn(warning_msg, FutureWarning, stacklevel=1)
seed: int = sum(map(ord, "mmm_multidimensional"))
rng: np.random.Generator = np.random.default_rng(seed=seed)
Read Data#
We read the simulated data from the MMM Multidimensional Example Notebook.
data_path = data_dir / "mmm_multidimensional_example.csv"
data_df = pd.read_csv(data_path, parse_dates=["date"])
data_df.info()
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 318 entries, 0 to 317
Data columns (total 7 columns):
# Column Non-Null Count Dtype
--- ------ -------------- -----
0 date 318 non-null datetime64[ns]
1 geo 318 non-null object
2 x1 318 non-null float64
3 x2 318 non-null float64
4 event_1 318 non-null int64
5 event_2 318 non-null int64
6 y 318 non-null float64
dtypes: datetime64[ns](1), float64(3), int64(2), object(1)
memory usage: 17.5+ KB
For our setup, imagine we are selling one product in two different countries (geo_a
and geo_b
). Our marketing team maintains two channels - one is a usually-on channel while the other channel is more tactical and is turned on during marketing campaigns. Visual inspection of the data suggests that there is at least some effect of marketing on sales, but the relationship is noisy. Our mission is to see if the MMM can parse the signal in the noise.
One strategy for dealing with noisy, low-signal data is to borrow information from similar contexts. If channel 2 seems to be pretty effective in geo_b
, that gives us reason to suspect it will be effective in geo_a
. This can be implemented either with full pooling or partial pooling (partial pooling models are often called ‘hierarchical’ or ‘multi-level’). So this notebook will demonstrate how to fit an MMM to multiple markets at the same time and make decisions about how to pool information across the two contexts.
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 8))
fig.suptitle("Channel Spends Over Time", fontsize=16, fontweight="bold")
blue_colors = ["#1f77b4", "#7aa6c2"] # Darker and lighter shades of blue
# Plot for geo_a
geo_a_data = data_df[data_df["geo"] == "geo_a"]
ax1.bar(geo_a_data["date"], geo_a_data["x1"], label="x1", width=7, color=blue_colors[0])
ax1.bar(
geo_a_data["date"],
geo_a_data["x2"],
bottom=geo_a_data["x1"],
label="x2",
width=7,
color=blue_colors[1],
)
ax1.plot(geo_a_data["date"], geo_a_data["y"], "--", label="y", color="black")
ax1.set_title("geo_a")
ax1.legend()
# Plot for geo_b
geo_b_data = data_df[data_df["geo"] == "geo_b"]
ax2.bar(geo_b_data["date"], geo_b_data["x1"], label="x1", width=7, color=blue_colors[0])
ax2.bar(
geo_b_data["date"],
geo_b_data["x2"],
bottom=geo_b_data["x1"],
label="x2",
width=7,
color=blue_colors[1],
)
ax2.plot(geo_b_data["date"], geo_b_data["y"], "--", label="y", color="black")
ax2.set_title("geo_b")
ax2.legend()
plt.tight_layout()

Prior Specification#
The beta parameter represents the maximum number of weekly sales you could drive through a channel. Beta will the only hierarchical parameter in this model. There is good reason to think that different markets saturate at different levels. A product might be popular in one place and the potential audience is very large. In another market, it’s a niche product. To capture this, we’ll allow the saturation point to vary across geography.
A hierarchical model does something clever - instead of assuming that every geography is different, we assume that they must be at least partially related to each other. If you had 10 geos and 9 of them had a high saturation point, you would reasonably expect the 10th one to have a high saturation point. A hierarchical model will adaptively pool information between contexts. When each geo is very different, it will transfer less information between them. When each geo is very similar, it will transfer more information. This is also called partial pooling. If you need an introduction on Bayesian hierarchical models, check out the comprehensive example “A Primer on Bayesian Methods for Multilevel Modeling” in the PyMC documentation.
We can build hierarchical parameters with the prior API. Notice that we have one Prior object with dimensions (channel, geos)
and then we have further Priors for each parameter. The prior on mu
captures what we expect the channels to do, without considering their variation on geography. The prior on sigma
represents how much the effect varies across geographies. Bayesian inference propagates information all the way up and down the network of parameters - as well learn the values of the interior parameters, mu
and sigma
, they will act as constraints on the behaviour of the individual beta parameters.
There are a lot of option in how we code up each type of effect. It can be difficult to keep track of which parameters share information, which are completely separate, and which are shared. But there is a trick to help you remember - notice that the interior Priors have fewer dimensions than the main Prior. The extra dimension represents the dimension across which information is transferred. The shared dimensions represent the ones that are independent. Channel 1 in geo_a
influences channel 2 in geo_b
, but channel 1 never influences channel 2.
beta_prior = Prior(
"Normal",
mu=Prior("Normal", mu=-1.5, sigma=0.5, dims=("channel")),
sigma=Prior("Exponential", scale=0.25, dims=("channel")),
dims=("channel", "geo"),
transform="exp",
centered=False,
)
This notebook is only illustrative so we’ll show you how to code each type of assumption you might make (we aren’t recommending a universal solution!). Lambda represents the efficiency of a channel. The higher the lambda, the more responsive sales are to spending on that channel. We’ll have the lambda parameter be fully pooled across all geographies. We are assuming that channel 1 has the same efficiency in both geos, so we do not specify “geo” dims. The package will automatically broadcast the channel effect for each geography.
saturation = LogisticSaturation(
priors={
"beta": beta_prior,
"lam": Prior(
"Gamma",
mu=0.5,
sigma=0.25,
dims=("channel"),
),
}
)
saturation.model_config
{'saturation_lam': Prior("Gamma", mu=0.5, sigma=0.25, dims="channel"),
'saturation_beta': Prior("Normal", mu=Prior("Normal", mu=-1.5, sigma=0.5, dims="channel"), sigma=Prior("Exponential", scale=0.25, dims="channel"), dims=("channel", "geo"), centered=False, transform="exp")}
Adstock alpha represents how long customers remember marketing. For adstock, we’ll illustrate the unpooled strategy. Here, each channel in each geography has its own effect and those effects do not influence each other. Notice that we put a dim for both geos and channels to indicate that we want 4 unique effects.
adstock = GeometricAdstock(
priors={"alpha": Prior("Beta", alpha=2, beta=5, dims=("geo", "channel"))}, l_max=8
)
adstock.model_config
{'adstock_alpha': Prior("Beta", alpha=2, beta=5, dims=("geo", "channel"))}
You can mix and match the unpooled, fully pooled, and partially pooled strategy for any of your effects. You can extend this strategy to controls or noise parameters as well. Given the variety of options, it can be hard to know which pooling strategy to choose for a given effect. In our opinion, the choice is primarily driven by computational considerations. Partial pooling is generally a more reasonable assumption but it can make the model slower to estimate, more complicated to debug, and more difficult to reason about.
For example, you might notice that we set our prior with centered=False
. This is known as a reparameterization, a strategy to solve computational difficulties that MCMC algorithms can run into when fitting hierarchical models. We recommend that you start with a model that uses only fully pooled or unpooled effects. Once you have a good working model you can add complexity slowly, verifying your model performance and accuracy at each stage.
We complete the model specification with similar priors as in the MMM Example Notebook.
model_config = {
"intercept": Prior("Gamma", mu=0.5, sigma=0.25, dims="geo"),
"gamma_control": Prior("Normal", mu=0, sigma=0.5, dims="control"),
"gamma_fourier": Prior(
"Normal",
mu=0,
sigma=Prior("HalfNormal", sigma=0.2),
dims=("geo", "fourier_mode"),
centered=False,
),
"likelihood": Prior(
"TruncatedNormal",
lower=0,
sigma=Prior("HalfNormal", sigma=1.5),
dims=("date", "geo"),
),
}
Model Definition#
We are now ready to define the model class. The API is very similar to the one in the MMM Example Notebook.
# Base MMM model specification
mmm = MMM(
date_column="date",
target_column="y",
channel_columns=["x1", "x2"],
control_columns=["event_1", "event_2"],
dims=("geo",),
scaling={
"channel": {"method": "max", "dims": ()},
"target": {"method": "max", "dims": ()},
},
adstock=adstock,
saturation=saturation,
yearly_seasonality=2,
model_config=model_config,
)
Tip
Observe we have the following two new arguments:
dims
: a tuple of strings that specify the dimensions of the model.scaling
: a dictionary that specifies the scaling method and dimensions for the target and media variables. In this case we leave the dimensions empty as we want to scale the target variable for each geo (see details below).
We can now prepare the training data.
x_train = data_df.drop(columns=["y"])
y_train = data_df["y"]
To build the model, we need to specify the training data and the target variables.
Tip
We do not need to build the model, we can simply fit the model. This is just to inspect the model structure.
mmm.build_model(X=x_train, y=y_train)
Let’s look into the model graph:
pm.model_to_graphviz(mmm.model)
It is great to see that the model automatically vectorizes and creates the expected hierarchies and dimensions 🚀!
As we are scaling our data internally, we can add deterministic terms to recover the component contributions in the original scale.
mmm.add_original_scale_contribution_variable(
var=[
"channel_contribution",
"control_contribution",
"intercept_contribution",
"yearly_seasonality_contribution",
"y",
]
)
pm.model_to_graphviz(mmm.model)
Coming back to the scalers, we can get them as an xarray dataset.
scalers = mmm.get_scales_as_xarray()
scalers
{'channel_scale': <xarray.DataArray '_channel' (geo: 2, channel: 2)> Size: 32B
array([[ 9318.97848455, 9755.9729876 ],
[10555.0774866 , 11760.98180037]])
Coordinates:
* geo (geo) object 16B 'geo_a' 'geo_b'
* channel (channel) object 16B 'x1' 'x2',
'target_scale': <xarray.DataArray '_target' (geo: 2)> Size: 16B
array([13812.08025674, 11002.97913936])
Coordinates:
* geo (geo) object 16B 'geo_a' 'geo_b'}
As expected, from the model definition, we have scalers for the target and media variables across geos.
Prior Predictive Checks#
Before fitting the model, we can inspect the prior predictive distribution.
with mmm.model:
prior = pm.sample_prior_predictive()
prior
Sampling: [adstock_alpha, gamma_control, gamma_fourier_offset, gamma_fourier_sigma, intercept_contribution, saturation_beta_raw_mu, saturation_beta_raw_offset, saturation_beta_raw_sigma, saturation_lam, y, y_sigma]
-
<xarray.Dataset> Size: 19MB Dimensions: (chain: 1, draw: 500, control: 2, date: 159, geo: 2, channel: 2, fourier_mode: 4) Coordinates: * chain (chain) int64 8B 0 * draw (draw) int64 4kB 0 1 ... 499 * control (control) <U7 56B 'event_... * date (date) datetime64[ns] 1kB ... * geo (geo) <U5 40B 'geo_a' 'ge... * channel (channel) <U2 16B 'x1' 'x2' * fourier_mode (fourier_mode) <U5 80B 's... Data variables: (12/23) gamma_control (chain, draw, control) float64 8kB ... y_original_scale (chain, draw, date, geo) float64 1MB ... saturation_beta (chain, draw, channel, geo) float64 16kB ... gamma_fourier (chain, draw, geo, fourier_mode) float64 32kB ... control_contribution (chain, draw, date, geo, control) float64 3MB ... y_sigma (chain, draw) float64 4kB ... ... ... yearly_seasonality_contribution_original_scale (chain, draw, date, geo) float64 1MB ... channel_contribution (chain, draw, date, geo, channel) float64 3MB ... intercept_contribution (chain, draw, geo) float64 8kB ... saturation_beta_raw (chain, draw, channel, geo) float64 16kB ... saturation_lam (chain, draw, channel) float64 8kB ... gamma_fourier_sigma (chain, draw) float64 4kB ... Attributes: created_at: 2025-08-15T21:50:06.142004+00:00 arviz_version: 0.22.0 inference_library: pymc inference_library_version: 5.25.1
-
<xarray.Dataset> Size: 1MB Dimensions: (chain: 1, draw: 500, date: 159, geo: 2) Coordinates: * chain (chain) int64 8B 0 * draw (draw) int64 4kB 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499 * date (date) datetime64[ns] 1kB 2022-06-06 2022-06-13 ... 2025-06-16 * geo (geo) <U5 40B 'geo_a' 'geo_b' Data variables: y (chain, draw, date, geo) float64 1MB 0.8647 1.005 ... 0.405 2.664 Attributes: created_at: 2025-08-15T21:50:06.164382+00:00 arviz_version: 0.22.0 inference_library: pymc inference_library_version: 5.25.1
-
<xarray.Dataset> Size: 4kB Dimensions: (date: 159, geo: 2) Coordinates: * date (date) datetime64[ns] 1kB 2022-06-06 2022-06-13 ... 2025-06-16 * geo (geo) <U5 40B 'geo_a' 'geo_b' Data variables: y (date, geo) float64 3kB 0.1917 0.06202 0.3635 ... 0.4068 0.5073 Attributes: created_at: 2025-08-15T21:50:06.167874+00:00 arviz_version: 0.22.0 inference_library: pymc inference_library_version: 5.25.1
-
<xarray.Dataset> Size: 12kB Dimensions: (geo: 2, channel: 2, date: 159, control: 2) Coordinates: * geo (geo) <U5 40B 'geo_a' 'geo_b' * channel (channel) <U2 16B 'x1' 'x2' * date (date) datetime64[ns] 1kB 2022-06-06 ... 2025-06-16 * control (control) <U7 56B 'event_1' 'event_2' Data variables: channel_scale (geo, channel) float64 32B 9.319e+03 9.756e+03 ... 1.176e+04 target_scale (geo) float64 16B 1.381e+04 1.1e+04 channel_data (date, geo, channel) float64 5kB 5.528e+03 0.0 ... 8.091e+03 target_data (date, geo) float64 3kB 2.648e+03 682.4 ... 5.581e+03 control_data (date, geo, control) int32 3kB 0 0 0 0 0 0 0 ... 0 0 0 0 0 0 dayofyear (date) int32 636B 157 164 171 178 185 ... 139 146 153 160 167 Attributes: created_at: 2025-08-15T21:50:06.181618+00:00 arviz_version: 0.22.0 inference_library: pymc inference_library_version: 5.25.1
g = sns.relplot(
data=data_df,
x="date",
y="y",
color="black",
col="geo",
col_wrap=1,
kind="line",
height=4,
aspect=3,
)
axes = g.axes.flatten()
for ax, geo in zip(axes, mmm.model.coords["geo"], strict=True):
az.plot_hdi(
x=mmm.model.coords["date"],
y=(
prior.prior.sel(geo=geo)["y_original_scale"]
.unstack()
.transpose(..., "date")
),
smooth=False,
color="C0",
hdi_prob=0.94,
fill_kwargs={"alpha": 0.3, "label": "94% HDI"},
ax=ax,
)
az.plot_hdi(
x=mmm.model.coords["date"],
y=(
prior.prior.sel(geo=geo)["y_original_scale"]
.unstack()
.transpose(..., "date")
),
smooth=False,
color="C0",
hdi_prob=0.5,
fill_kwargs={"alpha": 0.5, "label": "50% HDI"},
ax=ax,
)
ax.legend(loc="upper left")
g.figure.suptitle("Prior Predictive", fontsize=16, fontweight="bold", y=1.03);

The prior predictive distribution looks good and not too restrictive.
Model Fitting#
We can now fit the model and generate the posterior predictive distribution.
mmm.fit(
X=x_train,
y=y_train,
chains=4,
target_accept=0.975,
random_seed=rng,
)
mmm.sample_posterior_predictive(
X=x_train,
extend_idata=True,
combined=True,
random_seed=rng,
)
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [intercept_contribution, adstock_alpha, saturation_lam, saturation_beta_raw_offset, saturation_beta_raw_mu, saturation_beta_raw_sigma, gamma_control, gamma_fourier_offset, gamma_fourier_sigma, y_sigma]
Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 99 seconds.
Sampling: [y]
<xarray.Dataset> Size: 20MB Dimensions: (date: 159, geo: 2, sample: 4000) Coordinates: * date (date) datetime64[ns] 1kB 2022-06-06 ... 2025-06-16 * geo (geo) <U5 40B 'geo_a' 'geo_b' * sample (sample) object 32kB MultiIndex * chain (sample) int64 32kB 0 0 0 0 0 0 0 0 0 ... 3 3 3 3 3 3 3 3 * draw (sample) int64 32kB 0 1 2 3 4 5 ... 995 996 997 998 999 Data variables: y (date, geo, sample) float64 10MB 0.5745 0.5809 ... 0.0395 y_original_scale (date, geo, sample) float64 10MB 7.935e+03 ... 434.6 Attributes: created_at: 2025-08-15T21:52:03.186757+00:00 arviz_version: 0.22.0 inference_library: pymc inference_library_version: 5.25.1
The sampling looks good. No divergences and the r-hat values are close to \(1\).
mmm.idata.sample_stats.diverging.sum("draw")
<xarray.DataArray 'diverging' (chain: 4)> Size: 32B array([0, 0, 0, 0]) Coordinates: * chain (chain) int64 32B 0 1 2 3
az.summary(
mmm.idata,
var_names=[
"adstock_alpha",
"gamma_control",
"gamma_fourier",
"intercept_contribution",
"saturation_beta",
"saturation_beta_raw_mu",
"saturation_beta_raw_sigma",
"saturation_lam",
"y_sigma",
],
)
mean | sd | hdi_3% | hdi_97% | mcse_mean | mcse_sd | ess_bulk | ess_tail | r_hat | |
---|---|---|---|---|---|---|---|---|---|
adstock_alpha[geo_a, x1] | 0.291 | 0.163 | 0.036 | 0.598 | 0.002 | 0.002 | 4730.0 | 2551.0 | 1.0 |
adstock_alpha[geo_a, x2] | 0.311 | 0.165 | 0.034 | 0.609 | 0.002 | 0.002 | 4214.0 | 2413.0 | 1.0 |
adstock_alpha[geo_b, x1] | 0.259 | 0.149 | 0.014 | 0.526 | 0.002 | 0.002 | 4529.0 | 2386.0 | 1.0 |
adstock_alpha[geo_b, x2] | 0.273 | 0.154 | 0.009 | 0.544 | 0.002 | 0.002 | 4128.0 | 1896.0 | 1.0 |
gamma_control[event_1] | 0.303 | 0.087 | 0.140 | 0.463 | 0.001 | 0.001 | 4832.0 | 2865.0 | 1.0 |
gamma_control[event_2] | -0.098 | 0.092 | -0.271 | 0.078 | 0.001 | 0.002 | 4603.0 | 2753.0 | 1.0 |
gamma_fourier[geo_a, sin_1] | -0.350 | 0.035 | -0.415 | -0.285 | 0.001 | 0.000 | 3440.0 | 3156.0 | 1.0 |
gamma_fourier[geo_a, sin_2] | -0.028 | 0.028 | -0.081 | 0.025 | 0.000 | 0.001 | 3975.0 | 2903.0 | 1.0 |
gamma_fourier[geo_a, cos_1] | -0.285 | 0.033 | -0.348 | -0.224 | 0.001 | 0.000 | 4231.0 | 3608.0 | 1.0 |
gamma_fourier[geo_a, cos_2] | 0.004 | 0.027 | -0.046 | 0.055 | 0.000 | 0.000 | 4641.0 | 2992.0 | 1.0 |
gamma_fourier[geo_b, sin_1] | -0.046 | 0.025 | -0.092 | -0.000 | 0.000 | 0.000 | 6112.0 | 3272.0 | 1.0 |
gamma_fourier[geo_b, sin_2] | 0.190 | 0.027 | 0.141 | 0.239 | 0.000 | 0.000 | 5180.0 | 3432.0 | 1.0 |
gamma_fourier[geo_b, cos_1] | -0.200 | 0.030 | -0.254 | -0.142 | 0.000 | 0.000 | 5350.0 | 3719.0 | 1.0 |
gamma_fourier[geo_b, cos_2] | -0.030 | 0.025 | -0.077 | 0.018 | 0.000 | 0.000 | 5615.0 | 3183.0 | 1.0 |
intercept_contribution[geo_a] | 0.196 | 0.029 | 0.141 | 0.248 | 0.001 | 0.000 | 3064.0 | 2074.0 | 1.0 |
intercept_contribution[geo_b] | 0.213 | 0.028 | 0.159 | 0.261 | 0.001 | 0.000 | 3089.0 | 2077.0 | 1.0 |
saturation_beta[x1, geo_a] | 0.215 | 0.114 | 0.038 | 0.422 | 0.002 | 0.002 | 4150.0 | 3142.0 | 1.0 |
saturation_beta[x1, geo_b] | 0.278 | 0.176 | 0.051 | 0.590 | 0.003 | 0.005 | 3391.0 | 3092.0 | 1.0 |
saturation_beta[x2, geo_a] | 0.242 | 0.127 | 0.051 | 0.480 | 0.002 | 0.003 | 3943.0 | 3061.0 | 1.0 |
saturation_beta[x2, geo_b] | 0.241 | 0.131 | 0.050 | 0.480 | 0.002 | 0.003 | 4546.0 | 3303.0 | 1.0 |
saturation_beta_raw_mu[x1] | -1.557 | 0.449 | -2.490 | -0.784 | 0.007 | 0.007 | 4541.0 | 3110.0 | 1.0 |
saturation_beta_raw_mu[x2] | -1.537 | 0.448 | -2.410 | -0.730 | 0.007 | 0.007 | 4577.0 | 2795.0 | 1.0 |
saturation_beta_raw_sigma[x1] | 0.270 | 0.261 | 0.000 | 0.765 | 0.004 | 0.006 | 3002.0 | 2377.0 | 1.0 |
saturation_beta_raw_sigma[x2] | 0.238 | 0.233 | 0.000 | 0.662 | 0.003 | 0.006 | 3673.0 | 2250.0 | 1.0 |
saturation_lam[x1] | 0.471 | 0.220 | 0.108 | 0.880 | 0.003 | 0.004 | 4053.0 | 2717.0 | 1.0 |
saturation_lam[x2] | 0.483 | 0.219 | 0.108 | 0.876 | 0.003 | 0.004 | 4273.0 | 2618.0 | 1.0 |
y_sigma | 0.183 | 0.010 | 0.165 | 0.203 | 0.000 | 0.000 | 2925.0 | 3089.0 | 1.0 |
_ = az.plot_trace(
data=mmm.idata,
var_names=[
"adstock_alpha",
"gamma_control",
"gamma_fourier",
"intercept_contribution",
"saturation_beta",
"saturation_beta_raw_mu",
"saturation_beta_raw_sigma",
"saturation_lam",
"y_sigma",
],
compact=True,
backend_kwargs={"figsize": (15, 15), "layout": "constrained"},
)
plt.gcf().suptitle("Model Trace", fontsize=16, fontweight="bold", y=1.03);

Posterior Predictive Checks#
We can now inspect the posterior predictive distribution. As before, we need to scale the posterior predictive to the original scale to make it comparable to the data.
fig, axes = plt.subplots(
nrows=len(mmm.model.coords["geo"]),
figsize=(12, 9),
sharex=True,
sharey=True,
layout="constrained",
)
for i, geo in enumerate(mmm.model.coords["geo"]):
ax = axes[i]
az.plot_hdi(
x=mmm.model.coords["date"],
y=(mmm.idata["posterior_predictive"].y_original_scale.sel(geo=geo)),
color="C0",
smooth=False,
hdi_prob=0.94,
fill_kwargs={"alpha": 0.2, "label": "94% HDI"},
ax=ax,
)
az.plot_hdi(
x=mmm.model.coords["date"],
y=(mmm.idata["posterior_predictive"].y_original_scale.sel(geo=geo)),
color="C0",
smooth=False,
hdi_prob=0.5,
fill_kwargs={"alpha": 0.4, "label": "50% HDI"},
ax=ax,
)
sns.lineplot(
data=data_df.query("geo == @geo"),
x="date",
y="y",
color="black",
ax=ax,
)
ax.legend(loc="upper left")
ax.set(title=f"{geo}")
fig.suptitle("Posterior Predictive", fontsize=16, fontweight="bold", y=1.03);

The fit looks okay! There is a lot of white-noise in the sales process we cannot predict. However, the main movements in the sales are either captured by our seasonality model or the MMM components.
Model Components#
We can extract the contributions of each component of the model in the original scale thanks to the deterministic variables added to the model.
fig, axes = plt.subplots(
nrows=len(mmm.model.coords["geo"]),
figsize=(15, 10),
sharex=True,
sharey=True,
layout="constrained",
)
for i, geo in enumerate(mmm.model.coords["geo"]):
ax = axes[i]
for j, channel in enumerate(mmm.model.coords["channel"]):
az.plot_hdi(
x=mmm.model.coords["date"],
y=mmm.idata["posterior"]["channel_contribution_original_scale"].sel(
geo=geo, channel=channel
),
color=f"C{j}",
smooth=False,
hdi_prob=0.94,
fill_kwargs={"alpha": 0.5, "label": f"94% HDI ({channel})"},
ax=ax,
)
az.plot_hdi(
x=mmm.model.coords["date"],
y=mmm.idata["posterior"]["intercept_contribution_original_scale"]
.sel(geo=geo)
.expand_dims({"date": mmm.model.coords["date"]})
.transpose(..., "date"),
color="C2",
smooth=False,
hdi_prob=0.94,
fill_kwargs={"alpha": 0.5, "label": "94% HDI intercept"},
ax=ax,
)
az.plot_hdi(
x=mmm.model.coords["date"],
y=mmm.idata["posterior"]["yearly_seasonality_contribution_original_scale"].sel(
geo=geo,
),
color="C3",
smooth=False,
hdi_prob=0.94,
fill_kwargs={"alpha": 0.5, "label": "94% HDI Fourier"},
ax=ax,
)
for k, control in enumerate(mmm.model.coords["control"]):
az.plot_hdi(
x=mmm.model.coords["date"],
y=mmm.idata["posterior"]["control_contribution_original_scale"].sel(
geo=geo, control=control
),
color=f"C{5 + k}",
smooth=False,
hdi_prob=0.94,
fill_kwargs={"alpha": 0.5, "label": f"94% HDI control ({control})"},
ax=ax,
)
sns.lineplot(
data=data_df.query("geo == @geo"),
x="date",
y="y",
color="black",
label="y",
ax=ax,
)
ax.legend(
loc="upper center",
bbox_to_anchor=(0.5, -0.1),
ncol=4,
)
ax.set(title=f"{geo}")
fig.suptitle(
"Posterior Predictive - Channel Contributions",
fontsize=16,
fontweight="bold",
y=1.03,
);

Media Deep Dive#
Next, we can look into the individual channel contributions across geos. This new class has a new plot
name space that contains many plotting methods.
fig, axes = mmm.plot.contributions_over_time(
var=["channel_contribution_original_scale"],
)
# Adjust figure size and layout to 2x2
fig.set_size_inches(14, 10)
fig.set_constrained_layout(True)
# Reshape axes to 2x2 grid
num_axes = len(axes.flatten())
if num_axes > 0:
# Create a new 2x2 grid
gs = fig.add_gridspec(2, 2)
# Move existing axes to the new grid
for i, ax in enumerate(axes.flatten()):
if i < 4: # Only handle up to 4 axes for 2x2 grid
ax.set_position(gs[i // 2, i % 2].get_position(fig))
axes = axes.flatten()
# Share x and y axes across all subplots
for ax in axes:
ax.legend().remove()
ax.tick_params(axis="both", which="major", labelsize=6)
ax.tick_params(axis="both", which="minor", labelsize=6)
# Share y axis limits
y_min = min(ax.get_ylim()[0] for ax in axes)
y_max = max(ax.get_ylim()[1] for ax in axes)
for ax in axes:
ax.set_ylim(y_min, y_max)
# Share x axis limits
x_min = min(ax.get_xlim()[0] for ax in axes)
x_max = max(ax.get_xlim()[1] for ax in axes)
for ax in axes:
ax.set_xlim(x_min, x_max)

We can plot the saturation curves for each channel and geo, using a few different functions:
Using
saturation_scatterplot
, we can get only the scatterplot between investment and estimated returns.Using
saturation_curves
, we can get the posterior of the curves and their posterior fit regarding the given mean contribution.
mmm.plot.saturation_scatterplot(width_per_col=8, height_per_row=4, original_scale=True);

curve = mmm.saturation.sample_curve(
mmm.idata.posterior[["saturation_beta", "saturation_lam"]], max_value=2
)
fig, axes = mmm.plot.saturation_curves(
curve,
original_scale=True,
n_samples=10,
hdi_probs=0.85,
random_seed=rng,
subplot_kwargs={"figsize": (12, 8), "ncols": 2},
rc_params={
"xtick.labelsize": 10,
"ytick.labelsize": 10,
"axes.labelsize": 10,
"axes.titlesize": 10,
},
)
for ax in axes.ravel():
ax.title.set_fontsize(10)
if fig._suptitle is not None:
fig._suptitle.set_fontsize(12)
plt.tight_layout()
plt.show()
Sampling: [saturation_beta_raw_mu, saturation_beta_raw_offset, saturation_beta_raw_sigma]

Parameter recovery#
One nice sign that the model is working as intend is that it can recover the true parameter values underlying the marketing mechanism. In our case, we know the true parameter values because we simulated the date. Informally, if the bulk of the posterior distribution covers the parameter value, that’s a good sign. We do not expect the mean of the posterior to always line up with the true value - for small or noisy data, we should expect the posterior to cover a wide interval regardless of whether we built a good model or not. There are also formal frameworks for thinking about parameter recovery in simulations that might be helpful if you need even more rigorous evidence the model is working correctly.
Below we compare the posterior distribution to the true values for the main MMM parameters (saturation lambda
, saturation beta
and adstock alpha
).
# Load the true parameters used to generate the data
data_path = data_dir / "mmm_multidimensional_example_true_parameters.nc"
true_parameters = xr.open_dataset(data_path)
az.plot_posterior(
mmm.fit_result,
var_names=[
"saturation_lam",
],
figsize=(12, 4),
ref_val={
"saturation_lam": [
{
"channel": "x1",
"ref_val": true_parameters["saturation_lam"].sel(channel="x1").values,
},
{
"channel": "x2",
"ref_val": true_parameters["saturation_lam"].sel(channel="x2").values,
},
]
},
);

az.plot_posterior(
mmm.fit_result,
var_names=[
"saturation_beta",
],
grid=(2, 2),
figsize=(12, 8),
ref_val={
"saturation_beta": [
{
"channel": "x1",
"geo": "geo_a",
"ref_val": true_parameters["saturation_beta"]
.sel(channel="x1", geo="geo_a")
.values,
},
{
"channel": "x2",
"geo": "geo_a",
"ref_val": true_parameters["saturation_beta"]
.sel(channel="x2", geo="geo_a")
.values,
},
{
"channel": "x1",
"geo": "geo_b",
"ref_val": true_parameters["saturation_beta"]
.sel(channel="x1", geo="geo_b")
.values,
},
{
"channel": "x2",
"geo": "geo_b",
"ref_val": true_parameters["saturation_beta"]
.sel(channel="x2", geo="geo_b")
.values,
},
]
},
);

az.plot_posterior(
mmm.fit_result,
var_names=[
"adstock_alpha",
],
grid=(2, 2),
figsize=(12, 8),
ref_val={
"adstock_alpha": [
{
"channel": "x1",
"geo": "geo_a",
"ref_val": true_parameters["adstock_alpha"]
.sel(channel="x1", geo="geo_a")
.values,
},
{
"channel": "x2",
"geo": "geo_a",
"ref_val": true_parameters["adstock_alpha"]
.sel(channel="x2", geo="geo_a")
.values,
},
{
"channel": "x1",
"geo": "geo_b",
"ref_val": true_parameters["adstock_alpha"]
.sel(channel="x1", geo="geo_b")
.values,
},
{
"channel": "x2",
"geo": "geo_b",
"ref_val": true_parameters["adstock_alpha"]
.sel(channel="x2", geo="geo_b")
.values,
},
]
},
);

Out of Sample Predictions#
It is very important to be able to make predictions out of the sample. This is key for model validation, forward looking scenario planning and business decision making. Similarly as in the MMM Example Notebook, we assume the future spends are the same as the last day in the training sample. This way we can create a new dataset with the future dates and channel spends and use the model to make predictions.
last_date = x_train["date"].max()
# New dates starting from last in dataset
n_new = 7
new_dates = pd.date_range(start=last_date, periods=1 + n_new, freq="W-MON")[1:]
x_out_of_sample_geo_a = pd.DataFrame({"date": new_dates, "geo": "geo_a"})
x_out_of_sample_geo_b = pd.DataFrame({"date": new_dates, "geo": "geo_b"})
# Same channel spends as last day
x_out_of_sample_geo_a["x1"] = x_train.query("geo == 'geo_a'")["x1"].iloc[-1]
x_out_of_sample_geo_a["x2"] = x_train.query("geo == 'geo_a'")["x2"].iloc[-1]
x_out_of_sample_geo_b["x1"] = x_train.query("geo == 'geo_b'")["x1"].iloc[-1]
x_out_of_sample_geo_b["x2"] = x_train.query("geo == 'geo_b'")["x2"].iloc[-1]
# Other features
## Event 1
x_out_of_sample_geo_a["event_1"] = 0.0
x_out_of_sample_geo_a["event_2"] = 0.0
## Event 2
x_out_of_sample_geo_b["event_1"] = 0.0
x_out_of_sample_geo_b["event_2"] = 0.0
x_out_of_sample = pd.concat([x_out_of_sample_geo_a, x_out_of_sample_geo_b])
# Final dataset to generate out of sample predictions.
x_out_of_sample
date | geo | x1 | x2 | event_1 | event_2 | |
---|---|---|---|---|---|---|
0 | 2025-06-23 | geo_a | 0.0 | 6384.065021 | 0.0 | 0.0 |
1 | 2025-06-30 | geo_a | 0.0 | 6384.065021 | 0.0 | 0.0 |
2 | 2025-07-07 | geo_a | 0.0 | 6384.065021 | 0.0 | 0.0 |
3 | 2025-07-14 | geo_a | 0.0 | 6384.065021 | 0.0 | 0.0 |
4 | 2025-07-21 | geo_a | 0.0 | 6384.065021 | 0.0 | 0.0 |
5 | 2025-07-28 | geo_a | 0.0 | 6384.065021 | 0.0 | 0.0 |
6 | 2025-08-04 | geo_a | 0.0 | 6384.065021 | 0.0 | 0.0 |
0 | 2025-06-23 | geo_b | 0.0 | 8090.900533 | 0.0 | 0.0 |
1 | 2025-06-30 | geo_b | 0.0 | 8090.900533 | 0.0 | 0.0 |
2 | 2025-07-07 | geo_b | 0.0 | 8090.900533 | 0.0 | 0.0 |
3 | 2025-07-14 | geo_b | 0.0 | 8090.900533 | 0.0 | 0.0 |
4 | 2025-07-21 | geo_b | 0.0 | 8090.900533 | 0.0 | 0.0 |
5 | 2025-07-28 | geo_b | 0.0 | 8090.900533 | 0.0 | 0.0 |
6 | 2025-08-04 | geo_b | 0.0 | 8090.900533 | 0.0 | 0.0 |
Using the same sample_posterior_predictive
method, we can now generate the forecast.
y_out_of_sample = mmm.sample_posterior_predictive(
x_out_of_sample,
extend_idata=False,
include_last_observations=True,
random_seed=rng,
var_names=["y_original_scale"],
)
y_out_of_sample
Sampling: [y]
<xarray.Dataset> Size: 544kB Dimensions: (date: 7, geo: 2, sample: 4000) Coordinates: * date (date) datetime64[ns] 56B 2025-06-23 ... 2025-08-04 * geo (geo) <U5 40B 'geo_a' 'geo_b' * sample (sample) object 32kB MultiIndex * chain (sample) int64 32kB 0 0 0 0 0 0 0 0 0 ... 3 3 3 3 3 3 3 3 * draw (sample) int64 32kB 0 1 2 3 4 5 ... 995 996 997 998 999 Data variables: y_original_scale (date, geo, sample) float64 448kB 7.35e+03 ... 6.859e+03 Attributes: created_at: 2025-08-15T21:52:21.623974+00:00 arviz_version: 0.22.0 inference_library: pymc inference_library_version: 5.25.1
fig, axes = plt.subplots(
nrows=2,
ncols=1,
figsize=(12, 10),
sharex=True,
sharey=True,
layout="constrained",
)
n_train_to_plot = 30
for ax, geo in zip(axes, mmm.model.coords["geo"], strict=True):
for hdi_prob in [0.94, 0.5]:
az.plot_hdi(
x=mmm.model.coords["date"][-n_train_to_plot:],
y=(
mmm.idata["posterior_predictive"].y_original_scale.sel(geo=geo)[
:, :, -n_train_to_plot:
]
),
color="C0",
smooth=False,
hdi_prob=hdi_prob,
fill_kwargs={"alpha": 0.4, "label": f"{hdi_prob: 0.0%} HDI"},
ax=ax,
)
az.plot_hdi(
x_out_of_sample.query("geo == @geo")["date"],
(
y_out_of_sample["y_original_scale"]
.sel(geo=geo)
.unstack()
.transpose(..., "date")
),
color="C1",
smooth=False,
hdi_prob=hdi_prob,
fill_kwargs={"alpha": 0.4, "label": f"{hdi_prob: 0.0%} HDI"},
ax=ax,
)
ax.plot(
x_out_of_sample.query("geo == @geo")["date"],
y_out_of_sample["y_original_scale"].sel(geo=geo).mean(dim="sample"),
marker="o",
color="C1",
label="posterior predictive mean",
)
sns.lineplot(
data=data_df.query("(geo == @geo)").tail(n_train_to_plot),
x="date",
y="y",
marker="o",
color="black",
label="observed",
ax=ax,
)
ax.axvline(x=last_date, color="gray", linestyle="--", label="last observation")
ax.legend(
loc="upper center",
bbox_to_anchor=(0.5, -0.15),
ncol=3,
)
ax.set(title=f"{geo}")
fig.suptitle(
"Posterior Predictive - Out of Sample", fontsize=16, fontweight="bold", y=1.03
);

Optimization#
If you want to run optimizations, then you need to use the MultiDimensionalBudgetOptimizerWrapper
.
optimizable_model = MultiDimensionalBudgetOptimizerWrapper(
model=mmm, start_date="2021-10-01", end_date="2021-12-31"
)
allocation_xarray, scipy_opt_result = optimizable_model.optimize_budget(
budget=100_000,
)
sample_allocation = optimizable_model.sample_response_distribution(
allocation_strategy=allocation_xarray,
)
Sampling: [y]
This objects is an xarray dataset with the allocation and posterior predictive responses!
sample_allocation
<xarray.Dataset> Size: 4MB Dimensions: (date: 21, geo: 2, sample: 4000, channel: 2) Coordinates: * date (date) datetime64[ns] 168B 2021-... * geo (geo) <U5 40B 'geo_a' 'geo_b' * channel (channel) <U2 16B 'x1' 'x2' * sample (sample) object 32kB MultiIndex * chain (sample) int64 32kB 0 0 0 ... 3 3 3 * draw (sample) int64 32kB 0 1 ... 998 999 Data variables: y (date, geo, sample) float64 1MB ... channel_contribution (date, geo, channel, sample) float64 3MB ... total_media_contribution_original_scale (sample) float64 32kB 4.729e+04 ... allocation (geo, channel) float64 32B 2.622... x1 (date, geo) float64 336B 2.623e+... x2 (date, geo) float64 336B 3.105e+... Attributes: created_at: 2025-08-15T21:52:35.494520+00:00 arviz_version: 0.22.0 inference_library: pymc inference_library_version: 5.25.1
Once you get the allocation, you can plot a the results 🚀
optimizable_model.plot.budget_allocation(
samples=sample_allocation,
);

The graph shows the optimal budget for each channel on each geo, next to their respective mean contribution given the optimal budget. The method identify automatically the number of dimensions and tries to create a plot based on them.
If you want to see the full uncertainty over time, you can use the plot suite and the method allocated_contribution_by_channel_over_time
.
optimizable_model.plot.allocated_contribution_by_channel_over_time(
samples=sample_allocation,
);

If you have a custom model, you can wrapped it into the model protocol, and use the optimizer after. If your model handle scales internally, you don’t need to modify anything. Otherwise, for the plots, you may want to use scale_factor=N
. E.g:
optimizable_model.plot.budget_allocation(
samples=sample_allocation,
scale_factor=120
);
Save Model#
You can optionally save the result of your hard work. The model result objects (idata) can get very large once we start working in multiple dimensions. So it can sometimes be helpful to compress the idata before saving. Below are a couple of tricks.
# Reduce your posterior (optional)
# clone_idata = mmm.idata.copy()
# clone_idata.posterior = clone_idata.posterior.astype(np.float32)
# clone_idata.posterior = clone_idata.posterior.sel(draw=slice(None, None, 10))
# clone_idata.to_netcdf("multidimensional_model_compressed.nc", groups=["posterior", "fit_data"], engine="h5netcdf")
Note
We are very excited about this new feature and the possibilities it opens up. We are looking forward to hearing your feedback!
%load_ext watermark
%watermark -n -u -v -iv -w -p pymc_marketing,pytensor,nutpie
Last updated: Fri Aug 15 2025
Python implementation: CPython
Python version : 3.13.5
IPython version : 9.4.0
pymc_marketing: 0.15.1
pytensor : 2.31.7
nutpie : 0.15.2
xarray : 2025.7.1
matplotlib : 3.10.3
pymc : 5.25.1
pandas : 2.3.1
seaborn : 0.13.2
numpy : 2.2.6
arviz : 0.22.0
pymc_extras : 0.4.0
pymc_marketing: 0.15.1
Watermark: 2.5.0