Skip to content

Commit e8531bf

Browse files
authored
Merge pull request #62 from MunchLab/sparse-alg
add new sparse ect algo
2 parents 734040e + e68bbdf commit e8531bf

4 files changed

Lines changed: 193 additions & 118 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "ect"
3-
version = "1.1.2"
3+
version = "1.2.0"
44
authors = [
55
{ name="Liz Munch", email="muncheli@msu.edu" },
66
]

src/ect/ect.py

Lines changed: 102 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -101,34 +101,42 @@ def _ensure_thresholds(self, graph, override_bound_radius=None):
101101
def calculate(
102102
self,
103103
graph: EmbeddedComplex,
104-
theta: Optional[float] = None,
105-
override_bound_radius: Optional[float] = None,
104+
theta: float = None,
105+
override_bound_radius: float = None,
106106
):
107-
"""Calculate Euler Characteristic Transform (ECT) for a given graph and direction theta
108-
109-
Args:
110-
graph (EmbeddedComplex):
111-
The input complex to calculate the ECT for.
112-
theta (float):
113-
The angle in :math:`[0,2\pi]` for the direction to calculate the ECT.
114-
override_bound_radius (float):
115-
If None, uses the following in order: (i) the bounding radius stored in the class; or if not available (ii) the bounding radius of the given graph. Otherwise, should be a positive float :math:`R` where the ECC will be computed at thresholds in :math:`[-R,R]`. Default is None.
116-
"""
117107
self._ensure_directions(graph.dim, theta)
118108
self._ensure_thresholds(graph, override_bound_radius)
119-
120-
# override with theta if provided
121109
directions = (
122110
self.directions if theta is None else Directions.from_angles([theta])
123111
)
112+
ect_matrix = self._compute_ect(graph, directions, self.thresholds, self.dtype)
124113

125-
simplex_projections = self._compute_simplex_projections(graph, directions)
114+
return ECTResult(ect_matrix, directions, self.thresholds)
126115

127-
ect_matrix = self._compute_directional_transform(
128-
simplex_projections, self.thresholds, self.dtype
116+
def _compute_ect(
117+
self, graph, directions, thresholds: np.ndarray, dtype=np.int32
118+
) -> np.ndarray:
119+
cell_vertex_pointers, cell_vertex_indices_flat, cell_euler_signs, N = (
120+
graph._build_incidence_csr()
129121
)
130-
131-
return ECTResult(ect_matrix, directions, self.thresholds)
122+
thresholds = np.asarray(thresholds, dtype=np.float64)
123+
124+
V = directions.vectors
125+
X = graph.coord_matrix
126+
H = X @ V if V.shape[0] == X.shape[1] else X @ V.T # (N, m)
127+
H_T = np.ascontiguousarray(H.T) # (m, N) for contiguous per-direction rows
128+
129+
out64 = _ect_all_dirs(
130+
H_T,
131+
cell_vertex_pointers,
132+
cell_vertex_indices_flat,
133+
cell_euler_signs,
134+
thresholds,
135+
N,
136+
)
137+
if dtype == np.int32:
138+
return out64.astype(np.int32)
139+
return out64
132140

133141
def _compute_simplex_projections(self, graph: EmbeddedComplex, directions):
134142
"""Compute projections of each k-cell for all dimensions"""
@@ -160,41 +168,79 @@ def _compute_simplex_projections(self, graph: EmbeddedComplex, directions):
160168

161169
return simplex_projections
162170

163-
@staticmethod
164-
@njit(parallel=True, fastmath=True)
165-
def _compute_directional_transform(
166-
simplex_projections_list, thresholds, dtype=np.int32
167-
):
168-
"""Compute ECT by counting simplices below each threshold - VECTORIZED VERSION
169-
170-
Args:
171-
simplex_projections_list: List of arrays containing projections for each simplex type
172-
[vertex_projections, edge_projections, face_projections]
173-
thresholds: Array of threshold values to compute ECT at
174-
dtype: Data type for output array (default: np.int32)
175171

176-
Returns:
177-
Array of shape (num_directions, num_thresholds) containing Euler characteristics
178-
"""
179-
num_dir = simplex_projections_list[0].shape[1]
180-
num_thresh = thresholds.shape[0]
181-
result = np.empty((num_dir, num_thresh), dtype=dtype)
182-
183-
sorted_projections = List()
184-
for proj in simplex_projections_list:
185-
sorted_proj = np.empty_like(proj)
186-
for i in prange(num_dir):
187-
sorted_proj[:, i] = np.sort(proj[:, i])
188-
sorted_projections.append(sorted_proj)
189-
190-
for i in prange(num_dir):
191-
chi = np.zeros(num_thresh, dtype=dtype)
192-
for k in range(len(sorted_projections)):
193-
projs = sorted_projections[k][:, i]
194-
195-
count = np.searchsorted(projs, thresholds, side="right")
196-
197-
sign = -1 if k % 2 else 1
198-
chi += sign * count
199-
result[i] = chi
200-
return result
172+
@njit(cache=True, parallel=True)
173+
def _ect_all_dirs(
174+
heights_by_direction, # shape (num_directions, num_vertices)
175+
cell_vertex_pointers, # shape (num_cells + 1,)
176+
cell_vertex_indices_flat, # concatenated vertex indices for all cells
177+
cell_euler_signs, # per-cell sign: (+1) for even-dim, (-1) for odd-dim
178+
threshold_values, # shape (num_thresholds,), assumed nondecreasing
179+
num_vertices,
180+
):
181+
"""
182+
Calculate the Euler Characteristic Transform (ECT) for a given direction and thresholds.
183+
184+
Args:
185+
heights_by_direction: The heights of the vertices for each direction
186+
cell_vertex_pointers: The pointers to the vertices for each cell
187+
cell_vertex_indices_flat: The indices of the vertices for each cell
188+
cell_euler_signs: The signs of the cells
189+
threshold_values: The thresholds to calculate the ECT for
190+
num_vertices: The number of vertices in the graph
191+
"""
192+
num_directions = heights_by_direction.shape[0]
193+
num_thresholds = threshold_values.shape[0]
194+
ect_values = np.empty((num_directions, num_thresholds), dtype=np.int64)
195+
196+
for dir_idx in prange(num_directions):
197+
heights = heights_by_direction[dir_idx]
198+
199+
sort_order = np.argsort(heights)
200+
201+
# calculate what position each vertex is in the sorted heights starting from 1 (the rank)
202+
vertex_rank_1based = np.empty(num_vertices, dtype=np.int32)
203+
for rnk in range(num_vertices):
204+
vertex_index = sort_order[rnk]
205+
vertex_rank_1based[vertex_index] = rnk + 1
206+
207+
# euler char can only jump at each vertex value
208+
jump_amount = np.zeros(num_vertices + 1, dtype=np.int64)
209+
210+
# 0-cells add +1 at their entrance ranks
211+
for v in range(num_vertices):
212+
rank_v = vertex_rank_1based[v]
213+
jump_amount[rank_v] += 1
214+
215+
# each pair of pointers defines a cell, so we iterate over them
216+
num_cells = cell_vertex_pointers.shape[0] - 1
217+
for cell_idx in range(num_cells):
218+
start = cell_vertex_pointers[cell_idx]
219+
end = cell_vertex_pointers[cell_idx + 1]
220+
# cells come in when the highest vertex enters
221+
entrance_rank = 0
222+
for k in range(start, end):
223+
v = cell_vertex_indices_flat[k]
224+
rnk = vertex_rank_1based[v]
225+
if rnk > entrance_rank:
226+
entrance_rank = rnk
227+
# record at what rank the cell enters and how much the euler char changes
228+
jump_amount[entrance_rank] += cell_euler_signs[cell_idx]
229+
230+
# calculate euler char at the moment each vertex enters
231+
euler_prefix = np.empty(num_vertices + 1, dtype=np.int64)
232+
running_sum = 0
233+
for r in range(num_vertices + 1):
234+
running_sum += jump_amount[r]
235+
euler_prefix[r] = running_sum
236+
237+
# now find euler char at each threshold wrt the sorted heights
238+
sorted_heights = heights[sort_order]
239+
rank_cursor = 0 # equals r(t) = # { i : h_i <= t }
240+
for thresh_idx in range(num_thresholds):
241+
t = threshold_values[thresh_idx]
242+
while rank_cursor < num_vertices and sorted_heights[rank_cursor] <= t:
243+
rank_cursor += 1
244+
ect_values[dir_idx, thresh_idx] = euler_prefix[rank_cursor]
245+
246+
return ect_values

src/ect/embed_complex.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -782,6 +782,62 @@ def _get_nice_interval(self, range_size):
782782

783783
return nice_interval * magnitude
784784

785+
def _build_incidence_csr(self) -> tuple:
786+
"""
787+
Build column sparse representation of the cell-to-vertex incidence excluding 0-cells. Format is (cell_vertex_pointers, cell_vertex_indices_flat, cell_euler_signs, n_vertices).
788+
Example: takes the complex [(1,3),(2,4),(1,2,3)] and returns [(0,2,4,7),(1,3,2,4,1,2,3),(-1,-1,1),4]
789+
790+
"""
791+
n_vertices = len(self.node_list)
792+
793+
cells_by_dimension = {}
794+
795+
if hasattr(self, "edge_indices") and self.edge_indices is not None:
796+
edge_indices_array = np.asarray(self.edge_indices)
797+
if edge_indices_array.size:
798+
cells_by_dimension[1] = [
799+
tuple(map(int, row)) for row in edge_indices_array
800+
]
801+
802+
if hasattr(self, "cells") and self.cells:
803+
for dim, cells_of_dim in self.cells.items():
804+
if dim == 0:
805+
continue
806+
if dim == 1 and 1 in cells_by_dimension:
807+
continue
808+
if isinstance(cells_of_dim, np.ndarray):
809+
cell_list = [tuple(map(int, row)) for row in cells_of_dim]
810+
else:
811+
cell_list = [tuple(map(int, c)) for c in cells_of_dim]
812+
if len(cell_list) > 0:
813+
cells_by_dimension[dim] = cell_list
814+
815+
dimensions = sorted(cells_by_dimension.keys())
816+
n_cells = sum(len(cells_by_dimension[d]) for d in dimensions)
817+
818+
cell_vertex_pointers = np.empty(n_cells + 1, dtype=np.int64)
819+
cell_euler_signs = np.empty(n_cells, dtype=np.int32)
820+
cell_vertex_indices_flat = []
821+
822+
cell_vertex_pointers[0] = 0
823+
cell_index = 0
824+
for dim in dimensions:
825+
cells_in_dim = cells_by_dimension[dim]
826+
euler_sign = 1 if (dim % 2 == 0) else -1
827+
for cell_vertices in cells_in_dim:
828+
cell_vertex_indices_flat.extend(cell_vertices)
829+
cell_euler_signs[cell_index] = euler_sign
830+
cell_index += 1
831+
cell_vertex_pointers[cell_index] = len(cell_vertex_indices_flat)
832+
833+
cell_vertex_indices_flat = np.asarray(cell_vertex_indices_flat, dtype=np.int32)
834+
return (
835+
cell_vertex_pointers,
836+
cell_vertex_indices_flat,
837+
cell_euler_signs,
838+
n_vertices,
839+
)
840+
785841

786842
EmbeddedGraph = EmbeddedComplex
787843
EmbeddedCW = EmbeddedComplex

0 commit comments

Comments
 (0)