diff --git a/tide/processing.py b/tide/processing.py index 63eacad..330b953 100644 --- a/tide/processing.py +++ b/tide/processing.py @@ -1,3 +1,5 @@ +from collections import defaultdict + import pandas as pd import numpy as np import datetime as dt @@ -1273,19 +1275,30 @@ def _fit_implementation(self, X: pd.Series | pd.DataFrame, y=None): for req, method in self.tide_format_methods.items(): self.columns_methods.append((tide_request(X.columns, req), method)) - return self - - def _transform_implementation(self, X: pd.Series | pd.DataFrame): + def _transform_implementation(self, X: pd.DataFrame): check_is_fitted(self, attributes=["feature_names_in_"]) + if not self.columns_methods: - agg_dict = {col: self.method for col in X.columns} + col_to_method = {col: self.method for col in X.columns} else: - agg_dict = {col: agg for cols, agg in self.columns_methods for col in cols} + col_to_method = {} + for cols, method in self.columns_methods: + for col in cols: + col_to_method[col] = method for col in X.columns: - if col not in agg_dict.keys(): - agg_dict[col] = self.method + col_to_method.setdefault(col, self.method) + + method_to_cols = defaultdict(list) + for col, method in col_to_method.items(): + method_to_cols[method].append(col) + + results = [] + for method, cols in method_to_cols.items(): + res = X[cols].resample(self.rule).agg(method) + results.append(res) - return X.resample(rule=self.rule).agg(agg_dict)[X.columns] + out = pd.concat(results, axis=1) + return out[X.columns] class AddTimeLag(BaseProcessing):