From 2de35c85f786cecdde1c9db3822b7c15b7adfb16 Mon Sep 17 00:00:00 2001 From: Kamil Monicz Date: Sun, 1 Mar 2026 22:37:26 +0100 Subject: [PATCH] Optimize Dask EWA persist and prune fornav tasks --- pyresample/ewa/dask_ewa.py | 216 +++++++++++++++++++----- pyresample/test/test_dask_ewa.py | 273 +++++++++++++++++++++++++++++++ 2 files changed, 446 insertions(+), 43 deletions(-) diff --git a/pyresample/ewa/dask_ewa.py b/pyresample/ewa/dask_ewa.py index e41291d0..f69637a7 100644 --- a/pyresample/ewa/dask_ewa.py +++ b/pyresample/ewa/dask_ewa.py @@ -71,6 +71,79 @@ def _call_mapped_ll2cr(lons, lats, target_geo_def): return res +def _ll2cr_block_extent(ll2cr_block): + """Compute row/column bounds for a single ll2cr block. + + Args: + ll2cr_block: ll2cr output block as ``(cols, rows)`` arrays, or the + empty sentinel returned by ``_call_ll2cr``. + + Returns: + ``(row_min, row_max, col_min, col_max)`` as floats, or ``None`` when + the block contains no valid finite coordinates. + """ + # Empty ll2cr results: ((shape, fill, dtype), (shape, fill, dtype)) + if isinstance(ll2cr_block[0], tuple): + return None + + cols = np.asarray(ll2cr_block[0]) + rows = np.asarray(ll2cr_block[1]) + valid = np.isfinite(cols) & np.isfinite(rows) + if not np.any(valid): + return None + + valid_rows = rows[valid] + valid_cols = cols[valid] + row_min = float(valid_rows.min()) + row_max = float(valid_rows.max()) + col_min = float(valid_cols.min()) + col_max = float(valid_cols.max()) + return row_min, row_max, col_min, col_max + + +def _pad_bounds(bounds, overlap_margin): + """Pad ll2cr bounds by a constant overlap margin. + + Args: + bounds: ll2cr bounds tuple ``(row_min, row_max, col_min, col_max)``, + or ``None``. + overlap_margin: Non-negative overlap margin in grid cells. + + Returns: + Padded bounds tuple, or ``None`` when input bounds is ``None``. + """ + if bounds is None: + return None + row_min, row_max, col_min, col_max = bounds + return ( + row_min - overlap_margin, + row_max + overlap_margin, + col_min - overlap_margin, + col_max + overlap_margin, + ) + + +def _chunk_intersects_bounds(bounds, y_slice, x_slice): + """Check whether a target chunk overlaps pre-padded ll2cr bounds. + + Args: + bounds: ll2cr bounds tuple ``(row_min, row_max, col_min, col_max)``, + already padded for overlap, or ``None``. + y_slice: Output chunk row slice. + x_slice: Output chunk column slice. + + Returns: + ``True`` if the chunk intersects the bounds. + """ + if bounds is None: + return True + row_min, row_max, col_min, col_max = bounds + return ( + y_slice.stop > row_min and y_slice.start <= row_max and + x_slice.stop > col_min and x_slice.start <= col_max + ) + + def _delayed_fornav(ll2cr_result, target_geo_def, y_slice, x_slice, data, fill_value, kwargs): # Adjust cols and rows for this sub-area subdef = target_geo_def[y_slice, x_slice] @@ -107,6 +180,21 @@ def _chunk_callable(x_chunk, axis, keepdims, **kwargs): return x_chunk +def _sum_arrays(arrays): + """Sum arrays with one initial copy and in-place accumulation. + + Args: + arrays: Non-empty sequence of NumPy arrays with compatible shapes. + + Returns: + Element-wise sum as a NumPy array. + """ + total = arrays[0].copy() + for arr in arrays[1:]: + total += arr + return total + + def _combine_fornav(x_chunk, axis, keepdims, computing_meta=False, maximum_weight_mode=False): if computing_meta or _is_empty_chunk(x_chunk): @@ -126,6 +214,8 @@ def _combine_fornav(x_chunk, axis, keepdims, computing_meta=False, # split step - return "empty" chunk placeholder return x_chunk[0] return np.full(*x_chunk[0][0]), np.full(*x_chunk[0][1]) + if len(valid_chunks) == 1: + return valid_chunks[0] weights = [x[0] for x in valid_chunks] accums = [x[1] for x in valid_chunks] if maximum_weight_mode: @@ -135,9 +225,7 @@ def _combine_fornav(x_chunk, axis, keepdims, computing_meta=False, weights = np.take_along_axis(weights, max_indexes, axis=0).squeeze(axis=0) accums = np.take_along_axis(accums, max_indexes, axis=0).squeeze(axis=0) return weights, accums - # NOTE: We use the builtin "sum" function below because it does not copy - # the numpy arrays. Using numpy.sum would do that. - return sum(weights), sum(accums) + return _sum_arrays(weights), _sum_arrays(accums) def _is_empty_chunk(x_chunk): @@ -224,60 +312,79 @@ def _get_rows_per_scan(self, rows_per_scan=None): rows_per_scan = self.source_geo_def.shape[0] return rows_per_scan - def _fill_block_cache_with_ll2cr_results(self, ll2cr_result, - num_row_blocks, - num_col_blocks, - persist): + def _ll2cr_cache_matches(self, rows_per_scan, persist): + return ( + self.cache.get('rows_per_scan') == rows_per_scan and + self.cache.get('persist') == persist + ) + + def _get_ll2cr_blocks(self, ll2cr_result, persist): + ll2cr_blocks = [] + block_dependencies = None if persist: ll2cr_delayeds = ll2cr_result.to_delayed() - ll2cr_delayeds = dask.persist(*ll2cr_delayeds.tolist()) - - block_cache = {} - for in_row_idx in range(num_row_blocks): - for in_col_idx in range(num_col_blocks): - key = (ll2cr_result.name, in_row_idx, in_col_idx) - if persist: - this_delayed = ll2cr_delayeds[in_row_idx][in_col_idx] - result = dask.compute(this_delayed)[0] - # XXX: Is this optimization lost because the persisted keys - # in `ll2cr_delayeds` are used in future computations? - if not isinstance(result[0], tuple): - block_cache[key] = this_delayed.key - else: - block_cache[key] = key - return block_cache + flat_delayeds = [ + (in_row_idx, in_col_idx, delayed_block) + for in_row_idx, delayed_row in enumerate(ll2cr_delayeds) + for in_col_idx, delayed_block in enumerate(delayed_row) + ] + block_dependencies = [] + persisted_delayeds = dask.persist( + *(delayed for _, _, delayed in flat_delayeds)) + # Compute only per-block extents on workers to avoid materializing + # full ll2cr blocks in the client process. + extent_delayeds = [dask.delayed(_ll2cr_block_extent)(d) for d in persisted_delayeds] + computed_extents = dask.compute(*extent_delayeds) + for (in_row_idx, in_col_idx, _), persisted_delayed, extent in zip( + flat_delayeds, persisted_delayeds, computed_extents, strict=True): + if extent is None: + continue + ll2cr_blocks.append((in_row_idx, in_col_idx, persisted_delayed.key, extent)) + block_dependencies.append(persisted_delayed) + else: + num_row_blocks, num_col_blocks = ll2cr_result.numblocks[-2:] + for in_row_idx in range(num_row_blocks): + for in_col_idx in range(num_col_blocks): + ll2cr_blocks.append(( + in_row_idx, + in_col_idx, + (ll2cr_result.name, in_row_idx, in_col_idx), + None, + )) + return ll2cr_blocks, block_dependencies def precompute(self, cache_dir=None, rows_per_scan=None, persist=False, **kwargs): """Generate row and column arrays and store it for later use.""" - if self.cache: + rows_per_scan = self._get_rows_per_scan(rows_per_scan) + if self._ll2cr_cache_matches(rows_per_scan, persist): # this resampler should be used for one SwathDefinition - # no need to recompute ll2cr output again + # no need to recompute matching ll2cr output again return None if kwargs.get('mask') is not None: logger.warning("'mask' parameter has no affect during EWA " "resampling") - source_geo_def = self.source_geo_def - target_geo_def = self.target_geo_def if cache_dir: logger.warning("'cache_dir' is not used by EWA resampling") - rows_per_scan = self._get_rows_per_scan(rows_per_scan) - new_chunks = self._new_chunks(source_geo_def.lons, rows_per_scan) - lons, lats = source_geo_def.get_lonlats(chunks=new_chunks) + new_chunks = self._new_chunks(self.source_geo_def.lons, rows_per_scan) + lons, lats = self.source_geo_def.get_lonlats(chunks=new_chunks) # run ll2cr to get column/row indexes # if chunk does not overlap target area then None is returned # otherwise a 3D array (2, y, x) of cols, rows are returned - ll2cr_result = _call_mapped_ll2cr(lons, lats, target_geo_def) - block_cache = self._fill_block_cache_with_ll2cr_results( - ll2cr_result, lons.numblocks[0], lons.numblocks[1], persist) + ll2cr_result = _call_mapped_ll2cr(lons, lats, self.target_geo_def) + ll2cr_blocks, block_dependencies = self._get_ll2cr_blocks( + ll2cr_result, persist) # save the dask arrays in the class instance cache self.cache = { 'll2cr_result': ll2cr_result, - 'll2cr_blocks': block_cache, + 'll2cr_blocks': ll2cr_blocks, + 'll2cr_block_dependencies': block_dependencies, + 'rows_per_scan': rows_per_scan, + 'persist': persist, } return None @@ -323,6 +430,15 @@ def _generate_fornav_dask_tasks(out_chunks, ll2cr_blocks, task_name, input_name, target_geo_def, fill_value, kwargs): y_start = 0 output_stack = {} + overlap_margin = max( + float(kwargs.get("weight_delta_max", 0.0)), + float(kwargs.get("weight_distance_max", 0.0)), + 0.0, + ) + indexed_blocks = [] + for z_idx, (in_row_idx, in_col_idx, ll2cr_block, block_extent) in enumerate(ll2cr_blocks): + block_bounds = _pad_bounds(block_extent, overlap_margin) + indexed_blocks.append((z_idx, in_row_idx, in_col_idx, ll2cr_block, block_bounds)) for out_row_idx in range(len(out_chunks[0])): y_end = y_start + out_chunks[0][out_row_idx] x_start = 0 @@ -330,20 +446,30 @@ def _generate_fornav_dask_tasks(out_chunks, ll2cr_blocks, task_name, x_end = x_start + out_chunks[1][out_col_idx] y_slice = slice(y_start, y_end) x_slice = slice(x_start, x_end) - for z_idx, ((_, in_row_idx, in_col_idx), ll2cr_block) in enumerate(ll2cr_blocks): + placeholder = ( + ((y_end - y_start, x_end - x_start), 0, np.float32), + ((y_end - y_start, x_end - x_start), 0, np.float32), + ) + for z_idx, in_row_idx, in_col_idx, ll2cr_block, block_bounds in indexed_blocks: key = (task_name, z_idx, out_row_idx, out_col_idx) - output_stack[key] = (_delayed_fornav, - ll2cr_block, - target_geo_def, y_slice, x_slice, - (input_name, in_row_idx, in_col_idx), fill_value, kwargs) + if _chunk_intersects_bounds(block_bounds, y_slice, x_slice): + output_stack[key] = (_delayed_fornav, + ll2cr_block, + target_geo_def, y_slice, x_slice, + (input_name, in_row_idx, in_col_idx), fill_value, kwargs) + else: + output_stack[key] = placeholder x_start = x_end y_start = y_end return output_stack def _run_fornav_single(self, data, out_chunks, target_geo_def, fill_value, **kwargs): ll2cr_result = self.cache['ll2cr_result'] - ll2cr_blocks = self.cache['ll2cr_blocks'].items() - ll2cr_numblocks = ll2cr_result.shape if isinstance(ll2cr_result, np.ndarray) else ll2cr_result.numblocks + ll2cr_blocks = self.cache['ll2cr_blocks'] + ll2cr_block_dependencies = self.cache.get('ll2cr_block_dependencies') + if not ll2cr_blocks: + return da.full(target_geo_def.shape, fill_value, dtype=data.dtype, + chunks=out_chunks) fornav_task_name = f"fornav-{data.name}-{ll2cr_result.name}" maximum_weight_mode = kwargs.setdefault('maximum_weight_mode', False) weight_sum_min = kwargs.setdefault('weight_sum_min', -1.0) @@ -357,8 +483,12 @@ def _run_fornav_single(self, data, out_chunks, target_geo_def, fill_value, **kwa dsk_graph = HighLevelGraph.from_collections(fornav_task_name, output_stack, - dependencies=[data, ll2cr_result]) - stack_chunks = ((1,) * (ll2cr_numblocks[0] * ll2cr_numblocks[1]),) + out_chunks + dependencies=( + (data, ll2cr_result) + if ll2cr_block_dependencies is None + else (data, *ll2cr_block_dependencies) + )) + stack_chunks = ((1,) * len(ll2cr_blocks),) + out_chunks out_stack = da.Array(dsk_graph, fornav_task_name, stack_chunks, data.dtype) combine_fornav_with_kwargs = partial( _combine_fornav, maximum_weight_mode=maximum_weight_mode) diff --git a/pyresample/test/test_dask_ewa.py b/pyresample/test/test_dask_ewa.py index af6b7b88..89de9afe 100644 --- a/pyresample/test/test_dask_ewa.py +++ b/pyresample/test/test_dask_ewa.py @@ -154,6 +154,27 @@ def _coord_and_crs_checks(new_data, target_area, has_bands=False): ['R', 'G', 'B']) +def _fornav_task_keys(output_stack): + return { + key + for key, task in output_stack.items() + if isinstance(task, tuple) and task and task[0] is dask_ewa._delayed_fornav + } + + +def _fornav_task_count(output_stack): + return len(_fornav_task_keys(output_stack)) + + +OUT_CHUNKS_2X2 = ((2, 2), (2, 2)) +LL2CR_BLOCKS_2X2 = ( + (0, 0, "b00", None), + (0, 1, "b01", None), + (1, 0, "b10", None), + (1, 1, "b11", None), +) + + def _get_num_chunks(source_swath, resampler_class, rows_per_scan=10): if resampler_class is DaskEWAResampler: # ignore column-wise chunks because DaskEWA should rechunk to use whole scans @@ -385,3 +406,255 @@ def test_multiple_targets(self): assert res1.name != res2.name assert res1.compute().shape != res2.compute().shape + + def test_xarray_ewa_persist_computes(self): + """Ensure persisted ll2cr path builds a computable graph.""" + swath_data, source_swath, target_area = get_test_data( + input_shape=(100, 50), output_shape=(200, 100), + input_dims=('y', 'x'), input_dtype=np.float32, + ) + resampler = DaskEWAResampler(source_swath, target_area) + with dask.config.set(scheduler='sync'): + new_data = resampler.resample( + swath_data, + rows_per_scan=10, + persist=True, + chunks=(50, 50), + weight_delta_max=40, + ) + computed = new_data.compute() + assert computed.shape == (200, 100) + assert computed.dtype == np.float32 + + +def test_get_ll2cr_blocks_persist_uses_single_batched_compute(): + """Persisted ll2cr chunks should be computed in one batched call.""" + class FakeDelayed: + def __init__(self, key): + self.key = key + + class FakeLl2CrResult: + name = "ll2cr-test" + + def to_delayed(self): + return [ + [FakeDelayed("00"), FakeDelayed("01")], + [FakeDelayed("10"), FakeDelayed("11")], + ] + + fake_result = FakeLl2CrResult() + empty = (((), np.nan, np.float64), ((), np.nan, np.float64)) + ll2cr_by_key = { + "00": empty, + "01": np.stack([np.array([[1.0, 2.0]]), np.array([[3.0, 4.0]])]), + "10": np.stack([np.array([[5.0, 6.0]]), np.array([[7.0, 8.0]])]), + "11": np.stack([np.array([[9.0, 10.0]]), np.array([[11.0, 12.0]])]), + } + + def _persist(*args): + return args + + def _delayed(func): + def _wrap(arg): + return FakeDelayed(("extent", arg.key)) + return _wrap + + def _compute(*args): + out = [] + for d in args: + _, key = d.key + out.append(dask_ewa._ll2cr_block_extent(ll2cr_by_key[key])) + return tuple(out) + + with mock.patch.object(dask_ewa.dask, "persist", side_effect=_persist), \ + mock.patch.object(dask_ewa.dask, "delayed", side_effect=_delayed), \ + mock.patch.object(dask_ewa.dask, "compute", side_effect=_compute) as compute_mock: + ll2cr_blocks, block_dependencies = DaskEWAResampler._get_ll2cr_blocks( + None, fake_result, persist=True) + + assert compute_mock.call_count == 1 + assert ll2cr_blocks == [ + (0, 1, "01", (3.0, 4.0, 1.0, 2.0)), + (1, 0, "10", (7.0, 8.0, 5.0, 6.0)), + (1, 1, "11", (11.0, 12.0, 9.0, 10.0)), + ] + assert len(block_dependencies) == 3 + + +def test_get_ll2cr_blocks_without_persist_uses_numblocks_only(): + """Non-persist path should not build delayed wrappers just to count blocks.""" + class FakeLl2CrResult: + name = "ll2cr-test" + numblocks = (2, 3) + + def to_delayed(self): + raise AssertionError("to_delayed should not be used when persist=False") + + ll2cr_blocks, block_dependencies = DaskEWAResampler._get_ll2cr_blocks( + None, FakeLl2CrResult(), persist=False) + + assert block_dependencies is None + assert ll2cr_blocks == [ + (0, 0, ("ll2cr-test", 0, 0), None), + (0, 1, ("ll2cr-test", 0, 1), None), + (0, 2, ("ll2cr-test", 0, 2), None), + (1, 0, ("ll2cr-test", 1, 0), None), + (1, 1, ("ll2cr-test", 1, 1), None), + (1, 2, ("ll2cr-test", 1, 2), None), + ] + + +def test_generate_fornav_dask_tasks_filters_non_overlapping_pairs(): + """Only overlapping input/output chunk pairs should produce tasks.""" + ll2cr_blocks = ( + (0, 0, "b00", (0.1, 1.8, 0.1, 1.8)), + (0, 1, "b01", (0.1, 1.8, 2.1, 3.8)), + (1, 0, "b10", (2.1, 3.8, 0.1, 1.8)), + (1, 1, "b11", (2.1, 3.8, 2.1, 3.8)), + ) + output_stack = DaskEWAResampler._generate_fornav_dask_tasks( + OUT_CHUNKS_2X2, ll2cr_blocks, "fornav-test", "input", mock.Mock(), np.nan, + {"weight_delta_max": 0.0}) + assert len(output_stack) == 16 + fornav_pairs = {(key[1], key[2], key[3]) for key in _fornav_task_keys(output_stack)} + assert fornav_pairs == {(0, 0, 0), (1, 0, 1), (2, 1, 0), (3, 1, 1)} + + +def test_generate_fornav_dask_tasks_falls_back_to_cartesian_without_extents(): + """Without ll2cr extents the previous cartesian behavior is preserved.""" + output_stack = DaskEWAResampler._generate_fornav_dask_tasks( + OUT_CHUNKS_2X2, LL2CR_BLOCKS_2X2, "fornav-test", "input", mock.Mock(), np.nan, + {"weight_delta_max": 0.0}) + assert len(output_stack) == 16 + assert _fornav_task_count(output_stack) == 16 + + +@pytest.mark.parametrize( + ("kwargs", "expected_count"), + [ + ({"weight_delta_max": 0.0}, 1), + ({"weight_delta_max": 1.0}, 4), + ({"weight_delta_max": 0.0, "weight_distance_max": 1.0}, 4), + ], + ids=("no-padding", "delta-padding", "distance-padding"), +) +def test_generate_fornav_overlap_padding(kwargs, expected_count): + """Overlap padding should expand to neighboring output chunks.""" + output_stack = DaskEWAResampler._generate_fornav_dask_tasks( + OUT_CHUNKS_2X2, + ((0, 0, "b00", (1.9, 1.9, 1.9, 1.9)),), + "fornav-test", + "input", + mock.Mock(), + np.nan, + kwargs) + assert _fornav_task_count(output_stack) == expected_count + + +def test_ll2cr_block_extent_returns_none_for_all_non_finite(): + ll2cr_block = np.stack( + [np.full((2, 2), np.nan, dtype=np.float64), np.full((2, 2), np.nan, dtype=np.float64)], + axis=0, + ) + assert dask_ewa._ll2cr_block_extent(ll2cr_block) is None + + +def test_average_fornav_empty_keepdims_returns_fill(): + empty = (((2, 2), 0, np.float32), ((2, 2), 0, np.float32)) + out = dask_ewa._average_fornav([empty], axis=(0,), keepdims=True, dtype=np.float32, fill_value=np.nan) + assert out.shape == (2, 2) + assert np.all(np.isnan(out)) + + +def test_persisted_ll2cr_blocks_are_reused_between_resample_calls(): + """Persisted ll2cr blocks should not be recomputed for subsequent calls.""" + swath_data, source_swath, target_area = get_test_data( + input_shape=(100, 50), output_shape=(200, 100), + input_dims=('y', 'x'), input_dtype=np.float32, + ) + + with mock.patch.object(dask_ewa, 'll2cr', wraps=dask_ewa.ll2cr) as ll2cr_mock, \ + dask.config.set(scheduler='sync'): + resampler = DaskEWAResampler(source_swath, target_area) + + out1 = resampler.resample( + swath_data, + rows_per_scan=10, + persist=True, + chunks=(50, 50), + weight_delta_max=40, + ) + out1.compute() + calls_after_first = ll2cr_mock.call_count + + out2 = resampler.resample( + swath_data, + rows_per_scan=10, + persist=True, + chunks=(50, 50), + weight_delta_max=40, + ) + out2.compute() + calls_after_second = ll2cr_mock.call_count + + assert calls_after_first > 0 + assert calls_after_second == calls_after_first + + +@pytest.mark.parametrize( + ("first_kwargs", "second_kwargs"), + [ + ( + {"rows_per_scan": 10, "persist": False}, + {"rows_per_scan": 100, "persist": False}, + ), + ( + {"rows_per_scan": 10, "persist": False}, + {"rows_per_scan": 10, "persist": True}, + ), + ], + ids=("rows-per-scan-change", "persist-change"), +) +def test_ll2cr_cache_recomputes_when_precompute_mode_changes(first_kwargs, second_kwargs): + """Changing precompute mode should invalidate the cached ll2cr block layout.""" + swath_data, source_swath, target_area = get_test_data( + input_shape=(100, 50), output_shape=(200, 100), + input_dims=('y', 'x'), input_dtype=np.float32, + ) + + with mock.patch.object(source_swath, 'get_lonlats', wraps=source_swath.get_lonlats) as get_lonlats_mock, \ + dask.config.set(scheduler='sync'): + resampler = DaskEWAResampler(source_swath, target_area) + + resampler.resample( + swath_data, + chunks=(50, 50), + weight_delta_max=40, + **first_kwargs, + ).compute() + calls_after_first = get_lonlats_mock.call_count + + resampler.resample( + swath_data, + chunks=(50, 50), + weight_delta_max=40, + **second_kwargs, + ).compute() + calls_after_second = get_lonlats_mock.call_count + + assert calls_after_first > 0 + assert calls_after_second > calls_after_first + + +def test_xarray_ewa_persist_empty_returns_fill(): + """Persisted path should return full fill when all ll2cr blocks are empty.""" + output_proj = ('+proj=lcc +datum=WGS84 +ellps=WGS84 ' + '+lon_0=-55. +lat_0=25 +lat_1=25 +units=m +no_defs') + swath_data, source_swath, target_area = get_test_data( + input_shape=(100, 50), output_shape=(200, 100), + input_dims=('y', 'x'), input_dtype=np.float32, + output_proj=output_proj, + ) + resampler = DaskEWAResampler(source_swath, target_area) + with dask.config.set(scheduler='sync'): + assert np.all(np.isnan(resampler.resample(swath_data, rows_per_scan=10, persist=True).compute().values))