block_fmha_pipeline_qr_ks_vs_async_trload_policy.hpp Source File

block_fmha_pipeline_qr_ks_vs_async_trload_policy.hpp Source File#

Composable Kernel: block_fmha_pipeline_qr_ks_vs_async_trload_policy.hpp Source File
block_fmha_pipeline_qr_ks_vs_async_trload_policy.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include "ck_tile/core.hpp"
12
13// can remove all bank conflicts, but drop the performance for some cases
14// Probably it is limited by compiler optimization.
15#define CK_TILE_FMHA_HANDLE_XOR_LENGTH_FOLD 0
16namespace ck_tile {
17// This pipeline is qkv all located in LDS
19 : BlockFmhaPipelineQXKSVSCustomPolicy</* QLoadOnce = */ true,
20 /* AsyncCopy = */ false,
21 /* NumPrefetchK = */ 1,
22 /* NumPrefetchV = */ 1>
23{
24 using BasePolicy = BlockFmhaPipelineQXKSVSCustomPolicy</* QLoadOnce = */ true,
25 /* AsyncCopy = */ false,
26 /* NumPrefetchK = */ 1,
27 /* NumPrefetchV = */ 1>;
28
29 template <typename Problem>
30 CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentQ()
31 {
32 constexpr index_t kBlockSize = Problem::kBlockSize;
33 constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
34 constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kSubQKHeaddim;
35
36 constexpr index_t MaxVectorSize = 16 / sizeof(typename Problem::QDataType);
37
38 // this should align with MakeQDramTileDistribution()
39 constexpr index_t ElemPerThread = (kMPerBlock * kKPerBlock) / kBlockSize;
40 static_assert(0 < ElemPerThread);
41 return min(ElemPerThread, MaxVectorSize);
42 }
43
44 template <typename Problem>
45 CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentOacc()
46 {
48
49 return static_cast<index_t>(16 / sizeof(OaccDataType));
50 }
51
52 template <typename Problem>
53 CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentK()
54 {
55 constexpr index_t kBlockSize = Problem::kBlockSize;
56 constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
57 constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kSubQKHeaddim;
58
59 constexpr index_t MaxVectorSize = 16 / sizeof(typename Problem::KDataType);
60
61 constexpr index_t ElemPerThread = (kNPerBlock * kKPerBlock) / kBlockSize;
62 static_assert(0 < ElemPerThread);
63 return min(ElemPerThread, MaxVectorSize);
64 }
65
66 template <typename Problem>
67 CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentV()
68 {
69 constexpr index_t kBlockSize = Problem::kBlockSize;
70 constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1;
71 constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kN0;
72
73 constexpr index_t MaxVectorSize = 16 / sizeof(typename Problem::VDataType);
74
75 constexpr index_t ElemPerThread = (kNPerBlock * kKPerBlock) / kBlockSize;
76 static_assert(0 < ElemPerThread);
77 return min(ElemPerThread, MaxVectorSize);
78 }
79
80 template <typename Problem, bool BypassLDS = false>
82 {
83 if constexpr(!BypassLDS)
84 {
85 constexpr index_t kBlockSize = Problem::kBlockSize;
86 constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
87 constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kSubQKHeaddim;
88
89 constexpr index_t MaxVectorSize = 16 / sizeof(typename Problem::QDataType);
90
91 constexpr index_t ElemPerThread = (kMPerBlock * kKPerBlock) / kBlockSize;
92 static_assert(0 < ElemPerThread);
93 constexpr index_t kMaxVecLoad = min(ElemPerThread, MaxVectorSize);
94
95 constexpr index_t KPerThread = kMaxVecLoad;
96 constexpr index_t KThreads = kKPerBlock / KPerThread;
97 constexpr index_t MThreadPerWarp = get_warp_size() / KThreads;
98 constexpr index_t NumWarps = kBlockSize / get_warp_size();
99 constexpr index_t MPerThread = kMPerBlock / (MThreadPerWarp * NumWarps);
100
108 sequence<0, 1>>{});
109 }
110 else
111 {
113 constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
114 using WarpGemm = remove_cvref_t<decltype(config.template at<0>())>;
115
116 constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm0BlockWarps::at(number<0>{});
117 constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm0BlockWarps::at(number<1>{});
118
119 constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
120 constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kSubQKHeaddim;
121
122 constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WarpGemm::kM);
123 constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK;
124
125 constexpr auto q_block_outer_dstr_encoding = tile_distribution_encoding<
132
133 constexpr auto q_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
134 q_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{});
135
136 constexpr auto q_block_dstr = make_static_tile_distribution(q_block_dstr_encode);
137
138 return q_block_dstr;
139 }
140 }
141
142 template <typename Problem, bool LoadOnce = false>
144 {
146
147 constexpr index_t kBlockSize = Problem::kBlockSize;
148 constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
149 constexpr index_t kKPerBlock =
150 LoadOnce ? Problem::BlockFmhaShape::kSubQKHeaddim : Problem::BlockFmhaShape::kK0;
151
152 constexpr index_t MaxVectorSize = 16 / sizeof(KDataType);
153 constexpr index_t ElemPerThread = (kNPerBlock * kKPerBlock) / kBlockSize;
154
155 constexpr index_t K1 = min(MaxVectorSize, ElemPerThread);
156 constexpr index_t K0 = kKPerBlock / K1;
157 constexpr index_t N2 = get_warp_size() / K0;
158 constexpr index_t N1 = kBlockSize / get_warp_size();
159 constexpr index_t N0 = kNPerBlock / (N2 * N1);
160
167 sequence<0, 1>>{});
168 }
169
170 template <typename Problem>
172 {
174 constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
175 using WarpGemm = remove_cvref_t<decltype(config.template at<0>())>;
176
177 constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm0BlockWarps::at(number<0>{});
178 constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm0BlockWarps::at(number<1>{});
179
180 constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
181 constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kSubQKHeaddim;
182
183 constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WarpGemm::kM);
184 constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK;
185
186 // Read M first, then K
187 // This is the same data consume order as BlockGEMM
188 constexpr auto q_block_outer_dstr_encoding =
195
196 constexpr auto q_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
197 q_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{});
198
199 constexpr auto q_block_dstr = make_static_tile_distribution(q_block_dstr_encode);
200
201 return q_block_dstr;
202 }
203
204 template <typename Problem>
205 CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackQ()
206 {
207 // TODO: this is for 3d layout
209 return static_cast<index_t>(16 / sizeof(QDataType));
210 }
211
212 template <typename Problem, bool Xor = false>
214 {
215 constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
216 constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kSubQKHeaddim;
217
218 constexpr index_t kKPack = GetSmemKPackQ<Problem>();
219
220 constexpr auto q_lds_block_desc = [&]() {
221 if constexpr(Xor)
222 {
223#if CK_TILE_FMHA_HANDLE_XOR_LENGTH_FOLD
224 constexpr auto LDSLayerSize = 256 / sizeof(typename Problem::QDataType);
225 constexpr auto XorLengthFold = LDSLayerSize / kKPerBlock;
226
227 if constexpr(XorLengthFold > 1)
228 {
229 constexpr auto q_lds_block_desc_naive = make_naive_tensor_descriptor(
231 number<LDSLayerSize / kKPack>{},
235 number<1>{});
236
237 constexpr auto q_lds_block_desc_permuted = transform_tensor_descriptor(
238 q_lds_block_desc_naive,
241 number<LDSLayerSize / kKPack>{})),
245
246 constexpr auto q_lds_block_desc_tmp = transform_tensor_descriptor(
247 q_lds_block_desc_permuted,
251 make_tuple(number<XorLengthFold>{}, number<kKPerBlock / kKPack>{})),
255
257 q_lds_block_desc_tmp,
265 }
266 else
267#endif // CK_TILE_FMHA_HANDLE_XOR_LENGTH_FOLD
268 {
269 constexpr auto q_lds_block_desc_naive = make_naive_tensor_descriptor(
271 number<kMPerBlock>{}, number<kKPerBlock / kKPack>{}, number<kKPack>{}),
274 number<1>{});
275
276 constexpr auto q_lds_block_desc_permuted = transform_tensor_descriptor(
277 q_lds_block_desc_naive,
279 number<kKPerBlock / kKPack>{})),
283
285 q_lds_block_desc_permuted,
291 }
292 }
293 else
294 {
299 number<1>{});
300 }
301 }();
302
303 return q_lds_block_desc;
304 }
305
306 template <typename Problem, bool LoadOnce = false, bool Xor = false>
308 {
309 constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
310 constexpr index_t kKPerBlock =
311 LoadOnce ? Problem::BlockFmhaShape::kSubQKHeaddim : Problem::BlockFmhaShape::kK0;
312
313 constexpr index_t kKPack = GetSmemKPackK<Problem>();
314
315 constexpr auto k_lds_block_desc = [&]() {
316 if constexpr(Xor)
317 {
318#if CK_TILE_FMHA_HANDLE_XOR_LENGTH_FOLD
319 constexpr auto LDSLayerSize = 256 / sizeof(typename Problem::KDataType);
320 constexpr auto XorLengthFold = LDSLayerSize / kKPerBlock;
321
322 if constexpr(XorLengthFold > 1)
323 {
324 constexpr auto k_lds_block_desc_naive = make_naive_tensor_descriptor(
326 number<LDSLayerSize / kKPack>{},
330 number<1>{});
331
332 constexpr auto k_lds_block_desc_permuted = transform_tensor_descriptor(
333 k_lds_block_desc_naive,
336 number<LDSLayerSize / kKPack>{})),
340
341 constexpr auto k_lds_block_desc_tmp = transform_tensor_descriptor(
342 k_lds_block_desc_permuted,
346 make_tuple(number<XorLengthFold>{}, number<kKPerBlock / kKPack>{})),
350
352 k_lds_block_desc_tmp,
360 }
361 else
362#endif // CK_TILE_FMHA_HANDLE_XOR_LENGTH_FOLD
363 {
364 constexpr auto k_lds_block_desc_naive = make_naive_tensor_descriptor(
366 number<kNPerBlock>{}, number<kKPerBlock / kKPack>{}, number<kKPack>{}),
369 number<1>{});
370
371 constexpr auto k_lds_block_desc_permuted = transform_tensor_descriptor(
372 k_lds_block_desc_naive,
374 number<kKPerBlock / kKPack>{})),
378
380 k_lds_block_desc_permuted,
386 }
387 }
388 else
389 {
394 number<1>{});
395 }
396 }();
397
398 return k_lds_block_desc;
399 }
400
401 template <typename Problem, bool Xor = false>
403 {
404 constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1;
405 constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kN0;
406
407 constexpr index_t kKPack = GetSmemKPackV<Problem>();
408
409 constexpr auto v_lds_block_desc = [&]() {
410 if constexpr(Xor)
411 {
412 constexpr auto XorGroupSize =
413 Problem::BlockFmhaShape::Gemm1WarpTile::at(number<0>{});
414
415#if CK_TILE_FMHA_HANDLE_XOR_LENGTH_FOLD
416 constexpr auto LDSLayerSize = 256 / sizeof(typename Problem::VDataType);
417 constexpr auto XorLengthFold = LDSLayerSize / kNPerBlock;
418
419 if constexpr(XorLengthFold > 1)
420 {
421 constexpr auto v_lds_block_desc_naive = make_naive_tensor_descriptor(
423 number<LDSLayerSize / XorGroupSize>{},
427 number<1>{});
428
429 constexpr auto v_lds_block_desc_permuted = transform_tensor_descriptor(
430 v_lds_block_desc_naive,
433 number<LDSLayerSize / XorGroupSize>{})),
437
438 constexpr auto v_lds_block_desc_tmp = transform_tensor_descriptor(
439 v_lds_block_desc_permuted,
443 number<kNPerBlock / XorGroupSize>{})),
447
449 v_lds_block_desc_tmp,
457 }
458 else
459#endif // CK_TILE_FMHA_HANDLE_XOR_LENGTH_FOLD
460 {
461 constexpr auto v_lds_block_desc_naive = make_naive_tensor_descriptor(
463 number<kNPerBlock / XorGroupSize>{},
467 number<1>{});
468
469 constexpr auto v_lds_block_desc_permuted = transform_tensor_descriptor(
470 v_lds_block_desc_naive,
472 number<kKPerBlock>{}, number<kNPerBlock / XorGroupSize>{})),
476
478 v_lds_block_desc_permuted,
485 }
486 }
487 else
488 {
493 number<1>{});
494 }
495 }();
496
497 return v_lds_block_desc;
498 }
499
500 template <typename Problem>
501 CK_TILE_HOST_DEVICE static constexpr auto GetQKBlockGemm()
502 {
503 using GemmProblem =
504 BlockGemmProblem<typename Problem::QDataType,
505 typename Problem::KDataType,
506 typename Problem::SaccDataType,
507 Problem::kBlockSize,
508 TileGemmShape<sequence<Problem::BlockFmhaShape::kM0,
509 Problem::BlockFmhaShape::kN0,
510 Problem::BlockFmhaShape::kK0>,
511 typename Problem::BlockFmhaShape::Gemm0BlockWarps,
512 typename Problem::BlockFmhaShape::Gemm0WarpTile>>;
513
514 using WarpGemm = WarpGemmDispatcher<typename Problem::QDataType,
515 typename Problem::KDataType,
516 typename Problem::SaccDataType,
517 Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{}),
518 Problem::BlockFmhaShape::Gemm0WarpTile::at(number<1>{}),
519 Problem::BlockFmhaShape::Gemm0WarpTile::at(number<2>{}),
520 true>;
521
522 using BlockGemmPolicy =
523 BlockGemmARegBRegCRegV2CustomPolicy<typename Problem::QDataType,
524 typename Problem::KDataType,
525 typename Problem::SaccDataType,
526 typename Problem::BlockFmhaShape::Gemm0BlockWarps,
527 WarpGemm,
529
531 }
532
533 template <typename Problem>
534 CK_TILE_HOST_DEVICE static constexpr auto GetPVBlockGemm()
535 {
536 using GemmProblem =
537 BlockGemmProblem<typename Problem::PDataType,
538 typename Problem::VDataType,
539 typename Problem::OaccDataType,
540 Problem::kBlockSize,
541 TileGemmShape<sequence<Problem::BlockFmhaShape::kM0,
542 Problem::BlockFmhaShape::kN1,
543 Problem::BlockFmhaShape::kK1>,
544 typename Problem::BlockFmhaShape::Gemm1BlockWarps,
545 typename Problem::BlockFmhaShape::Gemm1WarpTile>>;
546
547 using WarpGemm =
548 WarpGemmDispatcher<typename Problem::PDataType,
549 typename Problem::VDataType,
550 typename Problem::OaccDataType,
551 Problem::BlockFmhaShape::Gemm1WarpTile::at(number<0>{}),
552 Problem::BlockFmhaShape::Gemm1WarpTile::at(number<1>{}),
553 Problem::BlockFmhaShape::Gemm1WarpTile::at(number<2>{}),
554 true,
555 false,
556 false,
557 ((Problem::BlockFmhaShape::Gemm1WarpTile::at(number<1>{}) == 16 &&
558 Problem::BlockFmhaShape::Gemm1WarpTile::at(number<2>{}) == 32) ||
559 (Problem::BlockFmhaShape::Gemm1WarpTile::at(number<1>{}) == 32 &&
560 Problem::BlockFmhaShape::Gemm1WarpTile::at(number<2>{}) == 16))
563
564 using BlockGemmPolicy =
565 BlockGemmARegBRegCRegV2CustomPolicy<typename Problem::PDataType,
566 typename Problem::VDataType,
567 typename Problem::OaccDataType,
568 typename Problem::BlockFmhaShape::Gemm1BlockWarps,
569 WarpGemm,
571
573 }
574
575 template <typename Problem>
577 {
579 constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
580 using WarpGemm = remove_cvref_t<decltype(config.template at<0>())>;
581
582 constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm0BlockWarps::at(number<0>{});
583 constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm0BlockWarps::at(number<1>{});
584
585 constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
586 constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
587
588 constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WarpGemm::kN);
589 constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK;
590
591 // Read N first, then K
592 // This is the same data consume order as BlockGEMM
593 constexpr auto k_block_outer_dstr_encoding =
600
601 constexpr auto k_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
602 k_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{});
603
604 constexpr auto k_block_dstr = make_static_tile_distribution(k_block_dstr_encode);
605
606 return k_block_dstr;
607 }
608
609 template <typename Problem>
611 {
612 constexpr index_t kBlockSize = Problem::kBlockSize;
613 constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1;
614 constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kN0;
615
616 constexpr index_t MaxVectorSize = 16 / sizeof(typename Problem::VDataType);
617
618 constexpr index_t ElemPerThread = (kNPerBlock * kKPerBlock) / kBlockSize;
619 static_assert(0 < ElemPerThread);
620 constexpr index_t kMaxVecLoad = min(ElemPerThread, MaxVectorSize);
621
622 constexpr index_t NPerThread = kMaxVecLoad;
623 constexpr index_t NThreads = kNPerBlock / NPerThread;
624 constexpr index_t KThreadPerWarp = get_warp_size() / NThreads;
625 constexpr index_t NumWarps = kBlockSize / get_warp_size();
626 constexpr index_t KPerThread = kKPerBlock / (KThreadPerWarp * NumWarps);
627
635 sequence<0, 1>>{});
636 }
637
638 template <typename Problem>
640 {
642 constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
643 using WarpGemm = remove_cvref_t<decltype(config.template at<0>())>;
644
645 constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm1BlockWarps::at(number<0>{});
646 constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm1BlockWarps::at(number<1>{});
647
648 constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
649 constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kN0;
650
651 constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WarpGemm::kM);
652 constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK;
653
654 // Read M first, then K
655 // This is the same data consume order as BlockGEMM
656 constexpr auto p_block_outer_dstr_encoding =
663
664 constexpr auto p_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
665 p_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{});
666
667 constexpr auto p_block_dstr = make_static_tile_distribution(p_block_dstr_encode);
668
669 return p_block_dstr;
670 }
671
672 template <typename Problem>
674 {
676 constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
677 using WarpGemm = remove_cvref_t<decltype(config.template at<0>())>;
678
679 constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm1BlockWarps::at(number<0>{});
680 constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm1BlockWarps::at(number<1>{});
681
682 constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1;
683 constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
684
685 constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WarpGemm::kN);
686 constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK;
687
688 // Read N first, then K
689 // This is the same data consume order as BlockGEMM
690 constexpr auto v_block_outer_dstr_encoding =
697
698 constexpr auto v_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
699 v_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{});
700
701 constexpr auto v_block_dstr =
703 decltype(v_block_dstr_encode),
704 typename Problem::VDataType>::TransposedDstrEncode{});
705
706 return v_block_dstr;
707 }
708
709 template <typename Problem>
710 CK_TILE_HOST_DEVICE static constexpr auto GetSmemNPackS()
711 {
713 return static_cast<index_t>(16 / sizeof(SDataType));
714 }
715
716 template <typename Problem>
718 {
719 constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
720 constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
721 constexpr index_t kNPack = GetSmemNPackS<Problem>();
722
723 constexpr auto s_lds_block_desc_0 = make_naive_tensor_descriptor(
725 make_tuple(number<(kMPerBlock + 1) * kNPack>{}, number<kNPack>{}, number<1>{}),
727 number<1>{});
728
729 constexpr auto s_lds_block_desc = transform_tensor_descriptor(
730 s_lds_block_desc_0,
736
737 return s_lds_block_desc;
738 }
739
740 template <typename Problem>
742 {
744
745 constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
746 using WG = remove_cvref_t<decltype(config.template at<0>())>;
747 constexpr index_t MWarp = config.template at<1>();
748 constexpr index_t NWarp = config.template at<2>();
749
750 // static_assert(MWarp == 1, "Check failed!");
751
752 constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
753 constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
754 constexpr index_t kTileK = Problem::BlockFmhaShape::kN0;
755
756 // K2 is equal to Impl::kABKPerLane * kKIterPerWarpGemm
757 constexpr index_t K3 = WG::kK / WG::WarpGemmAttribute::Impl::kABKLane;
758 constexpr index_t K2 = WG::WarpGemmAttribute::Impl::kABKLane;
759 constexpr index_t K1 = kKPerBlock / (K2 * K3);
760 constexpr index_t K0 = kTileK / kKPerBlock;
761 constexpr index_t M2 = WG::WarpGemmAttribute::Impl::kAMLane;
762 constexpr index_t M1 = MWarp;
763 constexpr index_t M0 = kMPerBlock / (M2 * M1);
764
765 constexpr auto s2_block_dstr_encoding =
772
773 constexpr auto s2_block_dstr = make_static_tile_distribution(s2_block_dstr_encoding);
774
775 return s2_block_dstr;
776 }
777
778 template <typename Problem>
780 {
781 return MakeQLdsBlockDescriptor<Problem>().get_element_space_size() *
782 sizeof(typename Problem::QDataType);
783 }
784
785 template <typename Problem, bool LoadOnce = false>
787 {
788 return MakeKLdsBlockDescriptor<Problem, LoadOnce>().get_element_space_size() *
789 sizeof(typename Problem::KDataType);
790 }
791
792 template <typename Problem>
794 {
795 return MakeVLdsBlockDescriptor<Problem>().get_element_space_size() *
796 sizeof(typename Problem::VDataType);
797 }
798
799 template <typename Problem>
801 {
802 constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm0BlockWarps::at(number<1>{});
803
804 return NWarp > 1 ? MakeSLdsBlockDescriptor<Problem>().get_element_space_size() *
805 sizeof(typename Problem::SaccDataType)
806 : 0;
807 }
808
809 template <typename Problem>
811 {
812 // Alignment on gfx950 is 1280 Bytes
813 // Alignment before gfx950 is 512 Bytes.
814 return max(GetSmemSizeQ<Problem>(),
816 }
817};
818
819} // namespace ck_tile
#define CK_TILE_DEVICE
Definition config.hpp:41
#define CK_TILE_HOST_DEVICE
Definition config.hpp:42
CK_TILE_HOST_DEVICE constexpr auto make_embed_tile_distribution_encoding(OuterDstr, InnerDstr)
Definition tile_distribution_encoding.hpp:457
Definition tile/core/algorithm/cluster_descriptor.hpp:13
@ Single
Definition warp_gemm_attribute_mfma.hpp:14
@ Double
Definition warp_gemm_attribute_mfma.hpp:15
typename impl::WarpGemmDispatcher< AType, BType, AccType, MPerWave, NPerWave, KPerWave, TransposeC, SwizzleA, UseStructuredSparsity, AttrNumAccess >::Type WarpGemmDispatcher
Definition warp_gemm_dispatcher.hpp:182
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
CK_TILE_HOST_DEVICE constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition coordinate_transform.hpp:1558
CK_TILE_HOST_DEVICE constexpr index_t get_warp_size()
Definition arch.hpp:63
TransposeTileDistributionTraits< TileDistributionEncoding_, DataType_, Policy, true > InputTileDistributionTraits
Definition load_tile_transpose.hpp:343
CK_TILE_HOST_DEVICE constexpr auto make_naive_tensor_descriptor(const tuple< Lengths... > &lengths, const tuple< Strides... > &strides, number< GuaranteedLastDimensionVectorLength >=number<-1 >{}, number< GuaranteedLastDimensionVectorStride >=number<-1 >{})
Definition tile/core/tensor/tensor_descriptor.hpp:274
CK_TILE_HOST_DEVICE constexpr auto make_merge_transform(const LowLengths &low_lengths)
Definition coordinate_transform.hpp:1615
CK_TILE_HOST_DEVICE constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldTopIdss, NewUpperDimensionNewTopIdss)
Definition tile/core/tensor/tensor_descriptor.hpp:203
CK_TILE_HOST_DEVICE constexpr auto make_unmerge_transform(const UpLengths &up_lengths, bool_constant< Use24BitIntegerCalculation >=bool_constant< false >{})
Definition coordinate_transform.hpp:1622
constant< v > number
Definition tile/core/numeric/integral_constant.hpp:37
CK_TILE_HOST_DEVICE constexpr auto make_xor_transform(const LowLengths &low_lengths)
Definition coordinate_transform.hpp:1662
CK_TILE_HOST_DEVICE constexpr auto make_merge_transform_v3_division_mod(const LowLengths &low_lengths)
Definition coordinate_transform.hpp:1609
CK_TILE_HOST_DEVICE constexpr T max(T x)
Definition tile/core/numeric/math.hpp:161
CK_TILE_HOST_DEVICE constexpr T min(T x)
Definition tile/core/numeric/math.hpp:210
int32_t index_t
Definition integer.hpp:9
CK_TILE_HOST_DEVICE constexpr auto make_static_tile_distribution(StaticTileDistributionEncoding_)
Definition tile_distribution.hpp:480
CK_TILE_HOST_DEVICE constexpr auto make_tuple(Xs &&... xs)
Definition tile/core/container/tuple.hpp:360
@ KMN
Definition block_gemm_areg_breg_creg_v2_custom_policy.hpp:12
@ MNK
Definition block_gemm_areg_breg_creg_v2_custom_policy.hpp:13
Definition block_fmha_pipeline_qr_ks_vs_async_trload_policy.hpp:23
static CK_TILE_HOST_DEVICE constexpr auto GetAlignmentV()
Definition block_fmha_pipeline_qr_ks_vs_async_trload_policy.hpp:67
static CK_TILE_HOST_DEVICE constexpr ck_tile::index_t GetSmemSizeQ()
Definition block_fmha_pipeline_qr_ks_vs_async_trload_policy.hpp:779
static CK_TILE_HOST_DEVICE constexpr auto MakeVRegTileDistribution()
Definition block_fmha_pipeline_qr_ks_vs_async_trload_policy.hpp:673
static CK_TILE_HOST_DEVICE constexpr auto MakeQDramTileDistribution()
Definition block_fmha_pipeline_qr_ks_vs_async_trload_policy.hpp:81
static CK_TILE_HOST_DEVICE constexpr auto MakeKRegTileDistribution()
Definition block_fmha_pipeline_qr_ks_vs_async_trload_policy.hpp:576
static CK_TILE_HOST_DEVICE constexpr auto GetPVBlockGemm()
Definition block_fmha_pipeline_qr_ks_vs_async_trload_policy.hpp:534
static CK_TILE_HOST_DEVICE constexpr auto GetSmemKPackQ()
Definition block_fmha_pipeline_qr_ks_vs_async_trload_policy.hpp:205
static CK_TILE_HOST_DEVICE constexpr auto MakeVLdsBlockDescriptor()
Definition block_fmha_pipeline_qr_ks_vs_async_trload_policy.hpp:402
static CK_TILE_HOST_DEVICE constexpr auto GetSmemNPackS()
Definition block_fmha_pipeline_qr_ks_vs_async_trload_policy.hpp:710
static CK_TILE_HOST_DEVICE constexpr ck_tile::index_t GetSmemSizeS()
Definition block_fmha_pipeline_qr_ks_vs_async_trload_policy.hpp:800
static CK_TILE_HOST_DEVICE constexpr auto MakeQLdsBlockDescriptor()
Definition block_fmha_pipeline_qr_ks_vs_async_trload_policy.hpp:213
static CK_TILE_HOST_DEVICE constexpr auto GetQKBlockGemm()
Definition block_fmha_pipeline_qr_ks_vs_async_trload_policy.hpp:501
static CK_TILE_HOST_DEVICE constexpr auto GetAlignmentOacc()
Definition block_fmha_pipeline_qr_ks_vs_async_trload_policy.hpp:45
static CK_TILE_HOST_DEVICE constexpr auto MakeKLdsBlockDescriptor()
Definition block_fmha_pipeline_qr_ks_vs_async_trload_policy.hpp:307
static CK_TILE_HOST_DEVICE constexpr ck_tile::index_t GetSmemSizeK()
Definition block_fmha_pipeline_qr_ks_vs_async_trload_policy.hpp:786
static CK_TILE_HOST_DEVICE constexpr auto GetAlignmentK()
Definition block_fmha_pipeline_qr_ks_vs_async_trload_policy.hpp:53
BlockFmhaPipelineQXKSVSCustomPolicy< true, false, 1, 1 > BasePolicy
Definition block_fmha_pipeline_qr_ks_vs_async_trload_policy.hpp:24
static CK_TILE_HOST_DEVICE constexpr auto MakeQRegTileDistribution()
Definition block_fmha_pipeline_qr_ks_vs_async_trload_policy.hpp:171
static CK_TILE_HOST_DEVICE constexpr auto MakeSLdsBlockDescriptor()
Definition block_fmha_pipeline_qr_ks_vs_async_trload_policy.hpp:717
static CK_TILE_HOST_DEVICE constexpr auto MakePRegTileDistribution()
Definition block_fmha_pipeline_qr_ks_vs_async_trload_policy.hpp:639
static CK_TILE_HOST_DEVICE constexpr ck_tile::index_t GetSmemSizeV()
Definition block_fmha_pipeline_qr_ks_vs_async_trload_policy.hpp:793
static CK_TILE_HOST_DEVICE constexpr ck_tile::index_t GetSmemSize()
Definition block_fmha_pipeline_qr_ks_vs_async_trload_policy.hpp:810
static CK_TILE_HOST_DEVICE constexpr auto MakeSRegTileDistribution()
Definition block_fmha_pipeline_qr_ks_vs_async_trload_policy.hpp:741
static CK_TILE_HOST_DEVICE constexpr auto MakeKDramTileDistribution()
Definition block_fmha_pipeline_qr_ks_vs_async_trload_policy.hpp:143
static CK_TILE_HOST_DEVICE constexpr auto GetAlignmentQ()
Definition block_fmha_pipeline_qr_ks_vs_async_trload_policy.hpp:30
static CK_TILE_DEVICE constexpr auto MakeVDramTileDistribution()
Definition block_fmha_pipeline_qr_ks_vs_async_trload_policy.hpp:610
Definition block_fmha_pipeline_qx_ks_vs_custom_policy.hpp:266
static CK_TILE_HOST_DEVICE constexpr auto GetSmemKPackV()
Definition block_fmha_pipeline_qx_ks_vs_custom_policy.hpp:373
static CK_TILE_HOST_DEVICE constexpr auto GetSmemKPackK()
Definition block_fmha_pipeline_qx_ks_vs_custom_policy.hpp:338
Definition block_gemm_areg_breg_creg_v2_custom_policy.hpp:23
Definition block_gemm_areg_breg_creg_v2.hpp:17
Definition block_gemm_problem.hpp:18
Definition tile_gemm_shape.hpp:17
Definition tile/core/container/sequence.hpp:49
Definition tile_distribution_encoding.hpp:26
Definition tile/core/container/tuple.hpp:192