diff --git a/xrspatial/tests/test_zonal.py b/xrspatial/tests/test_zonal.py index 764b1d9a4..645964078 100644 --- a/xrspatial/tests/test_zonal.py +++ b/xrspatial/tests/test_zonal.py @@ -1546,6 +1546,34 @@ def test_apply_dask_3d_axis2_rechunked_2526(): ) +@pytest.mark.skipif(not dask_array_available(), reason="Requires Dask") +def test_apply_3d_mixed_backend_raises(): + """Regression for #2639: 3D apply() with mixed dask/numpy zones and + values must fail early with a clear backend error, not crash with an + AttributeError or silently return eager numpy output. + """ + zones_data = np.array([[1, 0], + [0, 2]], dtype=np.int32) + values_data = np.ones((2, 2, 3)) * 5.0 + + # numpy zones + dask values: dask backend would hit zones.chunks + zones_np = xr.DataArray(zones_data, dims=['y', 'x']) + values_dask = xr.DataArray( + da.from_array(values_data, chunks=(2, 2, 3)), + dims=['y', 'x', 'band'], + ) + with pytest.raises(ValueError, match="same backend"): + apply(zones_np, values_dask, lambda x: x + 10, nodata=0) + + # dask zones + numpy values: numpy backend would silently go eager + zones_dask = xr.DataArray( + da.from_array(zones_data, chunks=(2, 2)), dims=['y', 'x'], + ) + values_np = xr.DataArray(values_data, dims=['y', 'x', 'band']) + with pytest.raises(ValueError, match="same backend"): + apply(zones_dask, values_np, lambda x: x + 10, nodata=0) + + def test_apply_nodata_none(): zones_data = np.array([[0, 1], [2, 3]], dtype=np.int32) values_data = np.array([[1.0, 2.0], [3.0, 4.0]]) diff --git a/xrspatial/zonal.py b/xrspatial/zonal.py index 18f864e88..6e43f1143 100644 --- a/xrspatial/zonal.py +++ b/xrspatial/zonal.py @@ -35,9 +35,9 @@ class cupy(object): ndarray = False # local modules -from xrspatial.utils import (ArrayTypeFunctionMapping, _validate_raster, cuda_args, - has_cuda_and_cupy, has_dask_array, is_cupy_array, is_dask_cupy, ngjit, - validate_arrays) +from xrspatial.utils import (ArrayTypeFunctionMapping, _classify_backend, _validate_raster, + cuda_args, has_cuda_and_cupy, has_dask_array, is_cupy_array, + is_dask_cupy, ngjit, validate_arrays) TOTAL_COUNT = '_total_count' @@ -1928,6 +1928,21 @@ def apply( # align chunks for 2D values if values.ndim == 2: validate_arrays(zones, values) + else: + # 3D values: validate_arrays can't be used because it requires equal + # full shapes (a 2D zones never equals a 3D values). Check backend + # compatibility directly so mixed dask/numpy inputs fail here with a + # clear error instead of crashing in the dask backend with an + # AttributeError or silently returning eager numpy output. + zones_backend = _classify_backend(zones) + values_backend = _classify_backend(values) + if zones_backend != values_backend: + # Wording mirrors validate_arrays() in utils.py so the two stay + # greppable together; the labels replace its "array 0"/"array N". + raise ValueError( + "input arrays must share the same backend; got " + f"'{zones_backend}' (zones) and '{values_backend}' (values)" + ) mapper = ArrayTypeFunctionMapping( numpy_func=_apply_numpy,