Skip to content

Commit e16964e

Browse files
committed
api: add mul interp mode
1 parent 995fc55 commit e16964e

File tree

3 files changed

+48
-3
lines changed

3 files changed

+48
-3
lines changed

devito/finite_differences/derivative.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -594,7 +594,7 @@ def _eval_fd(self, expr, **kwargs):
594594
res = generic_derivative(expr, self.dims[0], self.fd_order[0],
595595
self.deriv_order[0], weights=self.weights,
596596
side=self.side, matvec=self.transpose,
597-
x0=self.x0, expand=expand)
597+
x0=x0_deriv, expand=expand)
598598

599599
# Step 4: Apply substitutions
600600
for e in self._ppsubs:

devito/finite_differences/differentiable.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -629,6 +629,51 @@ def _gather_for_diff(self):
629629

630630
return self.func(*new_args, evaluate=False)
631631

632+
def _eval_at(self, func):
633+
# No a basic a*b*c... expression, just defer to superclass
634+
if any(isinstance(f, DifferentiableOp) for f in self.args):
635+
return super()._eval_at(func)
636+
637+
# Split Derivative and Differentiable args
638+
derivs, other = split(self.args, lambda e: isinstance(e, sympy.Derivative))
639+
640+
if derivs:
641+
derivs = Differentiable._eval_at(self.func(*derivs), func)
642+
else:
643+
derivs = 1
644+
645+
if not other:
646+
return derivs
647+
elif len(other) > 1:
648+
expr = self.func(*other)._gather_for_diff
649+
else:
650+
expr = other[0]
651+
652+
# Non differentiable expr (e.g., number)
653+
if not isinstance(expr, Differentiable):
654+
return self.func(derivs, expr)
655+
656+
# Build mapper for dimensions that need to be interpolated
657+
mapper = {}
658+
for d in self.dimensions:
659+
try:
660+
if self.indices_ref[d] is not func.indices_ref[d]:
661+
mapper[d] = func.indices_ref[d]
662+
except KeyError:
663+
pass
664+
665+
# Nothing to interpolate
666+
if not mapper:
667+
return super()._eval_at(func)
668+
669+
# Interpolate expr at the required indices
670+
interp = expr.diff(*mapper.keys(), deriv_order=[0 for _ in mapper],
671+
fd_order=[self.interp_order for _ in mapper],
672+
x0=mapper)
673+
674+
# Return the full expression with Derivatives
675+
return self.func(derivs, interp)
676+
632677

633678
class Pow(DifferentiableOp, sympy.Pow):
634679
_fd_priority = 0

devito/types/dense.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1070,7 +1070,7 @@ def __fd_setup__(self):
10701070

10711071
@cached_property
10721072
def _fd_priority(self):
1073-
return 1 if self.staggered.on_node else 2
1073+
return 2.1 if self.staggered.on_node else 2
10741074

10751075
def _eval_at(self, func):
10761076
if self.staggered == func.staggered:
@@ -1491,7 +1491,7 @@ def __shape_setup__(cls, **kwargs):
14911491

14921492
@cached_property
14931493
def _fd_priority(self):
1494-
return 2.1 if self.staggered.on_node else 2.2
1494+
return 2.3 if self.staggered.on_node else 2.2
14951495

14961496
@property
14971497
def time_order(self):

0 commit comments

Comments
 (0)