Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
216 changes: 173 additions & 43 deletions pyresample/ewa/dask_ewa.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,79 @@ def _call_mapped_ll2cr(lons, lats, target_geo_def):
return res


def _ll2cr_block_extent(ll2cr_block):
"""Compute row/column bounds for a single ll2cr block.

Args:
ll2cr_block: ll2cr output block as ``(cols, rows)`` arrays, or the
empty sentinel returned by ``_call_ll2cr``.

Returns:
``(row_min, row_max, col_min, col_max)`` as floats, or ``None`` when
the block contains no valid finite coordinates.
"""
# Empty ll2cr results: ((shape, fill, dtype), (shape, fill, dtype))
if isinstance(ll2cr_block[0], tuple):
return None

cols = np.asarray(ll2cr_block[0])
rows = np.asarray(ll2cr_block[1])
valid = np.isfinite(cols) & np.isfinite(rows)
if not np.any(valid):
return None

valid_rows = rows[valid]
valid_cols = cols[valid]
row_min = float(valid_rows.min())
row_max = float(valid_rows.max())
col_min = float(valid_cols.min())
col_max = float(valid_cols.max())
return row_min, row_max, col_min, col_max


def _pad_bounds(bounds, overlap_margin):
"""Pad ll2cr bounds by a constant overlap margin.

Args:
bounds: ll2cr bounds tuple ``(row_min, row_max, col_min, col_max)``,
or ``None``.
overlap_margin: Non-negative overlap margin in grid cells.

Returns:
Padded bounds tuple, or ``None`` when input bounds is ``None``.
"""
if bounds is None:
return None
row_min, row_max, col_min, col_max = bounds
return (
row_min - overlap_margin,
row_max + overlap_margin,
col_min - overlap_margin,
col_max + overlap_margin,
)


def _chunk_intersects_bounds(bounds, y_slice, x_slice):
"""Check whether a target chunk overlaps pre-padded ll2cr bounds.

Args:
bounds: ll2cr bounds tuple ``(row_min, row_max, col_min, col_max)``,
already padded for overlap, or ``None``.
y_slice: Output chunk row slice.
x_slice: Output chunk column slice.

Returns:
``True`` if the chunk intersects the bounds.
"""
if bounds is None:
return True
row_min, row_max, col_min, col_max = bounds
return (
y_slice.stop > row_min and y_slice.start <= row_max and
x_slice.stop > col_min and x_slice.start <= col_max
)
Comment on lines +141 to +144
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm sure I could be reading this wrong, but doesn't this only allow chunks to be processed that entirely encompass the output chunk? That's not what we want. We want to process the data if there is any overlap.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This code is checking for any overlap. test_generate_fornav_overlap_padding verifies this behavior.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here's me printing out the bounds, y_slice, x_slice in this function while running all instances of that test:

(1.9, 1.9, 1.9, 1.9) slice(0, 2, None) slice(0, 2, None)
(1.9, 1.9, 1.9, 1.9) slice(0, 2, None) slice(2, 4, None)
(1.9, 1.9, 1.9, 1.9) slice(2, 4, None) slice(0, 2, None)
(1.9, 1.9, 1.9, 1.9) slice(2, 4, None) slice(2, 4, None)

.(0.8999999999999999, 2.9, 0.8999999999999999, 2.9) slice(0, 2, None) slice(0, 2, None)
(0.8999999999999999, 2.9, 0.8999999999999999, 2.9) slice(0, 2, None) slice(2, 4, None)
(0.8999999999999999, 2.9, 0.8999999999999999, 2.9) slice(2, 4, None) slice(0, 2, None)
(0.8999999999999999, 2.9, 0.8999999999999999, 2.9) slice(2, 4, None) slice(2, 4, None)

.(0.8999999999999999, 2.9, 0.8999999999999999, 2.9) slice(0, 2, None) slice(0, 2, None)
(0.8999999999999999, 2.9, 0.8999999999999999, 2.9) slice(0, 2, None) slice(2, 4, None)
(0.8999999999999999, 2.9, 0.8999999999999999, 2.9) slice(2, 4, None) slice(0, 2, None)
(0.8999999999999999, 2.9, 0.8999999999999999, 2.9) slice(2, 4, None) slice(2, 4, None)

The first test case is the only one that doesn't expect an overlap of all 4 chunks and it has a zero-size input unless I'm missing something (all bounds are 1.9). So that doesn't seem realistic.

That said, I see now that the logic was the inverse of what I expected and that it works as expected...although it took me way too long to wrap my head around it at the end of a long day. Overall, makes sense. Thanks.



def _delayed_fornav(ll2cr_result, target_geo_def, y_slice, x_slice, data, fill_value, kwargs):
# Adjust cols and rows for this sub-area
subdef = target_geo_def[y_slice, x_slice]
Expand Down Expand Up @@ -107,6 +180,21 @@ def _chunk_callable(x_chunk, axis, keepdims, **kwargs):
return x_chunk


def _sum_arrays(arrays):
"""Sum arrays with one initial copy and in-place accumulation.

Args:
arrays: Non-empty sequence of NumPy arrays with compatible shapes.

Returns:
Element-wise sum as a NumPy array.
"""
total = arrays[0].copy()
for arr in arrays[1:]:
total += arr
return total


def _combine_fornav(x_chunk, axis, keepdims, computing_meta=False,
maximum_weight_mode=False):
if computing_meta or _is_empty_chunk(x_chunk):
Expand All @@ -126,6 +214,8 @@ def _combine_fornav(x_chunk, axis, keepdims, computing_meta=False,
# split step - return "empty" chunk placeholder
return x_chunk[0]
return np.full(*x_chunk[0][0]), np.full(*x_chunk[0][1])
if len(valid_chunks) == 1:
return valid_chunks[0]
weights = [x[0] for x in valid_chunks]
accums = [x[1] for x in valid_chunks]
if maximum_weight_mode:
Expand All @@ -135,9 +225,7 @@ def _combine_fornav(x_chunk, axis, keepdims, computing_meta=False,
weights = np.take_along_axis(weights, max_indexes, axis=0).squeeze(axis=0)
accums = np.take_along_axis(accums, max_indexes, axis=0).squeeze(axis=0)
return weights, accums
# NOTE: We use the builtin "sum" function below because it does not copy
# the numpy arrays. Using numpy.sum would do that.
return sum(weights), sum(accums)
return _sum_arrays(weights), _sum_arrays(accums)


def _is_empty_chunk(x_chunk):
Expand Down Expand Up @@ -224,60 +312,79 @@ def _get_rows_per_scan(self, rows_per_scan=None):
rows_per_scan = self.source_geo_def.shape[0]
return rows_per_scan

def _fill_block_cache_with_ll2cr_results(self, ll2cr_result,
num_row_blocks,
num_col_blocks,
persist):
def _ll2cr_cache_matches(self, rows_per_scan, persist):
return (
self.cache.get('rows_per_scan') == rows_per_scan and
self.cache.get('persist') == persist
)

def _get_ll2cr_blocks(self, ll2cr_result, persist):
ll2cr_blocks = []
block_dependencies = None
if persist:
ll2cr_delayeds = ll2cr_result.to_delayed()
ll2cr_delayeds = dask.persist(*ll2cr_delayeds.tolist())

block_cache = {}
for in_row_idx in range(num_row_blocks):
for in_col_idx in range(num_col_blocks):
key = (ll2cr_result.name, in_row_idx, in_col_idx)
if persist:
this_delayed = ll2cr_delayeds[in_row_idx][in_col_idx]
result = dask.compute(this_delayed)[0]
# XXX: Is this optimization lost because the persisted keys
# in `ll2cr_delayeds` are used in future computations?
if not isinstance(result[0], tuple):
block_cache[key] = this_delayed.key
else:
block_cache[key] = key
return block_cache
flat_delayeds = [
(in_row_idx, in_col_idx, delayed_block)
for in_row_idx, delayed_row in enumerate(ll2cr_delayeds)
for in_col_idx, delayed_block in enumerate(delayed_row)
]
block_dependencies = []
persisted_delayeds = dask.persist(
*(delayed for _, _, delayed in flat_delayeds))
# Compute only per-block extents on workers to avoid materializing
# full ll2cr blocks in the client process.
extent_delayeds = [dask.delayed(_ll2cr_block_extent)(d) for d in persisted_delayeds]
computed_extents = dask.compute(*extent_delayeds)
for (in_row_idx, in_col_idx, _), persisted_delayed, extent in zip(
flat_delayeds, persisted_delayeds, computed_extents, strict=True):
if extent is None:
continue
ll2cr_blocks.append((in_row_idx, in_col_idx, persisted_delayed.key, extent))
block_dependencies.append(persisted_delayed)
else:
num_row_blocks, num_col_blocks = ll2cr_result.numblocks[-2:]
for in_row_idx in range(num_row_blocks):
for in_col_idx in range(num_col_blocks):
ll2cr_blocks.append((
in_row_idx,
in_col_idx,
(ll2cr_result.name, in_row_idx, in_col_idx),
None,
))
return ll2cr_blocks, block_dependencies

def precompute(self, cache_dir=None, rows_per_scan=None, persist=False,
**kwargs):
"""Generate row and column arrays and store it for later use."""
if self.cache:
rows_per_scan = self._get_rows_per_scan(rows_per_scan)
if self._ll2cr_cache_matches(rows_per_scan, persist):
# this resampler should be used for one SwathDefinition
# no need to recompute ll2cr output again
# no need to recompute matching ll2cr output again
return None

if kwargs.get('mask') is not None:
logger.warning("'mask' parameter has no affect during EWA "
"resampling")

source_geo_def = self.source_geo_def
target_geo_def = self.target_geo_def
if cache_dir:
logger.warning("'cache_dir' is not used by EWA resampling")

rows_per_scan = self._get_rows_per_scan(rows_per_scan)
new_chunks = self._new_chunks(source_geo_def.lons, rows_per_scan)
lons, lats = source_geo_def.get_lonlats(chunks=new_chunks)
new_chunks = self._new_chunks(self.source_geo_def.lons, rows_per_scan)
lons, lats = self.source_geo_def.get_lonlats(chunks=new_chunks)
# run ll2cr to get column/row indexes
# if chunk does not overlap target area then None is returned
# otherwise a 3D array (2, y, x) of cols, rows are returned
ll2cr_result = _call_mapped_ll2cr(lons, lats, target_geo_def)
block_cache = self._fill_block_cache_with_ll2cr_results(
ll2cr_result, lons.numblocks[0], lons.numblocks[1], persist)
ll2cr_result = _call_mapped_ll2cr(lons, lats, self.target_geo_def)
ll2cr_blocks, block_dependencies = self._get_ll2cr_blocks(
ll2cr_result, persist)

# save the dask arrays in the class instance cache
self.cache = {
'll2cr_result': ll2cr_result,
'll2cr_blocks': block_cache,
'll2cr_blocks': ll2cr_blocks,
'll2cr_block_dependencies': block_dependencies,
'rows_per_scan': rows_per_scan,
'persist': persist,
}
return None

Expand Down Expand Up @@ -323,27 +430,46 @@ def _generate_fornav_dask_tasks(out_chunks, ll2cr_blocks, task_name,
input_name, target_geo_def, fill_value, kwargs):
y_start = 0
output_stack = {}
overlap_margin = max(
float(kwargs.get("weight_delta_max", 0.0)),
float(kwargs.get("weight_distance_max", 0.0)),
0.0,
)
indexed_blocks = []
for z_idx, (in_row_idx, in_col_idx, ll2cr_block, block_extent) in enumerate(ll2cr_blocks):
block_bounds = _pad_bounds(block_extent, overlap_margin)
indexed_blocks.append((z_idx, in_row_idx, in_col_idx, ll2cr_block, block_bounds))
for out_row_idx in range(len(out_chunks[0])):
y_end = y_start + out_chunks[0][out_row_idx]
x_start = 0
for out_col_idx in range(len(out_chunks[1])):
x_end = x_start + out_chunks[1][out_col_idx]
y_slice = slice(y_start, y_end)
x_slice = slice(x_start, x_end)
for z_idx, ((_, in_row_idx, in_col_idx), ll2cr_block) in enumerate(ll2cr_blocks):
placeholder = (
((y_end - y_start, x_end - x_start), 0, np.float32),
((y_end - y_start, x_end - x_start), 0, np.float32),
)
for z_idx, in_row_idx, in_col_idx, ll2cr_block, block_bounds in indexed_blocks:
key = (task_name, z_idx, out_row_idx, out_col_idx)
output_stack[key] = (_delayed_fornav,
ll2cr_block,
target_geo_def, y_slice, x_slice,
(input_name, in_row_idx, in_col_idx), fill_value, kwargs)
if _chunk_intersects_bounds(block_bounds, y_slice, x_slice):
output_stack[key] = (_delayed_fornav,
ll2cr_block,
target_geo_def, y_slice, x_slice,
(input_name, in_row_idx, in_col_idx), fill_value, kwargs)
else:
output_stack[key] = placeholder
x_start = x_end
y_start = y_end
return output_stack

def _run_fornav_single(self, data, out_chunks, target_geo_def, fill_value, **kwargs):
ll2cr_result = self.cache['ll2cr_result']
ll2cr_blocks = self.cache['ll2cr_blocks'].items()
ll2cr_numblocks = ll2cr_result.shape if isinstance(ll2cr_result, np.ndarray) else ll2cr_result.numblocks
ll2cr_blocks = self.cache['ll2cr_blocks']
ll2cr_block_dependencies = self.cache.get('ll2cr_block_dependencies')
if not ll2cr_blocks:
return da.full(target_geo_def.shape, fill_value, dtype=data.dtype,
chunks=out_chunks)
fornav_task_name = f"fornav-{data.name}-{ll2cr_result.name}"
maximum_weight_mode = kwargs.setdefault('maximum_weight_mode', False)
weight_sum_min = kwargs.setdefault('weight_sum_min', -1.0)
Expand All @@ -357,8 +483,12 @@ def _run_fornav_single(self, data, out_chunks, target_geo_def, fill_value, **kwa

dsk_graph = HighLevelGraph.from_collections(fornav_task_name,
output_stack,
dependencies=[data, ll2cr_result])
stack_chunks = ((1,) * (ll2cr_numblocks[0] * ll2cr_numblocks[1]),) + out_chunks
dependencies=(
(data, ll2cr_result)
if ll2cr_block_dependencies is None
else (data, *ll2cr_block_dependencies)
))
stack_chunks = ((1,) * len(ll2cr_blocks),) + out_chunks
out_stack = da.Array(dsk_graph, fornav_task_name, stack_chunks, data.dtype)
combine_fornav_with_kwargs = partial(
_combine_fornav, maximum_weight_mode=maximum_weight_mode)
Expand Down
Loading
Loading