Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.

Commit c4664fd

Browse files
KSGulinbfineran
authored andcommitted
Mask creator int -> round (#589)
* Fix: round for pruning index to match _threshold_from_sparsity
1 parent 35eab64 commit c4664fd

File tree

2 files changed

+2
-2
lines changed

2 files changed

+2
-2
lines changed

src/sparseml/pytorch/optim/mask_creator_pruning.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,7 @@ def create_sparsity_masks(
176176
local_rng = random.Random(42)
177177
local_rng.shuffle(rand_indices)
178178
num_elem = tensor.numel()
179-
num_mask = int(num_elem * sparsity_target)
179+
num_mask = round(num_elem * sparsity_target)
180180
rand_indices = rand_indices[:num_mask]
181181
rand_indices = tensor.new_tensor(rand_indices, dtype=torch.int64)
182182
zero_indices = zero_indices[rand_indices, :]

src/sparseml/pytorch/sparsification/pruning/mask_creator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ def create_sparsity_masks(
140140
local_rng = random.Random(42)
141141
local_rng.shuffle(rand_indices)
142142
num_elem = tensor.numel()
143-
num_mask = int(num_elem * sparsity_target)
143+
num_mask = round(num_elem * sparsity_target)
144144
rand_indices = rand_indices[:num_mask]
145145
rand_indices = tensor.new_tensor(rand_indices, dtype=torch.int64)
146146
zero_indices = zero_indices[rand_indices, :]

0 commit comments

Comments
 (0)