-
Notifications
You must be signed in to change notification settings - Fork 245
dsl: Introduce abstractions for multi-stage time integrators #2599
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 6 commits
1d830b8
7f087b3
214d882
d6c4d4a
78f8a0b
1c9d517
11db48b
83dfb04
d47a106
1f93a45
eea3a52
11d1429
4637ac2
ac1da7e
e9b3533
dc3dd77
1fd4a02
a0c45c1
93c6e3f
ef8d1ac
fa5acac
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -7,6 +7,8 @@ | |
| from devito.finite_differences.derivative import Derivative | ||
| from devito.tools import as_tuple | ||
|
|
||
| from devito.types.multistage import resolve_method | ||
|
|
||
| __all__ = ['solve', 'linsolve'] | ||
|
|
||
|
|
||
|
|
@@ -56,9 +58,16 @@ def solve(eq, target, **kwargs): | |
|
|
||
| # We need to rebuild the vector/tensor as sympy.solve outputs a tuple of solutions | ||
| if len(sols) > 1: | ||
| return target.new_from_mat(sols) | ||
| sols_temp = target.new_from_mat(sols) | ||
| else: | ||
| sols_temp = sols[0] | ||
|
|
||
| method = kwargs.get("method", None) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is method a string here, or is it the class for the method? In the latter case, it would remove the need to have the
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's a string. The idea is that the user provides a string to identify which time integrator to apply. |
||
| if method is not None: | ||
| method_cls = resolve_method(method) | ||
| return method_cls(target, sols_temp)._evaluate(**kwargs) | ||
| else: | ||
| return sols[0] | ||
| return sols_temp | ||
|
|
||
|
|
||
| def linsolve(expr, target, **kwargs): | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -15,7 +15,7 @@ | |
| InvalidOperator) | ||
| from devito.logger import (debug, info, perf, warning, is_log_enabled_for, | ||
| switch_log_level) | ||
| from devito.ir.equations import LoweredEq, lower_exprs, concretize_subdims | ||
| from devito.ir.equations import LoweredEq, lower_multistage, lower_exprs, concretize_subdims | ||
| from devito.ir.clusters import ClusterGroup, clusterize | ||
| from devito.ir.iet import (Callable, CInterface, EntryFunction, FindSymbols, | ||
| MetaCall, derive_parameters, iet_build) | ||
|
|
@@ -36,7 +36,6 @@ | |
| disk_layer) | ||
| from devito.types.dimension import Thickness | ||
|
|
||
|
|
||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please run the linter (
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
| __all__ = ['Operator'] | ||
|
|
||
|
|
||
|
|
@@ -327,6 +326,8 @@ def _lower_exprs(cls, expressions, **kwargs): | |
| * Apply substitution rules; | ||
| * Shift indices for domain alignment. | ||
| """ | ||
| expressions = lower_multistage(expressions) | ||
|
|
||
| expand = kwargs['options'].get('expand', True) | ||
|
|
||
| # Specialization is performed on unevaluated expressions | ||
|
|
||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this file should be moved to somewhere like
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I renamed the class to Regarding the file location, it’s currently in |
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
| @@ -0,0 +1,190 @@ | ||||||
| from .equation import Eq | ||||||
|
||||||
| from .dense import Function | ||||||
| from devito.symbolics import uxreplace | ||||||
|
|
||||||
| from functools import cached_property | ||||||
|
|
||||||
| # from devito.ir.support import SymbolRegistry | ||||||
|
|
||||||
| from .array import Array # Trying Array | ||||||
|
|
||||||
|
|
||||||
| method_registry = {} | ||||||
|
|
||||||
| def register_method(cls): | ||||||
|
||||||
| method_registry[cls.__name__] = cls | ||||||
| return cls | ||||||
|
|
||||||
|
|
||||||
| def resolve_method(method): | ||||||
| try: | ||||||
| return method_registry[method] | ||||||
| except KeyError: | ||||||
| raise ValueError(f"The time integrator '{method}' is not implemented.") | ||||||
|
|
||||||
|
|
||||||
| class MultiStage(Eq): | ||||||
| """ | ||||||
| Abstract base class for multi-stage time integration methods | ||||||
| (e.g., Runge-Kutta schemes) in Devito. | ||||||
|
|
||||||
| This class wraps a symbolic equation of the form `target = rhs` and | ||||||
| provides a mechanism to associate a time integration scheme via the | ||||||
| `method` argument. Subclasses must implement the `_evaluate` method to | ||||||
| generate stage-wise update expressions. | ||||||
|
|
||||||
| Parameters | ||||||
| ---------- | ||||||
| rhs : expr-like | ||||||
| The right-hand side of the equation to integrate. | ||||||
| target : Function | ||||||
| The time-updated symbol on the left-hand side, e.g., `u` or `u.forward`. | ||||||
| method : str or None | ||||||
|
||||||
| A string identifying the time integration method (e.g., 'RK44'), | ||||||
| which must correspond to a class defined in the global scope and | ||||||
| implementing `_evaluate`. If None, no method is applied. | ||||||
|
|
||||||
| Attributes | ||||||
| ---------- | ||||||
| eq : Eq | ||||||
| The symbolic equation `target = rhs`. | ||||||
| method : class | ||||||
| The integration method class resolved from the `method` string. | ||||||
| """ | ||||||
|
|
||||||
| def __new__(cls, lhs, rhs=0, subdomain=None, coefficients=None, implicit_dims=None, **kwargs): | ||||||
|
||||||
| return super().__new__(cls, lhs, rhs=rhs, subdomain=subdomain, coefficients=coefficients, implicit_dims=implicit_dims, **kwargs) | ||||||
|
|
||||||
| def _evaluate(self, **kwargs): | ||||||
| raise NotImplementedError( | ||||||
| f"_evaluate() must be implemented in the subclass {self.__class__.__name__}") | ||||||
|
|
||||||
|
|
||||||
| class RK(MultiStage): | ||||||
|
||||||
| class RK(MultiStage): | |
| class RungeKutta(MultiStage): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not just use concrete args with type hinting rather than kwargs?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nitpick: Probably doesn't need caching
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You shouldn't need counters like this. Use the SymbolRegistry (called sregistry) for this operator build and purge this kwarg
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tried to incorporate this suggestion, but I could use a bit of help. I think I should create a variable like
sym_registry=SymbolRegistry()
and then call
sym_registry.make_name(prefix='k')
However, I’m not sure where exactly I should declare the variable, or if this is the best approach overall.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nitpick: they will be Devito Eq objects
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nitpick: these are Array now
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
right!
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These are internal to Devito, should not appear in operator arguments, and should not be touched by the user, and so should use Array, not Function
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The thing is, when I try using Array, the second test fails, but if I use Functions, it works without any errors. I'm not sure if I'm doing something wrong.
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think RK4 and indeed all RK methods should be instances of the RK Class
Then you no longer need all of the boilerplate code below, which is just setting up Butcher tableau
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Or the coefficients should be class attributes and set by the child class
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think I didn't understand well. Aren’t they already implemented as described in the second option of your comment?
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would set these as tuples in the __init__. Definitely should not be mutable if set on a class level.
I would personally instead have a
def __init__(self):
a = (...
b = (...
c = (...
super.__init__(a=a, b=b, c=c)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I did something like that... could you check?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you do
return [_lower_multistage(expr, **kwargs) for i in exprs for expr in i]?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I did something like that...