@@ -175,17 +175,17 @@ def trace_prologue(self, sub_module: Module) -> None:
175175 # sub_module must match expectation else invalidate trace cache
176176 if len (self .__submodule_order ) <= self .__step_id :
177177 print_rank_0 (
178- f"Invalidate trace cache @ step { self .__step_id } and module { sub_module .id } : "
178+ f"Invalidate trace cache @ step { self .__step_id } and module { sub_module .ds_id } : "
179179 f"cache has only { len (self .__submodule_order )} modules" ,
180180 force = True )
181181 self ._invalidate_trace ()
182182 return
183183
184184 if sub_module != self .__submodule_order [self .__step_id ]:
185- expected_module_id = self .__submodule_order [self .__step_id ].id
185+ expected_module_id = self .__submodule_order [self .__step_id ].ds_id
186186 print_rank_0 (
187187 f"Invalidate trace cache @ step { self .__step_id } : "
188- f"expected module { expected_module_id } , but got module { sub_module .id } " ,
188+ f"expected module { expected_module_id } , but got module { sub_module .ds_id } " ,
189189 force = True )
190190 self ._invalidate_trace ()
191191
@@ -199,7 +199,7 @@ def record_module(self, sub_module: Module) -> None:
199199 raise RuntimeError (f"attempted to record trace when status = { self .__trace_mode } " )
200200
201201 self .__submodule_order .append (sub_module )
202- self .__step_id_module_fetched_for [sub_module .id ].append (self .__step_id )
202+ self .__step_id_module_fetched_for [sub_module .ds_id ].append (self .__step_id )
203203
204204 def record_parameters (self , sub_module : Module ) -> None :
205205 if is_compiling ():
@@ -208,7 +208,7 @@ def record_parameters(self, sub_module: Module) -> None:
208208 if not self .is_record_trace ():
209209 raise RuntimeError (f"attempted to record trace when status = { self .__trace_mode } " )
210210
211- step_id = self .__step_id_module_fetched_for [sub_module .id ].popleft ()
211+ step_id = self .__step_id_module_fetched_for [sub_module .ds_id ].popleft ()
212212 for param in sorted (set (iter_params (sub_module , recurse = z3_leaf_module (sub_module ))), key = lambda p : p .ds_id ):
213213 self .__param_order .append (__class__ .__ParamInTrace (param = param , step_id_last_used_at = step_id ))
214214
@@ -228,7 +228,7 @@ def reset_step(self) -> None:
228228
229229 if not self .is_complete_trace (): # not self.trace_complete:
230230 # Make sure that recorded submodule orders are identical across ranks
231- assert_ints_same_as_other_ranks ([m .id for m in self .__submodule_order ])
231+ assert_ints_same_as_other_ranks ([m .ds_id for m in self .__submodule_order ])
232232
233233 if self .is_record_trace ():
234234 # Successfully recorded a trace
@@ -241,7 +241,7 @@ def reset_step(self) -> None:
241241 self .__param_order = tuple (self .__param_order ) # freeze
242242 self .__trace_mode = ZeRoTraceMode .COMPLETE
243243 print_rank_0 (
244- f"completed record trace of { len (self .__submodule_order )} sub modules: { [m .id for m in self .__submodule_order ]} " ,
244+ f"completed record trace of { len (self .__submodule_order )} sub modules: { [m .ds_id for m in self .__submodule_order ]} " ,
245245 force = False )
246246 else :
247247 # Enable trace recording for next forward/backward pass
@@ -284,7 +284,7 @@ def fetch_sub_module(self, current_submodule: Module, forward: bool) -> None:
284284 """
285285 if logger .isEnabledFor (logging .DEBUG ):
286286 debug_rank0 (
287- f"{ self .__step_id } : M{ current_submodule .id } ({ type (current_submodule ).__name__ } ) P{ [p .ds_id for p in iter_params (current_submodule , recurse = z3_leaf_module (current_submodule ))]} "
287+ f"{ self .__step_id } : M{ current_submodule .ds_id } ({ type (current_submodule ).__name__ } ) P{ [p .ds_id for p in iter_params (current_submodule , recurse = z3_leaf_module (current_submodule ))]} "
288288 + str ({
289289 "avail" : f"{ self .__n_available_params :.1e} " ,
290290 "queue_sz" : f"{ len (self .__param_queue or [])} " ,
@@ -297,7 +297,7 @@ def fetch_sub_module(self, current_submodule: Module, forward: bool) -> None:
297297
298298 if fetch_numel > 0 :
299299 event_name = __class__ .FORWARD_FETCH_SUBMIT if forward else __class__ .BACKWARD_FETCH_SUBMIT
300- self ._dump_param_ids (event_name , current_submodule .id ,
300+ self ._dump_param_ids (event_name , current_submodule .ds_id ,
301301 [p .ds_id for p in params_to_fetch if p .ds_status == ZeroParamStatus .NOT_AVAILABLE ])
302302 self .__profiler .start_event (event_name )
303303 # kick off all gather for params in the immediately required submodule
@@ -314,7 +314,7 @@ def fetch_sub_module(self, current_submodule: Module, forward: bool) -> None:
314314 fast_fetch = self .fast_sharding_for_leaf_module and z3_leaf_module (current_submodule )
315315 # wait for parameters in the immediately needed submodule to become available
316316 for param in params_to_fetch :
317- param .ds_active_sub_modules .add (current_submodule .id )
317+ param .ds_active_sub_modules .add (current_submodule .ds_id )
318318 if logger .isEnabledFor (logging .DEBUG ):
319319 debug_rank0 (f"-wait: { param .ds_summary ()} " )
320320 if param in self .__inflight_param_registry :
@@ -358,7 +358,7 @@ def fetch_sub_module(self, current_submodule: Module, forward: bool) -> None:
358358 if discarded_from_prefetch_queue != params_not_already_fetched :
359359 raise RuntimeError (
360360 f"tracing error at step { self .__step_id } : \n "
361- f"module id: { current_submodule .id } , training: { current_submodule .training } \n "
361+ f"module id: { current_submodule .ds_id } , training: { current_submodule .training } \n "
362362 f"expected the next { len (params_not_already_fetched )} parameters in the "
363363 f"parameter fetch queue to be { tuple (p .ds_summary (use_debug_name = True ) for p in params_not_already_fetched )} \n "
364364 f"but got \n { tuple (p .ds_summary (use_debug_name = True ) for p in discarded_from_prefetch_queue )} ." )
@@ -425,7 +425,7 @@ def release_sub_module(self, submodule: Module) -> None:
425425 empty_buffer = torch .empty (1 , device = get_accelerator ().current_device ())
426426
427427 for param in iter_params (submodule , recurse = z3_leaf_module (submodule )):
428- param .ds_active_sub_modules .discard (submodule .id )
428+ param .ds_active_sub_modules .discard (submodule .ds_id )
429429 if param .ds_id in params_to_release and not param .is_external_param :
430430 self .__release_param (param , free_data )
431431 if not free_data :
0 commit comments