diff --git a/.azure-pipelines/scripts/ut/run_3x_pt_hpu.sh b/.azure-pipelines/scripts/ut/run_3x_pt_hpu.sh index 22166dd5e86..e19a7982767 100644 --- a/.azure-pipelines/scripts/ut/run_3x_pt_hpu.sh +++ b/.azure-pipelines/scripts/ut/run_3x_pt_hpu.sh @@ -14,6 +14,8 @@ echo "##[group]set up UT env..." export LD_LIBRARY_PATH=/usr/local/lib/:$LD_LIBRARY_PATH sed -i '/^auto-round/d;/^torchvision/d' /neural-compressor/test/torch/requirements.txt pip install -r /neural-compressor/test/torch/requirements.txt +pip install deepspeed@git+https://github.com/HabanaAI/DeepSpeed.git@main --no-deps +pip install msgpack hjson ninja # deepspeed dependency pip install auto-round-hpu pip install pytest-cov pytest-html pytest-html-merger beautifulsoup4==4.13.5 echo "##[endgroup]" diff --git a/examples/helloworld/fp8_example/b2b_unitest_2_steps.py b/examples/helloworld/fp8_example/b2b_unitest_2_steps.py new file mode 100644 index 00000000000..c8aa22b7da9 --- /dev/null +++ b/examples/helloworld/fp8_example/b2b_unitest_2_steps.py @@ -0,0 +1,94 @@ + +import argparse +import math + +import torch +import habana_frameworks.torch.core as htcore +from torch.nn import Parameter, init + +# Initialize HPU environment (must be called before HPU operations) +htcore.hpu_set_env() + +from neural_compressor.torch.quantization import FP8Config, convert, finalize_calibration, prepare + +torch.manual_seed(1) + + +class B2BMatmul(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, y, **kwargs): + return torch.matmul(x, y, **kwargs) + + + +class M(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.matmul = B2BMatmul() + + def forward(self, inp0, inp1): + res = self.matmul(inp0, inp1) + + return res + + +def main(): + parser = argparse.ArgumentParser( + description="Habana FP8 sample code with B2BMatmul.", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument("--quant_config", type=str, help="JSON file of quantization config") + args = parser.parse_args() + + # Build model & load config + model = M().eval() + config = FP8Config.from_json_file(args.quant_config) + + # Optional calibration preparation + if config.measure: + model = prepare(model, config) + + # Optional quantization + if config.quantize: + htcore.hpu_initialize() + model = convert(model, config) + print(model) + + # Create inputs and run + + with torch.no_grad(): + model.to("hpu") + + B = 6 + N = 100 + + inp0= torch.tensor([ + [1,0,0,0,0,0], # row 0 <- X[0] + [0,0,0,1,0,0], # row 1 <- X[3] + [0,1,0,0,0,0], # row 2 <- X[1] + [0,0,0,0,1,0], # row 3 <- X[4] + [0,0,0,0,0,0], # row 4 <- X[2] + [0,0,0,0,0,0], # row 5 <- X[5] + ], dtype=torch.float32).to("hpu") + + # Input for Matmul: [B, D] -> now [6, 100] + inp1 = torch.randn(B, N) + inp1[2, :] = 1000 + inp1[5, :] = 1000 + + + # Run the model + output = model(inp0, inp1) + print("Output shape:", output.shape) + print(output) + + + # Finalize calibration if measuring + if config.measure: + finalize_calibration(model) + + +if __name__ == "__main__": + main() diff --git a/examples/pytorch/nlp/huggingface_models/language-modeling/quantization/auto_round/qwen/requirements.txt b/examples/pytorch/nlp/huggingface_models/language-modeling/quantization/auto_round/qwen/requirements.txt index 008c711a563..90f91438095 100644 --- a/examples/pytorch/nlp/huggingface_models/language-modeling/quantization/auto_round/qwen/requirements.txt +++ b/examples/pytorch/nlp/huggingface_models/language-modeling/quantization/auto_round/qwen/requirements.txt @@ -1,3 +1,5 @@ loguru hf_transfer -transformers==4.57.3 \ No newline at end of file +transformers==4.57.3 +# pip install git+https://github.com/yiliu30/long-bench-eval +long-bench-eval @ git+https://github.com/yiliu30/long-bench-eval \ No newline at end of file diff --git a/examples/pytorch/nlp/huggingface_models/language-modeling/quantization/auto_round/qwen/setup.sh b/examples/pytorch/nlp/huggingface_models/language-modeling/quantization/auto_round/qwen/setup.sh index 85db8544575..21614f6d579 100644 --- a/examples/pytorch/nlp/huggingface_models/language-modeling/quantization/auto_round/qwen/setup.sh +++ b/examples/pytorch/nlp/huggingface_models/language-modeling/quantization/auto_round/qwen/setup.sh @@ -77,4 +77,4 @@ else echo "Unsupported device: $DEVICE. Supported devices are gpu and xpu." usage exit 1 -fi \ No newline at end of file +fi diff --git a/neural_compressor/torch/algorithms/autoround/autoround.py b/neural_compressor/torch/algorithms/autoround/autoround.py index 89fd552bb7d..0a0cfaac85d 100644 --- a/neural_compressor/torch/algorithms/autoround/autoround.py +++ b/neural_compressor/torch/algorithms/autoround/autoround.py @@ -218,7 +218,10 @@ def convert(self, model: torch.nn.Module, *args, **kwargs): model.autoround_config = weight_config return rounder.save_quantized(output_dir=self.output_dir, inplace=True) else: # pragma: no cover - rounder.quantize_and_save(output_dir=self.output_dir, format=self.export_format, inplace=True) + _, quantized_model_path = rounder.quantize_and_save( + output_dir=self.output_dir, format=self.export_format, inplace=True + ) + self.output_dir = quantized_model_path model = rounder.model model.autoround_config = rounder.layer_config @@ -236,8 +239,8 @@ def convert(self, model: torch.nn.Module, *args, **kwargs): import transformers # pylint: disable=E0401 model = transformers.AutoModelForCausalLM.from_pretrained(self.output_dir) - except: - pass + except Exception as e: + logger.error(f"Error reloading model: {e}") return model diff --git a/neural_compressor/torch/algorithms/fp8_quant/_core/common.py b/neural_compressor/torch/algorithms/fp8_quant/_core/common.py index 6ddeabdf9b8..f0b165472bc 100644 --- a/neural_compressor/torch/algorithms/fp8_quant/_core/common.py +++ b/neural_compressor/torch/algorithms/fp8_quant/_core/common.py @@ -48,7 +48,7 @@ def dequant_original_fp8_weight_if_needed(mod: torch.nn.Module, param: torch.Ten else: raise RuntimeError(f"Got fp8 weight for {mod}, but dequant function is None, please check.") else: - RuntimeError(f"Got fp8 weight for {mod}, but dequant function is not found, please check.") + raise RuntimeError(f"Got fp8 weight for {mod}, but dequant function is not found, please check.") return param @@ -326,14 +326,27 @@ def get_device_type_for_scales(mod): return config["device_for_scales"] -@lru_cache -def is_runtime_scale_patching(): - """Check whether runtime scale patching is enabled via environment variable. +class RuntimeState(Enum): + STATIC = 0 + RUNTIME_SCALE_PATCHING = 1 + DYNAMIC_QUANTIZATION = 2 - Returns: - bool: True when runtime patching is enabled. - """ - return os.getenv("RUNTIME_SCALE_PATCHING", "False").lower() in ["true", "1"] + +_runtime_state = RuntimeState.STATIC + +@lru_cache() +def set_runtime_state(is_dynamic_quantization): + global _runtime_state + if is_dynamic_quantization: + _runtime_state = RuntimeState.DYNAMIC_QUANTIZATION + elif (os.getenv("RUNTIME_SCALE_PATCHING", "False").lower() in ["true", "1"]): + _runtime_state = RuntimeState.RUNTIME_SCALE_PATCHING + else: + _runtime_state = RuntimeState.STATIC + + +def is_runtime_scale_patching(): + return _runtime_state == RuntimeState.RUNTIME_SCALE_PATCHING #TODO [SW-224612]: Use cguid to calc scales and remove the check @lru_cache diff --git a/neural_compressor/torch/algorithms/fp8_quant/_core/measure.py b/neural_compressor/torch/algorithms/fp8_quant/_core/measure.py index 7de353d5555..f7a96c811ab 100644 --- a/neural_compressor/torch/algorithms/fp8_quant/_core/measure.py +++ b/neural_compressor/torch/algorithms/fp8_quant/_core/measure.py @@ -131,6 +131,8 @@ def register_patched_measure_modules(model, mod_list, observer_class, d_shapes=N d_shapes (dict, optional): Defaults to None. """ top_level_config = get_hqt_config(model) + if top_level_config is None: + raise ValueError("HQT config is not initialized on the model.") config = top_level_config.cfg setup_calibration_counter(model, config) skip_outputs_measurements = config["measure_exclude"] & (MeasureExclude.OUTPUT | MeasureExclude.ALL) diff --git a/neural_compressor/torch/algorithms/fp8_quant/_core/patching_common.py b/neural_compressor/torch/algorithms/fp8_quant/_core/patching_common.py index 539513f60de..75affa55454 100644 --- a/neural_compressor/torch/algorithms/fp8_quant/_core/patching_common.py +++ b/neural_compressor/torch/algorithms/fp8_quant/_core/patching_common.py @@ -57,6 +57,7 @@ def create_mod_info_recursion(parent): "linear": ModuleType(1, ["weight"], 1, False), "row_parallel_linear": ModuleType(1, ["weight"], 2, True), "matmul": ModuleType(2, [], 1, False), + "b2b_matmul": ModuleType(2, [], 1, True), "kv_cache": ModuleType(1, [], 1, False), "softmax": ModuleType(1, [], 1, True), "fused_sdpa": ModuleType(3, [], 2, True), @@ -66,7 +67,8 @@ def create_mod_info_recursion(parent): _mod_default_dict = { - "Matmul": ModuleInfo("matmul", PatchedMatmul), + "Matmul": ModuleInfo("matmul", PatchedMatmul, supports_dynamic_quantization=True), + "B2BMatmul": ModuleInfo("b2b_matmul", PatchedMatmul, supports_dynamic_quantization=True), "Linear": ModuleInfo("linear", PatchedLinear, supports_dynamic_quantization=True), "ParallelLMHead": ModuleInfo("linear", PatchedParallelLMHead, supports_dynamic_quantization=True), "RowParallelLinear": ModuleInfo("row_parallel_linear", PatchedRowParallelLinear, supports_dynamic_quantization=True), @@ -75,7 +77,7 @@ def create_mod_info_recursion(parent): "QKVParallelLinear": ModuleInfo("linear", PatchedColumnParallelLinear, supports_dynamic_quantization=True), "FalconLinear": ModuleInfo("linear", PatchedLinear, supports_dynamic_quantization=True), "KVCache": ModuleInfo("kv_cache", PatchedKVCache), - "VLLMKVCache": ModuleInfo("kv_cache", PatchedVLLMKVCache), + "VLLMKVCache": ModuleInfo("kv_cache", PatchedVLLMKVCache, supports_dynamic_quantization=True), "Conv2d": ModuleInfo("linear", PatchedConv2d), "LoRACompatibleLinear": ModuleInfo("linear", PatchedLoRACompatibleLinear, supports_dynamic_quantization=True), "LoRACompatibleConv": ModuleInfo("linear", PatchedLoRACompatibleConv), diff --git a/neural_compressor/torch/algorithms/fp8_quant/_core/quant_dequant.py b/neural_compressor/torch/algorithms/fp8_quant/_core/quant_dequant.py index 51b2a159b67..23673c00e0a 100644 --- a/neural_compressor/torch/algorithms/fp8_quant/_core/quant_dequant.py +++ b/neural_compressor/torch/algorithms/fp8_quant/_core/quant_dequant.py @@ -18,6 +18,7 @@ from abc import abstractmethod from neural_compressor.torch.utils.auto_accelerator import auto_detect_accelerator from .quantized_func_wrappers import get_quantized_func_wrapper, OP_TYPE +from .fp_utils import invert_scale cur_accelerator = auto_detect_accelerator() @@ -69,23 +70,23 @@ def extra_repr(self) -> str: class QuantDequantNone(QuantDequantBase): def __init__(self, lp_dtype, hp_dtype, *args, **kwargs): - super(QuantDequantNone, self).__init__(lp_dtype, hp_dtype, *args, **kwargs) + super().__init__(lp_dtype, hp_dtype, *args, **kwargs) def forward(self, *args, **kwargs): return args[0] def extra_repr(self) -> str: - repr = super(QuantDequantNone, self).extra_repr() + repr = super().extra_repr() return f"{repr}, doesn't quantize nor dequantize" class QuantInput(QuantDequantBase): def __init__(self, scale_inv, lp_dtype, hp_dtype, *args, **kwargs): - super(QuantInput, self).__init__(lp_dtype, hp_dtype, *args, **kwargs) + super().__init__(lp_dtype, hp_dtype, *args, **kwargs) scale_inv = scale_inv.unsqueeze(1) if (scale_inv.numel() > 1 and not self.use_qdq) else scale_inv self.register_scale("scale_inv", scale_inv, self.scale_format) if self.use_qdq: - self.register_scale("scale", 1 / self.scale_inv, self.scale_format) + self.register_scale("scale", invert_scale(self.scale_inv), self.scale_format) op_type = OP_TYPE.QUANT_PC if self.scale_format == ScaleFormat.CONST and self.scale.numel() > 1 else OP_TYPE.QUANT else: op_type = OP_TYPE.CAST_TO_FP8 @@ -106,40 +107,40 @@ def forward_qdq(self, x): ) def extra_repr(self) -> str: - repr = super(QuantInput, self).extra_repr() + repr = super().extra_repr() dtype = get_scale_dtype(self.scale_inv) return f"{repr}, scale_inv dtype={dtype}" class QuantDynamicInput(QuantDequantBase): def __init__(self, input_scales_creator, lp_dtype, hp_dtype, *args, **kwargs): - super(QuantDynamicInput, self).__init__(lp_dtype, hp_dtype, *args, **kwargs) + super().__init__(lp_dtype, hp_dtype, *args, **kwargs) self.input_scales_creator = input_scales_creator - self.cast_to_op = get_quantized_func_wrapper(OP_TYPE.CAST_TO_FP8, self.scale_format) - def calculate_scales(self, x): - scale = self.input_scales_creator.calc_scales(x, QuantTensorType.DYNAMIC) + def calculate_scales(self, x, in_scale = None): + if in_scale is None: + scale = self.input_scales_creator.calc_scales(x, QuantTensorType.DYNAMIC) + else: + scale = in_scale scale_inv = self.input_scales_creator.invert_scales(scale) return scale, scale_inv - def forward(self, x): - scale, scale_inv = self.calculate_scales(x) - + def forward(self, x, in_scale=None): + scale, scale_inv = self.calculate_scales(x, in_scale) ret = self.cast_to_op(x, scale_inv, False, False, self.lp_dtype) - return ret, scale #TODO [SW-224609]: implement forward qdq def extra_repr(self) -> str: - repr = super(QuantDynamicInput, self).extra_repr() + repr = super().extra_repr() return f"{repr} input_scales_creator={self.input_scales_creator}" class DequantOutput(QuantDequantBase): def __init__(self, scale, lp_dtype, hp_dtype, *args, **kwargs): - super(DequantOutput, self).__init__(lp_dtype, hp_dtype, *args, **kwargs) + super().__init__(lp_dtype, hp_dtype, *args, **kwargs) self.register_scale("scale", scale, self.scale_format) if self.use_qdq: op_type = OP_TYPE.DEQUANT_PC if self.scale_format == ScaleFormat.CONST and self.scale.numel() > 1 else OP_TYPE.DEQUANT @@ -163,16 +164,25 @@ def forward_qdq(self, x): ) def extra_repr(self) -> str: - repr = super(DequantOutput, self).extra_repr() + repr = super().extra_repr() dtype = get_scale_dtype(self.scale) return f"{repr}, scale dtype={dtype}" +class DequantDynamicOutput(QuantDequantBase): + def __init__(self, lp_dtype, hp_dtype, *args, **kwargs): + super().__init__(lp_dtype, hp_dtype, *args, **kwargs) + self.cast_from_op = get_quantized_func_wrapper(OP_TYPE.CAST_FROM_FP8, self.scale_format) + + def forward(self, x, scale): + return self.cast_from_op(x, scale, self.hp_dtype) + + class QuantDequant(QuantDequantBase): def __init__(self, scale_inv, lp_dtype, hp_dtype, *args, **kwargs): - super(QuantDequant, self).__init__(lp_dtype, hp_dtype, *args, **kwargs) + super().__init__(lp_dtype, hp_dtype, *args, **kwargs) self.register_scale("scale_inv", scale_inv, self.scale_format) - self.register_scale("scale", 1 / scale_inv, self.scale_format) + self.register_scale("scale", invert_scale(scale_inv), self.scale_format) self.quantize_op = ( get_quantized_func_wrapper(OP_TYPE.QUANT, self.scale_format) if self.use_qdq @@ -215,5 +225,5 @@ def forward_qdq(self, x, *args, **kwargs): return z def extra_repr(self) -> str: - repr = super(QuantDequant, self).extra_repr() + repr = super().extra_repr() return f"{repr}, Quantize, and then dequantize" diff --git a/neural_compressor/torch/algorithms/fp8_quant/_core/quantize.py b/neural_compressor/torch/algorithms/fp8_quant/_core/quantize.py index 2620310c35f..7945897ea5f 100644 --- a/neural_compressor/torch/algorithms/fp8_quant/_core/quantize.py +++ b/neural_compressor/torch/algorithms/fp8_quant/_core/quantize.py @@ -268,7 +268,8 @@ def quantize(model, mod_list): elif config.cfg["mode"] == QuantMode.LOAD: # no measurement and scale file scale_method_config = {CfgStr.ACTIVATION: ScaleMethodConfig(scale_value_type=ScaleValueType.DUMMY_SCALES), - CfgStr.WEIGHT: ScaleMethodConfig(scale_value_type=ScaleValueType.DUMMY_SCALES)} + CfgStr.WEIGHT: ScaleMethodConfig(scale_value_type=ScaleValueType.DUMMY_SCALES, + granularity=scale_method_config[CfgStr.DEFAULT][CfgStr.WEIGHT].granularity)} prepare_model_with_dummy_measurement(model, mod_list, scale_method_config, scale_config) else: raise Exception("unexpected mode, expected QuantMode.QUANTIZE or QuantMode.LOAD") \ No newline at end of file diff --git a/neural_compressor/torch/algorithms/fp8_quant/_core/quantized_func_wrappers/hpu/hpu_quantized_func_wrapper.py b/neural_compressor/torch/algorithms/fp8_quant/_core/quantized_func_wrappers/hpu/hpu_quantized_func_wrapper.py index a5a6380dc39..15a92d30fb6 100644 --- a/neural_compressor/torch/algorithms/fp8_quant/_core/quantized_func_wrappers/hpu/hpu_quantized_func_wrapper.py +++ b/neural_compressor/torch/algorithms/fp8_quant/_core/quantized_func_wrappers/hpu/hpu_quantized_func_wrapper.py @@ -43,6 +43,8 @@ def get_default_quantized_func(self): raise NotImplementedError() def get_scalar_quantized_func(self): + if is_runtime_scale_patching(): + return self.get_default_quantized_func() return self.get_default_quantized_func().scalar def get_dynamic_scalar_quantized_func(self): @@ -64,8 +66,10 @@ def get_quantized_func(self, scale_format, is_dynamic=False): else: if is_runtime_scale_patching() or scale_format == ScaleFormat.CONST: return self.get_default_quantized_func() - else: + elif scale_format == ScaleFormat.SCALAR: return self.get_scalar_quantized_func() + else: + return self.get_default_quantized_func() def __call__(self, *args, **kwargs): return self._quantized_func_(*args, **kwargs) diff --git a/neural_compressor/torch/algorithms/fp8_quant/_core/quantized_func_wrappers/xpu/xpu_quantized_func_wrapper.py b/neural_compressor/torch/algorithms/fp8_quant/_core/quantized_func_wrappers/xpu/xpu_quantized_func_wrapper.py index 89341f5127b..025535b0183 100644 --- a/neural_compressor/torch/algorithms/fp8_quant/_core/quantized_func_wrappers/xpu/xpu_quantized_func_wrapper.py +++ b/neural_compressor/torch/algorithms/fp8_quant/_core/quantized_func_wrappers/xpu/xpu_quantized_func_wrapper.py @@ -23,7 +23,7 @@ class QuantizedXPUFuncWrapperBase(QuantizedFuncWrapperBase, metaclass=ABCMeta): """ - Placeholder for base class for XPU quantized func wrapper. + Placeholder for base class for XPU (Falcon/Jaguar Shores) quantized func wrapper. """ def __init__(self, scale_format, is_dynamic=False): self._quantized_func_ = self.get_default_quantized_func() diff --git a/neural_compressor/torch/algorithms/fp8_quant/_core/scale_methods/ops_quantizer.py b/neural_compressor/torch/algorithms/fp8_quant/_core/scale_methods/ops_quantizer.py index 564878ef349..8b014283f51 100644 --- a/neural_compressor/torch/algorithms/fp8_quant/_core/scale_methods/ops_quantizer.py +++ b/neural_compressor/torch/algorithms/fp8_quant/_core/scale_methods/ops_quantizer.py @@ -17,7 +17,7 @@ from neural_compressor.torch.algorithms.fp8_quant._quant_common.quant_config import get_hqt_config, is_supported_dynamic_op from .scale_method_factory import ScaleMethodFactory, QuantTensorName, ScaleValueType from ..common import ModuleConfig, QuantTensorType -from ..quant_dequant import DequantOutput, QuantDequant, QuantDequantNone, QuantInput, QuantDynamicInput +from ..quant_dequant import DequantOutput, QuantDequant, QuantDequantNone, QuantInput, QuantDynamicInput, DequantDynamicOutput from ...utils.logger import logger from neural_compressor.torch.algorithms.fp8_quant._core.common import dequant_original_fp8_weight_if_needed from neural_compressor.torch.utils.auto_accelerator import auto_detect_accelerator @@ -37,16 +37,15 @@ def __init__(self, config, mod, measurement, params, mod_type_str): self.output_scales_creators = [] self.params_scales_creators = [] self.is_dynamic = hqt_config["dynamic_quantization"] and is_supported_dynamic_op(mod_type_str) - + self.update_module_configuration() logger.debug("%s %s", self.__class__.__name__, self.__dict__) - def get_module_configuration(self): - scale_format = get_hqt_config(self.mod).cfg["scale_format"] - use_qdq = get_hqt_config(self.mod).cfg["use_qdq"] - fake_quant = get_hqt_config(self.mod).cfg["fake_quant"] - lp_dtype = self.params["lp_dtype"] - hp_dtype = self.params["hp_dtype"] - return scale_format, use_qdq, fake_quant, lp_dtype, hp_dtype + def update_module_configuration(self): + self.scale_format = get_hqt_config(self.mod).cfg["scale_format"] + self.use_qdq = get_hqt_config(self.mod).cfg["use_qdq"] + self.fake_quant = get_hqt_config(self.mod).cfg["fake_quant"] + self.lp_dtype = self.params["lp_dtype"] + self.hp_dtype = self.params["hp_dtype"] @abstractmethod def get_scales_module_config(self) -> ModuleConfig: @@ -83,10 +82,10 @@ def calc_output_scales(self): output_scales = self.output_scales_creators[0].calc_scales(output_measurement, QuantTensorType.MEASUREMENTS) return (output_scales,) - def init_input_config(self, scales_inv, lp_dtype, hp_dtype, scale_format, use_qdq, fake_quant): - if use_qdq or fake_quant: + def init_input_config(self, scales_inv): + if self.use_qdq or self.fake_quant: input_config = [ - QuantDequant(s_inv, lp_dtype, hp_dtype, scale_format=scale_format, use_qdq=use_qdq) + QuantDequant(s_inv, self.lp_dtype, self.hp_dtype, scale_format=self.scale_format, use_qdq=self.use_qdq) for s_inv in scales_inv ] else: @@ -94,10 +93,10 @@ def init_input_config(self, scales_inv, lp_dtype, hp_dtype, scale_format, use_qd for input_scales_creator, s_inv in zip(self.inputs_scales_creators, scales_inv): if self.is_dynamic: input_config.append( - QuantDynamicInput(input_scales_creator, lp_dtype, hp_dtype, scale_format=scale_format) + QuantDynamicInput(input_scales_creator, self.lp_dtype, self.hp_dtype, scale_format=self.scale_format) ) else: - input_config.append(QuantInput(s_inv, lp_dtype, hp_dtype, scale_format=scale_format)) + input_config.append(QuantInput(s_inv, self.lp_dtype, self.hp_dtype, scale_format=self.scale_format)) return input_config @@ -114,47 +113,52 @@ def __init__(self, config, mod, measurement, params, mod_type_str): self.weight_ich_scale_calc = self.scales_method_factory.get_scale_method(QuantTensorName.WEIGHT_IN_CH) self.output_scales_creators.append(self.scales_method_factory.get_scale_method(QuantTensorName.OUTPUT, self.is_dynamic)) + def calc_output_scales(self, input_0, input_1): + output_measurement = self.measurement.outputs[0] if self.measurement is not None else [] + output_scales = None + if not self.is_dynamic: + output_scales = self.output_scales_creators[0].calc_scales( + output_measurement, QuantTensorType.MEASUREMENTS, input0=input_0, input1=input_1) + return (output_scales,) + def get_scales_module_config(self): input_scales = self.calc_input_scales(num_of_inputs=1) - output_measurement = self.measurement.outputs[0] if self.measurement is not None else [] rescaled_weight = self.mod.weight if hasattr(self.mod, 'weight') else None if rescaled_weight is not None: rescaled_weight = dequant_original_fp8_weight_if_needed(self.mod, rescaled_weight) - if self.scales_method_factory.scale_method_config_map[QuantTensorName.WEIGHT_IN_CH].scale_value_type != ScaleValueType.DUMMY_SCALES: - # Calculating weight in hpu to support scale calculation CGUID torch.ops.hpu.calculate_scale_for_cast - rescaled_weight = rescaled_weight.to(cur_device) - if self.weight_ich_scale_calc is not None: - weight_scales_in_ch = self.weight_ich_scale_calc.calc_scales(input_scales[0], QuantTensorType.CONST) - rescaled_weight = torch.div(rescaled_weight, weight_scales_in_ch.reshape([1, -1])) - weights_scales_out_ch = self.weight_och_scale_calc.calc_scales(rescaled_weight, QuantTensorType.CONST) - + if self.scales_method_factory.scale_method_config_map[QuantTensorName.WEIGHT_IN_CH].scale_value_type != ScaleValueType.DUMMY_SCALES: + # Calculating weight in hpu to support scale calculation CGUID torch.ops.hpu.calculate_scale_for_cast + rescaled_weight = rescaled_weight.to(cur_device) + extra_kwargs = {} + if self.weight_ich_scale_calc is not None and rescaled_weight is not None: + weights_input_channel_dim_size = rescaled_weight.size(1) + weight_scales_in_ch = self.weight_ich_scale_calc.calc_scales(input_scales[0], QuantTensorType.CONST, in_channel_size=weights_input_channel_dim_size) + if self.scales_method_factory.scale_method_config_map[QuantTensorName.WEIGHT_IN_CH].scale_value_type == ScaleValueType.DUMMY_SCALES: + extra_kwargs = {"out_channel_size": rescaled_weight.size(0)} + weights_scales_out_ch = self.weight_och_scale_calc.calc_scales(rescaled_weight, QuantTensorType.CONST, **extra_kwargs) params_config = ( {"weight": weights_scales_out_ch} if (self.weight_ich_scale_calc is None) else {"weight": {0: weights_scales_out_ch, 1: weight_scales_in_ch}} ) - output_scales = None - if not self.is_dynamic: - output_scales = self.output_scales_creators[0].calc_scales( - output_measurement, QuantTensorType.MEASUREMENTS, input0=weights_scales_out_ch, input1=input_scales[0] - ) + output_scales = self.calc_output_scales(weights_scales_out_ch, input_scales[0]) return ModuleConfig( input_scales, - (output_scales,), + output_scales, params_config, ) - def init_weight_config(self, scales, scales_inv, lp_dtype, hp_dtype, scale_format, use_qdq, fake_quant): - if use_qdq: + def init_weight_config(self, scales, scales_inv): + if self.use_qdq: # to ensure the weights to be loaded to the device in fp8 weight_config = [ - QuantInput(scales_inv, lp_dtype, hp_dtype, scale_format=scale_format, use_qdq=use_qdq), - DequantOutput(scales, lp_dtype, hp_dtype, scale_format=scale_format, use_qdq=use_qdq), + QuantInput(scales_inv, self.lp_dtype, self.hp_dtype, scale_format=self.scale_format, use_qdq=self.use_qdq), + DequantOutput(scales, self.lp_dtype, self.hp_dtype, scale_format=self.scale_format, use_qdq=self.use_qdq), ] - elif fake_quant: - weight_config = [QuantDequant(scales_inv, lp_dtype, hp_dtype, scale_format=scale_format)] + elif self.fake_quant: + weight_config = [QuantDequant(scales_inv, self.lp_dtype, self.hp_dtype, scale_format=self.scale_format)] else: - weight_config = [QuantInput(scales_inv, lp_dtype, hp_dtype, scale_format=scale_format)] + weight_config = [QuantInput(scales_inv, self.lp_dtype, self.hp_dtype, scale_format=self.scale_format)] return weight_config def init_weights_from_module(self, params_config): @@ -164,38 +168,25 @@ def init_weights_from_module(self, params_config): else: self.weight_och_scale_calc.scale = params_config - def get_output_config(self, lp_dtype, hp_dtype, scale_format): - output_config = [QuantDequantNone(lp_dtype, hp_dtype, scale_format=scale_format)] + def get_output_config(self): + output_config = [QuantDequantNone(self.lp_dtype, self.hp_dtype, scale_format=self.scale_format)] return output_config def scales_module_config_to_q_and_dq(self, module): self.init_scales_from_module_config(module) self.init_weights_from_module(module.params["weight"]) - scale_format, use_qdq, fake_quant, lp_dtype, hp_dtype = self.get_module_configuration() - input_config = super().init_input_config( - (self.inputs_scales_creators[0].calc_invert_scales(),), - lp_dtype, - hp_dtype, - scale_format, - use_qdq, - fake_quant, - ) + input_config = super().init_input_config((self.inputs_scales_creators[0].calc_invert_scales(),)) # outputs as bf16, and descaled in gemm under PatchedLinear, so no need to work here - output_config = self.get_output_config(lp_dtype, hp_dtype, scale_format=scale_format) + output_config = self.get_output_config() weight_config = self.init_weight_config( self.weight_och_scale_calc.scale, - self.weight_och_scale_calc.calc_invert_scales(), - lp_dtype, - hp_dtype, - scale_format, - use_qdq, - fake_quant, + self.weight_och_scale_calc.calc_invert_scales() ) params_config = {"weight": weight_config} if hasattr(self.mod, "bias") and (getattr(self.mod, "bias") is not None): # In PatchedLinear the bias is added to the output of gemm. # The output is expected to be descaled and in bf16, so we don't need to touch the bias. - bias_config = [QuantDequantNone(lp_dtype, hp_dtype)] + bias_config = [QuantDequantNone(self.lp_dtype, self.hp_dtype)] params_config.update({"bias": bias_config}) return ModuleConfig(input_config, output_config, params_config) @@ -226,19 +217,19 @@ def get_scales_module_config(self): module_config.outputs = (module_config.outputs[0], output_scales,) return module_config - def get_output_config(self, lp_dtype, hp_dtype, scale_format): + def get_output_config(self): if not self.allreduce_quantization_enabled: - return super().get_output_config(lp_dtype, hp_dtype, scale_format) + return super().get_output_config() scale_0 = self.output_scales_creators[0].scale inv_scale_0 = self.output_scales_creators[0].calc_invert_scales() - output_config_dq_scatter_output = DequantOutput(scale_0, lp_dtype, hp_dtype, scale_format=scale_format) - output_config_q_scatter_input = QuantInput(inv_scale_0, lp_dtype, hp_dtype, scale_format=scale_format) + output_config_dq_scatter_output = DequantOutput(scale_0, self.lp_dtype, self.hp_dtype, scale_format=self.scale_format) + output_config_q_scatter_input = QuantInput(inv_scale_0, self.lp_dtype, self.hp_dtype, scale_format=self.scale_format) output_config = [output_config_dq_scatter_output, output_config_q_scatter_input] inv_scale_1 = self.output_scales_creators[1].calc_invert_scales() scale_1 = self.output_scales_creators[1].scale - output_config_q_gather_input = QuantInput(inv_scale_1, lp_dtype, hp_dtype, scale_format=scale_format) - output_config_dq_gather_output = DequantOutput(scale_1, lp_dtype, hp_dtype, scale_format=scale_format) + output_config_q_gather_input = QuantInput(inv_scale_1, self.lp_dtype, self.hp_dtype, scale_format=self.scale_format) + output_config_dq_gather_output = DequantOutput(scale_1, self.lp_dtype, self.hp_dtype, scale_format=self.scale_format) output_config.extend([output_config_q_gather_input, output_config_dq_gather_output]) return output_config @@ -246,37 +237,68 @@ class MatmulOpQuantizer(BaseOpQuantizer): def __init__(self, config, mod, measurement, params, mod_type_str): super().__init__(config, mod, measurement, params, mod_type_str) - self.inputs_scales_creators.append(self.scales_method_factory.get_scale_method(QuantTensorName.INPUT)) - self.inputs_scales_creators.append(self.scales_method_factory.get_scale_method(QuantTensorName.INPUT)) - self.output_scales_creators.append(self.scales_method_factory.get_scale_method(QuantTensorName.OUTPUT)) + self.inputs_scales_creators.append(self.scales_method_factory.get_scale_method(QuantTensorName.INPUT, self.is_dynamic, scale_dim_index=-1)) + self.inputs_scales_creators.append(self.scales_method_factory.get_scale_method(QuantTensorName.INPUT, self.is_dynamic, scale_dim_index=-2)) + self.output_scales_creators.append(self.scales_method_factory.get_scale_method(QuantTensorName.OUTPUT, self.is_dynamic)) + def get_scales_module_config(self): input_scales = self.calc_input_scales(num_of_inputs=2) - output_scales = input_scales[0] * input_scales[1] + output_scales = None + if not self.is_dynamic and input_scales[0] is not None and input_scales[1] is not None: + output_scales = input_scales[0] * input_scales[1] return ModuleConfig(input_scales, (output_scales,), {}) def scales_module_config_to_q_and_dq(self, module): self.init_scales_from_module_config(module) - scale_format, use_qdq, fake_quant, lp_dtype, hp_dtype = super().get_module_configuration() input_config = super().init_input_config( - (self.inputs_scales_creators[0].calc_invert_scales(), self.inputs_scales_creators[1].calc_invert_scales()), - lp_dtype, - hp_dtype, - scale_format, - use_qdq, - fake_quant, + (self.inputs_scales_creators[0].calc_invert_scales(), self.inputs_scales_creators[1].calc_invert_scales()) ) # 4bit->8bit inputs, no need to quant if hasattr(self.mod, "no_input_quant"): - input_config[1] = QuantDequantNone(lp_dtype, hp_dtype, scale_format=scale_format) + input_config[1] = QuantDequantNone(self.lp_dtype, self.hp_dtype, scale_format=self.scale_format) # outputs as bf16, and descaled in gemm under PatchedLinear, so no need to work here - output_config = [QuantDequantNone(lp_dtype, hp_dtype, scale_format=scale_format)] + output_config = [QuantDequantNone(self.lp_dtype, self.hp_dtype, scale_format=self.scale_format)] return ModuleConfig(input_config, output_config) +## Batch2Block and Block2Batch Matmul Op Quantizer, need special handling for input scales +## First input is matrix of 0 and 1, so its scale is always 1 +## Second input is corrupt with garbage values, so its scale is taken from output scale +class B2B_MatmulOpQuantizer(MatmulOpQuantizer): + def __init__(self, config, mod, measurement, params, mod_type_str): + super().__init__(config, mod, measurement, params, mod_type_str) + + def get_scales_module_config(self): + input_scales = [] + ## first input is matrix of 0 and 1 + input_measurement = self.measurement.inputs[0] if self.measurement is not None else [] + input_scale = None + if not self.is_dynamic: + input_scale = self.inputs_scales_creators[0].calc_scales( + input_measurement, QuantTensorType.MEASUREMENTS + ) + input_scales.append(input_scale) + ## second input is corrupt with garbage values - use measurement from output + input_measurement = self.measurement.outputs[0] if self.measurement is not None else [] + input_scale = None + if not self.is_dynamic: + input_scale = self.inputs_scales_creators[1].calc_scales( + input_measurement, QuantTensorType.MEASUREMENTS + ) + input_scales.append(input_scale) + + + output_scales = None + if not self.is_dynamic and input_scales[0] is not None and input_scales[1] is not None: + output_scales = input_scales[0] * input_scales[1] + return ModuleConfig(input_scales, (output_scales,), {}) + + + class SoftmaxOpQuantizer(BaseOpQuantizer): def __init__(self, config, mod, measurement, params, mod_type_str): @@ -290,9 +312,8 @@ def get_scales_module_config(self): def scales_module_config_to_q_and_dq(self, module): self.init_scales_from_module_config(module) - scale_format, use_qdq, fake_quant, lp_dtype, hp_dtype = super().get_module_configuration() output_config = [ - DequantOutput(self.output_scales_creators[0].scale, lp_dtype, hp_dtype, scale_format=scale_format) + DequantOutput(self.output_scales_creators[0].scale, self.lp_dtype, self.hp_dtype, scale_format=self.scale_format) ] return ModuleConfig([], output_config, {}) @@ -322,15 +343,12 @@ def get_scales_module_config(self): def scales_module_config_to_q_and_dq(self, module): self.init_scales_from_module_config(module) - scale_format, use_qdq, fake_quant, lp_dtype, hp_dtype = super().get_module_configuration() input_scales_inv = [ self.inputs_scales_creators[i].calc_invert_scales() for i in range(len(self.inputs_scales_creators)) ] - input_config = super().init_input_config( - input_scales_inv, lp_dtype, hp_dtype, scale_format, use_qdq, fake_quant - ) + input_config = super().init_input_config(input_scales_inv) output_config = [ - DequantOutput(self.output_scales_creators[0].scale, lp_dtype, hp_dtype, scale_format=scale_format) + DequantOutput(self.output_scales_creators[0].scale, self.lp_dtype, self.hp_dtype, scale_format=self.scale_format) ] return ModuleConfig(input_config, output_config, {}) @@ -339,39 +357,42 @@ class KVCacheOpQuantizer(BaseOpQuantizer): def __init__(self, config, mod, measurement, params, mod_type_str): super().__init__(config, mod, measurement, params, mod_type_str) - self.inputs_scales_creators.append(self.scales_method_factory.get_scale_method(QuantTensorName.INPUT)) - self.output_scales_creators.append(self.inputs_scales_creators[0]) + self.inputs_scales_creators.append(self.scales_method_factory.get_scale_method(QuantTensorName.INPUT, self.is_dynamic)) + self.output_scales_creators.append(self.scales_method_factory.get_scale_method(QuantTensorName.OUTPUT, self.is_dynamic)) # TODO: Remove after implementing lp_dtype in OHF. - def init_input_config(self, scales_inv, lp_dtype, hp_dtype, scale_format, use_qdq, fake_quant): - input_config = super().init_input_config(scales_inv, lp_dtype, hp_dtype, scale_format, False, fake_quant) - if use_qdq: + def init_input_config(self, scales_inv): + input_config = super().init_input_config(scales_inv) + if self.use_qdq: input_config.extend([ - QuantDequant(s_inv, lp_dtype, hp_dtype, scale_format=scale_format, use_qdq=use_qdq) + QuantDequant(s_inv, self.lp_dtype, self.hp_dtype, scale_format=self.scale_format, use_qdq=self.use_qdq) for s_inv in scales_inv ]) return input_config def get_scales_module_config(self): input_scales = self.calc_input_scales(num_of_inputs=1) - self.output_scales_creators[0].scale = self.inputs_scales_creators[0].scale - output_scales = [self.output_scales_creators[0].scale] + if self.is_dynamic: + output_scales = self.calc_output_scales() + else: + self.output_scales_creators[0].scale = self.inputs_scales_creators[0].scale + output_scales = [self.output_scales_creators[0].scale] return ModuleConfig(input_scales, output_scales, {}) + def get_output_config(self): + if self.is_dynamic: + output_config = [DequantDynamicOutput(self.lp_dtype, self.hp_dtype, scale_format=self.scale_format) for s in self.output_scales_creators] + else: + output_config = [DequantOutput(self.output_scales_creators[0].scale, self.lp_dtype, self.hp_dtype, scale_format=self.scale_format, use_qdq=self.use_qdq)] + return output_config + def scales_module_config_to_q_and_dq(self, module): self.init_scales_from_module_config(module) - scale_format, use_qdq, fake_quant, lp_dtype, hp_dtype = super().get_module_configuration() input_scales_inv = [ self.inputs_scales_creators[i].calc_invert_scales() for i in range(len(self.inputs_scales_creators)) ] - # TODO: After implementing lp_dtype in OHF can call: - # `super().init_input_config(scales_inv, lp_dtype, hp_dtype, scale_format, False, fake_quant)` - input_config = self.init_input_config( - input_scales_inv, lp_dtype, hp_dtype, scale_format, use_qdq, fake_quant - ) - output_config = [ - DequantOutput(self.output_scales_creators[0].scale, lp_dtype, hp_dtype, scale_format=scale_format, use_qdq=False) - ] + input_config = self.init_input_config(input_scales_inv) + output_config = self.get_output_config() return ModuleConfig(input_config, output_config) @@ -386,7 +407,7 @@ def __init__(self, config, mod, measurement, params, mod_type_str): num_of_experts = self.mod.num_experts else: num_of_experts = 8 - + self.inputs_scales_creators = [ self.scales_method_factory.get_scale_method(QuantTensorName.INPUT, is_dynamic=self.is_dynamic) for i in range(num_of_inputs + num_of_experts) @@ -414,17 +435,13 @@ def get_scales_module_config(self): def scales_module_config_to_q_and_dq(self, module): self.init_scales_from_module_config(module) - scale_format, use_qdq, fake_quant, lp_dtype, hp_dtype = super().get_module_configuration() input_scales_inv = [ self.inputs_scales_creators[i].calc_invert_scales() for i in range(len(self.inputs_scales_creators)) ] - input_config = super().init_input_config( - input_scales_inv, lp_dtype, hp_dtype, scale_format, use_qdq, fake_quant - ) - output_config = [QuantDequantNone(lp_dtype, hp_dtype, scale_format=scale_format)] + input_config = super().init_input_config(input_scales_inv) + output_config = [QuantDequantNone(self.lp_dtype, self.hp_dtype, scale_format=self.scale_format)] return ModuleConfig(input_config, output_config) - class EmbeddingOpQuantizer(BaseOpQuantizer): @@ -440,7 +457,8 @@ def get_scales_module_config(self): input_scales = self.calc_input_scales(num_of_inputs=1) if self.weight_ich_scale_calc is not None: - weight_scales_in_ch = self.weight_ich_scale_calc.calc_scales(input_scales[0], QuantTensorType.CONST) + weights_input_channel_dim_size = weight.size(1) if weight is not None else None + weight_scales_in_ch = self.weight_ich_scale_calc.calc_scales(input_scales[0], QuantTensorType.CONST, in_channel_size=weights_input_channel_dim_size) weight = torch.div(weight, weight_scales_in_ch.reshape([1, -1])) weights_scales_out_ch = self.weight_och_scale_calc.calc_scales(weight, QuantTensorType.CONST) @@ -455,12 +473,12 @@ def get_scales_module_config(self): params_config, ) - def init_weight_config(self, scales, scales_inv, lp_dtype, hp_dtype, scale_format, use_qdq, fake_quant): - if use_qdq: + def init_weight_config(self, scales, scales_inv): + if self.use_qdq: # to ensure the weights to be loaded to the device in fp8 weight_config = [ - QuantInput(scales_inv, lp_dtype, hp_dtype, scale_format=scale_format, use_qdq=use_qdq), - DequantOutput(scales, lp_dtype, hp_dtype, scale_format=scale_format, use_qdq=use_qdq), + QuantInput(scales_inv, self.lp_dtype, self.hp_dtype, scale_format=self.scale_format, use_qdq=self.use_qdq), + DequantOutput(scales, self.lp_dtype, self.hp_dtype, scale_format=self.scale_format, use_qdq=self.use_qdq), ] else: raise ValueError("For FP8 quantization, {} only supports QDQ mode now!".format(self.mod.__class__.__name__)) @@ -476,22 +494,20 @@ def init_weights_from_module(self, params_config): def scales_module_config_to_q_and_dq(self, module): self.init_scales_from_module_config(module) self.init_weights_from_module(module.params["weight"]) - scale_format, use_qdq, fake_quant, lp_dtype, hp_dtype = self.get_module_configuration() weight_config = self.init_weight_config( self.weight_och_scale_calc.scale, self.weight_och_scale_calc.calc_invert_scales(), - lp_dtype, - hp_dtype, - scale_format, - use_qdq, - fake_quant, ) params_config = {"weight": weight_config} return ModuleConfig([], [], params_config) + + + ops_quantizer_map = {"linear": LinearOpQuantizer, "matmul": MatmulOpQuantizer, + "b2b_matmul": B2B_MatmulOpQuantizer, "fused_sdpa": FsdpaOpQuantizer, "softmax": SoftmaxOpQuantizer, "kv_cache": KVCacheOpQuantizer, diff --git a/neural_compressor/torch/algorithms/fp8_quant/_core/scale_methods/round_scales_function.py b/neural_compressor/torch/algorithms/fp8_quant/_core/scale_methods/round_scales_function.py index 8aa8bfef3cc..fa5216e2211 100644 --- a/neural_compressor/torch/algorithms/fp8_quant/_core/scale_methods/round_scales_function.py +++ b/neural_compressor/torch/algorithms/fp8_quant/_core/scale_methods/round_scales_function.py @@ -24,7 +24,7 @@ def decorator(cls): scale_round_method_registry[name] = cls return cls return decorator - + @register_scale_round_method("POW2") class ScaleToPow2: def __init__(self): diff --git a/neural_compressor/torch/algorithms/fp8_quant/_core/scale_methods/scale_method_config.py b/neural_compressor/torch/algorithms/fp8_quant/_core/scale_methods/scale_method_config.py index b8ef8b27232..82e7c672b2d 100644 --- a/neural_compressor/torch/algorithms/fp8_quant/_core/scale_methods/scale_method_config.py +++ b/neural_compressor/torch/algorithms/fp8_quant/_core/scale_methods/scale_method_config.py @@ -32,6 +32,7 @@ class ScaleMethodString(Enum): MAXABS_POW2_OPT_WEIGHT = auto() MAXABS_ARBITRARY = auto() ACT_MAXABS_PCS_POW2_WEIGHT_MAXABS_PTS_POW2_HW = auto() + MAXABS_PCS_POW2 = auto() class ScaleGranularity(Enum): PTS = auto() @@ -66,7 +67,7 @@ class CfgStr(Enum): LAYERS_SLASH_PATTERN= r"layers/(\d+)" class ScaleMethodConfig: - def __init__(self, + def __init__(self, granularity=ScaleGranularity.PTS, scale_value_type=ScaleValueType.MAXABS, rounding_method=ScaleRoundMethod.IDENTITY, @@ -88,11 +89,11 @@ def __hash__(self): self.scale_value_type, self.rounding_method )) - + def __eq__(self, other): if not isinstance(other, ScaleMethodConfig): return False - + # Only check the three fields that define uniqueness return (self.granularity == other.granularity and self.scale_value_type == other.scale_value_type and @@ -101,63 +102,68 @@ def __eq__(self, other): scale_method_config_mapping = { ScaleMethodString.UNIT_SCALE: { - CfgStr.WEIGHT: ScaleMethodConfig(scale_value_type=ScaleValueType.FIXED_VALUE, rounding_method= ScaleRoundMethod.SCALE_UNIT), - CfgStr.ACTIVATION: ScaleMethodConfig(scale_value_type=ScaleValueType.FIXED_VALUE, rounding_method= ScaleRoundMethod.SCALE_UNIT) + CfgStr.WEIGHT: ScaleMethodConfig(scale_value_type=ScaleValueType.FIXED_VALUE, rounding_method=ScaleRoundMethod.SCALE_UNIT), + CfgStr.ACTIVATION: ScaleMethodConfig(scale_value_type=ScaleValueType.FIXED_VALUE, rounding_method=ScaleRoundMethod.SCALE_UNIT) }, ScaleMethodString.HW_ALIGNED_SINGLE_SCALE: { - CfgStr.WEIGHT: ScaleMethodConfig(scale_value_type=ScaleValueType.FIXED_VALUE, rounding_method= ScaleRoundMethod.HW_ALIGNED_FIXED), - CfgStr.ACTIVATION: ScaleMethodConfig(scale_value_type=ScaleValueType.FIXED_VALUE, rounding_method= ScaleRoundMethod.HW_ALIGNED_FIXED) + CfgStr.WEIGHT: ScaleMethodConfig(scale_value_type=ScaleValueType.FIXED_VALUE, rounding_method=ScaleRoundMethod.HW_ALIGNED_FIXED), + CfgStr.ACTIVATION: ScaleMethodConfig(scale_value_type=ScaleValueType.FIXED_VALUE, rounding_method=ScaleRoundMethod.HW_ALIGNED_FIXED) }, ScaleMethodString.MAXABS_HW: { - CfgStr.WEIGHT: ScaleMethodConfig(rounding_method= ScaleRoundMethod.HW_ALIGNED, backoff= 0.5), - CfgStr.ACTIVATION: ScaleMethodConfig(rounding_method= ScaleRoundMethod.HW_ALIGNED, backoff= 0.25) + CfgStr.WEIGHT: ScaleMethodConfig(rounding_method=ScaleRoundMethod.HW_ALIGNED, backoff=0.5), + CfgStr.ACTIVATION: ScaleMethodConfig(rounding_method=ScaleRoundMethod.HW_ALIGNED, backoff=0.25) }, ScaleMethodString.MAXABS_POW2: { - CfgStr.WEIGHT: ScaleMethodConfig(rounding_method= ScaleRoundMethod.POW2, backoff= 0.5), - CfgStr.ACTIVATION: ScaleMethodConfig(rounding_method= ScaleRoundMethod.POW2, backoff= 0.25) + CfgStr.WEIGHT: ScaleMethodConfig(rounding_method=ScaleRoundMethod.POW2, backoff=0.5), + CfgStr.ACTIVATION: ScaleMethodConfig(rounding_method=ScaleRoundMethod.POW2, backoff=0.25) }, ScaleMethodString.MAXABS_ARBITRARY: { - CfgStr.WEIGHT: ScaleMethodConfig(backoff= 0.5), - CfgStr.ACTIVATION: ScaleMethodConfig(backoff= 0.25) + CfgStr.WEIGHT: ScaleMethodConfig(backoff=0.5), + CfgStr.ACTIVATION: ScaleMethodConfig(backoff=0.25) }, ScaleMethodString.ACT_MAXABS_PCS_POW2_WEIGHT_MAXABS_PTS_POW2_HW: { - CfgStr.WEIGHT: ScaleMethodConfig(rounding_method= ScaleRoundMethod.HW_ALIGNED, backoff= 0.5), - CfgStr.ACTIVATION: ScaleMethodConfig(granularity= ScaleGranularity.PCS, rounding_method= ScaleRoundMethod.POW2) + CfgStr.WEIGHT: ScaleMethodConfig(rounding_method=ScaleRoundMethod.HW_ALIGNED, backoff=0.5), + CfgStr.ACTIVATION: ScaleMethodConfig(granularity=ScaleGranularity.PCS, rounding_method=ScaleRoundMethod.POW2) + }, + ScaleMethodString.MAXABS_PCS_POW2: + { + CfgStr.WEIGHT: ScaleMethodConfig(granularity=ScaleGranularity.PCS, rounding_method=ScaleRoundMethod.POW2, backoff=0.5), + CfgStr.ACTIVATION: ScaleMethodConfig(granularity=ScaleGranularity.PCS, rounding_method=ScaleRoundMethod.POW2) }, ScaleMethodString.MAXABS_HW_OPT_WEIGHT: { - CfgStr.WEIGHT: ScaleMethodConfig(scale_value_type= ScaleValueType.OPT, rounding_method= ScaleRoundMethod.HW_ALIGNED, backoff= 0.5, params={"weight_scales": get_fp8_hw_alligned_scales(torch.float8_e4m3fn)}), - CfgStr.ACTIVATION: ScaleMethodConfig(rounding_method= ScaleRoundMethod.HW_ALIGNED, backoff= 0.25) + CfgStr.WEIGHT: ScaleMethodConfig(scale_value_type=ScaleValueType.OPT, rounding_method=ScaleRoundMethod.HW_ALIGNED, backoff=0.5, params={"weight_scales": get_fp8_hw_alligned_scales(torch.float8_e4m3fn)}), + CfgStr.ACTIVATION: ScaleMethodConfig(rounding_method=ScaleRoundMethod.HW_ALIGNED, backoff=0.25) }, ScaleMethodString.MAXABS_POW2_OPT_WEIGHT: { - CfgStr.WEIGHT: ScaleMethodConfig(scale_value_type= ScaleValueType.OPT, rounding_method= ScaleRoundMethod.POW2, backoff= 0.5, params={"weight_scales": [2.0**s for s in range(-10, 10)]}), - CfgStr.ACTIVATION: ScaleMethodConfig(rounding_method= ScaleRoundMethod.POW2, backoff= 0.25) + CfgStr.WEIGHT: ScaleMethodConfig(scale_value_type=ScaleValueType.OPT, rounding_method=ScaleRoundMethod.POW2, backoff=0.5, params={"weight_scales": [2.0**s for s in range(-10, 10)]}), + CfgStr.ACTIVATION: ScaleMethodConfig(rounding_method=ScaleRoundMethod.POW2, backoff=0.25) }, ScaleMethodString.ACT_MAXABS_HW_WEIGHTS_PCS_MAXABS_POW2: { - CfgStr.WEIGHT: ScaleMethodConfig(granularity= ScaleGranularity.PCS, rounding_method= ScaleRoundMethod.POW2, backoff= 0.5), - CfgStr.ACTIVATION: ScaleMethodConfig(rounding_method= ScaleRoundMethod.HW_ALIGNED, backoff= 0.25) + CfgStr.WEIGHT: ScaleMethodConfig(granularity=ScaleGranularity.PCS, rounding_method=ScaleRoundMethod.POW2, backoff=0.5), + CfgStr.ACTIVATION: ScaleMethodConfig(rounding_method=ScaleRoundMethod.HW_ALIGNED, backoff=0.25) }, ScaleMethodString.ACT_MAXABS_HW_WEIGHTS_PCS_OPT_POW2: { - CfgStr.WEIGHT: ScaleMethodConfig(scale_value_type = ScaleValueType.OPT, granularity= ScaleGranularity.PCS, rounding_method= ScaleRoundMethod.POW2, backoff= 0.5, params={"weight_scales": [2.0**s for s in range(-3, 5)]}), - CfgStr.ACTIVATION: ScaleMethodConfig(rounding_method= ScaleRoundMethod.HW_ALIGNED, backoff= 0.25) + CfgStr.WEIGHT: ScaleMethodConfig(scale_value_type=ScaleValueType.OPT, granularity=ScaleGranularity.PCS, rounding_method=ScaleRoundMethod.POW2, backoff=0.5, params={"weight_scales": [2.0**s for s in range(-3, 5)]}), + CfgStr.ACTIVATION: ScaleMethodConfig(rounding_method=ScaleRoundMethod.HW_ALIGNED, backoff=0.25) }, ScaleMethodString.ACT_MAXABS_POW2_WEIGHTS_PCS_MAXABS_POW2: { - CfgStr.WEIGHT: ScaleMethodConfig(granularity= ScaleGranularity.PCS, rounding_method= ScaleRoundMethod.POW2, backoff= 0.5), - CfgStr.ACTIVATION: ScaleMethodConfig(rounding_method= ScaleRoundMethod.POW2, backoff= 0.25) + CfgStr.WEIGHT: ScaleMethodConfig(granularity=ScaleGranularity.PCS, rounding_method=ScaleRoundMethod.POW2, backoff=0.5), + CfgStr.ACTIVATION: ScaleMethodConfig(rounding_method=ScaleRoundMethod.POW2, backoff=0.25) }, ScaleMethodString.ACT_MAXABS_POW2_WEIGHTS_PCS_OPT_POW2: { - CfgStr.WEIGHT: ScaleMethodConfig(scale_value_type = ScaleValueType.OPT, granularity= ScaleGranularity.PCS, rounding_method= ScaleRoundMethod.POW2, backoff= 0.5, params={"weight_scales": [2.0**s for s in range(-3, 5)]}), - CfgStr.ACTIVATION: ScaleMethodConfig(rounding_method= ScaleRoundMethod.POW2, backoff= 0.25) + CfgStr.WEIGHT: ScaleMethodConfig(scale_value_type=ScaleValueType.OPT, granularity=ScaleGranularity.PCS, rounding_method=ScaleRoundMethod.POW2, backoff=0.5, params={"weight_scales": [2.0**s for s in range(-3, 5)]}), + CfgStr.ACTIVATION: ScaleMethodConfig(rounding_method=ScaleRoundMethod.POW2, backoff=0.25) }, } diff --git a/neural_compressor/torch/algorithms/fp8_quant/_core/scale_methods/scale_method_factory.py b/neural_compressor/torch/algorithms/fp8_quant/_core/scale_methods/scale_method_factory.py index 164c6458507..53d84c1ec16 100644 --- a/neural_compressor/torch/algorithms/fp8_quant/_core/scale_methods/scale_method_factory.py +++ b/neural_compressor/torch/algorithms/fp8_quant/_core/scale_methods/scale_method_factory.py @@ -50,8 +50,8 @@ def __init__(self, config, params, mod, op_type): logger.trace("%s %s", self.__class__.__name__, self.__dict__) - def get_scale_method(self, tensor_type, is_dynamic=False): - backoff = 1.0 if is_dynamic else self.scale_method_config_map[tensor_type].backoff + def get_scale_method(self, tensor_type, is_dynamic=False, scale_dim_index=-1): + backoff = 0.5 if is_dynamic else self.scale_method_config_map[tensor_type].backoff scale_round_method = self.scale_method_config_map[tensor_type].rounding_method scale_value_type = self.scale_method_config_map[tensor_type].scale_value_type scale_granularity = self.scale_method_config_map[tensor_type].granularity @@ -95,13 +95,12 @@ def get_scale_method(self, tensor_type, is_dynamic=False): return MaxAbsPts(scale_round_method, self.params, self.device_for_scales, backoff) ## maxabs/opt in channel PCS case (_, ScaleGranularity.PCS, QuantTensorName.WEIGHT_IN_CH, _)\ - if scale_value_type in {ScaleValueType.MAXABS, ScaleValueType.OPT}: - in_channel_size = self.mod.weight.shape[1] - return InputChannelScale(scale_round_method, self.params, self.device_for_scales, in_channel_size) + if scale_value_type in {ScaleValueType.MAXABS, ScaleValueType.OPT, ScaleValueType.DUMMY_SCALES}: + return InputChannelScale(scale_round_method, self.params, self.device_for_scales) ## maxabs PCS case (ScaleValueType.MAXABS, ScaleGranularity.PCS, _, _): if is_dynamic: - return MaxAbsDynamicPcs(scale_round_method, self.params, self.device_for_scales, backoff) + return MaxAbsDynamicPcs(scale_round_method, self.params, self.device_for_scales, backoff, dim=scale_dim_index) return MaxAbsPcs(scale_round_method, self.params, self.device_for_scales, backoff) ## opt PTS case (ScaleValueType.OPT, ScaleGranularity.PTS, _, _): diff --git a/neural_compressor/torch/algorithms/fp8_quant/_core/scale_methods/scales_method.py b/neural_compressor/torch/algorithms/fp8_quant/_core/scale_methods/scales_method.py index be03b80e93d..43319306c6b 100644 --- a/neural_compressor/torch/algorithms/fp8_quant/_core/scale_methods/scales_method.py +++ b/neural_compressor/torch/algorithms/fp8_quant/_core/scale_methods/scales_method.py @@ -102,6 +102,23 @@ def calc_scale_from_const_tensor(self, tensor): return scale_tensor +class MaxAbsDynamicPts(MaxAbsPts): + def __init__(self, round_scale_method, params, device_for_scales, backoff, fullscale=None): + super().__init__(round_scale_method, params, device_for_scales, backoff, fullscale, is_dynamic=True) + logger.trace("%s %s",self.__class__.__name__, self.__dict__) + + def get_scale_funcs_dict(self): + scale_funcs_dict = super().get_scale_funcs_dict() + scale_funcs_dict[QuantTensorType.DYNAMIC] = self.calc_scale_from_const_tensor + return scale_funcs_dict + + def calc_scales(self, tensor, tensor_type, **additional_kwargs): + # In dynamic quantization the scale is changed each time, + # and setting scale as a member is not supported in hpu graphs and torch.compile + # (it can break the graph) + return self._calculate_maxabs_scale(tensor, tensor_type, **additional_kwargs) + + ## MulAdditionalScales Get 2 input scales, and return their multiplication. # used for linear and matmul outputs class MulAdditionalScales(ScalesMethod): @@ -140,12 +157,15 @@ def calc_scales(self, tensor, tensor_type, **additional_kwargs): # used when running with dummy measurement (prepare_model_with_dummy_measurement) class DummyScales(ScalesMethod): def calc_scales(self, tensor, tensor_type, **additional_kwargs): - self.scale = torch.tensor(1.0).to(self.device) + out_channel_size = additional_kwargs.get("out_channel_size") + if out_channel_size is None: + self.scale = torch.tensor(1.0).to(self.device) + else: + self.scale = torch.ones([out_channel_size], dtype=self.hp_dtype, device=self.device) return self.scale class MaxAbsPcs(MaxAbsMethod): - def __init__(self, round_scale_method, params, device_for_scales, backoff, fullscale=None, dim=1, keepdim=False, is_dynamic=False): super().__init__(round_scale_method, params, device_for_scales, backoff, fullscale, is_dynamic=is_dynamic) self.dim = dim @@ -177,18 +197,36 @@ def calc_scale_from_const_tensor(self, tensor): return scale_tensor +class MaxAbsDynamicPcs(MaxAbsPcs): + def __init__(self, round_scale_method, params, device_for_scales, backoff, fullscale=None, dim=-1): + super().__init__(round_scale_method, params, device_for_scales, backoff, fullscale, dim=dim, keepdim=True, is_dynamic=True) + logger.trace("%s %s", self.__class__.__name__, self.__dict__) + + def get_scale_funcs_dict(self): + scale_funcs_dict = super().get_scale_funcs_dict() + scale_funcs_dict[QuantTensorType.DYNAMIC] = self.calc_scale_from_const_tensor_no_reshape + return scale_funcs_dict + + def calc_scales(self, tensor, tensor_type, **additional_kwargs): + # In dynamic quantization the scale is changed each time, + # and setting scale as a member is not supported in hpu graphs and torch.compile + # (it can break the graph) + return self._calculate_maxabs_scale(tensor, tensor_type, **additional_kwargs) + + ## InputChannelScale used for input channel in PCS mode class InputChannelScale(ScalesMethod): - def __init__(self, round_scale_method, params, device_for_scales, in_channel_size): + def __init__(self, round_scale_method, params, device_for_scales): super().__init__(round_scale_method, params, device_for_scales) - self.in_channel_size = in_channel_size def calc_scales(self, tensor, tensor_type, **additional_kwargs): - input_in_ch = torch.ones([self.in_channel_size, 1], dtype=self.hp_dtype, device=self.device) - return input_in_ch.flatten() + in_channel_size = additional_kwargs.get("in_channel_size") + if in_channel_size is None: + raise ValueError("Missing 'in_channel_size' in additional_kwargs") + + return torch.ones([in_channel_size], dtype=self.hp_dtype, device=self.device) class FixedScale(ScalesMethod): - def __init__(self, round_scale_method, params, device_for_scales): super().__init__(round_scale_method, params, device_for_scales) self.round_scale_method = round_scale_method @@ -199,7 +237,6 @@ def calc_scales(self, tensor, tensor_type, **additional_kwargs): class OptScalesPts(ScalesMethod): - def __init__(self, round_scale_method, optional_scales_list, params, device_for_scales, backoff): super().__init__(round_scale_method, params, device_for_scales) self.round_scale_method = round_scale_method @@ -210,6 +247,7 @@ def calc_scales(self, tensor, tensor_type, **additional_kwargs): self.scale = self.round_scale_method.calc(mmse_scale(tensor, self.optional_scales_list, self.lp_dtype, self.hp_dtype)) return self.scale + class OptScalesPcs(ScalesMethod): def __init__(self, round_scale_method, optional_scales_list, params, device_for_scales, backoff): super().__init__(round_scale_method, params, device_for_scales) @@ -230,39 +268,3 @@ def calc_scales(self, tensor, tensor_type, **additional_kwargs): ).unsqueeze(1) self.scale = self.round_scale_method.calc(const_opt_scale_out_ch).flatten() return self.scale - - -class MaxAbsDynamicPcs(MaxAbsPcs): - - def __init__(self, round_scale_method, params, device_for_scales, backoff, fullscale=None): - super().__init__(round_scale_method, params, device_for_scales, backoff, fullscale, -1, True, True) - logger.trace("%s %s", self.__class__.__name__, self.__dict__) - - def get_scale_funcs_dict(self): - scale_funcs_dict = super().get_scale_funcs_dict() - scale_funcs_dict[QuantTensorType.DYNAMIC] = self.calc_scale_from_const_tensor_no_reshape - return scale_funcs_dict - - def calc_scales(self, tensor, tensor_type, **additional_kwargs): - # In dynamic quantization the scale is changed each time, - # and setting scale as a member is not supported in hpu graphs and torch.compile - # (it can break the graph) - return self._calculate_maxabs_scale(tensor, tensor_type) - - -class MaxAbsDynamicPts(MaxAbsPts): - - def __init__(self, round_scale_method, params, device_for_scales, backoff, fullscale=None): - super().__init__(round_scale_method, params, device_for_scales, backoff, fullscale, True) - logger.trace("%s %s",self.__class__.__name__, self.__dict__) - - def get_scale_funcs_dict(self): - scale_funcs_dict = super().get_scale_funcs_dict() - scale_funcs_dict[QuantTensorType.DYNAMIC] = self.calc_scale_from_const_tensor - return scale_funcs_dict - - def calc_scales(self, tensor, tensor_type, **additional_kwargs): - # In dynamic quantization the scale is changed each time, - # and setting scale as a member is not supported in hpu graphs and torch.compile - # (it can break the graph) - return self._calculate_maxabs_scale(tensor, tensor_type) diff --git a/neural_compressor/torch/algorithms/fp8_quant/_core/utils.py b/neural_compressor/torch/algorithms/fp8_quant/_core/utils.py index b5b2d319ea4..8b993f11a18 100644 --- a/neural_compressor/torch/algorithms/fp8_quant/_core/utils.py +++ b/neural_compressor/torch/algorithms/fp8_quant/_core/utils.py @@ -20,7 +20,7 @@ from .measure import prepare_model as prepare_model_for_measure from .quantize import quantize from .scale_methods.scale_method_config import get_scale_method_from_config, ScaleMethodString, CfgStr -from .common import is_runtime_scale_patching +from .common import is_runtime_scale_patching, set_runtime_state from neural_compressor.torch.utils.auto_accelerator import is_any_gaudi_accelerator import os import re @@ -93,9 +93,12 @@ def quantize_dynamic_op(config, mod_type): runtime_scale_patching_supported_methods_list = [method for method in scaling_methods_list if not any(substr in method for substr in exclude_substrings)] -def set_runtime_scale_patching_mode(scale_method_config): - import habana_frameworks.torch.utils.experimental as htexp # importing in local scope since it is gaudi specific - scale_method = get_scale_method_from_config(scale_method_config[CfgStr.DEFAULT]) +def set_gaudi_modes_and_attributes(cfg_dict): + is_dynamic_quantization = cfg_dict['dynamic_quantization'] + import habana_frameworks.torch.utils.experimental as htexp # importing in local scope since it is gaudi specific. + set_runtime_state(is_dynamic_quantization) + htexp._set_is_dynamic_quantization(is_dynamic_quantization) + scale_method = get_scale_method_from_config(cfg_dict["scale_method"][CfgStr.DEFAULT]) if is_runtime_scale_patching() and hasattr(htexp, "_set_scale_attributes"): assert ( scale_method.name in runtime_scale_patching_supported_methods_list @@ -127,5 +130,5 @@ def prepare_model(model): return prepare_model_for_measure(model, mod_list) elif config.cfg["mode"] in [QuantMode.QUANTIZE, QuantMode.LOAD]: if is_any_gaudi_accelerator(config.cfg["device_type"]): - set_runtime_scale_patching_mode(config.cfg["scale_method"]) + set_gaudi_modes_and_attributes(config.cfg) return quantize(model, mod_list) diff --git a/neural_compressor/torch/algorithms/fp8_quant/_quant_common/helper_modules.py b/neural_compressor/torch/algorithms/fp8_quant/_quant_common/helper_modules.py index 5ee8073767e..4a161535254 100755 --- a/neural_compressor/torch/algorithms/fp8_quant/_quant_common/helper_modules.py +++ b/neural_compressor/torch/algorithms/fp8_quant/_quant_common/helper_modules.py @@ -17,14 +17,14 @@ import torch.nn as nn import types import functools +import os -from .._core.quant_dequant import QuantDequant as qdq, QuantDynamicInput +from .._core.quant_dequant import QuantDequant as qdq from .._core.quantized_func_wrappers import get_quantized_func_wrapper, OP_TYPE from .quant_config import QuantMode, get_hqt_config from ..patched_module_base import PatchedModuleBase, get_call_wrapper from .._core.scale_handler import get_scale_dtype, ScaleFormat from neural_compressor.torch.utils.auto_accelerator import auto_detect_accelerator -cur_accelerator = auto_detect_accelerator() class BMM(nn.Module): @@ -107,6 +107,7 @@ def __init__(self, mod, parent, mod_extra_config, *args, **kwargs): if self.quantization_mode in [QuantMode.QUANTIZE, QuantMode.LOAD]: self.quant_input_0 = self._mod_extra_config.inputs[0] self.quant_input_1 = self._mod_extra_config.inputs[1] + if not self.use_qdq and not self.fake_quant: self.register_scale("scale_input", mod_extra_config.scale.inputs[0], self.scale_format) self.register_scale("scale_other", mod_extra_config.scale.inputs[1], self.scale_format) @@ -114,28 +115,39 @@ def __init__(self, mod, parent, mod_extra_config, *args, **kwargs): # in DPQ we want to use the scales measured in the quantization process if hasattr(parent, 'scale_bf16_to_fp8') and parent.scale_bf16_to_fp8 > 0: self.scale_other = torch.nn.Parameter(parent.scale_bf16_to_fp8) + if self.is_dynamic_quantization: + self.forward = self.forward_dynamic - def forward_quant(self, input, other): + def forward_quant(self, input, other, out=None): qinput = self.quant_input_0(input) qother = self.quant_input_1(other) - output = self.matmul_fp8(qinput, - qother, - out_dtype=self._mod_extra_config.config_params["hp_dtype"], - scale_input_inv=self.scale_input, - scale_other_inv=self.scale_other) - return output - - def forward_qdq(self, input, other): + return self.forward_impl(qinput, self.scale_input, qother, self.scale_other, out) + + def forward_dynamic(self, input, other, out=None): + qinput, scale_input = self.quant_input_0(input) + qother, scale_other = self.quant_input_1(other) + return self.forward_impl(qinput, scale_input, qother, scale_other, out) + + def forward_impl(self, qinput, scale_input, qother, scale_other, out=None): + out = self.matmul_fp8(qinput, + qother, + out=out, + out_dtype=self._mod_extra_config.config_params["hp_dtype"], + scale_input_inv=scale_input, + scale_other_inv=scale_other) + return out + + def forward_qdq(self, input, other, out=None): qinput = self.quant_input_0(input) qother = self.quant_input_1(other) - output = torch.matmul(qinput, qother) - return output + out = torch.matmul(qinput, qother, out=out) + return out - def forward_measure(self, input, other): + def forward_measure(self, input, other, out=None): measure_input((input, other), observer=self._mod_extra_config.inputs) - output = self.orig_mod(input, other) - measure_output((output,), self._mod_extra_config.outputs) - return output + out = self.orig_mod(input, other, out=out) + measure_output((out,), self._mod_extra_config.outputs) + return out def extra_repr(self) -> str: return extra_representation( @@ -158,13 +170,16 @@ def init_mixture_of_experts_linears(instance): class PatchedLinearBase(PatchedModuleBase): def __init__(self, mod, parent, mod_extra_config, *args, **kwargs): super().__init__(mod, parent, mod_extra_config, *args, **kwargs) + self.init_linear(mod_extra_config) - # TODO [SW-224538]: Move init_linear to PatchedLinearBase __init__ def init_linear(self, mod_extra_config): if self.quantization_mode in [QuantMode.QUANTIZE, QuantMode.LOAD]: # When offloading weights to disk using device_map, the module forward is overridden. # __dict__.update call again overrides the PatchedLinear forward with the forward that device_map planted. # So need to set PatchedLinear forward to be the right forward. + #TODO [GAUDISW-246018]: find a good solution for weights with 3 dims + if self.weight.dim() == 3: + self.weight.squeeze_() self.weight = nn.Parameter(self.weight.t().contiguous()) self.quant_input = self._mod_extra_config.inputs[0] self.dequant_output = self._mod_extra_config.outputs[0] @@ -187,7 +202,6 @@ def init_linear(self, mod_extra_config): # only ScaleFormat.CONST is supported for per-channel scale now. self.register_scale("scale_weight", mod_extra_config.scale.params["weight"][0], ScaleFormat.CONST) - self.is_dynamic_quantization = isinstance(self.quant_input, QuantDynamicInput) self.quant_input_func = self.quant_input_and_get_scale_dynamic if self.is_dynamic_quantization else self.quant_input_and_get_scale_static elif (self.quantization_mode == QuantMode.MEASURE) or (self.quantization_mode == QuantMode.SHAPE): init_mixture_of_experts_linears(self) @@ -235,7 +249,6 @@ def extra_repr(self) -> str: class PatchedLinear(PatchedLinearBase): def __init__(self, mod, parent, mod_extra_config, *args, **kwargs): super().__init__(mod, parent, mod_extra_config, *args, **kwargs) - self.init_linear(mod_extra_config) def forward_measure(self, input): measure_input((input,), observer=self._mod_extra_config.inputs) @@ -254,7 +267,6 @@ def __init__(self, mod, parent, mod_extra_config, *args, **kwargs): # ParallelLMHead's forward method should not be called because LMHead's weights should be used # in the sampler. (The forward itself throws RuntimeError exception) # So in order to quantize that quant_method we patch only the "apply" method. - self.init_linear(mod_extra_config) self.orig_linear_quant_apply = self.orig_mod.quant_method.apply if self.quantization_mode in [QuantMode.QUANTIZE, QuantMode.LOAD]: if self.use_qdq or self.fake_quant: @@ -280,7 +292,6 @@ def apply_measure(self, layer, x, bias): class PatchedReplicatedLinear(PatchedLinearBase): def __init__(self, mod, parent, mod_extra_config, *args, **kwargs): super().__init__(mod, parent, mod_extra_config, *args, **kwargs) - self.init_linear(mod_extra_config) def forward_qdq(self, input): bias = self.bias if not self.skip_bias_add else None @@ -304,7 +315,6 @@ def forward_measure(self, input): class PatchedLinearAllReduce(PatchedLinearBase): def __init__(self, mod, parent, mod_extra_config, *args, **kwargs): super().__init__(mod, parent, mod_extra_config, *args, **kwargs) - self.init_linear(mod_extra_config) self.scoped_version = mod.__class__.__name__ == "ScopedLinearAllReduce" def forward_qdq(self, input): @@ -349,34 +359,41 @@ def post_all_reduce(self, input): class PatchedRowParallelLinear(PatchedLinearBase): def __init__(self, mod, parent, mod_extra_config, *args, **kwargs): - super().__init__(mod, parent, mod_extra_config, *args, **kwargs) from .._core.external_func_impl import get_external_row_parallel_collective_func self.row_parallel_collective_func = get_external_row_parallel_collective_func() # TODO [SW-224403]: Enable dynamic quantization in row parallel allreduce - allreduce_quantization_enable = get_hqt_config(mod).cfg["row_parallel_linear_allreduce_quantization"] + self.allreduce_quantization_enable = get_hqt_config(mod).cfg["row_parallel_linear_allreduce_quantization"] + + super().__init__(mod, parent, mod_extra_config, *args, **kwargs) + + # Finalize initialization after init_linear has been called + if self.quantization_mode == QuantMode.QUANTIZE: + if self.allreduce_quantization_enable: + self.dequant_scatter_output = self._mod_extra_config.outputs[0] + self.quant_scatter_input = self._mod_extra_config.outputs[1] + self.quant_gather_input = self._mod_extra_config.outputs[2] + self.dequant_gather_output = self._mod_extra_config.outputs[3] + from torch import distributed as dist + self.world_size = dist.get_world_size() + + def init_linear(self, mod_extra_config): + # Set up forward method before calling parent's init_linear if self.quantization_mode in (QuantMode.MEASURE, QuantMode.SHAPE): self.forward = self.forward_measure_reduce if self.reduce_results and self.tp_size > 1 else self.forward_measure_no_reduce - elif self.quantization_mode in [QuantMode.QUANTIZE, QuantMode.LOAD]: if self.fake_quant or self.use_qdq: self.forward = self.forward_qdq else: if self.reduce_results and self.tp_size > 1: - if allreduce_quantization_enable: + if self.allreduce_quantization_enable: self.forward = self.forward_quant_reduce_in_lp else: self.forward = self.forward_quant_reduce_in_hp else: self.forward = self.forward_quant_no_reduce - self.init_linear(mod_extra_config) - if self.quantization_mode == QuantMode.QUANTIZE: - if allreduce_quantization_enable: - self.dequant_scatter_output = self._mod_extra_config.outputs[0] - self.quant_scatter_input = self._mod_extra_config.outputs[1] - self.quant_gather_input = self._mod_extra_config.outputs[2] - self.dequant_gather_output = self._mod_extra_config.outputs[3] - from torch import distributed as dist - self.world_size = dist.get_world_size() + + # Call parent's init_linear which sets up quantization parameters + super().init_linear(mod_extra_config) def resolve_input(self, input_): """ @@ -493,7 +510,6 @@ def bias_add(self, output): class PatchedColumnParallelLinear(PatchedLinearBase): def __init__(self, mod, parent, mod_extra_config, *args, **kwargs): super().__init__(mod, parent, mod_extra_config, *args, **kwargs) - self.init_linear(mod_extra_config) from .._core.external_func_impl import get_external_column_parallel_collective_func self.column_parallel_collective_func = get_external_column_parallel_collective_func() @@ -534,7 +550,6 @@ def add_bias(self, output): class PatchedLmHeadLinearAllreduce(PatchedLinearBase): def __init__(self, mod, parent, mod_extra_config, *args, **kwargs): super().__init__(mod, parent, mod_extra_config, *args, **kwargs) - self.init_linear(mod_extra_config) def forward_qdq(self, input): assert ( @@ -641,13 +656,12 @@ class PatchedMixtralMoE(PatchedModuleBase): def __init__(self, mod, parent, mod_extra_config, *args, **kwargs): super().__init__(mod, parent, mod_extra_config, *args, **kwargs) # remove the MoE weights that are quanted by PatchedMoeMatmul - if self.quantization_mode in [QuantMode.QUANTIZE, QuantMode.LOAD]: - delattr(mod, "w13_weight") - delattr(mod, "w2_weight") - setattr(mod, "w13_weight", None) - setattr(mod, "w2_weight", None) - setattr(self, "w13_weight", None) - setattr(self, "w2_weight", None) + delattr(mod, "w13_weight") + delattr(mod, "w2_weight") + setattr(mod, "w13_weight", None) + setattr(mod, "w2_weight", None) + setattr(self, "w13_weight", None) + setattr(self, "w2_weight", None) self.forward = self.forward_orig # copied from https://github.com/HabanaAI/vllm-fork/blob/93b8bad8478451349d0c76b3116d3ad863a3b48e/vllm/model_executor/layers/fused_moe/layer.py#L1429 @@ -678,11 +692,9 @@ def extra_repr(self) -> str: class PatchedMoeMatmul(PatchedLinearBase): def __init__(self, mod, parent, mod_extra_config, *args, **kwargs): super().__init__(mod, parent, mod_extra_config, *args, **kwargs) - self.init_linear(mod_extra_config) + self.weight = torch.nn.Parameter(self.weight.squeeze(), requires_grad=False) if (self.quantization_mode == QuantMode.MEASURE) or (self.quantization_mode == QuantMode.SHAPE): measure_input((torch.tensor(0),), observer=self._mod_extra_config.inputs) - else: - self.weight = torch.nn.Parameter(self.weight.squeeze(), requires_grad=False) def forward_qdq(self, input, *args, **kwargs): return self.run_linear_qdq(input, None) @@ -911,12 +923,12 @@ def __init__(self, mod, parent, mod_extra_config, *args, **kwargs): [mod_extra_config.scale.inputs[x] for x in range(1, self.experts_used + 1)], self.scale_format, ) - self.is_dynamic_quantization = isinstance(self.quant_input, QuantDynamicInput) self.dynamic_moe_op = get_quantized_func_wrapper( OP_TYPE.DYNAMIC_MOE_FUSED_WEIGHTS, scale_format=self.scale_format, is_dynamic=self.is_dynamic_quantization ) if self.is_dynamic_quantization: self.forward = self.forward_dynamic_quant + self.dispatch_fn = self._get_dispatch_func() def _get_extra_kwargs(self, tokens_num: int): kwargs = {} @@ -924,20 +936,31 @@ def _get_extra_kwargs(self, tokens_num: int): kwargs = self.orig_mod._get_extra_kwargs(tokens_num) return kwargs + # For vLLM Data Parallel https://github.com/vllm-project/vllm-gaudi/pull/684 + def _get_dispatch_func(self): + def identity(x): + return x + if hasattr(self.orig_mod, "_get_dispatch_func"): + fn = self.orig_mod._get_dispatch_func() + if fn is not None: + return fn + return identity + def forward_quant(self, hidden_states, expert_routing_table, router_weights, permuted_weights=True, activation="silu"): - tokens_num, hidden_dim = hidden_states.shape - extra_kwargs = self._get_extra_kwargs(tokens_num) experts_range = range(self.experts_used) w1_list = [self.w13_list[i].weight for i in experts_range] w2_list = [self.w2_list[i].weight for i in experts_range] scale_w1 = [self.w13_list[i].scale_weight for i in experts_range] scale_w2 = [self.w2_list[i].scale_weight for i in experts_range] qinput = self.quant_input(hidden_states) + qinput = self.dispatch_fn(qinput) + tokens_num, hidden_dim = qinput.shape + extra_kwargs = self._get_extra_kwargs(tokens_num) output = self.dynamic_moe_op( hidden_states=qinput, expert_routing_table=expert_routing_table, @@ -969,6 +992,7 @@ def forward_dynamic_quant( scale_w1 = [self.w13_list[i].scale_weight for i in experts_range] scale_w2 = [self.w2_list[i].scale_weight for i in experts_range] qinput_fp8, input_scale = self.quant_input(hidden_states) + qinput_fp8 = self.dispatch_fn(qinput_fp8) output = self.dynamic_moe_op( hidden_states=qinput_fp8, expert_routing_table=expert_routing_table, @@ -1158,6 +1182,21 @@ def __init__(self, mod, parent, mod_extra_config, *args, **kwargs): if self.quantization_mode in [QuantMode.QUANTIZE, QuantMode.LOAD]: self.quant_input = self._mod_extra_config.inputs[0] self.dequant_output = self._mod_extra_config.outputs[0] + if self.is_dynamic_quantization: + assert os.getenv("VLLM_DYNAMIC_KV_QUANT") is not None, "VLLM_DYNAMIC_KV_QUANT env var must be set for dynamic kv cache quantization" + if hasattr(mod, 'is_v_cache') and mod.is_v_cache: + from .._core.fp_utils import get_fullscale + self.is_v_cache = True + device_type = auto_detect_accelerator().get_inc_accelerator_type() + self.full_scale = get_fullscale(self.lp_dtype, device_type) + else: + self.is_v_cache = False + self.forward = self.forward_quant_dynamic + self.fetch_from_cache = self.fetch_from_cache_dynamic + self.cur_device = auto_detect_accelerator().current_device_name() + else: + self.forward = self.forward_quant_static + self.fetch_from_cache = self.fetch_from_cache_static elif (self.quantization_mode == QuantMode.MEASURE) or (self.quantization_mode == QuantMode.SHAPE): self.fetch_from_cache = mod.fetch_from_cache @@ -1166,27 +1205,48 @@ def forward_qdq(self, input, *args, **kwargs): output_cache = self.orig_mod(qinput, *args, **kwargs) return output_cache - def forward_quant(self, input, cache, *args, **kwargs): + def forward_quant_static(self, input, cache, slot_mapping, *args, **kwargs): if input is not None: qinput = self.quant_input(input) - output_cache = self.orig_mod(qinput, cache, *args, **kwargs) + output_cache = self.orig_mod(qinput, cache, slot_mapping, *args, **kwargs) else: # In cross-attention during decode stage kv cache isn't updated # so input is None and we don't store it output_cache = cache return self.dequant_output(output_cache) - def forward_measure(self, input, cache, *args, **kwargs): + # input is a new Key/Value, input.size => (batch_size, num_kv_heads, head_size) + def forward_quant_dynamic(self, input, cache, slot_mapping, scales=None, block_size=None, is_prompt=False, *args, **kwargs): + if input is not None: + qinput, scale = self.quant_input(input) + if scales is not None: + if self.is_v_cache and block_size is not None: + # in v cache scales is a tuple: (scales_on_token_dim, scales_on_hidden_dim) + scales[0].index_copy_(0, slot_mapping, scale) + self.update_scales_on_hidden(input, scales, slot_mapping, block_size, is_prompt) + else: + scales.index_copy_(0, slot_mapping, scale) + output_cache = self.orig_mod(qinput, cache, slot_mapping, scales, *args, **kwargs) + else: + # In cross-attention during decode stage kv cache isn't updated + # so input is None and we don't store it + output_cache = cache + return output_cache + + def forward_measure(self, input, cache, slot_mapping, *args, **kwargs): # In cross-attention during decode stage kv cache isn't updated # so input is None and we don't measure it if input is None: return cache measure_input((input, ), self._mod_extra_config.inputs) - output_cache = self.orig_mod(input, cache, *args, **kwargs) + output_cache = self.orig_mod(input, cache, slot_mapping, *args, **kwargs) measure_output((output_cache, ), self._mod_extra_config.outputs) return output_cache - def fetch_from_cache(self, cache, blocks): + # cache.size => (num_blocks, block_size, num_kv_heads, head_size) + # scale_on_token_dim.size => (num_blocks, block_size, num_kv_heads, 1) + # scale_on_hidden_dim.size => (num_blocks, num_kv_heads, 1) + def fetch_from_cache_static(self, cache, blocks, scales=None): if cache.dtype != self.lp_dtype: quant_cache = self.quant_input(cache) else: @@ -1194,6 +1254,68 @@ def fetch_from_cache(self, cache, blocks): output_cache = self.orig_mod.fetch_from_cache(quant_cache, blocks) return self.dequant_output(output_cache) + def fetch_from_cache_dynamic(self, cache, blocks, scales=None): + if cache.dtype != self.lp_dtype: + cur_cache, cur_scales = self.fetch_cache_and_scales(cache, blocks, scales) + quant_cache, _ = self.quant_input(cur_cache, cur_scales) + else: + quant_cache, cur_scales = self.fetch_cache_and_scales(cache, blocks, scales) + cache_dequanted = self.dequant_output(quant_cache, cur_scales) + if self.is_v_cache and scales is not None: + cache_dequanted = self.convert_on_hidden(cache_dequanted, scales, blocks) + return cache_dequanted + + def fetch_cache_and_scales(self, cache, blocks, scales=None): + if self.is_v_cache and scales is not None: + cur_scales = scales[0] + else: + cur_scales = scales + if self.orig_mod.use_contiguous_pa: + cur_scales = cur_scales[:blocks.size(0)] if cur_scales is not None else None + cur_cache = cache[:blocks.size(0)] + else: + cur_scales = cur_scales.index_select(0, blocks) if cur_scales is not None else None + cur_cache = cache.index_select(0, blocks) + return cur_cache, cur_scales + + def update_scales_on_hidden(self, input, scales, slot_mapping, block_size, is_prompt=False): + from .._core.fp_utils import calculate_scale_maxabs_with_cguid, calculate_scale_rounding_with_cguid, ScaleCalculationMaxMode, ScaleCalculationRoundingMode + scale_tensor = calculate_scale_maxabs_with_cguid( + input.unsqueeze(1), + ScaleCalculationMaxMode.MAX_ABS_PCS_CALCULATION, + reduceAxis=1, + reduceKeepdim=False, + fullscale=self.full_scale, + backoff=0.5, + ) + pow2_tensor = calculate_scale_rounding_with_cguid(scale_tensor, ScaleCalculationRoundingMode.SCALE_TO_POW2_ROUNDING) + block_mapping = slot_mapping // block_size + if is_prompt: + # for the prompt, getting the max scale of its tokens and assign it to its blocks as the scale on hidden dim + max_scale = torch.max(pow2_tensor, dim=0) + max_pow2_tensor = max_scale[0].expand(block_mapping.size(0), pow2_tensor.size(1), pow2_tensor.size(2)) + scales[1].index_copy_(0, block_mapping, max_pow2_tensor) + else: + new_blocks_locations = (slot_mapping % block_size) == 0 + saved_scales = scales[1].index_select(0, block_mapping) + # zero out new sequence blocks so that the scale will be taken from the input tensor and not the old saved value + saved_scales[new_blocks_locations] = 0.0 + scales[1].index_copy_(0, block_mapping, torch.maximum(pow2_tensor, saved_scales)) + + def convert_on_hidden(self, cache_dequanted, scales, blocks): + # invalid blocks have idx of -1 or greater than max blocks in kv-cache so that they are not read/written by the device + # getting the invalid blocks indexes so that we can set them in the read scale tensor with valid scale values + invalid_block_idx = (blocks >= scales[1].size(0)) | (blocks == -1) + + if self.orig_mod.use_contiguous_pa: + cur_scale_on_h = scales[1][:blocks.size(0)] + else: + cur_scale_on_h = scales[1].index_select(0, blocks) + # set valid values in invalid scale tensor locations to avoid zero scales errors + cur_scale_on_h[invalid_block_idx] = 1.0 + cache_quanted_on_h, dq_scales = self.quant_input(cache_dequanted, cur_scale_on_h.unsqueeze(1)) + return self.dequant_output(cache_quanted_on_h, dq_scales) + def extra_repr(self) -> str: return extra_representation( self.extra_repr_org(), @@ -1201,6 +1323,39 @@ def extra_repr(self) -> str: get_current_repr(self), ) +##TODO SW-242485 - move this function to base conv class +def compute_padding(padding, kernel_size, stride=(1, 1), input_size=None): + """ + Compute padding values based on padding type and stride. + + Args: + padding: controls the amount of padding applied to the input. + It can be either a string {‘valid’, ‘same’} or an int / a tuple of + ints giving the amount of implicit padding applied on both sides. + kernel_size : (ch_in, ch_out, kernel_height, kernel_width) + stride (tuple): (stride_height, stride_width) + input_size (tuple): (ch_in, ch_out, input_height, input_width), optional for exact SAME padding + + Returns: + Tuple: (pad_h, pad_w) - pad_h and pad_w is half of the total padding_h and padding_w, padding added to all four sides of the input + """ + + if isinstance(padding, list) or isinstance(padding, tuple): + return padding + elif padding == "valid": + return 0 + elif padding == "same": + in_h, in_w = input_size[2], input_size[3] + sh, sw = stride + kh, kw = kernel_size[2], kernel_size[3] + out_h = (in_h + sh - 1) // sh + out_w = (in_w + sw - 1) // sw + pad_h_total = max((out_h - 1) * sh + kh - in_h, 0) + pad_w_total = max((out_w - 1) * sw + kw - in_w, 0) + return (int(pad_h_total/2), int(pad_w_total/2)) + else: + raise ValueError(f"Unsupported padding type: {padding}") + def init_conv(instance, mod_extra_config): if instance.quantization_mode in [QuantMode.QUANTIZE, QuantMode.LOAD]: @@ -1219,6 +1374,7 @@ def __init__(self, mod, parent, mod_extra_config, *args, **kwargs): init_conv(self, mod_extra_config) def forward_qdq(self, input): + padding = compute_padding(self.padding,self.weight.size(), self.stride, input.size() ) qweight = self.dequant_weights(self.weight, ) qinput = self.quant_input(input) output = torch.nn.functional.conv2d( @@ -1226,19 +1382,20 @@ def forward_qdq(self, input): qweight, self.bias, self.stride, - self.padding, + padding, self.dilation, self.groups, ) return output def forward_quant(self, input): + padding = compute_padding(self.padding,self.weight.size(), self.stride, input.size() ) qinput = self.quant_input(input) output = self.conv2d_fp8(qinput, self.weight, self.bias, self.stride, - self.padding, + padding, self.dilation, self.groups, out_dtype=self._mod_extra_config.config_params["hp_dtype"], @@ -1247,6 +1404,8 @@ def forward_quant(self, input): return output def forward_measure(self, input): + padding = compute_padding(self.padding,self.weight.size(), self.stride, input.size() ) + self.padding = padding measure_input((input,), observer=self._mod_extra_config.inputs) output = self.orig_mod(input) measure_output((output,), self._mod_extra_config.outputs) @@ -1337,7 +1496,6 @@ def forward_measure(self, attn, block_bias, block_groups, batch_size, global_max class PatchedLoRACompatibleLinear(PatchedLinearBase): def __init__(self, mod, parent, mod_extra_config, *args, **kwargs): super().__init__(mod, parent, mod_extra_config, *args, **kwargs) - self.init_linear(mod_extra_config) def forward_qdq(self, input, scale: float = 1.0): output = self.run_linear_qdq(input, self.bias) @@ -1389,11 +1547,12 @@ def forward_quant(self, input, scale: float = 1.0): # TODO SW-174899 support lora layer quantization _raise_lora_layer_error(self.class_name_org) else: + padding = compute_padding(self.padding,self.weight.size(), self.stride, input.size() ) output = self.conv2d_fp8(qinput, self.weight, self.bias, self.stride, - self.padding, + padding, self.dilation, self.groups, out_dtype=self._mod_extra_config.config_params["hp_dtype"], @@ -1402,6 +1561,8 @@ def forward_quant(self, input, scale: float = 1.0): return output def forward_measure(self, input, scale: float = 1.0): + padding = compute_padding(self.padding,self.weight.size(), self.stride, input.size() ) + self.padding = padding measure_input((input,), observer=self._mod_extra_config.inputs) output = self.orig_mod(input, scale) measure_output((output,), self._mod_extra_config.outputs) diff --git a/neural_compressor/torch/algorithms/fp8_quant/_quant_common/quant_config.py b/neural_compressor/torch/algorithms/fp8_quant/_quant_common/quant_config.py index b3bb98957f2..8c171ad140a 100644 --- a/neural_compressor/torch/algorithms/fp8_quant/_quant_common/quant_config.py +++ b/neural_compressor/torch/algorithms/fp8_quant/_quant_common/quant_config.py @@ -29,6 +29,12 @@ from .._core.scale_methods.scale_method_parser import parse_scale_method, validate_and_populate_scale_method, convert_scale_method_strings_to_enum from .._core.scale_methods.scale_method_config import get_scale_method_from_config, check_scale_method_fields, ScaleMethodString, CfgStr, ScaleGranularity, ScaleValueType, ScaleRoundMethod +# Scale methods that are only supported in dynamic quantization +SUPPORTED_DYNAMIC_QUANTIZATION_SCALES = [ + ScaleMethodString.ACT_MAXABS_PCS_POW2_WEIGHT_MAXABS_PTS_POW2_HW, + ScaleMethodString.MAXABS_PCS_POW2 +] + class QuantMode(Enum): NONE = 0 @@ -253,6 +259,8 @@ def parse(custom_config: Mapping[str, str]) -> Fp8cfg: dynamic_quantization = measured_global_config["dynamic_quantization"] # TODO [SW-217814]: get dynamic methods in a better way, or support file handling in dynamic mode if dynamic_quantization: + if measured_global_config["use_qdq"] or measured_global_config["fake_quant"]: + raise ValueError("Currently dynamic quantization is not supported for qdq and fake quant.") if auto_detect_accelerator().current_device_name() == "cpu": raise ValueError("Currently CPU device doesn't support dynamic quantization") logger.info(f"NOTE: Using dynamic scale method, only supported ops will be quantized.") @@ -274,10 +282,12 @@ def parse(custom_config: Mapping[str, str]) -> Fp8cfg: if measured_global_config["row_parallel_linear_allreduce_quantization"]: raise ValueError(f"Dynamic quantization is not supported when using row_parallel_linear_allreduce_quantization") else: - if check_scale_method_fields(scale_method_config, scale_method= ScaleMethodString.ACT_MAXABS_PCS_POW2_WEIGHT_MAXABS_PTS_POW2_HW, reducer=any): - raise ValueError( - f"Unsupported config: scale_method ACT_MAXABS_PCS_POW2_WEIGHT_MAXABS_PTS_POW2_HW is supported only in dynamic quantization" - ) + # Check if any of the dynamic-only scale methods are being used + for scale_method in SUPPORTED_DYNAMIC_QUANTIZATION_SCALES: + if check_scale_method_fields(scale_method_config, scale_method=scale_method, reducer=any): + raise ValueError( + f"Unsupported config: scale_method {scale_method.name} is supported only in dynamic quantization" + ) if (dynamic_quantization or check_scale_method_fields(scale_method_config, scale_value_type_activation= ScaleValueType.FIXED_VALUE, scale_value_type_weight= ScaleValueType.FIXED_VALUE, reducer=all)) and \ diff --git a/neural_compressor/torch/algorithms/fp8_quant/patched_module_base.py b/neural_compressor/torch/algorithms/fp8_quant/patched_module_base.py index 4b6cf3d20ae..7273c99c835 100644 --- a/neural_compressor/torch/algorithms/fp8_quant/patched_module_base.py +++ b/neural_compressor/torch/algorithms/fp8_quant/patched_module_base.py @@ -158,6 +158,7 @@ def __init__( set_attrs_from_orig_model(self, mod, parent, mod_extra_config, *func_names) add_scale_registry(self) self.mod_extra_config = mod_extra_config + self.is_dynamic_quantization = get_hqt_config(mod).cfg["dynamic_quantization"] if self.quantization_mode in (QuantMode.MEASURE, QuantMode.SHAPE): self.forward = self.forward_measure elif self.quantization_mode in [QuantMode.QUANTIZE, QuantMode.LOAD]: diff --git a/neural_compressor/torch/algorithms/fp8_quant/save_load.py b/neural_compressor/torch/algorithms/fp8_quant/save_load.py index df599a87009..4eac14febff 100644 --- a/neural_compressor/torch/algorithms/fp8_quant/save_load.py +++ b/neural_compressor/torch/algorithms/fp8_quant/save_load.py @@ -435,7 +435,8 @@ def load_empty_raw_model(model_name_or_path, **kwargs): # fp8 model provided by neuralmagic. if ( - "quant_method" in quantization_config + isinstance(quantization_config, dict) + and "quant_method" in quantization_config and quantization_config["quant_method"] in ["fp8", "compressed-tensors"] ): from_neuralmagic = True diff --git a/neural_compressor/torch/algorithms/smooth_quant/utility.py b/neural_compressor/torch/algorithms/smooth_quant/utility.py index 2c7aa8c8181..80a89c3232d 100644 --- a/neural_compressor/torch/algorithms/smooth_quant/utility.py +++ b/neural_compressor/torch/algorithms/smooth_quant/utility.py @@ -2271,6 +2271,7 @@ def _parse_absorb_to_layers(self, op_types, folding): # Check if input_maxes match self.absorb_to_layer # (due to self._get_all_layer_names use layer tree instead of forward_path) if not folding and self.need_calibration: + input_mins, input_maxes = self.input_mins, self.input_maxes if len(self.input_mins) == 0: ##there are some modules not used in forward calib = Calibration(self.model, self.dataloader, self.q_func, self.device) ## input_mins, input_maxes = calib.calibrate( @@ -2616,6 +2617,8 @@ def _calculate_qparams(self, input_scale, input_minmax, dtype=torch.quint8): """ if dtype == torch.quint8: quant_min, quant_max = 0, 255 + else: + raise ValueError(f"Unsupported dtype for quantization parameters: {dtype}") min_val = torch.min(input_minmax[0] * input_scale) max_val = torch.max(input_minmax[1] * input_scale) # work when min_val bigger than zero. diff --git a/neural_compressor/torch/algorithms/weight_only/modules.py b/neural_compressor/torch/algorithms/weight_only/modules.py index d31838f5aa1..d5aac7f4bcb 100644 --- a/neural_compressor/torch/algorithms/weight_only/modules.py +++ b/neural_compressor/torch/algorithms/weight_only/modules.py @@ -45,9 +45,9 @@ def __init__( """Init the Matmul object.""" super().__init__() - def forward(self, X, Y): + def forward(self, X, Y, **kwargs): """Forward function.""" - return torch.matmul(X, Y) + return torch.matmul(X, Y, **kwargs) class QDQLayer(torch.nn.Module): diff --git a/neural_compressor/torch/algorithms/weight_only/utility.py b/neural_compressor/torch/algorithms/weight_only/utility.py index 2640312e781..33a8ffbf593 100644 --- a/neural_compressor/torch/algorithms/weight_only/utility.py +++ b/neural_compressor/torch/algorithms/weight_only/utility.py @@ -367,7 +367,7 @@ def quant_tensor( weight2, scale2, zp2 = weight2 weight = torch.cat([weight1, weight2], dim=1) scale = torch.cat([scale1, scale2], dim=1) - zp = None if zp2 is None else torch.cat([zp1, zp2], dim=1) + zp = None if (zp1 is None or zp2 is None) else torch.cat([zp1, zp2], dim=1) accelerator.synchronize() orig_weight.copy_(weight) return orig_weight, scale, zp diff --git a/neural_compressor/transformers/utils/quantization_config.py b/neural_compressor/transformers/utils/quantization_config.py index 185de2b9076..c40c978e320 100644 --- a/neural_compressor/transformers/utils/quantization_config.py +++ b/neural_compressor/transformers/utils/quantization_config.py @@ -347,15 +347,9 @@ def __init__( self.device = kwargs.get("device", "auto") self.scheme = "sym" if self.sym else "asym" - if isinstance(compute_dtype, torch.dtype): - self.compute_dtype = compute_dtype - else: - self.compute_dtype = compute_dtype - - if isinstance(scale_dtype, torch.dtype): - self.scale_dtype = scale_dtype - else: - self.scale_dtype = scale_dtype + self.compute_dtype = compute_dtype + + self.scale_dtype = scale_dtype self.post_init_gptq() diff --git a/test/torch/algorithms/fp8_quant/tester.py b/test/torch/algorithms/fp8_quant/tester.py index 52ee9bf583e..c0067b7c451 100644 --- a/test/torch/algorithms/fp8_quant/tester.py +++ b/test/torch/algorithms/fp8_quant/tester.py @@ -20,13 +20,12 @@ ScaleFormat, ScaleMethodString, get_hqt_config, + SUPPORTED_DYNAMIC_QUANTIZATION_SCALES, ) from neural_compressor.torch.quantization import FP8Config, convert, prepare # user level API from .test_hpu_utils import get_device_name -SUPPORTED_DYNAMIC_SCALES = [ScaleMethodString.ACT_MAXABS_PCS_POW2_WEIGHT_MAXABS_PTS_POW2_HW] - HW_ALIGNED_SCALE_METHODS = [ ScaleMethodString.MAXABS_HW, ScaleMethodString.MAXABS_HW_OPT_WEIGHT, diff --git a/test/torch/algorithms/fp8_quant/unit_tests/test_functions/test_config_json.py b/test/torch/algorithms/fp8_quant/unit_tests/test_functions/test_config_json.py index 2e4ad7936c1..d67c683ad9e 100644 --- a/test/torch/algorithms/fp8_quant/unit_tests/test_functions/test_config_json.py +++ b/test/torch/algorithms/fp8_quant/unit_tests/test_functions/test_config_json.py @@ -10,7 +10,8 @@ from neural_compressor.torch.algorithms.fp8_quant._core.scale_methods.scale_method_config import ScaleMethodString from neural_compressor.torch.algorithms.fp8_quant._quant_common.helper_modules import Matmul from neural_compressor.torch.algorithms.fp8_quant._quant_common.quant_config import QuantMode - +from neural_compressor.torch.algorithms.fp8_quant._core.scale_methods.scale_method_config import ScaleMethodString +from ...tester import run_with_raised_exception, get_internal_config, SCALE_METHODS_QUANT_ONLY, SUPPORTED_DYNAMIC_QUANTIZATION_SCALES from ...test_hpu_utils import * from ...tester import SCALE_METHODS_QUANT_ONLY, get_internal_config, run_with_raised_exception @@ -52,7 +53,7 @@ def run_predefined_config(): prepare_model._prep_model_with_predefined_config(model, config=config) fp8_quant.finish_measurements(model) - if scale_method == ScaleMethodString.ACT_MAXABS_PCS_POW2_WEIGHT_MAXABS_PTS_POW2_HW: + if scale_method in SUPPORTED_DYNAMIC_QUANTIZATION_SCALES: return run_with_raised_exception(run_predefined_config, ValueError, "Unsupported config: scale_method") # This is an expected exception, as test is not measuring before elif scale_method not in SCALE_METHODS_QUANT_ONLY: diff --git a/test/torch/algorithms/fp8_quant/unit_tests/test_layers/test_conv2d.py b/test/torch/algorithms/fp8_quant/unit_tests/test_layers/test_conv2d.py index d15fd9c2f68..cf85a1ac683 100644 --- a/test/torch/algorithms/fp8_quant/unit_tests/test_layers/test_conv2d.py +++ b/test/torch/algorithms/fp8_quant/unit_tests/test_layers/test_conv2d.py @@ -22,9 +22,8 @@ def get_test_vectors( @pytest.mark.parametrize("lp_dtype", [torch.float8_e4m3fn], ids=["fp8_e4m3fn"]) @pytest.mark.parametrize("scale_method", ScaleMethodString) @pytest.mark.parametrize("device_type", device_type) -def test_conv2d_accuracy( - hp_dtype: torch.dtype, lp_dtype: torch.dtype, scale_method: ScaleMethodString, device_type: str -): +@pytest.mark.parametrize("padding_type",["valid", "same", [1,1]], ids=["valid", "same", "int_padding"]) +def test_conv2d_accuracy(hp_dtype: torch.dtype, lp_dtype: torch.dtype, scale_method: ScaleMethodString, device_type: str, padding_type): # TODO [SW-196641]: fix the following issues: if scale_method in SCALE_METHODS_SEGFAULT: pytest.skip("Not supported") @@ -49,7 +48,7 @@ def run(): "in_channels": C_in, "out_channels": C_out, "kernel_size": K, - "padding": 1, + "padding": padding_type, "bias": False, "device": "hpu", "dtype": hp_dtype, @@ -66,8 +65,6 @@ def run(): elif device_type_id[device_type] != get_device_type(): if not (device_type_id[device_type] == get_gaudi2_type() and is_gaudi3()): return run_with_raised_exception(run, ValueError, "Unsupported config: device_for_scales=") - elif scale_method == ScaleMethodString.ACT_MAXABS_PCS_POW2_WEIGHT_MAXABS_PTS_POW2_HW: - return run_with_raised_exception( - run, ValueError, "Unsupported config: scale_method ACT_MAXABS_PCS_POW2_WEIGHT_MAXABS_PTS_POW2_HW" - ) + elif scale_method in SUPPORTED_DYNAMIC_QUANTIZATION_SCALES: + return run_with_raised_exception(run, ValueError, f"Unsupported config: scale_method {scale_method.name}") return run() diff --git a/test/torch/algorithms/fp8_quant/unit_tests/test_layers/test_linear.py b/test/torch/algorithms/fp8_quant/unit_tests/test_layers/test_linear.py index a0c242a7f8a..308c18e38ca 100644 --- a/test/torch/algorithms/fp8_quant/unit_tests/test_layers/test_linear.py +++ b/test/torch/algorithms/fp8_quant/unit_tests/test_layers/test_linear.py @@ -2,33 +2,19 @@ import pytest import torch +import types from neural_compressor.torch.algorithms.fp8_quant._core.quant_dequant import QuantDynamicInput from neural_compressor.torch.algorithms.fp8_quant._core.scale_handler import scale_to_scalar from neural_compressor.torch.algorithms.fp8_quant._core.scale_methods.scale_method_config import ScaleMethodString from neural_compressor.torch.algorithms.fp8_quant._quant_common.quant_config import ScaleFormat +from neural_compressor.torch.algorithms.fp8_quant._core.scale_handler import scale_to_scalar from ...test_hpu_utils import * from ...tester import * -# Test Class to support restoration of calculated scale during runtime with dynamic quantization to test it correctness. -# This is a workaround to avoid saving the scale in the original QuantDynamicInput class as scale saving may cause unwanted graph breaks in torch.compile or issues with hpu_graph. -class TestQuantDynamicInput(QuantDynamicInput): - def __init__(self, input_scales_creator, lp_dtype, hp_dtype, *args, **kwargs): - super(TestQuantDynamicInput, self).__init__(input_scales_creator, lp_dtype, hp_dtype, *args, **kwargs) - self.input_scale = None - - def forward(self, x): - ret, scale = super().forward(x) - # We save the calculated scale during this forward pass to test it correctness. - self.input_scale = scale - return ret, scale - - -def get_test_vectors( - *, dtype: torch.dtype, N: int, D_in: int, atol: float = 0.02, rtol: float = 0.01 -) -> typing.Iterable[TestVector]: +def get_test_vectors(*, dtype: torch.dtype, N: int, D_in: int, atol: float = 0.02, rtol: float = 0.01) -> typing.Iterable[TestVector]: yield TestVector( inputs=[torch.ones(N, D_in, dtype=dtype, device="hpu", requires_grad=False)], atol=atol, @@ -131,8 +117,8 @@ def run(): if scale_method in HW_ALIGNED_SCALE_METHODS or scale_method in QUANT_ONLY_SCALE_METHODS: # When in dynamic quantization we don't support hw aligned scale methods and unit scale return run_with_raised_exception(run, ValueError, "Unsupported config: scale_method") - else: - if scale_method in SUPPORTED_DYNAMIC_SCALES: + else : + if scale_method in SUPPORTED_DYNAMIC_QUANTIZATION_SCALES: # When in static quantization we don't support dynamic scale method return run_with_raised_exception(run, ValueError, "Unsupported config: scale_method") return run() @@ -181,18 +167,17 @@ def run(): **module_kwargs, ) previous_input_dynamic_scale = 0 - test_quant_dynamic_input = TestQuantDynamicInput( - dynamic_quantized_model.inner.quant_input.input_scales_creator, - dynamic_quantized_model.inner.quant_input.lp_dtype, - dynamic_quantized_model.inner.quant_input.hp_dtype, - ) - dynamic_quantized_model.inner.quant_input = test_quant_dynamic_input + def wrapForward(self, input): + _,scale = self.inner.quant_input_func(input) + self.input_scale = scale + return self.inner(input) + dynamic_quantized_model.forward = types.MethodType(wrapForward, dynamic_quantized_model) for vector in test_vectors: dynamic_quantized_output = dynamic_quantized_model(*(input.clone() for input in vector.inputs)).to(float) # We save the calculated scale after the dynamic_quantized_model run the current input and calculates new scale. # In next iteration, we will have a new scale stored in the class. - current_input_dynamic_scale = dynamic_quantized_model.inner.quant_input.input_scale + current_input_dynamic_scale = dynamic_quantized_model.input_scale if isinstance(current_input_dynamic_scale, torch.Tensor): current_input_dynamic_scale = scale_to_scalar(current_input_dynamic_scale) diff --git a/test/torch/algorithms/fp8_quant/unit_tests/test_layers/test_matmul.py b/test/torch/algorithms/fp8_quant/unit_tests/test_layers/test_matmul.py index 75119a5c1b3..756811539ea 100644 --- a/test/torch/algorithms/fp8_quant/unit_tests/test_layers/test_matmul.py +++ b/test/torch/algorithms/fp8_quant/unit_tests/test_layers/test_matmul.py @@ -2,14 +2,13 @@ import pytest import torch +import types from neural_compressor.torch.algorithms.fp8_quant._core.scale_methods.scale_method_config import ScaleMethodString from ...test_hpu_utils import * from ...tester import * -SUPPORTED_DYNAMIC_SCALES = [ScaleMethodString.ACT_MAXABS_PCS_POW2_WEIGHT_MAXABS_PTS_POW2_HW] - def get_test_vectors(*, dtype: torch.dtype, atol) -> typing.Iterable[TestVector]: yield TestVector( @@ -45,8 +44,8 @@ class Matmul(torch.nn.Module): def __init__(self): super().__init__() - def forward(self, x, y): - return torch.matmul(x, y) + def forward(self, x, y, **kwargs): + return torch.matmul(x, y, **kwargs) @pytest.mark.parametrize("hp_dtype", [torch.bfloat16, torch.float32], ids=["bf16", "fp32"]) @@ -67,7 +66,8 @@ def test_matmul_accuracy( quant_modes = QUANT_MODES_QUANT_ONLY if scale_method == ScaleMethodString.HW_ALIGNED_SINGLE_SCALE: atol = 1.0 - + if scale_method == ScaleMethodString.ACT_MAXABS_PCS_POW2_WEIGHT_MAXABS_PTS_POW2_HW and hp_dtype == torch.float32: + atol = 0.7 def run(): run_accuracy_test( module_class=Matmul, @@ -92,8 +92,65 @@ def run(): if scale_method in HW_ALIGNED_SCALE_METHODS or scale_method in QUANT_ONLY_SCALE_METHODS: # When in dynamic quantization we don't support hw aligned scale methods and unit scale return run_with_raised_exception(run, ValueError, "Unsupported config: scale_method") - else: - if scale_method in SUPPORTED_DYNAMIC_SCALES: + else : + if scale_method in SUPPORTED_DYNAMIC_QUANTIZATION_SCALES: # When in static quantization we don't support dynamic scale method return run_with_raised_exception(run, ValueError, "Unsupported config: scale_method") return run() + + +@pytest.mark.parametrize("hp_dtype", [torch.bfloat16, torch.float32], ids=["bf16", "fp32"]) +@pytest.mark.parametrize("lp_dtype", [torch.float8_e4m3fn], ids=["fp8_e4m3fn"]) +@pytest.mark.parametrize("scale_method", ScaleMethodString) +@pytest.mark.parametrize("device_type", device_type) +@pytest.mark.parametrize("scale_format", ScaleFormat) +@pytest.mark.parametrize("use_hpu_graphs", [True, False], ids=["use_hpu_graphs", "no_hpu_graphs"]) +def test_matmul_dynamic_quantization( + hp_dtype: torch.dtype, + lp_dtype: torch.dtype, + scale_method: ScaleMethodString, + device_type: str, + scale_format: ScaleFormat, + use_hpu_graphs: bool +): + atol = 0.2 + if scale_method == ScaleMethodString.HW_ALIGNED_SINGLE_SCALE: + atol = 1.0 + module_class=Matmul + module_kwargs={} + def run(): + test_vectors=get_test_vectors(dtype=hp_dtype, atol=atol) + dynamic_quantized_model = WrapModel(module_class, None, **module_kwargs) + dynamic_quantized_model = setup_quantization( + dynamic_quantized_model, + QuantMode.QUANTIZE, + lp_dtype, + scale_method, + device_type, + scale_format, + True, + use_hpu_graphs, + **module_kwargs, + ) + previous_input_dynamic_scale = torch.Tensor([[0.0],[0.0]]) + def wrapForward(self, input, other): + _, scale = self.inner.quant_input_0(input) + self.input_scale = scale + return self.inner(input, other) + dynamic_quantized_model.forward = types.MethodType(wrapForward, dynamic_quantized_model) + + for vector in test_vectors: + dynamic_quantized_output = dynamic_quantized_model(*(input.clone() for input in vector.inputs)).to(float) + # We save the calculated scale after the dynamic_quantized_model run the current input and calculates new scale. + # In next iteration, we will have a new scale stored in the class. + current_input_dynamic_scale = dynamic_quantized_model.input_scale + if scale_method not in SCALE_METHODS_QUANT_ONLY: + assert not torch.equal(previous_input_dynamic_scale, current_input_dynamic_scale), f"input scales in dynamic quantization should differ in different tensors {previous_input_dynamic_scale=} {current_input_dynamic_scale=}" + previous_input_dynamic_scale = current_input_dynamic_scale.clone() + + if (device_type_id[device_type] == get_gaudi3_type() and is_gaudi2() and scale_method == ScaleMethodString.MAXABS_HW): + return run_with_raised_exception(run, ValueError, "Unsupported config: device_for_scales=") + if (get_device_type() != device_type_id[device_type]) or scale_method in HW_ALIGNED_SCALE_METHODS or scale_method in QUANT_ONLY_SCALE_METHODS: + return run_with_raised_exception(run, ValueError, "Unsupported config: scale_method") + + return run() \ No newline at end of file diff --git a/test/torch/algorithms/fp8_quant/unit_tests/test_runtime_scale_patching.py b/test/torch/algorithms/fp8_quant/unit_tests/test_runtime_scale_patching.py index 635c2a817e0..99d88060aa7 100644 --- a/test/torch/algorithms/fp8_quant/unit_tests/test_runtime_scale_patching.py +++ b/test/torch/algorithms/fp8_quant/unit_tests/test_runtime_scale_patching.py @@ -7,7 +7,8 @@ import pytest import torch -from neural_compressor.torch.algorithms.fp8_quant._core.common import is_runtime_scale_patching +from ..tester import RUNTIME_SCALE_PATCHING_SUPPORTED_METHODS_LIST, run_with_raised_exception, SUPPORTED_DYNAMIC_QUANTIZATION_SCALES +from neural_compressor.torch.algorithms.fp8_quant._core.common import set_runtime_state from neural_compressor.torch.algorithms.fp8_quant._quant_common.quant_config import ScaleMethodString from neural_compressor.torch.quantization import FP8Config, convert, finalize_calibration, prepare @@ -83,7 +84,7 @@ def test_no_assert(scale_method, scale_format, dynamic_scale_patching, temp_dire def run_convert(): convert(inference_model, quant_config) - is_runtime_scale_patching.cache_clear() + set_runtime_state.cache_clear() os.environ["RUNTIME_SCALE_PATCHING"] = "0" model = prepare(model, measure_config) @@ -91,7 +92,7 @@ def run_convert(): model(input) finalize_calibration(model) - if scale_method == ScaleMethodString.ACT_MAXABS_PCS_POW2_WEIGHT_MAXABS_PTS_POW2_HW: + if scale_method in SUPPORTED_DYNAMIC_QUANTIZATION_SCALES: return run_with_raised_exception(run_convert, ValueError, "Unsupported config: scale_method") if dynamic_scale_patching: os.environ["RUNTIME_SCALE_PATCHING"] = "1" diff --git a/test/torch/algorithms/fp8_quant/unit_tests/test_save_load.py b/test/torch/algorithms/fp8_quant/unit_tests/test_save_load.py index a8dafbcba8b..67b797c8956 100644 --- a/test/torch/algorithms/fp8_quant/unit_tests/test_save_load.py +++ b/test/torch/algorithms/fp8_quant/unit_tests/test_save_load.py @@ -7,7 +7,7 @@ htcore.hpu_set_env() from transformers import LlamaConfig, LlamaForCausalLM - +from habana_frameworks.torch.utils.version_checker import is_pytorch_at_least from neural_compressor.torch.algorithms.fp8_quant._quant_common.helper_modules import Matmul from neural_compressor.torch.quantization import FP8Config, convert, load, prepare, save diff --git a/test/torch/algorithms/fp8_quant/unit_tests/test_scale_method_config.py b/test/torch/algorithms/fp8_quant/unit_tests/test_scale_method_config.py index 967d448ab70..d0a8ac0ad27 100644 --- a/test/torch/algorithms/fp8_quant/unit_tests/test_scale_method_config.py +++ b/test/torch/algorithms/fp8_quant/unit_tests/test_scale_method_config.py @@ -52,7 +52,7 @@ def check_tests_to_skip(scale_method, scale_value_type_weight=None, scale_value_ or scale_value_type_activation == ScaleValueType.DUMMY_SCALES ): pytest.xfail("Dummy scales is not a scale method") - if scale_method in SUPPORTED_DYNAMIC_SCALES: + if scale_method in SUPPORTED_DYNAMIC_QUANTIZATION_SCALES: pytest.xfail("Key error") diff --git a/test/torch/algorithms/fp8_quant_xpu/unit_tests/test_xpu_basic.py b/test/torch/algorithms/fp8_quant_xpu/unit_tests/test_xpu_basic.py index d9627bcc071..16cd2453ddb 100644 --- a/test/torch/algorithms/fp8_quant_xpu/unit_tests/test_xpu_basic.py +++ b/test/torch/algorithms/fp8_quant_xpu/unit_tests/test_xpu_basic.py @@ -37,8 +37,8 @@ class Matmul(torch.nn.Module): def __init__(self, dim=16): super().__init__() - def forward(self, input, other): - return torch.matmul(input, other) + def forward(self, input, other, **kwargs): + return torch.matmul(input, other, **kwargs) class MyModelMatmul(torch.nn.Module): @@ -83,10 +83,7 @@ def test_xpu_basic_mamtul(): # verify that we actually did the checks assert verified_matmul_quantized_func_wrapper and verified_quant_input_quantized_func_wrapper - -pytest.mark.xfail(reason="PYTORCHDGQ-6840 - enable once low-precision casting custom XPU ops are supported") - - +@pytest.mark.skip(reason="PYTORCHDGQ-6840 - enable once low-precision casting custom XPU ops are supported") def test_xpu_quantized_func_wrapper(): # test for verifying xpu quantized wrapper logic my_model = MyModel() diff --git a/test/torch/algorithms/pt2e_quant/test_pt2e_w8a8.py b/test/torch/algorithms/pt2e_quant/test_pt2e_w8a8.py index 9a643b3ff73..d6f8954b380 100644 --- a/test/torch/algorithms/pt2e_quant/test_pt2e_w8a8.py +++ b/test/torch/algorithms/pt2e_quant/test_pt2e_w8a8.py @@ -67,8 +67,8 @@ def test_quantizer_on_simple_model(self): config.freezing = True opt_model = torch.compile(converted_model) out = opt_model(*example_inputs) - logger.warning("out shape is %s", out.shape) assert out is not None + logger.warning("out shape is %s", out.shape) @pytest.mark.skipif(get_torch_version() <= TORCH_VERSION_2_2_2, reason="Requires torch>=2.3.0") def test_quantizer_on_llm(self): diff --git a/test/torch/quantization/fp8_quant/test_layer_wise.py b/test/torch/quantization/fp8_quant/test_layer_wise.py index be0baa39c3a..43020a938b1 100644 --- a/test/torch/quantization/fp8_quant/test_layer_wise.py +++ b/test/torch/quantization/fp8_quant/test_layer_wise.py @@ -1,19 +1,22 @@ import habana_frameworks.torch.core as htcore import pytest import torch +from huggingface_hub import snapshot_download from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer from neural_compressor.torch.quantization import FP8Config, convert, finalize_calibration, prepare from neural_compressor.torch.utils import get_used_cpu_mem_MB -@pytest.mark.skip(reason="https://github.com/huggingface/transformers/issues/43159") def test_two_step_layer_wise(): # layer-wise is based on memory mapping technique and https://github.com/huggingface/transformers/pull/31771 # Workaround of [SW-208658]: torch.use_deterministic_algorithms(True) will break memory mapping tmp_memory_flag = torch.utils.deterministic.fill_uninitialized_memory torch.utils.deterministic.fill_uninitialized_memory = False model_name = "facebook/opt-125m" + # Pre-download all model files to local cache before memory measurement, + # so that from_pretrained can use memory mapping on local files. + snapshot_download(repo_id=model_name) config = AutoConfig.from_pretrained(model_name) # requires transformers >= 4.43.0, torch_dtype=config.torch_dtype # facebook/opt-125m parameters on disk is in torch.float16 dtype @@ -21,7 +24,6 @@ def test_two_step_layer_wise(): model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=config.torch_dtype, use_safetensors=True) cpu_mem1 = get_used_cpu_mem_MB() assert (cpu_mem1 - cpu_mem0) < 100, "model with memory mapping should use no more than 100MiB." - qconfig = FP8Config() model = prepare(model, qconfig) diff --git a/test/torch/quantization/fp8_quant/test_save_load.py b/test/torch/quantization/fp8_quant/test_save_load.py index 50db223b1fd..706447cd57f 100644 --- a/test/torch/quantization/fp8_quant/test_save_load.py +++ b/test/torch/quantization/fp8_quant/test_save_load.py @@ -5,6 +5,7 @@ import torch import transformers +from habana_frameworks.torch.utils.version_checker import is_pytorch_at_least from neural_compressor.torch.algorithms.fp8_quant._quant_common.helper_modules import PatchedLinear from neural_compressor.torch.algorithms.fp8_quant.prepare_quant.prepare_model import get_local_rank, get_world_size from neural_compressor.torch.quantization import FP8Config, convert, load, prepare, save diff --git a/test/torch/quantization/test_pt2e_quant.py b/test/torch/quantization/test_pt2e_quant.py index ae5b2674dca..9921e0c7159 100644 --- a/test/torch/quantization/test_pt2e_quant.py +++ b/test/torch/quantization/test_pt2e_quant.py @@ -35,7 +35,6 @@ def _is_ipex_imported(): monkeypatch.setattr("neural_compressor.torch.quantization.algorithm_entry.is_ipex_imported", _is_ipex_imported) monkeypatch.setattr("neural_compressor.torch.export.pt2e_export.is_ipex_imported", _is_ipex_imported) - class TestPT2EQuantization: def teardown_class(self): shutil.rmtree("saved_results", ignore_errors=True) @@ -132,6 +131,8 @@ def calib_fn(model): from neural_compressor.torch.quantization import load loaded_quantized_model = load("./saved_results") + if loaded_quantized_model is None: + logger.error("loaded_quantized_model is None") loaded_q_model_out = loaded_quantized_model(*example_inputs) assert torch.equal(loaded_q_model_out, q_model_out) @@ -172,6 +173,8 @@ def calib_fn(model): from torch._inductor import config config.freezing = True + if q_model is None: + logger.error("q_model is None") q_model_out = q_model(*example_inputs) assert torch.allclose(float_model_output, q_model_out, atol=1e-2), "Quantization failed!" opt_model = torch.compile(q_model) @@ -222,9 +225,9 @@ def test_prepare_and_convert_on_llm(self, force_not_import_ipex): attention_mask = inputs.attention_mask input_ids = inputs.input_ids - from transformers import DynamicCache - from transformers.integrations.executorch import export_with_dynamic_cache + from transformers.integrations.executorch import export_with_dynamic_cache + from transformers import DynamicCache ep = export_with_dynamic_cache(model, input_ids, attention_mask) model = ep.module() model._exported = True @@ -249,10 +252,10 @@ def test_prepare_and_convert_on_llm(self, force_not_import_ipex): config.freezing = True opt_model = torch.compile(converted_model) out = opt_model( - input_ids=input_ids, - attention_mask=attention_mask, - past_key_values=DynamicCache(config=model_config), - use_cache=True, + input_ids=input_ids, + attention_mask=attention_mask, + past_key_values=DynamicCache(config=model_config), + use_cache=True, ) assert out.logits is not None @@ -321,8 +324,8 @@ def test_auto_tune_mixed_int8_and_16bits(self, half_precision_dtype, op_name_or_ # Just make sure the pattern matches, not the accuracy. # config1: int8 for all # config2: half precision for linear/conv - from neural_compressor.torch.quantization.autotune import TuningConfig, autotune from neural_compressor.torch.quantization.config import INT8StaticQuantConfig + from neural_compressor.torch.quantization.autotune import autotune, TuningConfig config1 = INT8StaticQuantConfig() config2 = INT8StaticQuantConfig().set_local( diff --git a/test/torch/requirements.txt b/test/torch/requirements.txt index c7ccf1418f7..76d4e7cdf56 100644 --- a/test/torch/requirements.txt +++ b/test/torch/requirements.txt @@ -1,8 +1,8 @@ +accelerate auto-round @ git+https://github.com/intel/auto-round.git@main auto-round-lib compressed-tensors >= 0.15.0 datasets -deepspeed @ git+https://github.com/HabanaAI/DeepSpeed.git@1.23.0 expecttest numpy peft