diff --git a/exponax/stepper/__init__.py b/exponax/stepper/__init__.py index e012ec7..ddc6e42 100644 --- a/exponax/stepper/__init__.py +++ b/exponax/stepper/__init__.py @@ -13,6 +13,7 @@ - Dispersion - HyperDiffusion - Wave + - KleinGordon - Burgers - KortewegDeVries - KuramotoSivashinsky @@ -61,6 +62,10 @@ The Wave stepper uses a handcrafted diagonalization in Fourier space specific to the wave equation. It has no corresponding generic stepper. +The KleinGordon stepper extends the Wave stepper with a mass term, using the +Klein-Gordon dispersion relation ω(k) = √(c²|k|² + m²). Setting m=0 recovers +the wave equation. + In the reaction submodule you find specific steppers that are special cases of the GeneralPolynomialStepper, e.g., the FisherKPPStepper. @@ -83,6 +88,7 @@ from ._diffusion import Diffusion from ._dispersion import Dispersion from ._hyper_diffusion import HyperDiffusion +from ._klein_gordon import KleinGordon from ._korteweg_de_vries import KortewegDeVries from ._kuramoto_sivashinsky import KuramotoSivashinsky, KuramotoSivashinskyConservative from ._navier_stokes import ( @@ -100,6 +106,7 @@ "Dispersion", "HyperDiffusion", "Wave", + "KleinGordon", "Burgers", "KortewegDeVries", "KuramotoSivashinsky", diff --git a/exponax/stepper/_klein_gordon.py b/exponax/stepper/_klein_gordon.py new file mode 100644 index 0000000..7700e9a --- /dev/null +++ b/exponax/stepper/_klein_gordon.py @@ -0,0 +1,178 @@ +import jax.numpy as jnp +from jaxtyping import Array, Complex, Float + +from .._base_stepper import BaseStepper +from .._spectral import build_scaled_wavenumbers +from ..nonlin_fun import ZeroNonlinearFun + + +class KleinGordon(BaseStepper): + mass: float + speed_of_sound: float + frequency: Float[Array, " 1 ... (N//2)+1"] + + def __init__( + self, + num_spatial_dims: int, + domain_extent: float, + num_points: int, + dt: float, + *, + speed_of_sound: float = 1.0, + mass: float = 1.0, + ): + """ + Timestepper for the d-dimensional (`d ∈ {1, 2, 3}`) Klein-Gordon + equation on periodic boundary conditions. + + In 1d, the Klein-Gordon equation is given by + + ``` + uₜₜ = c² uₓₓ - m² u + ``` + + with `c ∈ ℝ` being the wave speed and `m ∈ ℝ` being the mass + parameter. This is the relativistic generalization of the wave + equation, fundamental to quantum field theory and lattice field + simulations. + + In higher dimensions: + + ``` + uₜₜ = c² Δu - m² u + ``` + + **Dispersion relation:** ω(k) = √(c²|k|² + m²) + + Unlike the wave equation (ω = c|k|), the Klein-Gordon equation has a + **mass gap** — no modes with ω < m exist. The group velocity + v_g = c²|k|/ω is always less than c (massive dispersion). + + Internally, the same diagonalization approach as the + [`exponax.stepper.Wave`][] stepper is used, but with + the Klein-Gordon dispersion relation. + + The second-order equation is rewritten as a first-order system: + + ``` + hₜ = v + vₜ = c² Δh - m² h + ``` + + In Fourier space, each wavenumber k oscillates at frequency + ω(k) = √(c²|k|² + m²). The system is diagonalized into + forward/backward traveling modes that each evolve as a pure + phase rotation. + + **Arguments:** + + - `num_spatial_dims`: The number of spatial dimensions `d`. + - `domain_extent`: The size of the domain `L`; in higher dimensions + the domain is assumed to be a scaled hypercube `Ω = (0, L)ᵈ`. + - `num_points`: The number of points `N` used to discretize the + domain. This **includes** the left boundary point and **excludes** + the right boundary point. In higher dimensions; the number of points + in each dimension is the same. + - `dt`: The timestep size `Δt` between two consecutive states. + - `speed_of_sound` (keyword-only): The wave speed `c`. Default: `1.0`. + - `mass` (keyword-only): The mass parameter `m`. Default: `1.0`. + + **Notes:** + + - The stepper is unconditionally stable, no matter the choice of + any argument because the equation is solved analytically in Fourier + space. + - Setting `mass = 0.0` recovers the standard wave equation. + - The factors `c Δt / L` and `m Δt` together affect the dynamics. + """ + self.speed_of_sound = speed_of_sound + self.mass = mass + wavenumber_norm = jnp.linalg.norm( + build_scaled_wavenumbers( + num_spatial_dims=num_spatial_dims, + domain_extent=domain_extent, + num_points=num_points, + ), + axis=0, + keepdims=True, + ) + # Klein-Gordon dispersion: ω(k) = sqrt(c²|k|² + m²) + self.frequency = jnp.sqrt( + speed_of_sound**2 * wavenumber_norm**2 + mass**2 + ) + super().__init__( + num_spatial_dims=num_spatial_dims, + domain_extent=domain_extent, + num_points=num_points, + dt=dt, + num_channels=2, + order=0, + ) + + def _forward_transform( + self, u_hat: Complex[Array, " 2 ... (N//2)+1"] + ) -> Complex[Array, " 2 ... (N//2)+1"]: + """Transform (h, v) into diagonalized Klein-Gordon wave modes.""" + h_hat, v_hat = u_hat[0:1], u_hat[1:2] + # Scale height to match velocity units: w = iω h + omega_guard = jnp.where(self.frequency == 0, 1.0, self.frequency) + w_hat = 1j * omega_guard * h_hat + + # Orthonormal rotation into wave modes + pos = (1 / jnp.sqrt(2)) * (w_hat + v_hat) + neg = (1 / jnp.sqrt(2)) * (w_hat - v_hat) + return jnp.concatenate([pos, neg], axis=0) + + def _inverse_transform( + self, waves_hat: Complex[Array, " 2 ... (N//2)+1"] + ) -> Complex[Array, " 2 ... (N//2)+1"]: + """Transform diagonalized wave modes back into (h, v).""" + pos, neg = waves_hat[0:1], waves_hat[1:2] + # Inverse rotation + w_hat = (1 / jnp.sqrt(2)) * (pos + neg) + v_hat = (1 / jnp.sqrt(2)) * (pos - neg) + + # Undo scaling to recover height + omega_guard = jnp.where(self.frequency == 0, 1.0, self.frequency) + h_hat = w_hat / (1j * omega_guard) + return jnp.concatenate([h_hat, v_hat], axis=0) + + def _build_linear_operator( + self, derivative_operator: Complex[Array, " D ... (N//2)+1"] + ) -> Complex[Array, " 2 ... (N//2)+1"]: + val = 1j * self.frequency + return jnp.concatenate( + ( + val, + -val, + ), + axis=0, + ) + + def _build_nonlinear_fun( + self, derivative_operator: Complex[Array, " D ... (N//2)+1"] + ) -> ZeroNonlinearFun: + return ZeroNonlinearFun(self.num_spatial_dims, self.num_points) + + def step_fourier( + self, u_hat: Complex[Array, " 2 ... (N//2)+1"] + ) -> Complex[Array, " 2 ... (N//2)+1"]: + """ + Advance the state by one timestep in Fourier space. + + Overrides the base method to wrap the ETDRK step with the + forward/inverse diagonalization transforms. + """ + waves_hat = self._forward_transform(u_hat) + waves_hat_next = super().step_fourier(waves_hat) + u_hat_next = self._inverse_transform(waves_hat_next) + + # At k=0 with m>0, the system is still diagonalizable (ω(0) = m ≠ 0), + # so no special DC correction is needed. However, if m=0 we fall back + # to the wave equation DC behavior. + if self.mass == 0.0: + h_dc_idx = (0,) + (0,) * self.num_spatial_dims + v_dc_idx = (1,) + (0,) * self.num_spatial_dims + u_hat_next = u_hat_next.at[h_dc_idx].add(self.dt * u_hat[v_dc_idx]) + + return u_hat_next diff --git a/tests/test_builtin_solvers.py b/tests/test_builtin_solvers.py index 368791a..9312412 100644 --- a/tests/test_builtin_solvers.py +++ b/tests/test_builtin_solvers.py @@ -18,6 +18,7 @@ def test_instantiate(): ex.stepper.Dispersion, ex.stepper.HyperDiffusion, ex.stepper.Wave, + ex.stepper.KleinGordon, ex.stepper.Burgers, ex.stepper.KuramotoSivashinsky, ex.stepper.KuramotoSivashinskyConservative, diff --git a/tests/test_klein_gordon.py b/tests/test_klein_gordon.py new file mode 100644 index 0000000..f1bea32 --- /dev/null +++ b/tests/test_klein_gordon.py @@ -0,0 +1,185 @@ +import jax.numpy as jnp +import pytest + +from exponax.stepper import KleinGordon, Wave + +L = 2 * jnp.pi # domain length + + +# =========================================================================== +# Instantiation +# =========================================================================== + + +class TestKleinGordonInstantiation: + @pytest.mark.parametrize("num_spatial_dims", [1, 2, 3]) + def test_instantiate(self, num_spatial_dims): + stepper = KleinGordon(num_spatial_dims, 10.0, 25, 0.1) + assert stepper.num_channels == 2 + assert stepper.num_spatial_dims == num_spatial_dims + + @pytest.mark.parametrize("num_spatial_dims", [1, 2, 3]) + def test_output_shape(self, num_spatial_dims): + N = 16 + stepper = KleinGordon(num_spatial_dims, L, N, 0.01) + u0 = jnp.zeros((2,) + (N,) * num_spatial_dims) + u1 = stepper(u0) + assert u1.shape == u0.shape + assert jnp.all(jnp.isfinite(u1)) + + def test_wrong_input_shape_raises(self): + stepper = KleinGordon(1, L, 32, 0.01) + with pytest.raises(ValueError, match="Expected shape"): + stepper(jnp.zeros((1, 32))) # needs 2 channels + + def test_default_params(self): + stepper = KleinGordon(1, L, 32, 0.01) + assert stepper.speed_of_sound == 1.0 + assert stepper.mass == 1.0 + + def test_custom_params(self): + stepper = KleinGordon(1, L, 32, 0.01, speed_of_sound=2.0, mass=3.0) + assert stepper.speed_of_sound == 2.0 + assert stepper.mass == 3.0 + + +# =========================================================================== +# mass=0 should recover Wave equation +# =========================================================================== + + +class TestKleinGordonRecoverWave: + """When mass=0, KleinGordon must produce identical results to Wave.""" + + @pytest.mark.parametrize("num_spatial_dims", [1, 2]) + def test_mass_zero_matches_wave(self, num_spatial_dims): + N, dt, c = 32, 0.01, 1.5 + kg = KleinGordon( + num_spatial_dims, L, N, dt, speed_of_sound=c, mass=0.0 + ) + wave = Wave(num_spatial_dims, L, N, dt, speed_of_sound=c) + + x = jnp.linspace(0, L, N, endpoint=False) + if num_spatial_dims == 1: + h0 = jnp.cos(2 * x)[None] + else: + h0 = jnp.cos(2 * x)[None, :, None] * jnp.ones((1, N, N)) + v0 = jnp.zeros_like(h0) + u0 = jnp.concatenate([h0, v0], axis=0) + + u_kg = u0 + u_wave = u0 + for _ in range(20): + u_kg = kg(u_kg) + u_wave = wave(u_wave) + + assert u_kg == pytest.approx(u_wave, abs=1e-5) + + def test_mass_zero_matches_wave_multi_step(self): + """Longer evolution to catch accumulation drift.""" + N, dt, c = 64, 0.005, 1.0 + kg = KleinGordon(1, L, N, dt, speed_of_sound=c, mass=0.0) + wave = Wave(1, L, N, dt, speed_of_sound=c) + + x = jnp.linspace(0, L, N, endpoint=False) + h0 = (jnp.cos(x) + 0.5 * jnp.cos(3 * x))[None] + v0 = jnp.zeros_like(h0) + u0 = jnp.concatenate([h0, v0], axis=0) + + u_kg = u0 + u_wave = u0 + for _ in range(100): + u_kg = kg(u_kg) + u_wave = wave(u_wave) + + assert u_kg == pytest.approx(u_wave, abs=1e-4) + + +# =========================================================================== +# Analytical correctness — 1D Klein-Gordon standing mode +# =========================================================================== + + +class TestKleinGordonAnalytical1D: + """For h(x,0) = cos(k0 x), v(x,0) = 0: + ω = sqrt(c²k0² + m²) + h(x,t) = cos(k0 x) cos(ω t) + v(x,t) = -ω cos(k0 x) sin(ω t) + """ + + def _make_stepper_and_ic(self, k0, c=1.0, m=1.0, N=64, dt=0.01): + stepper = KleinGordon(1, L, N, dt, speed_of_sound=c, mass=m) + x = jnp.linspace(0, L, N, endpoint=False) + h0 = jnp.cos(k0 * x)[None] + v0 = jnp.zeros_like(h0) + u0 = jnp.concatenate([h0, v0], axis=0) + omega = jnp.sqrt(c**2 * k0**2 + m**2) + return stepper, x, u0, float(omega) + + @pytest.mark.parametrize("k0", [1, 2, 3, 5]) + def test_single_mode(self, k0): + c, m, N, dt = 1.0, 2.0, 64, 0.01 + stepper, x, u0, omega = self._make_stepper_and_ic(k0, c, m, N, dt) + + n_steps = 10 + u = u0 + for _ in range(n_steps): + u = stepper(u) + t = n_steps * dt + + h_exact = jnp.cos(k0 * x) * jnp.cos(omega * t) + v_exact = -omega * jnp.cos(k0 * x) * jnp.sin(omega * t) + + assert u[0] == pytest.approx(h_exact, abs=1e-4) + assert u[1] == pytest.approx(v_exact, abs=1e-3) + + def test_mass_gap(self): + """With k0=0 (uniform mode), oscillation is at ω = m (the mass gap).""" + m, N, dt = 3.0, 32, 0.01 + stepper = KleinGordon(1, L, N, dt, speed_of_sound=1.0, mass=m) + + h0 = jnp.ones((1, N)) # k=0 mode + v0 = jnp.zeros_like(h0) + u0 = jnp.concatenate([h0, v0], axis=0) + + n_steps = 20 + u = u0 + for _ in range(n_steps): + u = stepper(u) + t = n_steps * dt + + h_exact = jnp.cos(m * t) * jnp.ones(N) + assert u[0] == pytest.approx(h_exact, abs=1e-4) + + def test_energy_bounded(self): + """Total Klein-Gordon energy should be conserved over many steps.""" + k0, c, m, N, dt = 3, 1.0, 2.0, 64, 0.005 + stepper = KleinGordon(1, L, N, dt, speed_of_sound=c, mass=m) + dx = L / N + + x = jnp.linspace(0, L, N, endpoint=False) + h0 = jnp.cos(k0 * x)[None] + v0 = jnp.zeros_like(h0) + u0 = jnp.concatenate([h0, v0], axis=0) + + # Wavenumbers for gradient energy in Fourier space + k = jnp.fft.rfftfreq(N, d=dx) * 2 * jnp.pi + + def energy(u): + h, v = u[0], u[1] + h_hat = jnp.fft.rfft(h) + # KE: ½∫v² dx, gradient PE: ½c²∫|∇h|² dx, mass PE: ½m²∫h² dx + # Parseval: ∫|∇h|² dx = (1/N) Σ |k|² |ĥ(k)|² + ke = 0.5 * jnp.sum(v**2) * dx + grad_pe = 0.5 * c**2 * jnp.sum(k**2 * jnp.abs(h_hat)**2) / N + mass_pe = 0.5 * m**2 * jnp.sum(h**2) * dx + return ke + grad_pe + mass_pe + + e0 = energy(u0) + u = u0 + for _ in range(200): + u = stepper(u) + e_final = energy(u) + + # Spectral stepper conserves energy well; f32 accumulation limits precision + assert e_final == pytest.approx(float(e0), rel=1e-3)