From 7fc69d6bb1b8037d2c3a7f5ec3c56bb271aff9d8 Mon Sep 17 00:00:00 2001 From: Brendan Collins Date: Tue, 2 Jun 2026 14:59:45 -0700 Subject: [PATCH] Keep MFD dask output assembly lazy (#2865) The dask assembly step in flow_accumulation_mfd, flow_length_mfd, and watershed_mfd computed each tile and wrapped the numpy result with da.from_array / da.block, materializing the full output raster during the API call. Rebuild the result with da.map_blocks over the lazy fractions array, passing the small converged boundary snapshot into the block function so the per-tile kernel runs at compute time. The boundary-convergence sweep still reads tile data eagerly by design; lazifying that phase is tracked separately. --- xrspatial/hydro/flow_accumulation_mfd.py | 60 +++++++++------- xrspatial/hydro/flow_length_mfd.py | 68 ++++++++++-------- .../hydro/tests/test_flow_accumulation_mfd.py | 55 +++++++++++++++ xrspatial/hydro/tests/test_flow_length_mfd.py | 50 +++++++++++++ xrspatial/hydro/tests/test_watershed_mfd.py | 70 +++++++++++++++++++ xrspatial/hydro/watershed_mfd.py | 58 ++++++++------- 6 files changed, 278 insertions(+), 83 deletions(-) diff --git a/xrspatial/hydro/flow_accumulation_mfd.py b/xrspatial/hydro/flow_accumulation_mfd.py index 18d407177..8c422309b 100644 --- a/xrspatial/hydro/flow_accumulation_mfd.py +++ b/xrspatial/hydro/flow_accumulation_mfd.py @@ -698,6 +698,11 @@ def _flow_accum_mfd_dask_iterative(fractions_da, chunks_y, chunks_x): n_tile_y = len(chunks_y) n_tile_x = len(chunks_x) + # The 8 direction bands must stay in a single chunk: every tile kernel + # needs all 8 fractions, and the lazy assembly drops axis 0 per block. + if fractions_da.chunks[0] != (fractions_da.shape[0],): + fractions_da = fractions_da.rechunk({0: fractions_da.shape[0]}) + # Phase 0: extract boundary fraction strips frac_bdry = _preprocess_mfd_tiles(fractions_da, chunks_y, chunks_x) @@ -737,34 +742,37 @@ def _flow_accum_mfd_dask_iterative(fractions_da, chunks_y, chunks_x): def _assemble_result_mfd(fractions_da, boundaries, frac_bdry, chunks_y, chunks_x, n_tile_y, n_tile_x): - """Build lazy dask array by re-running each MFD tile with converged seeds. + """Build a lazy dask array by re-running each MFD tile with converged seeds. - fractions_da is (8, H, W). We build a 2-D dask result by using - da.block from individually computed tiles. + fractions_da is (8, H, W) chunked one tile per (chunks_y, chunks_x) + block. The converged boundary snapshot and fraction strips are small, + so we capture them in a closure and let ``map_blocks`` run the per-tile + kernel at compute time. Nothing here materializes the full output + raster during the API call. """ - rows = [] - for iy in range(n_tile_y): - row = [] - for ix in range(n_tile_x): - y_start = sum(chunks_y[:iy]) - y_end = y_start + chunks_y[iy] - x_start = sum(chunks_x[:ix]) - x_end = x_start + chunks_x[ix] - - chunk = np.asarray( - fractions_da[:, y_start:y_end, x_start:x_end].compute(), - dtype=np.float64) - _, h, w = chunk.shape - - seeds = _compute_seeds_mfd( - iy, ix, boundaries, frac_bdry, - chunks_y, chunks_x, n_tile_y, n_tile_x) - - tile_accum = _flow_accum_mfd_tile_kernel(chunk, h, w, *seeds) - row.append(da.from_array(tile_accum, chunks=tile_accum.shape)) - rows.append(row) - - return da.block(rows) + # Cumulative tile-start offsets to map a block's spatial origin to (iy, ix). + y_starts = np.cumsum((0,) + tuple(chunks_y[:-1])) + x_starts = np.cumsum((0,) + tuple(chunks_x[:-1])) + + def _tile(chunk, block_info=None): + # block_info[0]['array-location'] gives ((0, 8), (y0, y1), (x0, x1)). + loc = block_info[0]['array-location'] + y0 = loc[1][0] + x0 = loc[2][0] + iy = int(np.searchsorted(y_starts, y0, side='right')) - 1 + ix = int(np.searchsorted(x_starts, x0, side='right')) - 1 + + chunk = np.asarray(chunk, dtype=np.float64) + _, h, w = chunk.shape + seeds = _compute_seeds_mfd( + iy, ix, boundaries, frac_bdry, + chunks_y, chunks_x, n_tile_y, n_tile_x) + return _flow_accum_mfd_tile_kernel(chunk, h, w, *seeds) + + return da.map_blocks( + _tile, fractions_da, drop_axis=0, + dtype=np.float64, meta=np.array((), dtype=np.float64), + ) def _flow_accum_mfd_dask_cupy(fractions_da, chunks_y, chunks_x): diff --git a/xrspatial/hydro/flow_length_mfd.py b/xrspatial/hydro/flow_length_mfd.py index d207b81e8..cca9e1476 100644 --- a/xrspatial/hydro/flow_length_mfd.py +++ b/xrspatial/hydro/flow_length_mfd.py @@ -866,6 +866,11 @@ def _flow_length_mfd_dask_iterative(fractions_da, direction, n_tile_y = len(chunks_y) n_tile_x = len(chunks_x) + # The 8 direction bands must stay in a single chunk: every tile kernel + # needs all 8 fractions, and the lazy assembly drops axis 0 per block. + if fractions_da.chunks[0] != (fractions_da.shape[0],): + fractions_da = fractions_da.rechunk({0: fractions_da.shape[0]}) + frac_bdry = _preprocess_mfd_tiles(fractions_da, chunks_y, chunks_x) fill = np.nan if direction == 'downstream' else 0.0 @@ -910,39 +915,40 @@ def _flow_length_mfd_dask_iterative(fractions_da, direction, def _assemble_result(fractions_da, boundaries, frac_bdry, chunks_y, chunks_x, n_tile_y, n_tile_x, direction, cellsize_x, cellsize_y): - """Build dask array by re-running tiles with converged boundaries.""" - rows = [] - for iy in range(n_tile_y): - row = [] - for ix in range(n_tile_x): - y_start = sum(chunks_y[:iy]) - y_end = y_start + chunks_y[iy] - x_start = sum(chunks_x[:ix]) - x_end = x_start + chunks_x[ix] - - chunk = np.asarray( - fractions_da[:, y_start:y_end, x_start:x_end].compute(), - dtype=np.float64) - _, h, w = chunk.shape - - if direction == 'downstream': - seeds = _compute_exit_seeds_downstream( - iy, ix, boundaries, frac_bdry, - chunks_y, chunks_x, n_tile_y, n_tile_x) - tile = _flow_length_mfd_downstream_tile( - chunk, h, w, cellsize_x, cellsize_y, *seeds) - else: - seeds = _compute_entry_seeds_upstream( - iy, ix, boundaries, frac_bdry, - chunks_y, chunks_x, n_tile_y, n_tile_x, - cellsize_x, cellsize_y) - tile = _flow_length_mfd_upstream_tile( - chunk, h, w, cellsize_x, cellsize_y, *seeds) + """Build a lazy dask array by re-running tiles with converged boundaries. - row.append(da.from_array(tile, chunks=tile.shape)) - rows.append(row) + The converged boundary snapshot and fraction strips are small, so we + capture them in a closure and let ``map_blocks`` run the per-tile + kernel at compute time. Nothing here materializes the full output + raster during the API call. + """ + y_starts = np.cumsum((0,) + tuple(chunks_y[:-1])) + x_starts = np.cumsum((0,) + tuple(chunks_x[:-1])) + + def _tile(chunk, block_info=None): + loc = block_info[0]['array-location'] + iy = int(np.searchsorted(y_starts, loc[1][0], side='right')) - 1 + ix = int(np.searchsorted(x_starts, loc[2][0], side='right')) - 1 - return da.block(rows) + chunk = np.asarray(chunk, dtype=np.float64) + _, h, w = chunk.shape + if direction == 'downstream': + seeds = _compute_exit_seeds_downstream( + iy, ix, boundaries, frac_bdry, + chunks_y, chunks_x, n_tile_y, n_tile_x) + return _flow_length_mfd_downstream_tile( + chunk, h, w, cellsize_x, cellsize_y, *seeds) + seeds = _compute_entry_seeds_upstream( + iy, ix, boundaries, frac_bdry, + chunks_y, chunks_x, n_tile_y, n_tile_x, + cellsize_x, cellsize_y) + return _flow_length_mfd_upstream_tile( + chunk, h, w, cellsize_x, cellsize_y, *seeds) + + return da.map_blocks( + _tile, fractions_da, drop_axis=0, + dtype=np.float64, meta=np.array((), dtype=np.float64), + ) def _flow_length_mfd_dask_cupy(fractions_da, direction, diff --git a/xrspatial/hydro/tests/test_flow_accumulation_mfd.py b/xrspatial/hydro/tests/test_flow_accumulation_mfd.py index b30faf1cb..ecf37959b 100644 --- a/xrspatial/hydro/tests/test_flow_accumulation_mfd.py +++ b/xrspatial/hydro/tests/test_flow_accumulation_mfd.py @@ -332,6 +332,61 @@ def test_dask_many_chunks(self): np.testing.assert_allclose( accum_dask.values, accum_np.values, atol=1e-10, equal_nan=True) + def test_dask_assembly_is_lazy(self, monkeypatch): + """The returned DataArray must not run the assembly kernel until compute. + + The boundary-convergence sweep runs the tile kernel eagerly during the + call, but assembling the output raster must be deferred to compute time. + Spy on the tile kernel and confirm that ``.compute()`` triggers + additional kernel calls beyond what the convergence sweep ran. + """ + dask = pytest.importorskip('dask.array') + import importlib + mod = importlib.import_module('xrspatial.hydro.flow_accumulation_mfd') + + counter = {'n': 0} + orig = mod._flow_accum_mfd_tile_kernel + + def _spy(*args, **kwargs): + counter['n'] += 1 + return orig(*args, **kwargs) + + monkeypatch.setattr(mod, '_flow_accum_mfd_tile_kernel', _spy) + + elev = _make_bowl(11) + mfd_np = flow_direction_mfd(elev) + data_dask = dask.from_array(mfd_np.values, chunks=(8, 4, 4)) + mfd_dask = xr.DataArray(data_dask, dims=mfd_np.dims, coords=mfd_np.coords) + + accum = flow_accumulation_mfd(mfd_dask) + calls_after_call = counter['n'] + + # The output is a lazy dask-backed DataArray. + assert isinstance(accum.data, dask.Array) + + accum.compute() + calls_added_by_compute = counter['n'] - calls_after_call + + # If assembly were eager, compute would add zero kernel calls. + assert calls_added_by_compute > 0 + + def test_dask_band_axis_chunked(self): + """An input chunked along the 8-band axis still matches numpy.""" + dask = pytest.importorskip('dask.array') + elev = _make_bowl(9) + mfd_np = flow_direction_mfd(elev) + accum_np = flow_accumulation_mfd(mfd_np) + + # Chunk axis 0 into two blocks to exercise the rechunk guard. + data_dask = dask.from_array(mfd_np.values, chunks=(4, 5, 5)) + mfd_dask = xr.DataArray(data_dask, + dims=mfd_np.dims, + coords=mfd_np.coords) + accum_dask = flow_accumulation_mfd(mfd_dask) + + np.testing.assert_allclose( + accum_dask.values, accum_np.values, atol=1e-10, equal_nan=True) + class TestFlowAccumulationMFDDataset: """Dataset support tests.""" diff --git a/xrspatial/hydro/tests/test_flow_length_mfd.py b/xrspatial/hydro/tests/test_flow_length_mfd.py index 77e3a68f1..3a2cd4cf5 100644 --- a/xrspatial/hydro/tests/test_flow_length_mfd.py +++ b/xrspatial/hydro/tests/test_flow_length_mfd.py @@ -366,6 +366,56 @@ def test_cross_tile_flow(self, direction): result_np.data, result_da.data.compute(), equal_nan=True, rtol=1e-10) + @pytest.mark.parametrize("direction", ['downstream', 'upstream']) + def test_band_axis_chunked(self, direction): + """Input chunked along the 8-band axis still matches numpy.""" + from xrspatial.hydro.flow_direction_mfd import flow_direction_mfd + + np.random.seed(7) + elev = np.random.uniform(0, 100, (6, 6)).astype(np.float64) + elev_r = create_test_raster(elev, backend='numpy', name='elev') + mfd = flow_direction_mfd(elev_r) + mfd_data = mfd.data.astype(np.float64) + + np_raster = _make_mfd_raster(mfd_data, backend='numpy') + da_raster = _make_mfd_raster(mfd_data, backend='dask', chunks=(3, 3)) + # Split the band axis into two chunks to exercise the rechunk guard. + da_raster = da_raster.chunk({'neighbor': 4}) + + result_np = flow_length_mfd(np_raster, direction=direction) + result_da = flow_length_mfd(da_raster, direction=direction) + + np.testing.assert_allclose( + result_np.data, result_da.data.compute(), + equal_nan=True, rtol=1e-10) + + def test_assembly_is_lazy(self, monkeypatch): + """Assembling the output raster must be deferred to compute time.""" + import importlib + mod = importlib.import_module('xrspatial.hydro.flow_length_mfd') + from xrspatial.hydro.flow_direction_mfd import flow_direction_mfd + + counter = {'n': 0} + orig = mod._flow_length_mfd_downstream_tile + + def _spy(*args, **kwargs): + counter['n'] += 1 + return orig(*args, **kwargs) + + monkeypatch.setattr(mod, '_flow_length_mfd_downstream_tile', _spy) + + np.random.seed(11) + elev = np.random.uniform(0, 100, (8, 8)).astype(np.float64) + elev_r = create_test_raster(elev, backend='numpy', name='elev') + mfd = flow_direction_mfd(elev_r) + mfd_data = mfd.data.astype(np.float64) + da_raster = _make_mfd_raster(mfd_data, backend='dask', chunks=(3, 3)) + + result = flow_length_mfd(da_raster, direction='downstream') + calls_after_call = counter['n'] + result.data.compute() + assert counter['n'] - calls_after_call > 0 + @cuda_and_cupy_available class TestFlowLengthMfdCuPy: diff --git a/xrspatial/hydro/tests/test_watershed_mfd.py b/xrspatial/hydro/tests/test_watershed_mfd.py index e997c444d..275ccbdc1 100644 --- a/xrspatial/hydro/tests/test_watershed_mfd.py +++ b/xrspatial/hydro/tests/test_watershed_mfd.py @@ -167,6 +167,76 @@ def test_numpy_equals_dask(chunks): np_result.data, dk_result.data.compute(), equal_nan=True) +@dask_array_available +def test_dask_band_axis_chunked(): + """Input chunked along the 8-band axis still matches numpy.""" + fracs = _make_all_east(4, 6) + pp = np.full((4, 6), np.nan, dtype=np.float64) + pp[0, 5] = 1.0 + pp[3, 5] = 2.0 + + fd_np = _make_mfd_raster(fracs, backend='numpy') + pp_np = create_test_raster(pp, backend='numpy') + fd_dk = _make_mfd_raster(fracs, backend='dask', chunks=(2, 2)) + fd_dk = fd_dk.chunk({'neighbor': 4}) + pp_dk = create_test_raster(pp, backend='dask', chunks=(2, 2)) + + np_result = watershed_mfd(fd_np, pp_np) + dk_result = watershed_mfd(fd_dk, pp_dk) + + np.testing.assert_allclose( + np_result.data, dk_result.data.compute(), equal_nan=True) + + +@dask_array_available +def test_dask_pour_points_chunk_mismatch(): + """Pour points chunked differently from fractions still match numpy.""" + fracs = _make_all_east(4, 6) + pp = np.full((4, 6), np.nan, dtype=np.float64) + pp[0, 5] = 1.0 + pp[3, 5] = 2.0 + + fd_np = _make_mfd_raster(fracs, backend='numpy') + pp_np = create_test_raster(pp, backend='numpy') + fd_dk = _make_mfd_raster(fracs, backend='dask', chunks=(2, 2)) + # Pour points chunked 3x3 while fractions are 2x2. + pp_dk = create_test_raster(pp, backend='dask', chunks=(3, 3)) + + np_result = watershed_mfd(fd_np, pp_np) + dk_result = watershed_mfd(fd_dk, pp_dk) + + np.testing.assert_allclose( + np_result.data, dk_result.data.compute(), equal_nan=True) + + +@dask_array_available +def test_dask_assembly_is_lazy(monkeypatch): + """Assembling the output raster must be deferred to compute time.""" + import importlib + mod = importlib.import_module('xrspatial.hydro.watershed_mfd') + + counter = {'n': 0} + orig = mod._watershed_mfd_tile_kernel + + def _spy(*args, **kwargs): + counter['n'] += 1 + return orig(*args, **kwargs) + + monkeypatch.setattr(mod, '_watershed_mfd_tile_kernel', _spy) + + fracs = _make_all_east(6, 6) + pp = np.full((6, 6), np.nan, dtype=np.float64) + pp[0, 5] = 1.0 + pp[5, 5] = 2.0 + fd_dk = _make_mfd_raster(fracs, backend='dask', chunks=(2, 2)) + pp_dk = create_test_raster(pp, backend='dask', chunks=(2, 2)) + + result = watershed_mfd(fd_dk, pp_dk) + calls_after_call = counter['n'] + result.data.compute() + assert counter['n'] - calls_after_call > 0 + + @cuda_and_cupy_available def test_numpy_equals_cupy(): fracs = _make_all_east(3, 4) diff --git a/xrspatial/hydro/watershed_mfd.py b/xrspatial/hydro/watershed_mfd.py index 6dfbd7838..c4575450b 100644 --- a/xrspatial/hydro/watershed_mfd.py +++ b/xrspatial/hydro/watershed_mfd.py @@ -561,6 +561,14 @@ def _watershed_mfd_dask(fractions_da, pour_points_da, chunks_y, chunks_x): n_tile_y = len(chunks_y) n_tile_x = len(chunks_x) + # The 8 direction bands must stay in a single chunk: every tile kernel + # needs all 8 fractions, and the lazy assembly drops axis 0 per block. + if fractions_da.chunks[0] != (fractions_da.shape[0],): + fractions_da = fractions_da.rechunk({0: fractions_da.shape[0]}) + # Align pour points to the fractions' spatial tile grid so the lazy + # assembly can map both arrays block-for-block. + pour_points_da = pour_points_da.rechunk((chunks_y, chunks_x)) + frac_bdry = _preprocess_mfd_tiles(fractions_da, chunks_y, chunks_x) boundaries = BoundaryStore(chunks_y, chunks_x, fill_value=np.nan) @@ -592,32 +600,30 @@ def _watershed_mfd_dask(fractions_da, pour_points_da, chunks_y, chunks_x): boundaries = boundaries.snapshot() - # Assemble final result - rows = [] - for iy in range(n_tile_y): - row = [] - for ix in range(n_tile_x): - y_start = sum(chunks_y[:iy]) - y_end = y_start + chunks_y[iy] - x_start = sum(chunks_x[:ix]) - x_end = x_start + chunks_x[ix] - - chunk = np.asarray( - fractions_da[:, y_start:y_end, x_start:x_end].compute(), - dtype=np.float64) - pp_chunk = np.asarray( - pour_points_da.blocks[iy, ix].compute(), dtype=np.float64) - _, h, w = chunk.shape - - exits = _compute_exit_labels_mfd( - iy, ix, boundaries, frac_bdry, - chunks_y, chunks_x, n_tile_y, n_tile_x) - - tile = _watershed_mfd_tile_kernel(chunk, h, w, pp_chunk, *exits) - row.append(da.from_array(tile, chunks=tile.shape)) - rows.append(row) - - return da.block(rows) + # Assemble the final result lazily. The converged boundary snapshot and + # fraction strips are small, so we capture them in a closure and let + # map_blocks run the per-tile kernel at compute time. Nothing here + # materializes the full output raster during the API call. + y_starts = np.cumsum((0,) + tuple(chunks_y[:-1])) + x_starts = np.cumsum((0,) + tuple(chunks_x[:-1])) + + def _tile(chunk, pp_chunk, block_info=None): + loc = block_info[0]['array-location'] + iy = int(np.searchsorted(y_starts, loc[1][0], side='right')) - 1 + ix = int(np.searchsorted(x_starts, loc[2][0], side='right')) - 1 + + chunk = np.asarray(chunk, dtype=np.float64) + pp_chunk = np.asarray(pp_chunk, dtype=np.float64) + _, h, w = chunk.shape + exits = _compute_exit_labels_mfd( + iy, ix, boundaries, frac_bdry, + chunks_y, chunks_x, n_tile_y, n_tile_x) + return _watershed_mfd_tile_kernel(chunk, h, w, pp_chunk, *exits) + + return da.map_blocks( + _tile, fractions_da, pour_points_da, drop_axis=0, + dtype=np.float64, meta=np.array((), dtype=np.float64), + ) # =====================================================================