Skip to content

Commit 2a12aae

Browse files
authored
Allow batch dimensions in dot and cos similarity (#119)
1 parent 7601e6c commit 2a12aae

File tree

5 files changed

+19
-8
lines changed

5 files changed

+19
-8
lines changed

torchhd/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
dot_similarity,
3535
hamming_similarity,
3636
multiset,
37+
multibundle,
3738
multibind,
3839
bundle_sequence,
3940
bind_sequence,
@@ -83,6 +84,7 @@
8384
"dot_similarity",
8485
"hamming_similarity",
8586
"multiset",
87+
"multibundle",
8688
"multibind",
8789
"bundle_sequence",
8890
"bind_sequence",

torchhd/bsc.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -406,14 +406,17 @@ def permute(self, shifts: int = 1) -> "BSC":
406406
def dot_similarity(self, others: "BSC") -> Tensor:
407407
"""Inner product with other hypervectors."""
408408
dtype = torch.get_default_dtype()
409+
device = self.device
409410

410-
min_one = torch.tensor(-1.0, dtype=dtype)
411-
plus_one = torch.tensor(1.0, dtype=dtype)
411+
min_one = torch.tensor(-1.0, dtype=dtype, device=device)
412+
plus_one = torch.tensor(1.0, dtype=dtype, device=device)
412413

413414
self_as_bipolar = torch.where(self.bool(), min_one, plus_one)
414415
others_as_bipolar = torch.where(others.bool(), min_one, plus_one)
415416

416-
return F.linear(self_as_bipolar, others_as_bipolar)
417+
if others.dim() >= 2:
418+
others_as_bipolar = others_as_bipolar.mT
419+
return torch.matmul(self_as_bipolar, others_as_bipolar)
417420

418421
def cos_similarity(self, others: "BSC") -> Tensor:
419422
"""Cosine similarity with other hypervectors."""

torchhd/fhrr.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -354,7 +354,9 @@ def permute(self, shifts: int = 1) -> "FHRR":
354354

355355
def dot_similarity(self, others: "FHRR") -> Tensor:
356356
"""Inner product with other hypervectors"""
357-
return F.linear(self, others.conj()).real
357+
if others.dim() >= 2:
358+
others = others.mT
359+
return torch.matmul(self, others.conj()).real
358360

359361
def cos_similarity(self, others: "FHRR", *, eps=1e-08) -> Tensor:
360362
"""Cosine similarity with other hypervectors"""

torchhd/hrr.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -339,7 +339,9 @@ def permute(self, shifts: int = 1) -> "HRR":
339339

340340
def dot_similarity(self, others: "HRR") -> Tensor:
341341
"""Inner product with other hypervectors"""
342-
return F.linear(self, others)
342+
if others.dim() >= 2:
343+
others = others.mT
344+
return torch.matmul(self, others)
343345

344346
def cos_similarity(self, others: "HRR", *, eps=1e-08) -> Tensor:
345347
"""Cosine similarity with other hypervectors"""

torchhd/map.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -320,7 +320,9 @@ def clipping(self, kappa) -> "MAP":
320320
def dot_similarity(self, others: "MAP") -> Tensor:
321321
"""Inner product with other hypervectors"""
322322
dtype = torch.get_default_dtype()
323-
return F.linear(self.to(dtype), others.to(dtype))
323+
if others.dim() >= 2:
324+
others = others.mT
325+
return torch.matmul(self.to(dtype), others.to(dtype))
324326

325327
def cos_similarity(self, others: "MAP", *, eps=1e-08) -> Tensor:
326328
"""Cosine similarity with other hypervectors"""
@@ -332,10 +334,10 @@ def cos_similarity(self, others: "MAP", *, eps=1e-08) -> Tensor:
332334
others_dot = torch.sum(others * others, dim=-1, dtype=dtype)
333335
others_mag = others_dot.sqrt()
334336

335-
if self.dim() > 1:
337+
if self.dim() >= 2:
336338
magnitude = self_mag.unsqueeze(-1) * others_mag.unsqueeze(0)
337339
else:
338340
magnitude = self_mag * others_mag
339341

340342
magnitude = magnitude.clamp(min=eps)
341-
return F.linear(self.to(dtype), others.to(dtype)) / magnitude
343+
return self.dot_similarity(others) / magnitude

0 commit comments

Comments
 (0)