Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
219 changes: 219 additions & 0 deletions olive/passes/onnx/graph_surgeries.py
Original file line number Diff line number Diff line change
Expand Up @@ -850,6 +850,225 @@ def get_rmsnorm_nodes(pow_node: str, dag: OnnxDAG) -> list[str] | None:
return rmsnorm_nodes if len(rmsnorm_nodes) >= (len(pattern) - 1) else []


class SimplifiedLayerNormToRMSNorm(ProtoSurgeon):
Comment thread
unnim-qti marked this conversation as resolved.
"""Replace SimplifiedLayerNormalization or SkipSimplifiedLayerNormalization with an RMSNorm subgraph built from elementwise ops.

RMS(x) = sqrt(mean(x^2, axis=-1, keepdims=1) + eps)
y = (x / RMS(x)) * gamma

For SkipSimplifiedLayerNormalization, we first do:
s = input + skip
and use 's' as x for RMSNorm. If the original node exposes a second output
(residual sum), we rewire its consumers to 's' to preserve graph behavior.

IMPORTANT: ReduceMean schema change across opsets:
- opset < 18: axes is an ATTRIBUTE
- opset >=18: axes is an INPUT tensor (int64), keepdims remains an attribute.
"""

def __call__(self, model: onnx.ModelProto):
from onnx import numpy_helper
from onnx.helper import tensor_dtype_to_np_dtype

dag = OnnxDAG(model)

# Determine the default ONNX opset for the main domain ("", "ai.onnx").
# We'll use this to decide how to build ReduceMean.
default_opset = None
for imp in model.opset_import:
if imp.domain in ("", "ai.onnx"):
default_opset = imp.version
break
if default_opset is None:
# Fall back defensively; most models have a default import.
default_opset = 13

use_axes_input_for_reduce_mean = default_opset >= 18

modified = 0

for node_name in dag.get_node_names():
op_type = dag.get_node_op_type(node_name)
if op_type not in {"SimplifiedLayerNormalization", "SkipSimplifiedLayerNormalization"}:
continue

graph_idx = dag.get_graph_idx(node_name)
inputs = dag.get_node_inputs(node_name, True)
outputs = dag.get_node_outputs(node_name, True)

# ---------------------------
# Build the input to be normalized: ln_input
# ---------------------------
if op_type == "SkipSimplifiedLayerNormalization":
# Expect inputs: [input, skip, gamma]
if len(inputs) != 3:
continue
root1, root2, gamma = inputs

# Add(input, skip) => skip_add_out
skip_add_name = self.create_new_name(node_name, op_type, "Add")
skip_add_out = f"{skip_add_name}_out"
skip_add_node = onnx.helper.make_node(
"Add",
inputs=[root1, root2],
outputs=[skip_add_out],
name=skip_add_name,
)
dag.add_node(skip_add_node, graph_idx)

ln_input = skip_add_out
else:
# SimplifiedLayerNormalization: inputs = [x, gamma]
if len(inputs) != 2:
continue
ln_input, gamma = inputs

# The original primary output (normalized tensor)
ln_output = outputs[0]

ln_elem_type = dag.get_io_elem_type(inputs[0]) or onnx.TensorProto.FLOAT
ln_np_dtype = tensor_dtype_to_np_dtype(ln_elem_type)

# ---------------------------
# Step 1: Pow(x, 2)
# ---------------------------
pow_name = self.create_new_name(node_name, op_type, "Pow")
pow_out = f"{pow_name}_out"
pow_const = numpy_helper.from_array(np.array([2.0], dtype=ln_np_dtype), name=f"{pow_name}_const")
dag.add_initializer(pow_const, graph_idx)
pow_node = onnx.helper.make_node(
"Pow",
inputs=[ln_input, pow_const.name],
outputs=[pow_out],
name=pow_name,
)
dag.add_node(pow_node, graph_idx)

# ---------------------------
# Step 2: ReduceMean over last dim, keepdims=1
# - opset < 18 : axes is an attribute
# - opset >= 18: axes is an input tensor (INT64)
# ---------------------------
mean_name = self.create_new_name(node_name, op_type, "ReduceMean")
mean_out = f"{mean_name}_out"

if use_axes_input_for_reduce_mean:
axes_init = numpy_helper.from_array(np.array([-1], dtype=np.int64), name=f"{mean_name}_axes")
dag.add_initializer(axes_init, graph_idx)

mean_node = onnx.helper.make_node(
"ReduceMean",
inputs=[pow_out, axes_init.name],
outputs=[mean_out],
name=mean_name,
keepdims=1,
)
else:
# Older schema: axes is an attribute
mean_node = onnx.helper.make_node(
"ReduceMean",
inputs=[pow_out],
outputs=[mean_out],
name=mean_name,
axes=[-1],
keepdims=1,
)
dag.add_node(mean_node, graph_idx)

# ---------------------------
# Step 3: Add epsilon
# ---------------------------
eps_value = 1e-06
add_eps_name = self.create_new_name(node_name, op_type, "AddEps")
add_eps_out = f"{add_eps_name}_out"

eps_const = numpy_helper.from_array(np.array([eps_value], dtype=ln_np_dtype), name=f"{add_eps_name}_const")
dag.add_initializer(eps_const, graph_idx)

add_eps_node = onnx.helper.make_node(
"Add",
inputs=[mean_out, eps_const.name],
outputs=[add_eps_out],
name=add_eps_name,
)
dag.add_node(add_eps_node, graph_idx)

# ---------------------------
# Step 4: Sqrt
# ---------------------------
sqrt_name = self.create_new_name(node_name, op_type, "Sqrt")
sqrt_out = f"{sqrt_name}_out"
sqrt_node = onnx.helper.make_node(
"Sqrt",
inputs=[add_eps_out],
outputs=[sqrt_out],
name=sqrt_name,
)
dag.add_node(sqrt_node, graph_idx)

# ---------------------------
# Step 5: Div (x / sqrt(...))
# ---------------------------
div_name = self.create_new_name(node_name, op_type, "Div")
div_out = f"{div_name}_out"
div_node = onnx.helper.make_node(
"Div",
inputs=[ln_input, sqrt_out],
outputs=[div_out],
name=div_name,
)
dag.add_node(div_node, graph_idx)

# ---------------------------
# Step 6: Mul with gamma
# ---------------------------
mul_name = self.create_new_name(node_name, op_type, "Mul")
mul_out = f"{mul_name}_out"
mul_node = onnx.helper.make_node(
"Mul",
inputs=[div_out, gamma],
outputs=[mul_out],
name=mul_name,
)
dag.add_node(mul_node, graph_idx)

# ---------------------------
# Rewire consumers of the original main output
# ---------------------------
for consumer in dag.get_consumers(ln_output):
dag.replace_node_input(consumer, ln_output, mul_out)

# ---------------------------
# For SkipSimplifiedLayerNormalization that had two outputs:
# - Output 1 is typically residual sum (input_skip_bias_sum)
# - Redirect its consumers to the skip-sum Add output
# ---------------------------
if op_type == "SkipSimplifiedLayerNormalization" and len(outputs) == 2:
second_output = outputs[1]

second_vi = dag.get_value_info_proto(second_output)
if second_vi is not None:
new_vi = onnx.ValueInfoProto()
new_vi.CopyFrom(second_vi)
new_vi.name = skip_add_out
dag.add_value_info(new_vi, graph_idx)

# Redirect all consumers of the second output
for consumer in dag.get_consumers(second_output):
dag.replace_node_input(consumer, second_output, skip_add_out)

dag.remove_node(node_name)
modified += 1

if modified > 0:
logger.debug(
"Replaced %d Simplified/SkipSimplifiedLayerNormalization nodes with RMSNorm subgraphs", modified
)

dag.update()
return dag.model


class SimplifiedLayerNormToL2Norm(ProtoSurgeon):
"""Replace Skip/SimplifiedLayerNormalization node with L2Norm subgraph.

Expand Down
149 changes: 149 additions & 0 deletions test/passes/onnx/test_graph_surgeries.py
Original file line number Diff line number Diff line change
Expand Up @@ -588,6 +588,155 @@ def test_simplifiedlayernorm_to_l2norm_skip(tmp_path, all_ones, output_skip_sum)
)


def check_rmsnorm(
original_model_path: str,
modified_model_path: str,
hidden_size: int,
expected_num_nodes: int,
has_skip: bool = False,
):
# check output values match
input_session = InferenceSession(original_model_path)
output_session = InferenceSession(modified_model_path)
input_feed = {"x": np.random.randn(1, hidden_size).astype(np.float32)}
if has_skip:
input_feed["skip"] = np.random.randn(1, hidden_size).astype(np.float32)
input_result = input_session.run(None, input_feed)
output_result = output_session.run(None, input_feed)
for i_r, o_r in zip(input_result, output_result):
np.testing.assert_allclose(i_r, o_r, rtol=1e-3, atol=1e-3)

# count nodes and verify expected op types are present
dag = OnnxDAG.from_model_path(modified_model_path)
assert len(dag.nodes) == expected_num_nodes
op_types = dag.get_node_op_types()
assert "Pow" in op_types
assert "ReduceMean" in op_types
assert "Sqrt" in op_types
assert "Div" in op_types
assert "Mul" in op_types
assert "SimplifiedLayerNormalization" not in op_types
assert "SkipSimplifiedLayerNormalization" not in op_types


@pytest.mark.parametrize("all_ones", [True, False])
def test_simplifiedlayernorm_to_rmsnorm(tmp_path, all_ones):
# setup
hidden_size = 3
inputs = [
onnx.helper.make_tensor_value_info("x", TensorProto.FLOAT, [1, hidden_size]),
]
outputs = [
onnx.helper.make_tensor_value_info("y", TensorProto.FLOAT, [1, hidden_size]),
]
weight = (np.ones(hidden_size) if all_ones else np.random.randn(hidden_size)).astype(np.float32)
initializers = [onnx.numpy_helper.from_array(weight, name="weight")]
nodes = [
onnx.helper.make_node(
"SimplifiedLayerNormalization",
inputs=["x", "weight"],
outputs=["layernorm_output"],
name="layernorm/LayerNorm",
),
onnx.helper.make_node("Identity", inputs=["layernorm_output"], outputs=["y"], name="Identity"),
]
graph = helper.make_graph(
nodes=nodes,
name="TestGraph",
inputs=inputs,
outputs=outputs,
initializer=initializers,
)
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 20)])
model.ir_version = 10
onnx.save(model, str(tmp_path / "input_model.onnx"))
input_model = ONNXModelHandler(model_path=str(tmp_path / "input_model.onnx"))

output_folder = str(tmp_path / "output")
p = create_pass_from_dict(
GraphSurgeries,
{"surgeries": [{"surgeon": "SimplifiedLayerNormToRMSNorm"}]},
disable_search=True,
)

# execute
onnx_model = p.run(input_model, output_folder)

# assert
# Pow, ReduceMean, Add(eps), Sqrt, Div, Mul, Identity = 7 nodes
check_rmsnorm(str(tmp_path / "input_model.onnx"), onnx_model.model_path, hidden_size, 7)


@pytest.mark.parametrize("all_ones", [True, False])
@pytest.mark.parametrize("output_skip_sum", [True, False])
def test_simplifiedlayernorm_to_rmsnorm_skip(tmp_path, all_ones, output_skip_sum):
# setup
hidden_size = 3
inputs = [
onnx.helper.make_tensor_value_info("x", TensorProto.FLOAT, [1, hidden_size]),
onnx.helper.make_tensor_value_info("skip", TensorProto.FLOAT, [1, hidden_size]),
]
outputs = [
onnx.helper.make_tensor_value_info("y", TensorProto.FLOAT, [1, hidden_size]),
]
if output_skip_sum:
outputs.append(
onnx.helper.make_tensor_value_info("skip_sum", TensorProto.FLOAT, [1, hidden_size]),
)
initializers = [
onnx.numpy_helper.from_array(
(np.ones(hidden_size) if all_ones else np.random.randn(hidden_size)).astype(np.float32), name="weight"
)
]
nodes = [
onnx.helper.make_node(
"SkipSimplifiedLayerNormalization",
inputs=["x", "skip", "weight"],
outputs=["layernorm_output"] if not output_skip_sum else ["layernorm_output", "", "", "layernorm_skip_sum"],
name="layernorm/LayerNorm",
domain=MSFT_DOMAIN,
),
onnx.helper.make_node("Identity", inputs=["layernorm_output"], outputs=["y"], name="Identity"),
]
if output_skip_sum:
nodes.append(
onnx.helper.make_node(
"Identity", inputs=["layernorm_skip_sum"], outputs=["skip_sum"], name="Identity_skip_sum"
)
)
graph = helper.make_graph(
nodes=nodes,
name="TestGraph",
inputs=inputs,
outputs=outputs,
initializer=initializers,
)
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 20)])
model.ir_version = 10
onnx.save(model, str(tmp_path / "input_model.onnx"))
input_model = ONNXModelHandler(model_path=str(tmp_path / "input_model.onnx"))

output_folder = str(tmp_path / "output")
p = create_pass_from_dict(
GraphSurgeries,
{"surgeries": [{"surgeon": "SimplifiedLayerNormToRMSNorm"}]},
disable_search=True,
)

# execute
output_model = p.run(input_model, output_folder)

# assert
# Add(skip), Pow, ReduceMean, Add(eps), Sqrt, Div, Mul, Identity[, Identity_skip_sum] = 8 or 9 nodes
check_rmsnorm(
str(tmp_path / "input_model.onnx"),
output_model.model_path,
hidden_size,
8 + int(output_skip_sum),
has_skip=True,
)


@pytest.mark.parametrize("use_large_cache", [True, False])
def test_remove_rope_multi_cache(tmp_path, use_large_cache):
# setup
Expand Down
Loading