diff --git a/ax/analysis/healthcheck/healthcheck_analysis.py b/ax/analysis/healthcheck/healthcheck_analysis.py index 22b5e638f1b..1e3eac451cf 100644 --- a/ax/analysis/healthcheck/healthcheck_analysis.py +++ b/ax/analysis/healthcheck/healthcheck_analysis.py @@ -5,11 +5,14 @@ # pyre-strict +from __future__ import annotations + import json from enum import IntEnum import pandas as pd -from ax.core.analysis_card import AnalysisCard +from ax.analysis.analysis import ErrorAnalysisCard +from ax.core.analysis_card import AnalysisCard, AnalysisCardBase class HealthcheckStatus(IntEnum): @@ -18,6 +21,13 @@ class HealthcheckStatus(IntEnum): WARNING = 2 +# Healthchecks that provide valuable progress info even when passing +PRIORITY_HEALTHCHECKS: set[str] = { + "BaselineImprovementAnalysis", + "EarlyStoppingAnalysis", +} + + class HealthcheckAnalysisCard(AnalysisCard): def get_status(self) -> HealthcheckStatus: return HealthcheckStatus(json.loads(self.blob)["status"]) @@ -49,3 +59,49 @@ def create_healthcheck_analysis_card( } ), ) + + +# Status order for sorting: FAIL first, then WARNING, then PASS +_STATUS_SORT_ORDER: dict[HealthcheckStatus, int] = { + HealthcheckStatus.FAIL: 1, + HealthcheckStatus.WARNING: 2, + HealthcheckStatus.PASS: 3, +} + + +def sort_healthcheck_cards( + cards: list[AnalysisCardBase], +) -> list[AnalysisCardBase]: + """ + Sort healthcheck cards by severity and priority. + + Order: + 1. ErrorAnalysisCard (errors during computation) + 2. FAIL status + 3. WARNING status + 4. PASS status with priority (BaselineImprovement, EarlyStopping, etc.) + 5. PASS status (rest) + + Args: + cards: List of analysis cards (typically HealthcheckAnalysisCard or + ErrorAnalysisCard instances). + + Returns: + Sorted list of cards. + """ + + def sort_key(card: AnalysisCardBase) -> tuple[int, int, str]: + if isinstance(card, ErrorAnalysisCard): + return (0, 0, card.name) + + if isinstance(card, HealthcheckAnalysisCard): + return ( + _STATUS_SORT_ORDER[card.get_status()], + 0 if card.name in PRIORITY_HEALTHCHECKS else 1, + card.name, + ) + + # Fallback for type safety (unreachable in practice) + return (4, 1, card.name) + + return sorted(cards, key=sort_key) diff --git a/ax/analysis/healthcheck/tests/test_healthcheck_analysis.py b/ax/analysis/healthcheck/tests/test_healthcheck_analysis.py new file mode 100644 index 00000000000..07639b71eab --- /dev/null +++ b/ax/analysis/healthcheck/tests/test_healthcheck_analysis.py @@ -0,0 +1,51 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +import pandas as pd +from ax.analysis.analysis import ErrorAnalysisCard +from ax.analysis.healthcheck.healthcheck_analysis import ( + create_healthcheck_analysis_card, + HealthcheckStatus, + sort_healthcheck_cards, +) +from ax.core.analysis_card import AnalysisCardBase +from ax.utils.common.testutils import TestCase + + +def _card(name: str, status: HealthcheckStatus) -> AnalysisCardBase: + return create_healthcheck_analysis_card( + name=name, title=name, subtitle=name, df=pd.DataFrame(), status=status + ) + + +def _error(name: str) -> AnalysisCardBase: + return ErrorAnalysisCard( + name=name, title=name, subtitle=name, df=pd.DataFrame(), blob="" + ) + + +class TestHealthcheckAnalysis(TestCase): + def test_sort_ordering(self) -> None: + cards: list[AnalysisCardBase] = [ + _card("RegularAnalysis", HealthcheckStatus.PASS), + _card("WarningAnalysis", HealthcheckStatus.WARNING), + _error("ErrorAnalysis"), + _card("BaselineImprovementAnalysis", HealthcheckStatus.PASS), + _card("FailAnalysis", HealthcheckStatus.FAIL), + ] + result = sort_healthcheck_cards(cards) + + self.assertEqual( + [c.name for c in result], + [ + "ErrorAnalysis", + "FailAnalysis", + "WarningAnalysis", + "BaselineImprovementAnalysis", + "RegularAnalysis", + ], + ) diff --git a/ax/analysis/overview.py b/ax/analysis/overview.py index d31b8674fc3..bb0f041ef77 100644 --- a/ax/analysis/overview.py +++ b/ax/analysis/overview.py @@ -8,7 +8,7 @@ from typing import Any, final from ax.adapter.base import Adapter -from ax.analysis.analysis import Analysis, ErrorAnalysisCard +from ax.analysis.analysis import Analysis from ax.analysis.diagnostics import DiagnosticAnalysis from ax.analysis.healthcheck.baseline_improvement import BaselineImprovementAnalysis from ax.analysis.healthcheck.can_generate_candidates import ( @@ -19,7 +19,7 @@ ConstraintsFeasibilityAnalysis, ) from ax.analysis.healthcheck.early_stopping_healthcheck import EarlyStoppingAnalysis -from ax.analysis.healthcheck.healthcheck_analysis import HealthcheckAnalysisCard +from ax.analysis.healthcheck.healthcheck_analysis import sort_healthcheck_cards from ax.analysis.healthcheck.metric_fetching_errors import MetricFetchingErrorsAnalysis from ax.analysis.healthcheck.predictable_metrics import PredictableMetricsAnalysis from ax.analysis.healthcheck.search_space_analysis import SearchSpaceAnalysis @@ -247,21 +247,14 @@ def compute( if analyis is not None ] - non_passing_health_checks = [ - card - for card in health_check_cards - if (isinstance(card, HealthcheckAnalysisCard) and not card.is_passing()) - or isinstance(card, ErrorAnalysisCard) - ] - health_checks_group = ( AnalysisCardGroup( name="HealthchecksAnalysis", title=HEALTH_CHECK_CARDGROUP_TITLE, subtitle=HEALTH_CHECK_CARDGROUP_SUBTITLE, - children=non_passing_health_checks, + children=sort_healthcheck_cards(health_check_cards), ) - if len(non_passing_health_checks) > 0 + if len(health_check_cards) > 0 else None )