Skip to content

Commit aa725e6

Browse files
Merge pull request #415 from krystophny/fix-save-binary-cluster-lock
Fix save_binary TypeError on objects holding a Cluster
2 parents fe997a7 + 5792fd8 commit aa725e6

2 files changed

Lines changed: 89 additions & 0 deletions

File tree

Modules/Cluster.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -333,6 +333,30 @@ def __setattr__(self, name, value):
333333
super(Cluster, self).__setattr__(name, value)
334334

335335

336+
def __getstate__(self):
337+
"""
338+
Return the picklable state of the cluster.
339+
340+
The thread lock created by compute_ensemble_batch cannot be pickled,
341+
so it is dropped here. This allows sscha.Utilities.save_binary to
342+
store objects holding a cluster after a calculation has run.
343+
"""
344+
state = self.__dict__.copy()
345+
state["lock"] = None
346+
return state
347+
348+
349+
def __setstate__(self, state):
350+
"""
351+
Restore the cluster from a pickled state.
352+
353+
The thread lock is transient runtime state and is reset to None,
354+
as after __init__; compute_ensemble_batch recreates it when needed.
355+
"""
356+
state["lock"] = None
357+
self.__dict__.update(state)
358+
359+
336360

337361
def copy_file(self, source, destination, server_source = False, server_dest = True, raise_error=False, **kwargs):
338362
"""
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
# -*- coding: utf-8 -*-
2+
from __future__ import print_function
3+
from __future__ import division
4+
5+
import os
6+
import tempfile
7+
import threading
8+
9+
import cellconstructor as CC
10+
import cellconstructor.Phonons
11+
12+
import sscha
13+
import sscha.Cluster
14+
import sscha.Ensemble
15+
import sscha.Relax
16+
import sscha.SchaMinimizer
17+
import sscha.Utilities
18+
19+
"""
20+
Regression test for issue #114: save_binary failed with
21+
TypeError: cannot pickle '_thread.lock' object
22+
when the relax object holds a cluster that already ran a calculation
23+
(Cluster.compute_ensemble_batch stores a threading.Lock on the cluster).
24+
"""
25+
26+
27+
def test_save_binary_relax_with_cluster(verbose=False):
28+
total_path = os.path.dirname(os.path.abspath(__file__))
29+
os.chdir(total_path)
30+
31+
DATA_PATH = "../../Examples/ensemble_data_test/"
32+
33+
dyn = CC.Phonons.Phonons(os.path.join(DATA_PATH, "dyn"))
34+
35+
ens = sscha.Ensemble.Ensemble(dyn, 0, dyn.GetSupercell())
36+
ens.load(DATA_PATH, 2, 10)
37+
38+
minim = sscha.SchaMinimizer.SSCHA_Minimizer(ens)
39+
40+
cluster = sscha.Cluster.Cluster(hostname="localhost")
41+
relax = sscha.Relax.SSCHA(minim, N_configs=10, max_pop=2,
42+
cluster=cluster)
43+
44+
# Cluster.compute_ensemble_batch leaves a threading.Lock on the
45+
# cluster after the ensemble calculation; reproduce that state.
46+
relax.cluster.lock = threading.Lock()
47+
48+
with tempfile.TemporaryDirectory() as tmpdir:
49+
filename = os.path.join(tmpdir, "relax.bin")
50+
sscha.Utilities.save_binary(relax, filename)
51+
52+
loaded = sscha.Utilities.load_binary(filename)
53+
54+
# The lock is transient runtime state and must come back unset.
55+
assert loaded.cluster.lock is None
56+
assert loaded.cluster.hostname == "localhost"
57+
assert loaded.N_configs == relax.N_configs
58+
assert loaded.minim.ensemble.N == ens.N
59+
60+
if verbose:
61+
print("save_binary/load_binary round trip succeeded")
62+
63+
64+
if __name__ == "__main__":
65+
test_save_binary_relax_with_cluster(True)

0 commit comments

Comments
 (0)