Make focal output dtype consistent across backends#3226
Merged
Conversation
mean() cast to float32 on the cupy and dask+cupy paths while the CPU paths returned float64, losing precision on the GPU. The dask paths of mean, apply, and focal_stats also passed an untyped meta to map_overlap, so the lazy dtype advertised float64 while compute() returned float32 for float32/int input. Keep the dispatched dtype on the GPU mean paths and type every map_overlap meta with data.dtype. Adds parametrized regression tests over all four backends.
brendancol
commented
Jun 10, 2026
brendancol
left a comment
Contributor
Author
There was a problem hiding this comment.
PR Review: Make focal output dtype consistent across backends
Blockers (must fix before merge)
None.
Suggestions (should fix, not blocking)
-
xrspatial/focal.py:385-402: themean()docstring's cupy example still shows output computed under the old float32 cast (array([[0.47928995, ...]])printed at float32 precision) and an outdatedcupy.core.core.ndarrayclass path. With this PR the GPU path computes in float64, so the example no longer matches what a user would see. Worth refreshing while you are touching this contract.
Nits (optional improvements)
-
xrspatial/tests/test_focal.py(new_computed_dtypehelper): the helper is defined mid-file between test sections. Fine as is, but thegeneral_checks.pymodule is where the other cross-backend helpers live if it ever gets a second user.
What looks good
- The fix targets exactly the five sites that produce or advertise the wrong dtype, and nothing else.
hotspots()meta sites were already typed and are untouched. _mean_cupykeeps the device array and inherits the dispatched dtype instead of forcing float32; the new test asserts exact (not approximate) CPU/GPU equality in float64, which holds because both loops accumulate sequentially.- Tests cover 4 backends x 3 input dtypes for
mean()and both dask backends x 3 dtypes forapply()/focal_stats(), asserting advertised dtype == computed dtype, which is the actual regression. - 4x4 input with (2, 2) chunks exercises multi-chunk map_overlap, so the typed meta is checked against real chunk boundaries.
Checklist
- Algorithm matches reference (no algorithm change; dtype plumbing only)
- All implemented backends produce consistent results (verified live on CUDA host)
- NaN handling unchanged and covered by existing suite (258 passed)
- Edge cases covered (float64/float32/int32 inputs)
- Dask chunk boundaries handled correctly (typed meta, depth unchanged)
- No premature materialization (meta typing is graph-construction only)
- Benchmark not needed (bug fix, no new function); note GPU mean now runs in float64, so a throughput drop on large GPU rasters is expected and intentional
- README feature matrix not applicable
- Docstrings: see suggestion about the stale cupy example
brendancol
commented
Jun 10, 2026
brendancol
left a comment
Contributor
Author
There was a problem hiding this comment.
Follow-up review (after 1f20224)
The suggestion from the first pass is addressed: the mean() cupy docstring example now shows the float64 output (0.47928994, verified by running the example on a CUDA host) and the current cupy.ndarray class path.
Disposition of first-pass findings:
- Suggestion (stale cupy docstring example): fixed in 1f20224.
- Nit (_computed_dtype helper location): left in test_focal.py. It has a single consumer; moving a 4-line helper to general_checks.py before a second user exists adds indirection for no benefit.
No new findings. Full focal suite still passes (258).
…ocal-2026-06-10-01
…ocal-2026-06-10-01 Conflicts: - xrspatial/focal.py: main's #3221 (issue #3214, a duplicate of #3217) changed mean() to the _promote_float contract; kept main's GPU/dask cast lines, this branch keeps the typed map_overlap metas. - xrspatial/tests/test_focal.py: aligned the 3217 mean dtype test with the _promote_float contract (float32 in -> float32 out). - .claude/sweep-metadata-state.csv: restored LF line endings, kept main's geotiff row and this branch's focal row (notes updated).
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Closes #3217
mean(): drop the hardcoded float32 cast in_mean_cupyand_mean_dask_cupy. The function already casts input to float64 before dispatch, so all four backends now return float64 and the GPU result matches the CPU result exactly (the float32 cast cost ~1e-4 relative error).apply()/focal_stats()/mean()dask paths: pass a typedmeta(dtype=data.dtype) to everymap_overlapcall, so the lazy DataArray advertises the dtype the chunks actually compute. Before, float32 and integer input advertised float64 but computed float32. Same fix as aspect() planar dask backends report float64 dtype but compute float32 #2682 (aspect) and proximity/allocation/direction: output dtype and .name differ across backends #2723 (proximity).Backend coverage: numpy, cupy, dask+numpy, dask+cupy all verified live on this host (CUDA available).
Test plan:
test_mean_dtype_consistent_across_backends_3217(4 backends x 3 input dtypes),test_apply_dask_advertised_dtype_matches_computed_3217,test_focal_stats_dask_advertised_dtype_matches_computed_3217(dask backends x 3 input dtypes), andtest_mean_gpu_matches_cpu_float64_3217(exact CPU/GPU value parity).test_focal.pysuite: 258 passed.