Skip to content

Commit ee99e1e

Browse files
committed
convert derivative nodes to fiat def, lint
1 parent 781d042 commit ee99e1e

File tree

7 files changed

+94
-78
lines changed

7 files changed

+94
-78
lines changed

fuse/cells.py

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -164,17 +164,15 @@ def polygon(n):
164164

165165
return Point(2, edges, vertex_num=n)
166166

167+
167168
def firedrake_triangle():
168169
vertices = []
169170
for i in range(3):
170171
vertices.append(Point(0))
171172
edges = []
172-
edges.append(
173-
Point(1, [vertices[1], vertices[2]], vertex_num=2))
174-
edges.append(
175-
Point(1, [vertices[0], vertices[2]], vertex_num=2))
176-
edges.append(
177-
Point(1, [vertices[0], vertices[1]], vertex_num=2))
173+
edges.append(Point(1, [vertices[1], vertices[2]], vertex_num=2))
174+
edges.append(Point(1, [vertices[0], vertices[2]], vertex_num=2))
175+
edges.append(Point(1, [vertices[0], vertices[1]], vertex_num=2))
178176
tri = Point(2, edges, vertex_num=3, edge_orientations={1: [1, 0]})
179177
# tri = polygon(3)
180178
s3 = tri.group
@@ -360,7 +358,6 @@ def get_topology(self):
360358
self.topology[i] = {}
361359
self.topology_unrelabelled[i] = {}
362360
for node in dimension:
363-
neighbours = list(self.G.neighbors(node))
364361
self.topology[i][node - min_ids[i]] = tuple([relabelled_verts[vert] for vert in self.get_node(node).ordered_vertices()])
365362
self.topology_unrelabelled[i][node - min_ids[i]] = tuple([vert - min_ids[0] for vert in self.get_node(node).ordered_vertices()])
366363
return self.topology_unrelabelled
@@ -728,11 +725,9 @@ class CellComplexToFiatSimplex(Simplex):
728725
def __init__(self, cell, name=None):
729726
self.fe_cell = cell
730727
if name is not None:
731-
name = "IndiaDefCell"
728+
name = "FuseCell"
732729
self.name = name
733730

734-
735-
736731
# verts = [cell.get_node(v, return_coords=True) for v in cell.ordered_vertices()]
737732
verts = cell.vertices(return_coords=True)
738733
topology = cell.get_topology()
@@ -838,7 +833,7 @@ def to_fiat(self):
838833
return self.cell_complex.to_fiat(name=self.cellname())
839834

840835
def __repr__(self):
841-
return super(CellComplexToUFL, self).__repr__()
836+
return super(CellComplexToUFL, self).__repr__()
842837

843838
def reconstruct(self, **kwargs):
844839
"""Reconstruct this cell, overwriting properties by those in kwargs."""
@@ -866,4 +861,3 @@ def constructCellComplex(name):
866861
return make_tetrahedron().to_ufl(name)
867862
else:
868863
raise TypeError("Cell complex construction undefined for {}".format(str(name)))
869-

fuse/dof.py

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def __call__(self, kernel, v, cell):
3737
def convert_to_fiat(self, ref_el, dof, interpolant_deg):
3838
pt = dof.eval(MyTestFunction(lambda *x: x))
3939
return PointEvaluation(ref_el, pt)
40-
40+
4141
def get_pts(self, ref_el, total_degree):
4242
entity = ref_el.construct_subelement(self.entity.dim())
4343
return [(0,) * entity.get_spatial_dimension()], [1], 1
@@ -272,24 +272,20 @@ def add_context(self, dof_gen, cell, space, g, overall_id=None, generator_id=Non
272272
self.sub_id = generator_id
273273

274274
def convert_to_fiat(self, ref_el, interpolant_degree):
275-
return self.pairing.convert_to_fiat(ref_el, self, interpolant_degree)
276-
277-
def convert_to_fiat_new(self, ref_el, interpolant_degree):
278275
total_degree = self.kernel.degree() + interpolant_degree
279276
pts, wts, jdet = self.pairing.get_pts(ref_el, total_degree)
280277
f_pts = self.kernel.tabulate(pts).T / jdet
281278
# TODO need to work out how i can discover the shape in a better way
282279
if isinstance(self.pairing, DeltaPairing):
283280
shp = tuple()
284-
pt_dict = {tuple(p) : [(w, tuple())] for (p, w) in zip(f_pts.T, wts)}
281+
pt_dict = {tuple(p): [(w, tuple())] for (p, w) in zip(f_pts.T, wts)}
285282
else:
286283
shp = tuple(f_pts.shape[:-1])
287284
weights = np.transpose(np.multiply(f_pts, wts), (-1,) + tuple(range(len(shp))))
288285
alphas = list(np.ndindex(shp))
289286
pt_dict = {tuple(pt): [(wt[alpha], alpha) for alpha in alphas] for pt, wt in zip(pts, weights)}
290287

291-
return Functional(ref_el, shp, pt_dict, {}, self.__repr__())
292-
288+
return [Functional(ref_el, shp, pt_dict, {}, self.__repr__())]
293289

294290
def __repr__(self, fn="v"):
295291
return str(self.pairing).format(fn=fn, kernel=self.kernel)
@@ -329,24 +325,24 @@ def tabulate(self, Qpts):
329325
immersion = self.target_space.tabulate(Qpts, self.trace_entity, self.g)
330326
res = self.kernel.tabulate(Qpts)
331327
return immersion*res
332-
333-
def convert_to_fiat_new(self, ref_el, interpolant_degree):
328+
329+
def convert_to_fiat(self, ref_el, interpolant_degree):
334330
total_degree = self.kernel.degree() + interpolant_degree
335331
pts, wts, jdet = self.pairing.get_pts(ref_el, total_degree)
336332
f_pts = self.kernel.tabulate(pts, self.attachment)
337333
attached_pts = [self.attachment(*p) for p in pts]
338334
immersion = self.target_space.tabulate(f_pts, self.trace_entity, self.g)
339-
335+
340336
f_pts = (f_pts * immersion).T / jdet
341-
pt_dict, deriv_dict = self.target_space.convert_to_fiat(attached_pts, f_pts, wts)
337+
dict_list = self.target_space.convert_to_fiat(attached_pts, f_pts, wts)
342338

343339
# breakpoint()
344340
# TODO need to work out how i can discover the shape in a better way
345341
if isinstance(self.pairing, DeltaPairing):
346342
shp = tuple()
347343
else:
348344
shp = tuple(f_pts.shape[:-1])
349-
return Functional(ref_el, shp, pt_dict, deriv_dict, self.__repr__())
345+
return [Functional(ref_el, shp, pt_dict, deriv_dict, self.__repr__()) for (pt_dict, deriv_dict) in dict_list]
350346

351347
def __call__(self, g):
352348
permuted = self.cell.permute_entities(g, self.trace_entity.dim())

fuse/traces.py

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,8 @@ def tabulate(self, Qpts, trace_entity, g):
5555
return np.ones_like(Qpts)
5656

5757
def convert_to_fiat(self, qpts, pts, wts):
58-
pt_dict = {tuple(p) : [(w, tuple())] for (p, w) in zip(pts.T, wts)}
59-
return pt_dict, {}
58+
pt_dict = {tuple(p): [(w, tuple())] for (p, w) in zip(pts.T, wts)}
59+
return [(pt_dict, {})]
6060

6161
def __repr__(self):
6262
return "H1"
@@ -96,8 +96,7 @@ def convert_to_fiat(self, qpts, pts, wts):
9696
weights = np.transpose(np.multiply(f_at_qpts, wts), (-1,) + tuple(range(len(shp))))
9797
alphas = list(np.ndindex(shp))
9898
pt_dict = {tuple(pt): [(wt[alpha], alpha) for alpha in alphas] for pt, wt in zip(qpts, weights)}
99-
return pt_dict, {}
100-
99+
return [(pt_dict, {})]
101100

102101
def tabulate(self, Qpts, trace_entity, g):
103102
entityBasis = np.array(trace_entity.basis_vectors())
@@ -135,14 +134,14 @@ def tabulate(self, Qpts, trace_entity, g):
135134
subEntityBasis = np.array(self.domain.basis_vectors(entity=trace_entity))
136135
result = np.matmul(tangent, subEntityBasis)
137136
return result
138-
137+
139138
def convert_to_fiat(self, qpts, pts, wts):
140139
f_at_qpts = pts
141140
shp = tuple(f_at_qpts.shape[:-1])
142141
weights = np.transpose(np.multiply(f_at_qpts, wts), (-1,) + tuple(range(len(shp))))
143142
alphas = list(np.ndindex(shp))
144143
pt_dict = {tuple(pt): [(wt[alpha], alpha) for alpha in alphas] for pt, wt in zip(qpts, weights)}
145-
return pt_dict, {}
144+
return [(pt_dict, {})]
146145

147146
def plot(self, ax, coord, trace_entity, g, **kwargs):
148147
permuted = self.domain.permute_entities(g, trace_entity.dimension)
@@ -179,10 +178,28 @@ def apply(*x):
179178
return tuple(result)
180179
return apply
181180

181+
def convert_to_fiat(self, qpts, pts, wts):
182+
shp = tuple(pts.shape[0:])
183+
alphas = []
184+
for i in range(pts.shape[0]):
185+
new = np.zeros_like(shp)
186+
new[i] = 1
187+
alphas += [tuple(new)]
188+
deriv_dicts = []
189+
for alpha in alphas:
190+
deriv_dicts += [{tuple(p): [(1.0, tuple(alpha), tuple())] for p in pts.T}]
191+
192+
# self.alpha = tuple(alpha)
193+
# self.order = sum(self.alpha)
194+
return [({}, d) for d in deriv_dicts]
195+
182196
def plot(self, ax, coord, trace_entity, g, **kwargs):
183197
circle1 = plt.Circle(coord, 0.075, fill=False, **kwargs)
184198
ax.add_patch(circle1)
185199

200+
def tabulate(self, Qpts, trace_entity, g):
201+
return np.ones_like(Qpts)
202+
186203
def __repr__(self):
187204
return "Grad"
188205

fuse/triples.py

Lines changed: 4 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -112,21 +112,10 @@ def to_fiat(self):
112112
dim = entity[0]
113113
for i in range(len(dofs)):
114114
if entity[1] == dofs[i].trace_entity.id - min_ids[dim]:
115-
entity_ids[dim][dofs[i].trace_entity.id - min_ids[dim]].append(counter)
116-
if hasattr(dofs[i], "convert_to_fiat_new"):
117-
nodes.append(dofs[i].convert_to_fiat_new(ref_el, degree))
118-
print("old")
119-
120-
print(dofs[i].convert_to_fiat(ref_el, degree).pt_dict)
121-
print(dofs[i].convert_to_fiat(ref_el, degree).target_shape)
122-
print("new")
123-
124-
print(dofs[i].convert_to_fiat_new(ref_el, degree).pt_dict)
125-
print(dofs[i].convert_to_fiat_new(ref_el, degree).target_shape)
126-
else:
127-
raise ValueError("using old")
128-
nodes.append(dofs[i].convert_to_fiat(ref_el, degree))
129-
counter += 1
115+
fiat_dofs = dofs[i].convert_to_fiat(ref_el, degree)
116+
nodes.extend(fiat_dofs)
117+
entity_ids[dim][dofs[i].trace_entity.id - min_ids[dim]].extend([counter + i for i in range(len(fiat_dofs))])
118+
counter += len(fiat_dofs)
130119
entity_perms, pure_perm = self.make_dof_perms(ref_el, entity_ids, nodes, poly_set)
131120

132121
form_degree = 1 if self.spaces[0].set_shape else 0

test/test_cells.py

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -111,10 +111,11 @@ def test_oriented_verts():
111111
print(g, permuted)
112112
assert g.permute(tetra.ordered_vertices()) == oriented.ordered_vertices()
113113

114+
114115
def test_compare_cell_to_firedrake():
115116
tri1 = polygon(3)
116117
tri2 = default_simplex(2)
117-
118+
118119
n = 3
119120
vertices = []
120121
for i in range(n):
@@ -123,40 +124,38 @@ def test_compare_cell_to_firedrake():
123124
for i in range(n):
124125
edges.append(
125126
Point(1, [vertices[(i) % n], vertices[(i+1) % n]], vertex_num=2))
126-
from sympy.combinatorics.named_groups import SymmetricGroup
127-
s3 = SymmetricGroup(3)
127+
128128
cellS3 = S3.add_cell(tri1)
129129
for g in cellS3.members():
130130
print(g.perm.array_form)
131-
try:
132-
p = g.perm.array_form
133-
# tri3 = Point(2, [edges[p[0]], edges[p[1]], edges[p[2]]], vertex_num=n)
134-
print(tri1.orient(g).get_topology())
135-
except:
136-
print('FAIL')
137-
131+
p = g.perm.array_form
132+
tri3 = Point(2, [edges[p[0]], edges[p[1]], edges[p[2]]], vertex_num=n)
133+
print(tri1.orient(g).get_topology())
134+
print(tri3.get_topology())
138135

139136
# print(tri1.get_topology())
140137
print(tri2.get_topology())
141138
tri3 = firedrake_triangle()
142139
print(tri3.get_topology())
143140

141+
144142
@pytest.fixture
145143
def mock_cell_complex(mocker, expect):
146144
mocker.patch('firedrake.mesh.constructCellComplex', return_value=expect.to_ufl("triangle"))
147145

146+
148147
@pytest.mark.usefixtures("mock_cell_complex")
149-
@pytest.mark.parametrize(["expect"],[(firedrake_triangle(),), (polygon(3),)])
148+
@pytest.mark.parametrize(["expect"], [(firedrake_triangle(),), (polygon(3),)])
150149
def test_ref_els(expect):
151150
scale_range = range(3, 6)
152151

153152
diff2 = [0 for i in scale_range]
154153
for i in scale_range:
155-
mesh = UnitSquareMesh(2 ** i, 2 ** i)
154+
mesh = UnitSquareMesh(2 ** i, 2 ** i)
156155

157-
V = FunctionSpace(mesh, "CG", 3)
158-
res1 = helmholtz_solve(mesh, V)
159-
diff2[i-3] = res1
156+
V = FunctionSpace(mesh, "CG", 3)
157+
res1 = helmholtz_solve(mesh, V)
158+
diff2[i-3] = res1
160159

161160
print("firedrake l2 error norms:", diff2)
162161
diff2 = np.array(diff2)

test/test_convert_to_fiat.py

Lines changed: 15 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,7 @@ def create_cg2_tri(cell):
133133
DOFGenerator(edge_xs, C3, S1)])
134134
return cg
135135

136+
136137
def create_hermite(tri):
137138
vert = tri.vertices()[0]
138139

@@ -145,14 +146,11 @@ def create_hermite(tri):
145146
v_derv_xs = [immerse(tri, dg0, TrGrad)]
146147
v_derv_dofs = DOFGenerator(v_derv_xs, S3/S2, S1)
147148

148-
v_derv2_xs = [immerse(tri, dg0, TrHess)]
149-
v_derv2_dofs = DOFGenerator(v_derv2_xs, S3/S2, S1)
150-
151149
i_xs = [DOF(DeltaPairing(), PointKernel((0, 0)))]
152150
i_dofs = DOFGenerator(i_xs, S1, S1)
153151

154152
her = ElementTriple(tri, (P3, CellH2, C0),
155-
[v_dofs, v_derv_dofs, v_derv2_dofs, i_dofs])
153+
[v_dofs, v_derv_dofs, i_dofs])
156154
return her
157155

158156

@@ -323,7 +321,6 @@ def test_helmholtz(elem_gen, elem_code, deg, conv_rate):
323321
diff2 = np.array(diff2)
324322
conv1 = np.log2(diff2[:-1] / diff2[1:])
325323
print("firedrake convergence order:", conv1)
326-
327324

328325
print("fuse l2 error norms:", diff)
329326
diff = np.array(diff)
@@ -351,10 +348,10 @@ def helmholtz_solve(mesh, V):
351348
a = (inner(grad(u), grad(v)) + inner(u, v)) * dx
352349
L = inner(f, v) * dx
353350
u = Function(V)
354-
l_a = assemble(L)
355-
elem = V.finat_element.fiat_equivalent
356-
W = VectorFunctionSpace(mesh, V.ufl_element())
357-
X = assemble(interpolate(mesh.coordinates, W))
351+
# l_a = assemble(L)
352+
# elem = V.finat_element.fiat_equivalent
353+
# W = VectorFunctionSpace(mesh, V.ufl_element())
354+
# X = assemble(interpolate(mesh.coordinates, W))
358355
solve(a == L, u)
359356
f.interpolate(cos(x*pi*2)*cos(y*pi*2))
360357
return sqrt(assemble(dot(u - f, u - f) * dx))
@@ -564,7 +561,8 @@ def test_project_3d(elem_gen, elem_code, deg):
564561

565562
assert np.allclose(out.dat.data, f.dat.data, rtol=1e-5)
566563

567-
@pytest.mark.xfail(reason='Derivative nodes to fiat')
564+
565+
@pytest.mark.xfail(reason='Handling generation of multiple fiat nodes from one in permutations')
568566
def test_create_hermite():
569567
deg = 3
570568
cell = polygon(3)
@@ -581,13 +579,13 @@ def test_create_hermite():
581579
Qpts, _ = Q.get_points(), Q.get_weights()
582580

583581
fiat_vals = fiat_elem.tabulate(0, Qpts)
584-
# my_vals = my_elem.tabulate(0, Qpts)
582+
my_vals = my_elem.tabulate(0, Qpts)
585583

586584
fiat_vals = flatten(fiat_vals[(0,) * sd])
587-
# my_vals = flatten(my_vals[(0,) * sd])
585+
my_vals = flatten(my_vals[(0,) * sd])
588586

589-
# (x, res, _, _) = np.linalg.lstsq(fiat_vals.T, my_vals.T)
590-
# x1 = np.linalg.inv(x)
591-
# assert np.allclose(np.linalg.norm(my_vals.T - fiat_vals.T @ x), 0)
592-
# assert np.allclose(np.linalg.norm(fiat_vals.T - my_vals.T @ x1), 0)
593-
# assert np.allclose(res, 0)
587+
(x, res, _, _) = np.linalg.lstsq(fiat_vals.T, my_vals.T)
588+
x1 = np.linalg.inv(x)
589+
assert np.allclose(np.linalg.norm(my_vals.T - fiat_vals.T @ x), 0)
590+
assert np.allclose(np.linalg.norm(fiat_vals.T - my_vals.T @ x1), 0)
591+
assert np.allclose(res, 0)

test/test_dofs.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from fuse import *
2-
from test_convert_to_fiat import create_cg1, create_dg1, construct_cg3, construct_rt, construct_nd
2+
from test_convert_to_fiat import create_cg1, create_dg1, construct_cg3, construct_rt, construct_nd, create_hermite
33
import sympy as sp
44
import numpy as np
55

@@ -106,3 +106,26 @@ def test_permute_nd():
106106
# print(g)
107107
# print(nd.cell.permute_entities(g, 0))
108108
# print(nd.cell.permute_entities(g, 1))
109+
110+
111+
def test_convert_dofs():
112+
cell = polygon(3)
113+
114+
cg3 = create_hermite(cell)
115+
116+
for dof in cg3.generate():
117+
print(dof)
118+
# print("old")
119+
# # old = dof.convert_to_fiat(cell.to_fiat(), 5).pt_dict
120+
# print(old)
121+
# print("new")
122+
new = dof.convert_to_fiat_new(cell.to_fiat(), 5)[0].pt_dict
123+
print(new)
124+
new = dof.convert_to_fiat_new(cell.to_fiat(), 5)[0].deriv_dict
125+
print(new)
126+
from FIAT.hermite import CubicHermite
127+
fiat_elem = CubicHermite(cell.to_fiat(), 3)
128+
129+
print(len([n.pt_dict for n in fiat_elem.dual.nodes]))
130+
print([n.pt_dict for n in fiat_elem.dual.nodes])
131+
print([n.deriv_dict for n in fiat_elem.dual.nodes])

0 commit comments

Comments
 (0)