Skip to content

Commit f0a4174

Browse files
committed
Refactor parameters
1 parent 9d1d369 commit f0a4174

File tree

1 file changed

+103
-84
lines changed

1 file changed

+103
-84
lines changed

paths_cli/parameters.py

Lines changed: 103 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
import click
22
import os
3-
# import openpathsampling as paths
3+
4+
_UNNAMED_STORES = ['snapshots', 'trajectories', 'samples', 'sample_sets',
5+
'steps']
6+
47

58
class AbstractParameter(object):
69
def __init__(self, *args, **kwargs):
@@ -12,11 +15,10 @@ def __init__(self, *args, **kwargs):
1215
def clicked(self, required=False):
1316
raise NotImplementedError()
1417

15-
# we'll use tests of the -h option in the .travis.yml to ensure that the
16-
# .clicked methods work
17-
1818
HELP_MULTIPLE = "; may be used more than once"
1919

20+
# we'll use tests of the -h option in the .travis.yml to ensure that the
21+
# .clicked methods work
2022
class Option(AbstractParameter):
2123
def clicked(self, required=False): # no-cov
2224
return click.option(*self.args, **self.kwargs, required=required)
@@ -77,119 +79,136 @@ def get(self, storage, names):
7779
for name in int_corrected]
7880

7981

82+
class Getter(object):
83+
def __init__(self, store_name):
84+
self.store_name = store_name
85+
86+
def _get(self, storage, name):
87+
store = getattr(storage, self.store_name)
88+
try:
89+
return store[name]
90+
except:
91+
return None
92+
93+
class GetByName(Getter):
94+
def __call__(self, storage, name):
95+
return self._get(storage, name)
96+
97+
class GetByNumber(Getter):
98+
def __call__(self, storage, name):
99+
try:
100+
num = int(name)
101+
except:
102+
return None
103+
104+
return self._get(storage, num)
105+
106+
class GetPredefinedName(Getter):
107+
def __init__(self, store_name, name):
108+
super().__init__(store_name=store_name)
109+
self.name = name
110+
111+
def __call__(self, storage):
112+
return self._get(storage, self.name)
113+
114+
class GetOnly(Getter):
115+
def __call__(self, storage):
116+
store = getattr(storage, self.store_name)
117+
if len(store) == 1:
118+
return store[0]
119+
120+
class GetOnlyNamed(Getter):
121+
def __call__(self, storage):
122+
store = getattr(storage, self.store_name)
123+
named_things = [o for o in store if o.is_named]
124+
if len(named_things) == 1:
125+
return named_things[0]
126+
127+
class GetOnlySnapshot(Getter):
128+
def __init__(self, store_name="snapshots"):
129+
super().__init__(store_name)
130+
131+
def __call__(self, storage):
132+
store = getattr(storage, self.store_name)
133+
if len(store) == 2:
134+
# this is really only 1 snapshot; reversed copy gets saved
135+
return store[0]
136+
137+
138+
def _try_strategies(strategies, storage, **kwargs):
139+
result = None
140+
for strategy in strategies:
141+
result = strategy(storage, **kwargs)
142+
if result is not None:
143+
return result
144+
145+
80146
class OPSStorageLoadSingle(AbstractLoader):
81-
def __init__(self, param, store, fallback=None, num_store=None):
147+
"""Objects that expect to load a single object.
148+
149+
These can sometimes include guesswork to figure out which object is
150+
desired.
151+
"""
152+
def __init__(self, param, store, value_strategies=None,
153+
none_strategies=None):
82154
super(OPSStorageLoadSingle, self).__init__(param)
83155
self.store = store
84-
self.fallback = fallback
85-
if num_store is None:
86-
num_store = store
87-
self.num_store = num_store
156+
if value_strategies is None:
157+
value_strategies = [GetByName(self.store),
158+
GetByNumber(self.store)]
159+
self.value_strategies = value_strategies
160+
161+
if none_strategies is None:
162+
none_strategies = [GetOnly(self.store),
163+
GetOnlyNamed(self.store)]
164+
self.none_strategies = none_strategies
88165

89166
def get(self, storage, name):
90167
store = getattr(storage, self.store)
91-
num_store = getattr(storage, self.num_store)
168+
# num_store = getattr(storage, self.num_store)
92169

93-
result = None
94-
# if the we can get by name/number, do it
95170
if name is not None:
96-
try:
97-
result = store[name]
98-
except:
99-
# on any error, we try everything else
100-
pass
101-
102-
if result is None:
103-
try:
104-
num = int(name)
105-
except ValueError:
106-
pass
107-
else:
108-
result = num_store[num]
109-
110-
if result is not None:
111-
return result
112-
113-
# if only one is named, take it
114-
if self.store != 'tags' and name is None:
115-
# if there's only one of them, take that
116-
if len(store) == 1:
117-
return store[0]
118-
named_things = [o for o in store if o.is_named]
119-
if len(named_things) == 1:
120-
return named_things[0]
121-
122-
if len(num_store) == 1 and name is None:
123-
return num_store[0]
124-
125-
if self.fallback:
126-
result = self.fallback(self, storage, name)
171+
result = _try_strategies(self.value_strategies, storage,
172+
name=name)
173+
else:
174+
result = _try_strategies(self.none_strategies, storage)
127175

128176
if result is None:
129177
raise RuntimeError("Couldn't find %s", name)
130178

131179
return result
132180

133-
134-
def init_traj_fallback(parameter, storage, name):
135-
result = None
136-
137-
if isinstance(name, int):
138-
return storage.trajectories[name]
139-
140-
if name is None:
141-
# fallback to final_conditions, initial_conditions, only trajectory
142-
# the "get" here may need to be changed for new storage
143-
for tag in ['final_conditions', 'initial_conditions']:
144-
result = storage.tags[tag]
145-
if result:
146-
return result
147-
148-
# already tried storage.samplesets
149-
if len(storage.trajectories) == 1:
150-
return storage.trajectories[0]
151-
152-
153-
def init_snap_fallback(parameter, storage, name):
154-
# this is structured so that other things can be added to it later
155-
result = None
156-
157-
if name is None:
158-
result = storage.tags['initial_snapshot']
159-
if result:
160-
return result
161-
162-
if len(storage.snapshots) == 2:
163-
# this is really only 1 snapshot; reversed copy gets saved
164-
return storage.snapshots[0]
165-
166181
ENGINE = OPSStorageLoadSingle(
167182
param=Option('-e', '--engine', help="identifer for the engine"),
168183
store='engines',
169-
fallback=None # for now... I'll add more tricks later
184+
# fallback=None # for now... I'll add more tricks later
170185
)
171186

172187
SCHEME = OPSStorageLoadSingle(
173188
param=Option('-m', '--scheme', help="identifier for the move scheme"),
174189
store='schemes',
175-
fallback=None
190+
# fallback=None
176191
)
177192

178193
INIT_CONDS = OPSStorageLoadSingle(
179194
param=Option('-t', '--init-conds',
180195
help=("identifier for initial conditions "
181196
+ "(sample set or trajectory)")),
182-
store='tags',
183-
num_store='samplesets',
184-
fallback=init_traj_fallback
197+
store='samplesets',
198+
value_strategies=[GetByName('tags'), GetByNumber('samplesets'),
199+
GetByNumber('trajectories')],
200+
none_strategies=[GetOnly('samplesets'), GetOnly('trajectories'),
201+
GetPredefinedName('tags', 'final_conditions'),
202+
GetPredefinedName('tags', 'initial_conditions')]
185203
)
186204

187205
INIT_SNAP = OPSStorageLoadSingle(
188206
param=Option('-f', '--init-frame',
189207
help="identifier for initial snapshot"),
190-
store='tags',
191-
num_store='snapshots',
192-
fallback=init_snap_fallback
208+
store='snapshots',
209+
value_strategies=[GetByName('tags'), GetByNumber('snapshots')],
210+
none_strategies=[GetOnlySnapshot(),
211+
GetPredefinedName('tags', 'initial_snapshot')]
193212
)
194213

195214
CVS = OPSStorageLoadNames(

0 commit comments

Comments
 (0)