[TE] Enable deterministic mode for fused attention#508
Conversation
There was a problem hiding this comment.
Pull request overview
Enables deterministic mode propagation for ROCm fused-attention backward (CK backend) and adds JAX coverage to validate bitwise reproducibility and gradient correctness when non-deterministic algorithms are disallowed.
Changes:
- Forward the
deterministicflag from NVTE ROCm fused-attn backward entrypoints into CK backend calls. - Add JAX tests that (on HIP/AMD) verify backward gradients are bitwise reproducible across runs and match an unfused JAX reference.
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 2 comments.
| File | Description |
|---|---|
| transformer_engine/common/fused_attn_rocm/fused_attn.cpp | Passes the deterministic argument into CK fused-attn backward implementations (qkvpacked/kvpacked/separate). |
| tests/jax/test_fused_attn.py | Adds HIP-only deterministic-backward tests and imports global_shard_guard to ensure mesh resource context is set. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 2 out of 2 changed files in this pull request and generated 4 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 2 out of 2 changed files in this pull request and generated 4 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
|
Unless we want to support non-deterministic CK only for the JAX integration, we should probably also add some tests to the pytorch integration side since it'll be enabled there too. Also I think you still need to adjust TransformerEngine/transformer_engine/pytorch/attention/dot_product_attention/utils.py Lines 1070 to 1078 in 82617fe |
wangye805
left a comment
There was a problem hiding this comment.
BTW, add some deterministic testcases in pytorch side as well
| if check_numerical is None: | ||
| check_numerical = seq_len <= 256 |
There was a problem hiding this comment.
Why do we skip checking the numerical for cases with seqlen<=256
| from transformer_engine.jax.cpp_extensions.misc import is_hip_extension | ||
| from transformer_engine.jax import autocast | ||
| from transformer_engine.jax.sharding import MeshResource | ||
| from transformer_engine.jax.sharding import MeshResource, global_shard_guard |
| if check_numerical is None: | ||
| check_numerical = seq_len <= 256 | ||
| s = seq_len | ||
| dtype = jnp.bfloat16 |
There was a problem hiding this comment.
Let's check for both bf16 and fp16
| backend = FusedAttnHelper( | ||
| True, dtype, dtype, qkv_layout, AttnBiasType.NO_BIAS, attn_mask_type, | ||
| 0.0, h_q, h_kv, s, s, d, d, (-1, -1), | ||
| ).get_fused_attn_backend() | ||
| if backend == NVTE_Fused_Attn_Backend.NVTE_No_Backend: | ||
| pytest.skip("No fused attention backend available for this config") | ||
| assert backend == NVTE_Fused_Attn_Backend.NVTE_CK, ( | ||
| f"Expected CK backend but got {backend}." | ||
| ) |
There was a problem hiding this comment.
Technically, if we specify NVTE_ALLOW_NONDETERMINISTIC_ALGO=0, the backend selection should take this env and choose deterministic backend for us, not restricting to CK. As I recall, aotriton by its nature is deterministic @xinyazhang
There was a problem hiding this comment.
Well, backend selection API does not support deterministic flag. And yes, TE considers AOTriton as deterministic. The question is if we want to test it for AOTrioton
| ], | ||
| ) | ||
| def test_deterministic_bwd_gqa(attn_mask_type): | ||
| """GQA variant: BSHD_BSHD_BSHD with h_q != h_kv.""" |
There was a problem hiding this comment.
Also, extend to nonGQA cases as well
| _run_deterministic_bwd_case( | ||
| qkv_layout=QKVLayout.BSHD_BSHD_BSHD, | ||
| attn_mask_type=attn_mask_type, | ||
| b=2, seq_len=2048, h_q=12, h_kv=4, d=128, |
There was a problem hiding this comment.
Also, check some sequence packing cases
wangye805
left a comment
There was a problem hiding this comment.
My only concern right now is, aotriton backend is also determinstic. But your current PR should be okay since CK backend has higher priority than aotriton backend
| :Type: ``int`` (0 or 1) | ||
| :Default: ``1`` | ||
| :Description: Allow non-deterministic algorithms for Transformer Engine execution. When set to ``0``, only deterministic algorithms are allowed. This is relevant for both PyTorch and JAX attention implementations. | ||
| :Description: Allow non-deterministic algorithms for Transformer Engine execution. When set to ``0``, only deterministic algorithms are allowed. This is relevant for both PyTorch and JAX attention implementations. On AMD/HIP builds, setting this to ``0`` enables the deterministic backward pass of the CK FusedAttention backend (which uses a split-accumulator workspace for deterministic ``dQ``); on NVIDIA builds it disables FusedAttention paths that are known to be non-deterministic. |
There was a problem hiding this comment.
As far as I know, our aotriton backend is also deterministic. Why setting this only enables the CK backend?
There was a problem hiding this comment.
Since AOTriton is always deterministic, there is no action required when NVTE_ALLOW_NONDETERMINISTIC_ALGO=0, so this envvar description is technically accuracy since only CK needs that flag to change its behavior.
There was a problem hiding this comment.
I think the wording is accurate, but confusing. The behavior really isn't different from NV. I don't think the description needs to be changed from the original.
|
Currently it remains untested -- we still need to add it to the CI tests |
|
Two questions regarding tests:
If we don't make them mutually exclusive, I think they should be level 3. If we do, then level 1 is fine. cc: @wangye805 |
Previously Ilya and I agreed that level 1 should be running things with default config, like hip cast transpose, instead of triton cast transpose. So if this deterministic flow is not set to true by default, we can put it into level 3 ci |
| :Type: ``int`` (0 or 1) | ||
| :Default: ``1`` | ||
| :Description: Allow non-deterministic algorithms for Transformer Engine execution. When set to ``0``, only deterministic algorithms are allowed. This is relevant for both PyTorch and JAX attention implementations. | ||
| :Description: Allow non-deterministic algorithms for Transformer Engine execution. When set to ``0``, only deterministic algorithms are allowed. |
There was a problem hiding this comment.
Let's minimize the diff here
|
Let's make sure to wait for a green level 3 CI first |
|
CI failure is unrelated |
Description
Please include a brief summary of the changes, relevant motivation and context.
Fixes https://github.com/ROCm/frameworks-internal/issues/15875
Type of change
Changes
Checklist: