Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions neural_compressor/jax/quantization/layers_static.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,18 @@ def post_quantization_cleanup(self):
None: Cleans up observers and sets quantized call.
"""
self._tracker.unlock()
if not self._is_quantized:
# Clean up observer only if it exists
if hasattr(self, "input_observer"):
if hasattr(self, "_layers") and self.input_observer in self._layers:
self._layers.remove(self.input_observer)
# Set call to pass-through/original
if hasattr(self, "call"):
# pass through
pass
self._const_variables = []
self._tracker.lock()
return
Comment on lines +246 to +257
Copy link

Copilot AI Apr 15, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new if not self._is_quantized: ... return short-circuits post_quantization_cleanup() in the normal successful quantization path because _is_quantized is only set to True at the end of this method. This prevents switching self.call to call_symmetric/call_asymmetric and prevents converting const vars, so static quantization will effectively never activate (and deserialized models will still have call() pointing at input_observer, which may not exist). Consider removing this early-return and instead gating the skip-path on a separate flag set by convert() when calibration fails (e.g., _skip_quantization=True), where you also explicitly set call to a pass-through implementation and delete/clear input_observer consistently.

Copilot uses AI. Check for mistakes.
Comment on lines +246 to +257
Copy link

Copilot AI Apr 15, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change is meant to handle the “activation scale is inf (e.g., single-sample calibration) + const_scale/const_weight” scenario, but there is no regression test covering that failure mode. Since the repo already has JAX quantization pytest coverage (e.g., test/jax/test_save_load.py), please add a unit/integration test that calibrates with a 1-sample dataset that triggers inf scale and asserts the layer/model still runs correctly (no observer-dependent call path, no const var conversion when quantization is skipped).

Copilot uses AI. Check for mistakes.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Early return just disable converting weight and scale to const (attributes) in line 264 and later

if hasattr(self, "_layers") and hasattr(self, "input_observer"):
if self.input_observer in self._layers:
self._layers.remove(self.input_observer)
Expand Down Expand Up @@ -447,6 +459,16 @@ def post_quantization_cleanup(self):
None: Cleans up observers and original weights.
"""
self._tracker.unlock()
if not self._is_quantized:
if hasattr(self, "input_observer"):
if hasattr(self, "_layers") and self.input_observer in self._layers:
self._layers.remove(self.input_observer)
# Set call to pass-through/original
if hasattr(self, "call"):
pass
self._const_variables = []
self._tracker.lock()
return
Comment on lines +462 to +471
Copy link

Copilot AI Apr 15, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same issue here: post_quantization_cleanup() now returns immediately whenever _is_quantized is False, but _is_quantized is only set to True at the end of this method (and convert() does not set it on success). This means the cleanup never switches self.call to call_fp8/call_int8 and will break both the static quantization workflow and prepare_deserialized_quantized_model() (which doesn't create input_observer, so leaving call() pointing at the observer path will raise at runtime). Instead of checking _is_quantized here, use a dedicated “skip quantization” flag set when calibration fails, and only take the early-return in that case after restoring the original call path.

Copilot uses AI. Check for mistakes.
if hasattr(self, "_kernel") and self._kernel in self._trainable_variables:
self._trainable_variables.remove(self._kernel)
del self._kernel
Expand Down
Loading