diff --git a/chelombus/clustering/PyQKmeans.py b/chelombus/clustering/PyQKmeans.py index 6684e8a..e08091d 100644 --- a/chelombus/clustering/PyQKmeans.py +++ b/chelombus/clustering/PyQKmeans.py @@ -51,7 +51,7 @@ def _predict_numba(pq_codes, centers, dtables): n = pq_codes.shape[0] m = pq_codes.shape[1] n_centers = centers.shape[0] - labels = np.empty(n, dtype=np.int64) + labels = np.empty(n, dtype=np.int32) for i in prange(n): best_dist = np.inf best_label = 0 @@ -62,7 +62,7 @@ def _predict_numba(pq_codes, centers, dtables): if dist < best_dist: best_dist = dist best_label = c - labels[i] = best_label + labels[i] = np.int32(best_label) return labels @@ -101,7 +101,7 @@ def _update_centers( hist = np.zeros(K * k_cb, dtype=np.int64) for start in range(0, N, chunk_size): end = min(start + chunk_size, N) - flat = (labels[start:end] * k_cb + flat = (labels[start:end].astype(np.int64) * k_cb + pq_codes[start:end, s].astype(np.int64)) hist += np.bincount(flat, minlength=K * k_cb) diff --git a/chelombus/clustering/_gpu_predict.py b/chelombus/clustering/_gpu_predict.py index 218702f..06d805a 100644 --- a/chelombus/clustering/_gpu_predict.py +++ b/chelombus/clustering/_gpu_predict.py @@ -136,7 +136,7 @@ def predict_gpu( verbose: Print per-batch progress (useful for billion-scale runs). Returns: - (N,) int64 cluster labels (same dtype as CPU path). + (N,) int32 cluster labels (same dtype as CPU path). """ import time as _time @@ -157,7 +157,7 @@ def predict_gpu( if batch_size <= 0: batch_size = _auto_batch_size(N, M) - labels_out = np.empty(N, dtype=np.int64) + labels_out = np.empty(N, dtype=np.int32) # Adaptive BLOCK_K: larger M means more registers per subvector, # so reduce BLOCK_K to avoid register spill.