Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 33 additions & 34 deletions monai/auto3dseg/analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,50 +216,58 @@ def __init__(self, image_key: str, stats_name: str = DataStatsKeys.IMAGE_STATS)
super().__init__(stats_name, report_format)
self.update_ops(ImageStatsKeys.INTENSITY, SampleOperations())

@torch.no_grad()
def __call__(self, data):
# Input Validation Addition
if not isinstance(data, dict):
raise TypeError(f"Input data must be a dict, but got {type(data).__name__}.")
if self.image_key not in data:
raise KeyError(f"Key '{self.image_key}' not found in input data.")
image = data[self.image_key]
if not isinstance(image, (np.ndarray, torch.Tensor, MetaTensor)):
raise TypeError(
f"Value for '{self.image_key}' must be a numpy array, torch.Tensor, or MetaTensor, "
f"but got {type(image).__name__}."
)
if image.ndim < 3:
raise ValueError(
f"Image data under '{self.image_key}' must have at least 3 dimensions, but got shape {image.shape}."
)
# --- End of validation ---
"""
Callable to execute the pre-defined functions
Callable to execute the pre-defined functions.

Returns:
A dictionary. The dict has the key in self.report_format. The value of
ImageStatsKeys.INTENSITY is in a list format. Each element of the value list
has stats pre-defined by SampleOperations (max, min, ....).

Raises:
RuntimeError if the stats report generated is not consistent with the pre-
KeyError: if ``self.image_key`` is not present in the input data.
TypeError: if the input data is not a dictionary, or if the image value is
not a numpy array, torch.Tensor, or MetaTensor.
ValueError: if the image has fewer than 3 dimensions, or if pre-computed
``nda_croppeds`` is not a list/tuple with one entry per image channel.
RuntimeError: if the stats report generated is not consistent with the pre-
defined report_format.

Note:
The stats operation uses numpy and torch to compute max, min, and other
functions. If the input has nan/inf, the stats results will be nan/inf.

"""
if not isinstance(data, dict):
raise TypeError(f"Input data must be a dict, but got {type(data).__name__}.")
if self.image_key not in data:
raise KeyError(f"Key '{self.image_key}' not found in input data.")
image = data[self.image_key]
if not isinstance(image, (np.ndarray, torch.Tensor, MetaTensor)):
raise TypeError(
f"Value for '{self.image_key}' must be a numpy array, torch.Tensor, or MetaTensor, "
f"but got {type(image).__name__}."
)
if image.ndim < 3:
raise ValueError(
f"Image data under '{self.image_key}' must have at least 3 dimensions, but got shape {image.shape}."
)

d = dict(data)
start = time.time()
restore_grad_state = torch.is_grad_enabled()
torch.set_grad_enabled(False)

ndas = [d[self.image_key][i] for i in range(d[self.image_key].shape[0])]
if "nda_croppeds" not in d:
if "nda_croppeds" in d:
nda_croppeds = d["nda_croppeds"]
if not isinstance(nda_croppeds, (list, tuple)) or len(nda_croppeds) != len(ndas):
raise ValueError(
"Pre-computed 'nda_croppeds' must be a list or tuple with one entry per image channel "
f"(expected {len(ndas)})."
)
Comment thread
coderabbitai[bot] marked this conversation as resolved.
else:
nda_croppeds = [get_foreground_image(nda) for nda in ndas]

# perform calculation
report = deepcopy(self.get_report_format())

report[ImageStatsKeys.SHAPE] = [list(nda.shape) for nda in ndas]
Expand All @@ -284,7 +292,6 @@ def __call__(self, data):

d[self.stats_name] = report

torch.set_grad_enabled(restore_grad_state)
logger.debug(f"Get image stats spent {time.time() - start}")
return d

Expand Down Expand Up @@ -321,6 +328,7 @@ def __init__(self, image_key: str, label_key: str, stats_name: str = DataStatsKe
super().__init__(stats_name, report_format)
self.update_ops(ImageStatsKeys.INTENSITY, SampleOperations())

@torch.no_grad()
def __call__(self, data: Mapping) -> dict:
"""
Callable to execute the pre-defined functions
Expand All @@ -341,9 +349,6 @@ def __call__(self, data: Mapping) -> dict:

d = dict(data)
start = time.time()
restore_grad_state = torch.is_grad_enabled()
torch.set_grad_enabled(False)

ndas = [d[self.image_key][i] for i in range(d[self.image_key].shape[0])]
ndas_label = d[self.label_key] # (H,W,D)

Expand All @@ -353,7 +358,6 @@ def __call__(self, data: Mapping) -> dict:
nda_foregrounds = [get_foreground_label(nda, ndas_label) for nda in ndas]
nda_foregrounds = [nda if nda.numel() > 0 else MetaTensor([0.0]) for nda in nda_foregrounds]

# perform calculation
report = deepcopy(self.get_report_format())

report[ImageStatsKeys.INTENSITY] = [
Expand All @@ -365,7 +369,6 @@ def __call__(self, data: Mapping) -> dict:

d[self.stats_name] = report

torch.set_grad_enabled(restore_grad_state)
logger.debug(f"Get foreground image stats spent {time.time() - start}")
return d

Expand Down Expand Up @@ -418,6 +421,7 @@ def __init__(
id_seq = ID_SEP_KEY.join([LabelStatsKeys.LABEL, "0", LabelStatsKeys.IMAGE_INTST])
self.update_ops_nested_label(id_seq, SampleOperations())

@torch.no_grad()
def __call__(self, data: Mapping[Hashable, MetaTensor]) -> dict[Hashable, MetaTensor | dict]:
"""
Callable to execute the pre-defined functions.
Expand Down Expand Up @@ -470,19 +474,15 @@ def __call__(self, data: Mapping[Hashable, MetaTensor]) -> dict[Hashable, MetaTe
start = time.time()
image_tensor = d[self.image_key]
label_tensor = d[self.label_key]
# Check if either tensor is on CUDA to determine if we should move both to CUDA for processing
using_cuda = any(
isinstance(t, (torch.Tensor, MetaTensor)) and t.device.type == "cuda" for t in (image_tensor, label_tensor)
)
restore_grad_state = torch.is_grad_enabled()
torch.set_grad_enabled(False)

if isinstance(image_tensor, (MetaTensor, torch.Tensor)) and isinstance(
label_tensor, (MetaTensor, torch.Tensor)
):
if label_tensor.device != image_tensor.device:
if using_cuda:
# Move both tensors to CUDA when mixing devices
cuda_device = image_tensor.device if image_tensor.device.type == "cuda" else label_tensor.device
image_tensor = cast(MetaTensor, image_tensor.to(cuda_device))
label_tensor = cast(MetaTensor, label_tensor.to(cuda_device))
Expand Down Expand Up @@ -548,7 +548,6 @@ def __call__(self, data: Mapping[Hashable, MetaTensor]) -> dict[Hashable, MetaTe

d[self.stats_name] = report # type: ignore[assignment]

torch.set_grad_enabled(restore_grad_state)
logger.debug(f"Get label stats spent {time.time() - start}")
return d # type: ignore[return-value]

Expand Down
78 changes: 77 additions & 1 deletion tests/apps/test_auto3dseg.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@
SqueezeDimd,
ToDeviced,
)
from monai.utils.enums import DataStatsKeys, LabelStatsKeys
from monai.utils.enums import DataStatsKeys, ImageStatsKeys, LabelStatsKeys
from tests.test_utils import skip_if_no_cuda

device = "cpu"
Expand Down Expand Up @@ -322,6 +322,47 @@ def test_image_stats_case_analyzer(self):
report_format = analyzer.get_report_format()
assert verify_report_format(d["image_stats"], report_format)

def test_image_stats_uses_precomputed_nda_croppeds(self):
"""Verify ImageStats uses valid pre-computed foreground crops."""
analyzer = ImageStats(image_key="image")
image = torch.arange(64.0, dtype=torch.float32).reshape(1, 4, 4, 4)
nda_croppeds = [torch.ones((2, 2, 2), dtype=torch.float32)]

result = analyzer({"image": image, "nda_croppeds": nda_croppeds})
report = result["image_stats"]

assert verify_report_format(report, analyzer.get_report_format())
assert report[ImageStatsKeys.CROPPED_SHAPE] == [[2, 2, 2]]
self.assertAlmostEqual(report[ImageStatsKeys.INTENSITY][0]["mean"], 1.0)

def test_image_stats_validates_precomputed_nda_croppeds(self):
"""Verify ImageStats rejects malformed pre-computed foreground crops."""
analyzer = ImageStats(image_key="image")
image = torch.ones((2, 4, 4, 4), dtype=torch.float32)
invalid_cases = [
("wrong_type", torch.ones((2, 2, 2), dtype=torch.float32)),
("wrong_length", [torch.ones((2, 2, 2), dtype=torch.float32)]),
]

for name, nda_croppeds in invalid_cases:
with self.subTest(case=name):
with self.assertRaisesRegex(ValueError, "one entry per image channel"):
analyzer({"image": image, "nda_croppeds": nda_croppeds})

def test_image_stats_preserves_grad_state_after_call(self):
"""Verify ImageStats preserves caller grad state on successful execution."""
analyzer = ImageStats(image_key="image")
data = {"image": MetaTensor(torch.rand(1, 10, 10, 10))}
original_grad_state = torch.is_grad_enabled()
try:
for grad_enabled in (True, False):
with self.subTest(grad_enabled=grad_enabled):
torch.set_grad_enabled(grad_enabled)
analyzer(data)
self.assertEqual(torch.is_grad_enabled(), grad_enabled)
finally:
torch.set_grad_enabled(original_grad_state)

def test_foreground_image_stats_cases_analyzer(self):
analyzer = FgImageStats(image_key="image", label_key="label")
transform_list = [
Expand Down Expand Up @@ -412,6 +453,41 @@ def test_label_stats_mixed_device_analyzer(self, input_params):
self.assertAlmostEqual(foreground_stats[0]["mean"], 4.75)
self.assertAlmostEqual(foreground_stats[1]["mean"], 14.75)

def test_case_analyzers_restore_grad_state_on_exception(self):
"""Verify analyzer calls restore caller grad state after exceptions."""
cases = [
(
"image_stats",
ImageStats(image_key="image"),
{"image": torch.randn(2, 4, 4, 4), "nda_croppeds": [torch.ones((2, 2, 2))]},
ValueError,
),
(
"fg_image_stats",
FgImageStats(image_key="image", label_key="label"),
{"image": torch.randn(1, 4, 4, 4), "label": torch.ones(3, 4, 4)},
ValueError,
),
(
"label_stats",
LabelStats(image_key="image", label_key="label"),
{"image": MetaTensor(torch.randn(1, 4, 4, 4)), "label": MetaTensor(torch.ones(3, 4, 4))},
ValueError,
),
]

original_grad_state = torch.is_grad_enabled()
try:
for name, analyzer, data, error in cases:
for grad_enabled in (True, False):
with self.subTest(analyzer=name, grad_enabled=grad_enabled):
torch.set_grad_enabled(grad_enabled)
with self.assertRaises(error):
analyzer(data)
self.assertEqual(torch.is_grad_enabled(), grad_enabled)
finally:
torch.set_grad_enabled(original_grad_state)

def test_filename_case_analyzer(self):
analyzer_image = FilenameStats("image", DataStatsKeys.BY_CASE_IMAGE_PATH)
analyzer_label = FilenameStats("label", DataStatsKeys.BY_CASE_IMAGE_PATH)
Expand Down
Loading