@@ -176,13 +176,21 @@ def __init__(
176176 data : pd .DataFrame ,
177177 formula : str ,
178178 time_variable_name : str ,
179+ group_variable_name : str ,
180+ treated : str ,
181+ untreated : str ,
179182 model = None ,
180183 ** kwargs ,
181184 ):
182185 super ().__init__ (model = model , ** kwargs )
183186 self .data = data
184187 self .formula = formula
185188 self .time_variable_name = time_variable_name
189+ self .group_variable_name = group_variable_name
190+ self .treated = treated # level of the group_variable_name that was treated
191+ self .untreated = (
192+ untreated # level of the group_variable_name that was untreated
193+ )
186194 y , X = dmatrices (formula , self .data )
187195 self ._y_design_info = y .design_info
188196 self ._x_design_info = X .design_info
@@ -194,32 +202,66 @@ def __init__(
194202 self .model .fit (X = self .X , y = self .y )
195203
196204 # predicted outcome for control group
197- self .x_pred_control = pd .DataFrame (
198- {"group" : [0 , 0 ], "t" : [0.0 , 1.0 ], "post_treatment" : [0 , 0 ]}
205+ self .x_pred_control = (
206+ self .data
207+ # just the untreated group
208+ .query (f"{ self .group_variable_name } == @self.untreated" )
209+ # drop the outcome variable
210+ .drop (self .outcome_variable_name , axis = 1 )
211+ # We may have multiple units per time point, we only want one time point
212+ .groupby (self .time_variable_name )
213+ .first ()
214+ .reset_index ()
199215 )
200216 assert not self .x_pred_control .empty
201217 (new_x ,) = build_design_matrices ([self ._x_design_info ], self .x_pred_control )
202218 self .y_pred_control = self .model .predict (np .asarray (new_x ))
203219
204220 # predicted outcome for treatment group
205- self .x_pred_treatment = pd .DataFrame (
206- {"group" : [1 , 1 ], "t" : [0.0 , 1.0 ], "post_treatment" : [0 , 1 ]}
221+ self .x_pred_treatment = (
222+ self .data
223+ # just the treated group
224+ .query (f"{ self .group_variable_name } == @self.treated" )
225+ # drop the outcome variable
226+ .drop (self .outcome_variable_name , axis = 1 )
227+ # We may have multiple units per time point, we only want one time point
228+ .groupby (self .time_variable_name )
229+ .first ()
230+ .reset_index ()
207231 )
208232 assert not self .x_pred_treatment .empty
209233 (new_x ,) = build_design_matrices ([self ._x_design_info ], self .x_pred_treatment )
210234 self .y_pred_treatment = self .model .predict (np .asarray (new_x ))
211235
212- # predicted outcome for counterfactual
213- self .x_pred_counterfactual = pd .DataFrame (
214- {"group" : [1 ], "t" : [1.0 ], "post_treatment" : [0 ]}
236+ # predicted outcome for counterfactual. This is given by removing the influence
237+ # of the interaction term between the group and the post_treatment variable
238+ self .x_pred_counterfactual = (
239+ self .data
240+ # just the treated group
241+ .query (f"{ self .group_variable_name } == @self.treated" )
242+ # just the treatment period(s)
243+ .query ("post_treatment == True" )
244+ # drop the outcome variable
245+ .drop (self .outcome_variable_name , axis = 1 )
246+ # We may have multiple units per time point, we only want one time point
247+ .groupby (self .time_variable_name )
248+ .first ()
249+ .reset_index ()
215250 )
216251 assert not self .x_pred_counterfactual .empty
217252 (new_x ,) = build_design_matrices (
218- [self ._x_design_info ], self .x_pred_counterfactual
253+ [self ._x_design_info ], self .x_pred_counterfactual , return_type = "dataframe"
219254 )
255+ # INTERVENTION: set the interaction term between the group and the
256+ # post_treatment variable to zero. This is the counterfactual.
257+ for i , label in enumerate (self .labels ):
258+ if "post_treatment" in label and self .group_variable_name in label :
259+ new_x .iloc [:, i ] = 0
220260 self .y_pred_counterfactual = self .model .predict (np .asarray (new_x ))
221261
222262 # calculate causal impact
263+ # This is the coefficient on the interaction term
264+ # TODO: THIS IS NOT YET CORRECT
223265 self .causal_impact = self .y_pred_treatment [1 ] - self .y_pred_counterfactual [0 ]
224266
225267 def plot (self ):
0 commit comments