From 2ddfb21fea8c3185273c8437a9c82869be42b0dc Mon Sep 17 00:00:00 2001 From: SrGonao Date: Thu, 12 Mar 2026 14:03:10 +0000 Subject: [PATCH] latents: fix cache path handling and typing --- delphi/latents/cache.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/delphi/latents/cache.py b/delphi/latents/cache.py index 4dfb75aa..6e9a3c68 100644 --- a/delphi/latents/cache.py +++ b/delphi/latents/cache.py @@ -282,18 +282,14 @@ def run(self, n_tokens: int, tokens: token_tensor_type): latents ) self.cache.add(sae_latents, batch, batch_number, hookpoint) - firing_counts = (sae_latents > 0).sum((0, 1)) + firing_counts = (sae_latents.cpu() > 0).sum((0, 1)) if self.width is None: self.width = sae_latents.shape[2] if hookpoint not in self.hookpoint_firing_counts: - self.hookpoint_firing_counts[hookpoint] = ( - firing_counts.cpu() - ) + self.hookpoint_firing_counts[hookpoint] = firing_counts else: - self.hookpoint_firing_counts[ - hookpoint - ] += firing_counts.cpu() + self.hookpoint_firing_counts[hookpoint] += firing_counts # Update the progress bar pbar.update(1)