diff --git a/xrspatial/geotiff/_writers/eager.py b/xrspatial/geotiff/_writers/eager.py index 484042bfe..0c027d82b 100644 --- a/xrspatial/geotiff/_writers/eager.py +++ b/xrspatial/geotiff/_writers/eager.py @@ -12,6 +12,8 @@ import numbers import os +import shutil +import uuid import warnings from typing import TYPE_CHECKING @@ -1175,11 +1177,21 @@ def _write_vrt_tiled(data, vrt_path, *, crs=None, nodata=None, tiles_dir_name = stem + '_tiles' tiles_dir = os.path.join(vrt_dir, tiles_dir_name) - # Validate tiles directory + # Validate tiles directory. A non-empty ``tiles_dir`` is treated as a + # prior completed write and refused. Partial output from a *failed* + # write never reaches this name -- tiles are staged in a temp dir and + # only renamed into place once every tile is written (see below), so a + # retry after a failure starts from a clean slate. if os.path.isdir(tiles_dir) and os.listdir(tiles_dir): raise FileExistsError( f"Tiles directory already contains files: {tiles_dir}") - os.makedirs(tiles_dir, exist_ok=True) + + # Stage tiles in a temp directory next to the final one. Same parent + # filesystem keeps the final ``os.replace`` atomic. The unique suffix + # avoids collisions between concurrent writers targeting the same VRT. + staging_dir = os.path.join( + vrt_dir, f'{tiles_dir_name}.tmp-{uuid.uuid4().hex}') + os.makedirs(staging_dir, exist_ok=True) # Resolve CRS. ``numbers.Integral`` covers numpy integer scalars # (``np.int32``, ``np.int64``) so ``crs=np.int64(4326)`` does not @@ -1297,93 +1309,151 @@ def _write_vrt_tiled(data, vrt_path, *, crs=None, nodata=None, # Zero-padding width for tile names pad_width = max(2, len(str(max(n_row_tiles, n_col_tiles) - 1))) - tile_paths = [] + # Tiles are written into ``staging_dir``; ``tile_names`` records the + # bare filenames so the final tile paths (under ``tiles_dir``) can be + # rebuilt for the VRT index after the atomic rename below. + tile_names = [] delayed_tasks = [] - row_offset = 0 - for ri in range(n_row_tiles): - if is_dask: - chunk_h = row_chunks[ri] - else: - chunk_h = min(tile_size, height - row_offset) - - col_offset = 0 - for ci in range(n_col_tiles): - if is_dask: - chunk_w = col_chunks[ci] - else: - chunk_w = min(tile_size, width - col_offset) - - tile_name = f'tile_{ri:0{pad_width}d}_{ci:0{pad_width}d}.tif' - tile_path = os.path.join(tiles_dir, tile_name) - tile_paths.append(tile_path) - - # Compute per-tile geo_transform - tile_gt = None - if geo_transform is not None: - tile_gt = GeoTransform( - origin_x=geo_transform.origin_x + col_offset * geo_transform.pixel_width, - origin_y=geo_transform.origin_y + row_offset * geo_transform.pixel_height, - pixel_width=geo_transform.pixel_width, - pixel_height=geo_transform.pixel_height, - ) + def _cleanup_staging(): + shutil.rmtree(staging_dir, ignore_errors=True) + def _safe_write_tile(*args, **kwargs): + # Return the exception instead of raising so a failure in one + # threaded dask task does not abort ``dask.compute`` while sibling + # threads are still writing into the staging dir. The caller raises + # the first captured failure after every task has settled. + try: + _write_single_tile(*args, **kwargs) + except BaseException as exc: # noqa: BLE001 - re-raised by caller + return exc + return None + + try: + row_offset = 0 + for ri in range(n_row_tiles): if is_dask: - # Slice the dask array for this chunk - r_end = row_offset + chunk_h - c_end = col_offset + chunk_w - chunk_data = raw[row_offset:r_end, col_offset:c_end] - - task = dask.delayed(_write_single_tile)( - chunk_data, tile_path, tile_gt, epsg, wkt_fallback, - nodata, compression, compression_level, - tile_size, predictor, bigtiff, max_z_error, - raster_type=raster_type, - x_resolution=x_res, - y_resolution=y_res, - resolution_unit=res_unit, - gdal_metadata_xml=gdal_meta_xml, - extra_tags=extra_tags_list, - photometric=photometric, - restore_sentinel=restore_sentinel, - allow_internal_only_jpeg=allow_internal_only_jpeg, - allow_unparseable_crs=allow_unparseable_crs) - delayed_tasks.append(task) + chunk_h = row_chunks[ri] else: - # Numpy: slice and write directly - chunk_data = np_arr[row_offset:row_offset + chunk_h, - col_offset:col_offset + chunk_w] - _write_single_tile( - chunk_data, tile_path, tile_gt, epsg, wkt_fallback, - nodata, compression, compression_level, - tile_size, predictor, bigtiff, max_z_error, - raster_type=raster_type, - x_resolution=x_res, - y_resolution=y_res, - resolution_unit=res_unit, - gdal_metadata_xml=gdal_meta_xml, - extra_tags=extra_tags_list, - photometric=photometric, - restore_sentinel=restore_sentinel, - allow_internal_only_jpeg=allow_internal_only_jpeg, - allow_unparseable_crs=allow_unparseable_crs) - - col_offset += chunk_w - row_offset += chunk_h - - # Execute all dask tasks. - # - # Each delayed task is an independent ``_write_single_tile`` call on - # a distinct output path, with no shared mutable Python state, so - # the writes are embarrassingly parallel. Using ``scheduler='threads'`` - # lets zlib / zstd / LZW release the GIL during compression and the - # OS coalesce concurrent writes; in a 256-tile zstd write on a - # 4096x4096 dask DataArray the wall time drops ~33% versus the - # ``synchronous`` scheduler this used to call. - if delayed_tasks: - import dask - dask.compute(*delayed_tasks, scheduler='threads') - - # Write VRT index with relative paths + chunk_h = min(tile_size, height - row_offset) + + col_offset = 0 + for ci in range(n_col_tiles): + if is_dask: + chunk_w = col_chunks[ci] + else: + chunk_w = min(tile_size, width - col_offset) + + tile_name = f'tile_{ri:0{pad_width}d}_{ci:0{pad_width}d}.tif' + tile_path = os.path.join(staging_dir, tile_name) + tile_names.append(tile_name) + + # Compute per-tile geo_transform + tile_gt = None + if geo_transform is not None: + tile_gt = GeoTransform( + origin_x=geo_transform.origin_x + col_offset * geo_transform.pixel_width, + origin_y=geo_transform.origin_y + row_offset * geo_transform.pixel_height, + pixel_width=geo_transform.pixel_width, + pixel_height=geo_transform.pixel_height, + ) + + if is_dask: + # Slice the dask array for this chunk + r_end = row_offset + chunk_h + c_end = col_offset + chunk_w + chunk_data = raw[row_offset:r_end, col_offset:c_end] + + task = dask.delayed(_safe_write_tile)( + chunk_data, tile_path, tile_gt, epsg, wkt_fallback, + nodata, compression, compression_level, + tile_size, predictor, bigtiff, max_z_error, + raster_type=raster_type, + x_resolution=x_res, + y_resolution=y_res, + resolution_unit=res_unit, + gdal_metadata_xml=gdal_meta_xml, + extra_tags=extra_tags_list, + photometric=photometric, + restore_sentinel=restore_sentinel, + allow_internal_only_jpeg=allow_internal_only_jpeg, + allow_unparseable_crs=allow_unparseable_crs) + delayed_tasks.append(task) + else: + # Numpy: slice and write directly + chunk_data = np_arr[row_offset:row_offset + chunk_h, + col_offset:col_offset + chunk_w] + _write_single_tile( + chunk_data, tile_path, tile_gt, epsg, wkt_fallback, + nodata, compression, compression_level, + tile_size, predictor, bigtiff, max_z_error, + raster_type=raster_type, + x_resolution=x_res, + y_resolution=y_res, + resolution_unit=res_unit, + gdal_metadata_xml=gdal_meta_xml, + extra_tags=extra_tags_list, + photometric=photometric, + restore_sentinel=restore_sentinel, + allow_internal_only_jpeg=allow_internal_only_jpeg, + allow_unparseable_crs=allow_unparseable_crs) + + col_offset += chunk_w + row_offset += chunk_h + + # Execute all dask tasks. + # + # Each delayed task is an independent ``_write_single_tile`` call on + # a distinct output path, with no shared mutable Python state, so + # the writes are embarrassingly parallel. Using ``scheduler='threads'`` + # lets zlib / zstd / LZW release the GIL during compression and the + # OS coalesce concurrent writes; in a 256-tile zstd write on a + # 4096x4096 dask DataArray the wall time drops ~33% versus the + # ``synchronous`` scheduler this used to call. + # + # ``dask.compute`` raises as soon as one task errors, while sibling + # worker threads may still be writing tiles into the staging dir. + # Cleaning up at that point races those threads and can leave a + # partial staging dir behind. Capture each task's outcome instead + # so every worker reaches a barrier before we inspect results, then + # re-raise the first failure once no thread is still writing. + if delayed_tasks: + import dask + results = dask.compute(*delayed_tasks, scheduler='threads') + for r in results: + if isinstance(r, BaseException): + raise r + except BaseException: + # Any tile failure leaves no partial output: drop the staging dir + # so the final ``tiles_dir`` name is never created and a retry + # starts clean. + _cleanup_staging() + raise + + # Every tile is written. Promote the staging dir to its final name in + # one atomic rename so the VRT only ever references a complete tile set. + # A pre-existing *empty* ``tiles_dir`` passes the leftover-state guard + # above (the old code reused it via ``makedirs(exist_ok=True)``), but + # ``os.replace`` onto an existing directory raises on Windows even when + # the target is empty. Drop the empty target first so the rename has a + # clear slot on every platform. + try: + if os.path.isdir(tiles_dir): + os.rmdir(tiles_dir) # only succeeds if empty; guard ensured that + os.replace(staging_dir, tiles_dir) + except BaseException: + _cleanup_staging() + raise + + # Write VRT index with relative paths. The VRT lives at ``vrt_path``; + # tile paths now resolve under the final ``tiles_dir``. + tile_paths = [os.path.join(tiles_dir, name) for name in tile_names] from .._vrt import write_vrt as _write_vrt_fn - _write_vrt_fn(vrt_path, tile_paths, relative=True, nodata=nodata) + try: + _write_vrt_fn(vrt_path, tile_paths, relative=True, nodata=nodata) + except BaseException: + # The index step failed after the rename. Remove the now-renamed + # tile dir too so a retry is not blocked by the leftover-state + # guard above. + shutil.rmtree(tiles_dir, ignore_errors=True) + raise diff --git a/xrspatial/geotiff/tests/write/test_vrt_atomic.py b/xrspatial/geotiff/tests/write/test_vrt_atomic.py new file mode 100644 index 000000000..d9102b1c8 --- /dev/null +++ b/xrspatial/geotiff/tests/write/test_vrt_atomic.py @@ -0,0 +1,157 @@ +"""Transactional VRT tiled writes (issue #2669). + +A tiled ``.vrt`` write must be atomic: a successful write leaves a +complete tile directory plus a readable VRT, and a write that fails +partway through leaves no poisoned partial output so a retry can +succeed. +""" +from __future__ import annotations + +import os + +import numpy as np +import pytest +import xarray as xr + +from xrspatial.geotiff import open_geotiff, to_geotiff +from xrspatial.geotiff._writers import eager + + +def _make_da(n=32): + """A georeferenced DataArray that tiles into a 2x2 grid at + ``tile_size=16`` (TIFF 6 requires tile sizes to be multiples of 16).""" + arr = np.arange(n * n, dtype=np.float32).reshape(n, n) + return xr.DataArray( + arr, + dims=['y', 'x'], + coords={'y': np.linspace(n - 0.5, 0.5, n), + 'x': np.linspace(0.5, n - 0.5, n)}, + attrs={'crs': 4326}, + ) + + +def _tiles_dir_for(vrt_path): + vrt_dir = os.path.dirname(os.path.abspath(vrt_path)) + stem = os.path.splitext(os.path.basename(vrt_path))[0] + return os.path.join(vrt_dir, stem + '_tiles') + + +def test_successful_tiled_write_produces_readable_vrt(tmp_path): + """A clean tiled write yields the tile directory and a VRT that reads + back to the original values.""" + da = _make_da() + vrt_path = str(tmp_path / 'atomic_2669_ok.vrt') + + to_geotiff(da, vrt_path, tiled=True, tile_size=16, compression='none') + + tiles_dir = _tiles_dir_for(vrt_path) + assert os.path.isfile(vrt_path) + assert os.path.isdir(tiles_dir) + tiles = [f for f in os.listdir(tiles_dir) if f.endswith('.tif')] + assert len(tiles) == 4 # 2x2 grid + + # No staging dir left behind. + leftovers = [f for f in os.listdir(tmp_path) + if '.tmp-' in f] + assert leftovers == [] + + result = open_geotiff(vrt_path) + np.testing.assert_array_almost_equal( + np.asarray(result.data).squeeze(), da.values, decimal=5) + + +def test_failed_tiled_write_leaves_no_poisoned_output_and_retry_succeeds( + tmp_path, monkeypatch): + """A tile write that raises partway through must leave neither the + final tile dir nor a staging dir behind, and a subsequent retry must + succeed.""" + da = _make_da() + vrt_path = str(tmp_path / 'atomic_2669_fail.vrt') + tiles_dir = _tiles_dir_for(vrt_path) + + real_write_single_tile = eager._write_single_tile + calls = {'n': 0} + + def flaky_write_single_tile(*args, **kwargs): + calls['n'] += 1 + if calls['n'] == 2: + raise RuntimeError("simulated tile write failure") + return real_write_single_tile(*args, **kwargs) + + monkeypatch.setattr(eager, '_write_single_tile', + flaky_write_single_tile) + + with pytest.raises(RuntimeError, match="simulated tile write failure"): + to_geotiff(da, vrt_path, tiled=True, tile_size=16, + compression='none') + + # No final tile dir, no VRT, no staging dir. + assert not os.path.exists(tiles_dir) + assert not os.path.exists(vrt_path) + leftovers = [f for f in os.listdir(tmp_path) if '.tmp-' in f] + assert leftovers == [], f"poisoned staging dirs left: {leftovers}" + + # Retry with a healthy writer must succeed -- the leftover-state guard + # no longer blocks it because nothing partial was committed. + monkeypatch.setattr(eager, '_write_single_tile', real_write_single_tile) + to_geotiff(da, vrt_path, tiled=True, tile_size=16, compression='none') + + assert os.path.isfile(vrt_path) + assert os.path.isdir(tiles_dir) + result = open_geotiff(vrt_path) + np.testing.assert_array_almost_equal( + np.asarray(result.data).squeeze(), da.values, decimal=5) + + +def test_preexisting_empty_tiles_dir_is_reused(tmp_path): + """An empty ``*_tiles`` directory left over from a prior aborted run + passes the non-empty leftover-state guard, so the write must promote + the staging dir into that slot rather than raising.""" + da = _make_da() + vrt_path = str(tmp_path / 'atomic_2669_empty.vrt') + tiles_dir = _tiles_dir_for(vrt_path) + os.makedirs(tiles_dir) # empty, pre-existing + + to_geotiff(da, vrt_path, tiled=True, tile_size=16, compression='none') + + assert os.path.isfile(vrt_path) + assert os.path.isdir(tiles_dir) + tiles = [f for f in os.listdir(tiles_dir) if f.endswith('.tif')] + assert len(tiles) == 4 + result = open_geotiff(vrt_path) + np.testing.assert_array_almost_equal( + np.asarray(result.data).squeeze(), da.values, decimal=5) + + +def test_dask_failed_tiled_write_cleans_up(tmp_path, monkeypatch): + """Same atomicity guarantee on the dask streaming path.""" + pytest.importorskip("dask.array") + da = _make_da().chunk({'y': 16, 'x': 16}) + vrt_path = str(tmp_path / 'atomic_2669_dask.vrt') + tiles_dir = _tiles_dir_for(vrt_path) + + real_write_single_tile = eager._write_single_tile + calls = {'n': 0} + + def flaky_write_single_tile(*args, **kwargs): + calls['n'] += 1 + if calls['n'] == 2: + raise RuntimeError("simulated dask tile failure") + return real_write_single_tile(*args, **kwargs) + + monkeypatch.setattr(eager, '_write_single_tile', + flaky_write_single_tile) + + with pytest.raises(RuntimeError): + to_geotiff(da, vrt_path, tiled=True, compression='none') + + assert not os.path.exists(tiles_dir) + assert not os.path.exists(vrt_path) + leftovers = [f for f in os.listdir(tmp_path) if '.tmp-' in f] + assert leftovers == [] + + # Retry succeeds. + monkeypatch.setattr(eager, '_write_single_tile', real_write_single_tile) + to_geotiff(da, vrt_path, tiled=True, compression='none') + assert os.path.isfile(vrt_path) + assert os.path.isdir(tiles_dir)