From 9600f40285a05336f90e0de9aae8b0ddc6d26b32 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 17 Apr 2023 11:06:18 +0000 Subject: [PATCH 01/16] Better deprecation message --- src/diffusers/configuration_utils.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/src/diffusers/configuration_utils.py b/src/diffusers/configuration_utils.py index 45930431351a..28cd2bb50ff5 100644 --- a/src/diffusers/configuration_utils.py +++ b/src/diffusers/configuration_utils.py @@ -118,6 +118,29 @@ def register_to_config(self, **kwargs): self._internal_dict = FrozenDict(internal_dict) + def __getattr__(self, name): + try: + getattr(super(), name) + except AttributeError as e: + try: + is_in_config = self._safe_hasattr("_internal_dict") and hasattr(self._internal_dict, name) + is_attribute = self._safe_hasattr(name) + + if is_in_config and not is_attribute: + return self._internal_dict[name] + except AttributeError as e: + raise AttributeError(str(e).replace("super", self.__class__.__name__)) + + raise AttributeError(str(e).replace("super", self.__class__.__name__)) + + def _safe_hasattr(self, name): + try: + self.__getattribute__(name) + except AttributeError: + return False + + return True + def save_config(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs): """ Save a configuration object to the directory `save_directory`, so that it can be re-loaded using the From eab27d4985e171418bdfa45fa6e6e7b4c244167d Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 17 Apr 2023 11:06:26 +0000 Subject: [PATCH 02/16] Better deprecation message --- src/diffusers/configuration_utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/diffusers/configuration_utils.py b/src/diffusers/configuration_utils.py index 28cd2bb50ff5..3ede931944e2 100644 --- a/src/diffusers/configuration_utils.py +++ b/src/diffusers/configuration_utils.py @@ -127,6 +127,8 @@ def __getattr__(self, name): is_attribute = self._safe_hasattr(name) if is_in_config and not is_attribute: + deprecation_message = f"Accessing config attribute `{name}` directly via '{self.__class__.__name__}' object attribute is deprecated. Please access '{name}' over '{self.__class__.__name__}'s config object instead, e.g. 'unet.config.{name}' or 'scheduler.config.{name}'." + deprecate("direct config name access", "1.0.0", deprecation_message, standard_warn=False) return self._internal_dict[name] except AttributeError as e: raise AttributeError(str(e).replace("super", self.__class__.__name__)) From fd2df1bc561880a321b7bf62352ee14a6a65abef Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 17 Apr 2023 11:15:03 +0000 Subject: [PATCH 03/16] Better doc string --- src/diffusers/configuration_utils.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/diffusers/configuration_utils.py b/src/diffusers/configuration_utils.py index 3ede931944e2..fbe9298e2c4f 100644 --- a/src/diffusers/configuration_utils.py +++ b/src/diffusers/configuration_utils.py @@ -119,6 +119,8 @@ def register_to_config(self, **kwargs): self._internal_dict = FrozenDict(internal_dict) def __getattr__(self, name): + """The only reason we overwrite `getattr` here is to gracefully deprecate accessing + config attributes directly. See https://github.com/huggingface/diffusers/pull/3129""" try: getattr(super(), name) except AttributeError as e: @@ -136,6 +138,8 @@ def __getattr__(self, name): raise AttributeError(str(e).replace("super", self.__class__.__name__)) def _safe_hasattr(self, name): + """This method should NOT be used and is only implemented for `def __getattr__` + for deprecation purposes""" try: self.__getattribute__(name) except AttributeError: From b007c65a5a58861dc561c1f5b0723152d6e3cf38 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 17 Apr 2023 11:33:27 +0000 Subject: [PATCH 04/16] Fixes --- src/diffusers/configuration_utils.py | 3 +-- src/diffusers/models/modeling_utils.py | 19 +++++++++++++++++++ src/diffusers/models/unet_2d.py | 12 +----------- src/diffusers/pipelines/pipeline_utils.py | 2 +- 4 files changed, 22 insertions(+), 14 deletions(-) diff --git a/src/diffusers/configuration_utils.py b/src/diffusers/configuration_utils.py index fbe9298e2c4f..dfb8478cdadc 100644 --- a/src/diffusers/configuration_utils.py +++ b/src/diffusers/configuration_utils.py @@ -122,12 +122,11 @@ def __getattr__(self, name): """The only reason we overwrite `getattr` here is to gracefully deprecate accessing config attributes directly. See https://github.com/huggingface/diffusers/pull/3129""" try: - getattr(super(), name) + return getattr(super(), name) except AttributeError as e: try: is_in_config = self._safe_hasattr("_internal_dict") and hasattr(self._internal_dict, name) is_attribute = self._safe_hasattr(name) - if is_in_config and not is_attribute: deprecation_message = f"Accessing config attribute `{name}` directly via '{self.__class__.__name__}' object attribute is deprecated. Please access '{name}' over '{self.__class__.__name__}'s config object instead, e.g. 'unet.config.{name}' or 'scheduler.config.{name}'." deprecate("direct config name access", "1.0.0", deprecation_message, standard_warn=False) diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 6a849f6f0e45..faa74a76678d 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -32,6 +32,7 @@ WEIGHTS_NAME, _add_variant, _get_model_file, + deprecate, is_accelerate_available, is_safetensors_available, is_torch_version, @@ -156,6 +157,24 @@ class ModelMixin(torch.nn.Module): def __init__(self): super().__init__() + def __getattr__(self, name): + """The only reason we overwrite `getattr` here is to gracefully deprecate accessing + config attributes directly. See https://github.com/huggingface/diffusers/pull/3129""" + try: + return super().__getattr__(name) + except AttributeError as e: + try: + is_in_config = self._safe_hasattr("_internal_dict") and hasattr(self._internal_dict, name) + is_attribute = self._safe_hasattr(name) + if is_in_config and not is_attribute: + deprecation_message = f"Accessing config attribute `{name}` directly via '{self.__class__.__name__}' object attribute is deprecated. Please access '{name}' over '{self.__class__.__name__}'s config object instead, e.g. 'unet.config.{name}' or 'scheduler.config.{name}'." + deprecate("direct config name access", "1.0.0", deprecation_message, standard_warn=False) + return self._internal_dict[name] + except AttributeError as e: + raise AttributeError(str(e).replace("super", self.__class__.__name__)) + + raise AttributeError(str(e).replace("super", self.__class__.__name__)) + @property def is_gradient_checkpointing(self) -> bool: """ diff --git a/src/diffusers/models/unet_2d.py b/src/diffusers/models/unet_2d.py index a83e4917c143..2a6a1b9de5f2 100644 --- a/src/diffusers/models/unet_2d.py +++ b/src/diffusers/models/unet_2d.py @@ -18,7 +18,7 @@ import torch.nn as nn from ..configuration_utils import ConfigMixin, register_to_config -from ..utils import BaseOutput, deprecate +from ..utils import BaseOutput from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps from .modeling_utils import ModelMixin from .unet_2d_blocks import UNetMidBlock2D, get_down_block, get_up_block @@ -216,16 +216,6 @@ def __init__( self.conv_act = nn.SiLU() self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=3, padding=1) - @property - def in_channels(self): - deprecate( - "in_channels", - "1.0.0", - "Accessing `in_channels` directly via unet.in_channels is deprecated. Please use `unet.config.in_channels` instead", - standard_warn=False, - ) - return self.config.in_channels - def forward( self, sample: torch.FloatTensor, diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index c095da1665de..38f13fa11e66 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -508,7 +508,7 @@ def register_modules(self, **kwargs): setattr(self, name, module) def __setattr__(self, name: str, value: Any): - if hasattr(self, name) and hasattr(self.config, name): + if self._safe_hasattr(name) and hasattr(self.config, name): # We need to overwrite the config if name exists in config if isinstance(getattr(self.config, name), (tuple, list)): if value is not None and self.config[name][0] is not None: From c6cc8244fa97531263ea8525842e188457508b95 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 17 Apr 2023 11:42:50 +0000 Subject: [PATCH 05/16] fix more --- src/diffusers/configuration_utils.py | 3 +-- src/diffusers/models/modeling_utils.py | 3 +-- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/src/diffusers/configuration_utils.py b/src/diffusers/configuration_utils.py index dfb8478cdadc..f7b48ce471a3 100644 --- a/src/diffusers/configuration_utils.py +++ b/src/diffusers/configuration_utils.py @@ -126,8 +126,7 @@ def __getattr__(self, name): except AttributeError as e: try: is_in_config = self._safe_hasattr("_internal_dict") and hasattr(self._internal_dict, name) - is_attribute = self._safe_hasattr(name) - if is_in_config and not is_attribute: + if is_in_config: deprecation_message = f"Accessing config attribute `{name}` directly via '{self.__class__.__name__}' object attribute is deprecated. Please access '{name}' over '{self.__class__.__name__}'s config object instead, e.g. 'unet.config.{name}' or 'scheduler.config.{name}'." deprecate("direct config name access", "1.0.0", deprecation_message, standard_warn=False) return self._internal_dict[name] diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index faa74a76678d..14b8d6879c5e 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -165,8 +165,7 @@ def __getattr__(self, name): except AttributeError as e: try: is_in_config = self._safe_hasattr("_internal_dict") and hasattr(self._internal_dict, name) - is_attribute = self._safe_hasattr(name) - if is_in_config and not is_attribute: + if is_in_config: deprecation_message = f"Accessing config attribute `{name}` directly via '{self.__class__.__name__}' object attribute is deprecated. Please access '{name}' over '{self.__class__.__name__}'s config object instead, e.g. 'unet.config.{name}' or 'scheduler.config.{name}'." deprecate("direct config name access", "1.0.0", deprecation_message, standard_warn=False) return self._internal_dict[name] From 9a65cbe40e5e057d4d3a179a4432819d9eaaab87 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 17 Apr 2023 12:00:26 +0000 Subject: [PATCH 06/16] fix more --- src/diffusers/configuration_utils.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/diffusers/configuration_utils.py b/src/diffusers/configuration_utils.py index f7b48ce471a3..2e713220e598 100644 --- a/src/diffusers/configuration_utils.py +++ b/src/diffusers/configuration_utils.py @@ -122,7 +122,7 @@ def __getattr__(self, name): """The only reason we overwrite `getattr` here is to gracefully deprecate accessing config attributes directly. See https://github.com/huggingface/diffusers/pull/3129""" try: - return getattr(super(), name) + return super().__getattr__(name) except AttributeError as e: try: is_in_config = self._safe_hasattr("_internal_dict") and hasattr(self._internal_dict, name) @@ -131,9 +131,11 @@ def __getattr__(self, name): deprecate("direct config name access", "1.0.0", deprecation_message, standard_warn=False) return self._internal_dict[name] except AttributeError as e: - raise AttributeError(str(e).replace("super", self.__class__.__name__)) + error = str(e).replace("super", self.__class__.__name__).replace("__getattr__", name) + raise AttributeError(error) - raise AttributeError(str(e).replace("super", self.__class__.__name__)) + error = str(e).replace("super", self.__class__.__name__).replace("__getattr__", name) + raise AttributeError(error) def _safe_hasattr(self, name): """This method should NOT be used and is only implemented for `def __getattr__` From e56d51a5e9968d37f528b8fe81fb523461b90480 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 17 Apr 2023 12:14:47 +0000 Subject: [PATCH 07/16] Improve __getattr__ --- src/diffusers/configuration_utils.py | 41 ++++++++++---------------- src/diffusers/models/modeling_utils.py | 33 +++++++++++---------- 2 files changed, 32 insertions(+), 42 deletions(-) diff --git a/src/diffusers/configuration_utils.py b/src/diffusers/configuration_utils.py index 2e713220e598..d231ae0e8d9a 100644 --- a/src/diffusers/configuration_utils.py +++ b/src/diffusers/configuration_utils.py @@ -118,34 +118,23 @@ def register_to_config(self, **kwargs): self._internal_dict = FrozenDict(internal_dict) - def __getattr__(self, name): + def __getattr__(self, name: str) -> Any: """The only reason we overwrite `getattr` here is to gracefully deprecate accessing - config attributes directly. See https://github.com/huggingface/diffusers/pull/3129""" - try: - return super().__getattr__(name) - except AttributeError as e: - try: - is_in_config = self._safe_hasattr("_internal_dict") and hasattr(self._internal_dict, name) - if is_in_config: - deprecation_message = f"Accessing config attribute `{name}` directly via '{self.__class__.__name__}' object attribute is deprecated. Please access '{name}' over '{self.__class__.__name__}'s config object instead, e.g. 'unet.config.{name}' or 'scheduler.config.{name}'." - deprecate("direct config name access", "1.0.0", deprecation_message, standard_warn=False) - return self._internal_dict[name] - except AttributeError as e: - error = str(e).replace("super", self.__class__.__name__).replace("__getattr__", name) - raise AttributeError(error) - - error = str(e).replace("super", self.__class__.__name__).replace("__getattr__", name) - raise AttributeError(error) - - def _safe_hasattr(self, name): - """This method should NOT be used and is only implemented for `def __getattr__` - for deprecation purposes""" - try: - self.__getattribute__(name) - except AttributeError: - return False + config attributes directly. See https://github.com/huggingface/diffusers/pull/3129 + + Tihs funtion is mostly copied from PyTorch's __getattr__ overwrite: + https://pytorch.org/docs/stable/_modules/torch/nn/modules/module.html#Module + """ + + is_in_config = "_internal_dict" in self.__dict__ and hasattr(self.__dict__["_internal_dict"], name) + is_attribute = name in self.__dict__ + + if is_in_config and not is_attribute: + deprecation_message = f"Accessing config attribute `{name}` directly via '{self.__class__.__name__}' object attribute is deprecated. Please access '{name}' over '{self.__class__.__name__}'s config object instead, e.g. 'unet.config.{name}' or 'scheduler.config.{name}'." + deprecate("direct config name access", "1.0.0", deprecation_message, standard_warn=False) + return self._internal_dict[name] - return True + raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") def save_config(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs): """ diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 14b8d6879c5e..533f1a3caacc 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -17,7 +17,7 @@ import inspect import os from functools import partial -from typing import Callable, List, Optional, Tuple, Union +from typing import Any, Callable, List, Optional, Tuple, Union import torch from torch import Tensor, device @@ -157,22 +157,23 @@ class ModelMixin(torch.nn.Module): def __init__(self): super().__init__() - def __getattr__(self, name): + def __getattr__(self, name: str) -> Any: """The only reason we overwrite `getattr` here is to gracefully deprecate accessing - config attributes directly. See https://github.com/huggingface/diffusers/pull/3129""" - try: - return super().__getattr__(name) - except AttributeError as e: - try: - is_in_config = self._safe_hasattr("_internal_dict") and hasattr(self._internal_dict, name) - if is_in_config: - deprecation_message = f"Accessing config attribute `{name}` directly via '{self.__class__.__name__}' object attribute is deprecated. Please access '{name}' over '{self.__class__.__name__}'s config object instead, e.g. 'unet.config.{name}' or 'scheduler.config.{name}'." - deprecate("direct config name access", "1.0.0", deprecation_message, standard_warn=False) - return self._internal_dict[name] - except AttributeError as e: - raise AttributeError(str(e).replace("super", self.__class__.__name__)) - - raise AttributeError(str(e).replace("super", self.__class__.__name__)) + config attributes directly. See https://github.com/huggingface/diffusers/pull/3129 We need to overwrite + __getattr__ here in addition so that we don't trigger `torch.nn.Module`'s __getattr__': + https://pytorch.org/docs/stable/_modules/torch/nn/modules/module.html#Module + """ + + is_in_config = "_internal_dict" in self.__dict__ and hasattr(self.__dict__["_internal_dict"], name) + is_attribute = name in self.__dict__ + + if is_in_config and not is_attribute: + deprecation_message = f"Accessing config attribute `{name}` directly via '{self.__class__.__name__}' object attribute is deprecated. Please access '{name}' over '{self.__class__.__name__}'s config object instead, e.g. 'unet.config.{name}' or 'scheduler.config.{name}'." + deprecate("direct config name access", "1.0.0", deprecation_message, standard_warn=False) + return self._internal_dict[name] + + # call PyTorch's https://pytorch.org/docs/stable/_modules/torch/nn/modules/module.html#Module + return super().__getattr__(name) @property def is_gradient_checkpointing(self) -> bool: From 2d45454b0a24778f9978def9c1beb4128cb0b0b4 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 17 Apr 2023 12:18:28 +0000 Subject: [PATCH 08/16] correct more --- src/diffusers/pipelines/pipeline_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index 38f13fa11e66..c095da1665de 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -508,7 +508,7 @@ def register_modules(self, **kwargs): setattr(self, name, module) def __setattr__(self, name: str, value: Any): - if self._safe_hasattr(name) and hasattr(self.config, name): + if hasattr(self, name) and hasattr(self.config, name): # We need to overwrite the config if name exists in config if isinstance(getattr(self.config, name), (tuple, list)): if value is not None and self.config[name][0] is not None: From 351bf978d70a3ed0020df07ed7893b9efedb1aa3 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 17 Apr 2023 12:21:46 +0000 Subject: [PATCH 09/16] fix more --- src/diffusers/configuration_utils.py | 1 + src/diffusers/pipelines/pipeline_utils.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/src/diffusers/configuration_utils.py b/src/diffusers/configuration_utils.py index d231ae0e8d9a..7a564ca19f0a 100644 --- a/src/diffusers/configuration_utils.py +++ b/src/diffusers/configuration_utils.py @@ -131,6 +131,7 @@ def __getattr__(self, name: str) -> Any: if is_in_config and not is_attribute: deprecation_message = f"Accessing config attribute `{name}` directly via '{self.__class__.__name__}' object attribute is deprecated. Please access '{name}' over '{self.__class__.__name__}'s config object instead, e.g. 'unet.config.{name}' or 'scheduler.config.{name}'." + assert False deprecate("direct config name access", "1.0.0", deprecation_message, standard_warn=False) return self._internal_dict[name] diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index c095da1665de..94725185e69d 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -508,7 +508,7 @@ def register_modules(self, **kwargs): setattr(self, name, module) def __setattr__(self, name: str, value: Any): - if hasattr(self, name) and hasattr(self.config, name): + if name in self.__dict__ and hasattr(self.config, name): # We need to overwrite the config if name exists in config if isinstance(getattr(self.config, name), (tuple, list)): if value is not None and self.config[name][0] is not None: From 3fe30de42e3104d2e09ab67dc2a265726fed3c0c Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 17 Apr 2023 12:31:39 +0000 Subject: [PATCH 10/16] fix --- src/diffusers/configuration_utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/diffusers/configuration_utils.py b/src/diffusers/configuration_utils.py index 7a564ca19f0a..d231ae0e8d9a 100644 --- a/src/diffusers/configuration_utils.py +++ b/src/diffusers/configuration_utils.py @@ -131,7 +131,6 @@ def __getattr__(self, name: str) -> Any: if is_in_config and not is_attribute: deprecation_message = f"Accessing config attribute `{name}` directly via '{self.__class__.__name__}' object attribute is deprecated. Please access '{name}' over '{self.__class__.__name__}'s config object instead, e.g. 'unet.config.{name}' or 'scheduler.config.{name}'." - assert False deprecate("direct config name access", "1.0.0", deprecation_message, standard_warn=False) return self._internal_dict[name] From d9149fcc2a5cdb66f095277e7702e535d48c89fc Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 17 Apr 2023 12:53:45 +0000 Subject: [PATCH 11/16] Improve more --- src/diffusers/configuration_utils.py | 2 +- src/diffusers/models/modeling_utils.py | 4 +- src/diffusers/pipelines/pipeline_utils.py | 49 +++++++++++------------ src/diffusers/utils/deprecation_utils.py | 4 +- 4 files changed, 29 insertions(+), 30 deletions(-) diff --git a/src/diffusers/configuration_utils.py b/src/diffusers/configuration_utils.py index d231ae0e8d9a..fe9047137b4e 100644 --- a/src/diffusers/configuration_utils.py +++ b/src/diffusers/configuration_utils.py @@ -130,7 +130,7 @@ def __getattr__(self, name: str) -> Any: is_attribute = name in self.__dict__ if is_in_config and not is_attribute: - deprecation_message = f"Accessing config attribute `{name}` directly via '{self.__class__.__name__}' object attribute is deprecated. Please access '{name}' over '{self.__class__.__name__}'s config object instead, e.g. 'unet.config.{name}' or 'scheduler.config.{name}'." + deprecation_message = f"Accessing config attribute `{name}` directly via '{self.__class__.__name__}' object attribute is deprecated. Please access '{name}' over '{self.__class__.__name__}'s config object instead, e.g. 'scheduler.config.{name}'." deprecate("direct config name access", "1.0.0", deprecation_message, standard_warn=False) return self._internal_dict[name] diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 533f1a3caacc..16941336eb2e 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -168,8 +168,8 @@ def __getattr__(self, name: str) -> Any: is_attribute = name in self.__dict__ if is_in_config and not is_attribute: - deprecation_message = f"Accessing config attribute `{name}` directly via '{self.__class__.__name__}' object attribute is deprecated. Please access '{name}' over '{self.__class__.__name__}'s config object instead, e.g. 'unet.config.{name}' or 'scheduler.config.{name}'." - deprecate("direct config name access", "1.0.0", deprecation_message, standard_warn=False) + deprecation_message = f"Accessing config attribute `{name}` directly via '{self.__class__.__name__}' object attribute is deprecated. Please access '{name}' over '{self.__class__.__name__}'s config object instead, e.g. 'unet.config.{name}'." + deprecate("direct config name access", "1.0.0", deprecation_message, standard_warn=False, stacklevel=3) return self._internal_dict[name] # call PyTorch's https://pytorch.org/docs/stable/_modules/torch/nn/modules/module.html#Module diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index 94725185e69d..0af7ba36dc0d 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -648,26 +648,25 @@ def module_is_offloaded(module): ) module_names, _ = self._get_signature_keys(self) - module_names = [m for m in module_names if hasattr(self, m)] + module_names = [m for m in module_names if isinstance(m, torch.nn.Module)] is_offloaded = pipeline_is_offloaded or pipeline_is_sequentially_offloaded for name in module_names: module = getattr(self, name) - if isinstance(module, torch.nn.Module): - module.to(torch_device, torch_dtype) - if ( - module.dtype == torch.float16 - and str(torch_device) in ["cpu"] - and not silence_dtype_warnings - and not is_offloaded - ): - logger.warning( - "Pipelines loaded with `torch_dtype=torch.float16` cannot run with `cpu` device. It" - " is not recommended to move them to `cpu` as running them will fail. Please make" - " sure to use an accelerator to run the pipeline in inference, due to the lack of" - " support for`float16` operations on this device in PyTorch. Please, remove the" - " `torch_dtype=torch.float16` argument, or use another device for inference." - ) + module.to(torch_device, torch_dtype) + if ( + module.dtype == torch.float16 + and str(torch_device) in ["cpu"] + and not silence_dtype_warnings + and not is_offloaded + ): + logger.warning( + "Pipelines loaded with `torch_dtype=torch.float16` cannot run with `cpu` device. It" + " is not recommended to move them to `cpu` as running them will fail. Please make" + " sure to use an accelerator to run the pipeline in inference, due to the lack of" + " support for`float16` operations on this device in PyTorch. Please, remove the" + " `torch_dtype=torch.float16` argument, or use another device for inference." + ) return self @property @@ -677,12 +676,12 @@ def device(self) -> torch.device: `torch.device`: The torch device on which the pipeline is located. """ module_names, _ = self._get_signature_keys(self) - module_names = [m for m in module_names if hasattr(self, m)] + module_names = [m for m in module_names if isinstance(m, torch.nn.Module)] for name in module_names: module = getattr(self, name) - if isinstance(module, torch.nn.Module): - return module.device + return module.device + return torch.device("cpu") @classmethod @@ -1452,12 +1451,11 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): fn_recursive_set_mem_eff(child) module_names, _, _ = self.extract_init_dict(dict(self.config)) - module_names = [m for m in module_names if hasattr(self, m)] + module_names = [m for m in module_names if isinstance(m, torch.nn.Module)] for module_name in module_names: module = getattr(self, module_name) - if isinstance(module, torch.nn.Module): - fn_recursive_set_mem_eff(module) + fn_recursive_set_mem_eff(module) def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): r""" @@ -1485,9 +1483,10 @@ def disable_attention_slicing(self): def set_attention_slice(self, slice_size: Optional[int]): module_names, _, _ = self.extract_init_dict(dict(self.config)) - module_names = [m for m in module_names if hasattr(self, m)] + module_names = [ + m for m in module_names if isinstance(m, torch.nn.Module) and hasattr(m, "set_attention_slice") + ] for module_name in module_names: module = getattr(self, module_name) - if isinstance(module, torch.nn.Module) and hasattr(module, "set_attention_slice"): - module.set_attention_slice(slice_size) + module.set_attention_slice(slice_size) diff --git a/src/diffusers/utils/deprecation_utils.py b/src/diffusers/utils/deprecation_utils.py index 6bdda664e102..f482deddd2f4 100644 --- a/src/diffusers/utils/deprecation_utils.py +++ b/src/diffusers/utils/deprecation_utils.py @@ -5,7 +5,7 @@ from packaging import version -def deprecate(*args, take_from: Optional[Union[Dict, Any]] = None, standard_warn=True): +def deprecate(*args, take_from: Optional[Union[Dict, Any]] = None, standard_warn=True, stacklevel=2): from .. import __version__ deprecated_kwargs = take_from @@ -32,7 +32,7 @@ def deprecate(*args, take_from: Optional[Union[Dict, Any]] = None, standard_warn if warning is not None: warning = warning + " " if standard_warn else "" - warnings.warn(warning + message, FutureWarning, stacklevel=2) + warnings.warn(warning + message, FutureWarning, stacklevel=stacklevel) if isinstance(deprecated_kwargs, dict) and len(deprecated_kwargs) > 0: call_frame = inspect.getouterframes(inspect.currentframe())[1] From 041fd125de54cebc608696061a390a4e8718ff35 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 17 Apr 2023 13:15:05 +0000 Subject: [PATCH 12/16] more improvements --- src/diffusers/pipelines/pipeline_utils.py | 30 +++++++++++------------ 1 file changed, 14 insertions(+), 16 deletions(-) diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index 0af7ba36dc0d..ad8980944eac 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -648,11 +648,11 @@ def module_is_offloaded(module): ) module_names, _ = self._get_signature_keys(self) - module_names = [m for m in module_names if isinstance(m, torch.nn.Module)] + modules = [getattr(self, n, None) for n in module_names] + modules = [m for m in modules if isinstance(m, torch.nn.Module)] is_offloaded = pipeline_is_offloaded or pipeline_is_sequentially_offloaded - for name in module_names: - module = getattr(self, name) + for module in modules: module.to(torch_device, torch_dtype) if ( module.dtype == torch.float16 @@ -676,10 +676,10 @@ def device(self) -> torch.device: `torch.device`: The torch device on which the pipeline is located. """ module_names, _ = self._get_signature_keys(self) - module_names = [m for m in module_names if isinstance(m, torch.nn.Module)] + modules = [getattr(self, n, None) for n in module_names] + modules = [m for m in modules if isinstance(m, torch.nn.Module)] - for name in module_names: - module = getattr(self, name) + for module in modules: return module.device return torch.device("cpu") @@ -1450,11 +1450,11 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): for child in module.children(): fn_recursive_set_mem_eff(child) - module_names, _, _ = self.extract_init_dict(dict(self.config)) - module_names = [m for m in module_names if isinstance(m, torch.nn.Module)] + module_names, _ = self._get_signature_keys(self) + modules = [getattr(self, n, None) for n in module_names] + modules = [m for m in modules if isinstance(m, torch.nn.Module)] - for module_name in module_names: - module = getattr(self, module_name) + for module in modules: fn_recursive_set_mem_eff(module) def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): @@ -1482,11 +1482,9 @@ def disable_attention_slicing(self): self.enable_attention_slicing(None) def set_attention_slice(self, slice_size: Optional[int]): - module_names, _, _ = self.extract_init_dict(dict(self.config)) - module_names = [ - m for m in module_names if isinstance(m, torch.nn.Module) and hasattr(m, "set_attention_slice") - ] + module_names, _ = self._get_signature_keys(self) + modules = [getattr(self, n, None) for n in module_names] + modules = [m for m in modules if isinstance(m, torch.nn.Module) and hasattr(m, "set_attention_slice")] - for module_name in module_names: - module = getattr(self, module_name) + for module in modules: module.set_attention_slice(slice_size) From b07e42c74ca5f69e9b068a50c44a14182a8fa56e Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 17 Apr 2023 13:20:24 +0000 Subject: [PATCH 13/16] fix more --- src/diffusers/models/unet_2d_condition.py | 12 +----------- 1 file changed, 1 insertion(+), 11 deletions(-) diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index 1b982aedc5de..b2814356939b 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -21,7 +21,7 @@ from ..configuration_utils import ConfigMixin, register_to_config from ..loaders import UNet2DConditionLoadersMixin -from ..utils import BaseOutput, deprecate, logging +from ..utils import BaseOutput, logging from .attention_processor import AttentionProcessor, AttnProcessor from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps from .modeling_utils import ModelMixin @@ -447,16 +447,6 @@ def __init__( block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding ) - @property - def in_channels(self): - deprecate( - "in_channels", - "1.0.0", - "Accessing `in_channels` directly via unet.in_channels is deprecated. Please use `unet.config.in_channels` instead", - standard_warn=False, - ) - return self.config.in_channels - @property def attn_processors(self) -> Dict[str, AttentionProcessor]: r""" From 72d9bd56052eb5ab709a9709f49d5a6d43bc9e8b Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 17 Apr 2023 14:22:46 +0100 Subject: [PATCH 14/16] Apply suggestions from code review Co-authored-by: Pedro Cuenca --- src/diffusers/configuration_utils.py | 2 +- src/diffusers/models/modeling_utils.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/configuration_utils.py b/src/diffusers/configuration_utils.py index fe9047137b4e..772e119fbe97 100644 --- a/src/diffusers/configuration_utils.py +++ b/src/diffusers/configuration_utils.py @@ -130,7 +130,7 @@ def __getattr__(self, name: str) -> Any: is_attribute = name in self.__dict__ if is_in_config and not is_attribute: - deprecation_message = f"Accessing config attribute `{name}` directly via '{self.__class__.__name__}' object attribute is deprecated. Please access '{name}' over '{self.__class__.__name__}'s config object instead, e.g. 'scheduler.config.{name}'." + deprecation_message = f"Accessing config attribute `{name}` directly via '{type(self).__name__}' object attribute is deprecated. Please access '{name}' over '{type(self).__name__}'s config object instead, e.g. 'scheduler.config.{name}'." deprecate("direct config name access", "1.0.0", deprecation_message, standard_warn=False) return self._internal_dict[name] diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 16941336eb2e..5363e6330623 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -168,7 +168,7 @@ def __getattr__(self, name: str) -> Any: is_attribute = name in self.__dict__ if is_in_config and not is_attribute: - deprecation_message = f"Accessing config attribute `{name}` directly via '{self.__class__.__name__}' object attribute is deprecated. Please access '{name}' over '{self.__class__.__name__}'s config object instead, e.g. 'unet.config.{name}'." + deprecation_message = f"Accessing config attribute `{name}` directly via '{type(self).__name__}' object attribute is deprecated. Please access '{name}' over '{type(self).__name__}'s config object instead, e.g. 'unet.config.{name}'." deprecate("direct config name access", "1.0.0", deprecation_message, standard_warn=False, stacklevel=3) return self._internal_dict[name] From f19468647d6d3c13a9b2601eabaefb34b110cfd8 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 17 Apr 2023 13:23:18 +0000 Subject: [PATCH 15/16] make style --- .../versatile_diffusion/modeling_text_unet.py | 15 +-------------- 1 file changed, 1 insertion(+), 14 deletions(-) diff --git a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py index 35ddfcadc3cb..4377be1181a8 100644 --- a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +++ b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py @@ -18,7 +18,7 @@ from ...models.embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps from ...models.transformer_2d import Transformer2DModel from ...models.unet_2d_condition import UNet2DConditionOutput -from ...utils import deprecate, logging +from ...utils import logging logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -544,19 +544,6 @@ def __init__( block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding ) - @property - def in_channels(self): - deprecate( - "in_channels", - "1.0.0", - ( - "Accessing `in_channels` directly via unet.in_channels is deprecated. Please use" - " `unet.config.in_channels` instead" - ), - standard_warn=False, - ) - return self.config.in_channels - @property def attn_processors(self) -> Dict[str, AttentionProcessor]: r""" From b5de4e8b0194ca2e868a096d54a91dd923ac4a14 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 17 Apr 2023 15:54:22 +0000 Subject: [PATCH 16/16] Fix all rest & add tests & remove old deprecation fns --- .../community/unclip_image_interpolation.py | 12 ++--- .../community/unclip_text_interpolation.py | 12 ++--- src/diffusers/models/autoencoder_kl.py | 12 +---- src/diffusers/models/unet_1d.py | 12 +---- .../pipeline_text_to_video_zero.py | 2 +- .../pipelines/unclip/pipeline_unclip.py | 12 ++--- .../unclip/pipeline_unclip_image_variation.py | 12 ++--- ...ipeline_versatile_diffusion_dual_guided.py | 2 +- ...ine_versatile_diffusion_image_variation.py | 2 +- ...eline_versatile_diffusion_text_to_image.py | 2 +- src/diffusers/schedulers/scheduling_ddpm.py | 12 +---- tests/models/test_modeling_common.py | 47 ++++++++++++++++++- tests/pipelines/unclip/test_unclip.py | 8 ++-- .../unclip/test_unclip_image_variation.py | 13 +++-- tests/schedulers/test_schedulers.py | 44 +++++++++++++++++ 15 files changed, 133 insertions(+), 71 deletions(-) diff --git a/examples/community/unclip_image_interpolation.py b/examples/community/unclip_image_interpolation.py index d0b54125b688..453ac07af7c6 100644 --- a/examples/community/unclip_image_interpolation.py +++ b/examples/community/unclip_image_interpolation.py @@ -372,9 +372,9 @@ def __call__( self.decoder_scheduler.set_timesteps(decoder_num_inference_steps, device=device) decoder_timesteps_tensor = self.decoder_scheduler.timesteps - num_channels_latents = self.decoder.in_channels - height = self.decoder.sample_size - width = self.decoder.sample_size + num_channels_latents = self.decoder.config.in_channels + height = self.decoder.config.sample_size + width = self.decoder.config.sample_size decoder_latents = self.prepare_latents( (batch_size, num_channels_latents, height, width), @@ -425,9 +425,9 @@ def __call__( self.super_res_scheduler.set_timesteps(super_res_num_inference_steps, device=device) super_res_timesteps_tensor = self.super_res_scheduler.timesteps - channels = self.super_res_first.in_channels // 2 - height = self.super_res_first.sample_size - width = self.super_res_first.sample_size + channels = self.super_res_first.config.in_channels // 2 + height = self.super_res_first.config.sample_size + width = self.super_res_first.config.sample_size super_res_latents = self.prepare_latents( (batch_size, channels, height, width), diff --git a/examples/community/unclip_text_interpolation.py b/examples/community/unclip_text_interpolation.py index ac6b73d974b6..290f45317004 100644 --- a/examples/community/unclip_text_interpolation.py +++ b/examples/community/unclip_text_interpolation.py @@ -452,9 +452,9 @@ def __call__( self.decoder_scheduler.set_timesteps(decoder_num_inference_steps, device=device) decoder_timesteps_tensor = self.decoder_scheduler.timesteps - num_channels_latents = self.decoder.in_channels - height = self.decoder.sample_size - width = self.decoder.sample_size + num_channels_latents = self.decoder.config.in_channels + height = self.decoder.config.sample_size + width = self.decoder.config.sample_size decoder_latents = self.prepare_latents( (batch_size, num_channels_latents, height, width), @@ -505,9 +505,9 @@ def __call__( self.super_res_scheduler.set_timesteps(super_res_num_inference_steps, device=device) super_res_timesteps_tensor = self.super_res_scheduler.timesteps - channels = self.super_res_first.in_channels // 2 - height = self.super_res_first.sample_size - width = self.super_res_first.sample_size + channels = self.super_res_first.config.in_channels // 2 + height = self.super_res_first.config.sample_size + width = self.super_res_first.config.sample_size super_res_latents = self.prepare_latents( (batch_size, channels, height, width), diff --git a/src/diffusers/models/autoencoder_kl.py b/src/diffusers/models/autoencoder_kl.py index 5d1c54a9af25..1a8a204d80ce 100644 --- a/src/diffusers/models/autoencoder_kl.py +++ b/src/diffusers/models/autoencoder_kl.py @@ -18,7 +18,7 @@ import torch.nn as nn from ..configuration_utils import ConfigMixin, register_to_config -from ..utils import BaseOutput, apply_forward_hook, deprecate +from ..utils import BaseOutput, apply_forward_hook from .modeling_utils import ModelMixin from .vae import Decoder, DecoderOutput, DiagonalGaussianDistribution, Encoder @@ -123,16 +123,6 @@ def __init__( self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1))) self.tile_overlap_factor = 0.25 - @property - def block_out_channels(self): - deprecate( - "block_out_channels", - "1.0.0", - "Accessing `block_out_channels` directly via vae.block_out_channels is deprecated. Please use `vae.config.block_out_channels instead`", - standard_warn=False, - ) - return self.config.block_out_channels - def _set_gradient_checkpointing(self, module, value=False): if isinstance(module, (Encoder, Decoder)): module.gradient_checkpointing = value diff --git a/src/diffusers/models/unet_1d.py b/src/diffusers/models/unet_1d.py index c7755bb3ed45..34a1d2b5160e 100644 --- a/src/diffusers/models/unet_1d.py +++ b/src/diffusers/models/unet_1d.py @@ -19,7 +19,7 @@ import torch.nn as nn from ..configuration_utils import ConfigMixin, register_to_config -from ..utils import BaseOutput, deprecate +from ..utils import BaseOutput from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps from .modeling_utils import ModelMixin from .unet_1d_blocks import get_down_block, get_mid_block, get_out_block, get_up_block @@ -190,16 +190,6 @@ def __init__( fc_dim=block_out_channels[-1] // 4, ) - @property - def in_channels(self): - deprecate( - "in_channels", - "1.0.0", - "Accessing `in_channels` directly via unet.in_channels is deprecated. Please use `unet.config.in_channels` instead", - standard_warn=False, - ) - return self.config.in_channels - def forward( self, sample: torch.FloatTensor, diff --git a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py index cf5e6e399a77..5b163bbbc8f5 100644 --- a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +++ b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py @@ -441,7 +441,7 @@ def __call__( timesteps = self.scheduler.timesteps # Prepare latent variables - num_channels_latents = self.unet.in_channels + num_channels_latents = self.unet.config.in_channels latents = self.prepare_latents( batch_size * num_videos_per_prompt, num_channels_latents, diff --git a/src/diffusers/pipelines/unclip/pipeline_unclip.py b/src/diffusers/pipelines/unclip/pipeline_unclip.py index 3aac39b3a3b0..abbb48ce8f46 100644 --- a/src/diffusers/pipelines/unclip/pipeline_unclip.py +++ b/src/diffusers/pipelines/unclip/pipeline_unclip.py @@ -413,9 +413,9 @@ def __call__( self.decoder_scheduler.set_timesteps(decoder_num_inference_steps, device=device) decoder_timesteps_tensor = self.decoder_scheduler.timesteps - num_channels_latents = self.decoder.in_channels - height = self.decoder.sample_size - width = self.decoder.sample_size + num_channels_latents = self.decoder.config.in_channels + height = self.decoder.config.sample_size + width = self.decoder.config.sample_size decoder_latents = self.prepare_latents( (batch_size, num_channels_latents, height, width), @@ -466,9 +466,9 @@ def __call__( self.super_res_scheduler.set_timesteps(super_res_num_inference_steps, device=device) super_res_timesteps_tensor = self.super_res_scheduler.timesteps - channels = self.super_res_first.in_channels // 2 - height = self.super_res_first.sample_size - width = self.super_res_first.sample_size + channels = self.super_res_first.config.in_channels // 2 + height = self.super_res_first.config.sample_size + width = self.super_res_first.config.sample_size super_res_latents = self.prepare_latents( (batch_size, channels, height, width), diff --git a/src/diffusers/pipelines/unclip/pipeline_unclip_image_variation.py b/src/diffusers/pipelines/unclip/pipeline_unclip_image_variation.py index 56d522354d9a..30d74cd36bb0 100644 --- a/src/diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +++ b/src/diffusers/pipelines/unclip/pipeline_unclip_image_variation.py @@ -339,9 +339,9 @@ def __call__( self.decoder_scheduler.set_timesteps(decoder_num_inference_steps, device=device) decoder_timesteps_tensor = self.decoder_scheduler.timesteps - num_channels_latents = self.decoder.in_channels - height = self.decoder.sample_size - width = self.decoder.sample_size + num_channels_latents = self.decoder.config.in_channels + height = self.decoder.config.sample_size + width = self.decoder.config.sample_size if decoder_latents is None: decoder_latents = self.prepare_latents( @@ -393,9 +393,9 @@ def __call__( self.super_res_scheduler.set_timesteps(super_res_num_inference_steps, device=device) super_res_timesteps_tensor = self.super_res_scheduler.timesteps - channels = self.super_res_first.in_channels // 2 - height = self.super_res_first.sample_size - width = self.super_res_first.sample_size + channels = self.super_res_first.config.in_channels // 2 + height = self.super_res_first.config.sample_size + width = self.super_res_first.config.sample_size if super_res_latents is None: super_res_latents = self.prepare_latents( diff --git a/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py b/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py index 0f385ed6612c..661a1bd3cf73 100644 --- a/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +++ b/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py @@ -533,7 +533,7 @@ def __call__( timesteps = self.scheduler.timesteps # 5. Prepare latent variables - num_channels_latents = self.image_unet.in_channels + num_channels_latents = self.image_unet.config.in_channels latents = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, diff --git a/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py b/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py index 2b47184d7773..e3a2ee370362 100644 --- a/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +++ b/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py @@ -378,7 +378,7 @@ def __call__( timesteps = self.scheduler.timesteps # 5. Prepare latent variables - num_channels_latents = self.image_unet.in_channels + num_channels_latents = self.image_unet.config.in_channels latents = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, diff --git a/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py b/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py index fdca625fd99d..26b9be2bfa76 100644 --- a/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +++ b/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py @@ -452,7 +452,7 @@ def __call__( timesteps = self.scheduler.timesteps # 5. Prepare latent variables - num_channels_latents = self.image_unet.in_channels + num_channels_latents = self.image_unet.config.in_channels latents = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, diff --git a/src/diffusers/schedulers/scheduling_ddpm.py b/src/diffusers/schedulers/scheduling_ddpm.py index 2bc34bb8b444..a8a71fe420aa 100644 --- a/src/diffusers/schedulers/scheduling_ddpm.py +++ b/src/diffusers/schedulers/scheduling_ddpm.py @@ -22,7 +22,7 @@ import torch from ..configuration_utils import ConfigMixin, register_to_config -from ..utils import BaseOutput, deprecate, randn_tensor +from ..utils import BaseOutput, randn_tensor from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin @@ -168,16 +168,6 @@ def __init__( self.variance_type = variance_type - @property - def num_train_timesteps(self): - deprecate( - "num_train_timesteps", - "1.0.0", - "Accessing `num_train_timesteps` directly via scheduler.num_train_timesteps is deprecated. Please use `scheduler.config.num_train_timesteps instead`", - standard_warn=False, - ) - return self.config.num_train_timesteps - def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor: """ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the diff --git a/tests/models/test_modeling_common.py b/tests/models/test_modeling_common.py index 40aba3b24967..4a94a77fcabb 100644 --- a/tests/models/test_modeling_common.py +++ b/tests/models/test_modeling_common.py @@ -26,8 +26,8 @@ from diffusers.models import UNet2DConditionModel from diffusers.training_utils import EMAModel -from diffusers.utils import torch_device -from diffusers.utils.testing_utils import require_torch_gpu +from diffusers.utils import logging, torch_device +from diffusers.utils.testing_utils import CaptureLogger, require_torch_gpu class ModelUtilsTest(unittest.TestCase): @@ -155,6 +155,49 @@ def test_from_save_pretrained(self): max_diff = (image - new_image).abs().sum().item() self.assertLessEqual(max_diff, 5e-5, "Models give different forward passes") + def test_getattr_is_correct(self): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + model = self.model_class(**init_dict) + + # save some things to test + model.dummy_attribute = 5 + model.register_to_config(test_attribute=5) + + logger = logging.get_logger("diffusers.models.modeling_utils") + # 30 for warning + logger.setLevel(30) + with CaptureLogger(logger) as cap_logger: + assert hasattr(model, "dummy_attribute") + assert getattr(model, "dummy_attribute") == 5 + assert model.dummy_attribute == 5 + + # no warning should be thrown + assert cap_logger.out == "" + + logger = logging.get_logger("diffusers.models.modeling_utils") + # 30 for warning + logger.setLevel(30) + with CaptureLogger(logger) as cap_logger: + assert hasattr(model, "save_pretrained") + fn = model.save_pretrained + fn_1 = getattr(model, "save_pretrained") + + assert fn == fn_1 + # no warning should be thrown + assert cap_logger.out == "" + + # warning should be thrown + with self.assertWarns(FutureWarning): + assert model.test_attribute == 5 + + with self.assertWarns(FutureWarning): + assert getattr(model, "test_attribute") == 5 + + with self.assertRaises(AttributeError) as error: + model.does_not_exist + + assert str(error.exception) == f"'{type(model).__name__}' object has no attribute 'does_not_exist'" + def test_from_save_pretrained_variant(self): init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() diff --git a/tests/pipelines/unclip/test_unclip.py b/tests/pipelines/unclip/test_unclip.py index 4df3e4d3828b..d2c699ea501d 100644 --- a/tests/pipelines/unclip/test_unclip.py +++ b/tests/pipelines/unclip/test_unclip.py @@ -293,16 +293,16 @@ class DummyScheduler: prior_latents = pipe.prepare_latents( shape, dtype=dtype, device=device, generator=generator, latents=None, scheduler=DummyScheduler() ) - shape = (batch_size, decoder.in_channels, decoder.sample_size, decoder.sample_size) + shape = (batch_size, decoder.config.in_channels, decoder.config.sample_size, decoder.config.sample_size) decoder_latents = pipe.prepare_latents( shape, dtype=dtype, device=device, generator=generator, latents=None, scheduler=DummyScheduler() ) shape = ( batch_size, - super_res_first.in_channels // 2, - super_res_first.sample_size, - super_res_first.sample_size, + super_res_first.config.in_channels // 2, + super_res_first.config.sample_size, + super_res_first.config.sample_size, ) super_res_latents = pipe.prepare_latents( shape, dtype=dtype, device=device, generator=generator, latents=None, scheduler=DummyScheduler() diff --git a/tests/pipelines/unclip/test_unclip_image_variation.py b/tests/pipelines/unclip/test_unclip_image_variation.py index 57d15559cc75..c1b8be9cd49e 100644 --- a/tests/pipelines/unclip/test_unclip_image_variation.py +++ b/tests/pipelines/unclip/test_unclip_image_variation.py @@ -379,16 +379,21 @@ class DummyScheduler: dtype = pipe.decoder.dtype batch_size = 1 - shape = (batch_size, pipe.decoder.in_channels, pipe.decoder.sample_size, pipe.decoder.sample_size) + shape = ( + batch_size, + pipe.decoder.config.in_channels, + pipe.decoder.config.sample_size, + pipe.decoder.config.sample_size, + ) decoder_latents = pipe.prepare_latents( shape, dtype=dtype, device=device, generator=generator, latents=None, scheduler=DummyScheduler() ) shape = ( batch_size, - pipe.super_res_first.in_channels // 2, - pipe.super_res_first.sample_size, - pipe.super_res_first.sample_size, + pipe.super_res_first.config.in_channels // 2, + pipe.super_res_first.config.sample_size, + pipe.super_res_first.config.sample_size, ) super_res_latents = pipe.prepare_latents( shape, dtype=dtype, device=device, generator=generator, latents=None, scheduler=DummyScheduler() diff --git a/tests/schedulers/test_schedulers.py b/tests/schedulers/test_schedulers.py index bfbf5cbc798f..69cddb36dde2 100755 --- a/tests/schedulers/test_schedulers.py +++ b/tests/schedulers/test_schedulers.py @@ -596,3 +596,47 @@ def test_trained_betas(self): new_scheduler = scheduler_class.from_pretrained(tmpdirname) assert scheduler.betas.tolist() == new_scheduler.betas.tolist() + + def test_getattr_is_correct(self): + for scheduler_class in self.scheduler_classes: + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config) + + # save some things to test + scheduler.dummy_attribute = 5 + scheduler.register_to_config(test_attribute=5) + + logger = logging.get_logger("diffusers.configuration_utils") + # 30 for warning + logger.setLevel(30) + with CaptureLogger(logger) as cap_logger: + assert hasattr(scheduler, "dummy_attribute") + assert getattr(scheduler, "dummy_attribute") == 5 + assert scheduler.dummy_attribute == 5 + + # no warning should be thrown + assert cap_logger.out == "" + + logger = logging.get_logger("diffusers.schedulers.schedulering_utils") + # 30 for warning + logger.setLevel(30) + with CaptureLogger(logger) as cap_logger: + assert hasattr(scheduler, "save_pretrained") + fn = scheduler.save_pretrained + fn_1 = getattr(scheduler, "save_pretrained") + + assert fn == fn_1 + # no warning should be thrown + assert cap_logger.out == "" + + # warning should be thrown + with self.assertWarns(FutureWarning): + assert scheduler.test_attribute == 5 + + with self.assertWarns(FutureWarning): + assert getattr(scheduler, "test_attribute") == 5 + + with self.assertRaises(AttributeError) as error: + scheduler.does_not_exist + + assert str(error.exception) == f"'{type(scheduler).__name__}' object has no attribute 'does_not_exist'"