Skip to content

[Pytorch][Common] Hybrid quantization#2817

Open
negvet wants to merge 46 commits into
NVIDIA:mainfrom
negvet:hybrid_quantization
Open

[Pytorch][Common] Hybrid quantization#2817
negvet wants to merge 46 commits into
NVIDIA:mainfrom
negvet:hybrid_quantization

Conversation

@negvet

@negvet negvet commented Mar 31, 2026

Copy link
Copy Markdown
Collaborator

Description

Hybrid (per-direction) quantization. Hybrid means rowwise/colwise can use different formats via CustomRecipe(qfactory).
This is an experimental feature.
The main problem that it tries to solve is that precision requirements are non-uniform.

Current recipes set one format for both rowwise and colwise directions.
Hybrid quantization enables, e.g. MXFP8 fwd and NVFP4 bwd (or vice versa) or any other valid combination. No need for a hardcoded recipe for every combination.

Composer-style (Composer 2 paper) grouped GEMM recipe, e.g. row-scaled NVFP4 fwd + MXFP8 bwd:

# CustomRecipe calls quantization_factory(role) for each quantized tensor
# Factory chooses formats

def hybrid_factory(role):
    is_grouped_linear = role is not None and role.module_type == "grouped_linear"
    is_linear = role is not None and role.module_type == "linear"

    if is_grouped_linear and role.tensor_type == "input":
        return HybridQuantizer(
            rowwise_quantizer=NVFP4Quantizer(row_scaled_nvfp4=True, ...),
            columnwise_quantizer=MXFP8Quantizer(...),
        )

    if is_grouped_linear and role.tensor_type == "weight":
        return HybridQuantizer(
            rowwise_quantizer=NVFP4Quantizer(...),
            columnwise_quantizer=MXFP8Quantizer(...),
        )

    if is_grouped_linear and role.tensor_type == "grad_output":
        return MXFP8Quantizer(...)

    if is_linear:
        return MXFP8Quantizer(...)

    return MXFP8Quantizer(...)

recipe = CustomRecipe(qfactory=hybrid_factory)
with autocast(recipe=recipe):
    y = model(x)

By default, the above factory uses columnwise_source="original", so MXFP8 backward operands are quantized from the original high-precision tensor. Use columnwise_source="rowwise_dequantized" when the backward operand should be derived from the dequantized rowwise NVFP4 forward value.

C++ optimizations (fusions, etc.) will come as standalone PRs. cc @kainzhong

TODO:

  • Convergence of base (non-hybrid) recipes
  • HybridFloat8BlockScaling is xfailed under FSDP2 because dim-0 shards can split 128-row block-scale tiles, producing all-gathered scale buffers whose shape does not match the global tensor.
  • Delayed scaling
  • Mid-training recipe change

Integration

Ecosystem integration (all functional, unit-tested):

  • [Done] quantized_model_init
  • [Done] FSDP2 (TODO: optimize communication buffers)
  • [Done] CPU offloading
  • [Done] Activation recomputation
  • [Done] TP/SP (TODO: enable quantized AG)

Megatron-LM integration status:

  • [Done] 1 GPU baseline
  • [Done] DP + distributed optimizer
  • [TODO] quantized_model_init + --fp{4,8}-param-gather + dist opt (persistent low-precision params via quantized_model_init + sharded-master FP32 → quantized cast via quantize_master_weights.)
    - [Done] Per-tensor Float8 hybrid (delayed and/or current, any per-direction combination
    including same-format, cross-format Float8, single-direction)
    - [TODO] Per-block hybrid sub-quantizers (MXFP8, NVFP4, Float8Blockwise) — each rejected per-direction by quantize_master_weights; unblocker is TE-side cast-helper / kernel.
  • [TODO] Megatron-FSDP + --fp{4,8}-param-gather (fix private attribute access)
  • [TODO] Torch FSDP2 + --fp{4,8}-param-gather
    - [Done] TE-side hybrid FSDP2 path works end-to-end for Float8 / MXFP8 / Float8Blockwise sub-storages (TODO: need some minor MLM update)
    - [TODO] NVFP4 sub-storage FSDP2 hooks
  • [Done] Activation recompute
  • [Done] CPU offload
  • [Done] TP/SP/PP
  • [Done] MoE + EP + grouped GEMM (qwen3 MoE; _hybrid_split_quantize under Megatron MoE)

Review

Total diff +14000
New hybrid source (hybrid_tensor.py, hybrid_tensor_storage.py, identity_tensor.py, identity_tensor_storage.py) ~1800
Adjacent modifications ~1500
Tests are the rest (~10K)

Suggested reading order

  1. Foundation — 7553e6a: Python containers + quantize/gemm dispatch/unwrap
  • tensor/hybrid_tensor.py — HybridQuantizer + HybridQuantizedTensor
    -columnwise_source controls whether columnwise quantization uses the original input or the rowwise-dequantized value.
  • tensor/storage/hybrid_tensor_storage.py
  • cpp_extensions/gemm.py — _unwrap_hybrid_A/B
  • common/transpose/quantize_transpose_square_blockwise.cu - Block FP8 columnwise-only null-checks
  • Module hooks in module/{base,grouped_linear,layernorm_linear,layernorm_mlp}.py
  • Tests: TestHybridQuantizer*, TestHybridGemmBitwiseIdentical* (proves zero-overhead vs vanilla recipes when both formats match), TestHybridDirectionUnwrap*, TestHybridGroupedLinear*

1.1 Identity passthrough — b99277a

  • tensor/identity_tensor.py and tensor/storage/identity_tensor_storage.py — IdentityQuantizer / IdentityTensor high-precision passthrough
  • custom_recipes/quantization_factory_zoo.py — examples for high-precision fwd/bwd directions and columnwise_source="rowwise_dequantized"
  • Tests: test_identity_quantizer.py plus hybrid tests covering Identity inside HybridQuantizer
  1. quantized_model_init + FusedAdam — f80f5d0
  • hybrid_tensor.py::HybridQuantizer.update_quantized — delegates to each sub-quantizer; unblocks workspace-cache quantize_() and FusedAdam writeback
  • module/base.py workspace-cache invalidation
  • Tests: TestHybridQuantizedModelInit, TestHybridFusedAdam, TestHybridQuantizedParamsEndToEnd, TestHybridCheckpoint, TestQuantizedParamsEquivalence*
  1. FSDP2 support — 2185b30
  • New base FSDP2 buffer protocol on QuantizedTensorStorage: fsdp_buffer_fields / fsdp_extract_buffers / fsdp_assign_gathered. Generic, reusable beyond hybrid.
  • Per-format overrides on Float8TensorStorage (direction-aware) and MXFP8TensorStorage (trips/re-applies scale alignment padding around the gather)
  • hybrid_tensor.py::fsdp_pre/post_all_gather + torch_dispatch for the FSDP2 op set (view, split, as_strided, slice, copy_, new_zeros, clone, detach)
  • Non-safety in float8_tensor.py and mxfp8_tensor.py for single-direction sub-storages (columnwise-only on Hopper/L40)
  • Tests: TestHybridTorchDispatchFSDP2Ops, TestHybridFsdpPreAllGatherProtocol, TestHybridFsdpRoundtrip (bitwise-exact against manual all_gather(dequantize(shard))), plus tests/pytorch/distributed/fsdp2_tests/
  1. CPU offloading — 103fffe
  • hybrid_tensor_storage.py::clear() (v1 path) + prepare_for_saving / restore_from_saved chain (v2 path)
  • hybrid_tensor.py::detach() re-wraps each sub-storage via make_like (required by cpu_offload_v2's detach → prepare_for_saving pattern; sharing sub-storage objects would null-out fields on the original)
  • TestHybridCpuOffloadPushPop, plus updates to test_cpu_offloading*.py
  1. Activation recomputation — 16fb371
  • Uses existing QuantizedTensorStorage::prepare_for_saving / restore_from_saved protocol, preserving ordering across both sub-storages
  • Tests: 20 bitwise tests in TestHybridActivationRecompute
  1. TP/SP — a50fd63
  • hybrid_tensor.py::HybridQuantizer.supports_only_rowwise_all_gather — overrides to handle the NVFP4 columnwise-dequantize gap in the BF16 fallback path
  • distributed.py::gather_along_first_dim — hybrid branch re-quantizes with both directions after AG (since hybrid has no _create_transpose synthesis path)
  • Tests: 9 distributed tests in run_hybrid_tp_sp.py / test_hybrid_tp_sp.py
  1. Megatron-LM integration — a164cd3
  • tensor/utils.py::_route_hybrid_to_buckets — per-direction dispatch for quantize_master_weights: iterates both sub-storages, routes each independently into the per-format bucket matching its own sub-quantizer type
  • Hybrid branches in replace_raw_data and post_all_gather_processing
  • Today: per-tensor Float8 sub-quantizers (delayed + current) work in any per-direction combination. Per-block sub-quantizers raise per-direction with in-code TODOs naming the unblocker.
  • Tests: TestHybridQuantizeMasterWeights, TestHybridPostAllGatherProcessing

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@greptile-apps

greptile-apps Bot commented Mar 31, 2026

Copy link
Copy Markdown
Contributor

Greptile Summary

This PR introduces hybrid (per-direction) quantization for Transformer Engine, allowing rowwise and columnwise GEMM operands to use different quantization formats (e.g., MXFP8 forward + NVFP4 backward). The implementation is substantial (~14K lines) and covers the full ecosystem: FSDP2 buffer protocols, CPU offloading, activation recompute, TP/SP, and distributed optimizer integration.

  • New tensor types: HybridQuantizedTensor / HybridQuantizer compose two existing sub-quantizers per direction; IdentityTensor / IdentityQuantizer provide high-precision passthrough for unquantized directions within a hybrid recipe.
  • FSDP2 sub-storage protocol: New fsdp_buffer_fields / fsdp_extract_buffers / fsdp_assign_gathered methods on Float8TensorStorage, MXFP8TensorStorage, and Float8BlockwiseQTensorStorage enable per-direction all-gather with format-specific padding strip/re-apply.
  • GroupedLinear dispatch: _hybrid_split_quantize runs tex.split_quantize twice (once per direction) and _unwrap_hybrid_A/B in general_gemm extracts the direction-appropriate sub-storage before the C++ kernel dispatch.

Confidence Score: 4/5

Safe to merge after fixing MXFP8 columnwise scale floor-division truncation in fsdp_extract_buffers.

The MXFP8 columnwise scale strip uses integer floor division, silently discarding the last scale block when M per rank is not a multiple of 32, corrupting dequantization under FSDP2 for those boundary rows.

transformer_engine/pytorch/tensor/storage/mxfp8_tensor_storage.py — fsdp_extract_buffers columnwise scale truncation at line 366.

Important Files Changed

Filename Overview
transformer_engine/pytorch/tensor/hybrid_tensor.py New 1042-line file implementing HybridQuantizer and HybridQuantizedTensor; FSDP2 protocol and full torch_dispatch set look well-structured; _sync_usage after fsdp_assign_gathered correctly invalidates stale transpose caches.
transformer_engine/pytorch/tensor/storage/mxfp8_tensor_storage.py Adds FSDP2 sub-storage protocol; fsdp_extract_buffers strips alignment padding before all-gather but uses floor division for columnwise scale truncation — valid scale entries are dropped when M is not a multiple of MXFP8_BLOCK_SCALING_SIZE (32).
transformer_engine/pytorch/tensor/storage/float8_tensor_storage.py Adds direction-aware fsdp_buffer_fields and fsdp_assign_gathered clearing _transpose_invalid; view updated to handle _data=None columnwise-only sub-storages cleanly.
transformer_engine/pytorch/tensor/storage/float8_blockwise_tensor_storage.py Adds FSDP2 sub-storage protocol with fp8_transpose-based M-major extraction; fsdp_extract_buffers handles only one direction at a time but fsdp_buffer_fields claims to list both — inconsistency could mislead future bidirectional callers.
transformer_engine/pytorch/tensor/utils.py Adds _route_hybrid_to_buckets, _update_transpose_only_float8_flat_fragment for Hopper columnwise-only distopt, and _cast_master_weights_to_identity; per-direction routing and scatter-write logic are correct.
transformer_engine/pytorch/module/grouped_linear.py _hybrid_split_quantize and identity-fallback helpers are well-guarded; None+Hybrid mixed lists correctly rejected at classifier level.
transformer_engine/pytorch/cpp_extensions/gemm.py _unwrap_hybrid_A/B correctly maps GEMM layout flags to direction-appropriate sub-storage; _reject_unsupported_output_quantizer guards HybridQuantizer/IdentityQuantizer output paths.
transformer_engine/pytorch/tensor/float8_tensor.py Updated to handle _data=None columnwise-only sub-storages in view, split, and clone; split shape inference for columnwise-only correctly recovers the logical shape.
transformer_engine/pytorch/tensor/identity_tensor.py Clean IdentityQuantizer/IdentityTensor implementation; _maybe_cast detaches to avoid spurious autograd edges.
transformer_engine/pytorch/distributed.py gather_along_first_dim hybrid branch saves and restores usage flags around re-quantize, correctly using both directions since hybrid has no _create_transpose synthesis path.

Sequence Diagram

%%{init: {'theme': 'neutral'}}%%
sequenceDiagram
    participant M as Module
    participant HQ as HybridQuantizer
    participant RQ as rowwise_quantizer
    participant CQ as columnwise_quantizer
    participant GG as general_gemm

    M->>HQ: quantize(tensor)
    HQ->>RQ: quantize(tensor)
    RQ-->>HQ: rowwise_storage
    alt "columnwise_source == rowwise_dequantized"
        HQ->>RQ: dequantize(rowwise_storage)
        RQ-->>HQ: bf16_tensor
        HQ->>CQ: quantize(bf16_tensor)
    else "columnwise_source == original"
        HQ->>CQ: quantize(tensor)
    end
    CQ-->>HQ: columnwise_storage
    HQ-->>M: HybridQuantizedTensor(row_sub, col_sub)

    M->>GG: "general_gemm(A=HybridTensor, layout=TN)"
    GG->>GG: _unwrap_hybrid_A(A, TN) row sub
    GG->>GG: _materialize_high_precision
    GG->>GG: cuBLAS GEMM with native sub-storage
Loading
%%{init: {'theme': 'base', 'themeVariables': {"darkMode": true, "background": "#0d1117", "primaryColor": "#21262d", "primaryTextColor": "#e6edf3", "primaryBorderColor": "#8b949e", "lineColor": "#8b949e", "textColor": "#e6edf3", "edgeLabelBackground": "#161b22", "actorBkg": "#21262d", "actorBorder": "#8b949e", "actorTextColor": "#e6edf3", "actorLineColor": "#8b949e", "signalColor": "#8b949e", "signalTextColor": "#e6edf3", "noteBkgColor": "#373320", "noteBorderColor": "#d4a72c", "noteTextColor": "#f0e6c0", "labelBoxBkgColor": "#21262d", "labelBoxBorderColor": "#8b949e", "labelTextColor": "#e6edf3", "loopTextColor": "#e6edf3", "activationBkgColor": "#30363d", "activationBorderColor": "#8b949e"}}}%%
sequenceDiagram
    participant M as Module
    participant HQ as HybridQuantizer
    participant RQ as rowwise_quantizer
    participant CQ as columnwise_quantizer
    participant GG as general_gemm

    M->>HQ: quantize(tensor)
    HQ->>RQ: quantize(tensor)
    RQ-->>HQ: rowwise_storage
    alt "columnwise_source == rowwise_dequantized"
        HQ->>RQ: dequantize(rowwise_storage)
        RQ-->>HQ: bf16_tensor
        HQ->>CQ: quantize(bf16_tensor)
    else "columnwise_source == original"
        HQ->>CQ: quantize(tensor)
    end
    CQ-->>HQ: columnwise_storage
    HQ-->>M: HybridQuantizedTensor(row_sub, col_sub)

    M->>GG: "general_gemm(A=HybridTensor, layout=TN)"
    GG->>GG: _unwrap_hybrid_A(A, TN) row sub
    GG->>GG: _materialize_high_precision
    GG->>GG: cuBLAS GEMM with native sub-storage
Loading

Reviews (18): Last reviewed commit: "Fix attention factories and align with e..." | Re-trigger Greptile

Comment thread transformer_engine/pytorch/module/grouped_linear.py Outdated
Comment thread transformer_engine/pytorch/tensor/storage/hybrid_tensor_storage.py
Comment thread transformer_engine/pytorch/tensor/hybrid_tensor.py Outdated

@timmoon10 timmoon10 left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall I think this moves us in a good direction. I see some minor bugs, as well as bugs reported by @greptile-apps.

Comment on lines +52 to +53
rowwise_result = self.rowwise_quantizer.quantize(tensor)
columnwise_result = self.columnwise_quantizer.quantize(tensor)

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we handle the case where not all usages are needed? I'd expect something like:

Suggested change
rowwise_result = self.rowwise_quantizer.quantize(tensor)
columnwise_result = self.columnwise_quantizer.quantize(tensor)
rowwise_result = self.rowwise_quantizer.quantize(tensor) if self.rowwise_usage else None
columnwise_result = self.columnwise_quantizer.quantize(tensor) if self.columnwise_usage else None

@negvet negvet May 21, 2026

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in 4858491

requires_grad: bool = False,
pin_memory: bool = False,
) -> HybridQuantizedTensor:
self.rowwise_quantizer.internal = True

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we just set internal=True in the constructor? I don't think we ever need PyTorch tensor functionality in the per-usage data.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would not work under FSDP2.

Comment thread transformer_engine/pytorch/tensor/hybrid_tensor.py Outdated
Comment on lines +1339 to +1355
def factory(role):
if role == "linear_weight":
return HybridQuantizer(
rowwise_quantizer=_make_fp8_quantizer(),
columnwise_quantizer=_make_mxfp8_quantizer(),
)
if role == "linear_input":
return HybridQuantizer(
rowwise_quantizer=_make_fp8_quantizer(),
columnwise_quantizer=_make_nvfp4_quantizer(),
)
if role in ("linear_grad_output", "linear_grad_input"):
return HybridQuantizer(
rowwise_quantizer=_make_mxfp8_quantizer(),
columnwise_quantizer=_make_nvfp4_quantizer(),
)
return None

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is horrifying. Good test.

negvet and others added 10 commits April 6, 2026 10:26
Signed-off-by: Evgeny <etsykunov@nvidia.com>
Signed-off-by: Evgeny <etsykunov@nvidia.com>
Signed-off-by: Evgeny <etsykunov@nvidia.com>
Signed-off-by: Evgeny <etsykunov@nvidia.com>
Signed-off-by: Evgeny <etsykunov@nvidia.com>
Comment thread transformer_engine/pytorch/module/grouped_linear.py Outdated
Comment thread transformer_engine/pytorch/tensor/hybrid_tensor.py
negvet and others added 2 commits April 29, 2026 16:02
Signed-off-by: Evgeny <etsykunov@nvidia.com>
Comment thread transformer_engine/pytorch/tensor/storage/float8_tensor_storage.py
negvet added 3 commits May 13, 2026 12:34
Signed-off-by: Evgeny <etsykunov@nvidia.com>
Signed-off-by: Evgeny <etsykunov@nvidia.com>
Signed-off-by: Evgeny <etsykunov@nvidia.com>
@negvet negvet requested a review from ksivaman as a code owner May 21, 2026 13:53
Comment thread transformer_engine/pytorch/tensor/float8_tensor.py
Comment on lines +27 to +30
# DCP serializes ``CustomRecipe`` via ``pickle``; closure-based qfactories
# (lambdas, inner functions referencing captured state) are not picklable,
# so the qfactory must live at module scope. See
# ``run_fsdp2_fused_adam.py::test_hybrid_dcp_output_parity``.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment is potentially useful, but I don't think it is in the right place - shouldn't it be closer to the actual implementation?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed

Comment on lines +1177 to +1184
for param in model.parameters():
state = optimizer.state[param]
assert state["exp_avg"].dtype == torch.float32
assert state["exp_avg_sq"].dtype == torch.float32
if "master_param" in state:
assert state["master_param"].dtype == torch.float32

assert losses[-1] < losses[0], f"Loss did not decrease: {losses[0]:.4f} -> {losses[-1]:.4f}"

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's not a very strict test, is there a way for us to do some numerical correctness comparisons?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Enabled check for the monotonic loss decrease (still mostly sanity), and also enabled hybrid vs vanilla bitwise recipe comparizon, see e.g. test_fused_adam_hybrid_vs_base_recipe_parity.

@negvet negvet requested a review from vthumbe1503 June 9, 2026 16:14
Signed-off-by: Evgeny <etsykunov@nvidia.com>
@negvet negvet requested a review from cyanguwa as a code owner June 10, 2026 16:49
negvet and others added 2 commits June 10, 2026 16:53
Signed-off-by: Evgeny <etsykunov@nvidia.com>
@negvet

negvet commented Jun 10, 2026

Copy link
Copy Markdown
Collaborator Author

/te-ci pytorch L1

Signed-off-by: Evgeny <etsykunov@nvidia.com>
@negvet

negvet commented Jun 11, 2026

Copy link
Copy Markdown
Collaborator Author

/te-ci pytorch L1

negvet added 2 commits June 12, 2026 13:15
Signed-off-by: Evgeny <etsykunov@nvidia.com>
Signed-off-by: Evgeny <etsykunov@nvidia.com>
@negvet

negvet commented Jun 15, 2026

Copy link
Copy Markdown
Collaborator Author

/te-ci pytorch L1

@negvet

negvet commented Jun 16, 2026

Copy link
Copy Markdown
Collaborator Author

/te-ci pytorch L1

@negvet

negvet commented Jun 23, 2026

Copy link
Copy Markdown
Collaborator Author

Enable columnwise_source and hybrid recipes
columnwise_source makes the origin of the columnwise operand explicit in HybridQuantizer (options are {"original", "rowwise_dequantized"}). With columnwise_source="rowwise_dequantized", the backward columnwise operand is built from dequantize(rowwise_fprop_quantized(x)), so backward sees the forward quantization error instead of re-reading the original high-precision tensor. The same mechanism also supports double-quantization, where colwise direction is quantized from the dequantized rowwise (if colwise quantizer in HybridQuantizer is a non-Identity quantizer).

Respect quantizer veto for save_original_inp
save_original_input is now treated as an optimization hint that can be rejected by the quantizer if it would violate recipe semantics. Hybrid quantizers that require the forward-quantized value can force the save-forward path.

negvet and others added 3 commits June 23, 2026 07:33
Signed-off-by: Evgeny <etsykunov@nvidia.com>
Signed-off-by: Evgeny <etsykunov@nvidia.com>
Comment on lines +365 to +370
elif name == "_columnwise_scale_inv" and t is not None:
expected = flattened_in_shape0 // MXFP8_BLOCK_SCALING_SIZE
if t.size(0) != expected:
t = t[:expected]
buffers.append(t)
return tuple(buffers), {"field_names": names}

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 The columnwise scale truncation uses floor division (flattened_in_shape0 // MXFP8_BLOCK_SCALING_SIZE) instead of ceiling. For a sharded tensor where M is not a multiple of 32, ceil(M/32) scale entries are valid but M//32 are retained — the entry covering the last partial block is silently dropped. After all-gather, dequantization for those boundary rows uses a stale or zero scale. For example with M=48: 2 scale entries valid, but 48//32=1 is used, discarding row 32–47's scale.

Suggested change
elif name == "_columnwise_scale_inv" and t is not None:
expected = flattened_in_shape0 // MXFP8_BLOCK_SCALING_SIZE
if t.size(0) != expected:
t = t[:expected]
buffers.append(t)
return tuple(buffers), {"field_names": names}
elif name == "_columnwise_scale_inv" and t is not None:
expected = math.ceil(flattened_in_shape0 / MXFP8_BLOCK_SCALING_SIZE)
if t.size(0) != expected:
t = t[:expected]
buffers.append(t)
return tuple(buffers), {"field_names": names}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants