diff --git a/jax_galsim/bounds.py b/jax_galsim/bounds.py index ed5942af..b7b09d8a 100644 --- a/jax_galsim/bounds.py +++ b/jax_galsim/bounds.py @@ -1,27 +1,65 @@ import galsim as _galsim import jax import jax.numpy as jnp +import numpy as np from jax.tree_util import register_pytree_node_class from jax_galsim.core.utils import ( cast_to_float, cast_to_int, ensure_hashable, + has_tracers, implements, ) from jax_galsim.position import Position, PositionD, PositionI +CONST_TYPES = (float, int, np.ndarray, np.int32, np.int64, np.float32, np.float64) +CONST_TYPES_WITH_JAX = CONST_TYPES + ( + jax.Array, + jnp.array, + jnp.int32, + jnp.int64, + jnp.float32, + jnp.float64, +) + +# TODO: write extra docs for JAX changes BOUNDS_LAX_DESCR = """\ The JAX implementation - will not always test whether the bounds are valid - will not always test whether BoundsI is initialized with integers + +Further, the JAX implementation adds a new method, ``isStatic`` to the +``BoundsI`` class. If JAX-GalSim detects that the ``BoundsI`` instance +has been instantiated with static, known values, ``isStatic()`` will +return ``True``. You can indicate to JAX-GalSim that a ``BoundsI`` +instance should be static via initializing it with the ``static`` +keyword set to the ``True``. If the object detects that it is being +initialized with non-static data, an error will be raised. + +``BoundsI`` objects in JAX-Galsim support an additional initialization +call ``BoundsI(xmin=..., deltax=..., ymin=..., deltay=...)``. In this case, +the values for ``deltax/y`` indicate the width of the bounds and must be +static constants. + +When calling ``jax.vmap`` over ``BoundsI`` objects, only ``x/ymin`` +are vectorized over. This restriction allows for code that renders +objects in fixed sized stamps with variable locations, a common +operation. ``BoundsI`` objects which are static (i.e., ``isStatic()`` +returns ``True``) are treated as constants with respect to ``vmap``, +``jit``, and other JAX transforms. """ @implements(_galsim.Bounds, lax_description=BOUNDS_LAX_DESCR) @register_pytree_node_class -class Bounds(_galsim.Bounds): +class Bounds: + def __init__(self): + raise NotImplementedError( + "Cannot instantiate the base class. Use either BoundsD or BoundsI." + ) + def _parse_args(self, *args, **kwargs): if len(kwargs) == 0: if len(args) == 4: @@ -29,14 +67,23 @@ def _parse_args(self, *args, **kwargs): self.xmin, self.xmax, self.ymin, self.ymax = args elif len(args) == 0: self._isdefined = False - self.xmin = self.xmax = self.ymin = self.ymax = 0 + self.xmin = 0 + self.ymin = 0 + self.deltax = 0 + self.deltay = 0 elif len(args) == 1: if isinstance(args[0], Bounds): - self._isdefined = True + if isinstance(self, BoundsI) and isinstance(args[0], BoundsD): + offset = 1 + elif isinstance(self, BoundsD) and isinstance(args[0], BoundsI): + offset = -1 + else: + offset = 0 + self._isdefined = args[0]._isdefined self.xmin = args[0].xmin - self.xmax = args[0].xmax + self.deltax = args[0].deltax + offset self.ymin = args[0].ymin - self.ymax = args[0].ymax + self.deltay = args[0].deltay + offset elif isinstance(args[0], Position): self._isdefined = True self.xmin = self.xmax = args[0].x @@ -73,27 +120,73 @@ def _parse_args(self, *args, **kwargs): try: self._isdefined = True self.xmin = kwargs.pop("xmin") - self.xmax = kwargs.pop("xmax") self.ymin = kwargs.pop("ymin") - self.ymax = kwargs.pop("ymax") except KeyError: raise TypeError( - "Keyword arguments, xmin, xmax, ymin, ymax are required for %s" + "Keyword arguments, xmin, ymin are required for %s" + % (self.__class__.__name__) + ) + + if "xmax" in kwargs and "ymax" in kwargs: + self.xmax = kwargs.pop("xmax") + self.ymax = kwargs.pop("ymax") + elif "deltax" in kwargs and "deltay" in kwargs: + self.deltax = kwargs.pop("deltax") + self.deltay = kwargs.pop("deltay") + else: + raise TypeError( + "Keyword arguments, either (xmax, ymax) " + "or (deltax, deltay) are required for %s" % (self.__class__.__name__) ) + if kwargs: raise TypeError("Got unexpected keyword arguments %s" % kwargs.keys()) # for simple inputs, we can check if the bounds are valid + if isinstance(self, BoundsD): + max_delta = 0 + else: + max_delta = 1 if ( - isinstance(self.xmin, (float, int)) - and isinstance(self.xmax, (float, int)) - and isinstance(self.ymin, (float, int)) - and isinstance(self.ymax, (float, int)) - and ((self.xmin > self.xmax) or (self.ymin > self.ymax)) + isinstance(self.deltax, CONST_TYPES) + and isinstance(self.deltay, CONST_TYPES) + and (self.deltax < max_delta or self.deltay < max_delta) ): self._isdefined = False + @implements(_galsim.Bounds.area) + def area(self): + return self._area() + + @implements(_galsim.Bounds.withBorder) + def withBorder(self, dx, dy=None): + self._check_scalar(dx, "dx") + if dy is None: + dy = dx + else: + self._check_scalar(dy, "dy") + return self.__class__( + xmin=self.xmin - dx, + deltax=self.deltax + 2 * dx, + ymin=self.ymin - dy, + deltay=self.deltay + 2 * dy, + ) + + @property + @implements(_galsim.Bounds.origin) + def origin(self): + return self._pos_class(self.xmin, self.ymin) + + @property + @implements(_galsim.Bounds.center) + def center(self): + if not self.isDefined(): + raise _galsim.GalSimUndefinedBoundsError( + "center is invalid for an undefined Bounds" + ) + return self._center + @property @implements(_galsim.Bounds.true_center) def true_center(self): @@ -110,18 +203,20 @@ def includes(self, *args): b = args[0] return ( self.isDefined() - and b.isDefined() - and self.xmin <= b.xmin - and self.xmax >= b.xmax - and self.ymin <= b.ymin - and self.ymax >= b.ymax + & b.isDefined() + & (self.xmin <= b.xmin) + & (self.xmax >= b.xmax) + & (self.ymin <= b.ymin) + & (self.ymax >= b.ymax) ) elif isinstance(args[0], Position): p = args[0] return ( self.isDefined() - and self.xmin <= p.x <= self.xmax - and self.ymin <= p.y <= self.ymax + & (self.xmin <= p.x) + & (self.ymin <= p.y) + & (p.x <= self.xmax) + & (p.y <= self.ymax) ) else: raise TypeError("Invalid argument %s" % args[0]) @@ -129,8 +224,10 @@ def includes(self, *args): x, y = args return ( self.isDefined() - and self.xmin <= float(x) <= self.xmax - and self.ymin <= float(y) <= self.ymax + & (self.xmin <= float(x)) + & (self.ymin <= float(y)) + & (float(x) <= self.xmax) + & (float(y) <= self.ymax) ) elif len(args) == 0: raise TypeError("include takes at least 1 argument (0 given)") @@ -148,6 +245,37 @@ def expand(self, factor_x, factor_y=None): dy = jnp.ceil(dy) return self.withBorder(dx, dy) + @implements(_galsim.Bounds.isDefined) + def isDefined(self): + return self._isdefined + + @implements(_galsim.Bounds.getXMin) + def getXMin(self): + return self.xmin + + @implements(_galsim.Bounds.getXMax) + def getXMax(self): + return self.xmax + + @implements(_galsim.Bounds.getYMin) + def getYMin(self): + return self.ymin + + @implements(_galsim.Bounds.getYMax) + def getYMax(self): + return self.ymax + + @implements(_galsim.Bounds.shift) + def shift(self, delta): + if not isinstance(delta, self._pos_class): + raise TypeError("delta must be a %s instance" % self._pos_class) + return self.__class__( + xmin=self.xmin + delta.x, + deltax=self.deltax, + ymin=self.ymin + delta.y, + deltay=self.deltay, + ) + def __and__(self, other): if not isinstance(other, self.__class__): raise TypeError("other must be a %s instance" % self.__class__.__name__) @@ -190,38 +318,29 @@ def __add__(self, other): % (self.__class__.__name__, self._pos_class.__name__) ) - def __repr__(self): + def _getinitargs(self): if self.isDefined(): - return "galsim.%s(xmin=%r, xmax=%r, ymin=%r, ymax=%r)" % ( - self.__class__.__name__, - ensure_hashable(self.xmin), - ensure_hashable(self.xmax), - ensure_hashable(self.ymin), - ensure_hashable(self.ymax), - ) + return (self.xmin, self.xmax, self.ymin, self.ymax) else: - return "galsim.%s()" % (self.__class__.__name__) + return () - def __str__(self): - if self.isDefined(): - return "galsim.%s(%s,%s,%s,%s)" % ( - self.__class__.__name__, - ensure_hashable(self.xmin), - ensure_hashable(self.xmax), - ensure_hashable(self.ymin), - ensure_hashable(self.ymax), - ) - else: - return "galsim.%s()" % (self.__class__.__name__) + def __eq__(self, other): + return self is other or ( + isinstance(other, self.__class__) + and self._getinitargs() == other._getinitargs() + ) + + def __ne__(self, other): + return not self.__eq__(other) def __hash__(self): return hash( ( self.__class__.__name__, ensure_hashable(self.xmin), - ensure_hashable(self.xmax), + ensure_hashable(self.deltax), ensure_hashable(self.ymin), - ensure_hashable(self.ymax), + ensure_hashable(self.deltay), ) ) @@ -230,7 +349,7 @@ def tree_flatten(self): nodes that will be traced by JAX and auxiliary static data.""" # Define the children nodes of the PyTree that need tracing if self.isDefined(): - children = (self.xmin, self.xmax, self.ymin, self.ymax) + children = (self.xmin, self.deltax, self.ymin, self.deltay) else: children = tuple() # Define auxiliary static data that doesn’t need to be traced @@ -240,15 +359,25 @@ def tree_flatten(self): @classmethod def tree_unflatten(cls, aux_data, children): """Recreates an instance of the class from flatten representation""" - return cls(*children) + if children: + return cls( + xmin=children[0], + deltax=children[1], + ymin=children[2], + deltay=children[3], + ) + else: + return cls() @classmethod def from_galsim(cls, galsim_bounds): """Create a jax_galsim `BoundsD/I` from a `galsim.BoundsD/I` object.""" if isinstance(galsim_bounds, _galsim.BoundsD): _cls = BoundsD + kwargs = {} elif isinstance(galsim_bounds, _galsim.BoundsI): _cls = BoundsI + kwargs = {"static": True} else: raise TypeError( "galsim_bounds must be either a %s or a %s" @@ -260,6 +389,7 @@ def from_galsim(cls, galsim_bounds): galsim_bounds.xmax, galsim_bounds.ymin, galsim_bounds.ymax, + **kwargs, ) else: return _cls() @@ -283,6 +413,12 @@ def to_galsim(self): else: return gs_class() + def isStatic(self): + """Returns ``True`` if the ``BoundsI`` instance + has static, known dimensions and location. Always returns + ``False`` for ``BoundsD``.""" + return self._isstatic + @implements(_galsim.BoundsD, lax_description=BOUNDS_LAX_DESCR) @register_pytree_node_class @@ -290,11 +426,12 @@ class BoundsD(Bounds): _pos_class = PositionD def __init__(self, *args, **kwargs): + self._isstatic = False self._parse_args(*args, **kwargs) self.xmin = cast_to_float(self.xmin) - self.xmax = cast_to_float(self.xmax) + self.deltax = cast_to_float(self.deltax) self.ymin = cast_to_float(self.ymin) - self.ymax = cast_to_float(self.ymax) + self.deltay = cast_to_float(self.deltay) def _check_scalar(self, x, name): try: @@ -310,13 +447,64 @@ def _check_scalar(self, x, name): pass raise TypeError("%s must be a float value" % name) + @property + def xmax(self): + return self.xmin + self.deltax + + @xmax.setter + def xmax(self, value): + self.deltax = value - self.xmin + + @property + def ymax(self): + return self.ymin + self.deltay + + @ymax.setter + def ymax(self, value): + self.deltay = value - self.ymin + def _area(self): - return (self.xmax - self.xmin) * (self.ymax - self.ymin) + return self.deltax * self.deltay @property def _center(self): return PositionD((self.xmax + self.xmin) / 2.0, (self.ymax + self.ymin) / 2.0) + def __repr__(self): + if self.isDefined(): + return "galsim.%s(%r, %r, %r, %r)" % ( + self.__class__.__name__, + ensure_hashable(self.xmin), + ensure_hashable(self.xmax), + ensure_hashable(self.ymin), + ensure_hashable(self.ymax), + ) + else: + return "galsim.%s()" % (self.__class__.__name__) + + def __str__(self): + if self.isDefined(): + return "galsim.%s(%s,%s,%s,%s)" % ( + self.__class__.__name__, + ensure_hashable(self.xmin), + ensure_hashable(self.xmax), + ensure_hashable(self.ymin), + ensure_hashable(self.ymax), + ) + else: + return "galsim.%s()" % (self.__class__.__name__) + + def __hash__(self): + return hash( + ( + self.__class__.__name__, + ensure_hashable(self.xmin), + ensure_hashable(self.deltax), + ensure_hashable(self.ymin), + ensure_hashable(self.deltay), + ) + ) + @implements(_galsim.BoundsI, lax_description=BOUNDS_LAX_DESCR) @register_pytree_node_class @@ -324,26 +512,50 @@ class BoundsI(Bounds): _pos_class = PositionI def __init__(self, *args, **kwargs): + # initial setting to let stuff pass through freely + self._isstatic = True + + force_static = kwargs.pop("static", False) + self._parse_args(*args, **kwargs) - # for simple inputs, we can check if the bounds are valid ints - if ( - isinstance(self.xmin, (float, int)) - and isinstance(self.xmax, (float, int)) - and isinstance(self.ymin, (float, int)) - and isinstance(self.ymax, (float, int)) - and ( - self.xmin != int(self.xmin) - or self.xmax != int(self.xmax) - or self.ymin != int(self.ymin) - or self.ymax != int(self.ymax) + + if has_tracers(self.deltax) or has_tracers(self.deltay): + raise RuntimeError( + "Jax-GalSim BoundsI instances must have a fixed width! " + f"Got deltax,deltay = {self.deltax!r},{self.deltay!r}." ) - ): + + self.deltax = int(cast_to_int(self.deltax)) + self.deltay = int(cast_to_int(self.deltay)) + + if (self.deltax != int(self.deltax)) or (self.deltay != int(self.deltay)): raise TypeError("BoundsI must be initialized with integer values") - self.xmin = cast_to_int(self.xmin) - self.xmax = cast_to_int(self.xmax) - self.ymin = cast_to_int(self.ymin) - self.ymax = cast_to_int(self.ymax) + if self.deltax < 1 and self.deltay < 1: + self._isdefined = False + + # for simple inputs, we can check if the bounds are valid ints + if isinstance(self._xmin, CONST_TYPES) and self._xmin != int(self._xmin): + raise TypeError("BoundsI must be initialized with integer values") + + if isinstance(self._ymin, CONST_TYPES) and self._ymin != int(self._ymin): + raise TypeError("BoundsI must be initialized with integer values") + + if not has_tracers(self._xmin) and not has_tracers(self._ymin): + self._isstatic = True + self._xmin = int(np.trunc(self._xmin)) + self._ymin = int(np.trunc(self._ymin)) + else: + self._isstatic = False + self._xmin = cast_to_float(jnp.trunc(self._xmin)) + self._ymin = cast_to_float(jnp.trunc(self._ymin)) + + if force_static and not self._isstatic: + raise RuntimeError( + "BoundsI initialized with non-static " + f"data (xmin,ymin = {self._xmin},{self._yminb}) " + "when static data was explicitly requested." + ) def _check_scalar(self, x, name): try: @@ -362,16 +574,60 @@ def _check_scalar(self, x, name): def numpyShape(self): "A simple utility function to get the numpy shape that corresponds to this `Bounds` object." if self.isDefined(): - return self.ymax - self.ymin + 1, self.xmax - self.xmin + 1 + return self.deltay, self.deltax else: return 0, 0 + @property + def xmin(self): + if self._isstatic: + return self._xmin + else: + return jnp.astype(self._xmin, jnp.int_) + + @xmin.setter + def xmin(self, value): + if self._isstatic: + self._xmin = value + else: + self._xmin = jnp.astype(value, jnp.float_) + + @property + def xmax(self): + return self.xmin + self.deltax - 1 + + @xmax.setter + def xmax(self, value): + self.deltax = value - self.xmin + 1 + + @property + def ymin(self): + if self._isstatic: + return self._ymin + else: + return jnp.astype(self._ymin, jnp.int_) + + @ymin.setter + def ymin(self, value): + if self._isstatic: + self._ymin = value + else: + self._ymin = jnp.astype(value, jnp.float_) + + @property + def ymax(self): + return self.ymin + self.deltay - 1 + + @ymax.setter + def ymax(self, value): + self.deltay = value - self.ymin + 1 + def _area(self): # Remember the + 1 this time to include the pixels on both edges of the bounds. if not self.isDefined(): return 0 else: - return (self.xmax - self.xmin + 1) * (self.ymax - self.ymin + 1) + return self.deltax * self.deltay @property def _center(self): @@ -381,6 +637,94 @@ def _center(self): # (-10,-1,-10,-1) -> (-5,-5) # Just up and to the right of the true center in both cases. return PositionI( - self.xmin + (self.xmax - self.xmin + 1) // 2, - self.ymin + (self.ymax - self.ymin + 1) // 2, + self.xmin + self.deltax // 2, + self.ymin + self.deltay // 2, + ) + + def tree_flatten(self): + """This function flattens the Bounds into a list of children + nodes that will be traced by JAX and auxiliary static data.""" + # Define the children nodes of the PyTree that need tracing + if self._isstatic: + # Define the children nodes of the PyTree that need tracing + children = tuple() + + # Define auxiliary static data that doesn’t need to be traced + aux_data = { + "xmin": self._xmin, + "ymin": self._ymin, + "deltax": self.deltax, + "deltay": self.deltay, + } + else: + children = (self._xmin, self._ymin) + # Define auxiliary static data that doesn’t need to be traced + aux_data = {"deltax": self.deltax, "deltay": self.deltay} + return (children, aux_data) + + @classmethod + def tree_unflatten(cls, aux_data, children): + """Recreates an instance of the class from flatten representation""" + ret = cls.__new__(cls) + if "xmin" in aux_data and "ymin" in aux_data: + ret._isstatic = True + ret._xmin = aux_data["xmin"] + ret._ymin = aux_data["ymin"] + else: + ret._isstatic = False + ret._xmin = children[0] + ret._ymin = children[1] + ret.deltax = aux_data["deltax"] + ret.deltay = aux_data["deltay"] + if ret.deltax < 1 and ret.deltay < 1: + ret._isdefined = False + else: + ret._isdefined = True + + return ret + + def __repr__(self): + if self.isDefined(): + return "galsim.%s(xmin=%r, deltax=%r, ymin=%r, deltay=%r)" % ( + self.__class__.__name__, + ensure_hashable(self.xmin), + ensure_hashable(self.deltax), + ensure_hashable(self.ymin), + ensure_hashable(self.deltay), + ) + else: + return "galsim.%s()" % (self.__class__.__name__) + + def __str__(self): + if self.isDefined(): + return "galsim.%s(xmin=%s, deltax=%s, ymin=%s, deltay=%s)" % ( + self.__class__.__name__, + ensure_hashable(self.xmin), + ensure_hashable(self.deltax), + ensure_hashable(self.ymin), + ensure_hashable(self.deltay), + ) + else: + return "galsim.%s()" % (self.__class__.__name__) + + def _getinitargs(self): + if self.isDefined(): + return (self.xmin, self.deltax, self.ymin, self.deltay) + else: + return () + + def __eq__(self, other): + return self is other or ( + isinstance(other, BoundsI) and self._getinitargs() == other._getinitargs() + ) + + def __hash__(self): + return hash( + ( + self.__class__.__name__, + ensure_hashable(self.xmin), + ensure_hashable(self.deltax), + ensure_hashable(self.ymin), + ensure_hashable(self.deltay), + ) ) diff --git a/jax_galsim/core/wrap_image.py b/jax_galsim/core/wrap_image.py index 72f1bb2d..de410408 100644 --- a/jax_galsim/core/wrap_image.py +++ b/jax_galsim/core/wrap_image.py @@ -55,7 +55,7 @@ def _block_reduce_loop(sim, nx, ny, nxwrap, nywrap): return fim -@partial(jax.jit, static_argnames=("xmin", "ymin", "nxwrap", "nywrap")) +@partial(jax.jit, static_argnames=("nxwrap", "nywrap")) def wrap_nonhermitian(im, xmin, ymin, nxwrap, nywrap): # these bits compute how many total blocks we need to cover the image nx = im.shape[1] // nxwrap @@ -81,7 +81,11 @@ def wrap_nonhermitian(im, xmin, ymin, nxwrap, nywrap): else: fim = _block_reduce_loop(sim, nx, ny, nxwrap, nywrap) - im = im.at[ymin : ymin + nywrap, xmin : xmin + nxwrap].set(fim) + im = jax.lax.dynamic_update_slice( + im, + fim, + (ymin, xmin), + ) return im @@ -98,10 +102,6 @@ def contract_hermitian_x(im): @partial( jax.jit, static_argnames=[ - "im_xmin", - "im_ymin", - "wrap_xmin", - "wrap_ymin", "wrap_nx", "wrap_ny", ], @@ -127,10 +127,6 @@ def contract_hermitian_y(im): @partial( jax.jit, static_argnames=[ - "im_xmin", - "im_ymin", - "wrap_xmin", - "wrap_ymin", "wrap_nx", "wrap_ny", ], diff --git a/jax_galsim/gsobject.py b/jax_galsim/gsobject.py index b687e175..64cc9ec7 100644 --- a/jax_galsim/gsobject.py +++ b/jax_galsim/gsobject.py @@ -380,7 +380,7 @@ def _setup_image( N = self.getGoodImageSize(1.0) if odd: N += 1 - bounds = BoundsI(1, N, 1, N) + bounds = BoundsI(xmin=1, deltax=N, ymin=1, deltay=N) image.resize(bounds) # Else use the given image as is @@ -486,7 +486,7 @@ def _get_new_bounds(self, image, nx, ny, bounds, center): if image is not None and image.bounds.isDefined(): return image.bounds elif nx is not None and ny is not None: - b = BoundsI(1, nx, 1, ny) + b = BoundsI(xmin=1, deltax=nx, ymin=1, deltay=ny) if center is not None: # this code has to match the code in _setup_image # for the same branch of the if statement block @@ -853,7 +853,14 @@ def drawFFT_makeKImage(self, image): image_N = jnp.max( jnp.array( [ - jnp.max(jnp.abs(jnp.array(image.bounds._getinitargs()))) * 2, + jnp.max( + jnp.abs( + jnp.array( + [image.xmin, image.xmax, image.ymin, image.ymax] + ) + ) + ) + * 2, jnp.max(jnp.array(image.bounds.numpyShape())), ] ) @@ -880,7 +887,9 @@ def drawFFT_makeKImage(self, image): "drawFFT requires an FFT that is too large.", Nk ) - bounds = BoundsI(0, Nk // 2, -Nk // 2, Nk // 2) + bounds = BoundsI( + xmin=0, deltax=Nk // 2 + 1, ymin=-Nk // 2, deltay=2 * (Nk // 2) + 1 + ) if image.dtype in (np.complex128, np.float64, np.int32, np.uint32): kimage = ImageCD(bounds=bounds, scale=dk) else: @@ -895,12 +904,20 @@ def drawFFT_finish(self, image, kimage, wrap_size, add_to_image): # Wrap the full image to the size we want for the FT. # Even if N == Nk, this is useful to make this portion properly Hermitian in the # N/2 column and N/2 row. - bwrap = BoundsI(0, wrap_size // 2, -wrap_size // 2, wrap_size // 2 - 1) - kimage_wrap = kimage._wrap(bwrap, True, False) + bwrap = BoundsI( + xmin=0, + deltax=wrap_size // 2 + 1, + ymin=-wrap_size // 2, + deltay=2 * (wrap_size // 2), + ) + kimage_wrap = kimage._wrap(bwrap, True, False, wrap_size) # Perform the fourier transform. breal = BoundsI( - -wrap_size // 2, wrap_size // 2 - 1, -wrap_size // 2, wrap_size // 2 - 1 + xmin=-wrap_size // 2, + deltax=2 * (wrap_size // 2), + ymin=-wrap_size // 2, + deltay=2 * (wrap_size // 2), ) kimg_shift = jnp.fft.ifftshift(kimage_wrap.array, axes=(-2,)) real_image_arr = jnp.fft.fftshift( diff --git a/jax_galsim/image.py b/jax_galsim/image.py index f6f0f518..3cd3605a 100644 --- a/jax_galsim/image.py +++ b/jax_galsim/image.py @@ -1,10 +1,11 @@ import galsim as _galsim +import jax import jax.numpy as jnp import numpy as np from jax.tree_util import register_pytree_node_class from jax_galsim.bounds import Bounds, BoundsD, BoundsI -from jax_galsim.core.utils import ensure_hashable, implements +from jax_galsim.core.utils import ensure_hashable, has_tracers, implements from jax_galsim.errors import GalSimImmutableError from jax_galsim.position import PositionI from jax_galsim.utilities import parse_pos_args @@ -183,7 +184,12 @@ def __init__(self, *args, **kwargs): ncol = int(ncol) nrow = int(nrow) self._array = self._make_empty(shape=(nrow, ncol), dtype=self._dtype) - self._bounds = BoundsI(xmin, xmin + ncol - 1, ymin, ymin + nrow - 1) + if not has_tracers(xmin) and not has_tracers(ymin): + self._bounds = BoundsI( + xmin=xmin, deltax=ncol, ymin=ymin, deltay=nrow, static=True + ) + else: + self._bounds = BoundsI(xmin=xmin, deltax=ncol, ymin=ymin, deltay=nrow) if init_value: self._array = self._array + init_value elif bounds is not None: @@ -196,7 +202,12 @@ def __init__(self, *args, **kwargs): elif array is not None: self._array = array.view(dtype=self._dtype) nrow, ncol = array.shape - self._bounds = BoundsI(xmin, xmin + ncol - 1, ymin, ymin + nrow - 1) + if not has_tracers(xmin) and not has_tracers(ymin): + self._bounds = BoundsI( + xmin=xmin, deltax=ncol, ymin=ymin, deltay=nrow, static=True + ) + else: + self._bounds = BoundsI(xmin=xmin, deltax=ncol, ymin=ymin, deltay=nrow) if init_value is not None: raise _galsim.GalSimIncompatibleValuesError( "Cannot specify init_value with array", @@ -260,7 +271,14 @@ def _get_xmin_ymin(array, kwargs, check_bounds=True): b = kwargs.pop("bounds") if not isinstance(b, BoundsI): raise TypeError("bounds must be a galsim.BoundsI instance") - if check_bounds and b.isDefined(): + if ( + check_bounds + and b.isDefined() + and not has_tracers(b.xmin) + and not has_tracers(b.ymin) + and not has_tracers(b.xmax) + and not has_tracers(b.ymax) + ): # We need to disable this when jitting if b.xmax - b.xmin + 1 != array.shape[1]: raise _galsim.GalSimIncompatibleValuesError( @@ -487,16 +505,28 @@ def subImage(self, bounds): raise _galsim.GalSimUndefinedBoundsError( "Attempt to access subImage of undefined image" ) - if not self.bounds.includes(bounds): + if ( + not has_tracers(self.bounds.xmin) + and not has_tracers(self.bounds.xmax) + and not has_tracers(self.bounds.ymin) + and not has_tracers(self.bounds.ymax) + and not has_tracers(bounds.xmin) + and not has_tracers(bounds.xmax) + and not has_tracers(bounds.ymin) + and not has_tracers(bounds.ymax) + and not self.bounds.includes(bounds) + ): raise _galsim.GalSimBoundsError( "Attempt to access subImage not (fully) in image", bounds, self.bounds ) - i1 = bounds.ymin - self.ymin - i2 = bounds.ymax - self.ymin + 1 - j1 = bounds.xmin - self.xmin - j2 = bounds.xmax - self.xmin + 1 - subarray = self.array[i1:i2, j1:j2] + start_inds = ( + bounds.ymin - self.ymin, + bounds.xmin - self.xmin, + ) + shape = bounds.numpyShape() + subarray = jax.lax.dynamic_slice(self.array, start_inds, shape) + # NB. The wcs is still accurate, since the sub-image uses the same (x,y) values # as the original image did for those pixels. It's only once you recenter or # reorigin that you need to update the wcs. So that's taken care of in im.shift. @@ -512,7 +542,17 @@ def setSubImage(self, bounds, rhs): raise _galsim.GalSimUndefinedBoundsError( "Attempt to access values of an undefined image" ) - if not self.bounds.includes(bounds): + if ( + not has_tracers(self.bounds.xmin) + and not has_tracers(self.bounds.xmax) + and not has_tracers(self.bounds.ymin) + and not has_tracers(self.bounds.ymax) + and not has_tracers(bounds.xmin) + and not has_tracers(bounds.xmax) + and not has_tracers(bounds.ymin) + and not has_tracers(bounds.ymax) + and not self.bounds.includes(bounds) + ): raise _galsim.GalSimBoundsError( "Attempt to access subImage not (fully) in image", bounds, self.bounds ) @@ -524,11 +564,15 @@ def setSubImage(self, bounds, rhs): self_image=self, rhs=rhs, ) - i1 = bounds.ymin - self.ymin - i2 = bounds.ymax - self.ymin + 1 - j1 = bounds.xmin - self.xmin - j2 = bounds.xmax - self.xmin + 1 - self._array = self._array.at[i1:i2, j1:j2].set(rhs.array) + start_inds = ( + bounds.ymin - self.ymin, + bounds.xmin - self.xmin, + ) + self._array = jax.lax.dynamic_update_slice( + self.array, + jnp.astype(rhs.array, self.dtype), + start_inds, + ) def __getitem__(self, *args): """Return either a subimage or a single pixel value. @@ -587,42 +631,42 @@ def wrap(self, bounds, hermitian=False): # Get this at the start to check for invalid bounds and raise the exception before # possibly writing data past the edge of the image. if not hermitian: - return self._wrap(bounds, False, False) + return self._wrap(bounds, False, False, None) elif hermitian == "x": - if self.bounds.xmin != 0: + if not has_tracers(self.bounds.xmin) and self.bounds.xmin != 0: raise _galsim.GalSimIncompatibleValuesError( "hermitian == 'x' requires self.bounds.xmin == 0", hermitian=hermitian, bounds=self.bounds, ) - if bounds.xmin != 0: + if not has_tracers(bounds.xmin) and bounds.xmin != 0: raise _galsim.GalSimIncompatibleValuesError( "hermitian == 'x' requires bounds.xmin == 0", hermitian=hermitian, bounds=bounds, ) - return self._wrap(bounds, True, False) + return self._wrap(bounds, True, False, 2 * bounds.xmax) elif hermitian == "y": - if self.bounds.ymin != 0: + if not has_tracers(self.bounds.ymin) and self.bounds.ymin != 0: raise _galsim.GalSimIncompatibleValuesError( "hermitian == 'y' requires self.bounds.ymin == 0", hermitian=hermitian, bounds=self.bounds, ) - if bounds.ymin != 0: + if not has_tracers(bounds.ymin) and bounds.ymin != 0: raise _galsim.GalSimIncompatibleValuesError( "hermitian == 'y' requires bounds.ymin == 0", hermitian=hermitian, bounds=bounds, ) - return self._wrap(bounds, False, True) + return self._wrap(bounds, False, True, 2 * bounds.ymax) else: raise _galsim.GalSimValueError( "Invalid value for hermitian", hermitian, (False, "x", "y") ) @implements(_galsim.Image._wrap) - def _wrap(self, bounds, hermx, hermy): + def _wrap(self, bounds, hermx, hermy, hermitian_wrap_size): if not hermx and not hermy: from jax_galsim.core.wrap_image import wrap_nonhermitian @@ -631,9 +675,8 @@ def _wrap(self, bounds, hermx, hermy): # zero indexed location of subimage bounds.xmin - self.xmin, bounds.ymin - self.ymin, - # we include pixels on the edges so +1 here - bounds.xmax - bounds.xmin + 1, - bounds.ymax - bounds.ymin + 1, + bounds.deltax, + bounds.deltay, ) elif hermx and not hermy: from jax_galsim.core.wrap_image import wrap_hermitian_x @@ -644,8 +687,8 @@ def _wrap(self, bounds, hermx, hermy): self.ymin, -bounds.xmax + 1, bounds.ymin, - 2 * bounds.xmax, - bounds.ymax - bounds.ymin + 1, + hermitian_wrap_size, + bounds.deltay, ) elif not hermx and hermy: from jax_galsim.core.wrap_image import wrap_hermitian_y @@ -656,8 +699,8 @@ def _wrap(self, bounds, hermx, hermy): -self.ymax, bounds.xmin, -bounds.ymax + 1, - bounds.xmax - bounds.xmin + 1, - 2 * bounds.ymax, + bounds.deltax, + hermitian_wrap_size, ) return self.subImage(bounds) @@ -682,6 +725,8 @@ def calculate_fft(self): "JAX-GalSim does not support forward FFTs of complex dtypes." ) + # TODO: figure out how to do FFT at fixed size and then reconstruct + # the result No2 = max( max( -self.bounds.xmin, @@ -693,7 +738,7 @@ def calculate_fft(self): ), ) - full_bounds = BoundsI(-No2, No2 - 1, -No2, No2 - 1) + full_bounds = BoundsI(xmin=-No2, deltax=2 * No2, ymin=-No2, deltay=2 * No2) if self.bounds == full_bounds: # Then the image is already in the shape we need. ximage = self @@ -706,7 +751,11 @@ def calculate_fft(self): # dk = 2pi / (N dk) dk = jnp.pi / (No2 * dx) - out = Image(BoundsI(0, No2, -No2, No2 - 1), dtype=np.complex128, scale=dk) + out = Image( + BoundsI(xmin=0, deltax=No2 + 1, ymin=-No2, deltay=2 * No2), + dtype=np.complex128, + scale=dk, + ) # we shift the image before and after the FFT to match the layout of the modes # used by GalSim out._array = jnp.fft.fftshift( @@ -743,32 +792,41 @@ def calculate_inverse_fft(self): self.bounds.ymax, ) - target_bounds = BoundsI(0, No2, -No2, No2 - 1) + target_bounds = BoundsI(xmin=0, deltax=No2 + 1, ymin=-No2, deltay=2 * No2) if self.bounds == target_bounds: # Then the image is already in the shape we need. kimage = self else: # Then we can pad out with zeros and wrap to get this in the form we need. - full_bounds = BoundsI(0, No2, -No2, No2) + full_bounds = BoundsI(xmin=0, deltax=No2 + 1, ymin=-No2, deltay=2 * No2 + 1) kimage = Image(full_bounds, dtype=self.dtype, init_value=0) posx_bounds = BoundsI( - 0, self.bounds.xmax, self.bounds.ymin, self.bounds.ymax + xmin=0, + xmax=self.bounds.xmax, + ymin=self.bounds.ymin, + ymax=self.bounds.ymax, ) kimage[posx_bounds] = self[posx_bounds] - kimage = kimage.wrap(target_bounds, hermitian="x") + kimage = kimage._wrap(target_bounds, True, False, 2 * No2) dk = self.scale # dx = 2pi / (N dk) dx = jnp.pi / (No2 * dk) # For the inverse, we need a bit of extra space for the fft. - out_extra = Image(BoundsI(-No2, No2 + 1, -No2, No2 - 1), dtype=float, scale=dx) + out_extra = Image( + BoundsI(xmin=-No2, deltax=2 * No2 + 2, ymin=-No2, deltay=2 * No2), + dtype=float, + scale=dx, + ) # we shift the image before and after the FFT to match the layout used by galsim out_extra._array = jnp.fft.fftshift( jnp.fft.irfft2(jnp.fft.fftshift(kimage.array, axes=0)) ) # Now cut off the bit we don't need. - out = out_extra.subImage(BoundsI(-No2, No2 - 1, -No2, No2 - 1)) + out = out_extra.subImage( + BoundsI(xmin=-No2, deltax=2 * No2, ymin=-No2, deltay=2 * No2) + ) out *= (dk * No2 / jnp.pi) ** 2 out.setCenter(0, 0) return out @@ -1033,7 +1091,12 @@ def __ne__(self, other): @implements(_galsim.Image.transpose) def transpose(self): - bT = BoundsI(self.ymin, self.ymax, self.xmin, self.xmax) + bT = self.bounds.__class__( + xmin=self.ymin, + deltax=self.bounds.deltay, + ymin=self.xmin, + deltay=self.bounds.deltax, + ) return _Image(self.array.T, bT, None) @implements(_galsim.Image.flip_lr) @@ -1046,12 +1109,22 @@ def flip_ud(self): @implements(_galsim.Image.rot_cw) def rot_cw(self): - bT = BoundsI(self.ymin, self.ymax, self.xmin, self.xmax) + bT = self.bounds.__class__( + xmin=self.ymin, + deltax=self.bounds.deltay, + ymin=self.xmin, + deltay=self.bounds.deltax, + ) return _Image(self.array.T.at[::-1, :].get(), bT, None) @implements(_galsim.Image.rot_ccw) def rot_ccw(self): - bT = BoundsI(self.ymin, self.ymax, self.xmin, self.xmax) + bT = self.bounds.__class__( + xmin=self.ymin, + deltax=self.bounds.deltay, + ymin=self.xmin, + deltay=self.bounds.deltax, + ) return _Image(self.array.T.at[:, ::-1].get(), bT, None) @implements(_galsim.Image.rot_180) @@ -1061,8 +1134,16 @@ def rot_180(self): def tree_flatten(self): """Flatten the image into a list of values.""" # Define the children nodes of the PyTree that need tracing - children = (self.array, self.wcs) - aux_data = {"dtype": self.dtype, "bounds": self.bounds, "isconst": self.isconst} + if self.bounds.isStatic(): + children = (self.array, self.wcs) + aux_data = { + "dtype": self.dtype, + "bounds": self.bounds, + "isconst": self.isconst, + } + else: + children = (self.array, self.wcs, self.bounds) + aux_data = {"dtype": self.dtype, "isconst": self.isconst} # other routines may add these attributes to images on the fly # we have to include them here so that JAX knows how to handle them in jitting etc. if hasattr(self, "added_flux"): @@ -1080,15 +1161,26 @@ def tree_unflatten(cls, aux_data, children): obj = object.__new__(cls) obj._array = children[0] obj.wcs = children[1] - obj._bounds = aux_data["bounds"] - obj._dtype = aux_data["dtype"] - obj._is_const = aux_data["isconst"] - if len(children) > 2: - obj.added_flux = children[2] - if "header" in aux_data: - obj.header = aux_data["header"] - if len(children) > 3: - obj.photons = children[3] + if "bounds" in aux_data: + obj._bounds = aux_data["bounds"] + obj._dtype = aux_data["dtype"] + obj._is_const = aux_data["isconst"] + if len(children) > 2: + obj.added_flux = children[2] + if "header" in aux_data: + obj.header = aux_data["header"] + if len(children) > 3: + obj.photons = children[3] + else: + obj._bounds = children[2] + obj._dtype = aux_data["dtype"] + obj._is_const = aux_data["isconst"] + if len(children) > 3: + obj.added_flux = children[3] + if "header" in aux_data: + obj.header = aux_data["header"] + if len(children) > 4: + obj.photons = children[4] return obj @classmethod @@ -1100,7 +1192,7 @@ def from_galsim(cls, galsim_image): else None ) im = cls( - array=galsim_image.array, + array=jnp.asarray(galsim_image.array), wcs=wcs, bounds=Bounds.from_galsim(galsim_image.bounds), ) @@ -1269,9 +1361,9 @@ def Image_iadd(self, other): a = other dt = type(a) if dt == self.array.dtype: - self._array = self.array + a + self._array = self.array.at[...].add(a) else: - self._array = (self.array + a).astype(self.array.dtype) + self._array = self.array.at[...].set((self.array + a).astype(self.array.dtype)) return self @@ -1297,9 +1389,9 @@ def Image_isub(self, other): a = other dt = type(a) if dt == self.array.dtype: - self._array = self.array - a + self._array = self.array.at[...].subtract(a) else: - self._array = (self.array - a).astype(self.array.dtype) + self._array = self.array.at[...].set((self.array - a).astype(self.array.dtype)) return self @@ -1321,9 +1413,9 @@ def Image_imul(self, other): a = other dt = type(a) if dt == self.array.dtype: - self._array = self.array * a + self._array = self.array.at[...].multiply(a) else: - self._array = (self.array * a).astype(self.array.dtype) + self._array = self.array.at[...].set((self.array * a).astype(self.array.dtype)) return self @@ -1351,9 +1443,9 @@ def Image_idiv(self, other): if dt == self.array.dtype and not self.isinteger: # if dtype is an integer type, then numpy doesn't allow true division /= to assign # back to an integer array. So for integers (or mixed types), don't use /=. - self._array = self.array / a + self._array = self.array.at[...].divide(a) else: - self._array = (self.array / a).astype(self.array.dtype) + self._array = self.array.at[...].set((self.array / a).astype(self.array.dtype)) return self @@ -1380,9 +1472,9 @@ def Image_ifloordiv(self, other): a = other dt = type(a) if dt == self.array.dtype: - self._array = self.array // a + self._array = self.array.at[...].set(self.array // a) else: - self._array = (self.array // a).astype(self.array.dtype) + self._array = self.array.at[...].set((self.array // a).astype(self.array.dtype)) return self @@ -1409,9 +1501,9 @@ def Image_imod(self, other): a = other dt = type(a) if dt == self.array.dtype: - self._array = self.array % a + self._array = self.array.at[...].set(self.array % a) else: - self._array = (self.array % a).astype(self.array.dtype) + self._array = self.array.at[...].set((self.array % a).astype(self.array.dtype)) return self @@ -1422,7 +1514,7 @@ def Image_pow(self, other): def Image_ipow(self, other): if not isinstance(other, int) and not isinstance(other, float): raise TypeError("Can only raise an image to a float or int power!") - self._array = self.array**other + self._array = self.array.at[...].power(other) return self @@ -1448,7 +1540,7 @@ def Image_iand(self, other): a = other.array except AttributeError: a = other - self._array = self.array & a + self._array = self.array.at[...].set(self.array & a) return self @@ -1467,7 +1559,7 @@ def Image_ixor(self, other): a = other.array except AttributeError: a = other - self._array = self.array ^ a + self._array = self.array.at[...].set(self.array ^ a) return self @@ -1486,7 +1578,7 @@ def Image_ior(self, other): a = other.array except AttributeError: a = other - self._array = self.array | a + self._array = self.array.at[...].set(self.array | a) return self diff --git a/jax_galsim/interpolatedimage.py b/jax_galsim/interpolatedimage.py index 2baeb6e6..a3edbed5 100644 --- a/jax_galsim/interpolatedimage.py +++ b/jax_galsim/interpolatedimage.py @@ -714,7 +714,7 @@ def _getStepK(self, calculate_stepk): else: # If not a bool, then value is max_stepk R = (jnp.ceil(jnp.pi / calculate_stepk)).astype(int) - b = BoundsI(-R, R, -R, R) + b = BoundsI(xmin=-R, deltax=2 * R + 1, ymin=-R, deltay=2 * R + 1) b = self.image.bounds & b im = self.image[b] thresh = (1.0 - self.gsparams.folding_threshold) * self._image_flux @@ -880,13 +880,13 @@ def _shoot(self, photons, rng): ).astype(int) yinds, xinds = jnp.unravel_index(inds, img.array.shape) - xedges = jnp.arange(img.bounds.xmin, img.bounds.xmax + 2) - 0.5 - yedges = jnp.arange(img.bounds.ymin, img.bounds.ymax + 2) - 0.5 + xedges = jnp.arange(0, img.bounds.deltax + 1) - 0.5 + yedges = jnp.arange(0, img.bounds.deltay + 1) - 0.5 # now we draw the position within the pixel ud = UniformDeviate(rng) - photons.x = ud.generate(photons.x) + xedges[xinds] - photons.y = ud.generate(photons.y) + yedges[yinds] + photons.x = ud.generate(photons.x) + xedges[xinds] + img.bounds.xmin + photons.y = ud.generate(photons.y) + yedges[yinds] + img.bounds.ymin # this magic set of factors comes from the galsim C++ code in # a few spots it is # diff --git a/tests/GalSim b/tests/GalSim index 04918b11..471a7a1e 160000 --- a/tests/GalSim +++ b/tests/GalSim @@ -1 +1 @@ -Subproject commit 04918b118926eafc01ec9403b8afed29fb918d51 +Subproject commit 471a7a1e45b76c5b67f202a4a67d6eed702e6643 diff --git a/tests/jax/test_api.py b/tests/jax/test_api.py index e79f320c..e76b081c 100644 --- a/tests/jax/test_api.py +++ b/tests/jax/test_api.py @@ -355,6 +355,20 @@ def _reg_fun(p): ): continue + # jax-galsim Bounds classes do not store xmax, ymax + if issubclass(cls, jax_galsim.Bounds) and method in [ + "xmax", + "ymax", + "isStatic", + ]: + continue + + if issubclass(cls, jax_galsim.BoundsI) and method in [ + "xmin", + "ymin", + ]: + continue + assert method in dir(gscls), ( cls.__name__ + "." + method + " not in galsim." + gscls.__name__ ) @@ -497,7 +511,9 @@ def _reg_sfun(g1): jax_galsim.BoundsD( jnp.array(0.2), jnp.array(4.0), jnp.array(-0.5), jnp.array(4.7) ), - jax_galsim.BoundsI(jnp.array(-10), jnp.array(5), jnp.array(0), jnp.array(7)), + jax_galsim.BoundsI(xmin=jnp.array(-10), deltax=5, ymin=jnp.array(0), deltay=7), + jax_galsim.BoundsI(xmin=np.array(-10), deltax=5, ymin=0, deltay=7), + jax_galsim.BoundsI(-10, -6, 0, 6), ], ) def test_api_bounds(obj): diff --git a/tests/jax/test_draw_bounds.py b/tests/jax/test_draw_bounds.py new file mode 100644 index 00000000..7f8cba6b --- /dev/null +++ b/tests/jax/test_draw_bounds.py @@ -0,0 +1,63 @@ +import jax +import numpy as np + +import jax_galsim + + +def test_draw_bounds_center(): + + def _draw(center, flux): + return jax_galsim.Gaussian( + fwhm=1.5, + flux=flux, + gsparams=jax_galsim.GSParams(minimum_fft_size=1024, maximum_fft_size=1024), + ).drawImage(nx=52, ny=52, center=center, scale=0.2) + + img = _draw(jax_galsim.PositionD(5.7, -2.1), 10) + np.testing.assert_allclose(img.array.sum(), 10, rtol=1e-5, atol=1e-5) + assert img.bounds.xmin != 1 + assert img.bounds.ymin != 1 + + +def test_draw_bounds_center_jit(): + + @jax.jit + def _draw(center, flux): + return jax_galsim.Gaussian( + fwhm=1.5, + flux=flux, + gsparams=jax_galsim.GSParams(minimum_fft_size=1024, maximum_fft_size=1024), + ).drawImage(nx=52, ny=52, center=center, scale=0.2) + + img = _draw(jax_galsim.PositionD(5.7, -2.1), 10) + np.testing.assert_allclose(img.array.sum(), 10, rtol=1e-5, atol=1e-5) + assert img.bounds.xmin != 1 + assert img.bounds.ymin != 1 + + +def test_draw_bounds_center_jit_vmap(): + + @jax.jit + def _draw(center, flux): + return jax_galsim.Gaussian( + fwhm=1.5, + flux=flux, + gsparams=jax_galsim.GSParams(minimum_fft_size=1024, maximum_fft_size=1024), + ).drawImage(nx=101, ny=101, center=center, scale=0.2) + + ng = 7 + rng = np.random.default_rng(seed=10) + pos_x = rng.uniform(low=-10, high=10, size=ng) + pos_y = rng.uniform(low=-10, high=10, size=ng) + flux = rng.uniform(low=1, high=10, size=ng) + pos = jax.vmap(lambda x, y: jax_galsim.PositionD(x, y))(pos_x, pos_y) + img = jax.jit(jax.vmap(_draw))(pos, flux) + assert img.array.shape == (ng, 101, 101) + assert not any(xmin == 1 for xmin in img.bounds.xmin) + assert not any(ymin == 1 for ymin in img.bounds.ymin) + for i in range(ng): + for j in range(i + 1, ng): + assert not np.array_equal(img.array[i, ...], img.array[j, ...]) + + fluxes = img.array.sum(axis=(1, 2)) + np.testing.assert_allclose(fluxes, flux, rtol=1e-5, atol=1e-5) diff --git a/tests/jax/test_image_wrapping.py b/tests/jax/test_image_wrapping.py index 63136a4b..0a3219fb 100644 --- a/tests/jax/test_image_wrapping.py +++ b/tests/jax/test_image_wrapping.py @@ -130,13 +130,13 @@ def test_image_wrapping_autodiff(func, K, L): # make sure these run without error if func == "wrap": - b3 = galsim.BoundsI(0, K, -L + 1, L) + b3 = galsim.BoundsI(xmin=0, deltax=K + 1, ymin=-L + 1, deltay=2 * L) im.wrap(b3) elif func == "vjp-jit" or func == "jvp-jit": @jax.jit def _wrapit(im): - b3 = galsim.BoundsI(0, K, -L + 1, L) + b3 = galsim.BoundsI(xmin=0, deltax=K + 1, ymin=-L + 1, deltay=2 * L) return im.wrap(b3) if func == "vjp-jit": @@ -148,7 +148,7 @@ def _wrapit(im): elif func == "vjp" or func == "jvp": def _wrapit(im): - b3 = galsim.BoundsI(0, K, -L + 1, L) + b3 = galsim.BoundsI(xmin=0, deltax=K + 1, ymin=-L + 1, deltay=2 * L) return im.wrap(b3) if func == "vjp": diff --git a/tests/jax/test_interpolatedimage_utils.py b/tests/jax/test_interpolatedimage_utils.py index 44c98f13..d349cd88 100644 --- a/tests/jax/test_interpolatedimage_utils.py +++ b/tests/jax/test_interpolatedimage_utils.py @@ -289,7 +289,9 @@ def _compute_fft_with_numpy_jax_galsim(im): dk = np.pi / (No2 * dx) out = Image(BoundsI(0, No2, -No2, No2 - 1), dtype=np.complex128, scale=dk) - out._array = np.fft.fftshift(np.fft.rfft2(np.fft.fftshift(ximage.array)), axes=0) + out._array = jnp.asarray( + np.fft.fftshift(np.fft.rfft2(np.fft.fftshift(ximage.array)), axes=0) + ) out *= dx * dx out.setOrigin(0, -No2) return out diff --git a/tests/jax/test_render_scene.py b/tests/jax/test_render_scene.py new file mode 100644 index 00000000..da5910a1 --- /dev/null +++ b/tests/jax/test_render_scene.py @@ -0,0 +1,403 @@ +from functools import partial + +import galsim as _galsim +import jax +import jax.numpy as jnp +import jax.random as jrng +import numpy as np +import pytest + +import jax_galsim as jgs +from jax_galsim.photon_array import fixed_photon_array_size + + +def _generate_image_one(rng_key, psf): + rng_key, use_key = jrng.split(rng_key) + flux = jrng.uniform(use_key, minval=1.5, maxval=2.5) + rng_key, use_key = jrng.split(rng_key) + hlr = jrng.uniform(use_key, minval=0.5, maxval=2.5) + rng_key, use_key = jrng.split(rng_key) + g1 = jrng.uniform(use_key, minval=-0.1, maxval=0.1) + rng_key, use_key = jrng.split(rng_key) + g2 = jrng.uniform(use_key, minval=-0.1, maxval=0.1) + + rng_key, use_key = jrng.split(rng_key) + dx = jrng.uniform(use_key, minval=-10, maxval=10) + rng_key, use_key = jrng.split(rng_key) + dy = jrng.uniform(use_key, minval=-10, maxval=10) + + return ( + jgs.Convolve( + [ + jgs.Exponential(half_light_radius=hlr) + .shear(g1=g1, g2=g2) + .shift(dx, dy) + .withFlux(flux), + psf, + ] + ) + .withGSParams(minimum_fft_size=1024, maximum_fft_size=1024) + .drawImage(nx=200, ny=200, scale=0.2) + ) + + +@partial(jax.jit, static_argnames=("n_obj")) +def _generate_image(rng_key, psf, n_obj): + use_keys = jrng.split(rng_key, num=n_obj + 1) + rng_key = use_keys[0] + use_keys = use_keys[1:] + + return jax.vmap(_generate_image_one, in_axes=(0, None))(use_keys, psf) + + +def test_render_scene_draw_many_ffts_full_img(): + psf = jgs.Gaussian(fwhm=0.9) + img = _generate_image(jrng.key(10), psf, 5) + + if False: + import pdb + + import matplotlib.pyplot as plt + + plt.imshow(img.array.sum(axis=0)) + pdb.set_trace() + + assert img.array.shape == (5, 200, 200) + assert img.array.sum() > 5.0 + + +def _generate_image_one_phot(rng_key, psf): + rng_key, use_key = jrng.split(rng_key) + flux = jrng.uniform(use_key, minval=10000.5, maxval=20000.5) + rng_key, use_key = jrng.split(rng_key) + hlr = jrng.uniform(use_key, minval=0.5, maxval=2.5) + rng_key, use_key = jrng.split(rng_key) + g1 = jrng.uniform(use_key, minval=-0.1, maxval=0.1) + rng_key, use_key = jrng.split(rng_key) + g2 = jrng.uniform(use_key, minval=-0.1, maxval=0.1) + + rng_key, use_key = jrng.split(rng_key) + dx = jrng.uniform(use_key, minval=-10, maxval=10) + rng_key, use_key = jrng.split(rng_key) + dy = jrng.uniform(use_key, minval=-10, maxval=10) + + rng_key, use_key = jrng.split(rng_key) + + return ( + jgs.Convolve( + [ + jgs.Exponential(half_light_radius=hlr) + .shear(g1=g1, g2=g2) + .shift(dx, dy) + .withFlux(flux), + psf, + ] + ) + .withGSParams(minimum_fft_size=1024, maximum_fft_size=1024) + .drawImage( + nx=200, ny=200, scale=0.2, method="phot", rng=jgs.BaseDeviate(use_key) + ) + ) + + +@partial(jax.jit, static_argnames=("n_obj")) +def _generate_image_phot(rng_key, psf, n_obj): + use_keys = jrng.split(rng_key, num=n_obj + 1) + rng_key = use_keys[0] + use_keys = use_keys[1:] + + with fixed_photon_array_size(1000): + return jax.vmap(_generate_image_one_phot, in_axes=(0, None))(use_keys, psf) + + +def test_render_scene_draw_many_ffts_full_img_phot(): + psf = jgs.Gaussian(fwhm=0.9) + img = _generate_image_phot(jrng.key(10), psf, 5) + + if False: + import pdb + + import matplotlib.pyplot as plt + + plt.imshow(img.array.sum(axis=0)) + pdb.set_trace() + + assert img.array.shape == (5, 200, 200) + assert img.array.sum() > 5.0 + + +def _get_bd_jgs( + flux_d, + flux_b, + hlr_b, + hlr_d, + q_b, + q_d, + beta, + *, + psf_hlr=0.7, +): + components = [] + + # disk + disk = jgs.Exponential(flux=flux_d, half_light_radius=hlr_d).shear( + q=q_d, beta=beta * jgs.degrees + ) + components.append(disk) + + # bulge + bulge = jgs.Spergel(nu=-0.6, flux=flux_b, half_light_radius=hlr_b).shear( + q=q_b, beta=beta * jgs.degrees + ) + components.append(bulge) + + galaxy = jgs.Add(components) + + # psf + psf = jgs.Moffat(2, flux=1.0, half_light_radius=0.7) + + gal_conv = jgs.Convolve([galaxy, psf]) + return gal_conv + + +@partial(jax.jit, static_argnames=("fft_size", "slen")) +def _draw_stamp_jgs( + galaxy_params: dict, + image_pos: jgs.PositionD, + local_wcs: jgs.PixelScale, + fft_size: int, + slen: int, +) -> jax.Array: + gsparams = jgs.GSParams(minimum_fft_size=fft_size, maximum_fft_size=fft_size) + + convolved_object = _get_bd_jgs(**galaxy_params).withGSParams(gsparams) + + stamp = convolved_object.drawImage( + nx=slen, + ny=slen, + center=image_pos, + wcs=local_wcs, + dtype=jnp.float64, + ) + + return stamp + + +@partial(jax.jit, static_argnames=("slen",)) +def _add_to_image(carry, x, slen): + image = carry[0] + stamp = x + + image[stamp.bounds] += stamp + + return (image,), None + + +@partial(jax.jit, static_argnames=("fft_size", "slen", "ilen", "ng")) +def _render_scene_stamps_jax_galsim( + galaxy_params: dict, + x: jnp.ndarray, + y: jnp.ndarray, + fft_size: int, + slen: int, + ilen: int, + ng: int, +): + image = jgs.Image(ncol=ilen, nrow=ilen, scale=0.2, dtype=jnp.float64) + wcs = image.wcs + + image_positions = jax.vmap(lambda x, y: jgs.PositionD(x=x, y=y))(x, y) + local_wcss = jax.vmap(lambda x: wcs.local(image_pos=x))(image_positions) + + stamps = jax.jit(jax.vmap(partial(_draw_stamp_jgs, slen=slen, fft_size=fft_size)))( + galaxy_params, image_positions, local_wcss + ) + + pad_image = jgs.ImageD( + jnp.pad(image.array, slen), wcs=image.wcs, bounds=image.bounds.withBorder(slen) + ) + + final_pad_image = jax.lax.scan( + partial(_add_to_image, slen=slen), + (pad_image,), + xs=stamps, + length=ng, + )[0][0] + + return stamps, final_pad_image + + +def _get_bd_gs( + flux_d, + flux_b, + hlr_b, + hlr_d, + q_b, + q_d, + beta, + *, + psf_hlr=0.7, +): + components = [] + + # disk + disk = _galsim.Exponential(flux=flux_d, half_light_radius=hlr_d).shear( + q=q_d, beta=beta * _galsim.degrees + ) + components.append(disk) + + # bulge + bulge = _galsim.Spergel(nu=-0.6, flux=flux_b, half_light_radius=hlr_b).shear( + q=q_b, beta=beta * _galsim.degrees + ) + components.append(bulge) + + galaxy = _galsim.Add(components) + + # psf + psf = _galsim.Moffat(2, flux=1.0, half_light_radius=0.7) + + gal_conv = _galsim.Convolve([galaxy, psf]) + return gal_conv + + +def _render_scene_stamps_galsim( + galaxy_params: dict, + x: np.ndarray, + y: np.ndarray, + fft_size: int, + slen: int, + ilen: int, + ng: int, +): + image = _galsim.Image(ncol=ilen, nrow=ilen, scale=0.2, dtype=np.float64) + wcs = image.wcs + + image_pos = list(map(lambda tup: _galsim.PositionD(x=tup[0], y=tup[1]), zip(x, y))) + local_wcs = list(map(lambda x: wcs.local(image_pos=x), image_pos)) + + gsparams = _galsim.GSParams(minimum_fft_size=fft_size, maximum_fft_size=fft_size) + + for i in range(ng): + gpars = {k: v[i] for k, v in galaxy_params.items()} + convolved_object = _get_bd_gs(**gpars).withGSParams(gsparams) + + stamp = convolved_object.drawImage( + nx=slen, + ny=slen, + center=(image_pos[i].x, image_pos[i].y), + wcs=local_wcs[i], + dtype=np.float64, + ) + + b = stamp.bounds & image.bounds + if b.isDefined(): + image[b] += stamp[b] + + return image + + +@pytest.mark.parametrize("slen", [51, 52]) +def test_render_scene_stamps(slen): + rng = np.random.default_rng(seed=10) + ng = 5 + fft_size = 2048 + ilen = 200 + + galaxy_params = { + "flux_d": rng.uniform(low=0, high=1.0, size=ng), + "flux_b": rng.uniform(low=0, high=1.0, size=ng), + "hlr_b": rng.uniform(low=0.3, high=0.5, size=ng), + "hlr_d": rng.uniform(low=0.5, high=0.7, size=ng), + "q_b": rng.uniform(low=0.1, high=0.9, size=ng), + "q_d": rng.uniform(low=0.1, high=0.9, size=ng), + "beta": rng.uniform(low=0, high=360, size=ng), + "x": rng.uniform(low=10, high=190, size=ng), + "y": rng.uniform(low=10, high=190, size=ng), + } + + x = galaxy_params.pop("x") + y = galaxy_params.pop("y") + + stamps, final_pad_image = _render_scene_stamps_jax_galsim( + galaxy_params, + x, + y, + fft_size, + slen, + ilen, + ng, + ) + assert stamps.array.shape == (ng, slen, slen) + assert stamps.array.sum() > 0 + + np.testing.assert_allclose(final_pad_image.array.sum(), stamps.array.sum()) + + if False: + import pdb + + import matplotlib.pyplot as plt + + plt.imshow(final_pad_image.array) + pdb.set_trace() + + gs_image = _render_scene_stamps_galsim( + galaxy_params, + x, + y, + fft_size, + slen, + ilen, + ng, + ) + + gs_image_mo = _render_scene_stamps_galsim( + galaxy_params, + x, + y, + fft_size, + slen + 1, + ilen, + ng, + ) + + abs_eps = np.max(np.abs(gs_image_mo.array - gs_image.array)) + rel_eps = 0.0 + + if False: + import pdb + + import matplotlib.pyplot as plt + + plt.imshow(gs_image_mo.array - gs_image.array) + pdb.set_trace() + + if False: + import pdb + + import matplotlib.pyplot as plt + + plt.imshow(gs_image.array) + pdb.set_trace() + + if False: + import pdb + + import matplotlib.pyplot as plt + + plt.imshow(final_pad_image.array[slen:-slen, slen:-slen] - gs_image.array) + pdb.set_trace() + + np.testing.assert_allclose( + final_pad_image.array[slen:-slen, slen:-slen].sum(), + gs_image.array.sum(), + atol=abs_eps, + rtol=rel_eps, + ) + + np.testing.assert_allclose( + final_pad_image.array[slen:-slen, slen:-slen], + gs_image.array, + atol=abs_eps, + rtol=rel_eps, + ) diff --git a/tests/jax/test_vmapping.py b/tests/jax/test_vmapping.py index 6b8d7c40..dd892e9e 100644 --- a/tests/jax/test_vmapping.py +++ b/tests/jax/test_vmapping.py @@ -142,11 +142,15 @@ def test_eq(self, other): def test_bounds_vmapping(): + from functools import partial + obj = galsim.BoundsD(0.0, 1.0, 0.0, 1.0) obj_d = jax.vmap(galsim.BoundsD)(0.0 * e, 1.0 * e, 0.0 * e, 1.0 * e) objI = galsim.BoundsI(0.0, 1.0, 0.0, 1.0) - objI_d = jax.vmap(galsim.BoundsI)(0.0 * e, 1.0 * e, 0.0 * e, 1.0 * e) + objI_d = jax.vmap(partial(galsim.BoundsI, deltax=2.0, deltay=2.0))( + xmin=0.0 * e, ymin=0.0 * e + ) def test_eq(self, other): return (