Skip to content

Commit 029e0a3

Browse files
tjruwaseloadams
andauthored
Use ds-specific module id to avoid conflicts (#6847)
Fix #6772 --------- Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
1 parent 4fea41f commit 029e0a3

File tree

4 files changed

+60
-24
lines changed

4 files changed

+60
-24
lines changed

deepspeed/runtime/zero/parameter_offload.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,7 @@ def _start_of_forward_hook(module, *args):
243243
self.module.register_forward_pre_hook(_start_of_forward_hook)
244244

245245
#likely one of them should be enough but just to be safe
246-
self._register_hooks_recursively(self.module)
246+
self._register_deepspeed_module(self.module)
247247

248248
# Add top module to stack trace
249249
global FWD_MODULE_STACK
@@ -269,19 +269,19 @@ def mark_persistent_parameters(self, param_threshold, model_threshold):
269269

270270
return persistent_params
271271

272-
def _register_hooks_recursively(self, module, count=[0]):
272+
def _register_deepspeed_module(self, module, count=[0]):
273273
my_count = count[0]
274-
module.id = my_count
274+
module.ds_id = my_count
275275

276-
#print(f"{module.__class__} : {module.id}")
276+
#print(f"{module.__class__} : {module.ds_id}")
277277

278278
if z3_leaf_module(module):
279279
for param in module.parameters():
280280
param.ds_z3_leaf_module = module
281281
else:
282282
for child in module.children():
283283
count[0] = count[0] + 1
284-
self._register_hooks_recursively(child, count=count)
284+
self._register_deepspeed_module(child, count=count)
285285

286286
@instrument_w_nvtx
287287
def _pre_forward_module_hook(module, *args):
@@ -466,14 +466,16 @@ def pre_sub_module_forward_function(self, sub_module):
466466

467467
@torch.no_grad()
468468
def post_sub_module_forward_function(self, sub_module):
469-
see_memory_usage(f"After sub module function {sub_module.__class__.__name__} {sub_module.id} before release",
470-
force=False)
469+
see_memory_usage(
470+
f"After sub module function {sub_module.__class__.__name__} {sub_module.ds_id} before release",
471+
force=False)
471472

472473
param_coordinator = self.get_param_coordinator()
473474
param_coordinator.release_sub_module(sub_module)
474475

475-
see_memory_usage(f"After sub module function {sub_module.__class__.__name__} {sub_module.id} after release",
476-
force=False)
476+
see_memory_usage(
477+
f"After sub module function {sub_module.__class__.__name__} {sub_module.ds_id} after release",
478+
force=False)
477479

478480
@torch.no_grad()
479481
def pre_sub_module_backward_function(self, sub_module):
@@ -488,13 +490,13 @@ def pre_sub_module_backward_function(self, sub_module):
488490
def post_sub_module_backward_function(self, sub_module):
489491
# assert sub_module.training, "backward pass is invalid for module in evaluation mode"
490492
see_memory_usage(
491-
f"After sub module backward function {sub_module.__class__.__name__} {sub_module.id} before release",
493+
f"After sub module backward function {sub_module.__class__.__name__} {sub_module.ds_id} before release",
492494
force=False)
493495

494496
self.get_param_coordinator().release_sub_module(sub_module)
495497

496498
see_memory_usage(
497-
f"After sub module backward function {sub_module.__class__.__name__} {sub_module.id} after release",
499+
f"After sub module backward function {sub_module.__class__.__name__} {sub_module.ds_id} after release",
498500
force=False)
499501

500502
def _set_z3_leaf_modules_by_threshold(self, module, zero_module_granularity_threshold):

deepspeed/runtime/zero/partitioned_param_coordinator.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -175,17 +175,17 @@ def trace_prologue(self, sub_module: Module) -> None:
175175
# sub_module must match expectation else invalidate trace cache
176176
if len(self.__submodule_order) <= self.__step_id:
177177
print_rank_0(
178-
f"Invalidate trace cache @ step {self.__step_id} and module {sub_module.id}: "
178+
f"Invalidate trace cache @ step {self.__step_id} and module {sub_module.ds_id}: "
179179
f"cache has only {len(self.__submodule_order)} modules",
180180
force=True)
181181
self._invalidate_trace()
182182
return
183183

184184
if sub_module != self.__submodule_order[self.__step_id]:
185-
expected_module_id = self.__submodule_order[self.__step_id].id
185+
expected_module_id = self.__submodule_order[self.__step_id].ds_id
186186
print_rank_0(
187187
f"Invalidate trace cache @ step {self.__step_id}: "
188-
f"expected module {expected_module_id}, but got module {sub_module.id}",
188+
f"expected module {expected_module_id}, but got module {sub_module.ds_id}",
189189
force=True)
190190
self._invalidate_trace()
191191

@@ -199,7 +199,7 @@ def record_module(self, sub_module: Module) -> None:
199199
raise RuntimeError(f"attempted to record trace when status = {self.__trace_mode}")
200200

201201
self.__submodule_order.append(sub_module)
202-
self.__step_id_module_fetched_for[sub_module.id].append(self.__step_id)
202+
self.__step_id_module_fetched_for[sub_module.ds_id].append(self.__step_id)
203203

204204
def record_parameters(self, sub_module: Module) -> None:
205205
if is_compiling():
@@ -208,7 +208,7 @@ def record_parameters(self, sub_module: Module) -> None:
208208
if not self.is_record_trace():
209209
raise RuntimeError(f"attempted to record trace when status = {self.__trace_mode}")
210210

211-
step_id = self.__step_id_module_fetched_for[sub_module.id].popleft()
211+
step_id = self.__step_id_module_fetched_for[sub_module.ds_id].popleft()
212212
for param in sorted(set(iter_params(sub_module, recurse=z3_leaf_module(sub_module))), key=lambda p: p.ds_id):
213213
self.__param_order.append(__class__.__ParamInTrace(param=param, step_id_last_used_at=step_id))
214214

@@ -228,7 +228,7 @@ def reset_step(self) -> None:
228228

229229
if not self.is_complete_trace(): # not self.trace_complete:
230230
# Make sure that recorded submodule orders are identical across ranks
231-
assert_ints_same_as_other_ranks([m.id for m in self.__submodule_order])
231+
assert_ints_same_as_other_ranks([m.ds_id for m in self.__submodule_order])
232232

233233
if self.is_record_trace():
234234
# Successfully recorded a trace
@@ -241,7 +241,7 @@ def reset_step(self) -> None:
241241
self.__param_order = tuple(self.__param_order) # freeze
242242
self.__trace_mode = ZeRoTraceMode.COMPLETE
243243
print_rank_0(
244-
f"completed record trace of {len(self.__submodule_order)} sub modules: {[m.id for m in self.__submodule_order]}",
244+
f"completed record trace of {len(self.__submodule_order)} sub modules: {[m.ds_id for m in self.__submodule_order]}",
245245
force=False)
246246
else:
247247
# Enable trace recording for next forward/backward pass
@@ -284,7 +284,7 @@ def fetch_sub_module(self, current_submodule: Module, forward: bool) -> None:
284284
"""
285285
if logger.isEnabledFor(logging.DEBUG):
286286
debug_rank0(
287-
f"{self.__step_id}: M{current_submodule.id}({type(current_submodule).__name__}) P{[p.ds_id for p in iter_params(current_submodule, recurse=z3_leaf_module(current_submodule))]} "
287+
f"{self.__step_id}: M{current_submodule.ds_id}({type(current_submodule).__name__}) P{[p.ds_id for p in iter_params(current_submodule, recurse=z3_leaf_module(current_submodule))]} "
288288
+ str({
289289
"avail": f"{self.__n_available_params:.1e}",
290290
"queue_sz": f"{len(self.__param_queue or [])}",
@@ -297,7 +297,7 @@ def fetch_sub_module(self, current_submodule: Module, forward: bool) -> None:
297297

298298
if fetch_numel > 0:
299299
event_name = __class__.FORWARD_FETCH_SUBMIT if forward else __class__.BACKWARD_FETCH_SUBMIT
300-
self._dump_param_ids(event_name, current_submodule.id,
300+
self._dump_param_ids(event_name, current_submodule.ds_id,
301301
[p.ds_id for p in params_to_fetch if p.ds_status == ZeroParamStatus.NOT_AVAILABLE])
302302
self.__profiler.start_event(event_name)
303303
# kick off all gather for params in the immediately required submodule
@@ -314,7 +314,7 @@ def fetch_sub_module(self, current_submodule: Module, forward: bool) -> None:
314314
fast_fetch = self.fast_sharding_for_leaf_module and z3_leaf_module(current_submodule)
315315
# wait for parameters in the immediately needed submodule to become available
316316
for param in params_to_fetch:
317-
param.ds_active_sub_modules.add(current_submodule.id)
317+
param.ds_active_sub_modules.add(current_submodule.ds_id)
318318
if logger.isEnabledFor(logging.DEBUG):
319319
debug_rank0(f"-wait: {param.ds_summary()}")
320320
if param in self.__inflight_param_registry:
@@ -358,7 +358,7 @@ def fetch_sub_module(self, current_submodule: Module, forward: bool) -> None:
358358
if discarded_from_prefetch_queue != params_not_already_fetched:
359359
raise RuntimeError(
360360
f"tracing error at step {self.__step_id}: \n"
361-
f"module id: {current_submodule.id}, training: {current_submodule.training}\n"
361+
f"module id: {current_submodule.ds_id}, training: {current_submodule.training}\n"
362362
f"expected the next {len(params_not_already_fetched)} parameters in the "
363363
f"parameter fetch queue to be {tuple(p.ds_summary(use_debug_name=True) for p in params_not_already_fetched)} \n"
364364
f"but got \n {tuple(p.ds_summary(use_debug_name=True) for p in discarded_from_prefetch_queue)}.")
@@ -425,7 +425,7 @@ def release_sub_module(self, submodule: Module) -> None:
425425
empty_buffer = torch.empty(1, device=get_accelerator().current_device())
426426

427427
for param in iter_params(submodule, recurse=z3_leaf_module(submodule)):
428-
param.ds_active_sub_modules.discard(submodule.id)
428+
param.ds_active_sub_modules.discard(submodule.ds_id)
429429
if param.ds_id in params_to_release and not param.is_external_param:
430430
self.__release_param(param, free_data)
431431
if not free_data:

deepspeed/runtime/zero/stage3.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ def unwrap_model_for_generation(model):
102102
optimizer_offload = model.optimizer.parameter_offload
103103
elif model.optimizer is not None:
104104
optimizer_offload = model.optimizer
105-
optimizer_offload._register_hooks_recursively(optimizer_offload.module)
105+
optimizer_offload._register_deepspeed_module(optimizer_offload.module)
106106
return
107107

108108

tests/unit/runtime/zero/test_zero.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1673,3 +1673,37 @@ def test(self, prefetch_ratio, zero_stage=3):
16731673
with torch.no_grad():
16741674
for batch in data_loader:
16751675
loss = model(batch[0], batch[1])
1676+
1677+
1678+
# Avoid overwriting client module id
1679+
# https://github.com/microsoft/DeepSpeed/issues/6772
1680+
class TestZero3ClientModuleID(DistributedTest):
1681+
world_size = 2
1682+
1683+
def test_client_module_id(self):
1684+
config_dict = {
1685+
"train_micro_batch_size_per_gpu": 1,
1686+
"steps_per_print": 1,
1687+
"optimizer": {
1688+
"type": "Adam",
1689+
},
1690+
"zero_optimization": {
1691+
"stage": 3
1692+
},
1693+
}
1694+
1695+
class MyModel(torch.nn.Module):
1696+
1697+
def __init__(self):
1698+
super().__init__()
1699+
self.id = 3 # ID arbitrary client usage, e.g. GPU placement
1700+
self.fc = Linear(128, 128)
1701+
1702+
def forward(self, x):
1703+
return self.fc(x)
1704+
1705+
model = MyModel()
1706+
pre_init_m_id = model.id
1707+
model, _, _, _ = deepspeed.initialize(model=model, model_parameters=model.parameters(), config=config_dict)
1708+
post_init_m_id = model.id
1709+
assert pre_init_m_id == post_init_m_id

0 commit comments

Comments
 (0)