Skip to content

Add fsdp2 fp8 unit tests TE 2.10#492

Open
sudhu2k wants to merge 16 commits intodevfrom
sudhu/FSDP2_unit_tests_fix_2.10
Open

Add fsdp2 fp8 unit tests TE 2.10#492
sudhu2k wants to merge 16 commits intodevfrom
sudhu/FSDP2_unit_tests_fix_2.10

Conversation

@sudhu2k
Copy link
Copy Markdown
Contributor

@sudhu2k sudhu2k commented Mar 17, 2026

Description

This PR adds unit test covering different configurations such as:

  1. delayed scaling + fp8 autocast only
  2. delayed scaling + fp8 init only (new)
  3. delayed scaling + fp8 init + fp8 autocast (new)
  4. current scaling + fp8 autocast only
  5. current scaling + fp8 init only (new)
  6. current scaling + fp8 init + fp8 autocast (new)
  7. MXFP8 scaling + fp8 autocast only
  8. MXFP8 scaling + fp8 init only (new)
  9. MXFP8 scaling + fp8 init + fp8 autocast (new)
  10. fp32 (new)

All the unit tests compare FSDP2 vs DDP grads/output.

This PR also cleans up fsdp2_all_gather_tensor to match upstream's methods.

  1. Removes keep_fp8_weight_transpose_cache dependency
  2. Removes storing module reference to the tensor.

This PR also fixes issue with fused_adam when using it with FSDP2.

Fixes # (https://github.com/ROCm/frameworks-internal/issues/15291)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

sudhu2k and others added 5 commits March 17, 2026 01:23
…g and refined test case generation for various configurations.

- Cleaned up unused variables and improved code readability in the FSDPAGTensor class by removing unnecessary parameters.
… FusedAdam. Added debug print for DTensor in MultiTensorApply.
… tolerances for tensor comparisons. Updated test logic to accommodate new tolerance parameters for improved accuracy in floating-point comparisons.
…l differences in gradient calculations. Clean up unused debug print statements in MultiTensorApply and ensure proper newline at the end of the FSDPAGTensor serialization method.
if not isinstance(quantizer, MXFP8Quantizer) and not self._keep_fp8_weight_transpose_cache:
quantizer = module.quantizers["scaling_fwd"][self._fp8_meta_index]
if not isinstance(quantizer, MXFP8Quantizer):
quantizer.set_usage(columnwise=False)
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For FSDP2 with FP8, keep_fp8_weight_transpose_cache should be False. Caching the transposed weight would imply an all-gather of the transposed tensor as well, increasing memory and communication and negating the advantages of FSDP2’s sharded parameter layout.

data = torch.zeros_like(param, dtype=torch.int16)
else:
data = torch.empty(param.shape, dtype=dtype, device=param.device)
data = torch.empty_like(param, dtype=dtype)
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When using FSDP2, parameters are DTensors, and when we do torch.zeros() or torch.empty() we create regular pytorch Tensors.
This was causing
[rank1]: RuntimeError: aten.copy_.default: got mixed torch.Tensor and DTensor, need to convert all torch.Tensor to DTensor before calling distributed operators!

[rank7]:   File "/workspace/TransformerEngine/transformer_engine/pytorch/optimizers/fused_adam.py", line 422, in initialize_state
[rank7]:     self.set_scaled_state(param, "master_param", param.clone().detach().float())
[rank7]:   File "/workspace/TransformerEngine/transformer_engine/pytorch/optimizers/fused_adam.py", line 363, in set_scaled_state
[rank7]:     state[state_name].copy_(unscaled_state)

Fix:
Keep optimizer state consistent with the parameter type: when parameters are DTensors, state should be DTensors as well. Using torch.empty_like(param, ...) (and the same idea for other state buffers) creates state as a DTensor with the same placement as param, so both sides of copy_ are DTensors and the error is avoided.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it upstream fix cherry-picking?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Upstream fixes this in TEv2.12, along with few other fixes.
NVIDIA/TransformerEngine@fe8fad5#diff-0801a8d92a56d458946da1439b62e0add1613b7da83d31bc218a852b6b9e42b1
This wasn't cherry picked.

…by adding a newline character after the pass statement in the test_dummy function.
@sudhu2k sudhu2k marked this pull request as ready for review March 17, 2026 21:45
@sudhu2k sudhu2k self-assigned this Mar 17, 2026

# Zero the parameter gradients
optimizer.zero_grad()
with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does with te.fp8_autocast(enabled=args.fp8_autocast,.. ) do the same?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It does do the same but since with TEv2.10, te.fp8_autocast is replaced with te.autocast, I've made the change to be consistent.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So will 'with te.autocast(enabled=args.fp8_autocast, recipe=...)' do the same as if/else?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it should. I'll make the changes.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

assert len(l1) == len(l2), "Unequal number of outputs."
for i, (t1, t2) in enumerate(zip(l1, l2)):
result = torch.allclose(t1, t2, atol=0, rtol=0)
tols = dict(atol=atol)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Move tolls calculation out of the loop

@sudhu2k sudhu2k requested a review from ipanfilo March 18, 2026 15:10
Comment thread tests/pytorch/distributed/run_fsdp2_fp8_model.py Outdated
Comment thread tests/pytorch/distributed/test_torch_fsdp2_fp8.py Outdated
Comment thread tests/pytorch/distributed/test_torch_fsdp2_fp8.py Outdated
Comment thread transformer_engine/pytorch/tensor/fsdp2_allgather_tensor.py
Comment thread transformer_engine/pytorch/tensor/fsdp2_allgather_tensor.py Outdated
Comment thread transformer_engine/pytorch/tensor/fsdp2_allgather_tensor.py Outdated
Comment thread transformer_engine/pytorch/tensor/fsdp2_allgather_tensor.py Outdated
Comment thread transformer_engine/pytorch/tensor/fsdp2_allgather_tensor.py
Manually ported fix from upstream commit 139c863
The full commit was not cherry-picked due to unrelated changes
across many files.

Addressed PR comments
Comment thread tests/pytorch/distributed/test_torch_fsdp2_fp8.py
sudhu2k added 2 commits April 6, 2026 04:04
- Updated model loading to conditionally use weights_only based on quantized_init.
- Modified optimizer initialization to remove master_weights parameter when not using quantized_init.
- Improved test setup for quantized models by integrating quantized_model_init with appropriate scaling recipes.
- Adjusted tolerance comments in tests to clarify FP8 behavior with FSDP2 and DDP.
- Added checks in base module to skip FP8 reduction for FSDP2 based on primary weights status.
@sudhu2k sudhu2k requested a review from wangye805 April 7, 2026 03:43
@sudhu2k sudhu2k added the ci-level 3 CI test level 3 label Apr 7, 2026
sudhu2k added 2 commits April 7, 2026 16:52
- Updated tolerance logic in assert_allclose to use the second tensor for relative tolerance calculations.
- Adjusted tolerance values based on quantization initialization conditions to ensure accurate testing of FP8 behavior with FSDP2.
- Add preserving of amax/scales when copying fp32 tensor to already existing fp8 tensor.
- Removed unnecessary model loading logic for quantized initialization in the training script since we already use same random seed.
- Exclude quantized_init + non autocast testcase.
- Updated comments to clarify tolerance handling in FP8 tests.
# DDP broadcast path: _broadcast_coalesced dequantizes Float8Tensors
# (via aten::cat fallback) then copies the plain tensor back.
# Re-quantize while preserving the original quantizer state.
if isinstance(dst, Float8Tensor) and not isinstance(src, Float8Tensor):
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When PyTorch's native DDP broadcasts parameters, broadcast_coalesced concatenates all module parameters (FP8 weights and FP32 biases) via aten::cat, which triggers Float8Tensor's fallback dispatch path and dequantizes the FP8 weights to FP32. After the broadcast, aten::copy writes the FP32 data back into the Float8Tensor parameter, calling quantize_() which recomputes amax and scale from the dequantized values — but these differ slightly from the originals due to FP8 round-trip error, causing numerical divergence from FSDP2. The fix intercepts this plain-tensor-to-Float8Tensor copy_ path, saves the quantizer's amax and scale before re-quantization, and restores them afterward, so the DDP broadcast becomes a no-op with respect to quantizer state. This makes the test pass for a model with layernormMLP, LayernormLinear and Linear module with 0 tolerance.

sudhu2k added 2 commits April 13, 2026 17:04
- Added `linear_only` parameter to `SimpleNet` to allow usage of only the final linear layer.
- Updated model initialization and forward pass logic to conditionally skip LayerNorm layers when `linear_only` is enabled.
- Modified argument parsing to include `--linear-only` flag for test scripts, ensuring compatibility with quantized initialization scenario.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci-level 3 CI test level 3

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants