Skip to content

Commit a629e66

Browse files
committed
Add INIT_SNAP parameter
1 parent a6584ae commit a629e66

File tree

2 files changed

+55
-2
lines changed

2 files changed

+55
-2
lines changed

paths_cli/parameters.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,18 @@ def init_traj_fallback(parameter, storage, name):
126126
return storage.trajectories[0]
127127

128128

129+
def init_snap_fallback(parameter, storage, name):
130+
# this is structured so that other things can be added to it later
131+
result = None
132+
133+
if name is None:
134+
result = storage.tags['initial_snapshot']
135+
if result:
136+
return result
137+
138+
if len(storage.snapshots) == 2:
139+
# this is really only 1 snapshot; reversed copy gets saved
140+
return storage.snapshots[0]
129141

130142
ENGINE = OPSStorageLoadSingle(
131143
param=Option('-e', '--engine', help="identifer for the engine"),
@@ -142,12 +154,20 @@ def init_traj_fallback(parameter, storage, name):
142154
INIT_CONDS = OPSStorageLoadSingle(
143155
param=Option('-t', '--init-conds',
144156
help=("identifier for initial conditions "
145-
+ "(sample set or trajectory")),
157+
+ "(sample set or trajectory)")),
146158
store='tags',
147159
num_store='samplesets',
148160
fallback=init_traj_fallback
149161
)
150162

163+
INIT_SNAP = OPSStorageLoadSingle(
164+
param=Option('-f', '--init-frame',
165+
help="identifier for initial snapshot"),
166+
store='tags',
167+
num_store='snapshots',
168+
fallback=init_snap_fallback
169+
)
170+
151171
CVS = OPSStorageLoadNames(
152172
param=Option('--cv', type=str, multiple=True,
153173
help='name of CV; may select more than once'),

paths_cli/tests/test_parameters.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -233,8 +233,41 @@ def test_get_none(self, num_in_file):
233233
obj = INIT_CONDS.get(st, None)
234234
assert obj == stored_things[num_in_file - 1]
235235

236+
class TestINIT_SNAP(ParamInstanceTest):
237+
PARAMETER = INIT_SNAP
238+
def setup(self):
239+
super(TestINIT_SNAP, self).setup()
240+
traj = make_1d_traj([1.0, 2.0])
241+
self.other_snap = traj[0]
242+
self.init_snap = traj[1]
243+
244+
def create_file(self, getter):
245+
filename = self._filename(getter)
246+
storage = paths.Storage(filename, 'w')
247+
storage.save(self.other_snap)
248+
if getter != 'none-num':
249+
storage.save(self.init_snap)
250+
storage.tags['initial_snapshot'] = self.init_snap
251+
storage.close()
252+
return filename
236253

237-
pass
254+
@pytest.mark.parametrize('getter', ['none-num', 'tag', 'number',
255+
'none-tag'])
256+
def test_get(self, getter):
257+
# get by number is 2 because of snapshot duplication in storage
258+
get_arg = {'none-num': None,
259+
'none-tag': None,
260+
'tag': 'initial_snapshot',
261+
'number': 2}[getter]
262+
expected = {'none-num': self.other_snap,
263+
'none-tag': self.init_snap,
264+
'tag': self.init_snap,
265+
'number': self.init_snap}[getter]
266+
filename = self.create_file(getter)
267+
storage = paths.Storage(filename, mode='r')
268+
269+
obj = self.PARAMETER.get(storage, get_arg)
270+
assert obj == expected
238271

239272

240273
class MultiParamInstanceTest(ParamInstanceTest):

0 commit comments

Comments
 (0)