Skip to content

Commit f5ba002

Browse files
committed
FP16
1 parent 426ddcd commit f5ba002

File tree

8 files changed

+648
-0
lines changed

8 files changed

+648
-0
lines changed

debug/animation.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
from manim import *
2+
3+
4+
class CreateCircle(Scene):
5+
def construct(self):
6+
self.camera.background_color = WHITE
7+
8+
tree_polygon_xyz = [
9+
[4, 1, 0], # middle right
10+
[4, -2.5, 0], # bottom right
11+
[0, -2.5, 0], # bottom left
12+
[0, 3, 0], # top left
13+
[2, 1, 0], # middle
14+
[4, 3, 0], # top right
15+
]
16+
17+
colorList = [RED, GREEN, BLUE, YELLOW]
18+
for i in range(200):
19+
point = Point(
20+
location=[
21+
0.63 * np.random.randint(-4, 4),
22+
0.37 * np.random.randint(-4, 4),
23+
0,
24+
],
25+
color=np.random.choice(colorList),
26+
)
27+
self.add(point)
28+
for i in range(200):
29+
point = Point(
30+
location=[
31+
0.37 * np.random.randint(-4, 4),
32+
0.63 * np.random.randint(-4, 4),
33+
0,
34+
],
35+
color=np.random.choice(colorList),
36+
)
37+
self.add(point)
38+
self.add(point)

debug/fp16-2.py

Lines changed: 194 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,194 @@
1+
# Copyright 2021 Yan Yan
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import print_function
16+
import argparse
17+
import torch
18+
import spconv.pytorch as spconv
19+
import torch.nn as nn
20+
import torch.nn.functional as F
21+
import torch.optim as optim
22+
from torchvision import datasets, transforms
23+
from torch.optim.lr_scheduler import StepLR
24+
import contextlib
25+
import torch.cuda.amp
26+
import time
27+
28+
29+
@contextlib.contextmanager
30+
def identity_ctx():
31+
yield
32+
33+
34+
class Net(nn.Module):
35+
def __init__(self):
36+
super(Net, self).__init__()
37+
self.net = spconv.SparseSequential(
38+
# nn.BatchNorm1d(1),
39+
spconv.SubMConv2d(1, 32, 3, 1),
40+
)
41+
self.fc1 = nn.Linear(14 * 14 * 64, 128)
42+
self.fc2 = nn.Linear(128, 10)
43+
self.dropout1 = nn.Dropout2d(0.25)
44+
self.dropout2 = nn.Dropout2d(0.5)
45+
46+
def forward(self, x: torch.Tensor):
47+
# x: [N, 28, 28, 1], must be NHWC tensor
48+
x_sp = spconv.SparseConvTensor.from_dense(x.reshape(-1, 28, 28, 1))
49+
x = self.net(x_sp)
50+
return torch.flatten(x.features, 1)
51+
52+
53+
def train(args, model, device, train_loader, optimizer, epoch):
54+
model.train()
55+
scaler = torch.cuda.amp.grad_scaler.GradScaler()
56+
amp_ctx = contextlib.nullcontext()
57+
if args.fp16:
58+
amp_ctx = torch.cuda.amp.autocast()
59+
60+
for batch_idx, (data, target) in enumerate(train_loader):
61+
data, target = data.to(device), target.to(device)
62+
optimizer.zero_grad()
63+
with amp_ctx:
64+
output = model(data)
65+
print(output.dtype)
66+
67+
68+
def main():
69+
# Training settings
70+
parser = argparse.ArgumentParser(description="PyTorch MNIST Example")
71+
parser.add_argument(
72+
"--batch-size",
73+
type=int,
74+
default=1000,
75+
metavar="N",
76+
help="input batch size for training (default: 64)",
77+
)
78+
parser.add_argument(
79+
"--test-batch-size",
80+
type=int,
81+
default=1000,
82+
metavar="N",
83+
help="input batch size for testing (default: 1000)",
84+
)
85+
parser.add_argument(
86+
"--epochs",
87+
type=int,
88+
default=10,
89+
metavar="N",
90+
help="number of epochs to train (default: 14)",
91+
)
92+
parser.add_argument(
93+
"--lr",
94+
type=float,
95+
default=1.0,
96+
metavar="LR",
97+
help="learning rate (default: 1.0)",
98+
)
99+
parser.add_argument(
100+
"--gamma",
101+
type=float,
102+
default=0.7,
103+
metavar="M",
104+
help="Learning rate step gamma (default: 0.7)",
105+
)
106+
parser.add_argument(
107+
"--no-cuda", action="store_true", default=False, help="disables CUDA training"
108+
)
109+
parser.add_argument(
110+
"--seed", type=int, default=1, metavar="S", help="random seed (default: 1)"
111+
)
112+
parser.add_argument(
113+
"--log-interval",
114+
type=int,
115+
default=10,
116+
metavar="N",
117+
help="how many batches to wait before logging training status",
118+
)
119+
120+
parser.add_argument(
121+
"--save-model",
122+
action="store_true",
123+
default=False,
124+
help="For Saving the current Model",
125+
)
126+
parser.add_argument(
127+
"--fp16",
128+
action="store_true",
129+
default=False,
130+
help="For mixed precision training",
131+
)
132+
133+
args = parser.parse_args()
134+
use_cuda = not args.no_cuda and torch.cuda.is_available()
135+
136+
torch.manual_seed(args.seed)
137+
138+
device = torch.device("cuda" if use_cuda else "cpu")
139+
140+
kwargs = {"num_workers": 20, "pin_memory": True} if use_cuda else {}
141+
train_loader = torch.utils.data.DataLoader(
142+
datasets.MNIST(
143+
"../data",
144+
train=True,
145+
download=True,
146+
transform=transforms.Compose(
147+
[
148+
transforms.ToTensor(),
149+
# here we remove norm to get sparse tensor with lots of zeros
150+
# transforms.Normalize((0.1307,), (0.3081,))
151+
]
152+
),
153+
),
154+
batch_size=args.batch_size,
155+
shuffle=True,
156+
**kwargs,
157+
)
158+
test_loader = torch.utils.data.DataLoader(
159+
datasets.MNIST(
160+
"../data",
161+
train=False,
162+
transform=transforms.Compose(
163+
[
164+
transforms.ToTensor(),
165+
# here we remove norm to get sparse tensor with lots of zeros
166+
# transforms.Normalize((0.1307,), (0.3081,))
167+
]
168+
),
169+
),
170+
batch_size=args.test_batch_size,
171+
shuffle=True,
172+
**kwargs,
173+
)
174+
175+
model = Net().to(device)
176+
optimizer = optim.Adadelta(model.parameters(), lr=args.lr)
177+
178+
scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)
179+
180+
t0 = time.time()
181+
for epoch in range(1, args.epochs + 1):
182+
train(args, model, device, train_loader, optimizer, epoch)
183+
184+
torch.cuda.current_stream().synchronize()
185+
t1 = time.time()
186+
187+
print(t1 - t0)
188+
189+
if args.save_model:
190+
torch.save(model.state_dict(), "mnist_cnn.pt")
191+
192+
193+
if __name__ == "__main__":
194+
main()

0 commit comments

Comments
 (0)