From ba062059d3424d5aa83c2e8dbcebc29dd7bf7751 Mon Sep 17 00:00:00 2001 From: architd Date: Mon, 15 Jun 2026 14:39:46 -0700 Subject: [PATCH 1/2] fix: preserve system IDs during inflight refill Signed-off-by: architd --- nvalchemi/dynamics/base.py | 11 +++++++++++ test/dynamics/test_inflight.py | 16 ++++++++++++++++ 2 files changed, 27 insertions(+) diff --git a/nvalchemi/dynamics/base.py b/nvalchemi/dynamics/base.py index 690ee2cc..1833bfa9 100644 --- a/nvalchemi/dynamics/base.py +++ b/nvalchemi/dynamics/base.py @@ -2047,6 +2047,17 @@ def refill_check(self, batch: Batch, exit_status: int) -> Batch | None: device = result.device for key, default_fn in self._bookkeeping_keys.items(): new_tensor = default_fn(n_total, device) + # Preserve values already carried by appended replacements before + # restoring the prefix for systems that stayed active. + result_vals = getattr(result, key, None) + if result_vals is not None: + result_vals = ( + result_vals.unsqueeze(-1) + if result_vals.dim() == 1 + else result_vals + ) + if result_vals.shape == new_tensor.shape: + new_tensor.copy_(result_vals) remaining_vals = getattr(batch, key, None) if remaining_vals is not None and n_remaining > 0: src = remaining_vals[remaining_indices] diff --git a/test/dynamics/test_inflight.py b/test/dynamics/test_inflight.py index 7a7da0ef..218f325f 100644 --- a/test/dynamics/test_inflight.py +++ b/test/dynamics/test_inflight.py @@ -635,6 +635,22 @@ def test_refill_writes_bookkeeping_to_storage(self) -> None: assert "status" in result + def test_refill_preserves_replacement_system_id(self) -> None: + """Replacement systems keep sampler-assigned system IDs after refill.""" + dataset = MockDataset([(1, 0)] * 3) + sampler = SizeAwareSampler(dataset, max_atoms=2) + dynamics = BaseDynamics(model=self.model, sampler=sampler, device_type="cpu") + + batch = sampler.build_initial_batch() + assert batch.system_id.view(-1).tolist() == [0, 1] + + batch["status"] = torch.tensor([[1], [1]]) + result = dynamics.refill_check(batch, exit_status=1) + + assert result is not None + assert result.system_id.view(-1).tolist() == [2] + assert result.status.view(-1).tolist() == [0] + def test_refill_partial_replacement(self) -> None: """When sampler has fewer replacements than graduated, batch shrinks. From dbcef64e40acfb82498aa297e2c6151d499a6f30 Mon Sep 17 00:00:00 2001 From: architd Date: Mon, 15 Jun 2026 15:56:37 -0700 Subject: [PATCH 2/2] fix: validate refill bookkeeping and cover mixed IDs Signed-off-by: architd --- nvalchemi/dynamics/base.py | 5 +++++ test/dynamics/test_inflight.py | 16 ++++++++++++++++ 2 files changed, 21 insertions(+) diff --git a/nvalchemi/dynamics/base.py b/nvalchemi/dynamics/base.py index 1833bfa9..082c0d12 100644 --- a/nvalchemi/dynamics/base.py +++ b/nvalchemi/dynamics/base.py @@ -2058,6 +2058,11 @@ def refill_check(self, batch: Batch, exit_status: int) -> Batch | None: ) if result_vals.shape == new_tensor.shape: new_tensor.copy_(result_vals) + else: + raise RuntimeError( + f"Bookkeeping key '{key}' has shape {result_vals.shape} " + f"after refill, expected {new_tensor.shape}." + ) remaining_vals = getattr(batch, key, None) if remaining_vals is not None and n_remaining > 0: src = remaining_vals[remaining_indices] diff --git a/test/dynamics/test_inflight.py b/test/dynamics/test_inflight.py index 218f325f..df8d115d 100644 --- a/test/dynamics/test_inflight.py +++ b/test/dynamics/test_inflight.py @@ -651,6 +651,22 @@ def test_refill_preserves_replacement_system_id(self) -> None: assert result.system_id.view(-1).tolist() == [2] assert result.status.view(-1).tolist() == [0] + def test_refill_preserves_mixed_system_ids(self) -> None: + """Remaining and replacement systems both keep their system IDs.""" + dataset = MockDataset([(1, 0)] * 3) + sampler = SizeAwareSampler(dataset, max_atoms=2) + dynamics = BaseDynamics(model=self.model, sampler=sampler, device_type="cpu") + + batch = sampler.build_initial_batch() + assert batch.system_id.view(-1).tolist() == [0, 1] + + batch["status"] = torch.tensor([[1], [0]]) + result = dynamics.refill_check(batch, exit_status=1) + + assert result is not None + assert result.system_id.view(-1).tolist() == [1, 2] + assert result.status.view(-1).tolist() == [0, 0] + def test_refill_partial_replacement(self) -> None: """When sampler has fewer replacements than graduated, batch shrinks.