Skip to content

Commit 472574d

Browse files
committed
shooting-point-selector compilers
also commenting out TIS networks for now
1 parent c455e9c commit 472574d

File tree

7 files changed

+66
-26
lines changed

7 files changed

+66
-26
lines changed

paths_cli/compiling/_gendocs/docs_generator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ def generate_plugin_rst(self, plugin, category_name,
104104
default="",
105105
description="name this object in order to reuse it",
106106
)
107-
rst += self.format_parameter(name_param, type_str=" (*string*)")
107+
rst += self.format_parameter(name_param, type_str=" (string)")
108108
for param in plugin.parameters:
109109
type_str = f" ({json_type_to_string(param.json_type)})"
110110
rst += self.format_parameter(param, type_str)

paths_cli/compiling/networks.py

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -69,24 +69,22 @@ def tis_trans_info(dct):
6969
)
7070

7171

72-
MISTIS_NETWORK_PLUGIN = NetworkCompilerPlugin(
73-
parameters=[Parameter('trans_info', mistis_trans_info)],
74-
builder=Builder('openpathsampling.MISTISNetwork'),
75-
name='mistis'
76-
)
72+
# MISTIS_NETWORK_PLUGIN = NetworkCompilerPlugin(
73+
# parameters=[Parameter('trans_info', mistis_trans_info)],
74+
# builder=Builder('openpathsampling.MISTISNetwork'),
75+
# name='mistis'
76+
# )
7777

7878

79-
TIS_NETWORK_PLUGIN = NetworkCompilerPlugin(
80-
builder=Builder('openpathsampling.MISTISNetwork'),
81-
parameters=[Parameter('trans_info', tis_trans_info)],
82-
name='tis'
83-
)
79+
# TIS_NETWORK_PLUGIN = NetworkCompilerPlugin(
80+
# builder=Builder('openpathsampling.MISTISNetwork'),
81+
# parameters=[Parameter('trans_info', tis_trans_info)],
82+
# name='tis'
83+
# )
8484

8585
# old names not yet replaced in testing THESE ARE WHY WE'RE DOUBLING! GET
8686
# RID OF THEM! (also, use an is-check)
8787
build_tps_network = TPS_NETWORK_PLUGIN
88-
build_mistis_network = MISTIS_NETWORK_PLUGIN
89-
build_tis_network = TIS_NETWORK_PLUGIN
9088

9189

9290
NETWORK_COMPILER = CategoryPlugin(NetworkCompilerPlugin, aliases=['networks'])

paths_cli/compiling/plugins.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,3 +45,11 @@ class SchemeCompilerPlugin(InstanceCompilerPlugin):
4545

4646
class StrategyCompilerPlugin(InstanceCompilerPlugin):
4747
category = 'strategy'
48+
49+
50+
class ShootingPointSelectorPlugin(InstanceCompilerPlugin):
51+
category = 'shooting-point-selector'
52+
53+
54+
class InterfaceSetPlugin(InstanceCompilerPlugin):
55+
category = 'interface-set'

paths_cli/compiling/root_compiler.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,8 @@ def _get_compiler(category):
7070
canonical_name = _canonical_name(category)
7171
# create a new compiler if none exists
7272
if canonical_name is None:
73-
canonical_name = category
74-
_COMPILERS[category] = CategoryCompiler(None, category)
73+
canonical_name = clean_input_key(category)
74+
_COMPILERS[canonical_name] = CategoryCompiler(None, category)
7575
return _COMPILERS[canonical_name]
7676

7777

paths_cli/compiling/shooting.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,10 @@
33
)
44
from paths_cli.compiling.root_compiler import compiler_for
55
from paths_cli.compiling.tools import custom_eval
6+
from paths_cli.compiling.plugins import ShootingPointSelectorPlugin
67

78

8-
build_uniform_selector = InstanceCompilerPlugin(
9+
build_uniform_selector = ShootingPointSelectorPlugin(
910
builder=Builder('openpathsampling.UniformSelector'),
1011
parameters=[],
1112
name='uniform',
@@ -19,7 +20,7 @@ def _remapping_gaussian_stddev(dct):
1920
return dct
2021

2122

22-
build_gaussian_selector = InstanceCompilerPlugin(
23+
build_gaussian_selector = ShootingPointSelectorPlugin(
2324
builder=Builder('openpathsampling.GaussianBiasSelector',
2425
remapper=_remapping_gaussian_stddev),
2526
parameters=[
@@ -31,10 +32,10 @@ def _remapping_gaussian_stddev(dct):
3132
)
3233

3334

34-
shooting_selector_compiler = CategoryCompiler(
35-
type_dispatch={
36-
'uniform': build_uniform_selector,
37-
'gaussian': build_gaussian_selector,
38-
},
39-
label='shooting-selectors'
40-
)
35+
# shooting_selector_compiler = CategoryCompiler(
36+
# type_dispatch={
37+
# 'uniform': build_uniform_selector,
38+
# 'gaussian': build_gaussian_selector,
39+
# },
40+
# label='shooting-point-selectors'
41+
# )

paths_cli/compiling/strategies.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
from paths_cli.compiling.core import Builder, Parameter
2-
from paths_cli.compiling.shooting import shooting_selector_compiler
2+
# from paths_cli.compiling.shooting import shooting_selector_compiler
33
from paths_cli.compiling.plugins import (
44
StrategyCompilerPlugin, CategoryPlugin
55
)
66
from paths_cli.compiling.root_compiler import compiler_for
77
from paths_cli.compiling.json_type import json_type_ref
88

9+
shooting_selector_compiler = compiler_for('shooting-point-selector')
10+
911

1012
def _strategy_name(class_name):
1113
return f"openpathsampling.strategies.{class_name}"

paths_cli/tests/compiling/test_root_compiler.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,29 +20,34 @@
2020
def foo_compiler():
2121
return CategoryCompiler(None, 'foo')
2222

23+
2324
@pytest.fixture
2425
def foo_compiler_plugin():
2526
return CategoryPlugin(Mock(category='foo', __name__='foo'), ['bar'])
2627

28+
2729
@pytest.fixture
2830
def foo_baz_builder_plugin():
29-
builder = InstanceCompilerPlugin(lambda: "FOO" , [], name='baz',
31+
builder = InstanceCompilerPlugin(lambda: "FOO", [], name='baz',
3032
aliases=['qux'])
3133
builder.category = 'foo'
3234
return builder
3335

36+
3437
### CONSTANTS ##############################################################
3538

3639
COMPILER_LOC = "paths_cli.compiling.root_compiler._COMPILERS"
3740
BASE = "paths_cli.compiling.root_compiler."
3841

42+
3943
### TESTS ##################################################################
4044

4145
@pytest.mark.parametrize('input_string', ["foo-bar", "FOO_bar", "foo bar",
4246
"foo_bar", "foo BAR"])
4347
def test_clean_input_key(input_string):
4448
assert clean_input_key(input_string) == "foo_bar"
4549

50+
4651
@pytest.mark.parametrize('input_name', ['canonical', 'alias'])
4752
def test_canonical_name(input_name):
4853
compilers = {'canonical': "FOO"}
@@ -51,6 +56,7 @@ def test_canonical_name(input_name):
5156
patch.dict(BASE + "_ALIASES", aliases) as aliases_:
5257
assert _canonical_name(input_name) == "canonical"
5358

59+
5460
class TestCategoryCompilerProxy:
5561
def setup(self):
5662
self.compiler = CategoryCompiler(None, "foo")
@@ -86,6 +92,7 @@ def _bar_dispatch(dct):
8692
with patch.dict(COMPILER_LOC, {'foo': foo_compiler}):
8793
assert proxy(user_input) == "bazbaz"
8894

95+
8996
def test_compiler_for_nonexisting():
9097
# if nothing is ever registered with the compiler, then compiler_for
9198
# should error
@@ -97,20 +104,23 @@ def test_compiler_for_nonexisting():
97104
with pytest.raises(RuntimeError, match="No CategoryCompiler"):
98105
proxy._proxy
99106

107+
100108
def test_compiler_for_existing(foo_compiler):
101109
# if a compiler already exists when compiler_for is called, then
102110
# compiler_for should get that as its proxy
103111
with patch.dict(COMPILER_LOC, {'foo': foo_compiler}):
104112
proxy = compiler_for('foo')
105113
assert proxy._proxy is foo_compiler
106114

115+
107116
def test_compiler_for_unregistered(foo_compiler):
108117
# if a compiler is registered after compiler_for is called, then
109118
# compiler_for should use that as its proxy
110119
proxy = compiler_for('foo')
111120
with patch.dict(COMPILER_LOC, {'foo': foo_compiler}):
112121
assert proxy._proxy is foo_compiler
113122

123+
114124
def test_compiler_for_registered_alias(foo_compiler):
115125
# if compiler_for is registered as an alias, compiler_for should still
116126
# get the correct compiler
@@ -121,12 +131,14 @@ def test_compiler_for_registered_alias(foo_compiler):
121131
proxy = compiler_for('bar')
122132
assert proxy._proxy is foo_compiler
123133

134+
124135
def test_get_compiler_existing(foo_compiler):
125136
# if a compiler has been registered, then _get_compiler should return the
126137
# registered compiler
127138
with patch.dict(COMPILER_LOC, {'foo': foo_compiler}):
128139
assert _get_compiler('foo') is foo_compiler
129140

141+
130142
def test_get_compiler_nonexisting(foo_compiler):
131143
# if a compiler has not been registered, then _get_compiler should create
132144
# the compiler
@@ -136,6 +148,18 @@ def test_get_compiler_nonexisting(foo_compiler):
136148
assert compiler.label == 'foo'
137149
assert 'foo' in _COMPILERS
138150

151+
152+
def test_get_compiler_nonstandard_name_multiple():
153+
# regression test based on real issue -- there was an error where
154+
# non-canonical names (e.g., names that involved hyphens instead of
155+
# underscores) would overwrite the previously created compiler instead
156+
# of getting the identical object
157+
with patch.dict(COMPILER_LOC, {}):
158+
c1 = _get_compiler('non-canonical-name')
159+
c2 = _get_compiler('non-canonical-name')
160+
assert c1 is c2
161+
162+
139163
@pytest.mark.parametrize('canonical,aliases,expected', [
140164
('foo', ['bar', 'baz'], ['foo', 'bar', 'baz']),
141165
('foo', ['baz', 'bar'], ['foo', 'baz', 'bar']),
@@ -149,6 +173,7 @@ def test_get_registration_names(canonical, aliases, expected):
149173
type(plugin).name = PropertyMock(return_value=canonical)
150174
assert _get_registration_names(plugin) == expected
151175

176+
152177
def test_register_compiler_plugin(foo_compiler_plugin):
153178
# _register_compiler_plugin should register compilers that don't exist
154179
compilers = {}
@@ -162,6 +187,7 @@ def test_register_compiler_plugin(foo_compiler_plugin):
162187

163188
assert 'foo' not in _COMPILERS
164189

190+
165191
@pytest.mark.parametrize('duplicate_of', ['canonical', 'alias'])
166192
@pytest.mark.parametrize('duplicate_from', ['canonical', 'alias'])
167193
def test_register_compiler_plugin_duplicate(duplicate_of, duplicate_from):
@@ -184,6 +210,7 @@ def test_register_compiler_plugin_duplicate(duplicate_of, duplicate_from):
184210
with pytest.raises(CategoryCompilerRegistrationError):
185211
_register_compiler_plugin(plugin)
186212

213+
187214
@pytest.mark.parametrize('compiler_exists', [True, False])
188215
def test_register_builder_plugin(compiler_exists, foo_baz_builder_plugin,
189216
foo_compiler):
@@ -203,6 +230,7 @@ def test_register_builder_plugin(compiler_exists, foo_baz_builder_plugin,
203230
assert type_dispatch['baz'] is foo_baz_builder_plugin
204231
assert type_dispatch['qux'] is foo_baz_builder_plugin
205232

233+
206234
def test_register_plugins_unit(foo_compiler_plugin, foo_baz_builder_plugin):
207235
# register_plugins should correctly sort builder and compiler plugins,
208236
# and call the correct registration functions
@@ -212,6 +240,7 @@ def test_register_plugins_unit(foo_compiler_plugin, foo_baz_builder_plugin):
212240
assert builder.called_once_with(foo_baz_builder_plugin)
213241
assert compiler.called_once_with(foo_compiler_plugin)
214242

243+
215244
def test_register_plugins_integration(foo_compiler_plugin,
216245
foo_baz_builder_plugin):
217246
# register_plugins should correctly register plugins
@@ -225,6 +254,7 @@ def test_register_plugins_integration(foo_compiler_plugin,
225254
type_dispatch = _COMPILERS['foo'].type_dispatch
226255
assert type_dispatch['baz'] is foo_baz_builder_plugin
227256

257+
228258
def test_sort_user_categories():
229259
# sorted user categories should match the expected compile order
230260
aliases = {'quux': 'qux'}
@@ -245,6 +275,7 @@ def test_sort_user_categories():
245275
# check that we unset properly (test the test)
246276
assert paths_cli.compiling.root_compiler.COMPILE_ORDER[0] == 'engine'
247277

278+
248279
def test_do_compile():
249280
# compiler should correctly compile a basic input dict
250281
compilers = {

0 commit comments

Comments
 (0)