[Pytorch][Common] Hybrid quantization#2817
Conversation
Signed-off-by: Evgeny <etsykunov@nvidia.com>
for more information, see https://pre-commit.ci
Greptile SummaryThis PR introduces hybrid (per-direction) quantization for Transformer Engine, allowing rowwise and columnwise GEMM operands to use different quantization formats (e.g., MXFP8 forward + NVFP4 backward). The implementation is substantial (~14K lines) and covers the full ecosystem: FSDP2 buffer protocols, CPU offloading, activation recompute, TP/SP, and distributed optimizer integration.
Confidence Score: 4/5Safe to merge after fixing MXFP8 columnwise scale floor-division truncation in fsdp_extract_buffers. The MXFP8 columnwise scale strip uses integer floor division, silently discarding the last scale block when M per rank is not a multiple of 32, corrupting dequantization under FSDP2 for those boundary rows. transformer_engine/pytorch/tensor/storage/mxfp8_tensor_storage.py — fsdp_extract_buffers columnwise scale truncation at line 366. Important Files Changed
Sequence Diagram%%{init: {'theme': 'neutral'}}%%
sequenceDiagram
participant M as Module
participant HQ as HybridQuantizer
participant RQ as rowwise_quantizer
participant CQ as columnwise_quantizer
participant GG as general_gemm
M->>HQ: quantize(tensor)
HQ->>RQ: quantize(tensor)
RQ-->>HQ: rowwise_storage
alt "columnwise_source == rowwise_dequantized"
HQ->>RQ: dequantize(rowwise_storage)
RQ-->>HQ: bf16_tensor
HQ->>CQ: quantize(bf16_tensor)
else "columnwise_source == original"
HQ->>CQ: quantize(tensor)
end
CQ-->>HQ: columnwise_storage
HQ-->>M: HybridQuantizedTensor(row_sub, col_sub)
M->>GG: "general_gemm(A=HybridTensor, layout=TN)"
GG->>GG: _unwrap_hybrid_A(A, TN) row sub
GG->>GG: _materialize_high_precision
GG->>GG: cuBLAS GEMM with native sub-storage
%%{init: {'theme': 'base', 'themeVariables': {"darkMode": true, "background": "#0d1117", "primaryColor": "#21262d", "primaryTextColor": "#e6edf3", "primaryBorderColor": "#8b949e", "lineColor": "#8b949e", "textColor": "#e6edf3", "edgeLabelBackground": "#161b22", "actorBkg": "#21262d", "actorBorder": "#8b949e", "actorTextColor": "#e6edf3", "actorLineColor": "#8b949e", "signalColor": "#8b949e", "signalTextColor": "#e6edf3", "noteBkgColor": "#373320", "noteBorderColor": "#d4a72c", "noteTextColor": "#f0e6c0", "labelBoxBkgColor": "#21262d", "labelBoxBorderColor": "#8b949e", "labelTextColor": "#e6edf3", "loopTextColor": "#e6edf3", "activationBkgColor": "#30363d", "activationBorderColor": "#8b949e"}}}%%
sequenceDiagram
participant M as Module
participant HQ as HybridQuantizer
participant RQ as rowwise_quantizer
participant CQ as columnwise_quantizer
participant GG as general_gemm
M->>HQ: quantize(tensor)
HQ->>RQ: quantize(tensor)
RQ-->>HQ: rowwise_storage
alt "columnwise_source == rowwise_dequantized"
HQ->>RQ: dequantize(rowwise_storage)
RQ-->>HQ: bf16_tensor
HQ->>CQ: quantize(bf16_tensor)
else "columnwise_source == original"
HQ->>CQ: quantize(tensor)
end
CQ-->>HQ: columnwise_storage
HQ-->>M: HybridQuantizedTensor(row_sub, col_sub)
M->>GG: "general_gemm(A=HybridTensor, layout=TN)"
GG->>GG: _unwrap_hybrid_A(A, TN) row sub
GG->>GG: _materialize_high_precision
GG->>GG: cuBLAS GEMM with native sub-storage
Reviews (18): Last reviewed commit: "Fix attention factories and align with e..." | Re-trigger Greptile |
timmoon10
left a comment
There was a problem hiding this comment.
Overall I think this moves us in a good direction. I see some minor bugs, as well as bugs reported by @greptile-apps.
| rowwise_result = self.rowwise_quantizer.quantize(tensor) | ||
| columnwise_result = self.columnwise_quantizer.quantize(tensor) |
There was a problem hiding this comment.
Do we handle the case where not all usages are needed? I'd expect something like:
| rowwise_result = self.rowwise_quantizer.quantize(tensor) | |
| columnwise_result = self.columnwise_quantizer.quantize(tensor) | |
| rowwise_result = self.rowwise_quantizer.quantize(tensor) if self.rowwise_usage else None | |
| columnwise_result = self.columnwise_quantizer.quantize(tensor) if self.columnwise_usage else None |
| requires_grad: bool = False, | ||
| pin_memory: bool = False, | ||
| ) -> HybridQuantizedTensor: | ||
| self.rowwise_quantizer.internal = True |
There was a problem hiding this comment.
Could we just set internal=True in the constructor? I don't think we ever need PyTorch tensor functionality in the per-usage data.
There was a problem hiding this comment.
This would not work under FSDP2.
| def factory(role): | ||
| if role == "linear_weight": | ||
| return HybridQuantizer( | ||
| rowwise_quantizer=_make_fp8_quantizer(), | ||
| columnwise_quantizer=_make_mxfp8_quantizer(), | ||
| ) | ||
| if role == "linear_input": | ||
| return HybridQuantizer( | ||
| rowwise_quantizer=_make_fp8_quantizer(), | ||
| columnwise_quantizer=_make_nvfp4_quantizer(), | ||
| ) | ||
| if role in ("linear_grad_output", "linear_grad_input"): | ||
| return HybridQuantizer( | ||
| rowwise_quantizer=_make_mxfp8_quantizer(), | ||
| columnwise_quantizer=_make_nvfp4_quantizer(), | ||
| ) | ||
| return None |
There was a problem hiding this comment.
This is horrifying. Good test.
Signed-off-by: Evgeny <etsykunov@nvidia.com>
Signed-off-by: Evgeny <etsykunov@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Evgeny <etsykunov@nvidia.com>
Signed-off-by: Evgeny <etsykunov@nvidia.com>
Signed-off-by: Evgeny <etsykunov@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Evgeny <etsykunov@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Evgeny <etsykunov@nvidia.com>
Signed-off-by: Evgeny <etsykunov@nvidia.com>
Signed-off-by: Evgeny <etsykunov@nvidia.com>
for more information, see https://pre-commit.ci
| # DCP serializes ``CustomRecipe`` via ``pickle``; closure-based qfactories | ||
| # (lambdas, inner functions referencing captured state) are not picklable, | ||
| # so the qfactory must live at module scope. See | ||
| # ``run_fsdp2_fused_adam.py::test_hybrid_dcp_output_parity``. |
There was a problem hiding this comment.
This comment is potentially useful, but I don't think it is in the right place - shouldn't it be closer to the actual implementation?
| for param in model.parameters(): | ||
| state = optimizer.state[param] | ||
| assert state["exp_avg"].dtype == torch.float32 | ||
| assert state["exp_avg_sq"].dtype == torch.float32 | ||
| if "master_param" in state: | ||
| assert state["master_param"].dtype == torch.float32 | ||
|
|
||
| assert losses[-1] < losses[0], f"Loss did not decrease: {losses[0]:.4f} -> {losses[-1]:.4f}" |
There was a problem hiding this comment.
That's not a very strict test, is there a way for us to do some numerical correctness comparisons?
There was a problem hiding this comment.
Enabled check for the monotonic loss decrease (still mostly sanity), and also enabled hybrid vs vanilla bitwise recipe comparizon, see e.g. test_fused_adam_hybrid_vs_base_recipe_parity.
Signed-off-by: Evgeny <etsykunov@nvidia.com>
Signed-off-by: Evgeny <etsykunov@nvidia.com>
for more information, see https://pre-commit.ci
|
/te-ci pytorch L1 |
Signed-off-by: Evgeny <etsykunov@nvidia.com>
|
/te-ci pytorch L1 |
for more information, see https://pre-commit.ci
Signed-off-by: Evgeny <etsykunov@nvidia.com>
Signed-off-by: Evgeny <etsykunov@nvidia.com>
|
/te-ci pytorch L1 |
for more information, see https://pre-commit.ci
Signed-off-by: Evgeny <etsykunov@nvidia.com>
Signed-off-by: Evgeny Tsykunov <etsykunov@nvidia.com>
|
/te-ci pytorch L1 |
for more information, see https://pre-commit.ci
Signed-off-by: Evgeny <etsykunov@nvidia.com>
Signed-off-by: Evgeny <etsykunov@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Evgeny <etsykunov@nvidia.com>
for more information, see https://pre-commit.ci
|
Enable columnwise_source and hybrid recipes Respect quantizer veto for save_original_inp |
Signed-off-by: Evgeny <etsykunov@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Evgeny <etsykunov@nvidia.com>
| elif name == "_columnwise_scale_inv" and t is not None: | ||
| expected = flattened_in_shape0 // MXFP8_BLOCK_SCALING_SIZE | ||
| if t.size(0) != expected: | ||
| t = t[:expected] | ||
| buffers.append(t) | ||
| return tuple(buffers), {"field_names": names} |
There was a problem hiding this comment.
The columnwise scale truncation uses floor division (
flattened_in_shape0 // MXFP8_BLOCK_SCALING_SIZE) instead of ceiling. For a sharded tensor where M is not a multiple of 32, ceil(M/32) scale entries are valid but M//32 are retained — the entry covering the last partial block is silently dropped. After all-gather, dequantization for those boundary rows uses a stale or zero scale. For example with M=48: 2 scale entries valid, but 48//32=1 is used, discarding row 32–47's scale.
| elif name == "_columnwise_scale_inv" and t is not None: | |
| expected = flattened_in_shape0 // MXFP8_BLOCK_SCALING_SIZE | |
| if t.size(0) != expected: | |
| t = t[:expected] | |
| buffers.append(t) | |
| return tuple(buffers), {"field_names": names} | |
| elif name == "_columnwise_scale_inv" and t is not None: | |
| expected = math.ceil(flattened_in_shape0 / MXFP8_BLOCK_SCALING_SIZE) | |
| if t.size(0) != expected: | |
| t = t[:expected] | |
| buffers.append(t) | |
| return tuple(buffers), {"field_names": names} |
Description
Hybrid (per-direction) quantization. Hybrid means rowwise/colwise can use different formats via CustomRecipe(qfactory).
This is an experimental feature.
The main problem that it tries to solve is that precision requirements are non-uniform.
Current recipes set one format for both rowwise and colwise directions.
Hybrid quantization enables, e.g. MXFP8 fwd and NVFP4 bwd (or vice versa) or any other valid combination. No need for a hardcoded recipe for every combination.
Composer-style (Composer 2 paper) grouped GEMM recipe, e.g. row-scaled NVFP4 fwd + MXFP8 bwd:
By default, the above factory uses
columnwise_source="original", so MXFP8 backward operands are quantized from the original high-precision tensor. Usecolumnwise_source="rowwise_dequantized"when the backward operand should be derived from the dequantized rowwise NVFP4 forward value.C++ optimizations (fusions, etc.) will come as standalone PRs. cc @kainzhong
TODO:
Integration
Ecosystem integration (all functional, unit-tested):
Megatron-LM integration status:
--fp{4,8}-param-gather+ dist opt (persistent low-precision params viaquantized_model_init+ sharded-master FP32 → quantized cast viaquantize_master_weights.)- [Done] Per-tensor Float8 hybrid (delayed and/or current, any per-direction combination
including same-format, cross-format Float8, single-direction)
- [TODO] Per-block hybrid sub-quantizers (MXFP8, NVFP4, Float8Blockwise) — each rejected per-direction by
quantize_master_weights; unblocker is TE-side cast-helper / kernel.--fp{4,8}-param-gather(fix private attribute access)--fp{4,8}-param-gather- [Done] TE-side hybrid FSDP2 path works end-to-end for Float8 / MXFP8 / Float8Blockwise sub-storages (TODO: need some minor MLM update)
- [TODO] NVFP4 sub-storage FSDP2 hooks
_hybrid_split_quantizeunder Megatron MoE)Review
Total diff +14000
New hybrid source (
hybrid_tensor.py,hybrid_tensor_storage.py,identity_tensor.py,identity_tensor_storage.py) ~1800Adjacent modifications ~1500
Tests are the rest (~10K)
Suggested reading order
-columnwise_source controls whether columnwise quantization uses the original input or the rowwise-dequantized value.
1.1 Identity passthrough — b99277a
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: