gridwise_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp Source File

gridwise_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp Source File#

Composable Kernel: gridwise_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp Source File
gridwise_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2023-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
19
22
23namespace ck {
24
25template <typename GridwiseGemm,
26 typename ADataType,
27 typename BDataType,
28 typename DsPointer,
29 typename EDataType,
30 typename AElementwiseOperation,
31 typename BElementwiseOperation,
32 typename CDEElementwiseOperation,
33 typename AGridDesc_AK0_M_AK1,
34 typename BGridDesc_BK0_N_BK1,
35 typename DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
36 typename EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
37 typename Block2ETileMap,
38 bool HasMainKBlockLoop>
39__global__ void
40#if CK_USE_LAUNCH_BOUNDS
42#endif
44 const ADataType* __restrict__ p_a_grid,
45 const BDataType* __restrict__ p_b_grid,
46 DsPointer p_ds_grid,
47 EDataType* __restrict__ p_e_grid,
48 const AElementwiseOperation a_element_op,
49 const BElementwiseOperation b_element_op,
50 const CDEElementwiseOperation cde_element_op,
51 const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
52 const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
53 const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
54 ds_grid_desc_mblock_mperblock_nblock_nperblock,
55 const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
56 e_grid_desc_mblock_mperblock_nblock_nperblock,
57 const Block2ETileMap block_2_etile_map)
58{
59#if(defined(__gfx90a__) || defined(__gfx94__))
60 if constexpr(GridwiseGemm::template IsValidCompilationParameter<>())
61 {
62 __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
63
64 GridwiseGemm::template Run<HasMainKBlockLoop>(
65 p_a_grid,
66 p_b_grid,
67 p_ds_grid,
68 p_e_grid,
69 p_shared,
70 a_element_op,
71 b_element_op,
72 cde_element_op,
73 a_grid_desc_ak0_m_ak1,
74 b_grid_desc_bk0_n_bk1,
75 ds_grid_desc_mblock_mperblock_nblock_nperblock,
76 e_grid_desc_mblock_mperblock_nblock_nperblock,
77 block_2_etile_map);
78 }
79#else
80 ignore = p_a_grid;
81 ignore = p_b_grid;
82 ignore = p_ds_grid;
83 ignore = p_e_grid;
84 ignore = a_element_op;
85 ignore = b_element_op;
86 ignore = cde_element_op;
87 ignore = a_grid_desc_ak0_m_ak1;
88 ignore = b_grid_desc_bk0_n_bk1;
89 ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock;
90 ignore = e_grid_desc_mblock_mperblock_nblock_nperblock;
91 ignore = block_2_etile_map;
92#endif
93}
94
95// GEMM:
96// input : A[M, K]
97// input : B[N, K]
98// input : D0[M, N], D1[M, N], ...
99// output : E[M, N]
100// C = a_op(A) * b_op(B)
101// E = cde_op(C, D0, D1, ...)
102// Assume:
103// D0, D1, ... and E have the same layout
104template <typename ALayout,
105 typename BLayout,
106 typename DsLayout,
107 typename ELayout,
108 typename ADataType,
109 typename BDataType,
110 typename AComputeDataType_,
111 typename AccDataType,
112 typename CShuffleDataType,
113 typename DsDataType,
114 typename EDataType,
115 typename AElementwiseOperation,
116 typename BElementwiseOperation,
117 typename CDEElementwiseOperation,
118 InMemoryDataOperationEnum EGlobalMemoryDataOperation,
120 index_t NumGemmKPrefetchStage,
121 index_t BlockSize,
122 index_t MPerBlock,
123 index_t NPerBlock,
124 index_t KPerBlock,
125 index_t AK1Value,
126 index_t BK1Value,
127 index_t MPerXdl,
128 index_t NPerXdl,
129 index_t MXdlPerWave,
130 index_t NXdlPerWave,
131 typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
132 typename ABlockTransferSrcAccessOrder,
133 index_t ABlockTransferSrcVectorDim,
134 index_t ABlockTransferScalarPerVector,
135 index_t ABlockLdsExtraM,
136 typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
137 typename BBlockTransferSrcAccessOrder,
138 index_t BBlockTransferSrcVectorDim,
139 index_t BBlockTransferScalarPerVector,
140 index_t BBlockLdsExtraN,
141 index_t CShuffleMXdlPerWavePerShuffle,
142 index_t CShuffleNXdlPerWavePerShuffle,
143 typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
144 index_t CDEShuffleBlockTransferScalarPerVector_NPerBlock,
145 LoopScheduler LoopSched,
147 typename BComputeDataType_ = AComputeDataType_>
149{
150 static constexpr index_t NumDTensor = DsDataType::Size();
151
152 static constexpr auto I0 = Number<0>{};
153 static constexpr auto I1 = Number<1>{};
154 static constexpr auto I2 = Number<2>{};
155 static constexpr auto I3 = Number<3>{};
156 static constexpr auto I4 = Number<4>{};
157 static constexpr auto I5 = Number<5>{};
158 static constexpr auto I6 = Number<6>{};
159 static constexpr auto I7 = Number<7>{};
160
161 static constexpr auto AK1 = Number<AK1Value>{};
162 static constexpr auto BK1 = Number<BK1Value>{};
163 static constexpr auto AK0PerBlock = Number<KPerBlock / AK1Value>{};
164 static constexpr auto BK0PerBlock = Number<KPerBlock / BK1Value>{};
165
167
170
171#if CK_GFX90A_DENORM_WORKAROUND
172 using AComputeDataType =
174#else
179#endif
180
181 __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
182 {
184 {
185 // FIXME: our support to non-K contiguous layout is limited, only work in some specific
186 // setting
189 }
190 else
191 {
194 }
195 }
196
197 __host__ __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
198 {
200 {
201 // FIXME: our support to non-K contiguous layout is limited, only work in some specific
202 // setting
205 }
206 else
207 {
210 }
211 }
212
213 __host__ __device__ static constexpr auto
215 {
216 constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
217 constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
218
219 constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
223 I1,
225
226 return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock;
227 }
228
229 // ck::Tuple<const D0DataType*, const D1DataType*, ...>
230 static constexpr auto MakeDsGridPointer()
231 {
232 return generate_tuple(
233 [&](auto i) {
234 using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
235
236 return static_cast<const DDataType*>(nullptr);
237 },
239 }
240
241 __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
242 {
243 // LDS allocation for A and B: be careful of alignment.
244 constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
245 constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
246
247 constexpr auto max_lds_align = math::lcm(AK1, BK1);
248
249 constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
250 a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
251
252 constexpr auto b_block_space_size_aligned = math::integer_least_multiple(
253 b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
254
255 // LDS allocation for C shuffle.
256 constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
258
259 constexpr auto c_block_size =
260 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
261
262 return math::max(
263 NumGemmKPrefetchStage * a_block_space_size_aligned * sizeof(AComputeDataType) +
264 NumGemmKPrefetchStage * b_block_space_size_aligned * sizeof(BComputeDataType),
265 c_block_size * sizeof(CShuffleDataType));
266 }
267
268 __host__ __device__ static auto
270 {
271 constexpr auto matrix_padder =
273 MPerBlock, NPerBlock, KPerBlock};
274
275 const auto a_grid_desc_mraw_kraw = [&]() {
277 {
278 return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw),
279 make_tuple(StrideA, I1));
280 }
282 {
283 return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw),
284 make_tuple(I1, StrideA));
285 }
286 }();
287
288 return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw);
289 }
290
291 __host__ __device__ static auto
293 {
294 constexpr auto matrix_padder =
296 MPerBlock, NPerBlock, KPerBlock};
297
298 const auto b_grid_desc_nraw_kraw = [&]() {
300 {
301 return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
302 make_tuple(I1, StrideB));
303 }
305 {
306 return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
307 make_tuple(StrideB, I1));
308 }
309 }();
310
311 return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw);
312 }
313
314 __host__ __device__ static auto
316 {
317 constexpr auto matrix_padder =
319 MPerBlock, NPerBlock, KPerBlock};
320 const auto e_grid_desc_mraw_nraw = [&]() {
322 {
323 return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw),
324 make_tuple(StrideE, I1));
325 }
327 {
328 return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw),
329 make_tuple(I1, StrideE));
330 }
331 }();
332
333 return matrix_padder.PadCDescriptor_M_N(e_grid_desc_mraw_nraw);
334 }
335
336 __host__ __device__ static auto
337 MakeDsGridDescriptor_M_N(const std::array<index_t, NumDTensor>& MRaws,
338 const std::array<index_t, NumDTensor>& NRaws,
339 const std::array<index_t, NumDTensor>& DsStride)
340 {
341 return generate_tuple(
342 [&](auto i) { return MakeEGridDescriptor_M_N(MRaws[i], NRaws[i], DsStride[i]); },
344 }
345
346 using AGridDesc_M_K = decltype(MakeAGridDescriptor_M_K(1, 1, 1));
347 using BGridDesc_N_K = decltype(MakeBGridDescriptor_N_K(1, 1, 1));
349 using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N(1, 1, 1));
350
351 // A desc for source in blockwise copy.
352 __host__ __device__ static constexpr auto
354 {
355 const auto M = a_grid_desc_m_k.GetLength(I0);
356 const auto K = a_grid_desc_m_k.GetLength(I1);
357
358 const auto AK0 = K / AK1;
359
360 return transform_tensor_descriptor(a_grid_desc_m_k,
365 }
366
367 // B desc for source in blockwise copy.
368 __host__ __device__ static constexpr auto
370 {
371 const auto N = b_grid_desc_n_k.GetLength(I0);
372 const auto K = b_grid_desc_n_k.GetLength(I1);
373
374 const auto BK0 = K / BK1;
375
376 return transform_tensor_descriptor(b_grid_desc_n_k,
381 }
382
383 // E desc for destination in blockwise copy.
384 __host__ __device__ static constexpr auto
386 {
387 const auto M = e_grid_desc_m_n.GetLength(I0);
388 const auto N = e_grid_desc_m_n.GetLength(I1);
389
390 const auto MBlock = M / MPerBlock;
391 const auto NBlock = N / NPerBlock;
392
393 const auto e_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor(
394 e_grid_desc_m_n,
399
400 return e_grid_desc_mblock_mperblock_nblock_nperblock;
401 }
402
403 // Ds desc for source in blockwise copy.
404 __host__ __device__ static constexpr auto
406 {
407 return generate_tuple(
408 [&](auto i) {
410 },
412 }
413
414 __host__ __device__ static constexpr auto
416 {
418 e_grid_desc_m_n);
419 }
420
427 DsGridDesc_M_N{}))>;
430 EGridDesc_M_N{}))>;
431
433
435
436 __host__ __device__ static constexpr bool CheckValidity(const AGridDesc_M_K& a_grid_desc_m_k,
437 const BGridDesc_N_K& b_grid_desc_n_k,
438 const DsGridDesc_M_N& ds_grid_desc_m_n,
439 const EGridDesc_M_N& e_grid_desc_m_n,
440 const Block2ETileMap& block_2_etile_map)
441 {
442 static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
443 (NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
444 "Invalid tuning param!");
445
446 static_assert(KPerBlock % AK1Value == 0 && KPerBlock % BK1Value == 0,
447 "KPerBlock must be divisible by AK1Value and BK1Value!");
448
449 static_assert(
450 std::is_same_v<AElementwiseOperation,
452 std::is_same_v<BElementwiseOperation,
454 "Direct load transfers do not support elementwise operations other than passthrough.");
455
456 const auto M = a_grid_desc_m_k.GetLength(I0);
457 const auto N = b_grid_desc_n_k.GetLength(I0);
458 const auto AK = a_grid_desc_m_k.GetLength(I1);
459 const auto BK = b_grid_desc_n_k.GetLength(I1);
460
461 // Check the consistency of descriptors.
462 if(!(M == e_grid_desc_m_n.GetLength(I0) && N == e_grid_desc_m_n.GetLength(I1) && AK == BK))
463 {
464 return false;
465 }
466
467 bool valid = true;
468
469 static_for<0, NumDTensor, 1>{}([&](auto i) {
470 valid = valid && (M == ds_grid_desc_m_n[i].GetLength(I0) &&
471 N == ds_grid_desc_m_n[i].GetLength(I1));
472 });
473
474 if(!valid)
475 {
476 return false;
477 }
478
479 // Check the tile size.
480 if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && AK % KPerBlock == 0))
481 {
482 return false;
483 }
484
485 // Check gridwise gemm pipeline.
486 const auto num_k_loop = AK / KPerBlock;
487 if(!GridwiseGemmPipe::IsSupported(num_k_loop))
488 {
489 return false;
490 }
491
492 // Check block-to-E-tile.
493 if(!block_2_etile_map.CheckValidity(e_grid_desc_m_n))
494 {
495 return false;
496 }
497
498 // Check tensor size: cannot exceed 2GB.
499 constexpr long_index_t TwoGB = (long_index_t{1} << 31);
500
501 if(!(a_grid_desc_m_k.GetElementSpaceSize() * sizeof(ADataType) <= TwoGB &&
502 b_grid_desc_n_k.GetElementSpaceSize() * sizeof(BDataType) <= TwoGB &&
503 e_grid_desc_m_n.GetElementSpaceSize() * sizeof(EDataType) <= TwoGB))
504 {
505 return false;
506 }
507
508 return true;
509 }
510
511 __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
512 {
513 const index_t num_loop = K / KPerBlock;
514 return GridwiseGemmPipe::CalculateHasMainLoop(num_loop);
515 }
516
517 using DsGridPointer = decltype(MakeDsGridPointer());
518
519 __device__ __host__ static constexpr auto GetMPerBlock() { return MPerBlock; }
520
521 template <bool HasMainKBlockLoop,
522 typename AGridDesc_AK0_M_AK1,
523 typename BGridDesc_BK0_N_BK1,
526 __device__ static void Run(const ADataType* __restrict__ p_a_grid,
527 const BDataType* __restrict__ p_b_grid,
528 DsGridPointer p_ds_grid,
529 EDataType* __restrict__ p_e_grid,
530 void* __restrict__ p_shared,
531 const AElementwiseOperation& a_element_op,
532 const BElementwiseOperation& b_element_op,
533 const CDEElementwiseOperation& cde_element_op,
534 const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1,
535 const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1,
537 ds_grid_desc_mblock_mperblock_nblock_nperblock,
539 e_grid_desc_mblock_mperblock_nblock_nperblock,
540 const Block2ETileMap& block_2_etile_map)
541 {
542 // Elementwise operations are not supported for A and B, arguments left only for the API
543 // consistency.
544 (void)a_element_op;
545 (void)b_element_op;
546
547 const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
548 p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
549
550 const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
551 p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
552
553 const auto ds_grid_buf = generate_tuple(
554 [&](auto i) {
556 p_ds_grid[i],
557 ds_grid_desc_mblock_mperblock_nblock_nperblock[i].GetElementSpaceSize());
558 },
560
562 p_e_grid, e_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
563
564 // Divide block work by [M, N].
565 const auto block_work_idx =
566 block_2_etile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
567
568 if(!block_2_etile_map.ValidCTileIndex(
569 block_work_idx,
570 make_tuple(e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
571 e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
572 {
573 return;
574 }
575
576 // This forces m/n_block_data_idx_on_grid into SGPR.
577 const index_t m_block_data_idx_on_grid =
578 __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock);
579 const index_t n_block_data_idx_on_grid =
580 __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock);
581
582 constexpr auto max_lds_align = math::lcm(AK1, BK1);
583
584 // A matrix in LDS memory, destination of blockwise copy.
585 constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
586
587 // B matrix in LDS memory, destination of blockwise copy.
588 constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
589
590 auto a_blockwise_copy =
593 ABlockTransferThreadClusterLengths_AK0_M_AK1,
594 ABlockTransferSrcAccessOrder,
595 ADataType,
597 decltype(a_grid_desc_ak0_m_ak1),
598 decltype(a_block_desc_ak0_m_ak1),
599 ABlockTransferSrcAccessOrder,
600 ABlockTransferSrcVectorDim,
601 2,
602 ABlockTransferScalarPerVector>(
603 a_grid_desc_ak0_m_ak1,
604 make_multi_index(0, m_block_data_idx_on_grid, 0),
605 a_block_desc_ak0_m_ak1,
606 make_multi_index(0, 0, 0));
607
608 auto b_blockwise_copy =
611 BBlockTransferThreadClusterLengths_BK0_N_BK1,
612 BBlockTransferSrcAccessOrder,
613 BDataType,
615 decltype(b_grid_desc_bk0_n_bk1),
616 decltype(b_block_desc_bk0_n_bk1),
617 BBlockTransferSrcAccessOrder,
618 BBlockTransferSrcVectorDim,
619 2,
620 BBlockTransferScalarPerVector>(
621 b_grid_desc_bk0_n_bk1,
622 make_multi_index(0, n_block_data_idx_on_grid, 0),
623 b_block_desc_bk0_n_bk1,
624 make_multi_index(0, 0, 0));
625
626 // GEMM definition
627 // c_mtx += transpose(a_mtx) * b_mtx
628 // a_mtx[K0PerBlock, MPerBlock] is in LDS
629 // b_mtx[K0PerBlock, NPerBlock] is in LDS
630 // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
631 // register
632 constexpr auto lcm_AK1_BK1 = math::lcm(AK1, BK1);
633 constexpr bool is_single_rate_mfma =
636 lcm_AK1_BK1 <= 4) ||
637 (is_same<AComputeDataType, int8_t>::value && lcm_AK1_BK1 <= 8) ||
639 lcm_AK1_BK1 < 32))
640 ? true
641 : false;
642 constexpr auto is_scale_mfma = false;
643
644 constexpr index_t KPack = math::max(lcm_AK1_BK1,
645 MfmaSelector<AComputeDataType_,
646 MPerXdl,
647 NPerXdl,
648 BComputeDataType_,
649 is_single_rate_mfma,
650 is_scale_mfma>::selected_mfma.k_per_blk);
651
653 BlockSize,
656 AccDataType,
657 decltype(a_block_desc_ak0_m_ak1),
658 decltype(b_block_desc_bk0_n_bk1),
659 MPerXdl,
660 NPerXdl,
661 MXdlPerWave,
662 NXdlPerWave,
663 KPack,
664 LoopSched,
665 AComputeDataType_,
666 BComputeDataType_>();
667
668 auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();
669
670 // LDS allocation for A and B: be careful of alignment.
671 constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
672 a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
673
674 const auto a_buffers_offset = 0;
675 auto a_block_buffers =
676 ck::lds_utils::AllocateLdsBuffers<AComputeDataType, NumGemmKPrefetchStage>(
677 p_shared,
678 a_block_desc_ak0_m_ak1.GetElementSpaceSize(),
679 a_buffers_offset,
680 max_lds_align);
681 const auto b_buffers_offset = a_block_space_size_aligned * NumGemmKPrefetchStage;
682 auto b_block_buffers =
683 ck::lds_utils::AllocateLdsBuffers<BComputeDataType, NumGemmKPrefetchStage>(
684 p_shared,
685 b_block_desc_bk0_n_bk1.GetElementSpaceSize(),
686 b_buffers_offset,
687 max_lds_align);
688
689 constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1, 0, 0);
690 constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1, 0, 0);
691
692 const auto gridwise_gemm_pipeline =
694
695 const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
696 (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
697 KPerBlock);
698
699 gridwise_gemm_pipeline.template Run<HasMainKBlockLoop>(a_grid_desc_ak0_m_ak1,
700 a_block_desc_ak0_m_ak1,
701 a_blockwise_copy,
702 a_grid_buf,
703 a_block_buffers,
704 a_block_slice_copy_step,
705 b_grid_desc_bk0_n_bk1,
706 b_block_desc_bk0_n_bk1,
707 b_blockwise_copy,
708 b_grid_buf,
709 b_block_buffers,
710 b_block_slice_copy_step,
711 blockwise_gemm,
712 c_thread_buf,
713 num_k_block_main_loop);
714
715 // Shuffle C and write out.
716 {
717 static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
718 NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
719 "wrong!");
720
721 constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
722 constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
723
724 constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
725 blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
726
727 // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
728 constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
729 blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
730
731 constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0);
732 constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1);
733 constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2);
734 constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3);
735 constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4);
736 constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5);
737 constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6);
738 constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7);
739
740 constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
742
743 auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
744 static_cast<CShuffleDataType*>(p_shared),
745 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
746
747 constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
748 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
752 Number<CShuffleMXdlPerWavePerShuffle>{}, // M0 (MXdlPerWave) per shuffle
753 M1, // M1 = MWave
754 M2, // M2 * M3 * M4 = MPerXdl
755 M3,
756 M4)),
759 Number<CShuffleNXdlPerWavePerShuffle>{}, // N0 (NXdlPerWave) per shuffle
760 N1, // N1 = NWave
761 N2))), // N2 = NPerXdl
765
766 // Calculate the origin of thread output tensor on global memory.
767 const auto c_thread_mtx_on_block =
768 blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
769
770 const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
771 const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
772
773 const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
775 make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
778
779 const auto m_thread_data_on_block_idx =
780 m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
781 make_multi_index(m_thread_data_on_block));
782
783 const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
788
789 const auto n_thread_data_on_block_idx =
790 n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
791 make_multi_index(n_thread_data_on_block));
792
793 // Shuffle: threadwise copy C from VGPR to LDS.
794 auto c_thread_copy_vgpr_to_lds =
796 CShuffleDataType,
797 decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
798 decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
800 Sequence<CShuffleMXdlPerWavePerShuffle,
801 CShuffleNXdlPerWavePerShuffle,
802 I1,
803 I1,
804 M2,
805 I1,
806 M4,
807 I1>,
809 7,
810 1,
812 1,
813 true>{
814 c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
816 0,
817 m_thread_data_on_block_idx[I1],
818 n_thread_data_on_block_idx[I1],
819 m_thread_data_on_block_idx[I2],
820 m_thread_data_on_block_idx[I3],
821 m_thread_data_on_block_idx[I4],
822 n_thread_data_on_block_idx[I2]),
824
825 // A tuple of reference to C/Ds tensor descriptors.
826 const auto c_ds_desc_refs = concat_tuple_of_reference(
827 tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
828 generate_tie([&](auto i) -> const auto& // return type should be reference
829 { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; },
831
832 // A tuple of reference to C/Ds grid buffers.
833 const auto c_ds_buf_refs = concat_tuple_of_reference(
834 tie(c_shuffle_block_buf),
835 generate_tie([&](auto i) -> const auto& // return type should be reference
836 { return ds_grid_buf[i]; },
838
839 // A tuple of starting index of C/Ds blockwise copy.
840 const auto idx_c_ds_block_begin = container_concat(
841 make_tuple(make_multi_index(0, 0, 0, 0)),
843 [&](auto) {
844 return make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0);
845 },
847
848 // Blockwise copy C/D/E between LDS and global.
849 auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7<
851 decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})),
853 decltype(c_ds_desc_refs),
854 decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)),
855 CDEElementwiseOperation,
856 Sequence<static_cast<index_t>(EGlobalMemoryDataOperation)>,
857 Sequence<1,
858 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
859 1,
860 CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
861 CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
862 Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
863 Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
864 3, // index_t VectorDim,
865 CDEShuffleBlockTransferScalarPerVector_NPerBlock,
869 false>>, // ThreadTransferSrcResetCoordinateAfterRunFlags
870 Sequence<false>> // ThreadTransferDstResetCoordinateAfterRunFlags
871 {c_ds_desc_refs,
872 idx_c_ds_block_begin,
873 tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
874 make_tuple(make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0)),
875 cde_element_op};
876
877 // Space filling curve for threadwise C in VGPR before shuffle.
878 constexpr auto sfc_c_vgpr =
881 Sequence<CShuffleMXdlPerWavePerShuffle,
882 CShuffleNXdlPerWavePerShuffle,
883 1,
884 1,
885 M2,
886 1,
887 M4,
888 1>>{};
889
890 // Space filling curve for shuffled blockwise C/D/E.
891 constexpr auto sfc_cde_block =
894 Sequence<1,
895 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
896 1,
897 CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
898
899 constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
900
901 static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!");
902
903 static_for<0, num_access, 1>{}([&](auto access_id) {
904 // Make sure it's safe to write to LDS.
906
907 // Each thread write its data from VGPR to LDS.
908 c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
909 sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
910 c_thread_buf,
911 c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
912 c_shuffle_block_buf);
913
914 // Make sure it's safe to read from LDS.
916
917 // Each block copy its data from LDS to global.
918 cde_block_copy_lds_and_global.Run(
919 c_ds_desc_refs,
920 c_ds_buf_refs,
921 tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
922 tie(e_grid_buf));
923
924 if constexpr(access_id < num_access - 1)
925 {
926 constexpr auto cde_lds_and_global_step =
927 sfc_cde_block.GetForwardStep(access_id);
928
929 // Move on Ds.
930 static_for<0, NumDTensor, 1>{}([&](auto i) {
931 cde_block_copy_lds_and_global.MoveSrcSliceWindow(
932 c_ds_desc_refs, i + I1, cde_lds_and_global_step);
933 });
934
935 // Move on E.
936 cde_block_copy_lds_and_global.MoveDstSliceWindow(
937 tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
938 I0,
939 cde_lds_and_global_step);
940 }
941 });
942 }
943 }
944
946 {
947 Argument(const void* p_a_grid,
948 const void* p_b_grid,
949 std::array<const void*, NumDTensor> p_ds_grid,
950 void* p_e_grid,
951 index_t MRaw,
952 index_t NRaw,
953 index_t KRaw,
954 index_t StrideA,
955 index_t StrideB,
956 std::array<index_t, NumDTensor> StrideDs,
957 index_t StrideE,
958 AElementwiseOperation a_element_op,
959 BElementwiseOperation b_element_op,
960 CDEElementwiseOperation cde_element_op)
961 : p_a_grid_{static_cast<const ADataType*>(p_a_grid)},
962 p_b_grid_{static_cast<const BDataType*>(p_b_grid)},
963 p_ds_grid_{},
964 p_e_grid_{static_cast<EDataType*>(p_e_grid)},
965 a_grid_desc_m_k_{MakeAGridDescriptor_M_K(MRaw, KRaw, StrideA)},
966 b_grid_desc_n_k_{MakeBGridDescriptor_N_K(KRaw, NRaw, StrideB)},
968 e_grid_desc_m_n_{MakeEGridDescriptor_M_N(MRaw, NRaw, StrideE)},
974 a_element_op_{a_element_op},
975 b_element_op_{b_element_op},
976 cde_element_op_{cde_element_op},
977 MRaw_{MRaw},
978 NRaw_{NRaw},
979 KRaw_{KRaw}
980 {
981 static_for<0, NumDTensor, 1>{}([&](auto i) {
982 using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
983 p_ds_grid_(i) = static_cast<const DDataType*>(p_ds_grid[i]);
984 ds_grid_desc_m_n_(i) = MakeEGridDescriptor_M_N(MRaw, NRaw, StrideDs[i]);
985 });
986
992 {
995
998 }
999 }
1000
1001 void Print() const
1002 {
1003 std::cout << "A[M, K]: " << a_grid_desc_m_k_ << std::endl;
1004 std::cout << "B[N, K]: " << b_grid_desc_n_k_ << std::endl;
1006 [&](auto i) { std::cout << "Ds[M, N]: " << ds_grid_desc_m_n_[i] << std::endl; });
1007 std::cout << "E[M, N]: " << e_grid_desc_m_n_ << std::endl;
1008 }
1009
1010 // Pointers
1011 const ADataType* p_a_grid_;
1012 const BDataType* p_b_grid_;
1014 EDataType* p_e_grid_;
1015
1016 // Tensor descriptors for problem definiton
1021
1022 // Tensor descriptors for block/thread-wise copy
1028
1029 // block-to-e-tile map
1031
1032 // element-wise ops
1033 AElementwiseOperation a_element_op_;
1034 BElementwiseOperation b_element_op_;
1035 CDEElementwiseOperation cde_element_op_;
1036
1037 // For checking vector load/store
1041 };
1042};
1043
1044} // namespace ck
#define CK_MIN_BLOCK_PER_CU
Definition ck.hpp:31
#define CK_MAX_THREAD_PER_BLOCK
Definition ck.hpp:30
#define IS_VALID_COMPILATION_PARAMETER_IMPL(CDataType_)
Definition device_base.hpp:178
__host__ __device__ constexpr auto integer_least_multiple(X x, Y y)
Definition utility/math.hpp:78
__host__ __device__ constexpr T max(T x)
Definition utility/math.hpp:84
__host__ __device__ constexpr auto lcm(X x, Y y)
Definition utility/math.hpp:198
GemmSpecialization
Definition gemm_specialization.hpp:11
Definition ck.hpp:268
ushort bhalf_t
Definition data_type.hpp:30
__host__ __device__ constexpr auto make_multi_index(Xs &&... xs)
Definition array_multi_index.hpp:15
typename uniform_sequence_gen< NSize, I >::type uniform_sequence_gen_t
Definition utility/sequence.hpp:928
constexpr auto BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector()
Definition blockwise_gemm_xdlops.hpp:620
__host__ __device__ constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition multi_index_transform_helper.hpp:12
__host__ __device__ constexpr auto container_concat(const X &x, const Ys &... ys)
Definition utility/container_helper.hpp:320
constexpr auto GridwiseGemmPipeline_Selector()
Definition gridwise_gemm_pipeline_selector.hpp:31
int32_t index_t
Definition ck.hpp:299
typename conditional< predicate, X, Y >::type conditional_t
Definition utility/functional.hpp:115
__host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition tensor_descriptor_helper.hpp:49
InMemoryDataOperationEnum
Definition ck.hpp:277
@ Set
Definition ck.hpp:278
__host__ __device__ constexpr auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition tensor_description/tensor_adaptor.hpp:425
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
__host__ __device__ constexpr auto make_freeze_transform(const LowerIndex &low_idx)
Definition multi_index_transform_helper.hpp:151
constexpr Tuple< Args &... > tie(Args &... args) noexcept
Definition utility/tuple.hpp:218
integral_constant< index_t, N > Number
Definition number.hpp:12
typename tuple_element< I, TTuple >::type tuple_element_t
Definition utility/tuple.hpp:208
__host__ __device__ constexpr auto make_merge_transform(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:55
constexpr detail::ignore_t ignore
Definition utility/ignore.hpp:20
__device__ index_t get_block_1d_id()
Definition get_id.hpp:47
constexpr bool is_same_v
Definition type.hpp:283
__host__ __device__ constexpr auto generate_tuple(F &&f, Number< N >)
Definition tuple_helper.hpp:21
__host__ __device__ constexpr auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition tensor_descriptor_helper.hpp:101
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
LoopScheduler
Definition loop_scheduler.hpp:15
typename sequence_merge< Sx, Sy >::type sequence_merge_t
Definition utility/sequence.hpp:925
int64_t long_index_t
Definition ck.hpp:300
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
__device__ void block_sync_lds()
Definition synchronization.hpp:16
PipelineVersion
Definition gridwise_gemm_pipeline_selector.hpp:18
@ v4
Definition gridwise_gemm_pipeline_selector.hpp:22
__host__ __device__ constexpr auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:90
__host__ __device__ constexpr auto make_dynamic_buffer(T *p, ElementSpaceSize element_space_size)
Definition dynamic_buffer.hpp:472
__host__ __device__ constexpr auto generate_tie(F &&f, Number< N >)
Definition tuple_helper.hpp:34
__host__ __device__ constexpr auto concat_tuple_of_reference(const Tuple< X &... > &tx, const Tuple< Y &... > &ty)
Definition tuple_helper.hpp:42
__global__ void kernel_gemm_multiple_d_xdl_cshuffle_lds_direct_load(const ADataType *__restrict__ p_a_grid, const BDataType *__restrict__ p_b_grid, DsPointer p_ds_grid, EDataType *__restrict__ p_e_grid, const AElementwiseOperation a_element_op, const BElementwiseOperation b_element_op, const CDEElementwiseOperation cde_element_op, const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock ds_grid_desc_mblock_mperblock_nblock_nperblock, const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock, const Block2ETileMap block_2_etile_map)
Definition gridwise_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp:43
Definition block_to_ctile_map.hpp:261
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock ds_grid_desc_mblock_mperblock_nblock_nperblock_
Definition gridwise_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp:1026
const BDataType * p_b_grid_
Definition gridwise_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp:1012
BElementwiseOperation b_element_op_
Definition gridwise_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp:1034
index_t NRaw_
Definition gridwise_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp:1039
index_t MRaw_
Definition gridwise_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp:1038
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_
Definition gridwise_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp:1023
index_t KRaw_
Definition gridwise_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp:1040
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock_
Definition gridwise_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp:1027
void Print() const
Definition gridwise_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp:1001
AGridDesc_M_K a_grid_desc_m_k_
Definition gridwise_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp:1017
EGridDesc_M_N e_grid_desc_m_n_
Definition gridwise_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp:1020
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_
Definition gridwise_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp:1024
DsGridDesc_M_N ds_grid_desc_m_n_
Definition gridwise_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp:1019
EDataType * p_e_grid_
Definition gridwise_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp:1014
Block2ETileMap block_2_etile_map_
Definition gridwise_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp:1030
BGridDesc_N_K b_grid_desc_n_k_
Definition gridwise_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp:1018
const ADataType * p_a_grid_
Definition gridwise_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp:1011
AElementwiseOperation a_element_op_
Definition gridwise_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp:1033
CDEElementwiseOperation cde_element_op_
Definition gridwise_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp:1035
DsGridPointer p_ds_grid_
Definition gridwise_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp:1013
Argument(const void *p_a_grid, const void *p_b_grid, std::array< const void *, NumDTensor > p_ds_grid, void *p_e_grid, index_t MRaw, index_t NRaw, index_t KRaw, index_t StrideA, index_t StrideB, std::array< index_t, NumDTensor > StrideDs, index_t StrideE, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op)
Definition gridwise_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp:947
Definition gridwise_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp:149
static __device__ void Run(const ADataType *__restrict__ p_a_grid, const BDataType *__restrict__ p_b_grid, DsGridPointer p_ds_grid, EDataType *__restrict__ p_e_grid, void *__restrict__ p_shared, const AElementwiseOperation &a_element_op, const BElementwiseOperation &b_element_op, const CDEElementwiseOperation &cde_element_op, const AGridDesc_AK0_M_AK1 &a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_BK1 &b_grid_desc_bk0_n_bk1, const DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock &ds_grid_desc_mblock_mperblock_nblock_nperblock, const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock &e_grid_desc_mblock_mperblock_nblock_nperblock, const Block2ETileMap &block_2_etile_map)
Definition gridwise_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp:526
Selects the appropriate MFMA instruction type and configuration for given data types and tile sizes o...
Definition xdlops_gemm.hpp:1208
Definition utility/sequence.hpp:43
Definition tensor_space_filling_curve.hpp:20
Definition thread_group_tensor_slice_transfer_direct_load.hpp:55
Definition thread_group_tensor_slice_transfer_v7.hpp:42
Definition threadwise_tensor_slice_transfer.hpp:39
Definition utility/tuple.hpp:117
static constexpr value_type value
Definition utility/integral_constant.hpp:13
Definition functional2.hpp:33
Definition device_base.hpp:197
__host__ __device__ constexpr auto PadBDescriptor_N_K(const BDesc_NRaw_KRaw &b_desc_nraw_kraw) const
Definition matrix_padder.hpp:155
__host__ __device__ constexpr auto PadCDescriptor_M_N(const CDesc_MRaw_NRaw &c_desc_mraw_nraw) const
Definition matrix_padder.hpp:163
__host__ __device__ constexpr auto PadADescriptor_M_K(const ADesc_MRaw_KRaw &a_desc_mraw_kraw) const
Definition matrix_padder.hpp:147
Definition matrix_padder.hpp:180
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:340