plot#
MMM related plotting class.
Examples#
Quickstart with MMM:
from pymc_marketing.mmm import GeometricAdstock, LogisticSaturation
from pymc_marketing.mmm.multidimensional import MMM
import pandas as pd
# Minimal dataset
X = pd.DataFrame(
{
"date": pd.date_range("2025-01-01", periods=12, freq="W-MON"),
"C1": [100, 120, 90, 110, 105, 115, 98, 102, 108, 111, 97, 109],
"C2": [80, 70, 95, 85, 90, 88, 92, 94, 91, 89, 93, 87],
}
)
y = pd.Series(
[230, 260, 220, 240, 245, 255, 235, 238, 242, 246, 233, 249], name="y"
)
mmm = MMM(
date_column="date",
channel_columns=["C1", "C2"],
target_column="y",
adstock=GeometricAdstock(l_max=10),
saturation=LogisticSaturation(),
)
mmm.fit(X, y)
mmm.sample_posterior_predictive(X)
# Posterior predictive time series
_ = mmm.plot.posterior_predictive(var=["y"], hdi_prob=0.9)
# Posterior contributions over time (e.g., channel_contribution)
_ = mmm.plot.contributions_over_time(var=["channel_contribution"], hdi_prob=0.9)
# Channel saturation scatter plot (scaled space by default)
_ = mmm.plot.saturation_scatterplot(original_scale=False)
Wrap a custom PyMC model#
Requirements
posterior_predictive plots: an
az.InferenceData
with aposterior_predictive
group containing the variable(s) you want to plot with adate
coordinate.contributions_over_time plots: a
posterior
group with time‑series variables (withdate
).saturation plots: a
constant_data
dataset with variables: -channel_data
: dims include("date", "channel", ...)
-channel_scale
: dims include("channel", ...)
-target_scale
: scalar or broadcastable to the curve dims and aposterior
variable namedchannel_contribution
(orchannel_contribution_original_scale
if plottingoriginal_scale=True
).
import numpy as np
import pandas as pd
import pymc as pm
from pymc_marketing.mmm.plot import MMMPlotSuite
dates = pd.date_range("2025-01-01", periods=30, freq="D")
y_obs = np.random.normal(size=len(dates))
with pm.Model(coords={"date": dates}):
sigma = pm.HalfNormal("sigma", 1.0)
pm.Normal("y", 0.0, sigma, observed=y_obs, dims="date")
idata = pm.sample_prior_predictive(random_seed=1)
idata.extend(pm.sample(draws=200, chains=2, tune=200, random_seed=1))
idata.extend(pm.sample_posterior_predictive(idata, random_seed=1))
plot = MMMPlotSuite(idata)
_ = plot.posterior_predictive(var=["y"], hdi_prob=0.9)
Custom contributions_over_time#
import numpy as np
import pandas as pd
import pymc as pm
from pymc_marketing.mmm.plot import MMMPlotSuite
dates = pd.date_range("2025-01-01", periods=30, freq="D")
x = np.linspace(0, 2 * np.pi, len(dates))
series = np.sin(x)
with pm.Model(coords={"date": dates}):
pm.Deterministic("component", series, dims="date")
idata = pm.sample_prior_predictive(random_seed=2)
idata.extend(pm.sample(draws=50, chains=1, tune=0, random_seed=2))
plot = MMMPlotSuite(idata)
_ = plot.contributions_over_time(var=["component"], hdi_prob=0.9)
Saturation plots with a custom model#
import numpy as np
import pandas as pd
import xarray as xr
import pymc as pm
from pymc_marketing.mmm.plot import MMMPlotSuite
dates = pd.date_range("2025-01-01", periods=20, freq="W-MON")
channels = ["C1", "C2"]
# Create constant_data required for saturation plots
channel_data = xr.DataArray(
np.random.rand(len(dates), len(channels)),
dims=("date", "channel"),
coords={"date": dates, "channel": channels},
name="channel_data",
)
channel_scale = xr.DataArray(
np.ones(len(channels)),
dims=("channel",),
coords={"channel": channels},
name="channel_scale",
)
target_scale = xr.DataArray(1.0, name="target_scale")
# Build a toy model that yields a matching posterior var
with pm.Model(coords={"date": dates, "channel": channels}):
# A fake contribution over time per channel (dims must include date & channel)
contrib = pm.Normal("channel_contribution", 0.0, 1.0, dims=("date", "channel"))
idata = pm.sample_prior_predictive(random_seed=3)
idata.extend(pm.sample(draws=50, chains=1, tune=0, random_seed=3))
# Attach constant_data to idata
idata.constant_data = xr.Dataset(
{
"channel_data": channel_data,
"channel_scale": channel_scale,
"target_scale": target_scale,
}
)
plot = MMMPlotSuite(idata)
_ = plot.saturation_scatterplot(original_scale=False)
Notes#
MMM
exposes this suite via themmm.plot
property, which internally passes the model’sidata
intoMMMPlotSuite
.Any PyMC model can use
MMMPlotSuite
directly if itsInferenceData
contains the needed groups/variables described above.
Classes
|
Media Mix Model Plot Suite. |