tile_flatmm_shape.hpp Source File

tile_flatmm_shape.hpp Source File#

Composable Kernel: tile_flatmm_shape.hpp Source File
tile_flatmm_shape.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include "ck_tile/core.hpp"
8
9namespace ck_tile {
10
11template <typename BlockTile_, typename BlockWarps_, typename WarpTile_>
13{
17
18 static constexpr auto idxM = number<0>{};
19 static constexpr auto idxN = number<1>{};
20 static constexpr auto idxK = number<2>{};
21
23
24 static constexpr index_t kM = BlockTile::at(idxM);
25 static constexpr index_t kN = BlockTile::at(idxN);
26 static constexpr index_t kK = BlockTile::at(idxK);
27
28 static constexpr index_t flatNPerWarp = BlockWarps::at(idxN);
29 static constexpr index_t flatKPerWarp = WarpTile::at(idxK) * WarpTile::at(idxN);
30 static constexpr index_t flatKPerBlock = flatKPerWarp * kK / WarpTile::at(idxK);
31
32 static constexpr bool PermuteA = false;
33 static constexpr bool PermuteB = false;
34
35 CK_TILE_HOST static std::string GetName()
36 {
37 // clang-format off
38 return concat('_', "tile_flatmm_shape",
39 concat('x', kM, kN, kK, NumWarps),
40 concat('x', BlockWarps::at(idxM), BlockWarps::at(idxN), BlockWarps::at(idxK)),
41 concat('x', (WarpTile::at(idxM)), WarpTile::at(idxN), WarpTile::at(idxK)));
42 // clang-format on
43 }
44};
45
46} // namespace ck_tile
#define CK_TILE_HOST
Definition config.hpp:40
Definition tile/core/algorithm/cluster_descriptor.hpp:13
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
auto concat(const Ts &... xs) -> std::enable_if_t<!AllConvertibleToStringView< Ts... >, std::string >
Definition concat.hpp:43
constant< v > number
Definition tile/core/numeric/integral_constant.hpp:37
int32_t index_t
Definition integer.hpp:9
CK_TILE_HOST_DEVICE constexpr index_t reduce_on_sequence(Seq, Reduce f, number< Init >)
Definition tile/core/container/sequence.hpp:982
Definition tile_flatmm_shape.hpp:13
static constexpr index_t flatKPerWarp
Definition tile_flatmm_shape.hpp:29
static constexpr index_t kK
Definition tile_flatmm_shape.hpp:26
static constexpr index_t flatKPerBlock
Definition tile_flatmm_shape.hpp:30
static constexpr auto idxK
Definition tile_flatmm_shape.hpp:20
static constexpr index_t NumWarps
Definition tile_flatmm_shape.hpp:22
static constexpr index_t kN
Definition tile_flatmm_shape.hpp:25
static constexpr bool PermuteB
Definition tile_flatmm_shape.hpp:33
static constexpr auto idxN
Definition tile_flatmm_shape.hpp:19
static constexpr bool PermuteA
Definition tile_flatmm_shape.hpp:32
static constexpr index_t flatNPerWarp
Definition tile_flatmm_shape.hpp:28
static constexpr auto idxM
Definition tile_flatmm_shape.hpp:18
remove_cvref_t< WarpTile_ > WarpTile
Definition tile_flatmm_shape.hpp:16
remove_cvref_t< BlockTile_ > BlockTile
Definition tile_flatmm_shape.hpp:14
remove_cvref_t< BlockWarps_ > BlockWarps
Definition tile_flatmm_shape.hpp:15
static CK_TILE_HOST std::string GetName()
Definition tile_flatmm_shape.hpp:35
static constexpr index_t kM
Definition tile_flatmm_shape.hpp:24
Definition tile/core/numeric/math.hpp:98