Skip to content

Commit c526122

Browse files
committed
Add unit tests for new parameters
1 parent 424aa6b commit c526122

File tree

2 files changed

+73
-7
lines changed

2 files changed

+73
-7
lines changed

paths_cli/parameters.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,7 @@ def init_snap_fallback(parameter, storage, name):
202202
param=Option('--volume', type=str, multiple=True,
203203
help='name or index of volume' + HELP_MULTIPLE),
204204
store='volumes'
205-
) #TODO:unit tests
205+
)
206206

207207
MULTI_ENGINE = OPSStorageLoadNames(
208208
param=Option('--engine', type=str, multiple=True,
@@ -221,19 +221,19 @@ def init_snap_fallback(parameter, storage, name):
221221
param=Option('--tag', type=str, multiple=True,
222222
help='tag for object' + HELP_MULTIPLE),
223223
store='tags'
224-
) # TODO: unit testsA
224+
)
225225

226226
MULTI_NETWORK = OPSStorageLoadNames(
227227
param=Option('--network', type=str, multiple=True,
228228
help='name or index of network' + HELP_MULTIPLE),
229229
store='networks'
230-
) # TODO: unit tests
230+
)
231231

232232
MULTI_SCHEME = OPSStorageLoadNames(
233233
param=Option('--scheme', type=str, multiple=True,
234234
help='name or index of move scheme' + HELP_MULTIPLE),
235235
store='schemes'
236-
) # TODO: unit tests
236+
)
237237

238238
INPUT_FILE = StorageLoader(
239239
param=Argument('input_file',

paths_cli/tests/test_parameters.py

Lines changed: 69 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,14 +50,15 @@ def setup(self):
5050
self.state_B = paths.CVDefinedVolume(
5151
self.cv, 10, float("inf")
5252
).named("B")
53-
network = paths.TPSNetwork(self.state_A, self.state_B)
53+
self.network = paths.TPSNetwork(self.state_A,
54+
self.state_B).named('network')
5455
self.scheme = paths.OneWayShootingMoveScheme(
55-
network,
56+
self.network,
5657
paths.UniformSelector(),
5758
self.engine
5859
).named("scheme")
5960
self.other_scheme = paths.OneWayShootingMoveScheme(
60-
network,
61+
self.network,
6162
paths.UniformSelector(),
6263
self.other_engine
6364
)
@@ -310,6 +311,71 @@ def test_get_other(self, getter):
310311
self.obj = self.state_B
311312
self._getter_test(getter)
312313

314+
315+
class MULTITest(MultiParamInstanceTest):
316+
# Abstract base class for tests of MULTI_* parameters
317+
# These parameters require name or number input, otherwise error
318+
@pytest.mark.parametrize("getter", ['name', 'number'])
319+
def test_get(self, getter):
320+
self._getter_test(getter)
321+
322+
def test_get_none(self):
323+
filename = self.create_file("none")
324+
storage = paths.Storage(filename, mode='r')
325+
with pytest.raises(TypeError):
326+
# TypeError: 'NoneType' object is not iterable
327+
self.PARAMETER.get(storage, None)
328+
329+
330+
class TestMULTI_VOLUME(TestSTATES, MULTITest):
331+
PARAMETER = MULTI_VOLUME
332+
# MULTI_VOLUME is basically the same as STATES, with different option
333+
# parameters
334+
335+
336+
class TestMULTI_ENGINE(MULTITest):
337+
PARAMETER = MULTI_ENGINE
338+
def setup(self):
339+
super(TestMULTI_ENGINE, self).setup()
340+
self.get_arg = {'name': ["engine"], 'number': [0]}
341+
self.obj = self.engine
342+
343+
344+
class TestMulti_NETWORK(MULTITest):
345+
PARAMETER = MULTI_NETWORK
346+
def setup(self):
347+
super(TestMulti_NETWORK, self).setup()
348+
self.get_arg = {'name': ['network'], 'number': [0]}
349+
self.obj = self.network
350+
351+
352+
class TestMULTI_SCHEME(MULTITest):
353+
PARAMETER = MULTI_SCHEME
354+
def setup(self):
355+
super(TestMULTI_SCHEME, self).setup()
356+
self.get_arg = {'name': ['scheme'], 'number': [0]}
357+
self.obj = self.scheme
358+
359+
360+
class TestMULTI_TAG(MULTITest):
361+
PARAMETER = MULTI_TAG
362+
def setup(self):
363+
super(TestMULTI_TAG, self).setup()
364+
self.obj = make_1d_traj([1.0, 2.0, 3.0])
365+
self.get_arg = {'name': ['traj']}
366+
367+
def create_file(self, getter):
368+
filename = self._filename(getter)
369+
storage = paths.Storage(filename, 'w')
370+
storage.tag['traj'] = self.obj
371+
storage.close()
372+
return filename
373+
374+
@pytest.mark.parametrize("getter", ['name'])
375+
def test_get(self, getter):
376+
self._getter_test(getter)
377+
378+
313379
def test_OUTPUT_FILE():
314380
tempdir = tempfile.mkdtemp()
315381
filename = os.path.join(tempdir, "test_output_file.nc")

0 commit comments

Comments
 (0)