diff --git a/marimo/_utils/cell_matching.py b/marimo/_utils/cell_matching.py index 2468eee6deb..1b203383a26 100644 --- a/marimo/_utils/cell_matching.py +++ b/marimo/_utils/cell_matching.py @@ -92,8 +92,11 @@ def _hungarian_algorithm(scores: list[list[float]]) -> list[int]: # Step 1: Subtract row minima for i in range(n): min_value = min(score_matrix[i]) + row = score_matrix[i] for j in range(n): - score_matrix[i][j] -= min_value + row[j] -= min_value + + # Step 2: Subtract column minima # Step 2: Subtract column minima for j in range(n): @@ -107,44 +110,59 @@ def _hungarian_algorithm(scores: list[list[float]]) -> list[int]: # Find independent zeros for i in range(n): - for j in range(n): - if ( - score_matrix[i][j] == 0 - and row_assignment[i] == -1 - and col_assignment[j] == -1 - ): - row_assignment[i] = j - col_assignment[j] = i + row_assigned = row_assignment[i] + if row_assigned == -1: + for j in range(n): + if score_matrix[i][j] == 0 and col_assignment[j] == -1: + row_assignment[i] = j + col_assignment[j] = i + break + + # Step 4: Improve assignment iteratively # Step 4: Improve assignment iteratively while True: - assigned_count = sum(1 for x in row_assignment if x != -1) + assigned_count = sum(x != -1 for x in row_assignment) if assigned_count == n: break # Find minimum uncovered value + + # Find minimum uncovered value (single loop: precompute covered rows/cols) + uncovered_rows = [i for i in range(n) if row_assignment[i] == -1] + uncovered_cols = [j for j in range(n) if col_assignment[j] == -1] min_uncovered = float("inf") - for i in range(n): - for j in range(n): - if row_assignment[i] == -1 and col_assignment[j] == -1: - min_uncovered = min(min_uncovered, score_matrix[i][j]) + for i in uncovered_rows: + row = score_matrix[i] + for j in uncovered_cols: + val = row[j] + if val < min_uncovered: + min_uncovered = val if min_uncovered == float("inf"): break - # Update matrix - for i in range(n): - for j in range(n): - if row_assignment[i] == -1 and col_assignment[j] == -1: - score_matrix[i][j] -= min_uncovered - elif row_assignment[i] != -1 and col_assignment[j] != -1: - score_matrix[i][j] += min_uncovered - - # Try to find new assignments - for i in range(n): + # Update matrix (batch according to cover/uncover sets) + covered_rows = [i for i in range(n) if row_assignment[i] != -1] + covered_cols = [j for j in range(n) if col_assignment[j] != -1] + + # Subtract min from uncovered positions + for i in uncovered_rows: + row = score_matrix[i] + for j in uncovered_cols: + row[j] -= min_uncovered + # Add min to covered positions + for i in covered_rows: + row = score_matrix[i] + for j in covered_cols: + row[j] += min_uncovered + + # Try to find new assignments (avoid redundant checks) + for i in uncovered_rows: if row_assignment[i] == -1: - for j in range(n): - if score_matrix[i][j] == 0 and col_assignment[j] == -1: + row = score_matrix[i] + for j in uncovered_cols: + if row[j] == 0 and col_assignment[j] == -1: row_assignment[i] = j col_assignment[j] = i break @@ -152,8 +170,9 @@ def _hungarian_algorithm(scores: list[list[float]]) -> list[int]: # Convert to result format result = [-1] * n for i in range(n): - if row_assignment[i] != -1: - result[row_assignment[i]] = i + assigned_col = row_assignment[i] + if assigned_col != -1: + result[assigned_col] = i return result