Add NVTE_BACKWARD_OVERRIDE=high_precision|dequantized#2644
Conversation
Greptile SummaryThis PR adds
Confidence Score: 5/5Safe to merge — all previously flagged blocking issues are resolved; remaining findings are P2 style suggestions. All previously raised P0/P1 concerns (duplicate recipe fields, recipe None crash, LayerNormMLP assertion message quality, unnecessary saved tensors, DelayedScaling interaction) are addressed in this revision. The feature is guarded by explicit assertions with clear error messages for unsupported combinations. The only new findings are a defensive getattr suggestion in fuser.py and a minor asymmetry in empty-tensor guards for MXFP8 storage, both P2. Comprehensive test coverage was added. transformer_engine/pytorch/ops/fuser.py (minor: direct attribute access on recipe.backward_override), transformer_engine/pytorch/tensor/storage/mxfp8_tensor_storage.py (minor: asymmetric empty-tensor guard in dequantize vs _FromMXFP8Func.forward) Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A[Forward Pass - quantized fprop] --> B{NVTE_BACKWARD_OVERRIDE}
B -->|None| C[Default: save rowwise+columnwise quantized tensors]
B -->|high_precision| D[Save original unquantized input and weight]
B -->|dequantized| E[Save rowwise-only quantized tensors]
C --> F[Backward: quantized dgrad and wgrad GEMMs]
D --> G[Backward: high-precision dgrad and wgrad using original fp16/bf16/fp32 operands]
E --> H[Backward: dequantize saved tensors then high-precision GEMMs]
subgraph Supported
L[Linear]
M[LayerNormLinear]
N[GroupedLinear]
O[ops.Linear / fused ops]
end
subgraph Unsupported - assertion error with clear message
P[LayerNormMLP]
Q[DelayedScaling recipe]
end
Reviews (43): Last reviewed commit: "Merge branch 'main' into keep-bwd" | Re-trigger Greptile |
|
I'll work on potential unit test breakage. |
| # Note: dgrad GEMM requires row-wise usage, wgrad GEMM | ||
| # requires column-wise usage | ||
| if ctx.grad_output_quantizer is not None: | ||
| if ctx.grad_output_quantizer is not None and use_fp8_bwd: |
There was a problem hiding this comment.
this line seems redundant since you already skip the quantization step in base.py grad_output_preprocess?
| not ctx.use_bias | ||
| and not ctx.requires_wgrad | ||
| and ctx.grad_output_quantizer is not None | ||
| and use_fp8_bwd |
| recipe = cls.get_fp8_recipe() | ||
| if recipe is not None and recipe.delayed(): | ||
| # Ignore NVTE_KEEP_BACKWARD_UNQUANTIZED when delayed scaling is used | ||
| return False |
There was a problem hiding this comment.
Maybe it's better to assert an error for delayed scaling? Okay with both.
There was a problem hiding this comment.
I agree. If the user specifies an unsupported combination, I think it's better to fail loudly than to secretly disobey their instructions.
| # Note: dgrad GEMM requires row-wise usage, wgrad GEMM | ||
| # requires column-wise usage | ||
| if ctx.grad_output_quantizer is not None: | ||
| if ctx.grad_output_quantizer is not None and use_fp8_bwd: |
There was a problem hiding this comment.
this seems redundant too if we skip quant in grad_output_preprocess
Signed-off-by: Ziang Li <ziangli@umich.edu>
…zed` Signed-off-by: Ziang Li <ziangli@umich.edu>
NVTE_BACKWARD_MODE=default|unquant|dequantNVTE_BACKWARD_OVERRIDE=high_precision|dequantized
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
@victordion I think you are describing the default TE 1d recipe or requantized behavior. |
Right. My mistake. My mental model assumed there is requantize happening. Thanks for responding! |
|
Regarding the env var design, since this feature is mainly used by RL, there has to be a way for the user to directly override the bwd behavior in RL framework instead of plumbing all the way through Megatron. |
|
/te-ci L0 L1 |
|
All pytorch ci passed. Some failed jax tests are due to |
|
/te-ci L0 L1 |
|
/te-ci L0 L1 |
Signed-off-by: Przemek Tredak <ptredak@nvidia.com>
|
/te-ci L0 L1 |
|
Failed JAX ci is unrelated to this PR: B200: L40: |
|
Merged, thank you for your contribution @zianglih! |
* Add NVTE_KEEP_BACKWARD_UNQUANTIZED Signed-off-by: Ziang Li <ziangli@umich.edu> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Disable ub and clean up Signed-off-by: Ziang Li <ziangli@umich.edu> * Drop fuser changes Signed-off-by: Ziang Li <ziangli@umich.edu> * Replace use_quantized_bwd with use_fp8_bwd Signed-off-by: Ziang Li <ziangli@umich.edu> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Ignore keep_backward_unquantized if delayed scaling Signed-off-by: Ziang Li <ziangli@umich.edu> * Refactor ignoring NVTE_KEEP_BACKWARD_UNQUANTIZED when delayed scaling is used Signed-off-by: Ziang Li <ziangli@umich.edu> * Add back missing ctx.debug Signed-off-by: Ziang Li <ziangli@umich.edu> * Refactor changes under fused Signed-off-by: Ziang Li <ziangli@umich.edu> * Clean up Signed-off-by: Ziang Li <ziangli@umich.edu> * Refactor high-precision overwrite if keep_backward_unquantized Signed-off-by: Ziang Li <ziangli@umich.edu> * Clean up Signed-off-by: Ziang Li <ziangli@umich.edu> * Drop redundant fp8_recipe_bwd Signed-off-by: Ziang Li <ziangli@umich.edu> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Drop redundant ub changes Signed-off-by: Ziang Li <ziangli@umich.edu> * Drop more redundant ub changes Signed-off-by: Ziang Li <ziangli@umich.edu> * Drop redundant delayed scaling changes Signed-off-by: Ziang Li <ziangli@umich.edu> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Drop unneeded backwards_needs_fc1_input Signed-off-by: Ziang Li <ziangli@umich.edu> * Drop and disallow LayerNormMLP implementation Signed-off-by: Ziang Li <ziangli@umich.edu> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Move interface changes to recipe Signed-off-by: Ziang Li <ziangli@umich.edu> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Move ub overrides to fwd Signed-off-by: Ziang Li <ziangli@umich.edu> * Remove duplication Signed-off-by: Ziang Li <ziangli@umich.edu> * Simplify use_fp8_bwd logic in bwd Signed-off-by: Ziang Li <ziangli@umich.edu> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Set grad quantizers to none if keep bwd unquantized Signed-off-by: Ziang Li <ziangli@umich.edu> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Drop delayed scaling change Signed-off-by: Ziang Li <ziangli@umich.edu> * Simplify env var logic Signed-off-by: Ziang Li <ziangli@umich.edu> * Move validation check to recipe Signed-off-by: Ziang Li <ziangli@umich.edu> * Simplify effective_enabled Signed-off-by: Ziang Li <ziangli@umich.edu> * Fix inverted assertion logic Signed-off-by: Ziang Li <ziangli@umich.edu> * Simplify changes under ops Signed-off-by: Ziang Li <ziangli@umich.edu> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Simplify ctx.keep_backward_unquantized Signed-off-by: Ziang Li <ziangli@umich.edu> * Fix missing attribute Signed-off-by: Ziang Li <ziangli@umich.edu> * Add unit tests Signed-off-by: Ziang Li <ziangli@umich.edu> * Fix bias errors in unit test Signed-off-by: Ziang Li <ziangli@umich.edu> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Add more shapes to unit test Signed-off-by: Ziang Li <ziangli@umich.edu> * Refator interface to `NVTE_BACKWARD_MODE=default|unquant|dequant` Signed-off-by: Ziang Li <ziangli@umich.edu> * Fix override and clean up Signed-off-by: Ziang Li <ziangli@umich.edu> * Clean up unit test Signed-off-by: Ziang Li <ziangli@umich.edu> * Clean up unit test Signed-off-by: Ziang Li <ziangli@umich.edu> * Override `ctx.reduce_and_update_bwd_fp8_tensors = False` Signed-off-by: Ziang Li <ziangli@umich.edu> * Expand unit test Signed-off-by: Ziang Li <ziangli@umich.edu> * Add `test_backward_mode_memory_peak_report` Signed-off-by: Ziang Li <ziangli@umich.edu> * Expand test coverage and fix Signed-off-by: Ziang Li <ziangli@umich.edu> * Use `numel()` Signed-off-by: Ziang Li <ziangli@umich.edu> * Refactor unit test Signed-off-by: Ziang Li <ziangli@umich.edu> * Fix grouped linear to override `*_quantizers` instead of `*_quantizer` Signed-off-by: Ziang Li <ziangli@umich.edu> * Only save input/weight when `*_requires_grad` on unquant mode Signed-off-by: Ziang Li <ziangli@umich.edu> * Fix Blackwell debug ci Signed-off-by: Ziang Li <ziangli@umich.edu> * Fix sm89 and sm90 tests Signed-off-by: Ziang Li <ziangli@umich.edu> * Fix unquant mode memory saving Signed-off-by: Ziang Li <ziangli@umich.edu> * Refactor interface to `NVTE_BACKWARD_OVERRIDE=high_precision|dequantized` Signed-off-by: Ziang Li <ziangli@umich.edu> * Rename unit test Signed-off-by: Ziang Li <ziangli@umich.edu> * Simplify env var parsing Signed-off-by: Ziang Li <ziangli@umich.edu> --------- Signed-off-by: Ziang Li <ziangli@umich.edu> Signed-off-by: Przemek Tredak <ptredak@nvidia.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com> Co-authored-by: Przemek Tredak <ptredak@nvidia.com>
Description
@HumansAnd
Add an NVTE_KEEP_BACKWARD_UNQUANTIZED env var for quantized fprop + high precision wgrad & dgrad.Add NVTE_BACKWARD_MODE=default|unquant|dequant env varAdd
NVTE_BACKWARD_OVERRIDE=high_precision|dequantizedenv var:high_precision: quantized fprop + high precision wgrad & dgrad using unquantized activation and weightdequantized: quantized fpop + high precision wgrad & dgrad using activation and weight dequantized directly from fprop quantized valueThe movitivation for this dequantized design is RL. Unlike pre-training which only needs to preserve coarse optimization direction and convergence, RL gradients are noisy and useful updates are small and delicate. If gradient quantization and chain rule violation are present, noise dominates the true and fragile update signal and model will collapse. This dequantized design avoids gradient quantization and effectively preserves chain rule.
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: