Skip to content

Commit 9c4f882

Browse files
committed
Modifying to be clear with "target"
1 parent 0d0464d commit 9c4f882

8 files changed

+58
-49
lines changed

data/config_files/basic_model.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ schema_version: 1
22

33
# ----------------------------------------------------------------------
44
model:
5-
target: pymc_marketing.mmm.multidimensional.MMM
5+
target_class: pymc_marketing.mmm.multidimensional.MMM
66
kwargs:
77
date_column: "date"
88
channel_columns: # explicit for reproducibility
@@ -13,11 +13,11 @@ model:
1313

1414
# --- media transformations ---------------------------------------
1515
adstock:
16-
target: pymc_marketing.mmm.GeometricAdstock
16+
target_class: pymc_marketing.mmm.GeometricAdstock
1717
kwargs: {l_max: 12} # any other hyper-parameters here
1818

1919
saturation:
20-
target: pymc_marketing.mmm.MichaelisMentenSaturation
20+
target_class: pymc_marketing.mmm.MichaelisMentenSaturation
2121
kwargs: {} # default α, λ priors inside the class
2222

2323
# ----------------------------------------------------------------------

data/config_files/example_with_original_scale_vars.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ schema_version: 1
22

33
# ----------------------------------------------------------------------
44
model:
5-
target: pymc_marketing.mmm.multidimensional.MMM
5+
target_class: pymc_marketing.mmm.multidimensional.MMM
66
kwargs:
77
date_column: "date"
88
channel_columns: # explicit for reproducibility
@@ -13,11 +13,11 @@ model:
1313

1414
# --- media transformations ---------------------------------------
1515
adstock:
16-
target: pymc_marketing.mmm.GeometricAdstock
16+
target_class: pymc_marketing.mmm.GeometricAdstock
1717
kwargs: {l_max: 12} # any other hyper-parameters here
1818

1919
saturation:
20-
target: pymc_marketing.mmm.MichaelisMentenSaturation
20+
target_class: pymc_marketing.mmm.MichaelisMentenSaturation
2121
kwargs: {} # default α, λ priors inside the class
2222

2323
# ----------------------------------------------------------------------

data/config_files/multi_dimensiona_hierarchical_model_nested_config.yml

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ schema_version: 1
22

33
# ----------------------------------------------------------------------
44
model:
5-
target: pymc_marketing.mmm.multidimensional.MMM
5+
target_class: pymc_marketing.mmm.multidimensional.MMM
66
kwargs:
77
date_column: "date"
88
channel_columns: # explicit for reproducibility
@@ -23,7 +23,7 @@ model:
2323

2424
# --- media transformations ---------------------------------------
2525
adstock:
26-
target: pymc_marketing.mmm.GeometricAdstock
26+
target_class: pymc_marketing.mmm.GeometricAdstock
2727
kwargs:
2828
priors:
2929
alpha:
@@ -34,7 +34,7 @@ model:
3434
l_max: 28
3535

3636
saturation:
37-
target: pymc_marketing.mmm.MichaelisMentenSaturation
37+
target_class: pymc_marketing.mmm.MichaelisMentenSaturation
3838
kwargs:
3939
priors:
4040
alpha:
@@ -51,7 +51,7 @@ model:
5151
# --- model (hierarchical) priors ---------------------------------
5252
model_config:
5353
intercept:
54-
target: pymc_marketing.prior.Prior
54+
target_class: pymc_marketing.prior.Prior
5555
kwargs:
5656
args: ["HalfCauchy"]
5757
beta:
@@ -60,7 +60,7 @@ model:
6060
dims: "market"
6161

6262
likelihood:
63-
target: pymc_marketing.prior.Prior
63+
target_class: pymc_marketing.prior.Prior
6464
kwargs:
6565
args: ["TruncatedNormal"]
6666
lower: 0
@@ -74,10 +74,10 @@ model:
7474
# Effects with complex priors
7575
effects:
7676
# 1. Linear Trend Effect with complex nested priors
77-
- target: pymc_marketing.mmm.additive_effect.LinearTrendEffect
77+
- target_class: pymc_marketing.mmm.additive_effect.LinearTrendEffect
7878
kwargs:
7979
trend:
80-
target: pymc_marketing.mmm.LinearTrend
80+
target_class: pymc_marketing.mmm.LinearTrend
8181
kwargs:
8282
n_changepoints: 2
8383
include_intercept: false
@@ -93,10 +93,10 @@ effects:
9393
prefix: "trend"
9494

9595
# 2. Fourier Effect with complex nested priors
96-
- target: pymc_marketing.mmm.additive_effect.FourierEffect
96+
- target_class: pymc_marketing.mmm.additive_effect.FourierEffect
9797
kwargs:
9898
fourier:
99-
target: pymc_marketing.mmm.WeeklyFourier
99+
target_class: pymc_marketing.mmm.WeeklyFourier
100100
kwargs:
101101
n_order: 3
102102
prefix: "weekly_fourier"

data/config_files/multi_dimensional_hierarchical_model.yml

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ schema_version: 1
22

33
# ----------------------------------------------------------------------
44
model:
5-
target: pymc_marketing.mmm.multidimensional.MMM
5+
target_class: pymc_marketing.mmm.multidimensional.MMM
66
kwargs:
77
date_column: "date"
88
channel_columns: # explicit for reproducibility
@@ -23,11 +23,11 @@ model:
2323

2424
# --- media transformations ---------------------------------------
2525
adstock:
26-
target: pymc_marketing.mmm.GeometricAdstock
26+
target_class: pymc_marketing.mmm.GeometricAdstock
2727
kwargs: {l_max: 12} # any other hyper-parameters here
2828

2929
saturation:
30-
target: pymc_marketing.mmm.LogisticSaturation
30+
target_class: pymc_marketing.mmm.LogisticSaturation
3131
kwargs:
3232
priors:
3333
beta:
@@ -47,15 +47,15 @@ model:
4747
# --- model (hierarchical) priors ---------------------------------
4848
model_config:
4949
intercept:
50-
target: pymc_marketing.prior.Prior
50+
target_class: pymc_marketing.prior.Prior
5151
kwargs:
5252
args: ["Gamma"]
5353
mu: 0.5
5454
sigma: 1
5555
dims: "market"
5656

5757
likelihood:
58-
target: pymc_marketing.prior.Prior
58+
target_class: pymc_marketing.prior.Prior
5959
kwargs:
6060
args: ["Normal"]
6161
sigma:

data/config_files/multi_dimensional_hierarchical_with_arbitrary_effects_model.yml

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ schema_version: 1
22

33
# ----------------------------------------------------------------------
44
model:
5-
target: pymc_marketing.mmm.multidimensional.MMM
5+
target_class: pymc_marketing.mmm.multidimensional.MMM
66
kwargs:
77
date_column: "date"
88
channel_columns: # explicit for reproducibility
@@ -23,11 +23,11 @@ model:
2323

2424
# --- media transformations ---------------------------------------
2525
adstock:
26-
target: pymc_marketing.mmm.GeometricAdstock
26+
target_class: pymc_marketing.mmm.GeometricAdstock
2727
kwargs: {l_max: 12} # any other hyper-parameters here
2828

2929
saturation:
30-
target: pymc_marketing.mmm.LogisticSaturation
30+
target_class: pymc_marketing.mmm.LogisticSaturation
3131
kwargs:
3232
priors:
3333
beta:
@@ -47,15 +47,15 @@ model:
4747
# --- model (hierarchical) priors ---------------------------------
4848
model_config:
4949
intercept:
50-
target: pymc_marketing.prior.Prior
50+
target_class: pymc_marketing.prior.Prior
5151
kwargs:
5252
args: ["Gamma"]
5353
mu: 0.5
5454
sigma: 1
5555
dims: "market"
5656

5757
likelihood:
58-
target: pymc_marketing.prior.Prior
58+
target_class: pymc_marketing.prior.Prior
5959
kwargs:
6060
args: ["Normal"]
6161
sigma:
@@ -69,10 +69,10 @@ model:
6969
# Effects with complex priors
7070
effects:
7171
# 1. Linear Trend Effect with complex nested priors
72-
- target: pymc_marketing.mmm.additive_effect.LinearTrendEffect
72+
- target_class: pymc_marketing.mmm.additive_effect.LinearTrendEffect
7373
kwargs:
7474
trend:
75-
target: pymc_marketing.mmm.LinearTrend
75+
target_class: pymc_marketing.mmm.LinearTrend
7676
kwargs:
7777
n_changepoints: 2
7878
include_intercept: false
@@ -88,10 +88,10 @@ effects:
8888
prefix: "trend"
8989

9090
# 2. Fourier Effect with complex nested priors
91-
- target: pymc_marketing.mmm.additive_effect.FourierEffect
91+
- target_class: pymc_marketing.mmm.additive_effect.FourierEffect
9292
kwargs:
9393
fourier:
94-
target: pymc_marketing.mmm.WeeklyFourier
94+
target_class: pymc_marketing.mmm.WeeklyFourier
9595
kwargs:
9696
n_order: 3
9797
prefix: "weekly_fourier"

data/config_files/multi_dimensional_model.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ schema_version: 1
22

33
# ----------------------------------------------------------------------
44
model:
5-
target: pymc_marketing.mmm.multidimensional.MMM
5+
target_class: pymc_marketing.mmm.multidimensional.MMM
66
kwargs:
77
date_column: "date"
88
channel_columns: # explicit for reproducibility
@@ -14,11 +14,11 @@ model:
1414

1515
# --- media transformations ---------------------------------------
1616
adstock:
17-
target: pymc_marketing.mmm.GeometricAdstock
17+
target_class: pymc_marketing.mmm.GeometricAdstock
1818
kwargs: {l_max: 12} # any other hyper-parameters here
1919

2020
saturation:
21-
target: pymc_marketing.mmm.MichaelisMentenSaturation
21+
target_class: pymc_marketing.mmm.MichaelisMentenSaturation
2222
kwargs: {} # default α, λ priors inside the class
2323

2424
# ----------------------------------------------------------------------

pymc_marketing/mmm/builders/deserializers.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,8 @@ def deserialize_prior(data: dict[str, Any]) -> Prior:
4646
data_copy[key] = deserialize_prior(value)
4747
elif (
4848
isinstance(value, dict)
49-
and "target" in value
50-
and value["target"] == "pymc_marketing.prior.Prior"
49+
and "target_class" in value
50+
and value["target_class"] == "pymc_marketing.prior.Prior"
5151
):
5252
data_copy[key] = deserialize_standard_prior(value)
5353

@@ -111,8 +111,8 @@ def is_standard_prior_dict(data: Any) -> tuple[bool, str]:
111111
return False, ""
112112

113113
if (
114-
"target" in data
115-
and data["target"] == "pymc_marketing.prior.Prior"
114+
"target_class" in data
115+
and data["target_class"] == "pymc_marketing.prior.Prior"
116116
and "kwargs" in data
117117
):
118118
return True, "Prior"
@@ -126,7 +126,7 @@ def deserialize_standard_prior(data: dict[str, Any]) -> Prior:
126126
127127
The expected format is:
128128
{
129-
"target": "pymc_marketing.prior.Prior",
129+
"target_class": "pymc_marketing.prior.Prior",
130130
"kwargs": {
131131
"args": ["Distribution"],
132132
"param1": value1,
@@ -148,7 +148,10 @@ def deserialize_standard_prior(data: dict[str, Any]) -> Prior:
148148
if isinstance(value, dict):
149149
if "distribution" in value:
150150
new_kwargs[key] = deserialize_prior(value)
151-
elif "target" in value and value["target"] == "pymc_marketing.prior.Prior":
151+
elif (
152+
"target_class" in value
153+
and value["target_class"] == "pymc_marketing.prior.Prior"
154+
):
152155
new_kwargs[key] = deserialize_standard_prior(value)
153156

154157
# Create Prior
@@ -165,7 +168,10 @@ def is_priors_dict(data: Any) -> bool:
165168
for _key, value in data.items():
166169
if isinstance(value, dict) and (
167170
"distribution" in value
168-
or ("target" in value and value["target"] == "pymc_marketing.prior.Prior")
171+
or (
172+
"target_class" in value
173+
and value["target_class"] == "pymc_marketing.prior.Prior"
174+
)
169175
):
170176
return True
171177
return False
@@ -178,7 +184,10 @@ def deserialize_priors_dict(data: dict[str, Any]) -> dict[str, Any]:
178184
if isinstance(value, dict):
179185
if "distribution" in value:
180186
result[key] = deserialize_prior(value)
181-
elif "target" in value and value["target"] == "pymc_marketing.prior.Prior":
187+
elif (
188+
"target_class" in value
189+
and value["target_class"] == "pymc_marketing.prior.Prior"
190+
):
182191
result[key] = deserialize_standard_prior(value)
183192
else:
184193
result[key] = value

pymc_marketing/mmm/builders/factories.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -90,17 +90,17 @@ def _build_single(spec: Mapping[str, Any]) -> Any:
9090
Notes
9191
-----
9292
Recognised keys
93-
* target : str (mandatory)
93+
* target_class : str (mandatory)
9494
* kwargs : dict (optional)
9595
* args : list (optional positional arguments)
9696
"""
97-
# Ensure target is a string
98-
if not isinstance(spec["target"], str):
97+
# Ensure target_class is a string
98+
if not isinstance(spec["target_class"], str):
9999
raise TypeError(
100-
f"Expected string for 'target' but got {type(spec['target']).__name__}: {spec['target']}"
100+
f"Expected string for 'target_class' but got {type(spec['target_class']).__name__}: {spec['target_class']}"
101101
)
102102

103-
cls = locate(spec["target"])
103+
cls = locate(spec["target_class"])
104104

105105
raw_kwargs: MutableMapping[str, Any] = dict(spec.get("kwargs", {}))
106106
raw_args: Sequence[Any] = raw_kwargs.pop("args", spec.get("args", ()))
@@ -164,21 +164,21 @@ def resolve(value):
164164
if isinstance(value, Mapping) and "scaling" in value:
165165
return value
166166
# nested object
167-
if isinstance(value, Mapping) and "target" in value:
167+
if isinstance(value, Mapping) and "target_class" in value:
168168
return _build_single(value)
169169
# list of nested objects
170170
if (
171171
isinstance(value, list)
172172
and value
173173
and isinstance(value[0], Mapping)
174-
and "target" in value[0]
174+
and "target_class" in value[0]
175175
):
176176
return [_build_single(v) for v in value]
177177
return value
178178

179179

180180
def build(spec: Mapping[str, Any]) -> Any:
181181
"""Public wrapper that checks minimal structure and delegates to _build_single."""
182-
if "target" not in spec:
183-
raise ValueError("Spec must contain a 'target' key.")
182+
if "target_class" not in spec:
183+
raise ValueError("Spec must contain a 'target_class' key.")
184184
return _build_single(spec)

0 commit comments

Comments
 (0)