Skip to content

bug: fuse_mul_addsub missing c - a*b FMA fusion pattern#82

Open
shauray8 wants to merge 1 commit intoNVIDIA:mainfrom
shauray8:shauray8/bug-fuse_mul_addsub
Open

bug: fuse_mul_addsub missing c - a*b FMA fusion pattern#82
shauray8 wants to merge 1 commit intoNVIDIA:mainfrom
shauray8:shauray8/bug-fuse_mul_addsub

Conversation

@shauray8
Copy link
Copy Markdown

Description

fuse_mul_addsub in _passes/rewrite_patterns.py fuses four of the five natural FMA patterns but missed the case where the mul is on the rhs of a subtraction (c - a*b). The guard on the rhs-mul branch:

elif op.fn == "add" and (mul_op := ctx.get_match(op.rhs, match_float_mul))   

excluded sub, so c - a*b fell through to NoMatch and stayed as RawBinaryArithmeticOperation(mul) + RawBinaryArithmeticOperation(sub) in the IR instead of a single FusedMulAddOperation.

This PR drops the op.fn == "add" guard, track rhs_is_mul, and negate mul_op.lhs instead of acc for the sub case so c - a*b emits fma(-a, b, c).

Tileiras has its own FMA contraction pass so the final SASS is identical for single-use intermediates either way. The gap is at the CuTile IR level.

Reproducer
import sys
import torch
from io import BytesIO
import cuda.tile as ct
import cuda.tile._passes.rewrite_patterns as rp
import cuda.tile._compile as _compile_mod
from cuda.tile._ir.ops import RawBinaryArithmeticOperation, FusedMulAddOperation
from cuda.tile.compilation import CallingConvention, KernelSignature
from cuda.tile._compile import get_sm_arch

_captured_ops = []
_orig = rp.rewrite_patterns

def _capture(block):
    _orig(block)
    _captured_ops.clear()
    _captured_ops.extend(
        op for op in block.traverse()
        if isinstance(op, (RawBinaryArithmeticOperation, FusedMulAddOperation))
    )

_compile_mod.rewrite_patterns = _capture


def mul_sub_kernel(x, y, z, output, TILE: ct.Constant[int], DIM: ct.Constant[int]):
    bidx = ct.bid(0)
    tx = ct.load(x, index=(bidx, 0), shape=(TILE, DIM))
    ty = ct.load(y, index=(bidx, 0), shape=(TILE, DIM))
    tz = ct.load(z, index=(bidx, 0), shape=(TILE, DIM))
    ct.store(output, index=(bidx, 0), tile=tx * ty - tz)

def sub_mul_kernel(x, y, z, output, TILE: ct.Constant[int], DIM: ct.Constant[int]):
    bidx = ct.bid(0)
    tx = ct.load(x, index=(bidx, 0), shape=(TILE, DIM))
    ty = ct.load(y, index=(bidx, 0), shape=(TILE, DIM))
    tz = ct.load(z, index=(bidx, 0), shape=(TILE, DIM))
    ct.store(output, index=(bidx, 0), tile=tz - tx * ty)


shape = (128, 32)
t = [torch.randn(shape, dtype=torch.float32, device="cuda") for _ in range(4)]

def run(fn):
    k = ct.kernel(fn)
    sig = KernelSignature.from_kernel_args(
        k, (*t, 32, 32), CallingConvention.cutile_python_v1()
    )
    ct.compilation.export_kernel(k, [sig], BytesIO(),
                                 gpu_code=get_sm_arch(),
                                 output_format="tileir_bytecode")
    return list(_captured_ops)

for fn, label in [(mul_sub_kernel, "a*b - c"), (sub_mul_kernel, "c - a*b")]:
    ops = run(fn)
    fused = any(isinstance(op, FusedMulAddOperation) for op in ops)
    print(f"{label}: {'fma' if fused else ' | '.join(getattr(op,'fn','fma') for op in ops)}")

sys.exit(0 if not all(
    any(isinstance(op, FusedMulAddOperation) for op in run(fn))
    for fn in (mul_sub_kernel, sub_mul_kernel)
) else 1)
a*b - c: fma
c - a*b: mul | sub

Checklist

  • I am familiar with the Contributing Guidelines.
  • New or existing tests cover these changes.
  • The documentation is up to date with these changes.

Signed-off-by: shauray8 <shauray9@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant