Skip to content

Commit d98d9b8

Browse files
committed
HRR documentation update
1 parent e9c0b48 commit d98d9b8

File tree

4 files changed

+10
-17
lines changed

4 files changed

+10
-17
lines changed

torchhd/fhrr.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -175,9 +175,8 @@ def random_hv(
175175
size = (num_vectors, dimensions)
176176
angle = torch.empty(size, dtype=dtype, device=device)
177177
angle.uniform_(-math.pi, math.pi, generator=generator)
178-
magnitude = torch.ones_like(angle)
179178

180-
result = torch.polar(magnitude, angle)
179+
result = torch.complex(angle.cos(), angle.sin())
181180
result.requires_grad = requires_grad
182181
return result.as_subclass(cls)
183182

torchhd/functional.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -438,6 +438,13 @@ def circular_hv(
438438
Implements circular-hypervectors based on level-hypervectors as described in `An Extension to Basis-Hypervectors for Learning from Circular Data in Hyperdimensional Computing <https://arxiv.org/abs/2205.07920>`_.
439439
Any hypervector is quasi-orthogonal to the hypervector opposite site of the circle.
440440
441+
.. note::
442+
Circular hypervectors cannot be created directly with Holographic Reduced Representations (HRR) because of imprecisions inherent to HRR.
443+
One way around this is to use FHRR for the creation of circular hypervectors and then transform them to HRR vectors. Example:
444+
445+
>>> hv = torchhd.circular_hv(10, 6, torchhd.FHRR)
446+
>>> hv = torch.real(torch.fft.ifft(hv)).as_subclass(torchhd.HRR)
447+
441448
Args:
442449
num_vectors (int): the number of hypervectors to generate.
443450
dimensions (int): the dimensionality of the hypervectors.

torchhd/hrr.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
import torch
22
from torch import Tensor
3-
import torch.nn.functional as F
43
from torch.fft import fft, ifft
54
from typing import Set
5+
import math
66

77
from torchhd.base import VSA_Model
88

@@ -164,7 +164,7 @@ def random_hv(
164164

165165
size = (num_vectors, dimensions)
166166
result = torch.empty(size, dtype=dtype, device=device)
167-
result.normal_(0, 1.0 / dimensions, generator=generator)
167+
result.normal_(0, 1.0 / math.sqrt(dimensions), generator=generator)
168168

169169
result.requires_grad = requires_grad
170170
return result.as_subclass(cls)

torchhd/tests/basis_hv/test_circular_hv.py

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -69,12 +69,6 @@ def test_value(self, dtype, model):
6969
elif model == torchhd.MAP:
7070
assert torch.all((hv == -1) | (hv == 1)).item()
7171

72-
elif model == torchhd.HRR:
73-
std, mean = torch.std_mean(hv)
74-
assert torch.allclose(
75-
mean, torch.tensor(0.0, dtype=mean.dtype), atol=0.0001
76-
)
77-
7872
elif model == torchhd.FHRR:
7973
mag = hv.abs()
8074
assert torch.allclose(mag, torch.tensor(1.0, dtype=mag.dtype))
@@ -140,14 +134,10 @@ def test_uses_default_dtype(self):
140134
torch.set_default_dtype(torch.float32)
141135
hv = functional.circular_hv(3, 52, torchhd.MAP)
142136
assert hv.dtype == torch.float32
143-
# hv = functional.circular_hv(3, 52, torchhd.HRR)
144-
# assert hv.dtype == torch.float32
145137

146138
torch.set_default_dtype(torch.float64)
147139
hv = functional.circular_hv(3, 52, torchhd.MAP)
148140
assert hv.dtype == torch.float64
149-
# hv = functional.circular_hv(3, 52, torchhd.HRR)
150-
# assert hv.dtype == torch.float64
151141

152142
hv = functional.circular_hv(3, 52, torchhd.FHRR)
153143
assert hv.dtype == torch.complex64
@@ -156,8 +146,5 @@ def test_requires_grad(self):
156146
hv = functional.circular_hv(3, 52, torchhd.MAP, requires_grad=True)
157147
assert hv.requires_grad == True
158148

159-
# hv = functional.circular_hv(3, 52, torchhd.HRR, requires_grad=True)
160-
# assert hv.requires_grad == True
161-
162149
hv = functional.circular_hv(3, 52, torchhd.FHRR, requires_grad=True)
163150
assert hv.requires_grad == True

0 commit comments

Comments
 (0)