@@ -546,15 +546,10 @@ def _setup_for_real_optimizer(self):
546546 self .grad_partitions_flat_buffer = get_accelerator ().pin_memory (self .grad_partitions_flat_buffer )
547547
548548 offset = 0
549- max_partition_numel = 0
550549 for param in all_params :
551550 self .__param_id_to_grad_partition [param .ds_id ] = self .grad_partitions_flat_buffer .narrow (
552551 0 , offset , param .partition_numel ())
553552 offset += param .partition_numel ()
554- max_partition_numel = max (max_partition_numel , param .partition_numel ())
555- if self .offload_optimizer :
556- self .pinned_grad_buffer : Tensor = get_accelerator ().pin_memory (
557- torch .empty (max_partition_numel , device = self .device ))
558553
559554 def _link_all_hp_params (self ):
560555 for p in self .module .parameters ():
@@ -1510,13 +1505,9 @@ def partition_grads(self, params_to_release: List[Parameter], grad_partitions: L
15101505 offload_fp32_gradients [i ].append (grad_buffer .float ())
15111506 offload_fp32_offsets [i ].append (dest_offset )
15121507 else :
1513- buffer_numel = grad_buffer .numel ()
15141508 fp32_grad_tensor = self .fp32_partitioned_groups_flat [i ].grad .narrow (
1515- 0 , dest_offset , buffer_numel )
1516- self .pinned_grad_buffer [:buffer_numel ].copy_ (
1517- grad_buffer .to (dtype = torch .float32 , non_blocking = True ))
1518- get_accelerator ().synchronize ()
1519- fp32_grad_tensor .copy_ (self .pinned_grad_buffer [:buffer_numel ], non_blocking = True )
1509+ 0 , dest_offset , grad_buffer .numel ())
1510+ fp32_grad_tensor .copy_ (grad_buffer .float ())
15201511
15211512 # free the gradient
15221513 if not get_accelerator ().is_synchronized_device ():
@@ -2661,11 +2652,9 @@ def _rigid_load_state_dict(self, state_dict, load_optimizer_states=True):
26612652 self .optimizer .load_state_dict (state_dict [OPTIMIZER_STATE_DICT ])
26622653 self ._clear_fp32_optimizer_param_groups ()
26632654
2664- if self .swap_optimizer or self . params_in_nvme_and_cpu :
2655+ if self .swap_optimizer :
26652656 # Purge the swapped optimizer state, it was initialized to the freshly created model and not the checkpoint
2666- for swap_info in self .optimizer_swapper .swap_params_info .values ():
2667- swap_info .tensors = [swap_info .tensors [0 ]]
2668- swap_info .has_state_tensors = False
2657+ self .optimizer_swapper .purge_state ()
26692658
26702659 if self .swap_optimizer :
26712660 # Touch all parameters to synchronize all buffers
@@ -2782,11 +2771,9 @@ def load_hp_checkpoint_state_from_checkpoint_dir_stage3(self, checkpoint_dir, pa
27822771 else :
27832772 optim_sd [OPTIMIZER_STATE_DICT ]['state' ][0 ][key ] = key_tensor
27842773
2785- if self .swap_optimizer or self . params_in_nvme_and_cpu :
2774+ if self .swap_optimizer :
27862775 # Purge the swapped optimizer state, it was initialized to the freshly created model and not the checkpoint
2787- for swap_info in self .optimizer_swapper .swap_params_info .values ():
2788- swap_info .tensors = [swap_info .tensors [0 ]]
2789- swap_info .has_state_tensors = False
2776+ self .optimizer_swapper .purge_state ()
27902777
27912778 if self .swap_optimizer :
27922779 # Touch all parameters to synchronize all buffers
0 commit comments