diff --git a/neural_compressor/torch/algorithms/fp8_quant/_quant_common/patched_helper_modules.py b/neural_compressor/torch/algorithms/fp8_quant/_quant_common/patched_helper_modules.py index 51a94b2cf0e..99c80a72c6f 100644 --- a/neural_compressor/torch/algorithms/fp8_quant/_quant_common/patched_helper_modules.py +++ b/neural_compressor/torch/algorithms/fp8_quant/_quant_common/patched_helper_modules.py @@ -126,6 +126,11 @@ def __init__(self, mod, parent, mod_extra_config, *args, **kwargs): self.fp8_apc_fsdpa_impl = impl_mapping[qkv_slice_impl] self.slice_causal = os.getenv("VLLM_HPU_FSDPA_SLICE_CAUSAL", "0") in ("1", "true") + self.with_mark_step = os.getenv("VLLM_HPU_FSDPA_SLICE_WITH_MARK_STEP", "0") in ("1", "true") + if self.with_mark_step: + import habana_frameworks.torch as ht + self.mark_step = ht.core.mark_step + def fp8_fsdpa_fwd( self, @@ -199,6 +204,9 @@ def fp8_apc_fsdpa_split_kv( prefix_linv = prefix_linv.to(torch.float32) * (128.0 if softmax_mode == "fast" else 1.0) prefix_out = self.dequant_output(prefix_out).to(torch.float32) + if self.with_mark_step: + self.mark_step() + # calculate the causal part causal_k = k[..., prefix_len:, :] causal_v = v[..., prefix_len:, :] @@ -255,6 +263,9 @@ def fp8_apc_fsdpa_slice_causal( prefix_linv = prefix_linv.to(torch.float32) * (128.0 if softmax_mode == "fast" else 1.0) prefix_out = self.dequant_output(prefix_out).to(torch.float32) + if self.with_mark_step: + self.mark_step() + # calculate the causal part chunk_outputs = [] num_chunks = (q_len + self.qkv_chunk_size - 1) // self.qkv_chunk_size @@ -284,6 +295,19 @@ def fp8_apc_fsdpa_slice_causal( if kv_chunk_idx == 0 and not is_causal_chunk else None ) + + if self.with_mark_step: + # mark_step() cannot break the tensor slicing, use clone to isolate the graph + q_chunk = q_chunk.clone() + k_chunk = k_chunk.clone() + v_chunk = v_chunk.clone() + if mask_chunk is not None: + mask_chunk = mask_chunk.clone() + last_out = last_out.clone() + last_m = last_m.clone() + last_linv = last_linv.clone() + self.mark_step() + chunk_res = self.fp8_fsdpa_fwd( q_chunk, k_chunk, v_chunk, mask_chunk, dropout_p, scale, is_causal_chunk, softmax_mode ) @@ -301,6 +325,9 @@ def fp8_apc_fsdpa_slice_causal( chunk_linv_rescaled * last_linv ) * chunk_out last_m = new_m + + if self.with_mark_step: + self.mark_step() chunk_outputs.append(last_out) chunk_outputs = list(reversed(chunk_outputs)) return torch.cat(chunk_outputs, dim=-2) @@ -352,6 +379,12 @@ def fp8_apc_fsdpa_slice_qkv( k_chunk = k[..., kv_start:kv_end, :] v_chunk = v[..., kv_start:kv_end, :] + if self.with_mark_step: + q_chunk = q_chunk.clone() + k_chunk = k_chunk.clone() + v_chunk = v_chunk.clone() + self.mark_step() + chunk_res = self.fp8_fsdpa_fwd( q_chunk, k_chunk, v_chunk, None, dropout_p, scale, False, softmax_mode ) @@ -373,6 +406,9 @@ def fp8_apc_fsdpa_slice_qkv( chunk_linv_rescaled * last_linv ) * chunk_out last_m = new_m + + if self.with_mark_step: + self.mark_step() for kv_chunk_idx in range(0, num_q_chunks - q_chunk_idx): kv_start = prefix_len + q_end - (kv_chunk_idx + 1) * self.qkv_chunk_size @@ -389,6 +425,15 @@ def fp8_apc_fsdpa_slice_qkv( if kv_chunk_idx == 0 and not is_causal_chunk else None ) + + if self.with_mark_step: + q_chunk = q_chunk.clone() + k_chunk = k_chunk.clone() + v_chunk = v_chunk.clone() + if mask_chunk is not None: + mask_chunk = mask_chunk.clone() + self.mark_step() + chunk_res = self.fp8_fsdpa_fwd( q_chunk, k_chunk, v_chunk, mask_chunk, dropout_p, scale, is_causal_chunk, softmax_mode ) @@ -410,6 +455,10 @@ def fp8_apc_fsdpa_slice_qkv( last_out = (last_linv_rescaled * last_linv) * last_out + \ (chunk_linv_rescaled * last_linv) * chunk_out last_m = new_m + + if self.with_mark_step: + self.mark_step() + chunk_outputs.append(last_out) chunk_outputs = list(reversed(chunk_outputs)) return torch.cat(chunk_outputs, dim=-2) @@ -469,6 +518,13 @@ def fp8_causal_fsdpa_slice_qkv( mask_chunk = (1.0 - torch.tril(torch.ones(mask_shape, dtype=self.hp_dtype, device=q_chunk.device))) * -3e38 else: mask_chunk = None + + if self.with_mark_step: + q_chunk = q_chunk.clone() + k_chunk = k_chunk.clone() + v_chunk = v_chunk.clone() + self.mark_step() + chunk_res = self.fp8_fsdpa_fwd( q_chunk, k_chunk, v_chunk, mask_chunk, dropout_p, scale, is_causal_chunk, softmax_mode ) @@ -490,6 +546,10 @@ def fp8_causal_fsdpa_slice_qkv( last_out = (last_linv_rescaled * last_linv) * last_out + \ (chunk_linv_rescaled * last_linv) * chunk_out last_m = new_m + + if self.with_mark_step: + self.mark_step() + chunk_outputs.append(last_out) chunk_outputs = list(reversed(chunk_outputs)) return torch.cat(chunk_outputs, dim=-2)