Skip to content
14 changes: 7 additions & 7 deletions monai/engines/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,8 +219,8 @@ def __call__(
`kwargs` supports other args for `Tensor.to()` API.
"""
image, label = default_prepare_batch(batchdata, device, non_blocking, **kwargs)
args_ = []
kwargs_ = {}
args_: tuple = ()
kwargs_: dict = {}

def _get_data(key: str) -> torch.Tensor:
data = batchdata[key]
Expand All @@ -231,13 +231,13 @@ def _get_data(key: str) -> torch.Tensor:
return data

if isinstance(self.extra_keys, (str, list, tuple)):
for k in ensure_tuple(self.extra_keys):
args_.append(_get_data(k))
args_ = tuple(_get_data(k) for k in ensure_tuple(self.extra_keys))

elif isinstance(self.extra_keys, dict):
for k, v in self.extra_keys.items():
kwargs_.update({k: _get_data(v)})
kwargs_ = {k: _get_data(v) for k, v in self.extra_keys.items()}


return cast(torch.Tensor, image), cast(torch.Tensor, label), tuple(args_), kwargs_
return cast(torch.Tensor, image), cast(torch.Tensor, label), args_, kwargs_


class DiffusionPrepareBatch(PrepareBatch):
Expand Down
Loading