From cd75314329d4970e5539921c8eabc88ea7ca5e71 Mon Sep 17 00:00:00 2001 From: Nitesh Kumar <166297874+niteshg97@users.noreply.github.com> Date: Wed, 20 May 2026 22:08:32 +0530 Subject: [PATCH] Rename weight variable to compression_mask and added TensorFlow graph compilation Apply element wise compression masking. Adds tensorFlow graph compilation to FITCompress.call() to reduce eager execution overhead during repeated masking operations in training/inference workloads. --- src/pquant/pruning_methods/fitcompress.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/pquant/pruning_methods/fitcompress.py b/src/pquant/pruning_methods/fitcompress.py index f2b3b5c..7129006 100644 --- a/src/pquant/pruning_methods/fitcompress.py +++ b/src/pquant/pruning_methods/fitcompress.py @@ -14,9 +14,10 @@ def __init__(self, config, *args, **kwargs): self.is_finetuning = False def build(self, input_shape): - self.mask = self.add_weight(shape=input_shape, initializer="ones", trainable=False) + self.mask = self.add_weight(name="compression_mask",shape=input_shape,initializer="ones",trainable=False) super().build(input_shape) + @tf.function(reduce_retracing=True) def call(self, weight): return self.mask * weight