device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp Source File

device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp Source File#

Composable Kernel: device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp Source File
device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include <iostream>
7#include <sstream>
8
18
19namespace ck {
20namespace tensor_operation {
21namespace device {
22
23template <typename GridwiseGemm,
24 typename GemmDesc,
25 GemmSpecialization GemmSpec,
26 typename AsLayout,
27 typename BsLayout,
28 typename DsLayout,
29 typename ELayout,
30 typename Block2ETileMap,
31 typename GroupedGemmBlock2ETileMap,
32 typename AElementwiseOperation,
33 typename BElementwiseOperation,
34 typename CDEElementwiseOperation,
35 InMemoryDataOperationEnum EGlobalMemoryDataOperation,
36 bool HasMainKBlockLoop>
37__global__ void
38#if CK_USE_LAUNCH_BOUNDS
40#endif
42 const index_t group_count,
43 const index_t grid_size_grp,
44 const AElementwiseOperation a_element_op,
45 const BElementwiseOperation b_element_op,
46 const CDEElementwiseOperation cde_element_op)
47{
48#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__)
49 if constexpr(GridwiseGemm::template IsValidCompilationParameter<EGlobalMemoryDataOperation>())
50 {
51 __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
52
53 const index_t KBatch = 1;
54
55 const index_t block_id = get_block_1d_id();
56
57 const auto gemm_desc_ptr = reinterpret_cast<const GemmDesc*>(
59
60 const index_t group_id = block_id / grid_size_grp;
61
62 if(group_id >= group_count)
63 return;
64
65 const index_t M = gemm_desc_ptr[group_id].M;
66 const index_t N = gemm_desc_ptr[group_id].N;
67 const index_t K = gemm_desc_ptr[group_id].K;
68
69 if(M == 0 || N == 0 || K == 0)
70 return;
71
72 const auto StrideAs = gemm_desc_ptr[group_id].StrideAs;
73 const auto StrideBs = gemm_desc_ptr[group_id].StrideBs;
74 const auto StrideDs = gemm_desc_ptr[group_id].StrideDs;
75 const auto StrideE = gemm_desc_ptr[group_id].StrideE;
76
77 const auto e_grid_desc_m_n =
78 GridwiseGemm::template MakeEGridDescriptor_M_N<ELayout, GemmSpec>(M, N, StrideE);
79
80 const index_t BlockStart = group_id * grid_size_grp;
81
82 const auto local_b2e_tile_map = Block2ETileMap{e_grid_desc_m_n, KBatch};
83
84 const auto local_grid_size = local_b2e_tile_map.CalculateGridSize(e_grid_desc_m_n);
85
86 constexpr auto NumATensor = GridwiseGemm::AsGridPointer::Size();
87 constexpr auto NumBTensor = GridwiseGemm::BsGridPointer::Size();
88 constexpr auto NumDTensor = GridwiseGemm::DsGridPointer::Size();
89
90 typename GridwiseGemm::AsGridPointer p_as_grid_;
91 typename GridwiseGemm::BsGridPointer p_bs_grid_;
92 typename GridwiseGemm::DsGridPointer p_ds_grid_;
93
94 static_for<0, NumATensor, 1>{}([&](auto i) {
95 using ADataType = remove_cvref_t<decltype(p_as_grid_(i))>;
96 p_as_grid_(i) = static_cast<ADataType>(gemm_desc_ptr[group_id].p_as_grid[i]);
97 });
98
99 static_for<0, NumBTensor, 1>{}([&](auto i) {
100 using BDataType = remove_cvref_t<decltype(p_bs_grid_(i))>;
101 p_bs_grid_(i) = static_cast<BDataType>(gemm_desc_ptr[group_id].p_bs_grid[i]);
102 });
103
104 static_for<0, NumDTensor, 1>{}([&](auto i) {
105 using DDataType = remove_cvref_t<decltype(p_ds_grid_(i))>;
106 p_ds_grid_(i) = static_cast<DDataType>(gemm_desc_ptr[group_id].p_ds_grid[i]);
107 });
108
109 index_t id_off = 0;
110 index_t id_local = get_block_1d_id() - BlockStart;
111
112 while(id_local < local_grid_size)
113 {
114 const auto block_2_etile_map =
115 GroupedGemmBlock2ETileMap(local_b2e_tile_map, BlockStart, id_off);
116
117 GridwiseGemm::
118 template Run<HasMainKBlockLoop, GemmSpec, AsLayout, BsLayout, DsLayout, ELayout>(
119 p_as_grid_,
120 p_bs_grid_,
121 p_ds_grid_,
122 gemm_desc_ptr[group_id].p_e_grid,
123 p_shared,
124 a_element_op,
125 b_element_op,
126 cde_element_op,
127 M,
128 N,
129 K,
130 StrideAs,
131 StrideBs,
132 StrideDs,
133 StrideE,
134 block_2_etile_map);
135
136 id_off += grid_size_grp;
137 id_local += grid_size_grp;
138 }
139 }
140#else
141 ignore = gemm_descs_const;
142 ignore = group_count;
143 ignore = grid_size_grp;
144 ignore = a_element_op;
145 ignore = b_element_op;
146 ignore = cde_element_op;
147#endif
148}
149
150template <typename AsLayout,
151 typename BsLayout,
152 typename DsLayout,
153 typename ELayout,
154 typename AsDataType,
155 typename BsDataType,
156 typename AccDataType,
157 typename CShuffleDataType,
158 typename DsDataType,
159 typename EDataType,
160 typename AElementwiseOperation,
161 typename BElementwiseOperation,
162 typename CDEElementwiseOperation,
163 GemmSpecialization GemmSpec,
164 ck::index_t NumPrefetch,
165 ck::index_t BlockSize,
166 ck::index_t MPerBlock,
167 ck::index_t NPerBlock,
168 ck::index_t KPerBlock,
169 ck::index_t AK1,
170 ck::index_t BK1,
171 ck::index_t MPerXDL,
172 ck::index_t NPerXDL,
173 ck::index_t MXdlPerWave,
174 ck::index_t NXdlPerWave,
175 typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
176 typename ABlockTransferThreadClusterArrangeOrder,
177 typename ABlockTransferSrcAccessOrder,
178 ck::index_t ABlockTransferSrcVectorDim,
179 ck::index_t ABlockTransferSrcScalarPerVector,
180 ck::index_t ABlockTransferDstScalarPerVector_AK1,
181 bool ABlockLdsExtraM,
182 typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
183 typename BBlockTransferThreadClusterArrangeOrder,
184 typename BBlockTransferSrcAccessOrder,
185 ck::index_t BBlockTransferSrcVectorDim,
186 ck::index_t BBlockTransferSrcScalarPerVector,
187 ck::index_t BBlockTransferDstScalarPerVector_BK1,
188 bool BBlockLdsExtraN,
189 index_t CShuffleMXdlPerWavePerShuffle,
190 index_t CShuffleNXdlPerWavePerShuffle,
191 typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
192 index_t CDEBlockTransferScalarPerVector_NPerBlock,
193 typename ComputeType = EDataType,
196 : public DeviceGroupedGemmMultiABDFixedNK<AsLayout,
197 BsLayout,
198 DsLayout,
199 ELayout,
200 AsDataType,
201 BsDataType,
202 DsDataType,
203 EDataType,
204 AElementwiseOperation,
205 BElementwiseOperation,
206 CDEElementwiseOperation>
207{
210 static constexpr auto NXdlPerWave64 = GetNXdlPerWave<true>();
211 static constexpr auto NXdlPerWave32 = GetNXdlPerWave<false>();
212
213 static constexpr index_t NumATensor = AsDataType::Size();
214 static constexpr index_t NumBTensor = BsDataType::Size();
215 static constexpr index_t NumDTensor = DsDataType::Size();
216
217 static constexpr auto I0 = Number<0>{};
218 static constexpr auto I1 = Number<1>{};
219 static constexpr auto I2 = Number<2>{};
220
221 static constexpr index_t NumGemmKPrefetchStage = 1;
222
223 // GridwiseGemm
224 template <index_t NXdlPerWave_>
226 AsDataType,
227 BsDataType,
228 ComputeType,
229 AccDataType,
230 CShuffleDataType,
231 DsDataType,
232 EDataType,
233 AElementwiseOperation,
234 BElementwiseOperation,
235 CDEElementwiseOperation,
238 BlockSize,
239 MPerBlock,
240 NPerBlock,
241 KPerBlock,
242 AK1,
243 BK1,
244 MPerXDL,
245 NPerXDL,
246 MXdlPerWave,
247 NXdlPerWave_,
248 ABlockTransferThreadClusterLengths_AK0_M_AK1,
249 ABlockTransferThreadClusterArrangeOrder,
250 ABlockTransferSrcAccessOrder,
251 ABlockTransferSrcVectorDim,
252 ABlockTransferSrcScalarPerVector,
253 ABlockTransferDstScalarPerVector_AK1,
254 false,
255 ABlockLdsExtraM,
256 BBlockTransferThreadClusterLengths_BK0_N_BK1,
257 BBlockTransferThreadClusterArrangeOrder,
258 BBlockTransferSrcAccessOrder,
259 BBlockTransferSrcVectorDim,
260 BBlockTransferSrcScalarPerVector,
261 BBlockTransferDstScalarPerVector_BK1,
262 false,
263 BBlockLdsExtraN,
264 CShuffleMXdlPerWavePerShuffle,
265 CShuffleNXdlPerWavePerShuffle,
266 CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
267 CDEBlockTransferScalarPerVector_NPerBlock,
268 LoopSched>;
271 template <typename UnderlyingBlockToCTileMap>
273 {
274 using underlying_type = UnderlyingBlockToCTileMap;
275
277 UnderlyingBlockToCTileMap block_to_ctile_map, index_t block_start, index_t id_off = 0)
278 {
279 block_to_ctile_map_ = block_to_ctile_map;
280 block_start_ = block_start;
281 id_off_ = id_off;
282 }
283
284 template <typename TopIdx>
285 __host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const
286 {
287 auto idx_bot = block_to_ctile_map_.CalculateBottomIndex(
289
290 return make_tuple(
291 // idx_bot[Number<0>{}],
292 idx_bot[Number<1>{}],
293 idx_bot[Number<2>{}]);
294 }
295
296 template <typename CTileIdx, typename CTileDim>
297 __host__ __device__ bool ValidCTileIndex(const CTileIdx& c_tile_idx,
298 const CTileDim& c_tile_dim) const
299 {
300 return block_to_ctile_map_.ValidCTileIndex(c_tile_idx, c_tile_dim);
301 }
302
303 template <typename CGridDesc_M_N>
304 __host__ bool CheckValidity(const CGridDesc_M_N& c_grid_desc_m_n) const
305 {
306 return block_to_ctile_map_.CheckValidity(c_grid_desc_m_n);
307 }
308
309 template <typename CGridDesc_M_N>
310 __host__ constexpr index_t CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) const
311 {
312 return block_to_ctile_map_.CalculateGridSize(c_grid_desc_m_n);
313 }
314
315 UnderlyingBlockToCTileMap block_to_ctile_map_;
318 };
319
320 template <index_t MPerBlock_, index_t NPerBlock_>
322 {
323 static constexpr auto I0 = Number<0>{};
324 static constexpr auto I1 = Number<1>{};
325
326 __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops() = default;
327
336
338 index_t N,
339 index_t KBatch,
340 index_t M01 = 8)
341 : M_(M), N_(N), KBatch_(KBatch), M01_(M01)
342 {
343 }
344
345 template <typename CGridDesc_M_N>
347 const CGridDesc_M_N& c_grid_desc_m_n, index_t KBatch, index_t M01 = 8)
349 c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1), KBatch, M01)
350 {
351 }
352
353 __host__ __device__ constexpr index_t CalculateGridSize(index_t M, index_t N) const
354 {
355 const auto M0 = math::integer_divide_ceil(M, MPerBlock);
356 const auto N0 = math::integer_divide_ceil(N, NPerBlock);
357
358 return M0 * N0 * KBatch_;
359 }
360
361 template <typename CGridDesc_M_N>
362 __host__ __device__ constexpr index_t
363 CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) const
364 {
365 return CalculateGridSize(c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1));
366 }
367
368 template <typename CGridDesc_M_N>
369 __host__ bool CheckValidity(const CGridDesc_M_N& /* c_grid_desc_m_n */) const
370 {
371 return true;
372 }
373
374 template <typename TopIdx>
375 __host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const
376 {
377 auto block_1d_id = idx_top[I0];
378
379 const auto M0 = math::integer_divide_ceil(M_, MPerBlock_);
380 const auto N0 = math::integer_divide_ceil(N_, NPerBlock_);
381
382 block_1d_id = block_1d_id % (M0 * N0 * KBatch_); // hide groups
383
384 const index_t idx_ksplit = block_1d_id / (M0 * N0);
385 block_1d_id = block_1d_id % (M0 * N0);
386
387 index_t idx_N0 = block_1d_id % N0;
388 index_t idx_M0 = block_1d_id / N0;
389
390 const auto M01_adapt = (idx_M0 < M0 - M0 % M01_) ? M01_ : M0 % M01_;
391
392 index_t idx_M00 = idx_M0 / M01_;
393 index_t idx_M01 = idx_M0 % M01_;
394 index_t idx_N0_M01_local = idx_N0 + idx_M01 * N0;
395
396 return make_tuple(idx_ksplit,
397 idx_N0_M01_local % M01_adapt + idx_M00 * M01_,
398 idx_N0_M01_local / M01_adapt);
399 }
400
401 template <typename CTileIdx, typename CTileDim>
402 __host__ __device__ bool ValidCTileIndex(const CTileIdx& /* c_tile_idx */,
403 const CTileDim& /* c_tile_dim */) const
404 {
405 return true; // always valid provided that user gets grid size from CalculateGridSize()
406 }
407
408 private:
409 index_t M_;
410 index_t N_;
411 index_t KBatch_;
412 index_t M01_;
413 };
414
417
419 {
420 // pointers
421 std::array<const void*, NumATensor> as_ptr_;
422 std::array<const void*, NumBTensor> bs_ptr_;
423 std::array<const void*, NumDTensor> ds_ptr_;
424 void* e_ptr_;
425
427 std::array<index_t, NumATensor> StrideAs_;
428 std::array<index_t, NumBTensor> StrideBs_;
429 std::array<index_t, NumDTensor> StrideDs_;
431 };
432
433 // Argument
434 struct Argument : public BaseArgument
435 {
436
438
439 Argument(std::vector<std::array<const void*, NumATensor>>&,
440 std::vector<std::array<const void*, NumBTensor>>&,
441 std::vector<std::array<const void*, NumDTensor>>&,
442 std::vector<void*>&,
443 std::vector<GemmMultiABDDesc>& gemm_descs,
444 AElementwiseOperation a_element_op = AElementwiseOperation{},
445 BElementwiseOperation b_element_op = BElementwiseOperation{},
446 CDEElementwiseOperation c_element_op = CDEElementwiseOperation{})
447 : a_element_op_{a_element_op}, b_element_op_{b_element_op}, c_element_op_{c_element_op}
448 {
449 grid_size_ = 0;
450
451 k_batch_ = 1;
452
454
455 group_count_ = ck::type_convert<ck::index_t>(gemm_descs.size());
456
458
459 index_t group_id = 0;
460
461 sum_of_m = gemm_descs[0].M_;
463 const index_t N = gemm_descs[0].N_;
464 const index_t K = gemm_descs[0].K_;
465
466 for(std::size_t i = 0; i < gemm_descs.size(); i++)
467 {
468 if(sum_of_m != gemm_descs[i].M_ || N != gemm_descs[i].N_ || K != gemm_descs[i].K_)
469 {
470 throw std::runtime_error("wrong! M/N/K is not identical");
471 }
472
473 a_mtx_mraw_kraw_.emplace_back(sum_of_m, K);
474 b_mtx_nraw_kraw_.emplace_back(N, K);
475
476 // pointer
477 std::array<const void*, NumATensor> p_as_grid;
478 std::array<const void*, NumBTensor> p_bs_grid;
479 std::array<const void*, NumDTensor> p_ds_grid;
480
481 static_for<0, NumATensor, 1>{}([&](auto j) { p_as_grid[j] = nullptr; });
482 static_for<0, NumBTensor, 1>{}([&](auto j) { p_bs_grid[j] = nullptr; });
483 static_for<0, NumDTensor, 1>{}([&](auto j) { p_ds_grid[j] = nullptr; });
484
485 std::array<index_t, NumATensor> StrideAs;
486 std::array<index_t, NumBTensor> StrideBs;
487 std::array<index_t, NumDTensor> StrideDs;
488
489 const index_t StrideE = gemm_descs[i].stride_C_;
490
491 if(gemm_descs[i].stride_As_.size() != NumATensor)
492 {
493 throw std::runtime_error(
494 "wrong! gemm_descs[i].stride_As_.size() does not match NumATensor");
495 }
496
497 static_for<0, NumATensor, 1>{}(
498 [&](auto j) { StrideAs[j] = gemm_descs[i].stride_As_[j]; });
499
500 if(gemm_descs[i].stride_Bs_.size() != NumBTensor)
501 {
502 throw std::runtime_error(
503 "wrong! gemm_descs[i].stride_Bs_.size() does not match NumBTensor");
504 }
505
506 static_for<0, NumBTensor, 1>{}(
507 [&](auto j) { StrideBs[j] = gemm_descs[i].stride_Bs_[j]; });
508
509 if(gemm_descs[i].stride_Ds_.size() != NumDTensor)
510 {
511 throw std::runtime_error(
512 "wrong! gemm_descs[i].stride_Ds_.size() does not match NumDTensor");
513 }
514
515 static_for<0, NumDTensor, 1>{}(
516 [&](auto j) { StrideDs[j] = gemm_descs[i].stride_Ds_[j]; });
517
518 const auto e_grid_desc_m_n =
519 GridwiseGemm64::template MakeEGridDescriptor_M_N<ELayout, GemmSpec>(
520 AverM, N, StrideE);
521
522 // block-to-e-tile map
523 const auto local_b2c_tile_map = Block2ETileMap{e_grid_desc_m_n, k_batch_};
524
525 grid_size_grp_ = local_b2c_tile_map.CalculateGridSize(e_grid_desc_m_n);
526
527 if(group_id * grid_size_grp_ != grid_size_)
528 {
529 throw std::runtime_error("wrong! grid_size_grp_ is not identical!");
530 }
531
533
534 // check block-to-E-tile
535 if(!local_b2c_tile_map.CheckValidity(e_grid_desc_m_n))
536 {
537 throw std::runtime_error("wrong! block_2_etile_map validation failed");
538 }
539
540 gemm_desc_kernel_arg_.push_back(GemmBiasTransKernelArg{
541 p_as_grid,
542 p_bs_grid,
543 p_ds_grid,
544 nullptr,
545 AverM,
546 N,
547 K,
548 StrideAs,
549 StrideBs,
550 StrideDs,
551 StrideE,
552 });
553
554 group_id++;
555 }
556
557 const auto e_grid_desc_sum_m_n =
558 GridwiseGemm64::template MakeEGridDescriptor_M_N<ELayout, GemmSpec>(
560
561 const auto local_b2c_tile_map = Block2ETileMap{e_grid_desc_sum_m_n, 1};
562
563 barrier_size_grp_ = local_b2c_tile_map.CalculateGridSize(e_grid_desc_sum_m_n);
564 }
565
566 // private:
568
569 AElementwiseOperation a_element_op_;
570 BElementwiseOperation b_element_op_;
571 CDEElementwiseOperation c_element_op_;
572
573 std::vector<GemmBiasTransKernelArg> gemm_desc_kernel_arg_;
574 std::vector<Tuple<index_t, index_t>> a_mtx_mraw_kraw_;
575 std::vector<Tuple<index_t, index_t>> b_mtx_nraw_kraw_;
576
578
583
585 };
586
587 // Invoker
588 struct Invoker : public BaseInvoker
589 {
591
592 template <typename GridwiseGemm>
593 float RunImp(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
594 {
595 bool has_main_k_block_loop = true;
596
597 for(std::size_t i = 0; i < arg.gemm_desc_kernel_arg_.size(); i++)
598 {
599 if(GridwiseGemm::CalculateHasMainKBlockLoop(arg.gemm_desc_kernel_arg_[i].K_) !=
600 has_main_k_block_loop)
601 {
602 throw std::runtime_error("wrong! not all gemm has_main_k_block_loop");
603 }
604 }
605
606 if(arg.grouped_gemm_kernel_args_dev == nullptr)
607 {
608 throw std::runtime_error("wrong! grouped_gemm_kernel_args_dev is nullpr");
609 }
610
611 float ave_time = 0;
612
613 auto launch_kernel = [&](auto has_main_k_block_loop_, auto e_global_memory_operation_) {
614 const auto kernel = kernel_grouped_gemm_xdl_fixed_nk<
615 GridwiseGemm,
616 GroupedGemmMultiABDKernelArgument<NumATensor, NumBTensor, NumDTensor>,
617 GemmSpec,
618 AsLayout,
619 BsLayout,
620 DsLayout,
621 ELayout,
624 AElementwiseOperation,
625 BElementwiseOperation,
626 CDEElementwiseOperation,
627 e_global_memory_operation_,
628 has_main_k_block_loop_>;
629
631 stream_config,
632 kernel,
633 dim3(arg.grid_size_),
634 dim3(BlockSize),
635 0,
637 arg.gemm_desc_kernel_arg_.size(),
638 arg.grid_size_grp_,
639 arg.a_element_op_,
640 arg.b_element_op_,
641 arg.c_element_op_);
642 };
643
645 constexpr auto Set = InMemoryDataOperationEnum::Set;
646
647 if(arg.k_batch_ > 1)
648 {
649 if(has_main_k_block_loop)
650 {
651 ave_time =
652 launch_kernel(integral_constant<bool, true>{},
653 integral_constant<InMemoryDataOperationEnum, AtomicAdd>{});
654 }
655 else
656 {
657 ave_time =
658 launch_kernel(integral_constant<bool, false>{},
659 integral_constant<InMemoryDataOperationEnum, AtomicAdd>{});
660 }
661 }
662 else
663 {
664 if(has_main_k_block_loop)
665 {
666 ave_time = launch_kernel(integral_constant<bool, true>{},
667 integral_constant<InMemoryDataOperationEnum, Set>{});
668 }
669 else
670 {
671 ave_time = launch_kernel(integral_constant<bool, false>{},
672 integral_constant<InMemoryDataOperationEnum, Set>{});
673 }
674 }
675
676 return ave_time;
677 }
678
680
681 // polymorphic
682 float Run(const BaseArgument* p_arg,
683 const StreamConfig& stream_config = StreamConfig{}) override
684 {
685 return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
686 }
687 };
688
689 static bool IsSupportedArgument(const Argument& arg)
690 {
691 // Split-K autodeduction is not supported
692 if(arg.k_batch_ < 1)
693 {
694 return false;
695 }
696
698 {
699 return false;
700 }
701
702 bool supported = true;
703
704 // If we use padding we do not support vector loads for dimensions not divisible by vector
705 // load size.
706 if constexpr(GemmSpec != GemmSpecialization::Default)
707 {
708 // [A|B]BlockTransferSrcVectorDim value define dimension in the block {K0,M,K1} layout,
709 // thus we have to adapt it to the {M,K} or {N,K} layout.
710 const auto a_raw_vector_dim = ABlockTransferSrcVectorDim != 1 ? 1 : 0;
711 const auto b_raw_vector_dim = BBlockTransferSrcVectorDim != 1 ? 1 : 0;
712
713 for(index_t i = 0; i < arg.group_count_; ++i)
714 {
715 const auto a_vector_dim = arg.a_mtx_mraw_kraw_[i].At(Number<a_raw_vector_dim>{});
716 const auto b_vector_dim = arg.b_mtx_nraw_kraw_[i].At(Number<b_raw_vector_dim>{});
717
718 supported = supported & (a_vector_dim % ABlockTransferSrcScalarPerVector == 0);
719 supported = supported & (b_vector_dim % BBlockTransferSrcScalarPerVector == 0);
720 }
721 }
722
723 return supported;
724 }
725
726 // polymorphic
727 bool IsSupportedArgument(const BaseArgument* p_arg) override
728 {
729 return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
730 }
731
732 static auto MakeArgument(std::vector<std::array<const void*, NumATensor>>& p_As,
733 std::vector<std::array<const void*, NumBTensor>>& p_Bs,
734 std::vector<std::array<const void*, NumDTensor>>& p_Ds,
735 std::vector<void*>& p_Es,
736 std::vector<GemmMultiABDDesc> gemm_descs,
737 AElementwiseOperation a_element_op = AElementwiseOperation{},
738 BElementwiseOperation b_element_op = BElementwiseOperation{},
739 CDEElementwiseOperation c_element_op = CDEElementwiseOperation{})
740 {
741 return Argument{
742 p_As, p_Bs, p_Ds, p_Es, gemm_descs, a_element_op, b_element_op, c_element_op};
743 }
744
745 static auto MakeInvoker() { return Invoker{}; }
746
747 // polymorphic
748 std::unique_ptr<BaseArgument>
749 MakeArgumentPointer(std::vector<std::array<const void*, NumATensor>>& p_As,
750 std::vector<std::array<const void*, NumBTensor>>& p_Bs,
751 std::vector<std::array<const void*, NumDTensor>>& p_Ds,
752 std::vector<void*>& p_Es,
753 std::vector<GemmMultiABDDesc>& gemm_descs,
754 AElementwiseOperation a_element_op = AElementwiseOperation{},
755 BElementwiseOperation b_element_op = BElementwiseOperation{},
756 CDEElementwiseOperation c_element_op = CDEElementwiseOperation{}) override
757 {
758 return std::make_unique<Argument>(
759 p_As, p_Bs, p_Ds, p_Es, gemm_descs, a_element_op, b_element_op, c_element_op);
760 }
761
762 // polymorphic
763 std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
764 {
765 return std::make_unique<Invoker>(Invoker{});
766 }
767
768 // polymorphic
769 std::string GetTypeString() const override
770 {
771 auto str = std::stringstream();
772
773 // clang-format off
774 str << "DeviceGroupedGemm_Xdl_Fixed_NK"
775 << "<"
776 << BlockSize << ", "
777 << MPerBlock << ", "
778 << NPerBlock << ", "
779 << KPerBlock << ", "
780 << AK1 << ", "
781 << BK1 << ", "
782 << MPerXDL << ", "
783 << NPerXDL << ", "
784 << MXdlPerWave << ", "
785 << NXdlPerWave << ", "
786 << ABlockTransferSrcScalarPerVector << ", "
787 << BBlockTransferSrcScalarPerVector << ", "
788 << CShuffleMXdlPerWavePerShuffle << ", "
789 << CShuffleNXdlPerWavePerShuffle << ", "
790 << getGemmSpecializationString(GemmSpec)
791 << ">";
792 // clang-format on
793
794 return str.str();
795 }
796
797 static void SetElementwiseOps(Argument& arg,
798 AElementwiseOperation a_element_op,
799 BElementwiseOperation b_element_op,
800 CDEElementwiseOperation c_element_op)
801 {
802 arg.a_element_op_ = a_element_op;
803 arg.b_element_op_ = b_element_op;
804 arg.c_element_op_ = c_element_op;
805 }
806
807 static void SetDeviceKernelArgs(Argument& arg, const void* kernel_args)
808 {
809 arg.grouped_gemm_kernel_args_dev = kernel_args;
810 }
811
812 // polymorphic
813 void SetDeviceKernelArgs(BaseArgument* p_arg, const void* kernel_args) const override
814 {
815 return SetDeviceKernelArgs(*dynamic_cast<Argument*>(p_arg), kernel_args);
816 }
817
819 AElementwiseOperation a_element_op,
820 BElementwiseOperation b_element_op,
821 CDEElementwiseOperation c_element_op) const override
822 {
823
825 *dynamic_cast<Argument*>(p_arg), a_element_op, b_element_op, c_element_op);
826 }
827
828 size_t GetDeviceKernelArgSize(const BaseArgument* p_arg) const override
829 {
830 auto arg = *dynamic_cast<const Argument*>(p_arg);
831
832 return arg.group_count_ *
834 }
835
836#if 0
837 size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override
838 {
839 auto arg = *dynamic_cast<const Argument*>(p_arg);
840
841 return arg.group_count_ * arg.barrier_size_grp_ * sizeof(uint32_t);
842 }
843
844 void SetWorkSpacePointer(BaseArgument* p_arg,
845 void* p_workspace,
846 const StreamConfig& stream_config = StreamConfig{}) const override
847 {
848 auto p_arg_ = dynamic_cast<Argument*>(p_arg);
849 p_arg_->p_workspace_ = p_workspace;
850
852 hipMemsetAsync(p_workspace, 0, GetWorkSpaceSize(p_arg), stream_config.stream_id_));
853 }
854#endif
855
856 static void SetKBatch(Argument& arg, index_t k_batch) { arg.UpdateKBatch(k_batch); }
857
858 // polymorphic
859 void SetKBatch(BaseArgument* p_arg, index_t k_batch) const override
860 {
861 return SetKBatch(*dynamic_cast<Argument*>(p_arg), k_batch);
862 }
863};
864
865} // namespace device
866} // namespace tensor_operation
867} // namespace ck
#define CK_CONSTANT_ADDRESS_SPACE
Definition ck.hpp:23
#define CK_MIN_BLOCK_PER_CU
Definition ck.hpp:31
#define CK_MAX_THREAD_PER_BLOCK
Definition ck.hpp:30
#define GET_NXDL_PER_WAVE_IMPL
Definition device_base.hpp:81
#define INVOKER_RUN_IMPL
Definition device_base.hpp:94
void hip_check_error(hipError_t x)
Definition host_utility/hip_check_error.hpp:10
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:14
__host__ __device__ constexpr T max(T x)
Definition utility/math.hpp:84
__host__ __device__ constexpr auto integer_divide_ceil(X x, Y y)
Definition utility/math.hpp:72
Definition convolution_backward_data_specialization.hpp:8
__global__ void kernel_grouped_gemm_xdl_fixed_nk(const void CK_CONSTANT_ADDRESS_SPACE *gemm_descs_const, const index_t group_count, const index_t grid_size_grp, const AElementwiseOperation a_element_op, const BElementwiseOperation b_element_op, const CDEElementwiseOperation cde_element_op)
Definition device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp:41
std::string getGemmSpecializationString(const GemmSpecialization &s)
Definition gemm_specialization.hpp:32
GemmSpecialization
Definition gemm_specialization.hpp:11
@ Default
Definition gemm_specialization.hpp:13
Definition convolution_backward_data_specialization.hpp:7
CK_TILE_HOST float launch_kernel(const stream_config &s, Callables &&... callables)
Definition tile/host/kernel_launch.hpp:173
Definition ck.hpp:268
__host__ __device__ constexpr auto make_multi_index(Xs &&... xs)
Definition array_multi_index.hpp:15
__host__ __device__ T CK_CONSTANT_ADDRESS_SPACE * cast_pointer_to_constant_address_space(T *p)
Definition amd_address_space.hpp:35
int32_t index_t
Definition ck.hpp:299
InMemoryDataOperationEnum
Definition ck.hpp:277
@ Set
Definition ck.hpp:278
@ AtomicAdd
Definition ck.hpp:279
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
integral_constant< index_t, N > Number
Definition number.hpp:12
constexpr detail::ignore_t ignore
Definition utility/ignore.hpp:20
__device__ index_t get_block_1d_id()
Definition get_id.hpp:47
__host__ __device__ constexpr Y type_convert(X x)
Definition utility/type_convert.hpp:98
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
LoopScheduler
Definition loop_scheduler.hpp:15
__device__ T * cast_pointer_to_generic_address_space(T CK_CONSTANT_ADDRESS_SPACE *p)
Definition amd_address_space.hpp:24
constexpr LoopScheduler make_default_loop_scheduler()
Definition loop_scheduler.hpp:20
unsigned int uint32_t
Definition stdint.h:126
Definition ck/stream_config.hpp:10
Definition gridwise_gemm_multiple_abd_xdl_cshuffle.hpp:77
Definition functional2.hpp:33
Definition device_base.hpp:197
virtual void SetWorkSpacePointer(BaseArgument *p_arg, void *p_workspace, const StreamConfig &=StreamConfig{}) const
Definition device_base.hpp:249
virtual size_t GetWorkSpaceSize(const BaseArgument *) const
Definition device_base.hpp:247
__host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops(const BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops &)=default
__host__ bool CheckValidity(const CGridDesc_M_N &) const
Definition device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp:369
__host__ __device__ bool ValidCTileIndex(const CTileIdx &, const CTileDim &) const
Definition device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp:402
__host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops(const CGridDesc_M_N &c_grid_desc_m_n, index_t KBatch, index_t M01=8)
Definition device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp:346
__host__ __device__ constexpr index_t CalculateGridSize(const CGridDesc_M_N &c_grid_desc_m_n) const
Definition device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp:363
__host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx &idx_top) const
Definition device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp:375
__host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops & operator=(const BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops &)=default
__host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops(BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops &&)=default
__host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops & operator=(BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops &&)=default
__host__ __device__ constexpr index_t CalculateGridSize(index_t M, index_t N) const
Definition device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp:353
__host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops(index_t M, index_t N, index_t KBatch, index_t M01=8)
Definition device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp:337
Block2ETileMap block_to_ctile_map_
Definition device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp:315
__host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx &idx_top) const
Definition device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp:285
__host__ __device__ bool ValidCTileIndex(const CTileIdx &c_tile_idx, const CTileDim &c_tile_dim) const
Definition device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp:297
__host__ bool CheckValidity(const CGridDesc_M_N &c_grid_desc_m_n) const
Definition device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp:304
UnderlyingBlockToCTileMap underlying_type
Definition device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp:274
__host__ constexpr index_t CalculateGridSize(const CGridDesc_M_N &c_grid_desc_m_n) const
Definition device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp:310
__host__ __device__ OffsettedBlockToCTileMapMLoops(UnderlyingBlockToCTileMap block_to_ctile_map, index_t block_start, index_t id_off=0)
Definition device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp:276
std::array< const void *, NumATensor > as_ptr_
Definition device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp:421
std::array< const void *, NumBTensor > bs_ptr_
Definition device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp:422
void * e_ptr_
Definition device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp:424
index_t M_
Definition device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp:426
std::array< index_t, NumBTensor > StrideBs_
Definition device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp:428
index_t K_
Definition device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp:426
std::array< index_t, NumDTensor > StrideDs_
Definition device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp:429
std::array< const void *, NumDTensor > ds_ptr_
Definition device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp:423
std::array< index_t, NumATensor > StrideAs_
Definition device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp:427
index_t N_
Definition device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp:426
index_t StrideE_
Definition device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp:430
Definition device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp:435
Argument(std::vector< std::array< const void *, NumATensor > > &, std::vector< std::array< const void *, NumBTensor > > &, std::vector< std::array< const void *, NumDTensor > > &, std::vector< void * > &, std::vector< GemmMultiABDDesc > &gemm_descs, AElementwiseOperation a_element_op=AElementwiseOperation{}, BElementwiseOperation b_element_op=BElementwiseOperation{}, CDEElementwiseOperation c_element_op=CDEElementwiseOperation{})
Definition device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp:439
std::vector< GemmBiasTransKernelArg > gemm_desc_kernel_arg_
Definition device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp:573
index_t sum_of_m
Definition device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp:582
BElementwiseOperation b_element_op_
Definition device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp:570
std::vector< Tuple< index_t, index_t > > a_mtx_mraw_kraw_
Definition device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp:574
index_t grid_size_grp_
Definition device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp:580
index_t grid_size_
Definition device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp:579
AElementwiseOperation a_element_op_
Definition device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp:569
CDEElementwiseOperation c_element_op_
Definition device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp:571
std::vector< Tuple< index_t, index_t > > b_mtx_nraw_kraw_
Definition device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp:575
const void * grouped_gemm_kernel_args_dev
Definition device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp:577
index_t group_count_
Definition device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp:567
void UpdateKBatch(index_t)
Definition device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp:437
index_t barrier_size_grp_
Definition device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp:581
index_t k_batch_
Definition device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp:584
Definition device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp:589
float RunImp(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp:593
INVOKER_RUN_IMPL float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp:682
DeviceOp::Argument Argument
Definition device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp:590
Definition device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp:207
static constexpr auto I0
Definition device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp:217
static constexpr auto I2
Definition device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp:219
static auto MakeArgument(std::vector< std::array< const void *, NumATensor > > &p_As, std::vector< std::array< const void *, NumBTensor > > &p_Bs, std::vector< std::array< const void *, NumDTensor > > &p_Ds, std::vector< void * > &p_Es, std::vector< GemmMultiABDDesc > gemm_descs, AElementwiseOperation a_element_op=AElementwiseOperation{}, BElementwiseOperation b_element_op=BElementwiseOperation{}, CDEElementwiseOperation c_element_op=CDEElementwiseOperation{})
Definition device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp:732
static constexpr index_t NumBTensor
Definition device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp:214
static constexpr index_t NumGemmKPrefetchStage
Definition device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp:221
size_t GetDeviceKernelArgSize(const BaseArgument *p_arg) const override
Definition device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp:828
static constexpr index_t NumATensor
Definition device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp:213
std::unique_ptr< BaseArgument > MakeArgumentPointer(std::vector< std::array< const void *, NumATensor > > &p_As, std::vector< std::array< const void *, NumBTensor > > &p_Bs, std::vector< std::array< const void *, NumDTensor > > &p_Ds, std::vector< void * > &p_Es, std::vector< GemmMultiABDDesc > &gemm_descs, AElementwiseOperation a_element_op=AElementwiseOperation{}, BElementwiseOperation b_element_op=BElementwiseOperation{}, CDEElementwiseOperation c_element_op=CDEElementwiseOperation{}) override
Definition device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp:749
static constexpr auto I1
Definition device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp:218
void SetElementwiseOps(BaseArgument *p_arg, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation c_element_op) const override
Definition device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp:818
BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops< MPerBlock, NPerBlock > Block2ETileMap
Definition device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp:415
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp:763
static void SetKBatch(Argument &arg, index_t k_batch)
Definition device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp:856
static constexpr auto NXdlPerWave32
Definition device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp:211
static auto MakeInvoker()
Definition device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp:745
GridwiseGemmBase< math::max(NXdlPerWave64, 1)> GridwiseGemm64
Definition device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp:269
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp:727
void SetKBatch(BaseArgument *p_arg, index_t k_batch) const override
Definition device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp:859
static GET_NXDL_PER_WAVE_IMPL constexpr auto NXdlPerWave64
Definition device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp:210
OffsettedBlockToCTileMapMLoops< Block2ETileMap > GroupedGemmBlock2ETileMap
Definition device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp:416
static bool IsSupportedArgument(const Argument &arg)
Definition device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp:689
std::string GetTypeString() const override
Definition device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp:769
void SetDeviceKernelArgs(BaseArgument *p_arg, const void *kernel_args) const override
Definition device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp:813
DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK DeviceOp
Definition device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp:208
GridwiseGemmBase< NXdlPerWave32 > GridwiseGemm32
Definition device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp:270
GridwiseGemmMultipleABD_xdl_cshuffle< AsDataType, BsDataType, ComputeType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, InMemoryDataOperationEnum::Set, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEBlockTransferScalarPerVector_NPerBlock, LoopSched > GridwiseGemmBase
Definition device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp:225
static void SetDeviceKernelArgs(Argument &arg, const void *kernel_args)
Definition device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp:807
static void SetElementwiseOps(Argument &arg, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation c_element_op)
Definition device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp:797
static constexpr index_t NumDTensor
Definition device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp:215
Definition device_grouped_gemm_multi_abd_fixed_nk.hpp:73
Definition device_grouped_gemm.hpp:80
Definition device_grouped_gemm_multi_abd_fixed_nk.hpp:17