@@ -31,9 +31,9 @@ def setup(self):
3131 ]).named ('transition' ),
3232 ]
3333 self .ensembles = {ens .name : ens for ens in ensembles }
34- traj_vals = [- 0.1 , 1.1 , 0.5 , - 0.2 , 0.1 , - 0.3 , 0.4 , 1.4 , - 1.0 ]
35- self .trajectory = make_1d_traj (traj_vals )
36- self .engine = CalvinistDynamics (traj_vals )
34+ self . traj_vals = [- 0.1 , 1.1 , 0.5 , - 0.2 , 0.1 , - 0.3 , 0.4 , 1.4 , - 1.0 ]
35+ self .trajectory = make_1d_traj (self . traj_vals )
36+ self .engine = CalvinistDynamics (self . traj_vals )
3737 self .satisfied_when_traj_len = {
3838 "len1" : 1 ,
3939 "len3" : 3 ,
@@ -76,6 +76,10 @@ def test_call(self, traj_len, expected, trusted):
7676 # only test call count if we're trusted
7777 assert mock .call_count == expected_calls
7878
79+ def test_long_traj_untrusted (self ):
80+ traj = make_1d_traj (self .traj_vals + [1.0 , 1.2 , 1.3 , 1.4 ])
81+ assert self .conditions (traj ) is False
82+
7983 def test_generate (self ):
8084 init_snap = self .trajectory [0 ]
8185 traj = self .engine .generate (init_snap , self .conditions )
@@ -118,17 +122,25 @@ def test_md(md_fixture):
118122 assert results .output == expected_output
119123 assert results .exit_code == 0
120124
121- def test_md_main (md_fixture ):
125+ @pytest .mark .parametrize ('inp' , ['nsteps' , 'ensemble' ])
126+ def test_md_main (md_fixture , inp ):
122127 tempdir = tempfile .mkdtemp ()
123128 try :
124129 store_name = os .path .join (tempdir , "md.nc" )
125130 storage = paths .Storage (store_name , mode = 'w' )
126- engine , ensemble , snapshot = md_fixture
131+ engine , ens , snapshot = md_fixture
132+ if inp == 'nsteps' :
133+ nsteps , ensembles = 5 , None
134+ elif inp == 'ensemble' :
135+ nsteps , ensembles = None , [ens ]
136+ else :
137+ raise RuntimeError ("pytest went crazy" )
138+
127139 traj , foo = md_main (
128140 output_storage = storage ,
129141 engine = engine ,
130- ensembles = [ ensemble ] ,
131- nsteps = None ,
142+ ensembles = ensembles ,
143+ nsteps = nsteps ,
132144 initial_frame = snapshot
133145 )
134146 assert isinstance (traj , paths .Trajectory )
@@ -140,3 +152,11 @@ def test_md_main(md_fixture):
140152 os .remove (store_name )
141153 os .rmdir (tempdir )
142154
155+ def test_md_main_error (md_fixture ):
156+ engine , ensemble , snapshot = md_fixture
157+ with pytest .raises (RuntimeError ):
158+ md_main (output_storage = None ,
159+ engine = engine ,
160+ ensembles = [ensemble ],
161+ nsteps = 5 ,
162+ initial_frame = snapshot )
0 commit comments