11import pytest
22from unittest import mock
33from paths_cli .tests .compiling .utils import mock_compiler
4+ from paths_cli .compiling .plugins import CVCompilerPlugin
5+ from paths_cli .compiling .core import Parameter
46
57import yaml
68import numpy as np
@@ -23,20 +25,20 @@ def setup(self):
2325 }
2426
2527 self .func = {
26- 'inline' : "\n " .join (["name: foo" , "type: mdtraj" ]),
28+ 'inline' : "\n " + "\n " .join ([
29+ "name: foo" ,
30+ "type: fake_type" ,
31+ "input_data: bar" ,
32+ ]),
2733 'external' : 'foo'
2834 }
2935
30- def create_inputs (self , inline , periodic ):
31- yml = "\n " .join (["type: cv-volume" , "cv: {func}" ,
32- "lambda_min: 0" , "lambda_max: 1" ])
33-
3436 def set_periodic (self , periodic ):
3537 if periodic == 'periodic' :
3638 self .named_objs_dict ['foo' ]['period_max' ] = 'np.pi'
3739 self .named_objs_dict ['foo' ]['period_min' ] = '-np.pi'
3840
39- @pytest .mark .parametrize ('inline' , ['external' , 'external ' ])
41+ @pytest .mark .parametrize ('inline' , ['external' , 'inline ' ])
4042 @pytest .mark .parametrize ('periodic' , ['periodic' , 'nonperiodic' ])
4143 def test_build_cv_volume (self , inline , periodic ):
4244 self .set_periodic (periodic )
@@ -47,14 +49,29 @@ def test_build_cv_volume(self, inline, periodic):
4749 mock_cv = CoordinateFunctionCV (lambda s : s .xyz [0 ][0 ],
4850 period_min = period_min ,
4951 period_max = period_max ).named ('foo' )
52+
53+ patch_loc = 'paths_cli.compiling.root_compiler._COMPILERS'
54+
5055 if inline == 'external' :
51- patch_loc = 'paths_cli.compiling.root_compiler._COMPILERS'
5256 compilers = {
5357 'cv' : mock_compiler ('cv' , named_objs = {'foo' : mock_cv })
5458 }
55- with mock .patch .dict (patch_loc , compilers ):
56- vol = build_cv_volume (dct )
57- elif inline == 'internal' :
59+ elif inline == 'inline' :
60+ fake_plugin = CVCompilerPlugin (
61+ name = "fake_type" ,
62+ parameters = [Parameter ('input_data' , str )],
63+ builder = lambda input_data : mock_cv
64+ )
65+ compilers = {
66+ 'cv' : mock_compiler (
67+ 'cv' ,
68+ type_dispatch = {'fake_type' : fake_plugin }
69+ )
70+ }
71+ else : # -no-cov-
72+ raise RuntimeError ("Should never get here" )
73+
74+ with mock .patch .dict (patch_loc , compilers ):
5875 vol = build_cv_volume (dct )
5976
6077 in_state = make_1d_traj ([0.5 ])[0 ]
0 commit comments