-
Notifications
You must be signed in to change notification settings - Fork 17
Add Klein-Gordon equation stepper #103
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,178 @@ | ||
| import jax.numpy as jnp | ||
| from jaxtyping import Array, Complex, Float | ||
|
|
||
| from .._base_stepper import BaseStepper | ||
| from .._spectral import build_scaled_wavenumbers | ||
| from ..nonlin_fun import ZeroNonlinearFun | ||
|
|
||
|
|
||
| class KleinGordon(BaseStepper): | ||
| mass: float | ||
| speed_of_sound: float | ||
| frequency: Float[Array, " 1 ... (N//2)+1"] | ||
|
|
||
| def __init__( | ||
| self, | ||
| num_spatial_dims: int, | ||
| domain_extent: float, | ||
| num_points: int, | ||
| dt: float, | ||
| *, | ||
| speed_of_sound: float = 1.0, | ||
| mass: float = 1.0, | ||
| ): | ||
| """ | ||
| Timestepper for the d-dimensional (`d ∈ {1, 2, 3}`) Klein-Gordon | ||
| equation on periodic boundary conditions. | ||
|
|
||
| In 1d, the Klein-Gordon equation is given by | ||
|
|
||
| ``` | ||
| uₜₜ = c² uₓₓ - m² u | ||
| ``` | ||
|
|
||
| with `c ∈ ℝ` being the wave speed and `m ∈ ℝ` being the mass | ||
| parameter. This is the relativistic generalization of the wave | ||
| equation, fundamental to quantum field theory and lattice field | ||
| simulations. | ||
|
|
||
| In higher dimensions: | ||
|
|
||
| ``` | ||
| uₜₜ = c² Δu - m² u | ||
| ``` | ||
|
|
||
| **Dispersion relation:** ω(k) = √(c²|k|² + m²) | ||
|
|
||
| Unlike the wave equation (ω = c|k|), the Klein-Gordon equation has a | ||
| **mass gap** — no modes with ω < m exist. The group velocity | ||
| v_g = c²|k|/ω is always less than c (massive dispersion). | ||
|
|
||
| Internally, the same diagonalization approach as the | ||
| [`exponax.stepper.Wave`][] stepper is used, but with | ||
| the Klein-Gordon dispersion relation. | ||
|
|
||
| The second-order equation is rewritten as a first-order system: | ||
|
|
||
| ``` | ||
| hₜ = v | ||
| vₜ = c² Δh - m² h | ||
| ``` | ||
|
|
||
| In Fourier space, each wavenumber k oscillates at frequency | ||
| ω(k) = √(c²|k|² + m²). The system is diagonalized into | ||
| forward/backward traveling modes that each evolve as a pure | ||
| phase rotation. | ||
|
|
||
| **Arguments:** | ||
|
|
||
| - `num_spatial_dims`: The number of spatial dimensions `d`. | ||
| - `domain_extent`: The size of the domain `L`; in higher dimensions | ||
| the domain is assumed to be a scaled hypercube `Ω = (0, L)ᵈ`. | ||
| - `num_points`: The number of points `N` used to discretize the | ||
| domain. This **includes** the left boundary point and **excludes** | ||
| the right boundary point. In higher dimensions; the number of points | ||
| in each dimension is the same. | ||
| - `dt`: The timestep size `Δt` between two consecutive states. | ||
| - `speed_of_sound` (keyword-only): The wave speed `c`. Default: `1.0`. | ||
| - `mass` (keyword-only): The mass parameter `m`. Default: `1.0`. | ||
|
|
||
| **Notes:** | ||
|
|
||
| - The stepper is unconditionally stable, no matter the choice of | ||
| any argument because the equation is solved analytically in Fourier | ||
| space. | ||
| - Setting `mass = 0.0` recovers the standard wave equation. | ||
| - The factors `c Δt / L` and `m Δt` together affect the dynamics. | ||
| """ | ||
| self.speed_of_sound = speed_of_sound | ||
| self.mass = mass | ||
| wavenumber_norm = jnp.linalg.norm( | ||
| build_scaled_wavenumbers( | ||
| num_spatial_dims=num_spatial_dims, | ||
| domain_extent=domain_extent, | ||
| num_points=num_points, | ||
| ), | ||
| axis=0, | ||
| keepdims=True, | ||
| ) | ||
| # Klein-Gordon dispersion: ω(k) = sqrt(c²|k|² + m²) | ||
| self.frequency = jnp.sqrt( | ||
| speed_of_sound**2 * wavenumber_norm**2 + mass**2 | ||
| ) | ||
| super().__init__( | ||
| num_spatial_dims=num_spatial_dims, | ||
| domain_extent=domain_extent, | ||
| num_points=num_points, | ||
| dt=dt, | ||
| num_channels=2, | ||
| order=0, | ||
| ) | ||
|
|
||
| def _forward_transform( | ||
| self, u_hat: Complex[Array, " 2 ... (N//2)+1"] | ||
| ) -> Complex[Array, " 2 ... (N//2)+1"]: | ||
| """Transform (h, v) into diagonalized Klein-Gordon wave modes.""" | ||
| h_hat, v_hat = u_hat[0:1], u_hat[1:2] | ||
| # Scale height to match velocity units: w = iω h | ||
| omega_guard = jnp.where(self.frequency == 0, 1.0, self.frequency) | ||
| w_hat = 1j * omega_guard * h_hat | ||
|
|
||
| # Orthonormal rotation into wave modes | ||
| pos = (1 / jnp.sqrt(2)) * (w_hat + v_hat) | ||
| neg = (1 / jnp.sqrt(2)) * (w_hat - v_hat) | ||
| return jnp.concatenate([pos, neg], axis=0) | ||
|
|
||
| def _inverse_transform( | ||
| self, waves_hat: Complex[Array, " 2 ... (N//2)+1"] | ||
| ) -> Complex[Array, " 2 ... (N//2)+1"]: | ||
| """Transform diagonalized wave modes back into (h, v).""" | ||
| pos, neg = waves_hat[0:1], waves_hat[1:2] | ||
| # Inverse rotation | ||
| w_hat = (1 / jnp.sqrt(2)) * (pos + neg) | ||
| v_hat = (1 / jnp.sqrt(2)) * (pos - neg) | ||
|
|
||
| # Undo scaling to recover height | ||
| omega_guard = jnp.where(self.frequency == 0, 1.0, self.frequency) | ||
| h_hat = w_hat / (1j * omega_guard) | ||
| return jnp.concatenate([h_hat, v_hat], axis=0) | ||
|
|
||
| def _build_linear_operator( | ||
| self, derivative_operator: Complex[Array, " D ... (N//2)+1"] | ||
| ) -> Complex[Array, " 2 ... (N//2)+1"]: | ||
| val = 1j * self.frequency | ||
| return jnp.concatenate( | ||
| ( | ||
| val, | ||
| -val, | ||
| ), | ||
| axis=0, | ||
| ) | ||
|
|
||
| def _build_nonlinear_fun( | ||
| self, derivative_operator: Complex[Array, " D ... (N//2)+1"] | ||
| ) -> ZeroNonlinearFun: | ||
| return ZeroNonlinearFun(self.num_spatial_dims, self.num_points) | ||
|
|
||
| def step_fourier( | ||
| self, u_hat: Complex[Array, " 2 ... (N//2)+1"] | ||
| ) -> Complex[Array, " 2 ... (N//2)+1"]: | ||
| """ | ||
| Advance the state by one timestep in Fourier space. | ||
|
|
||
| Overrides the base method to wrap the ETDRK step with the | ||
| forward/inverse diagonalization transforms. | ||
| """ | ||
| waves_hat = self._forward_transform(u_hat) | ||
| waves_hat_next = super().step_fourier(waves_hat) | ||
| u_hat_next = self._inverse_transform(waves_hat_next) | ||
|
|
||
| # At k=0 with m>0, the system is still diagonalizable (ω(0) = m ≠ 0), | ||
| # so no special DC correction is needed. However, if m=0 we fall back | ||
| # to the wave equation DC behavior. | ||
| if self.mass == 0.0: | ||
| h_dc_idx = (0,) + (0,) * self.num_spatial_dims | ||
| v_dc_idx = (1,) + (0,) * self.num_spatial_dims | ||
| u_hat_next = u_hat_next.at[h_dc_idx].add(self.dt * u_hat[v_dc_idx]) | ||
|
|
||
| return u_hat_next |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -18,6 +18,7 @@ def test_instantiate(): | |
| ex.stepper.Dispersion, | ||
| ex.stepper.HyperDiffusion, | ||
| ex.stepper.Wave, | ||
| ex.stepper.KleinGordon, | ||
| ex.stepper.Burgers, | ||
|
Comment on lines
20
to
22
|
||
| ex.stepper.KuramotoSivashinsky, | ||
| ex.stepper.KuramotoSivashinskyConservative, | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,173 @@ | ||
| import jax.numpy as jnp | ||
| import pytest | ||
|
|
||
| import exponax as ex | ||
| from exponax.stepper import KleinGordon, Wave | ||
|
|
||
| L = 2 * jnp.pi | ||
| PI = jnp.pi | ||
|
|
||
|
|
||
| # =========================================================================== | ||
| # Instantiation | ||
| # =========================================================================== | ||
|
|
||
|
|
||
| class TestKleinGordonInstantiation: | ||
| @pytest.mark.parametrize("num_spatial_dims", [1, 2, 3]) | ||
| def test_instantiate(self, num_spatial_dims): | ||
| stepper = KleinGordon(num_spatial_dims, 10.0, 25, 0.1) | ||
| assert stepper.num_channels == 2 | ||
| assert stepper.num_spatial_dims == num_spatial_dims | ||
|
|
||
| @pytest.mark.parametrize("num_spatial_dims", [1, 2, 3]) | ||
| def test_output_shape(self, num_spatial_dims): | ||
| N = 16 | ||
| stepper = KleinGordon(num_spatial_dims, L, N, 0.01) | ||
| u0 = jnp.zeros((2,) + (N,) * num_spatial_dims) | ||
| u1 = stepper(u0) | ||
| assert u1.shape == u0.shape | ||
| assert jnp.all(jnp.isfinite(u1)) | ||
|
|
||
| def test_wrong_input_shape_raises(self): | ||
| stepper = KleinGordon(1, L, 32, 0.01) | ||
| with pytest.raises(ValueError, match="Expected shape"): | ||
| stepper(jnp.zeros((1, 32))) # needs 2 channels | ||
|
|
||
| def test_default_params(self): | ||
| stepper = KleinGordon(1, L, 32, 0.01) | ||
| assert stepper.speed_of_sound == 1.0 | ||
| assert stepper.mass == 1.0 | ||
|
|
||
| def test_custom_params(self): | ||
| stepper = KleinGordon(1, L, 32, 0.01, speed_of_sound=2.0, mass=3.0) | ||
| assert stepper.speed_of_sound == 2.0 | ||
| assert stepper.mass == 3.0 | ||
|
|
||
|
|
||
| # =========================================================================== | ||
| # mass=0 should recover Wave equation | ||
| # =========================================================================== | ||
|
|
||
|
|
||
| class TestKleinGordonRecoverWave: | ||
| """When mass=0, KleinGordon must produce identical results to Wave.""" | ||
|
|
||
| @pytest.mark.parametrize("num_spatial_dims", [1, 2]) | ||
| def test_mass_zero_matches_wave(self, num_spatial_dims): | ||
| N, dt, c = 32, 0.01, 1.5 | ||
| kg = KleinGordon( | ||
| num_spatial_dims, L, N, dt, speed_of_sound=c, mass=0.0 | ||
| ) | ||
| wave = Wave(num_spatial_dims, L, N, dt, speed_of_sound=c) | ||
|
|
||
| x = jnp.linspace(0, L, N, endpoint=False) | ||
| if num_spatial_dims == 1: | ||
| h0 = jnp.cos(2 * x)[None] | ||
| else: | ||
| h0 = jnp.cos(2 * x)[None, :, None] * jnp.ones((1, N, N)) | ||
| v0 = jnp.zeros_like(h0) | ||
| u0 = jnp.concatenate([h0, v0], axis=0) | ||
|
|
||
| u_kg = u0 | ||
| u_wave = u0 | ||
| for _ in range(20): | ||
| u_kg = kg(u_kg) | ||
| u_wave = wave(u_wave) | ||
|
|
||
| assert u_kg == pytest.approx(u_wave, abs=1e-5) | ||
|
|
||
| def test_mass_zero_matches_wave_multi_step(self): | ||
| """Longer evolution to catch accumulation drift.""" | ||
| N, dt, c = 64, 0.005, 1.0 | ||
| kg = KleinGordon(1, L, N, dt, speed_of_sound=c, mass=0.0) | ||
| wave = Wave(1, L, N, dt, speed_of_sound=c) | ||
|
|
||
| x = jnp.linspace(0, L, N, endpoint=False) | ||
| h0 = (jnp.cos(x) + 0.5 * jnp.cos(3 * x))[None] | ||
| v0 = jnp.zeros_like(h0) | ||
| u0 = jnp.concatenate([h0, v0], axis=0) | ||
|
|
||
| u_kg = u0 | ||
| u_wave = u0 | ||
| for _ in range(100): | ||
| u_kg = kg(u_kg) | ||
| u_wave = wave(u_wave) | ||
|
|
||
| assert u_kg == pytest.approx(u_wave, abs=1e-4) | ||
|
|
||
|
|
||
| # =========================================================================== | ||
| # Analytical correctness — 1D Klein-Gordon standing mode | ||
| # =========================================================================== | ||
|
|
||
|
|
||
| class TestKleinGordonAnalytical1D: | ||
| """For h(x,0) = cos(k0 x), v(x,0) = 0: | ||
| ω = sqrt(c²k0² + m²) | ||
| h(x,t) = cos(k0 x) cos(ω t) | ||
| v(x,t) = -ω cos(k0 x) sin(ω t) | ||
| """ | ||
|
|
||
| def _make_stepper_and_ic(self, k0, c=1.0, m=1.0, N=64, dt=0.01): | ||
| stepper = KleinGordon(1, L, N, dt, speed_of_sound=c, mass=m) | ||
| x = jnp.linspace(0, L, N, endpoint=False) | ||
| h0 = jnp.cos(k0 * x)[None] | ||
| v0 = jnp.zeros_like(h0) | ||
| u0 = jnp.concatenate([h0, v0], axis=0) | ||
| omega = jnp.sqrt(c**2 * k0**2 + m**2) | ||
| return stepper, x, u0, float(omega) | ||
|
|
||
| @pytest.mark.parametrize("k0", [1, 2, 3, 5]) | ||
| def test_single_mode(self, k0): | ||
| c, m, N, dt = 1.0, 2.0, 64, 0.01 | ||
| stepper, x, u0, omega = self._make_stepper_and_ic(k0, c, m, N, dt) | ||
|
|
||
| n_steps = 10 | ||
| u = u0 | ||
| for _ in range(n_steps): | ||
| u = stepper(u) | ||
| t = n_steps * dt | ||
|
|
||
| h_exact = jnp.cos(k0 * x) * jnp.cos(omega * t) | ||
| v_exact = -omega * jnp.cos(k0 * x) * jnp.sin(omega * t) | ||
|
|
||
| assert u[0] == pytest.approx(h_exact, abs=1e-4) | ||
| assert u[1] == pytest.approx(v_exact, abs=1e-3) | ||
|
|
||
| def test_mass_gap(self): | ||
| """With k0=0 (uniform mode), oscillation is at ω = m (the mass gap).""" | ||
| m, N, dt = 3.0, 32, 0.01 | ||
| stepper = KleinGordon(1, L, N, dt, speed_of_sound=1.0, mass=m) | ||
|
|
||
| h0 = jnp.ones((1, N)) # k=0 mode | ||
| v0 = jnp.zeros_like(h0) | ||
| u0 = jnp.concatenate([h0, v0], axis=0) | ||
|
|
||
| n_steps = 20 | ||
| u = u0 | ||
| for _ in range(n_steps): | ||
| u = stepper(u) | ||
| t = n_steps * dt | ||
|
|
||
| h_exact = jnp.cos(m * t) * jnp.ones(N) | ||
| assert u[0] == pytest.approx(h_exact, abs=1e-4) | ||
|
|
||
| def test_energy_bounded(self): | ||
| """Total energy should be conserved (bounded) over many steps.""" | ||
| k0, c, m, N, dt = 3, 1.0, 2.0, 64, 0.005 | ||
| stepper, x, u0, omega = self._make_stepper_and_ic(k0, c, m, N, dt) | ||
|
|
||
| def energy(u): | ||
| h, v = u[0], u[1] | ||
| # KE + gradient PE + mass PE | ||
| return jnp.sum(v**2 + c**2 * jnp.abs(jnp.fft.rfft(h))**2 + m**2 * h**2) | ||
|
|
||
| e0 = energy(u0) | ||
| u = u0 | ||
| for _ in range(200): | ||
| u = stepper(u) | ||
| e_final = energy(u) | ||
|
|
||
| # Spectral solver should conserve energy to machine precision | ||
| assert e_final == pytest.approx(float(e0), rel=1e-3) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This new paragraph documents
KleinGordon, but the earlier “The concrete PDE steppers are:” list in the same module docstring still omits it. Please update that list to includeKleinGordonso the public-facing documentation remains consistent.