Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 34 additions & 26 deletions xrspatial/hydro/flow_accumulation_mfd.py
Original file line number Diff line number Diff line change
Expand Up @@ -698,6 +698,11 @@ def _flow_accum_mfd_dask_iterative(fractions_da, chunks_y, chunks_x):
n_tile_y = len(chunks_y)
n_tile_x = len(chunks_x)

# The 8 direction bands must stay in a single chunk: every tile kernel
# needs all 8 fractions, and the lazy assembly drops axis 0 per block.
if fractions_da.chunks[0] != (fractions_da.shape[0],):
fractions_da = fractions_da.rechunk({0: fractions_da.shape[0]})

# Phase 0: extract boundary fraction strips
frac_bdry = _preprocess_mfd_tiles(fractions_da, chunks_y, chunks_x)

Expand Down Expand Up @@ -737,34 +742,37 @@ def _flow_accum_mfd_dask_iterative(fractions_da, chunks_y, chunks_x):

def _assemble_result_mfd(fractions_da, boundaries, frac_bdry,
chunks_y, chunks_x, n_tile_y, n_tile_x):
"""Build lazy dask array by re-running each MFD tile with converged seeds.
"""Build a lazy dask array by re-running each MFD tile with converged seeds.

fractions_da is (8, H, W). We build a 2-D dask result by using
da.block from individually computed tiles.
fractions_da is (8, H, W) chunked one tile per (chunks_y, chunks_x)
block. The converged boundary snapshot and fraction strips are small,
so we capture them in a closure and let ``map_blocks`` run the per-tile
kernel at compute time. Nothing here materializes the full output
raster during the API call.
"""
rows = []
for iy in range(n_tile_y):
row = []
for ix in range(n_tile_x):
y_start = sum(chunks_y[:iy])
y_end = y_start + chunks_y[iy]
x_start = sum(chunks_x[:ix])
x_end = x_start + chunks_x[ix]

chunk = np.asarray(
fractions_da[:, y_start:y_end, x_start:x_end].compute(),
dtype=np.float64)
_, h, w = chunk.shape

seeds = _compute_seeds_mfd(
iy, ix, boundaries, frac_bdry,
chunks_y, chunks_x, n_tile_y, n_tile_x)

tile_accum = _flow_accum_mfd_tile_kernel(chunk, h, w, *seeds)
row.append(da.from_array(tile_accum, chunks=tile_accum.shape))
rows.append(row)

return da.block(rows)
# Cumulative tile-start offsets to map a block's spatial origin to (iy, ix).
y_starts = np.cumsum((0,) + tuple(chunks_y[:-1]))
x_starts = np.cumsum((0,) + tuple(chunks_x[:-1]))

def _tile(chunk, block_info=None):
# block_info[0]['array-location'] gives ((0, 8), (y0, y1), (x0, x1)).
loc = block_info[0]['array-location']
y0 = loc[1][0]
x0 = loc[2][0]
iy = int(np.searchsorted(y_starts, y0, side='right')) - 1
ix = int(np.searchsorted(x_starts, x0, side='right')) - 1

chunk = np.asarray(chunk, dtype=np.float64)
_, h, w = chunk.shape
seeds = _compute_seeds_mfd(
iy, ix, boundaries, frac_bdry,
chunks_y, chunks_x, n_tile_y, n_tile_x)
return _flow_accum_mfd_tile_kernel(chunk, h, w, *seeds)

return da.map_blocks(
_tile, fractions_da, drop_axis=0,
dtype=np.float64, meta=np.array((), dtype=np.float64),
)


def _flow_accum_mfd_dask_cupy(fractions_da, chunks_y, chunks_x):
Expand Down
68 changes: 37 additions & 31 deletions xrspatial/hydro/flow_length_mfd.py
Original file line number Diff line number Diff line change
Expand Up @@ -866,6 +866,11 @@ def _flow_length_mfd_dask_iterative(fractions_da, direction,
n_tile_y = len(chunks_y)
n_tile_x = len(chunks_x)

# The 8 direction bands must stay in a single chunk: every tile kernel
# needs all 8 fractions, and the lazy assembly drops axis 0 per block.
if fractions_da.chunks[0] != (fractions_da.shape[0],):
fractions_da = fractions_da.rechunk({0: fractions_da.shape[0]})

frac_bdry = _preprocess_mfd_tiles(fractions_da, chunks_y, chunks_x)

fill = np.nan if direction == 'downstream' else 0.0
Expand Down Expand Up @@ -910,39 +915,40 @@ def _flow_length_mfd_dask_iterative(fractions_da, direction,
def _assemble_result(fractions_da, boundaries, frac_bdry,
chunks_y, chunks_x, n_tile_y, n_tile_x,
direction, cellsize_x, cellsize_y):
"""Build dask array by re-running tiles with converged boundaries."""
rows = []
for iy in range(n_tile_y):
row = []
for ix in range(n_tile_x):
y_start = sum(chunks_y[:iy])
y_end = y_start + chunks_y[iy]
x_start = sum(chunks_x[:ix])
x_end = x_start + chunks_x[ix]

chunk = np.asarray(
fractions_da[:, y_start:y_end, x_start:x_end].compute(),
dtype=np.float64)
_, h, w = chunk.shape

if direction == 'downstream':
seeds = _compute_exit_seeds_downstream(
iy, ix, boundaries, frac_bdry,
chunks_y, chunks_x, n_tile_y, n_tile_x)
tile = _flow_length_mfd_downstream_tile(
chunk, h, w, cellsize_x, cellsize_y, *seeds)
else:
seeds = _compute_entry_seeds_upstream(
iy, ix, boundaries, frac_bdry,
chunks_y, chunks_x, n_tile_y, n_tile_x,
cellsize_x, cellsize_y)
tile = _flow_length_mfd_upstream_tile(
chunk, h, w, cellsize_x, cellsize_y, *seeds)
"""Build a lazy dask array by re-running tiles with converged boundaries.

row.append(da.from_array(tile, chunks=tile.shape))
rows.append(row)
The converged boundary snapshot and fraction strips are small, so we
capture them in a closure and let ``map_blocks`` run the per-tile
kernel at compute time. Nothing here materializes the full output
raster during the API call.
"""
y_starts = np.cumsum((0,) + tuple(chunks_y[:-1]))
x_starts = np.cumsum((0,) + tuple(chunks_x[:-1]))

def _tile(chunk, block_info=None):
loc = block_info[0]['array-location']
iy = int(np.searchsorted(y_starts, loc[1][0], side='right')) - 1
ix = int(np.searchsorted(x_starts, loc[2][0], side='right')) - 1

return da.block(rows)
chunk = np.asarray(chunk, dtype=np.float64)
_, h, w = chunk.shape
if direction == 'downstream':
seeds = _compute_exit_seeds_downstream(
iy, ix, boundaries, frac_bdry,
chunks_y, chunks_x, n_tile_y, n_tile_x)
return _flow_length_mfd_downstream_tile(
chunk, h, w, cellsize_x, cellsize_y, *seeds)
seeds = _compute_entry_seeds_upstream(
iy, ix, boundaries, frac_bdry,
chunks_y, chunks_x, n_tile_y, n_tile_x,
cellsize_x, cellsize_y)
return _flow_length_mfd_upstream_tile(
chunk, h, w, cellsize_x, cellsize_y, *seeds)

return da.map_blocks(
_tile, fractions_da, drop_axis=0,
dtype=np.float64, meta=np.array((), dtype=np.float64),
)


def _flow_length_mfd_dask_cupy(fractions_da, direction,
Expand Down
55 changes: 55 additions & 0 deletions xrspatial/hydro/tests/test_flow_accumulation_mfd.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,61 @@ def test_dask_many_chunks(self):
np.testing.assert_allclose(
accum_dask.values, accum_np.values, atol=1e-10, equal_nan=True)

def test_dask_assembly_is_lazy(self, monkeypatch):
"""The returned DataArray must not run the assembly kernel until compute.

The boundary-convergence sweep runs the tile kernel eagerly during the
call, but assembling the output raster must be deferred to compute time.
Spy on the tile kernel and confirm that ``.compute()`` triggers
additional kernel calls beyond what the convergence sweep ran.
"""
dask = pytest.importorskip('dask.array')
import importlib
mod = importlib.import_module('xrspatial.hydro.flow_accumulation_mfd')

counter = {'n': 0}
orig = mod._flow_accum_mfd_tile_kernel

def _spy(*args, **kwargs):
counter['n'] += 1
return orig(*args, **kwargs)

monkeypatch.setattr(mod, '_flow_accum_mfd_tile_kernel', _spy)

elev = _make_bowl(11)
mfd_np = flow_direction_mfd(elev)
data_dask = dask.from_array(mfd_np.values, chunks=(8, 4, 4))
mfd_dask = xr.DataArray(data_dask, dims=mfd_np.dims, coords=mfd_np.coords)

accum = flow_accumulation_mfd(mfd_dask)
calls_after_call = counter['n']

# The output is a lazy dask-backed DataArray.
assert isinstance(accum.data, dask.Array)

accum.compute()
calls_added_by_compute = counter['n'] - calls_after_call

# If assembly were eager, compute would add zero kernel calls.
assert calls_added_by_compute > 0

def test_dask_band_axis_chunked(self):
"""An input chunked along the 8-band axis still matches numpy."""
dask = pytest.importorskip('dask.array')
elev = _make_bowl(9)
mfd_np = flow_direction_mfd(elev)
accum_np = flow_accumulation_mfd(mfd_np)

# Chunk axis 0 into two blocks to exercise the rechunk guard.
data_dask = dask.from_array(mfd_np.values, chunks=(4, 5, 5))
mfd_dask = xr.DataArray(data_dask,
dims=mfd_np.dims,
coords=mfd_np.coords)
accum_dask = flow_accumulation_mfd(mfd_dask)

np.testing.assert_allclose(
accum_dask.values, accum_np.values, atol=1e-10, equal_nan=True)


class TestFlowAccumulationMFDDataset:
"""Dataset support tests."""
Expand Down
50 changes: 50 additions & 0 deletions xrspatial/hydro/tests/test_flow_length_mfd.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,56 @@ def test_cross_tile_flow(self, direction):
result_np.data, result_da.data.compute(),
equal_nan=True, rtol=1e-10)

@pytest.mark.parametrize("direction", ['downstream', 'upstream'])
def test_band_axis_chunked(self, direction):
"""Input chunked along the 8-band axis still matches numpy."""
from xrspatial.hydro.flow_direction_mfd import flow_direction_mfd

np.random.seed(7)
elev = np.random.uniform(0, 100, (6, 6)).astype(np.float64)
elev_r = create_test_raster(elev, backend='numpy', name='elev')
mfd = flow_direction_mfd(elev_r)
mfd_data = mfd.data.astype(np.float64)

np_raster = _make_mfd_raster(mfd_data, backend='numpy')
da_raster = _make_mfd_raster(mfd_data, backend='dask', chunks=(3, 3))
# Split the band axis into two chunks to exercise the rechunk guard.
da_raster = da_raster.chunk({'neighbor': 4})

result_np = flow_length_mfd(np_raster, direction=direction)
result_da = flow_length_mfd(da_raster, direction=direction)

np.testing.assert_allclose(
result_np.data, result_da.data.compute(),
equal_nan=True, rtol=1e-10)

def test_assembly_is_lazy(self, monkeypatch):
"""Assembling the output raster must be deferred to compute time."""
import importlib
mod = importlib.import_module('xrspatial.hydro.flow_length_mfd')
from xrspatial.hydro.flow_direction_mfd import flow_direction_mfd

counter = {'n': 0}
orig = mod._flow_length_mfd_downstream_tile

def _spy(*args, **kwargs):
counter['n'] += 1
return orig(*args, **kwargs)

monkeypatch.setattr(mod, '_flow_length_mfd_downstream_tile', _spy)

np.random.seed(11)
elev = np.random.uniform(0, 100, (8, 8)).astype(np.float64)
elev_r = create_test_raster(elev, backend='numpy', name='elev')
mfd = flow_direction_mfd(elev_r)
mfd_data = mfd.data.astype(np.float64)
da_raster = _make_mfd_raster(mfd_data, backend='dask', chunks=(3, 3))

result = flow_length_mfd(da_raster, direction='downstream')
calls_after_call = counter['n']
result.data.compute()
assert counter['n'] - calls_after_call > 0


@cuda_and_cupy_available
class TestFlowLengthMfdCuPy:
Expand Down
70 changes: 70 additions & 0 deletions xrspatial/hydro/tests/test_watershed_mfd.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,76 @@ def test_numpy_equals_dask(chunks):
np_result.data, dk_result.data.compute(), equal_nan=True)


@dask_array_available
def test_dask_band_axis_chunked():
"""Input chunked along the 8-band axis still matches numpy."""
fracs = _make_all_east(4, 6)
pp = np.full((4, 6), np.nan, dtype=np.float64)
pp[0, 5] = 1.0
pp[3, 5] = 2.0

fd_np = _make_mfd_raster(fracs, backend='numpy')
pp_np = create_test_raster(pp, backend='numpy')
fd_dk = _make_mfd_raster(fracs, backend='dask', chunks=(2, 2))
fd_dk = fd_dk.chunk({'neighbor': 4})
pp_dk = create_test_raster(pp, backend='dask', chunks=(2, 2))

np_result = watershed_mfd(fd_np, pp_np)
dk_result = watershed_mfd(fd_dk, pp_dk)

np.testing.assert_allclose(
np_result.data, dk_result.data.compute(), equal_nan=True)


@dask_array_available
def test_dask_pour_points_chunk_mismatch():
"""Pour points chunked differently from fractions still match numpy."""
fracs = _make_all_east(4, 6)
pp = np.full((4, 6), np.nan, dtype=np.float64)
pp[0, 5] = 1.0
pp[3, 5] = 2.0

fd_np = _make_mfd_raster(fracs, backend='numpy')
pp_np = create_test_raster(pp, backend='numpy')
fd_dk = _make_mfd_raster(fracs, backend='dask', chunks=(2, 2))
# Pour points chunked 3x3 while fractions are 2x2.
pp_dk = create_test_raster(pp, backend='dask', chunks=(3, 3))

np_result = watershed_mfd(fd_np, pp_np)
dk_result = watershed_mfd(fd_dk, pp_dk)

np.testing.assert_allclose(
np_result.data, dk_result.data.compute(), equal_nan=True)


@dask_array_available
def test_dask_assembly_is_lazy(monkeypatch):
"""Assembling the output raster must be deferred to compute time."""
import importlib
mod = importlib.import_module('xrspatial.hydro.watershed_mfd')

counter = {'n': 0}
orig = mod._watershed_mfd_tile_kernel

def _spy(*args, **kwargs):
counter['n'] += 1
return orig(*args, **kwargs)

monkeypatch.setattr(mod, '_watershed_mfd_tile_kernel', _spy)

fracs = _make_all_east(6, 6)
pp = np.full((6, 6), np.nan, dtype=np.float64)
pp[0, 5] = 1.0
pp[5, 5] = 2.0
fd_dk = _make_mfd_raster(fracs, backend='dask', chunks=(2, 2))
pp_dk = create_test_raster(pp, backend='dask', chunks=(2, 2))

result = watershed_mfd(fd_dk, pp_dk)
calls_after_call = counter['n']
result.data.compute()
assert counter['n'] - calls_after_call > 0


@cuda_and_cupy_available
def test_numpy_equals_cupy():
fracs = _make_all_east(3, 4)
Expand Down
Loading
Loading