@@ -190,6 +190,10 @@ def decompress_module(self, module: Module):
190190 for name , parameter in module .named_parameters ():
191191 compressed_data [name ] = parameter
192192
193+ # Save references to original parameters before decompression
194+ original_scale = compressed_data .get ("weight_scale" )
195+ original_zp = compressed_data .get ("weight_zero_point" )
196+
193197 # NOTE: decompress_weight may modify compressed_data dict in-place
194198 # This is subtle but allows us to update the module's qparams with
195199 # the unpacked values.
@@ -198,9 +202,15 @@ def decompress_module(self, module: Module):
198202 compressed_data = compressed_data , quantization_args = quantization_args
199203 ).to (device )
200204
201- # Update module's parameters if they were unpacked/upcast during decompression
202- for param_name in ["weight_zero_point" , "weight_scale" ]:
203- if param_name in compressed_data and hasattr (module , param_name ):
205+ # Update module's parameters only if they were actually modified during decompression
206+ for param_name , original_param in [
207+ ("weight_scale" , original_scale ),
208+ ("weight_zero_point" , original_zp ),
209+ ]:
210+ if (
211+ param_name in compressed_data
212+ and compressed_data [param_name ] is not original_param
213+ ):
204214 # Delete the old parameter and register the updated one
205215 delete_offload_parameter (module , param_name )
206216 offload_device = get_offloaded_device (module )
0 commit comments