-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathconvolution.py
More file actions
64 lines (54 loc) · 2.15 KB
/
convolution.py
File metadata and controls
64 lines (54 loc) · 2.15 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import torch
import torchvision.datasets
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
from torch.utils.data import DataLoader
train_dataset = torchvision.datasets.MNIST(root='./data', train=True, transform=transforms.ToTensor())
test_dataset = torchvision.datasets.MNIST(root='./data', train=False, transform=transforms.ToTensor())
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=32, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=32, shuffle=True)
class convolution(nn.Module):
def __init__(self, inputsize, outputsize):
super(convolution, self).__init__()
self.inputsize = inputsize
self.outputsize = outputsize
self.conv1 = nn.Conv2d(1, 6, 3)
self.pool1 = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 3)
self.pool2 = nn.MaxPool2d(2, 2)
self.linear1 = nn.Linear(16 * 5 * 5, 120)
self.relu1 = nn.ReLU()
self.linear2 = nn.Linear(120, outputsize)
def forward(self, x):
a = self.pool1(self.conv1(x))
a = self.pool2(self.conv2(a))
# print(a.shape)
a = self.linear1(a.reshape(-1, 16 * 5 * 5))
a = self.linear2(self.relu1(a))
return a
model = convolution(28 * 28, 10)
optimizer = torch.optim.Adam(params=model.parameters(), lr=0.01)
criterion = nn.CrossEntropyLoss()
n = len(train_dataset)
for i in range(10):
for j, (image, label) in enumerate(train_loader):
# image = image.reshape(-1,28*28)
out = model(image)
loss = criterion(out, label)
loss.backward()
optimizer.step()
optimizer.zero_grad()
print(i, loss.item())
with torch.no_grad():
n_correct = 0
n_samples = 0
for images, labels in test_loader:
# images = images.reshape(-1, 28 * 28)
outputs = model(images)
# max returns (value ,index)
_, predicted = torch.max(outputs.data, 1)
n_samples += labels.size(0)
n_correct += (predicted == labels).sum().item()
acc = 100.0 * n_correct / n_samples
print(f'Accuracy of the network on the 10000 test images: {acc} %')