Skip to content

Commit 47b80e5

Browse files
committed
Fixed an error on the broadcasting of the ASE calculator
1 parent 4ab0dee commit 47b80e5

1 file changed

Lines changed: 23 additions & 20 deletions

File tree

Modules/Ensemble.py

Lines changed: 23 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -3034,7 +3034,7 @@ def get_odd_correction(self, include_v4 = False, store_v3 = True,
30343034

30353035

30363036
def compute_ensemble(self, calculator, compute_stress = True, stress_numerical = False,
3037-
cluster = None):
3037+
cluster = None, verbose = True):
30383038
"""
30393039
GET ENERGY AND FORCES
30403040
=====================
@@ -3084,7 +3084,7 @@ def compute_ensemble(self, calculator, compute_stress = True, stress_numerical =
30843084
if is_cluster:
30853085
cluster.compute_ensemble(computing_ensemble, calculator, compute_stress)
30863086
else:
3087-
computing_ensemble.get_energy_forces(calculator, compute_stress, stress_numerical)
3087+
computing_ensemble.get_energy_forces(calculator, compute_stress, stress_numerical, verbose = verbose)
30883088

30893089
if should_i_merge:
30903090
# Remove the noncomputed ensemble from here, and merge
@@ -3203,7 +3203,7 @@ def get_noncomputed(self):
32033203

32043204
return self.split(non_mask)
32053205

3206-
def get_energy_forces(self, ase_calculator, compute_stress = True, stress_numerical = False, skip_computed = False):
3206+
def get_energy_forces(self, ase_calculator, compute_stress = True, stress_numerical = False, skip_computed = False, verbose = False):
32073207
"""
32083208
GET ENERGY AND FORCES FOR THE CURRENT ENSEMBLE
32093209
==============================================
@@ -3231,27 +3231,30 @@ def get_energy_forces(self, ase_calculator, compute_stress = True, stress_numeri
32313231
"""
32323232

32333233
# Setup the calculator for each structure
3234+
parallel = False
32343235
if __MPI__:
32353236
comm = MPI.COMM_WORLD
32363237
size = comm.Get_size()
32373238
rank = comm.Get_rank()
32383239

3239-
# Broad cast to all the structures
3240-
structures = comm.bcast(self.structures, root = 0)
3241-
nat3 = comm.bcast(self.current_dyn.structure.N_atoms* 3* np.prod(self.supercell), root = 0)
3242-
N_rand = comm.bcast(self.N, root=0)
3243-
3244-
# Setup the label of the calculator
3245-
ase_calculator = comm.bcast(ase_calculator, root = 0)
3246-
ase_calculator.set_label("esp_%d" % rank) # Avoid overwriting the same file
3247-
3248-
compute_stress = comm.bcast(compute_stress, root = 0)
3249-
3250-
# Check if the parallelization is correct
3251-
if N_rand % size != 0:
3252-
raise ValueError("Error, for paralelization the ensemble dimension must be a multiple of the processors")
3240+
if size > 1:
3241+
parallel = True
3242+
# Broad cast to all the structures
3243+
structures = comm.bcast(self.structures, root = 0)
3244+
nat3 = comm.bcast(self.current_dyn.structure.N_atoms* 3* np.prod(self.supercell), root = 0)
3245+
N_rand = comm.bcast(self.N, root=0)
3246+
3247+
# Setup the label of the calculator
3248+
#ase_calculator = comm.bcast(ase_calculator, root = 0) # This broadcasting seems causing some issues on some fortran codes called by python (which may interact with MPI)
3249+
ase_calculator.set_label("esp_%d" % rank) # Avoid overwriting the same file
3250+
3251+
compute_stress = comm.bcast(compute_stress, root = 0)
3252+
3253+
# Check if the parallelization is correct
3254+
if N_rand % size != 0:
3255+
raise ValueError("Error, for paralelization the ensemble dimension must be a multiple of the processors")
32533256

3254-
else:
3257+
if not parallel:
32553258
size = 1
32563259
rank = 0
32573260
structures = self.structures
@@ -3296,7 +3299,7 @@ def get_energy_forces(self, ase_calculator, compute_stress = True, stress_numeri
32963299

32973300

32983301
# Print the status
3299-
if rank == 0:
3302+
if rank == 0 and verbose:
33003303
print ("Computing configuration %d / %d" % (i0+1, N_rand / size))
33013304

33023305
# Avoid for errors
@@ -3333,7 +3336,7 @@ def get_energy_forces(self, ase_calculator, compute_stress = True, stress_numeri
33333336

33343337
# Collect all togheter
33353338

3336-
if __MPI__:
3339+
if parallel:
33373340
comm.Allgather([energies, MPI.DOUBLE], [self.energies, MPI.DOUBLE])
33383341
comm.Allgather([forces, MPI.DOUBLE], [total_forces, MPI.DOUBLE])
33393342

0 commit comments

Comments
 (0)