diff --git a/pyresample/bilinear/xarr.py b/pyresample/bilinear/xarr.py index e53fed47..65c1be80 100644 --- a/pyresample/bilinear/xarr.py +++ b/pyresample/bilinear/xarr.py @@ -65,6 +65,13 @@ class XArrayBilinearResampler(BilinearBase): """Bilinear interpolation using XArray.""" + def __init__(self, source_geo_def, target_geo_def, radius_of_influence, + neighbours=32, epsilon=0, reduce_data=True, limit_output=True): + """Initialize xarray bilinear resampler.""" + super().__init__(source_geo_def, target_geo_def, radius_of_influence, + neighbours=neighbours, epsilon=epsilon, reduce_data=reduce_data) + self._limit_output = limit_output + def resample(self, data, fill_value=None, nprocs=1): """Resample the given data.""" del nprocs @@ -102,7 +109,7 @@ def _limit_output_values_to_input(self, data, res, fill_value): find_indices_outside_min_and_max(res, data_min, data_max), fill_value, res) - return da.where(np.isnan(res), fill_value, res) + return da.where(da.isnan(res), fill_value, res) def _reshape_to_target_area(self, res, ndim): if ndim == 3: @@ -127,7 +134,10 @@ def _reshape_to_target_area(self, res, ndim): return res def _finalize_output_data(self, data, res, fill_value): - res = self._limit_output_values_to_input(data, res, fill_value) + if self._limit_output: + res = self._limit_output_values_to_input(data, res, fill_value) + else: + res = da.where(da.isnan(res), fill_value, res) res = self._reshape_to_target_area(res, data.ndim) self._add_missing_coordinates(data) diff --git a/pyresample/test/test_bilinear.py b/pyresample/test/test_bilinear.py index f4711ef1..b4e7f1e9 100644 --- a/pyresample/test/test_bilinear.py +++ b/pyresample/test/test_bilinear.py @@ -539,6 +539,7 @@ def test_init(self): self.assertEqual(resampler._neighbours, 32) self.assertEqual(resampler._epsilon, 0) self.assertTrue(resampler._reduce_data) + self.assertTrue(resampler._limit_output) # These should be None self.assertIsNone(resampler._valid_input_index) self.assertIsNone(resampler._index_array) @@ -557,10 +558,12 @@ def test_init(self): # Override defaults resampler = XArrayBilinearResampler(self.source_def, self.target_def, self.radius, neighbours=16, - epsilon=0.1, reduce_data=False) + epsilon=0.1, reduce_data=False, + limit_output=False) self.assertEqual(resampler._neighbours, 16) self.assertEqual(resampler._epsilon, 0.1) self.assertFalse(resampler._reduce_data) + self.assertFalse(resampler._limit_output) def test_get_bil_info(self): """Test calculation of bilinear info.""" @@ -680,6 +683,46 @@ def test_get_sample_from_bil_info(self): assert res.shape == (2,) + self.target_def.shape assert res.dims == data.dims + def test_get_sample_from_bil_info_without_output_limiting(self): + """Test disabling output value limiting.""" + import dask.array as da + from xarray import DataArray + + from pyresample.bilinear import XArrayBilinearResampler + + pattern = ((np.indices(self.source_def.shape).sum(axis=0) % 2) + 1).astype(np.float32) + data = DataArray(da.from_array(pattern, chunks=pattern.shape), dims=("y", "x")) + data_min = float(np.nanmin(pattern)) + data_max = float(np.nanmax(pattern)) + def _compute_values(limit_output): + resampler = XArrayBilinearResampler( + self.source_def, + self.target_def, + self.radius, + limit_output=limit_output, + ) + resampler.get_bil_info() + bilinear_s = np.asarray(resampler.bilinear_s).copy() + bilinear_t = np.asarray(resampler.bilinear_t).copy() + valid = np.isfinite(bilinear_s) & np.isfinite(bilinear_t) + bilinear_s[valid] = 2.0 + resampler.bilinear_s = da.from_array(bilinear_s, chunks=bilinear_s.shape) + resampler.bilinear_t = da.from_array(bilinear_t, chunks=bilinear_t.shape) + return resampler.get_sample_from_bil_info(data, fill_value=-999.0).compute().values + + def _outside_mask(values): + valid = np.isfinite(values) & (values != -999.0) + return valid & ((values < data_min - 1e-6) | (values > data_max + 1e-6)) + + no_limit_values = _compute_values(limit_output=False) + limited_values = _compute_values(limit_output=True) + + assert np.any(_outside_mask(no_limit_values)) + assert not np.any(_outside_mask(limited_values)) + assert np.count_nonzero(limited_values == -999.0) > np.count_nonzero( + no_limit_values == -999.0 + ) + def test_add_missing_coordinates(self): """Test coordinate updating.""" import dask.array as da