Skip to content

Commit a1671a4

Browse files
author
Thomas Baumann
committed
Added a convergence controller for estimating an embedded error using
different collocation problems
1 parent 3adf2a6 commit a1671a4

File tree

3 files changed

+146
-41
lines changed

3 files changed

+146
-41
lines changed

pySDC/implementations/convergence_controller_classes/estimate_embedded_error.py

Lines changed: 95 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import numpy as np
22

3-
from pySDC.core.ConvergenceController import ConvergenceController, Pars
3+
from pySDC.core.ConvergenceController import ConvergenceController, Pars, Status
44
from pySDC.implementations.convergence_controller_classes.store_uold import StoreUOld
55
from pySDC.implementations.hooks.log_embedded_error_estimate import LogEmbeddedErrorEstimate
66

@@ -203,3 +203,97 @@ def post_iteration_processing(self, controller, S, **kwargs):
203203
self.send(comm, dest=S.status.slot + 1, data=temp, blocking=True)
204204

205205
return None
206+
207+
208+
class EstimateEmbeddedErrorCollocation(ConvergenceController):
209+
"""
210+
Estimates an embedded error based on changing the underlying quadrature rule. The error estimate is stored as
211+
`error_embedded_estimate_collocation` in the status of the level. Note that we only compute the estimate on the
212+
finest level. The error is stored as a tuple with the first index denoting to which iteration it belongs. This
213+
is useful since the error estimate is not available immediately after, but only when the next collocation problem
214+
is converged to make sure the two solutions are of different accuracy.
215+
216+
Changing the collocation method between iterations happens using the `AdaptiveCollocation` convergence controller.
217+
Please refer to that for documentation on how to use this. Just pass the parameters for that convergence controller
218+
as `adaptive_coll_params` to the parameters for this one and they will be passed on when the `AdaptiveCollocation`
219+
convergence controller is automatically added while loading dependencies.
220+
"""
221+
222+
def setup(self, controller, params, description, **kwargs):
223+
"""
224+
Add a default value for control order to the parameters
225+
226+
Args:
227+
controller (pySDC.Controller): The controller
228+
params (dict): Parameters for the convergence controller
229+
description (dict): The description object used to instantiate the controller
230+
231+
Returns:
232+
dict: Updated parameters
233+
"""
234+
defaults = {
235+
"control_order": 210,
236+
"adaptive_coll_params": {},
237+
**super().setup(controller, params, description, **kwargs),
238+
}
239+
return defaults
240+
241+
def dependencies(self, controller, description, **kwargs):
242+
"""
243+
Load the `AdaptiveCollocation` convergence controller to switch between collocation problems between iterations.
244+
245+
Args:
246+
controller (pySDC.Controller): The controller
247+
description (dict): The description object used to instantiate the controller
248+
"""
249+
from pySDC.implementations.convergence_controller_classes.adaptive_collocation import AdaptiveCollocation
250+
251+
controller.add_convergence_controller(
252+
AdaptiveCollocation, params=self.params.adaptive_coll_params, description=description
253+
)
254+
255+
def post_iteration_processing(self, controller, step, **kwargs):
256+
"""
257+
Compute the embedded error as the difference between the interpolated and the current solution on the finest
258+
level.
259+
260+
Args:
261+
controller (pySDC.Controller.controller): The controller
262+
step (pySDC.Step.step): The current step
263+
"""
264+
if step.status.done:
265+
lvl = step.levels[0]
266+
lvl.sweep.compute_end_point()
267+
self.status.u += [lvl.uend]
268+
self.status.iter += [step.status.iter]
269+
270+
if len(self.status.u) > 1:
271+
lvl.status.error_embedded_estimate_collocation = (
272+
self.status.iter[-2],
273+
abs(self.status.u[-1] - self.status.u[-2]),
274+
)
275+
276+
def setup_status_variables(self, controller, **kwargs):
277+
"""
278+
Add the embedded error variable to the levels and add a status variable for previous steps.
279+
280+
Args:
281+
controller (pySDC.Controller): The controller
282+
"""
283+
self.status = Status(['u', 'iter'])
284+
self.status.u = [] # the solutions of converged collocation problems
285+
self.status.iter = [] # the iteration in which the solution converged
286+
287+
if 'comm' in kwargs.keys():
288+
steps = [controller.S]
289+
else:
290+
if 'active_slots' in kwargs.keys():
291+
steps = [controller.MS[i] for i in kwargs['active_slots']]
292+
else:
293+
steps = controller.MS
294+
where = ["levels", "status"]
295+
for S in steps:
296+
self.add_variable(S, name='error_embedded_estimate_collocation', where=where, init=None)
297+
298+
def reset_status_variables(self, controller, **kwargs):
299+
self.setup_status_variables(controller, **kwargs)

pySDC/implementations/hooks/log_embedded_error_estimate.py

Lines changed: 22 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,26 @@ class LogEmbeddedErrorEstimate(hooks):
66
Store the embedded error estimate at the end of each step as "error_embedded_estimate".
77
"""
88

9+
def log_error(self, step, level_number, appendix=''):
10+
L = step.levels[level_number]
11+
12+
for flavour in ['', '_collocation']:
13+
if L.status.get(f'error_embedded_estimate{flavour}'):
14+
if flavour == '_collocation':
15+
iter, value = L.status.error_embedded_estimate_collocation
16+
else:
17+
iter = step.status.iter
18+
value = L.status.error_embedded_estimate
19+
self.add_to_stats(
20+
process=step.status.slot,
21+
time=L.time + L.dt,
22+
level=L.level_index,
23+
iter=iter,
24+
sweep=L.status.sweep,
25+
type=f'error_embedded_estimate{flavour}{appendix}',
26+
value=value,
27+
)
28+
929
def post_step(self, step, level_number, appendix=''):
1030
"""
1131
Record embedded error estimate
@@ -18,18 +38,7 @@ def post_step(self, step, level_number, appendix=''):
1838
None
1939
"""
2040
super().post_step(step, level_number)
21-
22-
L = step.levels[level_number]
23-
24-
self.add_to_stats(
25-
process=step.status.slot,
26-
time=L.time + L.dt,
27-
level=L.level_index,
28-
iter=step.status.iter,
29-
sweep=L.status.sweep,
30-
type=f'error_embedded_estimate{appendix}',
31-
value=L.status.get('error_embedded_estimate'),
32-
)
41+
self.log_error(step, level_number, appendix)
3342

3443

3544
class LogEmbeddedErrorEstimatePostIter(LogEmbeddedErrorEstimate):
@@ -51,23 +60,8 @@ def post_iteration(self, step, level_number):
5160
Returns:
5261
None
5362
"""
54-
# check if the estimate is available at all
5563
super().post_iteration(step, level_number)
56-
57-
L = step.levels[level_number]
58-
59-
if not L.status.get('error_embedded_estimate'):
60-
return None
61-
62-
self.add_to_stats(
63-
process=step.status.slot,
64-
time=L.time + L.dt,
65-
level=L.level_index,
66-
iter=step.status.iter - 1,
67-
sweep=L.status.sweep,
68-
type='error_embedded_estimate_post_iteration',
69-
value=L.status.get('error_embedded_estimate'),
70-
)
64+
self.log_error(step, level_number, '_post_iteration')
7165

7266
def post_step(self, step, level_number):
7367
super().post_step(step, level_number, appendix='_post_iteration')

pySDC/projects/Resilience/collocation_adaptivity.py

Lines changed: 29 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,12 @@
99
from pySDC.projects.Resilience.hook import LogData
1010
from pySDC.projects.Resilience.accuracy_check import get_accuracy_order
1111
from pySDC.implementations.convergence_controller_classes.adaptive_collocation import AdaptiveCollocation
12+
from pySDC.implementations.convergence_controller_classes.estimate_embedded_error import (
13+
EstimateEmbeddedErrorCollocation,
14+
)
1215
from pySDC.core.Hooks import hooks
1316
from pySDC.implementations.hooks.log_errors import LogLocalErrorPostIter
17+
from pySDC.implementations.hooks.log_embedded_error_estimate import LogEmbeddedErrorEstimatePostIter
1418

1519

1620
# define global parameters for running problems and plotting
@@ -43,11 +47,11 @@
4347
}
4448

4549
special_params = {
46-
'inexact': {AdaptiveCollocation: coll_params_inexact},
47-
'refinement': {AdaptiveCollocation: coll_params_refinement},
48-
'reduce': {AdaptiveCollocation: coll_params_reduce},
50+
'inexact': {EstimateEmbeddedErrorCollocation: {'adaptive_coll_params': coll_params_inexact}},
51+
'refinement': {EstimateEmbeddedErrorCollocation: {'adaptive_coll_params': coll_params_refinement}},
52+
'reduce': {EstimateEmbeddedErrorCollocation: {'adaptive_coll_params': coll_params_reduce}},
4953
'standard': {},
50-
'type': {AdaptiveCollocation: coll_params_type},
54+
'type': {EstimateEmbeddedErrorCollocation: {'adaptive_coll_params': coll_params_type}},
5155
}
5256

5357

@@ -200,7 +204,7 @@ def check_order(prob, coll_name, ax, k_ax):
200204
Tend=dt_range[i],
201205
custom_description=custom_description,
202206
custom_controller_params=custom_controller_parameters,
203-
hook_class=[LogData, LogSweeperParams, LogLocalErrorPostIter],
207+
hook_class=[LogData, LogSweeperParams, LogLocalErrorPostIter, LogEmbeddedErrorEstimatePostIter],
204208
)
205209

206210
sweeper_params = get_sorted(stats, type='sweeper_params', sortby='iter')
@@ -212,31 +216,43 @@ def check_order(prob, coll_name, ax, k_ax):
212216
e_loc = np.array([me[1] for me in get_sorted(stats, type='e_local_post_iteration', sortby='iter')])[
213217
converged_solution
214218
]
219+
220+
e_em_raw = [
221+
me[1] for me in get_sorted(stats, type='error_embedded_estimate_collocation_post_iteration', sortby='iter')
222+
]
223+
e_em = np.array((e_em_raw + [None] if coll_name == 'refinement' else [None] + e_em_raw))
215224
coll_order = np.array([me[1] for me in get_sorted(stats, type='coll_order', sortby='iter')])[converged_solution]
216225

217-
res += [(dt_range[i], e_loc, idx[1:] - idx[:-1], labels, coll_order)]
218-
# res += [(dt_range[i], np.array([me[1] for me in e_loc])[converged_solution], (idx[1:]-idx[:-1])/(idx[:-1]+1)*100, labels)]
226+
res += [(dt_range[i], e_loc, idx[1:] - idx[:-1], labels, coll_order, e_em)]
219227

220228
# assemble sth we can compute the order from
221229
result = {'dt': [me[0] for me in res]}
230+
embedded_errors = {'dt': [me[0] for me in res]}
222231
num_sols = len(res[0][1])
223232
for i in range(num_sols):
224233
result[i] = [me[1][i] for me in res]
234+
embedded_errors[i] = [me[5][i] for me in res]
235+
225236
label = res[0][3][i]
226237
expected_order = res[0][4][i] + 1
227238

228-
order = get_accuracy_order(result, key=i, thresh=1e-9)
229-
assert np.isclose(
230-
np.mean(order), expected_order, atol=0.3
231-
), f"Expected order: {expected_order}, got {order:.2f}!"
239+
ax.scatter(result['dt'], embedded_errors[i], color=CMAP[i])
240+
241+
for me in [result, embedded_errors]:
242+
if None in me[i]:
243+
continue
244+
order = get_accuracy_order(me, key=i, thresh=1e-9)
245+
assert np.isclose(
246+
np.mean(order), expected_order, atol=0.3
247+
), f"Expected order: {expected_order}, got {order:.2f}!"
232248
ax.loglog(result['dt'], result[i], label=f'{label} nodes: order: {np.mean(order):.1f}', color=CMAP[i])
233249

234250
if i > 0:
235251
extra_iter = [me[2][i - 1] for me in res]
236252
k_ax.plot(result['dt'], extra_iter, ls='--', color=CMAP[i])
237253
ax.legend(frameon=False)
238254
ax.set_xlabel(r'$\Delta t$')
239-
ax.set_ylabel(r'$e_\mathrm{local}$')
255+
ax.set_ylabel(r'$e_\mathrm{local}$ (lines), $e_\mathrm{embedded}$ (dots)')
240256
k_ax.set_ylabel(r'extra iterations')
241257

242258

@@ -247,6 +263,7 @@ def order_stuff(prob):
247263
for i in range(len(modes)):
248264
k_axs += [axs.flatten()[i].twinx()]
249265
check_order(prob, modes[i], axs.flatten()[i], k_axs[-1])
266+
axs.flatten()[i].set_title(modes[i])
250267

251268
for i in range(2):
252269
k_axs[i].set_ylabel('')

0 commit comments

Comments
 (0)