Skip to content

zonal.apply 3D dask path leaves axis-2 chunked at size 1 after da.stack #2526

@brendancol

Description

@brendancol

Summary

xrspatial.zonal.apply on a 3D dask-backed values array calls
da.stack(layers, axis=2) to assemble the per-layer map_blocks
results. da.stack produces unit chunks along the new axis, and the
function never rechunk()s afterward, so the output has chunks[2] = (1, 1, ..., 1) regardless of the input chunking.

Confirmed by graph probe (256x256 raster, chunks=(64,64), 3 bands):

"apply_dask_3d": {
  "task_count": 176,
  "tasks_per_chunk": 11.0,
  "output_chunks_axis2": [1, 1, 1]
}

This is the documented anti-pattern from CLAUDE.md:

da.stack creates chunks of size 1 along the new axis — use
.rechunk({0: N}) to merge

Impact

  • Downstream operations that touch all bands together pay scheduling
    overhead proportional to the band count rather than seeing one chunk.
  • Task count inflates linearly with band count.
  • Both _apply_dask_numpy (zonal.py:1691) and _apply_dask_cupy
    (zonal.py:1731) have the same pattern.

Fix

Rechunk the stacked output to merge unit chunks along the new axis.

How to reproduce

import dask.array as da
import numpy as np
import xarray as xr
from xrspatial.zonal import apply

shape = (256, 256, 3)
z = xr.DataArray(
    da.from_array(np.zeros(shape[:2], dtype=np.int64), chunks=(64, 64)),
    dims=['y', 'x'],
)
v = xr.DataArray(
    da.from_array(np.zeros(shape), chunks=(64, 64, 3)),
    dims=['y', 'x', 'band'],
)
out = apply(z, v, lambda x: x * 2)
print(out.data.chunks[2])  # (1, 1, 1) -- bug

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workingperformancePR touches performance-sensitive code

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions