Skip to content

Commit ccb1901

Browse files
committed
dsl: Simplify index extraction
1 parent e825ab8 commit ccb1901

File tree

2 files changed

+43
-37
lines changed

2 files changed

+43
-37
lines changed

devito/petsc/iet/passes.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -271,9 +271,9 @@ def assign_time_iters(iet, struct):
271271

272272
mapper = {}
273273
for iter in time_iters:
274-
common_dims = [dim for dim in iter.dimensions if dim in struct.fields]
274+
common_dims = [d for d in iter.dimensions if d in struct.fields]
275275
common_dims = [
276-
DummyExpr(FieldFromComposite(dim, struct), dim) for dim in common_dims
276+
DummyExpr(FieldFromComposite(d, struct), d) for d in common_dims
277277
]
278278
iter_new = iter._rebuild(nodes=List(body=tuple(common_dims)+iter.nodes))
279279
mapper.update({iter: iter_new})
@@ -282,15 +282,15 @@ def assign_time_iters(iet, struct):
282282

283283

284284
def retrieve_time_dims(iters):
285-
time_iter = [i for i in iters if any(dim.is_Time for dim in i.dimensions)]
285+
time_iter = [i for i in iters if any(d.is_Time for d in i.dimensions)]
286286
mapper = {}
287287
if not time_iter:
288288
return mapper
289-
for dim in time_iter[0].dimensions:
290-
if dim.is_Modulo:
291-
mapper[dim.origin] = dim
292-
elif dim.is_Time:
293-
mapper[dim] = dim
289+
for d in time_iter[0].dimensions:
290+
if d.is_Modulo:
291+
mapper[d.origin] = d
292+
elif d.is_Time:
293+
mapper[d] = d
294294
return mapper
295295

296296

devito/petsc/solve.py

Lines changed: 35 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
from devito.finite_differences.differentiable import Mul
66
from devito.finite_differences.derivative import Derivative
7-
from devito.types import Eq, Symbol, SteppingDimension
7+
from devito.types import Eq, Symbol, SteppingDimension, TimeFunction
88
from devito.types.equation import InjectSolveEq
99
from devito.operations.solve import eval_time_derivatives
1010
from devito.symbolics import retrieve_functions
@@ -87,73 +87,79 @@ def separate_eqn(eqn, target):
8787
zeroed_eqn = Eq(eqn.lhs - eqn.rhs, 0)
8888
zeroed_eqn = eval_time_derivatives(zeroed_eqn.lhs)
8989
target_funcs = set(generate_targets(zeroed_eqn, target))
90-
b, F_target = remove_target(zeroed_eqn, target_funcs)
90+
b, F_target = remove_targets(zeroed_eqn, target_funcs)
9191
return -b, F_target, target_funcs
9292

9393

9494
def generate_targets(eq, target):
9595
"""
96-
Extract all the functions that share the same time index as the target,
96+
Extract all the functions that share the same time index as the target
9797
but may have different spatial indices.
9898
"""
9999
funcs = retrieve_functions(eq)
100-
if any(dim.is_Time for dim in target.dimensions):
101-
time_idx = [
102-
i for i, d in zip(target.indices, target.dimensions) if d.is_Time
103-
]
100+
if isinstance(target, TimeFunction):
101+
time_idx = target.indices[target.time_dim]
104102
targets = [
105-
func for func in funcs
106-
if func.function is target.function and time_idx[0]
107-
in func.indices
103+
f for f in funcs if f.function is target.function and time_idx
104+
in f.indices
108105
]
109106
else:
110-
targets = [
111-
func for func in funcs
112-
if func.function is target.function
113-
]
107+
targets = [f for f in funcs if f.function is target.function]
114108
return targets
115109

116110

117111
def targets_to_arrays(array, targets):
112+
"""
113+
Map each target in `targets` to a corresponding array generated from `array`,
114+
matching the spatial indices of the target.
115+
116+
Example:
117+
--------
118+
>>> array
119+
vec_u(x, y)
120+
121+
>>> targets
122+
{u(t + dt, x + h_x, y), u(t + dt, x - h_x, y), u(t + dt, x, y)}
123+
124+
>>> targets_to_arrays(array, targets)
125+
{u(t + dt, x - h_x, y): vec_u(x - h_x, y),
126+
u(t + dt, x + h_x, y): vec_u(x + h_x, y),
127+
u(t + dt, x, y): vec_u(x, y)}
128+
"""
118129
space_indices = [
119-
tuple(
120-
i for i, d in zip(func.indices, func.dimensions) if d.is_Space
121-
) for func in targets
130+
tuple(f.indices[d] for d in f.space_dimensions) for f in targets
122131
]
123132
array_targets = [
124-
array.subs(
125-
{arr_idx: target_idx for arr_idx, target_idx in zip(array.indices, indices)}
126-
)
127-
for indices in space_indices
133+
array.subs(dict(zip(array.indices, i))) for i in space_indices
128134
]
129-
return {target: array for target, array in zip(targets, array_targets)}
135+
return dict(zip(targets, array_targets))
130136

131137

132138
@singledispatch
133-
def remove_target(expr, targets):
139+
def remove_targets(expr, targets):
134140
return (0, expr) if expr in targets else (expr, 0)
135141

136142

137-
@remove_target.register(sympy.Add)
143+
@remove_targets.register(sympy.Add)
138144
def _(expr, targets):
139145
if not any(expr.has(t) for t in targets):
140146
return (expr, 0)
141147

142-
args_b, args_F = zip(*(remove_target(a, targets) for a in expr.args))
148+
args_b, args_F = zip(*(remove_targets(a, targets) for a in expr.args))
143149
return (expr.func(*args_b, evaluate=False), expr.func(*args_F, evaluate=False))
144150

145151

146-
@remove_target.register(Mul)
152+
@remove_targets.register(Mul)
147153
def _(expr, targets):
148154
if not any(expr.has(t) for t in targets):
149155
return (expr, 0)
150156

151-
args_b, args_F = zip(*[remove_target(a, targets) if any(a.has(t) for t in targets)
157+
args_b, args_F = zip(*[remove_targets(a, targets) if any(a.has(t) for t in targets)
152158
else (a, a) for a in expr.args])
153159
return (expr.func(*args_b, evaluate=False), expr.func(*args_F, evaluate=False))
154160

155161

156-
@remove_target.register(Derivative)
162+
@remove_targets.register(Derivative)
157163
def _(expr, targets):
158164
return (0, expr) if any(expr.has(t) for t in targets) else (expr, 0)
159165

@@ -219,4 +225,4 @@ def generate_time_mapper(funcs):
219225
if d.is_Time
220226
})
221227
tau_symbs = [Symbol('tau%d' % i) for i in range(len(time_indices))]
222-
return {time: tau for time, tau in zip(time_indices, tau_symbs)}
228+
return dict(zip(time_indices, tau_symbs))

0 commit comments

Comments
 (0)