Skip to content

Commit 6f14992

Browse files
Add tests for the models module (#129)
* Test models module * [github-action] formatting fixes * Add shape test * [github-action] formatting fixes --------- Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
1 parent 63c44d6 commit 6f14992

File tree

2 files changed

+127
-4
lines changed

2 files changed

+127
-4
lines changed

torchhd/models.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ class Centroid(nn.Module):
5151
out_features (int): Size of the output, typically the number of classes.
5252
device (``torch.device``, optional): the desired device of the weights. Default: if ``None``, uses the current device for the default tensor type (see ``torch.set_default_tensor_type()``). ``device`` will be the CPU for CPU tensor types and the current CUDA device for CUDA tensor types.
5353
dtype (``torch.dtype``, optional): the desired data type of the weights. Default: if ``None``, uses ``torch.get_default_dtype()``.
54+
requires_grad (bool, optional): If autograd should record operations on the returned tensor. Default: ``False``.
5455
5556
Shape:
5657
- Input: :math:`(*, d)` where :math:`*` means any number of
@@ -76,7 +77,12 @@ class Centroid(nn.Module):
7677
weight: Tensor
7778

7879
def __init__(
79-
self, in_features: int, out_features: int, device=None, dtype=None
80+
self,
81+
in_features: int,
82+
out_features: int,
83+
device=None,
84+
dtype=None,
85+
requires_grad=False,
8086
) -> None:
8187
factory_kwargs = {"device": device, "dtype": dtype}
8288
super(Centroid, self).__init__()
@@ -85,7 +91,7 @@ def __init__(
8591
self.out_features = out_features
8692

8793
weight = torch.empty((out_features, in_features), **factory_kwargs)
88-
self.weight = Parameter(weight)
94+
self.weight = Parameter(weight, requires_grad=requires_grad)
8995
self.reset_parameters()
9096

9197
def reset_parameters(self) -> None:
@@ -161,6 +167,7 @@ class IntRVFL(nn.Module):
161167
kappa (int, optional): Parameter of the clipping function limiting the range of values; used as the part of transforming input data.
162168
device (``torch.device``, optional): the desired device of the weights. Default: if ``None``, uses the current device for the default tensor type (see ``torch.set_default_tensor_type()``). ``device`` will be the CPU for CPU tensor types and the current CUDA device for CUDA tensor types.
163169
dtype (``torch.dtype``, optional): the desired data type of the weights. Default: if ``None``, uses ``torch.get_default_dtype()``.
170+
requires_grad (bool, optional): If autograd should record operations on the returned tensor. Default: ``False``.
164171
165172
Shape:
166173
- Input: :math:`(*, d)` where :math:`*` means any number of
@@ -189,6 +196,7 @@ def __init__(
189196
kappa: Optional[int] = None,
190197
device=None,
191198
dtype=None,
199+
requires_grad=False,
192200
) -> None:
193201
factory_kwargs = {"device": device, "dtype": dtype}
194202
super(IntRVFL, self).__init__()
@@ -202,8 +210,12 @@ def __init__(
202210
in_features, self.dimensions, **factory_kwargs
203211
)
204212

205-
weight = torch.zeros((out_features, dimensions), **factory_kwargs)
206-
self.weight = Parameter(weight)
213+
weight = torch.empty((out_features, dimensions), **factory_kwargs)
214+
self.weight = Parameter(weight, requires_grad=requires_grad)
215+
self.reset_parameters()
216+
217+
def reset_parameters(self) -> None:
218+
init.zeros_(self.weight)
207219

208220
def encode(self, x):
209221
encodings = self.encoding(x)

torchhd/tests/test_models.py

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
#
2+
# MIT License
3+
#
4+
# Copyright (c) 2023 Mike Heddes, Igor Nunes, Pere Vergés, Denis Kleyko, and Danny Abraham
5+
#
6+
# Permission is hereby granted, free of charge, to any person obtaining a copy
7+
# of this software and associated documentation files (the "Software"), to deal
8+
# in the Software without restriction, including without limitation the rights
9+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10+
# copies of the Software, and to permit persons to whom the Software is
11+
# furnished to do so, subject to the following conditions:
12+
#
13+
# The above copyright notice and this permission notice shall be included in all
14+
# copies or substantial portions of the Software.
15+
#
16+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22+
# SOFTWARE.
23+
#
24+
import pytest
25+
import torch
26+
import torch.nn.functional as F
27+
import torchhd
28+
from torchhd import models
29+
from torchhd import MAPTensor
30+
31+
from .utils import (
32+
torch_dtypes,
33+
vsa_tensors,
34+
supported_dtype,
35+
)
36+
37+
38+
class TestCentroid:
39+
@pytest.mark.parametrize("dtype", torch_dtypes)
40+
def test_initialization(self, dtype):
41+
if dtype not in MAPTensor.supported_dtypes:
42+
return
43+
44+
model = models.Centroid(1245, 12, dtype=dtype)
45+
assert torch.allclose(model.weight, torch.zeros(12, 1245, dtype=dtype))
46+
assert model.weight.dtype == dtype
47+
48+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
49+
50+
model = models.Centroid(1245, 12, dtype=dtype, device=device)
51+
assert torch.allclose(model.weight, torch.zeros(12, 1245, dtype=dtype))
52+
assert model.weight.dtype == dtype
53+
assert model.weight.device == device
54+
55+
def test_add(self):
56+
samples = torch.randn(4, 12)
57+
targets = torch.tensor([0, 1, 2, 2])
58+
59+
model = models.Centroid(12, 3)
60+
model.add(samples, targets)
61+
62+
c = samples[:-1].clone()
63+
c[-1] += samples[-1]
64+
65+
assert torch.allclose(model(samples), torchhd.cos(samples, c))
66+
assert torch.allclose(model(samples, dot=True), torchhd.dot(samples, c))
67+
68+
model.normalize()
69+
print(model(samples, dot=True))
70+
print(torchhd.cos(samples, c))
71+
assert torch.allclose(
72+
model(samples, dot=True), torchhd.dot(samples, F.normalize(c))
73+
)
74+
75+
def test_add_online(self):
76+
samples = torch.randn(10, 12)
77+
targets = torch.randint(0, 3, (10,))
78+
79+
model = models.Centroid(12, 3)
80+
model.add_online(samples, targets)
81+
82+
logits = model(samples)
83+
assert logits.shape == (10, 3)
84+
85+
86+
class TestIntRVFL:
87+
@pytest.mark.parametrize("dtype", torch_dtypes)
88+
def test_initialization(self, dtype):
89+
if dtype not in MAPTensor.supported_dtypes:
90+
return
91+
92+
model = models.IntRVFL(5, 1245, 12, dtype=dtype)
93+
assert torch.allclose(model.weight, torch.zeros(12, 1245, dtype=dtype))
94+
assert model.weight.dtype == dtype
95+
96+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
97+
98+
model = models.IntRVFL(5, 1245, 12, dtype=dtype, device=device)
99+
assert torch.allclose(model.weight, torch.zeros(12, 1245, dtype=dtype))
100+
assert model.weight.dtype == dtype
101+
assert model.weight.device == device
102+
103+
def test_fit_ridge_regression(self):
104+
samples = torch.randn(10, 12)
105+
targets = torch.randint(0, 3, (10,))
106+
107+
model = models.IntRVFL(12, 1245, 3)
108+
model.fit_ridge_regression(samples, targets)
109+
110+
logits = model(samples)
111+
assert logits.shape == (10, 3)

0 commit comments

Comments
 (0)