Skip to content

Commit 3adf2a6

Browse files
authored
Merge pull request #281 from brownbaerchen/interpolate_restarts
Interpolate after restarts
2 parents 5f1eb57 + 6eadfdb commit 3adf2a6

File tree

14 files changed

+471
-28
lines changed

14 files changed

+471
-28
lines changed

pySDC/core/Controller.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,13 +32,14 @@ class controller(object):
3232
Base abstract controller class
3333
"""
3434

35-
def __init__(self, controller_params, description):
35+
def __init__(self, controller_params, description, useMPI=None):
3636
"""
3737
Initialization routine for the base controller
3838
3939
Args:
4040
controller_params (dict): parameter set for the controller and the steps
4141
"""
42+
self.useMPI = useMPI
4243

4344
# check if we have a hook on this list. If not, use default class.
4445
self.__hooks = []
@@ -288,7 +289,7 @@ def add_convergence_controller(self, convergence_controller, description, params
288289
None
289290
'''
290291
# check if we passed any sort of special params
291-
params = {} if params is None else params
292+
params = {**({} if params is None else params), 'useMPI': self.useMPI}
292293

293294
# check if we already have the convergence controller or if we want to have it multiple times
294295
if convergence_controller not in [type(me) for me in self.convergence_controllers] or allow_double:

pySDC/core/ConvergenceController.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
class Pars(FrozenClass):
77
def __init__(self, params):
88
self.control_order = 0 # integer that determines the order in which the convergence controllers are called
9+
self.useMPI = None # depends on the controller
910

1011
for k, v in params.items():
1112
setattr(self, k, v)

pySDC/helpers/plot_helper.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def figsize(textwidth, scale, ratio):
2525
return fig_size
2626

2727

28-
def figsize_by_journal(journal, scale, ratio): # pragma no cover
28+
def figsize_by_journal(journal, scale, ratio): # pragma: no cover
2929
"""
3030
Get figsize for specific journal. If you supply a text height, we will rescale the figure to fit on the page instead
3131
of the parameters supplied.

pySDC/implementations/controller_classes/controller_MPI.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def __init__(self, controller_params, description, comm):
2525
"""
2626

2727
# call parent's initialization routine
28-
super(controller_MPI, self).__init__(controller_params, description)
28+
super().__init__(controller_params, description, useMPI=True)
2929

3030
# create single step per processor
3131
self.S = step(description)

pySDC/implementations/controller_classes/controller_nonMPI.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def __init__(self, num_procs, controller_params, description):
3030
raise ControllerError('predict flag is ignored, use predict_type instead')
3131

3232
# call parent's initialization routine
33-
super(controller_nonMPI, self).__init__(controller_params, description)
33+
super().__init__(controller_params, description, useMPI=False)
3434

3535
self.MS = [stepclass.step(description)]
3636

pySDC/implementations/convergence_controller_classes/adaptivity.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
from pySDC.implementations.convergence_controller_classes.basic_restarting import (
77
BasicRestartingNonMPI,
88
)
9-
from pySDC.implementations.controller_classes.controller_nonMPI import controller_nonMPI
109
from pySDC.implementations.hooks.log_step_size import LogStepSize
1110

1211

@@ -175,7 +174,7 @@ def dependencies(self, controller, description, **kwargs):
175174
super(Adaptivity, self).dependencies(controller, description)
176175

177176
controller.add_convergence_controller(
178-
EstimateEmbeddedError.get_implementation("nonMPI" if type(controller) == controller_nonMPI else "MPI"),
177+
EstimateEmbeddedError.get_implementation("nonMPI" if not self.params.useMPI else "MPI"),
179178
description=description,
180179
)
181180

pySDC/implementations/convergence_controller_classes/check_convergence.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,10 +45,9 @@ def dependencies(self, controller, description, **kwargs):
4545
from pySDC.implementations.convergence_controller_classes.estimate_embedded_error import (
4646
EstimateEmbeddedError,
4747
)
48-
from pySDC.implementations.controller_classes.controller_nonMPI import controller_nonMPI
4948

5049
controller.add_convergence_controller(
51-
EstimateEmbeddedError.get_implementation("nonMPI" if type(controller) == controller_nonMPI else "MPI"),
50+
EstimateEmbeddedError.get_implementation("nonMPI" if not self.params.useMPI else "MPI"),
5251
description=description,
5352
)
5453

pySDC/implementations/convergence_controller_classes/estimate_contraction_factor.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import numpy as np
22

33
from pySDC.core.ConvergenceController import ConvergenceController
4-
from pySDC.implementations.controller_classes.controller_nonMPI import controller_nonMPI
54
from pySDC.implementations.convergence_controller_classes.estimate_embedded_error import EstimateEmbeddedError
65

76

@@ -36,7 +35,7 @@ def dependencies(self, controller, description, **kwargs):
3635
None
3736
"""
3837
controller.add_convergence_controller(
39-
EstimateEmbeddedError.get_implementation("nonMPI" if type(controller) == controller_nonMPI else "MPI"),
38+
EstimateEmbeddedError.get_implementation("nonMPI" if not self.params.useMPI else "MPI"),
4039
description=description,
4140
)
4241

pySDC/implementations/convergence_controller_classes/hotrod.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import numpy as np
22

33
from pySDC.core.ConvergenceController import ConvergenceController
4-
from pySDC.implementations.controller_classes.controller_nonMPI import controller_nonMPI
54
from pySDC.implementations.convergence_controller_classes.estimate_extrapolation_error import (
65
EstimateExtrapolationErrorNonMPI,
76
)
@@ -49,7 +48,7 @@ def dependencies(self, controller, description, **kwargs):
4948
Returns:
5049
None
5150
"""
52-
if type(controller) == controller_nonMPI:
51+
if not self.params.useMPI:
5352
from pySDC.implementations.convergence_controller_classes.estimate_embedded_error import (
5453
EstimateEmbeddedErrorNonMPI,
5554
)
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
import numpy as np
2+
from pySDC.core.ConvergenceController import ConvergenceController, Status
3+
from pySDC.core.Lagrange import LagrangeApproximation
4+
from pySDC.core.Collocation import CollBase
5+
6+
7+
class InterpolateBetweenRestarts(ConvergenceController):
8+
"""
9+
Interpolate the solution and right hand side to the new set of collocation nodes after a restart.
10+
The idea is that when you adjust the step size between restarts, you already know what the new quadrature method
11+
is going to be and possibly interpolating the current iterate to these results in a better initial guess than
12+
spreading the initial conditions or whatever you usually like to do.
13+
"""
14+
15+
def setup(self, controller, params, description, **kwargs):
16+
"""
17+
Store the initial guess used in the sweeper when no restart has happened
18+
19+
Args:
20+
controller (pySDC.Controller.controller): The controller
21+
params (dict): Parameters for the convergence controller
22+
description (dict): The description object used to instantiate the controller
23+
"""
24+
defaults = {
25+
'control_order': 50,
26+
}
27+
return {**defaults, **super().setup(controller, params, description, **kwargs)}
28+
29+
def setup_status_variables(self, controller, **kwargs):
30+
"""
31+
Add variables to the sweeper containing the interpolated solution and right hand side.
32+
33+
Args:
34+
controller (pySDC.Controller.controller): The controller
35+
"""
36+
self.status = Status(['u_inter', 'f_inter', 'perform_interpolation'])
37+
38+
self.status.u_inter = []
39+
self.status.f_inter = []
40+
self.status.perform_interpolation = False
41+
42+
def post_spread_processing(self, controller, step, **kwargs):
43+
"""
44+
Spread the interpolated values to the collocation nodes. This overrides whatever the sweeper uses for prediction.
45+
46+
Args:
47+
controller (pySDC.Controller.controller): The controller
48+
step (pySDC.Step.step): The current step
49+
"""
50+
if self.status.perform_interpolation:
51+
for i in range(len(step.levels)):
52+
level = step.levels[i]
53+
for m in range(len(level.u)):
54+
level.u[m][:] = self.status.u_inter[i][m][:]
55+
level.f[m][:] = self.status.f_inter[i][m][:]
56+
57+
# reset the status variables
58+
self.status.perform_interpolation = False
59+
self.status.u_inter = []
60+
self.status.f_inter = []
61+
62+
def post_iteration_processing(self, controller, step, **kwargs):
63+
"""
64+
Interpolate the solution and right hand sides and store them in the sweeper, where they will be distributed
65+
accordingly in the prediction step.
66+
67+
This function is called after every iteration instead of just after the step because we might choose to stop
68+
iterating as soon as we have decided to restart. If we let the step continue to iterate, this is not the most
69+
efficient implementation and you may choose to write a different convergence controller.
70+
71+
The interpolation is based on Thibaut's magic.
72+
73+
Args:
74+
controller (pySDC.Controller): The controller
75+
step (pySDC.Step.step): The current step
76+
"""
77+
if step.status.restart and all([level.status.dt_new for level in step.levels]):
78+
for level in step.levels:
79+
nodes_old = level.sweep.coll.nodes.copy()
80+
nodes_new = level.sweep.coll.nodes.copy() * level.status.dt_new / level.params.dt
81+
82+
interpolator = LagrangeApproximation(points=np.append(0, nodes_old))
83+
self.status.u_inter += [(interpolator.getInterpolationMatrix(np.append(0, nodes_new)) @ level.u[:])[:]]
84+
self.status.f_inter += [(interpolator.getInterpolationMatrix(np.append(0, nodes_new)) @ level.f[:])[:]]
85+
86+
self.status.perform_interpolation = True
87+
88+
self.log(
89+
f'Interpolating before restart from dt={level.params.dt:.2e} to dt={level.status.dt_new:.2e}', step
90+
)
91+
else:
92+
self.status.perform_interpolation = False

0 commit comments

Comments
 (0)