Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
242 changes: 156 additions & 86 deletions xrspatial/geotiff/_writers/eager.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@

import numbers
import os
import shutil
import uuid
import warnings
from typing import TYPE_CHECKING

Expand Down Expand Up @@ -1175,11 +1177,21 @@ def _write_vrt_tiled(data, vrt_path, *, crs=None, nodata=None,
tiles_dir_name = stem + '_tiles'
tiles_dir = os.path.join(vrt_dir, tiles_dir_name)

# Validate tiles directory
# Validate tiles directory. A non-empty ``tiles_dir`` is treated as a
# prior completed write and refused. Partial output from a *failed*
# write never reaches this name -- tiles are staged in a temp dir and
# only renamed into place once every tile is written (see below), so a
# retry after a failure starts from a clean slate.
if os.path.isdir(tiles_dir) and os.listdir(tiles_dir):
raise FileExistsError(
f"Tiles directory already contains files: {tiles_dir}")
os.makedirs(tiles_dir, exist_ok=True)

# Stage tiles in a temp directory next to the final one. Same parent
# filesystem keeps the final ``os.replace`` atomic. The unique suffix
# avoids collisions between concurrent writers targeting the same VRT.
staging_dir = os.path.join(
vrt_dir, f'{tiles_dir_name}.tmp-{uuid.uuid4().hex}')
os.makedirs(staging_dir, exist_ok=True)

# Resolve CRS. ``numbers.Integral`` covers numpy integer scalars
# (``np.int32``, ``np.int64``) so ``crs=np.int64(4326)`` does not
Expand Down Expand Up @@ -1297,93 +1309,151 @@ def _write_vrt_tiled(data, vrt_path, *, crs=None, nodata=None,
# Zero-padding width for tile names
pad_width = max(2, len(str(max(n_row_tiles, n_col_tiles) - 1)))

tile_paths = []
# Tiles are written into ``staging_dir``; ``tile_names`` records the
# bare filenames so the final tile paths (under ``tiles_dir``) can be
# rebuilt for the VRT index after the atomic rename below.
tile_names = []
delayed_tasks = []

row_offset = 0
for ri in range(n_row_tiles):
if is_dask:
chunk_h = row_chunks[ri]
else:
chunk_h = min(tile_size, height - row_offset)

col_offset = 0
for ci in range(n_col_tiles):
if is_dask:
chunk_w = col_chunks[ci]
else:
chunk_w = min(tile_size, width - col_offset)

tile_name = f'tile_{ri:0{pad_width}d}_{ci:0{pad_width}d}.tif'
tile_path = os.path.join(tiles_dir, tile_name)
tile_paths.append(tile_path)

# Compute per-tile geo_transform
tile_gt = None
if geo_transform is not None:
tile_gt = GeoTransform(
origin_x=geo_transform.origin_x + col_offset * geo_transform.pixel_width,
origin_y=geo_transform.origin_y + row_offset * geo_transform.pixel_height,
pixel_width=geo_transform.pixel_width,
pixel_height=geo_transform.pixel_height,
)
def _cleanup_staging():
shutil.rmtree(staging_dir, ignore_errors=True)

def _safe_write_tile(*args, **kwargs):
# Return the exception instead of raising so a failure in one
# threaded dask task does not abort ``dask.compute`` while sibling
# threads are still writing into the staging dir. The caller raises
# the first captured failure after every task has settled.
try:
_write_single_tile(*args, **kwargs)
except BaseException as exc: # noqa: BLE001 - re-raised by caller
return exc
return None

try:
row_offset = 0
for ri in range(n_row_tiles):
if is_dask:
# Slice the dask array for this chunk
r_end = row_offset + chunk_h
c_end = col_offset + chunk_w
chunk_data = raw[row_offset:r_end, col_offset:c_end]

task = dask.delayed(_write_single_tile)(
chunk_data, tile_path, tile_gt, epsg, wkt_fallback,
nodata, compression, compression_level,
tile_size, predictor, bigtiff, max_z_error,
raster_type=raster_type,
x_resolution=x_res,
y_resolution=y_res,
resolution_unit=res_unit,
gdal_metadata_xml=gdal_meta_xml,
extra_tags=extra_tags_list,
photometric=photometric,
restore_sentinel=restore_sentinel,
allow_internal_only_jpeg=allow_internal_only_jpeg,
allow_unparseable_crs=allow_unparseable_crs)
delayed_tasks.append(task)
chunk_h = row_chunks[ri]
else:
# Numpy: slice and write directly
chunk_data = np_arr[row_offset:row_offset + chunk_h,
col_offset:col_offset + chunk_w]
_write_single_tile(
chunk_data, tile_path, tile_gt, epsg, wkt_fallback,
nodata, compression, compression_level,
tile_size, predictor, bigtiff, max_z_error,
raster_type=raster_type,
x_resolution=x_res,
y_resolution=y_res,
resolution_unit=res_unit,
gdal_metadata_xml=gdal_meta_xml,
extra_tags=extra_tags_list,
photometric=photometric,
restore_sentinel=restore_sentinel,
allow_internal_only_jpeg=allow_internal_only_jpeg,
allow_unparseable_crs=allow_unparseable_crs)

col_offset += chunk_w
row_offset += chunk_h

# Execute all dask tasks.
#
# Each delayed task is an independent ``_write_single_tile`` call on
# a distinct output path, with no shared mutable Python state, so
# the writes are embarrassingly parallel. Using ``scheduler='threads'``
# lets zlib / zstd / LZW release the GIL during compression and the
# OS coalesce concurrent writes; in a 256-tile zstd write on a
# 4096x4096 dask DataArray the wall time drops ~33% versus the
# ``synchronous`` scheduler this used to call.
if delayed_tasks:
import dask
dask.compute(*delayed_tasks, scheduler='threads')

# Write VRT index with relative paths
chunk_h = min(tile_size, height - row_offset)

col_offset = 0
for ci in range(n_col_tiles):
if is_dask:
chunk_w = col_chunks[ci]
else:
chunk_w = min(tile_size, width - col_offset)

tile_name = f'tile_{ri:0{pad_width}d}_{ci:0{pad_width}d}.tif'
tile_path = os.path.join(staging_dir, tile_name)
tile_names.append(tile_name)

# Compute per-tile geo_transform
tile_gt = None
if geo_transform is not None:
tile_gt = GeoTransform(
origin_x=geo_transform.origin_x + col_offset * geo_transform.pixel_width,
origin_y=geo_transform.origin_y + row_offset * geo_transform.pixel_height,
pixel_width=geo_transform.pixel_width,
pixel_height=geo_transform.pixel_height,
)

if is_dask:
# Slice the dask array for this chunk
r_end = row_offset + chunk_h
c_end = col_offset + chunk_w
chunk_data = raw[row_offset:r_end, col_offset:c_end]

task = dask.delayed(_safe_write_tile)(
chunk_data, tile_path, tile_gt, epsg, wkt_fallback,
nodata, compression, compression_level,
tile_size, predictor, bigtiff, max_z_error,
raster_type=raster_type,
x_resolution=x_res,
y_resolution=y_res,
resolution_unit=res_unit,
gdal_metadata_xml=gdal_meta_xml,
extra_tags=extra_tags_list,
photometric=photometric,
restore_sentinel=restore_sentinel,
allow_internal_only_jpeg=allow_internal_only_jpeg,
allow_unparseable_crs=allow_unparseable_crs)
delayed_tasks.append(task)
else:
# Numpy: slice and write directly
chunk_data = np_arr[row_offset:row_offset + chunk_h,
col_offset:col_offset + chunk_w]
_write_single_tile(
chunk_data, tile_path, tile_gt, epsg, wkt_fallback,
nodata, compression, compression_level,
tile_size, predictor, bigtiff, max_z_error,
raster_type=raster_type,
x_resolution=x_res,
y_resolution=y_res,
resolution_unit=res_unit,
gdal_metadata_xml=gdal_meta_xml,
extra_tags=extra_tags_list,
photometric=photometric,
restore_sentinel=restore_sentinel,
allow_internal_only_jpeg=allow_internal_only_jpeg,
allow_unparseable_crs=allow_unparseable_crs)

col_offset += chunk_w
row_offset += chunk_h

# Execute all dask tasks.
#
# Each delayed task is an independent ``_write_single_tile`` call on
# a distinct output path, with no shared mutable Python state, so
# the writes are embarrassingly parallel. Using ``scheduler='threads'``
# lets zlib / zstd / LZW release the GIL during compression and the
# OS coalesce concurrent writes; in a 256-tile zstd write on a
# 4096x4096 dask DataArray the wall time drops ~33% versus the
# ``synchronous`` scheduler this used to call.
#
# ``dask.compute`` raises as soon as one task errors, while sibling
# worker threads may still be writing tiles into the staging dir.
# Cleaning up at that point races those threads and can leave a
# partial staging dir behind. Capture each task's outcome instead
# so every worker reaches a barrier before we inspect results, then
# re-raise the first failure once no thread is still writing.
if delayed_tasks:
import dask
results = dask.compute(*delayed_tasks, scheduler='threads')
for r in results:
if isinstance(r, BaseException):
raise r
except BaseException:
# Any tile failure leaves no partial output: drop the staging dir
# so the final ``tiles_dir`` name is never created and a retry
# starts clean.
_cleanup_staging()
raise

# Every tile is written. Promote the staging dir to its final name in
# one atomic rename so the VRT only ever references a complete tile set.
# A pre-existing *empty* ``tiles_dir`` passes the leftover-state guard
# above (the old code reused it via ``makedirs(exist_ok=True)``), but
# ``os.replace`` onto an existing directory raises on Windows even when
# the target is empty. Drop the empty target first so the rename has a
# clear slot on every platform.
try:
if os.path.isdir(tiles_dir):
os.rmdir(tiles_dir) # only succeeds if empty; guard ensured that
os.replace(staging_dir, tiles_dir)
except BaseException:
_cleanup_staging()
raise

# Write VRT index with relative paths. The VRT lives at ``vrt_path``;
# tile paths now resolve under the final ``tiles_dir``.
tile_paths = [os.path.join(tiles_dir, name) for name in tile_names]
from .._vrt import write_vrt as _write_vrt_fn
_write_vrt_fn(vrt_path, tile_paths, relative=True, nodata=nodata)
try:
_write_vrt_fn(vrt_path, tile_paths, relative=True, nodata=nodata)
except BaseException:
# The index step failed after the rename. Remove the now-renamed
# tile dir too so a retry is not blocked by the leftover-state
# guard above.
shutil.rmtree(tiles_dir, ignore_errors=True)
raise
Loading
Loading