gridwise_batched_gemm_gemm_wmma_cshuffle_v3.hpp Source File

gridwise_batched_gemm_gemm_wmma_cshuffle_v3.hpp Source File#

Composable Kernel: gridwise_batched_gemm_gemm_wmma_cshuffle_v3.hpp Source File
gridwise_batched_gemm_gemm_wmma_cshuffle_v3.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include <cstdarg>
7#include "ck/utility/env.hpp"
19
20namespace ck {
21
22// Gemm0: A [M x K] x B0 [K x L] = Acc [M x L]
23// Gemm1: Acc [M x L] x B1 [L x N] = C [M x N]
24template <typename ADataType,
25 typename B0DataType,
26 typename Acc0DataType,
27 typename B1DataType,
28 typename Acc1DataType,
29 typename CShuffleDataType,
30 typename CDataType,
31 typename AElementwiseOperation,
32 typename B0ElementwiseOperation,
33 typename AccElementwiseOperation,
34 typename B1ElementwiseOperation,
35 typename CElementwiseOperation,
36 InMemoryDataOperationEnum CGlobalMemoryDataOperation,
37 typename AGridDesc,
38 typename B0GridDesc,
39 typename B1GridDesc,
40 typename CGridDesc_M_N,
41 index_t MPerBlock,
42 index_t LPerBlock,
43 index_t KPerBlock,
44 index_t AK1Value,
45 index_t BK1Value,
46 index_t NPerBlock,
47 index_t LTilePerBlock,
48 index_t L1Value,
49 index_t MPerWmma,
50 index_t LPerWmma,
51 index_t NPerWmma,
52 index_t MRepeat,
53 index_t LRepeat,
54 index_t NRepeat,
55 index_t BlockSize,
56 typename ABlockTransferThreadClusterLengths_K0_M_K1,
57 typename ABlockTransferThreadClusterArrangeOrder,
58 typename ABlockTransferSrcAccessOrder,
59 index_t ABlockTransferSrcVectorDim,
60 index_t ABlockTransferSrcScalarPerVector,
61 index_t ABlockTransferDstScalarPerVector_K1,
62 bool AThreadTransferSrcResetCoordinateAfterRun,
63 bool ABlockLdsExtraM,
64 typename B0BlockTransferThreadClusterLengths_K0_L_K1,
65 typename B0BlockTransferThreadClusterArrangeOrder,
66 typename B0BlockTransferSrcAccessOrder,
67 index_t B0BlockTransferSrcVectorDim,
68 index_t B0BlockTransferSrcScalarPerVector,
69 index_t B0BlockTransferDstScalarPerVector_K1,
70 bool B0ThreadTransferSrcResetCoordinateAfterRun,
71 bool B0BlockLdsExtraL,
72 typename B1BlockTransferThreadClusterLengths_L0_N_L1,
73 typename B1BlockTransferThreadClusterArrangeOrder,
74 typename B1BlockTransferSrcAccessOrder,
75 index_t B1BlockTransferSrcVectorDim,
76 index_t B1BlockTransferSrcScalarPerVector,
77 index_t B1BlockTransferDstScalarPerVector_L1,
78 bool B1ThreadTransferSrcResetCoordinateAfterRun,
79 bool B1BlockLdsExtraN,
80 index_t CShuffleMRepeatPerShuffle,
81 index_t CShuffleNRepeatPerShuffle,
82 typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
83 index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
84 bool PadN,
88{
89 static constexpr auto I0 = Number<0>{};
90 static constexpr auto I1 = Number<1>{};
91 static constexpr auto I2 = Number<2>{};
92 static constexpr auto I3 = Number<3>{};
93 static constexpr auto I4 = Number<4>{};
94 static constexpr auto I5 = Number<5>{};
95 static constexpr auto I6 = Number<6>{};
96
97 static constexpr auto AK0 = Number<KPerBlock / AK1Value>{};
98 static constexpr auto AK1 = Number<AK1Value>{};
99 static constexpr auto BK0 = Number<KPerBlock / BK1Value>{};
100 static constexpr auto BK1 = Number<BK1Value>{};
101
102 static constexpr auto L0PerBlock = LTilePerBlock / L1Value;
103 static constexpr auto AL0 = Number<L0PerBlock / 2>{}; // TODO: Where does this 2 come from?
104 static constexpr auto AL1 = Number<L1Value>{};
105 static constexpr auto BL0 = Number<L0PerBlock>{};
106 static constexpr auto BL1 = Number<L1Value>{};
107
108 static constexpr auto MWaves = MPerBlock / (MRepeat * MPerWmma);
109 static constexpr auto LWaves = LPerBlock / (LRepeat * LPerWmma);
110 static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma);
111
112 // TODO: I am pretty sure this is always 16 and *should* always be 16.
113 static constexpr auto KPack =
115
117
118 __host__ __device__ static constexpr auto MakeABlockDescriptor()
119 {
120 constexpr auto a_block_desc = [&]() {
121 // K0->M->K1 Per Block
122 constexpr auto K0PerBlock = KPerBlock / AK1;
123 constexpr auto max_lds_align = AK1;
124
125 if constexpr(ABlockLdsExtraM)
126 {
130 }
131 else
132 {
134 make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, AK1), max_lds_align);
135 }
136 }();
137
138 return a_block_desc;
139 }
140
141 __host__ __device__ static constexpr auto MakeB0BlockDescriptor()
142 {
143 constexpr auto b0_block_desc = [&]() {
144 // K0->L->BK1 Per Block
145 constexpr auto K0PerBlock = KPerBlock / BK1;
146 constexpr auto max_lds_align = BK1;
147
148 if constexpr(B0BlockLdsExtraL)
149 {
153 }
154 else
155 {
157 make_tuple(Number<K0PerBlock>{}, Number<LPerBlock>{}, BK1), max_lds_align);
158 }
159 }();
160
161 return b0_block_desc;
162 }
163
164 __host__ __device__ static constexpr auto MakeB1BlockDescriptor()
165 {
166 constexpr auto b1_block_desc = [&]() {
167 // L0->N->BL1 Per Block
168 constexpr auto max_lds_align = BL1;
169
170 if constexpr(B1BlockLdsExtraN)
171 {
175 }
176 else
177 {
179 make_tuple(Number<L0PerBlock>{}, Number<NPerBlock>{}, BL1), max_lds_align);
180 }
181 }();
182
183 return b1_block_desc;
184 }
185
186 __host__ __device__ static constexpr auto MakeABlockSliceCopyStep()
187 {
188 constexpr auto a_block_copy_step = [&]() { return make_multi_index(AK0, 0, 0); }();
189 return a_block_copy_step;
190 }
191
192 __host__ __device__ static constexpr auto MakeB0BlockSliceCopyStep()
193 {
194 constexpr auto b0_block_copy_step = [&]() { return make_multi_index(BK0, 0, 0); }();
195 return b0_block_copy_step;
196 }
197
198 __host__ __device__ static constexpr auto MakeB1BlockSliceCopyStep()
199 {
200 constexpr auto b1_block_copy_step = [&]() { return make_multi_index(L0PerBlock, 0, 0); }();
201 return b1_block_copy_step;
202 }
203
204 template <typename ABlockDesc_>
205 __host__ __device__ static constexpr auto MakeAWaveDescriptor(const ABlockDesc_&)
206 {
207 constexpr auto a_wave_desc = [&]() {
208 // AK0_M_AK1 -> AK0_MRepeat_Mwaves_AKRow_MPerWmma_AK1
209 constexpr auto A_K0 = ABlockDesc_{}.GetLength(I0);
210 constexpr auto A_K1 = ABlockDesc_{}.GetLength(I2);
211#ifdef __gfx12__
212 constexpr auto A_KRow = I2;
213#else
214 constexpr auto A_KRow = I1;
215#endif
217 ABlockDesc_{},
224 }();
225
226 return a_wave_desc;
227 }
228
229 template <typename B0BlockDesc_>
230 __host__ __device__ static constexpr auto MakeB0WaveDescriptor(const B0BlockDesc_&)
231 {
232 constexpr auto b0_wave_desc = [&]() {
233 // BK0_L_BK1 -> BK0_LRepeat_Lwaves_BKRow_LPerWmma_BK1
234 constexpr auto B_K0 = B0BlockDesc_{}.GetLength(I0);
235 constexpr auto B_K1 = B0BlockDesc_{}.GetLength(I2);
236#ifdef __gfx12__
237 constexpr auto B_KRow = I2;
238#else
239 constexpr auto B_KRow = I1;
240#endif
242 B0BlockDesc_{},
249 }();
250
251 return b0_wave_desc;
252 }
253
254 template <typename A1BlockDesc_AL0_M_AL1>
255 __host__ __device__ static constexpr auto
256 MakeA1WaveDescriptor_L0_M0_M1_M2_L1(const A1BlockDesc_AL0_M_AL1&)
257 {
258 constexpr index_t A_L0 = A1BlockDesc_AL0_M_AL1{}.GetLength(I0);
259 constexpr index_t A_L1 = A1BlockDesc_AL0_M_AL1{}.GetLength(I2);
260 constexpr auto A_LRow = I1;
262 A1BlockDesc_AL0_M_AL1{},
268 }
269
270 template <typename B1BlockDesc_>
271 __host__ __device__ static constexpr auto MakeB1WaveDescriptor(const B1BlockDesc_&)
272 {
273
274 constexpr auto b1_wave_desc = [&]() {
275 // BL0_N_BL1 -> BL0_NRepeat_Nwaves_NPerWmma_BL1
276 constexpr auto B_L0 = B1BlockDesc_{}.GetLength(I0);
277 constexpr auto B_L1 = B1BlockDesc_{}.GetLength(I2);
278#ifdef __gfx12__
279 constexpr auto B_LRow = I2;
280#else
281 constexpr auto B_LRow = I1;
282#endif
284 B1BlockDesc_{},
291 }();
292
293 return b1_wave_desc;
294 }
295
296 __host__ __device__ static constexpr auto
297 // *Caution Here repeat is shuffle repeat
299 {
300 constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat =
304 I1,
306
307 return c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat;
308 }
309
310 __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
311 {
312 // LDS allocation for A and B: be careful of alignment
313 const index_t gemm0_bytes_end =
314 (SharedMemTrait::a_block_space_size_aligned * sizeof(ADataType) +
316
317 const index_t gemm1_bytes_end =
320
321 const index_t acc0_bytes_end =
324
325 const index_t c_block_bytes_end =
326 SharedMemTrait::c_block_space_size * sizeof(CShuffleDataType);
327
328 return math::max(gemm0_bytes_end, gemm1_bytes_end, acc0_bytes_end, c_block_bytes_end);
329 }
330
331 // Blockwise gemm pipeline for gemm0, this replaces the old GridwiseGemmPipe +
332 // BlockwiseGemmWMMA. The latter had two enableLDS bools which we don't
333 // have anymore, the new pipelines ALWAYS use lds. Furthermore the original BlockwiseGemmWMMA
334 // used TransposeC = true which we still need to make the operation work.
337 BlkGemmPipelineVer,
338 BlkGemmPipeSched,
339 BlockSize,
340 ADataType,
341 B0DataType,
342 // TODO: Check if these compute types should always be
343 // equal to data type.
344 ADataType, // ComputeTypeA
345 B0DataType, // ComputeTypeB
346 Acc0DataType,
349 ABlockTransferSrcScalarPerVector,
350 B0BlockTransferSrcScalarPerVector,
351 MPerBlock,
352 LPerBlock,
353 KPerBlock,
354 MPerWmma,
355 LPerWmma,
356 MRepeat,
357 LRepeat,
358 KPack,
359 true>())>; // TransposeC (must be true to work), C' = B' x A'
360
361 // block_id to matrix tile idx (m0, n0) mapping is controlled by {M01, N01}
362 template <typename Block2CTileMap>
363 __host__ __device__ static constexpr bool CheckValidity(const AGridDesc& a_grid_desc,
364 const B0GridDesc& b0_grid_desc,
365 const B1GridDesc& b1_grid_desc,
366 const CGridDesc_M_N& c_grid_desc_m_n,
367 const Block2CTileMap& block_2_ctile_map)
368 {
369 // Print lambda with env check and printf() style formmating.
370 const char* curFunc = __func__;
371 auto print = [&curFunc](const char* format, ...) -> void {
372 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
373 {
374#if defined(__clang__)
375#pragma clang diagnostic push
376#pragma clang diagnostic ignored "-Wformat-nonliteral"
377#endif
378 va_list args;
379 va_start(args, format);
380 std::vfprintf(stdout, format, args);
381 va_end(args);
382#if defined(__clang__)
383#pragma clang diagnostic pop
384#endif
385 std::cout << "In file: " << __FILE__ << ", function: " << curFunc << "\n";
386 }
387 };
388
389 static_assert(MPerBlock % (MPerWmma * MRepeat) == 0 &&
390 LPerBlock % (LPerWmma * LRepeat) == 0 &&
391 NPerBlock % (NPerWmma * NRepeat) == 0,
392 "Invalid tuning param!");
393
394 const auto M = a_grid_desc.GetLength(I1);
395 const auto L = b0_grid_desc.GetLength(I1);
396 const auto K = a_grid_desc.GetLength(I0) * a_grid_desc.GetLength(I2);
397 const auto N = b1_grid_desc.GetLength(I1);
398
399 if(!(M == c_grid_desc_m_n.GetLength(I0) && N == c_grid_desc_m_n.GetLength(I1)))
400 {
401 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
402 {
403 print("GridwiseOp: M/N Length err, A_M/N = %d, %d | C_M/N = %d, %d\n",
404 M,
405 N,
406 c_grid_desc_m_n.GetLength(I0),
407 c_grid_desc_m_n.GetLength(I1));
408 }
409 return false;
410 }
411
412 if(!(M % MPerBlock == 0 && L % LPerBlock == 0 && K % KPerBlock == 0 && N % NPerBlock == 0))
413 {
414 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
415 {
416 print("GridwiseOp: M/L/K/N Division err, M/L/K/N = %d, %d, %d, %d | "
417 "M/L/K/NPerBlock = "
418 "%d, %d, %d, %d\n",
419 M,
420 L,
421 K,
422 N,
423 MPerBlock,
424 LPerBlock,
425 KPerBlock,
426 NPerBlock);
427 }
428 return false;
429 }
430
431 // check gemm1 gridwise gemm pipeline
432 if(!(LPerBlock % LTilePerBlock == 0))
433 {
434 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
435 {
436 print("GridwiseOp: inner loop division, L/LTilePerblock: %d, %d\n",
437 LPerBlock,
438 LTilePerBlock);
439 }
440 return false;
441 }
442
443 if(!block_2_ctile_map.CheckValidity(c_grid_desc_m_n))
444 {
445 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
446 {
447 print("GridwiseOp: invalid block_2_ctile_map\n");
448 }
449 return false;
450 }
451
452 // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
453 return true;
454 }
455
456 __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
457 {
458 const index_t num_loop = math::integer_divide_ceil(K, KPerBlock);
459 return BlockwiseGemmPipe::BlockHasHotloop(num_loop);
460 }
461
462 __host__ __device__ static constexpr TailNumber CalculateKBlockLoopTailNum(index_t K)
463 {
464 const index_t num_loop = math::integer_divide_ceil(K, KPerBlock);
465 return BlockwiseGemmPipe::BlockLoopTailNum(num_loop);
466 }
467
468 __host__ __device__ static constexpr auto
469 MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc_M_N& c_grid_desc_m_n)
470 {
471 const auto M = c_grid_desc_m_n.GetLength(I0);
472 const auto N = c_grid_desc_m_n.GetLength(I1);
473
474 const auto MBlock = M / MPerBlock;
475 const auto NBlock = N / NPerBlock;
476
477 const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor(
478 c_grid_desc_m_n,
483
484 return c_grid_desc_mblock_mperblock_nblock_nperblock;
485 }
486
487 // return block_id to C matrix tile idx (m0, n0) mapping
488 __host__ __device__ static constexpr auto MakeDefaultBlock2CTileMap(
489 const CGridDesc_M_N& c_grid_desc_m_n, index_t /* M01 */, index_t /* N01 */)
490 {
492 c_grid_desc_m_n);
493 }
494
497 CGridDesc_M_N{}))>;
499 remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}, 1, 1))>;
500
502 {
503 // LDS allocation for A and B: be careful of alignment
504 static constexpr auto max_lds_align = math::lcm(math::lcm(AK1, BK1), BL1);
505
507 MakeABlockDescriptor().GetElementSpaceSize(), max_lds_align);
509 MakeB0BlockDescriptor().GetElementSpaceSize(), max_lds_align);
511 MakeB1BlockDescriptor().GetElementSpaceSize(), max_lds_align);
512
513 static constexpr auto a_block_space_offset = 0;
515 static constexpr auto b1_block_space_offset = 0;
516
517 // LDS allocation for reduction
518 // Feature to add, IntraThread Reduction
521
522 static constexpr auto reduction_space_offset = 0;
523
524 // LDS allocation for C shuffle in LDS
525 static constexpr auto c_block_space_size =
527 .GetElementSpaceSize();
528 };
529
530 template <bool HasMainKBlockLoop,
531 TailNumber TailNum,
532 typename Block2CTileMap = DefaultBlock2CTileMap>
533 __device__ static void Run(const ADataType* __restrict__ p_a_grid,
534 const B0DataType* __restrict__ p_b0_grid,
535 const B1DataType* __restrict__ p_b1_grid,
536 CDataType* __restrict__ p_c_grid,
537 void* __restrict__ p_shared,
538 const AGridDesc& a_grid_desc,
539 const B0GridDesc& b0_grid_desc,
540 const B1GridDesc& b1_grid_desc,
542 c_grid_desc_mblock_mperblock_nblock_nperblock,
543 const AElementwiseOperation& a_element_op,
544 const B0ElementwiseOperation& b0_element_op,
545 const AccElementwiseOperation& acc_element_op,
546 const B1ElementwiseOperation& b1_element_op,
547 const CElementwiseOperation& c_element_op,
548 const Block2CTileMap& block_2_ctile_map)
549 {
550 // clang-format off
551/*******************************************************************************/
552// Memory buffer zone.
553 const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
554 p_a_grid, a_grid_desc.GetElementSpaceSize());
555 const auto b0_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
556 p_b0_grid, b0_grid_desc.GetElementSpaceSize());
557 const auto b1_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
558 p_b1_grid, b1_grid_desc.GetElementSpaceSize());
560 p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
561
562/*******************************************************************************/
563// BlockIdx.x -> [BlockId.m, BlockId.n]
564 const auto block_work_idx = block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
565 if(!block_2_ctile_map.ValidCTileIndex(
566 block_work_idx,
567 make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
568 c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
569 { return; }
570
571 // Store BlockId into SGPR
572 const index_t m_block_data_idx_on_grid = __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock);
573 const index_t n_block_data_idx_on_grid = __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock);
574
575/*******************************************************************************/
576// set up Gemm0
577/*******************************************************************************/
578
579/*******************************************************************************/
580// BlockLevel, A/B Matrix ThreadMapping in LDS, As Destination of BlockWise_Copy
581 constexpr auto a_block_desc = MakeABlockDescriptor();
582 constexpr auto b0_block_desc = MakeB0BlockDescriptor();
583
584 auto a_block_trait = [&](){
585 // A matrix blockwise copy
586 constexpr auto AK0PerBlock = KPerBlock/ AK1;
588 static_cast<ADataType*>(p_shared) + SharedMemTrait::a_block_space_offset,
590
591 auto a_blockwise_copy =
593/* typename SrcElementwiseOperation, */ AElementwiseOperation,
594/* typename DstElementwiseOperation, */ ck::tensor_operation::element_wise::PassThrough,
595/* InMemoryDataOperationEnum DstInMemOp, */ InMemoryDataOperationEnum::Set,
596/* typename BlockSliceLengths, */ Sequence<AK0PerBlock, MPerBlock, AK1>,
597/* typename ThreadClusterLengths, */ ABlockTransferThreadClusterLengths_K0_M_K1,
598/* typename ThreadClusterArrangeOrder, */ ABlockTransferThreadClusterArrangeOrder,
599/* typename SrcData, */ ADataType,
600/* typename DstData, */ ADataType,
601/* typename SrcDesc, */ decltype(a_grid_desc),
602/* typename DstDesc, */ decltype(a_block_desc),
603/* typename SrcDimAccessOrder, */ ABlockTransferSrcAccessOrder,
604/* typename DstDimAccessOrder, */ Sequence<0, 1, 2>,
605/* index_t SrcVectorDim, */ ABlockTransferSrcVectorDim,
606/* index_t DstVectorDim, */ 2,
607/* index_t SrcScalarPerVector, */ ABlockTransferSrcScalarPerVector,
608/* index_t DstScalarPerVector, */ ABlockTransferDstScalarPerVector_K1,
609/* index_t SrcScalarStrideInVector, */ 1,
610/* index_t DstScalarStrideInVector, */ 1,
611/* bool ThreadTransferSrcResetCoordinateAfterRun, */ AThreadTransferSrcResetCoordinateAfterRun,
612/* bool ThreadTransferDstResetCoordinateAfterRun, */ true,
613 BlockwiseGemmPipe::GlobalBufferNum>(
614 a_grid_desc,
615 make_multi_index(0, m_block_data_idx_on_grid, 0),
616 a_element_op,
617 a_block_desc,
618 make_multi_index(0, 0, 0),
620
621 return make_tuple(a_block_buf, a_blockwise_copy);
622 };
623
624 auto b0_block_trait = [&](){
626 static_cast<B0DataType*>(p_shared) + SharedMemTrait::b0_block_space_offset,
628
629 auto b0_blockwise_copy =
631 B0ElementwiseOperation,
635 B0BlockTransferThreadClusterLengths_K0_L_K1,
636 B0BlockTransferThreadClusterArrangeOrder,
637 B0DataType,
638 B0DataType,
639 decltype(b0_grid_desc),
640 decltype(b0_block_desc),
641 B0BlockTransferSrcAccessOrder,
643 B0BlockTransferSrcVectorDim,
644 2,
645 B0BlockTransferSrcScalarPerVector,
646 B0BlockTransferDstScalarPerVector_K1,
647 1,
648 1,
649 B0ThreadTransferSrcResetCoordinateAfterRun,
650 true,
651 BlockwiseGemmPipe::GlobalBufferNum>(
652 b0_grid_desc,
653 make_multi_index(0, 0, 0),
654 b0_element_op,
655 b0_block_desc,
656 make_multi_index(0, 0, 0),
658
659 return make_tuple(b0_block_buf, b0_blockwise_copy);
660 };
661
662 auto a_block_buf = a_block_trait()[I0];
663 auto a_blockwise_copy = a_block_trait()[I1];
664
665 auto b0_block_buf = b0_block_trait()[I0];
666 auto b0_blockwise_copy = b0_block_trait()[I1];
667
668/*******************************************************************************/
669 // Gemm0
670 // Blockwise GEMM0 pipeline
671 static_assert(std::is_default_constructible_v<BlockwiseGemmPipe>);
672 auto blockwise_gemm0_pipeline = BlockwiseGemmPipe{};
673 auto acc0_thread_buf = blockwise_gemm0_pipeline.GetCThreadBuffer();
674
675 // Note that we are using the transposeC version of GetCThreadDescriptor.
676 constexpr auto acc0_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs =
677 blockwise_gemm0_pipeline.GetCThreadDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs();
678
679 constexpr auto mrepeat = acc0_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I0);
680 constexpr auto mwave = acc0_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I1);
681 constexpr auto mthreadpersubgroup = acc0_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I2);
682 constexpr auto lrepeat = acc0_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I3);
683 constexpr auto lwave = acc0_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I4);
684 constexpr auto lsubgroup = acc0_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I5);
685 constexpr auto laccvgprs = acc0_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I6);
686
687 constexpr auto acc0_thread_desc_l0perblock_mperblock_l1 = transform_tensor_descriptor(
688 acc0_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs,
690 make_merge_transform_v3_division_mod(make_tuple(mrepeat, mwave, mthreadpersubgroup)),
691 make_pass_through_transform(laccvgprs)),
694
695/*******************************************************************************/
696 // Shift Per SUB_K
697 constexpr auto a_block_slice_copy_step = MakeABlockSliceCopyStep();
698 constexpr auto b0_block_slice_copy_step = MakeB0BlockSliceCopyStep();
699
700 const auto a_block_reset_copy_step = [&](){
701 return make_multi_index(-a_grid_desc.GetLength(I0), 0, 0);
702 }();
703
704 const auto b0_block_reset_copy_step = [&](){
705 return make_multi_index(-b0_grid_desc.GetLength(I0), LPerBlock, 0);
706 }();
707
708 const auto K = [&](){
709 return a_grid_desc.GetLength(I0) * a_grid_desc.GetLength(I2);
710 }();
711
712 const index_t KBlockMainLoop = __builtin_amdgcn_readfirstlane(K / KPerBlock);
713
714/*******************************************************************************/
715// set up Gemm1
716/*******************************************************************************/
717 // Acc0 thread buffer -> A1 thread buffer -> blockwise gemm
718 // A1 matrix in VGPR
719 constexpr auto A1ThreadSlice_L0PerBlock_MPerBlock_L1 = make_tuple(
723
724 constexpr auto A1ThreadSliceL0PerBlock = A1ThreadSlice_L0PerBlock_MPerBlock_L1[I0];
725 constexpr auto A1ThreadSliceMPerBlock = A1ThreadSlice_L0PerBlock_MPerBlock_L1[I1];
726 constexpr auto A1ThreadSliceL1 = A1ThreadSlice_L0PerBlock_MPerBlock_L1[I2];
727
728 constexpr auto a1_thread_desc_l0perblock_mperblock_l1 = make_naive_tensor_descriptor(
729 make_tuple(A1ThreadSliceL0PerBlock, A1ThreadSliceMPerBlock, A1ThreadSliceL1),
730 make_tuple(A1ThreadSliceMPerBlock * A1ThreadSliceL1, A1ThreadSliceL1, I1));
731
732 // A1 matrix blockwise copy
733 auto a1_blockwise_copy = ThreadwiseTensorSliceTransfer_StaticToStatic<
734 Acc0DataType,
735 ADataType,
736 decltype(acc0_thread_desc_l0perblock_mperblock_l1),
737 decltype(a1_thread_desc_l0perblock_mperblock_l1),
741 2,
743
745 a1_thread_desc_l0perblock_mperblock_l1.GetElementSpaceSize());
746
747 constexpr auto b1_block_desc = MakeB1BlockDescriptor();
748
749 auto b1_block_trait = [&](){
751 static_cast<B1DataType*>(p_shared) + SharedMemTrait::b1_block_space_offset,
753
754 auto b1_blockwise_copy = ThreadGroupTensorSliceTransfer_v4r1<
756/* typename SrcElementwiseOperation, */ B1ElementwiseOperation,
757/* typename DstElementwiseOperation, */ tensor_operation::element_wise::PassThrough,
758/* InMemoryDataOperationEnum DstInMemOp, */ InMemoryDataOperationEnum::Set,
759/* typename BlockSliceLengths, */ Sequence<BL0, NPerBlock, BL1>,
760/* typename ThreadClusterLengths, */ B1BlockTransferThreadClusterLengths_L0_N_L1,
761/* typename ThreadClusterArrangeOrder, */ B1BlockTransferThreadClusterArrangeOrder,
762/* typename SrcData, */ B1DataType,
763/* typename DstData, */ B1DataType,
764/* typename SrcDesc, */ decltype(b1_grid_desc),
765/* typename DstDesc, */ decltype(b1_block_desc),
766/* typename SrcDimAccessOrder, */ B1BlockTransferSrcAccessOrder,
767/* typename DstDimAccessOrder, */ Sequence<1, 0, 2>,
768/* index_t SrcVectorDim, */ B1BlockTransferSrcVectorDim,
769/* index_t DstVectorDim, */ 2,
770/* index_t SrcScalarPerVector, */ B1BlockTransferSrcScalarPerVector,
771/* index_t DstScalarPerVector, */ B1BlockTransferDstScalarPerVector_L1,
772/* index_t SrcScalarStrideInVector, */ 1,
773/* index_t DstScalarStrideInVector, */ 1,
774/* bool ThreadTransferSrcResetCoordinateAfterRun, */ B1ThreadTransferSrcResetCoordinateAfterRun,
775/* bool ThreadTransferDstResetCoordinateAfterRun, */ true, // DstResetCoord
776 1>( // Used to be NumGemmKPrefetchStage, never tested / used for != 1
777 b1_grid_desc,
778 make_multi_index(0, n_block_data_idx_on_grid, 0),
779 b1_element_op,
780 b1_block_desc,
781 make_multi_index(0, 0, 0),
783
784 return make_tuple(b1_block_buf, b1_blockwise_copy);
785 };
786
787 auto b1_block_buf = b1_block_trait()[I0];
788 auto b1_blockwise_copy = b1_block_trait()[I1];
789
790 constexpr auto b1_block_slice_copy_step = MakeB1BlockSliceCopyStep();
791
792 auto blockwise_gemm1 =
793 BlockwiseGemmWMMA<BlockSize,
794 ADataType,
795 B1DataType,
796 Acc1DataType,
797 decltype(MakeA1WaveDescriptor_L0_M0_M1_M2_L1(a1_thread_desc_l0perblock_mperblock_l1)),
798 decltype(MakeB1WaveDescriptor(b1_block_desc)),
799 MPerBlock,
800 NPerBlock,
801 LTilePerBlock,
802 MPerWmma,
803 NPerWmma,
804 MRepeat,
805 NRepeat,
806 KPack,
807 false, // Acc1EnableLds
808 true, // B1EnableLds
809 true>{make_tuple(0, 0, 0, 0, 0, 0)};
810
811 auto acc1_thread_buf = blockwise_gemm1.GetCThreadBuffer();
812
813 const auto L = [&](){
814 return b0_grid_desc.GetLength(I1);
815 }();
816
817 const index_t num_gemm1_l_block_outer_loop = L / LPerBlock;
818 constexpr index_t num_gemm1_l_block_inner_loop = LPerBlock / LTilePerBlock;
819
820 // Initialize C
821 StaticBuffer<AddressSpaceEnum::Vgpr, Acc1DataType, acc1_thread_buf.Size(), true> c_thread_buf;
822 c_thread_buf.Clear();
823
824 // Empty BScale struct for the blockwise pipeline.
825 using BScale = typename BlockwiseGemmPipe::Empty;
826 auto b_scale_struct = BScale{};
827
828/*******************************************************************************/
829 //
830 // Kernel Main Stage
831 //
832 index_t gemm1_l_block_outer_index = 0;
833 // Outer loop, along GEMM_L
834 // Inner loop, along GEMM_K
835 do {
836 blockwise_gemm0_pipeline.template Run<HasMainKBlockLoop, TailNum>(a_grid_desc,
837 a_block_desc,
838 a_blockwise_copy,
839 a_grid_buf,
840 a_block_buf,
841 a_block_slice_copy_step,
842 b0_grid_desc,
843 b0_block_desc,
844 b0_blockwise_copy,
845 b0_grid_buf,
846 b0_block_buf,
847 b0_block_slice_copy_step,
848 acc0_thread_buf,
849 b_scale_struct,
850 KBlockMainLoop,
851 1); // num_k_block_per_scale
852
853 static_for<0, acc0_thread_buf.Size(), 1>{}(
854 [&](auto i) { acc_element_op(acc0_thread_buf(i), acc0_thread_buf[i]); });
855
857
858 // gemm1
859 {
860 // TODO: explore using dynamic buffer for a1 thread buffer
861 // For a1_blockwise_copy, the goal is to satisfy pipeline requirements RunRead(),
862 // RunWrite(), and MoveSliceWindow(). But it is impossible to implement given that
863 // the A1 source buffer is static buffer holding the output of first GEMM and
864 // requires constexpr offset by design. Therefore, we pass tensor coordinate offset
865 // explicitly in Run() below.
866
867 // Initialize acc1
868 acc1_thread_buf.Clear();
869
870 // preload data into LDS
871 b1_blockwise_copy.RunRead(b1_grid_desc, b1_grid_buf);
872
873 b1_blockwise_copy.MoveSrcSliceWindow(b1_grid_desc,
874 b1_block_slice_copy_step);
875
876 block_sync_lds(); // wait for reduction LDS read
877
878 b1_blockwise_copy.RunWrite(b1_block_desc, b1_block_buf);
879
880 // main body
881 if constexpr(num_gemm1_l_block_inner_loop > 1)
882 {
883 static_for<0, num_gemm1_l_block_inner_loop - 1, 1>{}([&](auto i) {
884 // Data cast from Acc0DataType to ADataType happens here
885 a1_blockwise_copy.Run(acc0_thread_desc_l0perblock_mperblock_l1,
887 acc0_thread_buf,
888 a1_thread_desc_l0perblock_mperblock_l1,
889 make_tuple(I0, I0, I0),
890 a1_thread_buf);
891
892 b1_blockwise_copy.RunRead(b1_grid_desc, b1_grid_buf);
893
895
896 blockwise_gemm1.Run(a1_thread_buf, b1_block_buf, acc1_thread_buf);
897
899
900 b1_blockwise_copy.MoveSrcSliceWindow(b1_grid_desc,
901 b1_block_slice_copy_step);
902
903 b1_blockwise_copy.RunWrite(b1_block_desc, b1_block_buf);
904 });
905 }
906 // tail
907 {
908 a1_blockwise_copy.Run(
909 acc0_thread_desc_l0perblock_mperblock_l1,
911 Number<(num_gemm1_l_block_inner_loop - 1) * A1ThreadSliceL0PerBlock>{}, I0, I0),
912 acc0_thread_buf,
913 a1_thread_desc_l0perblock_mperblock_l1,
914 make_tuple(I0, I0, I0),
915 a1_thread_buf);
916
918
919 blockwise_gemm1.Run(a1_thread_buf, b1_block_buf, acc1_thread_buf);
920 }
921 } // end gemm1
922
923 constexpr auto c_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs =
924 blockwise_gemm1.GetCThreadDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs();
925 constexpr auto c_mrepeat = c_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I0);
926 constexpr auto c_mwave = c_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I1);
927 constexpr auto c_mthreadpersubgroup = c_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I2);
928 constexpr auto c_nrepeat = c_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I3);
929 constexpr auto c_nwave = c_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I4);
930 constexpr auto c_nsubgroup = c_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I5);
931 constexpr auto c_naccvgprs = c_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I6);
932
933 constexpr auto c_thread_slice_desc_m_n = make_naive_tensor_descriptor_packed(
934 make_tuple(c_mrepeat * c_mwave * c_mthreadpersubgroup,
935 c_nrepeat * c_nwave * c_nsubgroup * c_naccvgprs));
936 constexpr auto c_thread_buf_slice_m = c_thread_slice_desc_m_n.GetLength(I0);
937 constexpr auto c_thread_buf_slice_n = c_thread_slice_desc_m_n.GetLength(I1);
938
941 auto I = Number<c_thread_slice_desc_m_n.CalculateOffset(make_tuple(iM, iN))>{};
942 Acc1DataType acc1 = acc1_thread_buf[I]; // P*V
943 Acc1DataType c = c_thread_buf[I]; // O
944 Acc1DataType c_new = c + acc1; // Simply add results since we are no longer using softmax.
945
946 c_thread_buf(I) = c_new; // O_new
947 });
948 });
949
950 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc,
951 a_block_reset_copy_step); // rewind K
952 b0_blockwise_copy.MoveSrcSliceWindow(b0_grid_desc,
953 b0_block_reset_copy_step); // rewind K and step N
954
955 block_sync_lds(); // wait for gemm1 LDS read
956 }while(++gemm1_l_block_outer_index < num_gemm1_l_block_outer_loop);
957/*******************************************************************************/
958 // write out to C, implement shuffle
959 {
960 constexpr auto c_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs =
961 blockwise_gemm1.GetCThreadDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs();
962
963 // This API Provide All dimension (size) you need
964 constexpr auto c_block_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs_tmp =
965 blockwise_gemm1.GetCBlockDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs();
966
967 constexpr auto MWave = c_block_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs_tmp.GetLength(I1);
968 constexpr auto MThreadPerSubGroup = c_block_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs_tmp.GetLength(I2);
969 constexpr auto NWave = c_block_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs_tmp.GetLength(I4);
970 constexpr auto NSubGroup = c_block_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs_tmp.GetLength(I5);
971 constexpr auto NAccVgprs = c_block_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs_tmp.GetLength(I6);
972
973 // LDS descriptor, shuffle and write out in MRepeat x NRepeat times
974 constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat =
976
977 auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
978 static_cast<CShuffleDataType*>(p_shared),
979 c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat.GetElementSpaceSize());
980
981 constexpr auto c_block_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs = transform_tensor_descriptor(
982 c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat,
986 Number<CShuffleMRepeatPerShuffle>{}, // MRepeat per shuffle repeat
987 MWave, // MWave
988 MThreadPerSubGroup // MThreadPerSubGroup = MPerWmma
989 )),
992 Number<CShuffleNRepeatPerShuffle>{}, // NRepeat per shuffle repeat
993 NWave, // NWave
994 NSubGroup,
995 NAccVgprs))), // NSubGroup * NAccVgprs = NPerWmma
998
999 // calculate origin of thread output tensor on global memory
1000 // blockwise GEMM c matrix starting index
1001 const auto c_thread_mtx_on_block = blockwise_gemm1.CalculateCThreadOriginDataIndex(I0, I0);
1002
1003 const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
1004 const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
1005
1006 const auto m_thread_data_on_block_to_mrepeat_mwave_mthreadpersubgroup_adaptor =
1008 make_tuple(make_merge_transform(make_tuple(MRepeat, MWave, MThreadPerSubGroup))),
1011
1012 const auto n_thread_data_on_block_to_nrepeat_nwave_nsubgroup_naccvgprs_adaptor =
1014 make_tuple(make_merge_transform(make_tuple(NRepeat, NWave, NSubGroup, NAccVgprs))),
1017
1018 const auto m_thread_data_on_block_idx = m_thread_data_on_block_to_mrepeat_mwave_mthreadpersubgroup_adaptor.CalculateBottomIndex(
1019 make_multi_index(m_thread_data_on_block));
1020
1021 const auto n_thread_data_on_block_idx = n_thread_data_on_block_to_nrepeat_nwave_nsubgroup_naccvgprs_adaptor.CalculateBottomIndex(
1022 make_multi_index(n_thread_data_on_block));
1023
1024 // shuffle: threadwise copy C from VGPR to LDS
1025 auto c_thread_copy_vgpr_to_lds =
1027 CShuffleDataType,
1028 decltype(c_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs),
1029 decltype(c_block_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs),
1031 Sequence<CShuffleMRepeatPerShuffle,
1032 I1,
1033 I1,
1034 CShuffleNRepeatPerShuffle,
1035 I1,
1036 I1,
1037 NAccVgprs>,
1039 6,
1040 8, // vector write pixel
1042 1,
1043 true>{
1044 c_block_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs,
1046 m_thread_data_on_block_idx[I1],
1047 m_thread_data_on_block_idx[I2],
1048 0,
1049 n_thread_data_on_block_idx[I1],
1050 n_thread_data_on_block_idx[I2],
1051 n_thread_data_on_block_idx[I3]),
1053
1054 // shuffle: blockwise copy C from LDS to global
1055 auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1<
1056 ThisThreadBlock, // ThreadGroup
1057 CElementwiseOperation, // ElementwiseOperation,
1058 CGlobalMemoryDataOperation, // DstInMemOp,
1059 Sequence<1,
1060 CShuffleMRepeatPerShuffle * MWave * MPerWmma,
1061 1,
1062 CShuffleNRepeatPerShuffle * NWave * NPerWmma>, // BlockSliceLengths,
1063 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
1064 Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
1065 CShuffleDataType, // typename SrcData,
1066 CDataType, // typename DstData,
1067 decltype(c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat),
1068 decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
1069 Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
1070 3, // index_t VectorDim,
1071 CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector,
1072 true, // bool ThreadTransferSrcResetCoordinateAfterRun,
1073 false> // bool ThreadTransferDstResetCoordinateAfterRun>
1074 {c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat,
1075 make_multi_index(0, 0, 0, 0),
1076 c_grid_desc_mblock_mperblock_nblock_nperblock,
1077 make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0),
1078 c_element_op};
1079
1080 // space filling curve for local reg & global memory
1081 // space filling curve for threadwise C in VGPR
1082 constexpr auto sfc_c_vgpr =
1085 Sequence<CShuffleMRepeatPerShuffle,
1086 1,
1087 1,
1088 CShuffleNRepeatPerShuffle,
1089 1,
1090 1,
1091 NAccVgprs>>{};
1092
1093 // space filling curve for shuffled blockwise C in global mem
1094 constexpr auto sfc_c_global =
1097 Sequence<1,
1098 CShuffleMRepeatPerShuffle * MWave * MPerWmma,
1099 1,
1100 CShuffleNRepeatPerShuffle * NWave * NPerWmma>>{};
1101
1102 constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
1103
1104 static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!");
1105
1106 static_for<0, num_access, 1>{}([&](auto access_id) {
1107 // make sure it's safe to write to LDS
1109
1110 // each thread write its data from VGPR to LDS
1111 c_thread_copy_vgpr_to_lds.Run(c_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs,
1112 sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
1113 c_thread_buf,
1114 c_block_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs,
1115 c_shuffle_block_buf);
1116
1117 // make sure it's safe to read from LDS
1119
1120 // each block copy its data from LDS to global
1121 c_shuffle_block_copy_lds_to_global.Run(
1122 c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat,
1123 c_shuffle_block_buf,
1124 c_grid_desc_mblock_mperblock_nblock_nperblock,
1125 c_grid_buf);
1126
1127 if constexpr(access_id < num_access - 1)
1128 {
1129 constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id);
1130 // move on C
1131 c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow(
1132 c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step);
1133 }
1134 });
1135 }
1136 // clang-format on
1137 }
1138};
1139
1140} // namespace ck
__host__ __device__ constexpr auto integer_least_multiple(X x, Y y)
Definition utility/math.hpp:78
__host__ __device__ constexpr T max(T x)
Definition utility/math.hpp:84
__host__ __device__ constexpr auto integer_divide_ceil(X x, Y y)
Definition utility/math.hpp:72
__host__ __device__ constexpr auto lcm(X x, Y y)
Definition utility/math.hpp:198
Definition ck.hpp:268
__host__ __device__ constexpr auto make_multi_index(Xs &&... xs)
Definition array_multi_index.hpp:15
__host__ __device__ constexpr auto make_static_buffer(Number< N >)
Definition static_buffer.hpp:186
__host__ __device__ constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition multi_index_transform_helper.hpp:12
int32_t index_t
Definition ck.hpp:299
__host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition tensor_descriptor_helper.hpp:49
InMemoryDataOperationEnum
Definition ck.hpp:277
@ Set
Definition ck.hpp:278
__host__ __device__ constexpr auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition tensor_description/tensor_adaptor.hpp:425
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
BlockGemmPipelineVersion
Definition blkgemmpipe_scheduler.hpp:12
@ v1
Definition blkgemmpipe_scheduler.hpp:14
__host__ __device__ constexpr auto make_freeze_transform(const LowerIndex &low_idx)
Definition multi_index_transform_helper.hpp:151
integral_constant< index_t, N > Number
Definition number.hpp:12
TailNumber
Definition blkgemmpipe_scheduler.hpp:31
@ Vgpr
Definition amd_address_space.hpp:20
constexpr auto BlockGemmPipeline_Selector()
Definition blockwise_gemm_pipeline_wmma_selector.hpp:32
__host__ __device__ constexpr auto make_merge_transform(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:55
__device__ index_t get_block_1d_id()
Definition get_id.hpp:47
bool EnvIsEnabled(EnvVar)
Definition utility/env.hpp:140
__host__ __device__ constexpr auto make_merge_transform_v3_division_mod(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:84
BlockGemmPipelineScheduler
Definition blkgemmpipe_scheduler.hpp:25
@ Intrawave
Definition blkgemmpipe_scheduler.hpp:26
__host__ __device__ constexpr auto make_naive_tensor_descriptor_aligned(const Tuple< Lengths... > &lengths, Align align)
Definition tensor_descriptor_helper.hpp:132
__host__ __device__ constexpr auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition tensor_descriptor_helper.hpp:101
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
__device__ void block_sync_lds()
Definition synchronization.hpp:16
__host__ __device__ constexpr auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:90
__host__ __device__ constexpr auto make_dynamic_buffer(T *p, ElementSpaceSize element_space_size)
Definition dynamic_buffer.hpp:472
Definition block_to_ctile_map.hpp:261
Definition blockwise_gemm_wmma.hpp:550
Definition gridwise_batched_gemm_gemm_wmma_cshuffle_v3.hpp:502
static constexpr auto a_block_space_offset
Definition gridwise_batched_gemm_gemm_wmma_cshuffle_v3.hpp:513
static constexpr auto b0_block_space_offset
Definition gridwise_batched_gemm_gemm_wmma_cshuffle_v3.hpp:514
static constexpr auto a_block_space_size_aligned
Definition gridwise_batched_gemm_gemm_wmma_cshuffle_v3.hpp:506
static constexpr auto max_lds_align
Definition gridwise_batched_gemm_gemm_wmma_cshuffle_v3.hpp:504
static constexpr auto b1_block_space_offset
Definition gridwise_batched_gemm_gemm_wmma_cshuffle_v3.hpp:515
static constexpr auto b1_block_space_size_aligned
Definition gridwise_batched_gemm_gemm_wmma_cshuffle_v3.hpp:510
static constexpr auto b0_block_space_size_aligned
Definition gridwise_batched_gemm_gemm_wmma_cshuffle_v3.hpp:508
ck::GridwiseBatchedGemmGemm_wmma_cshuffle_v3< ADataType, B0DataType, AccDataType, B1DataType, AccDataType, CShuffleDataType, CDataType, AElementwiseOperation, B0ElementwiseOperation, AccElementwiseOperation, B1ElementwiseOperation, CElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc, B0GridDesc, B1GridDesc, CGridDesc_M_N, MPerBlock, LPerBlock, KPerBlock, AK1, BK1, NPerBlock, LTilePerBlock, L1, MPerWmma, LPerWmma, NPerWmma, MRepeat, LRepeat, NRepeat, BlockSize, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_K1, true, ABlockLdsAddExtraM, B0BlockTransferThreadClusterLengths_K0_L_K1, B0BlockTransferThreadClusterArrangeOrder, B0BlockTransferSrcAccessOrder, B0BlockTransferSrcVectorDim, B0BlockTransferSrcScalarPerVector, B0BlockTransferDstScalarPerVector_K1, true, B0BlockLdsAddExtraL, B1BlockTransferThreadClusterLengths_L0_N_L1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_L1, false, B1BlockLdsAddExtraN, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, Transform::matrix_padder.PadN, BlkGemmPipeSched, BlkGemmPipelineVer >::SharedMemTrait::reduction_space_size_aligned
static constexpr index_t reduction_space_size_aligned
Definition gridwise_batched_gemm_gemm_wmma_cshuffle_v3.hpp:519
static constexpr auto reduction_space_offset
Definition gridwise_batched_gemm_gemm_wmma_cshuffle_v3.hpp:522
static constexpr auto c_block_space_size
Definition gridwise_batched_gemm_gemm_wmma_cshuffle_v3.hpp:525
Definition gridwise_batched_gemm_gemm_wmma_cshuffle_v3.hpp:88
static __device__ void Run(const ADataType *__restrict__ p_a_grid, const B0DataType *__restrict__ p_b0_grid, const B1DataType *__restrict__ p_b1_grid, CDataType *__restrict__ p_c_grid, void *__restrict__ p_shared, const AGridDesc &a_grid_desc, const B0GridDesc &b0_grid_desc, const B1GridDesc &b1_grid_desc, const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock &c_grid_desc_mblock_mperblock_nblock_nperblock, const AElementwiseOperation &a_element_op, const B0ElementwiseOperation &b0_element_op, const AccElementwiseOperation &acc_element_op, const B1ElementwiseOperation &b1_element_op, const CElementwiseOperation &c_element_op, const Block2CTileMap &block_2_ctile_map)
Definition gridwise_batched_gemm_gemm_wmma_cshuffle_v3.hpp:533
__host__ static __device__ constexpr auto MakeB0BlockDescriptor()
Definition gridwise_batched_gemm_gemm_wmma_cshuffle_v3.hpp:141
remove_cvref_t< decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}, 1, 1))> DefaultBlock2CTileMap
Definition gridwise_batched_gemm_gemm_wmma_cshuffle_v3.hpp:498
static constexpr auto I1
Definition gridwise_batched_gemm_gemm_wmma_cshuffle_v3.hpp:90
static constexpr auto AL0
Definition gridwise_batched_gemm_gemm_wmma_cshuffle_v3.hpp:103
static constexpr auto BL1
Definition gridwise_batched_gemm_gemm_wmma_cshuffle_v3.hpp:106
static constexpr auto I6
Definition gridwise_batched_gemm_gemm_wmma_cshuffle_v3.hpp:95
__host__ static __device__ constexpr auto MakeB1BlockSliceCopyStep()
Definition gridwise_batched_gemm_gemm_wmma_cshuffle_v3.hpp:198
static constexpr auto I4
Definition gridwise_batched_gemm_gemm_wmma_cshuffle_v3.hpp:93
__host__ static __device__ constexpr index_t GetSharedMemoryNumberOfByte()
Definition gridwise_batched_gemm_gemm_wmma_cshuffle_v3.hpp:310
static constexpr auto AK1
Definition gridwise_batched_gemm_gemm_wmma_cshuffle_v3.hpp:98
static constexpr auto I2
Definition gridwise_batched_gemm_gemm_wmma_cshuffle_v3.hpp:91
__host__ static __device__ constexpr TailNumber CalculateKBlockLoopTailNum(index_t K)
Definition gridwise_batched_gemm_gemm_wmma_cshuffle_v3.hpp:462
__host__ static __device__ constexpr auto MakeB0BlockSliceCopyStep()
Definition gridwise_batched_gemm_gemm_wmma_cshuffle_v3.hpp:192
__host__ static __device__ constexpr bool CalculateHasMainKBlockLoop(index_t K)
Definition gridwise_batched_gemm_gemm_wmma_cshuffle_v3.hpp:456
__host__ static __device__ constexpr auto GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat()
Definition gridwise_batched_gemm_gemm_wmma_cshuffle_v3.hpp:298
static constexpr auto NWaves
Definition gridwise_batched_gemm_gemm_wmma_cshuffle_v3.hpp:110
static constexpr auto LWaves
Definition gridwise_batched_gemm_gemm_wmma_cshuffle_v3.hpp:109
static constexpr auto I5
Definition gridwise_batched_gemm_gemm_wmma_cshuffle_v3.hpp:94
ThisThreadBlock< BlockSize > ThisThreadBlock
Definition gridwise_batched_gemm_gemm_wmma_cshuffle_v3.hpp:116
static constexpr auto KPack
Definition gridwise_batched_gemm_gemm_wmma_cshuffle_v3.hpp:113
__host__ static __device__ constexpr auto MakeAWaveDescriptor(const ABlockDesc_ &)
Definition gridwise_batched_gemm_gemm_wmma_cshuffle_v3.hpp:205
static constexpr auto BL0
Definition gridwise_batched_gemm_gemm_wmma_cshuffle_v3.hpp:105
__host__ static __device__ constexpr auto MakeDefaultBlock2CTileMap(const CGridDesc_M_N &c_grid_desc_m_n, index_t, index_t)
Definition gridwise_batched_gemm_gemm_wmma_cshuffle_v3.hpp:488
__host__ static __device__ constexpr auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc_M_N &c_grid_desc_m_n)
Definition gridwise_batched_gemm_gemm_wmma_cshuffle_v3.hpp:469
static constexpr auto I0
Definition gridwise_batched_gemm_gemm_wmma_cshuffle_v3.hpp:89
__host__ static __device__ constexpr auto MakeA1WaveDescriptor_L0_M0_M1_M2_L1(const A1BlockDesc_AL0_M_AL1 &)
Definition gridwise_batched_gemm_gemm_wmma_cshuffle_v3.hpp:256
static constexpr auto BK0
Definition gridwise_batched_gemm_gemm_wmma_cshuffle_v3.hpp:99
static constexpr auto BK1
Definition gridwise_batched_gemm_gemm_wmma_cshuffle_v3.hpp:100
static constexpr auto MWaves
Definition gridwise_batched_gemm_gemm_wmma_cshuffle_v3.hpp:108
__host__ static __device__ constexpr bool CheckValidity(const AGridDesc &a_grid_desc, const B0GridDesc &b0_grid_desc, const B1GridDesc &b1_grid_desc, const CGridDesc_M_N &c_grid_desc_m_n, const Block2CTileMap &block_2_ctile_map)
Definition gridwise_batched_gemm_gemm_wmma_cshuffle_v3.hpp:363
__host__ static __device__ constexpr auto MakeB1WaveDescriptor(const B1BlockDesc_ &)
Definition gridwise_batched_gemm_gemm_wmma_cshuffle_v3.hpp:271
static constexpr auto AL1
Definition gridwise_batched_gemm_gemm_wmma_cshuffle_v3.hpp:104
__host__ static __device__ constexpr auto MakeB0WaveDescriptor(const B0BlockDesc_ &)
Definition gridwise_batched_gemm_gemm_wmma_cshuffle_v3.hpp:230
__host__ static __device__ constexpr auto MakeABlockDescriptor()
Definition gridwise_batched_gemm_gemm_wmma_cshuffle_v3.hpp:118
__host__ static __device__ constexpr auto MakeABlockSliceCopyStep()
Definition gridwise_batched_gemm_gemm_wmma_cshuffle_v3.hpp:186
__host__ static __device__ constexpr auto MakeB1BlockDescriptor()
Definition gridwise_batched_gemm_gemm_wmma_cshuffle_v3.hpp:164
ck::GridwiseBatchedGemmGemm_wmma_cshuffle_v3< ADataType, B0DataType, AccDataType, B1DataType, AccDataType, CShuffleDataType, CDataType, AElementwiseOperation, B0ElementwiseOperation, AccElementwiseOperation, B1ElementwiseOperation, CElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc, B0GridDesc, B1GridDesc, CGridDesc_M_N, MPerBlock, LPerBlock, KPerBlock, AK1, BK1, NPerBlock, LTilePerBlock, L1, MPerWmma, LPerWmma, NPerWmma, MRepeat, LRepeat, NRepeat, BlockSize, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_K1, true, ABlockLdsAddExtraM, B0BlockTransferThreadClusterLengths_K0_L_K1, B0BlockTransferThreadClusterArrangeOrder, B0BlockTransferSrcAccessOrder, B0BlockTransferSrcVectorDim, B0BlockTransferSrcScalarPerVector, B0BlockTransferDstScalarPerVector_K1, true, B0BlockLdsAddExtraL, B1BlockTransferThreadClusterLengths_L0_N_L1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_L1, false, B1BlockLdsAddExtraN, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, Transform::matrix_padder.PadN, BlkGemmPipeSched, BlkGemmPipelineVer >::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
remove_cvref_t< decltype(MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(CGridDesc_M_N{}))> CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
Definition gridwise_batched_gemm_gemm_wmma_cshuffle_v3.hpp:495
remove_cvref_t< decltype(BlockGemmPipeline_Selector< BlkGemmPipelineVer, BlkGemmPipeSched, BlockSize, ADataType, B0DataType, ADataType, B0DataType, AccDataType, decltype(MakeAWaveDescriptor(MakeABlockDescriptor())), decltype(MakeB0WaveDescriptor(MakeB0BlockDescriptor())), ABlockTransferSrcScalarPerVector, B0BlockTransferSrcScalarPerVector, MPerBlock, LPerBlock, KPerBlock, MPerWmma, LPerWmma, MRepeat, LRepeat, KPack, true >())> BlockwiseGemmPipe
Definition gridwise_batched_gemm_gemm_wmma_cshuffle_v3.hpp:335
static constexpr auto AK0
Definition gridwise_batched_gemm_gemm_wmma_cshuffle_v3.hpp:97
static constexpr auto L0PerBlock
Definition gridwise_batched_gemm_gemm_wmma_cshuffle_v3.hpp:102
static constexpr auto I3
Definition gridwise_batched_gemm_gemm_wmma_cshuffle_v3.hpp:92
Definition utility/sequence.hpp:43
Definition tensor_space_filling_curve.hpp:20
Definition static_buffer.hpp:16
Blockwise data transfer.
Definition thread_group_tensor_slice_transfer_v4r1.hpp:46
Definition thread_group_tensor_slice_transfer_v6r1.hpp:34
Threadwise data transfer.
Definition threadwise_tensor_slice_transfer.hpp:1720
Definition threadwise_tensor_slice_transfer.hpp:39
Definition functional2.hpp:33
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:340
#define CK_ENV(name)
Definition utility/env.hpp:129