Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,21 @@

### Unreleased

#### Added

- `da.xrs.open_geotiff(source, *, auto_reproject=False, **kwargs)`
DataArray accessor that mirrors the existing Dataset method, plus
backend-aware enhancements on both accessors. The accessor infers
the caller's backend (numpy / cupy / dask+numpy / dask+cupy) and
passes matching `gpu=` / `chunks=` to `open_geotiff` so the
returned DataArray matches the caller. Caller-supplied `gpu=` /
`chunks=` always override the inference. On CRS mismatch between
the caller and the file, the accessor raises a clear `ValueError`
by default (replacing the previous silently-wrong window) and
with `auto_reproject=True` it projects the caller's bbox into the
file's CRS for the windowed read and reprojects the result back
to the caller's CRS via `xrspatial.reproject.reproject`. (#2557)

#### Fixed

- `crosstab(cat_ids=[...])` no longer overcounts when `cat_ids` skips a
Expand Down
266 changes: 242 additions & 24 deletions xrspatial/accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,187 @@ def _listed_colormap_from_attrs(attrs):
return ListedColormap(colors, name='tiff_palette')


def _pick_representative_dataarray(ds, *, var=None):
"""Return the DataArray used to drive backend / CRS lookup for
Dataset-level GeoTIFF accessor methods.

``var`` -> ``ds[var]``. Otherwise the first 2D ``y``/``x`` variable.
"""
if var is not None:
return ds[var]
for v in ds.data_vars:
d = ds[v]
if d.ndim >= 2 and 'y' in d.dims and 'x' in d.dims:
return d
raise ValueError(
"Dataset has no 2D variable with 'y' and 'x' dimensions to use "
"for backend inference. Pass var='<name>' to select one, or "
"call xrspatial.geotiff.open_geotiff(source, ...) directly."
)


def _infer_caller_y_chunk(obj):
"""First y-axis chunk size of a dask-backed DataArray, or None."""
try:
y_axis = obj.get_axis_num('y')
except (ValueError, KeyError):
return None
chunks_tuple = getattr(obj, 'chunks', None)
if not chunks_tuple:
return None
try:
y_chunks = chunks_tuple[y_axis]
except (IndexError, TypeError):
return None
if not y_chunks:
return None
return int(y_chunks[0])


def _to_pyproj_crs(crs):
"""Normalize a CRS value (int EPSG, ``"EPSG:xxxx"``, WKT, PROJ
string, or ``pyproj.CRS``) into a ``pyproj.CRS`` instance.

Returns ``None`` only when ``crs`` is itself ``None`` (the
backward-compatible "no CRS to compare" path). A malformed but
present value raises ``ValueError`` so the caller sees the typo
instead of having the mismatch safety net silently disabled.
"""
if crs is None:
return None
from pyproj import CRS as _PyprojCRS
from pyproj.exceptions import CRSError
try:
return _PyprojCRS(crs)
except CRSError as e:
raise ValueError(
f"attrs['crs']={crs!r} is not a valid CRS: {e}"
) from e


def _bbox_edge_samples(x_min, y_min, x_max, y_max, n_per_side=20):
"""Return ``(xs, ys)`` arrays sampling each edge of the bbox.

Sampling the perimeter (instead of only the 4 corners) before
transforming into the source CRS keeps the envelope of the
transformed points close to the true projected bbox even when the
transform has curvature across the box (high latitudes, large
extents).
"""
import numpy as np
a = np.linspace(x_min, x_max, n_per_side)
b = np.full(n_per_side, y_min)
c = np.full(n_per_side, y_max)
d = np.linspace(y_min, y_max, n_per_side)
e = np.full(n_per_side, x_min)
f = np.full(n_per_side, x_max)
xs = np.concatenate([a, a, e, f])
ys = np.concatenate([b, c, d, d])
return xs, ys


def _open_geotiff_windowed(obj, source, *, auto_reproject=False, **kwargs):
"""Shared implementation for ``.xrs.open_geotiff`` on DataArray and
Dataset.

``obj`` is a DataArray that carries the bbox (``y``/``x`` coords),
the CRS (``attrs['crs']``), and the backend used to infer
``gpu=`` / ``chunks=`` kwargs for the underlying read.

The windowing extent is expanded by half a pixel on each side so
edge pixels of the caller's footprint are captured. CRS values on
both sides (``attrs['crs']`` and the file's georeference) are
normalized through ``pyproj.CRS`` for equality, so int EPSG codes,
``"EPSG:xxxx"`` strings, WKT strings, and ``pyproj.CRS`` instances
all compare correctly.

When ``auto_reproject=True`` and the CRSs differ, the caller's
bbox is sampled along its perimeter (not just the four corners)
before transformation into the file CRS, so the windowed read
covers the full footprint even when the transform has curvature
across the bbox.
"""
from .geotiff import open_geotiff, _read_geo_info, _extent_to_window
from .utils import _classify_backend

if 'y' not in obj.coords or 'x' not in obj.coords:
raise ValueError(
"Caller must have 'y' and 'x' coordinates to compute a "
"spatial window"
)
y = obj.coords['y'].values
x = obj.coords['x'].values
y_min, y_max = float(y.min()), float(y.max())
x_min, x_max = float(x.min()), float(x.max())

geo_info, file_h, file_w, _dtype, _nbands = _read_geo_info(source)
t = geo_info.transform
caller_crs_raw = obj.attrs.get('crs')
file_crs_raw = geo_info.crs_epsg
caller_crs = _to_pyproj_crs(caller_crs_raw)
file_crs = _to_pyproj_crs(file_crs_raw)

crs_mismatch = (
caller_crs is not None
and file_crs is not None
and not caller_crs.equals(file_crs)
)

if crs_mismatch and not auto_reproject:
raise ValueError(
f"CRS mismatch: caller has {caller_crs_raw!r} but file has "
f"EPSG:{file_crs_raw}. Pass auto_reproject=True to project "
f"the caller bbox into the file CRS for the windowed read "
f"and reproject the result back to the caller CRS."
)

if crs_mismatch:
from pyproj import Transformer
transformer = Transformer.from_crs(
caller_crs, file_crs, always_xy=True
)
# 20 samples per side is enough to keep the projected envelope
# within sub-pixel of the true bbox at any realistic latitude
# without making the transform call noticeably slower.
xs, ys = _bbox_edge_samples(x_min, y_min, x_max, y_max)
px, py = transformer.transform(xs, ys)
x_min = float(px.min())
x_max = float(px.max())
y_min = float(py.min())
y_max = float(py.max())

# Expand extent by half a pixel so we capture edge pixels
y_min -= abs(t.pixel_height) * 0.5
y_max += abs(t.pixel_height) * 0.5
x_min -= abs(t.pixel_width) * 0.5
x_max += abs(t.pixel_width) * 0.5

window = _extent_to_window(t, file_h, file_w,
y_min, y_max, x_min, x_max)
kwargs.pop('window', None)

# Infer backend kwargs. Caller-supplied values always win.
backend = _classify_backend(obj)
if backend in ("cupy", "dask+cupy"):
kwargs.setdefault('gpu', True)
if backend in ("dask+numpy", "dask+cupy"):
inferred_chunk = _infer_caller_y_chunk(obj)
if inferred_chunk is not None:
kwargs.setdefault('chunks', inferred_chunk)

result = open_geotiff(source, window=window, **kwargs)

if crs_mismatch:
from .reproject import reproject
result = reproject(
result,
target_crs=caller_crs,
source_crs=file_crs,
)

return result


@xr.register_dataarray_accessor("xrs")
class XrsSpatialDataArrayAccessor:
"""DataArray accessor exposing xarray-spatial operations."""
Expand Down Expand Up @@ -602,6 +783,41 @@ def to_geotiff(self, path, **kwargs):
from .geotiff import to_geotiff
return to_geotiff(self._obj, path, **kwargs)

def open_geotiff(self, source, *, auto_reproject=False, **kwargs):
"""Read a GeoTIFF windowed to this DataArray's spatial extent.

Uses ``self``'s ``y``/``x`` coordinates to compute a pixel window
and reads only that region from the file. The returned DataArray's
backend is matched to ``self``'s backend by inferring
``gpu=`` / ``chunks=`` from ``self`` (caller-supplied ``gpu=`` /
``chunks=`` override the inference).

Parameters
----------
source : str
File path to the GeoTIFF.
auto_reproject : bool
If False (default) and ``self.attrs['crs']`` differs from
the file's CRS, raises ``ValueError``. If True, projects
``self``'s bbox into the file CRS for the windowed read,
then reprojects the result back to ``self``'s CRS via
:func:`xrspatial.reproject.reproject` so the returned
DataArray lines up with ``self``.
**kwargs
Forwarded to :func:`xrspatial.geotiff.open_geotiff` (except
``window=``, which is computed automatically).

Returns
-------
xr.DataArray
The windowed portion of the GeoTIFF, in ``self``'s CRS when
``auto_reproject`` reprojection occurred and otherwise in the
file's native CRS.
"""
return _open_geotiff_windowed(
self._obj, source, auto_reproject=auto_reproject, **kwargs
)

# ---- Chunking ----

def rechunk_no_shuffle(self, **kwargs):
Expand Down Expand Up @@ -1065,50 +1281,52 @@ def to_geotiff(self, path, var=None, **kwargs):
"Dataset has no variable with 'y' and 'x' dimensions to write"
)

def open_geotiff(self, source, **kwargs):
def open_geotiff(self, source, *, auto_reproject=False, var=None, **kwargs):
"""Read a GeoTIFF windowed to this Dataset's spatial extent.

Uses the Dataset's y/x coordinates to compute a pixel window,
then reads only that region from the file.
Uses the Dataset's ``y``/``x`` coordinates to compute a pixel
window and reads only that region from the file. The returned
DataArray's backend is matched to the Dataset's first 2D
``y``/``x`` data variable (or ``ds[var]`` if ``var`` is given)
by inferring ``gpu=`` / ``chunks=`` from that variable.
Caller-supplied ``gpu=`` / ``chunks=`` override the inference.

Parameters
----------
source : str
File path to the GeoTIFF.
auto_reproject : bool
If False (default) and the Dataset's CRS differs from the
file's CRS, raises ``ValueError``. If True, projects the
Dataset's bbox into the file CRS for the windowed read and
reprojects the result back to the Dataset's CRS via
:func:`xrspatial.reproject.reproject`.
var : str or None
Data variable used for backend inference and CRS lookup. If
None, picks the first 2D variable with ``y``/``x`` dims.
**kwargs
Passed to :func:`xrspatial.geotiff.open_geotiff` (except
``window``, which is computed automatically).
Forwarded to :func:`xrspatial.geotiff.open_geotiff` (except
``window=``, which is computed automatically).

Returns
-------
xr.DataArray
The windowed portion of the GeoTIFF.
"""
from .geotiff import open_geotiff, _read_geo_info, _extent_to_window
ds = self._obj
if 'y' not in ds.coords or 'x' not in ds.coords:
raise ValueError(
"Dataset must have 'y' and 'x' coordinates to compute "
"a spatial window"
)
y = ds.coords['y'].values
x = ds.coords['x'].values
y_min, y_max = float(y.min()), float(y.max())
x_min, x_max = float(x.min()), float(x.max())

geo_info, file_h, file_w, _dtype, _nbands = _read_geo_info(source)
t = geo_info.transform

# Expand extent by half a pixel so we capture edge pixels
y_min -= abs(t.pixel_height) * 0.5
y_max += abs(t.pixel_height) * 0.5
x_min -= abs(t.pixel_width) * 0.5
x_max += abs(t.pixel_width) * 0.5

window = _extent_to_window(t, file_h, file_w,
y_min, y_max, x_min, x_max)
kwargs.pop('window', None)
return open_geotiff(source, window=window, **kwargs)
rep = _pick_representative_dataarray(ds, var=var)
# Fall back to Dataset-level CRS when the variable lacks it
if 'crs' not in rep.attrs and 'crs' in ds.attrs:
rep = rep.copy()
rep.attrs = {**rep.attrs, 'crs': ds.attrs['crs']}
return _open_geotiff_windowed(
rep, source, auto_reproject=auto_reproject, **kwargs
)

# ---- Chunking ----

Expand Down
Loading
Loading