\n", " | date | \n", "year | \n", "month | \n", "dayofyear | \n", "t | \n", "influencer_spend | \n", "shipping_threshold | \n", "intercept | \n", "trend | \n", "cs | \n", "cc | \n", "seasonality | \n", "epsilon | \n", "y | \n", "
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | \n", "2019-04-01 | \n", "2019 | \n", "4 | \n", "91 | \n", "0 | \n", "0.918883 | \n", "25.0 | \n", "2.0 | \n", "0.778279 | \n", "-0.012893 | \n", "0.006446 | \n", "-0.003223 | \n", "-0.118826 | \n", "2.561363 | \n", "
1 | \n", "2019-04-08 | \n", "2019 | \n", "4 | \n", "98 | \n", "1 | \n", "0.230898 | \n", "25.0 | \n", "2.0 | \n", "0.795664 | \n", "0.225812 | \n", "-0.113642 | \n", "0.056085 | \n", "0.064977 | \n", "2.264874 | \n", "
2 | \n", "2019-04-15 | \n", "2019 | \n", "4 | \n", "105 | \n", "2 | \n", "0.254486 | \n", "25.0 | \n", "2.0 | \n", "0.812559 | \n", "0.451500 | \n", "-0.232087 | \n", "0.109706 | \n", "-0.020269 | \n", "1.998208 | \n", "
3 | \n", "2019-04-22 | \n", "2019 | \n", "4 | \n", "112 | \n", "3 | \n", "0.035995 | \n", "25.0 | \n", "2.0 | \n", "0.828993 | \n", "0.651162 | \n", "-0.347175 | \n", "0.151993 | \n", "0.400209 | \n", "1.701116 | \n", "
4 | \n", "2019-04-29 | \n", "2019 | \n", "4 | \n", "119 | \n", "4 | \n", "0.336013 | \n", "25.0 | \n", "2.0 | \n", "0.844997 | \n", "0.813290 | \n", "-0.457242 | \n", "0.178024 | \n", "0.057609 | \n", "2.003646 | \n", "
Sampler Progress
\n", "Total Chains: 4
\n", "Active Chains: 0
\n", "\n", " Finished Chains:\n", " 4\n", "
\n", "Sampling for 14 seconds
\n", "\n", " Estimated Time to Completion:\n", " now\n", "
\n", "\n", " \n", "Progress | \n", "Draws | \n", "Divergences | \n", "Step Size | \n", "Gradients/Draw | \n", "
---|---|---|---|---|
\n", " \n", " | \n", "2000 | \n", "0 | \n", "0.19 | \n", "63 | \n", "
\n", " \n", " | \n", "2000 | \n", "0 | \n", "0.19 | \n", "31 | \n", "
\n", " \n", " | \n", "2000 | \n", "0 | \n", "0.18 | \n", "79 | \n", "
\n", " \n", " | \n", "2000 | \n", "0 | \n", "0.19 | \n", "31 | \n", "
/opt/homebrew/envs/pymc-marketing-dev/lib/python3.12/site-packages/rich/live.py:256: UserWarning: install \n", "\"ipywidgets\" for Jupyter support\n", " warnings.warn('install \"ipywidgets\" for Jupyter support')\n", "\n" ], "text/plain": [ "/opt/homebrew/envs/pymc-marketing-dev/lib/python3.12/site-packages/rich/live.py:256: UserWarning: install \n", "\"ipywidgets\" for Jupyter support\n", " warnings.warn('install \"ipywidgets\" for Jupyter support')\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n" ], "text/plain": [] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "Sampling: [y]\n" ] }, { "data": { "text/html": [ "
/opt/homebrew/envs/pymc-marketing-dev/lib/python3.12/site-packages/rich/live.py:256: UserWarning: install \n", "\"ipywidgets\" for Jupyter support\n", " warnings.warn('install \"ipywidgets\" for Jupyter support')\n", "\n" ], "text/plain": [ "/opt/homebrew/envs/pymc-marketing-dev/lib/python3.12/site-packages/rich/live.py:256: UserWarning: install \n", "\"ipywidgets\" for Jupyter support\n", " warnings.warn('install \"ipywidgets\" for Jupyter support')\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n" ], "text/plain": [] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "mmm = MMM(\n", " date_column=\"date\",\n", " target_column=\"y\",\n", " adstock=GeometricAdstock(l_max=8),\n", " saturation=LogisticSaturation(),\n", " channel_columns=[\"influencer_spend\"],\n", " control_columns=[\"t\", \"shipping_threshold\"],\n", " yearly_seasonality=2,\n", ")\n", "\n", "x_train = df.drop(columns=[\"y\"])\n", "y_train = df[\"y\"]\n", "\n", "mmm.fit(\n", " X=x_train,\n", " y=y_train,\n", " nuts_sampler=\"nutpie\",\n", " nuts_sampler_kwargs={\n", " \"backend\": \"jax\",\n", " \"gradient_backend\": \"jax\",\n", " },\n", ")\n", "mmm.sample_posterior_predictive(x_train, extend_idata=True);" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Sensitivity analysis and marginal effects" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### A multiplicative sweep on influencer spend" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "tags": [ "hide-output" ] }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Sampling: [y]\n", "Sampling: [y]\n", "Sampling: [y]\n", "Sampling: [y]\n", "Sampling: [y]\n", "Sampling: [y]\n", "Sampling: [y]\n", "Sampling: [y]\n", "Sampling: [y]\n", "Sampling: [y]\n", "Sampling: [y]\n", "Sampling: [y]\n" ] }, { "data": { "text/html": [ "
<xarray.Dataset> Size: 98MB\n", "Dimensions: (chain: 4, draw: 1000, date: 127, sweep: 12)\n", "Coordinates:\n", " * chain (chain) int64 32B 0 1 2 3\n", " * draw (draw) int64 8kB 0 1 2 3 4 5 6 ... 994 995 996 997 998 999\n", " * date (date) datetime64[ns] 1kB 2019-04-01 ... 2021-08-30\n", " * sweep (sweep) float64 96B 0.0 0.1818 0.3636 ... 1.636 1.818 2.0\n", "Data variables:\n", " y (chain, draw, date, sweep) float64 49MB -0.6762 ... 0.363\n", " marginal_effects (chain, draw, date, sweep) float64 49MB 1.464 ... 0.5909\n", "Attributes:\n", " sweep_type: multiplicative\n", " var_names: ['influencer_spend']
<xarray.Dataset> Size: 33MB\n", "Dimensions: (chain: 4, draw: 1000, control: 2,\n", " fourier_mode: 4, date: 127,\n", " channel: 1)\n", "Coordinates:\n", " * chain (chain) int64 32B 0 1 2 3\n", " * draw (draw) int64 8kB 0 1 2 ... 998 999\n", " * control (control) object 16B 'shipping_t...\n", " * fourier_mode (fourier_mode) object 32B 'sin_1...\n", " * date (date) datetime64[ns] 1kB 2019-0...\n", " * channel (channel) <U16 64B 'influencer_s...\n", "Data variables: (12/16)\n", " intercept_contribution (chain, draw) float64 32kB 0.488...\n", " adstock_alpha_logodds__ (chain, draw) float64 32kB -0.43...\n", " saturation_lam_log__ (chain, draw) float64 32kB 1.358...\n", " saturation_beta_log__ (chain, draw) float64 32kB -0.26...\n", " gamma_control (chain, draw, control) float64 64kB ...\n", " gamma_fourier (chain, draw, fourier_mode) float64 128kB ...\n", " ... ...\n", " y_sigma (chain, draw) float64 32kB 0.071...\n", " channel_contribution (chain, draw, date, channel) float64 4MB ...\n", " total_media_contribution_original_scale (chain, draw) float64 32kB 177.3...\n", " control_contribution (chain, draw, date, control) float64 8MB ...\n", " fourier_contribution (chain, draw, date, fourier_mode) float64 16MB ...\n", " yearly_seasonality_contribution (chain, draw, date) float64 4MB ...\n", "Attributes:\n", " created_at: 2025-08-13T10:19:30.246258+00:00\n", " arviz_version: 0.22.0\n", " inference_library: nutpie\n", " inference_library_version: 0.15.2\n", " sampling_time: 14.934990882873535\n", " tuning_steps: 1000\n", " pymc_marketing_version: 0.15.1
<xarray.Dataset> Size: 336kB\n", "Dimensions: (chain: 4, draw: 1000)\n", "Coordinates:\n", " * chain (chain) int64 32B 0 1 2 3\n", " * draw (draw) int64 8kB 0 1 2 3 4 5 ... 995 996 997 998 999\n", "Data variables:\n", " depth (chain, draw) uint64 32kB 4 6 4 6 5 4 ... 5 4 5 4 5 5\n", " maxdepth_reached (chain, draw) bool 4kB False False ... False False\n", " index_in_trajectory (chain, draw) int64 32kB 4 -27 -11 -9 ... 14 -12 16 9\n", " logp (chain, draw) float64 32kB 140.5 140.7 ... 138.8 140.4\n", " energy (chain, draw) float64 32kB -135.1 -136.4 ... -133.7\n", " diverging (chain, draw) bool 4kB False False ... False False\n", " energy_error (chain, draw) float64 32kB -0.1269 -0.3253 ... 0.00532\n", " step_size (chain, draw) float64 32kB 0.189 0.189 ... 0.1878\n", " step_size_bar (chain, draw) float64 32kB 0.189 0.189 ... 0.1878\n", " mean_tree_accept (chain, draw) float64 32kB 0.9691 0.9881 ... 0.982\n", " mean_tree_accept_sym (chain, draw) float64 32kB 0.8554 0.8769 ... 0.985\n", " n_steps (chain, draw) uint64 32kB 31 95 15 63 ... 63 15 31 47\n", "Attributes:\n", " created_at: 2025-08-13T10:19:30.237387+00:00\n", " arviz_version: 0.22.0
<xarray.Dataset> Size: 2kB\n", "Dimensions: (date: 127)\n", "Coordinates:\n", " * date (date) datetime64[ns] 1kB 2019-04-01 2019-04-08 ... 2021-08-30\n", "Data variables:\n", " y (date) float64 1kB 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0 0.0\n", "Attributes:\n", " created_at: 2025-08-13T10:19:30.861715+00:00\n", " arviz_version: 0.22.0\n", " inference_library: pymc\n", " inference_library_version: 5.25.1
<xarray.Dataset> Size: 6kB\n", "Dimensions: (channel: 1, date: 127, control: 2)\n", "Coordinates:\n", " * channel (channel) <U16 64B 'influencer_spend'\n", " * date (date) datetime64[ns] 1kB 2019-04-01 ... 2021-08-30\n", " * control (control) <U18 144B 'shipping_threshold' 't'\n", "Data variables:\n", " channel_scale (channel) float64 8B 0.9919\n", " target_scale float64 8B 3.981\n", " channel_data (date, channel) float64 1kB 0.9189 0.2309 ... 0.2797 0.2041\n", " target_data (date) float64 1kB 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0\n", " control_data (date, control) float64 2kB 25.0 0.0 25.0 ... 20.0 126.0\n", " dayofyear (date) int32 508B 91 98 105 112 119 ... 214 221 228 235 242\n", "Attributes:\n", " created_at: 2025-08-13T10:19:30.863703+00:00\n", " arviz_version: 0.22.0\n", " inference_library: pymc\n", " inference_library_version: 5.25.1
<xarray.Dataset> Size: 14kB\n", "Dimensions: (index: 127)\n", "Coordinates:\n", " * index (index) int64 1kB 0 1 2 3 4 5 ... 122 123 124 125 126\n", "Data variables: (12/14)\n", " date (index) datetime64[ns] 1kB 2019-04-01 ... 2021-08-30\n", " year (index) int32 508B 2019 2019 2019 ... 2021 2021 2021\n", " month (index) int32 508B 4 4 4 4 4 5 5 5 5 ... 7 7 7 8 8 8 8 8\n", " dayofyear (index) int32 508B 91 98 105 112 119 ... 221 228 235 242\n", " t (index) int64 1kB 0 1 2 3 4 5 ... 122 123 124 125 126\n", " influencer_spend (index) float64 1kB 0.9189 0.2309 ... 0.2797 0.2041\n", " ... ...\n", " trend (index) float64 1kB 0.7783 0.7957 0.8126 ... 1.779 1.783\n", " cs (index) float64 1kB -0.01289 0.2258 ... -0.9747 -0.8932\n", " cc (index) float64 1kB 0.006446 -0.1136 ... -0.623 -0.5246\n", " seasonality (index) float64 1kB -0.003223 0.05608 ... -0.7089\n", " epsilon (index) float64 1kB -0.1188 0.06498 ... -0.3317 -0.05244\n", " y (index) float64 1kB 2.561 2.265 1.998 ... 2.734 2.607
<xarray.Dataset> Size: 4MB\n", "Dimensions: (chain: 4, draw: 1000, date: 127)\n", "Coordinates:\n", " * chain (chain) int64 32B 0 1 2 3\n", " * draw (draw) int64 8kB 0 1 2 3 4 5 6 7 ... 993 994 995 996 997 998 999\n", " * date (date) datetime64[ns] 1kB 2019-04-01 2019-04-08 ... 2021-08-30\n", "Data variables:\n", " y (chain, draw, date) float64 4MB 0.6702 0.5258 ... 0.8099 0.6873\n", "Attributes:\n", " created_at: 2025-08-13T10:19:30.859848+00:00\n", " arviz_version: 0.22.0\n", " inference_library: pymc\n", " inference_library_version: 5.25.1
<xarray.Dataset> Size: 98MB\n", "Dimensions: (chain: 4, draw: 1000, date: 127, sweep: 12)\n", "Coordinates:\n", " * chain (chain) int64 32B 0 1 2 3\n", " * draw (draw) int64 8kB 0 1 2 3 4 5 6 ... 994 995 996 997 998 999\n", " * date (date) datetime64[ns] 1kB 2019-04-01 ... 2021-08-30\n", " * sweep (sweep) float64 96B 0.0 0.1818 0.3636 ... 1.636 1.818 2.0\n", "Data variables:\n", " y (chain, draw, date, sweep) float64 49MB -0.6762 ... 0.363\n", " marginal_effects (chain, draw, date, sweep) float64 49MB 1.464 ... 0.5909\n", "Attributes:\n", " sweep_type: multiplicative\n", " var_names: ['influencer_spend']
<xarray.Dataset> Size: 488kB\n", "Dimensions: (chain: 4, draw: 1000, control: 2, fourier_mode: 4)\n", "Coordinates:\n", " * chain (chain) int64 32B 0 1 2 3\n", " * draw (draw) int64 8kB 0 1 2 3 4 ... 995 996 997 998 999\n", " * control (control) object 16B 'shipping_threshold' 't'\n", " * fourier_mode (fourier_mode) object 32B 'sin_1' ... 'cos_2'\n", "Data variables:\n", " intercept_contribution (chain, draw) float64 32kB 0.2694 0.2694 ... 0.3854\n", " adstock_alpha_logodds__ (chain, draw) float64 32kB -1.646 -1.646 ... 0.195\n", " saturation_lam_log__ (chain, draw) float64 32kB 0.4302 0.4302 ... 1.437\n", " saturation_beta_log__ (chain, draw) float64 32kB 0.5996 ... -0.08112\n", " gamma_control (chain, draw, control) float64 64kB -0.5422 ... ...\n", " gamma_fourier (chain, draw, fourier_mode) float64 128kB 0.3939...\n", " y_sigma_log__ (chain, draw) float64 32kB 1.165 1.165 ... -2.623\n", " adstock_alpha (chain, draw) float64 32kB 0.1617 0.1617 ... 0.5486\n", " saturation_lam (chain, draw) float64 32kB 1.538 1.538 ... 4.208\n", " saturation_beta (chain, draw) float64 32kB 1.821 1.821 ... 0.9221\n", " y_sigma (chain, draw) float64 32kB 3.205 3.205 ... 0.07259\n", "Attributes:\n", " created_at: 2025-08-13T10:19:30.234662+00:00\n", " arviz_version: 0.22.0
<xarray.Dataset> Size: 336kB\n", "Dimensions: (chain: 4, draw: 1000)\n", "Coordinates:\n", " * chain (chain) int64 32B 0 1 2 3\n", " * draw (draw) int64 8kB 0 1 2 3 4 5 ... 995 996 997 998 999\n", "Data variables:\n", " depth (chain, draw) uint64 32kB 2 0 2 3 2 3 ... 4 6 4 4 5 4\n", " maxdepth_reached (chain, draw) bool 4kB False False ... False False\n", " index_in_trajectory (chain, draw) int64 32kB 2 0 2 1 -3 1 ... -24 8 6 13 9\n", " logp (chain, draw) float64 32kB -2.411e+03 ... 139.3\n", " energy (chain, draw) float64 32kB 2.986e+03 ... -135.3\n", " diverging (chain, draw) bool 4kB False True ... False False\n", " energy_error (chain, draw) float64 32kB -9.958 0.0 ... 0.2545\n", " step_size (chain, draw) float64 32kB 1.439 0.2431 ... 0.1878\n", " step_size_bar (chain, draw) float64 32kB 1.439 0.4998 ... 0.1878\n", " mean_tree_accept (chain, draw) float64 32kB 1.0 0.0 ... 0.9721 0.8904\n", " mean_tree_accept_sym (chain, draw) float64 32kB 0.08114 0.0 ... 0.9273\n", " n_steps (chain, draw) uint64 32kB 3 1 3 7 3 ... 63 15 15 63 15\n", "Attributes:\n", " created_at: 2025-08-13T10:19:30.240564+00:00\n", " arviz_version: 0.22.0
<xarray.Dataset> Size: 98MB\n", "Dimensions: (chain: 4, draw: 1000, date: 127, sweep: 12)\n", "Coordinates:\n", " * chain (chain) int64 32B 0 1 2 3\n", " * draw (draw) int64 8kB 0 1 2 3 4 5 6 ... 994 995 996 997 998 999\n", " * date (date) datetime64[ns] 1kB 2019-04-01 ... 2021-08-30\n", " * sweep (sweep) float64 96B 0.0 0.1818 0.3636 ... 1.636 1.818 2.0\n", "Data variables:\n", " y (chain, draw, date, sweep) float64 49MB -0.4471 ... 0.2292\n", " marginal_effects (chain, draw, date, sweep) float64 49MB -0.3712 ... -1.281\n", "Attributes:\n", " sweep_type: absolute\n", " var_names: ['influencer_spend']
<xarray.Dataset> Size: 98MB\n", "Dimensions: (chain: 4, draw: 1000, date: 127, sweep: 12)\n", "Coordinates:\n", " * chain (chain) int64 32B 0 1 2 3\n", " * draw (draw) int64 8kB 0 1 2 3 4 5 6 ... 994 995 996 997 998 999\n", " * date (date) datetime64[ns] 1kB 2019-04-01 ... 2021-08-30\n", " * sweep (sweep) float64 96B 0.0 0.1818 0.3636 ... 1.636 1.818 2.0\n", "Data variables:\n", " y (chain, draw, date, sweep) float64 49MB -0.03469 ... 0....\n", " marginal_effects (chain, draw, date, sweep) float64 49MB 1.17 ... 0.02177\n", "Attributes:\n", " sweep_type: additive\n", " var_names: ['influencer_spend']
<xarray.Dataset> Size: 98MB\n", "Dimensions: (chain: 4, draw: 1000, date: 127, sweep: 12)\n", "Coordinates:\n", " * chain (chain) int64 32B 0 1 2 3\n", " * draw (draw) int64 8kB 0 1 2 3 4 5 6 ... 994 995 996 997 998 999\n", " * date (date) datetime64[ns] 1kB 2019-04-01 ... 2021-08-30\n", " * sweep (sweep) float64 96B 0.0 0.09091 0.1818 ... 0.9091 1.0\n", "Data variables:\n", " y (chain, draw, date, sweep) float64 49MB 0.4111 ... 0.428\n", " marginal_effects (chain, draw, date, sweep) float64 49MB -0.6963 ... -1.128\n", "Attributes:\n", " sweep_type: absolute\n", " var_names: ['shipping_threshold']
<xarray.Dataset> Size: 244MB\n", "Dimensions: (chain: 4, draw: 1000, date: 127, sweep: 30)\n", "Coordinates:\n", " * chain (chain) int64 32B 0 1 2 3\n", " * draw (draw) int64 8kB 0 1 2 3 4 5 6 ... 994 995 996 997 998 999\n", " * date (date) datetime64[ns] 1kB 2019-04-01 ... 2021-08-30\n", " * sweep (sweep) float64 240B 0.0 0.03448 0.06897 ... 0.9655 1.0\n", "Data variables:\n", " y (chain, draw, date, sweep) float64 122MB 0.4907 ... 0.3568\n", " marginal_effects (chain, draw, date, sweep) float64 122MB -4.118 ... -4.627\n", "Attributes:\n", " sweep_type: absolute\n", " var_names: ['shipping_threshold']