@@ -340,7 +340,8 @@ def __init__(
340340 (new_x ,) = build_design_matrices ([self ._x_design_info ], self .x_pred_treatment )
341341 self .y_pred_treatment = self .model .predict (np .asarray (new_x ))
342342
343- # predicted outcome for counterfactual
343+ # predicted outcome for counterfactual. This is given by removing the influence
344+ # of the interaction term between the group and the post_treatment variable
344345 self .x_pred_counterfactual = (
345346 self .data
346347 # just the treated group
@@ -349,24 +350,28 @@ def __init__(
349350 .query ("post_treatment == True" )
350351 # drop the outcome variable
351352 .drop (self .outcome_variable_name , axis = 1 )
352- # DO AN INTERVENTION. Set the post_treatment variable to False
353- .assign (post_treatment = False )
354353 # We may have multiple units per time point, we only want one time point
355354 .groupby (self .time_variable_name )
356355 .first ()
357356 .reset_index ()
358357 )
359358 assert not self .x_pred_counterfactual .empty
360359 (new_x ,) = build_design_matrices (
361- [self ._x_design_info ], self .x_pred_counterfactual
360+ [self ._x_design_info ], self .x_pred_counterfactual , return_type = "dataframe"
362361 )
362+ # INTERVENTION: set the interaction term between the group and the
363+ # post_treatment variable to zero. This is the counterfactual.
364+ for i , label in enumerate (self .labels ):
365+ if "post_treatment" in label and self .group_variable_name in label :
366+ new_x .iloc [:, i ] = 0
363367 self .y_pred_counterfactual = self .model .predict (np .asarray (new_x ))
364368
365- # calculate causal impact
366- self .causal_impact = (
367- self .y_pred_treatment ["posterior_predictive" ].mu .isel ({"obs_ind" : 1 })
368- - self .y_pred_counterfactual ["posterior_predictive" ].mu .squeeze ()
369- )
369+ # calculate causal impact.
370+ # This is the coefficient on the interaction term
371+ coeff_names = self .idata .posterior .coords ["coeffs" ].data
372+ for i , label in enumerate (coeff_names ):
373+ if "post_treatment" in label and self .group_variable_name in label :
374+ self .causal_impact = self .idata .posterior ["beta" ].isel ({"coeffs" : i })
370375
371376 def plot (self ):
372377 """Plot the results.
0 commit comments