22import tempfile
33import os
44
5- import openpathsampling as paths
5+ import paths_cli
66from openpathsampling .tests .test_helpers import make_1d_traj
77
88from paths_cli .parameters import *
9+ import openpathsampling as paths
10+
11+
12+ def pre_monkey_patch ():
13+ # store things that get monkey-patched; ensure we un-patch
14+ stored_functions = {}
15+ CallableCV = paths .CallableCV
16+ PseudoAttr = paths .netcdfplus .FunctionPseudoAttribute
17+ stored_functions ['CallableCV.from' ] = CallableCV .from_dict
18+ stored_functions ['PseudoAttr.from' ] = PseudoAttr .from_dict
19+ stored_functions ['TPSNetwork.from' ] = paths .TPSNetwork .from_dict
20+ stored_functions ['MISTISNetwork.from' ] = paths .MISTISNetwork .from_dict
21+ stored_functions ['PseudoAttr.to' ] = PseudoAttr .to_dict
22+ stored_functions ['TPSNetwork.to' ] = paths .TPSNetwork .to_dict
23+ stored_functions ['MISTISNetwork.to' ] = paths .MISTISNetwork .to_dict
24+ return stored_functions
25+
26+ def undo_monkey_patch (stored_functions ):
27+ CallableCV = paths .CallableCV
28+ PseudoAttr = paths .netcdfplus .FunctionPseudoAttribute
29+ CallableCV .from_dict = stored_functions ['CallableCV.from' ]
30+ PseudoAttr .from_dict = stored_functions ['PseudoAttr.from' ]
31+ paths .TPSNetwork .from_dict = stored_functions ['TPSNetwork.from' ]
32+ paths .MISTISNetwork .from_dict = stored_functions ['MISTISNetwork.from' ]
33+ PseudoAttr .to_dict = stored_functions ['PseudoAttr.to' ]
34+ paths .TPSNetwork .to_dict = stored_functions ['TPSNetwork.to' ]
35+ paths .MISTISNetwork .to_dict = stored_functions ['MISTISNetwork.to' ]
36+ paths_cli .param_core .StorageLoader .has_simstore_patch = False
37+ paths .InterfaceSet .simstore = False
38+ import importlib
39+ importlib .reload (paths .netcdfplus )
40+ importlib .reload (paths .collectivevariable )
41+ importlib .reload (paths )
42+
943
1044
1145class ParameterTest (object ):
@@ -114,6 +148,17 @@ def setup(self):
114148 def test_get (self , getter ):
115149 self ._getter_test (getter )
116150
151+ def test_cannot_guess (self ):
152+ filename = self ._filename ('no-guess' )
153+ storage = paths .Storage (filename , 'w' )
154+ storage .save (self .engine )
155+ storage .save (self .other_engine .named ('other' ))
156+ storage .close ()
157+
158+ storage = paths .Storage (filename , mode = 'r' )
159+ with pytest .raises (RuntimeError ):
160+ self .PARAMETER .get (storage , None )
161+
117162
118163class TestSCHEME (ParamInstanceTest ):
119164 PARAMETER = SCHEME
@@ -241,6 +286,17 @@ def test_get_multiple(self):
241286 assert traj0 == self .traj
242287 assert traj1 == self .other_traj
243288
289+ def test_cannot_guess (self ):
290+ filename = self ._filename ('no-guess' )
291+ storage = paths .Storage (filename , 'w' )
292+ storage .save (self .traj )
293+ storage .save (self .other_traj )
294+ storage .close ()
295+
296+ storage = paths .Storage (filename , 'r' )
297+ with pytest .raises (RuntimeError ):
298+ self .PARAMETER .get (storage , None )
299+
244300
245301class TestINIT_SNAP (ParamInstanceTest ):
246302 PARAMETER = INIT_SNAP
@@ -278,6 +334,18 @@ def test_get(self, getter):
278334 obj = self .PARAMETER .get (storage , get_arg )
279335 assert obj == expected
280336
337+ def test_simstore_single_snapshot (self ):
338+ stored_functions = pre_monkey_patch ()
339+ filename = os .path .join (self .tempdir , "simstore.db" )
340+ storage = APPEND_FILE .get (filename )
341+ storage .save (self .init_snap )
342+ storage .close ()
343+
344+ storage = INPUT_FILE .get (filename )
345+ snap = self .PARAMETER .get (storage , None )
346+ assert snap == self .init_snap
347+ undo_monkey_patch (stored_functions )
348+
281349
282350class MultiParamInstanceTest (ParamInstanceTest ):
283351 def _getter_test (self , getter ):
@@ -384,21 +452,26 @@ def test_get(self, getter):
384452 self ._getter_test (getter )
385453
386454
387- def test_OUTPUT_FILE ():
455+ @pytest .mark .parametrize ('ext' , ['nc' , 'db' , 'sql' ])
456+ def test_OUTPUT_FILE (ext ):
457+ stored_functions = pre_monkey_patch ()
388458 tempdir = tempfile .mkdtemp ()
389- filename = os .path .join (tempdir , "test_output_file.nc" )
459+ filename = os .path .join (tempdir , "test_output_file." + ext )
390460 assert not os .path .exists (filename )
391461 storage = OUTPUT_FILE .get (filename )
392462 assert os .path .exists (filename )
393463 os .remove (filename )
394464 os .rmdir (tempdir )
465+ undo_monkey_patch (stored_functions )
395466
396- def test_APPEND_FILE ():
467+ @pytest .mark .parametrize ('ext' , ['nc' , 'db' , 'sql' ])
468+ def test_APPEND_FILE (ext ):
469+ stored_functions = pre_monkey_patch ()
397470 tempdir = tempfile .mkdtemp ()
398- filename = os .path .join (tempdir , "test_append_file.nc" )
471+ filename = os .path .join (tempdir , "test_append_file." + ext )
399472 assert not os .path .exists (filename )
400473 storage = APPEND_FILE .get (filename )
401- print (storage )
474+ # print(storage) # potentially useful debug; keep
402475 assert os .path .exists (filename )
403476 traj = make_1d_traj ([0.0 , 1.0 ])
404477 storage .tags ['first_save' ] = traj [0 ]
@@ -412,3 +485,4 @@ def test_APPEND_FILE():
412485 storage .close ()
413486 os .remove (filename )
414487 os .rmdir (tempdir )
488+ undo_monkey_patch (stored_functions )
0 commit comments