|
| 1 | +import os |
| 2 | + |
| 3 | +import pytest |
| 4 | +from unittest.mock import patch, MagicMock |
| 5 | +import tempfile |
| 6 | +from click.testing import CliRunner |
| 7 | + |
| 8 | +from paths_cli.commands.visit_all import * |
| 9 | + |
| 10 | +import openpathsampling as paths |
| 11 | + |
| 12 | +# patch with this for testing |
| 13 | +def print_test(output_storage, states, engine, initial_frame): |
| 14 | + print(isinstance(output_storage, paths.Storage)) |
| 15 | + print(sorted([s.__uuid__ for s in states])) |
| 16 | + print(engine.__uuid__) |
| 17 | + print(initial_frame.__uuid__) |
| 18 | + |
| 19 | + |
| 20 | +@pytest.fixture() |
| 21 | +def tps_fixture(flat_engine, tps_network_and_traj): |
| 22 | + network, traj = tps_network_and_traj |
| 23 | + scheme = paths.OneWayShootingMoveScheme(network=network, |
| 24 | + selector=paths.UniformSelector(), |
| 25 | + engine=flat_engine) |
| 26 | + init_conds = scheme.initial_conditions_from_trajectories(traj) |
| 27 | + return (scheme, network, flat_engine, init_conds) |
| 28 | + |
| 29 | +@pytest.fixture() |
| 30 | +def visit_all_fixture(tps_fixture): |
| 31 | + scheme, network, engine, init_conds = tps_fixture |
| 32 | + states = sorted(network.all_states, key=lambda x: x.__uuid__) |
| 33 | + init_frame = init_conds[0].trajectory[0] |
| 34 | + return states, engine, init_frame |
| 35 | + |
| 36 | + |
| 37 | +@patch('paths_cli.commands.visit_all.visit_all_main', print_test) |
| 38 | +def test_visit_all(visit_all_fixture): |
| 39 | + # this is an integration test; testing integration click & parameters |
| 40 | + states, engine, init_frame = visit_all_fixture |
| 41 | + runner = CliRunner() |
| 42 | + with runner.isolated_filesystem(): |
| 43 | + storage = paths.Storage("setup.nc", 'w') |
| 44 | + for obj in visit_all_fixture: |
| 45 | + storage.save(obj) |
| 46 | + storage.tags['initial_snapshot'] = init_frame |
| 47 | + storage.close() |
| 48 | + |
| 49 | + results = runner.invoke( |
| 50 | + visit_all, |
| 51 | + ["setup.nc", '-o', 'foo.nc', '-s', 'A', '-s', 'B', |
| 52 | + '-e', 'flat', '-f', 'initial_snapshot'] |
| 53 | + ) |
| 54 | + |
| 55 | + expected_output = ("True\n[" + str(states[0].__uuid__) + ", " |
| 56 | + + str(states[1].__uuid__) + "]\n") |
| 57 | + expected_output += "\n".join(str(obj.__uuid__) |
| 58 | + for obj in [engine, init_frame]) + "\n" |
| 59 | + assert results.exit_code == 0 |
| 60 | + assert results.output == expected_output |
| 61 | + |
| 62 | +def test_visit_all_main(visit_all_fixture): |
| 63 | + # just a smoke test here |
| 64 | + tempdir = tempfile.mkdtemp() |
| 65 | + try: |
| 66 | + store_name = os.path.join(tempdir, "visit_all.nc") |
| 67 | + storage = paths.Storage(store_name, mode='w') |
| 68 | + states, engine, init_frame = visit_all_fixture |
| 69 | + traj, foo = visit_all_main(storage, states, engine, init_frame) |
| 70 | + assert isinstance(traj, paths.Trajectory) |
| 71 | + assert foo is None |
| 72 | + assert len(storage.trajectories) == 1 |
| 73 | + storage.close() |
| 74 | + finally: |
| 75 | + os.remove(store_name) |
| 76 | + os.rmdir(tempdir) |
0 commit comments