Skip to content

Commit ff522dd

Browse files
author
Thomas Baumann
committed
Added adaptivity based on different collocation methods
1 parent a1671a4 commit ff522dd

File tree

6 files changed

+239
-13
lines changed

6 files changed

+239
-13
lines changed

pySDC/implementations/convergence_controller_classes/adaptive_collocation.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,3 +195,29 @@ def post_spread_processing(self, controller, S, **kwargs):
195195
None
196196
"""
197197
self.switch_sweeper(S)
198+
199+
def check_parameters(self, controller, params, description, **kwargs):
200+
"""
201+
Check if we allow the scheme to solve the collocation problems to convergence.
202+
203+
Args:
204+
controller (pySDC.Controller): The controller
205+
params (dict): The params passed for this specific convergence controller
206+
description (dict): The description object used to instantiate the controller
207+
208+
Returns:
209+
bool: Whether the parameters are compatible
210+
str: The error message
211+
"""
212+
if description["level_params"].get("restol", -1.0) <= 1e-16:
213+
return (
214+
False,
215+
"Switching the collocation problems requires solving them to some tolerance that can be reached. Please set attainable `restol` in the level params",
216+
)
217+
if description["step_params"].get("maxiter", -1.0) < 99:
218+
return (
219+
False,
220+
"Switching the collocation problems requires solving them exactly, which may require many iterations please set `maxiter` to at least 99 in the step params",
221+
)
222+
223+
return True, ""

pySDC/implementations/convergence_controller_classes/adaptivity.py

Lines changed: 147 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import numpy as np
2-
from pySDC.core.ConvergenceController import ConvergenceController
2+
from pySDC.core.ConvergenceController import ConvergenceController, Status
33
from pySDC.implementations.convergence_controller_classes.step_size_limiter import (
44
StepSizeLimiter,
55
)
@@ -202,7 +202,7 @@ def check_parameters(self, controller, params, description, **kwargs):
202202
bool: Whether the parameters are compatible
203203
str: The error message
204204
"""
205-
if description["step_params"].get("restol", -1.0) >= 0:
205+
if description["level_params"].get("restol", -1.0) >= 0:
206206
return (
207207
False,
208208
"Adaptivity needs constant order in time and hence restol in the step parameters has to be \
@@ -218,8 +218,7 @@ def check_parameters(self, controller, params, description, **kwargs):
218218
if "e_tol" not in params.keys():
219219
return (
220220
False,
221-
"Adaptivity needs a local tolerance! Please set some up in description['convergence_control\
222-
_params']['e_tol']!",
221+
"Adaptivity needs a local tolerance! Please pass `e_tol` to the parameters for this convergence controller!",
223222
)
224223

225224
return True, ""
@@ -389,7 +388,7 @@ def check_parameters(self, controller, params, description, **kwargs):
389388
bool: Whether the parameters are compatible
390389
str: The error message
391390
"""
392-
if description["step_params"].get("restol", -1.0) >= 0:
391+
if description["level_params"].get("restol", -1.0) >= 0:
393392
return (
394393
False,
395394
"Adaptivity needs constant order in time and hence restol in the step parameters has to be \
@@ -446,3 +445,146 @@ def get_local_error_estimate(self, controller, S, **kwargs):
446445
float: Embedded error estimate
447446
"""
448447
return S.levels[0].status.residual
448+
449+
450+
class AdaptivityCollocation(AdaptivityBase):
451+
def setup(self, controller, params, description, **kwargs):
452+
"""
453+
Add a default value for control order to the parameters.
454+
455+
Args:
456+
controller (pySDC.Controller): The controller
457+
params (dict): Parameters for the convergence controller
458+
description (dict): The description object used to instantiate the controller
459+
460+
Returns:
461+
dict: Updated parameters
462+
"""
463+
defaults = {
464+
"adaptive_coll_params": {},
465+
"num_colls": 0,
466+
**super().setup(controller, params, description, **kwargs),
467+
"control_order": 220,
468+
}
469+
470+
for key in defaults['adaptive_coll_params'].keys():
471+
if type(defaults['adaptive_coll_params'][key]) == list:
472+
defaults['num_colls'] = max([defaults['num_colls'], len(defaults['adaptive_coll_params'][key])])
473+
474+
return defaults
475+
476+
def setup_status_variables(self, controller, **kwargs):
477+
self.status = Status(['error', 'order'])
478+
self.status.error = []
479+
self.status.order = []
480+
481+
def reset_status_variables(self, controller, **kwargs):
482+
self.setup_status_variables(controller, **kwargs)
483+
484+
def dependencies(self, controller, description, **kwargs):
485+
"""
486+
Load the `EstimateEmbeddedErrorCollocation` convergence controller to estimate the local error by switching
487+
between collocation problems between iterations.
488+
489+
Args:
490+
controller (pySDC.Controller): The controller
491+
description (dict): The description object used to instantiate the controller
492+
"""
493+
from pySDC.implementations.convergence_controller_classes.estimate_embedded_error import (
494+
EstimateEmbeddedErrorCollocation,
495+
)
496+
497+
super().dependencies(controller, description)
498+
499+
params = {'adaptive_coll_params': self.params.adaptive_coll_params}
500+
controller.add_convergence_controller(
501+
EstimateEmbeddedErrorCollocation,
502+
params=params,
503+
description=description,
504+
)
505+
506+
def get_local_error_estimate(self, controller, S, **kwargs):
507+
"""
508+
Get the collocation based embedded error estimate.
509+
510+
Args:
511+
controller (pySDC.Controller): The controller
512+
S (pySDC.Step): The current step
513+
514+
Returns:
515+
float: Embedded error estimate
516+
"""
517+
if len(self.status.error) > 1:
518+
return self.status.error[-1][1]
519+
else:
520+
return 0.0
521+
522+
def post_iteration_processing(self, controller, step, **kwargs):
523+
"""
524+
Get the error estimate and its order if available.
525+
526+
Args:
527+
controller (pySDC.Controller.controller): The controller
528+
step (pySDC.Step.step): The current step
529+
"""
530+
if step.status.done:
531+
lvl = step.levels[0]
532+
self.status.error += [lvl.status.error_embedded_estimate_collocation]
533+
self.status.order += [lvl.sweep.coll.order]
534+
535+
def get_new_step_size(self, controller, S, **kwargs):
536+
if len(self.status.order) == self.params.num_colls:
537+
lvl = S.levels[0]
538+
539+
# compute next step size
540+
order = self.status.order[-2] + 1 # local order of second to most accurate solution
541+
e_est = self.get_local_error_estimate(controller, S)
542+
lvl.status.dt_new = self.compute_optimal_step_size(
543+
self.params.beta, lvl.params.dt, self.params.e_tol, e_est, order
544+
)
545+
self.log(f'Adjusting step size from {lvl.params.dt:.2e} to {lvl.status.dt_new:.2e}', S)
546+
547+
def check_parameters(self, controller, params, description, **kwargs):
548+
"""
549+
Check whether parameters are compatible with whatever assumptions went into the step size functions etc.
550+
For adaptivity, we need to know the order of the scheme.
551+
552+
Args:
553+
controller (pySDC.Controller): The controller
554+
params (dict): The params passed for this specific convergence controller
555+
description (dict): The description object used to instantiate the controller
556+
557+
Returns:
558+
bool: Whether the parameters are compatible
559+
str: The error message
560+
"""
561+
if controller.params.mssdc_jac:
562+
return (
563+
False,
564+
"Adaptivity needs the same order on all steps, please activate Gauss-Seidel multistep mode!",
565+
)
566+
567+
if "e_tol" not in params.keys():
568+
return (
569+
False,
570+
"Adaptivity needs a local tolerance! Please pass `e_tol` to the parameters for this convergence controller!",
571+
)
572+
573+
return True, ""
574+
575+
def determine_restart(self, controller, S, **kwargs):
576+
"""
577+
Check if the step wants to be restarted by comparing the estimate of the local error to a preset tolerance
578+
579+
Args:
580+
controller (pySDC.Controller): The controller
581+
S (pySDC.Step): The current step
582+
583+
Returns:
584+
None
585+
"""
586+
if len(self.status.order) == self.params.num_colls:
587+
e_est = self.get_local_error_estimate(controller, S)
588+
if e_est >= self.params.e_tol:
589+
S.status.restart = True
590+
self.log(f"Restarting: e={e_est:.2e} >= e_tol={self.params.e_tol:.2e}", S)

pySDC/implementations/convergence_controller_classes/estimate_embedded_error.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -270,7 +270,7 @@ def post_iteration_processing(self, controller, step, **kwargs):
270270
if len(self.status.u) > 1:
271271
lvl.status.error_embedded_estimate_collocation = (
272272
self.status.iter[-2],
273-
abs(self.status.u[-1] - self.status.u[-2]),
273+
max([np.finfo(float).eps, abs(self.status.u[-1] - self.status.u[-2])]),
274274
)
275275

276276
def setup_status_variables(self, controller, **kwargs):

pySDC/projects/Resilience/collocation_adaptivity.py

Lines changed: 46 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -274,11 +274,55 @@ def order_stuff(prob):
274274
fig.tight_layout()
275275

276276

277-
def main():
277+
def adaptivity_collocation(plotting=False):
278+
from pySDC.implementations.convergence_controller_classes.adaptivity import AdaptivityCollocation
279+
280+
e_tol = 1e-7
281+
282+
adaptive_coll_params = {
283+
'num_nodes': [2, 3],
284+
}
285+
286+
convergence_controllers = {}
287+
convergence_controllers[AdaptivityCollocation] = {'adaptive_coll_params': adaptive_coll_params, 'e_tol': e_tol}
288+
289+
step_params = {}
290+
step_params['maxiter'] = 99
291+
292+
level_params = {}
293+
level_params['restol'] = 1e-8
294+
295+
description = {}
296+
description['convergence_controllers'] = convergence_controllers
297+
description['step_params'] = step_params
298+
description['level_params'] = level_params
299+
300+
controller_params = {'logger_level': 30}
301+
302+
stats, controller, _ = run_vdp(custom_description=description, custom_controller_params=controller_params)
303+
304+
e_em = get_sorted(stats, type='error_embedded_estimate_collocation', recomputed=False)
305+
assert (
306+
max([me[1] for me in e_em]) <= e_tol
307+
), "Exceeded threshold for local tolerance when using collocation based adaptivity"
308+
assert (
309+
min([me[1] for me in e_em][1:-1]) >= e_tol / 10
310+
), "Over resolved problem when using collocation based adaptivity"
311+
312+
if plotting:
313+
from pySDC.projects.Resilience.vdp import plot_step_sizes
314+
315+
fig, ax = plt.subplots()
316+
317+
plot_step_sizes(stats, ax, 'error_embedded_estimate_collocation')
318+
319+
320+
def main(plotting=False):
321+
adaptivity_collocation(plotting)
278322
order_stuff(run_advection)
279323
compare_adaptive_collocation(run_vdp)
280324

281325

282326
if __name__ == "__main__":
283-
main()
327+
main(True)
284328
plt.show()

pySDC/projects/Resilience/vdp.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from pySDC.projects.Resilience.hook import LogData, hook_collection
1212

1313

14-
def plot_step_sizes(stats, ax):
14+
def plot_step_sizes(stats, ax, e_em_key='error_embedded_estimate'):
1515
"""
1616
Plot solution and step sizes to visualize the dynamics in the van der Pol equation.
1717
@@ -28,7 +28,7 @@ def plot_step_sizes(stats, ax):
2828
p = np.array([me[1][1] for me in get_sorted(stats, type='u', recomputed=False, sortby='time')])
2929
t = np.array([me[0] for me in get_sorted(stats, type='u', recomputed=False, sortby='time')])
3030

31-
e_em = np.array(get_sorted(stats, type='error_embedded_estimate', recomputed=False, sortby='time'))[:, 1]
31+
e_em = np.array(get_sorted(stats, type=e_em_key, recomputed=False, sortby='time'))[:, 1]
3232
dt = np.array(get_sorted(stats, type='dt', recomputed=False, sortby='time'))
3333
restart = np.array(get_sorted(stats, type='restart', recomputed=None, sortby='time'))
3434

pySDC/tests/test_projects/test_resilience/test_adaptive_collocation.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,21 @@
22

33

44
@pytest.mark.base
5-
def test_main():
6-
from pySDC.projects.Resilience.collocation_adaptivity import main
5+
def test_adaptivity_collocation():
6+
from pySDC.projects.Resilience.collocation_adaptivity import adaptivity_collocation
77

8-
main()
8+
adaptivity_collocation(plotting=False)
9+
10+
11+
@pytest.mark.base
12+
def test_error_estimate_order():
13+
from pySDC.projects.Resilience.collocation_adaptivity import order_stuff, run_advection
14+
15+
order_stuff(run_advection)
16+
17+
18+
@pytest.mark.base
19+
def test_adaptive_collocation():
20+
from pySDC.projects.Resilience.collocation_adaptivity import compare_adaptive_collocation, run_vdp
21+
22+
compare_adaptive_collocation(run_vdp)

0 commit comments

Comments
 (0)