Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "ect"
version = "1.2.4"
version = "1.3.0"
authors = [
{ name="Liz Munch", email="muncheli@msu.edu" },
]
Expand Down
222 changes: 158 additions & 64 deletions src/ect/ect.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,23 @@
from .results import ECTResult


def _thresholds_are_uniform(thresholds: np.ndarray) -> bool:
thresholds = np.asarray(thresholds, dtype=float)
if thresholds.ndim != 1:
raise ValueError("thresholds must be a 1-dimensional array")
n = thresholds.size
if n <= 1:
return True
diffs = np.diff(thresholds)
first = diffs[0]
if first == 0.0:
return bool(np.all(diffs == 0.0))
tol = 1e-12 * max(1.0, abs(first))
return bool(np.all(np.abs(diffs - first) <= tol))


class ECT:
"""
r"""
A class to calculate the Euler Characteristic Transform (ECT) from an input :class:`ect.embed_complex.EmbeddedComplex`,
using a set of directions to project the complex onto and thresholds to filter the projections.

Expand Down Expand Up @@ -55,6 +70,22 @@ def __init__(
self.bound_radius = bound_radius
self.thresholds = thresholds
self.dtype = dtype
self._thresholds_validated = False
if self.thresholds is not None:
self.thresholds = np.asarray(self.thresholds, dtype=float)
if self.thresholds.ndim != 1:
raise ValueError("thresholds must be a 1-dimensional array")
self._thresholds_validated = True
if num_thresh is not None:
self.is_uniform = True
elif self.thresholds is not None:
self.is_uniform = False
if not _thresholds_are_uniform(self.thresholds):
raise ValueError(
"thresholds must be uniform if num_thresh is not provided"
)
else:
self.is_uniform = True

def _ensure_directions(self, graph_dim, theta=None):
"""Ensures directions is a valid Directions object of correct dimension"""
Expand Down Expand Up @@ -97,11 +128,14 @@ def _ensure_thresholds(self, graph, override_bound_radius=None):
or graph.get_bounding_radius()
)
self.thresholds = np.linspace(-radius, radius, self.num_thresh, dtype=float)
self.is_uniform = True
self._thresholds_validated = True
else:
# validate and convert existing thresholds
self.thresholds = np.asarray(self.thresholds, dtype=float)
if self.thresholds.ndim != 1:
raise ValueError("thresholds must be a 1-dimensional array")
if not self._thresholds_validated:
self.thresholds = np.asarray(self.thresholds, dtype=float)
if self.thresholds.ndim != 1:
raise ValueError("thresholds must be a 1-dimensional array")
self._thresholds_validated = True

def calculate(
self,
Expand Down Expand Up @@ -132,14 +166,25 @@ def _compute_ect(
H = X @ V.T # (N, m)
H_T = np.ascontiguousarray(H.T) # (m, N) for contiguous per-direction rows

out64 = _ect_all_dirs(
H_T,
cell_vertex_pointers,
cell_vertex_indices_flat,
cell_euler_signs,
thresholds,
N,
)
is_uniform = bool(self.is_uniform) and thresholds[0] != thresholds[-1]
if is_uniform:
out64 = _ect_all_dirs_uniform(
H_T,
cell_vertex_pointers,
cell_vertex_indices_flat,
cell_euler_signs,
thresholds,
N,
)
else:
out64 = _ect_all_dirs_search(
H_T,
cell_vertex_pointers,
cell_vertex_indices_flat,
cell_euler_signs,
thresholds,
N,
)
if dtype == np.int32:
return out64.astype(np.int32)
return out64
Expand Down Expand Up @@ -176,74 +221,123 @@ def _compute_simplex_projections(self, graph: EmbeddedComplex, directions):


@njit(cache=True, parallel=True)
def _ect_all_dirs(
heights_by_direction, # shape (num_directions, num_vertices)
cell_vertex_pointers, # shape (num_cells + 1,)
cell_vertex_indices_flat, # concatenated vertex indices for all cells
cell_euler_signs, # per-cell sign: (+1) for even-dim, (-1) for odd-dim
threshold_values, # shape (num_thresholds,), assumed nondecreasing
def _ect_all_dirs_uniform(
heights_by_direction,
cell_vertex_pointers,
cell_vertex_indices_flat,
cell_euler_signs,
threshold_values,
num_vertices,
):
num_directions = heights_by_direction.shape[0]
num_thresholds = threshold_values.shape[0]
t_min = threshold_values[0] if num_thresholds > 0 else 0.0
t_max = threshold_values[-1] if num_thresholds > 0 else 0.0
span = t_max - t_min
inv_span = 1.0 / span
n_minus_1 = num_thresholds - 1

ect_values = np.empty((num_directions, num_thresholds), dtype=np.int64)

for dir_idx in prange(num_directions):
heights = heights_by_direction[dir_idx]

diff = np.zeros(num_thresholds, dtype=np.int64)
vertex_thresh_index = np.empty(num_vertices, dtype=np.int64)

for v in range(num_vertices):
h = heights[v]
u = (h - t_min) * inv_span
idx = int(np.ceil(u * n_minus_1))
if idx < 0:
idx = 0
elif idx >= num_thresholds:
idx = num_thresholds

vertex_thresh_index[v] = idx
if idx < num_thresholds:
diff[idx] += 1

num_cells = cell_vertex_pointers.shape[0] - 1

for cell_idx in range(num_cells):
start = cell_vertex_pointers[cell_idx]
end = cell_vertex_pointers[cell_idx + 1]

entrance_idx = -1
for k in range(start, end):
v = cell_vertex_indices_flat[k]
t_idx = vertex_thresh_index[v]
if t_idx > entrance_idx:
entrance_idx = t_idx

if 0 <= entrance_idx < num_thresholds:
diff[entrance_idx] += cell_euler_signs[cell_idx]

running = 0
for j in range(num_thresholds):
running += diff[j]
ect_values[dir_idx, j] = running

return ect_values


@njit(cache=True, parallel=True)
def _ect_all_dirs_search(
heights_by_direction,
cell_vertex_pointers,
cell_vertex_indices_flat,
cell_euler_signs,
threshold_values,
num_vertices,
):
"""
Calculate the Euler Characteristic Transform (ECT) for a given direction and thresholds.

Args:
heights_by_direction: The heights of the vertices for each direction
cell_vertex_pointers: The pointers to the vertices for each cell
cell_vertex_indices_flat: The indices of the vertices for each cell
cell_euler_signs: The signs of the cells
threshold_values: The thresholds to calculate the ECT for
num_vertices: The number of vertices in the graph
"""
num_directions = heights_by_direction.shape[0]
num_thresholds = threshold_values.shape[0]

ect_values = np.empty((num_directions, num_thresholds), dtype=np.int64)

for dir_idx in prange(num_directions):
heights = heights_by_direction[dir_idx]

sort_order = np.argsort(heights)
diff = np.zeros(num_thresholds, dtype=np.int64)
vertex_thresh_index = np.empty(num_vertices, dtype=np.int64)

for v in range(num_vertices):
h = heights[v]

# calculate what position each vertex is in the sorted heights starting from 1 (the rank)
vertex_rank_1based = np.empty(num_vertices, dtype=np.int32)
for rnk in range(num_vertices):
vertex_index = sort_order[rnk]
vertex_rank_1based[vertex_index] = rnk + 1
left = 0
right = num_thresholds
while left < right:
mid = (left + right) // 2
if threshold_values[mid] >= h:
right = mid
else:
left = mid + 1
idx = left

# euler char can only jump at each vertex value
# we know vertices add +1 so wait until end to add
#
jump_amount = np.zeros(num_vertices + 1, dtype=np.int64)
vertex_thresh_index[v] = idx
if idx < num_thresholds:
diff[idx] += 1

# each pair of pointers defines a cell, so we iterate over them
num_cells = cell_vertex_pointers.shape[0] - 1

for cell_idx in range(num_cells):
start = cell_vertex_pointers[cell_idx]
end = cell_vertex_pointers[cell_idx + 1]
# cells come in when the highest vertex enters
entrance_rank = 0

entrance_idx = -1
for k in range(start, end):
v = cell_vertex_indices_flat[k]
rnk = vertex_rank_1based[v]
if rnk > entrance_rank:
entrance_rank = rnk
# record at what rank the cell enters and how much the euler char changes
jump_amount[entrance_rank] += cell_euler_signs[cell_idx]

# calculate euler char at the moment each vertex enters
euler_prefix = np.empty(num_vertices + 1, dtype=np.int64)
running_sum = 0
for r in range(num_vertices + 1):
running_sum += jump_amount[r]
euler_prefix[r] = running_sum + r # +r because vertices add +1

# now find euler char at each threshold wrt the sorted heights
sorted_heights = heights[sort_order]
rank_cursor = 0 # equals r(t) = # { i : h_i <= t }
for thresh_idx in range(num_thresholds):
t = threshold_values[thresh_idx]
while rank_cursor < num_vertices and sorted_heights[rank_cursor] <= t:
rank_cursor += 1
ect_values[dir_idx, thresh_idx] = euler_prefix[rank_cursor]
t_idx = vertex_thresh_index[v]
if t_idx > entrance_idx:
entrance_idx = t_idx

if 0 <= entrance_idx < num_thresholds:
diff[entrance_idx] += cell_euler_signs[cell_idx]

running = 0
for j in range(num_thresholds):
running += diff[j]
ect_values[dir_idx, j] = running

return ect_values
Loading