diff --git a/create_trace_mapping.py b/create_trace_mapping.py index 900106b..29e3926 100644 --- a/create_trace_mapping.py +++ b/create_trace_mapping.py @@ -1,3 +1,7 @@ +"""Create trace mapping.""" + +from pathlib import Path + import yaml from nemosis import static_table @@ -20,14 +24,14 @@ solar_generator_mapping = draft_solar_generator_to_trace_mapping( solar_gens, solar_traces ) -with open("draft_solar_generator_mapping.yaml", "w") as file: +with Path.open("draft_solar_generator_mapping.yaml", "w") as file: yaml.dump(solar_generator_mapping, file, default_flow_style=False) solar_traces = "/media/nick/Samsung_T5/isp_2024_data/trace_data/solar/solar_2023" rezs = gets_rezs(workbook) solar_rez_mapping = draft_solar_rez_mapping(rezs, solar_traces) -with open("solar_area_mapping.yaml", "w") as file: +with Path.open("solar_area_mapping.yaml", "w") as file: yaml.dump(solar_rez_mapping, file, default_flow_style=False) duids_and_station_names = static_table( @@ -48,12 +52,12 @@ wind_generator_mapping = draft_wind_generator_to_trace_mapping( wind_gens, wind_duids_and_station_names, wind_traces ) -with open("draft_wind_generator_mapping.yaml", "w") as file: +with Path.open("draft_wind_generator_mapping.yaml", "w") as file: yaml.dump(wind_generator_mapping, file, default_flow_style=False, sort_keys=False) wind_traces = "D:/isp_2024_data/trace_data/wind/wind_2023" rezs = gets_rezs(workbook) wind_rez_mapping = draft_wind_rez_mapping(rezs, wind_traces) -with open("draft_wind_rez_mapping.yaml", "w") as file: +with Path.open("draft_wind_rez_mapping.yaml", "w") as file: yaml.dump(wind_rez_mapping, file, default_flow_style=False) diff --git a/generator_to_trace_draft_mapper.py b/generator_to_trace_draft_mapper.py index 54a6256..689a7e6 100644 --- a/generator_to_trace_draft_mapper.py +++ b/generator_to_trace_draft_mapper.py @@ -22,25 +22,23 @@ def get_all_generators(workbook_filepath): additional_gens["Status"] = "additional" existing_gens = existing_gens.rename( - columns={existing_gens.columns.values[0]: "Generator"} + columns={existing_gens.columns.to_numpy[0]: "Generator"} ) committed_gens = committed_gens.rename( - columns={committed_gens.columns.values[0]: "Generator"} + columns={committed_gens.columns.to_numpy[0]: "Generator"} ) anticipated_gens = anticipated_gens.rename( - columns={anticipated_gens.columns.values[0]: "Generator"} + columns={anticipated_gens.columns.to_numpy[0]: "Generator"} ) additional_gens = additional_gens.rename( - columns={additional_gens.columns.values[0]: "Generator"} + columns={additional_gens.columns.to_numpy[0]: "Generator"} ) all_gens = pd.concat( [existing_gens, committed_gens, anticipated_gens, additional_gens] ) - all_gens = all_gens.loc[:, ["Generator", "Technology type"]] - - return all_gens + return all_gens.loc[:, ["Generator", "Technology type"]] def gets_rezs(workbook_filepath): @@ -53,15 +51,12 @@ def gets_rezs(workbook_filepath): ) workbook = Parser(workbook_filepath) rezs = workbook.get_table_from_config(table_config) - rezs = rezs.loc[:, ["Name"]] - return rezs + return rezs.loc[:, ["Name"]] def find_best_match(plant_name, csv_files): best_match = process.extractOne(plant_name, csv_files, scorer=fuzz.token_set_ratio) - best_match = best_match[0] if best_match else None - best_match = best_match - return best_match + return best_match[0] if best_match else None def find_best_match_two_columns(row, csv_files): @@ -91,8 +86,7 @@ def draft_solar_generator_to_trace_mapping(solar_generators, solar_trace_directo solar_generators["CSVFile"] = solar_generators["Generator"].apply( lambda x: find_best_match(x, csv_project_names) ) - solar_generators = solar_generators.set_index("Generator")["CSVFile"].to_dict() - return solar_generators + return solar_generators.set_index("Generator")["CSVFile"].to_dict() def draft_solar_rez_mapping(rezs, rezs_trace_directory): @@ -100,8 +94,7 @@ def draft_solar_rez_mapping(rezs, rezs_trace_directory): csv_file_metadata = [extract_solar_trace_metadata(f) for f in csv_file_names] csv_rez_names = [f["name"] for f in csv_file_metadata if f["file_type"] == "area"] rezs["CSVFile"] = rezs["Name"].apply(lambda x: find_best_match(x, csv_rez_names)) - rezs = rezs.set_index("Name")["CSVFile"].to_dict() - return rezs + return rezs.set_index("Name")["CSVFile"].to_dict() def draft_wind_generator_to_trace_mapping( @@ -118,8 +111,8 @@ def draft_wind_generator_to_trace_mapping( wind_generators["Station Name"] = wind_generators["Generator"].apply( lambda x: find_best_match(x, wind_station_names) ) - wind_generators = pd.merge( - wind_generators, wind_duids_and_station_names, how="left", on="Station Name" + wind_generators = wind_generators.merge( + wind_duids_and_station_names, how="left", on="Station Name" ) wind_generators = wind_generators.drop_duplicates(["Generator"]) @@ -131,8 +124,7 @@ def draft_wind_generator_to_trace_mapping( :, ["Generator", "Station Name", "DUID", "CSVFile"] ] - wind_generators = wind_generators.set_index("Generator").to_dict(orient="index") - return wind_generators + return wind_generators.set_index("Generator").to_dict(orient="index") def draft_wind_rez_mapping(rezs, rezs_trace_directory): @@ -140,5 +132,4 @@ def draft_wind_rez_mapping(rezs, rezs_trace_directory): csv_file_metadata = [extract_wind_trace_metadata(f) for f in csv_file_names] csv_rez_names = [f["name"] for f in csv_file_metadata if f["file_type"] == "area"] rezs["CSVFile"] = rezs["Name"].apply(lambda x: find_best_match(x, csv_rez_names)) - rezs = rezs.set_index("Name")["CSVFile"].to_dict() - return rezs + return rezs.set_index("Name")["CSVFile"].to_dict() diff --git a/noxfile.py b/noxfile.py index 07ac5b1..fcb5491 100644 --- a/noxfile.py +++ b/noxfile.py @@ -1,3 +1,5 @@ +"""Nox configuration file.""" + import nox nox.options.default_venv_backend = "uv" diff --git a/src/isp_trace_parser/construct_reference_year_mapping.py b/src/isp_trace_parser/construct_reference_year_mapping.py index 6e384b2..932dd22 100644 --- a/src/isp_trace_parser/construct_reference_year_mapping.py +++ b/src/isp_trace_parser/construct_reference_year_mapping.py @@ -35,4 +35,4 @@ def construct_reference_year_mapping( reference_years = ( reference_years * full_reference_year_cycles ) + reference_years[:partial_cycle_length] - return dict(zip(years, reference_years)) + return dict(zip(years, reference_years, strict=True)) diff --git a/src/isp_trace_parser/demand_traces.py b/src/isp_trace_parser/demand_traces.py index 252777a..b671f5b 100644 --- a/src/isp_trace_parser/demand_traces.py +++ b/src/isp_trace_parser/demand_traces.py @@ -1,7 +1,7 @@ import functools import os from pathlib import Path -from typing import Literal, Optional +from typing import Literal import polars as pl import yaml @@ -47,15 +47,16 @@ class DemandMetadataFilter(BaseModel): reference_year: list of ints specifying reference_years """ - subregion: Optional[list[str]] = None - scenario: Optional[ + subregion: list[str] | None = None + scenario: ( list[Literal["Step Change", "Progressive Change", "Green Energy Exports"]] - ] = None - poe: Optional[list[Literal["POE50", "POE10"]]] = None - demand_type: Optional[ - list[Literal["OPSO_MODELLING", "OPSO_MODELLING_PVLITE", "PV_TOT"]] - ] = None - reference_year: Optional[list[int]] = None + | None + ) = None + poe: list[Literal["POE50", "POE10"]] | None = None + demand_type: ( + list[Literal["OPSO_MODELLING", "OPSO_MODELLING_PVLITE", "PV_TOT"]] | None + ) = None + reference_year: list[int] | None = None @validate_call @@ -135,10 +136,9 @@ def parse_demand_traces( files = get_all_filepaths(input_directory) - with open( + with Path.open( Path(__file__).parent.parent - / Path("isp_trace_name_mapping_configs/demand_scenario_mapping.yaml"), - "r", + / Path("isp_trace_name_mapping_configs/demand_scenario_mapping.yaml") ) as f: demand_scenario_mapping = yaml.safe_load(f) @@ -276,4 +276,4 @@ def extract_metadata_for_all_demand_files( A dictionary with filepaths as keys and metadata dicts as values. """ file_metadata = [extract_demand_trace_metadata(str(f.name)) for f in filenames] - return dict(zip(filenames, file_metadata)) + return dict(zip(filenames, file_metadata, strict=True)) diff --git a/src/isp_trace_parser/get_data.py b/src/isp_trace_parser/get_data.py index 6d620c2..3b03189 100644 --- a/src/isp_trace_parser/get_data.py +++ b/src/isp_trace_parser/get_data.py @@ -1,6 +1,6 @@ import datetime from pathlib import Path -from typing import List, Literal +from typing import Literal import pandas as pd import polars as pl @@ -37,10 +37,11 @@ def _year_range_to_dt_range( end_year, 7, 1 ) - elif year_type == "calendar": + if year_type == "calendar": return datetime.datetime(start_year, 1, 1), datetime.datetime( end_year + 1, 1, 1 ) + raise ValueError(year_type) def _query_parquet_single_reference_year( @@ -48,8 +49,8 @@ def _query_parquet_single_reference_year( end_year: int, reference_year: int, directory: str | Path, - filters: dict[str, any] = None, - select_columns: list[str] = None, + filters: dict[str, any] | None = None, + select_columns: list[str] | None = None, year_type: Literal["fy", "calendar"] = "fy", ) -> pd.DataFrame: """ @@ -103,14 +104,14 @@ def _query_parquet_single_reference_year( # Otherwise select all columns columns_to_select = df_lazy.columns - df = ( + dframe = ( df_lazy.filter(filter_expr) .select(*columns_to_select) .sort("datetime") .collect() ) - return df.to_pandas() + return dframe.to_pandas() def _query_parquet_multiple_reference_years( @@ -136,8 +137,7 @@ def _query_parquet_multiple_reference_years( start_year=year, end_year=year, reference_year=reference_year, **kwargs ) ) - data = pd.concat(data).reset_index(drop=True) - return data + return pd.concat(data).reset_index(drop=True) @validate_call @@ -145,10 +145,10 @@ def get_project_single_reference_year( start_year: int, end_year: int, reference_year: int, - project: str | List, + project: str | list, directory: str | Path, year_type: Literal["fy", "calendar"] = "fy", - select_columns: list[str] = None, + select_columns: list[str] | None = None, ): """ Query project trace data for a single reference year. @@ -237,11 +237,11 @@ def get_zone_single_reference_year( start_year: int, end_year: int, reference_year: int, - zone: str | List, - resource_type: str | List, + zone: str | list, + resource_type: str | list, directory: str | Path, year_type: Literal["fy", "calendar"] = "fy", - select_columns: list[str] = None, + select_columns: list[str] | None = None, ): """ Query zone trace data for a single reference year. @@ -333,13 +333,13 @@ def get_demand_single_reference_year( start_year: int, end_year: int, reference_year: int, - scenario: str | List, - subregion: str | List, - demand_type: str | List, - poe: str | List, + scenario: str | list, + subregion: str | list, + demand_type: str | list, + poe: str | list, directory: str | Path, year_type: Literal["fy", "calendar"] = "fy", - select_columns: list[str] = None, + select_columns: list[str] | None = None, ): """ Query demand trace data for a single reference year. @@ -441,10 +441,10 @@ def get_demand_single_reference_year( @validate_call def get_project_multiple_reference_years( reference_year_mapping: dict[int, int], - project: str | List, + project: str | list, directory: str | Path, year_type: Literal["fy", "calendar"] = "fy", - select_columns: list[str] = None, + select_columns: list[str] | None = None, ): """ Query project trace data across multiple reference years. @@ -530,11 +530,11 @@ def get_project_multiple_reference_years( @validate_call def get_zone_multiple_reference_years( reference_year_mapping: dict[int, int], - zone: str | List, - resource_type: str | List, + zone: str | list, + resource_type: str | list, directory: str | Path, year_type: Literal["fy", "calendar"] = "fy", - select_columns: list[str] = None, + select_columns: list[str] | None = None, ): """ Query zone trace data across multiple reference years. @@ -623,13 +623,13 @@ def get_zone_multiple_reference_years( @validate_call def get_demand_multiple_reference_years( reference_year_mapping: dict[int, int], - scenario: str | List, - subregion: str | List, - demand_type: str | List, - poe: str | List, + scenario: str | list, + subregion: str | list, + demand_type: str | list, + poe: str | list, directory: str | Path, year_type: Literal["fy", "calendar"] = "fy", - select_columns: list[str] = None, + select_columns: list[str] | None = None, ): """ Query demand trace data across multiple reference years. diff --git a/src/isp_trace_parser/input_validation.py b/src/isp_trace_parser/input_validation.py index cfa9164..c57f803 100644 --- a/src/isp_trace_parser/input_validation.py +++ b/src/isp_trace_parser/input_validation.py @@ -4,7 +4,8 @@ def input_directory(path: Path | str) -> Path: path = is_valid_path(path) if not path.is_dir(): - raise ValueError(f"Directory {path} does not exist") + msg = f"Directory {path} does not exist" + raise ValueError(msg) return path @@ -16,9 +17,11 @@ def is_valid_path(path: str | Path) -> Path: try: return Path(path) except (TypeError, ValueError): - raise ValueError(f"Invalid parsed directory path: {path}") + msg = f"Invalid parsed directory path: {path}" + raise ValueError(msg) from None -def start_year_before_end_year(start_year, end_year): +def start_year_before_end_year(start_year, end_year) -> None: if end_year < start_year: - raise ValueError(f"Start year {end_year} < end year {start_year}") + msg = f"Start year {end_year} < end year {start_year}" + raise ValueError(msg) diff --git a/src/isp_trace_parser/metadata_extractors.py b/src/isp_trace_parser/metadata_extractors.py index d579d6f..a3c9539 100644 --- a/src/isp_trace_parser/metadata_extractors.py +++ b/src/isp_trace_parser/metadata_extractors.py @@ -28,7 +28,8 @@ def extract_solar_trace_metadata(filename): match_data["reference_year"] = int(match_data["reference_year"]) return match_data - raise ValueError(f"Filename '{filename}' does not match the expected pattern") + msg = f"Filename '{filename}' does not match the expected pattern" + raise ValueError(msg) def extract_wind_trace_metadata(filename): @@ -57,7 +58,8 @@ def extract_wind_trace_metadata(filename): match_data["reference_year"] = int(match_data["reference_year"]) return match_data - raise ValueError(f"Filename '{filename}' does not match the expected pattern") + msg = f"Filename '{filename}' does not match the expected pattern" + raise ValueError(msg) def extract_demand_trace_metadata(filename): @@ -75,6 +77,7 @@ def extract_demand_trace_metadata(filename): match_data = match.groupdict() match_data["reference_year"] = int(match_data["reference_year"]) return match_data - else: - # If the pattern does not match, raise an error or return None - raise ValueError(f"Filename '{filename}' does not match the expected pattern") + + # If the pattern does not match, raise an error or return None + msg = f"Filename '{filename}' does not match the expected pattern" + raise ValueError(msg) diff --git a/src/isp_trace_parser/optimise_parquet.py b/src/isp_trace_parser/optimise_parquet.py index da00e07..dab2f80 100644 --- a/src/isp_trace_parser/optimise_parquet.py +++ b/src/isp_trace_parser/optimise_parquet.py @@ -1,6 +1,5 @@ from itertools import product from pathlib import Path -from typing import Optional import duckdb from pydantic import validate_call @@ -23,7 +22,7 @@ def partition_traces_by_columns( input_directory: str | Path, output_directory: str | Path, partition_cols: list[str], - sort_by: Optional[list[str]] = ["datetime"], + sort_by: list[str] | None = None, ) -> None: """Partition parquet traces by specified columns with optional sorting. @@ -53,6 +52,10 @@ def partition_traces_by_columns( ... partition_cols=["scenario", "reference_year"] ... ) # doctest: +SKIP """ + # handle default parameters to avoid using mutable objects as argument defaults + if sort_by is None: + sort_by = ["datetime"] + output_path = Path(output_directory) output_path.mkdir(parents=True, exist_ok=True) @@ -70,10 +73,8 @@ def partition_traces_by_columns( partitions = [tuple(val[0] for val in vals) for vals in product(*distinct_values)] for partition_values in partitions: - # print(*partition_values) - conditions = [] - for col, val in zip(partition_cols, partition_values): + for col, val in zip(partition_cols, partition_values, strict=True): if isinstance(val, str): conditions.append(f"{col}='{val}'") else: diff --git a/src/isp_trace_parser/remote/download.py b/src/isp_trace_parser/remote/download.py index 462887c..c2c8c5f 100644 --- a/src/isp_trace_parser/remote/download.py +++ b/src/isp_trace_parser/remote/download.py @@ -54,14 +54,16 @@ def _download_from_manifest( manifest_path = files("isp_trace_parser.remote.manifests") / f"{manifest_name}.txt" if not manifest_path.exists(): - raise FileNotFoundError(f"Manifest file not found: {manifest_path}") + msg = f"Manifest file not found: {manifest_path}" + raise FileNotFoundError(msg) # Read URLs from manifest - with open(manifest_path) as f: + with Path.open(manifest_path) as f: urls = [line.strip() for line in f if line.strip()] if not urls: - raise ValueError(f"No URLs found in manifest: {manifest_path}") + msg = f"No URLs found in manifest: {manifest_path}" + raise ValueError(msg) save_directory = Path(save_directory) @@ -81,8 +83,7 @@ def _download_with_retry( for attempt in range(max_retries): try: _download_file(url, save_directory, strip_levels, unquote_path) - return - except requests.exceptions.RequestException: + except requests.exceptions.RequestException: # noqa: PERF203 if attempt < max_retries - 1: time.sleep(2**attempt) else: @@ -124,10 +125,11 @@ def _download_file( # Strip specified number of directory levels path_parts = url_path.split("/") if strip_levels >= len(path_parts): - raise ValueError( + msg = ( f"Cannot strip {strip_levels} levels from path with only " f"{len(path_parts)} parts: {url_path}" ) + raise ValueError(msg) stripped_path = "/".join(path_parts[strip_levels:]) destination = save_directory / stripped_path @@ -144,7 +146,7 @@ def _download_file( # Write file with progress bar with ( - open(destination, "wb") as f, + Path.open(destination, "wb") as f, tqdm( total=total_size, unit="B", @@ -211,17 +213,16 @@ def fetch_trace_data( # Validate inputs if dataset_type not in ["full", "example"]: - raise ValueError( - f"dataset_type must be 'full' or 'example', got: {dataset_type}" - ) + msg = f"dataset_type must be 'full' or 'example', got: {dataset_type}" + raise ValueError(msg) if dataset_src != "isp_2024": - raise ValueError(f"Only isp_2024 is currently supported, got: {dataset_src}") + msg = f"Only isp_2024 is currently supported, got: {dataset_src}" + raise ValueError(msg) if data_format not in ["processed", "archive"]: - raise ValueError( - f"data_format must be 'processed' or 'archive', got: {data_format}" - ) + msg = f"data_format must be 'processed' or 'archive', got: {data_format}" + raise ValueError(msg) # Construct manifest name and download manifest_name = f"{data_format}/{dataset_type}_{dataset_src}" diff --git a/src/isp_trace_parser/solar_traces.py b/src/isp_trace_parser/solar_traces.py index 8931883..6a855df 100644 --- a/src/isp_trace_parser/solar_traces.py +++ b/src/isp_trace_parser/solar_traces.py @@ -1,7 +1,7 @@ import functools import os from pathlib import Path -from typing import Literal, Optional +from typing import Literal import yaml from joblib import Parallel, delayed @@ -51,10 +51,10 @@ class SolarMetadataFilter(BaseModel): reference_year: list of ints specifying reference_years """ - name: Optional[list[str]] = None - file_type: Optional[list[Literal["zone", "project"]]] = None - resource_type: Optional[list[Literal["SAT", "FFP", "CST"]]] = None - reference_year: Optional[list[int]] = None + name: list[str] | None = None + file_type: list[Literal["zone", "project"]] | None = None + resource_type: list[Literal["SAT", "FFP", "CST"]] | None = None + reference_year: list[int] | None = None @validate_call @@ -136,16 +136,14 @@ def parse_solar_traces( files = get_all_filepaths(input_directory) file_metadata = extract_metadata_for_all_solar_files(files) - with open( + with Path.open( Path(__file__).parent.parent - / Path("isp_trace_name_mapping_configs/solar_project_mapping.yaml"), - "r", + / Path("isp_trace_name_mapping_configs/solar_project_mapping.yaml") ) as f: project_name_mapping = yaml.safe_load(f) - with open( + with Path.open( Path(__file__).parent.parent - / Path("isp_trace_name_mapping_configs/solar_zone_mapping.yaml"), - "r", + / Path("isp_trace_name_mapping_configs/solar_zone_mapping.yaml") ) as f: zone_name_mapping = yaml.safe_load(f) name_mappings = {**project_name_mapping, **zone_name_mapping} @@ -158,7 +156,7 @@ def parse_solar_traces( } project_and_zone_output_names, project_and_zone_input_names = zip( - *name_mappings.items() + *name_mappings.items(), strict=True ) partial_func = functools.partial( @@ -173,12 +171,12 @@ def parse_solar_traces( Parallel(n_jobs=max_workers)( delayed(partial_func)(save_name, old_trace_name) for save_name, old_trace_name in zip( - project_and_zone_output_names, project_and_zone_input_names + project_and_zone_output_names, project_and_zone_input_names, strict=True ) ) else: for save_name, old_trace_name in zip( - project_and_zone_output_names, project_and_zone_input_names + project_and_zone_output_names, project_and_zone_input_names, strict=True ): partial_func(save_name, old_trace_name) @@ -277,7 +275,7 @@ def extract_metadata_for_all_solar_files( A dictionary with filepaths as keys and metadata dicts as values. """ file_metadata = [extract_solar_trace_metadata(str(f.name)) for f in filepaths] - return dict(zip(filepaths, file_metadata)) + return dict(zip(filepaths, file_metadata, strict=True)) def get_unique_resource_types_in_metadata( @@ -293,7 +291,7 @@ def get_unique_resource_types_in_metadata( A list of unique resource types. """ return list( - set(metadata["resource_type"] for metadata in metadata_for_trace_files.values()) + {metadata["resource_type"] for metadata in metadata_for_trace_files.values()} ) diff --git a/src/isp_trace_parser/trace_formatter.py b/src/isp_trace_parser/trace_formatter.py index b6262e4..e694543 100644 --- a/src/isp_trace_parser/trace_formatter.py +++ b/src/isp_trace_parser/trace_formatter.py @@ -65,10 +65,10 @@ def trace_formatter(trace_data: pl.DataFrame) -> pl.DataFrame: value_name="value", ) - def get_hour(time_label): + def get_hour(time_label) -> timedelta: return timedelta(hours=int(time_label) // 2) - def get_minute(time_label): + def get_minute(time_label) -> timedelta: return timedelta(minutes=int(time_label) % 2 * 30) trace_data = trace_data.with_columns( @@ -92,12 +92,10 @@ def get_minute(time_label): ] ) - trace_data = ( + return ( trace_data.with_columns( [(pl.col("datetime") + pl.col("Hour") + pl.col("Minute")).alias("datetime")] ) .select(["datetime", "value"]) .sort("datetime") ) - - return trace_data diff --git a/src/isp_trace_parser/trace_restructure_helper_functions.py b/src/isp_trace_parser/trace_restructure_helper_functions.py index 25c8229..5937ab1 100644 --- a/src/isp_trace_parser/trace_restructure_helper_functions.py +++ b/src/isp_trace_parser/trace_restructure_helper_functions.py @@ -1,4 +1,3 @@ -from datetime import timedelta from pathlib import Path import polars as pl @@ -10,14 +9,13 @@ def get_all_filepaths(directory: Path) -> list[Path]: if directory.is_dir(): return [path for path in Path(directory).rglob("*.csv") if path.is_file()] - else: - raise ValueError(f"{directory} not found.") + msg = f"{directory} not found." + raise ValueError(msg) def read_trace_csv(file: Path) -> pl.DataFrame: pl_types = [pl.Int64] * 3 + [pl.Float64] * 48 - data = pl.read_csv(file, schema_overrides=pl_types) - return data + return pl.read_csv(file, schema_overrides=pl_types) def read_and_format_traces(files: list[Path]) -> list[pl.DataFrame]: @@ -31,10 +29,10 @@ def read_and_format_traces(files: list[Path]) -> list[pl.DataFrame]: def calculate_average_trace(traces: list[pl.DataFrame]) -> pl.DataFrame: combined_traces = pl.concat(traces) - average_trace = combined_traces.group_by("datetime").agg( + # return average trace + return combined_traces.group_by("datetime").agg( [pl.col("value").mean().alias("value")] ) - return average_trace def _frame_with_metadata(trace: pl.DataFrame, file_metadata: dict) -> pl.DataFrame: @@ -72,11 +70,7 @@ def process_and_save_files( ) -> None: traces = read_and_format_traces(files) - if len(traces) > 1: - trace = calculate_average_trace(traces) - else: - trace = traces[0] - + trace = calculate_average_trace(traces) if len(traces) > 1 else traces[0] trace = _frame_with_metadata(trace, file_metadata) save_trace(trace, file_metadata, output_directory, write_output_filepath) @@ -87,21 +81,19 @@ def get_metadata_that_matches_trace_names( ) -> dict[Path, dict[str, str]]: if isinstance(trace_names, str): trace_names = [trace_names] - matching_meta_data = { + # Return matching metadata + return { f: metadata.copy() for f, metadata in all_input_file_metadata.items() if metadata["name"] in trace_names } - return matching_meta_data def get_unique_reference_years_in_metadata( metadata_for_trace_files: dict[Path, dict[str, str]], ) -> list[str]: return list( - set( - metadata["reference_year"] for metadata in metadata_for_trace_files.values() - ) + {metadata["reference_year"] for metadata in metadata_for_trace_files.values()} ) @@ -135,9 +127,12 @@ def check_filter_by_metadata( return True for field, allowed_values in filters.model_dump(exclude_unset=True).items(): - if field in metadata and allowed_values is not None: - if metadata[field] not in allowed_values: - return False + if ( + field in metadata + and allowed_values is not None + and metadata[field] not in allowed_values + ): + return False return True @@ -145,9 +140,7 @@ def check_filter_by_metadata( def get_unique_project_and_zone_names_in_input_files( metadata_for_trace_files: dict[Path, dict[str, str]], ) -> list[str]: - names = [] - for filepath, meta_data in metadata_for_trace_files.items(): - names.append(meta_data["name"]) + names = [meta_data["name"] for meta_data in metadata_for_trace_files.values()] return list(set(names)) @@ -159,9 +152,8 @@ def filter_mapping_by_names_in_input_files( if isinstance(input_name, list): if input_name[0] in names_in_input_files: filtered_mapping[output_name] = input_name - else: - if input_name in names_in_input_files: - filtered_mapping[output_name] = input_name + elif input_name in names_in_input_files: + filtered_mapping[output_name] = input_name return filtered_mapping diff --git a/src/isp_trace_parser/wind_traces.py b/src/isp_trace_parser/wind_traces.py index 316b280..2126a04 100644 --- a/src/isp_trace_parser/wind_traces.py +++ b/src/isp_trace_parser/wind_traces.py @@ -1,7 +1,7 @@ import functools import os from pathlib import Path -from typing import Literal, Optional +from typing import Literal import yaml from joblib import Parallel, delayed @@ -51,10 +51,10 @@ class WindMetadataFilter(BaseModel): reference_year: list of ints specifying reference_years """ - name: Optional[list[str]] = None - file_type: Optional[list[Literal["zone", "project"]]] = None - resource_type: Optional[list[Literal["WH", "WM", "WL", "WX", "wind"]]] = None - reference_year: Optional[list[int]] = None + name: list[str] | None = None + file_type: list[Literal["zone", "project"]] | None = None + resource_type: list[Literal["WH", "WM", "WL", "WX", "wind"]] | None = None + reference_year: list[int] | None = None @validate_call @@ -138,19 +138,17 @@ def parse_wind_traces( files = get_all_filepaths(input_directory) file_metadata = extract_metadata_for_all_wind_files(files) - with open( + with Path.open( Path(__file__).parent.parent - / Path("isp_trace_name_mapping_configs/wind_project_mapping.yaml"), - "r", + / Path("isp_trace_name_mapping_configs/wind_project_mapping.yaml") ) as f: project_name_mappings = yaml.safe_load(f) project_name_mappings = restructure_wind_project_mapping(project_name_mappings) - with open( + with Path.open( Path(__file__).parent.parent - / Path("isp_trace_name_mapping_configs/wind_zone_mapping.yaml"), - "r", + / Path("isp_trace_name_mapping_configs/wind_zone_mapping.yaml") ) as f: zone_name_mappings = yaml.safe_load(f) @@ -161,12 +159,14 @@ def parse_wind_traces( zone_name_mappings = filter_mapping_by_names_in_input_files( zone_name_mappings, project_and_zone_input_names ) - zone_output_names, zone_input_names = zip(*zone_name_mappings.items()) + zone_output_names, zone_input_names = zip(*zone_name_mappings.items(), strict=True) project_name_mappings = filter_mapping_by_names_in_input_files( project_name_mappings, project_and_zone_input_names ) - project_output_names, project_input_names = zip(*project_name_mappings.items()) + project_output_names, project_input_names = zip( + *project_name_mappings.items(), strict=True + ) zone_partial_func = functools.partial( restructure_wind_zone_files, @@ -187,21 +187,27 @@ def parse_wind_traces( Parallel(n_jobs=max_workers)( delayed(zone_partial_func)(save_name, old_trace_name) - for save_name, old_trace_name in zip(zone_output_names, zone_input_names) + for save_name, old_trace_name in zip( + zone_output_names, zone_input_names, strict=True + ) ) Parallel(n_jobs=max_workers)( delayed(project_partial_func)(save_name, old_trace_name) for save_name, old_trace_name in zip( - project_output_names, project_input_names + project_output_names, project_input_names, strict=True ) ) else: - for save_name, old_trace_name in zip(zone_output_names, zone_input_names): + for save_name, old_trace_name in zip( + zone_output_names, zone_input_names, strict=True + ): zone_partial_func(save_name, old_trace_name) - for save_name, old_trace_name in zip(project_output_names, project_input_names): + for save_name, old_trace_name in zip( + project_output_names, project_input_names, strict=True + ): project_partial_func(save_name, old_trace_name) @@ -343,14 +349,14 @@ def extract_metadata_for_all_wind_files(filepaths: list) -> dict: Returns a dict with filepaths as keys and metadata dicts as values. """ file_metadata = [extract_wind_trace_metadata(str(f.name)) for f in filepaths] - return dict(zip(filepaths, file_metadata)) + return dict(zip(filepaths, file_metadata, strict=True)) def get_unique_resource_types_in_metadata( metadata_for_trace_files: dict[str:str], ) -> list: return list( - set(metadata["resource_type"] for metadata in metadata_for_trace_files.values()) + {metadata["resource_type"] for metadata in metadata_for_trace_files.values()} ) diff --git a/tests/conftest.py b/tests/conftest.py index 09d5a86..0f482c2 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -18,8 +18,8 @@ def parsed_trace_trace_directory(request): use_concurrency = request.param - with tempfile.TemporaryDirectory() as tmp_parsed_directory: - tmp_parsed_directory = Path(tmp_parsed_directory) + with tempfile.TemporaryDirectory() as tmpdir: + tmp_parsed_directory = Path(tmpdir) for file_type in ["zone", "project"]: filters = wind_traces.WindMetadataFilter(file_type=[file_type]) @@ -52,7 +52,7 @@ def parsed_trace_trace_directory(request): optimise_parquet.partition_traces_by_columns( input_directory=tmp_parsed_directory / "demand", - output_directory=tmp_parsed_directory / f"demand_optimised", + output_directory=tmp_parsed_directory / "demand_optimised", partition_cols=["scenario", "reference_year"], ) yield tmp_parsed_directory diff --git a/tests/create_end_to_end_test_data.py b/tests/create_end_to_end_test_data.py index e9f2086..f8ca816 100644 --- a/tests/create_end_to_end_test_data.py +++ b/tests/create_end_to_end_test_data.py @@ -13,7 +13,7 @@ def generate_random_data(start_year, end_year): ) # Create a DataFrame for the date components - df = pd.DataFrame( + dframe = pd.DataFrame( {"Year": date_range.year, "Month": date_range.month, "Day": date_range.day} ) @@ -24,8 +24,9 @@ def generate_random_data(start_year, end_year): half_hour_columns = [f"{i:02d}" for i in range(1, 49)] # Combine the date components with the random data - df = pd.concat([df, pd.DataFrame(random_data, columns=half_hour_columns)], axis=1) - return df + return pd.concat( + [dframe, pd.DataFrame(random_data, columns=half_hour_columns)], axis=1 + ) data = generate_random_data(start_year=config.start, end_year=config.end) @@ -53,7 +54,7 @@ def create_solar_csvs(directory): ) -def create_wind_csvs(directory): +def create_wind_csvs(directory) -> None: combos = itertools.product( config.reference_years, simple_flatten(config.wind_projects.values()) ) @@ -71,7 +72,7 @@ def create_wind_csvs(directory): ) -def create_demand_csvs(directory): +def create_demand_csvs(directory) -> None: combos = itertools.product( config.reference_years, config.sub_regions, diff --git a/tests/test_download.py b/tests/test_download.py index 79cc85c..3724d23 100644 --- a/tests/test_download.py +++ b/tests/test_download.py @@ -12,8 +12,8 @@ def test_download_test_file(): """Test download with actual server file.""" - with TemporaryDirectory() as tmp_path: - tmp_path = Path(tmp_path) + with TemporaryDirectory() as tmpdir: + tmp_path = Path(tmpdir) download._download_file(TEST_URL, tmp_path, strip_levels=0) downloaded = tmp_path / "test" / "test" / "test_file.txt" @@ -24,8 +24,8 @@ def test_download_test_file(): def test_download_with_retry(): """Test retry logic with real server.""" - with TemporaryDirectory() as tmp_path: - tmp_path = Path(tmp_path) + with TemporaryDirectory() as tmpdir: + tmp_path = Path(tmpdir) download._download_with_retry(TEST_URL, tmp_path, strip_levels=0, max_retries=3) assert (tmp_path / "test" / "test" / "test_file.txt").exists() @@ -37,8 +37,8 @@ def test_fetch_trace_data_with_test_manifest(monkeypatch): with containing a single url ("https://data.openisp.au/test/test/test_file.txt") """ - with TemporaryDirectory() as tmp_path: - tmp_path = Path(tmp_path) + with TemporaryDirectory() as tmpdir: + tmp_path = Path(tmpdir) # Point to test fixtures instead of production manifests def mock_files(package): @@ -74,8 +74,8 @@ def test_fetch_trace_data(unquote: bool, monkeypatch): manifest with containing a single url ("https://data.openisp.au/test/test/test_file.txt") """ - with TemporaryDirectory() as tmp_path: - tmp_path = Path(tmp_path) + with TemporaryDirectory() as tmpdir: + tmp_path = Path(tmpdir) # Point to test manifests instead of production manifests def mock_files(package): @@ -97,28 +97,28 @@ def mock_files(package): def test_wrong_source(): # no ISP 2025 data - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="Only isp_2024 is currently supported"): download.fetch_trace_data("test", "isp_2025", "/", "archive") def test_wrong_format(): # only archive or processed data (not other) - with pytest.raises(ValueError): + with pytest.raises( + ValueError, match="data_format must be 'processed' or 'archive'" + ): download.fetch_trace_data("test", "isp_2024", "/", "other") def test_wrong_type(): # only full or example type - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="dataset_type must be 'full' or 'example'"): download.fetch_trace_data("other", "isp_2024", "/", "archive") def test_empty_manifest(monkeypatch): """Test that empty manifest raises ValueError.""" - from importlib.resources import files - - with TemporaryDirectory() as tmp_path: - tmp_path = Path(tmp_path) + with TemporaryDirectory() as tmpdir: + tmp_path = Path(tmpdir) # Point to test manifest instead of production manifests def mock_files(package): @@ -132,9 +132,9 @@ def mock_files(package): def test_strip_levels_too_high(): """Test that strip_levels >= path parts raises ValueError.""" - with TemporaryDirectory() as tmp_path: - tmp_path = Path(tmp_path) + with TemporaryDirectory() as tmpdir: + tmp_path = Path(tmpdir) # TEST_URL has path "test/test/test_file.txt" = 3 parts - with pytest.raises(ValueError, match="Cannot strip .* levels"): + with pytest.raises(ValueError, match=r"Cannot strip .* levels"): download._download_file(TEST_URL, tmp_path, strip_levels=10) diff --git a/tests/test_get_data.py b/tests/test_get_data.py index ddb1d0a..5c1facf 100644 --- a/tests/test_get_data.py +++ b/tests/test_get_data.py @@ -56,7 +56,7 @@ def test_get_zone_single_reference_year(parsed_trace_trace_directory: Path, year .to_pandas() ) - df = get_zone_single_reference_year( + dframe = get_zone_single_reference_year( start_year=2023, end_year=2024, reference_year=2022, @@ -66,7 +66,7 @@ def test_get_zone_single_reference_year(parsed_trace_trace_directory: Path, year year_type=year_type, ) - pd.testing.assert_frame_equal(test_df, df) + pd.testing.assert_frame_equal(test_df, dframe) def test_get_zone_multiple_reference_year(parsed_trace_trace_directory: Path): @@ -82,7 +82,7 @@ def test_get_zone_multiple_reference_year(parsed_trace_trace_directory: Path): .to_pandas() ) - df = get_zone_multiple_reference_years( + dframe = get_zone_multiple_reference_years( reference_year_mapping={2029: 2022, 2030: 2022}, zone="N1", resource_type="WM", @@ -90,7 +90,7 @@ def test_get_zone_multiple_reference_year(parsed_trace_trace_directory: Path): year_type="fy", ) - pd.testing.assert_frame_equal(test_df, df) + pd.testing.assert_frame_equal(test_df, dframe) def test_get_project_single_reference_year(parsed_trace_trace_directory: Path): @@ -108,7 +108,7 @@ def test_get_project_single_reference_year(parsed_trace_trace_directory: Path): .to_pandas() ) - df = get_project_single_reference_year( + dframe = get_project_single_reference_year( start_year=2023, end_year=2024, reference_year=2022, @@ -117,7 +117,7 @@ def test_get_project_single_reference_year(parsed_trace_trace_directory: Path): year_type="fy", ) - pd.testing.assert_frame_equal(test_df, df) + pd.testing.assert_frame_equal(test_df, dframe) def test_get_project_multiple_reference_year(parsed_trace_trace_directory: Path): @@ -135,14 +135,14 @@ def test_get_project_multiple_reference_year(parsed_trace_trace_directory: Path) .to_pandas() ) - df = get_project_multiple_reference_years( + dframe = get_project_multiple_reference_years( reference_year_mapping={2029: 2022, 2030: 2022}, project="Broken Hill Solar Farm", directory=parsed_trace_trace_directory / "project_optimised", year_type="fy", ) - pd.testing.assert_frame_equal(test_df, df) + pd.testing.assert_frame_equal(test_df, dframe) def test_get_demand_single_reference_year(parsed_trace_trace_directory: Path): @@ -163,7 +163,7 @@ def test_get_demand_single_reference_year(parsed_trace_trace_directory: Path): .to_pandas() ) - df = get_demand_single_reference_year( + dframe = get_demand_single_reference_year( start_year=2023, end_year=2024, reference_year=2011, @@ -175,7 +175,7 @@ def test_get_demand_single_reference_year(parsed_trace_trace_directory: Path): year_type="fy", ) - pd.testing.assert_frame_equal(test_df, df) + pd.testing.assert_frame_equal(test_df, dframe) def test_get_demand_multiple_reference_year(parsed_trace_trace_directory: Path): @@ -195,7 +195,7 @@ def test_get_demand_multiple_reference_year(parsed_trace_trace_directory: Path): .to_pandas() ) - df = get_demand_multiple_reference_years( + dframe = get_demand_multiple_reference_years( reference_year_mapping={2029: 2011, 2030: 2011}, scenario="Green Energy Exports", subregion="CNSW", @@ -205,11 +205,11 @@ def test_get_demand_multiple_reference_year(parsed_trace_trace_directory: Path): year_type="fy", ) - pd.testing.assert_frame_equal(test_df, df) + pd.testing.assert_frame_equal(test_df, dframe) def test_explicit_select_columns(parsed_trace_trace_directory): - df = get_zone_single_reference_year( + dframe = get_zone_single_reference_year( start_year=2023, end_year=2024, reference_year=2022, @@ -218,11 +218,11 @@ def test_explicit_select_columns(parsed_trace_trace_directory): directory=parsed_trace_trace_directory / "zone_optimised", select_columns=["datetime", "value", "zone"], ) - assert list(df.columns) == ["datetime", "value", "zone"] + assert list(dframe.columns) == ["datetime", "value", "zone"] def test_multi_value_filter(parsed_trace_trace_directory): - df = get_zone_single_reference_year( + dframe = get_zone_single_reference_year( start_year=2023, end_year=2024, reference_year=2022, @@ -231,7 +231,7 @@ def test_multi_value_filter(parsed_trace_trace_directory): directory=parsed_trace_trace_directory / "zone_optimised", ) # Should have both zone column included in output - assert "zone" in df.columns + assert "zone" in dframe.columns def test_wind_project_single_reference_year(parsed_trace_trace_directory): @@ -249,7 +249,7 @@ def test_wind_project_single_reference_year(parsed_trace_trace_directory): .to_pandas() ) - df = wind_project_single_reference_year( + dframe = wind_project_single_reference_year( start_year=2023, end_year=2024, reference_year=2022, @@ -257,7 +257,7 @@ def test_wind_project_single_reference_year(parsed_trace_trace_directory): directory=parsed_trace_trace_directory / "project_optimised", year_type="fy", ) - pd.testing.assert_frame_equal(test_df, df) + pd.testing.assert_frame_equal(test_df, dframe) def test_solar_project_single_reference_year(parsed_trace_trace_directory): @@ -275,7 +275,7 @@ def test_solar_project_single_reference_year(parsed_trace_trace_directory): .to_pandas() ) - df = solar_project_single_reference_year( + dframe = solar_project_single_reference_year( start_year=2023, end_year=2024, reference_year=2022, @@ -283,7 +283,7 @@ def test_solar_project_single_reference_year(parsed_trace_trace_directory): directory=parsed_trace_trace_directory / "project_optimised", year_type="fy", ) - pd.testing.assert_frame_equal(test_df, df) + pd.testing.assert_frame_equal(test_df, dframe) def test_solar_project_multiple_reference_years(parsed_trace_trace_directory: Path): @@ -301,14 +301,14 @@ def test_solar_project_multiple_reference_years(parsed_trace_trace_directory: Pa .to_pandas() ) - df = solar_project_multiple_reference_years( + dframe = solar_project_multiple_reference_years( reference_years={2029: 2022, 2030: 2022}, project="Broken Hill Solar Farm", directory=parsed_trace_trace_directory / "project_optimised", year_type="fy", ) - pd.testing.assert_frame_equal(test_df, df) + pd.testing.assert_frame_equal(test_df, dframe) def test_wind_project_multiple_reference_years(parsed_trace_trace_directory: Path): @@ -326,14 +326,14 @@ def test_wind_project_multiple_reference_years(parsed_trace_trace_directory: Pat .to_pandas() ) - df = wind_project_multiple_reference_years( + dframe = wind_project_multiple_reference_years( reference_years={2029: 2022, 2030: 2022}, project="Bodangora Wind Farm", directory=parsed_trace_trace_directory / "project_optimised", year_type="fy", ) - pd.testing.assert_frame_equal(test_df, df) + pd.testing.assert_frame_equal(test_df, dframe) def test_solar_area_single_reference_year(parsed_trace_trace_directory: Path): @@ -350,7 +350,7 @@ def test_solar_area_single_reference_year(parsed_trace_trace_directory: Path): .to_pandas() ) - df = solar_area_single_reference_year( + dframe = solar_area_single_reference_year( start_year=2023, end_year=2024, reference_year=2022, @@ -359,7 +359,7 @@ def test_solar_area_single_reference_year(parsed_trace_trace_directory: Path): directory=parsed_trace_trace_directory / "zone_optimised", ) - pd.testing.assert_frame_equal(test_df, df) + pd.testing.assert_frame_equal(test_df, dframe) def test_demand_single_reference_year(parsed_trace_trace_directory: Path): @@ -380,7 +380,7 @@ def test_demand_single_reference_year(parsed_trace_trace_directory: Path): .to_pandas() ) - df = demand_single_reference_year( + dframe = demand_single_reference_year( start_year=2023, end_year=2024, reference_year=2011, @@ -392,7 +392,7 @@ def test_demand_single_reference_year(parsed_trace_trace_directory: Path): year_type="fy", ) - pd.testing.assert_frame_equal(test_df, df) + pd.testing.assert_frame_equal(test_df, dframe) def test_demand_multiple_reference_years(parsed_trace_trace_directory: Path): @@ -412,7 +412,7 @@ def test_demand_multiple_reference_years(parsed_trace_trace_directory: Path): .to_pandas() ) - df = demand_multiple_reference_years( + dframe = demand_multiple_reference_years( reference_years={2029: 2011, 2030: 2011}, scenario="Green Energy Exports", subregion="CNSW", @@ -422,4 +422,4 @@ def test_demand_multiple_reference_years(parsed_trace_trace_directory: Path): year_type="fy", ) - pd.testing.assert_frame_equal(test_df, df) + pd.testing.assert_frame_equal(test_df, dframe) diff --git a/tests/test_input_validation.py b/tests/test_input_validation.py index b380d46..d674bd3 100644 --- a/tests/test_input_validation.py +++ b/tests/test_input_validation.py @@ -36,7 +36,7 @@ def test_solar_metadata_filter_valid(valid_input): @pytest.mark.parametrize( - "invalid_input,expected_error", + ("invalid_input", "expected_error"), [ ({"file_type": ["invalid"]}, "Input should be 'zone' or 'project'"), ({"resource_type": ["invalid"]}, "Input should be 'SAT', 'FFP' or 'CST'"), @@ -69,7 +69,7 @@ def test_wind_metadata_filter_valid(valid_input): @pytest.mark.parametrize( - "invalid_input,expected_error", + ("invalid_input", "expected_error"), [ ({"file_type": ["invalid"]}, "Input should be 'zone' or 'project'"), ( @@ -107,7 +107,7 @@ def test_demand_metadata_filter_valid(valid_input): @pytest.mark.parametrize( - "invalid_input,expected_error", + ("invalid_input", "expected_error"), [ ( {"scenario": ["invalid"]}, @@ -173,7 +173,7 @@ def test_construct_reference_year_mapping_validation_valid(): start_year=2030, end_year=2035, reference_years=[2011, 2013, 2018] ) assert isinstance(result, dict) - assert len(result) == 6 + assert len(result) == 2035 - 2030 + 1 assert all(isinstance(k, int) and isinstance(v, int) for k, v in result.items()) @@ -183,7 +183,7 @@ def test_input_directory(tmp_path): valid_dir.mkdir() assert input_validation.input_directory(valid_dir) == valid_dir - with pytest.raises(ValueError, match="Directory .* does not exist"): + with pytest.raises(ValueError, match=r"Directory .* does not exist"): input_validation.input_directory(tmp_path / "non_existent_dir") @@ -238,7 +238,7 @@ def test_is_valid_path_invalid(invalid_path): @pytest.mark.parametrize( - "start,end", + ("start", "end"), [ (2020, 2025), (2020, 2020), @@ -250,7 +250,7 @@ def test_start_year_before_end_year_valid(start, end): @pytest.mark.parametrize( - "start,end", + ("start", "end"), [ (2025, 2020), (0, -10), @@ -258,5 +258,5 @@ def test_start_year_before_end_year_valid(start, end): ], ) def test_start_year_before_end_year_invalid(start, end): - with pytest.raises(ValueError, match="Start year .* < end year"): + with pytest.raises(ValueError, match=r"Start year .* < end year"): input_validation.start_year_before_end_year(start, end) diff --git a/tests/test_optimise_parquet.py b/tests/test_optimise_parquet.py index f7b0efe..53343d5 100644 --- a/tests/test_optimise_parquet.py +++ b/tests/test_optimise_parquet.py @@ -10,7 +10,7 @@ @pytest.mark.parametrize( - "expected_data, file_type", + ("expected_data", "file_type"), [("zone_data_0.parquet", "zone"), ("project_data_0.parquet", "project")], ) def test_optimisation(parsed_trace_trace_directory, expected_data, file_type): diff --git a/tests/test_trace_file_meta_data_extraction.py b/tests/test_trace_file_meta_data_extraction.py index 5e50c9b..3dfee6c 100644 --- a/tests/test_trace_file_meta_data_extraction.py +++ b/tests/test_trace_file_meta_data_extraction.py @@ -1,26 +1,29 @@ from isp_trace_parser import metadata_extractors +YEAR_2011 = 2011 +YEAR_2023 = 2023 + def test_solar_trace_metadata_extraction(): file_name = "Woolooga_SAT_RefYear2023.csv" metadata = metadata_extractors.extract_solar_trace_metadata(file_name) assert metadata["name"] == "Woolooga" assert metadata["resource_type"] == "SAT" - assert metadata["reference_year"] == 2023 + assert metadata["reference_year"] == YEAR_2023 assert metadata["file_type"] == "project" file_name = "Darling_Downs_FFP_RefYear2023.csv" metadata = metadata_extractors.extract_solar_trace_metadata(file_name) assert metadata["name"] == "Darling_Downs" assert metadata["resource_type"] == "FFP" - assert metadata["reference_year"] == 2023 + assert metadata["reference_year"] == YEAR_2023 assert metadata["file_type"] == "project" file_name = "REZ_N0_NSW_Non-REZ_CST_RefYear2023.csv" metadata = metadata_extractors.extract_solar_trace_metadata(file_name) assert metadata["name"] == "N0" assert metadata["resource_type"] == "CST" - assert metadata["reference_year"] == 2023 + assert metadata["reference_year"] == YEAR_2023 assert metadata["file_type"] == "zone" @@ -29,21 +32,21 @@ def test_wind_trace_metadata_extraction(): metadata = metadata_extractors.extract_wind_trace_metadata(file_name) assert metadata["name"] == "ARWF1" assert metadata["resource_type"] == "WIND" - assert metadata["reference_year"] == 2023 + assert metadata["reference_year"] == YEAR_2023 assert metadata["file_type"] == "project" file_name = "CAPTL_WF_RefYear2023.csv" metadata = metadata_extractors.extract_wind_trace_metadata(file_name) assert metadata["name"] == "CAPTL_WF" assert metadata["resource_type"] == "WIND" - assert metadata["reference_year"] == 2023 + assert metadata["reference_year"] == YEAR_2023 assert metadata["file_type"] == "project" file_name = "N8_WH_Cooma-Monaro_RefYear2023.csv" metadata = metadata_extractors.extract_wind_trace_metadata(file_name) assert metadata["name"] == "N8" assert metadata["resource_type"] == "WH" - assert metadata["reference_year"] == 2023 + assert metadata["reference_year"] == YEAR_2023 assert metadata["file_type"] == "zone" @@ -51,7 +54,7 @@ def test_demand_trace_metadata_extraction(): file_name = "VIC_RefYear_2011_STEP_CHANGE_POE10_OPSO_MODELLING.csv" metadata = metadata_extractors.extract_demand_trace_metadata(file_name) assert metadata["subregion"] == "VIC" - assert metadata["reference_year"] == 2011 + assert metadata["reference_year"] == YEAR_2011 assert metadata["scenario"] == "STEP_CHANGE" assert metadata["poe"] == "POE10" assert metadata["demand_type"] == "OPSO_MODELLING" diff --git a/tests/test_trace_formatter.py b/tests/test_trace_formatter.py index 7e7e5fa..bf2f526 100644 --- a/tests/test_trace_formatter.py +++ b/tests/test_trace_formatter.py @@ -53,7 +53,7 @@ def test_trace_formatter(): formatted_data = formatted_data.sort(by=["Year", "Month", "Day", "Period"]) - formatted_data = formatted_data.pivot( + formatted_data = formatted_data.pivot_table( index=["Year", "Month", "Day"], on="Period", values="value" ) diff --git a/tests/test_trace_parsers.py b/tests/test_trace_parsers.py index c045908..fff027d 100644 --- a/tests/test_trace_parsers.py +++ b/tests/test_trace_parsers.py @@ -5,7 +5,7 @@ import pytest from polars.testing import assert_frame_equal -from isp_trace_parser import demand_traces, solar_traces, wind_traces +from isp_trace_parser import demand_traces TEST_DATA = Path(__file__).parent / "test_data" @@ -19,8 +19,8 @@ def test_demand_trace_parsing(use_concurrency: bool): ) test_demand_output_parquet = TEST_DATA / "output" / expected_filename - with tempfile.TemporaryDirectory() as tmp_parsed_directory: - tmp_parsed_directory = Path(tmp_parsed_directory) + with tempfile.TemporaryDirectory() as tmpdir: + tmp_parsed_directory = Path(tmpdir) demand_traces.parse_demand_traces( input_directory=test_demand_csv_directory, @@ -39,7 +39,7 @@ def test_demand_trace_parsing(use_concurrency: bool): @pytest.mark.parametrize( - "expected_filename, file_type", + ("expected_filename", "file_type"), [ ("RefYear2022_Bodangora_Wind_Farm.parquet", "project"), ("RefYear2022_N1_WM.parquet", "zone"), @@ -58,7 +58,7 @@ def test_wind_trace_parsing(parsed_trace_trace_directory, expected_filename, fil @pytest.mark.parametrize( - "expected_filename, file_type", + ("expected_filename", "file_type"), [ ("RefYear2022_N2_CST.parquet", "zone"), ("RefYear2022_Broken_Hill_Solar_Farm_FFP.parquet", "project"),