Skip to content

Commit 7033c0d

Browse files
committed
tests for precomputing CVs
1 parent 748705b commit 7033c0d

File tree

1 file changed

+53
-3
lines changed

1 file changed

+53
-3
lines changed

paths_cli/tests/test_file_copying.py

Lines changed: 53 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,41 @@
11
import os
2-
2+
import tempfile
3+
from unittest.mock import MagicMock, patch
34
import pytest
45

6+
7+
import openpathsampling as paths
8+
from openpathsampling.tests.test_helpers import make_1d_traj
9+
510
from paths_cli.file_copying import *
611

712
class Test_PRECOMPUTE_CVS(object):
8-
pass
13+
def setup(self):
14+
self.tmpdir = tempfile.mkdtemp()
15+
self.storage_filename = os.path.join(self.tmpdir, "test.nc")
16+
self.storage = paths.Storage(self.storage_filename, mode='w')
17+
snap = make_1d_traj([1])[0]
18+
self.storage.save(snap)
19+
self.cv_x = paths.CoordinateFunctionCV("x", lambda s: s.xyz[0][0])
20+
self.cv_y = paths.CoordinateFunctionCV("y", lambda s: s.xyz[0][1])
21+
self.storage.save([self.cv_x, self.cv_y])
22+
23+
def teardown(self):
24+
self.storage.close()
25+
26+
for filename in os.listdir(self.tmpdir):
27+
os.remove(os.path.join(self.tmpdir, filename))
28+
os.rmdir(self.tmpdir)
29+
30+
@pytest.mark.parametrize('getter', ['x', None, '--'])
31+
def test_get(self, getter):
32+
expected = {'x': [self.cv_x],
33+
None: [self.cv_x, self.cv_y],
34+
'--': []}[getter]
35+
getter = [] if getter is None else [getter] # CLI gives a list
36+
cvs = PRECOMPUTE_CVS.get(self.storage, getter)
37+
assert len(cvs) == len(expected)
38+
assert set(cvs) == set(expected)
939

1040

1141
@pytest.mark.parametrize('blocksize', [2, 3, 5, 10, 12])
@@ -22,8 +52,28 @@ def test_make_blocks(blocksize):
2252

2353

2454
class TestPrecompute(object):
55+
def setup(self):
56+
class RunOnceFunction(object):
57+
def __init__(self):
58+
self.previously_seen = set([])
59+
60+
def __call__(self, snap):
61+
if snap in self.previously_seen:
62+
raise AssertionError("Second CV eval for " + str(snap))
63+
self.previously_seen.update({snap})
64+
return snap.xyz[0][0]
65+
66+
self.cv = paths.FunctionCV("test", RunOnceFunction())
67+
traj = paths.tests.test_helpers.make_1d_traj([2, 1])
68+
self.snap = traj[0]
69+
self.other_snap = traj[1]
70+
2571
def test_precompute_cvs(self):
26-
pytest.skip()
72+
precompute_cvs([self.cv], [self.snap])
73+
assert self.cv.f.previously_seen == {self.snap}
74+
recalced = self.cv(self.snap) # AssertionError if func called
75+
assert recalced == 2
76+
assert self.cv.diskcache_enabled is True
2777

2878
def test_precompute_cvs_and_inputs(self):
2979
pytest.skip()

0 commit comments

Comments
 (0)