Skip to content

Commit 98f3868

Browse files
committed
Add tests for md
1 parent fbe73da commit 98f3868

File tree

2 files changed

+135
-4
lines changed

2 files changed

+135
-4
lines changed

paths_cli/commands/md.py

Lines changed: 52 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@
66
from paths_cli.parameters import (INPUT_FILE, OUTPUT_FILE, ENGINE,
77
MULTI_ENSEMBLE, INIT_SNAP)
88

9+
import logging
10+
logger = logging.getLogger(__name__)
11+
912
@click.command(
1013
"md",
1114
short_help=("Run MD for fixed time or until a given ensemble is "
@@ -31,18 +34,63 @@ def md(input_file, output_file, engine, ensemble, nsteps, init_frame):
3134
)
3235

3336

37+
class EnsembleSatisfiedContinueConditions(object):
38+
def __init__(self, ensembles):
39+
self.satisfied = {ens: False for ens in ensembles}
40+
41+
def _check_previous_frame(self, trajectory, start, unsatisfied):
42+
# TODO: add some debug logging in here
43+
if -start > len(trajectory):
44+
# we've done the whole traj; don't keep going
45+
return False
46+
subtraj = trajectory[start:]
47+
logger.debug(str(subtraj) + "/" + str(trajectory))
48+
for ens in unsatisfied:
49+
if not ens.strict_can_prepend(subtraj, trusted=True):
50+
# test if we can't prepend because we satsify
51+
self.satisfied[ens] = ens(subtraj) or ens(subtraj[1:])
52+
unsatisfied.remove(ens)
53+
return bool(unsatisfied)
54+
55+
def _call_untrusted(self, trajectory):
56+
self.satisfied = {ens: False for ens in self.satisfied}
57+
for i in range(1, len(trajectory)):
58+
keep_going = self(trajectory[:i], trusted=True)
59+
if not keep_going:
60+
return False
61+
return self(trajectory, trusted=True)
62+
63+
def __call__(self, trajectory, trusted=False):
64+
if not trusted:
65+
return self._call_untrusted(trajectory)
66+
67+
# below here, trusted is True
68+
unsatisfied = [ens for ens, done in self.satisfied.items()
69+
if not done]
70+
# TODO: update on how many ensembles left, what frame number we are
71+
72+
if not unsatisfied:
73+
return False
74+
75+
start = -1
76+
while self._check_previous_frame(trajectory, start, unsatisfied):
77+
start -= 1
78+
79+
return not all(self.satisfied.values())
80+
81+
3482
def md_main(output_storage, engine, ensembles, nsteps, initial_frame):
3583
import openpathsampling as paths
3684
if nsteps is not None and ensembles:
3785
raise RuntimeError("Options --ensemble and --nsteps cannot both be"
3886
" used at once.")
3987

40-
if nsteps is not None:
41-
ens = paths.LengthEnsemble(nsteps)
88+
if ensembles:
89+
continue_cond = EnsembleSatisfiedContinueConditions(ensembles)
4290
else:
43-
ens = functools.reduce(operator.and_, ensembles)
91+
continue_cond = paths.LengthEnsemble(nsteps).can_append
4492

45-
trajectory = engine.generate(initial_frame, running=ens.can_append)
93+
trajectory = engine.generate(initial_frame, running=continue_cond)
4694
if output_storage is not None:
4795
output_storage.save(trajectory)
4896
output_storage.tags['final_conditions'] = trajectory
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
import pytest
2+
import os
3+
import tempfile
4+
from unittest.mock import patch, Mock
5+
from click.testing import CliRunner
6+
7+
from paths_cli.commands.md import *
8+
9+
import openpathsampling as paths
10+
11+
from openpathsampling.tests.test_helpers import \
12+
make_1d_traj, CalvinistDynamics
13+
14+
15+
class TestEnsembleSatisfiedContinueConditions(object):
16+
def setup(self):
17+
cv = paths.CoordinateFunctionCV('x', lambda x: x.xyz[0][0])
18+
vol_A = paths.CVDefinedVolume(cv, float("-inf"), 0.0)
19+
vol_B = paths.CVDefinedVolume(cv, 1.0, float("inf"))
20+
ensembles = [
21+
paths.LengthEnsemble(1).named("len1"),
22+
paths.LengthEnsemble(3).named("len3"),
23+
paths.SequentialEnsemble([
24+
paths.LengthEnsemble(1) & paths.AllInXEnsemble(vol_A),
25+
paths.AllOutXEnsemble(vol_A | vol_B),
26+
paths.LengthEnsemble(1) & paths.AllInXEnsemble(vol_A)
27+
]).named('return'),
28+
paths.SequentialEnsemble([
29+
paths.LengthEnsemble(1) & paths.AllInXEnsemble(vol_A),
30+
paths.AllOutXEnsemble(vol_A | vol_B),
31+
paths.LengthEnsemble(1) & paths.AllInXEnsemble(vol_B)
32+
]).named('transition'),
33+
]
34+
self.ensembles = {ens.name: ens for ens in ensembles}
35+
traj_vals = [-0.1, 1.1, 0.5, -0.2, 0.1, -0.3, 0.4, 1.4, -1.0]
36+
self.trajectory = make_1d_traj(traj_vals)
37+
self.engine = CalvinistDynamics(traj_vals)
38+
self.satisfied_when_traj_len = {
39+
"len1": 1,
40+
"len3": 3,
41+
"return": 6,
42+
"transition": 8,
43+
}
44+
self.conditions = EnsembleSatisfiedContinueConditions(ensembles)
45+
46+
47+
@pytest.mark.parametrize('trusted', [True, False])
48+
@pytest.mark.parametrize('traj_len,expected', [
49+
# expected = (num_calls, num_satisfied)
50+
(0, (1, 0)),
51+
(1, (2, 1)),
52+
(2, (3, 1)),
53+
(3, (3, 2)),
54+
(5, (2, 2)),
55+
(6, (3, 3)),
56+
(7, (1, 3)),
57+
(8, (3, 4)),
58+
])
59+
def test_call(self, traj_len, expected, trusted):
60+
if trusted:
61+
already_satisfied = [
62+
self.ensembles[key]
63+
for key, val in self.satisfied_when_traj_len.items()
64+
if traj_len > val
65+
]
66+
for ens in already_satisfied:
67+
self.conditions.satisfied[ens] = True
68+
69+
traj = self.trajectory[:traj_len]
70+
mock = Mock(wraps=self.conditions._check_previous_frame)
71+
self.conditions._check_previous_frame = mock
72+
expected_calls, expected_satisfied = expected
73+
result = self.conditions(traj, trusted)
74+
assert result == (expected_satisfied != 4)
75+
assert sum(self.conditions.satisfied.values()) == expected_satisfied
76+
if trusted:
77+
# only test call count if we're trusted
78+
assert mock.call_count == expected_calls
79+
80+
def test_generate(self):
81+
init_snap = self.trajectory[0]
82+
traj = self.engine.generate(init_snap, self.conditions)
83+
assert len(traj) == 8

0 commit comments

Comments
 (0)