Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 74 additions & 0 deletions stingray/pulse/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,12 @@
from ..utils import contiguous_regions
from astropy.stats import poisson_conf_interval
import matplotlib.pyplot as plt
try:
import jax
import jax.numpy as jnp
HAS_JAX = True
except ImportError:
HAS_JAX = False


__all__ = [
Expand Down Expand Up @@ -244,6 +250,47 @@ def stat_fun(t, f, fd=0, **kwargs):
fdots=fdots,
)

def _z_search_jax_kernel(times, frequencies, fdots, nharm=2):
"""Calculates Z^2_n statistics for a grid of (frequency, fdot) pairs.

Parameters
----------
times : array
Photon arrival times (in seconds).
frequencies : array
1D array of trial frequencies.
fdots : array
1D array of trial frequency derivatives (can be just [0]).
nharm : int
Number of harmonics to sum.

Returns
-------
z_stats : 2D array (len(frequencies), len(fdots))
"""
n_events = times.shape[0]

def compute_statistic(f, fd):
phase = (times * f) + (0.5 * fd * times**2)
phase = phase - jnp.floor(phase)
phase *= 2 * jnp.pi

z_stat = 0.0
for k in range(1, nharm + 1):
sin_s = jnp.sum(jnp.sin(k * phase))
cos_s = jnp.sum(jnp.cos(k * phase))
z_stat += sin_s**2 + cos_s**2

return (2.0 / n_events) * z_stat

vectorized_inner = jax.vmap(compute_statistic, in_axes=(None, 0))
vectorized_outer = jax.vmap(vectorized_inner, in_axes=(0, None))

return vectorized_outer(frequencies, fdots)


if HAS_JAX:
_z_search_jax_kernel = jax.jit(_z_search_jax_kernel, static_argnums=(3,))

def z_n_search(
times,
Expand All @@ -255,6 +302,7 @@ def z_n_search(
weights=1,
gti=None,
fdots=0,
use_jax=False,
):
"""Calculates the Z^2_n statistics at trial frequencies in photon data.

Expand Down Expand Up @@ -300,6 +348,11 @@ def z_n_search(
weight for each time. This might be, for example, the number of counts
if the times array contains the time bins of a light curve

use_jax : bool
If True and JAX is installed, use JAX-accelerated computation for
the Z^2_n search. Computes the exact Z^2_n statistic directly from
event phases without binning. Default is False.

Returns
-------
(fgrid, stats) or (fgrid, fdgrid, stats), as follows:
Expand All @@ -311,6 +364,27 @@ def z_n_search(
stats : array-like
the Z^2_n statistics corresponding to each frequency bin.
"""
if use_jax and HAS_JAX:
if np.ndim(fdots) == 0:
fdots = np.array([fdots])
else:
fdots = np.asarray(fdots)

j_times = jax.device_put(np.asarray(times, dtype=np.float64) - times[0])
j_freqs = jax.device_put(np.asarray(frequencies, dtype=np.float64))
j_fdots = jax.device_put(np.asarray(fdots, dtype=np.float64))

# Run Kernel
stats = _z_search_jax_kernel(j_times, j_freqs, j_fdots, nharm)

# If fdots was scalar, flatten the result to match standard Stingray output
if len(fdots) == 1:
return frequencies, stats.flatten()

# If 2D search, return the full grid
f_grid, fd_grid = np.meshgrid(frequencies, fdots)
return f_grid, fd_grid, np.array(stats).T

phase = np.arange(0, 1, 1 / nbin)
if expocorr or not HAS_NUMBA or isinstance(weights, Iterable):
if expocorr and gti is None:
Expand Down
Loading