@@ -101,34 +101,42 @@ def _ensure_thresholds(self, graph, override_bound_radius=None):
101101 def calculate (
102102 self ,
103103 graph : EmbeddedComplex ,
104- theta : Optional [ float ] = None ,
105- override_bound_radius : Optional [ float ] = None ,
104+ theta : float = None ,
105+ override_bound_radius : float = None ,
106106 ):
107- """Calculate Euler Characteristic Transform (ECT) for a given graph and direction theta
108-
109- Args:
110- graph (EmbeddedComplex):
111- The input complex to calculate the ECT for.
112- theta (float):
113- The angle in :math:`[0,2\pi]` for the direction to calculate the ECT.
114- override_bound_radius (float):
115- If None, uses the following in order: (i) the bounding radius stored in the class; or if not available (ii) the bounding radius of the given graph. Otherwise, should be a positive float :math:`R` where the ECC will be computed at thresholds in :math:`[-R,R]`. Default is None.
116- """
117107 self ._ensure_directions (graph .dim , theta )
118108 self ._ensure_thresholds (graph , override_bound_radius )
119-
120- # override with theta if provided
121109 directions = (
122110 self .directions if theta is None else Directions .from_angles ([theta ])
123111 )
112+ ect_matrix = self ._compute_ect (graph , directions , self .thresholds , self .dtype )
124113
125- simplex_projections = self . _compute_simplex_projections ( graph , directions )
114+ return ECTResult ( ect_matrix , directions , self . thresholds )
126115
127- ect_matrix = self ._compute_directional_transform (
128- simplex_projections , self .thresholds , self .dtype
116+ def _compute_ect (
117+ self , graph , directions , thresholds : np .ndarray , dtype = np .int32
118+ ) -> np .ndarray :
119+ cell_vertex_pointers , cell_vertex_indices_flat , cell_euler_signs , N = (
120+ graph ._build_incidence_csr ()
129121 )
130-
131- return ECTResult (ect_matrix , directions , self .thresholds )
122+ thresholds = np .asarray (thresholds , dtype = np .float64 )
123+
124+ V = directions .vectors
125+ X = graph .coord_matrix
126+ H = X @ V if V .shape [0 ] == X .shape [1 ] else X @ V .T # (N, m)
127+ H_T = np .ascontiguousarray (H .T ) # (m, N) for contiguous per-direction rows
128+
129+ out64 = _ect_all_dirs (
130+ H_T ,
131+ cell_vertex_pointers ,
132+ cell_vertex_indices_flat ,
133+ cell_euler_signs ,
134+ thresholds ,
135+ N ,
136+ )
137+ if dtype == np .int32 :
138+ return out64 .astype (np .int32 )
139+ return out64
132140
133141 def _compute_simplex_projections (self , graph : EmbeddedComplex , directions ):
134142 """Compute projections of each k-cell for all dimensions"""
@@ -160,41 +168,79 @@ def _compute_simplex_projections(self, graph: EmbeddedComplex, directions):
160168
161169 return simplex_projections
162170
163- @staticmethod
164- @njit (parallel = True , fastmath = True )
165- def _compute_directional_transform (
166- simplex_projections_list , thresholds , dtype = np .int32
167- ):
168- """Compute ECT by counting simplices below each threshold - VECTORIZED VERSION
169-
170- Args:
171- simplex_projections_list: List of arrays containing projections for each simplex type
172- [vertex_projections, edge_projections, face_projections]
173- thresholds: Array of threshold values to compute ECT at
174- dtype: Data type for output array (default: np.int32)
175171
176- Returns:
177- Array of shape (num_directions, num_thresholds) containing Euler characteristics
178- """
179- num_dir = simplex_projections_list [0 ].shape [1 ]
180- num_thresh = thresholds .shape [0 ]
181- result = np .empty ((num_dir , num_thresh ), dtype = dtype )
182-
183- sorted_projections = List ()
184- for proj in simplex_projections_list :
185- sorted_proj = np .empty_like (proj )
186- for i in prange (num_dir ):
187- sorted_proj [:, i ] = np .sort (proj [:, i ])
188- sorted_projections .append (sorted_proj )
189-
190- for i in prange (num_dir ):
191- chi = np .zeros (num_thresh , dtype = dtype )
192- for k in range (len (sorted_projections )):
193- projs = sorted_projections [k ][:, i ]
194-
195- count = np .searchsorted (projs , thresholds , side = "right" )
196-
197- sign = - 1 if k % 2 else 1
198- chi += sign * count
199- result [i ] = chi
200- return result
172+ @njit (cache = True , parallel = True )
173+ def _ect_all_dirs (
174+ heights_by_direction , # shape (num_directions, num_vertices)
175+ cell_vertex_pointers , # shape (num_cells + 1,)
176+ cell_vertex_indices_flat , # concatenated vertex indices for all cells
177+ cell_euler_signs , # per-cell sign: (+1) for even-dim, (-1) for odd-dim
178+ threshold_values , # shape (num_thresholds,), assumed nondecreasing
179+ num_vertices ,
180+ ):
181+ """
182+ Calculate the Euler Characteristic Transform (ECT) for a given direction and thresholds.
183+
184+ Args:
185+ heights_by_direction: The heights of the vertices for each direction
186+ cell_vertex_pointers: The pointers to the vertices for each cell
187+ cell_vertex_indices_flat: The indices of the vertices for each cell
188+ cell_euler_signs: The signs of the cells
189+ threshold_values: The thresholds to calculate the ECT for
190+ num_vertices: The number of vertices in the graph
191+ """
192+ num_directions = heights_by_direction .shape [0 ]
193+ num_thresholds = threshold_values .shape [0 ]
194+ ect_values = np .empty ((num_directions , num_thresholds ), dtype = np .int64 )
195+
196+ for dir_idx in prange (num_directions ):
197+ heights = heights_by_direction [dir_idx ]
198+
199+ sort_order = np .argsort (heights )
200+
201+ # calculate what position each vertex is in the sorted heights starting from 1 (the rank)
202+ vertex_rank_1based = np .empty (num_vertices , dtype = np .int32 )
203+ for rnk in range (num_vertices ):
204+ vertex_index = sort_order [rnk ]
205+ vertex_rank_1based [vertex_index ] = rnk + 1
206+
207+ # euler char can only jump at each vertex value
208+ jump_amount = np .zeros (num_vertices + 1 , dtype = np .int64 )
209+
210+ # 0-cells add +1 at their entrance ranks
211+ for v in range (num_vertices ):
212+ rank_v = vertex_rank_1based [v ]
213+ jump_amount [rank_v ] += 1
214+
215+ # each pair of pointers defines a cell, so we iterate over them
216+ num_cells = cell_vertex_pointers .shape [0 ] - 1
217+ for cell_idx in range (num_cells ):
218+ start = cell_vertex_pointers [cell_idx ]
219+ end = cell_vertex_pointers [cell_idx + 1 ]
220+ # cells come in when the highest vertex enters
221+ entrance_rank = 0
222+ for k in range (start , end ):
223+ v = cell_vertex_indices_flat [k ]
224+ rnk = vertex_rank_1based [v ]
225+ if rnk > entrance_rank :
226+ entrance_rank = rnk
227+ # record at what rank the cell enters and how much the euler char changes
228+ jump_amount [entrance_rank ] += cell_euler_signs [cell_idx ]
229+
230+ # calculate euler char at the moment each vertex enters
231+ euler_prefix = np .empty (num_vertices + 1 , dtype = np .int64 )
232+ running_sum = 0
233+ for r in range (num_vertices + 1 ):
234+ running_sum += jump_amount [r ]
235+ euler_prefix [r ] = running_sum
236+
237+ # now find euler char at each threshold wrt the sorted heights
238+ sorted_heights = heights [sort_order ]
239+ rank_cursor = 0 # equals r(t) = # { i : h_i <= t }
240+ for thresh_idx in range (num_thresholds ):
241+ t = threshold_values [thresh_idx ]
242+ while rank_cursor < num_vertices and sorted_heights [rank_cursor ] <= t :
243+ rank_cursor += 1
244+ ect_values [dir_idx , thresh_idx ] = euler_prefix [rank_cursor ]
245+
246+ return ect_values
0 commit comments