Conversation
…g and refined test case generation for various configurations. - Cleaned up unused variables and improved code readability in the FSDPAGTensor class by removing unnecessary parameters.
… FusedAdam. Added debug print for DTensor in MultiTensorApply.
… tolerances for tensor comparisons. Updated test logic to accommodate new tolerance parameters for improved accuracy in floating-point comparisons.
…l differences in gradient calculations. Clean up unused debug print statements in MultiTensorApply and ensure proper newline at the end of the FSDPAGTensor serialization method.
| if not isinstance(quantizer, MXFP8Quantizer) and not self._keep_fp8_weight_transpose_cache: | ||
| quantizer = module.quantizers["scaling_fwd"][self._fp8_meta_index] | ||
| if not isinstance(quantizer, MXFP8Quantizer): | ||
| quantizer.set_usage(columnwise=False) |
There was a problem hiding this comment.
For FSDP2 with FP8, keep_fp8_weight_transpose_cache should be False. Caching the transposed weight would imply an all-gather of the transposed tensor as well, increasing memory and communication and negating the advantages of FSDP2’s sharded parameter layout.
| data = torch.zeros_like(param, dtype=torch.int16) | ||
| else: | ||
| data = torch.empty(param.shape, dtype=dtype, device=param.device) | ||
| data = torch.empty_like(param, dtype=dtype) |
There was a problem hiding this comment.
When using FSDP2, parameters are DTensors, and when we do torch.zeros() or torch.empty() we create regular pytorch Tensors.
This was causing
[rank1]: RuntimeError: aten.copy_.default: got mixed torch.Tensor and DTensor, need to convert all torch.Tensor to DTensor before calling distributed operators!
[rank7]: File "/workspace/TransformerEngine/transformer_engine/pytorch/optimizers/fused_adam.py", line 422, in initialize_state
[rank7]: self.set_scaled_state(param, "master_param", param.clone().detach().float())
[rank7]: File "/workspace/TransformerEngine/transformer_engine/pytorch/optimizers/fused_adam.py", line 363, in set_scaled_state
[rank7]: state[state_name].copy_(unscaled_state)
Fix:
Keep optimizer state consistent with the parameter type: when parameters are DTensors, state should be DTensors as well. Using torch.empty_like(param, ...) (and the same idea for other state buffers) creates state as a DTensor with the same placement as param, so both sides of copy_ are DTensors and the error is avoided.
There was a problem hiding this comment.
Is it upstream fix cherry-picking?
There was a problem hiding this comment.
Upstream fixes this in TEv2.12, along with few other fixes.
NVIDIA/TransformerEngine@fe8fad5#diff-0801a8d92a56d458946da1439b62e0add1613b7da83d31bc218a852b6b9e42b1
This wasn't cherry picked.
…by adding a newline character after the pass statement in the test_dummy function.
|
|
||
| # Zero the parameter gradients | ||
| optimizer.zero_grad() | ||
| with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe): |
There was a problem hiding this comment.
Does with te.fp8_autocast(enabled=args.fp8_autocast,.. ) do the same?
There was a problem hiding this comment.
It does do the same but since with TEv2.10, te.fp8_autocast is replaced with te.autocast, I've made the change to be consistent.
There was a problem hiding this comment.
So will 'with te.autocast(enabled=args.fp8_autocast, recipe=...)' do the same as if/else?
There was a problem hiding this comment.
Yes, it should. I'll make the changes.
| assert len(l1) == len(l2), "Unequal number of outputs." | ||
| for i, (t1, t2) in enumerate(zip(l1, l2)): | ||
| result = torch.allclose(t1, t2, atol=0, rtol=0) | ||
| tols = dict(atol=atol) |
There was a problem hiding this comment.
Move tolls calculation out of the loop
…s for improved clarity and consistency.
Manually ported fix from upstream commit 139c863 The full commit was not cherry-picked due to unrelated changes across many files. Addressed PR comments
- Updated model loading to conditionally use weights_only based on quantized_init. - Modified optimizer initialization to remove master_weights parameter when not using quantized_init. - Improved test setup for quantized models by integrating quantized_model_init with appropriate scaling recipes. - Adjusted tolerance comments in tests to clarify FP8 behavior with FSDP2 and DDP. - Added checks in base module to skip FP8 reduction for FSDP2 based on primary weights status.
- Updated tolerance logic in assert_allclose to use the second tensor for relative tolerance calculations. - Adjusted tolerance values based on quantization initialization conditions to ensure accurate testing of FP8 behavior with FSDP2.
- Add preserving of amax/scales when copying fp32 tensor to already existing fp8 tensor. - Removed unnecessary model loading logic for quantized initialization in the training script since we already use same random seed. - Exclude quantized_init + non autocast testcase. - Updated comments to clarify tolerance handling in FP8 tests.
| # DDP broadcast path: _broadcast_coalesced dequantizes Float8Tensors | ||
| # (via aten::cat fallback) then copies the plain tensor back. | ||
| # Re-quantize while preserving the original quantizer state. | ||
| if isinstance(dst, Float8Tensor) and not isinstance(src, Float8Tensor): |
There was a problem hiding this comment.
When PyTorch's native DDP broadcasts parameters, broadcast_coalesced concatenates all module parameters (FP8 weights and FP32 biases) via aten::cat, which triggers Float8Tensor's fallback dispatch path and dequantizes the FP8 weights to FP32. After the broadcast, aten::copy writes the FP32 data back into the Float8Tensor parameter, calling quantize_() which recomputes amax and scale from the dequantized values — but these differ slightly from the originals due to FP8 round-trip error, causing numerical divergence from FSDP2. The fix intercepts this plain-tensor-to-Float8Tensor copy_ path, saves the quantizer's amax and scale before re-quantization, and restores them afterward, so the DDP broadcast becomes a no-op with respect to quantizer state. This makes the test pass for a model with layernormMLP, LayernormLinear and Linear module with 0 tolerance.
- Added `linear_only` parameter to `SimpleNet` to allow usage of only the final linear layer. - Updated model initialization and forward pass logic to conditionally skip LayerNorm layers when `linear_only` is enabled. - Modified argument parsing to include `--linear-only` flag for test scripts, ensuring compatibility with quantized initialization scenario.
Description
This PR adds unit test covering different configurations such as:
All the unit tests compare FSDP2 vs DDP grads/output.
This PR also cleans up fsdp2_all_gather_tensor to match upstream's methods.
This PR also fixes issue with fused_adam when using it with FSDP2.
Fixes # (https://github.com/ROCm/frameworks-internal/issues/15291)
Type of change
Checklist: