Skip to content

Commit f071ff9

Browse files
authored
feat: relax tensorboard as a soft dependency (#65)
* add code * Update CHANGELOG.rst * Update requirements.txt * Update logging.py
1 parent 4070b7e commit f071ff9

File tree

4 files changed

+9
-7
lines changed

4 files changed

+9
-7
lines changed

CHANGELOG.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ Changelog
44
Ver 0.1.*
55
---------
66

7+
* |Enhancement| Relax :mod:`tensorboard` as a soft dependency | `@xuyxu <https://github.com/xuyxu>`__
78
* |Enhancement| |API| Simplify the training workflow of :class:`FastGeometricClassifier` and :class:`FastGeometricRegressor` | `@xuyxu <https://github.com/xuyxu>`__
89
* |Feature| |API| Support TensorBoard logging in :meth:`set_logger` | `@zzzzwj <https://github.com/zzzzwj>`__
910
* |Enhancement| |API| Add ``use_reduction_sum`` parameter for :meth:`fit` of Gradient Boosting | `@xuyxu <https://github.com/xuyxu>`__

build_tools/requirements.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
flake8
22
pytest-cov
3-
black==20.8b1
3+
black==20.8b1
4+
tensorboard==2.*

requirements.txt

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
11
torch>=1.4.0
22
torchvision>=0.2.2
3-
scikit-learn>=0.23.0
4-
tensorboard==2.*
3+
scikit-learn>=0.23.0

torchensemble/utils/logging.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -83,15 +83,16 @@ def _get_level(level):
8383

8484
def init_tb_logger(log_dir):
8585
try:
86-
from torch.utils.tensorboard import SummaryWriter
86+
import tensorboard # noqa: F401
8787
except ModuleNotFoundError:
8888
msg = (
89-
"Cannot load the module torch when building the "
90-
"ImageScanner. Please make sure that tensorboard is"
91-
" installed."
89+
"Cannot load the module tensorboard. Please make sure that"
90+
" tensorboard is installed."
9291
)
9392
raise ModuleNotFoundError(msg)
9493

94+
from torch.utils.tensorboard import SummaryWriter
95+
9596
global _tb_logger
9697

9798
if not _tb_logger:

0 commit comments

Comments
 (0)