diff --git a/src/isp_trace_parser/get_data.py b/src/isp_trace_parser/get_data.py index 6d620c2..2149158 100644 --- a/src/isp_trace_parser/get_data.py +++ b/src/isp_trace_parser/get_data.py @@ -7,6 +7,11 @@ from pydantic import validate_call +def nem_datetime(*args, **kwargs): + """Wrap datetime so that we can introduce timezone-aware datatimes in future.""" + return datetime.datetime(*args, **kwargs) + + def _year_range_to_dt_range( start_year: int, end_year: int, year_type: Literal["fy", "calendar"] = "fy" ): @@ -33,14 +38,10 @@ def _year_range_to_dt_range( """ if year_type == "fy": - return datetime.datetime(start_year - 1, 7, 1), datetime.datetime( - end_year, 7, 1 - ) + return nem_datetime(start_year - 1, 7, 1), nem_datetime(end_year, 7, 1) elif year_type == "calendar": - return datetime.datetime(start_year, 1, 1), datetime.datetime( - end_year + 1, 1, 1 - ) + return nem_datetime(start_year, 1, 1), nem_datetime(end_year + 1, 1, 1) def _query_parquet_single_reference_year( diff --git a/tests/test_get_data.py b/tests/test_get_data.py index ddb1d0a..7720c36 100644 --- a/tests/test_get_data.py +++ b/tests/test_get_data.py @@ -15,6 +15,7 @@ get_project_single_reference_year, get_zone_multiple_reference_years, get_zone_single_reference_year, + nem_datetime, solar_area_single_reference_year, solar_project_multiple_reference_years, solar_project_single_reference_year, @@ -29,16 +30,16 @@ def test_year_range_to_dt_range_fy(): """Test financial year conversion.""" start_dt, end_dt = _year_range_to_dt_range(2022, 2024, year_type="fy") - assert start_dt == datetime.datetime(2021, 7, 1, 0, 0) - assert end_dt == datetime.datetime(2024, 7, 1, 0, 0) + assert start_dt == nem_datetime(2021, 7, 1, 0, 0) + assert end_dt == nem_datetime(2024, 7, 1, 0, 0) def test_year_range_to_dt_range_calendar(): """Test calendar year conversion.""" start_dt, end_dt = _year_range_to_dt_range(2022, 2024, year_type="calendar") - assert start_dt == datetime.datetime(2022, 1, 1, 0, 0) - assert end_dt == datetime.datetime(2025, 1, 1, 0, 0) + assert start_dt == nem_datetime(2022, 1, 1, 0, 0) + assert end_dt == nem_datetime(2025, 1, 1, 0, 0) @pytest.mark.parametrize("year_type", ["fy", "calendar"]) @@ -74,8 +75,8 @@ def test_get_zone_multiple_reference_year(parsed_trace_trace_directory: Path): test_df = ( test_df_lazy.filter( - (pl.col("datetime") > datetime.datetime(2028, 7, 1)) - & (pl.col("datetime") <= datetime.datetime(2030, 7, 1)) + (pl.col("datetime") > nem_datetime(2028, 7, 1)) + & (pl.col("datetime") <= nem_datetime(2030, 7, 1)) ) .select(["datetime", "value"]) .collect() @@ -100,8 +101,8 @@ def test_get_project_single_reference_year(parsed_trace_trace_directory: Path): test_df = ( test_df_lazy.filter( - (pl.col("datetime") > datetime.datetime(2022, 7, 1)) - & (pl.col("datetime") <= datetime.datetime(2024, 7, 1)) + (pl.col("datetime") > nem_datetime(2022, 7, 1)) + & (pl.col("datetime") <= nem_datetime(2024, 7, 1)) ) .select(["datetime", "value"]) .collect() @@ -127,8 +128,8 @@ def test_get_project_multiple_reference_year(parsed_trace_trace_directory: Path) test_df = ( test_df_lazy.filter( - (pl.col("datetime") > datetime.datetime(2028, 7, 1)) - & (pl.col("datetime") <= datetime.datetime(2030, 7, 1)) + (pl.col("datetime") > nem_datetime(2028, 7, 1)) + & (pl.col("datetime") <= nem_datetime(2030, 7, 1)) ) .select(["datetime", "value"]) .collect() @@ -187,8 +188,8 @@ def test_get_demand_multiple_reference_year(parsed_trace_trace_directory: Path): test_df = ( test_df_lazy.filter( - (pl.col("datetime") > datetime.datetime(2028, 7, 1)) - & (pl.col("datetime") <= datetime.datetime(2030, 7, 1)) + (pl.col("datetime") > nem_datetime(2028, 7, 1)) + & (pl.col("datetime") <= nem_datetime(2030, 7, 1)) ) .select(["datetime", "value"]) .collect() @@ -241,8 +242,8 @@ def test_wind_project_single_reference_year(parsed_trace_trace_directory): test_df = ( test_df_lazy.filter( - (pl.col("datetime") > datetime.datetime(2022, 7, 1)) - & (pl.col("datetime") <= datetime.datetime(2024, 7, 1)) + (pl.col("datetime") > nem_datetime(2022, 7, 1)) + & (pl.col("datetime") <= nem_datetime(2024, 7, 1)) ) .select(["datetime", "value"]) .collect() @@ -267,8 +268,8 @@ def test_solar_project_single_reference_year(parsed_trace_trace_directory): test_df = ( test_df_lazy.filter( - (pl.col("datetime") > datetime.datetime(2022, 7, 1)) - & (pl.col("datetime") <= datetime.datetime(2024, 7, 1)) + (pl.col("datetime") > nem_datetime(2022, 7, 1)) + & (pl.col("datetime") <= nem_datetime(2024, 7, 1)) ) .select(["datetime", "value"]) .collect() @@ -293,8 +294,8 @@ def test_solar_project_multiple_reference_years(parsed_trace_trace_directory: Pa test_df = ( test_df_lazy.filter( - (pl.col("datetime") > datetime.datetime(2028, 7, 1)) - & (pl.col("datetime") <= datetime.datetime(2030, 7, 1)) + (pl.col("datetime") > nem_datetime(2028, 7, 1)) + & (pl.col("datetime") <= nem_datetime(2030, 7, 1)) ) .select(["datetime", "value"]) .collect() @@ -318,8 +319,8 @@ def test_wind_project_multiple_reference_years(parsed_trace_trace_directory: Pat test_df = ( test_df_lazy.filter( - (pl.col("datetime") > datetime.datetime(2028, 7, 1)) - & (pl.col("datetime") <= datetime.datetime(2030, 7, 1)) + (pl.col("datetime") > nem_datetime(2028, 7, 1)) + & (pl.col("datetime") <= nem_datetime(2030, 7, 1)) ) .select(["datetime", "value"]) .collect() @@ -404,8 +405,8 @@ def test_demand_multiple_reference_years(parsed_trace_trace_directory: Path): test_df = ( test_df_lazy.filter( - (pl.col("datetime") > datetime.datetime(2028, 7, 1)) - & (pl.col("datetime") <= datetime.datetime(2030, 7, 1)) + (pl.col("datetime") > nem_datetime(2028, 7, 1)) + & (pl.col("datetime") <= nem_datetime(2030, 7, 1)) ) .select(["datetime", "value"]) .collect()