From ad37fc1bc29654328ded013ef8240bb1549924c2 Mon Sep 17 00:00:00 2001 From: Shruti Patel Date: Thu, 29 Jan 2026 08:39:29 -0800 Subject: [PATCH] Show all healthcheck cards sorted by severity and priority Summary: Previously, the overview analysis only displayed non-passing healthcheck cards. This change shows all healthcheck cards, sorted by severity and priority to surface the most important information first. The sorting order is: 1. ErrorAnalysisCard (errors during computation) 2. FAIL status 3. WARNING status 4. PASS status with priority (BaselineImprovementAnalysis, EarlyStoppingAnalysis - these provide valuable progress metrics even when passing) 5. PASS status (rest) This gives users visibility into the full health of their experiment while keeping critical issues at the top. Differential Revision: D91750384 --- .../healthcheck/healthcheck_analysis.py | 58 ++++++++++++++++++- .../tests/test_healthcheck_analysis.py | 51 ++++++++++++++++ ax/analysis/overview.py | 15 ++--- 3 files changed, 112 insertions(+), 12 deletions(-) create mode 100644 ax/analysis/healthcheck/tests/test_healthcheck_analysis.py diff --git a/ax/analysis/healthcheck/healthcheck_analysis.py b/ax/analysis/healthcheck/healthcheck_analysis.py index 22b5e638f1b..1e3eac451cf 100644 --- a/ax/analysis/healthcheck/healthcheck_analysis.py +++ b/ax/analysis/healthcheck/healthcheck_analysis.py @@ -5,11 +5,14 @@ # pyre-strict +from __future__ import annotations + import json from enum import IntEnum import pandas as pd -from ax.core.analysis_card import AnalysisCard +from ax.analysis.analysis import ErrorAnalysisCard +from ax.core.analysis_card import AnalysisCard, AnalysisCardBase class HealthcheckStatus(IntEnum): @@ -18,6 +21,13 @@ class HealthcheckStatus(IntEnum): WARNING = 2 +# Healthchecks that provide valuable progress info even when passing +PRIORITY_HEALTHCHECKS: set[str] = { + "BaselineImprovementAnalysis", + "EarlyStoppingAnalysis", +} + + class HealthcheckAnalysisCard(AnalysisCard): def get_status(self) -> HealthcheckStatus: return HealthcheckStatus(json.loads(self.blob)["status"]) @@ -49,3 +59,49 @@ def create_healthcheck_analysis_card( } ), ) + + +# Status order for sorting: FAIL first, then WARNING, then PASS +_STATUS_SORT_ORDER: dict[HealthcheckStatus, int] = { + HealthcheckStatus.FAIL: 1, + HealthcheckStatus.WARNING: 2, + HealthcheckStatus.PASS: 3, +} + + +def sort_healthcheck_cards( + cards: list[AnalysisCardBase], +) -> list[AnalysisCardBase]: + """ + Sort healthcheck cards by severity and priority. + + Order: + 1. ErrorAnalysisCard (errors during computation) + 2. FAIL status + 3. WARNING status + 4. PASS status with priority (BaselineImprovement, EarlyStopping, etc.) + 5. PASS status (rest) + + Args: + cards: List of analysis cards (typically HealthcheckAnalysisCard or + ErrorAnalysisCard instances). + + Returns: + Sorted list of cards. + """ + + def sort_key(card: AnalysisCardBase) -> tuple[int, int, str]: + if isinstance(card, ErrorAnalysisCard): + return (0, 0, card.name) + + if isinstance(card, HealthcheckAnalysisCard): + return ( + _STATUS_SORT_ORDER[card.get_status()], + 0 if card.name in PRIORITY_HEALTHCHECKS else 1, + card.name, + ) + + # Fallback for type safety (unreachable in practice) + return (4, 1, card.name) + + return sorted(cards, key=sort_key) diff --git a/ax/analysis/healthcheck/tests/test_healthcheck_analysis.py b/ax/analysis/healthcheck/tests/test_healthcheck_analysis.py new file mode 100644 index 00000000000..07639b71eab --- /dev/null +++ b/ax/analysis/healthcheck/tests/test_healthcheck_analysis.py @@ -0,0 +1,51 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +import pandas as pd +from ax.analysis.analysis import ErrorAnalysisCard +from ax.analysis.healthcheck.healthcheck_analysis import ( + create_healthcheck_analysis_card, + HealthcheckStatus, + sort_healthcheck_cards, +) +from ax.core.analysis_card import AnalysisCardBase +from ax.utils.common.testutils import TestCase + + +def _card(name: str, status: HealthcheckStatus) -> AnalysisCardBase: + return create_healthcheck_analysis_card( + name=name, title=name, subtitle=name, df=pd.DataFrame(), status=status + ) + + +def _error(name: str) -> AnalysisCardBase: + return ErrorAnalysisCard( + name=name, title=name, subtitle=name, df=pd.DataFrame(), blob="" + ) + + +class TestHealthcheckAnalysis(TestCase): + def test_sort_ordering(self) -> None: + cards: list[AnalysisCardBase] = [ + _card("RegularAnalysis", HealthcheckStatus.PASS), + _card("WarningAnalysis", HealthcheckStatus.WARNING), + _error("ErrorAnalysis"), + _card("BaselineImprovementAnalysis", HealthcheckStatus.PASS), + _card("FailAnalysis", HealthcheckStatus.FAIL), + ] + result = sort_healthcheck_cards(cards) + + self.assertEqual( + [c.name for c in result], + [ + "ErrorAnalysis", + "FailAnalysis", + "WarningAnalysis", + "BaselineImprovementAnalysis", + "RegularAnalysis", + ], + ) diff --git a/ax/analysis/overview.py b/ax/analysis/overview.py index d31b8674fc3..bb0f041ef77 100644 --- a/ax/analysis/overview.py +++ b/ax/analysis/overview.py @@ -8,7 +8,7 @@ from typing import Any, final from ax.adapter.base import Adapter -from ax.analysis.analysis import Analysis, ErrorAnalysisCard +from ax.analysis.analysis import Analysis from ax.analysis.diagnostics import DiagnosticAnalysis from ax.analysis.healthcheck.baseline_improvement import BaselineImprovementAnalysis from ax.analysis.healthcheck.can_generate_candidates import ( @@ -19,7 +19,7 @@ ConstraintsFeasibilityAnalysis, ) from ax.analysis.healthcheck.early_stopping_healthcheck import EarlyStoppingAnalysis -from ax.analysis.healthcheck.healthcheck_analysis import HealthcheckAnalysisCard +from ax.analysis.healthcheck.healthcheck_analysis import sort_healthcheck_cards from ax.analysis.healthcheck.metric_fetching_errors import MetricFetchingErrorsAnalysis from ax.analysis.healthcheck.predictable_metrics import PredictableMetricsAnalysis from ax.analysis.healthcheck.search_space_analysis import SearchSpaceAnalysis @@ -247,21 +247,14 @@ def compute( if analyis is not None ] - non_passing_health_checks = [ - card - for card in health_check_cards - if (isinstance(card, HealthcheckAnalysisCard) and not card.is_passing()) - or isinstance(card, ErrorAnalysisCard) - ] - health_checks_group = ( AnalysisCardGroup( name="HealthchecksAnalysis", title=HEALTH_CHECK_CARDGROUP_TITLE, subtitle=HEALTH_CHECK_CARDGROUP_SUBTITLE, - children=non_passing_health_checks, + children=sort_healthcheck_cards(health_check_cards), ) - if len(non_passing_health_checks) > 0 + if len(health_check_cards) > 0 else None )