Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions xrspatial/tests/test_zonal.py
Original file line number Diff line number Diff line change
Expand Up @@ -1546,6 +1546,34 @@ def test_apply_dask_3d_axis2_rechunked_2526():
)


@pytest.mark.skipif(not dask_array_available(), reason="Requires Dask")
def test_apply_3d_mixed_backend_raises():
"""Regression for #2639: 3D apply() with mixed dask/numpy zones and
values must fail early with a clear backend error, not crash with an
AttributeError or silently return eager numpy output.
"""
zones_data = np.array([[1, 0],
[0, 2]], dtype=np.int32)
values_data = np.ones((2, 2, 3)) * 5.0

# numpy zones + dask values: dask backend would hit zones.chunks
zones_np = xr.DataArray(zones_data, dims=['y', 'x'])
values_dask = xr.DataArray(
da.from_array(values_data, chunks=(2, 2, 3)),
dims=['y', 'x', 'band'],
)
with pytest.raises(ValueError, match="same backend"):
apply(zones_np, values_dask, lambda x: x + 10, nodata=0)

# dask zones + numpy values: numpy backend would silently go eager
zones_dask = xr.DataArray(
da.from_array(zones_data, chunks=(2, 2)), dims=['y', 'x'],
)
values_np = xr.DataArray(values_data, dims=['y', 'x', 'band'])
with pytest.raises(ValueError, match="same backend"):
apply(zones_dask, values_np, lambda x: x + 10, nodata=0)


def test_apply_nodata_none():
zones_data = np.array([[0, 1], [2, 3]], dtype=np.int32)
values_data = np.array([[1.0, 2.0], [3.0, 4.0]])
Expand Down
21 changes: 18 additions & 3 deletions xrspatial/zonal.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,9 @@ class cupy(object):
ndarray = False

# local modules
from xrspatial.utils import (ArrayTypeFunctionMapping, _validate_raster, cuda_args,
has_cuda_and_cupy, has_dask_array, is_cupy_array, is_dask_cupy, ngjit,
validate_arrays)
from xrspatial.utils import (ArrayTypeFunctionMapping, _classify_backend, _validate_raster,
cuda_args, has_cuda_and_cupy, has_dask_array, is_cupy_array,
is_dask_cupy, ngjit, validate_arrays)

TOTAL_COUNT = '_total_count'

Expand Down Expand Up @@ -1928,6 +1928,21 @@ def apply(
# align chunks for 2D values
if values.ndim == 2:
validate_arrays(zones, values)
else:
# 3D values: validate_arrays can't be used because it requires equal
# full shapes (a 2D zones never equals a 3D values). Check backend
# compatibility directly so mixed dask/numpy inputs fail here with a
# clear error instead of crashing in the dask backend with an
# AttributeError or silently returning eager numpy output.
zones_backend = _classify_backend(zones)
values_backend = _classify_backend(values)
if zones_backend != values_backend:
# Wording mirrors validate_arrays() in utils.py so the two stay
# greppable together; the labels replace its "array 0"/"array N".
raise ValueError(
"input arrays must share the same backend; got "
f"'{zones_backend}' (zones) and '{values_backend}' (values)"
)

mapper = ArrayTypeFunctionMapping(
numpy_func=_apply_numpy,
Expand Down
Loading