diff --git a/src/pquant/pruning_methods/fitcompress.py b/src/pquant/pruning_methods/fitcompress.py index f2b3b5c..7129006 100644 --- a/src/pquant/pruning_methods/fitcompress.py +++ b/src/pquant/pruning_methods/fitcompress.py @@ -14,9 +14,10 @@ def __init__(self, config, *args, **kwargs): self.is_finetuning = False def build(self, input_shape): - self.mask = self.add_weight(shape=input_shape, initializer="ones", trainable=False) + self.mask = self.add_weight(name="compression_mask",shape=input_shape,initializer="ones",trainable=False) super().build(input_shape) + @tf.function(reduce_retracing=True) def call(self, weight): return self.mask * weight