Skip to content

[TE] Enable deterministic mode for fused attention#508

Merged
AllenFarcas merged 20 commits into
devfrom
alfarcas/aima60-fix
Apr 27, 2026
Merged

[TE] Enable deterministic mode for fused attention#508
AllenFarcas merged 20 commits into
devfrom
alfarcas/aima60-fix

Conversation

@AllenFarcas
Copy link
Copy Markdown
Contributor

@AllenFarcas AllenFarcas commented Mar 27, 2026

Description

Please include a brief summary of the changes, relevant motivation and context.

Fixes https://github.com/ROCm/frameworks-internal/issues/15875

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • Added deterministic functionality to fused attention
  • Added test for the introduced deterministic functionality

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Enables deterministic mode propagation for ROCm fused-attention backward (CK backend) and adds JAX coverage to validate bitwise reproducibility and gradient correctness when non-deterministic algorithms are disallowed.

Changes:

  • Forward the deterministic flag from NVTE ROCm fused-attn backward entrypoints into CK backend calls.
  • Add JAX tests that (on HIP/AMD) verify backward gradients are bitwise reproducible across runs and match an unfused JAX reference.

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 2 comments.

File Description
transformer_engine/common/fused_attn_rocm/fused_attn.cpp Passes the deterministic argument into CK fused-attn backward implementations (qkvpacked/kvpacked/separate).
tests/jax/test_fused_attn.py Adds HIP-only deterministic-backward tests and imports global_shard_guard to ensure mesh resource context is set.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread tests/jax/test_fused_attn.py
Comment thread tests/jax/test_fused_attn.py Outdated
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 2 out of 2 changed files in this pull request and generated 4 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread tests/jax/test_fused_attn.py
Comment thread tests/jax/test_fused_attn.py Outdated
Comment thread tests/jax/test_fused_attn.py Outdated
Comment thread tests/jax/test_fused_attn.py Outdated
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 2 out of 2 changed files in this pull request and generated 4 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread tests/jax/test_fused_attn.py Outdated
Comment thread tests/jax/test_fused_attn.py Outdated
Comment thread tests/jax/test_fused_attn.py Outdated
Comment thread tests/jax/test_fused_attn.py Outdated
@AllenFarcas AllenFarcas added the ci-level 1 CI test level 1 label Mar 27, 2026
@Micky774
Copy link
Copy Markdown
Contributor

Micky774 commented Mar 27, 2026

Unless we want to support non-deterministic CK only for the JAX integration, we should probably also add some tests to the pytorch integration side since it'll be enabled there too.

Also I think you still need to adjust

# TODO: remove the filtering after ck team tells us how to enable more deterministic bwd kernels
if use_fused_attention and deterministic and IS_HIP_EXTENSION:
if (
fused_attention_backend == FusedAttnBackend["CK"]
and is_training
):
logger.debug("Disabling FusedAttention for determinism reasons")
use_fused_attention = False
fused_attention_backend = None #TODO: switch to AOTriton when supported

Copy link
Copy Markdown
Collaborator

@wangye805 wangye805 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BTW, add some deterministic testcases in pytorch side as well

Comment thread tests/jax/test_fused_attn.py Outdated
Comment on lines +1273 to +1274
if check_numerical is None:
check_numerical = seq_len <= 256
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we skip checking the numerical for cases with seqlen<=256

Comment thread tests/jax/test_fused_attn.py Outdated
from transformer_engine.jax.cpp_extensions.misc import is_hip_extension
from transformer_engine.jax import autocast
from transformer_engine.jax.sharding import MeshResource
from transformer_engine.jax.sharding import MeshResource, global_shard_guard
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need this?

Comment thread tests/jax/test_fused_attn.py Outdated
if check_numerical is None:
check_numerical = seq_len <= 256
s = seq_len
dtype = jnp.bfloat16
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's check for both bf16 and fp16

Comment thread tests/jax/test_fused_attn.py Outdated
Comment thread tests/jax/test_fused_attn.py Outdated
Comment on lines +1285 to +1293
backend = FusedAttnHelper(
True, dtype, dtype, qkv_layout, AttnBiasType.NO_BIAS, attn_mask_type,
0.0, h_q, h_kv, s, s, d, d, (-1, -1),
).get_fused_attn_backend()
if backend == NVTE_Fused_Attn_Backend.NVTE_No_Backend:
pytest.skip("No fused attention backend available for this config")
assert backend == NVTE_Fused_Attn_Backend.NVTE_CK, (
f"Expected CK backend but got {backend}."
)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Technically, if we specify NVTE_ALLOW_NONDETERMINISTIC_ALGO=0, the backend selection should take this env and choose deterministic backend for us, not restricting to CK. As I recall, aotriton by its nature is deterministic @xinyazhang

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, backend selection API does not support deterministic flag. And yes, TE considers AOTriton as deterministic. The question is if we want to test it for AOTrioton

Comment thread tests/jax/test_fused_attn.py Outdated
Comment thread tests/jax/test_fused_attn.py Outdated
],
)
def test_deterministic_bwd_gqa(attn_mask_type):
"""GQA variant: BSHD_BSHD_BSHD with h_q != h_kv."""
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, extend to nonGQA cases as well

Comment thread tests/jax/test_fused_attn.py Outdated
_run_deterministic_bwd_case(
qkv_layout=QKVLayout.BSHD_BSHD_BSHD,
attn_mask_type=attn_mask_type,
b=2, seq_len=2048, h_q=12, h_kv=4, d=128,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, check some sequence packing cases

Comment thread tests/pytorch/attention/test_attention.py Outdated
Comment thread tests/pytorch/attention/test_attention.py Outdated
Comment thread tests/pytorch/attention/test_attention.py Outdated
@AllenFarcas AllenFarcas requested a review from Micky774 April 22, 2026 22:13
@AllenFarcas AllenFarcas requested review from Copilot and removed request for Copilot April 22, 2026 22:40
Comment thread tests/pytorch/attention/test_attention.py
Comment thread tests/pytorch/attention/test_attention.py Outdated
@AllenFarcas AllenFarcas requested a review from Micky774 April 23, 2026 15:06
Copy link
Copy Markdown
Collaborator

@wangye805 wangye805 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My only concern right now is, aotriton backend is also determinstic. But your current PR should be okay since CK backend has higher priority than aotriton backend

Comment thread docs/envvars.rst Outdated
:Type: ``int`` (0 or 1)
:Default: ``1``
:Description: Allow non-deterministic algorithms for Transformer Engine execution. When set to ``0``, only deterministic algorithms are allowed. This is relevant for both PyTorch and JAX attention implementations.
:Description: Allow non-deterministic algorithms for Transformer Engine execution. When set to ``0``, only deterministic algorithms are allowed. This is relevant for both PyTorch and JAX attention implementations. On AMD/HIP builds, setting this to ``0`` enables the deterministic backward pass of the CK FusedAttention backend (which uses a split-accumulator workspace for deterministic ``dQ``); on NVIDIA builds it disables FusedAttention paths that are known to be non-deterministic.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As far as I know, our aotriton backend is also deterministic. Why setting this only enables the CK backend?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since AOTriton is always deterministic, there is no action required when NVTE_ALLOW_NONDETERMINISTIC_ALGO=0, so this envvar description is technically accuracy since only CK needs that flag to change its behavior.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the wording is accurate, but confusing. The behavior really isn't different from NV. I don't think the description needs to be changed from the original.

@Micky774
Copy link
Copy Markdown
Contributor

Currently it remains untested -- we still need to add it to the CI tests

@AllenFarcas AllenFarcas added ci-level 3 CI test level 3 and removed ci-level 1 CI test level 1 labels Apr 23, 2026
@Micky774
Copy link
Copy Markdown
Contributor

Micky774 commented Apr 23, 2026

Two questions regarding tests:

  1. Do we want to have them be level 1 or 3?
  2. Do we want to have determinism tests be mutually exclusive w/ regular tests, i.e. add a decorator to skip normal tests for AMD if the flag is set?

If we don't make them mutually exclusive, I think they should be level 3. If we do, then level 1 is fine.

cc: @wangye805

@wangye805
Copy link
Copy Markdown
Collaborator

Two questions regarding tests:

  1. Do we want to have them be level 1 or 3?
  2. Do we want to have determinism tests be mutually exclusive w/ regular tests, i.e. add a decorator to skip normal tests for AMD if the flag is set?

If we don't make them mutually exclusive, I think they should be level 3. If we do, then level 1 is fine.

cc: @wangye805

Previously Ilya and I agreed that level 1 should be running things with default config, like hip cast transpose, instead of triton cast transpose. So if this deterministic flow is not set to true by default, we can put it into level 3 ci

Comment thread docs/envvars.rst Outdated
:Type: ``int`` (0 or 1)
:Default: ``1``
:Description: Allow non-deterministic algorithms for Transformer Engine execution. When set to ``0``, only deterministic algorithms are allowed. This is relevant for both PyTorch and JAX attention implementations.
:Description: Allow non-deterministic algorithms for Transformer Engine execution. When set to ``0``, only deterministic algorithms are allowed.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's minimize the diff here

@AllenFarcas AllenFarcas requested a review from Micky774 April 23, 2026 19:26
@Micky774
Copy link
Copy Markdown
Contributor

Let's make sure to wait for a green level 3 CI first

@Micky774
Copy link
Copy Markdown
Contributor

CI failure is unrelated

@AllenFarcas AllenFarcas merged commit 8943023 into dev Apr 27, 2026
2 of 3 checks passed
@AllenFarcas AllenFarcas deleted the alfarcas/aima60-fix branch April 27, 2026 18:41
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci-level 3 CI test level 3

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants