From a9a90bcec2188e05d5fb199b97097abd38535530 Mon Sep 17 00:00:00 2001 From: jiyang1011 Date: Thu, 26 Dec 2024 18:04:51 -0800 Subject: [PATCH] softmax with EVT (draft) --- examples/sycl/pvc/CMakeLists.txt | 5 + .../pvc/pvc_gemm_with_epilogue_softmax.cpp | 444 +++++++++++++++ .../epilogue/collective/xe_epilogue.hpp | 17 +- .../cutlass/epilogue/fusion/operations.hpp | 11 + .../cutlass/epilogue/fusion/xe_callbacks.hpp | 69 ++- .../epilogue/fusion/xe_vistor_softmax.hpp | 513 ++++++++++++++++++ 6 files changed, 1054 insertions(+), 5 deletions(-) create mode 100644 examples/sycl/pvc/pvc_gemm_with_epilogue_softmax.cpp create mode 100644 include/cutlass/epilogue/fusion/xe_vistor_softmax.hpp diff --git a/examples/sycl/pvc/CMakeLists.txt b/examples/sycl/pvc/CMakeLists.txt index abc138172..588b9afe0 100644 --- a/examples/sycl/pvc/CMakeLists.txt +++ b/examples/sycl/pvc/CMakeLists.txt @@ -47,6 +47,11 @@ cutlass_example_add_executable( pvc_gemm_with_epilogue_lincombdeeltact.cpp ) +cutlass_example_add_executable( + pvc_gemm_with_epilogue_softmax + pvc_gemm_with_epilogue_softmax.cpp +) + cutlass_example_add_executable( pvc_collective_builder pvc_collective_builder.cpp diff --git a/examples/sycl/pvc/pvc_gemm_with_epilogue_softmax.cpp b/examples/sycl/pvc/pvc_gemm_with_epilogue_softmax.cpp new file mode 100644 index 000000000..61d06f08d --- /dev/null +++ b/examples/sycl/pvc/pvc_gemm_with_epilogue_softmax.cpp @@ -0,0 +1,444 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2024 Codeplay Software Ltd. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/collective/xe_epilogue.hpp" +#include "cutlass/epilogue/fusion/xe_callbacks.hpp" +#include "cutlass/gemm/device/gemm_universal.h" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/collective/collective_mma.hpp" +#include "cutlass/util/GPU_Clock.hpp" + +#include +#include + +#include "cutlass/util/command_line.h" +#include "cutlass/util/device_memory.h" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/reference/device/gemm_complex.h" +#include "cutlass/util/reference/device/tensor_compare.h" +#include "cutlass/util/reference/device/tensor_relu.h" +#include "cutlass/tensor_view.h" +#include "cutlass/coord.h" + +#include "common.hpp" + +using namespace cute; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +// Command line options parsing +struct Options { + + bool help; + bool error; + + int m, n, k, l, iterations; + float alpha, beta; + + Options(): + help(false), + error(false), + m(5120), n(4096), k(4096), l(1), iterations(100), + alpha(1.f), beta(0.f) + { } + + // Parses the command line + void parse(int argc, char const **args) { + cutlass::CommandLine cmd(argc, args); + + if (cmd.check_cmd_line_flag("help")) { + help = true; + return; + } + + cmd.get_cmd_line_argument("m", m, 512); + cmd.get_cmd_line_argument("n", n, 512); + cmd.get_cmd_line_argument("k", k, 64); + cmd.get_cmd_line_argument("l", l, 32); + cmd.get_cmd_line_argument("alpha", alpha, 1.f); + cmd.get_cmd_line_argument("beta", beta, 0.f); + cmd.get_cmd_line_argument("iterations", iterations, 100); + } + + /// Prints the usage statement. + std::ostream & print_usage(std::ostream &out) const { + + out << "PVC GEMM Example\n\n" + << "Options:\n\n" + << " --help If specified, displays this usage statement\n\n" + << " --m= Sets the M extent of the GEMM\n" + << " --n= Sets the N extent of the GEMM\n" + << " --k= Sets the K extent of the GEMM\n" + << " --l= Sets the L extent (batch count) of the GEMM\n" + << " --alpha= Epilogue scalar alpha\n" + << " --beta= Epilogue scalar beta\n\n" + << " --iterations= Iterations\n\n"; + + return out; + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + class Gemm +> +struct ExampleRunner { + + using StrideA = typename Gemm::GemmKernel::StrideA; + using StrideB = typename Gemm::GemmKernel::StrideB; + using StrideC = typename Gemm::GemmKernel::StrideC; + using StrideD = typename Gemm::GemmKernel::StrideD; + + using LayoutA = typename Gemm::LayoutA; + using LayoutB = typename Gemm::LayoutB; + using LayoutC = typename Gemm::LayoutC; + using LayoutD = typename Gemm::LayoutD; + + using ElementA = typename Gemm::ElementA; + using ElementB = typename Gemm::ElementB; + using ElementAcc = typename Gemm::ElementAccumulator; + + using CollectiveEpilogue = typename Gemm::CollectiveEpilogue; + using ElementC = typename Gemm::ElementC; + using ElementOutput = typename CollectiveEpilogue::ElementOutput; + using ElementCompute = typename CollectiveEpilogue::ElementCompute; + using ElementAccumulator = typename CollectiveEpilogue::ElementAccumulator; + + using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; + + // + // Data members + // + + /// Initialization + StrideA stride_A; + StrideB stride_B; + StrideC stride_C; + StrideD stride_D; + uint64_t seed = 0; + + cutlass::DeviceAllocation block_A; + cutlass::DeviceAllocation block_B; + cutlass::DeviceAllocation block_C; + cutlass::DeviceAllocation block_D; + cutlass::DeviceAllocation block_ref_D; + + // + // Methods + // + + bool verify(const ProblemShapeType& problem_size, ElementCompute alpha, ElementCompute beta) { + return true; + auto [M, N, K, L] = problem_size; + + cutlass::TensorRef ref_A(block_A.get(), LayoutA::packed({M, K})); + cutlass::TensorRef ref_B(block_B.get(), LayoutB::packed({K, N})); + cutlass::TensorRef ref_C(block_C.get(), LayoutC::packed({M, N})); + cutlass::TensorRef ref_D(block_ref_D.get(), LayoutD::packed({M, N})); + + cutlass::reference::device::GemmComplex( + {M, N, K}, + alpha, + ref_A, + cutlass::ComplexTransform::kNone, + ref_B, + cutlass::ComplexTransform::kNone, + beta, + ref_C, + ref_D, + ElementAccumulator(0), + L, // batch_count + M * K, // batch_stride_A + K * N, // batch_stride_B + M * N, // batch_stride_C + M * N // batch_stride_D + ); + + syclcompat::wait(); +#define IDX (l * M * N + i * N + j) + + ElementOutput *ptr = + (ElementOutput *)std::malloc(M * N * L * sizeof(ElementOutput)); + syclcompat::memcpy(ptr, block_ref_D.get(), + M * N * L * sizeof(ElementOutput)); + syclcompat::wait(); + + for (int l = 0; l < L; l++) { + for (int i = 0; i < M; i++) { + + auto row_max = ptr[l * M * N + i * N]; + for (int j = 0; j < N; j++) { + row_max = max(row_max, ptr[IDX]); + } + + ElementOutput exp_sum = (ElementOutput)0; + for (int j = 0; j < N; j++) { + ptr[IDX] = ptr[IDX] - row_max; + ptr[IDX] = exp(ptr[IDX]); + exp_sum += ptr[IDX]; + } + + for (int j = 0; j < N; j++) { + ptr[IDX] = ptr[IDX] / exp_sum; + } + } + } + + syclcompat::memcpy(block_ref_D.get(), ptr, + M * N * L * sizeof(ElementOutput)); + syclcompat::wait(); + + ElementOutput *ptr_refD = + (ElementOutput *)std::malloc((size_t)M * N * L * sizeof(ElementOutput)); + syclcompat::memcpy(ptr_refD, block_D.get(), + (size_t)M * N * L * sizeof(ElementOutput)); + syclcompat::wait(); + + uint32_t err_cnt = 0; + + for (int b = 0; b < L; b++) { + for (int i = 0; i < M; i++) { + for (int j = 0; j < N; j++) { + int idx = b * M * N + i * N + j; + auto expect = ptr[idx]; + auto val = ptr_refD[idx]; + + if (std::isnormal(ptr[idx]) && std::isnormal(ptr_refD[idx])) { + auto gap = fabs(fabs(val - expect) / expect); + if (gap > 0.001f) { + std::cout << "(" << b << ", " << i << ", " << j + << "): " << "host: " << ptr[idx] + << " and device: " << ptr_refD[idx] + << ", gap: " << gap << std::endl; + err_cnt++; + } + } else { + std::cout << "(" << b << ", " << i << ", " << j + << "): " << "host: " << expect << " and device: " << val + << std::endl; + err_cnt++; + } + } + } + } + + std::free(ptr_refD); + std::free(ptr); + std::cout << "err count: " << err_cnt + << ", pass rate: " << 100 - (100 * err_cnt / (M * N * L)) << "%" + << std::endl; + return err_cnt == 0; + } + + /// Initialize operands to be used in the GEMM and reference GEMM + void initialize(const ProblemShapeType& problem_size) { + auto problem_shape_MNKL = cute::append<4>(problem_size, 1); + auto [M, N, K, L] = problem_shape_MNKL; + + stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L)); + stride_B = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L)); + stride_C = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, L)); + stride_D = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, L)); + + block_A.reset(M * K * L); + block_B.reset(K * N * L); + block_C.reset(M * N * L); + block_D.reset(M * N * L); + block_ref_D.reset(M * N * L); + + initialize_block(block_A, seed + 2023); + initialize_block(block_B, seed + 2022); + initialize_block(block_C, seed + 2021); + } + + void run(const Options& options, const cutlass::KernelHardwareInfo& hw_info) { + ProblemShapeType problem_size = ProblemShapeType{options.m, options.n, options.k, options.l}; + + initialize(problem_size); + auto [M, N, K, L] = problem_size; + typename Gemm::GemmKernel::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + problem_size, + {block_A.get(), stride_A, block_B.get(), stride_B}, + {{options.alpha, options.beta}, block_C.get(), stride_C, block_D.get(), stride_D}, + hw_info + }; + + Gemm gemm_op; + + size_t workspace_size = Gemm::get_workspace_size(arguments); + cutlass::device_memory::allocation workspace(workspace_size); + + gemm_op.can_implement(arguments); + + gemm_op.initialize(arguments, workspace.get()); + + // Run the GEMM + gemm_op.run(); + + syclcompat::wait(); + + // Verify that the result is correct + bool passed = verify(problem_size, options.alpha, options.beta); + std::cout << "Disposition: " << (passed ? "Passed" : "Failed") << std::endl; + + if (passed && options.iterations > 0) { + GPU_Clock timer; + timer.start(); + for (int i = 0; i < options.iterations; ++i) { + gemm_op.run(); + } + syclcompat::wait(); + double hbm = + L * + (M * K * sizeof(ElementA) + K * N * sizeof(ElementB) + + M * N * sizeof(ElementOutput)) * + 1e-9; + float cute_time = timer.seconds() / options.iterations; + double tflops = (2.0 * options.m * options.n * options.k * options.l) * 1e-12; + std::cout << "Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << 'x' << options.l << std::endl; + printf("Cutlass GEMM Performance: [%4.3f]GB/s (%6.4f)ms\n", hbm / cute_time, cute_time*1000); + } + + return; + } + +}; + +int main(int argc, const char** argv) +{ + // + // Parse options + // + + Options options; + + options.parse(argc, argv); + + if (options.help) { + options.print_usage(std::cout) << std::endl; + return 0; + } + + if (options.error) { + std::cerr << "Aborting execution." << std::endl; + return -1; + } + + // + // Run examples + // + + // The KernelHardwareInfo struct holds the number of EUs on the GPU with a given device ID. This + // information is used by the underlying kernel. + cutlass::KernelHardwareInfo hw_info; + + // Change device_id to another value if you are running on a machine with multiple GPUs and wish + // to use a GPU other than that with device ID 0. + hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + + bool passed; + + // The code section below describes datatype for input, output matrices and computation between + // elements in input matrices. + using ElementAccumulator = float; // <- data type of accumulator + using ElementComputeEpilogue = float; // <- data type of epilogue operations + using ElementInputA = bfloat16_t; // <- data type of elements in input matrix A + using ElementInputB = bfloat16_t; // <- data type of elements in input matrix B + using ElementOutput = float; // <- data type of elements in output matrix D + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::RowMajor; + using LayoutC = cutlass::layout::RowMajor; + using LayoutD = cutlass::layout::RowMajor; + + using GmemTiledCopyA = XE_2D_U16x8x16_LD_N; + using GmemTiledCopyB = XE_2D_U16x16x16_LD_V; + + // Workgroup-level tile + using TileShape = Shape<_32, _512, _32>; + + using TiledMma = TiledMMA, + Layout>, + Tile<_16,_256,_32>>; // Subgroup level-tile + + using EpilogueTile = Shape<_16, _32>; + constexpr int PipelineStages = 3; + using GEMMDispatchPolicy = cutlass::gemm::MainloopIntelPVC; + using EpilogueDispatchPolicy = cutlass::epilogue::IntelPVCEpilogue; + + using EpilogueOp = cutlass::epilogue::fusion::LinCombSoftmaxRow; + + using FusionCallBacks = cutlass::epilogue::fusion::FusionCallbacks; + using CollectiveEpilogue = cutlass::epilogue::collective::CollectiveEpilogue< + EpilogueDispatchPolicy, + TileShape, + ElementAccumulator, + cutlass::gemm::TagToStrideC_t, + ElementOutput, + cutlass::gemm::TagToStrideC_t, + FusionCallBacks, + XE_2D_U32x8x16_LD_N, + void, void, + XE_2D_U32x8x16_ST_N, + void, void>; + +// Mainloop + using CollectiveMainloop = cutlass::gemm::collective::CollectiveMma< + GEMMDispatchPolicy, + TileShape, + ElementInputA, + cutlass::gemm::TagToStrideA_t, + ElementInputB, + cutlass::gemm::TagToStrideB_t, + TiledMma, + GmemTiledCopyA, void, void, cute::identity, // A + GmemTiledCopyB, void, void, cute::identity // B + >; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + ExampleRunner runner; + + runner.run(options, hw_info); + + return 0; +} diff --git a/include/cutlass/epilogue/collective/xe_epilogue.hpp b/include/cutlass/epilogue/collective/xe_epilogue.hpp index 63a79b20d..77371dfe6 100644 --- a/include/cutlass/epilogue/collective/xe_epilogue.hpp +++ b/include/cutlass/epilogue/collective/xe_epilogue.hpp @@ -41,6 +41,7 @@ #include "cutlass/epilogue/collective/detail.hpp" #include "cutlass/epilogue/fusion/callbacks.hpp" #include "cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp" +#include "cutlass/epilogue/fusion/xe_vistor_softmax.hpp" #include "cutlass/detail/layout.hpp" #include "cute/tensor.hpp" @@ -303,7 +304,7 @@ class CollectiveEpilogue< bool is_C_load_needed = is_source_supported && fusion_callbacks.is_C_load_needed(); Tensor trC = make_tensor(Shape>{}); - Tensor trD = make_tensor(Shape>{}); + Tensor trD = make_tensor(Shape, Int, Int>{}); Tensor tOuti = params.xe_store_d.get_pvc_tensor( make_coord(m_offset, n_offset, 0), make_shape(_, Int{}, Int{}, L), @@ -340,7 +341,8 @@ class CollectiveEpilogue< FragsM * FragsN * FragmentSize * SubgroupSize * ATOM_M * ATOM_N * ATOM_K; constexpr int MN = get<0>(CtaTileMNK{}) * get<1>(CtaTileMNK{}); static_assert(ValuesLoaded == MN, "the total elements loaded by all threads should be the same as MxN" ); - + + auto synchronize = [&] () {}; CUTLASS_PRAGMA_UNROLL for (int epi_n = 0; epi_n < FragsN; epi_n++) { CUTLASS_PRAGMA_UNROLL @@ -356,12 +358,19 @@ class CollectiveEpilogue< CUTLASS_PRAGMA_UNROLL for (int epi_v = 0; epi_v < size(trD_frag); ++epi_v) { - trD_frag(epi_v) = cst_callbacks.visit(acc_frag_mn(epi_v), epi_v, epi_m, epi_n); + trD_frag(epi_v, epi_m, epi_n) = cst_callbacks.visit(acc_frag_mn(epi_v), epi_v, epi_m, epi_n); } - copy(params.xe_store_d, trD, rw_coord(_, epi_m, epi_n)); + cst_callbacks.reduce(nullptr, synchronize, epi_m, epi_n, (epi_m == FragsM - 1 && epi_n == FragsN - 1), trD(_, epi_m, epi_n)); } } + CUTLASS_PRAGMA_UNROLL + for (int epi_n = 0; epi_n < FragsN; epi_n++) { + CUTLASS_PRAGMA_UNROLL + for (int epi_m = 0; epi_m < FragsM; epi_m++) { + copy(params.xe_store_d, trD(_, epi_m, epi_n), rw_coord(_, epi_m, epi_n)); + } + } cst_callbacks.end(); } diff --git a/include/cutlass/epilogue/fusion/operations.hpp b/include/cutlass/epilogue/fusion/operations.hpp index 3aed32710..a167ff2a3 100644 --- a/include/cutlass/epilogue/fusion/operations.hpp +++ b/include/cutlass/epilogue/fusion/operations.hpp @@ -137,6 +137,17 @@ struct LinCombTopKSoftmaxCol : LinearCombination { }; +// D = softmax(alpha * acc + beta * C) +template< + class ElementOutput_, + class ElementCompute_, + class ElementSource_ = ElementOutput_, + class ElementScalar_ = ElementCompute_, + FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest +> +struct LinCombSoftmaxRow + : LinearCombination { +}; // D = alpha * acc + beta * C + per-row bias template< diff --git a/include/cutlass/epilogue/fusion/xe_callbacks.hpp b/include/cutlass/epilogue/fusion/xe_callbacks.hpp index bfacaeda6..bf907f24a 100644 --- a/include/cutlass/epilogue/fusion/xe_callbacks.hpp +++ b/include/cutlass/epilogue/fusion/xe_callbacks.hpp @@ -47,6 +47,7 @@ #include "cutlass/epilogue/fusion/sm90_visitor_load_tma_warpspecialized.hpp" #include "cutlass/epilogue/fusion/sm90_visitor_store_tma_warpspecialized.hpp" #include "cutlass/epilogue/fusion/sm90_visitor_compute_tma_warpspecialized.hpp" +#include "cutlass/epilogue/fusion/xe_vistor_softmax.hpp" ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -168,7 +169,73 @@ struct FusionCallbacks< using Impl::Impl; }; -///////////////////////////////////////////////////////////////////////////////////////////////// +// D = softmax(alpha * acc + beta * C) +template< + // int FragmentSize, + class CtaTileShapeMNK, + class EpilogueTile, + class ElementOutput, + class ElementCompute, + class ElementSource = ElementOutput, + class ElementScalar = ElementCompute, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using XeLinCombSoftmaxRow = + Sm90EVT, // softmax(beta * C + (alpha * acc)) + Sm90LinearCombination // beta * C + (alpha * acc) + >; + +template < + // int FragmentSize, + class ElementOutput_, + class ElementCompute_, + class ElementSource_, + class ElementScalar_, + FloatRoundStyle RoundStyle, + class CtaTileShapeMNK, + class EpilogueTile +> +struct FusionCallbacks< + epilogue::IntelPVCEpilogue, + fusion::LinCombSoftmaxRow, + CtaTileShapeMNK, + EpilogueTile +> : XeLinCombSoftmaxRow { + + using ElementOutput = ElementOutput_; + using ElementCompute = ElementCompute_; + using ElementSource = ElementSource_; + using ElementScalar = ElementScalar_; + using Impl = XeLinCombSoftmaxRow::type, ElementCompute, ElementSource, ElementScalar, RoundStyle>; + using Operation = fusion::LinCombSoftmaxRow; + + struct Arguments { + ElementScalar alpha = ElementScalar(1); + ElementScalar beta = ElementScalar(0); + ElementScalar const* alpha_ptr = nullptr; + ElementScalar const* beta_ptr = nullptr; + + operator typename Impl::Arguments() const { + return + { // unary op: activation(beta * C + (alpha * acc)) + { // ternary op : beta * C + (alpha * acc) + {{beta}, {beta_ptr}}, // leaf args : beta + {}, // leaf args : C + { // binary op : alpha * acc + {{alpha}, {alpha_ptr}}, // leaf args : alpha + {}, // leaf args : acc + {} // binary args : multiplies + }, // end binary op + {} // ternary args : multiply_add + }, // end ternary op + {} // unary args: activation + }; // end unary op + } + }; + + // Ctor inheritance + using Impl::Impl; +}; template< class CtaTileShapeMNK, diff --git a/include/cutlass/epilogue/fusion/xe_vistor_softmax.hpp b/include/cutlass/epilogue/fusion/xe_vistor_softmax.hpp new file mode 100644 index 000000000..9692dbfcb --- /dev/null +++ b/include/cutlass/epilogue/fusion/xe_vistor_softmax.hpp @@ -0,0 +1,513 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2024 Codeplay Software Ltd. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Visitor tree Softmax fusion operation for the Intel PVC epilogue +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/workspace.h" + +#include "cute/tensor.hpp" +#include "sm90_visitor_tma_warpspecialized.hpp" +#ifdef __SYCL_DEVICE_ONLY__ +#define SYCL_DEVICE_OCL(x) SYCL_EXTERNAL x +#else +#define SYCL_DEVICE_OCL(x) \ + inline x { assert(false); } +#endif + +SYCL_DEVICE_OCL(float sub_group_reduce_add(float i)); +SYCL_DEVICE_OCL(float sub_group_reduce_max(float i)); + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::epilogue::fusion { + +///////////////////////////////////////////////////////////////////////////////////////////////// +#undef MAX +#undef EXP +#undef DIV + +#define MAX sycl::max +#define EXP sycl::native::exp +#define DIV sycl::native::divide + +namespace detail { + +CUTLASS_DEVICE +float item_reduce_sum(float val) { + float res = val; + return sub_group_reduce_add(res); +} + +CUTLASS_DEVICE +float item_reduce_max(float val) { + float res = val; + return sub_group_reduce_max(res); +} + +template +CUTLASS_DEVICE +decltype(auto) sg_reduce_sum(float* vec) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; i++) { + vec[i] = item_reduce_sum(vec[i]); + } +} + +template +CUTLASS_DEVICE +decltype(auto) sg_reduce_max(float* vec) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; i++) { + vec[i] = item_reduce_max(vec[i]); + } +} + +template +CUTLASS_DEVICE +decltype(auto) work_group_reduce_sum(mem_t &mem, float* vec) { + auto item = sycl::ext::oneapi::experimental::this_nd_item<3>(); + auto sg = item.get_sub_group(); + auto group = item.get_group(); + auto sg_group_id_n = sg.get_group_id() % sg_num; + auto sg_local_id = sg.get_local_id()[0]; + + static_assert((sg_num % IntelPVCEpilogue::SubgroupSize) == 0); + + sycl::group_barrier(group); + + static constexpr auto step = sg_num / IntelPVCEpilogue::SubgroupSize; + + if constexpr (sg_num <= N) { + static constexpr auto n_step = N / sg_num; + auto base = sg_local_id * N * step + sg_group_id_n * n_step; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / sg_num; i++) { + auto sum = 0.f; + + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < step; j++) { + sum += mem[base + i + N * j]; + } + + auto group_sum = item_reduce_sum(sum); + + if (sg_local_id == i) { + mem[sg_group_id_n * n_step + i] = group_sum; + } + } + } + else { + auto sum = 0.f; + + auto base = sg_local_id * N * step + sg_group_id_n; + + if (sg_group_id_n < N) { + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < step; j++) { + sum += mem[base + N * j]; + } + + auto group_sum = item_reduce_sum(sum); + + if (sg_local_id == 0) { + mem[sg_group_id_n] = group_sum; + } + } + } + sycl::group_barrier(group); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; i++) { + vec[i] = mem[i]; + } +} + +template +CUTLASS_DEVICE +void work_group_reduce_max(mem_t &mem, float* vec) { + auto item = sycl::ext::oneapi::experimental::this_nd_item<3>(); + auto sg = item.get_sub_group(); + auto group = item.get_group(); + auto sg_group_id = sg.get_group_id(); + auto sg_group_id_n = sg_group_id % sg_num; + auto sg_local_id = sg.get_local_id()[0]; + + static_assert((sg_num % IntelPVCEpilogue::SubgroupSize) == 0); + + sycl::group_barrier(group); + + static constexpr auto step = sg_num / IntelPVCEpilogue::SubgroupSize; + + if constexpr (sg_num <= N) { + static constexpr auto n_step = N / sg_num; + auto base = sg_local_id * N * step + sg_group_id_n * n_step; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / sg_num; i++) { + auto local_max = mem[base + i]; + + CUTLASS_PRAGMA_UNROLL + for (int j = 1; j < step; j++) { + local_max = MAX(local_max, mem[base + i + N * j]); + } + + auto group_max = item_reduce_max(local_max); + + if (sg_local_id == i) { + mem[sg_group_id_n * n_step + i] = group_max; + } + } + } + else { + auto base = sg_local_id * N * step + sg_group_id_n; + auto local_max = mem[base]; + + if (sg_group_id_n < N) { + CUTLASS_PRAGMA_UNROLL + for (int j = 1; j < step; j++) { + local_max = MAX(local_max, mem[base + N * j]); + } + + auto group_max = item_reduce_max(local_max); + + if (sg_local_id == 0) { + mem[sg_group_id_n] = group_max; + } + } + } + sycl::group_barrier(group); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; i++) { + vec[i] = mem[i]; + } +} + +template +CUTLASS_DEVICE +void group_reduce_sum(mem_t smem, float *const vec, + float *out) { + auto item = sycl::ext::oneapi::experimental::this_nd_item<3>(); + auto sg = item.get_sub_group(); + + sg_reduce_sum(vec); + + auto sg_group_id = sg.get_group_id(); + auto sg_group_id_n = sg_group_id % sg_per_wg_n; + auto sg_local_id = sg.get_local_id()[0]; + + auto slm_base = smem + (sg_group_id / sg_per_wg_n) * sg_per_wg_n * N; + + if constexpr (N < IntelPVCEpilogue::SubgroupSize) { + if (sg_local_id < N) { + slm_base[sg_group_id_n * N + sg_local_id] = vec[sg_local_id]; + } + } + else { + static constexpr auto step = N / IntelPVCEpilogue::SubgroupSize; + auto base = sg_local_id * step; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < step; i++) { + auto offset = base + i; + slm_base[sg_group_id_n * N + offset] = vec[offset]; + } + } + + work_group_reduce_sum(slm_base, out); +} + +template +CUTLASS_DEVICE +void group_reduce_max(mem_t smem, float *const vec, + float *out) { + auto item = sycl::ext::oneapi::experimental::this_nd_item<3>(); + auto sg = item.get_sub_group(); + + sg_reduce_max(vec); + + auto sg_group_id = sg.get_group_id(); + auto sg_group_id_n = sg_group_id % sg_per_wg_n; + auto sg_local_id = sg.get_local_id()[0]; + + auto slm_base = smem + (sg_group_id / sg_per_wg_n) * sg_per_wg_n * N; + + if constexpr (N < IntelPVCEpilogue::SubgroupSize) { + if (sg_local_id < N) { + slm_base[sg_group_id_n * N + sg_local_id] = vec[sg_local_id]; + } + } + else { + static constexpr auto step = N / IntelPVCEpilogue::SubgroupSize; + auto base = sg_local_id * step; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < step; i++) { + slm_base[sg_group_id_n * N + base + i] = vec[base + i]; + } + } + + work_group_reduce_max(slm_base, out); +} + +template +CUTLASS_DEVICE +auto group_reduce_sum1(mem_t smem, RTensor const &t, float *out) { + static constexpr auto row = decltype(size<0>(t))::value; + static constexpr auto col = decltype(size<1>(t))::value; + + float local_sum[row]; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < row; i++) { + local_sum[i] = t(i, 0); + } + + CUTLASS_PRAGMA_UNROLL + for (int i = 1; i < col; i++) { + auto tmp = t(_, i); + + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < row; j++) { + local_sum[j] += tmp(j); + } + } + group_reduce_sum(smem, local_sum, out); +} + +template +CUTLASS_DEVICE +auto group_reduce_max1(mem_t smem, RTensor const &t, float *out) { + static constexpr auto row = decltype(size<0>(t))::value; + static constexpr auto col = decltype(size<1>(t))::value; + + float local_max[row]; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < row; i++) { + local_max[i] = t(i, 0); + } + + CUTLASS_PRAGMA_UNROLL + for (int i = 1; i < col; i++) { + auto tmp = t(_, i); + + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < row; j++) { + local_max[j] = MAX(local_max[j], tmp(j)); + } + } + group_reduce_max(smem, local_max, out); +} + +} // namespace detail + +template < + // int FragmentSize, + class CtaTileShapeMNK, + class EpilogueTile, + class ElementOutput, + class ElementCompute, + FloatRoundStyle RoundStyle +> +struct XeSoftmaxRowReduction +{ +public: + static constexpr int FragmentSize = 8; + static constexpr auto Tile_M = get<0>(CtaTileShapeMNK{}); + static constexpr auto Tile_N = get<1>(CtaTileShapeMNK{}); + static constexpr auto Epi_M = get<0>(EpilogueTile{}); + static constexpr auto Epi_N = get<1>(EpilogueTile{}); + static constexpr auto Sg_M = Tile_M / Epi_M; + static constexpr auto Sg_N = Tile_N / Epi_N; + static constexpr auto Sg_Nums = Sg_M * Sg_N; + struct SharedStorage { }; + + struct Arguments { }; + + struct Params { }; + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + return {}; + } + + template + static bool + can_implement(ProblemShape const& problem_shape, Arguments const& args) { + auto [M, N, K, L] = problem_shape; + auto [tile_M, tile_N, tile_K] = CtaTileShapeMNK{}; + // Cross CTA reduction is not possible because there is no guarantee that all CTAs run + // concurrently. + // Cross epilogue tile reduction is possible, but re-visiting and applying reduction + // to accumulators is only possible for the current epilogue tile. + auto [epi_M, epi_N] = EpilogueTile{}; + return N <= tile_N; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + return 0; + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, + CudaHostAdapter* cuda_adapter = nullptr) { + return Status::kSuccess; + } + + CUTLASS_DEVICE bool + is_producer_load_needed() const { + return false; + } + + CUTLASS_DEVICE bool + is_C_load_needed() const { + return false; + } + + CUTLASS_HOST_DEVICE + XeSoftmaxRowReduction() { } + + CUTLASS_HOST_DEVICE + XeSoftmaxRowReduction(Params const& params, SharedStorage const& shared_storage) + : params(params) { } + + Params params; + + template + CUTLASS_DEVICE auto + get_producer_load_callbacks(ProducerLoadArgs const& args) { + return EmptyProducerLoadCallbacks{}; + } + + template + struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { + + CUTLASS_DEVICE + ConsumerStoreCallbacks(Params const& params) : params(params) {} + + // ArgsTuple args_tuple; + Params const& params; + template + CUTLASS_DEVICE auto + visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n, + Array const& frg_input) { + + return frg_acc; + } + + template + CUTLASS_DEVICE void + reduce(STensor&& smem_buffer, SyncFn const& sync_fn, int epi_m, int epi_n, bool is_last_iteration, VTensor visit_results) { + if(is_last_iteration) { + constexpr auto vec_size = min(Epi_M, Sg_N); + constexpr auto loop_cnt = Epi_M / vec_size; + + auto smem = syclcompat::local_mem(); + + auto t = + make_tensor(static_cast(visit_results).data() - epi_m * FragmentSize - epi_n, + make_shape(Int{}, Int{}, Int{})); + + CUTLASS_PRAGMA_UNROLL + for (int loop = 0; loop < loop_cnt; loop++) { + auto loop_t = t(_, loop, _); + float group_max[vec_size]; + group_reduce_max1(smem, loop_t, group_max); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < Sg_N / IntelPVCEpilogue::SubgroupSize; i++) { + auto tmp = loop_t(_, i); + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < vec_size; j++) { + tmp(j) -= group_max[j]; + } + } + } + CUTLASS_PRAGMA_UNROLL + for (int loop = 0; loop < loop_cnt; loop++) { + auto loop_t = t(_, loop, _); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < Sg_N / IntelPVCEpilogue::SubgroupSize; i++) { + auto tmp = loop_t(_, i); + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < vec_size; j++) { + tmp(j) = EXP(tmp(j)); + } + } + } + CUTLASS_PRAGMA_UNROLL + for (int loop = 0; loop < loop_cnt; loop++) { + auto loop_t = t(_, loop, _); + + float group_sum[vec_size]; + group_reduce_sum1(smem, loop_t, group_sum); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < Sg_N / IntelPVCEpilogue::SubgroupSize; i++) { + auto tmp = loop_t(_, i); + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < vec_size; j++) { + tmp(j) = DIV(tmp(j), group_sum[j]); + } + } + } + } + } + }; + + template < + bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy + class... Args + > + CUTLASS_DEVICE auto + get_consumer_store_callbacks(ConsumerStoreArgs const& args) { + auto args_tuple = make_tuple(); + return ConsumerStoreCallbacks(params); + } + +}; + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::epilogue::fusion + +/////////////////////////////////////////////////////////////////////////////////////////////////