diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..2d0aa38 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,18 @@ +repos: +- repo: https://github.com/pre-commit/pre-commit-hooks + rev: v2.3.0 + hooks: + - id: check-yaml + - id: end-of-file-fixer + - id: trailing-whitespace +- repo: https://github.com/astral-sh/ruff-pre-commit + # Ruff version. + rev: v0.12.7 + hooks: + # Run the linter. + - id: ruff-check + args: + - --fix + - --select=E4,E7,E9,F + # Run the formatter. + - id: ruff-format diff --git a/docs/source/_static/APOCADO_yaml.yml b/docs/source/_static/APOCADO_yaml.yml index a55db72..2928b04 100644 --- a/docs/source/_static/APOCADO_yaml.yml +++ b/docs/source/_static/APOCADO_yaml.yml @@ -10,3 +10,4 @@ 'f_min': null 'f_max': null 'score': null + 'filename_format': "%Y-%m-%dT%H-%M-%S_000.wav" diff --git a/docs/source/example_overview.ipynb b/docs/source/example_overview.ipynb index d559501..5657ce3 100644 --- a/docs/source/example_overview.ipynb +++ b/docs/source/example_overview.ipynb @@ -10,22 +10,28 @@ "[^download]: This notebook can be downloaded as **{nb-download}\n", "repr(example_overview.ipynb) **.\n", "\n", - "This very simple example shows you how to make an overview plot provided a yaml configuration file and a result csv." + "This basic example shows you how to make an overview plot provided a YAML configuration file and a result CSV file." ] }, + { + "metadata": {}, + "cell_type": "markdown", + "source": "", + "id": "feb4595b96c1bbff" + }, { "cell_type": "code", - "execution_count": null, "id": "6a230b00c14cc64e", "metadata": {}, - "outputs": [], "source": [ "from pathlib import Path\n", "\n", "import matplotlib.pyplot as plt\n", "\n", "from post_processing.dataclass.data_aplose import DataAplose" - ] + ], + "outputs": [], + "execution_count": null }, { "cell_type": "markdown", @@ -35,14 +41,14 @@ }, { "cell_type": "code", - "execution_count": null, "id": "c19ddde8bf965ee8", "metadata": {}, - "outputs": [], "source": [ - "yaml_file = Path(r\"resource/APOCADO_yaml.yml\")\n", + "yaml_file = Path(r\"_static/APOCADO_yaml.yml\")\n", "data = DataAplose.from_yaml(file=yaml_file)" - ] + ], + "outputs": [], + "execution_count": null }, { "cell_type": "markdown", @@ -52,15 +58,15 @@ }, { "cell_type": "code", - "execution_count": null, "id": "cf59962e0e23eb96", "metadata": {}, - "outputs": [], "source": [ "data.overview()\n", "plt.tight_layout()\n", "plt.show()" - ] + ], + "outputs": [], + "execution_count": null } ], "metadata": { diff --git a/pyproject.toml b/pyproject.toml index d202dd8..ee3e533 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,6 +37,7 @@ dev = [ "sphinx-book-theme>=1.1.4", "sphinx-copybutton>=0.5.2", "coverage>=7.11.0", + "pre-commit>=4.5.1", ] [tool.ruff.lint.flake8-copyright] diff --git a/src/post_processing/dataclass/data_aplose.py b/src/post_processing/dataclass/data_aplose.py index e1d3cae..e57feec 100644 --- a/src/post_processing/dataclass/data_aplose.py +++ b/src/post_processing/dataclass/data_aplose.py @@ -8,12 +8,18 @@ from __future__ import annotations import logging -from copy import copy from typing import TYPE_CHECKING import matplotlib.dates as mdates import matplotlib.pyplot as plt -from pandas import DataFrame, Series, Timedelta, Timestamp, concat, date_range +from pandas import ( + DataFrame, + Series, + Timedelta, + Timestamp, + concat, + date_range, +) from pandas.tseries import offsets from post_processing.dataclass.detection_filter import DetectionFilter @@ -27,10 +33,10 @@ ) from post_processing.utils.metrics_utils import detection_perf from post_processing.utils.plot_utils import ( - agreement, heatmap, histo, overview, + plot_annotator_agreement, scatter, timeline, ) @@ -88,7 +94,13 @@ def _get_locator_from_offset( class DataAplose: """A class to handle APLOSE formatted data.""" - def __init__(self, df: DataFrame = None) -> None: + def __init__( + self, + df: DataFrame = None, + *, + begin: Timestamp = None, + end: Timestamp = None, + ) -> None: """Initialize a DataAplose object from a DataFrame. Parameters @@ -107,8 +119,8 @@ def __init__(self, df: DataFrame = None) -> None: ).reset_index(drop=True) self.annotators = sorted(set(self.df["annotator"])) if df is not None else None self.labels = sorted(set(self.df["annotation"])) if df is not None else None - self.begin = min(self.df["start_datetime"]) if df is not None else None - self.end = max(self.df["end_datetime"]) if df is not None else None + self.begin = min(self.df["start_datetime"]) if begin is None else begin + self.end = max(self.df["end_datetime"]) if end is None else end self.dataset = sorted(set(self.df["dataset"])) if df is not None else None self.lat = None self.lon = None @@ -183,12 +195,10 @@ def change_tz(self, tz: str | tzinfo) -> None: """ self.df["start_datetime"] = [ - elem.tz_convert(tz) - for elem in self.df["start_datetime"] + elem.tz_convert(tz) for elem in self.df["start_datetime"] ] self.df["end_datetime"] = [ - elem.tz_convert(tz) - for elem in self.df["end_datetime"] + elem.tz_convert(tz) for elem in self.df["end_datetime"] ] self.begin = self.begin.tz_convert(tz) self.end = self.end.tz_convert(tz) @@ -293,27 +303,21 @@ def overview(self, annotator: list[str] | None = None) -> None: def detection_perf( self, - annotators: tuple[str, str], - labels: tuple[str, str], - timestamps: list[Timestamp] | None = None, + annotators: tuple[str, str] | list[str], + labels: tuple[str, str] | list[str], ) -> tuple[float, float, float]: - """Compute performances metrics for detection. + """Compute performance metrics for detection. - Performances are computed with a reference annotator in - comparison with a second annotator/detector. - Precision and recall are computed in regard - with a reference annotator/label pair. + Precision and recall are computed in regard to a reference annotator/label pair. Parameters ---------- annotators: [str, str] List of the two annotators to compare. - First annotator is chosen as reference. + The first annotator is chosen as a reference. labels: [str, str] List of the two labels to compare. - First label is chosen as reference. - timestamps: list[Timestamp], optional - A list of Timestamps to base the computation on. + The first label is chosen as a reference. Returns ------- @@ -331,10 +335,16 @@ def detection_perf( if isinstance(labels, str): labels = [labels] ref = (annotators[0], labels[0]) + + if len(set(df_filtered["end_time"])) > 1: + msg = "Multiple time bins detected in DataFrame." + raise ValueError(msg) + timebin = Timedelta(df_filtered["end_time"].iloc[0], "s") + return detection_perf( df=df_filtered, ref=ref, - timestamps=timestamps, + time=date_range(self.begin, self.end, freq=timebin), ) def plot( @@ -355,29 +365,29 @@ def plot( Parameters ---------- - mode: str - Type of plot to generate. - Must be one of {"histogram", "scatter", "heatmap", "agreement"}. - ax: plt.Axes - Matplotlib Axes object to plot on. - annotator: str | list[str] - The selected annotator or list of annotators. - label: str | list[str] - The selected label or list of labels. - **kwargs: Additional keyword arguments depending on the mode. - - legend: bool - Whether to show the legend. - - season: bool - Whether to show the season. - - show_rise_set: bool - Whether to show sunrise and sunset times. - - color: str | list[str] - Color(s) for the bars. - - bin_size: Timedelta | BaseOffset - Bin size for the histogram. - - effort: Series - The timestamps intervals corresponding to the observation effort. - If provided, data will be normalized by observation effort. + mode: str + Type of plot to generate. + Must be one of {"histogram", "scatter", "heatmap", "agreement"}. + ax: plt.Axes + Matplotlib Axes object to plot on. + annotator: str | list[str] + The selected annotator or list of annotators. + label: str | list[str] + The selected label or list of labels. + **kwargs: Additional keyword arguments depending on the mode. + - legend: bool + Whether to show the legend. + - season: bool + Whether to show the season. + - show_rise_set: bool + Whether to show sunrise and sunset times. + - color: str | list[str] + Color(s) for the bars. + - bin_size: Timedelta | BaseOffset + Bin size for the histogram. + - effort: Series + The timestamp intervals corresponding to the observation effort. + If provided, data will be normalized by observation effort. """ df_filtered = self.filter_df( @@ -386,19 +396,20 @@ def plot( ) time = date_range(self.begin, self.end) + bin_size = kwargs.get("bin_size") + legend = kwargs.get("legend", True) + color = kwargs.get("color") + season = kwargs.get("season") + effort = kwargs.get("effort") + show_rise_set = kwargs.get("show_rise_set", True) if mode == "histogram": - bin_size = kwargs.get("bin_size") - legend = kwargs.get("legend", True) - color = kwargs.get("color") - season = kwargs.get("season") - effort = kwargs.get("effort") + ax.set_xlim(time[0], time[-1]) if not bin_size: msg = "'bin_size' missing for histogram plot." raise ValueError(msg) df_counts = get_count(df_filtered, bin_size) detection_size = Timedelta(max(df_filtered["end_time"]), "s") - return histo( df=df_counts, ax=ax, @@ -412,10 +423,7 @@ def plot( ) if mode == "heatmap": - show_rise_set = kwargs.get("show_rise_set", True) - season = kwargs.get("season", False) - bin_size = kwargs.get("bin_size") - + ax.set_xlim(time[0], time[-1]) return heatmap( df=df_filtered, ax=ax, @@ -427,31 +435,31 @@ def plot( ) if mode == "scatter": - show_rise_set = kwargs.get("show_rise_set", True) - season = kwargs.get("season", False) - effort = kwargs.get("effort") - - return scatter(df=df_filtered, - ax=ax, - time_range=time, - show_rise_set=show_rise_set, - season=season, - coordinates=self.coordinates, - effort=effort, - ) + ax.set_xlim(time[0], time[-1]) + return scatter( + df=df_filtered, + ax=ax, + time_range=time, + show_rise_set=show_rise_set, + season=season, + coordinates=self.coordinates, + effort=effort, + ) if mode == "agreement": - bin_size = kwargs.get("bin_size") - return agreement(df=df_filtered, bin_size=bin_size, ax=ax) + if not bin_size: + msg = "'bin_size' missing for agreement plot." + raise ValueError(msg) + df_counts = get_count(df_filtered, bin_size) + return plot_annotator_agreement(df=df_counts, bin_size=bin_size, ax=ax) if mode == "timeline": + ax.set_xlim(time[0], time[-1]) color = kwargs.get("color") - df_filtered = self.filter_df( annotator, label, ) - return timeline(df=df_filtered, ax=ax, color=color) msg = f"Unsupported plot mode: {mode}" @@ -509,60 +517,86 @@ def from_filters( if isinstance(filters, DetectionFilter): filters = [filters] cls_list = [cls(load_detections(fil)) for fil in filters] + + for cls_obj, fil in zip(cls_list, filters, strict=True): + cls.reshape(cls_obj, fil.begin, fil.end) + if len(cls_list) == 1: return cls_list[0] + if concat: return cls.concatenate(cls_list) return cls_list @classmethod def concatenate( - cls, data_list: list[DataAplose], + cls, + data_list: list[DataAplose], ) -> DataAplose: """Concatenate a list of DataAplose objects into one.""" df_concat = ( - concat([data.df for data in data_list], ignore_index=True) + concat( + [data.df for data in data_list], + ignore_index=True, + ) .sort_values( - by=["start_datetime", + by=[ + "start_datetime", "end_datetime", "annotator", "annotation", - ], + ], ) .reset_index(drop=True) ) - obj = cls(df_concat) + + obj = cls( + df=df_concat, + begin=min(obj.begin for obj in data_list), + end=max(obj.end for obj in data_list), + ) + if isinstance(get_timezone(df_concat), list): obj.change_tz("utc") - msg = ("Several timezones found in DataFrame," - " all timestamps are converted to UTC.") + msg = ( + "Several timezones found in DataFrame," + " all timestamps are converted to UTC." + ) logging.info(msg) return obj def reshape(self, begin: Timestamp = None, end: Timestamp = None) -> DataAplose: - """Reshape the DataAplose with new begin and/or end.""" - new_data = copy(self) - + """Reshape the DataAplose with a new beginning and/or end.""" if not any([begin, end]): - msg = "Must provide begin and/or end timestamps." - raise ValueError(msg) + msg = "No begin and end timestamps provided for reshape of DataAplose instance." + logging.debug(msg) + return self - tz = get_timezone(new_data.df) + tz = get_timezone(self.df) if begin: - new_data.begin = begin + self.begin = begin if not begin.tz: - new_data.begin = begin.tz_localize(tz) + self.begin = begin.tz_localize(tz) if end: - new_data.end = end + self.end = end if not end.tz: - new_data.end = end.tz_localize(tz) + self.end = end.tz_localize(tz) - new_data.df = new_data.df[ - (new_data.df["start_datetime"] >= new_data.begin) & - (new_data.df["end_datetime"] <= new_data.end) + if self.begin >= self.end: + msg = "Begin timestamp is not anterior than end timestamp." + raise ValueError(msg) + + self.df = self.df[ + (self.df["start_datetime"] >= self.begin) + & (self.df["end_datetime"] <= self.end) ] - new_data.dataset = get_dataset(new_data.df) - new_data.labels = get_labels(new_data.df) - new_data.annotators = get_annotators(new_data.df) - return new_data + if self.df.empty: + msg = "DataFrame is empty after reshaping." + raise ValueError(msg) + + self.dataset = get_dataset(self.df) + self.labels = get_labels(self.df) + self.annotators = get_annotators(self.df) + + return self diff --git a/src/post_processing/dataclass/detection_filter.py b/src/post_processing/dataclass/detection_filter.py index 09776ae..41538ca 100644 --- a/src/post_processing/dataclass/detection_filter.py +++ b/src/post_processing/dataclass/detection_filter.py @@ -71,8 +71,8 @@ def from_yaml( @classmethod def from_dict( - cls, - parameters: dict, + cls, + parameters: dict, ) -> DetectionFilter | list[DetectionFilter]: """Return a DetectionFilter object from a dict. @@ -105,8 +105,7 @@ def from_dict( if filters_dict.get("end"): filters_dict["end"] = Timestamp(filters_dict["end"]) if filters_dict.get("timestamp_file"): - filters_dict["timestamp_file"] = Path( - filters_dict["timestamp_file"]) + filters_dict["timestamp_file"] = Path(filters_dict["timestamp_file"]) filters.append(cls(**filters_dict)) diff --git a/src/post_processing/utils/core_utils.py b/src/post_processing/utils/core_utils.py index c149155..11449e9 100644 --- a/src/post_processing/utils/core_utils.py +++ b/src/post_processing/utils/core_utils.py @@ -7,11 +7,8 @@ import astral import easygui -import numpy as np from astral.sun import sunrise, sunset from matplotlib import pyplot as plt -from osekit.config import TIMESTAMP_FORMAT_AUDIO_FILE -from osekit.utils.timestamp_utils import strftime_osmose_format, strptime_from_text from pandas import ( DataFrame, DatetimeIndex, @@ -25,15 +22,6 @@ from pandas.tseries import offsets from pandas.tseries.offsets import BaseOffset -from post_processing.utils.filtering_utils import ( - get_annotators, - get_dataset, - get_labels, - get_max_freq, - get_max_time, - get_timezone, -) - if TYPE_CHECKING: from datetime import tzinfo from pathlib import Path @@ -134,16 +122,16 @@ def get_sun_times( # Convert sunrise and sunset to decimal hours h_sunrise.append( - dt_sunrise.hour + - dt_sunrise.minute / 60 + - dt_sunrise.second / 3600 + - dt_sunrise.microsecond / 3_600_000_000, + dt_sunrise.hour + + dt_sunrise.minute / 60 + + dt_sunrise.second / 3600 + + dt_sunrise.microsecond / 3_600_000_000, ) h_sunset.append( - dt_sunset.hour + - dt_sunset.minute / 60 + - dt_sunset.second / 3600 + - dt_sunset.microsecond / 3_600_000_000, + dt_sunset.hour + + dt_sunset.minute / 60 + + dt_sunset.second / 3600 + + dt_sunset.microsecond / 3_600_000_000, ) return h_sunrise, h_sunset @@ -195,83 +183,6 @@ def get_coordinates() -> tuple: return lat, lon -def add_weak_detection( - df: DataFrame, - datetime_format: str = TIMESTAMP_FORMAT_AUDIO_FILE, - max_time: Timedelta | None = None, - max_freq: float | None = None, -) -> DataFrame: - """Add weak detections APLOSE formatted DataFrame with only strong detections. - - Parameters - ---------- - df: DataFrame - An APLOSE formatted DataFrame. - datetime_format: str - A string corresponding to the datetime format in the `filename` column - max_time: Timedelta - Size of the weak detections - max_freq: float - Height of the weak detections - - """ - annotators = get_annotators(df) - labels = get_labels(df) - dataset_id = get_dataset(df) - tz = get_timezone(df) - - if not max_freq: - max_freq = get_max_freq(df) - if not max_time: - max_time = Timedelta(get_max_time(df), "s") - - df["start_datetime"] = [ - strftime_osmose_format(start) for start in df["start_datetime"] - ] - df["end_datetime"] = [ - strftime_osmose_format(stop) for stop in df["end_datetime"] - ] - - for ant in annotators: - for lbl in labels: - filenames = ( - df[(df["annotator"] == ant) & (df["annotation"] == lbl)]["filename"] - .drop_duplicates() - .tolist() - ) - for f in filenames: - test = df[(df["filename"] == f) & (df["annotation"] == lbl)]["type"] - if test.any(): - start_datetime = strptime_from_text( - text=f, - datetime_template=datetime_format, - ) - - if not start_datetime.tz: - start_datetime = tz.localize(start_datetime) - - end_datetime = start_datetime + Timedelta(max_time, unit="s") - new_line = [ - dataset_id, - f, - 0, - max_time.total_seconds(), - 0, - max_freq, - lbl, - ant, - strftime_osmose_format(start_datetime), - strftime_osmose_format(end_datetime), - "WEAK", - ] - - if "score" in df.columns: - new_line.append(np.nan) - df.loc[df.index.max() + 1] = new_line - - return df.sort_values(by=["start_datetime", "annotator"]).reset_index(drop=True) - - def json2df(json_path: Path) -> DataFrame: """Convert a metadatax JSON file into a DataFrame. @@ -313,9 +224,12 @@ def add_season_period( msg = "Axes have no data" raise ValueError(msg) + patches = ax.patches + bins = date_range( - start=Timestamp(ax.get_xlim()[0], unit="D"), - end=Timestamp(ax.get_xlim()[1], unit="D"), + start=Timestamp(ax.get_xlim()[0], unit="D").round("1ms"), + end=Timestamp(ax.get_xlim()[1], unit="D").round("1ms"), + freq=Timedelta(patches[0].get_width(), "D").round("1ms"), ) season_colors = { @@ -414,7 +328,9 @@ def add_recording_period( ax.set_ylim(ax.dataLim.ymin, ax.dataLim.ymax) -def get_count(df: DataFrame, bin_size: Timedelta | BaseOffset) -> DataFrame: +def get_count( + df: DataFrame, bin_size: Timedelta | BaseOffset, time: DatetimeIndex | None = None +) -> DataFrame: """Count observations per label and annotator. This function groups a DataFrame of events into uniform time bins and counts the @@ -426,6 +342,8 @@ def get_count(df: DataFrame, bin_size: Timedelta | BaseOffset) -> DataFrame: APLOSE-formatted DataFrame. bin_size : Timedelta | offsets Width or frequency of bins. + time: DatetimeIndex + DatetimeIndex from a specified beginning to end Returns ------- @@ -438,7 +356,7 @@ def get_count(df: DataFrame, bin_size: Timedelta | BaseOffset) -> DataFrame: msg = "`df` contains no data" raise ValueError(msg) - datetime_list = list(df["start_datetime"]) + datetime_list = list(df["start_datetime"]) if time is None else time.to_list() bins, bin_size = get_time_range_and_bin_size(datetime_list, bin_size) @@ -459,7 +377,7 @@ def get_count(df: DataFrame, bin_size: Timedelta | BaseOffset) -> DataFrame: def get_labels_and_annotators(df: DataFrame) -> tuple[list, list]: - """Extract and align annotation labels and annotators from a DataFrame. + """Extract and align annotation labels and annotators from an APLOSE DataFrame. If only one label is present, it is duplicated to match the number of annotators. Similarly, if one annotator is present, it is duplicated to match the labels. @@ -509,8 +427,9 @@ def get_time_range_and_bin_size( bin_size: Timedelta | BaseOffset, ) -> tuple[DatetimeIndex, Timedelta]: """Return time vector given a bin size.""" - if (not isinstance(timestamp_list, list) or - not all(isinstance(ts, Timestamp) for ts in timestamp_list)): + if not isinstance(timestamp_list, list) or not all( + isinstance(ts, Timestamp) for ts in timestamp_list + ): msg = "`timestamp_list` must be a list[Timestamp]" raise TypeError(msg) @@ -546,7 +465,7 @@ def round_begin_end_timestamps( if isinstance(bin_size, Timedelta): start = min(timestamp_list).floor(bin_size) - end = max(timestamp_list).ceil(bin_size) + end = max(timestamp_list).floor(bin_size) + bin_size return start, end, bin_size if isinstance(bin_size, BaseOffset): diff --git a/src/post_processing/utils/filtering_utils.py b/src/post_processing/utils/filtering_utils.py index faf2fd7..50d7486 100644 --- a/src/post_processing/utils/filtering_utils.py +++ b/src/post_processing/utils/filtering_utils.py @@ -7,8 +7,10 @@ import datetime from typing import TYPE_CHECKING +import numpy as np import pytz -from osekit.utils.timestamp_utils import strptime_from_text +from osekit.config import TIMESTAMP_FORMAT_AUDIO_FILE +from osekit.utils.timestamp_utils import strftime_osmose_format, strptime_from_text from pandas import ( DataFrame, Timedelta, @@ -19,6 +21,8 @@ to_datetime, ) +from post_processing.utils.core_utils import get_count + if TYPE_CHECKING: from pathlib import Path @@ -41,8 +45,10 @@ def find_delimiter(file: Path) -> str: dialect = sniffer.sniff(sample) if dialect.delimiter not in allowed_delimiters: - msg = (f"Could not determine delimiter for '{file}': " - f"unsupported delimiter '{dialect.delimiter}'") + msg = ( + f"Could not determine delimiter for '{file}': " + f"unsupported delimiter '{dialect.delimiter}'" + ) raise ValueError(msg) return dialect.delimiter @@ -259,11 +265,12 @@ def read_dataframe(file: Path, rows: int | None = None) -> DataFrame: """Read an APLOSE-formatted CSV file into a DataFrame.""" delimiter = find_delimiter(file) return ( - read_csv(file, - sep=delimiter, - parse_dates=["start_datetime", "end_datetime"], - nrows=rows, - ) + read_csv( + file, + sep=delimiter, + parse_dates=["start_datetime", "end_datetime"], + nrows=rows, + ) .drop_duplicates() .dropna(subset=["annotation"]) .sort_values(by=["start_datetime", "end_datetime"]) @@ -273,12 +280,18 @@ def read_dataframe(file: Path, rows: int | None = None) -> DataFrame: def get_annotators(df: DataFrame) -> list[str]: """Return the annotator list of APLOSE DataFrame.""" - return sorted(set(df["annotator"])) + if len(df) == 1: + return df["annotator"][0] + annotators = sorted(set(df["annotator"])) + return annotators if len(annotators) > 1 else annotators[0] -def get_labels(df: DataFrame) -> list[str]: +def get_labels(df: DataFrame) -> str | list[str]: """Return the label list of APLOSE DataFrame.""" - return sorted(set(df["annotation"])) + if len(df) == 1: + return df["annotation"][0] + labels = sorted(set(df["annotation"])) + return labels if len(labels) > 1 else labels[0] def get_max_freq(df: DataFrame) -> float: @@ -293,6 +306,8 @@ def get_max_time(df: DataFrame) -> float: def get_dataset(df: DataFrame) -> str | list[str]: """Return dataset list of APLOSE DataFrame.""" + if len(df) == 1: + return df["dataset"][0] datasets = sorted(set(df["dataset"])) return datasets if len(datasets) > 1 else datasets[0] @@ -327,8 +342,9 @@ def get_canonical_tz(tz: datetime.tzinfo) -> pytz.tzinfo.BaseTzInfo: raise TypeError(msg) -def get_timezone(df: DataFrame)\ - -> pytz.tzinfo.BaseTzInfo | list[pytz.tzinfo.BaseTzInfo]: +def get_timezone( + df: DataFrame, +) -> pytz.tzinfo.BaseTzInfo | list[pytz.tzinfo.BaseTzInfo]: """Return timezone(s) from APLOSE DataFrame. Parameters @@ -476,7 +492,10 @@ def _process_annotator_label_pair( # Build vectors filename_vector = _build_filename_vector( - time_vector, ts_detect_beg, timestamp_audio, filenames, + time_vector, + ts_detect_beg, + timestamp_audio, + filenames, ) detect_vec = _build_detection_vector(time_vector, ts_detect_beg, ts_detect_end) @@ -492,7 +511,13 @@ def _process_annotator_label_pair( return None return _create_result_dataframe( - file_vector, start_datetime, timebin_new, max_freq, dataset, label, annotator, + file_vector, + start_datetime, + timebin_new, + max_freq, + dataset, + label, + annotator, ) @@ -538,11 +563,20 @@ def reshape_timebin( df = _normalize_timezones(df) # Process each annotator-label combination + annotators = [annotators] if isinstance(annotators, str) else annotators + labels = [labels] if isinstance(labels, str) else labels + results = [] for ant in annotators: for lbl in labels: result = _process_annotator_label_pair( - df, ant, lbl, timebin_new, timestamp_audio, max_freq, dataset, + df, + ant, + lbl, + timebin_new, + timestamp_audio, + max_freq, + dataset, ) if result is not None: results.append(result) @@ -572,10 +606,11 @@ def get_filename_timestamps(df: DataFrame, date_parser: str) -> list[Timestamp]: """ tz = get_timezone(df) timestamps = [ - strptime_from_text( - ts, - datetime_template=date_parser, - ) for ts in df["filename"] + strptime_from_text( + ts, + datetime_template=date_parser, + ) + for ts in df["filename"] ] if all(t.tz is None for t in timestamps): @@ -621,10 +656,11 @@ def load_detections(filters: DetectionFilter) -> DataFrame: df = filter_by_freq(df, filters.f_min, filters.f_max) df = filter_by_score(df, filters.score) filename_ts = get_filename_timestamps(df, filters.filename_format) - df = reshape_timebin(df, - timebin_new=filters.timebin_new, - timestamp_audio=filename_ts, - ) + df = reshape_timebin( + df, + timebin_new=filters.timebin_new, + timestamp_audio=filename_ts, + ) annotators = get_annotators(df) if len(annotators) > 1 and filters.user_sel in {"union", "intersection"}: @@ -633,6 +669,81 @@ def load_detections(filters: DetectionFilter) -> DataFrame: return df.sort_values(by=["start_datetime", "end_datetime"]).reset_index(drop=True) +def add_weak_detection( + df: DataFrame, + datetime_format: str = TIMESTAMP_FORMAT_AUDIO_FILE, + max_time: Timedelta | None = None, + max_freq: float | None = None, +) -> DataFrame: + """Add weak detections APLOSE formatted DataFrame with only strong detections. + + Parameters + ---------- + df: DataFrame + An APLOSE formatted DataFrame. + datetime_format: str + A string corresponding to the datetime format in the `filename` column + max_time: Timedelta + Size of the weak detections + max_freq: float + Height of the weak detections + + """ + annotators = get_annotators(df) + labels = get_labels(df) + dataset_id = get_dataset(df) + tz = get_timezone(df) + + if not max_freq: + max_freq = get_max_freq(df) + if not max_time: + max_time = Timedelta(get_max_time(df), "s") + + df["start_datetime"] = [ + strftime_osmose_format(start) for start in df["start_datetime"] + ] + df["end_datetime"] = [strftime_osmose_format(stop) for stop in df["end_datetime"]] + + for ant in annotators: + for lbl in labels: + filenames = ( + df[(df["annotator"] == ant) & (df["annotation"] == lbl)]["filename"] + .drop_duplicates() + .tolist() + ) + for f in filenames: + test = df[(df["filename"] == f) & (df["annotation"] == lbl)]["type"] + if test.any(): + start_datetime = strptime_from_text( + text=f, + datetime_template=datetime_format, + ) + + if not start_datetime.tz: + start_datetime = tz.localize(start_datetime) + + end_datetime = start_datetime + Timedelta(max_time, unit="s") + new_line = [ + dataset_id, + f, + 0, + max_time.total_seconds(), + 0, + max_freq, + lbl, + ant, + strftime_osmose_format(start_datetime), + strftime_osmose_format(end_datetime), + "WEAK", + ] + + if "score" in df.columns: + new_line.append(np.nan) + df.loc[df.index.max() + 1] = new_line + + return df.sort_values(by=["start_datetime", "annotator"]).reset_index(drop=True) + + def intersection_or_union(df: DataFrame, user_sel: str) -> DataFrame: """Compute intersection or union of annotations from multiple annotators.""" annotators = get_annotators(df) @@ -640,6 +751,14 @@ def intersection_or_union(df: DataFrame, user_sel: str) -> DataFrame: msg = "Not enough annotators detected" raise ValueError(msg) + datasets = get_dataset(df) + labels = get_labels(df) + end_frequency = get_max_freq(df) + + annotators = [annotators] if isinstance(annotators, str) else annotators + datasets = [datasets] if isinstance(datasets, str) else datasets + labels = [labels] if isinstance(labels, str) else labels + if user_sel == "all": return df @@ -647,21 +766,34 @@ def intersection_or_union(df: DataFrame, user_sel: str) -> DataFrame: msg = "'user_sel' must be either 'intersection' or 'union'" raise ValueError(msg) - # Count how many annotators marked each (start_datetime, annotation) pair - counts = df.groupby(["annotation", "start_datetime"])["annotator"].transform( - "nunique", - ) + timebin = Timedelta(df["end_time"].iloc[0], "s") + df_count = get_count(df, timebin) if user_sel == "intersection": - df_result = df[counts == len(annotators)] + # Keep only time bins where ALL annotators detected something annotator_name = " ∩ ".join(annotators) - else: # union - df_result = df[counts >= 1] + dataset_name = " ∩ ".join(datasets) + label_name = " ∩ ".join(labels) + mask = (df_count > 0).all(axis=1) + else: + # Keep only time bins where AT LEAST ONE annotator detected something annotator_name = " ∪ ".join(annotators) # noqa: RUF001 + dataset_name = " ∪ ".join(datasets) # noqa: RUF001 + label_name = " ∪ ".join(labels) # noqa: RUF001 + mask = (df_count > 0).any(axis=1) - return ( - df_result.drop_duplicates(subset=["annotation", "start_datetime"]) - .assign(annotator=annotator_name) - .sort_values("start_datetime") - .reset_index(drop=True) - ) + # Get the selected timestamps + selected_times = df_count.index[mask] + + # Filter original df to rows whose start_time falls in selected_times + result = df[df["start_datetime"].isin(selected_times)].copy() + + # Drop duplicates keeping one row per timebin + result = result.drop_duplicates(subset=["start_datetime"]) + + result = result.assign(annotator=annotator_name) + result = result.assign(end_frequency=end_frequency) + result = result.assign(annotation=label_name) + result = result.assign(dataset=dataset_name) + + return result diff --git a/src/post_processing/utils/metrics_utils.py b/src/post_processing/utils/metrics_utils.py index b610d7f..4740bdc 100644 --- a/src/post_processing/utils/metrics_utils.py +++ b/src/post_processing/utils/metrics_utils.py @@ -3,28 +3,29 @@ from __future__ import annotations import logging -from typing import TYPE_CHECKING import numpy as np -from numpy import ndarray -from pandas import DataFrame, Series, Timedelta, Timestamp, date_range +from pandas import DataFrame, DatetimeIndex, Timedelta -if TYPE_CHECKING: - from post_processing.dataclass.recording_period import RecordingPeriod +from post_processing.utils.core_utils import get_count +from post_processing.utils.filtering_utils import ( + get_annotators, + get_labels, + get_max_time, + intersection_or_union, +) def detection_perf( df: DataFrame, - timestamps: list[Timestamp] | None = None, *, ref: tuple[str, str], + time: DatetimeIndex | None = None, ) -> tuple[float, float, float]: - """Compute performances metrics for detection. + """Compute the performance metrics for detection. - Performances are computed with a reference annotator in - comparison with a second annotator/detector. - Precision and recall are computed in regard - with a reference annotator/label pair. + Performances are computed with a reference annotator/label pair + in comparison to a second annotator/label pair. Parameters ---------- @@ -32,8 +33,8 @@ def detection_perf( APLOSE formatted detection/annotation DataFrame ref: tuple[str, str] Tuple of annotator/detector pairs. - timestamps: list[Timestamp] - A list of Timestamps to base the computation on. + time: DatetimeIndex + DatetimeIndex from a specified beginning to end Returns ------- @@ -42,145 +43,113 @@ def detection_perf( f_score: float """ - datetime_begin = df["start_datetime"].min() - datetime_end = df["start_datetime"].max() - df_freq = str(df["end_time"].max()) + "s" - labels = df["annotation"].unique().tolist() - annotators = df["annotator"].unique().tolist() - - num_annotators = 2 - if len(annotators) != num_annotators: + annotators = get_annotators(df) + if len(annotators) != 2: # noqa: PLR2004 msg = f"Two annotators needed, DataFrame contains {len(annotators)} annotators" raise ValueError(msg) - if not timestamps: - timestamps = [ - ts.timestamp() - for ts in date_range( - start=datetime_begin, - end=datetime_end, - freq=df_freq, - ) - ] - else: - timestamps = [ts.timestamp() for ts in timestamps] - - # df1 - REFERENCE - selected_annotator1 = ref[0] - selected_label1 = ref[1] - selected_annotations1 = df[ - (df["annotator"] == selected_annotator1) & (df["annotation"] == selected_label1) - ] - vec1 = _map_datetimes_to_vector(df=selected_annotations1, timestamps=timestamps) - - # df2 - selected_annotator2 = ( - next(ant for ant in annotators if ant != selected_annotator1) - if len(annotators) == 2 # noqa: PLR2004 - else selected_annotator1 - ) - selected_label2 = ( - next(lbl for lbl in labels if lbl != selected_label1) + labels = get_labels(df) + + timebin = Timedelta(get_max_time(df), "s") + df_count = get_count(df, timebin, time) + + # reference annotator and label + annotator1, label1 = ref + annotations1 = df[(df["annotator"] == annotator1) & (df["annotation"] == label1)] + if annotations1.empty: + msg = f"No detection found for {annotator1}/{label1}" + raise ValueError(msg) + vec1 = df_count[f"{label1}-{annotator1}"] + + # second annotator and label + annotator2 = next(ant for ant in annotators if ant != annotator1) + label2 = ( + next(lbl for lbl in labels if lbl != label1) if len(labels) == 2 # noqa: PLR2004 - else selected_label1 + else label1 ) - selected_annotations2 = df[ - (df["annotator"] == selected_annotator2) & (df["annotation"] == selected_label2) - ] - vec2 = _map_datetimes_to_vector(selected_annotations2, timestamps) - - # Metrics computation - true_pos = int(np.sum((vec1 == 1) & (vec2 == 1))) - false_pos = int(np.sum((vec1 == 0) & (vec2 == 1))) - false_neg = int(np.sum((vec1 == 1) & (vec2 == 0))) - true_neg = int(np.sum((vec1 == 0) & (vec2 == 0))) - error = int(np.sum((vec1 != 0) & (vec1 != 1) | (vec2 != 0) & (vec2 != 1))) - - if error != 0: - msg = f"Error : {error}" + vec2 = df_count[f"{label2}-{annotator2}"] + + # metrics computation + confusion_matrix = { + "true_pos": int(np.sum((vec1 == 1) & (vec2 == 1))), + "false_pos": int(np.sum((vec1 == 0) & (vec2 == 1))), + "false_neg": int(np.sum((vec1 == 1) & (vec2 == 0))), + "true_neg": int(np.sum((vec1 == 0) & (vec2 == 0))), + "error": int(np.sum((vec1 != 0) & (vec1 != 1) | (vec2 != 0) & (vec2 != 1))), + } + + if confusion_matrix["error"] != 0: + msg = f"{confusion_matrix['error']} errors in metric computation." raise ValueError(msg) - msg_result = "- Detection results -\n\n" - msg_result += f"True positive : {true_pos}\n" - msg_result += f"True negative : {true_neg}\n" - msg_result += f"False positive : {false_pos}\n" - msg_result += f"False negative : {false_neg}\n\n" - - if true_pos + false_pos == 0 or false_neg + true_pos == 0: - msg = "Precision/Recall computation impossible" + if ( + confusion_matrix["true_pos"] + confusion_matrix["false_pos"] == 0 + or confusion_matrix["false_neg"] + confusion_matrix["true_pos"] == 0 + ): + msg = "Precision/Recall computation impossible." raise ValueError(msg) - msg_result += f"Precision : {true_pos / (true_pos + false_pos):.2f}\n" - msg_result += f"Recall : {true_pos / (false_neg + true_pos):.2f}\n" - - precision = true_pos / (true_pos + false_pos) - recall = true_pos / (true_pos + false_neg) - f_score = 2 * (precision * recall) / (precision + recall) - msg_result += f"F-score : {f_score:.2f}\n\n" + _log_detection_results( + selection1=(annotator1, label1), + selection2=(annotator2, label2), + matrix=confusion_matrix, + df=df, + ) - msg_result += ( - f"Config 1 : {selected_annotator1}/{selected_label1} \n" - f"Config 2 : {selected_annotator2}/{selected_label2}\n\n" + return ( + _get_precision(confusion_matrix), + _get_recall(confusion_matrix), + _get_f_score(confusion_matrix), ) - logging.debug(msg_result) - logging.info(f"Precision: {precision:.2f}") - logging.info(f"Recall: {recall:.2f}") - logging.info(f"F-score: {f_score:.2f}") - return precision, recall, f_score +def _get_precision(confusion_matrix: dict) -> float: + """Compute precision.""" + tp = confusion_matrix["true_pos"] + fp = confusion_matrix["false_pos"] + return tp / (tp + fp) -def _map_datetimes_to_vector(df: DataFrame, timestamps: list[int]) -> ndarray: - """Map datetime ranges to a binary vector indicating overlap with timestamp bins. +def _get_recall(confusion_matrix: dict) -> float: + """Compute recall.""" + tp = confusion_matrix["true_pos"] + fn = confusion_matrix["false_neg"] + return tp / (tp + fn) - Parameters - ---------- - df : DataFrame - APLOSE-formatted DataFrame. - timestamps : list of int - List of UNIX timestamps representing bin start times. - Returns - ------- - ndarray - Binary array (0/1) where 1 indicates overlap with a bin. +def _get_f_score(confusion_matrix: dict) -> float: + """Compute F-score.""" + precision = _get_precision(confusion_matrix) + recall = _get_recall(confusion_matrix) + return 2 * (precision * recall) / (precision + recall) - """ - starts = df["start_datetime"].astype("int64") // 10**9 - ends = df["end_datetime"].astype("int64") // 10**9 - timebin = int(df["end_time"].iloc[0]) # duration in seconds - - timestamps = np.array(timestamps) - ts_start = timestamps - ts_end = timestamps + timebin - - vec = np.zeros(len(timestamps), dtype=int) - - for start, end in zip(starts, ends, strict=False): - overlap = (ts_start < end) & (ts_end > start) - vec[overlap] = 1 - - return vec - - -def normalize_counts_by_effort(counts: DataFrame, - effort: RecordingPeriod, - time_bin: Timedelta, - ) -> DataFrame: - """Normalize detection counts given the observation effort.""" - timebin_origin = effort.timebin_origin - effort_series = effort.counts - effort_intervals = effort_series.index - effort_series.index = [interval.left for interval in effort_series.index] - for col in counts.columns: - effort_ratio = effort_series * (timebin_origin / time_bin) - effort_ratio = Series( - np.where((effort_ratio > 0) & (effort_ratio < 1), 1.0, effort_ratio), - index=effort_series.index, - name=effort_series.name, - ) - counts[f"{col}"] = ((counts[col] / effort_ratio.reindex(counts[col].index)) - .clip(upper=1)) - effort_series.index = effort_intervals - return counts + +def _log_detection_results( + selection1: tuple[str, str], + selection2: tuple[str, str], + matrix: dict, + df: DataFrame, +) -> None: + """Log detection performance results.""" + annotator1, label1 = selection1 + annotator2, label2 = selection2 + precision = _get_precision(matrix) + recall = _get_recall(matrix) + f_score = _get_f_score(matrix) + + msg_result = ( + f"{' Detection results ':#^50}\n" + f"{'Config 1:':<10}{f'{annotator1}/{label1}':>40}\n" + f"{'Config 2:':<10}{f'{annotator2}/{label2}':>40}\n\n" + f"{'True positive:':<25}{matrix['true_pos']:>25}\n" + f"{'True negative:':<25}{matrix['true_neg']:>25}\n" + f"{'False positive:':<25}{matrix['false_pos']:>25}\n" + f"{'False negative:':<25}{matrix['false_neg']:>25}\n\n" + f"{'Precision:':<25}{precision:>25.2f}\n" + f"{'Recall:':<25}{recall:>25.2f}\n" + f"{'F-score:':<25}{f_score:>25.2f}\n\n" + f"{'Union:':<25}{len(intersection_or_union(df, 'union')):>25.0f}\n" + f"{'Intersection:':<25}{len(intersection_or_union(df, 'intersection')):>25.0f}\n" + ) + logging.info(msg_result) diff --git a/src/post_processing/utils/plot_utils.py b/src/post_processing/utils/plot_utils.py index 8d12fa3..8f84334 100644 --- a/src/post_processing/utils/plot_utils.py +++ b/src/post_processing/utils/plot_utils.py @@ -12,7 +12,7 @@ from matplotlib import dates as mdates from matplotlib.dates import num2date from matplotlib.patches import Patch -from numpy import ceil, histogram, polyfit +from numpy import ceil, polyfit from pandas import ( DataFrame, DatetimeIndex, @@ -24,7 +24,6 @@ ) from pandas.tseries import frequencies from scipy.stats import pearsonr -from seaborn import scatterplot from post_processing.utils.core_utils import ( add_season_period, @@ -32,8 +31,7 @@ get_labels_and_annotators, get_sun_times, get_time_range_and_bin_size, - round_begin_end_timestamps, - timedelta_to_str, + timedelta_to_str, round_begin_end_timestamps, ) from post_processing.utils.filtering_utils import ( filter_by_annotator, @@ -75,18 +73,18 @@ def histo( - legend: bool Whether to show the legend. - color: str | list[str] - Colour or list of colours for the histogram bars. - If not provided, default colours will be used. + Color or list of colors for the histogram bars. + If not provided, default colors will be used. - season: bool Whether to show the season. - coordinates: tuple[float, float] The coordinates of the plotted detections. - effort: RecordingPeriod Object corresponding to the observation effort. - If provided, data will be normalised by observation effort. + If provided, data will be normalized by observation effort. """ - labels, annotators = zip(*[col.rsplit("-", 1) for col in df.columns], strict=False) + labels, annotators = zip(*[col.rsplit("-", 1) for col in df.columns], strict=True) labels = list(labels) annotators = list(annotators) @@ -122,7 +120,7 @@ def histo( offset = i * bar_width.total_seconds() / 86400 bar_kwargs = { - "width": bar_width.total_seconds() / 86400, + "width": (bar_width.total_seconds() / 86400), "align": "edge", "edgecolor": "black", "color": color[i], @@ -153,8 +151,8 @@ def histo( ) if season: - if lat is None or lon is None: - get_coordinates() + if lat is None: + lat, _ = get_coordinates() add_season_period(ax, northern=lat >= 0) @@ -469,77 +467,81 @@ def wrap_text(text: str) -> str: ax.set_xticklabels(new_labels, rotation=0) -def agreement( +def plot_annotator_agreement( df: DataFrame, bin_size: Timedelta | BaseOffset, - ax: plt.Axes, + ax: Axes, ) -> None: - """Compute and visualise agreement between two annotators. + """Plot inter-annotator agreement with linear regression. - This function compares annotation timestamps from two annotators over a time range. - It also fits and plots a linear regression line and displays the coefficient - of determination (R²) on the plot. + Creates a scatter plot comparing annotation counts between two annotators + across time bins. Fits a linear regression line and displays the coefficient + of determination (R²) in the legend. Parameters ---------- df : DataFrame - APLOSE-formatted DataFrame. - It must contain The annotations of two annotators. - + APLOSE-formatted DataFrame containing annotations from exactly two annotators. bin_size : Timedelta | BaseOffset - The size of each time bin for aggregating annotation timestamps. - - ax : matplotlib.axes.Axes - Matplotlib axes object where the scatterplot and regression line will be drawn. - - """ - labels, annotators = get_labels_and_annotators(df) + Size of each time bin for aggregating annotation timestamps. + ax : plt.Axes + Matplotlib axes object where the scatter plot and regression line will be drawn. - datetimes = [ - list( - df[ - (df["annotator"] == annotators[i]) & (df["annotation"] == labels[i]) - ]["start_datetime"], - ) - for i in range(2) - ] + Notes + ----- + The function modifies the provided axes object in place and does not return a value. + Each point in the scatter plot represents the annotation counts from both annotators + within a single time bin. - # scatter plot - n_annot_max = bin_size.total_seconds() / df["end_time"].iloc[0] + Examples + -------- + >>> fig, ax = plt.subplots() + >>> plot_annotator_agreement(df, Timedelta(hours=1), ax) - freq = ( - bin_size if isinstance(bin_size, Timedelta) else str(bin_size.n) + bin_size.name - ) + """ + labels, annotators = zip(*[col.rsplit("-", 1) for col in df.columns], strict=True) + labels = list(labels) + annotators = list(annotators) - bins = date_range( - start=df["start_datetime"].min().floor(bin_size), - end=df["start_datetime"].max().ceil(bin_size), - freq=freq, + ax.scatter( + df[f"{labels[0]}-{annotators[0]}"], + df[f"{labels[1]}-{annotators[1]}"], + zorder=3, ) - - df_hist = ( - DataFrame( - { - annotators[0]: histogram(datetimes[0], bins=bins)[0], - annotators[1]: histogram(datetimes[1], bins=bins)[0], - }, - ) - / n_annot_max + coefficients = polyfit( + df[f"{labels[0]}-{annotators[0]}"], + df[f"{labels[1]}-{annotators[1]}"], + deg=1, ) - - scatterplot(data=df_hist, x=annotators[0], y=annotators[1], ax=ax) - - coefficients = polyfit(df_hist[annotators[0]], df_hist[annotators[1]], 1) poly = np.poly1d(coefficients) - ax.plot(df_hist[annotators[0]], poly(df_hist[annotators[0]]), lw=1) - - ax.set_xlabel(f"{annotators[0]}\n{labels[0]}") - ax.set_ylabel(f"{annotators[1]}\n{labels[1]}") - ax.grid(linestyle="-", linewidth=0.2) + r, _ = pearsonr( + df[f"{labels[0]}-{annotators[0]}"], + df[f"{labels[1]}-{annotators[1]}"], + ) # R² + ax.plot( + sorted(df[f"{labels[0]}-{annotators[0]}"]), + poly(sorted(df[f"{labels[0]}-{annotators[0]}"])), + lw=0.5, + color="k", + alpha=0.5, + linestyle="-", + label=f"R²={r**2:.2f}", + zorder=2, + ) - # Pearson correlation (R²) - r, _ = pearsonr(df_hist[annotators[0]], df_hist[annotators[1]]) - ax.text(0.05, 0.85, f"R² = {r**2:.2f}", transform=ax.transAxes) + ax.set_xlabel(f"""annotator: {annotators[0]}\nlabel: {labels[0]}""") + ax.set_ylabel(f"""annotator: {annotators[1]}\nlabel: {labels[1]}""") + ax.grid( + linestyle="-", + linewidth=0.2, + zorder=1, + ) + ax.legend( + loc="upper left", + frameon=True, + framealpha=1, + fontsize=8, + ) def timeline( diff --git a/tests/conftest.py b/tests/conftest.py index a6299e3..3d8998e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -10,6 +10,8 @@ from pandas import DataFrame, read_csv from pandas.tseries import frequencies +from post_processing.dataclass.data_aplose import DataAplose + SAMPLE = """dataset,filename,start_time,end_time,start_frequency,end_frequency,annotation,annotator,start_datetime,end_datetime,type,score sample_dataset,2025_01_25_06_20_00,0.0,10.0,0.0,72000.0,lbl2,ann2,2025-01-25T06:20:00.000+00:00,2025-01-25T06:20:10.000+00:00,WEAK,0.11 sample_dataset,2025_01_25_06_20_00,3.46662989520132,4.02371759514617,7523.0,15257.0,lbl2,ann2,2025-01-25T06:20:03.466+00:00,2025-01-25T06:20:04.023+00:00,BOX,0.23 @@ -62,8 +64,6 @@ sample_dataset,2025_01_25_06_20_30,0.0,10.0,0.0,72000.0,lbl2,ann1,2025-01-25T06:20:30.000+00:00,2025-01-25T06:20:40.000+00:00,WEAK,0.56 sample_dataset,2025_01_25_06_20_30,2.1042471042471,3.00330943188086,8296.0,18562.0,lbl2,ann1,2025-01-25T06:20:32.104+00:00,2025-01-25T06:20:33.003+00:00,BOX,0.13 sample_dataset,2025_01_25_06_20_30,4.04026475455047,5.41919470490899,7312.0,21515.0,lbl2,ann1,2025-01-25T06:20:34.040+00:00,2025-01-25T06:20:35.419+00:00,BOX,0.92 -sample_dataset,2025_01_25_06_20_30,0.0,10.0,0.0,72000.0,lbl2,ann1,2025-01-25T06:20:30.000+00:00,2025-01-25T06:20:40.000+00:00,WEAK,0.76 -sample_dataset,2025_01_25_06_20_30,0.0,10.0,0.0,72000.0,lbl2,ann1,2025-01-25T06:20:30.000+00:00,2025-01-25T06:20:40.000+00:00,WEAK,0.31 sample_dataset,2025_01_25_06_20_30,0.0,10.0,0.0,72000.0,lbl1,ann1,2025-01-25T06:20:30.000+00:00,2025-01-25T06:20:40.000+00:00,WEAK,0.41 sample_dataset,2025_01_25_06_20_30,0.0,10.0,0.0,72000.0,lbl2,ann6,2025-01-25T06:20:30.000+00:00,2025-01-25T06:20:40.000+00:00,WEAK,0.15 sample_dataset,2025_01_25_06_20_30,1.97186982901269,2.93160507446222,8578.0,17578.0,lbl2,ann6,2025-01-25T06:20:31.971+00:00,2025-01-25T06:20:32.931+00:00,BOX,0.49 @@ -133,9 +133,7 @@ sample_dataset,2025_01_26_06_20_20,FINISHED,FINISHED,FINISHED,FINISHED,FINISHED,FINISHED """ -# --------------------------------------------------------------------------- # Fake recording planning CSV used for tests -# --------------------------------------------------------------------------- RECORDING_PLANNING_CSV = """start_recording,end_recording,start_deployment,end_deployment 2024-01-01 00:00:00+0000,2024-04-09 02:00:00+0000,2024-01-02 00:00:00+0000,2024-04-30 02:00:00+0000 2024-04-30 01:00:00+0000,2024-07-14 06:00:00+0000,2024-04-30 02:00:00+0000,2024-07-06 14:00:00+0000 @@ -145,7 +143,17 @@ @pytest.fixture def sample_df() -> DataFrame: df = read_csv(io.StringIO(SAMPLE), parse_dates=["start_datetime", "end_datetime"]) - return df.sort_values(["start_datetime", "end_datetime", "annotator", "annotation"]).reset_index(drop=True) + return df.sort_values([ + "start_datetime", + "end_datetime", + "annotator", + "annotation", + ]).reset_index(drop=True) + + +@pytest.fixture +def sample_data_aplose(sample_df: DataFrame) -> DataAplose: + return DataAplose(sample_df) @pytest.fixture @@ -157,8 +165,12 @@ def sample_status() -> DataFrame: def sample_csv_result(tmp_path: Path, sample_df: DataFrame) -> Path: result_file = tmp_path / "results.csv" df_copy = sample_df.copy() - df_copy["start_datetime"] = [strftime_osmose_format(ts) for ts in sample_df["start_datetime"]] - df_copy["end_datetime"] = [strftime_osmose_format(ts) for ts in sample_df["end_datetime"]] + df_copy["start_datetime"] = [ + strftime_osmose_format(ts) for ts in sample_df["start_datetime"] + ] + df_copy["end_datetime"] = [ + strftime_osmose_format(ts) for ts in sample_df["end_datetime"] + ] df_copy.to_csv(result_file, index=False) return result_file @@ -172,9 +184,9 @@ def sample_csv_timestamp(tmp_path: Path, sample_status: DataFrame) -> Path: @pytest.fixture def sample_yaml( - tmp_path: Path, - sample_csv_result: Path, - sample_csv_timestamp: Path, + tmp_path: Path, + sample_csv_result: Path, + sample_csv_timestamp: Path, ) -> Path: yaml_content = { f"{sample_csv_result}": { @@ -221,7 +233,6 @@ def sample_audio(tmp_path: Path) -> Path: @pytest.fixture def tmp_audio_dir(tmp_path: Path) -> Path: - def create_file(path: Path, size: int = 2048): """Create a file of given size in bytes.""" path.write_bytes(os.urandom(size)) @@ -248,6 +259,7 @@ def recording_planning_csv(tmp_path) -> Path: @pytest.fixture def recording_planning_config(recording_planning_csv): """Minimal config object compatible with RecordingPeriod.from_path.""" + class RecordingPlanningConfig: timestamp_file: Path = recording_planning_csv timebin_origin = frequencies.to_offset("1min") diff --git a/tests/test_DataAplose.py b/tests/test_DataAplose.py index 9b9516c..d14e447 100644 --- a/tests/test_DataAplose.py +++ b/tests/test_DataAplose.py @@ -1,12 +1,15 @@ +from copy import copy from pathlib import Path +from typing import ContextManager import matplotlib.dates as mdates import matplotlib.pyplot as plt import pytest -from pandas import DataFrame, Timedelta +from pandas import DataFrame, Timedelta, Timestamp from pandas.tseries import frequencies from post_processing.dataclass.data_aplose import DataAplose +from post_processing.utils.filtering_utils import get_timezone def test_data_aplose_init(sample_df: DataFrame) -> None: @@ -26,8 +29,7 @@ def test_filter_df_single_pair(sample_df: DataFrame) -> None: assert sorted(set(filtered_data["annotation"])) == ["lbl1"] assert sorted(set(filtered_data["annotator"])) == ["ann1"] expected = sample_df[ - (sample_df["annotator"] == "ann1") & - (sample_df["annotation"] == "lbl1") + (sample_df["annotator"] == "ann1") & (sample_df["annotation"] == "lbl1") ].reset_index(drop=True) assert filtered_data.equals(expected) @@ -38,10 +40,18 @@ def test_change_tz(sample_df: DataFrame) -> None: data.change_tz(new_tz) start_dt = data.df["start_datetime"] end_dt = data.df["end_datetime"] - assert all(ts.tz.zone == new_tz for ts in start_dt), f"The detection start timestamps have to be in {new_tz} timezone" - assert all(ts.tz.zone == new_tz for ts in end_dt), f"The detection end timestamps have to be in {new_tz} timezone" - assert data.begin.tz.zone == new_tz, f"The begin value of the DataAplose has to be in {new_tz} timezone" - assert data.end.tz.zone == new_tz, f"The end value of the DataAplose has to be in {new_tz} timezone" + assert all(ts.tz.zone == new_tz for ts in start_dt), ( + f"The detection start timestamps have to be in {new_tz} timezone" + ) + assert all(ts.tz.zone == new_tz for ts in end_dt), ( + f"The detection end timestamps have to be in {new_tz} timezone" + ) + assert data.begin.tz.zone == new_tz, ( + f"The begin value of the DataAplose has to be in {new_tz} timezone" + ) + assert data.end.tz.zone == new_tz, ( + f"The end value of the DataAplose has to be in {new_tz} timezone" + ) def test_filter_df_multiple_pairs(sample_df: DataFrame) -> None: @@ -75,7 +85,7 @@ def test_filter_df_invalid_label(sample_df: DataFrame) -> None: def test_filter_df_invalid_combination( - sample_df: DataFrame, + sample_df: DataFrame, ) -> None: data = DataAplose(sample_df) with pytest.raises( @@ -87,7 +97,7 @@ def test_filter_df_invalid_combination( def test_filter_df_invalid_lists_size( - sample_df: DataFrame, + sample_df: DataFrame, ) -> None: data = DataAplose(sample_df) with pytest.raises( @@ -144,7 +154,9 @@ def test_plot_scatter_heatmap_timeline(sample_df: DataFrame, mode: str) -> None: data.lat = 0 bin_size = frequencies.to_offset("1d") fig, ax = plt.subplots() - data.plot(mode=mode, ax=ax, annotator="ann1", label="lbl1", bin_size=bin_size, color="red") + data.plot( + mode=mode, ax=ax, annotator="ann1", label="lbl1", bin_size=bin_size, color="red" + ) def test_heatmap_wrong_bin(sample_df: DataFrame) -> None: @@ -154,7 +166,14 @@ def test_heatmap_wrong_bin(sample_df: DataFrame) -> None: bins = frequencies.to_offset("10s") fig, ax = plt.subplots() with pytest.raises(ValueError, match="`bin_size` must be >= 24h for heatmap mode."): - data.plot(mode="heatmap", ax=ax, annotator="ann1", label="lbl1", bin_size=bins, color="red") + data.plot( + mode="heatmap", + ax=ax, + annotator="ann1", + label="lbl1", + bin_size=bins, + color="red", + ) def test_plot_invalid_mode(sample_df: DataFrame) -> None: @@ -167,27 +186,33 @@ def test_plot_invalid_mode(sample_df: DataFrame) -> None: def test_plot_agreement(sample_df: DataFrame) -> None: data = DataAplose(sample_df) fig, ax = plt.subplots() - data.plot(mode="agreement", - ax=ax, annotator=["ann1", "ann2"], - label="lbl1", - bin_size=Timedelta("10s"), - ) + data.plot( + mode="agreement", + ax=ax, + annotator=["ann1", "ann2"], + label="lbl1", + bin_size=Timedelta("10s"), + ) def test_set_ax(sample_df: DataFrame) -> None: da = DataAplose(sample_df) - fig, ax = plt.subplots() + _, ax = plt.subplots() ax = da.set_ax(ax, Timedelta("7h"), "%Y-%m-%d") locator = ax.xaxis.get_major_locator() assert isinstance(locator, mdates.HourLocator) def test_from_yaml( - sample_yaml: Path, - sample_df: DataFrame, + sample_yaml: Path, + sample_df: DataFrame, ) -> None: df_from_yaml = DataAplose.from_yaml(file=sample_yaml).df - df_expected = DataAplose(sample_df).filter_df(annotator="ann1", label="lbl1").reset_index(drop=True) + df_expected = ( + DataAplose(sample_df) + .filter_df(annotator="ann1", label="lbl1") + .reset_index(drop=True) + ) assert df_from_yaml.equals(df_expected) @@ -212,3 +237,149 @@ def test_concat(sample_yaml: Path, sample_df: DataFrame) -> None: assert got.equals(exp), f"Mismatch in {attr}" else: assert got == exp, f"Mismatch in {attr}" + + +# %% Reshape + + +@pytest.mark.parametrize( + ("begin", "end", "expected"), + [ + pytest.param( + Timestamp("2025-01-26T06:20:09.999+00:00"), + None, + pytest.raises(ValueError, match=r"DataFrame is empty after reshaping."), + id="new_begin_after_original_end", + ), + pytest.param( + None, + Timestamp("2025-01-25T06:20:00.001+00:00"), + pytest.raises(ValueError, match=r"DataFrame is empty after reshaping."), + id="new_end_before_original_begin", + ), + pytest.param( + Timestamp("2024-12-31"), + Timestamp("2024-01-01"), + pytest.raises( + ValueError, match=r"Begin timestamp is not anterior than end timestamp." + ), + id="begin_after_end_inverted_range", + ), + pytest.param( + Timestamp("2050-01-01", tz="UTC"), + Timestamp("2050-12-31", tz="UTC"), + pytest.raises(ValueError, match=r"DataFrame is empty after reshaping."), + id="tz_aware_future_range_no_data", + ), + pytest.param( + Timestamp("1990-01-01", tz="America/New_York"), + Timestamp("1990-12-31", tz="America/New_York"), + pytest.raises(ValueError, match=r"DataFrame is empty after reshaping."), + id="tz_aware_past_range_no_data", + ), + ], +) +def test_reshape_errors( + sample_data_aplose: DataAplose, + begin: Timestamp | None, + end: Timestamp | None, + expected: ContextManager[Exception], +) -> None: + """Test that reshape function handles error cases appropriately.""" + with expected: + sample_data_aplose.reshape(begin, end) + + +@pytest.mark.parametrize( + ("begin", "end", "should_filter"), + [ + pytest.param( + None, + None, + False, + id="no_timestamps_provided", + ), + pytest.param( + Timestamp("1990-01-01", tz="UTC"), + None, + True, + id="tz_aware_begin_only", + ), + pytest.param( + None, + Timestamp("2050-12-31", tz="UTC"), + True, + id="tz_aware_end_only", + ), + pytest.param( + Timestamp("2025-01-24", tz="Europe/Paris"), + Timestamp("2025-01-27", tz="Europe/Paris"), + True, + id="tz_aware_both_timestamps", + ), + pytest.param( + Timestamp("1990-01-01"), + None, + True, + id="tz_naive_begin_only", + ), + pytest.param( + None, + Timestamp("2050-12-31"), + True, + id="tz_naive_end_only", + ), + pytest.param( + Timestamp("1990-01-01"), + Timestamp("2050-12-31"), + True, + id="tz_naive_both_timestamps", + ), + pytest.param( + Timestamp("1990-01-01", tz="Europe/Paris"), + Timestamp("2050-12-31", tz="America/New_York"), + True, + id="tz_aware_different_timezone", + ), + pytest.param( + Timestamp("2025-01-26T00:00:00.000+00:00"), + Timestamp("2025-01-26T10:00:00.000+00:00"), + True, + id="narrowing_detections", + ), + ], +) +def test_reshape_valid_cases( + sample_data_aplose: DataAplose, + begin: Timestamp | None, + end: Timestamp | None, + should_filter: bool, +) -> None: + """Test that reshape handles valid timestamp cases appropriately.""" + reshaped = copy(sample_data_aplose) + reshaped.reshape(begin=begin, end=end) + + original_tz = get_timezone(sample_data_aplose.df) + + # Check timestamps were updated + if begin is not None: + if begin.tz is None: + begin = begin.tz_localize(original_tz) + assert reshaped.begin != sample_data_aplose.begin + assert reshaped.begin == begin + + if end is not None: + if end.tz is None: + end = end.tz_localize(original_tz) + assert reshaped.end != sample_data_aplose.end + assert reshaped.end == end + + # Check timezone was applied for tz-naive timestamps + assert reshaped.begin.tz is not None + assert reshaped.end.tz is not None + + # Check filtering behavior + if should_filter: + assert reshaped.shape <= sample_data_aplose.shape + assert all(reshaped.df["start_datetime"] >= reshaped.begin) + assert all(reshaped.df["end_datetime"] <= reshaped.end) diff --git a/tests/test_core_utils.py b/tests/test_core_utils.py index e72e482..6cfa494 100644 --- a/tests/test_core_utils.py +++ b/tests/test_core_utils.py @@ -10,7 +10,6 @@ from post_processing.utils.core_utils import ( add_recording_period, add_season_period, - add_weak_detection, get_coordinates, get_count, get_labels_and_annotators, @@ -23,6 +22,7 @@ set_bar_height, timedelta_to_str, ) +from post_processing.utils.filtering_utils import add_weak_detection def test_coordinates_valid_input(monkeypatch: pytest.MonkeyPatch) -> None: @@ -30,6 +30,7 @@ def test_coordinates_valid_input(monkeypatch: pytest.MonkeyPatch) -> None: def fake_box(msg: str, title: str, fields: list[str]) -> list[str]: return inputs + monkeypatch.setattr("easygui.multenterbox", fake_box) lat, lon = get_coordinates() assert lat == 42 # noqa: PLR2004 @@ -39,6 +40,7 @@ def fake_box(msg: str, title: str, fields: list[str]) -> list[str]: def test_coordinates_cancelled_input(monkeypatch: pytest.MonkeyPatch) -> None: def fake_box(msg: str, title: str, fields: list[str]) -> None: return None + monkeypatch.setattr("easygui.multenterbox", fake_box) with pytest.raises(TypeError, match="was cancelled"): get_coordinates() @@ -49,6 +51,7 @@ def test_coordinates_invalid_then_valid_input(monkeypatch: pytest.MonkeyPatch) - def fake_box(msg: str, title: str, fields: list[str]) -> list[str]: return inputs.pop(0) + monkeypatch.setattr("easygui.multenterbox", fake_box) lat, lon = get_coordinates() assert lat == 45.0 # noqa: PLR2004 @@ -60,6 +63,7 @@ def test_coordinates_non_numeric_input(monkeypatch: pytest.MonkeyPatch) -> None: def fake_box(msg: str, title: str, fields: list[str]) -> list[str]: return inputs.pop(0) + monkeypatch.setattr("easygui.multenterbox", fake_box) lat, lon = get_coordinates() assert lat == 10.0 # noqa: PLR2004 @@ -77,7 +81,6 @@ def fake_box(msg: str, title: str, fields: list[str]) -> list[str]: (Timestamp("2025-02-28"), True, ("winter", 2024)), (Timestamp("2024-02-29"), True, ("winter", 2023)), (Timestamp("2025-12-25"), True, ("winter", 2025)), - # Southern hemisphere (Timestamp("2025-03-15"), False, ("autumn", 2025)), (Timestamp("2025-06-21"), False, ("winter", 2025)), @@ -95,21 +98,26 @@ def test_get_season(ts: Timestamp, northern: bool, expected: tuple[str, int]) -> @pytest.mark.parametrize( ("start", "stop", "lat", "lon"), [ - (Timestamp("2025-06-01 00:00:00+00:00"), - Timestamp("2025-06-03 23:59:59+00:00"), - 49.4333, - -1.5167), - (Timestamp("2025-12-21 00:00:00+00:00"), - Timestamp("2025-12-22 23:59:59+00:00"), - -34.9011, - -56.1645), + ( + Timestamp("2025-06-01 00:00:00+00:00"), + Timestamp("2025-06-03 23:59:59+00:00"), + 49.4333, + -1.5167, + ), + ( + Timestamp("2025-12-21 00:00:00+00:00"), + Timestamp("2025-12-22 23:59:59+00:00"), + -34.9011, + -56.1645, + ), ], ) -def test_get_sun_times_valid_input(start: Timestamp, - stop: Timestamp, - lat: float, - lon: float, - ) -> None: +def test_get_sun_times_valid_input( + start: Timestamp, + stop: Timestamp, + lat: float, + lon: float, +) -> None: results = get_sun_times(start, stop, lat, lon) h_sunrise, h_sunset = results @@ -128,10 +136,12 @@ def test_get_sun_times_valid_input(start: Timestamp, @pytest.mark.parametrize( ("start", "stop", "lat", "lon"), [ - (Timestamp("2025-06-01 00:00:00"), - Timestamp("2025-06-03 23:59:59"), - 49.4333, - -1.5167), + ( + Timestamp("2025-06-01 00:00:00"), + Timestamp("2025-06-03 23:59:59"), + 49.4333, + -1.5167, + ), ], ) def test_get_sun_times_naive_timestamps( @@ -146,18 +156,21 @@ def test_get_sun_times_naive_timestamps( # %% get_count + def test_get_count_basic(sample_df: DataFrame) -> None: df = DataAplose(sample_df).filter_df(annotator="ann1", label="lbl1") - result = get_count(df, bin_size=Timedelta("30min")) + result = get_count(df, bin_size=Timedelta("1min")) expected = sample_df[ - (sample_df["annotator"] == "ann1") & - (sample_df["annotation"] == "lbl1") + (sample_df["annotator"] == "ann1") & (sample_df["annotation"] == "lbl1") ] - assert list(result.index) == date_range( - Timestamp("2025-01-25 06:00:00+0000"), - Timestamp("2025-01-26 06:00:00+0000"), - freq="30min", - ).to_list() + assert ( + list(result.index) + == date_range( + Timestamp("2025-01-25 06:20:00+0000"), + Timestamp("2025-01-26 06:20:00+0000"), + freq="1min", + ).to_list() + ) assert result.columns == ["lbl1-ann1"] assert sum(result["lbl1-ann1"].tolist()) == len(expected) @@ -166,9 +179,9 @@ def test_get_count_multiple_annotators(sample_df: DataFrame) -> None: df = DataAplose(sample_df).filter_df(annotator=["ann1", "ann2"], label="lbl1") result = get_count(df, bin_size=Timedelta("1d")) expected = sample_df[ - (sample_df["annotator"].isin(["ann1", "ann2"])) & - (sample_df["annotation"] == "lbl1") - ] + (sample_df["annotator"].isin(["ann1", "ann2"])) + & (sample_df["annotation"] == "lbl1") + ] assert set(result.columns) == {"lbl1-ann1", "lbl1-ann2"} assert result["lbl1-ann1"].sum() == len(expected[expected["annotator"] == "ann1"]) @@ -176,12 +189,14 @@ def test_get_count_multiple_annotators(sample_df: DataFrame) -> None: def test_get_count_multiple_labels(sample_df: DataFrame) -> None: - df = DataAplose(sample_df).filter_df(annotator="ann5", label=["lbl1", "lbl2", "lbl3"]) + df = DataAplose(sample_df).filter_df( + annotator="ann5", label=["lbl1", "lbl2", "lbl3"] + ) result = get_count(df, bin_size=Timedelta("1day")) expected = sample_df[ - (sample_df["annotator"] == "ann5") & - (sample_df["annotation"].isin(["lbl1", "lbl2", "lbl3"])) - ] + (sample_df["annotator"] == "ann5") + & (sample_df["annotation"].isin(["lbl1", "lbl2", "lbl3"])) + ] assert set(result.columns) == {"lbl1-ann5", "lbl2-ann5", "lbl3-ann5"} assert result["lbl1-ann5"].sum() == len(expected[expected["annotation"] == "lbl1"]) @@ -190,19 +205,29 @@ def test_get_count_multiple_labels(sample_df: DataFrame) -> None: def test_get_count_multiple_labels_annotators(sample_df: DataFrame) -> None: - df = DataAplose(sample_df).filter_df(annotator=["ann1", "ann2"], - label=["lbl1", "lbl2"], - ) + df = DataAplose(sample_df).filter_df( + annotator=["ann1", "ann2"], + label=["lbl1", "lbl2"], + ) result = get_count(df, bin_size=Timedelta("1day")) assert set(result.columns) == {"lbl1-ann1", "lbl2-ann2"} - assert result["lbl1-ann1"].sum() == len(sample_df[(sample_df["annotation"] == "lbl1") & (sample_df["annotator"] == "ann1")]) - assert result["lbl2-ann2"].sum() == len(sample_df[(sample_df["annotation"] == "lbl2") & (sample_df["annotator"] == "ann2")]) + assert result["lbl1-ann1"].sum() == len( + sample_df[ + (sample_df["annotation"] == "lbl1") & (sample_df["annotator"] == "ann1") + ] + ) + assert result["lbl2-ann2"].sum() == len( + sample_df[ + (sample_df["annotation"] == "lbl2") & (sample_df["annotator"] == "ann2") + ] + ) def test_get_count_empty_df() -> None: with pytest.raises(ValueError, match="`df` contains no data"): get_count(DataFrame(), Timedelta("1h")) + # %% get_labels_and_annotators @@ -240,6 +265,7 @@ def test_get_labels_and_annotators_empty_dataframe() -> None: with pytest.raises(ValueError, match="`df` contains no data"): get_labels_and_annotators(DataFrame()) + # %% localize_timestamps @@ -272,6 +298,7 @@ def test_mixed_naive_and_aware() -> None: assert localized[0].tzinfo.zone == tz.zone assert localized[1].tzinfo.zone == tz.zone + # %% get_time_range_and_bin_size @@ -280,7 +307,9 @@ def test_time_range_timedelta() -> None: bin_size = Timedelta("1h") time_range, returned_bin = get_time_range_and_bin_size(timestamps, bin_size) - expected = date_range(start="2025-08-20 12:00:00", end="2025-08-20 15:00:00", freq="1h") + expected = date_range( + start="2025-08-20 12:00:00", end="2025-08-20 15:00:00", freq="1h" + ) assert (time_range == expected).all() assert returned_bin == bin_size @@ -290,7 +319,9 @@ def test_time_range_baseoffset() -> None: bin_size = frequencies.to_offset("1h") time_range, returned_bin = get_time_range_and_bin_size(timestamps, bin_size) - expected = date_range(start="2025-08-20 12:00:00", end="2025-08-20 15:00:00", freq="1h") + expected = date_range( + start="2025-08-20 12:00:00", end="2025-08-20 15:00:00", freq="1h" + ) assert (time_range == expected).all() assert returned_bin == bin_size @@ -305,16 +336,21 @@ def test_empty_timestamp_list() -> None: def test_invalid_timestamp_list_type() -> None: timestamps = "not_a_list" bin_size = Timedelta("1h") - with pytest.raises(TypeError, match=r"`timestamp_list` must be a list\[Timestamp\]"): + with pytest.raises( + TypeError, match=r"`timestamp_list` must be a list\[Timestamp\]" + ): get_time_range_and_bin_size(timestamps, bin_size) def test_invalid_timestamp_list_content() -> None: timestamps = [Timestamp("2025-08-20"), "not_a_timestamp"] bin_size = Timedelta("1h") - with pytest.raises(TypeError, match=r"`timestamp_list` must be a list\[Timestamp\]"): + with pytest.raises( + TypeError, match=r"`timestamp_list` must be a list\[Timestamp\]" + ): get_time_range_and_bin_size(timestamps, bin_size) + # %% round_begin_end_timestamps @@ -354,12 +390,15 @@ def test_round_begin_end_timestamps_valid_entry_2() -> None: Timestamp("2025-01-01 10:15:00"), Timestamp("2025-01-03 18:45:00"), ] - start, end, bin_size = round_begin_end_timestamps(ts_list, frequencies.to_offset("1h")) + start, end, bin_size = round_begin_end_timestamps( + ts_list, frequencies.to_offset("1h") + ) assert start == Timestamp("2025-01-01 10:00:00") assert end == Timestamp("2025-01-03 19:00:00") assert bin_size == Timedelta("1h") + # %% timedelta_to_str @@ -391,22 +430,24 @@ def test_add_wd(sample_df: DataFrame) -> None: # %% add_season_period -def test_add_season_valid() -> None: - fig, ax = plt.subplots() - start = Timestamp("2025-01-01T00:00:00+00:00") - stop = Timestamp("2025-01-02T00:00:00+00:00") - ts = date_range(start=start, end=stop, freq="H", tz="UTC") - values = list(range(len(ts))) - ax.plot(ts, values) +def test_add_season_valid() -> None: + _, ax = plt.subplots() + start = Timestamp("2025-01-01") + stop = Timestamp("2026-01-01") + freq = Timedelta("1d") + ts = date_range(start=start, end=stop, freq=freq, tz="UTC") + values = [date.day for date in ts] + [ax.bar(loc + freq, height) for loc, height in zip(ts, values, strict=True)] add_season_period(ax=ax) def test_add_season_no_data() -> None: - fig, ax = plt.subplots() + _, ax = plt.subplots() with pytest.raises(ValueError, match=r"have no data"): add_season_period(ax=ax) + # %% add_recording_period @@ -437,6 +478,7 @@ def test_add_recording_period_no_data() -> None: with pytest.raises(ValueError, match=r"have no data"): add_recording_period(df=df, ax=ax) + # %% set_bar_height @@ -457,6 +499,7 @@ def test_set_bar_height_no_data() -> None: with pytest.raises(ValueError, match=r"have no data"): set_bar_height(ax=ax) + # %% json2df diff --git a/tests/test_filtering_utils.py b/tests/test_filtering_utils.py index 3ec3760..ceb0940 100644 --- a/tests/test_filtering_utils.py +++ b/tests/test_filtering_utils.py @@ -24,22 +24,27 @@ get_max_freq, get_max_time, get_timezone, - intersection_or_union, read_dataframe, reshape_timebin, + intersection_or_union, ) + # %% find delimiter -@pytest.mark.parametrize(("delimiter", "rows"), [ - (",", [["a", "b", "c"], ["1", "2", "3"]]), - (";", [["x", "y", "z"], ["4", "5", "6"]]), -]) -def test_find_delimiter_valid(tmp_path: Path, - delimiter: str, - rows: list[list[str]], - ) -> None: +@pytest.mark.parametrize( + ("delimiter", "rows"), + [ + (",", [["a", "b", "c"], ["1", "2", "3"]]), + (";", [["x", "y", "z"], ["4", "5", "6"]]), + ], +) +def test_find_delimiter_valid( + tmp_path: Path, + delimiter: str, + rows: list[list[str]], +) -> None: file = tmp_path / "test.csv" with file.open("w", newline="") as f: writer = csv.writer(f, delimiter=delimiter) @@ -49,7 +54,9 @@ def test_find_delimiter_valid(tmp_path: Path, assert detected == delimiter -def test_find_delimiter_invalid(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None: +def test_find_delimiter_invalid( + tmp_path: Path, monkeypatch: pytest.MonkeyPatch +) -> None: file = tmp_path / "bad.csv" file.write_text("a,b,c") @@ -84,6 +91,7 @@ def test_find_delimiter_unsupported_delimiter(tmp_path: Path) -> None: # %% filter utils + # filter_by_time @pytest.mark.parametrize( "begin, end", @@ -173,13 +181,13 @@ def test_filter_by_label_invalid(sample_df: DataFrame) -> None: "f_min, f_max", [ pytest.param( - 500, # valid lower bound + 500, # valid lower bound None, id="valid_f_min_only", ), pytest.param( None, - 60000, # valid upper bound + 60000, # valid upper bound id="valid_f_max_only", ), pytest.param( @@ -337,6 +345,7 @@ def test_get_timezone_several(sample_df: DataFrame) -> None: assert pytz.UTC in tz assert pytz.FixedOffset(420) in tz + # %% read DataFrame @@ -361,7 +370,7 @@ def test_read_dataframe_drop_duplicates_and_na(tmp_path: Path) -> None: "start_datetime,end_datetime,annotation\n" "2025-01-01 12:00:00,2025-01-01 12:05:00,whale\n" "2025-01-01 12:00:00,2025-01-01 12:05:00,whale\n" # duplicate - "2025-01-01 13:00:00,2025-01-01 13:05:00,\n", # NaN annotation + "2025-01-01 13:00:00,2025-01-01 13:05:00,\n", # NaN annotation ) df = read_dataframe(csv_file) @@ -395,6 +404,7 @@ def test_read_dataframe_nrows(tmp_path: Path) -> None: # %% reshape_timebin + def test_no_timebin_returns_original(sample_df: DataFrame) -> None: df_out = reshape_timebin(sample_df, timebin_new=None, timestamp_audio=None) assert df_out.equals(sample_df) @@ -419,8 +429,9 @@ def test_no_timebin_several_tz(sample_df: DataFrame) -> None: [sample_df, DataFrame([new_row])], ignore_index=False, ) - timestamp_wav = to_datetime(sample_df["filename"], - format="%Y_%m_%d_%H_%M_%S").dt.tz_localize(pytz.UTC) + timestamp_wav = to_datetime( + sample_df["filename"], format="%Y_%m_%d_%H_%M_%S" + ).dt.tz_localize(pytz.UTC) df_out = reshape_timebin(sample_df, timestamp_audio=timestamp_wav, timebin_new=None) assert df_out.equals(sample_df) @@ -482,7 +493,6 @@ def test_no_timebin_original_timebin(sample_df: DataFrame) -> None: "lbl2", "lbl2", "lbl1", - ], "annotator": [ "ann1", @@ -503,12 +513,11 @@ def test_no_timebin_original_timebin(sample_df: DataFrame) -> None: "ann3", "ann4", "ann5", - ], - "start_datetime": [Timestamp("2025-01-25 06:20:00+00:00")] * 11 + - [Timestamp("2025-01-26 06:20:00+00:00")] * 7, - "end_datetime": [Timestamp("2025-01-25 06:21:00+00:00")] * 11 + - [Timestamp("2025-01-26 06:21:00+00:00")] * 7, + "start_datetime": [Timestamp("2025-01-25 06:20:00+00:00")] * 11 + + [Timestamp("2025-01-26 06:20:00+00:00")] * 7, + "end_datetime": [Timestamp("2025-01-25 06:21:00+00:00")] * 11 + + [Timestamp("2025-01-26 06:21:00+00:00")] * 7, "type": ["WEAK"] * 18, }, ) @@ -540,18 +549,22 @@ def test_reshape_daily_multiple_bins(sample_df: DataFrame) -> None: sample_df["filename"], format="%Y_%m_%d_%H_%M_%S", ).dt.tz_localize(tz) - df_out = reshape_timebin(sample_df, timestamp_audio=timestamp_wav, timebin_new=Timedelta(days=1)) + df_out = reshape_timebin( + sample_df, timestamp_audio=timestamp_wav, timebin_new=Timedelta(days=1) + ) assert not df_out.empty assert all(df_out["end_time"] == 86400.0) - assert df_out["start_datetime"].min() >= sample_df["start_datetime"].min().floor("D") + assert df_out["start_datetime"].min() >= sample_df["start_datetime"].min().floor( + "D" + ) assert df_out["end_datetime"].max() <= sample_df["end_datetime"].max().ceil("D") def test_with_manual_timestamps_vector(sample_df: DataFrame) -> None: - tz = get_timezone(sample_df) - timestamp_wav = to_datetime(sample_df["filename"], - format="%Y_%m_%d_%H_%M_%S").dt.tz_localize(tz) + timestamp_wav = to_datetime( + sample_df["filename"], format="%Y_%m_%d_%H_%M_%S" + ).dt.tz_localize(tz) df_out = reshape_timebin( sample_df, timestamp_audio=timestamp_wav, @@ -565,14 +578,18 @@ def test_with_manual_timestamps_vector(sample_df: DataFrame) -> None: def test_empty_result_when_no_matching(sample_df: DataFrame) -> None: tz = get_timezone(sample_df) - timestamp_wav = to_datetime(sample_df["filename"], - format="%Y_%m_%d_%H_%M_%S").dt.tz_localize(tz) + timestamp_wav = to_datetime( + sample_df["filename"], format="%Y_%m_%d_%H_%M_%S" + ).dt.tz_localize(tz) with pytest.raises(ValueError, match="DataFrame is empty"): - reshape_timebin(DataFrame(), timestamp_audio=timestamp_wav, timebin_new=Timedelta(hours=1)) + reshape_timebin( + DataFrame(), timestamp_audio=timestamp_wav, timebin_new=Timedelta(hours=1) + ) # %% ensure_no_invalid + def test_ensure_no_invalid_empty() -> None: try: ensure_no_invalid([], "label") @@ -597,31 +614,39 @@ def test_ensure_no_invalid_single_element() -> None: assert "baz" in str(exc_info.value) assert "features" in str(exc_info.value) + # %% intersection / union -def test_intersection(sample_df) -> None: - df_result = intersection_or_union(sample_df[sample_df["annotator"].isin(["ann1", "ann2"])], user_sel="intersection") +def test_intersection(sample_df: DataFrame) -> None: + df_result = intersection_or_union( + sample_df[sample_df["annotator"].isin(["ann1", "ann2"])], + user_sel="intersection", + ) - assert set(df_result["annotation"]) == {"lbl1", "lbl2"} + assert set(df_result["annotation"]) == {"lbl1 ∩ lbl2"} assert set(df_result["annotator"]) == {"ann1 ∩ ann2"} -def test_union(sample_df) -> None: - df_result = intersection_or_union(sample_df[sample_df["annotator"].isin(["ann1", "ann2"])], user_sel="union") +def test_union(sample_df: DataFrame) -> None: + df_result = intersection_or_union( + sample_df[sample_df["annotator"].isin(["ann1", "ann2"])], user_sel="union" + ) - assert set(df_result["annotation"]) == {"lbl1", "lbl2"} + assert set(df_result["annotation"]) == {"lbl1 ∪ lbl2"} assert set(df_result["annotator"]) == {"ann1 ∪ ann2"} -def test_all_user_sel_returns_original(sample_df) -> None: +def test_all_user_sel_returns_original(sample_df: DataFrame) -> None: df_result = intersection_or_union(sample_df, user_sel="all") assert len(df_result) == len(sample_df) -def test_invalid_user_sel_raises(sample_df) -> None: - with pytest.raises(ValueError, match="'user_sel' must be either 'intersection' or 'union'"): +def test_invalid_user_sel_raises(sample_df: DataFrame) -> None: + with pytest.raises( + ValueError, match="'user_sel' must be either 'intersection' or 'union'" + ): intersection_or_union(sample_df, user_sel="invalid") diff --git a/tests/test_metric_utils.py b/tests/test_metric_utils.py index 35717e7..48890c2 100644 --- a/tests/test_metric_utils.py +++ b/tests/test_metric_utils.py @@ -1,16 +1,102 @@ +from contextlib import nullcontext +from typing import ContextManager + import pytest +from _pytest.monkeypatch import MonkeyPatch +from numpy import array from pandas import DataFrame from post_processing.utils.metrics_utils import detection_perf -def test_detection_perf(sample_df: DataFrame) -> None: - try: - detection_perf(df=sample_df[sample_df["annotator"].isin(["ann1", "ann4"])], ref=("ann1", "lbl1")) - except ValueError: - pytest.fail("test_detection_perf raised ValueError unexpectedly.") +@pytest.mark.parametrize( + ("filter_annotator", "filter_annotation", "ref", "expected"), + [ + pytest.param( + ["ann1", "ann4"], + None, + ("ann1", "lbl1"), + nullcontext(), + id="no_timestamps_provided", + ), + pytest.param( + ["ann1"], + None, + ("ann1", "lbl1"), + pytest.raises(ValueError, match="Two annotators needed"), + id="one_annotator_provided", + ), + pytest.param( + ["ann1", "ann6"], + ["lbl1"], + ("ann1", "lbl6"), + pytest.raises(ValueError, match="No detection found for ann1/lbl6"), + id="empty_ref_df", + ), + ], +) +def test_detection_perf( + sample_df: DataFrame, + filter_annotator: list[str, str], + filter_annotation: list[str, str], + ref: tuple[str, str], + expected: ContextManager[Exception], +) -> None: + filtered_df = sample_df[sample_df["type"] == "WEAK"] + if filter_annotator: + filtered_df = filtered_df[filtered_df["annotator"].isin(filter_annotator)] + if filter_annotation: + filtered_df = filtered_df[filtered_df["annotation"].isin(filter_annotation)] + + with expected: + detection_perf(df=filtered_df, ref=ref) + + +def test_detection_perf_confusion_matrix_errors( + monkeypatch: MonkeyPatch, + sample_df: DataFrame, +) -> None: + def fake_get_count(*args, **kwargs) -> DataFrame: + return DataFrame({ + "lbl2-ann1": array([1, 1, 1, 0, 666, 1]), + "lbl2-ann2": array([1, 0, 2, 1, 0, 1234]), + }) + + monkeypatch.setattr("post_processing.utils.metrics_utils.get_count", fake_get_count) + + filtered_df = sample_df[ + (sample_df["annotation"] == "lbl2") + & (sample_df["annotator"].isin(["ann1", "ann2"])) + & (sample_df["type"] == "WEAK") + ] + + with pytest.raises(ValueError, match="3 errors in metric computation"): + detection_perf( + filtered_df, + ref=("ann1", "lbl2"), + ) + + +def test_detection_perf_confusion_matrix_no_data( + monkeypatch: MonkeyPatch, + sample_df: DataFrame, +) -> None: + def fake_get_count(*args, **kwargs) -> DataFrame: + return DataFrame({ + "lbl2-ann1": array([0] * 10), + "lbl2-ann2": array([0] * 10), + }) + + monkeypatch.setattr("post_processing.utils.metrics_utils.get_count", fake_get_count) + filtered_df = sample_df[ + (sample_df["annotation"] == "lbl2") + & (sample_df["annotator"].isin(["ann1", "ann2"])) + & (sample_df["type"] == "WEAK") + ] -def test_detection_perf_one_annotator(sample_df: DataFrame) -> None: - with pytest.raises(ValueError, match="Two annotators needed"): - detection_perf(df=sample_df[sample_df["annotator"] == "ann1"], ref=("ann1", "lbl1")) + with pytest.raises(ValueError, match="Precision/Recall computation impossible"): + detection_perf( + filtered_df, + ref=("ann1", "lbl2"), + ) diff --git a/uv.lock b/uv.lock index 771c085..18e95a0 100644 --- a/uv.lock +++ b/uv.lock @@ -254,6 +254,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/7c/fc/6a8cb64e5f0324877d503c854da15d76c1e50eb722e320b15345c4d0c6de/cffi-1.17.1-cp313-cp313-win_amd64.whl", hash = "sha256:f6a16c31041f09ead72d69f583767292f750d24913dadacf5756b966aacb3f1a", size = 182009, upload-time = "2024-09-04T20:44:45.309Z" }, ] +[[package]] +name = "cfgv" +version = "3.5.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/4e/b5/721b8799b04bf9afe054a3899c6cf4e880fcf8563cc71c15610242490a0c/cfgv-3.5.0.tar.gz", hash = "sha256:d5b1034354820651caa73ede66a6294d6e95c1b00acc5e9b098e917404669132", size = 7334, upload-time = "2025-11-19T20:55:51.612Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/db/3c/33bac158f8ab7f89b2e59426d5fe2e4f63f7ed25df84c036890172b412b5/cfgv-3.5.0-py2.py3-none-any.whl", hash = "sha256:a8dc6b26ad22ff227d2634a65cb388215ce6cc96bbcc5cfde7641ae87e8dacc0", size = 7445, upload-time = "2025-11-19T20:55:50.744Z" }, +] + [[package]] name = "charset-normalizer" version = "3.4.3" @@ -485,6 +494,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/07/6c/aa3f2f849e01cb6a001cd8554a88d4c77c5c1a31c95bdf1cf9301e6d9ef4/defusedxml-0.7.1-py2.py3-none-any.whl", hash = "sha256:a352e7e428770286cc899e2542b6cdaedb2b4953ff269a210103ec58f6198a61", size = 25604, upload-time = "2021-03-08T10:59:24.45Z" }, ] +[[package]] +name = "distlib" +version = "0.4.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/96/8e/709914eb2b5749865801041647dc7f4e6d00b549cfe88b65ca192995f07c/distlib-0.4.0.tar.gz", hash = "sha256:feec40075be03a04501a973d81f633735b4b69f98b05450592310c0f401a4e0d", size = 614605, upload-time = "2025-07-17T16:52:00.465Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/33/6b/e0547afaf41bf2c42e52430072fa5658766e3d65bd4b03a563d1b6336f57/distlib-0.4.0-py2.py3-none-any.whl", hash = "sha256:9659f7d87e46584a30b5780e43ac7a2143098441670ff0a49d5f9034c54a6c16", size = 469047, upload-time = "2025-07-17T16:51:58.613Z" }, +] + [[package]] name = "docutils" version = "0.21.2" @@ -521,6 +539,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/cb/a8/20d0723294217e47de6d9e2e40fd4a9d2f7c4b6ef974babd482a59743694/fastjsonschema-2.21.2-py3-none-any.whl", hash = "sha256:1c797122d0a86c5cace2e54bf4e819c36223b552017172f32c5c024a6b77e463", size = 24024, upload-time = "2025-08-14T18:49:34.776Z" }, ] +[[package]] +name = "filelock" +version = "3.20.3" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/1d/65/ce7f1b70157833bf3cb851b556a37d4547ceafc158aa9b34b36782f23696/filelock-3.20.3.tar.gz", hash = "sha256:18c57ee915c7ec61cff0ecf7f0f869936c7c30191bb0cf406f1341778d0834e1", size = 19485, upload-time = "2026-01-09T17:55:05.421Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b5/36/7fb70f04bf00bc646cd5bb45aa9eddb15e19437a28b8fb2b4a5249fac770/filelock-3.20.3-py3-none-any.whl", hash = "sha256:4b0dda527ee31078689fc205ec4f1c1bf7d56cf88b6dc9426c4f230e46c2dce1", size = 16701, upload-time = "2026-01-09T17:55:04.334Z" }, +] + [[package]] name = "fonttools" version = "4.58.5" @@ -578,6 +605,8 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/19/0d/6660d55f7373b2ff8152401a83e02084956da23ae58cddbfb0b330978fe9/greenlet-3.2.4-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:3b3812d8d0c9579967815af437d96623f45c0f2ae5f04e366de62a12d83a8fb0", size = 607586, upload-time = "2025-08-07T13:18:28.544Z" }, { url = "https://files.pythonhosted.org/packages/8e/1a/c953fdedd22d81ee4629afbb38d2f9d71e37d23caace44775a3a969147d4/greenlet-3.2.4-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:abbf57b5a870d30c4675928c37278493044d7c14378350b3aa5d484fa65575f0", size = 1123281, upload-time = "2025-08-07T13:42:39.858Z" }, { url = "https://files.pythonhosted.org/packages/3f/c7/12381b18e21aef2c6bd3a636da1088b888b97b7a0362fac2e4de92405f97/greenlet-3.2.4-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:20fb936b4652b6e307b8f347665e2c615540d4b42b3b4c8a321d8286da7e520f", size = 1151142, upload-time = "2025-08-07T13:18:22.981Z" }, + { url = "https://files.pythonhosted.org/packages/27/45/80935968b53cfd3f33cf99ea5f08227f2646e044568c9b1555b58ffd61c2/greenlet-3.2.4-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:ee7a6ec486883397d70eec05059353b8e83eca9168b9f3f9a361971e77e0bcd0", size = 1564846, upload-time = "2025-11-04T12:42:15.191Z" }, + { url = "https://files.pythonhosted.org/packages/69/02/b7c30e5e04752cb4db6202a3858b149c0710e5453b71a3b2aec5d78a1aab/greenlet-3.2.4-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:326d234cbf337c9c3def0676412eb7040a35a768efc92504b947b3e9cfc7543d", size = 1633814, upload-time = "2025-11-04T12:42:17.175Z" }, { url = "https://files.pythonhosted.org/packages/e9/08/b0814846b79399e585f974bbeebf5580fbe59e258ea7be64d9dfb253c84f/greenlet-3.2.4-cp312-cp312-win_amd64.whl", hash = "sha256:a7d4e128405eea3814a12cc2605e0e6aedb4035bf32697f72deca74de4105e02", size = 299899, upload-time = "2025-08-07T13:38:53.448Z" }, { url = "https://files.pythonhosted.org/packages/49/e8/58c7f85958bda41dafea50497cbd59738c5c43dbbea5ee83d651234398f4/greenlet-3.2.4-cp313-cp313-macosx_11_0_universal2.whl", hash = "sha256:1a921e542453fe531144e91e1feedf12e07351b1cf6c9e8a3325ea600a715a31", size = 272814, upload-time = "2025-08-07T13:15:50.011Z" }, { url = "https://files.pythonhosted.org/packages/62/dd/b9f59862e9e257a16e4e610480cfffd29e3fae018a68c2332090b53aac3d/greenlet-3.2.4-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:cd3c8e693bff0fff6ba55f140bf390fa92c994083f838fece0f63be121334945", size = 641073, upload-time = "2025-08-07T13:42:57.23Z" }, @@ -587,6 +616,8 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/ee/43/3cecdc0349359e1a527cbf2e3e28e5f8f06d3343aaf82ca13437a9aa290f/greenlet-3.2.4-cp313-cp313-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:23768528f2911bcd7e475210822ffb5254ed10d71f4028387e5a99b4c6699671", size = 610497, upload-time = "2025-08-07T13:18:31.636Z" }, { url = "https://files.pythonhosted.org/packages/b8/19/06b6cf5d604e2c382a6f31cafafd6f33d5dea706f4db7bdab184bad2b21d/greenlet-3.2.4-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:00fadb3fedccc447f517ee0d3fd8fe49eae949e1cd0f6a611818f4f6fb7dc83b", size = 1121662, upload-time = "2025-08-07T13:42:41.117Z" }, { url = "https://files.pythonhosted.org/packages/a2/15/0d5e4e1a66fab130d98168fe984c509249c833c1a3c16806b90f253ce7b9/greenlet-3.2.4-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:d25c5091190f2dc0eaa3f950252122edbbadbb682aa7b1ef2f8af0f8c0afefae", size = 1149210, upload-time = "2025-08-07T13:18:24.072Z" }, + { url = "https://files.pythonhosted.org/packages/1c/53/f9c440463b3057485b8594d7a638bed53ba531165ef0ca0e6c364b5cc807/greenlet-3.2.4-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:6e343822feb58ac4d0a1211bd9399de2b3a04963ddeec21530fc426cc121f19b", size = 1564759, upload-time = "2025-11-04T12:42:19.395Z" }, + { url = "https://files.pythonhosted.org/packages/47/e4/3bb4240abdd0a8d23f4f88adec746a3099f0d86bfedb623f063b2e3b4df0/greenlet-3.2.4-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:ca7f6f1f2649b89ce02f6f229d7c19f680a6238af656f61e0115b24857917929", size = 1634288, upload-time = "2025-11-04T12:42:21.174Z" }, { url = "https://files.pythonhosted.org/packages/0b/55/2321e43595e6801e105fcfdee02b34c0f996eb71e6ddffca6b10b7e1d771/greenlet-3.2.4-cp313-cp313-win_amd64.whl", hash = "sha256:554b03b6e73aaabec3745364d6239e9e012d64c68ccd0b8430c64ccc14939a8b", size = 299685, upload-time = "2025-08-07T13:24:38.824Z" }, { url = "https://files.pythonhosted.org/packages/22/5c/85273fd7cc388285632b0498dbbab97596e04b154933dfe0f3e68156c68c/greenlet-3.2.4-cp314-cp314-macosx_11_0_universal2.whl", hash = "sha256:49a30d5fda2507ae77be16479bdb62a660fa51b1eb4928b524975b3bde77b3c0", size = 273586, upload-time = "2025-08-07T13:16:08.004Z" }, { url = "https://files.pythonhosted.org/packages/d1/75/10aeeaa3da9332c2e761e4c50d4c3556c21113ee3f0afa2cf5769946f7a3/greenlet-3.2.4-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:299fd615cd8fc86267b47597123e3f43ad79c9d8a22bebdce535e53550763e2f", size = 686346, upload-time = "2025-08-07T13:42:59.944Z" }, @@ -594,6 +625,8 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/dc/8b/29aae55436521f1d6f8ff4e12fb676f3400de7fcf27fccd1d4d17fd8fecd/greenlet-3.2.4-cp314-cp314-manylinux2014_s390x.manylinux_2_17_s390x.whl", hash = "sha256:b4a1870c51720687af7fa3e7cda6d08d801dae660f75a76f3845b642b4da6ee1", size = 694659, upload-time = "2025-08-07T13:53:17.759Z" }, { url = "https://files.pythonhosted.org/packages/92/2e/ea25914b1ebfde93b6fc4ff46d6864564fba59024e928bdc7de475affc25/greenlet-3.2.4-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:061dc4cf2c34852b052a8620d40f36324554bc192be474b9e9770e8c042fd735", size = 695355, upload-time = "2025-08-07T13:18:34.517Z" }, { url = "https://files.pythonhosted.org/packages/72/60/fc56c62046ec17f6b0d3060564562c64c862948c9d4bc8aa807cf5bd74f4/greenlet-3.2.4-cp314-cp314-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:44358b9bf66c8576a9f57a590d5f5d6e72fa4228b763d0e43fee6d3b06d3a337", size = 657512, upload-time = "2025-08-07T13:18:33.969Z" }, + { url = "https://files.pythonhosted.org/packages/23/6e/74407aed965a4ab6ddd93a7ded3180b730d281c77b765788419484cdfeef/greenlet-3.2.4-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:2917bdf657f5859fbf3386b12d68ede4cf1f04c90c3a6bc1f013dd68a22e2269", size = 1612508, upload-time = "2025-11-04T12:42:23.427Z" }, + { url = "https://files.pythonhosted.org/packages/0d/da/343cd760ab2f92bac1845ca07ee3faea9fe52bee65f7bcb19f16ad7de08b/greenlet-3.2.4-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:015d48959d4add5d6c9f6c5210ee3803a830dce46356e3bc326d6776bde54681", size = 1680760, upload-time = "2025-11-04T12:42:25.341Z" }, { url = "https://files.pythonhosted.org/packages/e3/a5/6ddab2b4c112be95601c13428db1d8b6608a8b6039816f2ba09c346c08fc/greenlet-3.2.4-cp314-cp314-win_amd64.whl", hash = "sha256:e37ab26028f12dbb0ff65f29a8d3d44a765c61e729647bf2ddfbbed621726f01", size = 303425, upload-time = "2025-08-07T13:32:27.59Z" }, ] @@ -634,6 +667,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/2a/39/e50c7c3a983047577ee07d2a9e53faf5a69493943ec3f6a384bdc792deb2/httpx-0.28.1-py3-none-any.whl", hash = "sha256:d909fcccc110f8c7faf814ca82a9a4d816bc5a6dbfea25d6591d6985b8ba59ad", size = 73517, upload-time = "2024-12-06T15:37:21.509Z" }, ] +[[package]] +name = "identify" +version = "2.6.16" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/5b/8d/e8b97e6bd3fb6fb271346f7981362f1e04d6a7463abd0de79e1fda17c067/identify-2.6.16.tar.gz", hash = "sha256:846857203b5511bbe94d5a352a48ef2359532bc8f6727b5544077a0dcfb24980", size = 99360, upload-time = "2026-01-12T18:58:58.201Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b8/58/40fbbcefeda82364720eba5cf2270f98496bdfa19ea75b4cccae79c698e6/identify-2.6.16-py2.py3-none-any.whl", hash = "sha256:391ee4d77741d994189522896270b787aed8670389bfd60f326d677d64a6dfb0", size = 99202, upload-time = "2026-01-12T18:58:56.627Z" }, +] + [[package]] name = "idna" version = "3.10" @@ -1289,6 +1331,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/a0/c4/c2971a3ba4c6103a3d10c4b0f24f461ddc027f0f09763220cf35ca1401b3/nest_asyncio-1.6.0-py3-none-any.whl", hash = "sha256:87af6efd6b5e897c81050477ef65c62e2b2f35d51703cae01aff2905b1852e1c", size = 5195, upload-time = "2024-01-21T14:25:17.223Z" }, ] +[[package]] +name = "nodeenv" +version = "1.10.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/24/bf/d1bda4f6168e0b2e9e5958945e01910052158313224ada5ce1fb2e1113b8/nodeenv-1.10.0.tar.gz", hash = "sha256:996c191ad80897d076bdfba80a41994c2b47c68e224c542b48feba42ba00f8bb", size = 55611, upload-time = "2025-12-20T14:08:54.006Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/88/b2/d0896bdcdc8d28a7fc5717c305f1a861c26e18c05047949fb371034d98bd/nodeenv-1.10.0-py2.py3-none-any.whl", hash = "sha256:5bb13e3eed2923615535339b3c620e76779af4cb4c6a90deccc9e36b274d3827", size = 23438, upload-time = "2025-12-20T14:08:52.782Z" }, +] + [[package]] name = "notebook" version = "7.4.5" @@ -1573,6 +1624,7 @@ dev = [ { name = "coverage" }, { name = "myst-nb" }, { name = "notebook" }, + { name = "pre-commit" }, { name = "pytest" }, { name = "ruff" }, { name = "setuptools" }, @@ -1605,6 +1657,7 @@ dev = [ { name = "coverage", specifier = ">=7.11.0" }, { name = "myst-nb", specifier = ">=1.3.0" }, { name = "notebook", specifier = ">=7.4.4" }, + { name = "pre-commit", specifier = ">=4.5.1" }, { name = "pytest" }, { name = "ruff", specifier = ">=0.11.10" }, { name = "setuptools", specifier = "<66.2" }, @@ -1613,6 +1666,22 @@ dev = [ { name = "sphinx-copybutton", specifier = ">=0.5.2" }, ] +[[package]] +name = "pre-commit" +version = "4.5.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "cfgv" }, + { name = "identify" }, + { name = "nodeenv" }, + { name = "pyyaml" }, + { name = "virtualenv" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/40/f1/6d86a29246dfd2e9b6237f0b5823717f60cad94d47ddc26afa916d21f525/pre_commit-4.5.1.tar.gz", hash = "sha256:eb545fcff725875197837263e977ea257a402056661f09dae08e4b149b030a61", size = 198232, upload-time = "2025-12-16T21:14:33.552Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/5d/19/fd3ef348460c80af7bb4669ea7926651d1f95c23ff2df18b9d24bab4f3fa/pre_commit-4.5.1-py2.py3-none-any.whl", hash = "sha256:3b3afd891e97337708c1674210f8eba659b52a38ea5f822ff142d10786221f77", size = 226437, upload-time = "2025-12-16T21:14:32.409Z" }, +] + [[package]] name = "prometheus-client" version = "0.22.1" @@ -2465,6 +2534,20 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/a7/c2/fe1e52489ae3122415c51f387e221dd0773709bad6c6cdaa599e8a2c5185/urllib3-2.5.0-py3-none-any.whl", hash = "sha256:e6b01673c0fa6a13e374b50871808eb3bf7046c4b125b216f6bf1cc604cff0dc", size = 129795, upload-time = "2025-06-18T14:07:40.39Z" }, ] +[[package]] +name = "virtualenv" +version = "20.36.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "distlib" }, + { name = "filelock" }, + { name = "platformdirs" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/aa/a3/4d310fa5f00863544e1d0f4de93bddec248499ccf97d4791bc3122c9d4f3/virtualenv-20.36.1.tar.gz", hash = "sha256:8befb5c81842c641f8ee658481e42641c68b5eab3521d8e092d18320902466ba", size = 6032239, upload-time = "2026-01-09T18:21:01.296Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/6a/2a/dc2228b2888f51192c7dc766106cd475f1b768c10caaf9727659726f7391/virtualenv-20.36.1-py3-none-any.whl", hash = "sha256:575a8d6b124ef88f6f51d56d656132389f961062a9177016a50e4f507bbcc19f", size = 6008258, upload-time = "2026-01-09T18:20:59.425Z" }, +] + [[package]] name = "wcwidth" version = "0.2.13"