Skip to content

Commit 2da6ef3

Browse files
committed
file setup for TestINIT_TRAJ
1 parent 95628c8 commit 2da6ef3

File tree

4 files changed

+98
-11
lines changed

4 files changed

+98
-11
lines changed

paths_cli/commands/contents.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,3 +65,4 @@ def get_section_string_nameable(section, store, get_named):
6565

6666
CLI = nclist
6767
SECTION = "Miscellaneous"
68+
REQUIRES_OPS = (1, 0)

paths_cli/commands/strip_snapshots.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,3 +73,4 @@ def strip_snapshots_main(input_storage, output_storage, cvs, blocksize):
7373

7474
CLI = strip_snapshots
7575
SECTION = "Miscellaneous"
76+
REQUIRES_OPS = (1, 0)

paths_cli/parameters.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,8 @@ def get(self, storage, name):
6464
result = None
6565
# if the we can get by name/number, do it
6666
if name is not None:
67-
result = store[name]
67+
# note: new storage may do a try/except here
68+
result = store[name]
6869

6970
if result is None:
7071
try:
@@ -95,6 +96,24 @@ def get(self, storage, name):
9596

9697
return result
9798

99+
def init_traj_fallback(parameter, storage, name):
100+
result = None
101+
if name and os.path.isfile(name):
102+
# TODO: read from file
103+
pass
104+
105+
if name is None:
106+
# fallback to final_conditions, initial_conditions, only trajectory
107+
# the "get" here may need to be changed for new storage
108+
for tag in ['final_conditions', 'initial_conditions']:
109+
result = storage.tags[tag]
110+
if result:
111+
return [result]
112+
113+
if len(storage.samplesets) == 1:
114+
return [s.trajectory for s in storage.samplesets[0]]
115+
116+
98117
ENGINE = OPSStorageLoadSingle(
99118
param=Option('-e', '--engine', help="identifer for the engine"),
100119
store='engines',
@@ -112,7 +131,7 @@ def get(self, storage, name):
112131
help="identifier for initial trajectory"),
113132
store='tags',
114133
num_store='trajectories',
115-
fallback=None # for now
134+
fallback=init_traj_fallback
116135
)
117136

118137
CVS = OPSStorageLoadNames(

paths_cli/tests/test_parameters.py

Lines changed: 75 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import os
44

55
import openpathsampling as paths
6+
from openpathsampling.tests.test_helpers import make_1d_traj
67

78
from paths_cli.parameters import *
89

@@ -98,6 +99,7 @@ def teardown(self):
9899
os.remove(os.path.join(self.tempdir, temp_f))
99100
os.rmdir(self.tempdir)
100101

102+
101103
class TestENGINE(ParamInstanceTest):
102104
PARAMETER = ENGINE
103105
def setup(self):
@@ -126,16 +128,72 @@ def test_get(self, getter):
126128
self._getter_test(getter)
127129

128130

129-
class TestINIT_TRAJ(object):
131+
class TestINIT_TRAJ(ParamInstanceTest):
132+
PARAMETER = INIT_TRAJ
130133
def setup(self):
131-
pytest.skip()
132-
pass
134+
super(TestINIT_TRAJ, self).setup()
135+
self.traj = make_1d_traj([-0.1, 1.0, 4.4, 7.7, 10.01])
136+
ensemble = self.scheme.network.sampling_ensembles[0]
137+
self.sample_set = paths.SampleSet([
138+
paths.Sample(trajectory=self.traj,
139+
replica=0,
140+
ensemble=ensemble)
141+
])
142+
self.other_traj = make_1d_traj([-1.0, 1.0, 100.0])
143+
self.other_sample_set = paths.SampleSet([
144+
paths.Sample(trajectory=self.other_traj,
145+
replica=0,
146+
ensemble=ensemble)
147+
])
148+
149+
@staticmethod
150+
def _parse_getter(getter):
151+
split_up = getter.split('-')
152+
get_type = split_up[-1]
153+
getter_style = "-".join(split_up[:-1])
154+
return get_type, getter_style
133155

134-
@pytest.mark.parametrize("getter", ['name', 'number', 'tag-final',
135-
'tag-initial', 'file'])
156+
def create_file(self, getter):
157+
filename = self._filename(getter)
158+
storage = paths.Storage(filename, 'w')
159+
storage.save(self.traj)
160+
storage.save(self.other_traj)
161+
get_type, getter_style = self._parse_getter(getter)
162+
main, other = {
163+
'traj': (self.traj, self.other_traj),
164+
'sset': (self.sample_set, self.other_sample_set)
165+
}[get_type]
166+
if get_type == 'sset':
167+
storage.save(self.sample_set)
168+
storage.save(self.other_sample_set)
169+
170+
tag, other_tag = {
171+
'name': ('traj', None),
172+
'number': (None, None),
173+
'tag-final': ('final_conditions', 'initial_conditions'),
174+
'tag-initial': ('initial_conditions', None)
175+
}[getter_style]
176+
if tag:
177+
storage.tags[tag] = main
178+
179+
if other_tag:
180+
storage.tags[other_tag] = other
181+
storage.close()
182+
return filename
183+
184+
@pytest.mark.parametrize("getter", [
185+
'name-traj', 'number-traj', 'tag-final-traj', 'tag-initial-traj',
186+
'name-sset', 'number-sset', 'tag-final-sset', 'tag-initial-sset'
187+
])
136188
def test_get(self, getter):
189+
filename = self.create_file(getter)
190+
pytest.skip()
137191
pass
138192

193+
def test_get_file(self):
194+
pytest.skip()
195+
196+
139197
class TestCVS(ParamInstanceTest):
140198
PARAMETER = CVS
141199
def setup(self):
@@ -147,12 +205,20 @@ def setup(self):
147205
def test_get(self, getter):
148206
self._getter_test(getter)
149207

150-
class TestSTATES(object):
208+
209+
class TestSTATES(ParamInstanceTest):
151210
PARAMETER = STATES
152211
def setup(self):
153-
pytest.skip()
154-
pass
212+
super(TestSTATES, self).setup()
213+
self.get_arg = {'name': "A", 'number': 0}
214+
self.obj = self.state_A
155215

156216
@pytest.mark.parametrize("getter", ['name', 'number'])
157217
def test_get(self, getter):
158-
pass
218+
self._getter_test(getter)
219+
220+
@pytest.mark.parametrize("getter", ['name', 'number'])
221+
def test_get_other(self, getter):
222+
self.get_arg = {'name': 'B', 'number': 1}
223+
self.obj = self.state_B
224+
self._getter_test(getter)

0 commit comments

Comments
 (0)