Skip to content

Commit 5677183

Browse files
authored
Merge pull request #1 from dwhswenson/visit-all
2 parents 59c437e + 46b8f1f commit 5677183

File tree

4 files changed

+147
-0
lines changed

4 files changed

+147
-0
lines changed

paths_cli/commands/visit_all.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
import click
2+
3+
from paths_cli.parameters import (INPUT_FILE, OUTPUT_FILE, ENGINE, STATES,
4+
INIT_SNAP)
5+
6+
@click.command(
7+
"visit-all",
8+
short_help="Run MD to generate initial trajectories",
9+
)
10+
@INPUT_FILE.clicked(required=True)
11+
@OUTPUT_FILE.clicked(required=True)
12+
@STATES.clicked(required=True)
13+
@ENGINE.clicked(required=False)
14+
@INIT_SNAP.clicked(required=False)
15+
def visit_all(input_file, output_file, state, engine, init_frame):
16+
"""Run until initial trajectory for TPS/MSTPS/MSTIS achieved.
17+
18+
This runs until all given states have been visited. That creates a long
19+
trajectory, subtrajectories of which will work for the initial
20+
trajectories in TPS, MSTPS, or MSTIS. Typically, you'll use a different
21+
engine from the TPS production engine (often high temperature).
22+
"""
23+
storage = INPUT_FILE.get(input_file)
24+
visit_all_main(
25+
output_storage=OUTPUT_FILE.get(output_file),
26+
states=STATES.get(storage, state),
27+
engine=ENGINE.get(storage, engine),
28+
initial_frame=INIT_SNAP.get(storage, init_frame)
29+
)
30+
31+
32+
def visit_all_main(output_storage, states, engine, initial_frame):
33+
import openpathsampling as paths
34+
timestep = getattr(engine, 'timestep', None)
35+
visit_all_ens = paths.VisitAllStatesEnsemble(states, timestep=timestep)
36+
trajectory = engine.generate(initial_frame, [visit_all_ens.can_append])
37+
if output_storage is not None:
38+
output_storage.save(trajectory)
39+
output_storage.tags['final_conditions'] = trajectory
40+
41+
return trajectory, None # no simulation object to return here
42+
43+
44+
CLI = visit_all
45+
SECTION = "Simulation"
46+
REQUIRES_OPS = (1, 0)

paths_cli/tests/commands/__init__.py

Whitespace-only changes.
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
import os
2+
3+
import pytest
4+
from unittest.mock import patch, MagicMock
5+
import tempfile
6+
from click.testing import CliRunner
7+
8+
from paths_cli.commands.visit_all import *
9+
10+
import openpathsampling as paths
11+
12+
# patch with this for testing
13+
def print_test(output_storage, states, engine, initial_frame):
14+
print(isinstance(output_storage, paths.Storage))
15+
print(sorted([s.__uuid__ for s in states]))
16+
print(engine.__uuid__)
17+
print(initial_frame.__uuid__)
18+
19+
20+
@pytest.fixture()
21+
def tps_fixture(flat_engine, tps_network_and_traj):
22+
network, traj = tps_network_and_traj
23+
scheme = paths.OneWayShootingMoveScheme(network=network,
24+
selector=paths.UniformSelector(),
25+
engine=flat_engine)
26+
init_conds = scheme.initial_conditions_from_trajectories(traj)
27+
return (scheme, network, flat_engine, init_conds)
28+
29+
@pytest.fixture()
30+
def visit_all_fixture(tps_fixture):
31+
scheme, network, engine, init_conds = tps_fixture
32+
states = sorted(network.all_states, key=lambda x: x.__uuid__)
33+
init_frame = init_conds[0].trajectory[0]
34+
return states, engine, init_frame
35+
36+
37+
@patch('paths_cli.commands.visit_all.visit_all_main', print_test)
38+
def test_visit_all(visit_all_fixture):
39+
# this is an integration test; testing integration click & parameters
40+
states, engine, init_frame = visit_all_fixture
41+
runner = CliRunner()
42+
with runner.isolated_filesystem():
43+
storage = paths.Storage("setup.nc", 'w')
44+
for obj in visit_all_fixture:
45+
storage.save(obj)
46+
storage.tags['initial_snapshot'] = init_frame
47+
storage.close()
48+
49+
results = runner.invoke(
50+
visit_all,
51+
["setup.nc", '-o', 'foo.nc', '-s', 'A', '-s', 'B',
52+
'-e', 'flat', '-f', 'initial_snapshot']
53+
)
54+
55+
expected_output = ("True\n[" + str(states[0].__uuid__) + ", "
56+
+ str(states[1].__uuid__) + "]\n")
57+
expected_output += "\n".join(str(obj.__uuid__)
58+
for obj in [engine, init_frame]) + "\n"
59+
assert results.exit_code == 0
60+
assert results.output == expected_output
61+
62+
def test_visit_all_main(visit_all_fixture):
63+
# just a smoke test here
64+
tempdir = tempfile.mkdtemp()
65+
try:
66+
store_name = os.path.join(tempdir, "visit_all.nc")
67+
storage = paths.Storage(store_name, mode='w')
68+
states, engine, init_frame = visit_all_fixture
69+
traj, foo = visit_all_main(storage, states, engine, init_frame)
70+
assert isinstance(traj, paths.Trajectory)
71+
assert foo is None
72+
assert len(storage.trajectories) == 1
73+
storage.close()
74+
finally:
75+
os.remove(store_name)
76+
os.rmdir(tempdir)

paths_cli/tests/conftest.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
from openpathsampling.tests.test_helpers import make_1d_traj
2+
from openpathsampling.engines import toy as toys
3+
import openpathsampling as paths
4+
import pytest
5+
6+
@pytest.fixture
7+
def flat_engine():
8+
pes = toys.LinearSlope([0, 0, 0], 0)
9+
topology = toys.Topology(n_spatial=3, masses=[1.0, 1.0, 1.0], pes=pes)
10+
integ = toys.LeapfrogVerletIntegrator(dt=0.1)
11+
options = {'integ': integ,
12+
'n_frames_max': 1000,
13+
'n_steps_per_frame': 1}
14+
engine = toys.Engine(options=options, topology=topology).named("flat")
15+
return engine
16+
17+
@pytest.fixture
18+
def tps_network_and_traj():
19+
cv = paths.CoordinateFunctionCV("x", lambda s: s.xyz[0][0])
20+
state_A = paths.CVDefinedVolume(cv, float("-inf"), 0).named("A")
21+
state_B = paths.CVDefinedVolume(cv, 1, float("inf")).named("B")
22+
network = paths.TPSNetwork(state_A, state_B)
23+
init_traj = make_1d_traj([-0.1, 0.1, 0.3, 0.5, 0.7, 0.9, 1.1],
24+
[2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0])
25+
return (network, init_traj)

0 commit comments

Comments
 (0)