Skip to content

Commit 40a503b

Browse files
authored
Merge pull request #29 from dwhswenson/simstore
SimStore Support
2 parents 0aa3aa2 + 17a0319 commit 40a503b

File tree

4 files changed

+136
-16
lines changed

4 files changed

+136
-16
lines changed

devtools/tests_require.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,6 @@ nose
33
pytest
44
pytest-cov
55
coveralls
6+
# following are for SimStore integration
7+
dill
8+
sqlalchemy

paths_cli/param_core.py

Lines changed: 50 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -72,22 +72,48 @@ class StorageLoader(AbstractLoader):
7272
mode : 'r', 'w', or 'a'
7373
the mode for the file
7474
"""
75+
has_simstore_patch = False
7576
def __init__(self, param, mode):
7677
super(StorageLoader, self).__init__(param)
7778
self.mode = mode
7879

80+
@staticmethod
81+
def _is_simstore(name):
82+
return name.endswith(".db") or name.endswith(".sql")
83+
7984
def _workaround(self, name):
8085
# this is messed up... for some reason, storage doesn't create a new
8186
# file in append mode. That may be a bug
8287
import openpathsampling as paths
83-
if self.mode == 'a' and not os.path.exists(name):
88+
needs_workaround = (
89+
self.mode == 'a'
90+
and not os.path.exists(name)
91+
and not self._is_simstore(name)
92+
)
93+
if needs_workaround:
8494
st = paths.Storage(name, mode='w')
8595
st.close()
8696

8797
def get(self, name):
88-
import openpathsampling as paths
89-
self._workaround(name)
90-
return paths.Storage(name, mode=self.mode)
98+
if self._is_simstore(name):
99+
import openpathsampling as paths
100+
from openpathsampling.experimental.storage import \
101+
Storage, monkey_patch_all
102+
103+
if not self.has_simstore_patch:
104+
paths = monkey_patch_all(paths)
105+
paths.InterfaceSet.simstore = True
106+
StorageLoader.has_simstore_patch = True
107+
108+
from openpathsampling.experimental.simstore import \
109+
SQLStorageBackend
110+
backend = SQLStorageBackend(name, mode=self.mode)
111+
storage = Storage.from_backend(backend)
112+
else:
113+
from openpathsampling import Storage
114+
self._workaround(name)
115+
storage = Storage(name, self.mode)
116+
return storage
91117

92118

93119
class OPSStorageLoadNames(AbstractLoader):
@@ -200,10 +226,20 @@ class GetOnlySnapshot(Getter):
200226
def __init__(self, store_name="snapshots"):
201227
super().__init__(store_name)
202228

229+
def _min_num_snapshots(self, storage):
230+
# For netcdfplus, we see 2 snapshots when there is only one
231+
# (reversed copy gets saved). For SimStore, we see only one.
232+
import openpathsampling as paths
233+
if isinstance(storage, paths.netcdfplus.NetCDFPlus):
234+
min_snaps = 2
235+
else:
236+
min_snaps = 1
237+
return min_snaps
238+
203239
def __call__(self, storage):
204240
store = getattr(storage, self.store_name)
205-
if len(store) == 2:
206-
# this is really only 1 snapshot; reversed copy gets saved
241+
min_snaps = self._min_num_snapshots(storage)
242+
if len(store) == min_snaps:
207243
return store[0]
208244

209245

@@ -270,7 +306,14 @@ def get(self, storage, name):
270306
result = _try_strategies(self.none_strategies, storage)
271307

272308
if result is None:
273-
raise RuntimeError("Couldn't find %s", name)
309+
if name is None:
310+
msg = "Couldn't guess which item to use from " + self.store
311+
else:
312+
msg = "Couldn't find {name} is {store}".format(
313+
name=name,
314+
store=self.store
315+
)
316+
raise RuntimeError(msg)
274317

275318
return result
276319

paths_cli/tests/commands/test_append.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,22 +8,22 @@
88
import openpathsampling as paths
99

1010
def make_input_file(tps_network_and_traj):
11-
input_file = paths.Storage("setup.py", mode='w')
11+
input_file = paths.Storage("setup.nc", mode='w')
1212
for obj in tps_network_and_traj:
1313
input_file.save(obj)
1414

1515
input_file.tags['template'] = input_file.snapshots[0]
1616
input_file.close()
17-
return "setup.py"
17+
return "setup.nc"
1818

1919
def test_append(tps_network_and_traj):
2020
runner = CliRunner()
2121
with runner.isolated_filesystem():
2222
in_file = make_input_file(tps_network_and_traj)
2323
result = runner.invoke(append, [in_file, '-a', 'output.nc',
2424
'--volume', 'A', '--volume', 'B'])
25-
assert result.exit_code == 0
2625
assert result.exception is None
26+
assert result.exit_code == 0
2727
storage = paths.Storage('output.nc', mode='r')
2828
assert len(storage.volumes) == 2
2929
assert len(storage.snapshots) == 0

paths_cli/tests/test_parameters.py

Lines changed: 80 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,44 @@
22
import tempfile
33
import os
44

5-
import openpathsampling as paths
5+
import paths_cli
66
from openpathsampling.tests.test_helpers import make_1d_traj
77

88
from paths_cli.parameters import *
9+
import openpathsampling as paths
10+
11+
12+
def pre_monkey_patch():
13+
# store things that get monkey-patched; ensure we un-patch
14+
stored_functions = {}
15+
CallableCV = paths.CallableCV
16+
PseudoAttr = paths.netcdfplus.FunctionPseudoAttribute
17+
stored_functions['CallableCV.from'] = CallableCV.from_dict
18+
stored_functions['PseudoAttr.from'] = PseudoAttr.from_dict
19+
stored_functions['TPSNetwork.from'] = paths.TPSNetwork.from_dict
20+
stored_functions['MISTISNetwork.from'] = paths.MISTISNetwork.from_dict
21+
stored_functions['PseudoAttr.to'] = PseudoAttr.to_dict
22+
stored_functions['TPSNetwork.to'] = paths.TPSNetwork.to_dict
23+
stored_functions['MISTISNetwork.to'] = paths.MISTISNetwork.to_dict
24+
return stored_functions
25+
26+
def undo_monkey_patch(stored_functions):
27+
CallableCV = paths.CallableCV
28+
PseudoAttr = paths.netcdfplus.FunctionPseudoAttribute
29+
CallableCV.from_dict = stored_functions['CallableCV.from']
30+
PseudoAttr.from_dict = stored_functions['PseudoAttr.from']
31+
paths.TPSNetwork.from_dict = stored_functions['TPSNetwork.from']
32+
paths.MISTISNetwork.from_dict = stored_functions['MISTISNetwork.from']
33+
PseudoAttr.to_dict = stored_functions['PseudoAttr.to']
34+
paths.TPSNetwork.to_dict = stored_functions['TPSNetwork.to']
35+
paths.MISTISNetwork.to_dict = stored_functions['MISTISNetwork.to']
36+
paths_cli.param_core.StorageLoader.has_simstore_patch = False
37+
paths.InterfaceSet.simstore = False
38+
import importlib
39+
importlib.reload(paths.netcdfplus)
40+
importlib.reload(paths.collectivevariable)
41+
importlib.reload(paths)
42+
943

1044

1145
class ParameterTest(object):
@@ -114,6 +148,17 @@ def setup(self):
114148
def test_get(self, getter):
115149
self._getter_test(getter)
116150

151+
def test_cannot_guess(self):
152+
filename = self._filename('no-guess')
153+
storage = paths.Storage(filename, 'w')
154+
storage.save(self.engine)
155+
storage.save(self.other_engine.named('other'))
156+
storage.close()
157+
158+
storage = paths.Storage(filename, mode='r')
159+
with pytest.raises(RuntimeError):
160+
self.PARAMETER.get(storage, None)
161+
117162

118163
class TestSCHEME(ParamInstanceTest):
119164
PARAMETER = SCHEME
@@ -241,6 +286,17 @@ def test_get_multiple(self):
241286
assert traj0 == self.traj
242287
assert traj1 == self.other_traj
243288

289+
def test_cannot_guess(self):
290+
filename = self._filename('no-guess')
291+
storage = paths.Storage(filename, 'w')
292+
storage.save(self.traj)
293+
storage.save(self.other_traj)
294+
storage.close()
295+
296+
storage = paths.Storage(filename, 'r')
297+
with pytest.raises(RuntimeError):
298+
self.PARAMETER.get(storage, None)
299+
244300

245301
class TestINIT_SNAP(ParamInstanceTest):
246302
PARAMETER = INIT_SNAP
@@ -278,6 +334,18 @@ def test_get(self, getter):
278334
obj = self.PARAMETER.get(storage, get_arg)
279335
assert obj == expected
280336

337+
def test_simstore_single_snapshot(self):
338+
stored_functions = pre_monkey_patch()
339+
filename = os.path.join(self.tempdir, "simstore.db")
340+
storage = APPEND_FILE.get(filename)
341+
storage.save(self.init_snap)
342+
storage.close()
343+
344+
storage = INPUT_FILE.get(filename)
345+
snap = self.PARAMETER.get(storage, None)
346+
assert snap == self.init_snap
347+
undo_monkey_patch(stored_functions)
348+
281349

282350
class MultiParamInstanceTest(ParamInstanceTest):
283351
def _getter_test(self, getter):
@@ -384,21 +452,26 @@ def test_get(self, getter):
384452
self._getter_test(getter)
385453

386454

387-
def test_OUTPUT_FILE():
455+
@pytest.mark.parametrize('ext', ['nc', 'db', 'sql'])
456+
def test_OUTPUT_FILE(ext):
457+
stored_functions = pre_monkey_patch()
388458
tempdir = tempfile.mkdtemp()
389-
filename = os.path.join(tempdir, "test_output_file.nc")
459+
filename = os.path.join(tempdir, "test_output_file." + ext)
390460
assert not os.path.exists(filename)
391461
storage = OUTPUT_FILE.get(filename)
392462
assert os.path.exists(filename)
393463
os.remove(filename)
394464
os.rmdir(tempdir)
465+
undo_monkey_patch(stored_functions)
395466

396-
def test_APPEND_FILE():
467+
@pytest.mark.parametrize('ext', ['nc', 'db', 'sql'])
468+
def test_APPEND_FILE(ext):
469+
stored_functions = pre_monkey_patch()
397470
tempdir = tempfile.mkdtemp()
398-
filename = os.path.join(tempdir, "test_append_file.nc")
471+
filename = os.path.join(tempdir, "test_append_file." + ext)
399472
assert not os.path.exists(filename)
400473
storage = APPEND_FILE.get(filename)
401-
print(storage)
474+
# print(storage) # potentially useful debug; keep
402475
assert os.path.exists(filename)
403476
traj = make_1d_traj([0.0, 1.0])
404477
storage.tags['first_save'] = traj[0]
@@ -412,3 +485,4 @@ def test_APPEND_FILE():
412485
storage.close()
413486
os.remove(filename)
414487
os.rmdir(tempdir)
488+
undo_monkey_patch(stored_functions)

0 commit comments

Comments
 (0)