|
31 | 31 | n_estimators : int |
32 | 32 | The number of base estimators in the ensemble. |
33 | 33 | estimator_args : dict, default=None |
34 | | - The dictionary of parameters used to instantiate base estimators. |
| 34 | + The dictionary of hyper-parameters used to instantiate base |
| 35 | + estimators (Optional). |
35 | 36 | shrinkage_rate : float, default=1 |
36 | | - The shrinkage rate in gradient boosting. |
| 37 | + The shrinkage rate used in gradient boosting. |
37 | 38 | cuda : bool, default=True |
38 | 39 |
|
39 | 40 | - If ``True``, use GPU to train and evaluate the ensemble. |
|
50 | 51 | Parameters |
51 | 52 | ---------- |
52 | 53 | train_loader : torch.utils.data.DataLoader |
53 | | - A :mod:`DataLoader` container that contains the training data. |
| 54 | + A :mod:`torch.utils.data.DataLoader` container that contains the |
| 55 | + training data. |
54 | 56 | epochs : int, default=100 |
55 | | - The number of training epochs. |
| 57 | + The number of training epochs per base estimator. |
56 | 58 | log_interval : int, default=100 |
57 | | - The number of batches to wait before printting the training status. |
| 59 | + The number of batches to wait before logging the training status. |
58 | 60 | test_loader : torch.utils.data.DataLoader, default=None |
59 | | - A :mod:`DataLoader` container that contains the evaluating data. |
| 61 | + A :mod:`torch.utils.data.DataLoader` container that contains the |
| 62 | + evaluating data. |
60 | 63 |
|
61 | | - - If ``None``, no validation is conducted after each training |
62 | | - epoch. |
| 64 | + - If ``None``, no validation is conducted after each base |
| 65 | + estimator being trained. |
63 | 66 | - If not ``None``, the ensemble will be evaluated on this |
64 | | - dataloader after each training epoch. |
| 67 | + dataloader after each base estimator being trained. |
65 | 68 | early_stopping_rounds : int, default=2 |
66 | 69 | Specify the number of tolerant rounds for early stopping. When the |
67 | 70 | validation performance of the ensemble does not improve after |
68 | 71 | adding the base estimator fitted in current iteration, the internal |
69 | 72 | counter on early stopping will increase by one. When the value of |
70 | 73 | the internal counter reaches ``early_stopping_rounds``, the |
71 | | - training stage will terminate early. |
| 74 | + training stage will terminate instantly. |
72 | 75 | save_model : bool, default=True |
73 | | - Whether to save the model. |
| 76 | + Specify whether to save the model parameters. |
74 | 77 |
|
75 | 78 | - If test_loader is ``None``, the ensemble containing |
76 | 79 | ``n_estimators`` base estimators will be saved. |
77 | 80 | - If test_loader is not ``None``, the ensemble with the best |
78 | 81 | validation performance will be saved. |
79 | 82 | save_dir : string, default=None |
80 | | - Specify where to save the model. |
| 83 | + Specify where to save the model parameters. |
81 | 84 |
|
82 | 85 | - If ``None``, the model will be saved in the current directory. |
83 | 86 | - If not ``None``, the model will be saved in the specified |
|
0 commit comments