[Draft] Use vendored cuDNN frontend for Python#3148
Conversation
Build and install the cuDNN frontend Python package from the vendored 3rdparty/cudnn-frontend submodule instead of relying on a separately installed PyPI package. Add a build_tools.cudnn_frontend helper that can report the vendored version, build a wheel, or install from the submodule. The helper validates the submodule checkout, uses a writable pip cache, detects installed direct-url metadata, and passes pybind11's CMake package location through pybind11_DIR/CMAKE_PREFIX_PATH for non-isolated builds. Add a transformer_engine.common.cudnn_frontend runtime helper that centralizes cudnn imports, prefers a built vendored checkout when present, validates source-tree version matches, exposes version checks, and keeps the JAX Python/C++ cuDNN frontend version check shared. Route PyTorch flex attention, JAX score_mod attention, grouped MLP fused kernels, tests, and the Mixtral example through the shared helper instead of direct cudnn imports or package metadata checks. Update source builds, QA scripts, and release wheel assembly to build/install the vendored nvidia-cudnn-frontend artifact, with release wheel installation pinned to the vendored submodule version. Validation: python3 -m build_tools.cudnn_frontend version; python3 -m compileall on touched Python helpers and call sites; non-isolated vendored wheel build; isolated vendored wheel import via PYTHONPATH; focused PyTorch flex-attention cache tests. Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
for more information, see https://pre-commit.ci
Greptile SummaryThis PR replaces the PyPI
Confidence Score: 3/5The build infrastructure and centralized import helper are well-designed, but a version-mismatch scenario in source checkouts causes The core logic — submodule detection, idempotency guard, pybind11 CMake forwarding,
Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A([caller: import_cudnn_frontend]) --> B{vendored _compiled_module* present?}
B -- yes --> C[sys.path.insert vendored python path]
B -- no --> D
C --> D{find_spec cudnn?}
D -- yes --> E[importlib.import_module cudnn]
E --> F{check_cudnn_frontend_vendored_version_match}
F -- no vendored source --> G([return module])
F -- versions match --> G
F -- mismatch --> H([raise RuntimeError])
D -- no --> I([raise ImportError])
J([setup.py build/install]) --> K{NVTE_RELEASE_BUILD=0 and pytorch/jax?}
K -- no --> L([skip])
K -- yes --> M[install_from_submodule]
M --> N{version match AND installed from vendored?}
N -- yes --> O([skip - already current])
N -- no --> P[pip install --no-deps --force-reinstall from submodule]
P --> Q([installed])
style H fill:#f88,color:#000
style I fill:#f88,color:#000
%%{init: {'theme': 'base', 'themeVariables': {"darkMode": true, "background": "#0d1117", "primaryColor": "#21262d", "primaryTextColor": "#e6edf3", "primaryBorderColor": "#8b949e", "lineColor": "#8b949e", "textColor": "#e6edf3", "edgeLabelBackground": "#161b22", "actorBkg": "#21262d", "actorBorder": "#8b949e", "actorTextColor": "#e6edf3", "actorLineColor": "#8b949e", "signalColor": "#8b949e", "signalTextColor": "#e6edf3", "noteBkgColor": "#373320", "noteBorderColor": "#d4a72c", "noteTextColor": "#f0e6c0", "labelBoxBkgColor": "#21262d", "labelBoxBorderColor": "#8b949e", "labelTextColor": "#e6edf3", "loopTextColor": "#e6edf3", "activationBkgColor": "#30363d", "activationBorderColor": "#8b949e"}}}%%
flowchart TD
A([caller: import_cudnn_frontend]) --> B{vendored _compiled_module* present?}
B -- yes --> C[sys.path.insert vendored python path]
B -- no --> D
C --> D{find_spec cudnn?}
D -- yes --> E[importlib.import_module cudnn]
E --> F{check_cudnn_frontend_vendored_version_match}
F -- no vendored source --> G([return module])
F -- versions match --> G
F -- mismatch --> H([raise RuntimeError])
D -- no --> I([raise ImportError])
J([setup.py build/install]) --> K{NVTE_RELEASE_BUILD=0 and pytorch/jax?}
K -- no --> L([skip])
K -- yes --> M[install_from_submodule]
M --> N{version match AND installed from vendored?}
N -- yes --> O([skip - already current])
N -- no --> P[pip install --no-deps --force-reinstall from submodule]
P --> Q([installed])
style H fill:#f88,color:#000
style I fill:#f88,color:#000
|
| def cudnn_frontend_version_at_least(min_version: str) -> bool: | ||
| """Whether the imported cuDNN frontend Python package is at least ``min_version``.""" | ||
| try: | ||
| return cudnn_frontend_version() >= PkgVersion(min_version) | ||
| except ImportError: | ||
| return False |
There was a problem hiding this comment.
RuntimeError from version mismatch escapes ImportError-only try/except blocks
cudnn_frontend_version_at_least catches only ImportError, but import_cudnn_frontend() → check_cudnn_frontend_vendored_version_match() raises RuntimeError when the installed version doesn't match the vendored source. In a source checkout with a stale install, every version guard in grouped_mlp.py (_cudnn_frontend_version_supported, _cudnn_frontend_geglu_runtime_params, etc.) would crash instead of returning False.
The same escape path also affects grouped_gemm_glu_hadamard_kernel and _grouped_gemm_dsrelu_backward_supported in grouped_mlp.py and the skip guard in tests/pytorch/test_grouped_mlp.py — all of which catch only ImportError, so a version mismatch turns a "feature unavailable" graceful path into an unhandled exception. Catching (ImportError, RuntimeError) in this function would preserve the intended boolean contract.
| try: | ||
| from cudnn import ( | ||
| grouped_gemm_glu_hadamard_wrapper_sm100, | ||
| ) # pylint: disable=no-name-in-module,import-outside-toplevel | ||
| return _get_cudnn_frontend_symbol("grouped_gemm_glu_hadamard_wrapper_sm100") | ||
| except ImportError: | ||
| return None |
There was a problem hiding this comment.
RuntimeError from version mismatch not caught — None-return contract broken
grouped_gemm_glu_hadamard_kernel is designed as a feature probe: it returns a callable if available or None if not. _get_cudnn_frontend_symbol calls import_cudnn_frontend(), which raises RuntimeError (not ImportError) when the installed package version mismatches the vendored source. In a source checkout with a stale nvidia-cudnn-frontend install, this except ImportError does not fire and the RuntimeError propagates to callers expecting Callable | None. The same pattern applies to _grouped_gemm_dsrelu_backward_supported at line 324–328 and the skip guard in tests/pytorch/test_grouped_mlp.py.
Build and install the cuDNN frontend Python package from the vendored 3rdparty/cudnn-frontend submodule instead of relying on a separately installed PyPI package.
Add a build_tools.cudnn_frontend helper that can report the vendored version, build a wheel, or install from the submodule. The helper validates the submodule checkout, uses a writable pip cache, detects installed direct-url metadata, and passes pybind11's CMake package location through pybind11_DIR/CMAKE_PREFIX_PATH for non-isolated builds.
Add a transformer_engine.common.cudnn_frontend runtime helper that centralizes cudnn imports, prefers a built vendored checkout when present, validates source-tree version matches, exposes version checks, and keeps the JAX Python/C++ cuDNN frontend version check shared.
Route PyTorch flex attention, JAX score_mod attention, grouped MLP fused kernels, tests, and the Mixtral example through the shared helper instead of direct cudnn imports or package metadata checks.
Update source builds, QA scripts, and release wheel assembly to build/install the vendored nvidia-cudnn-frontend artifact, with release wheel installation pinned to the vendored submodule version.
Validation: python3 -m build_tools.cudnn_frontend version; python3 -m compileall on touched Python helpers and call sites; non-isolated vendored wheel build; isolated vendored wheel import via PYTHONPATH; focused PyTorch flex-attention cache tests.
Description
Please include a brief summary of the changes, relevant motivation and context.
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: