plot#

MMM related plotting class.

Examples#

Quickstart with MMM:

from pymc_marketing.mmm import GeometricAdstock, LogisticSaturation
from pymc_marketing.mmm.multidimensional import MMM
import pandas as pd

# Minimal dataset
X = pd.DataFrame(
    {
        "date": pd.date_range("2025-01-01", periods=12, freq="W-MON"),
        "C1": [100, 120, 90, 110, 105, 115, 98, 102, 108, 111, 97, 109],
        "C2": [80, 70, 95, 85, 90, 88, 92, 94, 91, 89, 93, 87],
    }
)
y = pd.Series(
    [230, 260, 220, 240, 245, 255, 235, 238, 242, 246, 233, 249], name="y"
)

mmm = MMM(
    date_column="date",
    channel_columns=["C1", "C2"],
    target_column="y",
    adstock=GeometricAdstock(l_max=10),
    saturation=LogisticSaturation(),
)
mmm.fit(X, y)
mmm.sample_posterior_predictive(X)

# Posterior predictive time series
_ = mmm.plot.posterior_predictive(var=["y"], hdi_prob=0.9)

# Posterior contributions over time (e.g., channel_contribution)
_ = mmm.plot.contributions_over_time(var=["channel_contribution"], hdi_prob=0.9)

# Channel saturation scatter plot (scaled space by default)
_ = mmm.plot.saturation_scatterplot(original_scale=False)

Wrap a custom PyMC model#

Requirements

  • posterior_predictive plots: an az.InferenceData with a posterior_predictive group containing the variable(s) you want to plot with a date coordinate.

  • contributions_over_time plots: a posterior group with time‑series variables (with date).

  • saturation plots: a constant_data dataset with variables: - channel_data: dims include ("date", "channel", ...) - channel_scale: dims include ("channel", ...) - target_scale: scalar or broadcastable to the curve dims and a posterior variable named channel_contribution (or channel_contribution_original_scale if plotting original_scale=True).

import numpy as np
import pandas as pd
import pymc as pm
from pymc_marketing.mmm.plot import MMMPlotSuite

dates = pd.date_range("2025-01-01", periods=30, freq="D")
y_obs = np.random.normal(size=len(dates))

with pm.Model(coords={"date": dates}):
    sigma = pm.HalfNormal("sigma", 1.0)
    pm.Normal("y", 0.0, sigma, observed=y_obs, dims="date")

    idata = pm.sample_prior_predictive(random_seed=1)
    idata.extend(pm.sample(draws=200, chains=2, tune=200, random_seed=1))
    idata.extend(pm.sample_posterior_predictive(idata, random_seed=1))

plot = MMMPlotSuite(idata)
_ = plot.posterior_predictive(var=["y"], hdi_prob=0.9)

Custom contributions_over_time#

import numpy as np
import pandas as pd
import pymc as pm
from pymc_marketing.mmm.plot import MMMPlotSuite

dates = pd.date_range("2025-01-01", periods=30, freq="D")
x = np.linspace(0, 2 * np.pi, len(dates))
series = np.sin(x)

with pm.Model(coords={"date": dates}):
    pm.Deterministic("component", series, dims="date")
    idata = pm.sample_prior_predictive(random_seed=2)
    idata.extend(pm.sample(draws=50, chains=1, tune=0, random_seed=2))

plot = MMMPlotSuite(idata)
_ = plot.contributions_over_time(var=["component"], hdi_prob=0.9)

Saturation plots with a custom model#

import numpy as np
import pandas as pd
import xarray as xr
import pymc as pm
from pymc_marketing.mmm.plot import MMMPlotSuite

dates = pd.date_range("2025-01-01", periods=20, freq="W-MON")
channels = ["C1", "C2"]

# Create constant_data required for saturation plots
channel_data = xr.DataArray(
    np.random.rand(len(dates), len(channels)),
    dims=("date", "channel"),
    coords={"date": dates, "channel": channels},
    name="channel_data",
)
channel_scale = xr.DataArray(
    np.ones(len(channels)),
    dims=("channel",),
    coords={"channel": channels},
    name="channel_scale",
)
target_scale = xr.DataArray(1.0, name="target_scale")

# Build a toy model that yields a matching posterior var
with pm.Model(coords={"date": dates, "channel": channels}):
    # A fake contribution over time per channel (dims must include date & channel)
    contrib = pm.Normal("channel_contribution", 0.0, 1.0, dims=("date", "channel"))

    idata = pm.sample_prior_predictive(random_seed=3)
    idata.extend(pm.sample(draws=50, chains=1, tune=0, random_seed=3))

# Attach constant_data to idata
idata.constant_data = xr.Dataset(
    {
        "channel_data": channel_data,
        "channel_scale": channel_scale,
        "target_scale": target_scale,
    }
)

plot = MMMPlotSuite(idata)
_ = plot.saturation_scatterplot(original_scale=False)

Notes#

  • MMM exposes this suite via the mmm.plot property, which internally passes the model’s idata into MMMPlotSuite.

  • Any PyMC model can use MMMPlotSuite directly if its InferenceData contains the needed groups/variables described above.

Classes

MMMPlotSuite(idata)

Media Mix Model Plot Suite.