Conversation
…rd kernel Ports the MXFP4 training recipe from TE 2.8 (mxfp-fused-2.8) to the dev branch (TE 2.10). Uses the HIP-based cast+transpose kernel with fused shuffle and Hadamard transform for FP4 quantization, and AITER ASM a4w4 kernels for FP4 GEMM. Changes: - csrc: Add mxfp4_hip.cpp wrapper + cast_transpose_mxfp4_kernel_shuffled.cu - csrc: Register cast_transpose_mxfp4_fused_shuffle in pybind + extensions.h - tensor/mxfp4_tensor.py: Replace Triton quantize with HIP kernel, add quantize_impl(), respect USE_HADAMARD env var - module/linear.py, layernorm_linear.py: Add MXFP4 weight cache with rowwise-only optimization and quantized-norm bypass - module/fp4_handler_gemm.py: AITER ASM a4w4 GEMM handler with shape-aware kernel selection and float4_e2m1fn_x2 dtype conversion - cpp_extensions/gemm.py: Route MXFP4TensorStorage to fp4_handler_gemm - build_tools/pytorch.py: Collect .cu files for hipify/hipcc compilation Made-with: Cursor
Made-with: Cursor
Made-with: Cursor
Made-with: Cursor
Made-with: Cursor
Made-with: Cursor
- Updated .gitignore to exclude specific MXFP4 HIP files. - Introduced new scaling mode NVTE_MXFP4_1D_SCALING in common headers. - Enhanced scaling checks in CheckScaleTensorShape and CheckInputTensor functions to accommodate MXFP4. - Added MXFP4Quantizer class for handling MXFP4 tensor quantization, including tensor creation and parameter setting. - Updated quantization implementation in MXFP4Quantizer to utilize a new quantization function based on environment settings. This commit improves the handling of MXFP4 tensors and their quantization process, ensuring compatibility with the latest scaling modes and tensor operations.
For LoRA SFT, frozen base weights have is_grad_enabled=False so columnwise quantization is skipped (no wgrad needed). For full pretraining all weights have gradients so this is a no-op. Made-with: Cursor
- Introduced MXFP4 quantization logic in the quantization dispatch, including Hadamard transform and data shuffling options. - Added new MXFP4 tensor storage and quantizer classes to manage MXFP4 data formats and operations. - Updated CMakeLists.txt to include new MXFP4 source files and dependencies. - Enhanced common headers to define new quantization configuration attributes for MXFP4. - Implemented MXFP4 quantization in the PyTorch interface, allowing for flexible tensor operations. This commit significantly improves the MXFP4 support in the transformer engine, enabling advanced quantization techniques and optimizing performance for AMD architectures.
Move FP4 AITER GEMM handler from module/fp4_handler_gemm.py into cpp_extensions/gemm.py alongside other GEMM dispatch paths. Remove mxfp4_hip.cpp which is no longer needed after the nvte_quantize_v2 refactor.
handling for MXFP4 alongside existing FP8/NVFP4 paths. Expose MXFP4 device support via check_mxfp4_support / is_mxfp4_available (and FP8GlobalStateManager), validate MXFP4 in check_recipe_support, and return a larger alignment when recipe.mxfp4(). Teach TransformerEngineBase.set_meta_tensor and LayerNormMLP activation fusion gating to treat MXFP4 like other non-fused recipes. Extend test_numerics fp8_recipes with MXFP4BlockScaling when supported. Add default fp8_format on MXFP4BlockScaling for callers expecting recipe.fp8_format.
…cipe-212-refactor
- Introduced `is_mxfp4_available` import in the PyTorch interface. - Added `check_fp8_block_scaling_support` function to validate FP8 block scaling availability based on device compute capability and CUDA version. - Cleaned up imports in `gemm.py` by moving them to appropriate locations.
- Implemented support for new MXFP4 quantization configuration attributes: `mxfp4_use_hadamard` and `mxfp4_shuffle`. - Updated `nvte_get_quantization_config_attribute` and `nvte_set_quantization_config_attribute` functions to handle these new attributes, enhancing the flexibility of quantization settings in the transformer engine.
- Introduced a new helper function `_round_up` to round values up to the nearest multiple, aiding in scale padding. - Updated the MXFP4 quantization process to pad scales to match the native allocator layout, ensuring compatibility with the expected tensor dimensions. - Adjusted the handling of scales in the `MXFP4QuantizerRef` class to accommodate padded scales, improving the robustness of the quantization process.
The LAUNCH_KERNEL macro hardcoded SHUFFLE_SCALES=true, ignoring the runtime shuffle_scales parameter. This caused scales to be written in shuffled layout even when shuffle was disabled, producing incorrect output. Refactor dispatch into do-while-wrapped macros to also fix dangling-else issues. Add tolerant comparison helpers to the MXFP4 quantize test for C++/HIP backend rounding differences (±1 nibble).
…ions - Updated the QuantizationConfig structure to include separate flags for scale and data shuffling. - Modified the MXFP4 quantization logic to utilize these new flags, enhancing flexibility in quantization settings. - Adjusted related functions and classes to accommodate the new shuffling parameters, ensuring correct behavior during quantization. - Updated tests and kernel dispatch to reflect these changes, improving the overall robustness of the MXFP4 quantization process.
…fling options - Introduced a new function `un_shuffle_scales` to invert the AITER scale shuffle permutation, improving test accuracy. - Updated `check_quantization_mxfp4_versus_reference` to conditionally unshuffle scales based on the `shuffle_scales` parameter, ensuring correct comparisons. - Added new parameters for `use_hadamard`, `shuffle_B_matrix_for_aiter`, and `shuffle_scales` in the quantization tests to enhance flexibility and coverage. - Implemented a new function `_shuffle_fp4_data` in the quantization logic to support shuffling of packed FP4 data for AITER GEMM kernels. - Adjusted the `MXFP4QuantizerRef` class to utilize the new shuffling function, ensuring compatibility with the updated quantization process.
- Modified the test to view rowwise scales as `torch.uint8` for consistency. - Implemented conditional unshuffling of scales based on the `shuffle_scales` parameter, enhancing test accuracy and flexibility. - Ensured that both contiguous and non-contiguous scale tensors are correctly compared in the quantization tests.
- Introduced a new test file `test_mxfp4_gemm_exact.py` to validate the accuracy of MXFP4 GEMM operations against a Python reference implementation. - Implemented a parameterized test function `test_mxfp4_gemm_versus_reference` to cover various matrix dimensions and data types. - Enhanced the quantization process by integrating native MXFP4 quantization and reference quantization methods, ensuring robust comparisons. - Added checks for NaN values and ensured proper handling of output accumulation in the tests.
wangye805
approved these changes
Apr 23, 2026
ipanfilo
reviewed
Apr 23, 2026
ipanfilo
reviewed
Apr 23, 2026
…n fused attention setting
…all based on fused attention setting
ipanfilo
reviewed
Apr 23, 2026
ipanfilo
requested changes
Apr 24, 2026
…l only occurs if installation succeeds
ipanfilo
approved these changes
Apr 27, 2026
Contributor
|
All level 3 tests passed with the new aiter installed image: |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Introduces
nvte_cast_transpose_mxfp4_fused_shuffle— fused HIP kernel for MXFP4 cast+transpose with optional Hadamard transform and memory-layout shuffling for GEMM.fp4_gemm_handlerthat dispatches A4W4 GEMM calls to the AITER backend with the right layouts.MXFP4 weight caching in
Linear.forward()andLayerNormLinear.forward()that persists quantized MXFP4TensorStorage weights across forward passes.Training recipe can be enabled using:
MXFP4 Flow:
Loss graph:
Test plan
test_cast_mxfp4_transpose.cupasses on gfx950mxfp4_cpp_gfx942.log
mxfp4_cpp_gfx950.log
test_mxfp4_*tests pass on gfx950mxfp4_all_gfx942.log
mxfp4_all_gfx950.log