|
| 1 | +// SPDX-FileCopyrightText: Copyright (C) 2025 Advanced Micro Devices, Inc. All rights reserved. |
| 2 | +// SPDX-License-Identifier: Apache-2.0 |
| 3 | + |
| 4 | +// Fused dual-GEMV + SiLU + elementwise multiply kernel for AIE2+. |
| 5 | +// |
| 6 | +// Computes: output = silu(W1 @ x) * (W2 @ x) |
| 7 | +// |
| 8 | +// Two entry points called from the NPU design's core body: |
| 9 | +// 1. dual_gemv_matvec_bf16: GEMV writing to FIFO buffer c_out + row_offset |
| 10 | +// 2. dual_gemv_silu_mul_bf16: reads from static left_buf/right_buf, writes to FIFO c_out |
| 11 | +// |
| 12 | +// The static buffers are written via scalar stores (from matvec) and read |
| 13 | +// via aie::load_v in the silu_mul phase. Aligned to 64 bytes for safe vector access. |
| 14 | + |
| 15 | +#define NOCPP |
| 16 | + |
| 17 | +#include "../aie_kernel_utils.h" |
| 18 | + |
| 19 | +#include <aie_api/aie.hpp> |
| 20 | +#include <stdint.h> |
| 21 | +#include <type_traits> |
| 22 | + |
| 23 | +static bfloat16 left_buf[1024] __attribute__((aligned(64))); |
| 24 | +static bfloat16 right_buf[1024] __attribute__((aligned(64))); |
| 25 | + |
| 26 | +template <uint32_t r> |
| 27 | +void matvec_vectorized(uint32_t m, |
| 28 | + uint32_t k, |
| 29 | + const bfloat16 *__restrict a, |
| 30 | + const bfloat16 *__restrict b, |
| 31 | + bfloat16 *__restrict c) |
| 32 | +{ |
| 33 | + ::aie::set_rounding(aie::rounding_mode::conv_even); |
| 34 | + bfloat16 *c_end = c + m; |
| 35 | + const bfloat16 *b_end = b + k; |
| 36 | + for (; c < c_end; c++) { |
| 37 | + aie::accum acc = aie::zeros<accfloat, r>(); |
| 38 | + AIE_LOOP_MIN_ITERATION_COUNT(2) |
| 39 | + for (const bfloat16 *__restrict b_cur = b; b_cur < b_end; b_cur += r, a += r) { |
| 40 | + aie::vector<bfloat16, r> a_vec = aie::load_v<r>(a); |
| 41 | + aie::vector<bfloat16, r> b_vec = aie::load_v<r>(b_cur); |
| 42 | + acc = aie::mac(acc, a_vec, b_vec); |
| 43 | + } |
| 44 | + *c = static_cast<bfloat16>(aie::reduce_add(acc.template to_vector<float>())); |
| 45 | + } |
| 46 | +} |
| 47 | + |
| 48 | +extern "C" { |
| 49 | + |
| 50 | +// Phase 1 & 2: GEMV writing to a static buffer (left_buf or right_buf) |
| 51 | +// phase=0 writes to left_buf, phase=1 writes to right_buf |
| 52 | +void dual_gemv_matvec_bf16(uint32_t m, |
| 53 | + uint32_t k, |
| 54 | + uint32_t row_offset, |
| 55 | + const bfloat16 *__restrict a_in, |
| 56 | + const bfloat16 *__restrict b_in, |
| 57 | + uint32_t phase) |
| 58 | +{ |
| 59 | + bfloat16 *dst = (phase == 0) ? left_buf : right_buf; |
| 60 | + dst += row_offset; |
| 61 | + matvec_vectorized<64>(m, k, a_in, b_in, dst); |
| 62 | +} |
| 63 | + |
| 64 | +// Phase 3: silu(left_buf) * right_buf -> c_out (FIFO buffer) |
| 65 | +void dual_gemv_silu_mul_bf16(bfloat16 *__restrict c_out, int32_t m_output) |
| 66 | +{ |
| 67 | + event0(); |
| 68 | + |
| 69 | + aie::vector<bfloat16, 16> register_0_5 = aie::broadcast<bfloat16, 16>(0.5f); |
| 70 | + aie::vector<bfloat16, 16> register_1 = aie::broadcast<bfloat16, 16>(1.0f); |
| 71 | + AIE_PREPARE_FOR_PIPELINING |
| 72 | + for (int i = 0; i < m_output; i += 16) { |
| 73 | + aie::vector<bfloat16, 16> left_val = aie::load_v<16>(left_buf + i); |
| 74 | + aie::vector<bfloat16, 16> right_val = aie::load_v<16>(right_buf + i); |
| 75 | + |
| 76 | + // SiLU(x) = x * sigmoid(x) = x * 0.5 * (1 + tanh(x/2)) |
| 77 | + auto half_x = aie::mul(left_val, register_0_5); |
| 78 | + auto tanh_half_x = aie::tanh<bfloat16>(half_x.to_vector<float>()); |
| 79 | + auto tanh_half_x_approx = aie::add(tanh_half_x, register_1); |
| 80 | + aie::vector<bfloat16, 16> sigmoid_approx = aie::mul(tanh_half_x_approx, register_0_5); |
| 81 | + auto silu_output = aie::mul(left_val, sigmoid_approx); |
| 82 | + |
| 83 | + auto fused_output = aie::mul(silu_output.to_vector<bfloat16>(), right_val); |
| 84 | + aie::store_v(c_out + i, fused_output.to_vector<bfloat16>()); |
| 85 | + } |
| 86 | + |
| 87 | + event1(); |
| 88 | +} |
| 89 | + |
| 90 | +} // extern "C" |
0 commit comments