Skip to content

[JAX] Expert Parallelism: JAX primitives + VJPs#3036

Open
phu0ngng wants to merge 23 commits into
NVIDIA:mainfrom
phu0ngng:phuong/ep-3-jax
Open

[JAX] Expert Parallelism: JAX primitives + VJPs#3036
phu0ngng wants to merge 23 commits into
NVIDIA:mainfrom
phu0ngng:phuong/ep-3-jax

Conversation

@phu0ngng

@phu0ngng phu0ngng commented May 22, 2026

Copy link
Copy Markdown
Collaborator

Summary

Third PR in the TE Expert Parallelism (EP) series, built on top of #3034. Lands the JAX bindings: an XLA FFI layer over the nvte_ep_* C API, a Python wrapper with custom_vjp for autograd, mesh-aware sharding rules, a multi-process test suite, and an end-to-end MoE example. NCCL ncclEpDispatch/ncclEpCombine are exposed as XLA primitives and work with CUDA-graph capture.

Implementation

Public Python API (transformer_engine/jax/ep.py)

from transformer_engine.jax.ep import (
    EpHandle,        # opaque (id, handle_mem) pair from ep_prepare
    ep_bootstrap,    # one-shot per-process: init NCCL comm + nvte_ep_initialize
    ep_dispatch,     # custom_vjp-wrapped dispatch 
    ep_combine,      # custom_vjp-wrapped combine

ep_dispatch / ep_combine are jax.custom_vjp functions: forward is the FFI primitive, backward calls the matching nvte_ep_*_bwd FFI primitive directly (no ep_prepare in the bwd — routing state is already cached in handle.mem). Note that ep_dispatch also calls ep_prepare in the forward path, which all-gathers and preprocesses routing maps.

XLA FFI bindings (transformer_engine/jax/csrc/extensions/ep.cpp)

Five XLA_FFI_DEFINE_HANDLER_SYMBOL entries — EpPrepareHandler, EpDispatchHandler, EpCombineHandler, EpDispatchBwdHandler, EpCombineBwdHandler — each calling the corresponding nvte_ep_* C entry point. All marked FFI_CudaGraph_Traits so they capture cleanly. handle_id is a static FFI attribute baked at jit trace time.

Primitives + Python layer (transformer_engine/jax/cpp_extensions/ep.py, +951 lines)

Standard TE primitive plumbing: abstract_eval (shape/dtype inference), lowering, impl, outer_primitive registration, and partitioning rules so the EP collective is treated as a single sharded op by XLA (no spurious resharding around it).

Sharding (transformer_engine/jax/sharding.py, +12 lines)

Adds the EP mesh axis to the global mesh resource set so downstream sharding rules can reference it.

Build wiring (build_tools/jax.py, +41 lines)

Threads NCCL EP linkage through the JAX transformer_engine_jax extension. No new top-level build flags; rides on the parent PR's NVTE_BUILD_WITH_NCCL_EP.

Tests & example

  • tests/jax/test_multi_process_ep.py (+690 lines): 13 tests covering bootstrap, ep_prepare shape/handle contracts, primitive-level dispatch/combine identity (uniform + skewed routing), custom_vjp fwd+bwd correctness, and HLO inspection (must not insert XLA collectives outside the EP FFI).
  • tests/jax/multi_process_launch_ep.sh: 4-rank launcher; sets XLA_FLAGS to keep XLA command-buffer capture off for the EP FFI sequence (NCCL EP graph-destroy interaction).
  • examples/jax/ep/ep_moe.py (+394 lines) + run_test_ep.sh: end-to-end MoE with EP, dp=ep=2 mesh, includes a ref-comparison --check that verifies fwd+bwd vs a single-process reference.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@greptile-apps

greptile-apps Bot commented May 22, 2026

Copy link
Copy Markdown
Contributor

Greptile Summary

This PR lands the JAX bindings for the NCCL Expert Parallelism (EP) feature: an XLA FFI layer over the nvte_ep_* C API, jax.custom_vjp-wrapped ep_dispatch / ep_combine with mesh-aware SPMD sharding rules, a multi-process test suite, and an end-to-end MoE example. The build system is also consolidated — nccl_ep_enabled() now lives in build_tools/utils.py and is shared by both setup.py and build_tools/jax.py, removing the previous inconsistency in arch-guard logic.

  • transformer_engine/jax/ep.py: ep_bootstrap bootstraps the NCCL communicator via a UID all-gather, then registers ep_dispatch and ep_combine as custom_vjp functions whose backward calls the matching nvte_ep_*_bwd FFI primitives.
  • transformer_engine/jax/cpp_extensions/ep.py: Five JAX primitives (EpPrepare, EpDispatch, EpCombine, and their bwd counterparts) with abstract_eval, lowering, SPMD partition, and Shardy sharding rules; the inner/outer primitive split allows the SPMD layer to run per-shard FFI calls.
  • transformer_engine/jax/csrc/extensions/ep.cpp: Five XLA_FFI_DEFINE_HANDLER_SYMBOL entries backed by a EpResources singleton (NCCL comm + EP backend) managed via shared_ptr / weak_ptr; all handlers carry FFI_CudaGraph_Traits for CUDA-graph capture.

Confidence Score: 4/5

The change is safe to merge with awareness of a few gaps: a missing ep_size >= 2 guard in ep_bootstrap, non-differentiable outputs exposed as custom_vjp primals without stop_gradient, and the EpCombinePrimitive.partition None-spec path allocating with the wrong per-shard shape.

The core forward/backward math and SPMD sharding rules are well-reasoned and thoroughly tested. Three non-blocking quality gaps were found: ep_bootstrap accepts ep_size=1 silently before failing in C++; handle_mem and token_counts returned as differentiable primals from ep_dispatch create latent friction with JAX's autodiff type machinery; and ep_combine_fwd's partition falls back to a semantically incorrect per-shard shape when out_partition_spec=None. None of these affect the primary custom_vjp code path used in the example and tests.

transformer_engine/jax/ep.py (ep_bootstrap validation, ep_dispatch primal outputs), transformer_engine/jax/cpp_extensions/ep.py (EpCombinePrimitive.partition None-spec branch), build_tools/utils.py (native arch heuristic).

Important Files Changed

Filename Overview
transformer_engine/jax/ep.py Core public API: ep_bootstrap, ep_dispatch, ep_combine custom_vjp wrappers. Missing ep_size >= 2 guard; handle_mem/token_counts exposed as differentiable primal outputs in ep_dispatch return.
transformer_engine/jax/cpp_extensions/ep.py JAX primitive layer: EpPreparePrimitive, EpDispatchPrimitive, EpCombinePrimitive, and their bwd counterparts with SPMD partition rules. EpCombinePrimitive.partition with out_partition_spec=None passes global shape as per-shard shape, causing unwritten slots in SPMD context.
transformer_engine/jax/csrc/extensions/ep.cpp XLA FFI C++ layer: five handler symbols wrapping nvte_ep_* C API. Correct guard/lifecycle via EpInstanceState weak_ptr + anchor. Int32→int64 topk_idx upcast on-stream handled cleanly.
build_tools/utils.py Adds nccl_ep_enabled() used by both setup.py and build_tools/jax.py. 'native' arch token unconditionally accepted as Hopper+ regardless of build host GPU.
build_tools/jax.py Wires nccl_ep_enabled() into the JAX extension build. Defers arch/flag logic to utils.py cleanly; eliminates the previous inconsistency between setup.py and jax.py.
transformer_engine/jax/sharding.py Adds ep_resource field to MeshResource and ep_axis_size() helper. Clean, minimal changes consistent with existing parallelism resource pattern.
tests/jax/test_multi_process_ep.py 690-line multi-process test suite covering bootstrap, prepare shape contracts, dispatch/combine identity, custom_vjp fwd+bwd correctness, and HLO reshard guard. Good coverage of the happy path.
examples/jax/ep/ep_moe.py End-to-end MoE example with numerical reference check. Well-structured and self-contained.

Sequence Diagram

%%{init: {'theme': 'neutral'}}%%
sequenceDiagram
    participant Py as Python (ep.py)
    participant Prim as JAX Primitives (cpp_extensions/ep.py)
    participant FFI as XLA FFI (ep.cpp)
    participant NCCL as NCCL EP (nvte_ep_*)

    Note over Py: ep_bootstrap()
    Py->>FFI: set_ep_bootstrap_params(uid, ep_size, …)
    FFI->>NCCL: ncclCommInitRank + nvte_ep_initialize

    Note over Py: ep_dispatch (custom_vjp fwd)
    Py->>Prim: ep_prepare(cfg, topk_idx)
    Prim->>FFI: EpPrepareHandler (token_counts, handle_mem)
    FFI->>NCCL: nvte_ep_prepare
    Py->>Prim: ep_dispatch_fwd(cfg, handle_mem, tokens, weights)
    Prim->>FFI: EpDispatchHandler (recv_tokens, recv_topk_weights)
    FFI->>NCCL: nvte_ep_dispatch

    Note over Py: ep_combine (custom_vjp fwd)
    Py->>Prim: ep_combine_fwd(cfg, handle_mem, expert_out, T)
    Prim->>FFI: EpCombineHandler (result)
    FFI->>NCCL: nvte_ep_combine

    Note over Py: Backward pass
    Py->>Prim: ep_combine_bwd(cfg, handle_mem, g_result)
    Prim->>FFI: EpCombineBwdHandler (grad_expert_out)
    FFI->>NCCL: nvte_ep_combine_bwd
    Py->>Prim: ep_dispatch_bwd(cfg, handle_mem, g_recv_tokens, g_weights)
    Prim->>FFI: EpDispatchBwdHandler (grad_tokens, grad_topk_weights)
    FFI->>NCCL: nvte_ep_dispatch_bwd

    Note over FFI: EpResources lifecycle via weak_ptr + anchor (Python atexit)
Loading
%%{init: {'theme': 'base', 'themeVariables': {"darkMode": true, "background": "#0d1117", "primaryColor": "#21262d", "primaryTextColor": "#e6edf3", "primaryBorderColor": "#8b949e", "lineColor": "#8b949e", "textColor": "#e6edf3", "edgeLabelBackground": "#161b22", "actorBkg": "#21262d", "actorBorder": "#8b949e", "actorTextColor": "#e6edf3", "actorLineColor": "#8b949e", "signalColor": "#8b949e", "signalTextColor": "#e6edf3", "noteBkgColor": "#373320", "noteBorderColor": "#d4a72c", "noteTextColor": "#f0e6c0", "labelBoxBkgColor": "#21262d", "labelBoxBorderColor": "#8b949e", "labelTextColor": "#e6edf3", "loopTextColor": "#e6edf3", "activationBkgColor": "#30363d", "activationBorderColor": "#8b949e"}}}%%
sequenceDiagram
    participant Py as Python (ep.py)
    participant Prim as JAX Primitives (cpp_extensions/ep.py)
    participant FFI as XLA FFI (ep.cpp)
    participant NCCL as NCCL EP (nvte_ep_*)

    Note over Py: ep_bootstrap()
    Py->>FFI: set_ep_bootstrap_params(uid, ep_size, …)
    FFI->>NCCL: ncclCommInitRank + nvte_ep_initialize

    Note over Py: ep_dispatch (custom_vjp fwd)
    Py->>Prim: ep_prepare(cfg, topk_idx)
    Prim->>FFI: EpPrepareHandler (token_counts, handle_mem)
    FFI->>NCCL: nvte_ep_prepare
    Py->>Prim: ep_dispatch_fwd(cfg, handle_mem, tokens, weights)
    Prim->>FFI: EpDispatchHandler (recv_tokens, recv_topk_weights)
    FFI->>NCCL: nvte_ep_dispatch

    Note over Py: ep_combine (custom_vjp fwd)
    Py->>Prim: ep_combine_fwd(cfg, handle_mem, expert_out, T)
    Prim->>FFI: EpCombineHandler (result)
    FFI->>NCCL: nvte_ep_combine

    Note over Py: Backward pass
    Py->>Prim: ep_combine_bwd(cfg, handle_mem, g_result)
    Prim->>FFI: EpCombineBwdHandler (grad_expert_out)
    FFI->>NCCL: nvte_ep_combine_bwd
    Py->>Prim: ep_dispatch_bwd(cfg, handle_mem, g_recv_tokens, g_weights)
    Prim->>FFI: EpDispatchBwdHandler (grad_tokens, grad_topk_weights)
    FFI->>NCCL: nvte_ep_dispatch_bwd

    Note over FFI: EpResources lifecycle via weak_ptr + anchor (Python atexit)
Loading

Reviews (23): Last reviewed commit: "Fix test_two_layer_dispatch_no_handle_al..." | Re-trigger Greptile

Comment thread build_tools/jax.py Outdated
Comment thread build_tools/jax.py Outdated
Comment thread transformer_engine/jax/cpp_extensions/ep.py Outdated
Comment thread transformer_engine/jax/csrc/extensions/ep.cpp Outdated
Error_Type EpPrepareFFI(cudaStream_t stream, Buffer_Type topk_idx, Result_Type token_counts,
Result_Type handle_mem, Result_Type workspace, EpPrepareConfig config) {
auto topk_dims = topk_idx.dimensions();
NVTE_CHECK(topk_dims.size() >= 2,

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: can we return FFI InvalidArgument instead of a NVTE_CHECK for these inputs?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is probably a good idea. I suggest we make another follow-up MR to do so for all the FFIs.

@phu0ngng phu0ngng requested a review from tdophung May 22, 2026 15:51
@phu0ngng

Copy link
Copy Markdown
Collaborator Author

I would appreciate your help to review this PR @tdophung @jberchtold-nvidia!
Please focus on the changes in the JAX side, as the TE/Common ones will be discussed in #3034

Comment thread examples/jax/ep/ep_moe.py Outdated
Comment thread tests/jax/multi_process_launch_ep.sh Outdated
Comment thread transformer_engine/jax/cpp_extensions/ep.py Outdated
Comment thread transformer_engine/jax/cpp_extensions/ep.py Outdated
Comment thread examples/jax/ep/ep_moe.py
Comment thread transformer_engine/jax/ep.py Outdated
Comment thread transformer_engine/jax/ep.py Outdated
Comment thread transformer_engine/jax/ep.py Outdated

@jberchtold-nvidia jberchtold-nvidia left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM pending CI

Comment thread transformer_engine/jax/ep.py
Comment thread transformer_engine/jax/ep.py
Comment thread transformer_engine/jax/ep.py Outdated
Comment thread transformer_engine/jax/cpp_extensions/ep.py Outdated
Comment thread transformer_engine/jax/csrc/extensions/ep.cpp
Comment thread transformer_engine/jax/cpp_extensions/ep.py Outdated
Comment thread transformer_engine/jax/csrc/extensions.h Outdated
jberchtold-nvidia pushed a commit to jberchtold-nvidia/TransformerEngine that referenced this pull request Jun 5, 2026
PR NVIDIA#3034 commit 9b225cb added a required NVTEEpGroupConfig.max_token_dtype
field. The C++ backend (ep_backend.cpp:349) enforces
    typeToSize(tok_dtype) <= typeToSize(max_token_dtype)
at every dispatch, and the field is also used at group create to size the
NCCL EP staging buffers (ep_backend.cpp:221-222).

PR NVIDIA#3036's JAX bootstrap (SetEpBootstrapParams / ep_bootstrap) was written
before this field existed and never set it, so any JAX EP group landed with
the zero-initialized default (kByte = 1 byte). Any bf16/fp16 dispatch from
JAX then failed immediately with:
    tokens dtype (6) wider than group max_token_dtype (0)

This commit threads max_token_dtype end-to-end:

  - transformer_engine/jax/csrc/extensions.h
    update SetEpBootstrapParams declaration to match the new arity.

  - transformer_engine/jax/csrc/extensions/ep.cpp
    add max_token_dtype to EpBootstrapParams and SetEpBootstrapParams;
    forward it into NVTEEpGroupConfig in the EpResources ctor.

  - transformer_engine/jax/csrc/extensions/pybind.cpp
    add the matching pybind11::arg("max_token_dtype") = 0.

  - transformer_engine/jax/ep.py
    add max_token_dtype kwarg to ep_bootstrap, convert numpy dtype to
    NVTEDType int, forward to the C++ setter.

Carried on the te-ep-fixes branch until PR NVIDIA#3036 exposes the field upstream.
See PR NVIDIA#3034 (commit 9b225cb, ep.h:43) for the field definition.
Comment thread transformer_engine/jax/cpp_extensions/ep.py Outdated
Comment thread transformer_engine/jax/csrc/extensions/ep.cpp
@phu0ngng phu0ngng force-pushed the phuong/ep-3-jax branch 2 times, most recently from 06f8a13 to c34771d Compare June 10, 2026 15:24
tdophung added a commit to tdophung/TransformerEngine that referenced this pull request Jun 10, 2026
PR NVIDIA#3034 commit 9b225cb added a required NVTEEpGroupConfig.max_token_dtype
field. The C++ backend (ep_backend.cpp:349) enforces
    typeToSize(tok_dtype) <= typeToSize(max_token_dtype)
at every dispatch, and the field is also used at group create to size the
NCCL EP staging buffers (ep_backend.cpp:221-222).

PR NVIDIA#3036's JAX bootstrap (SetEpBootstrapParams / ep_bootstrap) was written
before this field existed and never set it, so any JAX EP group landed with
the zero-initialized default (kByte = 1 byte). Any bf16/fp16 dispatch from
JAX then failed immediately with:
    tokens dtype (6) wider than group max_token_dtype (0)

This commit threads max_token_dtype end-to-end:

  - transformer_engine/jax/csrc/extensions.h
    update SetEpBootstrapParams declaration to match the new arity.

  - transformer_engine/jax/csrc/extensions/ep.cpp
    add max_token_dtype to EpBootstrapParams and SetEpBootstrapParams;
    forward it into NVTEEpGroupConfig in the EpResources ctor.

  - transformer_engine/jax/csrc/extensions/pybind.cpp
    add the matching pybind11::arg("max_token_dtype") = 0.

  - transformer_engine/jax/ep.py
    add max_token_dtype kwarg to ep_bootstrap, convert numpy dtype to
    NVTEDType int, forward to the C++ setter.

Carried on the te-ep-fixes branch until PR NVIDIA#3036 exposes the field upstream.
See PR NVIDIA#3034 (commit 9b225cb, ep.h:43) for the field definition.
tdophung added a commit to tdophung/TransformerEngine that referenced this pull request Jun 10, 2026
Reset 33 local commits onto phuong/ep-3-jax @ c34771d (her latest with
EpConfig + EpLayerConfig API, NCCL bumped to 808d2433) and re-applied
the three deltas uniquely ours:

  * transformer_engine/jax/moe.py: replaces upstream's multi-backend
    MoE block with our TE-EP-only single-custom-vjp rewrite. Adapted
    to her new API surface: tex.EpLayerConfig replaces tex.ep_make_handle
    (no more EpHandle pool/cache); 5 EP callsites rewired (cfg passed
    in place of handle, ep_prepare arg order swapped, top_k= dropped
    from ep_dispatch_bwd since it's now in cfg.
  * tests/jax/test_te_ep_moe.py: TE-EP MoE test (kept), with
    ep_bootstrap kwargs ep_size= and allow_handle_mem_reloc= dropped
    (no longer supported; ep_size is derived from mesh axes and the
    handle_mem reloc gating is gone).
  * tests/jax/run_te_ep_moe.sh: multi-process launcher (kept).

Pre-sync state preserved at branch
teddy/te_ep_integration.backup-pre-phuong-sync.
EOF
)
@phu0ngng

Copy link
Copy Markdown
Collaborator Author

/te-ci JAX L1

phu0ngng and others added 17 commits June 24, 2026 00:22
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
… with_sharding_constraint

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…trap

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…EpLayerConfig type)

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…ives (lint 10.00)

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…; define NVTE_WITH_NCCL_EP

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…ract, drop dead helpers

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…e example) jax distributed suites

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…ARY_PATH for libnccl_ep.so

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…ck via nccl_ep_enabled()

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
@phu0ngng

Copy link
Copy Markdown
Collaborator Author

/te-ci JAX L1

Comment thread transformer_engine/jax/cpp_extensions/ep.py Outdated
@phu0ngng phu0ngng added the 2.7.0 label Jun 24, 2026
…ition methods

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
@phu0ngng

Copy link
Copy Markdown
Collaborator Author

/te-ci JAX L1

@jberchtold-nvidia jberchtold-nvidia left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM pending CI

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
@phu0ngng

Copy link
Copy Markdown
Collaborator Author

/te-ci JAX L1

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants