diff --git a/firedrake/interpolation.py b/firedrake/interpolation.py index c8693ce6db..3c22e96673 100644 --- a/firedrake/interpolation.py +++ b/firedrake/interpolation.py @@ -206,7 +206,7 @@ def __init__(self, expr: Interpolate): # Delay calling .unique() because MixedInterpolator is fine with MeshSequence self.target_mesh = self.target_space.mesh() """The domain we are interpolating into.""" - self.source_mesh = extract_unique_domain(operand) or self.target_mesh + self.source_mesh = extract_unique_domain(operand, expand_mesh_sequence=False) or self.target_mesh """The domain we are interpolating from.""" # Interpolation options @@ -434,6 +434,7 @@ class CrossMeshInterpolator(Interpolator): def __init__(self, expr: Interpolate): super().__init__(expr) self.target_mesh = self.target_mesh.unique() + self.source_mesh = self.source_mesh.unique() if self.access and self.access != op2.WRITE: raise NotImplementedError( "Access other than op2.WRITE not implemented for cross-mesh interpolation." @@ -616,6 +617,7 @@ class SameMeshInterpolator(Interpolator): def __init__(self, expr): super().__init__(expr) self.target_mesh = self.target_mesh.unique() + self.source_mesh = self.source_mesh.unique() subset = self.subset if subset is None: target = self.target_mesh.topology @@ -1697,7 +1699,7 @@ def _build_aij( def _get_callable(self, tensor=None, bcs=None, mat_type=None, sub_mat_type=None): mat_type = mat_type or "aij" - sub_mat_type = sub_mat_type or "baij" + sub_mat_type = sub_mat_type or "aij" Isub = self._get_sub_interpolators(bcs=bcs) V_dest = self.ufl_interpolate.function_space() or self.target_space f = tensor or Function(V_dest) diff --git a/tests/firedrake/regression/test_interpolate_cross_mesh.py b/tests/firedrake/regression/test_interpolate_cross_mesh.py index a669618ed3..63b43aff0a 100644 --- a/tests/firedrake/regression/test_interpolate_cross_mesh.py +++ b/tests/firedrake/regression/test_interpolate_cross_mesh.py @@ -1,6 +1,8 @@ from firedrake import * from firedrake.petsc import DEFAULT_PARTITIONER from firedrake.ufl_expr import extract_unique_domain +from firedrake.mesh import Mesh, plex_from_cell_list +from firedrake.formmanipulation import split_form import numpy as np import pytest from ufl import product @@ -613,8 +615,8 @@ def test_line_integral(): # Create a 1D line mesh in 2D from (0, 0) to (1, 1) with 1 cell cells = np.asarray([[0, 1]]) vertex_coords = np.asarray([[0.0, 0.0], [1.0, 1.0]]) - plex = mesh.plex_from_cell_list(1, cells, vertex_coords, comm=m.comm) - line = mesh.Mesh(plex, dim=2) + plex = plex_from_cell_list(1, cells, vertex_coords, comm=m.comm) + line = Mesh(plex, dim=2) x, y = SpatialCoordinate(line) V_line = FunctionSpace(line, "CG", 2) f_line = Function(V_line).interpolate(x * y) @@ -623,8 +625,8 @@ def test_line_integral(): # Create a 1D line around the unit square (2D) with 4 cells cells = np.asarray([[0, 1], [1, 2], [2, 3], [3, 0]]) vertex_coords = np.asarray([[0.0, 0.0], [1.0, 0.0], [1.0, 1.0], [0.0, 1.0]]) - plex = mesh.plex_from_cell_list(1, cells, vertex_coords, comm=m.comm) - line_square = mesh.Mesh(plex, dim=2) + plex = plex_from_cell_list(1, cells, vertex_coords, comm=m.comm) + line_square = Mesh(plex, dim=2) x, y = SpatialCoordinate(line_square) V_line_square = FunctionSpace(line_square, "CG", 2) f_line_square = Function(V_line_square).interpolate(x * y) @@ -750,3 +752,40 @@ def test_interpolate_cross_mesh_interval(periodic): f_dest = Function(V_dest).interpolate(f_src) x_dest, = SpatialCoordinate(m_dest) assert abs(assemble((f_dest - (-(x_dest - .5) ** 2)) ** 2 * dx)) < 1.e-16 + + +def test_mixed_interpolator_cross_mesh(): + # Tests assembly of mixed interpolator across meshes + mesh1 = UnitSquareMesh(4, 4) + mesh2 = UnitSquareMesh(3, 3, quadrilateral=True) + mesh3 = UnitDiskMesh(2) + mesh4 = UnitTriangleMesh(3) + V1 = FunctionSpace(mesh1, "CG", 1) + V2 = FunctionSpace(mesh2, "CG", 2) + V3 = FunctionSpace(mesh3, "CG", 3) + V4 = FunctionSpace(mesh4, "CG", 4) + + W = V1 * V2 + U = V3 * V4 + + w = TrialFunction(W) + w0, w1 = split(w) + expr = as_vector([w0 + w1, w0 + w1]) + mixed_interp = interpolate(expr, U, allow_missing_dofs=True) # Interpolating from W to U + + # The block matrix structure is + # | V1 -> V3 V2 -> V3 | + # | V1 -> V4 V2 -> V4 | + + res = assemble(mixed_interp, mat_type="nest") + assert isinstance(res, AssembledMatrix) + assert res.petscmat.type == "nest" + + split_interp = dict(split_form(mixed_interp)) + + for i in range(2): + for j in range(2): + interp_ij = split_interp[(i, j)] + assert isinstance(interp_ij, Interpolate) + res_block = assemble(interpolate(TrialFunction(W.sub(j)), U.sub(i), allow_missing_dofs=True)) + assert np.allclose(res.petscmat.getNestSubMatrix(i, j)[:, :], res_block.petscmat[:, :]) diff --git a/tests/firedrake/regression/test_interpolator_types.py b/tests/firedrake/regression/test_interpolator_types.py index 3ccdae844a..c77900b781 100644 --- a/tests/firedrake/regression/test_interpolator_types.py +++ b/tests/firedrake/regression/test_interpolator_types.py @@ -163,8 +163,8 @@ def test_mixed_same_mesh_mattype(value_shape, mat_type, sub_mat_type): # Always seqaij for scalar assert sub_mat.type == "seqaij" else: - # matnest sub_mat_type defaults to baij - assert sub_mat.type == "seq" + (sub_mat_type if sub_mat_type else "baij") + # matnest sub_mat_type defaults to aij + assert sub_mat.type == "seq" + (sub_mat_type if sub_mat_type else "aij") with pytest.raises(NotImplementedError): assemble(interp, mat_type="baij")