diff --git a/delphi/latents/cache.py b/delphi/latents/cache.py index 4dfb75aa..6e9a3c68 100644 --- a/delphi/latents/cache.py +++ b/delphi/latents/cache.py @@ -282,18 +282,14 @@ def run(self, n_tokens: int, tokens: token_tensor_type): latents ) self.cache.add(sae_latents, batch, batch_number, hookpoint) - firing_counts = (sae_latents > 0).sum((0, 1)) + firing_counts = (sae_latents.cpu() > 0).sum((0, 1)) if self.width is None: self.width = sae_latents.shape[2] if hookpoint not in self.hookpoint_firing_counts: - self.hookpoint_firing_counts[hookpoint] = ( - firing_counts.cpu() - ) + self.hookpoint_firing_counts[hookpoint] = firing_counts else: - self.hookpoint_firing_counts[ - hookpoint - ] += firing_counts.cpu() + self.hookpoint_firing_counts[hookpoint] += firing_counts # Update the progress bar pbar.update(1)