Skip to content

Commit b6ae3a5

Browse files
committed
Fix cosine similarity for batch processing
1 parent 5188462 commit b6ae3a5

File tree

3 files changed

+5
-5
lines changed

3 files changed

+5
-5
lines changed

torchhd/tensors/fhrr.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -366,8 +366,8 @@ def cosine_similarity(self, others: "FHRRTensor", *, eps=1e-08) -> Tensor:
366366
others_dot = torch.sum(torch.real(others * torch.conj(others)), dim=-1)
367367
others_mag = torch.sqrt(others_dot)
368368

369-
if self.dim() > 1:
370-
magnitude = self_mag.unsqueeze(-1) * others_mag.unsqueeze(0)
369+
if self.dim() >= 2:
370+
magnitude = self_mag.unsqueeze(-1) * others_mag.unsqueeze(-2)
371371
else:
372372
magnitude = self_mag * others_mag
373373

torchhd/tensors/hrr.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -353,8 +353,8 @@ def cosine_similarity(self, others: "HRRTensor", *, eps=1e-08) -> Tensor:
353353
others_dot = torch.sum(others * others, dim=-1)
354354
others_mag = torch.sqrt(others_dot)
355355

356-
if self.dim() > 1:
357-
magnitude = self_mag.unsqueeze(-1) * others_mag.unsqueeze(0)
356+
if self.dim() >= 2:
357+
magnitude = self_mag.unsqueeze(-1) * others_mag.unsqueeze(-2)
358358
else:
359359
magnitude = self_mag * others_mag
360360

torchhd/tensors/map.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -335,7 +335,7 @@ def cosine_similarity(self, others: "MAPTensor", *, eps=1e-08) -> Tensor:
335335
others_mag = torch.sqrt(others_dot)
336336

337337
if self.dim() >= 2:
338-
magnitude = self_mag.unsqueeze(-1) * others_mag.unsqueeze(0)
338+
magnitude = self_mag.unsqueeze(-1) * others_mag.unsqueeze(-2)
339339
else:
340340
magnitude = self_mag * others_mag
341341

0 commit comments

Comments
 (0)