Skip to content

Commit 1b15bea

Browse files
add: split TestNoRepeatedInitializationAllowed test into two separate ones
1 parent b1d4330 commit 1b15bea

File tree

1 file changed

+22
-1
lines changed

1 file changed

+22
-1
lines changed

tests/unit/runtime/test_ds_initialize.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -442,7 +442,7 @@ class TestNoRepeatedInitializationAllowed(DistributedTest):
442442
world_size = 1
443443

444444
@pytest.mark.parametrize('optimizer_type', [None, Optimizer, Callable])
445-
def test(self, optimizer_type):
445+
def test_objs_marked_ds_inited(self, optimizer_type):
446446
hidden_dim = 10
447447
model = SimpleModel(hidden_dim)
448448

@@ -472,6 +472,27 @@ def _optimizer_callable(params) -> Optimizer:
472472
assert _is_ds_initialized(model_engine), "Model engine should be marked as initialized"
473473
assert _is_ds_initialized(optim), "Optimizer should be marked as initialized"
474474

475+
@pytest.mark.parametrize('optimizer_type', [None, Optimizer, Callable])
476+
def test_repeated_initialization_raises_error(self, optimizer_type):
477+
hidden_dim = 10
478+
model = SimpleModel(hidden_dim)
479+
480+
def _optimizer_callable(params) -> Optimizer:
481+
return AdamW(params=params)
482+
483+
config_dict = {'train_batch_size': 1}
484+
if optimizer_type is None:
485+
client_optimizer = None
486+
config_dict['optimizer'] = {'type': ADAM_OPTIMIZER}
487+
elif optimizer_type is Optimizer:
488+
client_optimizer = Adam(model.parameters())
489+
else:
490+
client_optimizer = _optimizer_callable
491+
492+
# Initialize DeepSpeed engine
493+
model_engine, optim, _, _ = deepspeed.initialize(model=model,
494+
optimizer=client_optimizer,
495+
config_params=config_dict)
475496
err_msg_pattern = "has already been initialized"
476497
with pytest.raises(ValueError, match=err_msg_pattern):
477498
deepspeed.initialize(model=model, optimizer=client_optimizer, config_params=config_dict)

0 commit comments

Comments
 (0)