Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 47 additions & 28 deletions marimo/_utils/cell_matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,11 @@ def _hungarian_algorithm(scores: list[list[float]]) -> list[int]:
# Step 1: Subtract row minima
for i in range(n):
min_value = min(score_matrix[i])
row = score_matrix[i]
for j in range(n):
score_matrix[i][j] -= min_value
row[j] -= min_value

# Step 2: Subtract column minima

# Step 2: Subtract column minima
for j in range(n):
Expand All @@ -107,53 +110,69 @@ def _hungarian_algorithm(scores: list[list[float]]) -> list[int]:

# Find independent zeros
for i in range(n):
for j in range(n):
if (
score_matrix[i][j] == 0
and row_assignment[i] == -1
and col_assignment[j] == -1
):
row_assignment[i] = j
col_assignment[j] = i
row_assigned = row_assignment[i]
if row_assigned == -1:
for j in range(n):
if score_matrix[i][j] == 0 and col_assignment[j] == -1:
row_assignment[i] = j
col_assignment[j] = i
break

# Step 4: Improve assignment iteratively

# Step 4: Improve assignment iteratively
while True:
assigned_count = sum(1 for x in row_assignment if x != -1)
assigned_count = sum(x != -1 for x in row_assignment)
if assigned_count == n:
break

# Find minimum uncovered value

# Find minimum uncovered value (single loop: precompute covered rows/cols)
uncovered_rows = [i for i in range(n) if row_assignment[i] == -1]
uncovered_cols = [j for j in range(n) if col_assignment[j] == -1]
min_uncovered = float("inf")
for i in range(n):
for j in range(n):
if row_assignment[i] == -1 and col_assignment[j] == -1:
min_uncovered = min(min_uncovered, score_matrix[i][j])
for i in uncovered_rows:
row = score_matrix[i]
for j in uncovered_cols:
val = row[j]
if val < min_uncovered:
min_uncovered = val

if min_uncovered == float("inf"):
break

# Update matrix
for i in range(n):
for j in range(n):
if row_assignment[i] == -1 and col_assignment[j] == -1:
score_matrix[i][j] -= min_uncovered
elif row_assignment[i] != -1 and col_assignment[j] != -1:
score_matrix[i][j] += min_uncovered

# Try to find new assignments
for i in range(n):
# Update matrix (batch according to cover/uncover sets)
covered_rows = [i for i in range(n) if row_assignment[i] != -1]
covered_cols = [j for j in range(n) if col_assignment[j] != -1]

# Subtract min from uncovered positions
for i in uncovered_rows:
row = score_matrix[i]
for j in uncovered_cols:
row[j] -= min_uncovered
# Add min to covered positions
for i in covered_rows:
row = score_matrix[i]
for j in covered_cols:
row[j] += min_uncovered

# Try to find new assignments (avoid redundant checks)
for i in uncovered_rows:
if row_assignment[i] == -1:
for j in range(n):
if score_matrix[i][j] == 0 and col_assignment[j] == -1:
row = score_matrix[i]
for j in uncovered_cols:
if row[j] == 0 and col_assignment[j] == -1:
row_assignment[i] = j
col_assignment[j] = i
break

# Convert to result format
result = [-1] * n
for i in range(n):
if row_assignment[i] != -1:
result[row_assignment[i]] = i
assigned_col = row_assignment[i]
if assigned_col != -1:
result[assigned_col] = i

return result

Expand Down