block_fmha_fwd_splitkv_combine_pipeline_default_policy.hpp Source File

block_fmha_fwd_splitkv_combine_pipeline_default_policy.hpp Source File#

Composable Kernel: block_fmha_fwd_splitkv_combine_pipeline_default_policy.hpp Source File
block_fmha_fwd_splitkv_combine_pipeline_default_policy.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include "ck_tile/core.hpp"
8
9namespace ck_tile {
10
12{
13 template <index_t NumWarps, index_t M, index_t N, typename DataType>
15 {
16 static_assert(NumWarps == 1 || NumWarps == 2 || NumWarps == 4);
17
18 constexpr index_t ElemPerThread = (M * N) / (NumWarps * get_warp_size());
19 if constexpr(0 < ElemPerThread)
20 {
21 return NumWarps;
22 }
23 else
24 { // try dividing tile by smaller # of warps
25 return GetMaxNumWarpsForTile<NumWarps / 2, M, N, DataType>();
26 }
27 }
28
29 template <index_t NumWarps, index_t M, index_t N, typename DataType>
31 {
33
34 constexpr index_t ElemPerThread = (M * N) / (MaxNumWarps * get_warp_size());
35
36 constexpr index_t MaxNPerThread = 16 / sizeof(DataType);
37 return min(MaxNPerThread, ElemPerThread);
38 }
39
40 // alignment for dram lse tile (shape=[kMaxSplits, kM0])
41 template <typename Problem>
42 CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentLSE()
43 {
44 return GetVectorSizeForTile<Problem::kNumWarps,
45 Problem::kMaxSplits,
46 Problem::kM0,
47 typename Problem::LSEDataType>();
48 }
49
50 template <typename Problem>
51 CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentOacc()
52 {
54
55 constexpr index_t kNumWarps = Problem::kNumWarps;
56 constexpr index_t kMPerBlock = Problem::kM0;
57 constexpr index_t kNPerBlock = Problem::kN1;
58
59 constexpr index_t M1 = kNumWarps;
60 constexpr index_t M2 = min(kMPerBlock / M1, get_warp_size());
61 constexpr index_t N0 = get_warp_size() / M2;
62 constexpr index_t N1 = kNPerBlock / N0;
63
64 return min(N1, static_cast<index_t>(16 / sizeof(OaccDataType)));
65 }
66
67 template <typename Problem>
68 CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentO()
69 {
71 }
72
73 template <typename Problem>
75 {
76 return sizeof(typename Problem::LSEDataType) *
77 MakeLSEaccLdsBlockDescriptor<Problem>().get_element_space_size();
78 }
79
80 template <typename Problem>
82 {
83 return sizeof(typename Problem::OaccDataType) *
84 MakeOaccLdsBlockDescriptor<Problem>().get_element_space_size();
85 }
86
87 template <typename Problem>
92
93 // shape=[kMaxSplits, kM0]
94 template <typename Problem>
96 {
98
99 constexpr index_t kMPerBlock = Problem::kMaxSplits;
100 constexpr index_t kNPerBlock = Problem::kM0;
101
102 constexpr index_t MaxNumWarps =
104 constexpr index_t Replicate = Problem::kNumWarps / MaxNumWarps;
105
106 constexpr index_t NPerThread =
108 constexpr index_t NThreads = kNPerBlock / NPerThread;
109
110 constexpr index_t MThreadsPerWarp = get_warp_size() / NThreads;
111 constexpr index_t MPerThread = kMPerBlock / (MaxNumWarps * MThreadsPerWarp);
112
113 static_assert(MPerThread * MaxNumWarps * MThreadsPerWarp == kMPerBlock);
114 static_assert(NThreads * NPerThread == kNPerBlock);
115
123 sequence<0, 1>>{});
124 }
125
126 // 3d + padding, shape=[kMaxSplits, kM0]
127 template <typename Problem>
129 {
131
132 constexpr index_t kMPerBlock = Problem::kMaxSplits;
133 constexpr index_t kNPerBlock = Problem::kM0;
134 constexpr index_t NPack =
136
137 constexpr auto lse_acc_lds_block_desc_0 = make_naive_tensor_descriptor(
139 make_tuple(number<(kMPerBlock + 1) * NPack>{}, number<NPack>{}, number<1>{}),
141 number<1>{});
142
143 constexpr auto lse_acc_lds_block_desc = transform_tensor_descriptor(
144 lse_acc_lds_block_desc_0,
150
151 return lse_acc_lds_block_desc;
152 }
153
154 // 3d + padding, shape=[kM0, kMaxSplits]
155 template <typename Problem>
157 {
159
160 constexpr index_t kMPerBlock = Problem::kMaxSplits;
161 constexpr index_t kNPerBlock = Problem::kM0;
162 constexpr index_t NPack =
164
165 constexpr auto lse_acc_lds_block_desc_0 = make_naive_tensor_descriptor(
167 make_tuple(number<(kMPerBlock + 1) * NPack>{}, number<NPack>{}, number<1>{}),
169 number<1>{});
170
171 constexpr auto lse_acc_t_lds_block_desc = transform_tensor_descriptor(
172 lse_acc_lds_block_desc_0,
178
179 return lse_acc_t_lds_block_desc;
180 }
181
182 // 3d + padding, shape=[kNumWarps * kM0, kN1]
183 template <typename Problem>
185 {
187
188 constexpr index_t kNumWarps = Problem::kNumWarps;
189 constexpr index_t kMPerBlock = kNumWarps * Problem::kM0;
190 constexpr index_t kNPerBlock = Problem::kN1;
191 constexpr index_t NPack =
193
194 constexpr auto o_acc_lds_block_desc_0 = make_naive_tensor_descriptor(
196 make_tuple(number<(kMPerBlock + 1) * NPack>{}, number<NPack>{}, number<1>{}),
198 number<1>{});
199
200 constexpr auto o_acc_lds_block_desc = transform_tensor_descriptor(
201 o_acc_lds_block_desc_0,
203 make_merge_transform(make_tuple(kNPerBlock / NPack, NPack))),
206
207 return o_acc_lds_block_desc;
208 }
209
210 // shape=[kM0, kMaxSplits]
211 template <typename Problem>
213 {
214 constexpr index_t kMPerBlock = Problem::kM0;
215 constexpr index_t kNPerBlock = Problem::kMaxSplits;
216
217 constexpr index_t MaxNThreads = 8;
218 constexpr index_t NThreads = min(kNPerBlock, MaxNThreads);
219 constexpr index_t NPerThread = kNPerBlock / NThreads;
220
221 constexpr index_t MPerThread = 1;
222 constexpr index_t MThreads = kMPerBlock / MPerThread;
223 constexpr index_t MThreadPerWarp = get_warp_size() / NThreads;
224
225 constexpr index_t MaxNumWarps = (MThreads * NThreads) / get_warp_size();
226 constexpr index_t Replicate = Problem::kNumWarps / MaxNumWarps;
227
228 static_assert(MaxNumWarps * MThreadPerWarp * MPerThread == kMPerBlock);
229 static_assert(NThreads * NPerThread == kNPerBlock);
230
238 sequence<2, 1>>{});
239 }
240
241 // similar to MakeOaccResultDramTileDistribution(), but duplicate same 1-warp encoding kNumWarps
242 // times on M direction
243 template <typename Problem>
245 {
246 constexpr index_t kNumWarps = Problem::kNumWarps;
247 constexpr index_t kMPerBlock = Problem::kM0; // real kMPerBlock we want is (kNumWarps * kM0)
248 constexpr index_t kNPerBlock = Problem::kN1;
249 static_assert(get_warp_size() <= kMPerBlock * kNPerBlock);
250
251 constexpr index_t M1 = 1; // compose encoding base on 1 warp
252 constexpr index_t M2 = min(kMPerBlock / M1, get_warp_size());
253 constexpr index_t N0 = get_warp_size() / M2;
254 constexpr index_t N1 = kNPerBlock / N0;
255 constexpr index_t M0 = kMPerBlock / (M2 * M1);
256
263 sequence<1, 1>>{});
264 }
265
266 template <typename Problem>
268 {
269 constexpr index_t kNumWarps = Problem::kNumWarps;
270 constexpr index_t kMPerBlock = Problem::kM0;
271 constexpr index_t kNPerBlock = Problem::kN1;
272 static_assert(kNumWarps * get_warp_size() <= kMPerBlock * kNPerBlock);
273
274 constexpr index_t M1 = kNumWarps;
275 constexpr index_t M2 = min(kMPerBlock / M1, get_warp_size());
276 constexpr index_t N0 = get_warp_size() / M2;
277 constexpr index_t N1 = kNPerBlock / N0;
278 constexpr index_t M0 = kMPerBlock / (M2 * M1);
279
286 sequence<0, 1>>{});
287 }
288};
289
290} // namespace ck_tile
#define CK_TILE_HOST_DEVICE
Definition config.hpp:42
Definition tile/core/algorithm/cluster_descriptor.hpp:13
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
CK_TILE_HOST_DEVICE constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition coordinate_transform.hpp:1558
CK_TILE_HOST_DEVICE constexpr index_t get_warp_size()
Definition arch.hpp:63
CK_TILE_HOST_DEVICE constexpr auto make_naive_tensor_descriptor(const tuple< Lengths... > &lengths, const tuple< Strides... > &strides, number< GuaranteedLastDimensionVectorLength >=number<-1 >{}, number< GuaranteedLastDimensionVectorStride >=number<-1 >{})
Definition tile/core/tensor/tensor_descriptor.hpp:274
CK_TILE_HOST_DEVICE constexpr auto make_merge_transform(const LowLengths &low_lengths)
Definition coordinate_transform.hpp:1615
CK_TILE_HOST_DEVICE constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldTopIdss, NewUpperDimensionNewTopIdss)
Definition tile/core/tensor/tensor_descriptor.hpp:203
constant< v > number
Definition tile/core/numeric/integral_constant.hpp:37
CK_TILE_HOST_DEVICE constexpr T min(T x)
Definition tile/core/numeric/math.hpp:210
int32_t index_t
Definition integer.hpp:9
CK_TILE_HOST_DEVICE constexpr auto make_static_tile_distribution(StaticTileDistributionEncoding_)
Definition tile_distribution.hpp:480
CK_TILE_HOST_DEVICE constexpr auto make_tuple(Xs &&... xs)
Definition tile/core/container/tuple.hpp:360
Definition block_fmha_fwd_splitkv_combine_pipeline_default_policy.hpp:12
static CK_TILE_HOST_DEVICE constexpr auto MakeOaccLdsBlockDescriptor()
Definition block_fmha_fwd_splitkv_combine_pipeline_default_policy.hpp:184
static CK_TILE_HOST_DEVICE constexpr ck_tile::index_t GetSmemSizeOacc()
Definition block_fmha_fwd_splitkv_combine_pipeline_default_policy.hpp:81
static CK_TILE_HOST_DEVICE constexpr auto GetVectorSizeForTile()
Definition block_fmha_fwd_splitkv_combine_pipeline_default_policy.hpp:30
static CK_TILE_HOST_DEVICE constexpr auto GetAlignmentLSE()
Definition block_fmha_fwd_splitkv_combine_pipeline_default_policy.hpp:42
static CK_TILE_HOST_DEVICE constexpr auto MakeOaccResultDramTileDistribution()
Definition block_fmha_fwd_splitkv_combine_pipeline_default_policy.hpp:267
static CK_TILE_HOST_DEVICE constexpr auto MakeLSEaccDramTileDistribution()
Definition block_fmha_fwd_splitkv_combine_pipeline_default_policy.hpp:95
static CK_TILE_HOST_DEVICE constexpr ck_tile::index_t GetSmemSizeLSEacc()
Definition block_fmha_fwd_splitkv_combine_pipeline_default_policy.hpp:74
static CK_TILE_HOST_DEVICE constexpr auto MakeLSEaccLdsBlockDescriptor()
Definition block_fmha_fwd_splitkv_combine_pipeline_default_policy.hpp:156
static CK_TILE_HOST_DEVICE constexpr auto MakeLSEaccRegTileDistribution()
Definition block_fmha_fwd_splitkv_combine_pipeline_default_policy.hpp:212
static CK_TILE_HOST_DEVICE constexpr auto GetAlignmentOacc()
Definition block_fmha_fwd_splitkv_combine_pipeline_default_policy.hpp:51
static CK_TILE_HOST_DEVICE constexpr auto MakeOaccDramTileDistribution()
Definition block_fmha_fwd_splitkv_combine_pipeline_default_policy.hpp:244
static CK_TILE_HOST_DEVICE constexpr auto MakeLSEaccLdsStoreBlockDescriptor()
Definition block_fmha_fwd_splitkv_combine_pipeline_default_policy.hpp:128
static CK_TILE_HOST_DEVICE constexpr ck_tile::index_t GetSmemSize()
Definition block_fmha_fwd_splitkv_combine_pipeline_default_policy.hpp:88
static CK_TILE_HOST_DEVICE constexpr auto GetMaxNumWarpsForTile()
Definition block_fmha_fwd_splitkv_combine_pipeline_default_policy.hpp:14
static CK_TILE_HOST_DEVICE constexpr auto GetAlignmentO()
Definition block_fmha_fwd_splitkv_combine_pipeline_default_policy.hpp:68
Definition tile/core/container/sequence.hpp:49
Definition tile_distribution_encoding.hpp:26
Definition tile/core/container/tuple.hpp:192