diff --git a/pyproject.toml b/pyproject.toml index 897db80..df56501 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "ect" -version = "1.2.4" +version = "1.3.0" authors = [ { name="Liz Munch", email="muncheli@msu.edu" }, ] diff --git a/src/ect/ect.py b/src/ect/ect.py index 3e2bed4..0360a71 100644 --- a/src/ect/ect.py +++ b/src/ect/ect.py @@ -8,8 +8,23 @@ from .results import ECTResult +def _thresholds_are_uniform(thresholds: np.ndarray) -> bool: + thresholds = np.asarray(thresholds, dtype=float) + if thresholds.ndim != 1: + raise ValueError("thresholds must be a 1-dimensional array") + n = thresholds.size + if n <= 1: + return True + diffs = np.diff(thresholds) + first = diffs[0] + if first == 0.0: + return bool(np.all(diffs == 0.0)) + tol = 1e-12 * max(1.0, abs(first)) + return bool(np.all(np.abs(diffs - first) <= tol)) + + class ECT: - """ + r""" A class to calculate the Euler Characteristic Transform (ECT) from an input :class:`ect.embed_complex.EmbeddedComplex`, using a set of directions to project the complex onto and thresholds to filter the projections. @@ -55,6 +70,22 @@ def __init__( self.bound_radius = bound_radius self.thresholds = thresholds self.dtype = dtype + self._thresholds_validated = False + if self.thresholds is not None: + self.thresholds = np.asarray(self.thresholds, dtype=float) + if self.thresholds.ndim != 1: + raise ValueError("thresholds must be a 1-dimensional array") + self._thresholds_validated = True + if num_thresh is not None: + self.is_uniform = True + elif self.thresholds is not None: + self.is_uniform = False + if not _thresholds_are_uniform(self.thresholds): + raise ValueError( + "thresholds must be uniform if num_thresh is not provided" + ) + else: + self.is_uniform = True def _ensure_directions(self, graph_dim, theta=None): """Ensures directions is a valid Directions object of correct dimension""" @@ -97,11 +128,14 @@ def _ensure_thresholds(self, graph, override_bound_radius=None): or graph.get_bounding_radius() ) self.thresholds = np.linspace(-radius, radius, self.num_thresh, dtype=float) + self.is_uniform = True + self._thresholds_validated = True else: - # validate and convert existing thresholds - self.thresholds = np.asarray(self.thresholds, dtype=float) - if self.thresholds.ndim != 1: - raise ValueError("thresholds must be a 1-dimensional array") + if not self._thresholds_validated: + self.thresholds = np.asarray(self.thresholds, dtype=float) + if self.thresholds.ndim != 1: + raise ValueError("thresholds must be a 1-dimensional array") + self._thresholds_validated = True def calculate( self, @@ -132,14 +166,25 @@ def _compute_ect( H = X @ V.T # (N, m) H_T = np.ascontiguousarray(H.T) # (m, N) for contiguous per-direction rows - out64 = _ect_all_dirs( - H_T, - cell_vertex_pointers, - cell_vertex_indices_flat, - cell_euler_signs, - thresholds, - N, - ) + is_uniform = bool(self.is_uniform) and thresholds[0] != thresholds[-1] + if is_uniform: + out64 = _ect_all_dirs_uniform( + H_T, + cell_vertex_pointers, + cell_vertex_indices_flat, + cell_euler_signs, + thresholds, + N, + ) + else: + out64 = _ect_all_dirs_search( + H_T, + cell_vertex_pointers, + cell_vertex_indices_flat, + cell_euler_signs, + thresholds, + N, + ) if dtype == np.int32: return out64.astype(np.int32) return out64 @@ -176,74 +221,123 @@ def _compute_simplex_projections(self, graph: EmbeddedComplex, directions): @njit(cache=True, parallel=True) -def _ect_all_dirs( - heights_by_direction, # shape (num_directions, num_vertices) - cell_vertex_pointers, # shape (num_cells + 1,) - cell_vertex_indices_flat, # concatenated vertex indices for all cells - cell_euler_signs, # per-cell sign: (+1) for even-dim, (-1) for odd-dim - threshold_values, # shape (num_thresholds,), assumed nondecreasing +def _ect_all_dirs_uniform( + heights_by_direction, + cell_vertex_pointers, + cell_vertex_indices_flat, + cell_euler_signs, + threshold_values, + num_vertices, +): + num_directions = heights_by_direction.shape[0] + num_thresholds = threshold_values.shape[0] + t_min = threshold_values[0] if num_thresholds > 0 else 0.0 + t_max = threshold_values[-1] if num_thresholds > 0 else 0.0 + span = t_max - t_min + inv_span = 1.0 / span + n_minus_1 = num_thresholds - 1 + + ect_values = np.empty((num_directions, num_thresholds), dtype=np.int64) + + for dir_idx in prange(num_directions): + heights = heights_by_direction[dir_idx] + + diff = np.zeros(num_thresholds, dtype=np.int64) + vertex_thresh_index = np.empty(num_vertices, dtype=np.int64) + + for v in range(num_vertices): + h = heights[v] + u = (h - t_min) * inv_span + idx = int(np.ceil(u * n_minus_1)) + if idx < 0: + idx = 0 + elif idx >= num_thresholds: + idx = num_thresholds + + vertex_thresh_index[v] = idx + if idx < num_thresholds: + diff[idx] += 1 + + num_cells = cell_vertex_pointers.shape[0] - 1 + + for cell_idx in range(num_cells): + start = cell_vertex_pointers[cell_idx] + end = cell_vertex_pointers[cell_idx + 1] + + entrance_idx = -1 + for k in range(start, end): + v = cell_vertex_indices_flat[k] + t_idx = vertex_thresh_index[v] + if t_idx > entrance_idx: + entrance_idx = t_idx + + if 0 <= entrance_idx < num_thresholds: + diff[entrance_idx] += cell_euler_signs[cell_idx] + + running = 0 + for j in range(num_thresholds): + running += diff[j] + ect_values[dir_idx, j] = running + + return ect_values + + +@njit(cache=True, parallel=True) +def _ect_all_dirs_search( + heights_by_direction, + cell_vertex_pointers, + cell_vertex_indices_flat, + cell_euler_signs, + threshold_values, num_vertices, ): - """ - Calculate the Euler Characteristic Transform (ECT) for a given direction and thresholds. - - Args: - heights_by_direction: The heights of the vertices for each direction - cell_vertex_pointers: The pointers to the vertices for each cell - cell_vertex_indices_flat: The indices of the vertices for each cell - cell_euler_signs: The signs of the cells - threshold_values: The thresholds to calculate the ECT for - num_vertices: The number of vertices in the graph - """ num_directions = heights_by_direction.shape[0] num_thresholds = threshold_values.shape[0] + ect_values = np.empty((num_directions, num_thresholds), dtype=np.int64) for dir_idx in prange(num_directions): heights = heights_by_direction[dir_idx] - sort_order = np.argsort(heights) + diff = np.zeros(num_thresholds, dtype=np.int64) + vertex_thresh_index = np.empty(num_vertices, dtype=np.int64) + + for v in range(num_vertices): + h = heights[v] - # calculate what position each vertex is in the sorted heights starting from 1 (the rank) - vertex_rank_1based = np.empty(num_vertices, dtype=np.int32) - for rnk in range(num_vertices): - vertex_index = sort_order[rnk] - vertex_rank_1based[vertex_index] = rnk + 1 + left = 0 + right = num_thresholds + while left < right: + mid = (left + right) // 2 + if threshold_values[mid] >= h: + right = mid + else: + left = mid + 1 + idx = left - # euler char can only jump at each vertex value - # we know vertices add +1 so wait until end to add - # - jump_amount = np.zeros(num_vertices + 1, dtype=np.int64) + vertex_thresh_index[v] = idx + if idx < num_thresholds: + diff[idx] += 1 - # each pair of pointers defines a cell, so we iterate over them num_cells = cell_vertex_pointers.shape[0] - 1 + for cell_idx in range(num_cells): start = cell_vertex_pointers[cell_idx] end = cell_vertex_pointers[cell_idx + 1] - # cells come in when the highest vertex enters - entrance_rank = 0 + + entrance_idx = -1 for k in range(start, end): v = cell_vertex_indices_flat[k] - rnk = vertex_rank_1based[v] - if rnk > entrance_rank: - entrance_rank = rnk - # record at what rank the cell enters and how much the euler char changes - jump_amount[entrance_rank] += cell_euler_signs[cell_idx] - - # calculate euler char at the moment each vertex enters - euler_prefix = np.empty(num_vertices + 1, dtype=np.int64) - running_sum = 0 - for r in range(num_vertices + 1): - running_sum += jump_amount[r] - euler_prefix[r] = running_sum + r # +r because vertices add +1 - - # now find euler char at each threshold wrt the sorted heights - sorted_heights = heights[sort_order] - rank_cursor = 0 # equals r(t) = # { i : h_i <= t } - for thresh_idx in range(num_thresholds): - t = threshold_values[thresh_idx] - while rank_cursor < num_vertices and sorted_heights[rank_cursor] <= t: - rank_cursor += 1 - ect_values[dir_idx, thresh_idx] = euler_prefix[rank_cursor] + t_idx = vertex_thresh_index[v] + if t_idx > entrance_idx: + entrance_idx = t_idx + + if 0 <= entrance_idx < num_thresholds: + diff[entrance_idx] += cell_euler_signs[cell_idx] + + running = 0 + for j in range(num_thresholds): + running += diff[j] + ect_values[dir_idx, j] = running return ect_values