Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions lib/axon/initializers.ex
Original file line number Diff line number Diff line change
Expand Up @@ -686,13 +686,18 @@ defmodule Axon.Initializers do

{m, n} = get_flat_shape(shape)

# Generate a square random matrix of size max(m, n) so QR
# produces enough orthogonal columns for wide matrices
# (e.g. LSTM weights shaped {hidden, 4*hidden})
k = max(m, n)

random_seed =
case distribution do
:uniform ->
Nx.Random.uniform_split(key, 0.0, 1.0, shape: {m, n}, type: type)
Nx.Random.uniform_split(key, 0.0, 1.0, shape: {k, k}, type: type)

:normal ->
Nx.Random.normal_split(key, 0.0, 1.0, shape: {m, n}, type: type)
Nx.Random.normal_split(key, 0.0, 1.0, shape: {k, k}, type: type)

dist ->
raise ArgumentError,
Expand Down
25 changes: 25 additions & 0 deletions test/axon/initializers_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,31 @@ defmodule Axon.InitializersTest do
)
end

test "works with wide matrices (n > m, e.g. LSTM/GRU weights)" do
init_fn = Axon.Initializers.orthogonal()

# Wide matrix like LSTM kernel {hidden, 4*hidden}
# Using small dims to keep QR fast on BinaryBackend
t = init_fn.({8, 32}, {:f, 32}, Nx.Random.key(1))
assert Nx.shape(t) == {8, 32}

# Rows should be orthonormal: t * t^T = I
identity = t |> Nx.dot(Nx.transpose(t))

assert_all_close(identity, Nx.eye(Nx.shape(identity)),
atol: 1.0e-3,
rtol: 1.0e-3
)
end

test "works with wide high-rank shapes" do
init_fn = Axon.Initializers.orthogonal()

# Shape that flattens to wide: {2, 3} -> {2, 3} where n > m
t = init_fn.({2, 8}, {:f, 32}, Nx.Random.key(1))
assert Nx.shape(t) == {2, 8}
end

test "raises on input rank less than 2" do
assert_raise ArgumentError,
~r/Axon.Initializers.orthogonal: expected input_shape shape to have at least rank 2/,
Expand Down