diff --git a/src/isp_workbook_parser/config_model.py b/src/isp_workbook_parser/config_model.py index 26e6ae0..322e416 100644 --- a/src/isp_workbook_parser/config_model.py +++ b/src/isp_workbook_parser/config_model.py @@ -1,5 +1,6 @@ +from __future__ import annotations + from pathlib import Path -from typing import Dict, List, Optional import yaml from pydantic import BaseModel @@ -49,11 +50,11 @@ class TableConfig(BaseModel): name: str sheet_name: str - header_rows: int | List[int] + header_rows: int | list[int] end_row: int column_range: str - skip_rows: Optional[int | List[int] | Dict[str, int]] = None - columns_with_merged_rows: Optional[str | List[str]] = None + skip_rows: int | list[int] | dict[str, int] | None = None + columns_with_merged_rows: str | list[str] | None = None forward_fill_values: bool = True @@ -118,7 +119,7 @@ def load_yaml(path: Path) -> dict[str, TableConfig]: path: pathlib Path instance specifying the location of the YAML file. """ - with open(path, "r") as f: + with Path(path).open() as f: config = yaml.safe_load(f) f.close() if config is not None: diff --git a/src/isp_workbook_parser/parser.py b/src/isp_workbook_parser/parser.py index 1520660..3caddca 100644 --- a/src/isp_workbook_parser/parser.py +++ b/src/isp_workbook_parser/parser.py @@ -1,5 +1,5 @@ -import glob -import os +from __future__ import annotations + import warnings from pathlib import Path from typing import Any @@ -107,23 +107,21 @@ def _determine_config_path( def _check_version_is_supported(self, config_path) -> None: """Check the default config directory contains a subdirectory that matches the workbook version number.""" - versions = os.listdir(config_path) + versions = Path(config_path).iterdir() if self.workbook_version not in versions: - raise ValueError( - f"The workbook version {self.workbook_version} is not supported." - ) + msg = f"The workbook version {self.workbook_version} is not supported." + raise ValueError(msg) def _load_config(self) -> dict[str, dict[str, Any]]: """Load all the YAML files stored in the config directory into a nested dictionary with sheet names as keys and table names as second level keys. For robustness across workbook versions, the config sheet name is matched with a workbook sheet name in case-agnostic manner. """ - pattern = os.path.join(self.config_path, "*.yaml") - config_files = glob.glob(pattern) + config_files = Path(self.config_path).glob("*.yaml") configs = {} for file in config_files: config_dict = load_yaml(Path(file)) - for config_name in config_dict.keys(): + for config_name in config_dict: config = config_dict[config_name] config_sheet_name_lowercase = config.sheet_name.lower() sheet_names = [ @@ -132,15 +130,14 @@ def _load_config(self) -> dict[str, dict[str, Any]]: if sheet_name.lower() == config_sheet_name_lowercase ] if len(sheet_names) > 1: - raise TableConfigError( - f"Workbook sheet '{config.sheet_name}' is not unique" - ) - elif len(sheet_names) < 1: - raise TableConfigError( + msg = f"Workbook sheet '{config.sheet_name}' is not unique" + raise TableConfigError(msg) + if len(sheet_names) < 1: + msg = ( f" Sheet '{config.sheet_name}' cannot be found in the workbook" ) - else: - config.sheet_name = sheet_names.pop() + raise TableConfigError(msg) + config.sheet_name = sheet_names.pop() config_dict[config_name] = config configs.update(config_dict) return configs @@ -157,7 +154,7 @@ def _get_table_names_by_sheet(self): return sorted_table_names_by_sheet def _check_data_ends_where_expected( - self, tab: str, end_row: int, range: str, name: str + self, tab: str, end_row: int, cell_range: str, name: str ) -> None: """Check that the cell after the last row of the table in the second column is blank. @@ -165,7 +162,7 @@ def _check_data_ends_where_expected( second column ends appears to be always blank. Therefore, checking that this cell is blank can be used to verify that the config has not specified a table end row that is before the actual last row of the table. """ - first_column = range.split(":")[0] + first_column = cell_range.split(":", maxsplit=1)[0] first_col_index = openpyxl.utils.column_index_from_string(first_column) second_col_index = first_col_index + 1 # We check that value in the second column is blank because sometime the row after the first column will @@ -181,7 +178,7 @@ def _check_data_ends_where_expected( raise TableConfigError(error_message) def _check_no_data_above_first_header_row( - self, tab: str, header_rows: int, range: str, name: str + self, tab: str, header_rows: int, cell_range: str, name: str ) -> None: """Check that the cell before the first header row of the table in the second column is blank. @@ -189,7 +186,7 @@ def _check_no_data_above_first_header_row( second column appears to be always blank. Therefore, checking that this cell is blank can be used to verify that the config has not specified a table header row that is after the first header row of the table. """ - first_column = range.split(":")[0] + first_column = cell_range.split(":", maxsplit=1)[0] first_col_index = openpyxl.utils.column_index_from_string(first_column) second_col_index = first_col_index + 1 @@ -270,7 +267,7 @@ def _check_columns_unique(data: pd.DataFrame, name: str) -> None: raise TableConfigError(error_message) def _check_for_missed_column_on_right_hand_side_of_table( - self, sheet_name: str, start_row: int, end_row: int, range: str, name: str + self, sheet_name: str, start_row: int, end_row: int, cell_range: str, name: str ) -> None: """Checks if there is data in the column adjacent to last column specified in the config. @@ -278,7 +275,7 @@ def _check_for_missed_column_on_right_hand_side_of_table( there is data in the adjacent column can help detect when the column range in the config has been incorrectly specified. """ - last_column = range.split(":")[1] + last_column = cell_range.split(":")[1] last_col_index = openpyxl.utils.column_index_from_string(last_column) column_next_to_last_column = openpyxl.utils.get_column_letter( last_col_index + 1 @@ -310,7 +307,7 @@ def _check_for_missed_column_on_right_hand_side_of_table( raise TableConfigError(error_message) def _check_for_missed_column_on_left_hand_side_of_table( - self, sheet_name: str, start_row: int, end_row: int, range: str, name: str + self, sheet_name: str, start_row: int, end_row: int, cell_range: str, name: str ) -> None: """Checks if there is data in the column adjacent to first column specified in the config. @@ -318,7 +315,7 @@ def _check_for_missed_column_on_left_hand_side_of_table( there is data in the adjacent column can help detect when the column range in the config has been incorrectly specified. """ - first_column = range.split(":")[0] + first_column = cell_range.split(":", maxsplit=1)[0] first_col_index = openpyxl.utils.column_index_from_string(first_column) column_next_to_first_column = openpyxl.utils.get_column_letter( first_col_index - 1 @@ -335,10 +332,9 @@ def _check_for_missed_column_on_left_hand_side_of_table( usecols=column_next_to_first_column, nrows=(end_row - start_row), ) - if data[data.columns[0]].isna().all(): - range_error = False - elif ( - "DO NOT DELETE THIS COLUMN" in str(data.columns[0]) + if ( + data[data.columns[0]].isna().all() + or "DO NOT DELETE THIS COLUMN" in str(data.columns[0]) or first_column == "B" ): range_error = False @@ -459,7 +455,7 @@ def _postprocess_percentage_columns_between_0_and_100( if isinstance(sr, list) and cell.row in sr: skipped_rows += 1 continue - elif isinstance(sr, int) and cell.row == sr: + if isinstance(sr, int) and cell.row == sr: skipped_rows += 1 continue if isinstance(cell.value, (int, float)) and "%" in cell.number_format: @@ -475,7 +471,7 @@ def _postprocess_percentage_columns_between_0_and_100( # add the data column index if the entire column consists of percentage values # else, add the individual cells as a list of tuples if len(percentage_cells) == (table_config.end_row - min_row + 1): - percentage_columns.append(set(x[1] for x in percentage_cells).pop()) + percentage_columns.append({x[1] for x in percentage_cells}.pop()) else: percentage_columns.append(percentage_cells) @@ -505,7 +501,7 @@ def get_table_names(self) -> list[str]: return self.table_names_by_sheet def get_table_from_config( - self, table_config: TableConfig, config_checks: bool = True + self, table_config: TableConfig, *, config_checks: bool = True ) -> pd.DataFrame: """Retrieves a table from the assumptions workbook using the config provided and returns as pd.DataFrame. @@ -556,7 +552,7 @@ def get_table_from_config( self._check_table(data, table_config) return data - def get_table(self, table_name: str, config_checks: bool = True) -> pd.DataFrame: + def get_table(self, table_name: str, *, config_checks: bool = True) -> pd.DataFrame: """Retrieves a table from the assumptions workbook and returns as `pd.DataFrame`. Examples @@ -578,22 +574,24 @@ def get_table(self, table_name: str, config_checks: bool = True) -> pd.DataFrame starts and ends where expected and the workbook header matches the config header. """ if not isinstance(table_name, str): - raise ValueError("The parameter table_name must be provided as a string.") - if table_name not in self.table_configs.keys(): + msg = "The parameter table_name must be provided as a string." + raise TypeError(msg) + if table_name not in self.table_configs: closest = process.extractOne(table_name, self.table_configs.keys())[0] - raise ValueError( + msg = ( f"The table_name ({table_name}) provided is not in the config for this workbook version." - + f" Did you mean '{closest}'?" + f" Did you mean '{closest}'?" ) + raise ValueError(msg) table_config = self.table_configs[table_name] - data = self.get_table_from_config(table_config, config_checks=config_checks) - return data + return self.get_table_from_config(table_config, config_checks=config_checks) def save_tables( self, directory: str | Path, tables: list[str] | str = "all", + *, config_checks: bool = True, ) -> None: """Saves tables from the provided workbook to the specified directory as CSV files. @@ -619,18 +617,19 @@ def save_tables( directory.mkdir(parents=True) if not directory.is_dir(): - raise ValueError("The path provided is not a directory.") + msg = "The path provided is not a directory." + raise ValueError(msg) - if not (isinstance(tables, str) or isinstance(tables, list)): - raise ValueError( - "The parameter tables must be provided as str or list[str]." - ) + if not (isinstance(tables, (str, list))): + msg = "The parameter tables must be provided as str or list[str]." + raise TypeError(msg) if isinstance(tables, str) and tables != "all": - raise ValueError( + msg = ( "If the parameter tables is provided as a str it must \n", f"have the value 'all' but '{tables}' was provided.", ) + raise ValueError(msg) if tables == "all": tables = self.table_configs.keys() diff --git a/src/isp_workbook_parser/read_table.py b/src/isp_workbook_parser/read_table.py index dec0acd..21362dc 100644 --- a/src/isp_workbook_parser/read_table.py +++ b/src/isp_workbook_parser/read_table.py @@ -1,11 +1,14 @@ -from typing import List, Union +from __future__ import annotations + +from typing import TYPE_CHECKING import numpy as np import openpyxl import openpyxl.utils import pandas as pd -from isp_workbook_parser import TableConfig +if TYPE_CHECKING: + from isp_workbook_parser import TableConfig from .sanitisers import _column_name_sanitiser @@ -80,60 +83,62 @@ def read_table(workbook_file: pd.ExcelFile, table: TableConfig) -> pd.DataFrame: df, table.columns_with_merged_rows, table.column_range ) return df + + df_initial = pd.read_excel( + workbook_file, + sheet_name=table.sheet_name, + header=(table.header_rows[0] - 1), + usecols=table.column_range, + nrows=(table.end_row - table.header_rows[0]), + # do not parse dtypes + dtype="object", + ) + df_initial.columns = _column_name_sanitiser(df_initial.columns) + # check that header_rows list is sorted + if sorted(table.header_rows) != table.header_rows: + raise ValueError + # check that the header_rows are adjacent + if set(np.diff(table.header_rows)) != {1}: + raise ValueError + # start processing multiple header rows + header_rows_in_table = table.header_rows[-1] - table.header_rows[0] + initial_header = pd.Series(df_initial.columns) + ffilled_initial_header = _ffill_highest_header(initial_header) + filled_headers = [] + # ffill intermediate header rows + for i in range(header_rows_in_table - 1): + if i == 0: + preceding_header = initial_header + filled_headers.append( + _ffill_intermediate_header_row(df_initial.iloc[i, :], preceding_header) + ) + preceding_header = df_initial.iloc[i, :] + # process last header row + if not filled_headers: + processed_last_header = _process_last_header_row( + df_initial.iloc[header_rows_in_table - 1, :], ffilled_initial_header + ) else: - df_initial = pd.read_excel( - workbook_file, - sheet_name=table.sheet_name, - header=(table.header_rows[0] - 1), - usecols=table.column_range, - nrows=(table.end_row - table.header_rows[0]), - # do not parse dtypes - dtype="object", + processed_last_header = _process_last_header_row( + df_initial.iloc[header_rows_in_table - 1, :], filled_headers[-1] ) - df_initial.columns = _column_name_sanitiser(df_initial.columns) - # check that header_rows list is sorted - assert sorted(table.header_rows) == table.header_rows - # check that the header_rows are adjacent - assert set(np.diff(table.header_rows)) == set([1]) - # start processing multiple header rows - header_rows_in_table = table.header_rows[-1] - table.header_rows[0] - initial_header = pd.Series(df_initial.columns) - ffilled_initial_header = _ffill_highest_header(initial_header) - filled_headers = [] - # ffill intermediate header rows - for i in range(0, header_rows_in_table - 1): - if i == 0: - preceding_header = initial_header - filled_headers.append( - _ffill_intermediate_header_row(df_initial.iloc[i, :], preceding_header) - ) - preceding_header = df_initial.iloc[i, :] - # process last header row - if not filled_headers: - processed_last_header = _process_last_header_row( - df_initial.iloc[header_rows_in_table - 1, :], ffilled_initial_header - ) - else: - processed_last_header = _process_last_header_row( - df_initial.iloc[header_rows_in_table - 1, :], filled_headers[-1] - ) - filled_headers.append(processed_last_header) - # add separators manually - ignore any "" entries - for series in filled_headers: - series[series != ""] = "_" + series[series != ""] - merged_headers = ffilled_initial_header.str.cat(filled_headers) - df_cleaned = _build_cleaned_dataframe( - df_initial, header_rows_in_table, merged_headers, table.forward_fill_values + filled_headers.append(processed_last_header) + # add separators manually - ignore any "" entries + for series in filled_headers: + series[series != ""] = "_" + series[series != ""] + merged_headers = ffilled_initial_header.str.cat(filled_headers) + df_cleaned = _build_cleaned_dataframe( + df_initial, header_rows_in_table, merged_headers, table.forward_fill_values + ) + if table.skip_rows: + df_cleaned = _skip_rows_in_dataframe( + df_cleaned, table.skip_rows, table.header_rows[-1] ) - if table.skip_rows: - df_cleaned = _skip_rows_in_dataframe( - df_cleaned, table.skip_rows, table.header_rows[-1] - ) - if table.columns_with_merged_rows: - df_cleaned = _handle_merged_rows( - df_cleaned, table.columns_with_merged_rows, table.column_range - ) - return df_cleaned + if table.columns_with_merged_rows: + df_cleaned = _handle_merged_rows( + df_cleaned, table.columns_with_merged_rows, table.column_range + ) + return df_cleaned def _ffill_highest_header(initial_header: pd.Series) -> pd.Series: @@ -142,8 +147,7 @@ def _ffill_highest_header(initial_header: pd.Series) -> pd.Series: a multi-header table """ initial_header[initial_header.str.contains("Unnamed")] = pd.NA - ffill_initial_header = initial_header.ffill().reset_index(drop=True).fillna("") - return ffill_initial_header + return initial_header.ffill().reset_index(drop=True).fillna("") def _ffill_intermediate_header_row( @@ -172,8 +176,7 @@ def _ffill_intermediate_header_row( int_header.iloc[n] = pd.NA _ffill_intermediate_header = int_header.reset_index(drop=True).fillna("") - _ffill_intermediate_header = _column_name_sanitiser(_ffill_intermediate_header) - return _ffill_intermediate_header + return _column_name_sanitiser(_ffill_intermediate_header) def _process_last_header_row( @@ -187,14 +190,14 @@ def _process_last_header_row( """ last_header = last_header.reset_index(drop=True).fillna("") last_header = _column_name_sanitiser(last_header) - last_header = last_header.where(last_header != preceding_header, "") - return last_header + return last_header.where(last_header != preceding_header, "") def _build_cleaned_dataframe( df_initial: pd.DataFrame, header_rows_in_table: int, new_headers: pd.Series, + *, forward_fill_values: bool, ) -> pd.DataFrame: """ @@ -208,12 +211,11 @@ def _build_cleaned_dataframe( df_cleaned.columns = new_headers if forward_fill_values: df_cleaned = df_cleaned.ffill(axis=1) - df_cleaned = df_cleaned.reset_index(drop=True) - return df_cleaned + return df_cleaned.reset_index(drop=True) def _skip_rows_in_dataframe( - df: pd.DataFrame, config_skip_rows: Union[int, List[int]], last_header_row: int + df: pd.DataFrame, config_skip_rows: int | list[int], last_header_row: int ) -> pd.DataFrame: """ Drop rows specified by `skip_rows` by applying an offset from the header and @@ -227,13 +229,12 @@ def _skip_rows_in_dataframe( skip_rows = np.subtract(skip_rows, last_header_row + 1) else: skip_rows = np.subtract(config_skip_rows, last_header_row + 1) - dropped = df_reset_index.drop(index=skip_rows).reset_index(drop=True) - return dropped + return df_reset_index.drop(index=skip_rows).reset_index(drop=True) def _handle_merged_rows( df: pd.DataFrame, - config_cols_with_merged_rows: Union[str, List[str]], + config_cols_with_merged_rows: str | list[str], column_range: str, ) -> pd.DataFrame: """ @@ -243,9 +244,7 @@ def _handle_merged_rows( cols = [config_cols_with_merged_rows] else: cols = config_cols_with_merged_rows - actual_col_indices = list( - map(lambda col: _find_data_column_index(col, column_range), cols) - ) + actual_col_indices = [_find_data_column_index(col, column_range) for col in cols] for index in actual_col_indices: df.iloc[:, index] = df.iloc[:, index].ffill() return df @@ -267,7 +266,7 @@ def _find_data_column_index( (zero-indexed) """ first_col_index = openpyxl.utils.column_index_from_string( - column_range_from_table_config.split(":")[0] + column_range_from_table_config.split(":", maxsplit=1)[0] ) data_col_index = openpyxl.utils.column_index_from_string(column_alphabetical) return data_col_index - first_col_index diff --git a/src/isp_workbook_parser/sanitisers.py b/src/isp_workbook_parser/sanitisers.py index 237875f..f7caeec 100644 --- a/src/isp_workbook_parser/sanitisers.py +++ b/src/isp_workbook_parser/sanitisers.py @@ -1,3 +1,6 @@ +from __future__ import annotations + +import contextlib import re import numpy as np @@ -8,6 +11,8 @@ def _column_name_sanitiser(columns: pd.Index | pd.Series) -> pd.Index | pd.Series: """ + Sanitise column names. + Sanitises column names by: 1. Removing 'versioning' from column names introduced by `mangle_dupe_cols` in pandas parser, e.g. 'Generator.1' is sanitised to 'Generator' @@ -22,8 +27,7 @@ def _column_name_sanitiser(columns: pd.Index | pd.Series) -> pd.Index | pd.Serie columns = columns.str.strip() columns = _replace_series_newlines_with_whitespace(columns) columns = _remove_series_double_whitespaces(columns) - columns = _remove_column_name_trailing_footnotes(columns) - return columns + return _remove_column_name_trailing_footnotes(columns) def _custom_string_replacements( @@ -54,7 +58,7 @@ def _values_casting_and_sanitisation(df: pd.DataFrame) -> pd.DataFrame: will return `pd.NA` """ df = _replace_dataframe_hyphens_with_na(df) - for object_col in df.dtypes[df.dtypes == "object"].keys(): + for object_col in df.dtypes[df.dtypes == "object"]: try: df.loc[:, object_col] = pd.to_numeric(df[object_col]) except (ValueError, TypeError): @@ -73,10 +77,8 @@ def _values_casting_and_sanitisation(df: pd.DataFrame) -> pd.DataFrame: ): df.loc[where_str_values, object_col] = series_func(df[object_col]) # re-attempt conversion following sanitisation - try: + with contextlib.suppress(ValueError, TypeError): df[object_col] = pd.to_numeric(df[object_col]) - except (ValueError, TypeError): - pass return df @@ -159,8 +161,7 @@ def _remove_series_notes_after_values( r"\1", regex=True, ) - series = series.str.replace(r"^\-\s?(?:(\([\w\s\.\<\=\-\(\)]+)+)", "", regex=True) - return series + return series.str.replace(r"^\-\s?(?:(\([\w\s\.\<\=\-\(\)]+)+)", "", regex=True) def _extract_numeric_value_millions( @@ -184,9 +185,8 @@ def extract(val): if num_str.replace(".", "", 1).isdigit(): # Convert to float and multiply by 1,000,000 return float(num_str) * 1_000_000 - else: - # If not a valid number, return the original value - return val + # If not a valid number, return the original value + return val # Return value unchanged if pattern does not match return val diff --git a/tests/test_packaged_table_configs.py b/tests/test_packaged_table_configs.py index 2da7a1e..8b2e964 100644 --- a/tests/test_packaged_table_configs.py +++ b/tests/test_packaged_table_configs.py @@ -9,10 +9,10 @@ @pytest.mark.parametrize("workbook_version_folder", workbook_path.iterdir()) def test_packaged_table_configs_for_each_version(workbook_version_folder: Path): - xl_file = [file for file in workbook_version_folder.glob("[!.]*.xls*")] - assert ( - len(xl_file) == 1 - ), f"There should only be one Excel workbook in each version sub-directory, got {xl_file}" + xl_file = list(workbook_version_folder.glob("[!.]*.xls*")) + if len(xl_file) != 1: + msg = f"There should only be one Excel workbook in each version sub-directory, got {xl_file}" + raise RuntimeError(msg) workbook_name = xl_file.pop() workbook = Parser(workbook_name) @@ -43,8 +43,8 @@ def test_packaged_table_configs_for_each_version(workbook_version_folder: Path): error_tables[table_name] = e if error_tables: error_str = "" - for key in error_tables: - error_str += key + ":" + str(error_tables[key]) + "\n" + for key, value in error_tables.items(): + error_str += key + ":" + str(value) + "\n" raise TableLoadError(error_str) diff --git a/tests/test_read_table_functionality.py b/tests/test_read_table_functionality.py index abadcef..6e79235 100644 --- a/tests/test_read_table_functionality.py +++ b/tests/test_read_table_functionality.py @@ -11,8 +11,10 @@ def test_skip_single_row_in_single_header_row_table(workbook_v6): skip_rows=30, ) df = workbook_v6.get_table_from_config(table_config) - assert len(df) == (table_config.end_row - table_config.header_rows - 1) - assert df[df.Technology.str.contains("Hydrogen")].empty + if len(df) != (table_config.end_row - table_config.header_rows - 1): + raise ValueError + if not df[df.Technology.str.contains("Hydrogen")].empty: + raise ValueError def test_skip_multiple_rows_in_single_header_row_table(workbook_v6):