99from pySDC .projects .Resilience .hook import LogData
1010from pySDC .projects .Resilience .accuracy_check import get_accuracy_order
1111from pySDC .implementations .convergence_controller_classes .adaptive_collocation import AdaptiveCollocation
12+ from pySDC .implementations .convergence_controller_classes .estimate_embedded_error import (
13+ EstimateEmbeddedErrorCollocation ,
14+ )
1215from pySDC .core .Hooks import hooks
1316from pySDC .implementations .hooks .log_errors import LogLocalErrorPostIter
17+ from pySDC .implementations .hooks .log_embedded_error_estimate import LogEmbeddedErrorEstimatePostIter
1418
1519
1620# define global parameters for running problems and plotting
4347}
4448
4549special_params = {
46- 'inexact' : {AdaptiveCollocation : coll_params_inexact },
47- 'refinement' : {AdaptiveCollocation : coll_params_refinement },
48- 'reduce' : {AdaptiveCollocation : coll_params_reduce },
50+ 'inexact' : {EstimateEmbeddedErrorCollocation : { 'adaptive_coll_params' : coll_params_inexact } },
51+ 'refinement' : {EstimateEmbeddedErrorCollocation : { 'adaptive_coll_params' : coll_params_refinement } },
52+ 'reduce' : {EstimateEmbeddedErrorCollocation : { 'adaptive_coll_params' : coll_params_reduce } },
4953 'standard' : {},
50- 'type' : {AdaptiveCollocation : coll_params_type },
54+ 'type' : {EstimateEmbeddedErrorCollocation : { 'adaptive_coll_params' : coll_params_type } },
5155}
5256
5357
@@ -200,7 +204,7 @@ def check_order(prob, coll_name, ax, k_ax):
200204 Tend = dt_range [i ],
201205 custom_description = custom_description ,
202206 custom_controller_params = custom_controller_parameters ,
203- hook_class = [LogData , LogSweeperParams , LogLocalErrorPostIter ],
207+ hook_class = [LogData , LogSweeperParams , LogLocalErrorPostIter , LogEmbeddedErrorEstimatePostIter ],
204208 )
205209
206210 sweeper_params = get_sorted (stats , type = 'sweeper_params' , sortby = 'iter' )
@@ -212,31 +216,43 @@ def check_order(prob, coll_name, ax, k_ax):
212216 e_loc = np .array ([me [1 ] for me in get_sorted (stats , type = 'e_local_post_iteration' , sortby = 'iter' )])[
213217 converged_solution
214218 ]
219+
220+ e_em_raw = [
221+ me [1 ] for me in get_sorted (stats , type = 'error_embedded_estimate_collocation_post_iteration' , sortby = 'iter' )
222+ ]
223+ e_em = np .array ((e_em_raw + [None ] if coll_name == 'refinement' else [None ] + e_em_raw ))
215224 coll_order = np .array ([me [1 ] for me in get_sorted (stats , type = 'coll_order' , sortby = 'iter' )])[converged_solution ]
216225
217- res += [(dt_range [i ], e_loc , idx [1 :] - idx [:- 1 ], labels , coll_order )]
218- # res += [(dt_range[i], np.array([me[1] for me in e_loc])[converged_solution], (idx[1:]-idx[:-1])/(idx[:-1]+1)*100, labels)]
226+ res += [(dt_range [i ], e_loc , idx [1 :] - idx [:- 1 ], labels , coll_order , e_em )]
219227
220228 # assemble sth we can compute the order from
221229 result = {'dt' : [me [0 ] for me in res ]}
230+ embedded_errors = {'dt' : [me [0 ] for me in res ]}
222231 num_sols = len (res [0 ][1 ])
223232 for i in range (num_sols ):
224233 result [i ] = [me [1 ][i ] for me in res ]
234+ embedded_errors [i ] = [me [5 ][i ] for me in res ]
235+
225236 label = res [0 ][3 ][i ]
226237 expected_order = res [0 ][4 ][i ] + 1
227238
228- order = get_accuracy_order (result , key = i , thresh = 1e-9 )
229- assert np .isclose (
230- np .mean (order ), expected_order , atol = 0.3
231- ), f"Expected order: { expected_order } , got { order :.2f} !"
239+ ax .scatter (result ['dt' ], embedded_errors [i ], color = CMAP [i ])
240+
241+ for me in [result , embedded_errors ]:
242+ if None in me [i ]:
243+ continue
244+ order = get_accuracy_order (me , key = i , thresh = 1e-9 )
245+ assert np .isclose (
246+ np .mean (order ), expected_order , atol = 0.3
247+ ), f"Expected order: { expected_order } , got { order :.2f} !"
232248 ax .loglog (result ['dt' ], result [i ], label = f'{ label } nodes: order: { np .mean (order ):.1f} ' , color = CMAP [i ])
233249
234250 if i > 0 :
235251 extra_iter = [me [2 ][i - 1 ] for me in res ]
236252 k_ax .plot (result ['dt' ], extra_iter , ls = '--' , color = CMAP [i ])
237253 ax .legend (frameon = False )
238254 ax .set_xlabel (r'$\Delta t$' )
239- ax .set_ylabel (r'$e_\mathrm{local}$' )
255+ ax .set_ylabel (r'$e_\mathrm{local}$ (lines), $e_\mathrm{embedded}$ (dots) ' )
240256 k_ax .set_ylabel (r'extra iterations' )
241257
242258
@@ -247,6 +263,7 @@ def order_stuff(prob):
247263 for i in range (len (modes )):
248264 k_axs += [axs .flatten ()[i ].twinx ()]
249265 check_order (prob , modes [i ], axs .flatten ()[i ], k_axs [- 1 ])
266+ axs .flatten ()[i ].set_title (modes [i ])
250267
251268 for i in range (2 ):
252269 k_axs [i ].set_ylabel ('' )
0 commit comments