diff --git a/neural_compressor/jax/quantization/layers_static.py b/neural_compressor/jax/quantization/layers_static.py index 01c86bb9bda..91d4e9a647b 100644 --- a/neural_compressor/jax/quantization/layers_static.py +++ b/neural_compressor/jax/quantization/layers_static.py @@ -243,6 +243,18 @@ def post_quantization_cleanup(self): None: Cleans up observers and sets quantized call. """ self._tracker.unlock() + if not self._is_quantized: + # Clean up observer only if it exists + if hasattr(self, "input_observer"): + if hasattr(self, "_layers") and self.input_observer in self._layers: + self._layers.remove(self.input_observer) + # Set call to pass-through/original + if hasattr(self, "call"): + # pass through + pass + self._const_variables = [] + self._tracker.lock() + return if hasattr(self, "_layers") and hasattr(self, "input_observer"): if self.input_observer in self._layers: self._layers.remove(self.input_observer) @@ -447,6 +459,16 @@ def post_quantization_cleanup(self): None: Cleans up observers and original weights. """ self._tracker.unlock() + if not self._is_quantized: + if hasattr(self, "input_observer"): + if hasattr(self, "_layers") and self.input_observer in self._layers: + self._layers.remove(self.input_observer) + # Set call to pass-through/original + if hasattr(self, "call"): + pass + self._const_variables = [] + self._tracker.lock() + return if hasattr(self, "_kernel") and self._kernel in self._trainable_variables: self._trainable_variables.remove(self._kernel) del self._kernel