diff --git a/.gitignore b/.gitignore index b52b152a..b2197db0 100644 --- a/.gitignore +++ b/.gitignore @@ -267,3 +267,4 @@ package.json # Log files generated by 'vagrant up' *.log +.worktree diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 8e0601c7..64489eee 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -3,7 +3,7 @@ repos: - repo: https://github.com/astral-sh/ruff-pre-commit rev: "v0.14.10" hooks: - - id: ruff + - id: ruff-check args: ["--fix"] - id: ruff-format - repo: https://github.com/PyCQA/isort diff --git a/changelog/239.feature.rst b/changelog/239.feature.rst new file mode 100644 index 00000000..72a338c9 --- /dev/null +++ b/changelog/239.feature.rst @@ -0,0 +1 @@ +Add a new `sunkit_spex.spectrum.spectrum.Spectrum` object to hold spectral data. `~sunkit_spex.spectrum.spectrum.Spectrum` is based on `NDCube` and butils on it coordinate aware methods and metadata handling. diff --git a/docs/reference/index.rst b/docs/reference/index.rst index 6c60dd2d..fc15b38d 100644 --- a/docs/reference/index.rst +++ b/docs/reference/index.rst @@ -9,7 +9,7 @@ Software and API. .. toctree:: :maxdepth: 2 - - fitting + spectrum models + fitting legacy diff --git a/docs/reference/spectrum.rst b/docs/reference/spectrum.rst new file mode 100644 index 00000000..5ce9a9e8 --- /dev/null +++ b/docs/reference/spectrum.rst @@ -0,0 +1,7 @@ +Models (`sunkit_spex.spectrum`) +******************************* + +``sunkit_spex.spectrum`` module contains objects for holding spectral data + +.. automodapi:: sunkit_spex.spectrum +.. automodapi:: sunkit_spex.spectrum.spectrum diff --git a/examples/spectrum.py b/examples/spectrum.py new file mode 100644 index 00000000..413a0970 --- /dev/null +++ b/examples/spectrum.py @@ -0,0 +1,233 @@ +""" +======== +Spectrum +======== + +This example will demonstrate how to store spectral data in `~sunkit_spex.spectrum.Specutm` container +""" + +##################################################### +# +# Imports + +import numpy as np +from ndcube import NDMeta +from ndcube.extra_coords import QuantityTableCoordinate, TimeTableCoordinate + +import astropy.units as u +from astropy.coordinates import SpectralCoord +from astropy.time import Time + +from sunkit_spex.spectrum import Spectrum + +rng = np.random.default_rng() +##################################################### +# +# 1D Spectrum +# ----------- +# Let's being with the simplest case a single spectrum that is a series of measurements as function of wavelength or +# energy. We will start of by creating some synthetic data and corresponding energy bins as well as some important metadata +# in this case the exposure time. + +data = rng.random(50) * u.ct +energy = np.linspace(1, 50, 50) * u.keV +time = Time("2025-02-18T15:08") + +exposure_time = 5 * u.s + +##################################################### +# +# Once we have our synthetic data we can create our metadata container `NDMeta` and `Spectrum` object. + +meta = NDMeta() +meta.add("exposure_time", exposure_time) +meta.add("date-obs", time) + +spec_1d = Spectrum(data, spectral_axis=energy, meta=meta) +spec_1d + +##################################################### +# +# One of the key feature of the `Spectrum` object is the ability to slice, crop and perform other operations using +# standard sliceing methods: + +spec_1d_sliced = spec_1d[10:20] +print(spec_1d_sliced.shape) +print(spec_1d_sliced.axis_world_coords_values()) +print(spec_1d_sliced.meta) +print(spec_1d_sliced.spectral_axis) + +##################################################### +# +# High level coordinate objects such as SkyCoord and SpectralCoord + +spec_1d_crop = spec_1d.crop(SpectralCoord(10.5, unit=u.keV), SpectralCoord(20, unit=u.keV)) +print(spec_1d_crop.shape) +print(spec_1d_crop.axis_world_coords_values()) +print(spec_1d_crop.meta) +print(spec_1d_crop.spectral_axis) + +##################################################### +# +# And Quantities + +spec_1d_crop_value = spec_1d.crop_by_values((10.5 * u.keV), (20.5 * u.keV)) +print(spec_1d_crop_value.shape) +print(spec_1d_crop_value.axis_world_coords_values()) +print(spec_1d_crop_value.meta) +print(spec_1d_crop_value.spectral_axis) + +##################################################### +# +# 2D Spectrum (spectrogram or time v energy) +# ------------------------------------------ +# Let build on the previous example by increasing the dimensionality of the data in this case to a spectrogram or a +# series of spectra as a function of time. Here we will simulate a series of 10 spectra taken over 10 minutes. Again we +# begin by creating our synthetic data as before but additionally creating the time variable. + +data = rng.random((10, 50)) * u.ct +energy = np.linspace(1, 50, 51) * u.keV +times = Time("2025-02-18T15:08") + np.arange(10) * u.min +exposure_time = np.arange(5, 15) * u.s + +##################################################### +# +# We are also going to demonstrate the power of the sliceable metadata, so in this example each of the individual +# spectra have different exposure times (this could be another important information regard the observation) + +meta = NDMeta() +meta.add("exposure_time", exposure_time, axes=(0,)) + +time_coord = TimeTableCoordinate(times, names="time", physical_types="time") +energy_coord = QuantityTableCoordinate(energy, names="energy", physical_types="em.energy") +wcs = (energy_coord & time_coord).wcs + +spec_2d_time_energy = Spectrum(data, spectral_axis=energy, wcs=wcs, spectral_axis_index=1, meta=meta) + +###################################################### +# +# Again all standard slicing works + +spec_2d_time_energy[2:5] +spec_2d_time_energy[:, 10:20] +spec_2d_time_energy_sliced = spec_2d_time_energy[2:5, 10:20] + +###################################################### +# +# We can being to see the usefulness of the sliceable metadata notice how the exposure time entry has been sliced +# appropriately + +print(spec_2d_time_energy_sliced.shape) +print(spec_2d_time_energy_sliced.axis_world_coords_values()) +print(spec_2d_time_energy_sliced.meta) +print(spec_2d_time_energy_sliced.spectral_axis) + +###################################################### +# +# The same can be archived using height level coordinate objects +# + +spec_2d_time_energy_crop = spec_2d_time_energy.crop( + [SpectralCoord(10, unit=u.keV), Time("2025-02-18T15:10")], [SpectralCoord(20, unit=u.keV), Time("2025-02-18T15:12")] +) + +print(spec_2d_time_energy_crop.shape) +print(spec_2d_time_energy_crop.axis_world_coords_values()) +print(spec_2d_time_energy_crop.meta) +print(spec_2d_time_energy_crop.spectral_axis) + +###################################################### +# +# Or Quantities as before +spec_2d_time_energy_crop_values = spec_2d_time_energy.crop_by_values((10 * u.keV, 2 * u.min), (19.5 * u.keV, 4 * u.min)) + +print(spec_2d_time_energy_crop_values.shape) +print(spec_2d_time_energy_crop_values.axis_world_coords_values()) +print(spec_2d_time_energy_crop_values.meta) +print(spec_2d_time_energy_crop_values.spectral_axis) + +##################################################### +# +# 2D Spectrum ( e.g. detector v energy) +# ------------------------------------- + +data = rng.random((10, 50)) * u.ct +energy = np.linspace(1, 50, 50) * u.keV + +exposure_time = np.arange(10) * u.s +labels = np.array([f"det_+{chr(97 + i)}" for i in range(10)]) + +meta = NDMeta() +meta.add("exposure_time", exposure_time, axes=0) +meta.add("detector", labels, axes=0) + +spec_2d_det_time = Spectrum(data, spectral_axis=energy, spectral_axis_index=1, meta=meta) +spec_2d_det_time + + +##################################################### +# + +# spec_2d_det_time.crop((SpectralCoord(10 * u.keV), None), (SpectralCoord(20 * u.keV), None)) + +##################################################### +# + +# spec_2d_det_time.crop_by_values((10 * u.keV, 0), (20 * u.keV, 2)) + +##################################################### +# +# 3D Spectrum ( e.g. detector v energy v time) +# -------------------------------------------- + +# data = rng.random(10, 20, 30) * u.ct +# energy = np.linspace(1, 31, 31) * u.keV +# +# labels = np.array([chr(97 + i) for i in range(10)]) +# exposure_time = np.arange(10 * 20).reshape(10, 20) * u.s +# times = Time.now() + np.arange(20) * u.s +# +# meta = NDMeta() +# meta.add("exposure_time", exposure_time, axes=(0, 1)) +# meta.add("detector", labels, axes=(0,)) +# +# spec_3d_det_energy_time = Spectrum(data, spectral_axis=energy, spectral_axis_index=2, meta=meta) +# spec_3d_det_energy_time.extra_coords.add("time", (0,), times) +# +# spec_3d_det_energy_time[:, 10:15, :].meta +# spec_3d_det_energy_time[2:3, 10:15, :].meta + +##################################################### +# +# 4D Spectrum ( e.g. spatial v spatial v energy v time) +# ----------------------------------------------------- + +# import numpy as np +# from ndcube import NDMeta +# +# import astropy.units as u +# from astropy.time import Time +# +# data = np.random.rand(10, 10, 20, 30) * u.ct +# energy = np.linspace(1, 31, 31) * u.keV +# exposure_time = np.arange(20) * u.s +# times = Time.now() + np.arange(20) * u.s +# +# meta = NDMeta() +# meta.add("exposure_time", exposure_time, axes=(2,)) +# +# wcs = astropy.wcs.WCS(naxis=2) +# wcs.wcs.ctype = "HPLT-TAN", "HPLN-TAN" +# wcs.wcs.cunit = "deg", "deg" +# wcs.wcs.cdelt = 0.5, 0.4 +# wcs.wcs.crpix = 5, 6 +# wcs.wcs.crval = 0.5, 1 +# wcs.wcs.cname = "HPC lat", "HPC lon" +# +# cube = NDCube(data=data, wcs=wcs, meta=meta) +# +# # Now instantiate the NDCube +# spec_4d_lon_lat_time_energy = Spectrum(data, wcs=wcs, spectral_axis=energy, spectral_axis_index=3, meta=meta) +# spec_4d_lon_lat_time_energy.extra_coords.add("time", (2,), times) +# +# spec_4d_lon_lat_time_energy diff --git a/pyproject.toml b/pyproject.toml index be250874..e8f44c46 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,10 +23,10 @@ dependencies = [ "numdifftools>=0.9.42", "numpy>=1.26", # Note: keeping support for numpy 1.x for now "parfive>=2.1", - "scipy>=1.12", + "scipy>=1.14.1", "sunpy>=7.0", "xarray>=2023.12", - "gwcs>=0.21.0", + "gwcs>=0.26.0,<1.0.0", #until https://github.com/sunpy/ndcube/issues/913 is resolved "ndcube>=2.3", ] diff --git a/pytest.ini b/pytest.ini index 7ec6ac41..4f886703 100644 --- a/pytest.ini +++ b/pytest.ini @@ -40,6 +40,7 @@ filterwarnings = # Oldestdeps issues ignore:`finfo.machar` is deprecated:DeprecationWarning ignore:Please use `convolve1d` from the `scipy.ndimage` namespace, the `scipy.ndimage.filters` namespace is deprecated.:DeprecationWarning - ignore::pyparsing.warnings.PyparsingDeprecationWarning ignore::FutureWarning:arviz.* ignore:The isiterable function.*:astropy.utils.exceptions.AstropyDeprecationWarning + ignore:'datfix' made the change:astropy.wcs.wcs.FITSFixedWarning + ignore::pyparsing.warnings.PyparsingDeprecationWarning diff --git a/sunkit_spex/conftest.py b/sunkit_spex/conftest.py deleted file mode 100644 index 44d984ab..00000000 --- a/sunkit_spex/conftest.py +++ /dev/null @@ -1,44 +0,0 @@ -# This file is used to configure the behavior of pytest when using the Astropy -# test infrastructure. - - -# Uncomment the following line to treat all DeprecationWarnings as -# exceptions. For Astropy v2.0 or later, there are 2 additional keywords, -# as follow (although default should work for most cases). -# To ignore some packages that produce deprecation warnings on import -# (in addition to 'compiler', 'scipy', 'pygments', 'ipykernel', and -# 'setuptools'), add: -# modules_to_ignore_on_import=['module_1', 'module_2'] -# To ignore some specific deprecation warning messages for Python version -# MAJOR.MINOR or later, add: -# warnings_to_ignore_by_pyver={(MAJOR, MINOR): ['Message to ignore']} -# enable_deprecations_as_exceptions() - -# Uncomment and customize the following lines to add/remove entries from -# the list of packages for which version numbers are displayed when running -# the tests. Making it pass for KeyError is essential in some cases when -# the package uses other astropy affiliated packages. -# try: -# PYTEST_HEADER_MODULES['Astropy'] = 'astropy' -# PYTEST_HEADER_MODULES['scikit-image'] = 'skimage' -# del PYTEST_HEADER_MODULES['h5py'] -# except (NameError, KeyError): # NameError is needed to support Astropy < 1.0 -# pass - -# Uncomment the following lines to display the version number of the -# package rather than the version number of Astropy in the top line when -# running the tests. -# import os -# -# This is to figure out the package version, rather than -# using Astropy's -# try: -# from .version import version -# except ImportError: -# version = 'dev' -# -# try: -# packagename = os.path.basename(os.path.dirname(__file__)) -# TESTED_VERSIONS[packagename] = version -# except NameError: # Needed to support Astropy <= 1.0.0 -# pass diff --git a/sunkit_spex/spectrum/conftest.py b/sunkit_spex/spectrum/conftest.py new file mode 100644 index 00000000..e69de29b diff --git a/sunkit_spex/spectrum/spectrum.py b/sunkit_spex/spectrum/spectrum.py index b1ca2d86..9ee36569 100644 --- a/sunkit_spex/spectrum/spectrum.py +++ b/sunkit_spex/spectrum/spectrum.py @@ -1,3 +1,6 @@ +import copy +from copy import deepcopy + import numpy as np from gwcs import WCS as GWCS from gwcs import coordinate_frames as cf @@ -5,53 +8,193 @@ import astropy.units as u from astropy.coordinates import SpectralCoord +from astropy.modeling.mappings import Identity, Mapping from astropy.modeling.tabular import Tabular1D from astropy.utils import lazyproperty +from astropy.wcs.wcsapi import sanitize_slices __all__ = ["SpectralAxis", "Spectrum", "gwcs_from_array"] -__doctest_requires__ = {"Spectrum": ["ndcube>=2.3"]} __doctest_requires__ = {"Spectrum": ["ndcube>=2.3"]} -def gwcs_from_array(array): +class SpectralGWCS(GWCS): + """ + This is a placeholder lookup-table GWCS created when a :class:`~specutils.Spectrum` is + instantiated with a ``spectral_axis`` and no WCS. + """ + + def __init__(self, *args, **kwargs): + self.original_unit = kwargs.pop("original_unit", "") + super().__init__(*args, **kwargs) + + def copy(self): + """ + Return a shallow copy of the object. + + Convenience method so user doesn't have to import the + :mod:`copy` stdlib module. + + .. warning:: + Use `deepcopy` instead of `copy` unless you know why you need a + shallow copy. + """ + return copy.copy(self) + + def deepcopy(self): + """ + Return a deep copy of the object. + + Convenience method so user doesn't have to import the + :mod:`copy` stdlib module. + """ + return copy.deepcopy(self) + + +def gwcs_from_array(array, flux_shape, spectral_axis_index=None): """ Create a new WCS from provided tabular data. This defaults to being - a GWCS object. + a GWCS object with a lookup table for the spectral axis and filler + pixel to pixel identity conversions for spatial axes, if they exist. """ orig_array = u.Quantity(array) - - coord_frame = cf.CoordinateFrame(naxes=1, axes_type=("SPECTRAL",), axes_order=(0,)) - spec_frame = cf.SpectralFrame(unit=array.unit, axes_order=(0,)) + naxes = len(flux_shape) + + if naxes > 1: + if spectral_axis_index is None: + raise ValueError("spectral_axis_index must be set for multidimensional flux arrays") + # Axis order is reversed for WCS from numpy array + spectral_axis_index = naxes - spectral_axis_index - 1 + elif naxes == 1: + spectral_axis_index = 0 + + axes_order = list(np.arange(naxes)) + axes_type = [ + "SPATIAL", + ] * naxes + axes_type[spectral_axis_index] = "SPECTRAL" + + detector_frame = cf.CoordinateFrame( + naxes=naxes, + name="detector", + unit=[ + u.pix, + ] + * naxes, + axes_order=axes_order, + axes_type=axes_type, + ) + + if array.unit in ("", "pix", "pixel"): + # Spectrum was initialized without a wcs or spectral axis + spectral_frame = cf.CoordinateFrame( + naxes=1, + unit=[ + array.unit, + ], + axes_type=[ + "Spectral", + ], + axes_order=(spectral_axis_index,), + ) + else: + phys_types = None + # Note that some units have multiple physical types, so we can't just set the + # axis name to the physical type string. + if array.unit.physical_type == "length": + axes_names = [ + "wavelength", + ] + elif array.unit.physical_type == "frequency": + axes_names = [ + "frequency", + ] + elif array.unit.physical_type == "velocity": + axes_names = [ + "velocity", + ] + phys_types = [ + "spect.dopplerVeloc.optical", + ] + elif array.unit.physical_type == "wavenumber": + axes_names = [ + "wavenumber", + ] + elif array.unit.physical_type == "energy": + axes_names = [ + "energy", + ] + else: + raise ValueError("Spectral axis units must be one of length,frequency, velocity, energy, or wavenumber") + + spectral_frame = cf.SpectralFrame( + unit=array.unit, axes_order=(spectral_axis_index,), axes_names=axes_names, axis_physical_types=phys_types + ) + + if naxes > 1: + axes_order.remove(spectral_axis_index) + spatial_frame = cf.CoordinateFrame( + naxes=naxes - 1, + unit=[ + "", + ] + * (naxes - 1), + axes_type=[ + "Spatial", + ] + * (naxes - 1), + axes_order=axes_order, + ) + output_frame = cf.CompositeFrame(frames=[spatial_frame, spectral_frame]) + else: + output_frame = spectral_frame # In order for the world_to_pixel transformation to automatically convert - # input units, the equivalencies in the lookup table have to be extended + # input units, the equivalencies in the look up table have to be extended # with spectral unit information. - SpectralTabular1D = type("SpectralTabular1D", (Tabular1D,), {"input_units_equivalencies": {"x0": u.spectral()}}) + SpectralTabular1D = type( + "SpectralTabular1D", (Tabular1D,), {"input_units_equivalencies": {"x0": u.spectral()}, "bounds_error": True} + ) + + # We pass through the pixel values of spatial axes with Identity and use a lookup + # table for the spectral axis values. We use Mapping to pipe the values to the correct + # model depending on which axis is the spectral axis + if naxes == 1: + forward_transform = SpectralTabular1D(np.arange(len(array)) * u.pix, lookup_table=array) + else: + axes_order.append(spectral_axis_index) + # WCS axis order is reverse of numpy array order + mapped_axes = axes_order + out_mapping = np.ones(len(mapped_axes)).astype(int) + for i in range(len(mapped_axes)): + out_mapping[mapped_axes[i]] = i + forward_transform = ( + Mapping(mapped_axes) + | Identity(naxes - 1) & SpectralTabular1D(np.arange(len(array)) * u.pix, lookup_table=array) + | Mapping(out_mapping) + ) - forward_transform = SpectralTabular1D(np.arange(len(array)), lookup_table=array) # If our spectral axis is in descending order, we have to flip the lookup # table to be ascending in order for world_to_pixel to work. - if len(array) == 0 or array[-1] > array[0]: - forward_transform.inverse = SpectralTabular1D(array, lookup_table=np.arange(len(array))) - else: - forward_transform.inverse = SpectralTabular1D(array[::-1], lookup_table=np.arange(len(array))[::-1]) - - class SpectralGWCS(GWCS): - def pixel_to_world(self, *args, **kwargs): - if orig_array.unit == "": - return u.Quantity(super().pixel_to_world_values(*args, **kwargs)) - return super().pixel_to_world(*args, **kwargs).to(orig_array.unit, equivalencies=u.spectral()) + forward_transform.inverse = SpectralTabular1D(array, lookup_table=np.arange(len(array)) * u.pix) - return SpectralGWCS(forward_transform=forward_transform, input_frame=coord_frame, output_frame=spec_frame) + tabular_gwcs = SpectralGWCS( + original_unit=orig_array.unit, + forward_transform=forward_transform, + input_frame=detector_frame, + output_frame=output_frame, + ) + tabular_gwcs.bounding_box = None # Store the intended unit from the origin input array # tabular_gwcs._input_unit = orig_array.unit + return tabular_gwcs + class SpectralAxis(SpectralCoord): - """ + r""" Coordinate object representing spectral values corresponding to a specific spectrum. Overloads SpectralCoord with additional information (currently only bin edges). @@ -75,6 +218,9 @@ def __new__(cls, value, *args, bin_specification="centers", **kwargs): ): raise ValueError("u.pix spectral axes should always be ascending") + if bin_specification == "edges" and value.size < 2: + raise ValueError('If bin_specification="centers" have at least two bin edges.') + # Convert to bin centers if bin edges were given, since SpectralCoord # only accepts centers if bin_specification == "edges": @@ -88,43 +234,35 @@ def __new__(cls, value, *args, bin_specification="centers", **kwargs): return obj - @staticmethod - def _edges_from_centers(centers, unit): - """ - Calculates interior bin edges based on the average of each pair of - centers, with the two outer edges based on extrapolated centers added - to the beginning and end of the spectral axis. - """ - a = np.insert(centers, 0, 2 * centers[0] - centers[1]) - b = np.append(centers, 2 * centers[-1] - centers[-2]) - edges = (a + b) / 2 - return edges * unit - @staticmethod def _centers_from_edges(edges): - """ + r""" Calculates the bin centers as the average of each pair of edges """ return (edges[1:] + edges[:-1]) / 2 @lazyproperty def bin_edges(self): - """ + r""" Calculates bin edges if the spectral axis was created with centers specified. """ - if hasattr(self, "_bin_edges"): + if hasattr(self, "_bin_edges") and self._bin_edges is not None: return self._bin_edges - return self._edges_from_centers(self.value, self.unit) + return None + + def __array_finalize__(self, obj): + super().__array_finalize__(obj) + self._bin_edges = getattr(obj, "_bin_edges", None) class Spectrum(NDCube): r""" - Spectrum container for data with one spectral axis. + Spectrum container for data which share a common spectral axis. Note that "1D" in this case refers to the fact that there is only one - spectral axis. `Spectrum` can contain "vector 1D spectra" by having the - ``flux`` have a shape with dimension greater than 1. + spectral axis. `Spectrum` can contain ND data where + ``data`` have a shape with dimension greater than 1. Notes ----- @@ -134,15 +272,15 @@ class Spectrum(NDCube): ---------- data : `~astropy.units.Quantity` The data for this spectrum. This can be a simple `~astropy.units.Quantity`, - or an existing `~Spectrum1D` or `~ndcube.NDCube` object. + or an existing `~Spectrum` or `~ndcube.NDCube` object. uncertainty : `~astropy.nddata.NDUncertainty` Contains uncertainty information along with propagation rules for spectrum arithmetic. Can take a unit, but if none is given, will use - the unit defined in the flux. + the unit defined in the data. spectral_axis : `~astropy.units.Quantity` or `~specutils.SpectralAxis` - Dispersion information with the same shape as the dimension specified by spectral_dimension - of shape plus one if specifying bin edges. - spectral_dimension : `int` default 0 + Dispersion information with the same shape as the dimension specified by spectral_axis_index + or shape plus one if specifying bin edges. + spectral_axis_index : `int` default 0 The dimension of the data which represents the spectral information default to first dimension index 0. mask : `~numpy.ndarray`-like Array where values in the flux to be masked are those that @@ -157,7 +295,7 @@ class Spectrum(NDCube): >>> import numpy as np >>> import astropy.units as u >>> from sunkit_spex.spectrum import Spectrum - >>> spec = Spectrum(np.arange(1, 11)*u.watt, spectral_axis=np.arange(1, 12)*u.keV) + >>> spec = Spectrum(np.arange(1, 11)*u.watt,spectral_axis=np.arange(1, 12)*u.keV) >>> spec 0: + raise ValueError( + "Initializer contains unknown arguments(s): {}.".format(", ".join(map(str, unknown_kwargs))) + ) + + # Handle initializing from NDCube objects + if isinstance(data, NDCube): + if data.unit is None: + raise ValueError("Input NDCube missing unit parameter") + + if spectral_axis is None: + raise ValueError("Spectral axis must be specified") + + # Change the data array from bare ndarray to a Quantity + q_data = data.data << u.Unit(data.unit) + + self.__init__( + q_data, wcs=data.wcs, mask=data.mask, uncertainty=data.uncertainty, spectral_axis=spectral_axis + ) + return + + self._spectral_axis_index = spectral_axis_index + # If here data should be an array or quantity + if spectral_axis_index is None and data is not None: + if data.ndim == 1: + self._spectral_axis_index = 0 + elif data is None: + self._spectral_axis_index = 0 # Ensure that the unit information codified in the quantity object is # the One True Unit. kwargs.setdefault("unit", data.unit if isinstance(data, u.Quantity) else kwargs.get("unit")) - # If flux and spectral axis are both specified, check that their lengths + # If a WCS is provided, determine which axis is the spectral axis + if wcs is not None: + if spectral_axis is None: + raise ValueError("Spectral axis must be specified") + + naxis = None + if hasattr(wcs, "naxis"): + naxis = wcs.naxis + # GWCS doesn't have naxis + elif hasattr(wcs, "world_n_dim"): + naxis = wcs.world_n_dim + + if naxis is not None and naxis > 1: + temp_axes = [] + phys_axes = wcs.world_axis_physical_types + if self._spectral_axis_index is None: + for i in range(len(phys_axes)): + if phys_axes[i] is None: + continue + if phys_axes[i][0:2] == "em" or phys_axes[i][0:5] == "spect" or phys_axes[i][7:12] == "Spect": + temp_axes.append(i) + if len(temp_axes) != 1: + raise ValueError( + f"Input WCS must have exactly one axis with spectral units, found {len(temp_axes)}" + ) + # Due to FITS conventions, the WCS axes are listed in opposite + # order compared to the data array. + self._spectral_axis_index = len(data.shape) - temp_axes[0] - 1 + + else: + if data is not None and data.ndim == 1: + self._spectral_axis_index = 0 + else: + if self.spectral_axis_index is None: + raise ValueError("WCS is 1D but flux is multi-dimensional. Please specify spectral_axis_index.") + + # If data and spectral axis are both specified, check that their lengths # match or are off by one (implying the spectral axis stores bin edges) + bin_specification = "centers" # default value if data is not None and spectral_axis is not None: - if spectral_axis.shape[0] == data.shape[spectral_dimension]: + if spectral_axis.shape[0] == data.shape[self.spectral_axis_index]: bin_specification = "centers" - elif spectral_axis.shape[0] == data.shape[spectral_dimension] + 1: + elif spectral_axis.shape[0] == data.shape[self.spectral_axis_index] + 1: bin_specification = "edges" else: raise ValueError( - f"Spectral axis length ({spectral_axis.shape[0]}) must be the same size or one " - "greater (if specifying bin edges) than that of the spextral" - f"axis ({data.shape[spectral_dimension]})" + f"Spectral axis length ({spectral_axis.shape[0]}) must be the " + "same size or one greater (if specifying bin edges) than that " + f"of the corresponding data axis ({data.shape[self.spectral_axis_index]})" ) # Attempt to parse the spectral axis. If none is given, try instead to # parse a given wcs. This is put into a GWCS object to # then be used behind-the-scenes for all operations. - if spectral_axis is not None: - # Ensure that the spectral axis is an astropy Quantity - if not isinstance(spectral_axis, u.Quantity): - raise ValueError("Spectral axis must be a `Quantity` or `SpectralAxis` object.") - - # If a spectral axis is provided as an astropy Quantity, convert it - # to a SpectralAxis object. - if not isinstance(spectral_axis, SpectralAxis): - if spectral_axis.shape[0] == data.shape[spectral_dimension] + 1: - bin_specification = "edges" - else: - bin_specification = "centers" - self._spectral_axis = SpectralAxis(spectral_axis, bin_specification=bin_specification) - - wcs = gwcs_from_array(self._spectral_axis) - - super().__init__( - data=data.value if isinstance(data, u.Quantity) else data, - wcs=wcs, - mask=mask, - meta=meta, - uncertainty=uncertainty, - **kwargs, - ) + + # Ensure that the spectral axis is an astropy Quantity or SpectralAxis + if not isinstance(spectral_axis, (u.Quantity, SpectralAxis)): + raise ValueError("Spectral axis must be a `Quantity` or `SpectralAxis` object.") + + # If spectral axis is provided as an astropy Quantity, convert it + # to a specutils SpectralAxis object. + if not isinstance(spectral_axis, SpectralAxis): + self._spectral_axis = SpectralAxis(spectral_axis, bin_specification=bin_specification) + # If a SpectralAxis object is provided, we assume it doesn't need + # information from other keywords added + else: + self._spectral_axis = spectral_axis + + # Check the spectral_axis matches the wcs + if wcs is not None: + wsc_coords = None + if hasattr(wcs, "spectral") and getattr(wcs, "is_spectral", False): + wcs_coords = wcs.spectral.pixel_to_world(np.arange(data.shape[self.spectral_axis_index])).to("keV") + elif wcs.pixel_n_dim == 1: + wcs_coords = wcs.pixel_to_world(np.arange(data.shape[self.spectral_axis_index])) + # else: + # array_index = wcs.pixel_n_dim - self._spectral_axis_index - 1 + # pixels = [0] * wcs.pixel_n_dim + # pixels[array_index] = np.arange(data.shape[self.spectral_axis_index]) + # wcs_coords = wcs.pixel_to_world(*pixels)[array_index] + if wsc_coords is not None: + if not u.allclose(self._spectral_axis, wcs_coords): + raise ValueError( + f"Spectral axis {self._spectral_axis} and wcs spectral axis {wcs_coords} must match." + ) + + if wcs is None: + wcs = gwcs_from_array(self._spectral_axis, data.shape, spectral_axis_index=self.spectral_axis_index) + + super().__init__(data=data.value if isinstance(data, u.Quantity) else data, wcs=wcs, **kwargs) + + # make sure that spectral axis is strictly increasing or strictly decreasing + is_strictly_increasing = np.all(self._spectral_axis[1:] > self._spectral_axis[:-1]) + if len(self._spectral_axis) > 1 and not (is_strictly_increasing): + raise ValueError("Spectral axis must be strictly increasing decreasing.") + + if hasattr(self, "uncertainty") and self.uncertainty is not None: + if not data.shape == self.uncertainty.array.shape: + raise ValueError( + f"Data axis ({data.shape}) and uncertainty ({self.uncertainty.array.shape}) shapes must be the " + "same." + ) + + def __getitem__(self, item): + sliced_cube = super().__getitem__(item) + item = tuple(sanitize_slices(item, len(self.shape))) + sliced_spec_axis = self.spectral_axis[item[self.spectral_axis_index]] + return Spectrum(sliced_cube, spectral_axis=sliced_spec_axis) + + def _slice(self, item): + kwargs = super()._slice(item) + item = tuple(sanitize_slices(item, len(self.shape))) + + kwargs["spectral_axis_index"] = self.spectral_axis_index + kwargs["spectral_axis"] = self.spectral_axis[item[self.spectral_axis_index]] + return kwargs + + def _new_instance(self, **kwargs): + keys = ("unit", "wcs", "mask", "meta", "uncertainty", "psf", "spectral_axis") + full_kwargs = {k: deepcopy(getattr(self, k, None)) for k in keys} + # We Explicitly DO NOT deepcopy any data + full_kwargs["data"] = self.data + full_kwargs.update(kwargs) + new_spectrum = type(self)(**full_kwargs) + if self.extra_coords is not None: + new_spectrum._extra_coords = deepcopy(self.extra_coords) + if self.global_coords is not None: + new_spectrum._global_coords = deepcopy(self.global_coords) + return new_spectrum + + @property + def spectral_axis(self): + return self._spectral_axis + + @property + def spectral_axis_index(self): + return self._spectral_axis_index diff --git a/sunkit_spex/spectrum/tests/test_spectrum.py b/sunkit_spex/spectrum/tests/test_spectrum.py index 1763bb05..3fec06a1 100644 --- a/sunkit_spex/spectrum/tests/test_spectrum.py +++ b/sunkit_spex/spectrum/tests/test_spectrum.py @@ -1,16 +1,410 @@ +from operator import add, mul, sub, truediv + import numpy as np +import pytest +from gwcs import coordinate_frames as cf +from ndcube import NDCube +from ndcube.extra_coords import QuantityTableCoordinate, TimeTableCoordinate +from ndcube.wcs.wrappers import CompoundLowLevelWCS from numpy.testing import assert_array_equal import astropy.units as u +from astropy.modeling import models +from astropy.nddata import StdDevUncertainty +from astropy.tests.helper import assert_quantity_allclose +from astropy.time import Time +from astropy.wcs import WCS + +from sunkit_spex.spectrum.spectrum import SpectralAxis, SpectralGWCS, Spectrum, gwcs_from_array + +rng = np.random.default_rng() + + +def test_spectral_gwcs_init_and_copy(): + # Setup dummy transform and frames + trans = models.Identity(1) + # Create distinct frames with unique names + # Usually 'detector' or 'pixel' for input, and 'world' or 'spectral' for output + input_frame = cf.CoordinateFrame(naxes=1, axes_type=["SPECTRAL"], axes_order=(0,), name="pixel_frame") + + output_frame = cf.CoordinateFrame(naxes=1, axes_type=["SPECTRAL"], axes_order=(0,), name="world_frame") + sgwcs = SpectralGWCS( + forward_transform=trans, input_frame=input_frame, output_frame=output_frame, original_unit="Angstrom" + ) + + assert sgwcs.original_unit == "Angstrom" + + # Test shallow copy + sgwcs_copy = sgwcs.copy() + assert sgwcs_copy.original_unit == sgwcs.original_unit + assert sgwcs_copy is not sgwcs + + # Test deep copy + sgwcs_deepcopy = sgwcs.deepcopy() + assert sgwcs_deepcopy.original_unit == sgwcs.original_unit + assert sgwcs_deepcopy is not sgwcs + + +def test_gwcs_from_array_1d_wavelength(): + wavelengths = np.linspace(4000, 7000, 100) * u.AA + flux_shape = (100,) + + wcs = gwcs_from_array(wavelengths, flux_shape) + + assert isinstance(wcs, SpectralGWCS) + assert wcs.output_frame.unit[0] == u.AA + assert wcs.output_frame.axes_names[0] == "wavelength" -from sunkit_spex.spectrum.spectrum import Spectrum + # Test forward transform (pixel to world) + assert np.allclose(wcs(0), 4000 << u.AA) + assert np.allclose(wcs(99), 7000 << u.AA) + # Test inverse transform (world to pixel) + assert np.allclose(wcs.invert(4000).value, 0) -def test_spectrum_bin_edges(): + +def test_gwcs_from_array_3d_cube(): + # 3D cube: (Spatial, Spatial, Spectral) -> (y, x, lambda) + # In numpy: shape is (ny, nx, nlambda) + # We want spectral axis to be index 2 + n_lambda = 50 + flux_shape = (10, 20, n_lambda) + freqs = np.linspace(100, 200, n_lambda) * u.keV + + # Note: spectral_axis_index is relative to numpy shape + wcs = gwcs_from_array(freqs, flux_shape, spectral_axis_index=2) + + assert wcs.output_frame.naxes == 3 + assert wcs.forward_transform.n_inputs == 3 + + # Test mapping: (x, y, lambda_pix) -> (spatial, spatial, freq) + # GWCS/WCS usually expects (x, y, z) input order + world = wcs.pixel_to_world(0, 0, 0) # pixels for x, y, lambda + assert world[0] == 100 * u.keV + assert wcs.output_frame.frames[1].axes_names[0] == "energy" + + +def test_gwcs_from_array_invalid_units(): + data = np.arange(10) * u.Jy # Flux units are not valid for spectral axis + with pytest.raises(ValueError, match="Spectral axis units must be one of"): + gwcs_from_array(data, (10,)) + + +def test_gwcs_from_array_missing_index(): + data = np.linspace(1, 10, 10) * u.m + # 2D flux but no index provided + with pytest.raises(ValueError, match="spectral_axis_index must be set"): + gwcs_from_array(data, (10, 10)) + + +def test_spectrum_quantity_bin_edges(): spec = Spectrum(np.arange(1, 11) * u.watt, spectral_axis=np.arange(1, 12) * u.keV) assert_array_equal(spec._spectral_axis, [1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5, 10.5] * u.keV) -def test_spectrum_bin_centers(): - spec = Spectrum(np.arange(1, 11) * u.watt, spectral_axis=(np.arange(1, 11) - 0.5) * u.keV) - assert_array_equal(spec._spectral_axis, [0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5] * u.keV) +def test_spectrum_quantity_bin_centers(): + spec = Spectrum(np.arange(1, 11) * u.watt, spectral_axis=(np.arange(1, 11) + 0.5) * u.keV) + assert_array_equal(spec._spectral_axis, [1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5, 10.5] * u.keV) + + +def test_spectrum_spectral_axis_bin_edges(): + spec_axis = SpectralAxis(np.arange(1, 12) * u.keV, bin_specification="edges") + spec = Spectrum(np.arange(1, 11) * u.watt, spectral_axis=spec_axis) + assert_array_equal(spec._spectral_axis, [1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5, 10.5] * u.keV) + + +def test_spectrum_spectral_axis_bin_centers(): + spec_axis = SpectralAxis((np.arange(1, 11) + 0.5) * u.keV, bin_specification="centers") + spec = Spectrum(np.arange(1, 11) * u.watt, spectral_axis=spec_axis) + assert_array_equal(spec._spectral_axis, [1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5, 10.5] * u.keV) + + +def test_spectrum_from_spectrum(): + spec_orig = Spectrum(np.arange(1, 11) * u.watt, spectral_axis=np.arange(1, 12) * u.keV) + spec_new = Spectrum(spec_orig) + spec_orig == spec_new + + +def test_spectrum_unknow_keywords(): + with pytest.raises(ValueError, match=r"Initializer contains unknown arguments."): + Spectrum(np.arange(1, 11) * u.watt, spectral_axis=np.arange(1, 12) * u.keV, mykeyword="myvalue") + + +def test_spectrum_from_ndcube_wcs(): + header = { + "CTYPE1": "TIME ", + "CUNIT1": "s", + "CDELT1": 10, + "CRPIX1": 0, + "CRVAL1": 0, + "CTYPE2": "ENER ", + "CUNIT2": "keV", + "CDELT2": 1, + "CRPIX2": 0.5, + "CRVAL2": 0.0, + "DATEREF": "2020-01-01T00:00:00", + } + wcs = WCS(header=header) + shape = (10, 5) + wcs.array_shape = shape + + spec_axis_edges = np.arange(11) * u.keV + spec_axis_centers = spec_axis_edges[:-1] + np.diff(spec_axis_edges) * 0.5 + + data = np.arange(np.prod(shape), dtype=int).reshape(shape) + cube = NDCube(data.T, wcs=wcs) + with pytest.raises(ValueError, match=r"Input NDCube missing unit.*"): + Spectrum(cube) + + cube = NDCube(data, wcs=wcs, unit=u.ph) + + with pytest.raises(ValueError, match="Spectral axis must be specified"): + Spectrum(cube) + + with pytest.raises(ValueError, match=r"Spectral axis"): + Spectrum(cube, spectral_axis=spec_axis_edges[1:-1]) + + spec = Spectrum(cube, spectral_axis=spec_axis_edges) + assert isinstance(spec, Spectrum) + assert spec.spectral_axis_index == 0 + assert spec.shape == (10, 5) + assert_quantity_allclose(spec_axis_centers, spec.wcs.pixel_to_world(0, np.arange(10))[1].to("keV")) + assert_quantity_allclose(spec_axis_edges, spec.spectral_axis.bin_edges) + + +def test_spectrum_from_cube_wcs_tab(): + spec_axis_edges = np.arange(11) * u.keV + spec_axis_centers = spec_axis_edges[:-1] + np.diff(spec_axis_edges) * 0.5 + energy_coord = QuantityTableCoordinate(spec_axis_centers, names="energy", physical_types="em.energy") + data = rng.random(len(spec_axis_centers)) + cube = NDCube(data=data, wcs=energy_coord.wcs, unit=u.ph) + + spec = Spectrum(cube, spectral_axis=spec_axis_edges) + assert isinstance(spec, Spectrum) + + assert spec.spectral_axis_index == 0 + assert spec.shape == (10,) + assert_quantity_allclose(spec_axis_centers, spec.wcs.pixel_to_world(np.arange(10)).to("keV")) + + +def test_spectrum_spectra_axis_detection(): + energy = (np.arange(0, 10) + 0.5) * u.keV + energy_coord = QuantityTableCoordinate(energy, names="energy", physical_types="em.energy") + times = Time("2020-01-01") + 5 * np.arange(0, 11) * u.s + time_coord = TimeTableCoordinate(times, names="time", physical_types="time") + time_energy_wcs = (time_coord & energy_coord).wcs + data = np.arange(5 * 10).reshape(10, 5) + spec1 = Spectrum(data * u.ph, wcs=time_energy_wcs, spectral_axis=np.arange(11) * u.keV) + assert spec1.spectral_axis_index == 0 + + energy_energy_wcs = (energy_coord & time_coord).wcs + data = np.arange(10 * 5).reshape(5, 10) + spec2 = Spectrum(data * u.ph, wcs=energy_energy_wcs, spectral_axis=np.arange(11) * u.keV) + assert spec2.spectral_axis_index == 1 + + +def test_spectrum_from_cubs_wcs_norm_tab(): + header = { + "CTYPE1": "TIME ", + "CUNIT1": "s", + "CDELT1": 10, + "CRPIX1": 0, + "CRVAL1": 0, + "DATEREF": "2020-01-01T00:00:00", + } + time_wcs = WCS(header=header) + energy = (0.5 + np.arange(10)) * u.keV + energy_coord = QuantityTableCoordinate(energy, names="energy", physical_types="em.energy") + comp_wcs = CompoundLowLevelWCS(time_wcs, energy_coord.wcs) + cube = NDCube(np.arange(10 * 5).reshape(10, 5), unit=u.ph, wcs=comp_wcs) + spec = Spectrum(cube, spectral_axis=np.arange(11) * u.keV) + assert spec.shape == (10, 5) + assert spec.spectral_axis_index == 0 + assert_quantity_allclose(energy, spec.spectral_axis) + + +def test_slice(): + spec = Spectrum(np.arange(1, 11) * u.watt, spectral_axis=(np.arange(1, 11) + 0.5) * u.keV) + sliced_spec = spec[5:] + assert sliced_spec.shape == (5,) + assert sliced_spec.spectral_axis.shape == (5,) + + +@pytest.mark.parametrize( + "op, value, res", + [ + (add, 2 * u.W, np.arange(1.0, 11) + 2), + (sub, 2 * u.W, np.arange(1.0, 11) - 2), + (mul, 2, np.arange(1.0, 11) * 2), + (truediv, 2 * u.W, np.arange(1.0, 11) / 2), + ], +) +def test_arithmetic_operators(op, value, res): + spec = Spectrum(np.arange(1, 11) * u.watt, spectral_axis=(np.arange(1, 11) + 0.5) * u.keV) + res_spec = op(spec, value) + assert_array_equal(res_spec.data, res) + + +def test_spectral_axis_bin_edges_from_centers(): + """Test that bin_edges are correctly calculated when SpectralAxis is created with centers.""" + spec_axis = SpectralAxis(np.array([1.5, 2.5, 3.5, 4.5]) * u.keV, bin_specification="centers") + edges = spec_axis.bin_edges + assert edges is None + + +def test_spectral_axis_bin_edges_preserved(): + """Test that bin_edges are preserved when SpectralAxis is created with edges.""" + input_edges = np.array([1.0, 2.0, 3.0, 4.0, 5.0]) * u.keV + spec_axis = SpectralAxis(input_edges, bin_specification="edges") + assert_quantity_allclose(spec_axis.bin_edges, input_edges) + + +def test_spectral_axis_centers_from_edges(): + """Test that centers are correctly calculated from edges.""" + input_edges = np.array([1.0, 2.0, 3.0, 4.0, 5.0]) * u.keV + spec_axis = SpectralAxis(input_edges, bin_specification="edges") + assert_quantity_allclose(spec_axis, [1.5, 2.5, 3.5, 4.5] * u.keV) + + +def test_spectral_axis_single_center(): + """Test SpectralAxis handles single-element arrays.""" + spec_axis = SpectralAxis(np.array([5.0]) * u.keV, bin_specification="centers") + edges = spec_axis.bin_edges + assert edges is None + + +def test_spectral_axis_single_bin(): + """Test SpectralAxis handles single bins""" + with pytest.raises(ValueError, match="If bin_specification"): + SpectralAxis(np.array([5.0]) * u.keV, bin_specification="edges") + + spec_axis = SpectralAxis(np.array([5.0, 6.0]) * u.keV, bin_specification="edges") + edges = spec_axis.bin_edges + assert edges is not None + assert len(edges) == 2 + assert spec_axis == 5.5 * u.keV + + +def test_spectral_axis_empty_array(): + """Test SpectralAxis handles empty arrays.""" + edges = SpectralAxis(np.array([]), u.keV) + assert len(edges) == 0 + + +def test_spectral_axis_pixel_ascending(): + """Test that pixel spectral axes must be ascending.""" + with pytest.raises(ValueError, match=r"u\.pix spectral axes should always be ascending"): + SpectralAxis(np.array([5, 4, 3, 2, 1]) * u.pix) + + +def test_spectral_axis_pixel_ascending_valid(): + """Test that ascending pixel spectral axes are accepted.""" + spec_axis = SpectralAxis(np.array([1, 2, 3, 4, 5]) * u.pix) + assert len(spec_axis) == 5 + + +def test_spectrum_from_spectrum_inherits_attributes(): + """Test that Spectrum created from another Spectrum inherits spectral_axis and spectral_axis_index.""" + spec_orig = Spectrum(np.arange(1, 11) * u.watt, spectral_axis=np.arange(1, 12) * u.keV) + spec_new = Spectrum(spec_orig) + + # Verify spectral_axis_index is inherited (Bug #1 fix) + assert spec_new.spectral_axis_index == spec_orig.spectral_axis_index + assert spec_new.spectral_axis_index == 0 + + # Verify spectral_axis is inherited + assert_quantity_allclose(spec_new.spectral_axis, spec_orig.spectral_axis) + + +def test_spectrum_from_spectrum_preserves_data(): + """Test that Spectrum created from another Spectrum preserves data.""" + data = np.arange(1, 11) * u.watt + spec_orig = Spectrum(data, spectral_axis=np.arange(1, 12) * u.keV) + spec_new = Spectrum(spec_orig) + + assert_array_equal(spec_new.data, spec_orig.data) + assert spec_new.unit == spec_orig.unit + + +def test_spectrum_strictly_increasing_spectral_axis(): + """Test that strictly increasing spectral axis is accepted.""" + spec = Spectrum(np.arange(1, 6) * u.watt, spectral_axis=np.array([1, 2, 3, 4, 5]) * u.keV) + assert_quantity_allclose(spec.spectral_axis, [1, 2, 3, 4, 5] * u.keV) + + +def test_spectrum_non_monotonic_spectral_axis_raises(): + """Test that non-monotonic spectral axis raises ValueError.""" + with pytest.raises(ValueError, match="strictly increasing"): + Spectrum(np.arange(1, 6) * u.watt, spectral_axis=np.array([1, 3, 2, 4, 5]) * u.keV) + + +def test_spectrum_duplicate_values_in_spectral_axis_raises(): + """Test that duplicate values in spectral axis raises ValueError.""" + with pytest.raises(ValueError, match="strictly increasing"): + Spectrum(np.arange(1, 5) * u.watt, spectral_axis=np.array([1, 2, 2, 3]) * u.keV) + + +def test_spectrum_single_element_spectral_axis(): + """Test that single-element spectral axis is accepted.""" + spec = Spectrum(np.array([5]) * u.watt, spectral_axis=np.array([10]) * u.keV) + assert spec.shape == (1,) + assert_quantity_allclose(spec.spectral_axis, [10] * u.keV) + + +def test_spectrum_spectral_axis_length_mismatch(): + """Test that mismatched spectral axis length raises ValueError.""" + with pytest.raises(ValueError, match="Spectral axis length"): + Spectrum(np.arange(1, 11) * u.watt, spectral_axis=np.arange(1, 5) * u.keV) + + +def test_spectrum_uncertainty_shape_mismatch(): + """Test that mismatched uncertainty shape raises ValueError.""" + data = np.arange(1, 11) * u.watt + uncertainty = StdDevUncertainty(np.arange(1, 6)) # Wrong shape + with pytest.raises(ValueError, match=r"Data axis .* and uncertainty .* shapes must be the same"): + Spectrum(data, spectral_axis=np.arange(1, 12) * u.keV, uncertainty=uncertainty) + + +def test_spectrum_with_valid_uncertainty(): + """Test Spectrum with correctly shaped uncertainty.""" + data = np.arange(1, 11) * u.watt + uncertainty = StdDevUncertainty(np.ones(10) * 0.1) + spec = Spectrum(data, spectral_axis=np.arange(1, 12) * u.keV, uncertainty=uncertainty) + assert spec.uncertainty is not None + assert spec.uncertainty.array.shape == data.shape + + +def test_slice_preserves_spectral_axis_index(): + """Test that slicing preserves spectral_axis_index.""" + spec = Spectrum(np.arange(1, 11) * u.watt, spectral_axis=(np.arange(1, 11) + 0.5) * u.keV) + sliced = spec[2:7] + assert sliced.spectral_axis_index == spec.spectral_axis_index + + +def test_slice_updates_spectral_axis(): + """Test that slicing correctly slices spectral_axis.""" + spec = Spectrum(np.arange(1, 11) * u.watt, spectral_axis=(np.arange(1, 11) + 0.5) * u.keV) + sliced = spec[2:5] + assert_quantity_allclose(sliced.spectral_axis, [3.5, 4.5, 5.5] * u.keV) + + +def test_slice_single_element(): + """Test slicing to a single element.""" + spec = Spectrum(np.arange(1, 11) * u.watt, spectral_axis=(np.arange(1, 11) + 0.5) * u.keV) + sliced = spec[5:6] + assert sliced.shape == (1,) + assert_quantity_allclose(sliced.spectral_axis, [6.5] * u.keV) + + +def test_arithmetic_preserves_spectral_axis(): + """Test that arithmetic operations preserve spectral_axis.""" + spec = Spectrum(np.arange(1, 11) * u.watt, spectral_axis=(np.arange(1, 11) + 0.5) * u.keV) + result = spec + 1 * u.watt + assert_quantity_allclose(result.spectral_axis, spec.spectral_axis) + + +def test_arithmetic_preserves_spectral_axis_index(): + """Test that arithmetic operations preserve spectral_axis_index.""" + spec = Spectrum(np.arange(1, 11) * u.watt, spectral_axis=(np.arange(1, 11) + 0.5) * u.keV) + result = spec * 2 + assert result.spectral_axis_index == spec.spectral_axis_index