|
| 1 | +import pytest |
| 2 | +import os |
| 3 | +import tempfile |
| 4 | +from unittest.mock import patch, Mock |
| 5 | +from click.testing import CliRunner |
| 6 | + |
| 7 | +from paths_cli.commands.md import * |
| 8 | + |
| 9 | +import openpathsampling as paths |
| 10 | + |
| 11 | +from openpathsampling.tests.test_helpers import \ |
| 12 | + make_1d_traj, CalvinistDynamics |
| 13 | + |
| 14 | + |
| 15 | +class TestEnsembleSatisfiedContinueConditions(object): |
| 16 | + def setup(self): |
| 17 | + cv = paths.CoordinateFunctionCV('x', lambda x: x.xyz[0][0]) |
| 18 | + vol_A = paths.CVDefinedVolume(cv, float("-inf"), 0.0) |
| 19 | + vol_B = paths.CVDefinedVolume(cv, 1.0, float("inf")) |
| 20 | + ensembles = [ |
| 21 | + paths.LengthEnsemble(1).named("len1"), |
| 22 | + paths.LengthEnsemble(3).named("len3"), |
| 23 | + paths.SequentialEnsemble([ |
| 24 | + paths.LengthEnsemble(1) & paths.AllInXEnsemble(vol_A), |
| 25 | + paths.AllOutXEnsemble(vol_A | vol_B), |
| 26 | + paths.LengthEnsemble(1) & paths.AllInXEnsemble(vol_A) |
| 27 | + ]).named('return'), |
| 28 | + paths.SequentialEnsemble([ |
| 29 | + paths.LengthEnsemble(1) & paths.AllInXEnsemble(vol_A), |
| 30 | + paths.AllOutXEnsemble(vol_A | vol_B), |
| 31 | + paths.LengthEnsemble(1) & paths.AllInXEnsemble(vol_B) |
| 32 | + ]).named('transition'), |
| 33 | + ] |
| 34 | + self.ensembles = {ens.name: ens for ens in ensembles} |
| 35 | + traj_vals = [-0.1, 1.1, 0.5, -0.2, 0.1, -0.3, 0.4, 1.4, -1.0] |
| 36 | + self.trajectory = make_1d_traj(traj_vals) |
| 37 | + self.engine = CalvinistDynamics(traj_vals) |
| 38 | + self.satisfied_when_traj_len = { |
| 39 | + "len1": 1, |
| 40 | + "len3": 3, |
| 41 | + "return": 6, |
| 42 | + "transition": 8, |
| 43 | + } |
| 44 | + self.conditions = EnsembleSatisfiedContinueConditions(ensembles) |
| 45 | + |
| 46 | + |
| 47 | + @pytest.mark.parametrize('trusted', [True, False]) |
| 48 | + @pytest.mark.parametrize('traj_len,expected', [ |
| 49 | + # expected = (num_calls, num_satisfied) |
| 50 | + (0, (1, 0)), |
| 51 | + (1, (2, 1)), |
| 52 | + (2, (3, 1)), |
| 53 | + (3, (3, 2)), |
| 54 | + (5, (2, 2)), |
| 55 | + (6, (3, 3)), |
| 56 | + (7, (1, 3)), |
| 57 | + (8, (3, 4)), |
| 58 | + ]) |
| 59 | + def test_call(self, traj_len, expected, trusted): |
| 60 | + if trusted: |
| 61 | + already_satisfied = [ |
| 62 | + self.ensembles[key] |
| 63 | + for key, val in self.satisfied_when_traj_len.items() |
| 64 | + if traj_len > val |
| 65 | + ] |
| 66 | + for ens in already_satisfied: |
| 67 | + self.conditions.satisfied[ens] = True |
| 68 | + |
| 69 | + traj = self.trajectory[:traj_len] |
| 70 | + mock = Mock(wraps=self.conditions._check_previous_frame) |
| 71 | + self.conditions._check_previous_frame = mock |
| 72 | + expected_calls, expected_satisfied = expected |
| 73 | + result = self.conditions(traj, trusted) |
| 74 | + assert result == (expected_satisfied != 4) |
| 75 | + assert sum(self.conditions.satisfied.values()) == expected_satisfied |
| 76 | + if trusted: |
| 77 | + # only test call count if we're trusted |
| 78 | + assert mock.call_count == expected_calls |
| 79 | + |
| 80 | + def test_generate(self): |
| 81 | + init_snap = self.trajectory[0] |
| 82 | + traj = self.engine.generate(init_snap, self.conditions) |
| 83 | + assert len(traj) == 8 |
0 commit comments