Skip to content

Implement TeaCache #12652

Open
LawJarp-A wants to merge 43 commits intohuggingface:mainfrom
LawJarp-A:teacache-flux
Open

Implement TeaCache #12652
LawJarp-A wants to merge 43 commits intohuggingface:mainfrom
LawJarp-A:teacache-flux

Conversation

@LawJarp-A
Copy link
Copy Markdown

@LawJarp-A LawJarp-A commented Nov 13, 2025

What does this PR do?

What is TeaCache?

TeaCache (Timestep Embedding Aware Cache) is a training-free caching technique that speeds up diffusion model inference by 1.5x-2.6x by reusing transformer block computations when consecutive timestep embeddings are similar.

Architecture & Design

TeaCache uses a ModelHook to intercept transformer forward passes without modifying model code. The algorithm:

  1. Extracts modulated input from first transformer block (after norm1 + timestep embedding)
  2. Computes relative L1 distance vs previous timestep
  3. Applies model-specific polynomial rescaling: c[0]*x^4 + c[1]*x^3 + c[2]*x^2 + c[3]*x + c[4]
  4. Accumulates rescaled distance across timesteps
  5. If accumulated < threshold → Reuses cached residual (FAST)
  6. If accumulated >= threshold → Full transformer pass (SLOW, update cache)

Key Design Features:

  • Hook-based: Integrates with HookRegistry and CacheMixin for lifecycle management
  • State Isolation: StateManager with context-aware state for CFG conditional/unconditional branches
  • Model Auto-Detection: Detects model type from class name and config path (specific variants checked first)
  • Boundary Guarantee: First and last timesteps always computed fully for quality
  • Specialized Strategies: Dual residual caching (CogVideoX), per-sequence-length caching (Lumina2)

Supported Models

Model Coefficients Status
FLUX Auto-detected Tested
FLUX-Kontext Auto-detected Ready
Mochi Auto-detected Ready
Lumina2 Auto-detected Ready
CogVideoX (2b/5b/1.5-5B) Auto-detected Ready

All models support automatic coefficient detection based on model class name and config path. Custom coefficients can also be provided via TeaCacheConfig.


Benchmark Results (FLUX.1-dev)

Threshold Time Speedup
Baseline 9.26s 1.00x
0.2 6.85s 1.35x
0.4 5.24s 1.77x
0.6 4.64s 2.00x
0.8 4.18s 2.22x

Benchmark Results (Lumina2)

Threshold Time Speedup
Baseline 3.45s 1.00x
0.2 3.07s 1.12x
0.4 2.27s 1.52x
0.6 1.84s 1.88x

Benchmark Results (CogVideoX-2b)

Threshold Time Speedup
Baseline 26.27s 1.00x
0.3 23.97s 1.10x
0.5 22.57s 1.16x
0.7 19.31s 1.36x
0.9 17.38s 1.51x

Benchmark Results (Mochi)

Threshold Time Speedup
Baseline 7.71s 1.00x
0.05 6.27s 1.23x
0.06 6.03s 1.28x
0.08 5.73s 1.35x
0.10 5.41s 1.42x

Test Hardware: NVIDIA h100
Framework: Diffusers with TeaCache hooks
All tests: Same seed (42) for reproducibility

Usage

from diffusers import FluxPipeline
from diffusers.hooks import TeaCacheConfig

pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16)
pipe.to("cuda")

# Enable TeaCache (1.75x speedup with 0.4 threshold)
config = TeaCacheConfig(rel_l1_thresh=0.4)
pipe.transformer.enable_cache(config)

image = pipe("A dragon on a crystal mountain", num_inference_steps=20).images[0]

pipe.transformer.disable_cache()

Configuration Options

The TeaCacheConfig supports the following parameters:

  • rel_l1_thresh (float, default=0.2): Threshold for accumulated relative L1 distance. Recommended values: 0.25 for ~1.5x speedup, 0.4 for ~1.8x, 0.6 for ~2.0x. Mochi models require lower thresholds (0.06-0.09).
  • coefficients (List[float], optional): Polynomial coefficients for rescaling L1 distance. Auto-detected based on model type if not provided.
  • num_inference_steps (int, optional): Total inference steps. Ensures first/last timesteps are always computed. Auto-detected if not provided.
  • num_inference_steps_callback (Callable[[], int], optional): Callback returning total inference steps. Alternative to num_inference_steps.
  • current_timestep_callback (Callable[[], int], optional): Callback returning current timestep. Used for debugging/statistics.

Files Changed

  • src/diffusers/hooks/teacache.py - Core implementation with model-specific forward functions
  • src/diffusers/models/cache_utils.py - CacheMixin integration
  • src/diffusers/hooks/__init__.py - Export TeaCacheConfig and apply_teacache
  • tests/hooks/test_teacache.py - Comprehensive unit tests

Fixes # (issue)
#12589
#12635

Before submitting

Who can review?

@sayakpaul @yiyixuxu @DN6

@sayakpaul sayakpaul requested a review from DN6 November 13, 2025 16:49
@LawJarp-A
Copy link
Copy Markdown
Author

LawJarp-A commented Nov 13, 2025

Work done

  • Implement teacache for FLUX architecture using hooks (only flux for now)
  • add logging
  • add compatible tests

Waiting for feedback and review :)
cc: @dhruvrnaik @sayakpaul @yiyixuxu

@LawJarp-A LawJarp-A marked this pull request as ready for review November 14, 2025 08:23
@LawJarp-A
Copy link
Copy Markdown
Author

Hi @sayakpaul @dhruvrnaik any updates?

@sayakpaul
Copy link
Copy Markdown
Member

@LawJarp-A sorry about the delay on our end. @DN6 will review it soon.

@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@DN6
Copy link
Copy Markdown
Collaborator

DN6 commented Nov 24, 2025

Hi @LawJarp-A I think we would need TeaCache to be implemented in a model agnostic way in order to merge the PR. The First Block Cache implementation is a good reference for this.

@LawJarp-A
Copy link
Copy Markdown
Author

Hi @LawJarp-A I think we would need TeaCache to be implemented in a model agnostic way in order to merge the PR. The First Block Cache implementation is a good reference for this.

Yep @DN6 , I agree, I wanted to first implement it just for a single model and get feedback on that before I work on Model agnostic full implementation. I'm sort of working on it, didn't push it yet. I'll take a look at First block cache for reference as well.
On the same note, lemme know if there is anything to add to the current implementation

@LawJarp-A
Copy link
Copy Markdown
Author

LawJarp-A commented Nov 26, 2025

@DN6 updated it in a more model agnostic way.
Requesting review and feedback

@LawJarp-A
Copy link
Copy Markdown
Author

Added multi model support, testing it thoroughly though.

@LawJarp-A
Copy link
Copy Markdown
Author

Hi @DN6 @sayakpaul
Two questions, I'm almost done testing, I'll update the PR with more descriptive results and changes. And do final cleanup/merging etc

  1. Any tests I should write and anything I can refer to for the same?
  2. Added support for other models, I'll add pictures comparison with speedup and threshold to the PR as well?

In the meantime any feedback would be appreciated

@sayakpaul
Copy link
Copy Markdown
Member

Thanks @LawJarp-A!

Any tests I should write and anything I can refer to for the same?

You can refer to #12569 for testing

Added support for other models, I'll add pictures comparison with speedup and threshold to the PR as well?

Yes, I think that is informative for users.

Copy link
Copy Markdown
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some initial feedback. Most important question is it seems like we need to craft different logic based on different model? Can we not keep it model agnostic?

@LawJarp-A
Copy link
Copy Markdown
Author

LawJarp-A commented Dec 8, 2025

I am trying to think if ways we can avoid having the forward model for each model now. Initially that seemed like th ebe

Some initial feedback. Most important question is it seems like we need to craft different logic based on different model? Can we not keep it model agnostic?

t was fine when I wrote for flux, but lumina needed multi stage preprocessing.
I am trying to think how to , but keeping a generic forward might not work very well :/
Firstcache, FirstBlock all work block level, but TeaCache is more model level.
Defo open to ideas :)

LawJarp-A and others added 2 commits January 12, 2026 16:31
Copy link
Copy Markdown
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I left some comments. LMK if they make sense.

Comment on lines +45 to +47
# Fallback to default context for backward compatibility with
# pipelines that don't call cache_context()
context = "_default"
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this branch not error out like previous?

Comment on lines +109 to +110
if prev_mean.item() > 1e-9:
return ((current - previous).abs().mean() / prev_mean).item()
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to make it data-dependent (item() call)? Raising it because it makes torch.compile cry.

@sayakpaul sayakpaul requested a review from Copilot January 20, 2026 09:09
@sayakpaul
Copy link
Copy Markdown
Member

@bot /style

@github-actions
Copy link
Copy Markdown
Contributor

github-actions bot commented Jan 20, 2026

Style bot fixed some files and pushed the changes.

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR implements TeaCache (Timestep Embedding Aware Cache), a training-free caching technique that speeds up diffusion model inference by 1.5x-2.6x by reusing transformer block computations when consecutive timestep embeddings are similar.

Changes:

  • Adds TeaCache hook system with model-specific forward implementations for FLUX, Mochi, Lumina2, and CogVideoX models
  • Integrates TeaCache with the existing CacheMixin infrastructure for unified cache management
  • Implements StateManager improvements for context-aware state isolation (CFG support)

Reviewed changes

Copilot reviewed 10 out of 10 changed files in this pull request and generated 6 comments.

Show a summary per file
File Description
src/diffusers/hooks/teacache.py Core TeaCache implementation with polynomial rescaling, model auto-detection, and specialized forward functions for each supported model
src/diffusers/models/cache_utils.py Integration of TeaCacheConfig into enable_cache/disable_cache methods
src/diffusers/hooks/init.py Export TeaCacheConfig, apply_teacache, and StateManager
src/diffusers/hooks/hooks.py StateManager enhancement with default context fallback for backward compatibility
src/diffusers/models/transformers/transformer_lumina2.py Add CacheMixin to Lumina2Transformer2DModel
tests/hooks/test_teacache.py Comprehensive unit tests for config validation, state management, and model detection

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

state.cnt = 0
state.accumulated_rel_l1_distance = 0.0
state.previous_modulated_input = None
state.previous_residual = None
Copy link

Copilot AI Jan 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The _maybe_reset_state_for_new_inference method doesn't reset cache_dict and uncond_seq_len which are used by Lumina2. This could cause stale cache data to persist across inference runs when using Lumina2 models. Consider calling state.reset() instead of manually resetting individual fields, or add these Lumina2-specific fields to the reset logic.

Suggested change
state.previous_residual = None
state.previous_residual = None
# Reset Lumina2-specific state to avoid stale cache/data between inference runs
if hasattr(state, "cache_dict") and state.cache_dict is not None:
# Clear in-place to preserve any existing references to the cache dict
state.cache_dict.clear()
if hasattr(state, "uncond_seq_len"):
state.uncond_seq_len = None

Copilot uses AI. Check for mistakes.
@LawJarp-A
Copy link
Copy Markdown
Author

Thanks for the review. Taking a look

@sayakpaul
Copy link
Copy Markdown
Member

sayakpaul commented Feb 16, 2026

@LawJarp-A I am guessing the Copilot review comments were resolved? There also seems to be a couple of unresolved comments.

@sayakpaul
Copy link
Copy Markdown
Member

Cc: @LiewFeng would you like to give this a review as well?

@sayakpaul
Copy link
Copy Markdown
Member

@LawJarp-A a gentle ping.

@LawJarp-A
Copy link
Copy Markdown
Author

@LawJarp-A a gentle ping.

Noted @sayakpaul
Working on a timeline of another week or two.
I'll work on the unresolved comments.

@sayakpaul
Copy link
Copy Markdown
Member

@LawJarp-A sounds good. Let us know whenever you're ready.

@sayakpaul
Copy link
Copy Markdown
Member

@LawJarp-A let us know.

@LawJarp-A
Copy link
Copy Markdown
Author

@LawJarp-A let us know.

Yessir. Will ping here

@LawJarp-A
Copy link
Copy Markdown
Author

Hi @sayakpaul, here’s a quick update addressing the remaining review comments.

Fixes in the latest commit:

  • Fixed infinite recursion in _is_peft_adapter. It now checks PeftAdapterMixin correctly using a lazy import.
  • Handled a Lumina2 edge case. _maybe_reset_state_for_new_inference now re-reads num_steps at cnt=0 and resets if it changes. This covers cases where cnt wraps without a normal reset. (In standard usage, maybe_free_model_hooks() already resets correctly.)
  • Fixed test assertion: updated message from “must be positive” to “must be non-negative” to match the actual error.
  • Updated apply_teacache docstring to recommend pipe.transformer.enable_cache(config) for pipeline usage.

Responses to open comments:

  • StateManager fallback (hooks.py:47): The default fallback is needed since some pipelines (like Lumina2) don’t use cache_context() in their loop. Throwing an error would break them. I can add a logger.debug if you prefer visibility.
  • Calibration utility: Custom coefficients can already be passed via TeaCacheConfig. A calibration step like [Feat] TaylorSeer Cache #12648 would be useful—I'll open a separate issue.
  • Configurable rescaling: The current 4th-degree polynomial comes from the paper and works across tested models. We can revisit this if needed.
  • .item() and torch.compile: All operations stay in tensor form except the final .item() for branching, which is required for Python control flow. This is the minimal graph break.
  • Copilot comments: All addressed—boundary checks, class name matching, zero threshold validation, and Lumina2 state reset.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants