diff --git a/fuse/__init__.py b/fuse/__init__.py index 152f815..2e94a49 100644 --- a/fuse/__init__.py +++ b/fuse/__init__.py @@ -1,5 +1,5 @@ -from fuse.cells import Point, Edge, polygon, make_tetrahedron, constructCellComplex +from fuse.cells import Point, Edge, polygon, line, make_tetrahedron, constructCellComplex, TensorProductPoint from fuse.groups import S1, S2, S3, D4, Z3, Z4, C3, C4, S4, A4, tri_C3, tet_edges, tet_faces, sq_edges, GroupRepresentation, PermutationSetRepresentation, get_cyc_group, get_sym_group from fuse.dof import DeltaPairing, DOF, L2Pairing, FuseFunction, PointKernel, PolynomialKernel from fuse.triples import ElementTriple, DOFGenerator, immerse diff --git a/fuse/cells.py b/fuse/cells.py index fb33c9b..f8c43ff 100644 --- a/fuse/cells.py +++ b/fuse/cells.py @@ -148,6 +148,13 @@ def compute_scaled_verts(d, n): raise ValueError("Dimension {} not supported".format(d)) +def line(): + """ + Constructs the default 1D interval + """ + return Point(1, [Point(0), Point(0)], vertex_num=2) + + def polygon(n): """ Constructs the 2D default cell with n sides/vertices @@ -336,6 +343,7 @@ def compute_cell_group(self): """ verts = self.ordered_vertices() v_coords = [self.get_node(v, return_coords=True) for v in verts] + n = len(verts) max_group = SymmetricGroup(n) edges = [edge.ordered_vertices() for edge in self.edges()] @@ -407,6 +415,15 @@ def get_starter_ids(self): min_ids = [min(dimension) for dimension in structure] return min_ids + def local_id(self, node): + structure = [sorted(generation) for generation in nx.topological_generations(self.G)] + structure.reverse() + min_id = self.get_starter_ids() + for d in range(len(structure)): + if node.id in structure[d]: + return node.id - min_id[d] + raise ValueError("Node not found in cell") + def graph_dim(self): if self.oriented: dim = self.dimension + 1 @@ -449,6 +466,9 @@ def ordered_vertices(self, get_class=False): return self.oriented.permute(verts) return verts + def ordered_vertex_coords(self): + return [self.get_node(o, return_coords=True) for o in self.ordered_vertices()] + def d_entities_ids(self, d): return self.d_entities(d, get_class=False) @@ -471,7 +491,7 @@ def get_node(self, node, return_coords=False): if return_coords: top_level_node = self.d_entities_ids(self.graph_dim())[0] if self.dimension == 0: - return [()] + return () return self.attachment(top_level_node, node)() return self.G.nodes.data("point_class")[node] @@ -550,7 +570,6 @@ def basis_vectors(self, return_coords=True, entity=None): self_levels = [sorted(generation) for generation in nx.topological_generations(self.G)] vertices = entity_levels[entity.graph_dim()] if self.dimension == 0: - # return [[] raise ValueError("Dimension 0 entities cannot have Basis Vectors") top_level_node = self_levels[0][0] v_0 = vertices[0] @@ -715,8 +734,8 @@ def copy(self): def to_fiat(self, name=None): if len(self.vertices()) == self.dimension + 1: return CellComplexToFiatSimplex(self, name) - if len(self.vertices()) == 2 ** self.dimension: - return CellComplexToFiatHypercube(self, name) + # if len(self.vertices()) == 2 ** self.dimension: + # return CellComplexToFiatHypercube(self, name) raise NotImplementedError("Custom shape elements/ First class quads are not yet supported") def to_ufl(self, name=None): @@ -736,6 +755,13 @@ def dict_id(self): def _from_dict(o_dict): return Point(o_dict["dim"], o_dict["edges"], oriented=o_dict["oriented"], cell_id=o_dict["id"]) + def equivalent(self, other): + if self.dimension != other.dimension: + return False + if set(self.ordered_vertex_coords()) != set(other.ordered_vertex_coords()): + return False + return self.get_topology() == other.get_topology() + class Edge(): """ @@ -759,7 +785,7 @@ def __call__(self, *x): if hasattr(self.attachment, '__iter__'): res = [] for attach_comp in self.attachment: - if len(attach_comp.atoms(sp.Symbol)) == len(x): + if len(attach_comp.atoms(sp.Symbol)) <= len(x): res.append(sympy_to_numpy(attach_comp, syms, x)) else: res.append(attach_comp.subs({syms[i]: x[i] for i in range(len(x))})) @@ -794,17 +820,17 @@ def _from_dict(o_dict): class TensorProductPoint(): - def __init__(self, A, B, flat=False): + def __init__(self, A, B): self.A = A self.B = B self.dimension = self.A.dimension + self.B.dimension - self.flat = flat + self.flat = False def get_spatial_dimension(self): return self.dimension - def dimension(self): - return tuple(self.A.dimension, self.B.dimension) + def dim(self): + return (self.A.dimension, self.B.dimension) def d_entities(self, d, get_class=True): return self.A.d_entities(d, get_class) + self.B.d_entities(d, get_class) @@ -819,17 +845,91 @@ def vertices(self, get_class=True, return_coords=False): return verts def to_ufl(self, name=None): - if self.flat: - return CellComplexToUFL(self, "quadrilateral") return TensorProductCell(self.A.to_ufl(), self.B.to_ufl()) def to_fiat(self, name=None): - if self.flat: - return CellComplexToFiatHypercube(self, CellComplexToFiatTensorProduct(self, name)) return CellComplexToFiatTensorProduct(self, name) def flatten(self): - return TensorProductPoint(self.A, self.B, True) + assert self.A.equivalent(self.B) + return FlattenedPoint(self.A, self.B) + + +class FlattenedPoint(Point, TensorProductPoint): + + def __init__(self, A, B): + self.A = A + self.B = B + self.dimension = self.A.dimension + self.B.dimension + self.flat = True + fuse_edges = self.construct_fuse_rep() + super().__init__(self.dimension, fuse_edges) + + def to_ufl(self, name=None): + return CellComplexToUFL(self, "quadrilateral") + + def to_fiat(self, name=None): + # TODO this should check if it actually is a hypercube + fiat = CellComplexToFiatHypercube(self, CellComplexToFiatTensorProduct(self, name)) + return fiat + + def construct_fuse_rep(self): + sub_cells = [self.A, self.B] + dims = (self.A.dimension, self.B.dimension) + + points = {cell: {i: [] for i in range(max(dims) + 1)} for cell in sub_cells} + attachments = {cell: {i: [] for i in range(max(dims) + 1)} for cell in sub_cells} + + for d in range(max(dims) + 1): + for cell in sub_cells: + if d <= cell.dimension: + sub_ent = cell.d_entities(d, get_class=True) + points[cell][d].extend(sub_ent) + for s in sub_ent: + attachments[cell][d].extend(s.connections) + + prod_points = list(itertools.product(*[points[cell][0] for cell in sub_cells])) + # temp = prod_points[1] + # prod_points[1] = prod_points[2] + # prod_points[2] = temp + point_cls = [Point(0) for i in range(len(prod_points))] + edges = [] + + # generate edges of tensor product result + for a in prod_points: + for b in prod_points: + # of all combinations of point, take those where at least one changes and at least one is the same + if any(a[i] == b[i] for i in range(len(a))) and any(a[i] != b[i] for i in range(len(sub_cells))): + # ensure if they change, that edge exists in the existing topology + if all([a[i] == b[i] or (sub_cells[i].local_id(a[i]), sub_cells[i].local_id(b[i])) in list(sub_cells[i].topology[1].values()) for i in range(len(sub_cells))]): + edges.append((a, b)) + # hasse level 1 + edge_cls1 = {e: None for e in edges} + for i in range(len(sub_cells)): + for (a, b) in edges: + a_idx = prod_points.index(a) + b_idx = prod_points.index(b) + if a[i] != b[i]: + a_edge = [att for att in attachments[sub_cells[i]][1] if att.point == a[i]][0] + b_edge = [att for att in attachments[sub_cells[i]][1] if att.point == b[i]][0] + edge_cls1[(a, b)] = Point(1, [Edge(point_cls[a_idx], a_edge.attachment, a_edge.o), + Edge(point_cls[b_idx], b_edge.attachment, b_edge.o)]) + edge_cls2 = [] + # hasse level 2 + for i in range(len(sub_cells)): + for (a, b) in edges: + if a[i] == b[i]: + x = sp.Symbol("x") + a_edge = [att for att in attachments[sub_cells[i]][1] if att.point == a[i]][0] + if i == 0: + attach = (x,) + a_edge.attachment + else: + attach = a_edge.attachment + (x,) + edge_cls2.append(Edge(edge_cls1[(a, b)], attach, a_edge.o)) + return edge_cls2 + + def flatten(self): + return self class CellComplexToFiatSimplex(Simplex): @@ -1011,8 +1111,7 @@ def constructCellComplex(name): return polygon(3).to_ufl(name) # return firedrake_triangle().to_ufl(name) elif name == "quadrilateral": - interval = Point(1, [Point(0), Point(0)], vertex_num=2) - return TensorProductPoint(interval, interval).flatten().to_ufl(name) + return TensorProductPoint(line(), line()).flatten().to_ufl(name) # return firedrake_quad().to_ufl(name) # return polygon(4).to_ufl(name) elif name == "tetrahedron": diff --git a/fuse/groups.py b/fuse/groups.py index e6a5d51..cddadff 100644 --- a/fuse/groups.py +++ b/fuse/groups.py @@ -55,6 +55,9 @@ def compute_perm(self, base_val=None): return val, val_list def numeric_rep(self): + """ Uses a standard formula to number permutations in the group. + For the case where this doesn't automatically number from 0..n (ie the group is not the full symmetry group), + a mapping is constructed on group creation""" identity = self.group.identity.perm.array_form m_array = self.perm.array_form val = 0 @@ -62,6 +65,8 @@ def numeric_rep(self): loc = m_array.index(identity[i]) m_array.remove(identity[i]) val += loc * math.factorial(len(identity) - i - 1) + if self.group.group_rep_numbering is not None: + return self.group.group_rep_numbering[val] return val def __eq__(self, x): @@ -134,6 +139,11 @@ def __init__(self, perm_list, cell=None): self._members.append(p_rep) counter += 1 + self.group_rep_numbering = None + numeric_reps = [m.numeric_rep() for m in self.members()] + if sorted(numeric_reps) != list(range(len(numeric_reps))): + self.group_rep_numbering = {a: b for a, b in zip(sorted(numeric_reps), list(range(len(numeric_reps))))} + def add_cell(self, cell): return PermutationSetRepresentation(self.perm_list, cell=cell) @@ -224,6 +234,11 @@ def __init__(self, base_group, cell=None): self.identity = p_rep counter += 1 + self.group_rep_numbering = None + numeric_reps = [m.numeric_rep() for m in self.members()] + if sorted(numeric_reps) != list(range(len(numeric_reps))): + self.group_rep_numbering = {a: b for a, b in zip(sorted(numeric_reps), list(range(len(numeric_reps))))} + # this order produces simpler generator lists # self.generators.reverse() diff --git a/fuse/spaces/polynomial_spaces.py b/fuse/spaces/polynomial_spaces.py index bff2c95..47c775b 100644 --- a/fuse/spaces/polynomial_spaces.py +++ b/fuse/spaces/polynomial_spaces.py @@ -1,13 +1,16 @@ from FIAT.polynomial_set import ONPolynomialSet +from FIAT.expansions import morton_index2, morton_index3 from FIAT.quadrature_schemes import create_quadrature from FIAT.reference_element import cell_to_simplex from FIAT import expansions, polynomial_set, reference_element from itertools import chain -from fuse.utils import tabulate_sympy, max_deg_sp_mat +from fuse.utils import tabulate_sympy, max_deg_sp_expr import sympy as sp import numpy as np from functools import total_ordering +morton_index = {2: morton_index2, 3: morton_index3} + @total_ordering class PolynomialSpace(object): @@ -47,7 +50,6 @@ def degree(self): return self.maxdegree def to_ON_polynomial_set(self, ref_el, k=None): - # how does super/sub degrees work here if not isinstance(ref_el, reference_element.Cell): ref_el = ref_el.to_fiat() sd = ref_el.get_spatial_dimension() @@ -56,18 +58,25 @@ def to_ON_polynomial_set(self, ref_el, k=None): shape = (sd,) else: shape = tuple() + base_ON = ONPolynomialSet(ref_el, self.maxdegree, shape, scale="orthonormal") + indices = None if self.mindegree > 0: - base_ON = ONPolynomialSet(ref_el, self.maxdegree, shape, scale="orthonormal") dimPmin = expansions.polynomial_dimension(ref_el, self.mindegree) dimPmax = expansions.polynomial_dimension(ref_el, self.maxdegree) if self.set_shape: indices = list(chain(*(range(i * dimPmin, i * dimPmax) for i in range(sd)))) else: indices = list(range(dimPmin, dimPmax)) - restricted_ON = base_ON.take(indices) - return restricted_ON - return ONPolynomialSet(ref_el, self.maxdegree, shape, scale="orthonormal") + + if self.contains != self.maxdegree and self.contains != -1: + indices = [morton_index[sd](p, q) for p in range(self.contains + 1) for q in range(self.contains + 1)] + + if indices is None: + return base_ON + + restricted_ON = base_ON.take(indices) + return restricted_ON def __repr__(self): res = "" @@ -161,37 +170,49 @@ def to_ON_polynomial_set(self, ref_el): if not isinstance(ref_el, reference_element.Cell): ref_el = ref_el.to_fiat() k = max([s.maxdegree for s in self.spaces]) - space_poly_sets = [s.to_ON_polynomial_set(ref_el) for s in self.spaces] sd = ref_el.get_spatial_dimension() ref_el = cell_to_simplex(ref_el) - if all([w == 1 for w in self.weights]): - weighted_sets = space_poly_sets - # otherwise have to work on this through tabulation Q = create_quadrature(ref_el, 2 * (k + 1)) Qpts, Qwts = Q.get_points(), Q.get_weights() weighted_sets = [] - for (space, w) in zip(space_poly_sets, self.weights): + for (s, w) in zip(self.spaces, self.weights): + space = s.to_ON_polynomial_set(ref_el) + if s.set_shape: + shape = (sd,) + else: + shape = tuple() if not (isinstance(w, sp.Expr) or isinstance(w, sp.Matrix)): weighted_sets.append(space) else: - w_deg = max_deg_sp_mat(w) - Pkpw = ONPolynomialSet(ref_el, space.degree + w_deg, scale="orthonormal") - vec_Pkpw = ONPolynomialSet(ref_el, space.degree + w_deg, (sd,), scale="orthonormal") + if isinstance(w, sp.Expr): + w = sp.Matrix([[w]]) + vec = False + else: + vec = True + w_deg = max_deg_sp_expr(w) + Pkpw = ONPolynomialSet(ref_el, space.degree + w_deg, shape, scale="orthonormal") + # vec_Pkpw = ONPolynomialSet(ref_el, space.degree + w_deg, (sd,), scale="orthonormal") space_at_Qpts = space.tabulate(Qpts)[(0,) * sd] Pkpw_at_Qpts = Pkpw.tabulate(Qpts)[(0,) * sd] tabulated_expr = tabulate_sympy(w, Qpts).T - scaled_at_Qpts = space_at_Qpts[:, None, :] * tabulated_expr[None, :, :] + if s.set_shape or vec: + scaled_at_Qpts = space_at_Qpts[:, None, :] * tabulated_expr[None, :, :] + else: + # breakpoint() + scaled_at_Qpts = space_at_Qpts[:, None, :] * tabulated_expr[None, :, :] + scaled_at_Qpts = scaled_at_Qpts.squeeze() PkHw_coeffs = np.dot(np.multiply(scaled_at_Qpts, Qwts), Pkpw_at_Qpts.T) + # breakpoint() weighted_sets.append(polynomial_set.PolynomialSet(ref_el, space.degree + w_deg, space.degree + w_deg, - vec_Pkpw.get_expansion_set(), + Pkpw.get_expansion_set(), PkHw_coeffs)) combined_sets = weighted_sets[0] for i in range(1, len(weighted_sets)): diff --git a/fuse/triples.py b/fuse/triples.py index e047a72..30561d7 100644 --- a/fuse/triples.py +++ b/fuse/triples.py @@ -126,10 +126,12 @@ def to_fiat(self): entity_perms, pure_perm = self.make_dof_perms(ref_el, entity_ids, nodes, poly_set) self.matrices = self.make_overall_dense_matrices(ref_el, entity_ids, nodes, poly_set) form_degree = 1 if self.spaces[0].set_shape else 0 - print("my", [n.pt_dict for n in nodes]) + # print("my", [n.pt_dict for n in nodes]) + print(self.cell) + print(ref_el) + print("first class", ref_el.topology) + print(pure_perm) print(entity_perms) - print(entity_ids) - print(ref_el.vertices) print() # TODO: Change this when Dense case in Firedrake if pure_perm: @@ -378,7 +380,6 @@ def make_dof_perms(self, ref_el, entity_ids, nodes, poly_set): if g in dof_gen_class[dim].g1.members() or (pure_perm and len(dof_gen_class[dim].g1.members()) > 1): sub_mat = g.matrix_form() oriented_mats_overall[val][np.ix_(ent_dofs_ids, ent_dofs_ids)] = sub_mat.copy() - for val, mat in oriented_mats_overall.items(): cell_dofs = entity_ids[dim][0] flat_by_entity[dim][e_id][val] = perm_matrix_to_perm_array(mat[np.ix_(cell_dofs, cell_dofs)]) diff --git a/fuse/utils.py b/fuse/utils.py index b475cc4..ab98522 100644 --- a/fuse/utils.py +++ b/fuse/utils.py @@ -41,7 +41,7 @@ def tabulate_sympy(expr, pts): # expr: sp matrix expression in x,y,z for components of R^d # pts: n values in R^d # returns: evaluation of expr at pts - res = np.array(pts) + res = np.zeros((pts.shape[0],) + (expr.shape[-1],)) i = 0 syms = ["x", "y", "z"] for pt in pts: @@ -51,16 +51,21 @@ def tabulate_sympy(expr, pts): subbed = np.array(subbed).astype(np.float64) res[i] = subbed[0] i += 1 - final = res.squeeze() - return final + # final = res.squeeze() + return res -def max_deg_sp_mat(sp_mat): +def max_deg_sp_expr(sp_expr): degs = [] - for comp in sp_mat: - # only compute degree if component is a polynomial - if sp.sympify(comp).as_poly(): - degs += [sp.sympify(comp).as_poly().degree()] + if isinstance(sp_expr, sp.Matrix): + for comp in sp_expr: + # only compute degree if component is a polynomial + if sp.sympify(comp).as_poly(): + degs += [sp.sympify(comp).as_poly().degree()] + else: + if sp.sympify(sp_expr).as_poly(): + degs += [sp.sympify(sp_expr).as_poly().degree()] + return max(degs) diff --git a/test/test_cells.py b/test/test_cells.py index 900828d..577ba94 100644 --- a/test/test_cells.py +++ b/test/test_cells.py @@ -182,3 +182,13 @@ def test_comparison(): print(tensor_product >= tensor_product1) # print(tensor_product1 >= tensor_product) # print(tensor_product1 >= tensor_product1) + + +def test_self_equality(C): + assert C == C + + +@pytest.mark.parametrize(["A", "B", "res"], [(firedrake_triangle(), polygon(3), False), + (line(), line(), True),]) +def test_equivalence(A, B, res): + assert A.equivalent(B) == res diff --git a/test/test_convert_to_fiat.py b/test/test_convert_to_fiat.py index 4c00cf7..0d3ab76 100644 --- a/test/test_convert_to_fiat.py +++ b/test/test_convert_to_fiat.py @@ -94,13 +94,11 @@ def create_cg1(cell): def create_cg1_quad(): deg = 1 - # cell = polygon(4) - cell = constructCellComplex("quadrilateral").cell_complex - + cell = TensorProductPoint(line(), line()).flatten() + print(cell, type(cell)) vert_dg = create_dg1(cell.vertices()[0]) xs = [immerse(cell, vert_dg, TrH1)] - - Pk = PolynomialSpace(deg, deg + 1) + Pk = PolynomialSpace(deg + 1, deg) cg = ElementTriple(cell, (Pk, CellL2, C0), DOFGenerator(xs, get_cyc_group(len(cell.vertices())), S1)) return cg @@ -434,6 +432,7 @@ def helmholtz_solve(mesh, V): def run_test(r, elem, parameters={}, quadrilateral=False): # Create mesh and define function space m = UnitSquareMesh(2 ** r, 2 ** r, quadrilateral=quadrilateral) + x = SpatialCoordinate(m) V = FunctionSpace(m, elem) # Define variational problem @@ -453,6 +452,29 @@ def run_test(r, elem, parameters={}, quadrilateral=False): return sqrt(assemble(inner(u - f, u - f) * dx)) +def run_test_original(r, elem_code, deg, parameters={}, quadrilateral=False): + # Create mesh and define function space + m = UnitSquareMesh(2 ** r, 2 ** r, quadrilateral=quadrilateral) + + x = SpatialCoordinate(m) + V = FunctionSpace(m, elem_code, deg) + # Define variational problem + u = Function(V) + v = TestFunction(V) + a = inner(grad(u), grad(v)) * dx + + bcs = [DirichletBC(V, Constant(0), 3), + DirichletBC(V, Constant(42), 4)] + + # Compute solution + solve(a == 0, u, solver_parameters=parameters, bcs=bcs) + + f = Function(V) + f.interpolate(42*x[1]) + + return sqrt(assemble(inner(u - f, u - f) * dx)) + + @pytest.mark.parametrize(['params', 'elem_gen'], [(p, d) for p in [{}, {'snes_type': 'ksponly', 'ksp_type': 'preonly', 'pc_type': 'lu'}] @@ -464,17 +486,21 @@ def test_poisson_analytic(params, elem_gen): @pytest.mark.parametrize(['elem_gen'], - [(create_cg1_quad_tensor,), pytest.param(create_cg1_quad, marks=pytest.mark.xfail(reason='Need to allow generation on tensor product quads'))]) + [(create_cg1_quad_tensor,), pytest.param(create_cg1_quad, marks=pytest.mark.xfail(reason='Issue with cell/mesh'))]) def test_quad(elem_gen): elem = elem_gen() r = 0 - # m = UnitSquareMesh(2 ** r, 2 ** r, quadrilateral=True) ufl_elem = elem.to_ufl() assert (run_test(r, ufl_elem, parameters={}, quadrilateral=True) < 1.e-9) -def test_non_tensor_quad(): - create_cg1_quad() +# # @pytest.mark.xfail(reason="Issue with quad cell") +# def test_non_tensor_quad(): +# elem = create_cg1_quad() +# ufl_elem = elem.to_ufl() +# print(elem.to_fiat().entity_permutations()) +# # elem.cell.hasse_diagram(filename="cg1quad.png") +# assert (run_test_original(1, "CG", 1, parameters={}, quadrilateral=True) < 1.e-9) @pytest.mark.parametrize("elem_gen,elem_code,deg", [(create_cg2_tri, "CG", 2), @@ -562,3 +588,12 @@ def test_investigate_dpc(): U = FunctionSpace(mesh, "DPC", 1) print(U) + f = Function(U) + f.assign(1) + + out = Function(U) + u = TrialFunction(U) + v = TestFunction(U) + a = inner(u, v)*dx + L = inner(f, v)*dx + solve(a == L, out) diff --git a/test/test_groups.py b/test/test_groups.py index cc3765f..149d728 100644 --- a/test/test_groups.py +++ b/test/test_groups.py @@ -124,3 +124,10 @@ def test_perm_mat_conversion(): mat_form = g.matrix_form() array_form = perm_matrix_to_perm_array(mat_form) assert np.allclose(g.perm.array_form, array_form) + + +def test_numeric_reps(): + cell = polygon(4) + rot4 = get_cyc_group(4).add_cell(cell) + + assert sorted([m.numeric_rep() for m in rot4.members()]) == list(range(len(rot4.members()))) diff --git a/test/test_polynomial_space.py b/test/test_polynomial_space.py index cb1c769..9709ee9 100644 --- a/test/test_polynomial_space.py +++ b/test/test_polynomial_space.py @@ -41,6 +41,7 @@ def test_restriction(): res_on_set = restricted.to_ON_polynomial_set(cell) P3_on_set = P3.to_ON_polynomial_set(cell) + assert res_on_set.get_num_members() < P3_on_set.get_num_members() not_restricted = P3.restrict(0, 3) @@ -48,6 +49,16 @@ def test_restriction(): assert not_restricted.mindegree == 0 +def test_square_space(): + cell = polygon(3) + q2 = PolynomialSpace(3, 1) + + q2_on_set = q2.to_ON_polynomial_set(cell) + P3_on_set = P3.to_ON_polynomial_set(cell) + + assert q2_on_set.get_num_members() < P3_on_set.get_num_members() + + @pytest.mark.parametrize("deg", [1, 2, 3, 4]) def test_complete_space(deg): cell = polygon(3) diff --git a/test/test_tensor_prod.py b/test/test_tensor_prod.py index fa117e7..ea689d6 100644 --- a/test/test_tensor_prod.py +++ b/test/test_tensor_prod.py @@ -116,3 +116,16 @@ def test_quad_mesh_helmholtz(): conv = np.log2(res[:-1] / res[1:]) print("convergence order:", conv) assert (np.array(conv) > 1.8).all() + + +@pytest.mark.parametrize(["A", "B", "res"], [(Point(0), line(), False), + (line(), line(), True), + (polygon(3), line(), False),]) +def test_flattening(A, B, res): + tensor_cell = TensorProductPoint(A, B) + if not res: + with pytest.raises(AssertionError): + tensor_cell.flatten() + else: + cell = tensor_cell.flatten() + cell.construct_fuse_rep()