|
| 1 | +import collections |
| 2 | +import functools |
| 3 | +import os |
| 4 | +import tempfile |
| 5 | +from unittest.mock import MagicMock, patch |
| 6 | +import pytest |
| 7 | + |
| 8 | +import openpathsampling as paths |
| 9 | +from openpathsampling.tests.test_helpers import make_1d_traj |
| 10 | + |
| 11 | +from paths_cli.file_copying import * |
| 12 | + |
| 13 | +class Test_PRECOMPUTE_CVS(object): |
| 14 | + def setup(self): |
| 15 | + self.tmpdir = tempfile.mkdtemp() |
| 16 | + self.storage_filename = os.path.join(self.tmpdir, "test.nc") |
| 17 | + self.storage = paths.Storage(self.storage_filename, mode='w') |
| 18 | + snap = make_1d_traj([1])[0] |
| 19 | + self.storage.save(snap) |
| 20 | + self.cv_x = paths.CoordinateFunctionCV("x", lambda s: s.xyz[0][0]) |
| 21 | + self.cv_y = paths.CoordinateFunctionCV("y", lambda s: s.xyz[0][1]) |
| 22 | + self.storage.save([self.cv_x, self.cv_y]) |
| 23 | + |
| 24 | + def teardown(self): |
| 25 | + self.storage.close() |
| 26 | + |
| 27 | + for filename in os.listdir(self.tmpdir): |
| 28 | + os.remove(os.path.join(self.tmpdir, filename)) |
| 29 | + os.rmdir(self.tmpdir) |
| 30 | + |
| 31 | + @pytest.mark.parametrize('getter', ['x', None, '--']) |
| 32 | + def test_get(self, getter): |
| 33 | + expected = {'x': [self.cv_x], |
| 34 | + None: [self.cv_x, self.cv_y], |
| 35 | + '--': []}[getter] |
| 36 | + getter = [] if getter is None else [getter] # CLI gives a list |
| 37 | + cvs = PRECOMPUTE_CVS.get(self.storage, getter) |
| 38 | + assert len(cvs) == len(expected) |
| 39 | + assert set(cvs) == set(expected) |
| 40 | + |
| 41 | + |
| 42 | +@pytest.mark.parametrize('blocksize', [2, 3, 5, 10, 12]) |
| 43 | +def test_make_blocks(blocksize): |
| 44 | + expected_lengths = {2: [2, 2, 2, 2, 2], |
| 45 | + 3: [3, 3, 3, 1], |
| 46 | + 5: [5, 5], |
| 47 | + 10: [10], |
| 48 | + 12: [10]}[blocksize] |
| 49 | + ll = list(range(10)) |
| 50 | + blocks = make_blocks(ll, blocksize) |
| 51 | + assert [len(block) for block in blocks] == expected_lengths |
| 52 | + assert sum(blocks, []) == ll |
| 53 | + |
| 54 | + |
| 55 | +class TestPrecompute(object): |
| 56 | + def setup(self): |
| 57 | + class RunOnceFunction(object): |
| 58 | + def __init__(self): |
| 59 | + self.previously_seen = set([]) |
| 60 | + |
| 61 | + def __call__(self, snap): |
| 62 | + if snap in self.previously_seen: |
| 63 | + raise AssertionError("Second CV eval for " + str(snap)) |
| 64 | + self.previously_seen.update({snap}) |
| 65 | + return snap.xyz[0][0] |
| 66 | + |
| 67 | + self.cv = paths.FunctionCV("test", RunOnceFunction()) |
| 68 | + traj = make_1d_traj([2, 1]) |
| 69 | + self.snap = traj[0] |
| 70 | + self.other_snap = traj[1] |
| 71 | + |
| 72 | + def test_precompute_cvs(self): |
| 73 | + precompute_cvs([self.cv], [self.snap]) |
| 74 | + assert self.cv.f.previously_seen == {self.snap} |
| 75 | + recalced = self.cv(self.snap) # AssertionError if func called |
| 76 | + assert recalced == 2 |
| 77 | + assert self.cv.diskcache_enabled is True |
| 78 | + |
| 79 | + @pytest.mark.parametrize('cvs', [['test'], None]) |
| 80 | + def test_precompute_cvs_and_inputs(self, cvs): |
| 81 | + with tempfile.TemporaryDirectory() as tmpdir: |
| 82 | + storage = paths.Storage(os.path.join(tmpdir, "test.nc"), |
| 83 | + mode='w') |
| 84 | + traj = make_1d_traj(list(range(10))) |
| 85 | + cv = paths.FunctionCV("test", lambda s: s.xyz[0][0]) |
| 86 | + storage.save(traj) |
| 87 | + storage.save(cv) |
| 88 | + |
| 89 | + if cvs is not None: |
| 90 | + cvs = [storage.cvs[cv] for cv in cvs] |
| 91 | + |
| 92 | + precompute_func, blocks = precompute_cvs_func_and_inputs( |
| 93 | + input_storage=storage, |
| 94 | + cvs=cvs, |
| 95 | + blocksize=2 |
| 96 | + ) |
| 97 | + assert len(blocks) == 5 |
| 98 | + for block in blocks: |
| 99 | + assert len(block) == 2 |
| 100 | + |
| 101 | + # smoke test: only effect should be caching results |
| 102 | + precompute_func(blocks[0]) |
| 103 | + |
| 104 | + |
| 105 | +def test_rewrite_file(): |
| 106 | + # making a mock for storage instead of actually testing integration |
| 107 | + class FakeStore(object): |
| 108 | + def __init__(self): |
| 109 | + self._stores = collections.defaultdict(list) |
| 110 | + |
| 111 | + def store(self, obj, store_name): |
| 112 | + self._stores[store_name].append(obj) |
| 113 | + |
| 114 | + stage_names = ['foo', 'bar'] |
| 115 | + storage = FakeStore() |
| 116 | + store_funcs = { |
| 117 | + name: functools.partial(storage.store, store_name=name) |
| 118 | + for name in stage_names |
| 119 | + } |
| 120 | + stage_mapping = { |
| 121 | + 'foo': (store_funcs['foo'], [0, 1, 2]), |
| 122 | + 'bar': (store_funcs['bar'], [[3], [4], [5]]) |
| 123 | + } |
| 124 | + silent_tqdm = lambda x, desc=None, leave=True: x |
| 125 | + with patch('paths_cli.file_copying.tqdm', silent_tqdm): |
| 126 | + rewrite_file(stage_names, stage_mapping) |
| 127 | + |
| 128 | + assert storage._stores['foo'] == [0, 1, 2] |
| 129 | + assert storage._stores['bar'] == [[3], [4], [5]] |
0 commit comments