Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 29 additions & 1 deletion neural_compressor/jax/quantization/layers_dynamic.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import numpy as np
from jax import numpy as jnp
from keras import ops
from keras.layers import Dense, EinsumDense, MultiHeadAttention
from keras.layers import Conv2D, Dense, EinsumDense, MultiHeadAttention
from keras_hub.layers import ReversibleEmbedding
from keras_hub.src.models.gemma3.gemma3_attention import CachedGemma3Attention
from keras_hub.src.models.gemma3.gemma3_vision_encoder import Gemma3VisionAttention
Expand Down Expand Up @@ -299,6 +299,34 @@ class QDynamicEinsumDense(QDynamicDenseMixin, EinsumDense):
verify_api(EinsumDense, QDynamicEinsumDense, "call")


class QDynamicConv2DMixin(QDynamicDenseMixin, Conv2D):
"""Mixin that adds dynamic quantization to Conv2D layers."""

def call(self, inputs):
"""Apply quantized input processing before the convolution computation.

Args:
inputs (jnp.ndarray): Input tensor.
training (Optional[bool]): Training mode flag.

Returns:
jnp.ndarray: Layer output tensor.
"""
x = self.input_qdq(inputs)
x = super(QDynamicDenseMixin, self).call(x)
return x


@register_dynamic_quantized_layer(Conv2D)
class QDynamicConv2D(QDynamicConv2DMixin, Conv2D):
"""Dynamically quantized Conv2D layer."""

pass


verify_api(Conv2D, QDynamicConv2D, "call")


@register_dynamic_quantized_layer(MultiHeadAttention)
class QDynamicMultiHeadAttention(SaveableLayerMixin, MultiHeadAttention):
"""Dynamically quantized MultiHeadAttention layer."""
Expand Down
70 changes: 69 additions & 1 deletion neural_compressor/jax/quantization/layers_static.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import numpy as np
from jax import numpy as jnp
from keras import ops
from keras.layers import Dense, EinsumDense, MultiHeadAttention
from keras.layers import Conv2D, Dense, EinsumDense, MultiHeadAttention
from keras_hub.layers import ReversibleEmbedding, RotaryEmbedding
from keras_hub.src.models.gemma3.gemma3_attention import CachedGemma3Attention
from keras_hub.src.models.gemma3.gemma3_vision_encoder import Gemma3VisionAttention
Expand Down Expand Up @@ -563,6 +563,74 @@ class QStaticEinsumDense(QStaticDenseMixin, EinsumDense):
verify_api(EinsumDense, QStaticEinsumDense, "call")


class QStaticConv2DMixin(QStaticDenseMixin, Conv2D):
"""Mixin that adds static quantization to Conv2D layers."""

def call(self, inputs):
"""Run calibration observer before the dense computation.

Args:
inputs (jnp.ndarray): Input tensor.
training (Optional[bool]): Training mode flag.

Returns:
jnp.ndarray: Layer output tensor.
"""
x = self.input_observer(inputs)
x = super(QStaticDenseMixin, self).call(x)
return x

def call_fp8(self, inputs):
"""Apply FP8 quantize-dequantize before dense computation.

Args:
inputs (jnp.ndarray): Input tensor.
training (Optional[bool]): Training mode flag.

Returns:
jnp.ndarray: Layer output tensor.
"""
if self.const_scale:
a_scale = self.a_scale
else:
a_scale = self.a_scale.value
x = self.aquantfun(inputs, a_scale)
x = self.adequantfun(x, a_scale)
x = super(QStaticDenseMixin, self).call(x)
return x

def call_int8(self, inputs):
"""Apply int8 quantize-dequantize before dense computation.

Args:
inputs (jnp.ndarray): Input tensor.
training (Optional[bool]): Training mode flag.

Returns:
jnp.ndarray: Layer output tensor.
"""
if self.const_scale:
a_scale = self.a_scale
a_zero_point = self.a_zero_point
else:
a_scale = self.a_scale.value
a_zero_point = self.a_zero_point.value
x = self.aquantfun(inputs, a_scale, a_zero_point)
x = self.adequantfun(x, a_scale, a_zero_point)
x = super(QStaticDenseMixin, self).call(x)
return x


@register_static_quantized_layer(Conv2D)
class QStaticConv2d(QStaticConv2DMixin, Conv2D):
"""Statically quantized Conv2D layer."""

pass


verify_api(Conv2D, QStaticConv2d, "call")


@register_static_quantized_layer(MultiHeadAttention)
class QStaticMultiHeadAttention(SaveableLayerMixin, MultiHeadAttention):
"""Statically quantized MultiHeadAttention layer."""
Expand Down
Loading