Skip to content

Commit ee72f18

Browse files
authored
[DOC] Improve docstrings for better readability
* [DOC] Improve docstrings * [MNT] Update unit tests * flake8 formatting
1 parent 04e615b commit ee72f18

File tree

8 files changed

+50
-36
lines changed

8 files changed

+50
-36
lines changed

torchensemble/_base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def __len__(self):
7474
"""
7575
Return the number of base estimators in the ensemble. The real number
7676
of base estimators may not match `self.n_estimators` because of the
77-
early stopping stage in several ensembles.
77+
early stopping stage in several ensembles such as Gradient Boosting.
7878
"""
7979
return len(self.estimators_)
8080

torchensemble/_constants.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@
8383
8484
- If ``None``, the model will be saved in the current directory.
8585
- If not ``None``, the model will be saved in the specified
86-
directory.
86+
directory: ``save_dir``.
8787
"""
8888

8989

torchensemble/adversarial_training.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -31,30 +31,32 @@
3131
Parameters
3232
----------
3333
train_loader : torch.utils.data.DataLoader
34-
A :mod:`DataLoader` container that contains the training data.
34+
A :mod:`torch.utils.data.DataLoader` container that contains the
35+
training data.
3536
epochs : int, default=100
3637
The number of training epochs.
37-
epsilon : float, defaul=0.01
38+
epsilon : float, default=0.01
3839
The step used to generate adversarial samples in the fast gradient
3940
sign method (FGSM), which should be in the range [0, 1].
4041
log_interval : int, default=100
41-
The number of batches to wait before printting the training status.
42+
The number of batches to wait before logging the training status.
4243
test_loader : torch.utils.data.DataLoader, default=None
43-
A :mod:`DataLoader` container that contains the evaluating data.
44+
A :mod:`torch.utils.data.DataLoader` container that contains the
45+
evaluating data.
4446
4547
- If ``None``, no validation is conducted after each training
4648
epoch.
4749
- If not ``None``, the ensemble will be evaluated on this
4850
dataloader after each training epoch.
4951
save_model : bool, default=True
50-
Whether to save the model.
52+
Specify whether to save the model parameters.
5153
52-
- If test_loader is ``None``, the ensemble containing
53-
``n_estimators`` base estimators will be saved.
54+
- If test_loader is ``None``, the ensemble fully trained will be
55+
saved.
5456
- If test_loader is not ``None``, the ensemble with the best
5557
validation performance will be saved.
5658
save_dir : string, default=None
57-
Specify where to save the model.
59+
Specify where to save the model parameters.
5860
5961
- If ``None``, the model will be saved in the current directory.
6062
- If not ``None``, the model will be saved in the specified

torchensemble/gradient_boosting.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,10 @@
3131
n_estimators : int
3232
The number of base estimators in the ensemble.
3333
estimator_args : dict, default=None
34-
The dictionary of parameters used to instantiate base estimators.
34+
The dictionary of hyper-parameters used to instantiate base
35+
estimators (Optional).
3536
shrinkage_rate : float, default=1
36-
The shrinkage rate in gradient boosting.
37+
The shrinkage rate used in gradient boosting.
3738
cuda : bool, default=True
3839
3940
- If ``True``, use GPU to train and evaluate the ensemble.
@@ -50,34 +51,36 @@
5051
Parameters
5152
----------
5253
train_loader : torch.utils.data.DataLoader
53-
A :mod:`DataLoader` container that contains the training data.
54+
A :mod:`torch.utils.data.DataLoader` container that contains the
55+
training data.
5456
epochs : int, default=100
55-
The number of training epochs.
57+
The number of training epochs per base estimator.
5658
log_interval : int, default=100
57-
The number of batches to wait before printting the training status.
59+
The number of batches to wait before logging the training status.
5860
test_loader : torch.utils.data.DataLoader, default=None
59-
A :mod:`DataLoader` container that contains the evaluating data.
61+
A :mod:`torch.utils.data.DataLoader` container that contains the
62+
evaluating data.
6063
61-
- If ``None``, no validation is conducted after each training
62-
epoch.
64+
- If ``None``, no validation is conducted after each base
65+
estimator being trained.
6366
- If not ``None``, the ensemble will be evaluated on this
64-
dataloader after each training epoch.
67+
dataloader after each base estimator being trained.
6568
early_stopping_rounds : int, default=2
6669
Specify the number of tolerant rounds for early stopping. When the
6770
validation performance of the ensemble does not improve after
6871
adding the base estimator fitted in current iteration, the internal
6972
counter on early stopping will increase by one. When the value of
7073
the internal counter reaches ``early_stopping_rounds``, the
71-
training stage will terminate early.
74+
training stage will terminate instantly.
7275
save_model : bool, default=True
73-
Whether to save the model.
76+
Specify whether to save the model parameters.
7477
7578
- If test_loader is ``None``, the ensemble containing
7679
``n_estimators`` base estimators will be saved.
7780
- If test_loader is not ``None``, the ensemble with the best
7881
validation performance will be saved.
7982
save_dir : string, default=None
80-
Specify where to save the model.
83+
Specify where to save the model parameters.
8184
8285
- If ``None``, the model will be saved in the current directory.
8386
- If not ``None``, the model will be saved in the specified

torchensemble/snapshot_ensemble.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -44,23 +44,24 @@
4444
epochs : int, default=100
4545
The number of training epochs.
4646
log_interval : int, default=100
47-
The number of batches to wait before printting the training status.
47+
The number of batches to wait before logging the training status.
4848
test_loader : torch.utils.data.DataLoader, default=None
49-
A :mod:`DataLoader` container that contains the evaluating data.
49+
A :mod:`torch.utils.data.DataLoader` container that contains the
50+
evaluating data.
5051
51-
- If ``None``, no validation is conducted after each training
52-
epoch.
52+
- If ``None``, no validation is conducted after each snapshot model
53+
being generated.
5354
- If not ``None``, the ensemble will be evaluated on this
54-
dataloader after each training epoch.
55+
dataloader after each snapshot model being generated.
5556
save_model : bool, default=True
56-
Whether to save the model.
57+
Specify whether to save the model parameters.
5758
58-
- If test_loader is ``None``, the ensemble containing
59+
- If test_loader is ``None``, the ensemble with
5960
``n_estimators`` base estimators will be saved.
6061
- If test_loader is not ``None``, the ensemble with the best
6162
validation performance will be saved.
6263
save_dir : string, default=None
63-
Specify where to save the model.
64+
Specify where to save the model parameters.
6465
6566
- If ``None``, the model will be saved in the current directory.
6667
- If not ``None``, the model will be saved in the specified

torchensemble/tests/test_set_optimizer.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,12 @@ def test_set_optimizer_normal(optimizer_name):
3939

4040
def test_set_optimizer_Unknown():
4141
model = MLP()
42-
with pytest.raises(NotImplementedError) as excinfo:
42+
43+
err_msg = ("Unrecognized optimizer: {}, should be one of"
44+
" {{Adadelta, Adagrad, Adam, AdamW, Adamax, ASGD, RMSprop,"
45+
" Rprop, SGD}}.").format("Unknown")
46+
with pytest.raises(NotImplementedError, match=err_msg):
4347
torchensemble.utils.set_module.set_optimizer(model, "Unknown")
44-
assert "Unknown name of the optimizer" in str(excinfo.value)
4548

4649

4750
def test_update_lr():

torchensemble/tests/test_set_scheduler.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,11 @@ def test_set_scheduler_ReduceLROnPlateau():
7979
def test_set_scheduler_Unknown():
8080
model = MLP()
8181
optimizer = torch.optim.Adam(model.parameters())
82-
with pytest.raises(NotImplementedError) as excinfo:
82+
83+
err_msg = ("Unrecognized scheduler: {}, should be one of"
84+
" {{LambdaLR, MultiplicativeLR, StepLR, MultiStepLR,"
85+
" ExponentialLR, CosineAnnealingLR, ReduceLROnPlateau,"
86+
" CyclicLR, OneCycleLR, CosineAnnealingWarmRestarts}}.")
87+
err_msg = err_msg.format("Unknown")
88+
with pytest.raises(NotImplementedError, match=err_msg):
8389
torchensemble.utils.set_module.set_scheduler(optimizer, "Unknown")
84-
assert "Unknown name of the scheduler" in str(excinfo.value)

torchensemble/utils/set_module.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def set_optimizer(model, optimizer_name, **kwargs):
2828
elif optimizer_name == "SGD":
2929
optimizer = torch.optim.SGD(model.parameters(), **kwargs)
3030
else:
31-
msg = ("Unknown name of the optimizer {}, should be one of"
31+
msg = ("Unrecognized optimizer: {}, should be one of"
3232
" {{Adadelta, Adagrad, Adam, AdamW, Adamax, ASGD, RMSprop,"
3333
" Rprop, SGD}}.")
3434
raise NotImplementedError(msg.format(optimizer_name))
@@ -77,7 +77,7 @@ def set_scheduler(optimizer, scheduler_name, **kwargs):
7777
elif scheduler_name == "ReduceLROnPlateau":
7878
scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, **kwargs)
7979
else:
80-
msg = ("Unknown name of the scheduler {}, should be one of"
80+
msg = ("Unrecognized scheduler: {}, should be one of"
8181
" {{LambdaLR, MultiplicativeLR, StepLR, MultiStepLR,"
8282
" ExponentialLR, CosineAnnealingLR, ReduceLROnPlateau,"
8383
" CyclicLR, OneCycleLR, CosineAnnealingWarmRestarts}}.")

0 commit comments

Comments
 (0)