From 90944cf877f411cd8a5569a3f303967b342d6dd7 Mon Sep 17 00:00:00 2001 From: Greg Partin Date: Wed, 11 Mar 2026 12:01:19 -0700 Subject: [PATCH 1/3] Add Klein-Gordon equation stepper MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add a KleinGordon stepper for the relativistic wave equation uₜₜ = c² Δu - m² u, using the same spectral diagonalization approach as the Wave stepper. Key differences from Wave: - Uses Klein-Gordon dispersion: ω(k) = √(c²|k|² + m²) - Has mass gap: no modes with ω < m exist - DC mode (k=0) is diagonalizable when m > 0 (ω(0) = m) - Setting mass=0 recovers the standard wave equation Files: - exponax/stepper/_klein_gordon.py: KleinGordon stepper class - exponax/stepper/__init__.py: Updated exports - tests/test_builtin_solvers.py: Added to instantiation tests --- exponax/stepper/__init__.py | 6 ++ exponax/stepper/_klein_gordon.py | 178 +++++++++++++++++++++++++++++++ tests/test_builtin_solvers.py | 1 + 3 files changed, 185 insertions(+) create mode 100644 exponax/stepper/_klein_gordon.py diff --git a/exponax/stepper/__init__.py b/exponax/stepper/__init__.py index e012ec7..a81a8f1 100644 --- a/exponax/stepper/__init__.py +++ b/exponax/stepper/__init__.py @@ -61,6 +61,10 @@ The Wave stepper uses a handcrafted diagonalization in Fourier space specific to the wave equation. It has no corresponding generic stepper. +The KleinGordon stepper extends the Wave stepper with a mass term, using the +Klein-Gordon dispersion relation ω(k) = √(c²|k|² + m²). Setting m=0 recovers +the wave equation. + In the reaction submodule you find specific steppers that are special cases of the GeneralPolynomialStepper, e.g., the FisherKPPStepper. @@ -91,6 +95,7 @@ NavierStokesVelocity, NavierStokesVorticity, ) +from ._klein_gordon import KleinGordon from ._wave import Wave __all__ = [ @@ -100,6 +105,7 @@ "Dispersion", "HyperDiffusion", "Wave", + "KleinGordon", "Burgers", "KortewegDeVries", "KuramotoSivashinsky", diff --git a/exponax/stepper/_klein_gordon.py b/exponax/stepper/_klein_gordon.py new file mode 100644 index 0000000..7700e9a --- /dev/null +++ b/exponax/stepper/_klein_gordon.py @@ -0,0 +1,178 @@ +import jax.numpy as jnp +from jaxtyping import Array, Complex, Float + +from .._base_stepper import BaseStepper +from .._spectral import build_scaled_wavenumbers +from ..nonlin_fun import ZeroNonlinearFun + + +class KleinGordon(BaseStepper): + mass: float + speed_of_sound: float + frequency: Float[Array, " 1 ... (N//2)+1"] + + def __init__( + self, + num_spatial_dims: int, + domain_extent: float, + num_points: int, + dt: float, + *, + speed_of_sound: float = 1.0, + mass: float = 1.0, + ): + """ + Timestepper for the d-dimensional (`d ∈ {1, 2, 3}`) Klein-Gordon + equation on periodic boundary conditions. + + In 1d, the Klein-Gordon equation is given by + + ``` + uₜₜ = c² uₓₓ - m² u + ``` + + with `c ∈ ℝ` being the wave speed and `m ∈ ℝ` being the mass + parameter. This is the relativistic generalization of the wave + equation, fundamental to quantum field theory and lattice field + simulations. + + In higher dimensions: + + ``` + uₜₜ = c² Δu - m² u + ``` + + **Dispersion relation:** ω(k) = √(c²|k|² + m²) + + Unlike the wave equation (ω = c|k|), the Klein-Gordon equation has a + **mass gap** — no modes with ω < m exist. The group velocity + v_g = c²|k|/ω is always less than c (massive dispersion). + + Internally, the same diagonalization approach as the + [`exponax.stepper.Wave`][] stepper is used, but with + the Klein-Gordon dispersion relation. + + The second-order equation is rewritten as a first-order system: + + ``` + hₜ = v + vₜ = c² Δh - m² h + ``` + + In Fourier space, each wavenumber k oscillates at frequency + ω(k) = √(c²|k|² + m²). The system is diagonalized into + forward/backward traveling modes that each evolve as a pure + phase rotation. + + **Arguments:** + + - `num_spatial_dims`: The number of spatial dimensions `d`. + - `domain_extent`: The size of the domain `L`; in higher dimensions + the domain is assumed to be a scaled hypercube `Ω = (0, L)ᵈ`. + - `num_points`: The number of points `N` used to discretize the + domain. This **includes** the left boundary point and **excludes** + the right boundary point. In higher dimensions; the number of points + in each dimension is the same. + - `dt`: The timestep size `Δt` between two consecutive states. + - `speed_of_sound` (keyword-only): The wave speed `c`. Default: `1.0`. + - `mass` (keyword-only): The mass parameter `m`. Default: `1.0`. + + **Notes:** + + - The stepper is unconditionally stable, no matter the choice of + any argument because the equation is solved analytically in Fourier + space. + - Setting `mass = 0.0` recovers the standard wave equation. + - The factors `c Δt / L` and `m Δt` together affect the dynamics. + """ + self.speed_of_sound = speed_of_sound + self.mass = mass + wavenumber_norm = jnp.linalg.norm( + build_scaled_wavenumbers( + num_spatial_dims=num_spatial_dims, + domain_extent=domain_extent, + num_points=num_points, + ), + axis=0, + keepdims=True, + ) + # Klein-Gordon dispersion: ω(k) = sqrt(c²|k|² + m²) + self.frequency = jnp.sqrt( + speed_of_sound**2 * wavenumber_norm**2 + mass**2 + ) + super().__init__( + num_spatial_dims=num_spatial_dims, + domain_extent=domain_extent, + num_points=num_points, + dt=dt, + num_channels=2, + order=0, + ) + + def _forward_transform( + self, u_hat: Complex[Array, " 2 ... (N//2)+1"] + ) -> Complex[Array, " 2 ... (N//2)+1"]: + """Transform (h, v) into diagonalized Klein-Gordon wave modes.""" + h_hat, v_hat = u_hat[0:1], u_hat[1:2] + # Scale height to match velocity units: w = iω h + omega_guard = jnp.where(self.frequency == 0, 1.0, self.frequency) + w_hat = 1j * omega_guard * h_hat + + # Orthonormal rotation into wave modes + pos = (1 / jnp.sqrt(2)) * (w_hat + v_hat) + neg = (1 / jnp.sqrt(2)) * (w_hat - v_hat) + return jnp.concatenate([pos, neg], axis=0) + + def _inverse_transform( + self, waves_hat: Complex[Array, " 2 ... (N//2)+1"] + ) -> Complex[Array, " 2 ... (N//2)+1"]: + """Transform diagonalized wave modes back into (h, v).""" + pos, neg = waves_hat[0:1], waves_hat[1:2] + # Inverse rotation + w_hat = (1 / jnp.sqrt(2)) * (pos + neg) + v_hat = (1 / jnp.sqrt(2)) * (pos - neg) + + # Undo scaling to recover height + omega_guard = jnp.where(self.frequency == 0, 1.0, self.frequency) + h_hat = w_hat / (1j * omega_guard) + return jnp.concatenate([h_hat, v_hat], axis=0) + + def _build_linear_operator( + self, derivative_operator: Complex[Array, " D ... (N//2)+1"] + ) -> Complex[Array, " 2 ... (N//2)+1"]: + val = 1j * self.frequency + return jnp.concatenate( + ( + val, + -val, + ), + axis=0, + ) + + def _build_nonlinear_fun( + self, derivative_operator: Complex[Array, " D ... (N//2)+1"] + ) -> ZeroNonlinearFun: + return ZeroNonlinearFun(self.num_spatial_dims, self.num_points) + + def step_fourier( + self, u_hat: Complex[Array, " 2 ... (N//2)+1"] + ) -> Complex[Array, " 2 ... (N//2)+1"]: + """ + Advance the state by one timestep in Fourier space. + + Overrides the base method to wrap the ETDRK step with the + forward/inverse diagonalization transforms. + """ + waves_hat = self._forward_transform(u_hat) + waves_hat_next = super().step_fourier(waves_hat) + u_hat_next = self._inverse_transform(waves_hat_next) + + # At k=0 with m>0, the system is still diagonalizable (ω(0) = m ≠ 0), + # so no special DC correction is needed. However, if m=0 we fall back + # to the wave equation DC behavior. + if self.mass == 0.0: + h_dc_idx = (0,) + (0,) * self.num_spatial_dims + v_dc_idx = (1,) + (0,) * self.num_spatial_dims + u_hat_next = u_hat_next.at[h_dc_idx].add(self.dt * u_hat[v_dc_idx]) + + return u_hat_next diff --git a/tests/test_builtin_solvers.py b/tests/test_builtin_solvers.py index 368791a..9312412 100644 --- a/tests/test_builtin_solvers.py +++ b/tests/test_builtin_solvers.py @@ -18,6 +18,7 @@ def test_instantiate(): ex.stepper.Dispersion, ex.stepper.HyperDiffusion, ex.stepper.Wave, + ex.stepper.KleinGordon, ex.stepper.Burgers, ex.stepper.KuramotoSivashinsky, ex.stepper.KuramotoSivashinskyConservative, From a9730f0c4b71488be43d0f6bddb6f24df69f903a Mon Sep 17 00:00:00 2001 From: Greg Partin Date: Thu, 12 Mar 2026 06:08:32 -0700 Subject: [PATCH 2/3] Address review: add KG tests, update stepper list - Add test_klein_gordon.py with dedicated solver tests: - Instantiation and output shape checks - mass=0 recovers Wave stepper (numerical equivalence) - Analytical standing-mode correctness - Mass gap test (k=0 oscillates at omega=m) - Energy conservation bounds - Update stepper list in __init__.py docstring to include KleinGordon --- exponax/stepper/__init__.py | 1 + tests/test_klein_gordon.py | 173 ++++++++++++++++++++++++++++++++++++ 2 files changed, 174 insertions(+) create mode 100644 tests/test_klein_gordon.py diff --git a/exponax/stepper/__init__.py b/exponax/stepper/__init__.py index a81a8f1..4030fbf 100644 --- a/exponax/stepper/__init__.py +++ b/exponax/stepper/__init__.py @@ -13,6 +13,7 @@ - Dispersion - HyperDiffusion - Wave + - KleinGordon - Burgers - KortewegDeVries - KuramotoSivashinsky diff --git a/tests/test_klein_gordon.py b/tests/test_klein_gordon.py new file mode 100644 index 0000000..48df02c --- /dev/null +++ b/tests/test_klein_gordon.py @@ -0,0 +1,173 @@ +import jax.numpy as jnp +import pytest + +import exponax as ex +from exponax.stepper import KleinGordon, Wave + +L = 2 * jnp.pi +PI = jnp.pi + + +# =========================================================================== +# Instantiation +# =========================================================================== + + +class TestKleinGordonInstantiation: + @pytest.mark.parametrize("num_spatial_dims", [1, 2, 3]) + def test_instantiate(self, num_spatial_dims): + stepper = KleinGordon(num_spatial_dims, 10.0, 25, 0.1) + assert stepper.num_channels == 2 + assert stepper.num_spatial_dims == num_spatial_dims + + @pytest.mark.parametrize("num_spatial_dims", [1, 2, 3]) + def test_output_shape(self, num_spatial_dims): + N = 16 + stepper = KleinGordon(num_spatial_dims, L, N, 0.01) + u0 = jnp.zeros((2,) + (N,) * num_spatial_dims) + u1 = stepper(u0) + assert u1.shape == u0.shape + assert jnp.all(jnp.isfinite(u1)) + + def test_wrong_input_shape_raises(self): + stepper = KleinGordon(1, L, 32, 0.01) + with pytest.raises(ValueError, match="Expected shape"): + stepper(jnp.zeros((1, 32))) # needs 2 channels + + def test_default_params(self): + stepper = KleinGordon(1, L, 32, 0.01) + assert stepper.speed_of_sound == 1.0 + assert stepper.mass == 1.0 + + def test_custom_params(self): + stepper = KleinGordon(1, L, 32, 0.01, speed_of_sound=2.0, mass=3.0) + assert stepper.speed_of_sound == 2.0 + assert stepper.mass == 3.0 + + +# =========================================================================== +# mass=0 should recover Wave equation +# =========================================================================== + + +class TestKleinGordonRecoverWave: + """When mass=0, KleinGordon must produce identical results to Wave.""" + + @pytest.mark.parametrize("num_spatial_dims", [1, 2]) + def test_mass_zero_matches_wave(self, num_spatial_dims): + N, dt, c = 32, 0.01, 1.5 + kg = KleinGordon( + num_spatial_dims, L, N, dt, speed_of_sound=c, mass=0.0 + ) + wave = Wave(num_spatial_dims, L, N, dt, speed_of_sound=c) + + x = jnp.linspace(0, L, N, endpoint=False) + if num_spatial_dims == 1: + h0 = jnp.cos(2 * x)[None] + else: + h0 = jnp.cos(2 * x)[None, :, None] * jnp.ones((1, N, N)) + v0 = jnp.zeros_like(h0) + u0 = jnp.concatenate([h0, v0], axis=0) + + u_kg = u0 + u_wave = u0 + for _ in range(20): + u_kg = kg(u_kg) + u_wave = wave(u_wave) + + assert u_kg == pytest.approx(u_wave, abs=1e-5) + + def test_mass_zero_matches_wave_multi_step(self): + """Longer evolution to catch accumulation drift.""" + N, dt, c = 64, 0.005, 1.0 + kg = KleinGordon(1, L, N, dt, speed_of_sound=c, mass=0.0) + wave = Wave(1, L, N, dt, speed_of_sound=c) + + x = jnp.linspace(0, L, N, endpoint=False) + h0 = (jnp.cos(x) + 0.5 * jnp.cos(3 * x))[None] + v0 = jnp.zeros_like(h0) + u0 = jnp.concatenate([h0, v0], axis=0) + + u_kg = u0 + u_wave = u0 + for _ in range(100): + u_kg = kg(u_kg) + u_wave = wave(u_wave) + + assert u_kg == pytest.approx(u_wave, abs=1e-4) + + +# =========================================================================== +# Analytical correctness — 1D Klein-Gordon standing mode +# =========================================================================== + + +class TestKleinGordonAnalytical1D: + """For h(x,0) = cos(k0 x), v(x,0) = 0: + ω = sqrt(c²k0² + m²) + h(x,t) = cos(k0 x) cos(ω t) + v(x,t) = -ω cos(k0 x) sin(ω t) + """ + + def _make_stepper_and_ic(self, k0, c=1.0, m=1.0, N=64, dt=0.01): + stepper = KleinGordon(1, L, N, dt, speed_of_sound=c, mass=m) + x = jnp.linspace(0, L, N, endpoint=False) + h0 = jnp.cos(k0 * x)[None] + v0 = jnp.zeros_like(h0) + u0 = jnp.concatenate([h0, v0], axis=0) + omega = jnp.sqrt(c**2 * k0**2 + m**2) + return stepper, x, u0, float(omega) + + @pytest.mark.parametrize("k0", [1, 2, 3, 5]) + def test_single_mode(self, k0): + c, m, N, dt = 1.0, 2.0, 64, 0.01 + stepper, x, u0, omega = self._make_stepper_and_ic(k0, c, m, N, dt) + + n_steps = 10 + u = u0 + for _ in range(n_steps): + u = stepper(u) + t = n_steps * dt + + h_exact = jnp.cos(k0 * x) * jnp.cos(omega * t) + v_exact = -omega * jnp.cos(k0 * x) * jnp.sin(omega * t) + + assert u[0] == pytest.approx(h_exact, abs=1e-4) + assert u[1] == pytest.approx(v_exact, abs=1e-3) + + def test_mass_gap(self): + """With k0=0 (uniform mode), oscillation is at ω = m (the mass gap).""" + m, N, dt = 3.0, 32, 0.01 + stepper = KleinGordon(1, L, N, dt, speed_of_sound=1.0, mass=m) + + h0 = jnp.ones((1, N)) # k=0 mode + v0 = jnp.zeros_like(h0) + u0 = jnp.concatenate([h0, v0], axis=0) + + n_steps = 20 + u = u0 + for _ in range(n_steps): + u = stepper(u) + t = n_steps * dt + + h_exact = jnp.cos(m * t) * jnp.ones(N) + assert u[0] == pytest.approx(h_exact, abs=1e-4) + + def test_energy_bounded(self): + """Total energy should be conserved (bounded) over many steps.""" + k0, c, m, N, dt = 3, 1.0, 2.0, 64, 0.005 + stepper, x, u0, omega = self._make_stepper_and_ic(k0, c, m, N, dt) + + def energy(u): + h, v = u[0], u[1] + # KE + gradient PE + mass PE + return jnp.sum(v**2 + c**2 * jnp.abs(jnp.fft.rfft(h))**2 + m**2 * h**2) + + e0 = energy(u0) + u = u0 + for _ in range(200): + u = stepper(u) + e_final = energy(u) + + # Spectral solver should conserve energy to machine precision + assert e_final == pytest.approx(float(e0), rel=1e-3) From 3dc88d6f3b00566072a1664e5a9f71c54cc235b7 Mon Sep 17 00:00:00 2001 From: Greg Partin Date: Sat, 14 Mar 2026 06:11:37 -0700 Subject: [PATCH 3/3] Address review: fix import order, remove unused imports, correct energy formula - Move KleinGordon import to alphabetical position (after _hyper_diffusion) - Remove unused 'import exponax as ex' (F401) - Remove unused PI constant (F841) - Decouple test_energy_bounded from helper to avoid unused variables (F841) - Fix energy formula: add |k|^2 weighting for gradient PE, proper dx scaling - Fix misleading 'machine precision' comment to match rel=1e-3 tolerance --- exponax/stepper/__init__.py | 2 +- tests/test_klein_gordon.py | 28 ++++++++++++++++++++-------- 2 files changed, 21 insertions(+), 9 deletions(-) diff --git a/exponax/stepper/__init__.py b/exponax/stepper/__init__.py index 4030fbf..ddc6e42 100644 --- a/exponax/stepper/__init__.py +++ b/exponax/stepper/__init__.py @@ -88,6 +88,7 @@ from ._diffusion import Diffusion from ._dispersion import Dispersion from ._hyper_diffusion import HyperDiffusion +from ._klein_gordon import KleinGordon from ._korteweg_de_vries import KortewegDeVries from ._kuramoto_sivashinsky import KuramotoSivashinsky, KuramotoSivashinskyConservative from ._navier_stokes import ( @@ -96,7 +97,6 @@ NavierStokesVelocity, NavierStokesVorticity, ) -from ._klein_gordon import KleinGordon from ._wave import Wave __all__ = [ diff --git a/tests/test_klein_gordon.py b/tests/test_klein_gordon.py index 48df02c..f1bea32 100644 --- a/tests/test_klein_gordon.py +++ b/tests/test_klein_gordon.py @@ -1,11 +1,9 @@ import jax.numpy as jnp import pytest -import exponax as ex from exponax.stepper import KleinGordon, Wave -L = 2 * jnp.pi -PI = jnp.pi +L = 2 * jnp.pi # domain length # =========================================================================== @@ -154,14 +152,28 @@ def test_mass_gap(self): assert u[0] == pytest.approx(h_exact, abs=1e-4) def test_energy_bounded(self): - """Total energy should be conserved (bounded) over many steps.""" + """Total Klein-Gordon energy should be conserved over many steps.""" k0, c, m, N, dt = 3, 1.0, 2.0, 64, 0.005 - stepper, x, u0, omega = self._make_stepper_and_ic(k0, c, m, N, dt) + stepper = KleinGordon(1, L, N, dt, speed_of_sound=c, mass=m) + dx = L / N + + x = jnp.linspace(0, L, N, endpoint=False) + h0 = jnp.cos(k0 * x)[None] + v0 = jnp.zeros_like(h0) + u0 = jnp.concatenate([h0, v0], axis=0) + + # Wavenumbers for gradient energy in Fourier space + k = jnp.fft.rfftfreq(N, d=dx) * 2 * jnp.pi def energy(u): h, v = u[0], u[1] - # KE + gradient PE + mass PE - return jnp.sum(v**2 + c**2 * jnp.abs(jnp.fft.rfft(h))**2 + m**2 * h**2) + h_hat = jnp.fft.rfft(h) + # KE: ½∫v² dx, gradient PE: ½c²∫|∇h|² dx, mass PE: ½m²∫h² dx + # Parseval: ∫|∇h|² dx = (1/N) Σ |k|² |ĥ(k)|² + ke = 0.5 * jnp.sum(v**2) * dx + grad_pe = 0.5 * c**2 * jnp.sum(k**2 * jnp.abs(h_hat)**2) / N + mass_pe = 0.5 * m**2 * jnp.sum(h**2) * dx + return ke + grad_pe + mass_pe e0 = energy(u0) u = u0 @@ -169,5 +181,5 @@ def energy(u): u = stepper(u) e_final = energy(u) - # Spectral solver should conserve energy to machine precision + # Spectral stepper conserves energy well; f32 accumulation limits precision assert e_final == pytest.approx(float(e0), rel=1e-3)