Skip to content

Commit 3c86694

Browse files
committed
[Update] classmethod, from_config
1 parent 9f0d6b5 commit 3c86694

File tree

5 files changed

+26
-14
lines changed

5 files changed

+26
-14
lines changed

datasets/__init__.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,8 @@
11
from .mnist_dataset import MNIST_Dataset
22

3-
def load_dataset(**cfg):
3+
def load_dataset(cfg):
44
if cfg['dataset'] == 'MNIST':
5-
return MNIST_Dataset(root=cfg['root'],
6-
download=cfg['download'],
7-
mode=cfg['mode'])
5+
return MNIST_Dataset.from_config(cfg)
86

97
else:
108
raise Exception(f"Dataset: {cfg['dataset']} is not supported.")

datasets/mnist_dataset.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,4 +37,10 @@ def __getitem__(self, idx):
3737

3838
label = torch.tensor(label, dtype=torch.int64)
3939

40-
return image, label
40+
return image, label
41+
42+
@classmethod
43+
def from_config(cls, cfg):
44+
return cls(root=cfg['root'],
45+
download=cfg['download'],
46+
mode=cfg['mode'])

models/ConvNet/convnet.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,4 +33,10 @@ def forward(self, x):
3333

3434
x = self.li(x)
3535

36-
return x
36+
return x
37+
38+
@classmethod
39+
def from_config(cls, cfg):
40+
return cls(in_channels=cfg['in_channels'],
41+
layers=cfg['layers'],
42+
class_num=cfg['class_num'])

models/ConvNet2/convnet2.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,4 +33,10 @@ def forward(self, x):
3333

3434
x = self.li(x)
3535

36-
return x
36+
return x
37+
38+
@classmethod
39+
def from_config(cls, cfg):
40+
return cls(in_channels=cfg['in_channels'],
41+
layers=cfg['layers'],
42+
class_num=cfg['class_num'])

models/__init__.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,12 @@
11
from .ConvNet.convnet import ConvNet
22
from .ConvNet2.convnet2 import ConvNet2
33

4-
def load_model(**cfg):
4+
def load_model(cfg):
55
if cfg['name'] == 'ConvNet':
6-
return ConvNet(in_channels=cfg['in_channels'],
7-
layers=cfg['layers'],
8-
class_num=cfg['class_num'])
6+
return ConvNet.from_config(cfg)
97

108
elif cfg['name'] == 'ConvNet2':
11-
return ConvNet2(in_channels=cfg['in_channels'],
12-
layers=cfg['layers'],
13-
class_num=cfg['class_num'])
9+
return ConvNet2.from_config(cfg)
1410

1511
else:
1612
raise Exception(f"Model: {cfg['name']} is not supported.")

0 commit comments

Comments
 (0)