Skip to content

Commit 8e9fa46

Browse files
committed
more tests; docstrings
1 parent 38e01ba commit 8e9fa46

File tree

2 files changed

+64
-6
lines changed

2 files changed

+64
-6
lines changed

paths_cli/commands/md.py

Lines changed: 38 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -39,25 +39,43 @@ def md(input_file, output_file, engine, ensemble, nsteps, init_frame):
3939
)
4040

4141
class ProgressReporter(object):
42+
"""Generic class for a callable that reports progress.
43+
44+
Base class for ends-with-ensemble and fixed-length tricks.
45+
46+
Parameters
47+
----------
48+
timestep : Any
49+
timestep, optionally with units
50+
update_freq : int
51+
how often to report updates
52+
"""
4253
def __init__(self, timestep, update_freq):
4354
self.timestep = timestep
4455
self.update_freq = update_freq
4556

4657
def steps_progress_string(self, n_steps):
58+
"""Return string for number of frames run and time elapsed
59+
60+
Not newline-terminated.
61+
"""
4762
report_str = "Ran {n_steps} frames"
4863
if self.timestep is not None:
49-
report_str += " [{}]".format(str(n_steps * timestep))
64+
report_str += " [{}]".format(str(n_steps * self.timestep))
5065
report_str += '.'
51-
return report_str
66+
return report_str.format(n_steps=n_steps)
5267

5368
def progress_string(self, n_steps):
69+
"""Return the progress string. Subclasses may override.
70+
"""
5471
report_str = self.steps_progress_string(n_steps) + "\n"
5572
return report_str.format(n_steps=n_steps)
5673

57-
58-
def report_progress(self, n_steps):
74+
def report_progress(self, n_steps, force=False):
75+
"""Report the progress to the terminal.
76+
"""
5977
import openpathsampling as paths
60-
if n_steps % self.update_freq == 0:
78+
if (n_steps % self.update_freq == 0) or force:
6179
string = self.progress_string(n_steps)
6280
paths.tools.refresh_output(string)
6381

@@ -77,6 +95,10 @@ class EnsembleSatisfiedContinueConditions(ProgressReporter):
7795
----------
7896
ensembles: List[:class:`openpathsampling.Ensemble`]
7997
the ensembles to satisfy
98+
timestep : Any
99+
timestep, optionally with units
100+
update_freq : int
101+
how often to report updates
80102
"""
81103
def __init__(self, ensembles, timestep=None, update_freq=10):
82104
super().__init__(timestep, update_freq)
@@ -136,7 +158,16 @@ def __call__(self, trajectory, trusted=False):
136158

137159

138160
class FixedLengthContinueCondition(ProgressReporter):
139-
"""
161+
"""Continuation condition for fixed-length runs.
162+
163+
Parameters
164+
----------
165+
length : int
166+
final length of the trajectory in frames
167+
timestep : Any
168+
timestep, optionally with units
169+
update_freq : int
170+
how often to report updates
140171
"""
141172
def __init__(self, length, timestep=None, update_freq=10):
142173
super().__init__(timestep, update_freq)
@@ -161,6 +192,7 @@ def md_main(output_storage, engine, ensembles, nsteps, initial_frame):
161192
continue_cond = FixedLengthContinueCondition(nsteps)
162193

163194
trajectory = engine.generate(initial_frame, running=continue_cond)
195+
continue_cond.report_progress(len(trajectory) - 1, force=True)
164196
paths_cli.utils.tag_final_result(trajectory, output_storage,
165197
'final_conditions')
166198
return trajectory, None

paths_cli/tests/commands/test_md.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,32 @@
1111
from openpathsampling.tests.test_helpers import \
1212
make_1d_traj, CalvinistDynamics
1313

14+
class TestProgressReporter(object):
15+
def setup(self):
16+
self.progress = ProgressReporter(timestep=None, update_freq=5)
17+
18+
@pytest.mark.parametrize('timestep', [None, 0.1])
19+
def test_progress_string(self, timestep):
20+
progress = ProgressReporter(timestep, update_freq=5)
21+
expected = "Ran 25 frames"
22+
if timestep is not None:
23+
expected += " [2.5]"
24+
expected += '.\n'
25+
assert progress.progress_string(25) == expected
26+
27+
@pytest.mark.parametrize('n_steps', [0, 5, 6])
28+
@pytest.mark.parametrize('force', [True, False])
29+
@patch('openpathsampling.tools.refresh_output',
30+
lambda s: print(s, end=''))
31+
def test_report_progress(self, n_steps, force, capsys):
32+
self.progress.report_progress(n_steps, force)
33+
expected = "Ran {n_steps} frames.\n".format(n_steps=n_steps)
34+
out, err = capsys.readouterr()
35+
if (n_steps in [0, 5]) or force:
36+
assert out == expected
37+
else:
38+
assert out == ""
39+
1440
class TestEnsembleSatisfiedContinueConditions(object):
1541
def setup(self):
1642
cv = paths.CoordinateFunctionCV('x', lambda x: x.xyz[0][0])

0 commit comments

Comments
 (0)