Skip to content

Commit 26bc247

Browse files
authored
Merge pull request #284 from brownbaerchen/adaptivity_Q
Adaptivity based on error estimate based on interpolating between quadrature rules based on Thibaut
2 parents fa524d2 + 2430ab5 commit 26bc247

File tree

6 files changed

+252
-15
lines changed

6 files changed

+252
-15
lines changed

pySDC/implementations/convergence_controller_classes/adaptive_collocation.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ def switch_sweeper(self, S):
105105
P = L.prob
106106

107107
# store solution of current level which will be interpolated to new level
108-
u_old = L.u.copy()
108+
u_old = [me.flatten() for me in L.u]
109109
nodes_old = L.sweep.coll.nodes.copy()
110110

111111
# change sweeper
@@ -119,12 +119,13 @@ def switch_sweeper(self, S):
119119
# interpolate solution of old collocation problem to new one
120120
nodes_new = L.sweep.coll.nodes.copy()
121121
interpolator = LagrangeApproximation(points=np.append(0, nodes_old))
122+
122123
u_inter = interpolator.getInterpolationMatrix(np.append(0, nodes_new)) @ u_old
123124

124125
# assign the interpolated values to the nodes in the level
125126
for i in range(0, len(u_inter)):
126127
me = P.dtype_u(P.init)
127-
me[:] = u_inter[i]
128+
me[:] = np.reshape(u_inter[i], P.init[0])
128129
L.u[i] = me
129130

130131
# reevaluate rhs
@@ -195,3 +196,29 @@ def post_spread_processing(self, controller, S, **kwargs):
195196
None
196197
"""
197198
self.switch_sweeper(S)
199+
200+
def check_parameters(self, controller, params, description, **kwargs):
201+
"""
202+
Check if we allow the scheme to solve the collocation problems to convergence.
203+
204+
Args:
205+
controller (pySDC.Controller): The controller
206+
params (dict): The params passed for this specific convergence controller
207+
description (dict): The description object used to instantiate the controller
208+
209+
Returns:
210+
bool: Whether the parameters are compatible
211+
str: The error message
212+
"""
213+
if description["level_params"].get("restol", -1.0) <= 1e-16:
214+
return (
215+
False,
216+
"Switching the collocation problems requires solving them to some tolerance that can be reached. Please set attainable `restol` in the level params",
217+
)
218+
if description["step_params"].get("maxiter", -1.0) < 99:
219+
return (
220+
False,
221+
"Switching the collocation problems requires solving them exactly, which may require many iterations please set `maxiter` to at least 99 in the step params",
222+
)
223+
224+
return True, ""

pySDC/implementations/convergence_controller_classes/adaptivity.py

Lines changed: 157 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import numpy as np
2-
from pySDC.core.ConvergenceController import ConvergenceController
2+
from pySDC.core.ConvergenceController import ConvergenceController, Status
33
from pySDC.implementations.convergence_controller_classes.step_size_limiter import (
44
StepSizeLimiter,
55
)
@@ -202,7 +202,7 @@ def check_parameters(self, controller, params, description, **kwargs):
202202
bool: Whether the parameters are compatible
203203
str: The error message
204204
"""
205-
if description["step_params"].get("restol", -1.0) >= 0:
205+
if description["level_params"].get("restol", -1.0) >= 0:
206206
return (
207207
False,
208208
"Adaptivity needs constant order in time and hence restol in the step parameters has to be \
@@ -218,8 +218,7 @@ def check_parameters(self, controller, params, description, **kwargs):
218218
if "e_tol" not in params.keys():
219219
return (
220220
False,
221-
"Adaptivity needs a local tolerance! Please set some up in description['convergence_control\
222-
_params']['e_tol']!",
221+
"Adaptivity needs a local tolerance! Please pass `e_tol` to the parameters for this convergence controller!",
223222
)
224223

225224
return True, ""
@@ -389,7 +388,7 @@ def check_parameters(self, controller, params, description, **kwargs):
389388
bool: Whether the parameters are compatible
390389
str: The error message
391390
"""
392-
if description["step_params"].get("restol", -1.0) >= 0:
391+
if description["level_params"].get("restol", -1.0) >= 0:
393392
return (
394393
False,
395394
"Adaptivity needs constant order in time and hence restol in the step parameters has to be \
@@ -446,3 +445,156 @@ def get_local_error_estimate(self, controller, S, **kwargs):
446445
float: Embedded error estimate
447446
"""
448447
return S.levels[0].status.residual
448+
449+
450+
class AdaptivityCollocation(AdaptivityBase):
451+
"""
452+
Control the step size via a collocation based estimate of the local error.
453+
The error estimate works by subtracting two solutions to collocation problems with different order. You can
454+
interpolate between collocation methods as much as you want but the adaptive step size selection will always be
455+
based on the last switch of quadrature.
456+
"""
457+
458+
def setup(self, controller, params, description, **kwargs):
459+
"""
460+
Add a default value for control order to the parameters.
461+
462+
Args:
463+
controller (pySDC.Controller): The controller
464+
params (dict): Parameters for the convergence controller
465+
description (dict): The description object used to instantiate the controller
466+
467+
Returns:
468+
dict: Updated parameters
469+
"""
470+
defaults = {
471+
"adaptive_coll_params": {},
472+
"num_colls": 0,
473+
**super().setup(controller, params, description, **kwargs),
474+
"control_order": 220,
475+
}
476+
477+
for key in defaults['adaptive_coll_params'].keys():
478+
if type(defaults['adaptive_coll_params'][key]) == list:
479+
defaults['num_colls'] = max([defaults['num_colls'], len(defaults['adaptive_coll_params'][key])])
480+
481+
return defaults
482+
483+
def setup_status_variables(self, controller, **kwargs):
484+
self.status = Status(['error', 'order'])
485+
self.status.error = []
486+
self.status.order = []
487+
488+
def reset_status_variables(self, controller, **kwargs):
489+
self.setup_status_variables(controller, **kwargs)
490+
491+
def dependencies(self, controller, description, **kwargs):
492+
"""
493+
Load the `EstimateEmbeddedErrorCollocation` convergence controller to estimate the local error by switching
494+
between collocation problems between iterations.
495+
496+
Args:
497+
controller (pySDC.Controller): The controller
498+
description (dict): The description object used to instantiate the controller
499+
"""
500+
from pySDC.implementations.convergence_controller_classes.estimate_embedded_error import (
501+
EstimateEmbeddedErrorCollocation,
502+
)
503+
504+
super().dependencies(controller, description)
505+
506+
params = {'adaptive_coll_params': self.params.adaptive_coll_params}
507+
controller.add_convergence_controller(
508+
EstimateEmbeddedErrorCollocation,
509+
params=params,
510+
description=description,
511+
)
512+
513+
def get_local_error_estimate(self, controller, S, **kwargs):
514+
"""
515+
Get the collocation based embedded error estimate.
516+
517+
Args:
518+
controller (pySDC.Controller): The controller
519+
S (pySDC.Step): The current step
520+
521+
Returns:
522+
float: Embedded error estimate
523+
"""
524+
if len(self.status.error) > 1:
525+
return self.status.error[-1][1]
526+
else:
527+
return 0.0
528+
529+
def post_iteration_processing(self, controller, step, **kwargs):
530+
"""
531+
Get the error estimate and its order if available.
532+
533+
Args:
534+
controller (pySDC.Controller.controller): The controller
535+
step (pySDC.Step.step): The current step
536+
"""
537+
if step.status.done:
538+
lvl = step.levels[0]
539+
self.status.error += [lvl.status.error_embedded_estimate_collocation]
540+
self.status.order += [lvl.sweep.coll.order]
541+
542+
def get_new_step_size(self, controller, S, **kwargs):
543+
if len(self.status.order) == self.params.num_colls:
544+
lvl = S.levels[0]
545+
546+
# compute next step size
547+
order = (
548+
min(self.status.order[-2::]) + 1
549+
) # local order of less accurate of the last two collocation problems
550+
e_est = self.get_local_error_estimate(controller, S)
551+
552+
lvl.status.dt_new = self.compute_optimal_step_size(
553+
self.params.beta, lvl.params.dt, self.params.e_tol, e_est, order
554+
)
555+
self.log(f'Adjusting step size from {lvl.params.dt:.2e} to {lvl.status.dt_new:.2e}', S)
556+
557+
def check_parameters(self, controller, params, description, **kwargs):
558+
"""
559+
Check whether parameters are compatible with whatever assumptions went into the step size functions etc.
560+
For adaptivity, we need to know the order of the scheme.
561+
562+
Args:
563+
controller (pySDC.Controller): The controller
564+
params (dict): The params passed for this specific convergence controller
565+
description (dict): The description object used to instantiate the controller
566+
567+
Returns:
568+
bool: Whether the parameters are compatible
569+
str: The error message
570+
"""
571+
if controller.params.mssdc_jac:
572+
return (
573+
False,
574+
"Adaptivity needs the same order on all steps, please activate Gauss-Seidel multistep mode!",
575+
)
576+
577+
if "e_tol" not in params.keys():
578+
return (
579+
False,
580+
"Adaptivity needs a local tolerance! Please pass `e_tol` to the parameters for this convergence controller!",
581+
)
582+
583+
return True, ""
584+
585+
def determine_restart(self, controller, S, **kwargs):
586+
"""
587+
Check if the step wants to be restarted by comparing the estimate of the local error to a preset tolerance
588+
589+
Args:
590+
controller (pySDC.Controller): The controller
591+
S (pySDC.Step): The current step
592+
593+
Returns:
594+
None
595+
"""
596+
if len(self.status.order) == self.params.num_colls:
597+
e_est = self.get_local_error_estimate(controller, S)
598+
if e_est >= self.params.e_tol:
599+
S.status.restart = True
600+
self.log(f"Restarting: e={e_est:.2e} >= e_tol={self.params.e_tol:.2e}", S)

pySDC/implementations/convergence_controller_classes/estimate_embedded_error.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -270,7 +270,7 @@ def post_iteration_processing(self, controller, step, **kwargs):
270270
if len(self.status.u) > 1:
271271
lvl.status.error_embedded_estimate_collocation = (
272272
self.status.iter[-2],
273-
abs(self.status.u[-1] - self.status.u[-2]),
273+
max([np.finfo(float).eps, abs(self.status.u[-1] - self.status.u[-2])]),
274274
)
275275

276276
def setup_status_variables(self, controller, **kwargs):

pySDC/projects/Resilience/collocation_adaptivity.py

Lines changed: 46 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -274,11 +274,55 @@ def order_stuff(prob):
274274
fig.tight_layout()
275275

276276

277-
def main():
277+
def adaptivity_collocation(plotting=False):
278+
from pySDC.implementations.convergence_controller_classes.adaptivity import AdaptivityCollocation
279+
280+
e_tol = 1e-7
281+
282+
adaptive_coll_params = {
283+
'num_nodes': [2, 3],
284+
}
285+
286+
convergence_controllers = {}
287+
convergence_controllers[AdaptivityCollocation] = {'adaptive_coll_params': adaptive_coll_params, 'e_tol': e_tol}
288+
289+
step_params = {}
290+
step_params['maxiter'] = 99
291+
292+
level_params = {}
293+
level_params['restol'] = 1e-8
294+
295+
description = {}
296+
description['convergence_controllers'] = convergence_controllers
297+
description['step_params'] = step_params
298+
description['level_params'] = level_params
299+
300+
controller_params = {'logger_level': 30}
301+
302+
stats, controller, _ = run_vdp(custom_description=description, custom_controller_params=controller_params)
303+
304+
e_em = get_sorted(stats, type='error_embedded_estimate_collocation', recomputed=False)
305+
assert (
306+
max([me[1] for me in e_em]) <= e_tol
307+
), "Exceeded threshold for local tolerance when using collocation based adaptivity"
308+
assert (
309+
min([me[1] for me in e_em][1:-1]) >= e_tol / 10
310+
), "Over resolved problem when using collocation based adaptivity"
311+
312+
if plotting:
313+
from pySDC.projects.Resilience.vdp import plot_step_sizes
314+
315+
fig, ax = plt.subplots()
316+
317+
plot_step_sizes(stats, ax, 'error_embedded_estimate_collocation')
318+
319+
320+
def main(plotting=False):
321+
adaptivity_collocation(plotting)
278322
order_stuff(run_advection)
279323
compare_adaptive_collocation(run_vdp)
280324

281325

282326
if __name__ == "__main__":
283-
main()
327+
main(True)
284328
plt.show()

pySDC/projects/Resilience/vdp.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from pySDC.projects.Resilience.hook import LogData, hook_collection
1212

1313

14-
def plot_step_sizes(stats, ax):
14+
def plot_step_sizes(stats, ax, e_em_key='error_embedded_estimate'):
1515
"""
1616
Plot solution and step sizes to visualize the dynamics in the van der Pol equation.
1717
@@ -28,7 +28,7 @@ def plot_step_sizes(stats, ax):
2828
p = np.array([me[1][1] for me in get_sorted(stats, type='u', recomputed=False, sortby='time')])
2929
t = np.array([me[0] for me in get_sorted(stats, type='u', recomputed=False, sortby='time')])
3030

31-
e_em = np.array(get_sorted(stats, type='error_embedded_estimate', recomputed=False, sortby='time'))[:, 1]
31+
e_em = np.array(get_sorted(stats, type=e_em_key, recomputed=False, sortby='time'))[:, 1]
3232
dt = np.array(get_sorted(stats, type='dt', recomputed=False, sortby='time'))
3333
restart = np.array(get_sorted(stats, type='restart', recomputed=None, sortby='time'))
3434

pySDC/tests/test_projects/test_resilience/test_adaptive_collocation.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,21 @@
22

33

44
@pytest.mark.base
5-
def test_main():
6-
from pySDC.projects.Resilience.collocation_adaptivity import main
5+
def test_adaptivity_collocation():
6+
from pySDC.projects.Resilience.collocation_adaptivity import adaptivity_collocation
77

8-
main()
8+
adaptivity_collocation(plotting=False)
9+
10+
11+
@pytest.mark.base
12+
def test_error_estimate_order():
13+
from pySDC.projects.Resilience.collocation_adaptivity import order_stuff, run_advection
14+
15+
order_stuff(run_advection)
16+
17+
18+
@pytest.mark.base
19+
def test_adaptive_collocation():
20+
from pySDC.projects.Resilience.collocation_adaptivity import compare_adaptive_collocation, run_vdp
21+
22+
compare_adaptive_collocation(run_vdp)

0 commit comments

Comments
 (0)