Skip to content

Commit 20bf64d

Browse files
author
pereverges
committed
accelerator examples
1 parent 63bd0e6 commit 20bf64d

File tree

2 files changed

+175
-0
lines changed

2 files changed

+175
-0
lines changed

examples/mnist_hugging_face.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
import torch
2+
import torch.nn as nn
3+
import torch.nn.functional as F
4+
import torchvision
5+
from torchvision.datasets import MNIST
6+
7+
# Note: this example requires the torchmetrics library: https://torchmetrics.readthedocs.io
8+
import torchmetrics
9+
from tqdm import tqdm
10+
from accelerate import Accelerator
11+
12+
from torchhd import functional
13+
from torchhd import embeddings
14+
15+
accelerator = Accelerator()
16+
device = accelerator.device
17+
print("Using {} device".format(device))
18+
19+
DIMENSIONS = 10000
20+
IMG_SIZE = 28
21+
NUM_LEVELS = 1000
22+
BATCH_SIZE = 1 # for GPUs with enough memory we can process multiple images at ones
23+
24+
transform = torchvision.transforms.ToTensor()
25+
26+
train_ds = MNIST("../data", train=True, transform=transform, download=True)
27+
train_ld = torch.utils.data.DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
28+
29+
test_ds = MNIST("../data", train=False, transform=transform, download=True)
30+
test_ld = torch.utils.data.DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False)
31+
32+
33+
class Model(nn.Module):
34+
def __init__(self, num_classes, size):
35+
super(Model, self).__init__()
36+
37+
self.flatten = torch.nn.Flatten()
38+
39+
self.position = embeddings.Random(size * size, DIMENSIONS)
40+
self.value = embeddings.Level(NUM_LEVELS, DIMENSIONS)
41+
42+
self.classify = nn.Linear(DIMENSIONS, num_classes, bias=False)
43+
self.classify.weight.data.fill_(0.0)
44+
45+
def encode(self, x):
46+
x = self.flatten(x)
47+
sample_hv = functional.bind(self.position.weight, self.value(x))
48+
sample_hv = functional.multiset(sample_hv)
49+
return functional.hard_quantize(sample_hv)
50+
51+
def forward(self, x):
52+
enc = self.encode(x)
53+
logit = self.classify(enc)
54+
return logit
55+
56+
57+
model = Model(len(train_ds.classes), IMG_SIZE)
58+
model.to(device)
59+
60+
model, train_ld, test_ld = accelerator.prepare(
61+
model, train_ld, test_ld
62+
)
63+
64+
with torch.no_grad():
65+
for samples, labels in tqdm(train_ld, desc="Training"):
66+
samples = samples.to(device)
67+
labels = labels.to(device)
68+
69+
samples_hv = model.encode(samples)
70+
model.classify.weight[labels] += samples_hv
71+
72+
model.classify.weight[:] = F.normalize(model.classify.weight)
73+
74+
accuracy = torchmetrics.Accuracy()
75+
76+
77+
with torch.no_grad():
78+
for samples, labels in tqdm(test_ld, desc="Testing"):
79+
samples = samples.to(device)
80+
81+
outputs = model(samples)
82+
predictions = torch.argmax(outputs, dim=-1)
83+
accuracy.update(predictions.cpu(), labels)
84+
85+
print(f"Testing accuracy of {(accuracy.compute().item() * 100):.3f}%")

examples/mnist_torch_lightning.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
import torch
2+
import torch.nn as nn
3+
import torch.nn.functional as F
4+
import torchvision
5+
from torchvision.datasets import MNIST
6+
7+
# Note: this example requires the torchmetrics library: https://torchmetrics.readthedocs.io
8+
import torchmetrics
9+
from tqdm import tqdm
10+
import pytorch_lightning as pl
11+
from torchhd import functional
12+
from torchhd import embeddings
13+
14+
15+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16+
print("Using {} device".format(device))
17+
18+
DIMENSIONS = 10000
19+
IMG_SIZE = 28
20+
NUM_LEVELS = 1000
21+
BATCH_SIZE = 1 # for GPUs with enough memory we can process multiple images at ones
22+
23+
transform = torchvision.transforms.ToTensor()
24+
25+
train_ds = MNIST("../data", train=True, transform=transform, download=True)
26+
train_ld = torch.utils.data.DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
27+
28+
test_ds = MNIST("../data", train=False, transform=transform, download=True)
29+
test_ld = torch.utils.data.DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False)
30+
31+
32+
class Model(pl.LightningModule):
33+
def __init__(self, num_classes, size):
34+
super(Model, self).__init__()
35+
36+
self.flatten = torch.nn.Flatten()
37+
38+
self.position = embeddings.Random(size * size, DIMENSIONS)
39+
self.value = embeddings.Level(NUM_LEVELS, DIMENSIONS)
40+
41+
self.classify = nn.Linear(DIMENSIONS, num_classes, bias=False)
42+
self.classify.weight.data.fill_(0.0)
43+
44+
def encode(self, x):
45+
x = self.flatten(x)
46+
sample_hv = functional.bind(self.position.weight, self.value(x))
47+
sample_hv = functional.multiset(sample_hv)
48+
return functional.hard_quantize(sample_hv)
49+
50+
def forward(self, x):
51+
enc = self.encode(x)
52+
logit = self.classify(enc)
53+
return logit
54+
55+
def training_step(self, batch):
56+
return
57+
58+
def configure_optimizers(self):
59+
return
60+
61+
62+
model = Model(len(train_ds.classes), IMG_SIZE)
63+
trainer = pl.Trainer(
64+
accelerator="cpu",
65+
devices=1,
66+
)
67+
trainer.fit(model, train_ld, test_ld)
68+
69+
with torch.no_grad():
70+
for samples, labels in tqdm(train_ld, desc="Training"):
71+
samples = samples.to(device)
72+
labels = labels.to(device)
73+
74+
samples_hv = model.encode(samples)
75+
model.classify.weight[labels] += samples_hv
76+
77+
model.classify.weight[:] = F.normalize(model.classify.weight)
78+
79+
accuracy = torchmetrics.Accuracy()
80+
81+
82+
with torch.no_grad():
83+
for samples, labels in tqdm(test_ld, desc="Testing"):
84+
samples = samples.to(device)
85+
86+
outputs = model(samples)
87+
predictions = torch.argmax(outputs, dim=-1)
88+
accuracy.update(predictions.cpu(), labels)
89+
90+
print(f"Testing accuracy of {(accuracy.compute().item() * 100):.3f}%")

0 commit comments

Comments
 (0)