Skip to content

Commit 62067cc

Browse files
handle callable types in init mark
1 parent d1e7777 commit 62067cc

File tree

2 files changed

+42
-34
lines changed

2 files changed

+42
-34
lines changed

deepspeed/__init__.py

Lines changed: 21 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -66,48 +66,45 @@ def _parse_version(version_str):
6666
dist = None
6767

6868

69-
def _mark_initialized(trainobj: Union[torch.nn.Module, Optimizer, _LRScheduler]):
69+
def _mark_ds_initialized(trainobj: Union[torch.nn.Module, Optimizer, _LRScheduler]):
7070
"""Mark a trainobj as initialized by setting the ds_is_inited attribute to True."""
71-
if hasattr(trainobj, 'ds_is_inited'):
72-
assert trainobj.ds_is_inited, "Not expecting the training object has `ds_is_inited` to be False if it exists, make sure you didn't set it to False or called deepspeed.initialize on the model more than once."
73-
return
74-
7571
trainobj.ds_is_inited = True
7672

7773

78-
def _is_initialized(trainobj: Union[torch.nn.Module, Optimizer, _LRScheduler]):
74+
def _is_ds_initialized(trainobj: Union[torch.nn.Module, Optimizer, _LRScheduler]):
7975
"""Check if a trainobj has been initialized by checking the ds_is_inited attribute."""
80-
if hasattr(trainobj, 'ds_is_inited'):
81-
# we shouldn't hit the assert below, but just in case
82-
assert trainobj.ds_is_inited, "Not expecting the training object has `ds_is_inited` to be False if it exists, make sure you didn't set it to False or called deepspeed.initialize on the model more than once."
83-
return True
84-
return False
76+
return getattr(trainobj, 'ds_is_inited', False)
8577

8678

87-
def _assert_trainobjs_not_inited(model: torch.nn.Module, optimizer: Optional[Optimizer],
88-
lr_scheduler: Optional[_LRScheduler]):
79+
def _assert_trainobjs_not_inited(model: torch.nn.Module, optimizer: Optional[Union[Optimizer,
80+
DeepSpeedOptimizerCallable]],
81+
lr_scheduler: Optional[Union[_LRScheduler, DeepSpeedSchedulerCallable]]):
8982
"""Enforce the model, optimizer, and lr_scheduler have not been used in a previous deepspeed.initialize call."""
90-
if _is_initialized(model):
83+
if _is_ds_initialized(model):
9184
raise ValueError(
9285
"Model has already been initialized, please make sure to only call deepspeed.initialize on a model once.")
93-
if optimizer is not None and _is_initialized(optimizer):
86+
if optimizer is not None and isinstance(optimizer, Optimizer) and _is_ds_initialized(optimizer):
9487
raise ValueError(
9588
"Optimizer has already been initialized, please make sure to only call deepspeed.initialize on an optimizer once."
9689
)
97-
if lr_scheduler is not None and _is_initialized(lr_scheduler):
90+
if lr_scheduler is not None and isinstance(lr_scheduler, _LRScheduler) and _is_ds_initialized(lr_scheduler):
9891
raise ValueError(
9992
"LR scheduler has already been initialized, please make sure to only call deepspeed.initialize on an LR scheduler once."
10093
)
10194

10295

103-
def _mark_trainobjs_initialized(model: torch.nn.Module, optimizer: Optional[Optimizer],
104-
lr_scheduler: Optional[_LRScheduler]):
105-
"""Mark the model, optimizer, and lr_scheduler as initialized."""
106-
_mark_initialized(model)
107-
if optimizer is not None:
108-
_mark_initialized(optimizer)
109-
if lr_scheduler is not None:
110-
_mark_initialized(lr_scheduler)
96+
def _mark_trainobjs_initialized(model: torch.nn.Module, optimizer: Optional[Union[Optimizer,
97+
DeepSpeedOptimizerCallable]],
98+
lr_scheduler: Optional[Union[_LRScheduler, DeepSpeedSchedulerCallable]]):
99+
"""Mark the model, optimizer, and lr_scheduler as initialized.
100+
Note that callables of type DeepSpeedOptimizerCallable and DeepSpeedSchedulerCallable are not marked
101+
as they are not stateful and reuse should be permissible.
102+
"""
103+
_mark_ds_initialized(model)
104+
if optimizer is not None and isinstance(optimizer, Optimizer):
105+
_mark_ds_initialized(optimizer)
106+
if lr_scheduler is not None and isinstance(lr_scheduler, _LRScheduler):
107+
_mark_ds_initialized(lr_scheduler)
111108

112109

113110
def initialize(args=None,

tests/unit/runtime/test_ds_initialize.py

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from deepspeed.utils.torch import required_torch_version
2222
from deepspeed.accelerator import get_accelerator
2323
from deepspeed.ops.op_builder import FusedAdamBuilder
24-
from deepspeed import _assert_trainobjs_not_inited, _is_initialized
24+
from deepspeed import _assert_trainobjs_not_inited, _is_ds_initialized
2525

2626

2727
@pytest.mark.parametrize('zero_stage', [0, 3])
@@ -441,12 +441,22 @@ def _lr_scheduler_callable(optimizer) -> _LRScheduler:
441441
class TestNoRepeatedInitializationAllowed(DistributedTest):
442442
world_size = 1
443443

444-
def test_no_repeated_init(self):
444+
@pytest.mark.parametrize('optimizer_type', [None, Optimizer, Callable])
445+
def test(self, optimizer_type):
445446
hidden_dim = 10
446447
model = SimpleModel(hidden_dim)
447-
client_optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
448-
# Initialize DeepSpeed configurations for fp16
448+
449+
def _optimizer_callable(params) -> Optimizer:
450+
return AdamW(params=params)
451+
449452
config_dict = {'train_batch_size': 1}
453+
if optimizer_type is None:
454+
client_optimizer = None
455+
config_dict['optimizer'] = {'type': ADAM_OPTIMIZER}
456+
elif optimizer_type is Optimizer:
457+
client_optimizer = Adam(model.parameters())
458+
else:
459+
client_optimizer = _optimizer_callable
450460

451461
# Initialize DeepSpeed engine
452462
_assert_trainobjs_not_inited(model=model, optimizer=client_optimizer, lr_scheduler=None)
@@ -455,12 +465,13 @@ def test_no_repeated_init(self):
455465
config_params=config_dict)
456466

457467
# arguments should be marked as initialized now
458-
assert _is_initialized(model), "Client model should be marked as initialized"
459-
assert _is_initialized(client_optimizer), "Client optimizer should be marked as initialized"
468+
assert _is_ds_initialized(model), "Client model should be marked as initialized"
469+
if optimizer_type is Optimizer:
470+
assert _is_ds_initialized(client_optimizer), "Client optimizer should be marked as initialized"
460471

461472
# return values should also be marked as initialized
462-
assert _is_initialized(model_engine), "Model engine should be marked as initialized"
463-
assert _is_initialized(optim), "Optimizer should be marked as initialized"
473+
assert _is_ds_initialized(model_engine), "Model engine should be marked as initialized"
474+
assert _is_ds_initialized(optim), "Optimizer should be marked as initialized"
464475

465476
exception_raised = False
466477
try:
@@ -480,15 +491,15 @@ def test_no_repeated_init(self):
480491

481492
exception_raised = False
482493
try:
483-
deepspeed.initialize(model=model, optimizer=client_optimizer, config_params=config_dict)
494+
deepspeed.initialize(model=model, optimizer=optim, config_params=config_dict)
484495
except ValueError:
485496
exception_raised = True
486497

487498
assert exception_raised, "Initialization on ds types should raise an exception"
488499

489500
exception_raised = False
490501
try:
491-
deepspeed.initialize(model=model_engine, optimizer=client_optimizer, config_params=config_dict)
502+
deepspeed.initialize(model=model_engine, optimizer=optim, config_params=config_dict)
492503
except ValueError:
493504
exception_raised = True
494505
assert exception_raised, "Initialization on ds types should raise an exception"

0 commit comments

Comments
 (0)