Skip to content

Commit e222d06

Browse files
committed
Added the test for the gradient
1 parent 4f98ae4 commit e222d06

3 files changed

Lines changed: 59 additions & 1 deletion

File tree

Modules/SchaMinimizer.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -474,6 +474,8 @@ def minimization_step(self, custom_function_gradient = None):
474474
is_diag_ok = True
475475
self.minimizer.update_dyn(new_kl_ratio, dyn_grad, struct_grad)
476476

477+
478+
477479
# Get the new dynamical matrix and strucure after the step
478480
new_dyn, new_struct = self.minimizer.get_dyn_struct()
479481

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
# -*- coding: utf-8 -*-
2+
from __future__ import print_function
3+
from __future__ import division
4+
5+
import sys, os
6+
import numpy as np
7+
import cellconstructor as CC
8+
import cellconstructor.Phonons
9+
10+
import sscha, sscha.Ensemble
11+
import sscha.SchaMinimizer
12+
13+
"""
14+
This test makes a simple relaxation of the sample ensemble
15+
provided within this distribution
16+
"""
17+
18+
def test_gradient_comparison(verbose = False):
19+
total_path = os.path.dirname(os.path.abspath(__file__))
20+
os.chdir(total_path)
21+
22+
DATA_PATH = "../../Examples/ensemble_data_test/"
23+
24+
dyn_start = CC.Phonons.Phonons(os.path.join(DATA_PATH, "dyn"))
25+
26+
27+
# Perform the minimization
28+
ens = sscha.Ensemble.Ensemble(dyn_start, 0, dyn_start.GetSupercell())
29+
ens.load(DATA_PATH, 2, 1000)
30+
31+
minim = sscha.SchaMinimizer.SSCHA_Minimizer(ens)
32+
minim.minim_struct = True
33+
minim.min_step_dyn = 0.5
34+
minim.min_step_struc = 0.5
35+
minim.meaningful_factor = 1e-10
36+
minim.max_ka = 5
37+
38+
ka = 0
39+
def compare_gradients(dyn_grad, struct_grad):
40+
if not os.path.exists("grad_{}.dat".format(ka)):
41+
np.savetxt("grad_{}.dat".format(ka), dyn_grad[0,:,:])
42+
else:
43+
correct_grad = np.loadtxt("grad_{}.dat".format(ka))
44+
diff = np.max(np.abs(dyn_grad - correct_grad))
45+
print("KA = {} | difference = {}".format(ka, diff))
46+
47+
48+
49+
minim.init()
50+
minim.run(custom_function_gradient = compare_gradients)
51+
minim.finalize()
52+
53+
54+
55+
if __name__ == "__main__":
56+
test_gradient_comparison(True)

tests/test_simple_relax/test_relax.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def test_simple_relax(verbose = False):
6969
minim.min_step_struc = 0.5
7070
minim.meaningful_factor = 1e-10
7171
minim.init()
72-
minim.run(verbose = 0)
72+
minim.run()
7373
minim.finalize()
7474

7575
# Check the differences in the atomic positions

0 commit comments

Comments
 (0)