diff --git a/models/evaluate.py b/models/evaluate.py index 9db829d..9b7b6a5 100644 --- a/models/evaluate.py +++ b/models/evaluate.py @@ -71,7 +71,7 @@ def main(): summaries.append(summary) df = pd.DataFrame(summaries) - print(df.to_string()) + print(df[["model_name", "task_name", "test_error", "inference_time_s"]].to_string()) output_path = f"{display_name}.csv" df.to_csv(output_path, index=False) print(f"Saved to {output_path}") diff --git a/models/statsforecast/model.py b/models/statsforecast/model.py index 8678b5a..ca7707d 100644 --- a/models/statsforecast/model.py +++ b/models/statsforecast/model.py @@ -89,6 +89,7 @@ def _predict_window( if (train_df["ds"] > pd.Timestamp.max).any(): train_df["ds"] = train_df.groupby("unique_id", sort=False).cumcount() + predictor.freq = 1 if self.max_context_length is not None: train_df = (