|
4 | 4 |
|
5 | 5 | from devito.finite_differences.differentiable import Mul |
6 | 6 | from devito.finite_differences.derivative import Derivative |
7 | | -from devito.types import Eq, Symbol, SteppingDimension |
| 7 | +from devito.types import Eq, Symbol, SteppingDimension, TimeFunction |
8 | 8 | from devito.types.equation import InjectSolveEq |
9 | 9 | from devito.operations.solve import eval_time_derivatives |
10 | 10 | from devito.symbolics import retrieve_functions |
@@ -87,73 +87,79 @@ def separate_eqn(eqn, target): |
87 | 87 | zeroed_eqn = Eq(eqn.lhs - eqn.rhs, 0) |
88 | 88 | zeroed_eqn = eval_time_derivatives(zeroed_eqn.lhs) |
89 | 89 | target_funcs = set(generate_targets(zeroed_eqn, target)) |
90 | | - b, F_target = remove_target(zeroed_eqn, target_funcs) |
| 90 | + b, F_target = remove_targets(zeroed_eqn, target_funcs) |
91 | 91 | return -b, F_target, target_funcs |
92 | 92 |
|
93 | 93 |
|
94 | 94 | def generate_targets(eq, target): |
95 | 95 | """ |
96 | | - Extract all the functions that share the same time index as the target, |
| 96 | + Extract all the functions that share the same time index as the target |
97 | 97 | but may have different spatial indices. |
98 | 98 | """ |
99 | 99 | funcs = retrieve_functions(eq) |
100 | | - if any(dim.is_Time for dim in target.dimensions): |
101 | | - time_idx = [ |
102 | | - i for i, d in zip(target.indices, target.dimensions) if d.is_Time |
103 | | - ] |
| 100 | + if isinstance(target, TimeFunction): |
| 101 | + time_idx = target.indices[target.time_dim] |
104 | 102 | targets = [ |
105 | | - func for func in funcs |
106 | | - if func.function is target.function and time_idx[0] |
107 | | - in func.indices |
| 103 | + f for f in funcs if f.function is target.function and time_idx |
| 104 | + in f.indices |
108 | 105 | ] |
109 | 106 | else: |
110 | | - targets = [ |
111 | | - func for func in funcs |
112 | | - if func.function is target.function |
113 | | - ] |
| 107 | + targets = [f for f in funcs if f.function is target.function] |
114 | 108 | return targets |
115 | 109 |
|
116 | 110 |
|
117 | 111 | def targets_to_arrays(array, targets): |
| 112 | + """ |
| 113 | + Map each target in `targets` to a corresponding array generated from `array`, |
| 114 | + matching the spatial indices of the target. |
| 115 | +
|
| 116 | + Example: |
| 117 | + -------- |
| 118 | + >>> array |
| 119 | + vec_u(x, y) |
| 120 | +
|
| 121 | + >>> targets |
| 122 | + {u(t + dt, x + h_x, y), u(t + dt, x - h_x, y), u(t + dt, x, y)} |
| 123 | +
|
| 124 | + >>> targets_to_arrays(array, targets) |
| 125 | + {u(t + dt, x - h_x, y): vec_u(x - h_x, y), |
| 126 | + u(t + dt, x + h_x, y): vec_u(x + h_x, y), |
| 127 | + u(t + dt, x, y): vec_u(x, y)} |
| 128 | + """ |
118 | 129 | space_indices = [ |
119 | | - tuple( |
120 | | - i for i, d in zip(func.indices, func.dimensions) if d.is_Space |
121 | | - ) for func in targets |
| 130 | + tuple(f.indices[d] for d in f.space_dimensions) for f in targets |
122 | 131 | ] |
123 | 132 | array_targets = [ |
124 | | - array.subs( |
125 | | - {arr_idx: target_idx for arr_idx, target_idx in zip(array.indices, indices)} |
126 | | - ) |
127 | | - for indices in space_indices |
| 133 | + array.subs(dict(zip(array.indices, i))) for i in space_indices |
128 | 134 | ] |
129 | | - return {target: array for target, array in zip(targets, array_targets)} |
| 135 | + return dict(zip(targets, array_targets)) |
130 | 136 |
|
131 | 137 |
|
132 | 138 | @singledispatch |
133 | | -def remove_target(expr, targets): |
| 139 | +def remove_targets(expr, targets): |
134 | 140 | return (0, expr) if expr in targets else (expr, 0) |
135 | 141 |
|
136 | 142 |
|
137 | | -@remove_target.register(sympy.Add) |
| 143 | +@remove_targets.register(sympy.Add) |
138 | 144 | def _(expr, targets): |
139 | 145 | if not any(expr.has(t) for t in targets): |
140 | 146 | return (expr, 0) |
141 | 147 |
|
142 | | - args_b, args_F = zip(*(remove_target(a, targets) for a in expr.args)) |
| 148 | + args_b, args_F = zip(*(remove_targets(a, targets) for a in expr.args)) |
143 | 149 | return (expr.func(*args_b, evaluate=False), expr.func(*args_F, evaluate=False)) |
144 | 150 |
|
145 | 151 |
|
146 | | -@remove_target.register(Mul) |
| 152 | +@remove_targets.register(Mul) |
147 | 153 | def _(expr, targets): |
148 | 154 | if not any(expr.has(t) for t in targets): |
149 | 155 | return (expr, 0) |
150 | 156 |
|
151 | | - args_b, args_F = zip(*[remove_target(a, targets) if any(a.has(t) for t in targets) |
| 157 | + args_b, args_F = zip(*[remove_targets(a, targets) if any(a.has(t) for t in targets) |
152 | 158 | else (a, a) for a in expr.args]) |
153 | 159 | return (expr.func(*args_b, evaluate=False), expr.func(*args_F, evaluate=False)) |
154 | 160 |
|
155 | 161 |
|
156 | | -@remove_target.register(Derivative) |
| 162 | +@remove_targets.register(Derivative) |
157 | 163 | def _(expr, targets): |
158 | 164 | return (0, expr) if any(expr.has(t) for t in targets) else (expr, 0) |
159 | 165 |
|
@@ -219,4 +225,4 @@ def generate_time_mapper(funcs): |
219 | 225 | if d.is_Time |
220 | 226 | }) |
221 | 227 | tau_symbs = [Symbol('tau%d' % i) for i in range(len(time_indices))] |
222 | | - return {time: tau for time, tau in zip(time_indices, tau_symbs)} |
| 228 | + return dict(zip(time_indices, tau_symbs)) |
0 commit comments