Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
1d830b8
implementation of multi-stage time integrators
fernanvr May 5, 2025
7f087b3
Merge remote-tracking branch 'upstream/main' into multi-stage-time-in…
fernanvr Jun 13, 2025
214d882
Return of first PR comments
fernanvr Jun 13, 2025
d6c4d4a
Return of first PR comments
fernanvr Jun 13, 2025
78f8a0b
2nd PR revision
fernanvr Jun 23, 2025
1c9d517
2nd PR revision
fernanvr Jun 23, 2025
11db48b
2rd PR, updating tests and suggestions of 2nd PR revision
fernanvr Jun 25, 2025
83dfb04
3rd PR, updating tests and suggestions of 2nd PR revision
fernanvr Jun 25, 2025
d47a106
4th PR revision, code refining and improving tests
fernanvr Jun 25, 2025
1f93a45
4th PR revision, code refining and improving tests
fernanvr Jun 25, 2025
eea3a52
5th PR revision, one suggestion from EdC and improving tests
fernanvr Jun 26, 2025
11d1429
including two more Runge-Kutta methods and improving tests: checking …
fernanvr Jul 1, 2025
4637ac2
changes to consider coupled Multistage equations
fernanvr Jul 16, 2025
ac1da7e
Improvements of the HORK_EXP
fernanvr Aug 15, 2025
e9b3533
Merge branch 'main' into multi-stage-time-integrator
fernanvr Aug 15, 2025
dc3dd77
Merge remote-tracking branch 'upstream/main' into multi-stage-time-in…
fernanvr Oct 8, 2025
1fd4a02
tuples, improved class names, extensive tests
fernanvr Oct 8, 2025
a0c45c1
improving spacing in some tests
fernanvr Oct 8, 2025
93c6e3f
Add MFE time stepping Jupyter notebook
fernanvr Oct 23, 2025
ef8d1ac
Remove MFE_time_size.ipynb notebook
fernanvr Oct 23, 2025
fa5acac
Update multistage implementation and tests
fernanvr Oct 29, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 31 additions & 1 deletion devito/ir/equations/algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@
from devito.data.allocators import DataReference
from devito.logger import warning

__all__ = ['dimension_sort', 'lower_exprs', 'concretize_subdims']
from devito.types.multistage import MultiStage

__all__ = ['dimension_sort', 'lower_multistage', 'lower_exprs', 'concretize_subdims']


def dimension_sort(expr):
Expand Down Expand Up @@ -95,6 +97,34 @@ def handle_indexed(indexed):
return ordering


def lower_multistage(expressions):
"""
Separating the multi-stage time-integrator scheme in stages:
* If the object is MultiStage, it creates the stages of the method.
"""
lowered = []
for i, eq in enumerate(as_tuple(expressions)):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rather than having the the as_tuple here, why not have a dispatch for _lower_multistage that dispatches on iterable types as per _concretize_subdims?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done, I think..

lowered.extend(_lower_multistage(eq, i))
return lowered


@singledispatch
def _lower_multistage(expr, index):
"""
Default handler for expressions that are not MultiStage.
Simply return them in a list.
"""
return [expr]


@_lower_multistage.register
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would personally tweak this for consistency with other uses of singledispatch in the codebase

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think now it's done

def _(expr: MultiStage, index):
"""
Specialized handler for MultiStage expressions.
"""
return expr.method(expr.eq.rhs, expr.eq.lhs)._evaluate(eq_num=index)


def lower_exprs(expressions, subs=None, **kwargs):
"""
Lowering an expression consists of the following passes:
Expand Down
14 changes: 11 additions & 3 deletions devito/operations/solve.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from devito.finite_differences.derivative import Derivative
from devito.tools import as_tuple

from devito.types.multistage import MultiStage

__all__ = ['solve', 'linsolve']


Expand All @@ -15,7 +17,7 @@ class SolveError(Exception):
pass


def solve(eq, target, **kwargs):
def solve(eq, target, method = None, eq_num = 0, **kwargs):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Kwargs should not have spaces around them. Furthermore, can method and eq_num simply be folded into **kwargs?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

"""
Algebraically rearrange an Eq w.r.t. a given symbol.

Expand Down Expand Up @@ -56,9 +58,15 @@ def solve(eq, target, **kwargs):

# We need to rebuild the vector/tensor as sympy.solve outputs a tuple of solutions
if len(sols) > 1:
return target.new_from_mat(sols)
sols_temp=target.new_from_mat(sols)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should have whitespace around operator. Same below

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

else:
sols_temp=sols[0]

if method is not None:
method_cls = MultiStage._resolve_method(method)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To me, this implies thatMultiStage._resolve_method should not be a class method. Perhaps the method kwarg should be dropped, and instead you should have something like:

method_cls = eq._resolve_method()  # Possibly make this a property?
if method_cls is None:
    return sols_temp

return method_cls(sols_temp, target)._evaluate(eq_num=eq_num)

or even just

return eq._resolve_method(sols_temp, target)._evaluate(eq_num=eq_num)

where _resolve_method is some abstract method of Eq and its subclasses which defaults to a no-op or similar.

As a side note, why is the eq_num kwarg required?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, eq_num it is for the counting in order to not repeat names. The think is that I'm having some troubles while implementing the SymbolRegistry

return method_cls(sols_temp, target)._evaluate(eq_num=eq_num)
else:
return sols[0]
return sols_temp


def linsolve(expr, target, **kwargs):
Expand Down
5 changes: 3 additions & 2 deletions devito/operator/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
InvalidOperator)
from devito.logger import (debug, info, perf, warning, is_log_enabled_for,
switch_log_level)
from devito.ir.equations import LoweredEq, lower_exprs, concretize_subdims
from devito.ir.equations import LoweredEq, lower_multistage, lower_exprs, concretize_subdims
from devito.ir.clusters import ClusterGroup, clusterize
from devito.ir.iet import (Callable, CInterface, EntryFunction, FindSymbols,
MetaCall, derive_parameters, iet_build)
Expand All @@ -36,7 +36,6 @@
disk_layer)
from devito.types.dimension import Thickness


Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please run the linter (flake8) 🙂

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

__all__ = ['Operator']


Expand Down Expand Up @@ -327,6 +326,8 @@ def _lower_exprs(cls, expressions, **kwargs):
* Apply substitution rules;
* Shift indices for domain alignment.
"""
expressions=lower_multistage(expressions)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Whitespace missing around operator - check throughout the PR for this

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done


expand = kwargs['options'].get('expand', True)

# Specialization is performed on unevaluated expressions
Expand Down
2 changes: 2 additions & 0 deletions devito/types/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,5 @@
from .relational import * # noqa
from .sparse import * # noqa
from .tensor import * # noqa

from .multistage import *
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will need a # noqa

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

186 changes: 186 additions & 0 deletions devito/types/multistage.py
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this file should be moved to somewhere like devito/timestepping/rungekutta.py or devito/timestepping/explicitmultistage.py that way additional timesteppers can be contributed as new files. (I'm thinking about implicit multistage, backward difference formulae etc...)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I renamed the class to HighOrderRungeKuttaExponential. I realize the name might be confusing since this particular Runge-Kutta is explicit, but “EXP” was intended to highlight the exponential aspect. I’ve also updated the other class names based on your suggestions.

Regarding the file location, it’s currently in /types as recommended by @mloubout (see suggestion). Personally, I think both /timestepping and /types are reasonable options. Perhaps we can discuss this with @EdCaunt and @FabioLuporini to reach a consensus.

Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
# from devito import Function, Eq
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Leftover

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

from .equation import Eq
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make these imports absolute (devito.types.equation)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

from .dense import Function
from devito.symbolics import uxreplace

from .array import Array # Trying Array


class MultiStage(Eq):
"""
Abstract base class for multi-stage time integration methods
(e.g., Runge-Kutta schemes) in Devito.

This class wraps a symbolic equation of the form `target = rhs` and
provides a mechanism to associate a time integration scheme via the
`method` argument. Subclasses must implement the `_evaluate` method to
generate stage-wise update expressions.

Parameters
----------
rhs : expr-like
The right-hand side of the equation to integrate.
target : Function
The time-updated symbol on the left-hand side, e.g., `u` or `u.forward`.
method : str or None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be either a class or a callable imo. Alternatively, it should be entirely omitted and set by defining some method/_evaluate method. Of these two, I prefer the latter as it results in cleaner code and a simpler API.

In general, if you are using a string comparison, there is probably a better (and safer) way to achieve your aim.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for that, I removed 'method' from the class, but didn't update the docstring... fixing it now

A string identifying the time integration method (e.g., 'RK44'),
which must correspond to a class defined in the global scope and
implementing `_evaluate`. If None, no method is applied.

Attributes
----------
eq : Eq
The symbolic equation `target = rhs`.
method : class
The integration method class resolved from the `method` string.
"""

def __new__(cls, rhs, target, method=None):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is going to strip subdomain information etc, and the API is inconsistent with the standard Eq. The following would likely be better:

def __new__(cls, lhs, rhs=0, method=None, subdomain=None, coefficients=None, implicit_dims=None, **kwargs):
    obj = super().__new__(lhs, rhs=rhs, subdomain=subdomain, coefficients=coefficients, implicit_dims=implicit_dims, **kwargs)
    obj._method = method  # NOTE: Have `_resolve_method` as a cached_property or similar based on some processing of `_method`
    return obj

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

eq = Eq(target, rhs)
obj = Eq.__new__(cls, eq.lhs, eq.rhs)
obj._eq = eq
obj._method = cls._resolve_method(method)
return obj

@classmethod
def _resolve_method(cls, method):
try:
if cls is MultiStage:
return globals()[method]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We generally try to avoid globals. To me this implies that this functionality is misplaced

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

else:
return cls
except KeyError:
raise ValueError(f"The time integrator '{method}' is not implemented.")

@property
def eq(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can be dropped with the restructured __new__ detailed above

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

return self._eq

@property
def method(self):
return self._method

def _evaluate(self, expand=False):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should probably take **kwargs for consistency with Eq._evaluate()

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

raise NotImplementedError(
f"_evaluate() must be implemented in the subclass {self.__class__.__name__}")


class RK(MultiStage):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
class RK(MultiStage):
class RungeKutta(MultiStage):

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

"""
Base class for explicit Runge-Kutta (RK) time integration methods defined
via a Butcher tableau.

This class handles the general structure of RK schemes by using
the Butcher coefficients (`a`, `b`, `c`) to expand a single equation into
a series of intermediate stages followed by a final update. Subclasses
must define `a`, `b`, and `c` as class attributes.

Parameters
----------
a : list of list of float
The coefficient matrix representing stage dependencies.
b : list of float
The weights for the final combination step.
c : list of float
The time shifts for each intermediate stage (often the row sums of `a`).

Attributes
----------
a : list[list[float]]
Butcher tableau `a` coefficients (stage coupling).
b : list[float]
Butcher tableau `b` coefficients (weights for combining stages).
c : list[float]
Butcher tableau `c` coefficients (stage time positions).
s : int
Number of stages in the RK method, inferred from `b`.
"""

def __init__(self, *args):
self.a = getattr(self, 'a', None)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems strangely formed. I would have a, b, and c as either positional args or kwargs. If self.s is constructed from knowledge about these parameters, then it should be a cached_property

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I worked on this. Can you confirm if this is what you meant?

self.b = getattr(self, 'b', None)
self.c = getattr(self, 'c', None)
self.s = len(self.b) if self.b is not None else 0 # Number of stages

self._validate()

def _validate(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This error handling should happen at the point where these values are first supplied.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I worked on this too. Can you confirm if this is what you meant?

assert self.a is not None and self.b is not None and self.c is not None, \
f"RK subclass must define class attributes a, b, and c"
assert len(self.a) == self.s, f"'a'={a} must have {self.s} rows"
assert len(self.c) == self.s, f"'c'={c} must have {self.s} elements"

def _evaluate(self, eq_num=0):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should have **kwargs for consistency with Eq._evaluate()

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

"""
Generate the stage-wise equations for a Runge-Kutta time integration method.

This method takes a single equation of the form `Eq(u.forward, rhs)` and
expands it into a sequence of intermediate stage evaluations and a final
update equation according to the Runge-Kutta coefficients `a`, `b`, and `c`.

Parameters
----------
eq_num : int, optional
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You shouldn't need counters like this. Use the SymbolRegistry (called sregistry) for this operator build and purge this kwarg

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried to incorporate this suggestion, but I could use a bit of help. I think I should create a variable like
sym_registry=SymbolRegistry()
and then call
sym_registry.make_name(prefix='k')
However, I’m not sure where exactly I should declare the variable, or if this is the best approach overall.

An identifier index used to uniquely name the intermediate stage variables
(`k{eq_num}i`) in case of multiple equations being expanded.

Returns
-------
list of Eq
A list of SymPy Eq objects representing:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nitpick: they will be Devito Eq objects

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

- `s` stage equations of the form `k_i = rhs evaluated at intermediate state`
- 1 final update equation of the form `u.forward = u + dt * sum(b_i * k_i)`
"""
base_eq=self.eq
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With the __new__ specified above, these would just be u = self.lhs.function etc, which would clean things up

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

u = base_eq.lhs.function
rhs = base_eq.rhs
grid = u.grid
t = grid.time_dim
dt = t.spacing

# Create temporary Functions to hold each stage
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nitpick: these are Array now

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

right!

# k = [Array(name=f'k{eq_num}{i}', dimensions=grid.shape, grid=grid, dtype=u.dtype) for i in range(self.s)] # Trying Array
k = [Function(name=f'k{eq_num}{i}', grid=grid, space_order=u.space_order, dtype=u.dtype)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are internal to Devito, should not appear in operator arguments, and should not be touched by the user, and so should use Array, not Function

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The thing is, when I try using Array, the second test fails, but if I use Functions, it works without any errors. I'm not sure if I'm doing something wrong.

for i in range(self.s)]

stage_eqs = []

# Build each stage
for i in range(self.s):
u_temp = u + dt * sum(aij * kj for aij, kj in zip(self.a[i][:i], k[:i]))
t_shift = t + self.c[i] * dt

# Evaluate RHS at intermediate value
stage_rhs = uxreplace(rhs, {u: u_temp, t: t_shift})
stage_eqs.append(Eq(k[i], stage_rhs))

# Final update: u.forward = u + dt * sum(b_i * k_i)
u_next = u + dt * sum(bi * ki for bi, ki in zip(self.b, k))
stage_eqs.append(Eq(u.forward, u_next))

return stage_eqs


class RK44(RK):
"""
Classic 4th-order Runge-Kutta (RK4) time integration method.

This class implements the classic explicit Runge-Kutta method of order 4 (RK44).
It uses four intermediate stages and specific Butcher coefficients to achieve
high accuracy while remaining explicit.

Attributes
----------
a : list[list[float]]
Coefficients of the `a` matrix for intermediate stage coupling.
b : list[float]
Weights for final combination.
c : list[float]
Time positions of intermediate stages.
"""
a = [[0, 0, 0, 0],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would set these as tuples in the __init__. Definitely should not be mutable if set on a class level.

I would personally instead have a

def __init__(self):
    a = (...
    b = (...
    c = (...
    super.__init__(a=a, b=b, c=c)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did something like that... could you check?

[1/2, 0, 0, 0],
[0, 1/2, 0, 0],
[0, 0, 1, 0]]
b = [1/6, 1/3, 1/3, 1/6]
c = [0, 1/2, 1/2, 1]
Loading