-
Notifications
You must be signed in to change notification settings - Fork 4
Cory implicit adams #4
base: main
Are you sure you want to change the base?
Changes from all commits
cb564e5
e465177
44baab1
1d1e717
1905ae4
f9f6619
62de9f8
c1b563c
47d9b73
4b8c9bf
dd7990a
6945909
b066fbc
266d87d
e854c74
83c7530
cc24a36
112b32b
75d3b7e
571640e
aa19b84
4d41c8e
6d4d9ad
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||
|---|---|---|---|---|---|---|---|---|---|---|
| @@ -1,10 +1,10 @@ | ||||||||||
| """Adams-Bashforth ODE solvers.""" | ||||||||||
| """Adams-Bashforth and Adams-Moulton ODE solvers.""" | ||||||||||
|
|
||||||||||
|
|
||||||||||
| __copyright__ = """ | ||||||||||
| Copyright (C) 2007 Andreas Kloeckner | ||||||||||
| Copyright (C) 2014, 2015 Matt Wala | ||||||||||
| Copyright (C) 2015 Cory Mikida | ||||||||||
| Copyright (C) 2015, 2020 Cory Mikida | ||||||||||
| """ | ||||||||||
|
|
||||||||||
| __license__ = """ | ||||||||||
|
|
@@ -37,6 +37,7 @@ | |||||||||
| .. autoclass:: AdamsIntegrationFunctionFamily | ||||||||||
| .. autoclass:: AdamsMonomialIntegrationFunctionFamily | ||||||||||
| .. autoclass:: AdamsBashforthMethodBuilder | ||||||||||
| .. autoclass:: AdamsMoultonMethodBuilder | ||||||||||
| """ | ||||||||||
|
|
||||||||||
|
|
||||||||||
|
|
@@ -416,4 +417,341 @@ def rk_bootstrap(self, cb): | |||||||||
|
|
||||||||||
| # }}} | ||||||||||
|
|
||||||||||
| # {{{ am method | ||||||||||
|
|
||||||||||
|
|
||||||||||
| class AdamsMoultonMethodBuilder(MethodBuilder): | ||||||||||
| """ | ||||||||||
| User-supplied context: | ||||||||||
| <state> + component_id: The value that is integrated | ||||||||||
| <func> + component_id: The right hand side | ||||||||||
|
Comment on lines
+425
to
+427
Owner
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I suspect this won't render the way you intend. Check via |
||||||||||
|
|
||||||||||
| .. automethod:: __init__ | ||||||||||
| .. automethod:: generate | ||||||||||
| """ | ||||||||||
|
|
||||||||||
| def __init__(self, component_id, function_family=None, state_filter_name=None, | ||||||||||
| hist_length=None, static_dt=False, order=None, _extra_bootstrap=False): | ||||||||||
| """ | ||||||||||
| :arg function_family: Accepts an instance of | ||||||||||
| :class:`AdamsIntegrationFunctionFamily` | ||||||||||
| or an integer, in which case the classical monomial function family | ||||||||||
| with the order given by the integer is used. | ||||||||||
| :arg static_dt: If *True*, changing the timestep during time integration | ||||||||||
| is not allowed. | ||||||||||
| """ | ||||||||||
|
|
||||||||||
| if function_family is not None and order is not None: | ||||||||||
| raise ValueError("may not specify both function_family and order") | ||||||||||
|
|
||||||||||
| if function_family is None: | ||||||||||
| function_family = order | ||||||||||
| del order | ||||||||||
|
|
||||||||||
| if isinstance(function_family, int): | ||||||||||
| function_family = AdamsMonomialIntegrationFunctionFamily(function_family) | ||||||||||
|
|
||||||||||
| super(AdamsMoultonMethodBuilder, self).__init__() | ||||||||||
|
Owner
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||
| self.function_family = function_family | ||||||||||
|
|
||||||||||
| if hist_length is None: | ||||||||||
| hist_length = len(function_family) | ||||||||||
|
|
||||||||||
| self.hist_length = hist_length | ||||||||||
| self.static_dt = static_dt | ||||||||||
| self.extra_bootstrap = _extra_bootstrap | ||||||||||
|
|
||||||||||
| self.component_id = component_id | ||||||||||
|
|
||||||||||
| # Declare variables | ||||||||||
| self.step = var("<p>step") | ||||||||||
| self.function = var("<func>" + component_id) | ||||||||||
| self.history = \ | ||||||||||
| [var("<p>f_n_minus_" + str(i)) for i in range(hist_length - 1, 0, -1)] | ||||||||||
|
|
||||||||||
| if not self.static_dt: | ||||||||||
| self.time_history = [ | ||||||||||
| var("<p>t_n_minus_" + str(i)) | ||||||||||
| for i in range(hist_length - 1, 0, -1)] | ||||||||||
|
|
||||||||||
| self.state = var("<state>" + component_id) | ||||||||||
| self.t = var("<t>") | ||||||||||
| self.dt = var("<dt>") | ||||||||||
|
|
||||||||||
| if state_filter_name is not None: | ||||||||||
| self.state_filter = var("<func>" + state_filter_name) | ||||||||||
| else: | ||||||||||
| self.state_filter = None | ||||||||||
|
Owner
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It seems this constructor is largely the same as the AB one. Why is its code duplicated? (e.g. instead of inheriting it from a common base) |
||||||||||
|
|
||||||||||
| def generate(self): | ||||||||||
| """ | ||||||||||
| :returns: :class:`dagrt.language.DAGCode` | ||||||||||
| """ | ||||||||||
| from pytools import UniqueNameGenerator | ||||||||||
| name_gen = UniqueNameGenerator() | ||||||||||
|
|
||||||||||
| from dagrt.language import DAGCode, CodeBuilder | ||||||||||
|
|
||||||||||
| array = var("<builtin>array") | ||||||||||
| rhs_next_var = var("rhs_next_var") | ||||||||||
|
|
||||||||||
| # Initialization | ||||||||||
| with CodeBuilder(name="initialization") as cb_init: | ||||||||||
| cb_init(self.step, 1) | ||||||||||
|
|
||||||||||
| # Primary | ||||||||||
| with CodeBuilder(name="primary") as cb_primary: | ||||||||||
|
|
||||||||||
| rhs_var_to_unknown = {} | ||||||||||
| unkvar = cb_primary.fresh_var("unk") | ||||||||||
| rhs_var_to_unknown[rhs_next_var] = unkvar | ||||||||||
|
|
||||||||||
| # In implicit mode, the time history must | ||||||||||
| # include the *next* point in time. | ||||||||||
| if not self.static_dt: | ||||||||||
| time_history_data = self.time_history + [self.t + self.dt] | ||||||||||
| time_hist_var = var(name_gen("time_history")) | ||||||||||
| cb_primary(time_hist_var, array(self.hist_length)) | ||||||||||
| for i in range(self.hist_length): | ||||||||||
| cb_primary(time_hist_var[i], time_history_data[i] - self.t) | ||||||||||
|
|
||||||||||
| time_hist = time_hist_var | ||||||||||
| t_end = self.dt | ||||||||||
|
Comment on lines
+512
to
+519
Owner
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I imagine this time history management code occurs similarly somewhere else. Could you factor this out into a utility? |
||||||||||
| dt_factor = 1 | ||||||||||
|
|
||||||||||
| else: | ||||||||||
| time_hist = list(range(-self.hist_length+2, 0+2)) # noqa pylint:disable=invalid-unary-operand-type | ||||||||||
| dt_factor = self.dt | ||||||||||
| t_end = 1 | ||||||||||
|
|
||||||||||
| # Implicit setup - rhs_next_var is an unknown, needs implicit solve. | ||||||||||
| equations = [] | ||||||||||
| unknowns = set() | ||||||||||
| knowns = set() | ||||||||||
|
|
||||||||||
| unknowns.add(rhs_next_var) | ||||||||||
|
|
||||||||||
| # Update history | ||||||||||
| history = self.history + [rhs_next_var] | ||||||||||
|
|
||||||||||
| # Set up the actual Adams-Moulton step. | ||||||||||
| ab_sum = emit_adams_integration( | ||||||||||
| cb_primary, name_gen, | ||||||||||
| self.function_family, | ||||||||||
| time_hist, history, | ||||||||||
| 0, t_end) | ||||||||||
|
|
||||||||||
| state_est = self.state + dt_factor * ab_sum | ||||||||||
| if self.state_filter is not None: | ||||||||||
| state_est = self.state_filter(state_est) | ||||||||||
|
|
||||||||||
| # Build the implicit solve expression. | ||||||||||
| from dagrt.expression import collapse_constants | ||||||||||
| from pymbolic.mapper.distributor import DistributeMapper as DistMap | ||||||||||
| solve_expression = collapse_constants( | ||||||||||
| rhs_next_var - self.eval_rhs(self.t + self.dt, | ||||||||||
| DistMap()(state_est)), | ||||||||||
| list(unknowns) + [self.state], | ||||||||||
| cb_primary.assign, cb_primary.fresh_var) | ||||||||||
| equations.append(solve_expression) | ||||||||||
|
|
||||||||||
| # {{{ emit solve if possible | ||||||||||
|
Owner
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What does "if possible" mean here? |
||||||||||
|
|
||||||||||
| if unknowns and len(unknowns) == len(equations): | ||||||||||
|
Owner
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should this have an |
||||||||||
| # got a square system, let's solve | ||||||||||
| assignees = [unk.name for unk in unknowns] | ||||||||||
|
|
||||||||||
| from pymbolic import substitute | ||||||||||
| subst_dict = dict( | ||||||||||
| (rhs_var.name, rhs_var_to_unknown[rhs_var]) | ||||||||||
| for rhs_var in unknowns) | ||||||||||
|
|
||||||||||
| cb_primary.assign_implicit( | ||||||||||
| assignees=assignees, | ||||||||||
| solve_components=[ | ||||||||||
| rhs_var_to_unknown[unk].name | ||||||||||
| for unk in unknowns], | ||||||||||
| expressions=[ | ||||||||||
| substitute(eq, subst_dict) | ||||||||||
| for eq in equations], | ||||||||||
|
|
||||||||||
| # TODO: Could supply a starting guess | ||||||||||
| other_params={ | ||||||||||
| "guess": self.state}, | ||||||||||
| solver_id="solve") | ||||||||||
|
|
||||||||||
| del equations[:] | ||||||||||
| knowns.update(unknowns) | ||||||||||
| unknowns.clear() | ||||||||||
|
|
||||||||||
| # }}} | ||||||||||
|
|
||||||||||
| # Update the state now that we've solved. | ||||||||||
| cb_primary(self.state, state_est) | ||||||||||
|
Owner
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How come the solve result is not passed through the state filter? |
||||||||||
|
|
||||||||||
| # Rotate history and time history. | ||||||||||
| for i in range(self.hist_length - 1): | ||||||||||
| cb_primary(self.history[i], history[i + 1]) | ||||||||||
|
|
||||||||||
| if not self.static_dt: | ||||||||||
| cb_primary(self.time_history[i], time_history_data[i + 1]) | ||||||||||
|
Comment on lines
+592
to
+597
Owner
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I imagine this time history management code occurs similarly somewhere else. Could you factor this out into a utility? |
||||||||||
|
|
||||||||||
| cb_primary(self.t, self.t + self.dt) | ||||||||||
| cb_primary.yield_state(expression=self.state, | ||||||||||
| component_id=self.component_id, | ||||||||||
| time_id="", time=self.t) | ||||||||||
|
|
||||||||||
| if self.hist_length == 1: | ||||||||||
| # The first order method requires no bootstrapping. | ||||||||||
| return DAGCode( | ||||||||||
| phases={ | ||||||||||
| "initial": cb_init.as_execution_phase(next_phase="primary"), | ||||||||||
| "primary": cb_primary.as_execution_phase(next_phase="primary") | ||||||||||
| }, | ||||||||||
| initial_phase="initial") | ||||||||||
|
|
||||||||||
| # Bootstrap | ||||||||||
|
Owner
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||
| with CodeBuilder(name="bootstrap") as cb_bootstrap: | ||||||||||
| self.rk_bootstrap(cb_bootstrap) | ||||||||||
| cb_bootstrap(self.t, self.t + self.dt) | ||||||||||
| cb_bootstrap.yield_state(expression=self.state, | ||||||||||
| component_id=self.component_id, | ||||||||||
| time_id="", time=self.t) | ||||||||||
| cb_bootstrap(self.step, self.step + 1) | ||||||||||
| # Bootstrap length is typically one less because of implicit, | ||||||||||
|
Owner
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. less than what? |
||||||||||
| # but if we are comparing with IMEX MRAM, we need one more | ||||||||||
| # bootstrap step. | ||||||||||
|
Comment on lines
+622
to
+623
Owner
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not sure I fully understand this comment. What's being compared here? |
||||||||||
| if self.extra_bootstrap: | ||||||||||
| with cb_bootstrap.if_(self.step, "==", self.hist_length): | ||||||||||
| cb_bootstrap.switch_phase("primary") | ||||||||||
| else: | ||||||||||
| with cb_bootstrap.if_(self.step, "==", self.hist_length - 1): | ||||||||||
| cb_bootstrap.switch_phase("primary") | ||||||||||
|
Owner
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||
|
|
||||||||||
| return DAGCode( | ||||||||||
| phases={ | ||||||||||
| "initialization": cb_init.as_execution_phase("bootstrap"), | ||||||||||
| "bootstrap": cb_bootstrap.as_execution_phase("bootstrap"), | ||||||||||
| "primary": cb_primary.as_execution_phase("primary"), | ||||||||||
| }, | ||||||||||
| initial_phase="initialization") | ||||||||||
|
|
||||||||||
| def eval_rhs(self, t, y): | ||||||||||
| """Return a node that evaluates the RHS at the given time and | ||||||||||
| component value.""" | ||||||||||
| from pymbolic.primitives import CallWithKwargs | ||||||||||
| return CallWithKwargs(function=self.function, | ||||||||||
| parameters=(), | ||||||||||
| kw_parameters={"t": t, self.component_id: y}) | ||||||||||
|
|
||||||||||
| def rk_bootstrap(self, cb): | ||||||||||
| """Initialize the timestepper with an IMPLICIT RK method.""" | ||||||||||
|
|
||||||||||
| equations = [] | ||||||||||
| unknowns = set() | ||||||||||
| knowns = set() | ||||||||||
| rhs_var_to_unknown = {} | ||||||||||
|
|
||||||||||
| from leap.rk import IMPLICIT_ORDER_TO_RK_METHOD_BUILDER | ||||||||||
| rk_method = IMPLICIT_ORDER_TO_RK_METHOD_BUILDER[self.function_family.order] | ||||||||||
| rk_tableau = tuple(zip(rk_method.c, rk_method.a_implicit)) | ||||||||||
| rk_coeffs = rk_method.output_coeffs | ||||||||||
|
|
||||||||||
| if self.extra_bootstrap: | ||||||||||
| first_save_step = 2 | ||||||||||
| else: | ||||||||||
| first_save_step = 1 | ||||||||||
|
|
||||||||||
| with cb.if_(self.step, "==", first_save_step): | ||||||||||
| # Save the first RHS to the AM history | ||||||||||
| rhs_var = var("rhs_var") | ||||||||||
|
|
||||||||||
| cb(rhs_var, self.eval_rhs(self.t, self.state)) | ||||||||||
| cb(self.history[0], rhs_var) | ||||||||||
|
|
||||||||||
| if not self.static_dt: | ||||||||||
| cb(self.time_history[0], self.t) | ||||||||||
|
|
||||||||||
| # Stage loop | ||||||||||
| rhss = [var("rk_rhs_" + str(i)) for i in range(len(rk_tableau))] | ||||||||||
| for stage_num, (c, coeffs) in enumerate(rk_tableau): | ||||||||||
| stage = self.state + sum(self.dt * coeff * rhss[j] | ||||||||||
| for (j, coeff) | ||||||||||
| in enumerate(coeffs)) | ||||||||||
|
|
||||||||||
| if self.state_filter is not None: | ||||||||||
| stage = self.state_filter(stage) | ||||||||||
|
|
||||||||||
| # In a DIRK setting, the unknown is always the same RHS | ||||||||||
| # as the stage number. | ||||||||||
| unknowns.add(rhss[stage_num]) | ||||||||||
|
Comment on lines
+685
to
+687
Owner
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Looks like this assumes DIRK. Does it check that the given method is DIRK? |
||||||||||
| unkvar = cb.fresh_var("unk_s%d" % (stage_num)) | ||||||||||
| rhs_var_to_unknown[rhss[stage_num]] = unkvar | ||||||||||
| from dagrt.expression import collapse_constants | ||||||||||
| solve_expression = collapse_constants( | ||||||||||
| rhss[stage_num] - self.eval_rhs(self.t + c*self.dt, stage), | ||||||||||
| list(unknowns) + [self.state], | ||||||||||
| cb.assign, cb.fresh_var) | ||||||||||
| equations.append(solve_expression) | ||||||||||
|
|
||||||||||
| # {{{ emit solve if possible | ||||||||||
|
|
||||||||||
| if unknowns and len(unknowns) == len(equations): | ||||||||||
| # got a square system, let's solve | ||||||||||
| assignees = [unk.name for unk in unknowns] | ||||||||||
|
|
||||||||||
| from pymbolic import substitute | ||||||||||
| subst_dict = dict( | ||||||||||
| (rhs_var.name, rhs_var_to_unknown[rhs_var]) | ||||||||||
| for rhs_var in unknowns) | ||||||||||
|
|
||||||||||
| cb.assign_implicit( | ||||||||||
| assignees=assignees, | ||||||||||
| solve_components=[ | ||||||||||
| rhs_var_to_unknown[unk].name | ||||||||||
| for unk in unknowns], | ||||||||||
| expressions=[ | ||||||||||
| substitute(eq, subst_dict) | ||||||||||
| for eq in equations], | ||||||||||
|
|
||||||||||
| # TODO: Could supply a starting guess | ||||||||||
| other_params={ | ||||||||||
| "guess": self.state}, | ||||||||||
| solver_id="solve") | ||||||||||
|
|
||||||||||
| del equations[:] | ||||||||||
| knowns.update(unknowns) | ||||||||||
| unknowns.clear() | ||||||||||
|
|
||||||||||
| # }}} | ||||||||||
|
|
||||||||||
| # Merge the values of the RHSs. | ||||||||||
| rk_comb = sum(coeff * rhss[j] for j, coeff in enumerate(rk_coeffs)) | ||||||||||
|
|
||||||||||
| state_est = self.state + self.dt * rk_comb | ||||||||||
| if self.state_filter is not None: | ||||||||||
| state_est = self.state_filter(state_est) | ||||||||||
|
|
||||||||||
| # Assign the value of the new state. | ||||||||||
| cb(self.state, state_est) | ||||||||||
|
|
||||||||||
| # Save the "next" RHS to the AM history | ||||||||||
| rhs_next_var = var("rhs_next_var") | ||||||||||
|
|
||||||||||
| cb(rhs_next_var, self.eval_rhs(self.t + self.dt, self.state)) | ||||||||||
|
|
||||||||||
| for i in range(1, len(self.history)): | ||||||||||
| if self.extra_bootstrap: | ||||||||||
| save_crit = i+1 | ||||||||||
| else: | ||||||||||
| save_crit = i | ||||||||||
|
|
||||||||||
| with cb.if_(self.step, "==", save_crit): | ||||||||||
| cb(self.history[i], rhs_next_var) | ||||||||||
|
|
||||||||||
| if not self.static_dt: | ||||||||||
| cb(self.time_history[i], self.t + self.dt) | ||||||||||
|
|
||||||||||
| # }}} | ||||||||||
|
|
||||||||||
| # vim: fdm=marker | ||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.