Skip to content

Commit 5a773e6

Browse files
authored
Add cleanup function for associative memory lookup (#30)
* Add cleanup function * Add cleanup to __all__
1 parent ea8b50a commit 5a773e6

File tree

3 files changed

+28
-1
lines changed

3 files changed

+28
-1
lines changed

docs/_templates/class.rst

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,5 @@
77

88
.. autoclass:: {{ name }}
99
:members:
10-
:special-members:
10+
:special-members:
11+
:exclude-members: __weakref__

docs/functional.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ Operations
2828
bind
2929
bundle
3030
permute
31+
cleanup
3132
soft_quantize
3233
hard_quantize
3334

torchhd/functional.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
"bind",
1515
"bundle",
1616
"permute",
17+
"cleanup",
1718
"hard_quantize",
1819
"soft_quantize",
1920
"hamming_similarity",
@@ -543,3 +544,27 @@ def index_to_value(
543544
544545
"""
545546
return map_range(input.float(), 0, index_length - 1, out_min, out_max)
547+
548+
549+
def cleanup(input: Tensor, memory: Tensor, threshold=0.0) -> Tensor:
550+
"""Returns a copy of the most similar hypervector in memory.
551+
552+
If the cosine similarity is less than threshold, raises a KeyError.
553+
554+
Args:
555+
input (Tensor): The hypervector to cleanup
556+
memory (Tensor): The `n` hypervectors in memory of shape (n, d)
557+
558+
Returns:
559+
Tensor: output tensor
560+
"""
561+
scores = cosine_similarity(input, memory)
562+
value, index = torch.max(scores, dim=-1)
563+
564+
if value.item() < threshold:
565+
raise KeyError(
566+
"Hypervector with the highest similarity is less similar than the provided threshold"
567+
)
568+
569+
# Copying prevents manipulating the memory tensor
570+
return torch.clone(memory[index])

0 commit comments

Comments
 (0)