Skip to content

Commit 2f8257a

Browse files
committed
conflict
1 parent a98ccde commit 2f8257a

File tree

3 files changed

+93
-16
lines changed

3 files changed

+93
-16
lines changed

firedrake/assemble.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1107,17 +1107,6 @@ def local_kernels(self):
11071107
each possible combination.
11081108
11091109
"""
1110-
try:
1111-
topology, = set(d.topology.submesh_ancesters[-1] for d in self._form.ufl_domains())
1112-
except ValueError:
1113-
raise NotImplementedError("All integration domains must share a mesh topology")
1114-
1115-
for o in itertools.chain(self._form.arguments(), self._form.coefficients()):
1116-
domains = extract_domains(o)
1117-
for domain in domains:
1118-
if domain is not None and domain.topology.submesh_ancesters[-1] != topology:
1119-
raise NotImplementedError("Assembly with multiple meshes is not supported")
1120-
11211110
if isinstance(self._form, ufl.Form):
11221111
kernels = tsfc_interface.compile_form(
11231112
self._form, "form", diagonal=self.diagonal,

tests/firedrake/regression/test_multiple_domains.py

Lines changed: 87 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -58,11 +58,6 @@ def test_mismatching_meshes_real_space(mesh1, mesh3):
5858
project(donor, target)
5959

6060

61-
def test_mismatching_topologies(mesh1, mesh3):
62-
with pytest.raises(NotImplementedError):
63-
assemble(1*dx(domain=mesh1) + 2*dx(domain=mesh3))
64-
65-
6661
def test_functional(mesh1, mesh2):
6762
c = Constant(1)
6863

@@ -123,3 +118,90 @@ def test_two_form(mesh1, mesh2, form, expect):
123118
val = assemble(form).M.values
124119

125120
assert np.allclose(val, expect)
121+
122+
123+
def test_multi_domain_solve():
124+
mesh1 = UnitSquareMesh(7, 7, quadrilateral=True)
125+
x1, y1 = SpatialCoordinate(mesh1)
126+
mesh2 = UnitSquareMesh(8, 8)
127+
x2, y2 = SpatialCoordinate(mesh2)
128+
V1 = FunctionSpace(mesh1, "Q", 3)
129+
V2 = FunctionSpace(mesh2, "CG", 2)
130+
V = V1 * V2
131+
132+
u1, u2 = TrialFunctions(V)
133+
v1, v2 = TestFunctions(V)
134+
135+
a = (
136+
inner(grad(u1), grad(v1))*dx(domain=mesh1)
137+
+ inner(grad(u2), grad(v2))*dx(domain=mesh2)
138+
)
139+
140+
u_exact_expr1 = sin(pi * x1) * sin(pi * y1)
141+
u_exact_expr2 = x2 * y2 * (1 - x2) * (1 - y2)
142+
f1 = -div(grad(u_exact_expr1))
143+
f2 = -div(grad(u_exact_expr2))
144+
145+
L = (
146+
inner(f1, v1)*dx(domain=mesh1)
147+
+ inner(f2, v2)*dx(domain=mesh2)
148+
)
149+
150+
bc1 = DirichletBC(V.sub(0), 0, "on_boundary")
151+
bc2 = DirichletBC(V.sub(1), 0, "on_boundary")
152+
u_sol = Function(V)
153+
solve(a == L, u_sol, bcs=[bc1, bc2])
154+
u1_sol, u2_sol = u_sol.subfunctions
155+
156+
u_exact = Function(V)
157+
u1_exact, u2_exact = u_exact.subfunctions
158+
u1_exact.interpolate(u_exact_expr1)
159+
u2_exact.interpolate(u_exact_expr2)
160+
161+
err1 = errornorm(u1_exact, u1_sol)
162+
assert err1 < 1e-5
163+
err2 = errornorm(u2_exact, u2_sol)
164+
assert err2 < 1e-5
165+
166+
167+
def test_multi_domain_assemble():
168+
mesh1 = UnitSquareMesh(1, 1, quadrilateral=True)
169+
mesh2 = UnitSquareMesh(2, 2)
170+
V1 = FunctionSpace(mesh1, "Q", 1)
171+
V2 = FunctionSpace(mesh2, "CG", 1)
172+
V = V1 * V2
173+
174+
u = TrialFunctions(V)
175+
v = TestFunctions(V)
176+
f = split(Function(V))
177+
178+
for i, j in [(0, 1), (1, 0)]:
179+
a1 = inner(u[i], v[j])*dx(domain=mesh1)
180+
with pytest.raises(NotImplementedError):
181+
assemble(a1)
182+
a2 = inner(u[i], v[j])*dx(domain=mesh2)
183+
with pytest.raises(NotImplementedError):
184+
assemble(a2)
185+
l1 = inner(f[i], v[j])*dx(domain=mesh1)
186+
with pytest.raises(NotImplementedError):
187+
assemble(l1)
188+
l2 = inner(f[i], v[j])*dx(domain=mesh2)
189+
with pytest.raises(NotImplementedError):
190+
assemble(l2)
191+
192+
for i, j in [(0, 0), (1, 1)]:
193+
a = inner(u[i], v[j])*dx(domain=mesh1)
194+
if i == 1:
195+
with pytest.raises(NotImplementedError):
196+
assemble(a)
197+
continue
198+
A = assemble(a)
199+
assert A.M.values.shape == (V.dim(), V.dim())
200+
201+
a = inner(u[0], v[0])*dx(domain=mesh1) + inner(u[0], v[1])*dx(domain=mesh2)
202+
with pytest.raises(NotImplementedError):
203+
assemble(a)
204+
205+
a = inner(u[0], v[0])*dx(domain=mesh1) + inner(u[1], v[1])*dx(domain=mesh2)
206+
A = assemble(a)
207+
assert A.M.values.shape == (V.dim(), V.dim())

tsfc/driver.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,11 @@ def compile_integral(integral_data, form_data, prefix, parameters, *, diagonal=F
145145
domain_integral_type_map.update(dict.fromkeys(coefficient_meshes, "cell"))
146146
domain_integral_type_map.update(integral_data.domain_integral_type_map)
147147

148+
for arg in arguments:
149+
if domain_integral_type_map[extract_unique_domain(arg)] is None:
150+
raise NotImplementedError("Assembly of forms over unrelated meshes is not supported. "
151+
"Try using Submeshes or cross-mesh interpolation.")
152+
148153
integral_data_info = TSFCIntegralDataInfo(
149154
domain=integral_data.domain,
150155
integral_type=integral_data.integral_type,
@@ -156,6 +161,7 @@ def compile_integral(integral_data, form_data, prefix, parameters, *, diagonal=F
156161
coefficient_split=coefficient_split,
157162
coefficient_numbers=coefficient_numbers,
158163
)
164+
159165
builder = firedrake_interface_loopy.KernelBuilder(
160166
integral_data_info,
161167
scalar_type,

0 commit comments

Comments
 (0)