diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index f4bf732b5322..c737bcc9dbd9 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -519,6 +519,8 @@ title: DeepFloyd IF - local: api/pipelines/dit title: DiT + - local: api/pipelines/dreamlite + title: DreamLite - local: api/pipelines/easyanimate title: EasyAnimate - local: api/pipelines/ernie_image diff --git a/docs/source/en/api/attnprocessor.md b/docs/source/en/api/attnprocessor.md index 7ab053f10756..95dccd317e32 100644 --- a/docs/source/en/api/attnprocessor.md +++ b/docs/source/en/api/attnprocessor.md @@ -44,6 +44,10 @@ An attention processor is a class for applying different types of attention mech [[autodoc]] models.attention_processor.FusedCogVideoXAttnProcessor2_0 +## DreamLite + +[[autodoc]] models.attention_processor.DreamLiteAttnProcessor2_0 + ## CrossFrameAttnProcessor [[autodoc]] pipelines.deprecated.text_to_video_synthesis.pipeline_text_to_video_zero.CrossFrameAttnProcessor diff --git a/docs/source/en/api/pipelines/dreamlite.md b/docs/source/en/api/pipelines/dreamlite.md new file mode 100644 index 000000000000..d673087e1a10 --- /dev/null +++ b/docs/source/en/api/pipelines/dreamlite.md @@ -0,0 +1,160 @@ + + +# DreamLite + +DreamLite is a text-to-image and image-editing model from ByteDance. It pairs a custom 2D U-Net +(`DreamLiteUNetModel`) with the `Qwen3-VL` multimodal encoder as its prompt / image-instruction encoder, +and uses an `AutoencoderTiny` (TAESD-style) VAE for fast latent encode/decode. + +Two pipelines are exposed: + +| Pipeline | Modes | CFG | Use case | +|---|---|---|---| +| [`DreamLitePipeline`] | text-to-image **and** image-editing (auto-selected by whether `image` is `None`) | 3-branch dual CFG (`guidance_scale` on text branch, `image_guidance_scale` on image branch, à la InstructPix2Pix) | Highest quality | +| [`DreamLiteMobilePipeline`] | text-to-image **and** image-editing (auto-selected by whether `image` is `None`) | None — distilled, single UNet forward per step | On-device / low-latency | + +Official checkpoints: + +* Base model: [carlofkl/DreamLite-base](https://huggingface.co/carlofkl/DreamLite-base) +* Distilled mobile model: [carlofkl/DreamLite-mobile](https://huggingface.co/carlofkl/DreamLite-mobile) + +> [!TIP] +> Both pipelines auto-detect text-to-image vs. image-editing mode from whether the `image` argument is +> provided. There is no separate `Img2Img` class. + +> [!TIP] +> When loading an input image for editing, prefer `diffusers.utils.load_image(...)` over raw `PIL.Image.open(...)`. +> `load_image` enforces an RGB conversion and applies EXIF orientation, both of which the pipeline assumes. +> A plain `Image.open` of an RGBA / palette / EXIF-rotated source will silently produce a different latent +> conditioning and degrade output quality. + +## Text-to-image (Base) + +```python +import torch +from diffusers import DreamLitePipeline + +pipe = DreamLitePipeline.from_pretrained("carlofkl/DreamLite-base", revision="diffusers", torch_dtype=torch.bfloat16) +pipe = pipe.to("cuda") + +image = pipe( + prompt="a dog running on the grass", + negative_prompt="", + height=1024, + width=1024, + num_inference_steps=28, + guidance_scale=3.5, + generator=torch.Generator("cpu").manual_seed(42), +).images[0] +image.save("dreamlite_t2i.png") +``` + +## Image editing (Base) + +Pass an `image` to enter edit mode. Both `guidance_scale` (text branch) and `image_guidance_scale` +(image branch) are active here. + +```python +import torch +from diffusers import DreamLitePipeline +from diffusers.utils import load_image + +pipe = DreamLitePipeline.from_pretrained("carlofkl/DreamLite-base", revision="diffusers", torch_dtype=torch.bfloat16) +pipe = pipe.to("cuda") + +source = load_image("https://huggingface.co/datasets/diffusers/diffusers-images-docs/resolve/main/cat.png") + +image = pipe( + prompt="turn the cat into a corgi", + image=source, + height=1024, + width=1024, + num_inference_steps=28, + guidance_scale=3.5, + image_guidance_scale=1.5, + generator=torch.Generator("cpu").manual_seed(42), +).images[0] +image.save("dreamlite_edit.png") +``` + +## Text-to-image (Mobile) + +The mobile pipeline is distilled and skips CFG entirely — a single UNet forward per step. It accepts the +same `prompt` / `height` / `width` / `num_inference_steps` arguments, but **ignores** `guidance_scale` and +`image_guidance_scale` if passed (a warning is logged). + +```python +import torch +from diffusers import DreamLiteMobilePipeline + +pipe = DreamLiteMobilePipeline.from_pretrained("carlofkl/DreamLite-mobile", revision="diffusers", torch_dtype=torch.bfloat16) +pipe = pipe.to("cuda") + +image = pipe( + prompt="a dog running on the grass", + height=1024, + width=1024, + num_inference_steps=4, + generator=torch.Generator("cpu").manual_seed(42), +).images[0] +image.save("dreamlite_mobile_t2i.png") +``` + +## Image editing (Mobile) + +```python +import torch +from diffusers import DreamLiteMobilePipeline +from diffusers.utils import load_image + +pipe = DreamLiteMobilePipeline.from_pretrained("carlofkl/DreamLite-mobile", revision="diffusers", torch_dtype=torch.bfloat16) +pipe = pipe.to("cuda") + +source = load_image("https://huggingface.co/datasets/diffusers/diffusers-images-docs/resolve/main/cat.png") + +image = pipe( + prompt="turn the cat into a corgi", + image=source, + height=1024, + width=1024, + num_inference_steps=4, + generator=torch.Generator("cpu").manual_seed(42), +).images[0] +image.save("dreamlite_mobile_edit.png") +``` + +## Notes and limitations + +* Both pipelines force `batch_size = 1` internally; `num_images_per_prompt` controls how many samples + are drawn from the same prompt rather than parallel batching. +* The prompt encoder is `Qwen3-VL`, which is a multimodal model. Loading the full pipeline therefore + requires sufficient GPU memory for both the U-Net and the Qwen3-VL text encoder (~4 GB + ~0.7 GB + in bf16 for the base release). +* The VAE is `AutoencoderTiny` and exposes `encoder_block_out_channels`; `vae_scale_factor` is derived + from it at pipeline init time. + +## DreamLitePipeline + +[[autodoc]] DreamLitePipeline + - all + - __call__ + +## DreamLiteMobilePipeline + +[[autodoc]] DreamLiteMobilePipeline + - all + - __call__ + +## DreamLitePipelineOutput + +[[autodoc]] pipelines.dreamlite.pipeline_output.DreamLitePipelineOutput diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 3a8332dc0c3a..d63cfedef9db 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -238,6 +238,8 @@ "CosmosControlNetModel", "CosmosTransformer3DModel", "DiTTransformer2DModel", + "DreamLiteTransformer2DModel", + "DreamLiteUNetModel", "EasyAnimateTransformer3DModel", "ErnieImageTransformer2DModel", "Flux2Transformer2DModel", @@ -546,6 +548,9 @@ "CosmosTextToWorldPipeline", "CosmosVideoToWorldPipeline", "CycleDiffusionPipeline", + "DreamLiteMobilePipeline", + "DreamLitePipeline", + "DreamLitePipelineOutput", "EasyAnimateControlPipeline", "EasyAnimateInpaintPipeline", "EasyAnimatePipeline", @@ -1071,6 +1076,8 @@ CosmosControlNetModel, CosmosTransformer3DModel, DiTTransformer2DModel, + DreamLiteTransformer2DModel, + DreamLiteUNetModel, EasyAnimateTransformer3DModel, ErnieImageTransformer2DModel, Flux2Transformer2DModel, @@ -1354,6 +1361,9 @@ CosmosTextToWorldPipeline, CosmosVideoToWorldPipeline, CycleDiffusionPipeline, + DreamLiteMobilePipeline, + DreamLitePipeline, + DreamLitePipelineOutput, EasyAnimateControlPipeline, EasyAnimateInpaintPipeline, EasyAnimatePipeline, diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index a4aea6361ece..e270a643ea21 100755 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -94,6 +94,7 @@ _import_structure["transformers.stable_audio_transformer"] = ["StableAudioDiTModel"] _import_structure["transformers.t5_film_transformer"] = ["T5FilmDecoder"] _import_structure["transformers.transformer_2d"] = ["Transformer2DModel"] + _import_structure["transformers.transformer_2d_dreamlite"] = ["DreamLiteTransformer2DModel"] _import_structure["transformers.transformer_allegro"] = ["AllegroTransformer3DModel"] _import_structure["transformers.transformer_anyflow"] = ["AnyFlowTransformer3DModel"] _import_structure["transformers.transformer_anyflow_far"] = ["AnyFlowFARTransformer3DModel"] @@ -141,6 +142,7 @@ _import_structure["unets.unet_2d"] = ["UNet2DModel"] _import_structure["unets.unet_2d_condition"] = ["UNet2DConditionModel"] _import_structure["unets.unet_3d_condition"] = ["UNet3DConditionModel"] + _import_structure["unets.unet_dreamlite"] = ["DreamLiteUNetModel"] _import_structure["unets.unet_i2vgen_xl"] = ["I2VGenXLUNet"] _import_structure["unets.unet_kandinsky3"] = ["Kandinsky3UNet"] _import_structure["unets.unet_motion_model"] = ["MotionAdapter", "UNetMotionModel"] @@ -229,6 +231,7 @@ ConsisIDTransformer3DModel, CosmosTransformer3DModel, DiTTransformer2DModel, + DreamLiteTransformer2DModel, DualTransformer2DModel, EasyAnimateTransformer3DModel, ErnieImageTransformer2DModel, @@ -274,6 +277,7 @@ ZImageTransformer2DModel, ) from .unets import ( + DreamLiteUNetModel, I2VGenXLUNet, Kandinsky3UNet, MotionAdapter, diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index e2ece5cb3685..d86bf7c35d9c 100755 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -2787,6 +2787,114 @@ def __call__( return hidden_states +class DreamLiteAttnProcessor2_0: + r""" + Processor for implementing scaled dot-product attention with Grouped Query Attention (GQA / MQA) support (enabled + by default if you're using PyTorch 2.0). + + Identical to :class:`AttnProcessor2_0` except the key/value reshape branch correctly handles ``attn.kv_heads != + attn.heads`` by reshaping K/V to ``kv_heads`` and then ``repeat_interleave``-ing them up to ``attn.heads``. This is + required by the DreamLite UNet, which combines GQA with ``qk_norm`` — a combination the default + :class:`AttnProcessor2_0` does not handle. + """ + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError( + "DreamLiteAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." + ) + + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + temb: torch.Tensor | None = None, + *args, + **kwargs, + ) -> torch.Tensor: + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + deprecate("scale", "1.0.0", deprecation_message) + + residual = hidden_states + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + # --- GQA-aware reshape (the only real difference vs AttnProcessor2_0) --- + head_dim = query.shape[-1] // attn.heads + kv_heads = key.shape[-1] // head_dim + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, kv_heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, kv_heads, head_dim).transpose(1, 2) + + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + if kv_heads != attn.heads: + # GQA / MQA: repeat K/V heads up to query heads for SDPA. + heads_per_kv_head = attn.heads // kv_heads + key = torch.repeat_interleave(key, heads_per_kv_head, dim=1, output_size=key.shape[1] * heads_per_kv_head) + value = torch.repeat_interleave( + value, heads_per_kv_head, dim=1, output_size=value.shape[1] * heads_per_kv_head + ) + # ------------------------------------------------------------------------ + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + class XLAFlashAttnProcessor2_0: r""" Processor for implementing scaled dot-product attention with pallas flash attention kernel if using `torch_xla`. diff --git a/src/diffusers/models/resnet_dreamlite.py b/src/diffusers/models/resnet_dreamlite.py new file mode 100644 index 000000000000..80238f9da926 --- /dev/null +++ b/src/diffusers/models/resnet_dreamlite.py @@ -0,0 +1,254 @@ +# Copyright 2026 ByteDance Ltd. and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from functools import partial +from typing import Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ..utils import deprecate +from .activations import get_activation +from .downsampling import Downsample2D, downsample_2d +from .upsampling import Upsample2D, upsample_2d + + +class DepthwiseSeparableConv(nn.Module): + """ + Depthwise separable convolution used by DreamLite mobile-friendly ResNet blocks. + + A depthwise convolution (groups == in_channels) followed by a 1x1 pointwise convolution. The pointwise output + channel count is multiplied by `expand_ratio` to support inverted-residual style expansion / contraction inside + [`ResnetBlock2DDreamLite`]. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int = 1, + padding: int = 0, + bias: bool = False, + expand_ratio: float = 1, + ): + super().__init__() + self.depthwise = nn.Conv2d( + in_channels, + in_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + groups=in_channels, + bias=bias, + ) + self.pointwise = nn.Conv2d(in_channels, int(out_channels * expand_ratio), kernel_size=1, bias=bias) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.depthwise(hidden_states) + hidden_states = self.pointwise(hidden_states) + return hidden_states + + +class ResnetBlock2DDreamLite(nn.Module): + r""" + A ResNet block used by DreamLite. Mirrors [`diffusers.models.resnet.ResnetBlock2D`] with one extra option: + + use_sep_conv (`bool`, *optional*, defaults to `False`): + Replace the two 3x3 convolutions with [`DepthwiseSeparableConv`]. The first conv expands the channel count + by 2x; the second conv contracts it back. Used by the mobile-friendly DreamLite checkpoints. + + All other parameters behave identically to [`diffusers.models.resnet.ResnetBlock2D`]. + """ + + def __init__( + self, + *, + in_channels: int, + out_channels: Optional[int] = None, + conv_shortcut: bool = False, + dropout: float = 0.0, + temb_channels: int = 512, + groups: int = 32, + groups_out: Optional[int] = None, + pre_norm: bool = True, + eps: float = 1e-6, + non_linearity: str = "swish", + skip_time_act: bool = False, + time_embedding_norm: str = "default", + kernel: Optional[torch.Tensor] = None, + output_scale_factor: float = 1.0, + use_in_shortcut: Optional[bool] = None, + up: bool = False, + down: bool = False, + conv_shortcut_bias: bool = True, + conv_2d_out_channels: Optional[int] = None, + use_sep_conv: bool = False, + ): + super().__init__() + if time_embedding_norm in ("ada_group", "spatial"): + raise ValueError( + f"`time_embedding_norm`={time_embedding_norm!r} is not supported by `ResnetBlock2DDreamLite`. " + "Use `diffusers.models.resnet.ResnetBlockCondNorm2D` instead." + ) + + self.pre_norm = True + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + self.up = up + self.down = down + self.output_scale_factor = output_scale_factor + self.time_embedding_norm = time_embedding_norm + self.skip_time_act = skip_time_act + + if groups_out is None: + groups_out = groups + + self.norm1 = nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True) + + # Inverted-residual style expansion when `use_sep_conv=True`: conv1 expands channels by 2x, + # conv2 contracts them back. For the standard branch this is just a regular 3x3 conv. + if use_sep_conv: + expand_ratio = 2 + self.conv1 = DepthwiseSeparableConv( + in_channels, out_channels, kernel_size=3, stride=1, padding=1, expand_ratio=expand_ratio + ) + out_channels = out_channels * expand_ratio + else: + expand_ratio = 1 + self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) + + if temb_channels is not None: + if self.time_embedding_norm == "default": + self.time_emb_proj = nn.Linear(temb_channels, out_channels) + elif self.time_embedding_norm == "scale_shift": + self.time_emb_proj = nn.Linear(temb_channels, 2 * out_channels) + else: + raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm}") + else: + self.time_emb_proj = None + + self.norm2 = nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True) + + self.dropout = nn.Dropout(dropout) + conv_2d_out_channels = conv_2d_out_channels or out_channels + if use_sep_conv: + self.conv2 = DepthwiseSeparableConv( + out_channels, + conv_2d_out_channels, + kernel_size=3, + stride=1, + padding=1, + expand_ratio=1 / expand_ratio, + ) + conv_2d_out_channels = conv_2d_out_channels // expand_ratio + else: + self.conv2 = nn.Conv2d(out_channels, conv_2d_out_channels, kernel_size=3, stride=1, padding=1) + + self.nonlinearity = get_activation(non_linearity) + + self.upsample = self.downsample = None + if self.up: + if kernel == "fir": + fir_kernel = (1, 3, 3, 1) + self.upsample = lambda x: upsample_2d(x, kernel=fir_kernel) + elif kernel == "sde_vp": + self.upsample = partial(F.interpolate, scale_factor=2.0, mode="nearest") + else: + self.upsample = Upsample2D(in_channels, use_conv=False) + elif self.down: + if kernel == "fir": + fir_kernel = (1, 3, 3, 1) + self.downsample = lambda x: downsample_2d(x, kernel=fir_kernel) + elif kernel == "sde_vp": + self.downsample = partial(F.avg_pool2d, kernel_size=2, stride=2) + else: + self.downsample = Downsample2D(in_channels, use_conv=False, padding=1, name="op") + + self.use_in_shortcut = self.in_channels != conv_2d_out_channels if use_in_shortcut is None else use_in_shortcut + + self.conv_shortcut = None + if self.use_in_shortcut: + self.conv_shortcut = nn.Conv2d( + in_channels, + conv_2d_out_channels, + kernel_size=1, + stride=1, + padding=0, + bias=conv_shortcut_bias, + ) + + def forward(self, input_tensor: torch.Tensor, temb: torch.Tensor, *args, **kwargs) -> torch.Tensor: + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = ( + "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise " + "an error in the future. `scale` should directly be passed while calling the underlying pipeline " + "component i.e., via `cross_attention_kwargs`." + ) + deprecate("scale", "1.0.0", deprecation_message) + + hidden_states = input_tensor + + hidden_states = self.norm1(hidden_states) + hidden_states = self.nonlinearity(hidden_states) + + if self.upsample is not None: + # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984 + if hidden_states.shape[0] >= 64: + input_tensor = input_tensor.contiguous() + hidden_states = hidden_states.contiguous() + input_tensor = self.upsample(input_tensor) + hidden_states = self.upsample(hidden_states) + elif self.downsample is not None: + input_tensor = self.downsample(input_tensor) + hidden_states = self.downsample(hidden_states) + + hidden_states = self.conv1(hidden_states) + + if self.time_emb_proj is not None: + if not self.skip_time_act: + temb = self.nonlinearity(temb) + temb = self.time_emb_proj(temb)[:, :, None, None] + + if self.time_embedding_norm == "default": + if temb is not None: + hidden_states = hidden_states + temb + hidden_states = self.norm2(hidden_states) + elif self.time_embedding_norm == "scale_shift": + if temb is None: + raise ValueError(f"`temb` should not be None when `time_embedding_norm` is {self.time_embedding_norm}") + time_scale, time_shift = torch.chunk(temb, 2, dim=1) + hidden_states = self.norm2(hidden_states) + hidden_states = hidden_states * (1 + time_scale) + time_shift + else: + hidden_states = self.norm2(hidden_states) + + hidden_states = self.nonlinearity(hidden_states) + + hidden_states = self.dropout(hidden_states) + hidden_states = self.conv2(hidden_states) + + if self.conv_shortcut is not None: + # Only call .contiguous() under training, to avoid DDP gradient-stride warnings while keeping + # inference fast (especially on CPU). Mirrors the upstream fix from huggingface/diffusers#12975. + if self.training: + input_tensor = input_tensor.contiguous() + input_tensor = self.conv_shortcut(input_tensor) + + output_tensor = (input_tensor + hidden_states) / self.output_scale_factor + + return output_tensor diff --git a/src/diffusers/models/transformers/__init__.py b/src/diffusers/models/transformers/__init__.py index bb10b101c1b9..c4f1bb38d034 100755 --- a/src/diffusers/models/transformers/__init__.py +++ b/src/diffusers/models/transformers/__init__.py @@ -17,6 +17,7 @@ from .stable_audio_transformer import StableAudioDiTModel from .t5_film_transformer import T5FilmDecoder from .transformer_2d import Transformer2DModel + from .transformer_2d_dreamlite import DreamLiteTransformer2DModel from .transformer_allegro import AllegroTransformer3DModel from .transformer_anyflow import AnyFlowTransformer3DModel from .transformer_anyflow_far import AnyFlowFARTransformer3DModel diff --git a/src/diffusers/models/transformers/transformer_2d_dreamlite.py b/src/diffusers/models/transformers/transformer_2d_dreamlite.py new file mode 100644 index 000000000000..da8ed1cad04d --- /dev/null +++ b/src/diffusers/models/transformers/transformer_2d_dreamlite.py @@ -0,0 +1,599 @@ +# Copyright (c) 2026 ByteDance Ltd. and/or its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""DreamLite 2D transformer. + +This module is intentionally self-contained: it defines + +* ``BasicTransformerBlockDreamLite`` — a DreamLite-flavoured variant of + :class:`~diffusers.models.attention.BasicTransformerBlock` with four additional knobs (``use_self_attention``, + ``qk_norm``, ``num_kv_heads``, ``ff_mult``); and +* ``DreamLiteTransformer2DModel`` — a continuous-input-only counterpart of + :class:`~diffusers.models.transformers.transformer_2d.Transformer2DModel` that wires those knobs all the way down to + each block. + +Keeping everything here means the DreamLite integration never touches the upstream ``attention.py`` / +``transformer_2d.py``, which is the convention followed by other ported pipelines (SD3, Flux, Chroma, …). + +The numerical behaviour mirrors the original DreamLite reference implementation at ``dreamlite/models/{attention.py, +transformers/transformer_2d.py}`` — specifically, when ``use_self_attention=False`` the block keeps ``norm1``'s output +as the post-self-attn hidden state instead of running ``attn1``, matching the "Remove self-attention" path used by +DreamLite's ``CrossAttnDownRemoveSelfAttnBlock2D`` and ``CrossAttnUpRemoveSelfAttnBlock2DV1``. +""" + +from typing import Any + +import torch +from torch import nn + +from ...configuration_utils import ConfigMixin, register_to_config +from ...utils import logging +from ..attention import FeedForward, GatedSelfAttentionDense, _chunked_feed_forward +from ..attention_processor import Attention +from ..embeddings import SinusoidalPositionalEmbedding +from ..modeling_utils import ModelMixin +from ..normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero +from .transformer_2d import Transformer2DModelOutput + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class BasicTransformerBlockDreamLite(nn.Module): + r"""DreamLite variant of :class:`BasicTransformerBlock`. + + Adds four constructor knobs on top of the upstream block: + + * ``use_self_attention`` — when ``False``, ``attn1`` is *not* instantiated and the self-attention residual branch + in ``forward`` is replaced by ``norm1``'s output (no add-residual). This implements DreamLite's "Remove + self-attention" trick used inside ``CrossAttnDownRemoveSelfAttnBlock2DDreamLite`` / + ``CrossAttnUpRemoveSelfAttnBlock2DV1DreamLite``. + * ``qk_norm`` — propagated to both attention layers' ``qk_norm``. + * ``num_kv_heads`` — propagated to both attention layers' ``kv_heads`` (enables Grouped-Query Attention). + * ``ff_mult`` — propagated to :class:`FeedForward.mult` (DreamLite uses a non-default expansion factor). + + Only the ``norm_type`` values actually exercised by DreamLite are supported in detail (``layer_norm`` and + ``ada_norm``); the other branches are preserved verbatim from the upstream block so that callers writing new + variants do not have to re-port them. + """ + + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + dropout: float = 0.0, + cross_attention_dim: int | None = None, + activation_fn: str = "geglu", + num_embeds_ada_norm: int | None = None, + attention_bias: bool = False, + only_cross_attention: bool = False, + double_self_attention: bool = False, + upcast_attention: bool = False, + norm_elementwise_affine: bool = True, + norm_type: str = "layer_norm", + norm_eps: float = 1e-5, + final_dropout: bool = False, + attention_type: str = "default", + positional_embeddings: str | None = None, + num_positional_embeddings: int | None = None, + ada_norm_continous_conditioning_embedding_dim: int | None = None, + ada_norm_bias: int | None = None, + ff_inner_dim: int | None = None, + ff_bias: bool = True, + attention_out_bias: bool = True, + use_self_attention: bool = True, + qk_norm: str | None = None, + num_kv_heads: int | None = None, + ff_mult: int = 4, + ): + super().__init__() + self.dim = dim + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + self.dropout = dropout + self.cross_attention_dim = cross_attention_dim + self.activation_fn = activation_fn + self.attention_bias = attention_bias + self.double_self_attention = double_self_attention + self.norm_elementwise_affine = norm_elementwise_affine + self.positional_embeddings = positional_embeddings + self.num_positional_embeddings = num_positional_embeddings + self.only_cross_attention = only_cross_attention + self.use_self_attention = use_self_attention + + if not use_self_attention and norm_type in ("ada_norm_zero", "ada_norm_single"): + raise ValueError( + f"`use_self_attention=False` is incompatible with `norm_type={norm_type}` because " + "the gate/shift/scale modulation tuple is derived from `norm1`. " + "Use `norm_type='layer_norm'` or `'ada_norm'` instead." + ) + + # Backward-compatible boolean flags (kept for parity with BasicTransformerBlock). + self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero" + self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm" + self.use_ada_layer_norm_single = norm_type == "ada_norm_single" + self.use_layer_norm = norm_type == "layer_norm" + self.use_ada_layer_norm_continuous = norm_type == "ada_norm_continuous" + + if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None: + raise ValueError( + f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. " + f"Please make sure to define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}." + ) + + self.norm_type = norm_type + self.num_embeds_ada_norm = num_embeds_ada_norm + + if positional_embeddings and (num_positional_embeddings is None): + raise ValueError( + "If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined." + ) + + if positional_embeddings == "sinusoidal": + self.pos_embed = SinusoidalPositionalEmbedding(dim, max_seq_length=num_positional_embeddings) + else: + self.pos_embed = None + + # 1. Self-Attn (or its replacement) + if norm_type == "ada_norm": + self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) + elif norm_type == "ada_norm_zero": + self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm) + elif norm_type == "ada_norm_continuous": + self.norm1 = AdaLayerNormContinuous( + dim, + ada_norm_continous_conditioning_embedding_dim, + norm_elementwise_affine, + norm_eps, + ada_norm_bias, + "rms_norm", + ) + else: + self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps) + + if use_self_attention: + self.attn1 = Attention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + cross_attention_dim=cross_attention_dim if only_cross_attention else None, + upcast_attention=upcast_attention, + out_bias=attention_out_bias, + qk_norm=qk_norm, + kv_heads=num_kv_heads, + ) + else: + self.attn1 = None + + # 2. Cross-Attn + if cross_attention_dim is not None or double_self_attention: + if norm_type == "ada_norm": + self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm) + elif norm_type == "ada_norm_continuous": + self.norm2 = AdaLayerNormContinuous( + dim, + ada_norm_continous_conditioning_embedding_dim, + norm_elementwise_affine, + norm_eps, + ada_norm_bias, + "rms_norm", + ) + else: + self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine) + + self.attn2 = Attention( + query_dim=dim, + cross_attention_dim=cross_attention_dim if not double_self_attention else None, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + upcast_attention=upcast_attention, + out_bias=attention_out_bias, + qk_norm=qk_norm, + kv_heads=num_kv_heads, + ) + else: + if norm_type == "ada_norm_single": + self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine) + else: + self.norm2 = None + self.attn2 = None + + # 3. Feed-forward + if norm_type == "ada_norm_continuous": + self.norm3 = AdaLayerNormContinuous( + dim, + ada_norm_continous_conditioning_embedding_dim, + norm_elementwise_affine, + norm_eps, + ada_norm_bias, + "layer_norm", + ) + elif norm_type in ["ada_norm_zero", "ada_norm", "layer_norm"]: + self.norm3 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine) + elif norm_type == "layer_norm_i2vgen": + self.norm3 = None + + self.ff = FeedForward( + dim, + dropout=dropout, + activation_fn=activation_fn, + final_dropout=final_dropout, + inner_dim=ff_inner_dim, + bias=ff_bias, + mult=ff_mult, + ) + + # 4. Fuser + if attention_type == "gated" or attention_type == "gated-text-image": + self.fuser = GatedSelfAttentionDense(dim, cross_attention_dim, num_attention_heads, attention_head_dim) + + # 5. Scale-shift for PixArt-Alpha (kept for completeness; DreamLite does not use it). + if norm_type == "ada_norm_single": + self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5) + + # let chunk size default to None + self._chunk_size = None + self._chunk_dim = 0 + + def set_chunk_feed_forward(self, chunk_size: int | None, dim: int = 0): + self._chunk_size = chunk_size + self._chunk_dim = dim + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor | None = None, + encoder_hidden_states: torch.Tensor | None = None, + encoder_attention_mask: torch.Tensor | None = None, + timestep: torch.LongTensor | None = None, + cross_attention_kwargs: dict[str, Any] = None, + class_labels: torch.LongTensor | None = None, + added_cond_kwargs: dict[str, torch.Tensor] | None = None, + ) -> torch.Tensor: + if cross_attention_kwargs is not None: + if cross_attention_kwargs.get("scale", None) is not None: + logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") + + # 0. Self-Attention norm + batch_size = hidden_states.shape[0] + + if self.norm_type == "ada_norm": + norm_hidden_states = self.norm1(hidden_states, timestep) + elif self.norm_type == "ada_norm_zero": + norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1( + hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype + ) + elif self.norm_type in ["layer_norm", "layer_norm_i2vgen"]: + norm_hidden_states = self.norm1(hidden_states) + elif self.norm_type == "ada_norm_continuous": + norm_hidden_states = self.norm1(hidden_states, added_cond_kwargs["pooled_text_emb"]) + elif self.norm_type == "ada_norm_single": + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( + self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1) + ).chunk(6, dim=1) + norm_hidden_states = self.norm1(hidden_states) + norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa + else: + raise ValueError("Incorrect norm used") + + if self.pos_embed is not None: + norm_hidden_states = self.pos_embed(norm_hidden_states) + + # 1. GLIGEN kwargs split + cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {} + gligen_kwargs = cross_attention_kwargs.pop("gligen", None) + + if self.use_self_attention: + attn_output = self.attn1( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, + attention_mask=attention_mask, + **cross_attention_kwargs, + ) + + if self.norm_type == "ada_norm_zero": + attn_output = gate_msa.unsqueeze(1) * attn_output + elif self.norm_type == "ada_norm_single": + attn_output = gate_msa * attn_output + + hidden_states = attn_output + hidden_states + if hidden_states.ndim == 4: + hidden_states = hidden_states.squeeze(1) + else: + # DreamLite "Remove self-attention" path: drop attn1 entirely and let + # the normalized state propagate as-is to cross-attn / FF. Matches + # upstream DreamLite `BasicTransformerBlock.forward` when + # `use_self_attention=False`. + hidden_states = norm_hidden_states + if hidden_states.ndim == 4: + hidden_states = hidden_states.squeeze(1) + + # 1.2 GLIGEN control + if gligen_kwargs is not None: + hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"]) + + # 3. Cross-Attention + if self.attn2 is not None: + if self.norm_type == "ada_norm": + norm_hidden_states = self.norm2(hidden_states, timestep) + elif self.norm_type in ["ada_norm_zero", "layer_norm", "layer_norm_i2vgen"]: + norm_hidden_states = self.norm2(hidden_states) + elif self.norm_type == "ada_norm_single": + norm_hidden_states = hidden_states + elif self.norm_type == "ada_norm_continuous": + norm_hidden_states = self.norm2(hidden_states, added_cond_kwargs["pooled_text_emb"]) + else: + raise ValueError("Incorrect norm") + + if self.pos_embed is not None and self.norm_type != "ada_norm_single": + norm_hidden_states = self.pos_embed(norm_hidden_states) + + attn_output = self.attn2( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + **cross_attention_kwargs, + ) + hidden_states = attn_output + hidden_states + + # 4. Feed-forward + if self.norm_type == "ada_norm_continuous": + norm_hidden_states = self.norm3(hidden_states, added_cond_kwargs["pooled_text_emb"]) + elif not self.norm_type == "ada_norm_single": + norm_hidden_states = self.norm3(hidden_states) + + if self.norm_type == "ada_norm_zero": + norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + + if self.norm_type == "ada_norm_single": + norm_hidden_states = self.norm2(hidden_states) + norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp + + if self._chunk_size is not None: + ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size) + else: + ff_output = self.ff(norm_hidden_states) + + if self.norm_type == "ada_norm_zero": + ff_output = gate_mlp.unsqueeze(1) * ff_output + elif self.norm_type == "ada_norm_single": + ff_output = gate_mlp * ff_output + + hidden_states = ff_output + hidden_states + if hidden_states.ndim == 4: + hidden_states = hidden_states.squeeze(1) + + return hidden_states + + +class DreamLiteTransformer2DModel(ModelMixin, ConfigMixin): + r"""Continuous-input 2D transformer used by the DreamLite U-Net. + + Equivalent to :class:`Transformer2DModel` restricted to the ``is_input_continuous`` branch (``in_channels`` set, + ``patch_size`` and ``num_vector_embeds`` both ``None``), with four extra knobs that are propagated into every + :class:`BasicTransformerBlockDreamLite`: + + * ``use_self_attention`` — set ``False`` from ``CrossAttn*RemoveSelfAttnBlock2D*DreamLite`` to enable DreamLite's + "Remove self-attention" path. + * ``qk_norm`` — RMS/LayerNorm applied to Q and K projections. + * ``num_kv_heads`` — enables Grouped-Query Attention when fewer than ``num_attention_heads``. + * ``ff_mult`` — feed-forward expansion factor (DreamLite uses a non-default value). + """ + + _supports_gradient_checkpointing = True + _no_split_modules = ["BasicTransformerBlockDreamLite"] + _skip_layerwise_casting_patterns = ["norm"] + + @register_to_config + def __init__( + self, + num_attention_heads: int = 16, + attention_head_dim: int = 88, + in_channels: int | None = None, + out_channels: int | None = None, + num_layers: int = 1, + dropout: float = 0.0, + norm_num_groups: int = 32, + cross_attention_dim: int | None = None, + attention_bias: bool = False, + activation_fn: str = "geglu", + num_embeds_ada_norm: int | None = None, + use_linear_projection: bool = False, + only_cross_attention: bool = False, + double_self_attention: bool = False, + upcast_attention: bool = False, + norm_type: str = "layer_norm", + norm_elementwise_affine: bool = True, + norm_eps: float = 1e-5, + attention_type: str = "default", + use_self_attention: bool = True, + qk_norm: str | None = None, + num_kv_heads: int | None = None, + ff_mult: int = 4, + ): + super().__init__() + + if in_channels is None: + raise ValueError( + "`DreamLiteTransformer2DModel` only supports continuous inputs; `in_channels` must be provided." + ) + + self.use_linear_projection = use_linear_projection + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim + self.in_channels = in_channels + self.out_channels = in_channels if out_channels is None else out_channels + self.gradient_checkpointing = False + + self.norm = torch.nn.GroupNorm( + num_groups=self.config.norm_num_groups, num_channels=self.in_channels, eps=1e-6, affine=True + ) + if self.use_linear_projection: + self.proj_in = torch.nn.Linear(self.in_channels, self.inner_dim) + else: + self.proj_in = torch.nn.Conv2d(self.in_channels, self.inner_dim, kernel_size=1, stride=1, padding=0) + + self.transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlockDreamLite( + self.inner_dim, + self.config.num_attention_heads, + self.config.attention_head_dim, + dropout=self.config.dropout, + cross_attention_dim=self.config.cross_attention_dim, + activation_fn=self.config.activation_fn, + num_embeds_ada_norm=self.config.num_embeds_ada_norm, + attention_bias=self.config.attention_bias, + only_cross_attention=self.config.only_cross_attention, + double_self_attention=self.config.double_self_attention, + upcast_attention=self.config.upcast_attention, + norm_type=norm_type, + norm_elementwise_affine=self.config.norm_elementwise_affine, + norm_eps=self.config.norm_eps, + attention_type=self.config.attention_type, + use_self_attention=self.config.use_self_attention, + qk_norm=self.config.qk_norm, + num_kv_heads=self.config.num_kv_heads, + ff_mult=self.config.ff_mult, + ) + for _ in range(self.config.num_layers) + ] + ) + + if self.use_linear_projection: + self.proj_out = torch.nn.Linear(self.inner_dim, self.out_channels) + else: + self.proj_out = torch.nn.Conv2d(self.inner_dim, self.out_channels, kernel_size=1, stride=1, padding=0) + + def _operate_on_continuous_inputs(self, hidden_states): + batch, _, height, width = hidden_states.shape + hidden_states = self.norm(hidden_states) + + if not self.use_linear_projection: + hidden_states = self.proj_in(hidden_states) + inner_dim = hidden_states.shape[1] + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) + else: + inner_dim = hidden_states.shape[1] + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) + hidden_states = self.proj_in(hidden_states) + + return hidden_states, inner_dim + + def _get_output_for_continuous_inputs(self, hidden_states, residual, batch_size, height, width, inner_dim): + if not self.use_linear_projection: + hidden_states = ( + hidden_states.reshape(batch_size, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() + ) + hidden_states = self.proj_out(hidden_states) + else: + hidden_states = self.proj_out(hidden_states) + hidden_states = ( + hidden_states.reshape(batch_size, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() + ) + + output = hidden_states + residual + return output + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor | None = None, + timestep: torch.LongTensor | None = None, + added_cond_kwargs: dict[str, torch.Tensor] = None, + class_labels: torch.LongTensor | None = None, + cross_attention_kwargs: dict[str, Any] = None, + attention_mask: torch.Tensor | None = None, + encoder_attention_mask: torch.Tensor | None = None, + return_dict: bool = True, + ): + """Forward pass of :class:`DreamLiteTransformer2DModel`. + + Args: + hidden_states: Input latent tensor of shape ``(batch, channels, height, width)``. + encoder_hidden_states: Cross-attention conditioning embeddings. + timestep: Diffusion timestep(s); broadcast to batch if scalar. + added_cond_kwargs: Optional extra conditioning (e.g. ``text_embeds``, ``time_ids``). + class_labels: Optional class labels for class-conditional generation. + cross_attention_kwargs: Optional kwargs forwarded to the cross-attention processor. + Note: passing ``scale`` is deprecated and will be ignored. + attention_mask: Optional self-attention mask; 2D masks are converted to additive biases. + encoder_attention_mask: Optional cross-attention mask; 2D masks are converted to additive biases. + return_dict: If ``True``, returns a :class:`Transformer2DModelOutput`; otherwise a 1-tuple ``(sample,)``. + + Returns: + :class:`~diffusers.models.transformers.transformer_2d.Transformer2DModelOutput` (or a 1-tuple of the + sample) — kept output-compatible with the upstream class so callers don't have to special-case DreamLite. + """ + if cross_attention_kwargs is not None: + if cross_attention_kwargs.get("scale", None) is not None: + logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") + + # Convert masks to additive biases (broadcast-friendly). + if attention_mask is not None and attention_mask.ndim == 2: + attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0 + attention_mask = attention_mask.unsqueeze(1) + + if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2: + encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0 + encoder_attention_mask = encoder_attention_mask.unsqueeze(1) + + # 1. Input + batch_size, _, height, width = hidden_states.shape + residual = hidden_states + hidden_states, inner_dim = self._operate_on_continuous_inputs(hidden_states) + + # 2. Blocks + for block in self.transformer_blocks: + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states = self._gradient_checkpointing_func( + block, + hidden_states, + attention_mask, + encoder_hidden_states, + encoder_attention_mask, + timestep, + cross_attention_kwargs, + class_labels, + ) + else: + hidden_states = block( + hidden_states, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + timestep=timestep, + cross_attention_kwargs=cross_attention_kwargs, + class_labels=class_labels, + ) + + # 3. Output + output = self._get_output_for_continuous_inputs( + hidden_states=hidden_states, + residual=residual, + batch_size=batch_size, + height=height, + width=width, + inner_dim=inner_dim, + ) + + if not return_dict: + return (output,) + + return Transformer2DModelOutput(sample=output) diff --git a/src/diffusers/models/unets/__init__.py b/src/diffusers/models/unets/__init__.py index 9ef04fb62606..394df72261c6 100644 --- a/src/diffusers/models/unets/__init__.py +++ b/src/diffusers/models/unets/__init__.py @@ -6,6 +6,7 @@ from .unet_2d import UNet2DModel from .unet_2d_condition import UNet2DConditionModel from .unet_3d_condition import UNet3DConditionModel + from .unet_dreamlite import DreamLiteUNetModel from .unet_i2vgen_xl import I2VGenXLUNet from .unet_kandinsky3 import Kandinsky3UNet from .unet_motion_model import MotionAdapter, UNetMotionModel diff --git a/src/diffusers/models/unets/unet_2d_blocks_dreamlite.py b/src/diffusers/models/unets/unet_2d_blocks_dreamlite.py new file mode 100644 index 000000000000..6c6b2d458ef1 --- /dev/null +++ b/src/diffusers/models/unets/unet_2d_blocks_dreamlite.py @@ -0,0 +1,705 @@ +# Copyright (c) 2026 ByteDance Ltd. and/or its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +DreamLite-specific UNet 2D blocks. + +These mirror the upstream ``unet_2d_blocks`` Down/Mid/Up cross-attention blocks but additionally thread the +DreamLite-specific knobs: + +- ``use_sep_conv``: replace standard convs in ResnetBlock2DDreamLite with depthwise-separable convs (mobile-friendly). +- ``qk_norm``, ``num_kv_heads``, ``ff_mult``: propagated into DreamLiteTransformer2DModel / BasicTransformerBlock. +- ``RemoveSelfAttn`` variants hard-code ``use_self_attention=False`` in their DreamLiteTransformer2DModel calls. +""" + +from __future__ import annotations + +from typing import Any + +import torch +from torch import nn + +from ..attention_processor import Attention # noqa: F401 (re-export friendliness) +from ..resnet_dreamlite import ResnetBlock2DDreamLite +from ..transformers.dual_transformer_2d import DualTransformer2DModel +from ..transformers.transformer_2d_dreamlite import DreamLiteTransformer2DModel +from .unet_2d_blocks import Downsample2D, Upsample2D, apply_freeu + + +# --------------------------------------------------------------------------- +# Mid block +# --------------------------------------------------------------------------- +class UNetMidBlock2DCrossAttnDreamLite(nn.Module): + def __init__( + self, + in_channels: int, + temb_channels: int, + out_channels: int | None = None, + dropout: float = 0.0, + num_layers: int = 1, + transformer_layers_per_block: int | tuple[int] = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_groups_out: int | None = None, + resnet_pre_norm: bool = True, + num_attention_heads: int = 1, + output_scale_factor: float = 1.0, + cross_attention_dim: int = 1280, + dual_cross_attention: bool = False, + use_linear_projection: bool = False, + upcast_attention: bool = False, + attention_type: str = "default", + # DreamLite extras + qk_norm: str | None = None, + use_sep_conv: bool = False, + ff_mult: int = 4, + num_kv_heads: int | None = None, + num_mid_layers: int = 1, + ): + super().__init__() + + out_channels = out_channels or in_channels + self.in_channels = in_channels + self.out_channels = out_channels + + self.has_cross_attention = True + self.num_attention_heads = num_attention_heads + resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) + + if isinstance(transformer_layers_per_block, int): + transformer_layers_per_block = [transformer_layers_per_block] * num_layers + + resnet_groups_out = resnet_groups_out or resnet_groups + + resnets = [ + ResnetBlock2DDreamLite( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + groups_out=resnet_groups_out, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + use_sep_conv=use_sep_conv, + ) + ] + attentions = [] + + for i in range(num_layers): + if not dual_cross_attention: + attentions.append( + DreamLiteTransformer2DModel( + num_attention_heads, + out_channels // num_attention_heads, + in_channels=out_channels, + num_layers=transformer_layers_per_block[i], + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups_out, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + attention_type=attention_type, + qk_norm=qk_norm, + ff_mult=ff_mult, + num_kv_heads=num_kv_heads, + ) + ) + else: + attentions.append( + DualTransformer2DModel( + num_attention_heads, + out_channels // num_attention_heads, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + ) + ) + resnets.append( + ResnetBlock2DDreamLite( + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups_out, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + use_sep_conv=use_sep_conv, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + temb: torch.Tensor | None = None, + encoder_hidden_states: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + cross_attention_kwargs: dict[str, Any] | None = None, + encoder_attention_mask: torch.Tensor | None = None, + ) -> torch.Tensor: + hidden_states = self.resnets[0](hidden_states, temb) + for attn, resnet in zip(self.attentions, self.resnets[1:]): + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb) + else: + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + hidden_states = resnet(hidden_states, temb) + + return hidden_states + + +# --------------------------------------------------------------------------- +# Down blocks +# --------------------------------------------------------------------------- +def _make_down_block_class(class_name: str, *, remove_self_attn: bool): + class _DownBlock(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + transformer_layers_per_block: int | tuple[int] = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + num_attention_heads: int = 1, + cross_attention_dim: int = 1280, + output_scale_factor: float = 1.0, + downsample_padding: int = 1, + add_downsample: bool = True, + dual_cross_attention: bool = False, + use_linear_projection: bool = False, + only_cross_attention: bool = False, + upcast_attention: bool = False, + attention_type: str = "default", + # DreamLite extras + qk_norm: str | None = None, + use_sep_conv: bool = False, + ff_mult: int = 4, + num_kv_heads: int | None = None, + ): + super().__init__() + resnets = [] + attentions = [] + + self.has_cross_attention = True + self.num_attention_heads = num_attention_heads + if isinstance(transformer_layers_per_block, int): + transformer_layers_per_block = [transformer_layers_per_block] * num_layers + + for i in range(num_layers): + in_ch = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock2DDreamLite( + in_channels=in_ch, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + use_sep_conv=use_sep_conv, + ) + ) + if not dual_cross_attention: + tf_kwargs = { + "num_attention_heads": num_attention_heads, + "attention_head_dim": out_channels // num_attention_heads, + "in_channels": out_channels, + "num_layers": transformer_layers_per_block[i], + "cross_attention_dim": cross_attention_dim, + "norm_num_groups": resnet_groups, + "use_linear_projection": use_linear_projection, + "only_cross_attention": only_cross_attention, + "upcast_attention": upcast_attention, + "attention_type": attention_type, + "qk_norm": qk_norm, + "ff_mult": ff_mult, + "num_kv_heads": num_kv_heads, + } + if remove_self_attn: + tf_kwargs["use_self_attention"] = False + attentions.append(DreamLiteTransformer2DModel(**tf_kwargs)) + else: + attentions.append( + DualTransformer2DModel( + num_attention_heads, + out_channels // num_attention_heads, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + Downsample2D( + out_channels, + use_conv=True, + out_channels=out_channels, + padding=downsample_padding, + name="op", + ) + ] + ) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + temb: torch.Tensor | None = None, + encoder_hidden_states: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + cross_attention_kwargs: dict[str, Any] | None = None, + encoder_attention_mask: torch.Tensor | None = None, + additional_residuals: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, tuple[torch.Tensor, ...]]: + output_states: tuple[torch.Tensor, ...] = () + blocks = list(zip(self.resnets, self.attentions)) + + for i, (resnet, attn) in enumerate(blocks): + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb) + else: + hidden_states = resnet(hidden_states, temb) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + + if i == len(blocks) - 1 and additional_residuals is not None: + hidden_states = hidden_states + additional_residuals + + output_states = output_states + (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + output_states = output_states + (hidden_states,) + + return hidden_states, output_states + + _DownBlock.__name__ = class_name + _DownBlock.__qualname__ = class_name + return _DownBlock + + +CrossAttnDownBlock2DDreamLite = _make_down_block_class("CrossAttnDownBlock2DDreamLite", remove_self_attn=False) +CrossAttnDownRemoveSelfAttnBlock2DDreamLite = _make_down_block_class( + "CrossAttnDownRemoveSelfAttnBlock2DDreamLite", remove_self_attn=True +) + + +# --------------------------------------------------------------------------- +# Up blocks +# --------------------------------------------------------------------------- +def _make_up_block_class(class_name: str, *, remove_self_attn: bool): + class _UpBlock(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + prev_output_channel: int, + temb_channels: int, + resolution_idx: int | None = None, + dropout: float = 0.0, + num_layers: int = 1, + transformer_layers_per_block: int | tuple[int] = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + num_attention_heads: int = 1, + cross_attention_dim: int = 1280, + output_scale_factor: float = 1.0, + add_upsample: bool = True, + dual_cross_attention: bool = False, + use_linear_projection: bool = False, + only_cross_attention: bool = False, + upcast_attention: bool = False, + attention_type: str = "default", + # DreamLite extras + qk_norm: str | None = None, + use_sep_conv: bool = False, + ff_mult: int = 4, + num_kv_heads: int | None = None, + ): + super().__init__() + resnets = [] + attentions = [] + + self.has_cross_attention = True + self.num_attention_heads = num_attention_heads + + if isinstance(transformer_layers_per_block, int): + transformer_layers_per_block = [transformer_layers_per_block] * num_layers + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock2DDreamLite( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + use_sep_conv=use_sep_conv, + ) + ) + if not dual_cross_attention: + tf_kwargs = { + "num_attention_heads": num_attention_heads, + "attention_head_dim": out_channels // num_attention_heads, + "in_channels": out_channels, + "num_layers": transformer_layers_per_block[i], + "cross_attention_dim": cross_attention_dim, + "norm_num_groups": resnet_groups, + "use_linear_projection": use_linear_projection, + "only_cross_attention": only_cross_attention, + "upcast_attention": upcast_attention, + "attention_type": attention_type, + "qk_norm": qk_norm, + "ff_mult": ff_mult, + "num_kv_heads": num_kv_heads, + } + if remove_self_attn: + tf_kwargs["use_self_attention"] = False + attentions.append(DreamLiteTransformer2DModel(**tf_kwargs)) + else: + attentions.append( + DualTransformer2DModel( + num_attention_heads, + out_channels // num_attention_heads, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + if add_upsample: + self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + self.resolution_idx = resolution_idx + + def forward( + self, + hidden_states: torch.Tensor, + res_hidden_states_tuple: tuple[torch.Tensor, ...], + temb: torch.Tensor | None = None, + encoder_hidden_states: torch.Tensor | None = None, + cross_attention_kwargs: dict[str, Any] | None = None, + upsample_size: int | None = None, + attention_mask: torch.Tensor | None = None, + encoder_attention_mask: torch.Tensor | None = None, + ) -> torch.Tensor: + is_freeu_enabled = ( + getattr(self, "s1", None) + and getattr(self, "s2", None) + and getattr(self, "b1", None) + and getattr(self, "b2", None) + ) + + for resnet, attn in zip(self.resnets, self.attentions): + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + + if is_freeu_enabled: + hidden_states, res_hidden_states = apply_freeu( + self.resolution_idx, + hidden_states, + res_hidden_states, + s1=self.s1, + s2=self.s2, + b1=self.b1, + b2=self.b2, + ) + + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb) + else: + hidden_states = resnet(hidden_states, temb) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + + return hidden_states + + _UpBlock.__name__ = class_name + _UpBlock.__qualname__ = class_name + return _UpBlock + + +CrossAttnUpBlock2DDreamLite = _make_up_block_class("CrossAttnUpBlock2DDreamLite", remove_self_attn=False) +CrossAttnUpRemoveSelfAttnBlock2DV1DreamLite = _make_up_block_class( + "CrossAttnUpRemoveSelfAttnBlock2DV1DreamLite", remove_self_attn=True +) + + +# --------------------------------------------------------------------------- +# Plain resnet-only blocks (no attention) +# --------------------------------------------------------------------------- +class DownBlock2DDreamLite(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor: float = 1.0, + add_downsample: bool = True, + downsample_padding: int = 1, + use_sep_conv: bool = False, + ): + super().__init__() + resnets = [] + for i in range(num_layers): + in_ch = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock2DDreamLite( + in_channels=in_ch, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + use_sep_conv=use_sep_conv, + ) + ) + self.resnets = nn.ModuleList(resnets) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + Downsample2D( + out_channels, + use_conv=True, + out_channels=out_channels, + padding=downsample_padding, + name="op", + ) + ] + ) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + temb: torch.Tensor | None = None, + **kwargs, + ) -> tuple[torch.Tensor, tuple[torch.Tensor, ...]]: + output_states: tuple[torch.Tensor, ...] = () + for resnet in self.resnets: + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb) + else: + hidden_states = resnet(hidden_states, temb) + output_states = output_states + (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + output_states = output_states + (hidden_states,) + + return hidden_states, output_states + + +class UpBlock2DDreamLite(nn.Module): + def __init__( + self, + in_channels: int, + prev_output_channel: int, + out_channels: int, + temb_channels: int, + resolution_idx: int | None = None, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor: float = 1.0, + add_upsample: bool = True, + use_sep_conv: bool = False, + ): + super().__init__() + resnets = [] + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock2DDreamLite( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + use_sep_conv=use_sep_conv, + ) + ) + self.resnets = nn.ModuleList(resnets) + + if add_upsample: + self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + self.resolution_idx = resolution_idx + + def forward( + self, + hidden_states: torch.Tensor, + res_hidden_states_tuple: tuple[torch.Tensor, ...], + temb: torch.Tensor | None = None, + upsample_size: int | None = None, + **kwargs, + ) -> torch.Tensor: + is_freeu_enabled = ( + getattr(self, "s1", None) + and getattr(self, "s2", None) + and getattr(self, "b1", None) + and getattr(self, "b2", None) + ) + + for resnet in self.resnets: + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + + if is_freeu_enabled: + hidden_states, res_hidden_states = apply_freeu( + self.resolution_idx, + hidden_states, + res_hidden_states, + s1=self.s1, + s2=self.s2, + b1=self.b1, + b2=self.b2, + ) + + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb) + else: + hidden_states = resnet(hidden_states, temb) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + + return hidden_states + + +__all__ = [ + "UNetMidBlock2DCrossAttnDreamLite", + "CrossAttnDownBlock2DDreamLite", + "CrossAttnDownRemoveSelfAttnBlock2DDreamLite", + "CrossAttnUpBlock2DDreamLite", + "CrossAttnUpRemoveSelfAttnBlock2DV1DreamLite", + "DownBlock2DDreamLite", + "UpBlock2DDreamLite", +] diff --git a/src/diffusers/models/unets/unet_dreamlite.py b/src/diffusers/models/unets/unet_dreamlite.py new file mode 100644 index 000000000000..2060a6c1f501 --- /dev/null +++ b/src/diffusers/models/unets/unet_dreamlite.py @@ -0,0 +1,703 @@ +# Copyright (c) 2026 ByteDance Ltd. and/or its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +DreamLite UNet model. + +This module defines :class:`DreamLiteUNetModel`, a subclass of :class:`UNet2DConditionModel` that: + +* swaps every Down / Mid / Up block for the DreamLite variants defined in :mod:`unet_2d_blocks_dreamlite`, which + support ``use_sep_conv``, ``qk_norm``, ``num_kv_heads`` and ``ff_mult``; +* defaults its attention processors to :class:`DreamLiteAttnProcessor2_0` (GQA-aware SDPA), which is required because + the upstream ``AttnProcessor2_0`` does not handle ``kv_heads != heads`` correctly. + +Everything else (forward pass, time / class / additional / encoder-hid embeddings, conv-in / conv-out, GLIGEN +positional net, etc.) is inherited unchanged from :class:`UNet2DConditionModel`. +""" + +from __future__ import annotations + +from torch import nn + +from ...configuration_utils import register_to_config +from ..activations import get_activation +from ..attention_processor import Attention, DreamLiteAttnProcessor2_0 +from ..normalization import RMSNorm +from .unet_2d_blocks_dreamlite import ( + CrossAttnDownBlock2DDreamLite, + CrossAttnDownRemoveSelfAttnBlock2DDreamLite, + CrossAttnUpBlock2DDreamLite, + CrossAttnUpRemoveSelfAttnBlock2DV1DreamLite, + DownBlock2DDreamLite, + UNetMidBlock2DCrossAttnDreamLite, + UpBlock2DDreamLite, +) +from .unet_2d_condition import UNet2DConditionModel + + +# --------------------------------------------------------------------------- +# Local block dispatch (DreamLite-only) +# --------------------------------------------------------------------------- +def _get_down_block_dreamlite( + down_block_type: str, + *, + num_layers, + transformer_layers_per_block, + in_channels, + out_channels, + temb_channels, + add_downsample, + resnet_eps, + resnet_act_fn, + resnet_groups, + cross_attention_dim, + num_attention_heads, + downsample_padding, + dual_cross_attention, + use_linear_projection, + only_cross_attention, + upcast_attention, + resnet_time_scale_shift, + attention_type, + dropout, + qk_norm, + use_sep_conv, + ff_mult, + num_kv_heads, +): + if down_block_type == "DownBlock2D": + return DownBlock2DDreamLite( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + dropout=dropout, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + resnet_time_scale_shift=resnet_time_scale_shift, + use_sep_conv=use_sep_conv, + ) + if down_block_type in ( + "CrossAttnDownBlock2D", + "CrossAttnDownRemoveSelfAttnBlock2D", + ): + if cross_attention_dim is None: + raise ValueError(f"cross_attention_dim must be specified for {down_block_type}") + cls = ( + CrossAttnDownBlock2DDreamLite + if down_block_type == "CrossAttnDownBlock2D" + else CrossAttnDownRemoveSelfAttnBlock2DDreamLite + ) + return cls( + num_layers=num_layers, + transformer_layers_per_block=transformer_layers_per_block, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + dropout=dropout, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + cross_attention_dim=cross_attention_dim, + num_attention_heads=num_attention_heads, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + attention_type=attention_type, + qk_norm=qk_norm, + use_sep_conv=use_sep_conv, + ff_mult=ff_mult, + num_kv_heads=num_kv_heads, + ) + raise ValueError(f"DreamLite does not support down_block_type={down_block_type!r}") + + +def _get_mid_block_dreamlite( + mid_block_type, + *, + temb_channels, + in_channels, + resnet_eps, + resnet_act_fn, + resnet_groups, + output_scale_factor, + transformer_layers_per_block, + num_attention_heads, + cross_attention_dim, + dual_cross_attention, + use_linear_projection, + upcast_attention, + resnet_time_scale_shift, + attention_type, + dropout, + qk_norm, + use_sep_conv, + ff_mult, + num_kv_heads, + num_mid_layers=1, +): + if mid_block_type is None: + return None + if mid_block_type == "UNetMidBlock2DCrossAttn": + return UNetMidBlock2DCrossAttnDreamLite( + transformer_layers_per_block=transformer_layers_per_block, + in_channels=in_channels, + temb_channels=temb_channels, + dropout=dropout, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + output_scale_factor=output_scale_factor, + resnet_time_scale_shift=resnet_time_scale_shift, + cross_attention_dim=cross_attention_dim, + num_attention_heads=num_attention_heads, + resnet_groups=resnet_groups, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + attention_type=attention_type, + qk_norm=qk_norm, + use_sep_conv=use_sep_conv, + ff_mult=ff_mult, + num_kv_heads=num_kv_heads, + num_layers=num_mid_layers, + ) + raise ValueError(f"DreamLite does not support mid_block_type={mid_block_type!r}") + + +def _get_up_block_dreamlite( + up_block_type, + *, + num_layers, + transformer_layers_per_block, + in_channels, + out_channels, + prev_output_channel, + temb_channels, + add_upsample, + resnet_eps, + resnet_act_fn, + resolution_idx, + resnet_groups, + cross_attention_dim, + num_attention_heads, + dual_cross_attention, + use_linear_projection, + only_cross_attention, + upcast_attention, + resnet_time_scale_shift, + attention_type, + dropout, + qk_norm, + use_sep_conv, + ff_mult, + num_kv_heads, +): + if up_block_type == "UpBlock2D": + return UpBlock2DDreamLite( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + resolution_idx=resolution_idx, + dropout=dropout, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + use_sep_conv=use_sep_conv, + ) + if up_block_type in ( + "CrossAttnUpBlock2D", + "CrossAttnUpRemoveSelfAttnBlock2DV1", + ): + if cross_attention_dim is None: + raise ValueError(f"cross_attention_dim must be specified for {up_block_type}") + cls = ( + CrossAttnUpBlock2DDreamLite + if up_block_type == "CrossAttnUpBlock2D" + else CrossAttnUpRemoveSelfAttnBlock2DV1DreamLite + ) + return cls( + num_layers=num_layers, + transformer_layers_per_block=transformer_layers_per_block, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + resolution_idx=resolution_idx, + dropout=dropout, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + cross_attention_dim=cross_attention_dim, + num_attention_heads=num_attention_heads, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + attention_type=attention_type, + qk_norm=qk_norm, + use_sep_conv=use_sep_conv, + ff_mult=ff_mult, + num_kv_heads=num_kv_heads, + ) + raise ValueError(f"DreamLite does not support up_block_type={up_block_type!r}") + + +# --------------------------------------------------------------------------- +# Model +# --------------------------------------------------------------------------- +class DreamLiteUNetModel(UNet2DConditionModel): + r""" + DreamLite variant of :class:`UNet2DConditionModel`. + + Differences vs the parent class: + + * Down / Mid / Up blocks are dispatched to the DreamLite variants (``unet_2d_blocks_dreamlite``), which support + depthwise-separable convolutions in resnets and Grouped Query Attention with RMSNorm ``qk_norm`` in attention. + * ``default_attn_processor`` returns :class:`DreamLiteAttnProcessor2_0` so SDPA is GQA-aware out of the box. + """ + + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + sample_size: int | tuple[int, int] | None = None, + in_channels: int = 4, + out_channels: int = 4, + center_input_sample: bool = False, + flip_sin_to_cos: bool = True, + freq_shift: int = 0, + down_block_types: tuple[str, ...] = ( + "CrossAttnDownRemoveSelfAttnBlock2D", + "CrossAttnDownRemoveSelfAttnBlock2D", + "CrossAttnDownBlock2D", + ), + mid_block_type: str | None = "UNetMidBlock2DCrossAttn", + up_block_types: tuple[str, ...] = ( + "CrossAttnUpBlock2D", + "CrossAttnUpRemoveSelfAttnBlock2DV1", + "UpBlock2D", + ), + only_cross_attention: bool | tuple[bool, ...] = False, + block_out_channels: tuple[int, ...] = (320, 640, 1280), + layers_per_block: int | tuple[int, ...] = 2, + downsample_padding: int = 1, + mid_block_scale_factor: float = 1, + dropout: float = 0.0, + act_fn: str = "silu", + norm_num_groups: int | None = 32, + norm_eps: float = 1e-5, + cross_attention_dim: int | tuple[int, ...] = 2048, + transformer_layers_per_block: int | tuple[int, ...] | tuple[tuple, ...] = 1, + reverse_transformer_layers_per_block: tuple[tuple[int, ...], ...] | None = None, + encoder_hid_dim: int | None = None, + encoder_hid_dim_type: str | None = None, + attention_head_dim: int | tuple[int, ...] = 64, + num_attention_heads: int | tuple[int, ...] | None = None, + dual_cross_attention: bool = False, + use_linear_projection: bool = False, + class_embed_type: str | None = None, + addition_embed_type: str | None = None, + addition_time_embed_dim: int | None = None, + num_class_embeds: int | None = None, + upcast_attention: bool = False, + resnet_time_scale_shift: str = "default", + resnet_skip_time_act: bool = False, + resnet_out_scale_factor: float = 1.0, + time_embedding_type: str = "positional", + time_embedding_dim: int | None = None, + time_embedding_act_fn: str | None = None, + timestep_post_act: str | None = None, + time_cond_proj_dim: int | None = None, + conv_in_kernel: int = 3, + conv_out_kernel: int = 3, + projection_class_embeddings_input_dim: int | None = None, + attention_type: str = "default", + class_embeddings_concat: bool = False, + mid_block_only_cross_attention: bool | None = None, + cross_attention_norm: str | None = None, + addition_embed_type_num_heads: int = 64, + # ---- DreamLite extras ---- + qk_norm: str | None = "rms_norm", + use_sep_conv: bool = True, + ff_mult: int = 6, + num_kv_heads: int | None = 1, + num_mid_layers: int = 1, + ): + # NOTE: deliberately skip UNet2DConditionModel.__init__ and call nn.Module directly, + # because we replicate the body with DreamLite block dispatch. + nn.Module.__init__(self) + + self.sample_size = sample_size + + if num_attention_heads is not None: + raise ValueError( + "At the moment it is not possible to define the number of attention heads via " + "`num_attention_heads` because of a naming issue as described in " + "https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. " + "Passing `num_attention_heads` will only be supported in diffusers v0.19." + ) + num_attention_heads = num_attention_heads or attention_head_dim + + # Reuse parent helpers (they only touch self, no super().__init__ required). + self._check_config( + down_block_types=down_block_types, + up_block_types=up_block_types, + only_cross_attention=only_cross_attention, + block_out_channels=block_out_channels, + layers_per_block=layers_per_block, + cross_attention_dim=cross_attention_dim, + transformer_layers_per_block=transformer_layers_per_block, + reverse_transformer_layers_per_block=reverse_transformer_layers_per_block, + attention_head_dim=attention_head_dim, + num_attention_heads=num_attention_heads, + ) + + self.projection_class_embeddings_input_dim = projection_class_embeddings_input_dim + + # input + conv_in_padding = (conv_in_kernel - 1) // 2 + self.conv_in = nn.Conv2d( + in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding + ) + + # time + time_embed_dim, timestep_input_dim = self._set_time_proj( + time_embedding_type, + block_out_channels=block_out_channels, + flip_sin_to_cos=flip_sin_to_cos, + freq_shift=freq_shift, + time_embedding_dim=time_embedding_dim, + ) + + from ..embeddings import TimestepEmbedding # local import to avoid cycle + + self.time_embedding = TimestepEmbedding( + timestep_input_dim, + time_embed_dim, + act_fn=act_fn, + post_act_fn=timestep_post_act, + cond_proj_dim=time_cond_proj_dim, + ) + + self._set_encoder_hid_proj( + encoder_hid_dim_type, + cross_attention_dim=cross_attention_dim, + encoder_hid_dim=encoder_hid_dim, + ) + self._set_class_embedding( + class_embed_type, + act_fn=act_fn, + num_class_embeds=num_class_embeds, + projection_class_embeddings_input_dim=projection_class_embeddings_input_dim, + time_embed_dim=time_embed_dim, + timestep_input_dim=timestep_input_dim, + ) + self._set_add_embedding( + addition_embed_type, + addition_embed_type_num_heads=addition_embed_type_num_heads, + addition_time_embed_dim=addition_time_embed_dim, + cross_attention_dim=cross_attention_dim, + encoder_hid_dim=encoder_hid_dim, + flip_sin_to_cos=flip_sin_to_cos, + freq_shift=freq_shift, + projection_class_embeddings_input_dim=projection_class_embeddings_input_dim, + time_embed_dim=time_embed_dim, + ) + + self.time_embed_act = None if time_embedding_act_fn is None else get_activation(time_embedding_act_fn) + + self.down_blocks = nn.ModuleList([]) + self.up_blocks = nn.ModuleList([]) + + # Normalize per-stage args + if isinstance(only_cross_attention, bool): + if mid_block_only_cross_attention is None: + mid_block_only_cross_attention = only_cross_attention + only_cross_attention = [only_cross_attention] * len(down_block_types) + if mid_block_only_cross_attention is None: + mid_block_only_cross_attention = False + if isinstance(num_attention_heads, int): + num_attention_heads = (num_attention_heads,) * len(down_block_types) + if isinstance(attention_head_dim, int): + attention_head_dim = (attention_head_dim,) * len(down_block_types) + if isinstance(cross_attention_dim, int): + cross_attention_dim = (cross_attention_dim,) * len(down_block_types) + if isinstance(layers_per_block, int): + layers_per_block = [layers_per_block] * len(down_block_types) + if isinstance(transformer_layers_per_block, int): + transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types) + + blocks_time_embed_dim = time_embed_dim * 2 if class_embeddings_concat else time_embed_dim + + # ---- Down ---- + output_channel = block_out_channels[0] + for i, down_block_type in enumerate(down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + + self.down_blocks.append( + _get_down_block_dreamlite( + down_block_type, + num_layers=layers_per_block[i], + transformer_layers_per_block=transformer_layers_per_block[i], + in_channels=input_channel, + out_channels=output_channel, + temb_channels=blocks_time_embed_dim, + add_downsample=not is_final_block, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim[i], + num_attention_heads=num_attention_heads[i], + downsample_padding=downsample_padding, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention[i], + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + attention_type=attention_type, + dropout=dropout, + qk_norm=qk_norm, + use_sep_conv=use_sep_conv, + ff_mult=ff_mult, + num_kv_heads=num_kv_heads, + ) + ) + + # ---- Mid ---- + self.mid_block = _get_mid_block_dreamlite( + mid_block_type, + temb_channels=blocks_time_embed_dim, + in_channels=block_out_channels[-1], + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + output_scale_factor=mid_block_scale_factor, + transformer_layers_per_block=transformer_layers_per_block[-1], + num_attention_heads=num_attention_heads[-1], + cross_attention_dim=cross_attention_dim[-1], + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + attention_type=attention_type, + dropout=dropout, + qk_norm=qk_norm, + use_sep_conv=use_sep_conv, + ff_mult=ff_mult, + num_kv_heads=num_kv_heads, + num_mid_layers=num_mid_layers, + ) + + # ---- Up ---- + self.num_upsamplers = 0 + reversed_block_out_channels = list(reversed(block_out_channels)) + reversed_num_attention_heads = list(reversed(num_attention_heads)) + reversed_layers_per_block = list(reversed(layers_per_block)) + reversed_cross_attention_dim = list(reversed(cross_attention_dim)) + reversed_transformer_layers_per_block = ( + list(reversed(transformer_layers_per_block)) + if reverse_transformer_layers_per_block is None + else reverse_transformer_layers_per_block + ) + only_cross_attention = list(reversed(only_cross_attention)) + + output_channel = reversed_block_out_channels[0] + for i, up_block_type in enumerate(up_block_types): + is_final_block = i == len(block_out_channels) - 1 + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] + + if not is_final_block: + add_upsample = True + self.num_upsamplers += 1 + else: + add_upsample = False + + self.up_blocks.append( + _get_up_block_dreamlite( + up_block_type, + num_layers=reversed_layers_per_block[i] + 1, + transformer_layers_per_block=reversed_transformer_layers_per_block[i], + in_channels=input_channel, + out_channels=output_channel, + prev_output_channel=prev_output_channel, + temb_channels=blocks_time_embed_dim, + add_upsample=add_upsample, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resolution_idx=i, + resnet_groups=norm_num_groups, + cross_attention_dim=reversed_cross_attention_dim[i], + num_attention_heads=reversed_num_attention_heads[i], + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention[i], + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + attention_type=attention_type, + dropout=dropout, + qk_norm=qk_norm, + use_sep_conv=use_sep_conv, + ff_mult=ff_mult, + num_kv_heads=num_kv_heads, + ) + ) + + # ---- Out ---- + if norm_num_groups is not None: + self.conv_norm_out = nn.GroupNorm( + num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps + ) + self.conv_act = get_activation(act_fn) + else: + self.conv_norm_out = None + self.conv_act = None + + conv_out_padding = (conv_out_kernel - 1) // 2 + self.conv_out = nn.Conv2d( + block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding + ) + + self._set_pos_net_if_use_gligen(attention_type=attention_type, cross_attention_dim=cross_attention_dim) + + # ---- DreamLite: install GQA-aware processor everywhere ---- + for module in self.modules(): + if isinstance(module, Attention): + module.set_processor(DreamLiteAttnProcessor2_0()) + + # ----- override default processor so set_attn_processor("default") restores GQA ---- + @property + def default_attn_processor(self): # type: ignore[override] + return DreamLiteAttnProcessor2_0() + + # ----- DreamLite extension: support `text_proj_rms` encoder_hid_proj ----- + def _set_encoder_hid_proj( # type: ignore[override] + self, + encoder_hid_dim_type, + cross_attention_dim, + encoder_hid_dim, + ): + """ + Override to support DreamLite's `text_proj_rms` variant (Linear → RMSNorm). All other variants fall back to the + parent implementation, preserving full compatibility with upstream configs (`text_proj`, `text_image_proj`, + `image_proj`, ...). + """ + if encoder_hid_dim_type == "text_proj_rms": + if encoder_hid_dim is None: + raise ValueError( + "`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to 'text_proj_rms'." + ) + self.encoder_hid_proj = nn.Sequential( + nn.Linear(encoder_hid_dim, cross_attention_dim), + RMSNorm(cross_attention_dim, eps=1e-5, elementwise_affine=True), + ) + return + super()._set_encoder_hid_proj( + encoder_hid_dim_type=encoder_hid_dim_type, + cross_attention_dim=cross_attention_dim, + encoder_hid_dim=encoder_hid_dim, + ) + + # ----- DreamLite extension: dispatch `text_proj_rms` like `text_proj` ----- + def process_encoder_hidden_states( # type: ignore[override] + self, encoder_hidden_states, added_cond_kwargs + ): + """ + For `text_proj_rms`, the projection is a plain `nn.Sequential` applied to `encoder_hidden_states` (same call + signature as `text_proj`). All other variants are delegated to the parent. + """ + if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj_rms": + return self.encoder_hid_proj(encoder_hidden_states) + return super().process_encoder_hidden_states( + encoder_hidden_states=encoder_hidden_states, + added_cond_kwargs=added_cond_kwargs, + ) + + # ----- DreamLite extension: support `addition_embed_type == "time"` ----- + def _set_add_embedding( # type: ignore[override] + self, + addition_embed_type, + addition_embed_type_num_heads, + addition_time_embed_dim, + flip_sin_to_cos, + freq_shift, + cross_attention_dim, + encoder_hid_dim, + projection_class_embeddings_input_dim, + time_embed_dim, + ): + """ + Override to support DreamLite's `addition_embed_type == "time"` variant (same module layout as `text_time` but + `get_aug_embed` does not require `text_embeds`). All other variants delegate to the parent implementation. + """ + if addition_embed_type == "time": + from ..embeddings import TimestepEmbedding, Timesteps + + self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift) + self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) + return + super()._set_add_embedding( + addition_embed_type=addition_embed_type, + addition_embed_type_num_heads=addition_embed_type_num_heads, + addition_time_embed_dim=addition_time_embed_dim, + flip_sin_to_cos=flip_sin_to_cos, + freq_shift=freq_shift, + cross_attention_dim=cross_attention_dim, + encoder_hid_dim=encoder_hid_dim, + projection_class_embeddings_input_dim=projection_class_embeddings_input_dim, + time_embed_dim=time_embed_dim, + ) + + # ----- DreamLite extension: dispatch `addition_embed_type == "time"` ----- + def get_aug_embed( # type: ignore[override] + self, emb, encoder_hidden_states, added_cond_kwargs + ): + """ + For `addition_embed_type == "time"`, build aug_emb from `time_ids` only (no `text_embeds`). All other variants + are delegated to the parent. + """ + if self.config.addition_embed_type == "time": + if "time_ids" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'time' " + "which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`" + ) + time_ids = added_cond_kwargs.get("time_ids") + time_embeds = self.add_time_proj(time_ids.flatten()) + time_embeds = time_embeds.reshape((-1, self.config.projection_class_embeddings_input_dim)) + add_embeds = time_embeds.to(emb.dtype) + return self.add_embedding(add_embeds) + return super().get_aug_embed( + emb=emb, + encoder_hidden_states=encoder_hidden_states, + added_cond_kwargs=added_cond_kwargs, + ) + + +__all__ = ["DreamLiteUNetModel"] diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index c0d12121d5e8..8a9cd324d2de 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -270,6 +270,7 @@ "IFPipeline", "IFSuperResolutionPipeline", ] + _import_structure["dreamlite"] = ["DreamLitePipeline", "DreamLiteMobilePipeline", "DreamLitePipelineOutput"] _import_structure["easyanimate"] = [ "EasyAnimatePipeline", "EasyAnimateInpaintPipeline", @@ -708,6 +709,11 @@ WuerstchenDecoderPipeline, WuerstchenPriorPipeline, ) + from .dreamlite import ( + DreamLiteMobilePipeline, + DreamLitePipeline, + DreamLitePipelineOutput, + ) from .easyanimate import ( EasyAnimateControlPipeline, EasyAnimateInpaintPipeline, diff --git a/src/diffusers/pipelines/dreamlite/__init__.py b/src/diffusers/pipelines/dreamlite/__init__.py new file mode 100644 index 000000000000..01a0609265c1 --- /dev/null +++ b/src/diffusers/pipelines/dreamlite/__init__.py @@ -0,0 +1,42 @@ +# Copyright (c) 2026 ByteDance Ltd. and/or its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + _LazyModule, +) + + +_import_structure = { + "pipeline_dreamlite": ["DreamLitePipeline"], + "pipeline_dreamlite_mobile": ["DreamLiteMobilePipeline"], + "pipeline_output": ["DreamLitePipelineOutput"], +} + + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + from .pipeline_dreamlite import DreamLitePipeline + from .pipeline_dreamlite_mobile import DreamLiteMobilePipeline + from .pipeline_output import DreamLitePipelineOutput +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) diff --git a/src/diffusers/pipelines/dreamlite/pipeline_dreamlite.py b/src/diffusers/pipelines/dreamlite/pipeline_dreamlite.py new file mode 100644 index 000000000000..f729b7e83ab8 --- /dev/null +++ b/src/diffusers/pipelines/dreamlite/pipeline_dreamlite.py @@ -0,0 +1,538 @@ +# Copyright (c) 2026 ByteDance Ltd. and/or its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import List, Optional, Union + +import numpy as np +import torch +from PIL import Image +from torch.nn.utils.rnn import pad_sequence +from transformers import AutoTokenizer, Qwen3VLForConditionalGeneration, Qwen3VLProcessor + +from ...image_processor import VaeImageProcessor +from ...loaders import FromSingleFileMixin, TextualInversionLoaderMixin +from ...models import AutoencoderTiny +from ...models.unets.unet_dreamlite import DreamLiteUNetModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import is_torch_xla_available, logging +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline +from ..stable_diffusion.pipeline_stable_diffusion_img2img import retrieve_latents +from .pipeline_output import DreamLitePipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from PIL import Image + >>> from diffusers import DreamLitePipeline + + >>> pipe = DreamLitePipeline.from_pretrained( + ... "carlofkl/DreamLite-base", revision="diffusers", torch_dtype=torch.float16 + ... ) + >>> pipe.to("cuda") + + >>> # Text-to-image + >>> image = pipe(prompt="A serene mountain lake at sunrise").images[0] + + >>> # Image-to-image (instruction-based edit) + >>> init_image = Image.open("input.png").convert("RGB") + >>> edited = pipe(prompt="make it snowy", image=init_image).images[0] + ``` +""" + + +def calculate_shift( + image_seq_len: int, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.16, +) -> float: + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +# Copied from diffusers.pipelines.flux.pipeline_flux.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class DreamLitePipeline(DiffusionPipeline, FromSingleFileMixin, TextualInversionLoaderMixin): + r"""DreamLite pipeline for text-to-image and instruction-based image editing. + + The same pipeline supports both modes; the operating mode is auto-detected from the inputs: + + * ``image is None`` -> text-to-image (single CFG on text). + * ``image is not None`` -> image-to-image / instruction edit (dual CFG: text + image). + + Components: + text_encoder ([`~transformers.Qwen3VLForConditionalGeneration`]): + Multimodal text/vision encoder used to produce conditioning embeddings. + tokenizer ([`~transformers.AutoTokenizer`]): + Tokenizer for text-only (generate) mode. + processor ([`~transformers.Qwen3VLProcessor`]): + Multimodal processor for edit mode (text + image template). + vae ([`~diffusers.AutoencoderTiny`]): + Mobile-friendly tiny VAE for latent encode/decode. + unet ([`~diffusers.DreamLiteUNetModel`]): + DreamLite UNet (GQA + qk_norm + depthwise-separable convs). + scheduler ([`~diffusers.FlowMatchEulerDiscreteScheduler`]): + Flow-matching Euler scheduler with dynamic shift. + + Note: + ``batch_size`` is currently forced to ``1``; ``num_images_per_prompt`` is supported. + """ + + model_cpu_offload_seq = "text_encoder->unet->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds"] + + def __init__( + self, + text_encoder: Qwen3VLForConditionalGeneration, + tokenizer: AutoTokenizer, + processor: Qwen3VLProcessor, + vae: AutoencoderTiny, + unet: DreamLiteUNetModel, + scheduler: FlowMatchEulerDiscreteScheduler, + ): + super().__init__() + self.register_modules( + text_encoder=text_encoder, + tokenizer=tokenizer, + processor=processor, + vae=vae, + unet=unet, + scheduler=scheduler, + ) + + # Safe VAE scale factor: AutoencoderTiny exposes `encoder_block_out_channels`; fall back to 8. + if self.vae is not None and hasattr(self.vae.config, "encoder_block_out_channels"): + self.vae_scale_factor = 2 ** (len(self.vae.config.encoder_block_out_channels) - 1) + else: + self.vae_scale_factor = 8 + + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) + self.default_sample_size = 128 + + # --------------------------------------------------------------------- + # Helpers + # --------------------------------------------------------------------- + @staticmethod + def _extract_masked_hidden(hidden_states: torch.Tensor, mask: torch.Tensor) -> List[torch.Tensor]: + bool_mask = mask.bool() + valid_lengths = bool_mask.sum(dim=1).tolist() + selected = hidden_states[bool_mask] + return torch.split(selected, valid_lengths, dim=0) + + def encode_prompt( + self, + mode: str, + prompts: List[str], + device: torch.device, + dtype: torch.dtype, + image: Optional[Image.Image] = None, + max_sequence_length: int = 500, + text_pad_embedding: Optional[torch.Tensor] = None, + ): + if mode == "edit": + drop_idx = 64 + template = ( + "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, " + "texture, objects, background), then explain how the user's text instruction should alter " + "or modify the image. Generate a new image that meets the user's requirements while maintaining " + "consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n" + "<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n" + ) + + txts = [template.format(p) for p in prompts] + images = [image.resize((512, 512), Image.Resampling.LANCZOS)] * len(prompts) + + tk_out = self.processor(text=txts, images=images, padding=True, return_tensors="pt").to(device) + + outputs = self.text_encoder( + input_ids=tk_out.input_ids, + attention_mask=tk_out.attention_mask, + pixel_values=tk_out.pixel_values, + image_grid_thw=tk_out.image_grid_thw, + output_hidden_states=True, + ) + + elif mode == "generate": + drop_idx = 34 + template = ( + "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, " + "quantity, text, spatial relationships of the objects and background:<|im_end|>\n" + "<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n" + ) + + txts = [template.format(p) for p in prompts] + tk_out = self.tokenizer(text=txts, padding=True, return_tensors="pt").to(device) + + outputs = self.text_encoder( + input_ids=tk_out.input_ids, + attention_mask=tk_out.attention_mask, + output_hidden_states=True, + ) + else: + raise ValueError(f"Unknown mode: {mode!r}; expected 'generate' or 'edit'.") + + hidden_states = outputs.hidden_states[-1] + split_hidden_states = self._extract_masked_hidden(hidden_states, tk_out.attention_mask) + split_hidden_states = [e[drop_idx:] for e in split_hidden_states] + + prompt_embeds = pad_sequence(split_hidden_states, batch_first=True, padding_value=0).to( + dtype=dtype, device=device + ) + + B, L, _ = prompt_embeds.shape + prompt_embeds_mask = torch.zeros((B, L), dtype=torch.long, device=device) + for i, seq in enumerate(split_hidden_states): + prompt_embeds_mask[i, : seq.shape[0]] = 1 + + if text_pad_embedding is not None: + pad_emb = text_pad_embedding.to(dtype=dtype, device=device) + if pad_emb.ndim == 1: + pad_emb = pad_emb.unsqueeze(0).unsqueeze(0) + elif pad_emb.ndim == 2: + pad_emb = pad_emb.unsqueeze(0) + + mask_expanded = prompt_embeds_mask.unsqueeze(-1).to(dtype=dtype) + prompt_embeds = prompt_embeds * mask_expanded + pad_emb * (1 - mask_expanded) + + return prompt_embeds, prompt_embeds_mask + + def prepare_latents( + self, + batch_size: int, + num_channels_latents: int, + height: int, + width: int, + dtype: torch.dtype, + device: torch.device, + generator: Optional[torch.Generator], + latents: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + height = int(height) // self.vae_scale_factor + width = int(width) // self.vae_scale_factor + shape = (batch_size, num_channels_latents, height, width) + + if latents is not None: + return latents.to(device=device, dtype=dtype) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError("Generator list length must match batch size.") + + return randn_tensor(shape, generator=generator, device=device, dtype=dtype) + + def prepare_image_latents( + self, + image: Union[torch.Tensor, Image.Image, List[Image.Image]], + dtype: torch.dtype, + device: torch.device, + generator: Optional[torch.Generator] = None, + ) -> torch.Tensor: + if not isinstance(image, (torch.Tensor, Image.Image, list)): + raise ValueError(f"`image` must be of type `torch.Tensor`, `PIL.Image.Image` or `list`, got {type(image)}") + + image = image.to(device=device, dtype=dtype) + + if image.shape[1] == 4: + image_latents = image + else: + image_latents = retrieve_latents(self.vae.encode(image), sample_mode="argmax") + + return image_latents + + # --------------------------------------------------------------------- + # Properties + # --------------------------------------------------------------------- + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def image_guidance_scale(self): + return self._image_guidance_scale + + # --------------------------------------------------------------------- + # Main entry + # --------------------------------------------------------------------- + @torch.no_grad() + def __call__( + self, + prompt: Optional[str] = None, + negative_prompt: Optional[str] = None, + image: Optional[Image.Image] = None, + height: Optional[int] = None, + width: Optional[int] = None, + guidance_scale: float = 7.5, + image_guidance_scale: float = 1.0, + num_inference_steps: int = 30, + sigmas: Optional[List[float]] = None, + num_images_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + max_sequence_length: int = 200, + text_pad_embedding: Optional[torch.Tensor] = None, + ): + r"""Run the DreamLite pipeline. + + Args: + prompt: Text prompt. + negative_prompt: Negative text prompt (defaults to empty string). + image: Optional input image. If provided, the pipeline runs in **edit / image-to-image** mode + with dual classifier-free guidance; otherwise it runs in **text-to-image** mode. + height: Output resolution (height). Defaults to ``default_sample_size * vae_scale_factor`` (1024). + The same default applies in both T2I and I2I; pass an explicit value to override. + width: Output resolution (width). Defaults to ``default_sample_size * vae_scale_factor`` (1024). + The same default applies in both T2I and I2I; pass an explicit value to override. + guidance_scale: CFG scale on the text branch (both modes). + image_guidance_scale: Additional CFG scale on the image branch (edit mode only). + num_inference_steps: Number of denoising steps. + sigmas: Optional explicit FlowMatch sigmas; defaults to a uniform linspace. + num_images_per_prompt: Output images per prompt (note: ``batch_size`` is forced to 1). + generator: Random generator(s). + output_type: ``"pil"``, ``"np"``, ``"pt"`` or ``"latent"``. + return_dict: If True, returns a :class:`DreamLitePipelineOutput`; else a tuple ``(images,)``. + max_sequence_length: Reserved (passed through to ``encode_prompt``). + text_pad_embedding: Optional learned pad embedding for masked positions. + + Returns: + :class:`DreamLitePipelineOutput` or ``tuple``. + """ + # 1. Init pipeline parameters + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + self._guidance_scale = guidance_scale + self._image_guidance_scale = image_guidance_scale + + task = "generate" if image is None else "edit" + device = self._execution_device + dtype = self.text_encoder.dtype + batch_size = 1 # Note: pipeline currently forces batch_size = 1. + negative_prompt = negative_prompt or "" + + if sigmas is None: + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) + + # 2. Prepare Time IDs (carries original H,W as additional conditioning) + original_size = (width, height) + add_time_ids = torch.tensor([list(original_size)], device=device, dtype=dtype) + + # 3. Prepare Noise Latents (x_t) + num_channels_latents = self.vae.config.latent_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + dtype, + device, + generator, + ) + + # 4. Prepare Timesteps (FlowMatch with dynamic shift) + image_seq_len = latents.shape[2] * latents.shape[3] // 4 + mu = calculate_shift( + image_seq_len, + self.scheduler.config.get("base_image_seq_len", 256), + self.scheduler.config.get("max_image_seq_len", 4096), + self.scheduler.config.get("base_shift", 0.5), + self.scheduler.config.get("max_shift", 1.16), + ) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + sigmas=sigmas, + mu=mu, + ) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + + # 5. Prepare Conditions (Text & Image) + if task == "generate": + prompt_str = f"[Generate]: {prompt}" + prompt_embeds, text_attention_mask = self.encode_prompt( + mode="generate", + prompts=[negative_prompt, prompt_str], + device=device, + dtype=dtype, + text_pad_embedding=text_pad_embedding, + ) + if num_images_per_prompt > 1: + prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0) + text_attention_mask = text_attention_mask.repeat_interleave(num_images_per_prompt, dim=0) + image_latents = torch.zeros_like(latents) + else: + prompt_str = ( + f"[Edit]: A diptych with two side-by-side images of the same scene. " + f"Compared to the right side, the left one has {prompt}" + ) + prompt_embeds, text_attention_mask = self.encode_prompt( + mode="edit", + prompts=[negative_prompt, negative_prompt, prompt_str], + image=image, + device=device, + dtype=dtype, + ) + if num_images_per_prompt > 1: + prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0) + text_attention_mask = text_attention_mask.repeat_interleave(num_images_per_prompt, dim=0) + image_processed = self.image_processor.preprocess(image.resize((width, height), Image.Resampling.LANCZOS)) + image_latents = self.prepare_image_latents( + image_processed, + dtype=dtype, + device=device, + ) + uncond_image_latents = torch.zeros_like(latents) + + # 6. Denoising Loop + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # Expand latents for classifier-free guidance + if task == "generate": + latents_in = torch.cat([latents] * 2) + cond_img_in = torch.cat([image_latents] * 2) + model_input = torch.cat([latents_in, cond_img_in], dim=3) + time_ids_in = torch.cat([add_time_ids] * 2) + else: # edit + latents_in = torch.cat([latents] * 3) + cond_img_in = torch.cat([uncond_image_latents, image_latents, image_latents]) + model_input = torch.cat([latents_in, cond_img_in], dim=3) + time_ids_in = torch.cat([add_time_ids] * 3) + + # UNet Forward + noise_pred = self.unet( + model_input, + timestep=t.expand(model_input.shape[0]).to(latents.dtype), + encoder_hidden_states=prompt_embeds, + encoder_attention_mask=text_attention_mask, + added_cond_kwargs={"time_ids": time_ids_in}, + return_dict=False, + )[0] + + # Classifier-Free Guidance (single for T2I, dual for I2I) + noise_pred = noise_pred[..., : latents.shape[-1]] + if task == "generate": + noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self._guidance_scale * (noise_pred_cond - noise_pred_uncond) + else: # edit + noise_pred_uncond, noise_pred_image, noise_pred_text = noise_pred.chunk(3) + noise_pred = ( + noise_pred_uncond + + self._guidance_scale * (noise_pred_text - noise_pred_image) + + self._image_guidance_scale * (noise_pred_image - noise_pred_uncond) + ) + + # Scheduler Step + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + latents = latents.to(latents_dtype) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + # 7. Decode Latents + if output_type == "latent": + image_out = latents + else: + shift_factor = getattr(self.vae.config, "shift_factor", 0.0) or 0.0 + latents = (latents / self.vae.config.scaling_factor) + shift_factor + image_out = self.vae.decode(latents, return_dict=False)[0] + image_out = self.image_processor.postprocess(image_out, output_type=output_type) + + self.maybe_free_model_hooks() + + if not return_dict: + return (image_out,) + + return DreamLitePipelineOutput(images=image_out) diff --git a/src/diffusers/pipelines/dreamlite/pipeline_dreamlite_mobile.py b/src/diffusers/pipelines/dreamlite/pipeline_dreamlite_mobile.py new file mode 100644 index 000000000000..6361fad00add --- /dev/null +++ b/src/diffusers/pipelines/dreamlite/pipeline_dreamlite_mobile.py @@ -0,0 +1,440 @@ +# Copyright (c) 2026 ByteDance Ltd. and/or its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Optional, Union + +import numpy as np +import torch +from PIL import Image +from torch.nn.utils.rnn import pad_sequence +from transformers import AutoTokenizer, Qwen3VLForConditionalGeneration, Qwen3VLProcessor + +from ...image_processor import VaeImageProcessor +from ...loaders import FromSingleFileMixin, TextualInversionLoaderMixin +from ...models import AutoencoderTiny +from ...models.unets.unet_dreamlite import DreamLiteUNetModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import is_torch_xla_available, logging +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline +from ..stable_diffusion.pipeline_stable_diffusion_img2img import retrieve_latents +from .pipeline_dreamlite import calculate_shift, retrieve_timesteps +from .pipeline_output import DreamLitePipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from PIL import Image + >>> from diffusers import DreamLiteMobilePipeline + + >>> pipe = DreamLiteMobilePipeline.from_pretrained( + ... "carlofkl/DreamLite-mobile", revision="diffusers", torch_dtype=torch.float16 + ... ) + >>> pipe.to("cuda") + + >>> # Text-to-image (4 steps, no CFG) + >>> image = pipe(prompt="A serene mountain lake at sunrise").images[0] + + >>> # Image-to-image (instruction-based edit, 4 steps) + >>> init_image = Image.open("input.png").convert("RGB") + >>> edited = pipe(prompt="make it snowy", image=init_image).images[0] + ``` +""" + + +class DreamLiteMobilePipeline(DiffusionPipeline, FromSingleFileMixin, TextualInversionLoaderMixin): + r"""DreamLite **Mobile** pipeline: a distilled, classifier-free-guidance-free variant of + :class:`DreamLitePipeline` for fast few-step inference (default 4 steps). + + The operating mode is auto-detected from inputs (same as the base pipeline): + + * ``image is None`` -> text-to-image. + * ``image is not None`` -> image-to-image / instruction edit. + + Because classifier-free guidance is **distilled away**, ``guidance_scale`` and ``image_guidance_scale`` are + accepted for API parity with :class:`DreamLitePipeline` but are ignored in the denoising loop. ``negative_prompt`` + is intentionally absent. + + Components (identical to the base pipeline): + text_encoder ([`~transformers.Qwen3VLForConditionalGeneration`]): + Multimodal text/vision encoder. + tokenizer ([`~transformers.AutoTokenizer`]): + Tokenizer for text-only (generate) mode. + processor ([`~transformers.Qwen3VLProcessor`]): + Multimodal processor for edit mode. + vae ([`~diffusers.AutoencoderTiny`]): + Mobile-friendly tiny VAE. + unet ([`~diffusers.DreamLiteUNetModel`]): + DreamLite UNet. + scheduler ([`~diffusers.FlowMatchEulerDiscreteScheduler`]): + Flow-matching Euler scheduler with dynamic shift. + + Note: + ``batch_size`` is currently forced to ``1``; ``num_images_per_prompt`` is supported. + """ + + model_cpu_offload_seq = "text_encoder->unet->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds"] + + def __init__( + self, + text_encoder: Qwen3VLForConditionalGeneration, + tokenizer: AutoTokenizer, + processor: Qwen3VLProcessor, + vae: AutoencoderTiny, + unet: DreamLiteUNetModel, + scheduler: FlowMatchEulerDiscreteScheduler, + ): + super().__init__() + self.register_modules( + text_encoder=text_encoder, + tokenizer=tokenizer, + processor=processor, + vae=vae, + unet=unet, + scheduler=scheduler, + ) + + # Safe VAE scale factor: AutoencoderTiny exposes `encoder_block_out_channels`; fall back to 8. + if self.vae is not None and hasattr(self.vae.config, "encoder_block_out_channels"): + self.vae_scale_factor = 2 ** (len(self.vae.config.encoder_block_out_channels) - 1) + else: + self.vae_scale_factor = 8 + + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) + self.default_sample_size = 128 + + # --------------------------------------------------------------------- + # Helpers (identical to DreamLitePipeline) + # --------------------------------------------------------------------- + @staticmethod + def _extract_masked_hidden(hidden_states: torch.Tensor, mask: torch.Tensor) -> List[torch.Tensor]: + bool_mask = mask.bool() + valid_lengths = bool_mask.sum(dim=1).tolist() + selected = hidden_states[bool_mask] + return torch.split(selected, valid_lengths, dim=0) + + def encode_prompt( + self, + mode: str, + prompts: List[str], + device: torch.device, + dtype: torch.dtype, + image: Optional[Image.Image] = None, + max_sequence_length: int = 500, + text_pad_embedding: Optional[torch.Tensor] = None, + ): + if mode == "edit": + drop_idx = 64 + template = ( + "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, " + "texture, objects, background), then explain how the user's text instruction should alter " + "or modify the image. Generate a new image that meets the user's requirements while maintaining " + "consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n" + "<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n" + ) + + txts = [template.format(p) for p in prompts] + images = [image.resize((512, 512), Image.Resampling.LANCZOS)] * len(prompts) + + tk_out = self.processor(text=txts, images=images, padding=True, return_tensors="pt").to(device) + + outputs = self.text_encoder( + input_ids=tk_out.input_ids, + attention_mask=tk_out.attention_mask, + pixel_values=tk_out.pixel_values, + image_grid_thw=tk_out.image_grid_thw, + output_hidden_states=True, + ) + + elif mode == "generate": + drop_idx = 34 + template = ( + "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, " + "quantity, text, spatial relationships of the objects and background:<|im_end|>\n" + "<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n" + ) + + txts = [template.format(p) for p in prompts] + tk_out = self.tokenizer(text=txts, padding=True, return_tensors="pt").to(device) + + outputs = self.text_encoder( + input_ids=tk_out.input_ids, + attention_mask=tk_out.attention_mask, + output_hidden_states=True, + ) + else: + raise ValueError(f"Unknown mode: {mode!r}; expected 'generate' or 'edit'.") + + hidden_states = outputs.hidden_states[-1] + split_hidden_states = self._extract_masked_hidden(hidden_states, tk_out.attention_mask) + split_hidden_states = [e[drop_idx:] for e in split_hidden_states] + + prompt_embeds = pad_sequence(split_hidden_states, batch_first=True, padding_value=0).to( + dtype=dtype, device=device + ) + + B, L, _ = prompt_embeds.shape + prompt_embeds_mask = torch.zeros((B, L), dtype=torch.long, device=device) + for i, seq in enumerate(split_hidden_states): + prompt_embeds_mask[i, : seq.shape[0]] = 1 + + if text_pad_embedding is not None: + pad_emb = text_pad_embedding.to(dtype=dtype, device=device) + if pad_emb.ndim == 1: + pad_emb = pad_emb.unsqueeze(0).unsqueeze(0) + elif pad_emb.ndim == 2: + pad_emb = pad_emb.unsqueeze(0) + + mask_expanded = prompt_embeds_mask.unsqueeze(-1).to(dtype=dtype) + prompt_embeds = prompt_embeds * mask_expanded + pad_emb * (1 - mask_expanded) + + return prompt_embeds, prompt_embeds_mask + + def prepare_latents( + self, + batch_size: int, + num_channels_latents: int, + height: int, + width: int, + dtype: torch.dtype, + device: torch.device, + generator: Optional[torch.Generator], + latents: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + height = int(height) // self.vae_scale_factor + width = int(width) // self.vae_scale_factor + shape = (batch_size, num_channels_latents, height, width) + + if latents is not None: + return latents.to(device=device, dtype=dtype) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError("Generator list length must match batch size.") + + return randn_tensor(shape, generator=generator, device=device, dtype=dtype) + + def prepare_image_latents( + self, + image: Union[torch.Tensor, Image.Image, List[Image.Image]], + dtype: torch.dtype, + device: torch.device, + generator: Optional[torch.Generator] = None, + ) -> torch.Tensor: + if not isinstance(image, (torch.Tensor, Image.Image, list)): + raise ValueError(f"`image` must be of type `torch.Tensor`, `PIL.Image.Image` or `list`, got {type(image)}") + + image = image.to(device=device, dtype=dtype) + + if image.shape[1] == 4: + image_latents = image + else: + image_latents = retrieve_latents(self.vae.encode(image), sample_mode="argmax") + + return image_latents + + # --------------------------------------------------------------------- + # Main entry + # --------------------------------------------------------------------- + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]] = None, + image: Optional[Image.Image] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 4, + guidance_scale: Optional[float] = None, + image_guidance_scale: Optional[float] = None, + sigmas: Optional[List[float]] = None, + num_images_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + max_sequence_length: int = 200, + text_pad_embedding: Optional[torch.Tensor] = None, + ): + r"""Run the distilled DreamLite Mobile pipeline. + + Args: + prompt: Text prompt. + image: Optional input image. If provided, runs in **edit / image-to-image** mode; + otherwise runs in **text-to-image** mode. + height: Output resolution (height). Defaults to ``default_sample_size * vae_scale_factor`` (1024). + width: Output resolution (width). Defaults to ``default_sample_size * vae_scale_factor`` (1024). + num_inference_steps: Number of denoising steps. Defaults to **4** (distilled). + guidance_scale: Accepted for API parity with :class:`DreamLitePipeline`; **ignored** + because CFG was distilled away. + image_guidance_scale: Accepted for API parity with :class:`DreamLitePipeline`; **ignored** + because CFG was distilled away. + sigmas: Optional explicit FlowMatch sigmas; defaults to a uniform linspace. + num_images_per_prompt: Output images per prompt (note: ``batch_size`` is forced to 1). + generator: Random generator(s). + output_type: ``"pil"``, ``"np"``, ``"pt"`` or ``"latent"``. + return_dict: If True, returns a :class:`DreamLitePipelineOutput`; else ``(images,)``. + max_sequence_length: Reserved (passed through to ``encode_prompt``). + text_pad_embedding: Optional learned pad embedding for masked positions. + + Returns: + :class:`DreamLitePipelineOutput` or ``tuple``. + """ + # 1. Init pipeline parameters + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + task = "generate" if image is None else "edit" + device = self._execution_device + dtype = self.text_encoder.dtype + batch_size = 1 # Note: pipeline currently forces batch_size = 1. + + if guidance_scale is not None or image_guidance_scale is not None: + logger.warning( + "`guidance_scale` / `image_guidance_scale` are ignored by DreamLiteMobilePipeline " + "because classifier-free guidance was distilled away." + ) + + if sigmas is None: + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) + + # 2. Prepare Time IDs + original_size = (width, height) + add_time_ids = torch.tensor([list(original_size)], device=device, dtype=dtype) + + # 3. Prepare Noise Latents (x_t) + num_channels_latents = self.vae.config.latent_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + dtype, + device, + generator, + ) + + # 4. Prepare Timesteps (FlowMatch with dynamic shift) + image_seq_len = latents.shape[2] * latents.shape[3] // 4 + mu = calculate_shift( + image_seq_len, + self.scheduler.config.get("base_image_seq_len", 256), + self.scheduler.config.get("max_image_seq_len", 4096), + self.scheduler.config.get("base_shift", 0.5), + self.scheduler.config.get("max_shift", 1.16), + ) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + sigmas=sigmas, + mu=mu, + ) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + + # 5. Prepare Conditions (Text & Image) — no negatives because CFG is distilled away + if task == "generate": + prompt_str = f"[Generate]: {prompt}" + prompt_embeds, text_attention_mask = self.encode_prompt( + mode="generate", + prompts=[prompt_str], + device=device, + dtype=dtype, + text_pad_embedding=text_pad_embedding, + ) + if num_images_per_prompt > 1: + prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0) + text_attention_mask = text_attention_mask.repeat_interleave(num_images_per_prompt, dim=0) + image_latents = torch.zeros_like(latents) + else: + prompt_str = ( + f"[Edit]: A diptych with two side-by-side images of the same scene. " + f"Compared to the right side, the left one has {prompt}" + ) + prompt_embeds, text_attention_mask = self.encode_prompt( + mode="edit", + prompts=[prompt_str], + image=image, + device=device, + dtype=dtype, + ) + if num_images_per_prompt > 1: + prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0) + text_attention_mask = text_attention_mask.repeat_interleave(num_images_per_prompt, dim=0) + image_processed = self.image_processor.preprocess(image.resize((width, height), Image.Resampling.LANCZOS)) + image_latents = self.prepare_image_latents( + image_processed, + dtype=dtype, + device=device, + ) + + # 6. Denoising Loop (no CFG: single forward per step, no cat/chunk) + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + model_input = torch.cat([latents, image_latents], dim=3) + time_ids_in = add_time_ids + + # UNet Forward + noise_pred = self.unet( + model_input, + timestep=t.expand(model_input.shape[0]).to(latents.dtype), + encoder_hidden_states=prompt_embeds, + encoder_attention_mask=text_attention_mask, + added_cond_kwargs={"time_ids": time_ids_in}, + return_dict=False, + )[0] + + # Drop extra channels (image-conditioning half of the spatial concat) + noise_pred = noise_pred[..., : latents.shape[-1]] + + # Scheduler Step + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + latents = latents.to(latents_dtype) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + # 7. Decode Latents + if output_type == "latent": + image_out = latents + else: + shift_factor = getattr(self.vae.config, "shift_factor", 0.0) or 0.0 + latents = (latents / self.vae.config.scaling_factor) + shift_factor + image_out = self.vae.decode(latents, return_dict=False)[0] + image_out = self.image_processor.postprocess(image_out, output_type=output_type) + + self.maybe_free_model_hooks() + + if not return_dict: + return (image_out,) + + return DreamLitePipelineOutput(images=image_out) diff --git a/src/diffusers/pipelines/dreamlite/pipeline_output.py b/src/diffusers/pipelines/dreamlite/pipeline_output.py new file mode 100644 index 000000000000..6d4cf15aed6e --- /dev/null +++ b/src/diffusers/pipelines/dreamlite/pipeline_output.py @@ -0,0 +1,35 @@ +# Copyright (c) 2026 ByteDance Ltd. and/or its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import List, Union + +import numpy as np +import PIL.Image + +from ...utils import BaseOutput + + +@dataclass +class DreamLitePipelineOutput(BaseOutput): + """ + Output class for DreamLite pipelines. + + Args: + images (`List[PIL.Image.Image]` or `np.ndarray`): + List of denoised PIL images of length `batch_size` or NumPy array of shape `(batch_size, height, width, + num_channels)`. PIL images or NumPy array present the denoised images of the diffusion pipeline. + """ + + images: Union[List[PIL.Image.Image], np.ndarray] diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index 8317a58b3cd6..3c8df92b6542 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -1140,6 +1140,36 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class DreamLiteTransformer2DModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class DreamLiteUNetModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class EasyAnimateTransformer3DModel(metaclass=DummyObject): _backends = ["torch"] diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index d8965054560c..3cb96ce98721 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -1367,6 +1367,51 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class DreamLiteMobilePipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class DreamLitePipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class DreamLitePipelineOutput(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class EasyAnimateControlPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] diff --git a/tests/pipelines/dreamlite/__init__.py b/tests/pipelines/dreamlite/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/pipelines/dreamlite/test_pipeline_dreamlite.py b/tests/pipelines/dreamlite/test_pipeline_dreamlite.py new file mode 100644 index 000000000000..cfb2b6286c19 --- /dev/null +++ b/tests/pipelines/dreamlite/test_pipeline_dreamlite.py @@ -0,0 +1,429 @@ +# Copyright (c) 2026 ByteDance Ltd. and/or its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for ``DreamLitePipeline``. + +Test design +----------- +``DreamLitePipeline`` depends on Qwen3-VL as its text/image encoder, which is a +large multimodal model that cannot be reasonably miniaturised for CPU tests. +To keep the fast tests CPU-friendly and CI-compatible, we mock out the +``encode_prompt`` method and the ``text_encoder`` / ``tokenizer`` / ``processor`` +sub-modules, and exercise everything else (UNet forward, scheduler, CFG branching, +VAE encode/decode) with real (tiny) modules. + +For end-to-end verification against the original repo, see the +``parity_run_*.py`` scripts shipped with the integration. +""" + +import gc +import os +import unittest +from unittest.mock import MagicMock, patch + +import numpy as np +import torch +from PIL import Image + +from diffusers import ( + AutoencoderTiny, + DreamLitePipeline, + DreamLiteUNetModel, + FlowMatchEulerDiscreteScheduler, +) +from diffusers.utils.testing_utils import ( + enable_full_determinism, + nightly, + require_torch_gpu, + torch_device, +) + +from ..test_pipelines_common import ( + PipelineTesterMixin, + to_np, +) + + +enable_full_determinism() + + +# Cross-attention dim used by the tiny UNet below. ``encode_prompt`` is mocked +# to return embeddings with this final dim, so unet ``encoder_hidden_states`` +# shape matches. +_CROSS_ATTN_DIM = 32 +_DUMMY_SEQ_LEN = 8 + + +def _make_fake_encode_prompt(cross_attn_dim: int = _CROSS_ATTN_DIM, seq_len: int = _DUMMY_SEQ_LEN): + """Build a stand-in for ``DreamLitePipeline.encode_prompt``. + + Returns deterministic ``(prompt_embeds, prompt_embeds_mask)`` with the + correct shapes / dtypes / device so the UNet forward pass type-checks. + """ + + def fake_encode_prompt( + self, + mode, + prompts, + device, + dtype, + image=None, + max_sequence_length=500, + text_pad_embedding=None, + ): + batch = len(prompts) + prompt_embeds = torch.randn(batch, seq_len, cross_attn_dim, device=device, dtype=dtype) + prompt_embeds_mask = torch.ones(batch, seq_len, device=device, dtype=torch.long) + return prompt_embeds, prompt_embeds_mask + + return fake_encode_prompt + + +class DreamLitePipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pipeline_class = DreamLitePipeline + params = frozenset( + [ + "prompt", + "height", + "width", + "guidance_scale", + "negative_prompt", + "num_inference_steps", + ] + ) + batch_params = frozenset(["prompt", "negative_prompt"]) + required_optional_params = frozenset( + [ + "num_inference_steps", + "generator", + "output_type", + "return_dict", + ] + ) + # We mock encode_prompt, so embed-related test conveniences are not applicable. + test_xformers_attention = False + test_attention_slicing = False + test_layerwise_casting = False + test_group_offloading = False + + def get_dummy_components(self): + torch.manual_seed(0) + unet = DreamLiteUNetModel( + sample_size=8, + in_channels=4, + out_channels=4, + down_block_types=( + "CrossAttnDownRemoveSelfAttnBlock2D", + "CrossAttnDownBlock2D", + ), + up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"), + block_out_channels=(32, 64), + cross_attention_dim=_CROSS_ATTN_DIM, + attention_head_dim=8, + layers_per_block=1, + norm_num_groups=8, + transformer_layers_per_block=1, + ) + + torch.manual_seed(0) + vae = AutoencoderTiny( + in_channels=3, + out_channels=3, + encoder_block_out_channels=(32, 32), + decoder_block_out_channels=(32, 32), + num_encoder_blocks=(1, 1), + num_decoder_blocks=(1, 1), + latent_channels=4, + ) + + scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000) + + # text_encoder must expose a real torch.dtype because pipeline does + # ``dtype = self.text_encoder.dtype``. Everything else is mocked. + text_encoder = MagicMock() + text_encoder.dtype = torch.float32 + # Must look like an nn.Module for register_modules; give it a stub. + text_encoder.to = MagicMock(return_value=text_encoder) + text_encoder.eval = MagicMock(return_value=text_encoder) + + tokenizer = MagicMock() + processor = MagicMock() + + return { + "text_encoder": text_encoder, + "tokenizer": tokenizer, + "processor": processor, + "vae": vae, + "unet": unet, + "scheduler": scheduler, + } + + def get_dummy_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device="cpu").manual_seed(seed) + return { + "prompt": "a small dog", + "negative_prompt": "", + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 3.5, + "height": 64, + "width": 64, + "output_type": "np", + } + + def get_dummy_i2i_inputs(self, device, seed=0): + inputs = self.get_dummy_inputs(device, seed) + # 64x64 RGB image -- will be processed by VaeImageProcessor. + inputs["image"] = Image.fromarray((np.random.RandomState(seed).rand(64, 64, 3) * 255).astype(np.uint8)) + inputs["image_guidance_scale"] = 1.5 + return inputs + + # ---- mixin compatibility: auto-patch encode_prompt for ALL tests ------ + # PipelineTesterMixin tests (test_cfg, test_inference_batch_*, etc.) do + # not know about ``_patch_encode_prompt`` and instantiate the pipeline + # themselves, then call it. Without a patched ``encode_prompt`` they hit + # the real Qwen3-VL code path (``drop_idx=34`` slice on a MagicMock + # tokenizer output) and crash inside ``pad_sequence``. Patching at the + # class level via ``unittest.mock.patch.object`` covers every pipeline + # instance built during a test method, with automatic teardown. + def setUp(self): + super().setUp() + self._encode_prompt_patcher = patch.object( + self.pipeline_class, + "encode_prompt", + _make_fake_encode_prompt(_CROSS_ATTN_DIM, _DUMMY_SEQ_LEN), + ) + self._encode_prompt_patcher.start() + + def tearDown(self): + self._encode_prompt_patcher.stop() + super().tearDown() + + # ---- patching helpers -------------------------------------------------- + def _patch_encode_prompt(self, pipe): + fake = _make_fake_encode_prompt(_CROSS_ATTN_DIM, _DUMMY_SEQ_LEN) + pipe.encode_prompt = fake.__get__(pipe, type(pipe)) + + # ---- override mixin tests that don't apply to DreamLite --------------- + # The following inherited PipelineTesterMixin tests are skipped because + # they make assumptions that don't fit DreamLite's design: + # * MagicMock text_encoder cannot be moved between dtypes/devices + # (test_to_dtype, test_torch_dtype_dict) + # * MagicMock components cannot be serialised + # (test_save_load_dduf, test_loading_with_variants, + # test_pipeline_with_accelerator_device_map) + # * UNet uses a custom DreamLiteAttnProcessor2_0 that is not in + # UNet2DConditionModel's ADDED_KV / CROSS_ATTENTION processor sets + # (test_dict_tuple_outputs_equivalent calls set_default_attn_processor) + # * encode_prompt returns (embeds, mask) tuple, not a single tensor + # (test_encode_prompt_works_in_isolation) + # This mirrors what SD3 / Flux do for the same incompatibilities. + @unittest.skip("MagicMock text_encoder has no real dtype propagation.") + def test_to_dtype(self): + pass + + @unittest.skip("MagicMock text_encoder has no real dtype propagation.") + def test_torch_dtype_dict(self): + pass + + @unittest.skip( + "DreamLite intentionally limits ``batch_size`` to 1 (CFG memory blow-up); " + "only ``num_images_per_prompt > 1`` is supported. The mixin sweep over " + "batch_size=[1, 2] x num_images_per_prompt=[1, 2] would fail on " + "batch_size=2 cases." + ) + def test_num_images_per_prompt(self): + pass + + @unittest.skip("MagicMock components cannot be serialised via save_pretrained.") + def test_save_load_dduf(self, atol=1e-4, rtol=1e-4): + pass + + @unittest.skip("MagicMock components cannot be serialised via save_pretrained.") + def test_loading_with_variants(self): + pass + + @unittest.skip("MagicMock components cannot be serialised via save_pretrained.") + def test_pipeline_with_accelerator_device_map(self): + pass + + @unittest.skip( + "DreamLite UNet uses DreamLiteAttnProcessor2_0 which is not in " + "UNet2DConditionModel's default processor set; set_default_attn_processor raises." + ) + def test_dict_tuple_outputs_equivalent(self, expected_max_difference=0.0001): + pass + + @unittest.skip( + "DreamLite encode_prompt returns (embeds, mask) tuple, not a single tensor; " + "the mixin's test_encode_prompt_works_in_isolation assumes single tensor return." + ) + def test_encode_prompt_works_in_isolation(self, extra_required_param_value_dict=None, atol=1e-4, rtol=1e-4): + pass + + # ---- actual tests ------------------------------------------------------ + def test_dreamlite_t2i_default_case(self): + device = torch_device + components = self.get_dummy_components() + pipe = self.pipeline_class(**components).to(device) + pipe.set_progress_bar_config(disable=None) + self._patch_encode_prompt(pipe) + + inputs = self.get_dummy_inputs(device) + out = pipe(**inputs).images + out_np = to_np(out) + + # shape: (B=1, H, W, C=3) + self.assertEqual(out_np.shape, (1, 64, 64, 3)) + self.assertFalse(np.isnan(out_np).any()) + + def test_dreamlite_i2i_default_case(self): + device = torch_device + components = self.get_dummy_components() + pipe = self.pipeline_class(**components).to(device) + pipe.set_progress_bar_config(disable=None) + self._patch_encode_prompt(pipe) + + inputs = self.get_dummy_i2i_inputs(device) + out = pipe(**inputs).images + out_np = to_np(out) + + self.assertEqual(out_np.shape, (1, 64, 64, 3)) + self.assertFalse(np.isnan(out_np).any()) + + def test_dreamlite_cfg_branch_count(self): + """In edit mode the pipeline must run a 3-way CFG concat (uncond/img/text).""" + device = torch_device + components = self.get_dummy_components() + pipe = self.pipeline_class(**components).to(device) + pipe.set_progress_bar_config(disable=None) + self._patch_encode_prompt(pipe) + + original_forward = pipe.unet.forward + seen_batches = [] + + def spy_forward(*args, **kwargs): + x = args[0] if args else kwargs["sample"] + seen_batches.append(x.shape[0]) + return original_forward(*args, **kwargs) + + pipe.unet.forward = spy_forward + inputs = self.get_dummy_i2i_inputs(device) + inputs["num_inference_steps"] = 1 + pipe(**inputs) + + self.assertTrue(all(b == 3 for b in seen_batches), f"expected all 3-way, got {seen_batches}") + + # ---- skips for mixin tests that don't apply --------------------------- + @unittest.skip("DreamLite uses mocked text_encoder; save/load round-trip is N/A.") + def test_save_load_local(self): + pass + + @unittest.skip("DreamLite uses mocked text_encoder; save/load round-trip is N/A.") + def test_save_load_optional_components(self): + pass + + @unittest.skip("DreamLite uses mocked text_encoder; save/load round-trip is N/A.") + def test_save_load_float16(self): + pass + + @unittest.skip("DreamLite uses mocked text_encoder; from_pretrained round-trip is N/A.") + def test_from_pipe_consistent_config(self): + pass + + @unittest.skip("DreamLite uses mocked text_encoder; serialization is N/A.") + def test_serialization(self): + pass + + @unittest.skip("DreamLite forces batch_size=1 internally.") + def test_inference_batch_consistent(self): + pass + + @unittest.skip("DreamLite forces batch_size=1 internally.") + def test_inference_batch_single_identical(self): + pass + + +@nightly +@require_torch_gpu +class DreamLitePipelineSlowTests(unittest.TestCase): + """End-to-end test against the real DreamLite-base checkpoint on the Hub. + + By default this loads ``carlofkl/DreamLite-base`` (``diffusers`` branch) + from the HF Hub. To run against a local copy during development, set the + ``DREAMLITE_BASE_PATH`` env var to that path. + """ + + repo_id = "carlofkl/DreamLite-base" + revision = "diffusers" + + def setUp(self): + super().setUp() + gc.collect() + torch.cuda.empty_cache() + + def tearDown(self): + super().tearDown() + gc.collect() + torch.cuda.empty_cache() + + def _from_pretrained_kwargs(self): + local = os.getenv("DREAMLITE_BASE_PATH") + if local: + return {"pretrained_model_name_or_path": local} + return {"pretrained_model_name_or_path": self.repo_id, "revision": self.revision} + + def test_dreamlite_t2i_real_checkpoint(self): + pipe = DreamLitePipeline.from_pretrained(**self._from_pretrained_kwargs(), torch_dtype=torch.bfloat16).to( + "cuda" + ) + out = pipe( + prompt="a dog running on the grass", + num_inference_steps=2, + guidance_scale=3.5, + height=1024, + width=1024, + generator=torch.Generator("cpu").manual_seed(0), + output_type="np", + ).images + + self.assertEqual(out.shape, (1, 1024, 1024, 3)) + self.assertFalse(np.isnan(out).any()) + + def test_dreamlite_i2i_real_checkpoint(self): + pipe = DreamLitePipeline.from_pretrained(**self._from_pretrained_kwargs(), torch_dtype=torch.bfloat16).to( + "cuda" + ) + + src = Image.fromarray((np.random.RandomState(0).rand(1024, 1024, 3) * 255).astype(np.uint8)) + out = pipe( + prompt="make it look like a painting", + image=src, + num_inference_steps=2, + guidance_scale=3.5, + image_guidance_scale=1.5, + height=1024, + width=1024, + generator=torch.Generator("cpu").manual_seed(0), + output_type="np", + ).images + + self.assertEqual(out.shape, (1, 1024, 1024, 3)) + self.assertFalse(np.isnan(out).any()) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/pipelines/dreamlite/test_pipeline_dreamlite_mobile.py b/tests/pipelines/dreamlite/test_pipeline_dreamlite_mobile.py new file mode 100644 index 000000000000..339d0bed30a7 --- /dev/null +++ b/tests/pipelines/dreamlite/test_pipeline_dreamlite_mobile.py @@ -0,0 +1,402 @@ +# Copyright (c) 2026 ByteDance Ltd. and/or its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for ``DreamLiteMobilePipeline``. + +The mobile pipeline is a distilled, no-CFG sibling of ``DreamLitePipeline``. +It runs a single UNet forward per step (no 3-way concat) and ignores +``guidance_scale`` / ``image_guidance_scale``. Test layout mirrors +``test_pipeline_dreamlite.py``; see that file for the rationale behind +mocking ``encode_prompt``. +""" + +import gc +import os +import unittest +from unittest.mock import MagicMock, patch + +import numpy as np +import torch +from PIL import Image + +from diffusers import ( + AutoencoderTiny, + DreamLiteMobilePipeline, + DreamLiteUNetModel, + FlowMatchEulerDiscreteScheduler, +) +from diffusers.utils.testing_utils import ( + enable_full_determinism, + nightly, + require_torch_gpu, + torch_device, +) + +from ..test_pipelines_common import ( + PipelineTesterMixin, + to_np, +) + + +enable_full_determinism() + + +_CROSS_ATTN_DIM = 32 +_DUMMY_SEQ_LEN = 8 + + +def _make_fake_encode_prompt(cross_attn_dim: int = _CROSS_ATTN_DIM, seq_len: int = _DUMMY_SEQ_LEN): + def fake_encode_prompt( + self, + mode, + prompts, + device, + dtype, + image=None, + max_sequence_length=500, + text_pad_embedding=None, + ): + batch = len(prompts) + prompt_embeds = torch.randn(batch, seq_len, cross_attn_dim, device=device, dtype=dtype) + prompt_embeds_mask = torch.ones(batch, seq_len, device=device, dtype=torch.long) + return prompt_embeds, prompt_embeds_mask + + return fake_encode_prompt + + +class DreamLiteMobilePipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pipeline_class = DreamLiteMobilePipeline + params = frozenset( + [ + "prompt", + "height", + "width", + "num_inference_steps", + ] + ) + batch_params = frozenset(["prompt"]) + required_optional_params = frozenset( + [ + "num_inference_steps", + "generator", + "output_type", + "return_dict", + ] + ) + test_xformers_attention = False + test_attention_slicing = False + test_layerwise_casting = False + test_group_offloading = False + + def get_dummy_components(self): + torch.manual_seed(0) + unet = DreamLiteUNetModel( + sample_size=8, + in_channels=4, + out_channels=4, + down_block_types=( + "CrossAttnDownRemoveSelfAttnBlock2D", + "CrossAttnDownBlock2D", + ), + up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"), + block_out_channels=(32, 64), + cross_attention_dim=_CROSS_ATTN_DIM, + attention_head_dim=8, + layers_per_block=1, + norm_num_groups=8, + transformer_layers_per_block=1, + ) + + torch.manual_seed(0) + vae = AutoencoderTiny( + in_channels=3, + out_channels=3, + encoder_block_out_channels=(32, 32), + decoder_block_out_channels=(32, 32), + num_encoder_blocks=(1, 1), + num_decoder_blocks=(1, 1), + latent_channels=4, + ) + + scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000) + + text_encoder = MagicMock() + text_encoder.dtype = torch.float32 + text_encoder.to = MagicMock(return_value=text_encoder) + text_encoder.eval = MagicMock(return_value=text_encoder) + + tokenizer = MagicMock() + processor = MagicMock() + + return { + "text_encoder": text_encoder, + "tokenizer": tokenizer, + "processor": processor, + "vae": vae, + "unet": unet, + "scheduler": scheduler, + } + + def get_dummy_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device="cpu").manual_seed(seed) + return { + "prompt": "a small dog", + "generator": generator, + "num_inference_steps": 2, + "height": 64, + "width": 64, + "output_type": "np", + } + + def get_dummy_i2i_inputs(self, device, seed=0): + inputs = self.get_dummy_inputs(device, seed) + inputs["image"] = Image.fromarray((np.random.RandomState(seed).rand(64, 64, 3) * 255).astype(np.uint8)) + return inputs + + # ---- mixin compatibility: auto-patch encode_prompt for ALL tests ------ + # PipelineTesterMixin tests (test_cfg, test_inference_batch_*, etc.) do + # not know about ``_patch_encode_prompt`` and instantiate the pipeline + # themselves, then call it. Without a patched ``encode_prompt`` they hit + # the real Qwen3-VL code path (``drop_idx=34`` slice on a MagicMock + # tokenizer output) and crash inside ``pad_sequence``. Patching at the + # class level via ``unittest.mock.patch.object`` covers every pipeline + # instance built during a test method, with automatic teardown. + def setUp(self): + super().setUp() + self._encode_prompt_patcher = patch.object( + self.pipeline_class, + "encode_prompt", + _make_fake_encode_prompt(_CROSS_ATTN_DIM, _DUMMY_SEQ_LEN), + ) + self._encode_prompt_patcher.start() + + def tearDown(self): + self._encode_prompt_patcher.stop() + super().tearDown() + + def _patch_encode_prompt(self, pipe): + fake = _make_fake_encode_prompt(_CROSS_ATTN_DIM, _DUMMY_SEQ_LEN) + pipe.encode_prompt = fake.__get__(pipe, type(pipe)) + + # ---- override mixin tests that don't apply to DreamLite --------------- + # The following inherited PipelineTesterMixin tests are skipped because + # they make assumptions that don't fit DreamLite's design. + # This mirrors what SD3 / Flux do for the same incompatibilities. + @unittest.skip("MagicMock text_encoder has no real dtype propagation.") + def test_to_dtype(self): + pass + + @unittest.skip("MagicMock text_encoder has no real dtype propagation.") + def test_torch_dtype_dict(self): + pass + + @unittest.skip("MagicMock components cannot be serialised via save_pretrained.") + def test_save_load_dduf(self, atol=1e-4, rtol=1e-4): + pass + + @unittest.skip("MagicMock components cannot be serialised via save_pretrained.") + def test_loading_with_variants(self): + pass + + @unittest.skip("MagicMock components cannot be serialised via save_pretrained.") + def test_pipeline_with_accelerator_device_map(self): + pass + + @unittest.skip( + "DreamLite UNet uses DreamLiteAttnProcessor2_0 which is not in " + "UNet2DConditionModel's default processor set; set_default_attn_processor raises." + ) + def test_dict_tuple_outputs_equivalent(self, expected_max_difference=0.0001): + pass + + @unittest.skip( + "DreamLite encode_prompt returns (embeds, mask) tuple, not a single tensor; " + "the mixin's test_encode_prompt_works_in_isolation assumes single tensor return." + ) + def test_encode_prompt_works_in_isolation(self, extra_required_param_value_dict=None, atol=1e-4, rtol=1e-4): + pass + + @unittest.skip( + "DreamLite intentionally limits ``batch_size`` to 1 (CFG memory blow-up); " + "only ``num_images_per_prompt > 1`` is supported. The mixin sweep over " + "batch_size=[1, 2] x num_images_per_prompt=[1, 2] would fail on " + "batch_size=2 cases." + ) + def test_num_images_per_prompt(self): + pass + + def test_mobile_t2i_default_case(self): + device = torch_device + components = self.get_dummy_components() + pipe = self.pipeline_class(**components).to(device) + pipe.set_progress_bar_config(disable=None) + self._patch_encode_prompt(pipe) + + inputs = self.get_dummy_inputs(device) + out = pipe(**inputs).images + out_np = to_np(out) + + self.assertEqual(out_np.shape, (1, 64, 64, 3)) + self.assertFalse(np.isnan(out_np).any()) + + def test_mobile_i2i_default_case(self): + device = torch_device + components = self.get_dummy_components() + pipe = self.pipeline_class(**components).to(device) + pipe.set_progress_bar_config(disable=None) + self._patch_encode_prompt(pipe) + + inputs = self.get_dummy_i2i_inputs(device) + out = pipe(**inputs).images + out_np = to_np(out) + + self.assertEqual(out_np.shape, (1, 64, 64, 3)) + self.assertFalse(np.isnan(out_np).any()) + + def test_mobile_single_forward_per_step(self): + """Mobile pipeline must run exactly ONE UNet forward per step (no CFG concat).""" + device = torch_device + components = self.get_dummy_components() + pipe = self.pipeline_class(**components).to(device) + pipe.set_progress_bar_config(disable=None) + self._patch_encode_prompt(pipe) + + original_forward = pipe.unet.forward + seen_batches = [] + + def spy_forward(*args, **kwargs): + x = args[0] if args else kwargs["sample"] + seen_batches.append(x.shape[0]) + return original_forward(*args, **kwargs) + + pipe.unet.forward = spy_forward + inputs = self.get_dummy_i2i_inputs(device) + inputs["num_inference_steps"] = 2 + pipe(**inputs) + + self.assertTrue(all(b == 1 for b in seen_batches), f"expected all 1-way, got {seen_batches}") + self.assertEqual(len(seen_batches), 2, "expected exactly 2 unet calls (1 per step)") + + def test_mobile_guidance_scale_ignored(self): + """Passing guidance_scale to the mobile pipeline should be accepted but ignored (with warning).""" + device = torch_device + components = self.get_dummy_components() + pipe = self.pipeline_class(**components).to(device) + pipe.set_progress_bar_config(disable=None) + self._patch_encode_prompt(pipe) + + inputs = self.get_dummy_inputs(device) + inputs["guidance_scale"] = 7.5 # should not raise + inputs["image_guidance_scale"] = 1.5 # should not raise + out = pipe(**inputs).images + self.assertEqual(to_np(out).shape, (1, 64, 64, 3)) + + @unittest.skip("DreamLiteMobile uses mocked text_encoder; save/load round-trip is N/A.") + def test_save_load_local(self): + pass + + @unittest.skip("DreamLiteMobile uses mocked text_encoder; save/load round-trip is N/A.") + def test_save_load_optional_components(self): + pass + + @unittest.skip("DreamLiteMobile uses mocked text_encoder; save/load round-trip is N/A.") + def test_save_load_float16(self): + pass + + @unittest.skip("DreamLiteMobile uses mocked text_encoder; from_pretrained round-trip is N/A.") + def test_from_pipe_consistent_config(self): + pass + + @unittest.skip("DreamLiteMobile uses mocked text_encoder; serialization is N/A.") + def test_serialization(self): + pass + + @unittest.skip("DreamLiteMobile forces batch_size=1 internally.") + def test_inference_batch_consistent(self): + pass + + @unittest.skip("DreamLiteMobile forces batch_size=1 internally.") + def test_inference_batch_single_identical(self): + pass + + +@nightly +@require_torch_gpu +class DreamLiteMobilePipelineSlowTests(unittest.TestCase): + """End-to-end test against the real DreamLite-mobile checkpoint on the Hub. + + By default this loads ``carlofkl/DreamLite-mobile`` (``diffusers`` branch) + from the HF Hub. To run against a local copy during development, set the + ``DREAMLITE_MOBILE_PATH`` env var to that path. + """ + + repo_id = "carlofkl/DreamLite-mobile" + revision = "diffusers" + + def setUp(self): + super().setUp() + gc.collect() + torch.cuda.empty_cache() + + def tearDown(self): + super().tearDown() + gc.collect() + torch.cuda.empty_cache() + + def _from_pretrained_kwargs(self): + local = os.getenv("DREAMLITE_MOBILE_PATH") + if local: + return {"pretrained_model_name_or_path": local} + return {"pretrained_model_name_or_path": self.repo_id, "revision": self.revision} + + def test_mobile_t2i_real_checkpoint(self): + pipe = DreamLiteMobilePipeline.from_pretrained( + **self._from_pretrained_kwargs(), torch_dtype=torch.bfloat16 + ).to("cuda") + out = pipe( + prompt="a dog running on the grass", + num_inference_steps=4, + height=1024, + width=1024, + generator=torch.Generator("cpu").manual_seed(0), + output_type="np", + ).images + + self.assertEqual(out.shape, (1, 1024, 1024, 3)) + self.assertFalse(np.isnan(out).any()) + + def test_mobile_i2i_real_checkpoint(self): + pipe = DreamLiteMobilePipeline.from_pretrained( + **self._from_pretrained_kwargs(), torch_dtype=torch.bfloat16 + ).to("cuda") + + src = Image.fromarray((np.random.RandomState(0).rand(1024, 1024, 3) * 255).astype(np.uint8)) + out = pipe( + prompt="make it look like a painting", + image=src, + num_inference_steps=4, + height=1024, + width=1024, + generator=torch.Generator("cpu").manual_seed(0), + output_type="np", + ).images + + self.assertEqual(out.shape, (1, 1024, 1024, 3)) + self.assertFalse(np.isnan(out).any()) + + +if __name__ == "__main__": + unittest.main()