Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.

Commit 1dcbeed

Browse files
authored
sparseml.transformers.export_onnx CPU only support (#516) (#517)
1 parent 38fd61d commit 1dcbeed

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

src/sparseml/transformers/utils/export.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,8 @@ def export_transformer_to_onnx(
159159
_LOGGER.warning(f"recipe not found under {recipe_path}")
160160

161161
# load weights
162-
state_dict = torch.load(os.path.join(model_path, WEIGHTS_NAME))
162+
load_kwargs = {} if torch.cuda.is_available() else {"map_location": "cpu"}
163+
state_dict = torch.load(os.path.join(model_path, WEIGHTS_NAME), **load_kwargs)
163164
model.load_state_dict(state_dict)
164165

165166
# create fake model input

0 commit comments

Comments
 (0)