11import os
2-
2+ import tempfile
3+ from unittest .mock import MagicMock , patch
34import pytest
45
6+
7+ import openpathsampling as paths
8+ from openpathsampling .tests .test_helpers import make_1d_traj
9+
510from paths_cli .file_copying import *
611
712class Test_PRECOMPUTE_CVS (object ):
8- pass
13+ def setup (self ):
14+ self .tmpdir = tempfile .mkdtemp ()
15+ self .storage_filename = os .path .join (self .tmpdir , "test.nc" )
16+ self .storage = paths .Storage (self .storage_filename , mode = 'w' )
17+ snap = make_1d_traj ([1 ])[0 ]
18+ self .storage .save (snap )
19+ self .cv_x = paths .CoordinateFunctionCV ("x" , lambda s : s .xyz [0 ][0 ])
20+ self .cv_y = paths .CoordinateFunctionCV ("y" , lambda s : s .xyz [0 ][1 ])
21+ self .storage .save ([self .cv_x , self .cv_y ])
22+
23+ def teardown (self ):
24+ self .storage .close ()
25+
26+ for filename in os .listdir (self .tmpdir ):
27+ os .remove (os .path .join (self .tmpdir , filename ))
28+ os .rmdir (self .tmpdir )
29+
30+ @pytest .mark .parametrize ('getter' , ['x' , None , '--' ])
31+ def test_get (self , getter ):
32+ expected = {'x' : [self .cv_x ],
33+ None : [self .cv_x , self .cv_y ],
34+ '--' : []}[getter ]
35+ getter = [] if getter is None else [getter ] # CLI gives a list
36+ cvs = PRECOMPUTE_CVS .get (self .storage , getter )
37+ assert len (cvs ) == len (expected )
38+ assert set (cvs ) == set (expected )
939
1040
1141@pytest .mark .parametrize ('blocksize' , [2 , 3 , 5 , 10 , 12 ])
@@ -22,8 +52,28 @@ def test_make_blocks(blocksize):
2252
2353
2454class TestPrecompute (object ):
55+ def setup (self ):
56+ class RunOnceFunction (object ):
57+ def __init__ (self ):
58+ self .previously_seen = set ([])
59+
60+ def __call__ (self , snap ):
61+ if snap in self .previously_seen :
62+ raise AssertionError ("Second CV eval for " + str (snap ))
63+ self .previously_seen .update ({snap })
64+ return snap .xyz [0 ][0 ]
65+
66+ self .cv = paths .FunctionCV ("test" , RunOnceFunction ())
67+ traj = paths .tests .test_helpers .make_1d_traj ([2 , 1 ])
68+ self .snap = traj [0 ]
69+ self .other_snap = traj [1 ]
70+
2571 def test_precompute_cvs (self ):
26- pytest .skip ()
72+ precompute_cvs ([self .cv ], [self .snap ])
73+ assert self .cv .f .previously_seen == {self .snap }
74+ recalced = self .cv (self .snap ) # AssertionError if func called
75+ assert recalced == 2
76+ assert self .cv .diskcache_enabled is True
2777
2878 def test_precompute_cvs_and_inputs (self ):
2979 pytest .skip ()
0 commit comments