Source code for pymc_marketing.mmm.sensitivity_analysis
# Copyright 2022 - 2025 The PyMC Labs Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Counterfactual sweeps for Marketing Mix Models (MMM)."""
from typing import Literal
import numpy as np
import pandas as pd
import xarray as xr
[docs]
class SensitivityAnalysis:
"""SensitivityAnalysis class is used to perform counterfactual analysis on MMM's."""
[docs]
def __init__(self, mmm) -> None:
"""
Initialize the SensitivityAnalysis with a reference to the MMM instance.
Parameters
----------
mmm : MMM
The marketing mix model instance used for predictions.
"""
self.mmm = mmm
[docs]
def run_sweep(
self,
var_names: list[str],
sweep_values: np.ndarray,
sweep_type: Literal[
"multiplicative", "additive", "absolute"
] = "multiplicative",
) -> xr.Dataset:
"""Run the model's predict function over the sweep grid and store results.
Parameters
----------
var_names : list[str]
List of variable names to intervene on.
sweep_values : np.ndarray
Array of sweep values.
sweep_type : Literal["multiplicative", "additive", "absolute"], optional
Type of intervention to apply, by default "multiplicative".
- 'multiplicative': Multiply the original predictor values by each sweep value.
- 'additive': Add each sweep value to the original predictor values.
- 'absolute': Set the predictor values directly to each sweep value (ignoring original values).
Returns
-------
xr.Dataset
Dataset containing the sensitivity analysis results.
"""
# Validate that idata exists
if not hasattr(self.mmm, "idata"):
raise ValueError("idata does not exist. Build the model first and fit.")
# Store parameters for this run
self.var_names = var_names
self.sweep_values = sweep_values
self.sweep_type = sweep_type
# TODO: Ideally we can use this --------------------------------------------
# actual = self.mmm._get_group_predictive_data(
# group="posterior_predictive", original_scale=True
# )["y"]
actual = self.mmm.idata["posterior_predictive"]["y"]
# --------------------------------------------------------------------------
predictions = []
for sweep_value in self.sweep_values:
X_new = self.create_intervention(sweep_value)
counterfac = self.mmm.sample_posterior_predictive(
X_new, extend_idata=False, combined=False, progressbar=False
)
uplift = counterfac - actual
predictions.append(uplift)
results = (
xr.concat(predictions, dim="sweep")
.assign_coords(sweep=self.sweep_values)
.transpose(..., "sweep")
)
marginal_effects = self.compute_marginal_effects(results, self.sweep_values)
results = xr.merge(
[
results,
marginal_effects.rename({"y": "marginal_effects"}),
]
).transpose(..., "sweep")
# Add metadata to the results
results.attrs["sweep_type"] = self.sweep_type
results.attrs["var_names"] = self.var_names
# Add results to the MMM's idata
if hasattr(self.mmm.idata, "sensitivity_analysis"):
delattr(self.mmm.idata, "sensitivity_analysis")
self.mmm.idata.add_groups({"sensitivity_analysis": results}) # type: ignore
return results
[docs]
def create_intervention(self, sweep_value: float) -> pd.DataFrame:
"""Apply the intervention to the predictors."""
X_new = self.mmm.X.copy()
if self.sweep_type == "multiplicative":
for var_name in self.var_names:
X_new[var_name] *= sweep_value
elif self.sweep_type == "additive":
for var_name in self.var_names:
X_new[var_name] += sweep_value
elif self.sweep_type == "absolute":
for var_name in self.var_names:
X_new[var_name] = sweep_value
else:
raise ValueError(f"Unsupported sweep_type: {self.sweep_type}")
return X_new
[docs]
@staticmethod
def compute_marginal_effects(results, sweep_values) -> xr.DataArray:
"""Compute marginal effects via finite differences from the sweep results."""
marginal_effects = results.differentiate(coord="sweep")
return marginal_effects