Skip to content

Full MXFP4 Training Recipe#537

Merged
sudhu2k merged 65 commits intodevfrom
feature/mxfp4-recipe-212
Apr 28, 2026
Merged

Full MXFP4 Training Recipe#537
sudhu2k merged 65 commits intodevfrom
feature/mxfp4-recipe-212

Conversation

@sarthak-amd
Copy link
Copy Markdown
Contributor

@sarthak-amd sarthak-amd commented Apr 13, 2026

Summary

  1. Introduces nvte_cast_transpose_mxfp4_fused_shuffle — fused HIP kernel for MXFP4 cast+transpose with optional Hadamard transform and memory-layout shuffling for GEMM.

  2. fp4_gemm_handler that dispatches A4W4 GEMM calls to the AITER backend with the right layouts.

  3. MXFP4 weight caching in Linear.forward() and LayerNormLinear.forward() that persists quantized MXFP4TensorStorage weights across forward passes.

Training recipe can be enabled using:

export FP4_RECIPE=mxfp4
export NVTE_MXFP4_USE_HADAMARD=1
export FP4=True

MXFP4 Flow:

image

Loss graph:

image

Test plan

…rd kernel

Ports the MXFP4 training recipe from TE 2.8 (mxfp-fused-2.8) to the dev
branch (TE 2.10). Uses the HIP-based cast+transpose kernel with fused
shuffle and Hadamard transform for FP4 quantization, and AITER ASM a4w4
kernels for FP4 GEMM.

Changes:
- csrc: Add mxfp4_hip.cpp wrapper + cast_transpose_mxfp4_kernel_shuffled.cu
- csrc: Register cast_transpose_mxfp4_fused_shuffle in pybind + extensions.h
- tensor/mxfp4_tensor.py: Replace Triton quantize with HIP kernel, add
  quantize_impl(), respect USE_HADAMARD env var
- module/linear.py, layernorm_linear.py: Add MXFP4 weight cache with
  rowwise-only optimization and quantized-norm bypass
- module/fp4_handler_gemm.py: AITER ASM a4w4 GEMM handler with shape-aware
  kernel selection and float4_e2m1fn_x2 dtype conversion
- cpp_extensions/gemm.py: Route MXFP4TensorStorage to fp4_handler_gemm
- build_tools/pytorch.py: Collect .cu files for hipify/hipcc compilation

Made-with: Cursor
Made-with: Cursor
@sudhu2k sudhu2k added the ci-level 3 CI test level 3 label Apr 13, 2026
sarthak-amd and others added 22 commits April 13, 2026 14:42
Made-with: Cursor
Made-with: Cursor
- Updated .gitignore to exclude specific MXFP4 HIP files.
- Introduced new scaling mode NVTE_MXFP4_1D_SCALING in common headers.
- Enhanced scaling checks in CheckScaleTensorShape and CheckInputTensor functions to accommodate MXFP4.
- Added MXFP4Quantizer class for handling MXFP4 tensor quantization, including tensor creation and parameter setting.
- Updated quantization implementation in MXFP4Quantizer to utilize a new quantization function based on environment settings.

This commit improves the handling of MXFP4 tensors and their quantization process, ensuring compatibility with the latest scaling modes and tensor operations.
For LoRA SFT, frozen base weights have is_grad_enabled=False so
columnwise quantization is skipped (no wgrad needed). For full
pretraining all weights have gradients so this is a no-op.

Made-with: Cursor
- Introduced MXFP4 quantization logic in the quantization dispatch, including Hadamard transform and data shuffling options.
- Added new MXFP4 tensor storage and quantizer classes to manage MXFP4 data formats and operations.
- Updated CMakeLists.txt to include new MXFP4 source files and dependencies.
- Enhanced common headers to define new quantization configuration attributes for MXFP4.
- Implemented MXFP4 quantization in the PyTorch interface, allowing for flexible tensor operations.

This commit significantly improves the MXFP4 support in the transformer engine, enabling advanced quantization techniques and optimizing performance for AMD architectures.
Move FP4 AITER GEMM handler from module/fp4_handler_gemm.py into
cpp_extensions/gemm.py alongside other GEMM dispatch paths. Remove
mxfp4_hip.cpp which is no longer needed after the nvte_quantize_v2
refactor.
handling for MXFP4 alongside existing FP8/NVFP4 paths.

Expose MXFP4 device support via check_mxfp4_support / is_mxfp4_available
(and FP8GlobalStateManager), validate MXFP4 in check_recipe_support, and
return a larger alignment when recipe.mxfp4().

Teach TransformerEngineBase.set_meta_tensor and LayerNormMLP activation
fusion gating to treat MXFP4 like other non-fused recipes.

Extend test_numerics fp8_recipes with MXFP4BlockScaling when supported.

Add default fp8_format on MXFP4BlockScaling for callers expecting
recipe.fp8_format.
- Introduced `is_mxfp4_available` import in the PyTorch interface.
- Added `check_fp8_block_scaling_support` function to validate FP8 block scaling availability based on device compute capability and CUDA version.
- Cleaned up imports in `gemm.py` by moving them to appropriate locations.
- Implemented support for new MXFP4 quantization configuration attributes: `mxfp4_use_hadamard` and `mxfp4_shuffle`.
- Updated `nvte_get_quantization_config_attribute` and `nvte_set_quantization_config_attribute` functions to handle these new attributes, enhancing the flexibility of quantization settings in the transformer engine.
- Introduced a new helper function `_round_up` to round values up to the nearest multiple, aiding in scale padding.
- Updated the MXFP4 quantization process to pad scales to match the native allocator layout, ensuring compatibility with the expected tensor dimensions.
- Adjusted the handling of scales in the `MXFP4QuantizerRef` class to accommodate padded scales, improving the robustness of the quantization process.
The LAUNCH_KERNEL macro hardcoded SHUFFLE_SCALES=true, ignoring the
runtime shuffle_scales parameter. This caused scales to be written in
shuffled layout even when shuffle was disabled, producing incorrect
output. Refactor dispatch into do-while-wrapped macros to also fix
dangling-else issues. Add tolerant comparison helpers to the MXFP4
quantize test for C++/HIP backend rounding differences (±1 nibble).
…ions

- Updated the QuantizationConfig structure to include separate flags for scale and data shuffling.
- Modified the MXFP4 quantization logic to utilize these new flags, enhancing flexibility in quantization settings.
- Adjusted related functions and classes to accommodate the new shuffling parameters, ensuring correct behavior during quantization.
- Updated tests and kernel dispatch to reflect these changes, improving the overall robustness of the MXFP4 quantization process.
…fling options

- Introduced a new function `un_shuffle_scales` to invert the AITER scale shuffle permutation, improving test accuracy.
- Updated `check_quantization_mxfp4_versus_reference` to conditionally unshuffle scales based on the `shuffle_scales` parameter, ensuring correct comparisons.
- Added new parameters for `use_hadamard`, `shuffle_B_matrix_for_aiter`, and `shuffle_scales` in the quantization tests to enhance flexibility and coverage.
- Implemented a new function `_shuffle_fp4_data` in the quantization logic to support shuffling of packed FP4 data for AITER GEMM kernels.
- Adjusted the `MXFP4QuantizerRef` class to utilize the new shuffling function, ensuring compatibility with the updated quantization process.
- Modified the test to view rowwise scales as `torch.uint8` for consistency.
- Implemented conditional unshuffling of scales based on the `shuffle_scales` parameter, enhancing test accuracy and flexibility.
- Ensured that both contiguous and non-contiguous scale tensors are correctly compared in the quantization tests.
- Introduced a new test file `test_mxfp4_gemm_exact.py` to validate the accuracy of MXFP4 GEMM operations against a Python reference implementation.
- Implemented a parameterized test function `test_mxfp4_gemm_versus_reference` to cover various matrix dimensions and data types.
- Enhanced the quantization process by integrating native MXFP4 quantization and reference quantization methods, ensuring robust comparisons.
- Added checks for NaN values and ensured proper handling of output accumulation in the tests.
@sudhu2k sudhu2k requested a review from wangye805 April 22, 2026 17:47
@sudhu2k sudhu2k added ci-level 1 CI test level 1 and removed ci-level 3 CI test level 3 labels Apr 22, 2026
Comment thread tests/cpp/operator/test_cast_mxfp4_transpose.cu
Comment thread tests/cpp/operator/test_cast_mxfp4_transpose.cu Outdated
Comment thread tests/cpp/operator/test_cast_mxfp4_transpose.cu
Comment thread transformer_engine/pytorch/cpp_extensions/gemm.py
Comment thread transformer_engine/common/recipe/__init__.py
@sudhu2k sudhu2k requested a review from ipanfilo April 23, 2026 21:34
Comment thread ci/pytorch.sh Outdated
@sudhu2k sudhu2k requested a review from ipanfilo April 23, 2026 22:03
Comment thread ci/pytorch.sh Outdated
@sudhu2k sudhu2k requested a review from ipanfilo April 24, 2026 02:27
Comment thread ci/pytorch.sh Outdated
@sudhu2k sudhu2k requested a review from ipanfilo April 24, 2026 05:06
@sudhu2k sudhu2k added ci-level 3 CI test level 3 and removed ci-level 1 CI test level 1 labels Apr 26, 2026
@sudhu2k
Copy link
Copy Markdown
Contributor

sudhu2k commented Apr 28, 2026

All level 3 tests passed with the new aiter installed image:
https://github.com/ROCm/TransformerEngine/actions/runs/25014847020/job/73259852413

@sudhu2k sudhu2k merged commit dcfae3e into dev Apr 28, 2026
7 of 9 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci-level 3 CI test level 3

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants