Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 24 additions & 2 deletions monai/losses/cldice.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,11 @@

from monai.losses.dice import DiceLoss
from monai.networks import one_hot
from monai.utils import LossReduction
from monai.utils import LossReduction, optional_import
from monai.utils.deprecate_utils import deprecated_arg

binary_thinning_3d, _has_thinning = optional_import("binary_thinning_3d")


def soft_erode(img: torch.Tensor) -> torch.Tensor: # type: ignore
"""
Expand Down Expand Up @@ -129,6 +131,7 @@ def __init__(
softmax: bool = False,
other_act: Callable | None = None,
reduction: LossReduction | str = LossReduction.MEAN,
use_hard_target: bool = False,
) -> None:
"""
Args:
Expand All @@ -151,6 +154,8 @@ def __init__(
- ``"none"``: no reduction will be applied.
- ``"mean"``: the sum of the output will be divided by the number of elements in the output.
- ``"sum"``: the output will be summed.
use_hard_target: if True, use the exact CUDA 3D binary thinning for the target skeleton instead of soft skeletonization.
Requires binary_thinning_3d_cuda package and a CUDA 3D target. Defaults to False.

Raises:
TypeError: When ``other_act`` is not an ``Optional[Callable]``.
Expand Down Expand Up @@ -181,6 +186,7 @@ def __init__(
self.sigmoid = sigmoid
self.softmax = softmax
self.other_act = other_act
self.use_hard_target = use_hard_target

@deprecated_arg("y_pred", since="1.5", removed="1.8", new_name="input", msg_suffix="please use `input` instead.")
@deprecated_arg("y_true", since="1.5", removed="1.8", new_name="target", msg_suffix="please use `target` instead.")
Expand Down Expand Up @@ -226,7 +232,19 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
raise AssertionError(f"ground truth has different shape ({target.shape}) from input ({input.shape})")

skel_pred = soft_skel(input, self.iter)
skel_true = soft_skel(target, self.iter)
if self.use_hard_target:
if not (target.dim() == 5 and _has_thinning and target.is_cuda):
raise ValueError(
"use_hard_target=True but conditions not met. "
"Requires 5D CUDA tensor and binary_thinning_3d_cuda package."
)
skel_true = (target > 0).to(torch.uint8).contiguous()
for b in range(target.shape[0]):
for c in range(target.shape[1]):
binary_thinning_3d.binary_thinning(skel_true[b, c], 0)
skel_true = skel_true.to(target.dtype)
else:
skel_true = soft_skel(target, self.iter)

# Compute per-batch clDice by reducing over channel and spatial dimensions
# reduce_axis includes all dimensions except batch (dim 0)
Expand Down Expand Up @@ -279,6 +297,7 @@ def __init__(
softmax: bool = False,
other_act: Callable | None = None,
reduction: LossReduction | str = LossReduction.MEAN,
use_hard_target: bool = False,
) -> None:
"""
Args:
Expand All @@ -304,6 +323,8 @@ def __init__(
- ``"none"``: no reduction will be applied.
- ``"mean"``: the sum of the output will be divided by the number of elements in the output.
- ``"sum"``: the output will be summed.
use_hard_target: if True, use the exact CUDA 3D binary thinning for the target skeleton instead of soft skeletonization.
Requires MONAI C++ extensions and a 3D target. Defaults to False.

Raises:
TypeError: When ``other_act`` is not an ``Optional[Callable]``.
Expand Down Expand Up @@ -336,6 +357,7 @@ def __init__(
softmax=softmax,
other_act=other_act,
reduction=reduction,
use_hard_target=use_hard_target,
)
self.alpha = alpha
self.to_onehot_y = to_onehot_y
Expand Down
3 changes: 3 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ all =
nvidia-ml-py
huggingface_hub
pyamg>=5.0.0, <5.3.0
binary_thinning_3d_cuda
nibabel =
nibabel
ninja =
Expand Down Expand Up @@ -179,6 +180,8 @@ huggingface_hub =
huggingface_hub
pyamg =
pyamg>=5.0.0, <5.3.0
binary_thinning =
binary_thinning_3d_cuda
# segment-anything =
# segment-anything @ git+https://github.com/facebookresearch/segment-anything@6fdee8f2727f4506cfbbe553e23b895e27956588#egg=segment-anything

Expand Down
13 changes: 13 additions & 0 deletions tests/losses/test_cldice_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,19 @@ def test_cuda(self):
result = loss(ONES_2D["input"].cuda(), ONES_2D["target"].cuda())
np.testing.assert_allclose(result.detach().cpu().numpy(), 0.0, atol=1e-4)

@skip_if_no_cuda
def test_hard_target(self):
Comment thread
sychen52 marked this conversation as resolved.
"""Test SoftclDiceLoss with use_hard_target=True using binary thinning on 3D CUDA tensors."""
# Skip if binary_thinning not available
from monai.losses.cldice import _has_thinning
if not _has_thinning:
self.skipTest("binary_thinning_3d_cuda not available")

loss = SoftclDiceLoss(use_hard_target=True)
# MUST BE 3D for hard target logic to trigger! (shape: B, N, H, W, D)
result = loss(ONES_3D["input"].cuda(), ONES_3D["target"].cuda())
np.testing.assert_allclose(result.detach().cpu().numpy(), 0.0, atol=1e-4)
Comment thread
sychen52 marked this conversation as resolved.

def test_reduction_shapes(self):
input_tensor = torch.ones((4, 2, 8, 8))
target = torch.ones((4, 2, 8, 8))
Expand Down
Loading