diff --git a/olive/passes/onnx/graph_surgeries.py b/olive/passes/onnx/graph_surgeries.py index db32c99b7..1c66c7a5e 100644 --- a/olive/passes/onnx/graph_surgeries.py +++ b/olive/passes/onnx/graph_surgeries.py @@ -850,6 +850,225 @@ def get_rmsnorm_nodes(pow_node: str, dag: OnnxDAG) -> list[str] | None: return rmsnorm_nodes if len(rmsnorm_nodes) >= (len(pattern) - 1) else [] +class SimplifiedLayerNormToRMSNorm(ProtoSurgeon): + """Replace SimplifiedLayerNormalization or SkipSimplifiedLayerNormalization with an RMSNorm subgraph built from elementwise ops. + + RMS(x) = sqrt(mean(x^2, axis=-1, keepdims=1) + eps) + y = (x / RMS(x)) * gamma + + For SkipSimplifiedLayerNormalization, we first do: + s = input + skip + and use 's' as x for RMSNorm. If the original node exposes a second output + (residual sum), we rewire its consumers to 's' to preserve graph behavior. + + IMPORTANT: ReduceMean schema change across opsets: + - opset < 18: axes is an ATTRIBUTE + - opset >=18: axes is an INPUT tensor (int64), keepdims remains an attribute. + """ + + def __call__(self, model: onnx.ModelProto): + from onnx import numpy_helper + from onnx.helper import tensor_dtype_to_np_dtype + + dag = OnnxDAG(model) + + # Determine the default ONNX opset for the main domain ("", "ai.onnx"). + # We'll use this to decide how to build ReduceMean. + default_opset = None + for imp in model.opset_import: + if imp.domain in ("", "ai.onnx"): + default_opset = imp.version + break + if default_opset is None: + # Fall back defensively; most models have a default import. + default_opset = 13 + + use_axes_input_for_reduce_mean = default_opset >= 18 + + modified = 0 + + for node_name in dag.get_node_names(): + op_type = dag.get_node_op_type(node_name) + if op_type not in {"SimplifiedLayerNormalization", "SkipSimplifiedLayerNormalization"}: + continue + + graph_idx = dag.get_graph_idx(node_name) + inputs = dag.get_node_inputs(node_name, True) + outputs = dag.get_node_outputs(node_name, True) + + # --------------------------- + # Build the input to be normalized: ln_input + # --------------------------- + if op_type == "SkipSimplifiedLayerNormalization": + # Expect inputs: [input, skip, gamma] + if len(inputs) != 3: + continue + root1, root2, gamma = inputs + + # Add(input, skip) => skip_add_out + skip_add_name = self.create_new_name(node_name, op_type, "Add") + skip_add_out = f"{skip_add_name}_out" + skip_add_node = onnx.helper.make_node( + "Add", + inputs=[root1, root2], + outputs=[skip_add_out], + name=skip_add_name, + ) + dag.add_node(skip_add_node, graph_idx) + + ln_input = skip_add_out + else: + # SimplifiedLayerNormalization: inputs = [x, gamma] + if len(inputs) != 2: + continue + ln_input, gamma = inputs + + # The original primary output (normalized tensor) + ln_output = outputs[0] + + ln_elem_type = dag.get_io_elem_type(inputs[0]) or onnx.TensorProto.FLOAT + ln_np_dtype = tensor_dtype_to_np_dtype(ln_elem_type) + + # --------------------------- + # Step 1: Pow(x, 2) + # --------------------------- + pow_name = self.create_new_name(node_name, op_type, "Pow") + pow_out = f"{pow_name}_out" + pow_const = numpy_helper.from_array(np.array([2.0], dtype=ln_np_dtype), name=f"{pow_name}_const") + dag.add_initializer(pow_const, graph_idx) + pow_node = onnx.helper.make_node( + "Pow", + inputs=[ln_input, pow_const.name], + outputs=[pow_out], + name=pow_name, + ) + dag.add_node(pow_node, graph_idx) + + # --------------------------- + # Step 2: ReduceMean over last dim, keepdims=1 + # - opset < 18 : axes is an attribute + # - opset >= 18: axes is an input tensor (INT64) + # --------------------------- + mean_name = self.create_new_name(node_name, op_type, "ReduceMean") + mean_out = f"{mean_name}_out" + + if use_axes_input_for_reduce_mean: + axes_init = numpy_helper.from_array(np.array([-1], dtype=np.int64), name=f"{mean_name}_axes") + dag.add_initializer(axes_init, graph_idx) + + mean_node = onnx.helper.make_node( + "ReduceMean", + inputs=[pow_out, axes_init.name], + outputs=[mean_out], + name=mean_name, + keepdims=1, + ) + else: + # Older schema: axes is an attribute + mean_node = onnx.helper.make_node( + "ReduceMean", + inputs=[pow_out], + outputs=[mean_out], + name=mean_name, + axes=[-1], + keepdims=1, + ) + dag.add_node(mean_node, graph_idx) + + # --------------------------- + # Step 3: Add epsilon + # --------------------------- + eps_value = 1e-06 + add_eps_name = self.create_new_name(node_name, op_type, "AddEps") + add_eps_out = f"{add_eps_name}_out" + + eps_const = numpy_helper.from_array(np.array([eps_value], dtype=ln_np_dtype), name=f"{add_eps_name}_const") + dag.add_initializer(eps_const, graph_idx) + + add_eps_node = onnx.helper.make_node( + "Add", + inputs=[mean_out, eps_const.name], + outputs=[add_eps_out], + name=add_eps_name, + ) + dag.add_node(add_eps_node, graph_idx) + + # --------------------------- + # Step 4: Sqrt + # --------------------------- + sqrt_name = self.create_new_name(node_name, op_type, "Sqrt") + sqrt_out = f"{sqrt_name}_out" + sqrt_node = onnx.helper.make_node( + "Sqrt", + inputs=[add_eps_out], + outputs=[sqrt_out], + name=sqrt_name, + ) + dag.add_node(sqrt_node, graph_idx) + + # --------------------------- + # Step 5: Div (x / sqrt(...)) + # --------------------------- + div_name = self.create_new_name(node_name, op_type, "Div") + div_out = f"{div_name}_out" + div_node = onnx.helper.make_node( + "Div", + inputs=[ln_input, sqrt_out], + outputs=[div_out], + name=div_name, + ) + dag.add_node(div_node, graph_idx) + + # --------------------------- + # Step 6: Mul with gamma + # --------------------------- + mul_name = self.create_new_name(node_name, op_type, "Mul") + mul_out = f"{mul_name}_out" + mul_node = onnx.helper.make_node( + "Mul", + inputs=[div_out, gamma], + outputs=[mul_out], + name=mul_name, + ) + dag.add_node(mul_node, graph_idx) + + # --------------------------- + # Rewire consumers of the original main output + # --------------------------- + for consumer in dag.get_consumers(ln_output): + dag.replace_node_input(consumer, ln_output, mul_out) + + # --------------------------- + # For SkipSimplifiedLayerNormalization that had two outputs: + # - Output 1 is typically residual sum (input_skip_bias_sum) + # - Redirect its consumers to the skip-sum Add output + # --------------------------- + if op_type == "SkipSimplifiedLayerNormalization" and len(outputs) == 2: + second_output = outputs[1] + + second_vi = dag.get_value_info_proto(second_output) + if second_vi is not None: + new_vi = onnx.ValueInfoProto() + new_vi.CopyFrom(second_vi) + new_vi.name = skip_add_out + dag.add_value_info(new_vi, graph_idx) + + # Redirect all consumers of the second output + for consumer in dag.get_consumers(second_output): + dag.replace_node_input(consumer, second_output, skip_add_out) + + dag.remove_node(node_name) + modified += 1 + + if modified > 0: + logger.debug( + "Replaced %d Simplified/SkipSimplifiedLayerNormalization nodes with RMSNorm subgraphs", modified + ) + + dag.update() + return dag.model + + class SimplifiedLayerNormToL2Norm(ProtoSurgeon): """Replace Skip/SimplifiedLayerNormalization node with L2Norm subgraph. diff --git a/test/passes/onnx/test_graph_surgeries.py b/test/passes/onnx/test_graph_surgeries.py index 337ee9139..ad44db040 100644 --- a/test/passes/onnx/test_graph_surgeries.py +++ b/test/passes/onnx/test_graph_surgeries.py @@ -588,6 +588,155 @@ def test_simplifiedlayernorm_to_l2norm_skip(tmp_path, all_ones, output_skip_sum) ) +def check_rmsnorm( + original_model_path: str, + modified_model_path: str, + hidden_size: int, + expected_num_nodes: int, + has_skip: bool = False, +): + # check output values match + input_session = InferenceSession(original_model_path) + output_session = InferenceSession(modified_model_path) + input_feed = {"x": np.random.randn(1, hidden_size).astype(np.float32)} + if has_skip: + input_feed["skip"] = np.random.randn(1, hidden_size).astype(np.float32) + input_result = input_session.run(None, input_feed) + output_result = output_session.run(None, input_feed) + for i_r, o_r in zip(input_result, output_result): + np.testing.assert_allclose(i_r, o_r, rtol=1e-3, atol=1e-3) + + # count nodes and verify expected op types are present + dag = OnnxDAG.from_model_path(modified_model_path) + assert len(dag.nodes) == expected_num_nodes + op_types = dag.get_node_op_types() + assert "Pow" in op_types + assert "ReduceMean" in op_types + assert "Sqrt" in op_types + assert "Div" in op_types + assert "Mul" in op_types + assert "SimplifiedLayerNormalization" not in op_types + assert "SkipSimplifiedLayerNormalization" not in op_types + + +@pytest.mark.parametrize("all_ones", [True, False]) +def test_simplifiedlayernorm_to_rmsnorm(tmp_path, all_ones): + # setup + hidden_size = 3 + inputs = [ + onnx.helper.make_tensor_value_info("x", TensorProto.FLOAT, [1, hidden_size]), + ] + outputs = [ + onnx.helper.make_tensor_value_info("y", TensorProto.FLOAT, [1, hidden_size]), + ] + weight = (np.ones(hidden_size) if all_ones else np.random.randn(hidden_size)).astype(np.float32) + initializers = [onnx.numpy_helper.from_array(weight, name="weight")] + nodes = [ + onnx.helper.make_node( + "SimplifiedLayerNormalization", + inputs=["x", "weight"], + outputs=["layernorm_output"], + name="layernorm/LayerNorm", + ), + onnx.helper.make_node("Identity", inputs=["layernorm_output"], outputs=["y"], name="Identity"), + ] + graph = helper.make_graph( + nodes=nodes, + name="TestGraph", + inputs=inputs, + outputs=outputs, + initializer=initializers, + ) + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 20)]) + model.ir_version = 10 + onnx.save(model, str(tmp_path / "input_model.onnx")) + input_model = ONNXModelHandler(model_path=str(tmp_path / "input_model.onnx")) + + output_folder = str(tmp_path / "output") + p = create_pass_from_dict( + GraphSurgeries, + {"surgeries": [{"surgeon": "SimplifiedLayerNormToRMSNorm"}]}, + disable_search=True, + ) + + # execute + onnx_model = p.run(input_model, output_folder) + + # assert + # Pow, ReduceMean, Add(eps), Sqrt, Div, Mul, Identity = 7 nodes + check_rmsnorm(str(tmp_path / "input_model.onnx"), onnx_model.model_path, hidden_size, 7) + + +@pytest.mark.parametrize("all_ones", [True, False]) +@pytest.mark.parametrize("output_skip_sum", [True, False]) +def test_simplifiedlayernorm_to_rmsnorm_skip(tmp_path, all_ones, output_skip_sum): + # setup + hidden_size = 3 + inputs = [ + onnx.helper.make_tensor_value_info("x", TensorProto.FLOAT, [1, hidden_size]), + onnx.helper.make_tensor_value_info("skip", TensorProto.FLOAT, [1, hidden_size]), + ] + outputs = [ + onnx.helper.make_tensor_value_info("y", TensorProto.FLOAT, [1, hidden_size]), + ] + if output_skip_sum: + outputs.append( + onnx.helper.make_tensor_value_info("skip_sum", TensorProto.FLOAT, [1, hidden_size]), + ) + initializers = [ + onnx.numpy_helper.from_array( + (np.ones(hidden_size) if all_ones else np.random.randn(hidden_size)).astype(np.float32), name="weight" + ) + ] + nodes = [ + onnx.helper.make_node( + "SkipSimplifiedLayerNormalization", + inputs=["x", "skip", "weight"], + outputs=["layernorm_output"] if not output_skip_sum else ["layernorm_output", "", "", "layernorm_skip_sum"], + name="layernorm/LayerNorm", + domain=MSFT_DOMAIN, + ), + onnx.helper.make_node("Identity", inputs=["layernorm_output"], outputs=["y"], name="Identity"), + ] + if output_skip_sum: + nodes.append( + onnx.helper.make_node( + "Identity", inputs=["layernorm_skip_sum"], outputs=["skip_sum"], name="Identity_skip_sum" + ) + ) + graph = helper.make_graph( + nodes=nodes, + name="TestGraph", + inputs=inputs, + outputs=outputs, + initializer=initializers, + ) + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 20)]) + model.ir_version = 10 + onnx.save(model, str(tmp_path / "input_model.onnx")) + input_model = ONNXModelHandler(model_path=str(tmp_path / "input_model.onnx")) + + output_folder = str(tmp_path / "output") + p = create_pass_from_dict( + GraphSurgeries, + {"surgeries": [{"surgeon": "SimplifiedLayerNormToRMSNorm"}]}, + disable_search=True, + ) + + # execute + output_model = p.run(input_model, output_folder) + + # assert + # Add(skip), Pow, ReduceMean, Add(eps), Sqrt, Div, Mul, Identity[, Identity_skip_sum] = 8 or 9 nodes + check_rmsnorm( + str(tmp_path / "input_model.onnx"), + output_model.model_path, + hidden_size, + 8 + int(output_skip_sum), + has_skip=True, + ) + + @pytest.mark.parametrize("use_large_cache", [True, False]) def test_remove_rope_multi_cache(tmp_path, use_large_cache): # setup