[Bugfix] [Offloading] Save disk-offloaded buffers, Save converted weights#46902
[Bugfix] [Offloading] Save disk-offloaded buffers, Save converted weights#46902kylesayrs wants to merge 10 commits into
Conversation
fbe41de to
54dfbf8
Compare
SunMarc
left a comment
There was a problem hiding this comment.
Thanks, left a couple of comments but I think @Cyrilvallez might have better ideas on how to deal with those as he's the one who coded this !
| filename = os.path.join(save_directory, shard_file) | ||
| shard_state_dict = {} | ||
| for tensor_name in tensor_names: | ||
| for tensor_name in sorted(tensor_names): |
There was a problem hiding this comment.
Note that
load_offloaded_parametermay load multiple weights for a single tensor.
While it is possible to overload CPU memory by loading parameters in a bad order,
in practicesplit_torch_state_dict_into_shardspreserves weight locality
sorting helps reduce the chances that bad load ordering occurs.
An example of bad load ordering would be
"layers.0.experts.0.up_proj" -> loads "layers.0.experts.gate_up_proj"
"layers.1.experts.0.up_proj" -> loads "layers.1.experts.gate_up_proj"
"layers.2.experts.0.up_proj" -> loads "layers.2.experts.gate_up_proj"
In this scenario, 3 separate gate_up_proj weights have been loaded onto cpu, but only 3 shard weights have been consumed by state_dict.pop.
Sorting reduces the chances that split_torch_state_dict_into_shards gives an adversarially bad ordering. It doesn't fix adversarially bad ordering between shards, but there's not much we can do about that.
f8b104d to
41e9caf
Compare
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
7d663f8 to
dbc8c39
Compare
CI recapDashboard: View test results in Grafana |
|
Hey @kylesayrs! I took the liberty to open #47018 to fix the issue, I believe it is simpler and more robust in general! |
Purpose
offload_buffers=TrueChanges
get_parametercall withget_parameter_or_buffercallload_offloaded_parameterto load a state dict of all weights associated with the checkpoint weightis_offloadedflagTesting
Able to save disk-offloaded models with conversion mappings now
Used these changes to quantize RedHatAI/DeepSeek-V4-Pro-NVFP4-FP8 and RedHatAI/GLM-5.2-NVFP4-FP8
Suggested Reviewers