diff --git a/CHANGELOG.md b/CHANGELOG.md index 83ce68646..349a15819 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,21 @@ ### Unreleased +#### Added + +- `da.xrs.open_geotiff(source, *, auto_reproject=False, **kwargs)` + DataArray accessor that mirrors the existing Dataset method, plus + backend-aware enhancements on both accessors. The accessor infers + the caller's backend (numpy / cupy / dask+numpy / dask+cupy) and + passes matching `gpu=` / `chunks=` to `open_geotiff` so the + returned DataArray matches the caller. Caller-supplied `gpu=` / + `chunks=` always override the inference. On CRS mismatch between + the caller and the file, the accessor raises a clear `ValueError` + by default (replacing the previous silently-wrong window) and + with `auto_reproject=True` it projects the caller's bbox into the + file's CRS for the windowed read and reprojects the result back + to the caller's CRS via `xrspatial.reproject.reproject`. (#2557) + #### Fixed - `crosstab(cat_ids=[...])` no longer overcounts when `cat_ids` skips a diff --git a/xrspatial/accessor.py b/xrspatial/accessor.py index 6eebb042a..05198d850 100644 --- a/xrspatial/accessor.py +++ b/xrspatial/accessor.py @@ -67,6 +67,187 @@ def _listed_colormap_from_attrs(attrs): return ListedColormap(colors, name='tiff_palette') +def _pick_representative_dataarray(ds, *, var=None): + """Return the DataArray used to drive backend / CRS lookup for + Dataset-level GeoTIFF accessor methods. + + ``var`` -> ``ds[var]``. Otherwise the first 2D ``y``/``x`` variable. + """ + if var is not None: + return ds[var] + for v in ds.data_vars: + d = ds[v] + if d.ndim >= 2 and 'y' in d.dims and 'x' in d.dims: + return d + raise ValueError( + "Dataset has no 2D variable with 'y' and 'x' dimensions to use " + "for backend inference. Pass var='' to select one, or " + "call xrspatial.geotiff.open_geotiff(source, ...) directly." + ) + + +def _infer_caller_y_chunk(obj): + """First y-axis chunk size of a dask-backed DataArray, or None.""" + try: + y_axis = obj.get_axis_num('y') + except (ValueError, KeyError): + return None + chunks_tuple = getattr(obj, 'chunks', None) + if not chunks_tuple: + return None + try: + y_chunks = chunks_tuple[y_axis] + except (IndexError, TypeError): + return None + if not y_chunks: + return None + return int(y_chunks[0]) + + +def _to_pyproj_crs(crs): + """Normalize a CRS value (int EPSG, ``"EPSG:xxxx"``, WKT, PROJ + string, or ``pyproj.CRS``) into a ``pyproj.CRS`` instance. + + Returns ``None`` only when ``crs`` is itself ``None`` (the + backward-compatible "no CRS to compare" path). A malformed but + present value raises ``ValueError`` so the caller sees the typo + instead of having the mismatch safety net silently disabled. + """ + if crs is None: + return None + from pyproj import CRS as _PyprojCRS + from pyproj.exceptions import CRSError + try: + return _PyprojCRS(crs) + except CRSError as e: + raise ValueError( + f"attrs['crs']={crs!r} is not a valid CRS: {e}" + ) from e + + +def _bbox_edge_samples(x_min, y_min, x_max, y_max, n_per_side=20): + """Return ``(xs, ys)`` arrays sampling each edge of the bbox. + + Sampling the perimeter (instead of only the 4 corners) before + transforming into the source CRS keeps the envelope of the + transformed points close to the true projected bbox even when the + transform has curvature across the box (high latitudes, large + extents). + """ + import numpy as np + a = np.linspace(x_min, x_max, n_per_side) + b = np.full(n_per_side, y_min) + c = np.full(n_per_side, y_max) + d = np.linspace(y_min, y_max, n_per_side) + e = np.full(n_per_side, x_min) + f = np.full(n_per_side, x_max) + xs = np.concatenate([a, a, e, f]) + ys = np.concatenate([b, c, d, d]) + return xs, ys + + +def _open_geotiff_windowed(obj, source, *, auto_reproject=False, **kwargs): + """Shared implementation for ``.xrs.open_geotiff`` on DataArray and + Dataset. + + ``obj`` is a DataArray that carries the bbox (``y``/``x`` coords), + the CRS (``attrs['crs']``), and the backend used to infer + ``gpu=`` / ``chunks=`` kwargs for the underlying read. + + The windowing extent is expanded by half a pixel on each side so + edge pixels of the caller's footprint are captured. CRS values on + both sides (``attrs['crs']`` and the file's georeference) are + normalized through ``pyproj.CRS`` for equality, so int EPSG codes, + ``"EPSG:xxxx"`` strings, WKT strings, and ``pyproj.CRS`` instances + all compare correctly. + + When ``auto_reproject=True`` and the CRSs differ, the caller's + bbox is sampled along its perimeter (not just the four corners) + before transformation into the file CRS, so the windowed read + covers the full footprint even when the transform has curvature + across the bbox. + """ + from .geotiff import open_geotiff, _read_geo_info, _extent_to_window + from .utils import _classify_backend + + if 'y' not in obj.coords or 'x' not in obj.coords: + raise ValueError( + "Caller must have 'y' and 'x' coordinates to compute a " + "spatial window" + ) + y = obj.coords['y'].values + x = obj.coords['x'].values + y_min, y_max = float(y.min()), float(y.max()) + x_min, x_max = float(x.min()), float(x.max()) + + geo_info, file_h, file_w, _dtype, _nbands = _read_geo_info(source) + t = geo_info.transform + caller_crs_raw = obj.attrs.get('crs') + file_crs_raw = geo_info.crs_epsg + caller_crs = _to_pyproj_crs(caller_crs_raw) + file_crs = _to_pyproj_crs(file_crs_raw) + + crs_mismatch = ( + caller_crs is not None + and file_crs is not None + and not caller_crs.equals(file_crs) + ) + + if crs_mismatch and not auto_reproject: + raise ValueError( + f"CRS mismatch: caller has {caller_crs_raw!r} but file has " + f"EPSG:{file_crs_raw}. Pass auto_reproject=True to project " + f"the caller bbox into the file CRS for the windowed read " + f"and reproject the result back to the caller CRS." + ) + + if crs_mismatch: + from pyproj import Transformer + transformer = Transformer.from_crs( + caller_crs, file_crs, always_xy=True + ) + # 20 samples per side is enough to keep the projected envelope + # within sub-pixel of the true bbox at any realistic latitude + # without making the transform call noticeably slower. + xs, ys = _bbox_edge_samples(x_min, y_min, x_max, y_max) + px, py = transformer.transform(xs, ys) + x_min = float(px.min()) + x_max = float(px.max()) + y_min = float(py.min()) + y_max = float(py.max()) + + # Expand extent by half a pixel so we capture edge pixels + y_min -= abs(t.pixel_height) * 0.5 + y_max += abs(t.pixel_height) * 0.5 + x_min -= abs(t.pixel_width) * 0.5 + x_max += abs(t.pixel_width) * 0.5 + + window = _extent_to_window(t, file_h, file_w, + y_min, y_max, x_min, x_max) + kwargs.pop('window', None) + + # Infer backend kwargs. Caller-supplied values always win. + backend = _classify_backend(obj) + if backend in ("cupy", "dask+cupy"): + kwargs.setdefault('gpu', True) + if backend in ("dask+numpy", "dask+cupy"): + inferred_chunk = _infer_caller_y_chunk(obj) + if inferred_chunk is not None: + kwargs.setdefault('chunks', inferred_chunk) + + result = open_geotiff(source, window=window, **kwargs) + + if crs_mismatch: + from .reproject import reproject + result = reproject( + result, + target_crs=caller_crs, + source_crs=file_crs, + ) + + return result + + @xr.register_dataarray_accessor("xrs") class XrsSpatialDataArrayAccessor: """DataArray accessor exposing xarray-spatial operations.""" @@ -602,6 +783,41 @@ def to_geotiff(self, path, **kwargs): from .geotiff import to_geotiff return to_geotiff(self._obj, path, **kwargs) + def open_geotiff(self, source, *, auto_reproject=False, **kwargs): + """Read a GeoTIFF windowed to this DataArray's spatial extent. + + Uses ``self``'s ``y``/``x`` coordinates to compute a pixel window + and reads only that region from the file. The returned DataArray's + backend is matched to ``self``'s backend by inferring + ``gpu=`` / ``chunks=`` from ``self`` (caller-supplied ``gpu=`` / + ``chunks=`` override the inference). + + Parameters + ---------- + source : str + File path to the GeoTIFF. + auto_reproject : bool + If False (default) and ``self.attrs['crs']`` differs from + the file's CRS, raises ``ValueError``. If True, projects + ``self``'s bbox into the file CRS for the windowed read, + then reprojects the result back to ``self``'s CRS via + :func:`xrspatial.reproject.reproject` so the returned + DataArray lines up with ``self``. + **kwargs + Forwarded to :func:`xrspatial.geotiff.open_geotiff` (except + ``window=``, which is computed automatically). + + Returns + ------- + xr.DataArray + The windowed portion of the GeoTIFF, in ``self``'s CRS when + ``auto_reproject`` reprojection occurred and otherwise in the + file's native CRS. + """ + return _open_geotiff_windowed( + self._obj, source, auto_reproject=auto_reproject, **kwargs + ) + # ---- Chunking ---- def rechunk_no_shuffle(self, **kwargs): @@ -1065,50 +1281,52 @@ def to_geotiff(self, path, var=None, **kwargs): "Dataset has no variable with 'y' and 'x' dimensions to write" ) - def open_geotiff(self, source, **kwargs): + def open_geotiff(self, source, *, auto_reproject=False, var=None, **kwargs): """Read a GeoTIFF windowed to this Dataset's spatial extent. - Uses the Dataset's y/x coordinates to compute a pixel window, - then reads only that region from the file. + Uses the Dataset's ``y``/``x`` coordinates to compute a pixel + window and reads only that region from the file. The returned + DataArray's backend is matched to the Dataset's first 2D + ``y``/``x`` data variable (or ``ds[var]`` if ``var`` is given) + by inferring ``gpu=`` / ``chunks=`` from that variable. + Caller-supplied ``gpu=`` / ``chunks=`` override the inference. Parameters ---------- source : str File path to the GeoTIFF. + auto_reproject : bool + If False (default) and the Dataset's CRS differs from the + file's CRS, raises ``ValueError``. If True, projects the + Dataset's bbox into the file CRS for the windowed read and + reprojects the result back to the Dataset's CRS via + :func:`xrspatial.reproject.reproject`. + var : str or None + Data variable used for backend inference and CRS lookup. If + None, picks the first 2D variable with ``y``/``x`` dims. **kwargs - Passed to :func:`xrspatial.geotiff.open_geotiff` (except - ``window``, which is computed automatically). + Forwarded to :func:`xrspatial.geotiff.open_geotiff` (except + ``window=``, which is computed automatically). Returns ------- xr.DataArray The windowed portion of the GeoTIFF. """ - from .geotiff import open_geotiff, _read_geo_info, _extent_to_window ds = self._obj if 'y' not in ds.coords or 'x' not in ds.coords: raise ValueError( "Dataset must have 'y' and 'x' coordinates to compute " "a spatial window" ) - y = ds.coords['y'].values - x = ds.coords['x'].values - y_min, y_max = float(y.min()), float(y.max()) - x_min, x_max = float(x.min()), float(x.max()) - - geo_info, file_h, file_w, _dtype, _nbands = _read_geo_info(source) - t = geo_info.transform - - # Expand extent by half a pixel so we capture edge pixels - y_min -= abs(t.pixel_height) * 0.5 - y_max += abs(t.pixel_height) * 0.5 - x_min -= abs(t.pixel_width) * 0.5 - x_max += abs(t.pixel_width) * 0.5 - - window = _extent_to_window(t, file_h, file_w, - y_min, y_max, x_min, x_max) - kwargs.pop('window', None) - return open_geotiff(source, window=window, **kwargs) + rep = _pick_representative_dataarray(ds, var=var) + # Fall back to Dataset-level CRS when the variable lacks it + if 'crs' not in rep.attrs and 'crs' in ds.attrs: + rep = rep.copy() + rep.attrs = {**rep.attrs, 'crs': ds.attrs['crs']} + return _open_geotiff_windowed( + rep, source, auto_reproject=auto_reproject, **kwargs + ) # ---- Chunking ---- diff --git a/xrspatial/geotiff/tests/integration/test_dask_pipeline.py b/xrspatial/geotiff/tests/integration/test_dask_pipeline.py index acce7e03f..b87959ecb 100644 --- a/xrspatial/geotiff/tests/integration/test_dask_pipeline.py +++ b/xrspatial/geotiff/tests/integration/test_dask_pipeline.py @@ -1034,3 +1034,318 @@ def test_kwargs_forwarded(self, tmp_path): }) result = template.xrs.open_geotiff(path, name='myname') assert result.name == 'myname' + + +# --------------------------------------------------------------------------- +# DataArray.xrs.open_geotiff (issue #2557 - DataArray-side windowed read +# with backend inference and auto-reproject on CRS mismatch) +# --------------------------------------------------------------------------- + + +class TestDataArrayOpenGeotiff_2557: + def test_windowed_read(self, tmp_path): + big = _make_da_accessor_io(height=20, width=20) + big_path = str(tmp_path / 'test_2557_da_window.tif') + to_geotiff(big, big_path, compression='none') + + y_sub = big.coords['y'].values[5:15] + x_sub = big.coords['x'].values[5:15] + template = xr.DataArray( + np.zeros((len(y_sub), len(x_sub)), dtype=np.float32), + dims=['y', 'x'], + coords={'y': y_sub, 'x': x_sub}, + attrs={'crs': 4326}, + ) + result = template.xrs.open_geotiff(big_path) + # Half-pixel expansion may include one extra edge row/col + assert len(y_sub) <= result.shape[0] <= len(y_sub) + 2 + assert len(x_sub) <= result.shape[1] <= len(x_sub) + 2 + + def test_no_coords_raises(self, tmp_path): + da = _make_da_accessor_io() + path = str(tmp_path / 'test_2557_da_nocoords.tif') + to_geotiff(da, path, compression='none') + + bad = xr.DataArray(np.zeros(5), dims=['z']) + with pytest.raises(ValueError, match="'y' and 'x' coordinates"): + bad.xrs.open_geotiff(path) + + def test_kwargs_forwarded(self, tmp_path): + da = _make_da_accessor_io(height=8, width=10) + path = str(tmp_path / 'test_2557_da_kwargs.tif') + to_geotiff(da, path, compression='none') + + template = xr.DataArray( + np.zeros_like(da.values), + dims=['y', 'x'], + coords={'y': da.coords['y'].values, + 'x': da.coords['x'].values}, + attrs={'crs': 4326}, + ) + result = template.xrs.open_geotiff(path, name='myname') + assert result.name == 'myname' + + +# --------------------------------------------------------------------------- +# Backend inference: caller backend -> result backend (issue #2557) +# --------------------------------------------------------------------------- + + +class TestOpenGeotiffBackendInference_2557: + def test_numpy_caller_returns_numpy(self, tmp_path): + big = _make_da_accessor_io(height=10, width=10) + path = str(tmp_path / 'test_2557_numpy.tif') + to_geotiff(big, path, compression='none') + + template = xr.DataArray( + np.zeros((10, 10), dtype=np.float32), + dims=['y', 'x'], + coords={'y': big.coords['y'].values, + 'x': big.coords['x'].values}, + attrs={'crs': 4326}, + ) + result = template.xrs.open_geotiff(path) + assert isinstance(result.data, np.ndarray) + assert result.chunks is None + + def test_dask_caller_returns_dask_with_inferred_chunks(self, tmp_path): + import dask.array as dda + big = _make_da_accessor_io(height=12, width=12) + path = str(tmp_path / 'test_2557_dask.tif') + to_geotiff(big, path, compression='none') + + template = xr.DataArray( + dda.zeros((12, 12), chunks=(4, 4), dtype=np.float32), + dims=['y', 'x'], + coords={'y': big.coords['y'].values, + 'x': big.coords['x'].values}, + attrs={'crs': 4326}, + ) + result = template.xrs.open_geotiff(path) + assert isinstance(result.data, dda.Array) + # First y-chunk should match the caller (4) + assert result.chunks[0][0] == 4 + + def test_explicit_chunks_override_inference(self, tmp_path): + import dask.array as dda + big = _make_da_accessor_io(height=12, width=12) + path = str(tmp_path / 'test_2557_chunks_override.tif') + to_geotiff(big, path, compression='none') + + template = xr.DataArray( + dda.zeros((12, 12), chunks=(4, 4), dtype=np.float32), + dims=['y', 'x'], + coords={'y': big.coords['y'].values, + 'x': big.coords['x'].values}, + attrs={'crs': 4326}, + ) + result = template.xrs.open_geotiff(path, chunks=3) + assert isinstance(result.data, dda.Array) + # Caller asked for 3, override wins over inferred 4 + assert result.chunks[0][0] == 3 + + def test_dataset_caller_infers_from_first_var(self, tmp_path): + import dask.array as dda + big = _make_da_accessor_io(height=10, width=10) + path = str(tmp_path / 'test_2557_ds_backend.tif') + to_geotiff(big, path, compression='none') + + var = xr.DataArray( + dda.zeros((10, 10), chunks=(5, 5), dtype=np.float32), + dims=['y', 'x'], + coords={'y': big.coords['y'].values, + 'x': big.coords['x'].values}, + attrs={'crs': 4326}, + ) + ds = xr.Dataset({'elevation': var}) + result = ds.xrs.open_geotiff(path) + assert isinstance(result.data, dda.Array) + assert result.chunks[0][0] == 5 + + +# --------------------------------------------------------------------------- +# CRS mismatch handling (issue #2557) +# --------------------------------------------------------------------------- + + +class TestOpenGeotiffCRSMismatch_2557: + def test_mismatch_raises_by_default(self, tmp_path): + # File in EPSG:4326 + big = _make_da_accessor_io(height=10, width=10, crs=4326) + path = str(tmp_path / 'test_2557_mismatch.tif') + to_geotiff(big, path, compression='none') + + # Caller in EPSG:3857 + template = xr.DataArray( + np.zeros((4, 4), dtype=np.float32), + dims=['y', 'x'], + coords={'y': np.linspace(5e6, 4.9e6, 4), + 'x': np.linspace(-1.3e7, -1.32e7, 4)}, + attrs={'crs': 3857}, + ) + with pytest.raises(ValueError, match="CRS mismatch"): + template.xrs.open_geotiff(path) + + def test_auto_reproject_returns_caller_crs(self, tmp_path): + # File in EPSG:4326 over a small mid-latitude box + height, width = 30, 30 + arr = np.arange(height * width, + dtype=np.float32).reshape(height, width) + y = np.linspace(45.5, 44.5, height) + x = np.linspace(-120.5, -119.5, width) + file_da = xr.DataArray( + arr, dims=['y', 'x'], + coords={'y': y, 'x': x}, + attrs={'crs': 4326}, + ) + path = str(tmp_path / 'test_2557_autoreproj.tif') + to_geotiff(file_da, path, compression='none') + + # Caller in EPSG:3857 over the same region + from pyproj import Transformer + tr = Transformer.from_crs(4326, 3857, always_xy=True) + x0, y0 = tr.transform(-120.25, 45.25) + x1, y1 = tr.transform(-119.75, 44.75) + template = xr.DataArray( + np.zeros((6, 6), dtype=np.float32), + dims=['y', 'x'], + coords={'y': np.linspace(max(y0, y1), min(y0, y1), 6), + 'x': np.linspace(min(x0, x1), max(x0, x1), 6)}, + attrs={'crs': 3857}, + ) + result = template.xrs.open_geotiff(path, auto_reproject=True) + # Result coords are in caller's CRS (mercator metres). Check the + # bbox roughly matches the caller's bbox so a future regression + # in projection direction would be caught. Tolerance is one + # output pixel's worth of metres at this latitude (~150 km of + # bbox / 6 pixels). + tol = abs(float(template.coords['x'][1] - template.coords['x'][0])) + assert abs(float(result.coords['x'].min()) + - float(template.coords['x'].min())) < 2 * tol + assert abs(float(result.coords['x'].max()) + - float(template.coords['x'].max())) < 2 * tol + + def test_chained_open_geotiff_with_wkt_crs(self, tmp_path): + # After auto_reproject, xrspatial.reproject sets attrs['crs'] to + # a WKT string. Calling open_geotiff again on that DataArray must + # not crash on int() conversion (regression for PR #2598 review). + height, width = 20, 20 + arr = np.arange(height * width, + dtype=np.float32).reshape(height, width) + y = np.linspace(45.5, 44.5, height) + x = np.linspace(-120.5, -119.5, width) + file_da = xr.DataArray( + arr, dims=['y', 'x'], + coords={'y': y, 'x': x}, + attrs={'crs': 4326}, + ) + path = str(tmp_path / 'test_2557_wkt.tif') + to_geotiff(file_da, path, compression='none') + + # Synthesize a caller with a WKT-string crs (what reproject sets) + from pyproj import CRS as _PyprojCRS + wkt_3857 = _PyprojCRS(3857).to_wkt() + from pyproj import Transformer + tr = Transformer.from_crs(4326, 3857, always_xy=True) + x0, y0 = tr.transform(-120.25, 45.25) + x1, y1 = tr.transform(-119.75, 44.75) + template = xr.DataArray( + np.zeros((4, 4), dtype=np.float32), + dims=['y', 'x'], + coords={'y': np.linspace(max(y0, y1), min(y0, y1), 4), + 'x': np.linspace(min(x0, x1), max(x0, x1), 4)}, + attrs={'crs': wkt_3857}, + ) + # Should raise ValueError (mismatch), not TypeError/ValueError + # from int(wkt_string) + with pytest.raises(ValueError, match="CRS mismatch"): + template.xrs.open_geotiff(path) + # And auto_reproject path also works. reproject preserves the + # file's native pixel resolution, so the result shape reflects + # the windowed slice of the file rather than the caller's + # template shape -- the (20x20-pixel, 1deg-extent) file at a + # ~0.5deg caller bbox gives ~10 pixels per side plus boundary + # rounding. Bound generously. + result = template.xrs.open_geotiff(path, auto_reproject=True) + assert result.ndim == 2 + assert 4 <= result.shape[0] <= 20 + assert 4 <= result.shape[1] <= 20 + + def test_no_caller_crs_no_mismatch_check(self, tmp_path): + # Caller without attrs['crs'] should skip mismatch logic and + # proceed (preserves prior behavior). + big = _make_da_accessor_io(height=10, width=10, crs=4326) + path = str(tmp_path / 'test_2557_nocrs.tif') + to_geotiff(big, path, compression='none') + + template = xr.DataArray( + np.zeros((10, 10), dtype=np.float32), + dims=['y', 'x'], + coords={'y': big.coords['y'].values, + 'x': big.coords['x'].values}, + # No attrs['crs'] + ) + # Should not raise + result = template.xrs.open_geotiff(path) + np.testing.assert_array_equal(result.values, big.values) + + def test_dataset_crs_falls_back_from_ds_attrs(self, tmp_path): + # CRS on Dataset attrs, not on the data_var, should still be + # picked up. + big = _make_da_accessor_io(height=10, width=10, crs=4326) + path = str(tmp_path / 'test_2557_ds_crs_attr.tif') + to_geotiff(big, path, compression='none') + + var = xr.DataArray( + np.zeros((4, 4), dtype=np.float32), + dims=['y', 'x'], + coords={'y': np.linspace(5e6, 4.9e6, 4), + 'x': np.linspace(-1.3e7, -1.32e7, 4)}, + ) + ds = xr.Dataset({'elevation': var}, attrs={'crs': 3857}) + with pytest.raises(ValueError, match="CRS mismatch"): + ds.xrs.open_geotiff(path) + + def test_malformed_crs_raises(self, tmp_path): + # A garbage attrs['crs'] must raise rather than silently skip + # the mismatch check (follow-up review hardening). + big = _make_da_accessor_io(height=10, width=10, crs=4326) + path = str(tmp_path / 'test_2557_bad_crs.tif') + to_geotiff(big, path, compression='none') + + template = xr.DataArray( + np.zeros((4, 4), dtype=np.float32), + dims=['y', 'x'], + coords={'y': big.coords['y'].values[:4], + 'x': big.coords['x'].values[:4]}, + attrs={'crs': 'not-a-real-crs'}, + ) + with pytest.raises(ValueError, match="not a valid CRS"): + template.xrs.open_geotiff(path) + + +# --------------------------------------------------------------------------- +# GPU backend inference (gated on CUDA + cupy availability) - issue #2557 +# --------------------------------------------------------------------------- + +from xrspatial.tests.general_checks import cuda_and_cupy_available # noqa: E402 + + +@cuda_and_cupy_available +class TestOpenGeotiffGPUBackendInference_2557: + def test_cupy_caller_returns_cupy(self, tmp_path): + import cupy as cp + big = _make_da_accessor_io(height=10, width=10) + path = str(tmp_path / 'test_2557_cupy.tif') + to_geotiff(big, path, compression='none') + + template = xr.DataArray( + cp.zeros((10, 10), dtype=cp.float32), + dims=['y', 'x'], + coords={'y': big.coords['y'].values, + 'x': big.coords['x'].values}, + attrs={'crs': 4326}, + ) + result = template.xrs.open_geotiff(path) + # Result should be cupy-backed (gpu=True was inferred) + assert type(result.data).__module__.split('.')[0] == 'cupy'