From e2287572050c529168042f34b55441afb489d273 Mon Sep 17 00:00:00 2001 From: Alejandro Acosta Date: Mon, 7 Oct 2024 17:57:31 +0100 Subject: [PATCH] Add benchmark configurations (#127) * Refactor benchmark configurations * fix overflow in benchmarks * Rebase --------- Co-authored-by: atharva.dubey --- benchmarks/CMakeLists.txt | 2 - .../ampere/ampere_gemm_bf16_bf16_fp32.hpp | 119 ---- .../ampere/ampere_gemm_fp16_fp16_fp32.hpp | 122 ---- .../ampere/ampere_gemm_tf32_tf32_fp32.hpp | 119 ---- benchmarks/ampere/benchmarks.hpp | 115 +++- benchmarks/ampere/gemm_configuration.hpp | 636 +++++++++++++----- benchmarks/ampere/input.in | 77 ++- benchmarks/{common => }/benchmark_runner.hpp | 16 +- benchmarks/main.cpp | 16 +- benchmarks/pvc/benchmarks.hpp | 15 +- benchmarks/pvc/gemm_configuration.hpp | 151 +++++ benchmarks/pvc/input.in | 21 +- benchmarks/pvc/pvc_gemm_bf16_bf16_fp32.cpp | 114 ---- 13 files changed, 814 insertions(+), 709 deletions(-) delete mode 100644 benchmarks/ampere/ampere_gemm_bf16_bf16_fp32.hpp delete mode 100644 benchmarks/ampere/ampere_gemm_fp16_fp16_fp32.hpp delete mode 100644 benchmarks/ampere/ampere_gemm_tf32_tf32_fp32.hpp rename benchmarks/{common => }/benchmark_runner.hpp (96%) create mode 100644 benchmarks/pvc/gemm_configuration.hpp delete mode 100644 benchmarks/pvc/pvc_gemm_bf16_bf16_fp32.cpp diff --git a/benchmarks/CMakeLists.txt b/benchmarks/CMakeLists.txt index de609bc9b..f5f5b200c 100644 --- a/benchmarks/CMakeLists.txt +++ b/benchmarks/CMakeLists.txt @@ -35,8 +35,6 @@ FetchContent_Declare( ) FetchContent_MakeAvailable(googlebenchmark) -set(CUTLASS_BENCHMARKS_COMMON_SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/common) - add_custom_target(cutlass_benchmarks) function(cutlass_benchmark_add_executable NAME) diff --git a/benchmarks/ampere/ampere_gemm_bf16_bf16_fp32.hpp b/benchmarks/ampere/ampere_gemm_bf16_bf16_fp32.hpp deleted file mode 100644 index c5c80b13f..000000000 --- a/benchmarks/ampere/ampere_gemm_bf16_bf16_fp32.hpp +++ /dev/null @@ -1,119 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2024 - 2024 Codeplay Software Ltd. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - -#pragma once - -#include "gemm_configuration.hpp" - -template < - typename LayoutA, - typename LayoutB, - typename LayoutC, - typename LayoutD - > -struct AmpereGemmBF16BF16FP32 { - using ElementAccumulator = float; - using ElementComputeEpilogue = float; - using ElementInputA = bfloat16_t; - using ElementInputB = bfloat16_t; - using ElementOutput = float; - - using TileShape = Shape<_128, _128, _32>; - - using TiledMma = TiledMMA< - MMA_Atom, - Layout>, - Tile<_32,_32,_16>>; - - static constexpr int kAlignmentA = 8; - using DefaultOperandA = DefaultGemm_TensorOpSm80_OperandA< - ElementInputA, LayoutA, kAlignmentA, 32>; - using SmemLayoutAtomA = typename DefaultOperandA::SmemLayoutAtom; // M, K - using SmemCopyAtomA = typename DefaultOperandA::SmemCopyAtom; - using GmemTiledCopyA = typename DefaultOperandA::GmemTiledCopy; - - static constexpr int kAlignmentB = 8; - using DefaultOperandB = DefaultGemm_TensorOpSm80_OperandB< - ElementInputB, LayoutB, kAlignmentB, 32>; - using SmemLayoutAtomB = typename DefaultOperandB::SmemLayoutAtom; // N, K - using SmemCopyAtomB = typename DefaultOperandB::SmemCopyAtom; - using GmemTiledCopyB = typename DefaultOperandB::GmemTiledCopy; - - using Stages = Int<3>; - - // This code section describes the epilogue part of the kernel - using EpilogueOp = cutlass::epilogue::thread::LinearCombination< - ElementOutput, - 128 / cutlass::sizeof_bits::value, - ElementAccumulator, - ElementComputeEpilogue>; - - using DispatchPolicy = cutlass::gemm::MainloopSm80CpAsync; - - // Define strides (mixed) - using StrideA = cutlass::detail::TagToStrideA_t; - using StrideB = cutlass::detail::TagToStrideB_t; - using StrideC = cutlass::detail::TagToStrideC_t; - using StrideD = cutlass::detail::TagToStrideC_t; - - using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< - StrideC, - StrideD, - EpilogueOp, - cutlass::gemm::EpilogueDefault>; - - // Mainloop - using CollectiveMainloop = cutlass::gemm::collective::CollectiveMma< - DispatchPolicy, - TileShape, - ElementInputA, - StrideA, - ElementInputB, - StrideB, - TiledMma, - GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, // A - GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity // B - >; - - using GemmKernel = cutlass::gemm::kernel::GemmUniversal< - Shape, - CollectiveMainloop, - CollectiveEpilogue - >; - - using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -}; - -using AmpereGemmBF16BF16FP32_CCCC = AmpereGemmBF16BF16FP32< - cutlass::layout::ColumnMajor, - cutlass::layout::ColumnMajor, - cutlass::layout::ColumnMajor, - cutlass::layout::ColumnMajor>; diff --git a/benchmarks/ampere/ampere_gemm_fp16_fp16_fp32.hpp b/benchmarks/ampere/ampere_gemm_fp16_fp16_fp32.hpp deleted file mode 100644 index 16147bb96..000000000 --- a/benchmarks/ampere/ampere_gemm_fp16_fp16_fp32.hpp +++ /dev/null @@ -1,122 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2024 - 2024 Codeplay Software Ltd. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - -#pragma once - -#include "gemm_configuration.hpp" - -template < - typename LayoutA, - typename LayoutB, - typename LayoutC, - typename LayoutD - > -struct AmpereGemmFP16FP16FP32 { - using ElementAccumulator = float; - using ElementComputeEpilogue = float; - using ElementInputA = half_t; - using ElementInputB = half_t; - using ElementOutput = float; - - using TileShape = Shape<_128, _128, _32>; - - using TiledMma = TiledMMA< - MMA_Atom, - Layout>, - Tile<_32,_32,_16>>; - - static constexpr int kAlignmentA = 8; - using DefaultOperandA = DefaultGemm_TensorOpSm80_OperandA< - ElementInputA, LayoutA, kAlignmentA, 32>; - using SmemLayoutAtomA = typename DefaultOperandA::SmemLayoutAtom; // M, K - using SmemCopyAtomA = typename DefaultOperandA::SmemCopyAtom; - using GmemTiledCopyA = typename DefaultOperandA::GmemTiledCopy; - - static constexpr int kAlignmentB = 8; - using DefaultOperandB = DefaultGemm_TensorOpSm80_OperandB< - ElementInputB, LayoutB, kAlignmentB, 32>; - using SmemLayoutAtomB = typename DefaultOperandB::SmemLayoutAtom; // N, K - using SmemCopyAtomB = typename DefaultOperandB::SmemCopyAtom; - using GmemTiledCopyB = typename DefaultOperandB::GmemTiledCopy; - - using Stages = Int<3>; - - // This code section describes the epilogue part of the kernel - using EpilogueOp = cutlass::epilogue::thread::LinearCombination< - ElementOutput, // <- data type of output matrix - 128 / cutlass::sizeof_bits::value, // <- the number of elements per vectorized - // memory access. For a byte, it's 16 - // elements. This becomes the vector width of - // math instructions in the epilogue too - ElementAccumulator, // <- data type of accumulator - ElementComputeEpilogue>; // <- data type for alpha/beta in linear combination function - - using DispatchPolicy = cutlass::gemm::MainloopSm80CpAsync; - - // Define strides (mixed) - using StrideA = cutlass::detail::TagToStrideA_t; - using StrideB = cutlass::detail::TagToStrideB_t; - using StrideC = cutlass::detail::TagToStrideC_t; - using StrideD = cutlass::detail::TagToStrideC_t; - - using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< - StrideC, - StrideD, - EpilogueOp, - cutlass::gemm::EpilogueDefault>; - - // Mainloop - using CollectiveMainloop = cutlass::gemm::collective::CollectiveMma< - DispatchPolicy, - TileShape, - ElementInputA, - StrideA, - ElementInputB, - StrideB, - TiledMma, - GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, // A - GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity // B - >; - - using GemmKernel = cutlass::gemm::kernel::GemmUniversal< - Shape, - CollectiveMainloop, - CollectiveEpilogue - >; - - using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -}; - -using AmpereGemmFP16FP16FP32_CCCC = AmpereGemmFP16FP16FP32< - cutlass::layout::ColumnMajor, - cutlass::layout::ColumnMajor, - cutlass::layout::ColumnMajor, - cutlass::layout::ColumnMajor>; diff --git a/benchmarks/ampere/ampere_gemm_tf32_tf32_fp32.hpp b/benchmarks/ampere/ampere_gemm_tf32_tf32_fp32.hpp deleted file mode 100644 index c3460af41..000000000 --- a/benchmarks/ampere/ampere_gemm_tf32_tf32_fp32.hpp +++ /dev/null @@ -1,119 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2024 - 2024 Codeplay Software Ltd. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - -#pragma once - -#include "gemm_configuration.hpp" - -template < - typename LayoutA, - typename LayoutB, - typename LayoutC, - typename LayoutD - > -struct AmpereGemmTF32TF32FP32 { - using ElementAccumulator = float; - using ElementComputeEpilogue = float; - using ElementInputA = tfloat32_t; - using ElementInputB = tfloat32_t; - using ElementOutput = float; - - using TileShape = Shape<_128, _128, _32>; - - using TiledMma = TiledMMA< - MMA_Atom, - Layout, Stride<_2, _1, _1>>, - Tile<_32,_32,_8>>; - - static constexpr int kAlignmentA = 4; - using DefaultOperandA = DefaultGemm_TensorOpSm80_OperandA< - ElementInputA, LayoutA, kAlignmentA, 32>; - using SmemLayoutAtomA = typename DefaultOperandA::SmemLayoutAtom; // M, K - using SmemCopyAtomA = typename DefaultOperandA::SmemCopyAtom; - using GmemTiledCopyA = typename DefaultOperandA::GmemTiledCopy; - - static constexpr int kAlignmentB = 4; - using DefaultOperandB = DefaultGemm_TensorOpSm80_OperandB< - ElementInputB, LayoutB, kAlignmentB, 32>; - using SmemLayoutAtomB = typename DefaultOperandB::SmemLayoutAtom; // N, K - using SmemCopyAtomB = typename DefaultOperandB::SmemCopyAtom; - using GmemTiledCopyB = typename DefaultOperandB::GmemTiledCopy; - - using Stages = Int<3>; - - // This code section describes the epilogue part of the kernel - using EpilogueOp = cutlass::epilogue::thread::LinearCombination< - ElementOutput, - 128 / cutlass::sizeof_bits::value, - ElementAccumulator, - ElementComputeEpilogue>; - - using DispatchPolicy = cutlass::gemm::MainloopSm80CpAsync; - - // Define strides (mixed) - using StrideA = cutlass::detail::TagToStrideA_t; - using StrideB = cutlass::detail::TagToStrideB_t; - using StrideC = cutlass::detail::TagToStrideC_t; - using StrideD = cutlass::detail::TagToStrideC_t; - - using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< - StrideC, - StrideD, - EpilogueOp, - cutlass::gemm::EpilogueDefault>; - - // Mainloop - using CollectiveMainloop = cutlass::gemm::collective::CollectiveMma< - DispatchPolicy, - TileShape, - ElementInputA, - StrideA, - ElementInputB, - StrideB, - TiledMma, - GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, // A - GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity // B - >; - - using GemmKernel = cutlass::gemm::kernel::GemmUniversal< - Shape, - CollectiveMainloop, - CollectiveEpilogue - >; - - using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -}; - -using AmpereGemmTF32TF32FP32_CCCC = AmpereGemmTF32TF32FP32< - cutlass::layout::ColumnMajor, - cutlass::layout::ColumnMajor, - cutlass::layout::ColumnMajor, - cutlass::layout::ColumnMajor>; diff --git a/benchmarks/ampere/benchmarks.hpp b/benchmarks/ampere/benchmarks.hpp index 1723c69c1..686518bde 100644 --- a/benchmarks/ampere/benchmarks.hpp +++ b/benchmarks/ampere/benchmarks.hpp @@ -31,17 +31,112 @@ #pragma once -#include "../common/benchmark_runner.hpp" -#include "ampere_gemm_bf16_bf16_fp32.hpp" -#include "ampere_gemm_fp16_fp16_fp32.hpp" -#include "ampere_gemm_tf32_tf32_fp32.hpp" +#include "../benchmark_runner.hpp" +#include "gemm_configuration.hpp" -CUTLASS_CREATE_GEMM_BENCHMARK(AmpereGemmBF16BF16FP32_CCCC); -CUTLASS_CREATE_GEMM_BENCHMARK(AmpereGemmFP16FP16FP32_CCCC); -CUTLASS_CREATE_GEMM_BENCHMARK(AmpereGemmTF32TF32FP32_CCCC); +using AmpereGemmBF16BF16FP32_CCC = cutlass::gemm::device::GemmConfiguration< + cutlass::arch::Sm80, + cutlass::bfloat16_t, cutlass::layout::ColumnMajor, 8, + cutlass::bfloat16_t, cutlass::layout::ColumnMajor, 8, + float, cutlass::layout::ColumnMajor, + float>; + +using AmpereGemmBF16BF16FP32_CCC_kAlignmentA4 = cutlass::gemm::device::GemmConfiguration< + cutlass::arch::Sm80, + cutlass::bfloat16_t, cutlass::layout::ColumnMajor, 4, + cutlass::bfloat16_t, cutlass::layout::ColumnMajor, 8, + float, cutlass::layout::ColumnMajor, + float>; + +using AmpereGemmBF16BF16FP32_CCC_kAlignmentA1 = cutlass::gemm::device::GemmConfiguration< + cutlass::arch::Sm80, + cutlass::bfloat16_t, cutlass::layout::ColumnMajor, 1, + cutlass::bfloat16_t, cutlass::layout::ColumnMajor, 8, + float, cutlass::layout::ColumnMajor, + float>; + +using AmpereGemmBF16BF16FP32_CCC_kAlignment1 = cutlass::gemm::device::GemmConfiguration< + cutlass::arch::Sm80, + cutlass::bfloat16_t, cutlass::layout::ColumnMajor, 1, + cutlass::bfloat16_t, cutlass::layout::ColumnMajor, 1, + float, cutlass::layout::ColumnMajor, + float>; + +using AmpereGemmFP16FP16FP32_CCC = cutlass::gemm::device::GemmConfiguration< + cutlass::arch::Sm80, + cutlass::half_t, cutlass::layout::ColumnMajor, 8, + cutlass::half_t, cutlass::layout::ColumnMajor, 8, + float, cutlass::layout::ColumnMajor, + float>; + +using AmpereGemmFP16FP16FP32_CCC_kAlignmentA4 = cutlass::gemm::device::GemmConfiguration< + cutlass::arch::Sm80, + cutlass::half_t, cutlass::layout::ColumnMajor, 4, + cutlass::half_t, cutlass::layout::ColumnMajor, 8, + float, cutlass::layout::ColumnMajor, + float>; + +using AmpereGemmFP16FP16FP32_CCC_kAlignmentA1 = cutlass::gemm::device::GemmConfiguration< + cutlass::arch::Sm80, + cutlass::half_t, cutlass::layout::ColumnMajor, 1, + cutlass::half_t, cutlass::layout::ColumnMajor, 8, + float, cutlass::layout::ColumnMajor, + float>; + +using AmpereGemmFP16FP16FP32_CCC_kAlignment1 = cutlass::gemm::device::GemmConfiguration< + cutlass::arch::Sm80, + cutlass::half_t, cutlass::layout::ColumnMajor, 1, + cutlass::half_t, cutlass::layout::ColumnMajor, 1, + float, cutlass::layout::ColumnMajor, + float>; + +using AmpereGemmTF32TF32FP32_CCC = cutlass::gemm::device::GemmConfiguration< + cutlass::arch::Sm80, + float, cutlass::layout::ColumnMajor, 4, + float, cutlass::layout::ColumnMajor, 4, + float, cutlass::layout::ColumnMajor, + float>; + +using AmpereGemmTF32TF32FP32_CCC_kAlignmentA1 = cutlass::gemm::device::GemmConfiguration< + cutlass::arch::Sm80, + float, cutlass::layout::ColumnMajor, 1, + float, cutlass::layout::ColumnMajor, 4, + float, cutlass::layout::ColumnMajor, + float>; + +using AmpereGemmTF32TF32FP32_CCC_kAlignment1 = cutlass::gemm::device::GemmConfiguration< + cutlass::arch::Sm80, + float, cutlass::layout::ColumnMajor, 1, + float, cutlass::layout::ColumnMajor, 1, + float, cutlass::layout::ColumnMajor, + float>; + +CUTLASS_CREATE_GEMM_BENCHMARK(AmpereGemmBF16BF16FP32_CCC); +CUTLASS_CREATE_GEMM_BENCHMARK(AmpereGemmBF16BF16FP32_CCC_kAlignment1); +CUTLASS_CREATE_GEMM_BENCHMARK(AmpereGemmBF16BF16FP32_CCC_kAlignmentA1); +CUTLASS_CREATE_GEMM_BENCHMARK(AmpereGemmBF16BF16FP32_CCC_kAlignmentA4); + +CUTLASS_CREATE_GEMM_BENCHMARK(AmpereGemmFP16FP16FP32_CCC); +CUTLASS_CREATE_GEMM_BENCHMARK(AmpereGemmFP16FP16FP32_CCC_kAlignment1); +CUTLASS_CREATE_GEMM_BENCHMARK(AmpereGemmFP16FP16FP32_CCC_kAlignmentA1); +CUTLASS_CREATE_GEMM_BENCHMARK(AmpereGemmFP16FP16FP32_CCC_kAlignmentA4); + +CUTLASS_CREATE_GEMM_BENCHMARK(AmpereGemmTF32TF32FP32_CCC); +CUTLASS_CREATE_GEMM_BENCHMARK(AmpereGemmTF32TF32FP32_CCC_kAlignment1); +CUTLASS_CREATE_GEMM_BENCHMARK(AmpereGemmTF32TF32FP32_CCC_kAlignmentA1); static void register_benchmarks() { - CUTLASS_BENCHMARK(AmpereGemmBF16BF16FP32_CCCC); - CUTLASS_BENCHMARK(AmpereGemmFP16FP16FP32_CCCC); - CUTLASS_BENCHMARK(AmpereGemmTF32TF32FP32_CCCC); + CUTLASS_BENCHMARK(AmpereGemmBF16BF16FP32_CCC); + CUTLASS_BENCHMARK(AmpereGemmBF16BF16FP32_CCC_kAlignment1); + CUTLASS_BENCHMARK(AmpereGemmBF16BF16FP32_CCC_kAlignmentA1); + CUTLASS_BENCHMARK(AmpereGemmBF16BF16FP32_CCC_kAlignmentA4); + + CUTLASS_BENCHMARK(AmpereGemmFP16FP16FP32_CCC); + CUTLASS_BENCHMARK(AmpereGemmFP16FP16FP32_CCC_kAlignment1); + CUTLASS_BENCHMARK(AmpereGemmFP16FP16FP32_CCC_kAlignmentA1); + CUTLASS_BENCHMARK(AmpereGemmFP16FP16FP32_CCC_kAlignmentA4); + + CUTLASS_BENCHMARK(AmpereGemmTF32TF32FP32_CCC); + CUTLASS_BENCHMARK(AmpereGemmTF32TF32FP32_CCC_kAlignment1); + CUTLASS_BENCHMARK(AmpereGemmTF32TF32FP32_CCC_kAlignmentA1); } diff --git a/benchmarks/ampere/gemm_configuration.hpp b/benchmarks/ampere/gemm_configuration.hpp index f1767e543..27b4abf11 100644 --- a/benchmarks/ampere/gemm_configuration.hpp +++ b/benchmarks/ampere/gemm_configuration.hpp @@ -31,242 +31,508 @@ #pragma once -#include "cutlass/half.h" +#include "cute/atom/mma_atom.hpp" +#include "cute/atom/copy_atom.hpp" + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/arch/arch.h" +#include "cutlass/arch/mma.h" #include "cutlass/layout/layout.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/collective/collective_mma.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" -#include "cute/swizzle.hpp" -#include "cute/layout.hpp" -#include "cute/arch/copy_sm75.hpp" -#include "cute/arch/copy_sm80.hpp" -#include "cute/atom/copy_atom.hpp" +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" using namespace cute; -template -struct DefaultGemm_TensorOpSm80_OperandA; - -template -struct DefaultGemm_TensorOpSm80_OperandB; +namespace cutlass { +namespace gemm { +namespace device { + +template< + class ArchTag, + class ElementA, class LayoutA, int kAlignmentA, + class ElementB, class LayoutB, int kAlignmentB, + class ElementC, class LayoutC, + class ElementAccumulator> +struct GemmConfiguration { + static_assert(sizeof(ElementA) == 0, "No valid GemmConfiguration configuration exists."); +}; ///////////////////////////////////////////////////////////////////////// // half -/// Operand A - Row-major (K-Major) -template <> -struct DefaultGemm_TensorOpSm80_OperandA -{ - // Smem - using SmemLayoutAtom = decltype( - composition(Swizzle<3,3,3>{}, - Layout, - Stride<_64, _1>>{})); - using SmemCopyAtom = Copy_Atom; - - // Gmem - using GmemTiledCopy = decltype( - make_tiled_copy(Copy_Atom, half_t>{}, - Layout, - Stride< _8,_1>>{}, - Layout>{})); -}; +namespace detail { +template +struct Gemm_OperandA; + +template +struct Gemm_OperandB; /// Operand A - Column-major (M-major) -template -struct DefaultGemm_TensorOpSm80_OperandA -{ - // Smem - using SmemLayoutAtom = decltype( - composition(Swizzle<3,3,3>{}, - Layout, - Stride< _1,_64>>{})); - using SmemCopyAtom = Copy_Atom; - - // Gmem - using GmemTiledCopy = decltype( - make_tiled_copy(Copy_Atom, half_t>{}, - Layout, - Stride< _1,_16>>{}, - Layout>{})); +template +struct Gemm_OperandA { + // Smem + using SmemLayoutAtom = decltype( + composition(Swizzle<3, 3, 3>{}, + Layout, + Stride<_1, _64> >{})); + using SmemCopyAtom = Copy_Atom; + + // Gmem + using GmemTiledCopy = decltype( + make_tiled_copy(Copy_Atom, half_t>{}, + Layout, + Stride<_1, _16> >{}, + Layout >{})); +}; + +template +struct Gemm_OperandA { + // Smem + using SmemLayoutAtom = decltype( + composition(Swizzle<3, 3, 3>{}, + Layout, + Stride<_1, _64> >{})); + using SmemCopyAtom = Copy_Atom; + + // Gmem + using GmemTiledCopy = decltype( + make_tiled_copy(Copy_Atom, half_t>{}, + Layout, + Stride<_1, _16> >{}, + Layout >{})); +}; + +template +struct Gemm_OperandA { + // Smem + using SmemLayoutAtom = decltype( + composition(Swizzle<3, 3, 3>{}, + Layout, + Stride<_1, _64> >{})); + using SmemCopyAtom = Copy_Atom; + + // Gmem + using GmemTiledCopy = decltype( + make_tiled_copy(Copy_Atom, half_t>{}, + Layout, + Stride<_1, _16> >{}, + Layout >{})); }; /// Operand A - Row-major (K-Major) -template <> -struct DefaultGemm_TensorOpSm80_OperandA -{ - // Smem - using SmemLayoutAtom = decltype( - composition(Swizzle<2,3,3>{}, - Layout, - Stride<_32, _1>>{})); - using SmemCopyAtom = Copy_Atom; - - // Gmem - using GmemTiledCopy = decltype( - make_tiled_copy(Copy_Atom, half_t>{}, - Layout, - Stride< _4,_1>>{}, - Layout>{})); +template<> +struct Gemm_OperandA { + // Smem + using SmemLayoutAtom = decltype( + composition(Swizzle<2, 3, 3>{}, + Layout, + Stride<_32, _1> >{})); + using SmemCopyAtom = Copy_Atom; + + // Gmem + using GmemTiledCopy = decltype( + make_tiled_copy(Copy_Atom, half_t>{}, + Layout, + Stride<_4, _1> >{}, + Layout >{})); +}; + +template<> +struct Gemm_OperandA { + // Smem + using SmemLayoutAtom = decltype( + composition(Swizzle<2, 3, 3>{}, + Layout, + Stride<_32, _1> >{})); + using SmemCopyAtom = Copy_Atom; + + // Gmem + using GmemTiledCopy = decltype( + make_tiled_copy(Copy_Atom, half_t>{}, + Layout, + Stride<_4, _1> >{}, + Layout >{})); }; -// Because the F32F16 TiledMMA is A-B symmetric, we can reuse the DefaultOperands +// Because the F32F16 TiledMMA is A-B symmetric, we can reuse the Operands // Operand B - Column-Major (K-major) -template -struct DefaultGemm_TensorOpSm80_OperandB - : DefaultGemm_TensorOpSm80_OperandA -{}; +template +struct Gemm_OperandB + : Gemm_OperandA { +}; // Operand B - Row-Major (N-major) -template -struct DefaultGemm_TensorOpSm80_OperandB - : DefaultGemm_TensorOpSm80_OperandA -{}; +template +struct Gemm_OperandB + : Gemm_OperandA { +}; +} // namespace details + +template +struct GemmConfiguration< + arch::Sm80, + half_t, LayoutA, kAlignmentA, + half_t, LayoutB, kAlignmentB, + float, LayoutC, + float> { + using TileShape = Shape<_128, _128, _32>; + using DispatchPolicy = MainloopSm80CpAsync<3>; + using TiledMma = TiledMMA< + MMA_Atom, + Layout >, + Tile<_32, _32, _16> >; + + // A + using OperandA = detail::Gemm_OperandA< + half_t, LayoutA, kAlignmentA, 32>; + using SmemLayoutAtomA = typename OperandA::SmemLayoutAtom; // M, K + using SmemCopyAtomA = typename OperandA::SmemCopyAtom; + using GmemTiledCopyA = typename OperandA::GmemTiledCopy; + + // B + using OperandB = detail::Gemm_OperandB< + half_t, LayoutB, kAlignmentB, 32>; + using SmemLayoutAtomB = typename OperandB::SmemLayoutAtom; // N, K + using SmemCopyAtomB = typename OperandB::SmemCopyAtom; + using GmemTiledCopyB = typename OperandB::GmemTiledCopy; + + // Mainloop + using CollectiveMainloop = collective::CollectiveMma< + DispatchPolicy, TileShape, + half_t, TagToStrideA_t, + half_t, TagToStrideB_t, + TiledMma, + GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, identity, // A + GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, identity // B + >; + + // Epilogue + using CollectiveEpilogue = epilogue::collective::DefaultEpilogue< + TagToStrideC_t, + TagToStrideC_t, + epilogue::thread::LinearCombination, + EpilogueDefault>; + + using GemmKernel = kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = GemmUniversalAdapter; +}; ///////////////////////////////////////////////////////////////////////// // Bfloat -/// Operand A - Row-major (K-Major) -template <> -struct DefaultGemm_TensorOpSm80_OperandA -{ - // Smem - using SmemLayoutAtom = decltype( - composition(Swizzle<3,3,3>{}, - Layout, - Stride<_64, _1>>{})); - using SmemCopyAtom = Copy_Atom; - - // Gmem - using GmemTiledCopy = decltype( - make_tiled_copy(Copy_Atom, bfloat16_t>{}, - Layout, - Stride< _8,_1>>{}, - Layout>{})); +namespace detail { +/// Operand A - Column-major (M-major) +template +struct Gemm_OperandA { + // Smem + using SmemLayoutAtom = decltype( + composition(Swizzle<3, 3, 3>{}, + Layout, + Stride<_1, _64> >{})); + using SmemCopyAtom = Copy_Atom; + + // Gmem + using GmemTiledCopy = decltype( + make_tiled_copy(Copy_Atom, bfloat16_t>{}, + Layout, + Stride<_1, _16> >{}, + Layout >{})); }; -/// Operand A - Column-major (M-major) -template -struct DefaultGemm_TensorOpSm80_OperandA -{ - // Smem - using SmemLayoutAtom = decltype( - composition(Swizzle<3,3,3>{}, - Layout, - Stride< _1,_64>>{})); - using SmemCopyAtom = Copy_Atom; - - // Gmem - using GmemTiledCopy = decltype( - make_tiled_copy(Copy_Atom, bfloat16_t>{}, - Layout, - Stride< _1,_16>>{}, - Layout>{})); +template +struct Gemm_OperandA { + // Smem + using SmemLayoutAtom = decltype( + composition(Swizzle<3, 3, 3>{}, + Layout, + Stride<_1, _64> >{})); + using SmemCopyAtom = Copy_Atom; + + // Gmem + using GmemTiledCopy = decltype( + make_tiled_copy(Copy_Atom, bfloat16_t>{}, + Layout, + Stride<_1, _16> >{}, + Layout >{})); +}; + +template +struct Gemm_OperandA { + // Smem + using SmemLayoutAtom = decltype( + composition(Swizzle<3, 3, 3>{}, + Layout, + Stride<_1, _64> >{})); + using SmemCopyAtom = Copy_Atom; + + // Gmem + using GmemTiledCopy = decltype( + make_tiled_copy(Copy_Atom, bfloat16_t>{}, + Layout, + Stride<_1, _16> >{}, + Layout >{})); }; /// Operand A - Row-major (K-Major) -template <> -struct DefaultGemm_TensorOpSm80_OperandA -{ - // Smem - using SmemLayoutAtom = decltype( - composition(Swizzle<2,3,3>{}, - Layout, - Stride<_32, _1>>{})); - using SmemCopyAtom = Copy_Atom; - - // Gmem - using GmemTiledCopy = decltype( - make_tiled_copy(Copy_Atom, bfloat16_t>{}, - Layout, - Stride< _4,_1>>{}, - Layout>{})); +template<> +struct Gemm_OperandA { + // Smem + using SmemLayoutAtom = decltype( + composition(Swizzle<2, 3, 3>{}, + Layout, + Stride<_32, _1> >{})); + using SmemCopyAtom = Copy_Atom; + + // Gmem + using GmemTiledCopy = decltype( + make_tiled_copy(Copy_Atom, bfloat16_t>{}, + Layout, + Stride<_4, _1> >{}, + Layout >{})); }; -// Because the F32F16 TiledMMA is A-B symmetric, we can reuse the DefaultOperands +template<> +struct Gemm_OperandA { + // Smem + using SmemLayoutAtom = decltype( + composition(Swizzle<2, 3, 3>{}, + Layout, + Stride<_32, _1> >{})); + using SmemCopyAtom = Copy_Atom; + + // Gmem + using GmemTiledCopy = decltype( + make_tiled_copy(Copy_Atom, bfloat16_t>{}, + Layout, + Stride<_4, _1> >{}, + Layout >{})); +}; + +// Because the F32F16 TiledMMA is A-B symmetric, we can reuse the Operands // Operand B - Column-Major (K-major) -template -struct DefaultGemm_TensorOpSm80_OperandB - : DefaultGemm_TensorOpSm80_OperandA -{}; +template +struct Gemm_OperandB + : Gemm_OperandA { +}; // Operand B - Row-Major (N-major) -template -struct DefaultGemm_TensorOpSm80_OperandB - : DefaultGemm_TensorOpSm80_OperandA -{}; +template +struct Gemm_OperandB + : Gemm_OperandA { +}; + +} // namespace detail + +template +struct GemmConfiguration< + arch::Sm80, + bfloat16_t, LayoutA, kAlignmentA, + bfloat16_t, LayoutB, kAlignmentB, + float, LayoutC, + float> { + using TileShape = Shape<_128, _128, _32>; + using DispatchPolicy = MainloopSm80CpAsync<3>; + using TiledMma = TiledMMA< + MMA_Atom, + Layout >, + Tile<_32, _32, _16> >; + + // A + using OperandA = detail::Gemm_OperandA< + bfloat16_t, LayoutA, kAlignmentA, 32>; + using SmemLayoutAtomA = typename OperandA::SmemLayoutAtom; // M, K + using SmemCopyAtomA = typename OperandA::SmemCopyAtom; + using GmemTiledCopyA = typename OperandA::GmemTiledCopy; + + // B + using OperandB = detail::Gemm_OperandB< + bfloat16_t, LayoutB, kAlignmentB, 32>; + using SmemLayoutAtomB = typename OperandB::SmemLayoutAtom; // N, K + using SmemCopyAtomB = typename OperandB::SmemCopyAtom; + using GmemTiledCopyB = typename OperandB::GmemTiledCopy; + + // Mainloop + using CollectiveMainloop = collective::CollectiveMma< + DispatchPolicy, TileShape, + bfloat16_t, TagToStrideA_t, + bfloat16_t, TagToStrideB_t, + TiledMma, + GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, identity, // A + GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, identity // B + >; + + // Epilogue + using CollectiveEpilogue = epilogue::collective::DefaultEpilogue< + TagToStrideC_t, + TagToStrideC_t, + epilogue::thread::LinearCombination, + EpilogueDefault>; + + using GemmKernel = kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = GemmUniversalAdapter; +}; ///////////////////////////////////////////////////////////////////////// // TFloat32 +namespace detail { /// Operand A - Row-major (K-major) (kBlock = 32) -template <> -struct DefaultGemm_TensorOpSm80_OperandA -{ - // Smem - using SmemLayoutAtom = decltype( - composition(Swizzle<3,2,3>{}, - Layout, - Stride<_32, _1>>{})); - using SmemCopyAtom = Copy_Atom; - - // Gmem - using GmemTiledCopy = decltype( - make_tiled_copy(Copy_Atom, tfloat32_t>{}, - Layout, - Stride< _8,_1>>{}, - Layout>{})); +template<> +struct Gemm_OperandA { + // Smem + using SmemLayoutAtom = decltype( + composition(Swizzle<3, 2, 3>{}, + Layout, + Stride<_32, _1> >{})); + using SmemCopyAtom = Copy_Atom; + + // Gmem + using GmemTiledCopy = decltype( + make_tiled_copy(Copy_Atom, tfloat32_t>{}, + Layout, + Stride<_8, _1> >{}, + Layout >{})); }; -/// Operand A - Row-major (K-major) (kBlock = 16) -template <> -struct DefaultGemm_TensorOpSm80_OperandA -{ - // Smem - using SmemLayoutAtom = decltype( - composition(Swizzle<2,2,3>{}, - Layout, - Stride<_16, _1>>{})); - using SmemCopyAtom = Copy_Atom; - // Gmem - using GmemTiledCopy = decltype( - make_tiled_copy(Copy_Atom, tfloat32_t>{}, - Layout, - Stride< _4,_1>>{}, - Layout>{})); +template<> +struct Gemm_OperandA { + // Smem + using SmemLayoutAtom = decltype( + composition(Swizzle<3, 2, 3>{}, + Layout, + Stride<_32, _1> >{})); + using SmemCopyAtom = Copy_Atom; + + // Gmem + using GmemTiledCopy = decltype( + make_tiled_copy(Copy_Atom, tfloat32_t>{}, + Layout, + Stride<_8, _1> >{}, + Layout >{})); }; /// Operand A - Column-major (M-major) -template -struct DefaultGemm_TensorOpSm80_OperandA -{ - // Smem - using SmemLayoutAtom = decltype( - composition(Swizzle<2,3,2>{}, - Layout, - Stride< _1,_32>>{})); - using SmemCopyAtom = Copy_Atom, tfloat32_t>; - // Gmem - using GmemTiledCopy = decltype( - make_tiled_copy(Copy_Atom, tfloat32_t>{}, - Layout, - Stride< _1,_16>>{}, - Layout>{})); +template +struct Gemm_OperandA { + // Smem + using SmemLayoutAtom = decltype( + composition(Swizzle<2, 3, 2>{}, + Layout, + Stride<_1, _32> >{})); + using SmemCopyAtom = Copy_Atom, tfloat32_t>; + // Gmem + using GmemTiledCopy = decltype( + make_tiled_copy(Copy_Atom, tfloat32_t>{}, + Layout, + Stride<_1, _16> >{}, + Layout >{})); +}; + +template +struct Gemm_OperandA { + // Smem + using SmemLayoutAtom = decltype( + composition(Swizzle<2, 3, 2>{}, + Layout, + Stride<_1, _32> >{})); + using SmemCopyAtom = Copy_Atom, tfloat32_t>; + // Gmem + using GmemTiledCopy = decltype( + make_tiled_copy(Copy_Atom, tfloat32_t>{}, + Layout, + Stride<_1, _16> >{}, + Layout >{})); }; -// Because the TF32 TiledMMA is A-B symmetric, we can reuse the DefaultOperands +// Because the TF32 TiledMMA is A-B symmetric, we can reuse the Operands // Operand B - Column-Major (K-major) -template -struct DefaultGemm_TensorOpSm80_OperandB - : DefaultGemm_TensorOpSm80_OperandA -{}; +template +struct Gemm_OperandB + : Gemm_OperandA { +}; // Operand B - Row-Major (N-major) -template -struct DefaultGemm_TensorOpSm80_OperandB - : DefaultGemm_TensorOpSm80_OperandA -{}; +template +struct Gemm_OperandB + : Gemm_OperandA { +}; +} // namespace details + +template +struct GemmConfiguration< + arch::Sm80, + float, LayoutA, kAlignmentA, + float, LayoutB, kAlignmentB, + float, LayoutC, + float> { + using TileShape = Shape<_128, _128, _32>; + using DispatchPolicy = MainloopSm80CpAsync<3>; + using TiledMma = TiledMMA< + MMA_Atom, + Layout, Stride<_2, _1, _1>>, + Tile<_32,_32,_8> >; + + // A + using OperandA = detail::Gemm_OperandA< + tfloat32_t, LayoutA, kAlignmentA, 32>; + using SmemLayoutAtomA = typename OperandA::SmemLayoutAtom; // M, K + using SmemCopyAtomA = typename OperandA::SmemCopyAtom; + using GmemTiledCopyA = typename OperandA::GmemTiledCopy; + + // B + using OperandB = detail::Gemm_OperandB< + tfloat32_t, LayoutB, kAlignmentB, 32>; + using SmemLayoutAtomB = typename OperandB::SmemLayoutAtom; // N, K + using SmemCopyAtomB = typename OperandB::SmemCopyAtom; + using GmemTiledCopyB = typename OperandB::GmemTiledCopy; + + // Mainloop + using CollectiveMainloop = collective::CollectiveMma< + DispatchPolicy, TileShape, + tfloat32_t, TagToStrideA_t, + tfloat32_t, TagToStrideB_t, + TiledMma, + GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, identity, // A + GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, identity // B + >; + + // Epilogue + using CollectiveEpilogue = epilogue::collective::DefaultEpilogue< + TagToStrideC_t, + TagToStrideC_t, + epilogue::thread::LinearCombination, + EpilogueDefault>; + + using GemmKernel = kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = GemmUniversalAdapter; +}; + +} // namespace device +} // namespace gemm +} // namespace cutlass diff --git a/benchmarks/ampere/input.in b/benchmarks/ampere/input.in index 43c835865..52e27bf22 100644 --- a/benchmarks/ampere/input.in +++ b/benchmarks/ampere/input.in @@ -1,20 +1,67 @@ # BFloat16 benchmarks -AmpereGemmBF16BF16FP32_CCCC --bm_name=bf16_bf16_fp32 --l=1 --m=128 --k=128 --n=128 -AmpereGemmBF16BF16FP32_CCCC --bm_name=bf16_bf16_fp32 --l=1 --m=1024 --k=1024 --n=1024 -AmpereGemmBF16BF16FP32_CCCC --bm_name=bf16_bf16_fp32 --l=1 --m=4096 --k=4096 --n=4096 -AmpereGemmBF16BF16FP32_CCCC --bm_name=bf16_bf16_fp32 --l=1 --m=8192 --k=8192 --n=8192 -AmpereGemmBF16BF16FP32_CCCC --bm_name=bf16_bf16_fp32 --l=2 --m=2048 --k=96 --n=2048 +AmpereGemmBF16BF16FP32_CCC --bm_name=bf16_bf16_fp32 --l=1 --m=128 --k=128 --n=128 +AmpereGemmBF16BF16FP32_CCC --bm_name=bf16_bf16_fp32 --l=1 --m=1024 --k=1024 --n=1024 +AmpereGemmBF16BF16FP32_CCC --bm_name=bf16_bf16_fp32 --l=1 --m=4096 --k=4096 --n=4096 +AmpereGemmBF16BF16FP32_CCC --bm_name=bf16_bf16_fp32 --l=1 --m=8192 --k=8192 --n=8192 +AmpereGemmBF16BF16FP32_CCC --bm_name=bf16_bf16_fp32 --l=256 --m=2048 --k=96 --n=2048 +AmpereGemmBF16BF16FP32_CCC_kAlignment1 --bm_name=bf16_bf16_fp32 --l=1 --m=512 --k=379 --n=687 +AmpereGemmBF16BF16FP32_CCC_kAlignmentA1 --bm_name=bf16_bf16_fp32 --l=1 --m=1 --k=5120 --n=5120 +AmpereGemmBF16BF16FP32_CCC_kAlignmentA1 --bm_name=bf16_bf16_fp32 --l=1 --m=1 --k=13824 --n=5120 +AmpereGemmBF16BF16FP32_CCC_kAlignmentA1 --bm_name=bf16_bf16_fp32 --l=1 --m=1 --k=5120 --n=13824 +AmpereGemmBF16BF16FP32_CCC_kAlignmentA4 --bm_name=bf16_bf16_fp32 --l=1 --m=4 --k=4096 --n=250880 +AmpereGemmBF16BF16FP32_CCC_kAlignmentA4 --bm_name=bf16_bf16_fp32 --l=1 --m=4 --k=16384 --n=4096 +AmpereGemmBF16BF16FP32_CCC --bm_name=bf16_bf16_fp32 --l=1 --m=256 --k=4096 --n=4096 +AmpereGemmBF16BF16FP32_CCC --bm_name=bf16_bf16_fp32 --l=1 --m=512 --k=4096 --n=32000 +AmpereGemmBF16BF16FP32_CCC --bm_name=bf16_bf16_fp32 --l=1 --m=1024 --k=28672 --n=8192 +AmpereGemmBF16BF16FP32_CCC --bm_name=bf16_bf16_fp32 --l=1 --m=16000 --k=12544 --n=1024 +AmpereGemmBF16BF16FP32_CCC --bm_name=bf16_bf16_fp32 --l=1 --m=1024 --k=1024 --n=1024 +AmpereGemmBF16BF16FP32_CCC --bm_name=bf16_bf16_fp32 --l=1 --m=3072 --k=4096 --n=3072 +AmpereGemmBF16BF16FP32_CCC_kAlignmentA4 --bm_name=bf16_bf16_fp32 --l=1 --m=4 --k=4096 --n=12288 +AmpereGemmBF16BF16FP32_CCC_kAlignmentA4 --bm_name=bf16_bf16_fp32 --l=1 --m=4 --k=14336 --n=5376 +AmpereGemmBF16BF16FP32_CCC_kAlignmentA4 --bm_name=bf16_bf16_fp32 --l=1 --m=4 --k=8192 --n=2048 + # FP16 benchmarks -AmpereGemmFP16FP16FP32_CCCC --bm_name=fp16_fp16_fp32 --l=1 --m=128 --k=128 --n=128 -AmpereGemmFP16FP16FP32_CCCC --bm_name=fp16_fp16_fp32 --l=1 --m=1024 --k=1024 --n=1024 -AmpereGemmFP16FP16FP32_CCCC --bm_name=fp16_fp16_fp32 --l=1 --m=4096 --k=4096 --n=4096 -AmpereGemmFP16FP16FP32_CCCC --bm_name=fp16_fp16_fp32 --l=1 --m=8192 --k=8192 --n=8192 -AmpereGemmFP16FP16FP32_CCCC --bm_name=fp16_fp16_fp32 --l=2 --m=2048 --k=96 --n=2048 +AmpereGemmFP16FP16FP32_CCC --bm_name=fp16_fp16_fp32 --l=1 --m=128 --k=128 --n=128 +AmpereGemmFP16FP16FP32_CCC --bm_name=fp16_fp16_fp32 --l=1 --m=1024 --k=1024 --n=1024 +AmpereGemmFP16FP16FP32_CCC --bm_name=fp16_fp16_fp32 --l=1 --m=4096 --k=4096 --n=4096 +AmpereGemmFP16FP16FP32_CCC --bm_name=fp16_fp16_fp32 --l=1 --m=8192 --k=8192 --n=8192 +AmpereGemmFP16FP16FP32_CCC --bm_name=fp16_fp16_fp32 --l=256 --m=2048 --k=96 --n=2048 +AmpereGemmFP16FP16FP32_CCC_kAlignment1 --bm_name=fp16_fp16_fp32 --l=1 --m=512 --k=379 --n=687 +AmpereGemmFP16FP16FP32_CCC_kAlignmentA1 --bm_name=fp16_fp16_fp32 --l=1 --m=1 --k=5120 --n=5120 +AmpereGemmFP16FP16FP32_CCC_kAlignmentA1 --bm_name=fp16_fp16_fp32 --l=1 --m=1 --k=13824 --n=5120 +AmpereGemmFP16FP16FP32_CCC_kAlignmentA1 --bm_name=fp16_fp16_fp32 --l=1 --m=1 --k=5120 --n=13824 +AmpereGemmFP16FP16FP32_CCC_kAlignmentA4 --bm_name=fp16_fp16_fp32 --l=1 --m=4 --k=4096 --n=250880 +AmpereGemmFP16FP16FP32_CCC_kAlignmentA4 --bm_name=fp16_fp16_fp32 --l=1 --m=4 --k=16384 --n=4096 +AmpereGemmFP16FP16FP32_CCC --bm_name=fp16_fp16_fp32 --l=1 --m=256 --k=4096 --n=4096 +AmpereGemmFP16FP16FP32_CCC --bm_name=fp16_fp16_fp32 --l=1 --m=512 --k=4096 --n=32000 +AmpereGemmFP16FP16FP32_CCC --bm_name=fp16_fp16_fp32 --l=1 --m=1024 --k=28672 --n=8192 +AmpereGemmFP16FP16FP32_CCC --bm_name=fp16_fp16_fp32 --l=1 --m=16000 --k=12544 --n=1024 +AmpereGemmFP16FP16FP32_CCC --bm_name=fp16_fp16_fp32 --l=1 --m=1024 --k=1024 --n=1024 +AmpereGemmFP16FP16FP32_CCC --bm_name=fp16_fp16_fp32 --l=1 --m=3072 --k=4096 --n=3072 +AmpereGemmFP16FP16FP32_CCC_kAlignmentA4 --bm_name=fp16_fp16_fp32 --l=1 --m=4 --k=4096 --n=12288 +AmpereGemmFP16FP16FP32_CCC_kAlignmentA4 --bm_name=fp16_fp16_fp32 --l=1 --m=4 --k=14336 --n=5376 +AmpereGemmFP16FP16FP32_CCC_kAlignmentA4 --bm_name=fp16_fp16_fp32 --l=1 --m=4 --k=8192 --n=2048 + # TF32 benchmarks -AmpereGemmTF32TF32FP32_CCCC --bm_name=tf32_tf32_fp32 --l=1 --m=128 --k=128 --n=128 -AmpereGemmTF32TF32FP32_CCCC --bm_name=tf32_tf32_fp32 --l=1 --m=1024 --k=1024 --n=1024 -AmpereGemmTF32TF32FP32_CCCC --bm_name=tf32_tf32_fp32 --l=1 --m=4096 --k=4096 --n=4096 -AmpereGemmTF32TF32FP32_CCCC --bm_name=tf32_tf32_fp32 --l=1 --m=8192 --k=8192 --n=8192 -AmpereGemmTF32TF32FP32_CCCC --bm_name=tf32_tf32_fp32 --l=2 --m=2048 --k=96 --n=2048 +AmpereGemmTF32TF32FP32_CCC --bm_name=tf32_tf32_fp32 --l=1 --m=128 --k=128 --n=128 +AmpereGemmTF32TF32FP32_CCC --bm_name=tf32_tf32_fp32 --l=1 --m=1024 --k=1024 --n=1024 +AmpereGemmTF32TF32FP32_CCC --bm_name=tf32_tf32_fp32 --l=1 --m=4096 --k=4096 --n=4096 +AmpereGemmTF32TF32FP32_CCC --bm_name=tf32_tf32_fp32 --l=1 --m=8192 --k=8192 --n=8192 +AmpereGemmTF32TF32FP32_CCC --bm_name=tf32_tf32_fp32 --l=256 --m=2048 --k=96 --n=2048 +AmpereGemmTF32TF32FP32_CCC_kAlignment1 --bm_name=tf32_tf32_fp32 --l=1 --m=512 --k=379 --n=687 +AmpereGemmTF32TF32FP32_CCC_kAlignmentA1 --bm_name=tf32_tf32_fp32 --l=1 --m=1 --k=5120 --n=5120 +AmpereGemmTF32TF32FP32_CCC_kAlignmentA1 --bm_name=tf32_tf32_fp32 --l=1 --m=1 --k=13824 --n=5120 +AmpereGemmTF32TF32FP32_CCC_kAlignmentA1 --bm_name=tf32_tf32_fp32 --l=1 --m=1 --k=5120 --n=13824 +AmpereGemmTF32TF32FP32_CCC --bm_name=tf32_tf32_fp32 --l=1 --m=4 --k=4096 --n=250880 +AmpereGemmTF32TF32FP32_CCC --bm_name=tf32_tf32_fp32 --l=1 --m=4 --k=16384 --n=4096 +AmpereGemmTF32TF32FP32_CCC --bm_name=tf32_tf32_fp32 --l=1 --m=256 --k=4096 --n=4096 +AmpereGemmTF32TF32FP32_CCC --bm_name=tf32_tf32_fp32 --l=1 --m=512 --k=4096 --n=32000 +AmpereGemmTF32TF32FP32_CCC --bm_name=tf32_tf32_fp32 --l=1 --m=1024 --k=28672 --n=8192 +AmpereGemmTF32TF32FP32_CCC --bm_name=tf32_tf32_fp32 --l=1 --m=16000 --k=12544 --n=1024 +AmpereGemmTF32TF32FP32_CCC --bm_name=tf32_tf32_fp32 --l=1 --m=1024 --k=1024 --n=1024 +AmpereGemmTF32TF32FP32_CCC --bm_name=tf32_tf32_fp32 --l=1 --m=3072 --k=4096 --n=3072 +AmpereGemmTF32TF32FP32_CCC --bm_name=tf32_tf32_fp32 --l=1 --m=4 --k=4096 --n=12288 +AmpereGemmTF32TF32FP32_CCC --bm_name=tf32_tf32_fp32 --l=1 --m=4 --k=14336 --n=5376 +AmpereGemmTF32TF32FP32_CCC --bm_name=tf32_tf32_fp32 --l=1 --m=4 --k=8192 --n=2048 diff --git a/benchmarks/common/benchmark_runner.hpp b/benchmarks/benchmark_runner.hpp similarity index 96% rename from benchmarks/common/benchmark_runner.hpp rename to benchmarks/benchmark_runner.hpp index 7bac43119..11eed51ea 100644 --- a/benchmarks/common/benchmark_runner.hpp +++ b/benchmarks/benchmark_runner.hpp @@ -231,11 +231,15 @@ struct BenchmarkRunnerGemm { stride_C = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, L)); stride_D = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, L)); - block_A.reset(M * K * L); - block_B.reset(K * N * L); - block_C.reset(M * N * L); - block_D.reset(M * N * L); - block_ref_D.reset(M * N * L); + std::size_t block_A_size = std::size_t(M) * std::size_t(K) * std::size_t(L); + std::size_t block_B_size = std::size_t(K) * std::size_t(N) * std::size_t(L); + std::size_t block_C_size = std::size_t(M) * std::size_t(N) * std::size_t(L); + + block_A.reset(block_A_size); + block_B.reset(block_B_size); + block_C.reset(block_C_size); + block_D.reset(block_C_size); + block_ref_D.reset(block_C_size); initialize_block(block_A, seed + 2023); initialize_block(block_B, seed + 2022); @@ -276,7 +280,7 @@ struct BenchmarkRunnerGemm { // Verify that the result is correct bool passed = verify(problem_size, options.alpha, options.beta); if(not passed) { - throw std::runtime_error("Disposition Failed."); + state.SkipWithError("Disposition Failed."); } auto tflop = ((2.0 * options.m * options.n * options.k * options.l) * 1e-12); diff --git a/benchmarks/main.cpp b/benchmarks/main.cpp index 2f2a91fc9..59cfa51cd 100644 --- a/benchmarks/main.cpp +++ b/benchmarks/main.cpp @@ -29,21 +29,21 @@ * **************************************************************************************************/ -#include "common/benchmark_runner.hpp" +#include "cutlass/cutlass.h" #include "cutlass/kernel_hardware_info.h" #include "cutlass/util/command_line.h" -#include -#include -#include -#include - -#if (SYCL_NVIDIA_TARGET || !CUTLASS_ENABLE_SYCL) +#include "benchmark_runner.hpp" +#if defined(SYCL_NVIDIA_TARGET) || !defined(CUTLASS_ENABLE_SYCL) #include "ampere/benchmarks.hpp" -#elif (SYCL_INTEL_TARGET) +#elif defined(SYCL_INTEL_TARGET) #include "pvc/benchmarks.hpp" #endif +#include +#include +#include +#include /////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/benchmarks/pvc/benchmarks.hpp b/benchmarks/pvc/benchmarks.hpp index ef51e1de4..a8ebc0b67 100644 --- a/benchmarks/pvc/benchmarks.hpp +++ b/benchmarks/pvc/benchmarks.hpp @@ -31,11 +31,18 @@ #pragma once -#include "../common/benchmark_runner.hpp" -#include "pvc_gemm_bf16_bf16_fp32.cpp" +#include "../benchmark_runner.hpp" +#include "gemm_configuration.hpp" -CUTLASS_CREATE_GEMM_BENCHMARK(PvcGemmBF16BF16FP32_RRRR); +using PvcGemmBF16BF16FP32_RRR = cutlass::gemm::device::GemmConfiguration< + cutlass::arch::IntelPVC, + cutlass::bfloat16_t, cutlass::layout::RowMajor, + cutlass::bfloat16_t, cutlass::layout::RowMajor, + float, cutlass::layout::RowMajor, + float>; + +CUTLASS_CREATE_GEMM_BENCHMARK(PvcGemmBF16BF16FP32_RRR); static void register_benchmarks() { - CUTLASS_BENCHMARK(PvcGemmBF16BF16FP32_RRRR); + CUTLASS_BENCHMARK(PvcGemmBF16BF16FP32_RRR); } diff --git a/benchmarks/pvc/gemm_configuration.hpp b/benchmarks/pvc/gemm_configuration.hpp new file mode 100644 index 000000000..332f37395 --- /dev/null +++ b/benchmarks/pvc/gemm_configuration.hpp @@ -0,0 +1,151 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2024 Codeplay Software Ltd. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cute/atom/mma_atom.hpp" +#include "cute/atom/copy_atom.hpp" + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/arch/arch.h" +#include "cutlass/arch/mma.h" +#include "cutlass/layout/layout.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/collective/collective_mma.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" + +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" + +using namespace cute; + +namespace cutlass { +namespace gemm { +namespace device { + +template< + class ArchTag, + class ElementA, class LayoutA, + class ElementB, class LayoutB, + class ElementC, class LayoutC, + class ElementAccumulator> +struct GemmConfiguration { + static_assert(sizeof(ElementA) == 0, "No valid GemmConfiguration configuration exists."); +}; + +///////////////////////////////////////////////////////////////////////// + +// bfloat16 + +namespace detail { + +template +struct Gemm_OperandA; + +template +struct Gemm_OperandB; + +template<> +struct Gemm_OperandA { + using GmemTiledCopy = XE_2D_U16x8x16x4x2_LD_N; +}; + +template<> +struct Gemm_OperandB { + using GmemTiledCopy = XE_2D_U16x16x16x2x2_V; +}; + +} // namespace details + +template +struct GemmConfiguration< + arch::IntelPVC, + bfloat16_t, LayoutA, + bfloat16_t, LayoutB, + float, LayoutC, + float> { + using TileShape = Shape<_256, _256, _32>; + using DispatchPolicy = MainloopIntelPVC<3>;; + using TiledMma = TiledMMA< + MMA_Atom, + Layout>, + Tile<_32,_64,_32>>; + + // A + using OperandA = detail::Gemm_OperandA; + using GmemTiledCopyA = typename OperandA::GmemTiledCopy; + + // B + using OperandB = detail::Gemm_OperandB; + using GmemTiledCopyB = typename OperandB::GmemTiledCopy; + + // Mainloop + using CollectiveMainloop = collective::CollectiveMma< + DispatchPolicy, TileShape, + bfloat16_t, TagToStrideA_t, + bfloat16_t, TagToStrideB_t, + TiledMma, + GmemTiledCopyA, void, void, identity, // A + GmemTiledCopyB, void, void, identity // B + >; + + // Epilogue + using EpilogueDispatchPolicy = epilogue::IntelPVCEpilogue; + using EpilogueOp = epilogue::fusion::LinearCombination; + using FusionCallBacks = epilogue::fusion::FusionCallbacks; + + using CollectiveEpilogue = epilogue::collective::CollectiveEpilogue< + EpilogueDispatchPolicy, + TileShape, + float, + TagToStrideC_t, + float, + TagToStrideC_t, + FusionCallBacks, + XE_2D_U32x8x16x1x1_LD_N, + void, void, + XE_2D_U32x8x16x1x1_ST_N, + void, void>; + + using GemmKernel = kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = GemmUniversalAdapter; +}; + +} // namespace device +} // namespace gemm +} // namespace cutlass diff --git a/benchmarks/pvc/input.in b/benchmarks/pvc/input.in index 0a21d269b..8e68fcd56 100644 --- a/benchmarks/pvc/input.in +++ b/benchmarks/pvc/input.in @@ -1,6 +1,17 @@ # BFloat16 benchmarks -PvcGemmBF16BF16FP32_RRRR --bm_name=bf16_bf16_fp32 --l=1 --m=128 --k=128 --n=128 -PvcGemmBF16BF16FP32_RRRR --bm_name=bf16_bf16_fp32 --l=1 --m=1024 --k=1024 --n=1024 -PvcGemmBF16BF16FP32_RRRR --bm_name=bf16_bf16_fp32 --l=1 --m=4096 --k=4096 --n=4096 -PvcGemmBF16BF16FP32_RRRR --bm_name=bf16_bf16_fp32 --l=1 --m=8192 --k=8192 --n=8192 -PvcGemmBF16BF16FP32_RRRR --bm_name=bf16_bf16_fp32 --l=2 --m=2048 --k=96 --n=2048 +PvcGemmBF16BF16FP32_RRR --bm_name=bf16_bf16_fp32 --l=1 --m=512 --k=8192 --n=8192 +PvcGemmBF16BF16FP32_RRR --bm_name=bf16_bf16_fp32 --l=1 --m=512 --k=8192 --n=32768 +PvcGemmBF16BF16FP32_RRR --bm_name=bf16_bf16_fp32 --l=1 --m=512 --k=32768 --n=8192 +PvcGemmBF16BF16FP32_RRR --bm_name=bf16_bf16_fp32 --l=1 --m=16384 --k=8192 --n=1024 +PvcGemmBF16BF16FP32_RRR --bm_name=bf16_bf16_fp32 --l=1 --m=16384 --k=1024 --n=8192 +PvcGemmBF16BF16FP32_RRR --bm_name=bf16_bf16_fp32 --l=1 --m=16384 --k=8192 --n=4096 +PvcGemmBF16BF16FP32_RRR --bm_name=bf16_bf16_fp32 --l=1 --m=16384 --k=4096 --n=8192 +PvcGemmBF16BF16FP32_RRR --bm_name=bf16_bf16_fp32 --l=1 --m=4096 --k=16384 --n=8192 +PvcGemmBF16BF16FP32_RRR --bm_name=bf16_bf16_fp32 --l=1 --m=8192 --k=16384 --n=4096 +PvcGemmBF16BF16FP32_RRR --bm_name=bf16_bf16_fp32 --l=1 --m=1024 --k=16384 --n=8192 +PvcGemmBF16BF16FP32_RRR --bm_name=bf16_bf16_fp32 --l=1 --m=8192 --k=16384 --n=1024 +PvcGemmBF16BF16FP32_RRR --bm_name=bf16_bf16_fp32 --l=4096 --m=8 --k=128 --n=16384 +PvcGemmBF16BF16FP32_RRR --bm_name=bf16_bf16_fp32 --l=4096 --m=8 --k=16384 --n=128 +PvcGemmBF16BF16FP32_RRR --bm_name=bf16_bf16_fp32 --l=4 --m=32768 --k=128 --n=4096 +PvcGemmBF16BF16FP32_RRR --bm_name=bf16_bf16_fp32 --l=4 --m=32768 --k=4096 --n=128 +PvcGemmBF16BF16FP32_RRR --bm_name=bf16_bf16_fp32 --l=32 --m=4096 --k=4096 --n=128 diff --git a/benchmarks/pvc/pvc_gemm_bf16_bf16_fp32.cpp b/benchmarks/pvc/pvc_gemm_bf16_bf16_fp32.cpp deleted file mode 100644 index 178705c16..000000000 --- a/benchmarks/pvc/pvc_gemm_bf16_bf16_fp32.cpp +++ /dev/null @@ -1,114 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2024 - 2024 Codeplay Software Ltd. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - -#pragma once - -#include "cutlass/epilogue/collective/default_epilogue.hpp" -#include "cutlass/epilogue/collective/xe_epilogue.hpp" -#include "cutlass/epilogue/fusion/xe_callbacks.hpp" - -using namespace cute; - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -template < - typename LayoutA, - typename LayoutB, - typename LayoutC, - typename LayoutD - > -struct PvcGemmBF16BF16FP32 { - using ElementAccumulator = float; // <- data type of accumulator - using ElementComputeEpilogue = float; // <- data type of epilogue operations - using ElementInputA = bfloat16_t; // <- data type of elements in input matrix A - using ElementInputB = bfloat16_t; // <- data type of elements in input matrix B - using ElementOutput = float; // <- data type of elements in output matrix D - - // Workgroup-level tile - using TileShape = Shape<_256, _256, _32>; - - using TiledMma = TiledMMA< - MMA_Atom, - Layout>, - Tile<_32,_64,_32>>; // Subgroup level-tile - - using GmemTiledCopyA = XE_2D_U16x8x16x4x2_LD_N; - using GmemTiledCopyB = XE_2D_U16x16x16x2x2_V; - - using PipelineStages = Int<3>; - using GEMMDispatchPolicy = cutlass::gemm::MainloopIntelPVC; - using EpilogueDispatchPolicy = cutlass::epilogue::IntelPVCEpilogue; - - using EpilogueOp = cutlass::epilogue::fusion::LinearCombination; - - using FusionCallBacks = cutlass::epilogue::fusion::FusionCallbacks; - using CollectiveEpilogue = cutlass::epilogue::collective::CollectiveEpilogue< - EpilogueDispatchPolicy, - TileShape, - ElementAccumulator, - cutlass::gemm::TagToStrideC_t, - ElementOutput, - cutlass::gemm::TagToStrideC_t, - FusionCallBacks, - XE_2D_U32x8x16x1x1_LD_N, - void, void, - XE_2D_U32x8x16x1x1_ST_N, - void, void>; - - // Mainloop - using CollectiveMainloop = cutlass::gemm::collective::CollectiveMma< - GEMMDispatchPolicy, - TileShape, - ElementInputA, - cutlass::gemm::TagToStrideA_t, - ElementInputB, - cutlass::gemm::TagToStrideB_t, - TiledMma, - GmemTiledCopyA, void, void, cute::identity, // A - GmemTiledCopyB, void, void, cute::identity // B - >; - - using GemmKernel = cutlass::gemm::kernel::GemmUniversal< - Shape, - CollectiveMainloop, - CollectiveEpilogue - >; - - using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -}; - -using PvcGemmBF16BF16FP32_RRRR = PvcGemmBF16BF16FP32< - cutlass::layout::RowMajor, - cutlass::layout::RowMajor, - cutlass::layout::RowMajor, - cutlass::layout::RowMajor>;