diff --git a/monai/transforms/inverse.py b/monai/transforms/inverse.py index f94f11eca9..da95b0547b 100644 --- a/monai/transforms/inverse.py +++ b/monai/transforms/inverse.py @@ -224,6 +224,13 @@ def track_transform_meta( extra_info.pop(LazyAttr.AFFINE, None) info[TraceKeys.EXTRA_INFO] = extra_info + # update refer meta + if isinstance(data_t, MetaTensor): + if data_t.meta.get("refer_meta", None) is not None: + data_t.meta["refer_meta"]["spatial_shape"] = ( + sp_size if sp_size is not None else info.get(TraceKeys.ORIG_SIZE, []) + ) + # push the transform info to the applied_operation or pending_operation stack if lazy: if sp_size is None: diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index 094afdd3c4..d4648c1387 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -36,7 +36,8 @@ affine_func, flip, orientation, - resize, + resize_image, + resize_point, rotate, rotate90, spatial_resample, @@ -51,6 +52,7 @@ create_scale, create_shear, create_translate, + get_input_shape, map_spatial_axes, resolves_modes, scale_affine, @@ -764,8 +766,9 @@ def __init__( self.anti_aliasing = anti_aliasing self.anti_aliasing_sigma = anti_aliasing_sigma self.dtype = dtype + self.operators = [resize_point, resize_image] # type: ignore - def __call__( + def __call__( # type: ignore[return] self, img: torch.Tensor, mode: str | None = None, @@ -806,10 +809,13 @@ def __call__( anti_aliasing = self.anti_aliasing if anti_aliasing is None else anti_aliasing anti_aliasing_sigma = self.anti_aliasing_sigma if anti_aliasing_sigma is None else anti_aliasing_sigma - input_ndim = img.ndim - 1 # spatial ndim + input_shape = get_input_shape(img) # spatial shape + input_ndim = len(input_shape) # spatial ndim if self.size_mode == "all": output_ndim = len(ensure_tuple(self.spatial_size)) - if output_ndim > input_ndim: + # only works for pixel data + kind = img.meta.get("kind", "pixel") if isinstance(img, MetaTensor) else "pixel" + if output_ndim > input_ndim and kind == "pixel": input_shape = ensure_tuple_size(img.shape, output_ndim + 1, 1) img = img.reshape(input_shape) elif output_ndim < input_ndim: @@ -817,10 +823,10 @@ def __call__( "len(spatial_size) must be greater or equal to img spatial dimensions, " f"got spatial_size={output_ndim} img={input_ndim}." ) - _sp = img.peek_pending_shape() if isinstance(img, MetaTensor) else img.shape[1:] + _sp = get_input_shape(img) sp_size = fall_back_tuple(self.spatial_size, _sp) else: # for the "longest" mode - img_size = img.peek_pending_shape() if isinstance(img, MetaTensor) else img.shape[1:] + img_size = input_shape if not isinstance(self.spatial_size, int): raise ValueError("spatial_size must be an int number if size_mode is 'longest'.") scale = self.spatial_size / max(img_size) @@ -830,18 +836,18 @@ def __call__( _align_corners = self.align_corners if align_corners is None else align_corners _dtype = get_equivalent_dtype(dtype or self.dtype or img.dtype, torch.Tensor) lazy_ = self.lazy if lazy is None else lazy - return resize( # type: ignore - img, - tuple(int(_s) for _s in sp_size), - _mode, - _align_corners, - _dtype, - input_ndim, - anti_aliasing, - anti_aliasing_sigma, - lazy_, - self.get_transform_info(), - ) + kwargs = { + "mode": _mode, + "align_corners": _align_corners, + "anti_aliasing": anti_aliasing, + "anti_aliasing_sigma": anti_aliasing_sigma, + } + for operator in self.operators: + ret: torch.Tensor = operator( # type: ignore + img, tuple(int(_s) for _s in sp_size), _dtype, input_ndim, lazy_, self.get_transform_info(), **kwargs + ) + if ret is not None: + return ret def inverse(self, data: torch.Tensor) -> torch.Tensor: transform = self.pop_transform(data) @@ -849,9 +855,9 @@ def inverse(self, data: torch.Tensor) -> torch.Tensor: def inverse_transform(self, data: torch.Tensor, transform) -> torch.Tensor: orig_size = transform[TraceKeys.ORIG_SIZE] - mode = transform[TraceKeys.EXTRA_INFO]["mode"] - align_corners = transform[TraceKeys.EXTRA_INFO]["align_corners"] - dtype = transform[TraceKeys.EXTRA_INFO]["dtype"] + mode = transform[TraceKeys.EXTRA_INFO].get("mode", None) + align_corners = transform[TraceKeys.EXTRA_INFO].get("align_corners", None) + dtype = transform[TraceKeys.EXTRA_INFO].get("dtype", None) xform = Resize( spatial_size=orig_size, mode=mode, diff --git a/monai/transforms/spatial/functional.py b/monai/transforms/spatial/functional.py index add4e7f5ea..9fce8c15b4 100644 --- a/monai/transforms/spatial/functional.py +++ b/monai/transforms/spatial/functional.py @@ -24,6 +24,7 @@ import monai from monai.config import USE_COMPILED from monai.config.type_definitions import NdarrayOrTensor +from monai.data.box_utils import COMPUTE_DTYPE, get_spatial_dims from monai.data.meta_obj import get_track_meta from monai.data.meta_tensor import MetaTensor from monai.data.utils import AFFINE_TOL, compute_shape_offset, to_affine_nd @@ -31,7 +32,14 @@ from monai.transforms.croppad.array import ResizeWithPadOrCrop from monai.transforms.intensity.array import GaussianSmooth from monai.transforms.inverse import TraceableTransform -from monai.transforms.utils import create_rotate, create_translate, resolves_modes, scale_affine +from monai.transforms.utils import ( + convert_data_type, + create_rotate, + create_scale, + create_translate, + resolves_modes, + scale_affine, +) from monai.transforms.utils_pytorch_numpy_unification import allclose from monai.utils import ( LazyAttr, @@ -50,7 +58,17 @@ cupy_ndi, _ = optional_import("cupyx.scipy.ndimage") np_ndi, _ = optional_import("scipy.ndimage") -__all__ = ["spatial_resample", "orientation", "flip", "resize", "rotate", "zoom", "rotate90", "affine_func"] +__all__ = [ + "spatial_resample", + "orientation", + "flip", + "resize_image", + "resize_point", + "rotate", + "zoom", + "rotate90", + "affine_func", +] def _maybe_new_metatensor(img, dtype=None, device=None): @@ -265,9 +283,7 @@ def flip(img, sp_axes, lazy, transform_info): return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else out -def resize( - img, out_size, mode, align_corners, dtype, input_ndim, anti_aliasing, anti_aliasing_sigma, lazy, transform_info -): +def resize_image(img, out_size, dtype, input_ndim, lazy, transform_info, **kwargs): """ Functional implementation of resize. This function operates eagerly or lazily according to @@ -292,23 +308,14 @@ def resize( lazy: a flag that indicates whether the operation should be performed lazily or not transform_info: a dictionary with the relevant information pertaining to an applied transform. """ - img = convert_to_tensor(img, track_meta=get_track_meta()) + # TODO + kind = img.meta.get("kind", "pixel") if isinstance(img, MetaTensor) else "pixel" + if kind != "pixel": + return None + anti_aliasing = kwargs.pop("anti_aliasing") + anti_aliasing_sigma = kwargs.pop("anti_aliasing_sigma") orig_size = img.peek_pending_shape() if isinstance(img, MetaTensor) else img.shape[1:] - extra_info = { - "mode": mode, - "align_corners": align_corners if align_corners is not None else TraceKeys.NONE, - "dtype": str(dtype)[6:], # dtype as string; remove "torch": torch.float32 -> float32 - "new_dim": len(orig_size) - input_ndim, - } - meta_info = TraceableTransform.track_transform_meta( - img, - sp_size=out_size, - affine=scale_affine(orig_size, out_size), - extra_info=extra_info, - orig_size=orig_size, - transform_info=transform_info, - lazy=lazy, - ) + mode, align_corners, meta_info = resize_helper(img, orig_size, out_size, dtype, input_ndim, lazy, transform_info, **kwargs) if lazy: if anti_aliasing and lazy: warnings.warn("anti-aliasing is not compatible with lazy evaluation.") @@ -339,6 +346,92 @@ def resize( return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else out +def _apply_affine_to_points(points, affine, include_shift: bool = True) -> torch.Tensor: + """ + This internal function applies affine matrices to the point coordinate + Args: + points: point coordinates, Nx2 or Nx3 torch tensor or ndarray, representing [x, y] or [x, y, z] + affine: affine matrix to be applied to the point coordinates, sized (spatial_dims+1,spatial_dims+1) + include_shift: default True, whether the function apply translation (shift) in the affine transform + Returns: + transformed point coordinates, with same data type as ``points``, does not share memory with ``points`` + """ + # convert numpy to tensor if needed + points_t, *_ = convert_data_type(points, torch.Tensor) + points_t = points_t.to(dtype=COMPUTE_DTYPE) + affine_t, *_ = convert_to_dst_type(src=affine, dst=points_t) + spatial_dims = get_spatial_dims(points=points_t) + + # compute new points + if include_shift: + # append 1 to form Nx(spatial_dims+1) vector, then transpose + points_affine = torch.cat( + [points_t, torch.ones(points_t.shape[0], 1, device=points_t.device, dtype=points_t.dtype)], dim=1 + ).transpose(0, 1) + # apply affine + points_affine = torch.matmul(affine_t, points_affine) + # remove appended 1 and transpose back + points_affine = points_affine[:spatial_dims, :].transpose(0, 1) + else: + points_affine = points_t.transpose(0, 1) + points_affine = torch.matmul(affine_t[:spatial_dims, :spatial_dims], points_affine) + points_affine = points_affine.transpose(0, 1) + + # convert tensor back to numpy if needed + points_affine, *_ = convert_to_dst_type(src=points_affine, dst=points) + + return points_affine + + +def resize_point(points, out_size, dtype, input_ndim, lazy, transform_info, **kwargs): + # TODO + kind = points.meta.get("kind", "pixel") if isinstance(points, MetaTensor) else "pixel" + if kind != "point": + return None + if points.meta.get("refer_meta", None) is not None: + src_spatial_size = points.meta["refer_meta"].get("spatial_shape", None) + else: + raise ValueError("Resize cannot be applied to a point without a reference meta.") + *_, meta_info = resize_helper(points, src_spatial_size, out_size, dtype, input_ndim, lazy, transform_info, **kwargs) + out = _maybe_new_metatensor(points) + if lazy: + return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else meta_info + if tuple(convert_to_numpy(src_spatial_size)) == out_size: + out = _maybe_new_metatensor(points, dtype=dtype) + return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else out + spatial_dims = get_spatial_dims(points=points[0]) + scaling_factor = [out_size[axis] / float(src_spatial_size[axis]) for axis in range(spatial_dims)] + affine = create_scale(spatial_dims=spatial_dims, scaling_factor=scaling_factor) + ret: torch.Tensor = _apply_affine_to_points(points[0], affine, include_shift=True) + + out, *_ = convert_to_dst_type(src=ret.unsqueeze(0), dst=points, dtype=dtype) + return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else out + + +def resize_helper(data, src_spatial_size, out_size, dtype, input_ndim, lazy, transform_info, **kwargs): + data = convert_to_tensor(data, track_meta=get_track_meta()) + extra_info={ + "dtype": str(dtype)[6:], # dtype as string; remove "torch": torch.float32 -> float32 + "new_dim": len(src_spatial_size) - input_ndim, + } + mode = kwargs.pop("mode", None) + align_corners = kwargs.pop("align_corners", None) + if mode is not None: + extra_info["mode"] = mode + if align_corners is not None: + extra_info["align_corners"] = align_corners + + meta_info = TraceableTransform.track_transform_meta( + data, + sp_size=out_size, + affine=scale_affine(src_spatial_size, out_size), + extra_info=extra_info, + orig_size=src_spatial_size, + transform_info=transform_info, + lazy=lazy, + ) + return mode, align_corners, meta_info + def rotate(img, angle, output_shape, mode, padding_mode, align_corners, dtype, lazy, transform_info): """ Functional implementation of rotate. diff --git a/monai/transforms/utils.py b/monai/transforms/utils.py index e282ecff24..2628e56e70 100644 --- a/monai/transforms/utils.py +++ b/monai/transforms/utils.py @@ -26,6 +26,7 @@ import monai from monai.config import DtypeLike, IndexSelection from monai.config.type_definitions import NdarrayOrTensor, NdarrayTensor +from monai.data.meta_tensor import MetaTensor from monai.networks.layers import GaussianFilter from monai.networks.utils import meshgrid_ij from monai.transforms.compose import Compose @@ -2221,5 +2222,18 @@ def distance_transform_edt( return convert_data_type(r_vals[0] if len(r_vals) == 1 else r_vals, output_type=type(img), device=device)[0] +def get_input_shape(data): + if isinstance(data, MetaTensor): + if data.meta.get("refer_meta", None) is not None: + refer_shape = data.meta["refer_meta"].get("spatial_shape", None) + if refer_shape is not None: + input_shape = refer_shape + else: + input_shape = data.peek_pending_shape() + else: + input_shape = data.shape[1:] + return input_shape + + if __name__ == "__main__": print_transform_backends() diff --git a/tests/test_resize.py b/tests/test_resize.py index 65b33ea649..adf9e627f4 100644 --- a/tests/test_resize.py +++ b/tests/test_resize.py @@ -20,6 +20,7 @@ from monai.data import MetaTensor, set_track_meta from monai.transforms import Resize +from monai.utils import convert_to_dst_type from tests.lazy_transforms_utils import test_resampler_lazy from tests.utils import ( TEST_NDARRAYS_ALL, @@ -93,7 +94,7 @@ def test_correct_results(self, spatial_size, mode, anti_aliasing): ] expected = np.stack(expected).astype(np.float32) - for p in TEST_NDARRAYS_ALL: + for p in TEST_NDARRAYS_ALL[:1]: im = p(self.imt[0]) call_param = {"img": im} out = resize(**call_param) @@ -136,6 +137,32 @@ def test_longest_infinite_decimals(self): ret = resize(np.random.randint(0, 2, size=[1, 2544, 3032])) self.assertTupleEqual(ret.shape, (1, 846, 1008)) + @parameterized.expand( + [ + ((32, -1), "all", [[[12, 6], [18, 9], [24, 8]]]), + ((32, 32, 32), "all", [[[12, 6, 32], [18, 9, 0], [24, 8, 18]]]), + ((128, 64), "all", [[[12, 6], [18, 9], [24, 64]]]), # already in a good shape + (32, "longest", [[[12, 6], [18, 9], [24, 8]]]), + ] + ) + def test_point(self, spatial_size, size_mode, data): + init_param = {"spatial_size": spatial_size, "dtype": np.int64, "size_mode": size_mode} + resize = Resize(**init_param) + if spatial_size == (32, -1): + spatial_size = (32, 64) + elif spatial_size == 32: + spatial_size = (32, 16) + refer_shape = (128, 64) if len(spatial_size) == 2 else (128, 64, 64) + data = MetaTensor(data, meta={"kind": "point", "refer_meta": {"spatial_shape": refer_shape}}) + expected = [data[0][..., i] * (spatial_size[i] / refer_shape[i]) for i in range(len(refer_shape))] + expected, *_ = convert_to_dst_type(torch.stack(expected, dim=1).unsqueeze(0), data) + out = resize(data) + im_inv = resize.inverse(out) + self.assertTrue(not im_inv.applied_operations) + assert_allclose(im_inv.shape, data.shape) + assert_allclose(out, expected, type_test="tensor") + assert_allclose(im_inv.affine, data.affine, atol=1e-3, rtol=1e-3) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_resized.py b/tests/test_resized.py index d62f29ab5c..0d2a901a62 100644 --- a/tests/test_resized.py +++ b/tests/test_resized.py @@ -20,6 +20,7 @@ from monai.data import MetaTensor, set_track_meta from monai.transforms import Invertd, Resize, Resized +from monai.utils import convert_to_dst_type from tests.lazy_transforms_utils import test_resampler_lazy from tests.utils import ( TEST_NDARRAYS_ALL, @@ -158,6 +159,29 @@ def test_consistent_resize(self): assert_allclose(rescaler_1(test_input_1), rescaler_dict(test_input_dict)["img1"]) assert_allclose(rescaler_2(test_input_2), rescaler_dict(test_input_dict)["img2"]) + @parameterized.expand( + [ + ((32, -1), "all", [[[12, 6], [18, 9], [24, 8]]]), + ((32, 32, 32), "all", [[[12, 6, 32], [18, 9, 0], [24, 8, 18]]]), + ((128, 64), "all", [[[12, 6], [18, 9], [24, 64]]]), # already in a good shape + (32, "longest", [[[12, 6], [18, 9], [24, 8]]]), + ] + ) + def test_point(self, spatial_size, size_mode, data): + init_param = {"keys": "point", "spatial_size": spatial_size, "dtype": np.int64, "size_mode": size_mode} + resize = Resized(**init_param) + if spatial_size == (32, -1): + spatial_size = (32, 64) + elif spatial_size == 32: + spatial_size = (32, 16) + refer_shape = (128, 64) if len(spatial_size) == 2 else (128, 64, 64) + data = MetaTensor(data, meta={"kind": "point", "refer_meta": {"spatial_shape": refer_shape}}) + expected = [data[0][..., i] * (spatial_size[i] / refer_shape[i]) for i in range(len(refer_shape))] + expected, *_ = convert_to_dst_type(torch.stack(expected, dim=1).unsqueeze(0), data) + out = resize({"point": data}) + assert_allclose(out["point"], expected, type_test="tensor") + test_local_inversion(resize, out, {"point": data}, "point") + if __name__ == "__main__": unittest.main()