Skip to content

Commit 3db10a3

Browse files
committed
tests for INIT_CONDS
1 parent d57c6df commit 3db10a3

File tree

2 files changed

+43
-19
lines changed

2 files changed

+43
-19
lines changed

paths_cli/parameters.py

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ def __init__(self, param, store, fallback=None, num_store=None):
6161

6262
def get(self, storage, name):
6363
store = getattr(storage, self.store)
64+
num_store = getattr(storage, self.num_store)
6465

6566
result = None
6667
# if the we can get by name/number, do it
@@ -77,20 +78,23 @@ def get(self, storage, name):
7778
except ValueError:
7879
pass
7980
else:
80-
num_store = getattr(storage, self.num_store)
8181
result = num_store[num]
8282

8383
if result is not None:
8484
return result
8585

86-
# if there's only one of them, take that
87-
if len(store) == 1:
88-
return store[0]
8986

9087
# if only one is named, take it
91-
named_things = [o for o in store if o.is_named]
92-
if len(named_things) == 1:
93-
return named_things[0]
88+
if self.store != 'tags':
89+
# if there's only one of them, take that
90+
if len(store) == 1:
91+
return store[0]
92+
named_things = [o for o in store if o.is_named]
93+
if len(named_things) == 1:
94+
return named_things[0]
95+
96+
if len(num_store) == 1:
97+
return num_store[0]
9498

9599
if self.fallback:
96100
result = self.fallback(self, storage, name)
@@ -101,26 +105,24 @@ def get(self, storage, name):
101105
return result
102106

103107

104-
105108
def init_traj_fallback(parameter, storage, name):
106109
result = None
110+
107111
if isinstance(name, int):
108112
return storage.trajectories[name]
109113

110-
if name and os.path.isfile(name):
111-
# TODO: read from file
112-
pass
113-
114114
if name is None:
115115
# fallback to final_conditions, initial_conditions, only trajectory
116116
# the "get" here may need to be changed for new storage
117117
for tag in ['final_conditions', 'initial_conditions']:
118118
result = storage.tags[tag]
119119
if result:
120-
return [result]
120+
return result
121+
122+
# already tried storage.samplesets
123+
if len(storage.trajectories) == 1:
124+
return storage.trajectories[0]
121125

122-
if len(storage.samplesets) == 1:
123-
return [s.trajectory for s in storage.samplesets[0]]
124126

125127

126128
ENGINE = OPSStorageLoadSingle(

paths_cli/tests/test_parameters.py

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -200,12 +200,34 @@ def test_get(self, getter):
200200
'tag-initial': 'initial_conditions'
201201
}[getter_style]
202202
obj = self.PARAMETER.get(storage, get_arg)
203+
assert obj == expected
204+
205+
@pytest.mark.parametrize("num_in_file", [1, 2, 3, 4])
206+
def test_get_none(self, num_in_file):
207+
stored_things = [
208+
self.traj, self.sample_set, self.other_sample_set,
209+
self.other_sample_set
210+
]
211+
to_store = stored_things[:num_in_file]
212+
filename = self._filename("init_conds_" + str(num_in_file) + ".nc")
213+
storage = paths.Storage(filename, mode='w')
214+
for item in to_store:
215+
storage.save(item)
216+
217+
if num_in_file == 3:
218+
storage.tags['initial_conditions'] = self.other_sample_set
219+
elif num_in_file == 4:
220+
storage.tags['final_conditions'] = self.other_sample_set
221+
storage.tags['initial_conditions'] = self.sample_set
203222

204-
pytest.skip()
205-
pass
223+
storage.close()
224+
225+
st = paths.Storage(filename, mode='r')
226+
obj = INIT_CONDS.get(st, None)
227+
assert obj == stored_things[num_in_file - 1]
206228

207-
def test_get_file(self):
208-
pytest.skip()
229+
230+
pass
209231

210232

211233
class MultiParamInstanceTest(ParamInstanceTest):

0 commit comments

Comments
 (0)