device_grouped_query_attention_forward_wmma.hpp Source File

device_grouped_query_attention_forward_wmma.hpp Source File#

Composable Kernel: device_grouped_query_attention_forward_wmma.hpp Source File
device_grouped_query_attention_forward_wmma.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include <iostream>
7#include <sstream>
8#include <numeric>
9#include <initializer_list>
10#include <cstdlib>
11
12#include "ck/ck.hpp"
24
25namespace ck {
26namespace tensor_operation {
27namespace device {
28
29// Multi-Query Attention (MQA) kernel implementation
30// Assume number of head of K,V is 1.
31// Q [G0, G1, M, K] * K [G0, 1, K, N] = P [G0, G1, M, N]
32// P [G0, G1, M, N] * V [G0, 1, N, O] = Out [G0, G1, M, O]
33template <typename DeviceOp,
34 typename GridwiseOp,
35 typename ADataType,
36 typename B0DataType,
37 typename B1DataType,
38 typename CDataType,
39 typename AElementwiseOperation,
40 typename B0ElementwiseOperation,
41 typename AccElementwiseOperation,
42 typename B1ElementwiseOperation,
43 typename CElementwiseOperation,
44 ck::index_t QueryGroupNumber,
45 bool HasMainKBlockLoop>
46__global__ void
47#if CK_USE_LAUNCH_BOUNDS
49#endif
50 kernel_grouped_query_attention_wmma(const ADataType* __restrict__ p_a_grid,
51 const B0DataType* __restrict__ p_b0_grid,
52 const B1DataType* __restrict__ p_b1_grid,
53 CDataType* __restrict__ p_c_grid,
54 index_t M, // SequenceQ
55 index_t N, // SequenceK
56 index_t K, // HeadDim
57 index_t O, // SequenceK
58 index_t G0, // Batch
59 index_t G1, // HeadNum
60 float alpha,
61 bool input_permute,
62 bool output_permute)
63{
64#if(defined(__gfx11__) || defined(__gfx12__))
65
66 // clang-format off
67// ***************************************************
68 const auto q_head = G1;
69 const auto kv_head = QueryGroupNumber;
70// Make Tensor Descriptors
71 constexpr index_t array_size = 4;
72 std::array<ck::index_t, array_size> a_gs_ms_ks_lengths{G0, q_head, M, K};
73 std::array<ck::index_t, array_size> a_gs_ms_ks_strides =
74 input_permute
75 ? std::array<ck::index_t, array_size>{M * q_head * K, K, q_head * K, 1} // A layout [G0, M, G1, K]
76 : std::array<ck::index_t, array_size>{q_head * M * K, M * K, K, 1}; // A layout [G0, G1, M, K]
77
78 std::array<ck::index_t, array_size> b0_gs_ns_ks_lengths{G0, kv_head, N, K};
79 std::array<ck::index_t, array_size> b0_gs_ns_ks_strides =
80 input_permute
81 ? std::array<ck::index_t, array_size>{N * kv_head * K, K, kv_head * K, 1} // B0 layout [G0, N, 1, K]
82 : std::array<ck::index_t, array_size>{kv_head * N * K, N * K, K, 1}; // B0 layout [G0, 1, N, K]
83
84 std::array<ck::index_t, array_size> b1_gs_os_ns_lengths{G0, kv_head, O, N};
85 std::array<ck::index_t, array_size> b1_gs_os_ns_strides =
86 input_permute
87 ? std::array<ck::index_t, array_size>{N * kv_head * O, O, 1, kv_head * O} // B1 layout [G0, N, 1, O]
88 : std::array<ck::index_t, array_size>{kv_head * N * O, N * O, 1, O}; // B1 layout [G0, 1, N, O]
89
90 std::array<ck::index_t, array_size> c_gs_ms_os_lengths{G0, q_head, M, O};
91 std::array<ck::index_t, array_size> c_gs_ms_os_strides =
92 output_permute
93 ? std::array<ck::index_t, array_size>{M * q_head * O, O, q_head * O, 1} // C layout [G0, M, G1, O]
94 : std::array<ck::index_t, array_size>{q_head * M * O, M * O, O, 1}; // C layout [G0, G1, M, O]
95
96 const auto a_element_op = AElementwiseOperation{};
97 const auto b0_element_op = B0ElementwiseOperation{};
98 const auto acc0_element_op = AccElementwiseOperation{alpha};
99 const auto b1_element_op = B1ElementwiseOperation{};
100 const auto c_element_op = CElementwiseOperation{};
101 // fail to reuse DeviceOp::MakeArgument() because of the __device__ function required.
102
103 const auto a_grid_desc = DeviceOp::MakeAGridDescriptor(a_gs_ms_ks_lengths, a_gs_ms_ks_strides);
104 const auto b0_grid_desc =
105 DeviceOp::MakeB0GridDescriptor(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides);
106 const auto b1_grid_desc =
107 DeviceOp::MakeB1GridDescriptor(b1_gs_os_ns_lengths, b1_gs_os_ns_strides);
108 const auto c_grid_desc_m_n =
109 DeviceOp::Transform::MakeCGridDescriptor_M_N(c_gs_ms_os_lengths, c_gs_ms_os_strides);
110 const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
111 GridwiseOp::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(c_grid_desc_m_n);
112 const auto block_2_ctile_map = GridwiseOp::MakeDefaultBlock2CTileMap(c_grid_desc_m_n, 1, 1);
113
114 const auto a_grid_desc_g_m_k =
115 DeviceOp::Transform::MakeAGridDescriptor_G_M_K(a_gs_ms_ks_lengths, a_gs_ms_ks_strides);
116 const auto b0_grid_desc_g_l_k =
117 DeviceOp::Transform::MakeB0GridDescriptor_G_N_K(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides);
118 const auto b1_grid_desc_g_n_l =
119 DeviceOp::Transform::MakeB1GridDescriptor_G_N_K(b1_gs_os_ns_lengths, b1_gs_os_ns_strides);
120 const auto c_grid_desc_g_m_n =
121 DeviceOp::Transform::MakeCGridDescriptor_G_M_N(c_gs_ms_os_lengths, c_gs_ms_os_strides);
122 const auto compute_base_ptr_of_batch =
123 typename DeviceOp::ComputeBasePtrOfStridedBatch{a_grid_desc_g_m_k, b0_grid_desc_g_l_k, b1_grid_desc_g_n_l, c_grid_desc_g_m_n};
124 index_t batch_count = c_grid_desc_g_m_n.GetLength(Number<0>{});
125 const auto c0_matrix_mask = typename DeviceOp::C0MatrixMask{b0_grid_desc_g_l_k.GetLength(Number<1>{})};
126
127 // clang-format on
128 __shared__ char p_shared[GridwiseOp::GetSharedMemoryNumberOfByte()];
129 const index_t num_blocks_per_batch =
130 __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
131 const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
132
133 const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane(
134 static_cast<long_index_t>(compute_base_ptr_of_batch.GetABasePtr(g_idx)));
135 const long_index_t b0_batch_offset = __builtin_amdgcn_readfirstlane(static_cast<long_index_t>(
136 compute_base_ptr_of_batch.GetB0BasePtr(g_idx * QueryGroupNumber / G1)));
137 const long_index_t b1_batch_offset = __builtin_amdgcn_readfirstlane(static_cast<long_index_t>(
138 compute_base_ptr_of_batch.GetB1BasePtr(g_idx * QueryGroupNumber / G1)));
139 const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane(
140 static_cast<long_index_t>(compute_base_ptr_of_batch.GetCBasePtr(g_idx)));
141
142 GridwiseOp::template Run<HasMainKBlockLoop>(p_a_grid + a_batch_offset,
143 p_b0_grid + b0_batch_offset,
144 p_b1_grid + b1_batch_offset,
145 p_c_grid + c_batch_offset,
146 p_shared,
147 a_grid_desc,
148 b0_grid_desc,
149 b1_grid_desc,
150 c_grid_desc_mblock_mperblock_nblock_nperblock,
151 a_element_op,
152 b0_element_op,
153 acc0_element_op,
154 b1_element_op,
155 c_element_op,
156 c0_matrix_mask,
157 block_2_ctile_map);
158#else
159 ignore = p_a_grid;
160 ignore = p_b0_grid;
161 ignore = p_b1_grid;
162 ignore = p_c_grid;
163 ignore = M;
164 ignore = N;
165 ignore = K;
166 ignore = O;
167 ignore = G0;
168 ignore = G1;
169 ignore = alpha;
170 ignore = input_permute;
171 ignore = output_permute;
172#endif // end of if (defined(__gfx11__))
173}
174
175// Computes C = A * B0 * B1
176// MN = MK * KL * LN
177// ^^^^^^ (Acc0)
178// ^^^^^^^^^^^ (Acc1)
179template <index_t NumDimG,
180 index_t NumDimM,
181 index_t NumDimL,
182 index_t NumDimK,
183 index_t NumDimN,
184 typename ADataType,
185 typename B0DataType,
186 typename B1DataType,
187 typename CDataType,
188 typename Acc0BiasDataType,
189 typename Acc0DataType,
190 typename Acc1BiasDataType,
191 typename Acc1DataType,
192 typename CShuffleDataType,
193 typename AElementwiseOperation,
194 typename B0ElementwiseOperation,
195 typename AccElementwiseOperation,
196 typename B1ElementwiseOperation,
197 typename CElementwiseOperation,
198 GemmSpecialization GemmSpec,
203 ck::index_t NumPrefetch,
204 ck::index_t QueryGroupNumber,
205 ck::index_t BlockSize,
206 ck::index_t MPerBlock,
207 ck::index_t LPerBlock,
208 ck::index_t KPerBlock,
209 ck::index_t AK1,
210 ck::index_t BK1,
211 ck::index_t NPerBlock,
212 ck::index_t LTilePerBlock,
213 ck::index_t L1,
214 ck::index_t MPerWmma,
215 ck::index_t LPerWmma,
216 ck::index_t NPerWmma,
217 ck::index_t MRepeat,
218 ck::index_t LRepeat,
219 ck::index_t NRepeat,
220 typename ABlockTransferThreadClusterLengths_K0_M_K1,
221 typename ABlockTransferThreadClusterArrangeOrder,
222 typename ABlockTransferSrcAccessOrder,
223 ck::index_t ABlockTransferSrcVectorDim,
224 ck::index_t ABlockTransferSrcScalarPerVector,
225 ck::index_t ABlockTransferDstScalarPerVector_K1,
226 bool ABlockLdsAddExtraM,
227 typename B0BlockTransferThreadClusterLengths_K0_L_K1,
228 typename B0BlockTransferThreadClusterArrangeOrder,
229 typename B0BlockTransferSrcAccessOrder,
230 ck::index_t B0BlockTransferSrcVectorDim,
231 ck::index_t B0BlockTransferSrcScalarPerVector,
232 ck::index_t B0BlockTransferDstScalarPerVector_K1,
233 bool B0BlockLdsAddExtraL,
234 typename B1BlockTransferThreadClusterLengths_L0_N_L1,
235 typename B1BlockTransferThreadClusterArrangeOrder,
236 typename B1BlockTransferSrcAccessOrder,
237 ck::index_t B1BlockTransferSrcVectorDim,
238 ck::index_t B1BlockTransferSrcScalarPerVector,
239 ck::index_t B1BlockTransferDstScalarPerVector_L1,
240 bool B1BlockLdsAddExtraN,
241 index_t CShuffleMRepeatPerShuffle,
242 index_t CShuffleNRepeatPerShuffle,
243 typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
244 index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
245 MaskingSpecialization MaskingSpec,
250 NumDimM,
251 NumDimL,
252 NumDimK,
253 NumDimN,
254 ADataType,
255 B0DataType,
256 B1DataType,
257 CDataType,
258 Acc0BiasDataType,
259 Acc1BiasDataType,
260 AElementwiseOperation,
261 B0ElementwiseOperation,
262 AccElementwiseOperation,
263 B1ElementwiseOperation,
264 CElementwiseOperation,
265 MaskingSpec>
266{
267 static_assert(NumDimG > 0 && NumDimM > 0 && NumDimL > 0 && NumDimK > 0 && NumDimN > 0,
268 "Number of dimension must be greater than 0");
269
270 static constexpr index_t NumAcc0Bias = Acc0BiasDataType::Size();
271 static constexpr index_t NumAcc1Bias = Acc1BiasDataType::Size();
272
273 // TODO ANT: implement bias combination
274 static_assert(NumAcc0Bias == 0 && NumAcc0Bias == 0, "Bias addition is unimplemented");
275
276 static constexpr index_t NumDimGemm0M = NumDimM;
277 static constexpr index_t NumDimGemm0N = NumDimL;
278 static constexpr index_t NumDimGemm0K = NumDimK;
279 static constexpr index_t NumDimGemm1M = NumDimM;
280 static constexpr index_t NumDimGemm1N = NumDimN;
281 static constexpr index_t NumDimGemm1K = NumDimL;
282
284
285 static constexpr auto I0 = Number<0>{};
286 static constexpr auto I1 = Number<1>{};
287 static constexpr auto I2 = Number<2>{};
288 static constexpr auto I3 = Number<3>{};
289 static constexpr auto I4 = Number<4>{};
290 static constexpr auto I5 = Number<5>{};
291 static constexpr auto I6 = Number<6>{};
292
293 static constexpr auto WmmaK = 16;
294
295 static constexpr auto MWaves = MPerBlock / (MRepeat * MPerWmma);
296 static constexpr auto LWaves = LPerBlock / (LRepeat * LPerWmma);
297 static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma);
298
299 static constexpr auto AEnableLds_auto = LWaves == 1 ? false : true;
300 static constexpr auto B0EnableLds_auto = MWaves == 1 ? false : true;
301 static constexpr auto B1EnableLds_auto = MWaves == 1 ? false : true;
302
303 static constexpr auto AEnableLds_manu = false;
304 static constexpr auto B0EnableLds_manu = true;
305 static constexpr auto B1EnableLds_manu = true;
306
307 static constexpr auto AEnableLds = AEnableLds_auto || AEnableLds_manu || (NumPrefetch > 1);
308 static constexpr auto B0EnableLds = B0EnableLds_auto || B0EnableLds_manu || (NumPrefetch > 1);
309 static constexpr auto B1EnableLds = B1EnableLds_auto || B1EnableLds_manu || (NumPrefetch > 1);
310
314 GemmSpec,
315 ASpec,
316 B0Spec,
317 B1Spec,
318 CSpec>;
319
320 __host__ __device__ static auto MakeAGridDescriptor(
321 const std::array<index_t, NumDimG + NumDimM + NumDimN>& a_gs_ms_ks_lengths_vec,
322 const std::array<index_t, NumDimG + NumDimM + NumDimN>& a_gs_ms_ks_strides_vec)
323 {
324 if constexpr(AEnableLds)
325 {
327 Transform::MakeAGridDescriptor_M_K(a_gs_ms_ks_lengths_vec, a_gs_ms_ks_strides_vec),
328 Number<AK1>{});
329 }
330 else
331 {
332 return Transform::
334 Transform::MakeAGridDescriptor_M_K(a_gs_ms_ks_lengths_vec,
335 a_gs_ms_ks_strides_vec),
340 Number<AK1>{});
341 }
342 }
343
344 __host__ __device__ static auto MakeB0GridDescriptor(
345 const std::array<index_t, NumDimG + NumDimM + NumDimN>& b0_gs_ls_ks_lengths_vec,
346 const std::array<index_t, NumDimG + NumDimM + NumDimN>& b0_gs_ls_ks_strides_vec)
347 {
348 if constexpr(B0EnableLds)
349 {
351 Transform::MakeB0GridDescriptor_N_K(b0_gs_ls_ks_lengths_vec,
352 b0_gs_ls_ks_strides_vec),
353 Number<BK1>{});
354 }
355 else
356 {
357 return Transform::
359 Transform::MakeB0GridDescriptor_N_K(b0_gs_ls_ks_lengths_vec,
360 b0_gs_ls_ks_strides_vec),
365 Number<BK1>{});
366 }
367 }
368
369 __host__ __device__ static auto MakeB1GridDescriptor(
370 const std::array<index_t, NumDimG + NumDimM + NumDimN>& b1_gs_ns_ls_lengths_vec,
371 const std::array<index_t, NumDimG + NumDimM + NumDimN>& b1_gs_ns_ls_strides_vec)
372 {
373 if constexpr(B1EnableLds)
374 {
376 Transform::MakeB1GridDescriptor_N_K(b1_gs_ns_ls_lengths_vec,
377 b1_gs_ns_ls_strides_vec),
378 Number<L1>{});
379 }
380 else
381 {
382 return Transform::
384 Transform::MakeB1GridDescriptor_N_K(b1_gs_ns_ls_lengths_vec,
385 b1_gs_ns_ls_strides_vec),
390 Number<L1>{});
391 }
392 }
393
394 using AGridDesc = decltype(MakeAGridDescriptor({}, {}));
395 using B0GridDesc = decltype(MakeB0GridDescriptor({}, {}));
396 using B1GridDesc = decltype(MakeB1GridDescriptor({}, {}));
402
403 __host__ __device__ constexpr static auto make_MaskOutPredicate()
404 {
405 if constexpr(MaskingSpec == MaskingSpecialization::MaskDisabled)
406 {
407 return MaskDisabledPredicate{};
408 }
409 else if constexpr(MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle)
410 {
412 }
413 }
415
417 {
418 __host__ __device__ ComputeBasePtrOfStridedBatch(const AGridDesc_G_M_K& a_grid_desc_g_m_k,
419 const B0GridDesc_G_L_K& b0_grid_desc_g_l_k,
420 const B1GridDesc_G_N_L& b1_grid_desc_g_n_l,
421 const CGridDesc_G_M_N& c_grid_desc_g_m_n)
422 : a_grid_desc_g_m_k_(a_grid_desc_g_m_k),
423 b0_grid_desc_g_l_k_(b0_grid_desc_g_l_k),
424 b1_grid_desc_g_n_l_(b1_grid_desc_g_n_l),
425 c_grid_desc_g_m_n_(c_grid_desc_g_m_n)
426 {
427 }
428
429 __host__ __device__ constexpr long_index_t GetABasePtr(index_t g_idx) const
430 {
431 return a_grid_desc_g_m_k_.CalculateOffset(make_multi_index(g_idx, 0, 0));
432 }
433
434 __host__ __device__ constexpr long_index_t GetB0BasePtr(index_t g_idx) const
435 {
436 return b0_grid_desc_g_l_k_.CalculateOffset(make_multi_index(g_idx, 0, 0));
437 }
438
439 __host__ __device__ constexpr long_index_t GetB1BasePtr(index_t g_idx) const
440 {
441 return b1_grid_desc_g_n_l_.CalculateOffset(make_multi_index(g_idx, 0, 0));
442 }
443
444 __host__ __device__ constexpr long_index_t GetCBasePtr(index_t g_idx) const
445 {
446 return c_grid_desc_g_m_n_.CalculateOffset(make_multi_index(g_idx, 0, 0));
447 }
448
449 private:
450 AGridDesc_G_M_K a_grid_desc_g_m_k_;
451 B0GridDesc_G_L_K b0_grid_desc_g_l_k_;
452 B1GridDesc_G_N_L b1_grid_desc_g_n_l_;
453 CGridDesc_G_M_N c_grid_desc_g_m_n_;
454 };
455
456 // GridwiseOp
458 // DataType Family
459 ADataType,
460 B0DataType,
461 Acc0DataType,
462 B1DataType,
463 Acc1DataType,
464 CShuffleDataType,
465 CDataType,
466 // ElementwiseOp Family
467 AElementwiseOperation,
468 B0ElementwiseOperation,
469 AccElementwiseOperation,
470 B1ElementwiseOperation,
471 CElementwiseOperation,
473 // InMemory Data Descriptor
474 AGridDesc,
478 // Tiling Family
479 MPerBlock,
480 LPerBlock,
481 KPerBlock,
482 AK1,
483 BK1,
484 NPerBlock,
485 LTilePerBlock,
486 L1,
487 MPerWmma,
488 LPerWmma,
489 NPerWmma,
490 MRepeat,
491 LRepeat,
492 NRepeat,
493 // ThreadCluster Family
494 BlockSize,
495 ABlockTransferThreadClusterLengths_K0_M_K1,
496 ABlockTransferThreadClusterArrangeOrder,
497 ABlockTransferSrcAccessOrder,
498 ABlockTransferSrcVectorDim,
499 ABlockTransferSrcScalarPerVector,
500 ABlockTransferDstScalarPerVector_K1,
501 true,
503 ABlockLdsAddExtraM,
504 B0BlockTransferThreadClusterLengths_K0_L_K1,
505 B0BlockTransferThreadClusterArrangeOrder,
506 B0BlockTransferSrcAccessOrder,
507 B0BlockTransferSrcVectorDim,
508 B0BlockTransferSrcScalarPerVector,
509 B0BlockTransferDstScalarPerVector_K1,
510 true,
512 B0BlockLdsAddExtraL,
513 B1BlockTransferThreadClusterLengths_L0_N_L1,
514 B1BlockTransferThreadClusterArrangeOrder,
515 B1BlockTransferSrcAccessOrder,
516 B1BlockTransferSrcVectorDim,
517 B1BlockTransferSrcScalarPerVector,
518 B1BlockTransferDstScalarPerVector_L1,
519 false,
521 B1BlockLdsAddExtraN,
522 CShuffleMRepeatPerShuffle,
523 CShuffleNRepeatPerShuffle,
524 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
525 CShuffleBlockTransferScalarPerVector_NPerBlock,
528 NumPrefetch,
529 LoopSched,
530 PipelineVer>;
531
532 struct RawArg : public BaseArgument
533 {
534 RawArg(const ADataType* p_a_grid,
535 const B0DataType* p_b0_grid,
536 const B1DataType* p_b1_grid,
537 CDataType* p_c_grid,
538 index_t M,
539 index_t N,
540 index_t K,
541 index_t O,
542 index_t G0,
543 index_t G1,
544 float alpha,
545 bool input_permute,
546 bool output_permute)
547 : p_a_grid_{p_a_grid},
548 p_b0_grid_{p_b0_grid},
549 p_b1_grid_{p_b1_grid},
550 p_c_grid_{p_c_grid},
551 M_{M},
552 N_{N},
553 K_{K},
554 O_{O},
555 G0_{G0},
556 G1_{G1},
557 alpha_{alpha},
558 input_permute_{input_permute},
559 output_permute_{output_permute}
560 {
561 }
562 // Pointers
563 const ADataType* p_a_grid_;
564 const B0DataType* p_b0_grid_;
565 const B1DataType* p_b1_grid_;
566 CDataType* p_c_grid_;
567
568 // Raw Problem Size
575 float alpha_;
578 };
579
580 static auto MakeArgument(const ADataType* p_a,
581 const B0DataType* p_b0,
582 const B1DataType* p_b1,
583 CDataType* p_c,
584 index_t M,
585 index_t N,
586 index_t K,
587 index_t O,
588 index_t G0,
589 index_t G1,
590 float alpha,
591 bool input_permute,
592 bool output_permute)
593 {
594 return RawArg{
595 p_a, p_b0, p_b1, p_c, M, N, K, O, G0, G1, alpha, input_permute, output_permute};
596 }
597
598 static bool IsSupportedArgument(const RawArg& arg)
599 {
601 {
603 {
604 printf("DeviceOp: Acc0 Type err");
605 return false;
606 }
607
609 {
610 printf("DeviceOp: Acc1 Type err");
611 return false;
612 }
613 }
614 else
615 {
616 printf("DeviceOp: Arch err");
617 return false;
618 }
619
620 if(arg.G1_ % QueryGroupNumber != 0)
621 {
622 return false;
623 }
624
625 constexpr index_t array_size = 4;
626 ck::index_t G0 = arg.G0_;
627 ck::index_t G1 = arg.G1_;
628 ck::index_t M = arg.M_;
629 ck::index_t N = arg.N_;
630 ck::index_t K = arg.K_;
631 ck::index_t O = arg.O_;
632 bool input_permute = arg.input_permute_;
633 bool output_permute = arg.output_permute_;
634
635 std::array<ck::index_t, array_size> a_gs_ms_ks_lengths{G0, G1, M, K};
636 std::array<ck::index_t, array_size> a_gs_ms_ks_strides =
637 input_permute ? std::array<ck::index_t, array_size>{M * G1 * K, K, G1 * K, 1}
638 // A layout [G0, M, G1, K]
639 : std::array<ck::index_t, array_size>{
640 G1 * M * K, M * K, K, 1}; // A layout [G0, G1, M, K]
641
642 std::array<ck::index_t, array_size> b0_gs_ns_ks_lengths{G0, G1, N, K};
643 std::array<ck::index_t, array_size> b0_gs_ns_ks_strides =
644 input_permute ? std::array<ck::index_t, array_size>{N * G1 * K, K, G1 * K, 1}
645 // B0 layout [G0, N, G1, K]
646 : std::array<ck::index_t, array_size>{
647 G1 * N * K, N * K, K, 1}; // B0 layout [G0, G1, N, K]
648
649 std::array<ck::index_t, array_size> b1_gs_os_ns_lengths{G0, G1, O, N};
650 std::array<ck::index_t, array_size> b1_gs_os_ns_strides =
651 input_permute ? std::array<ck::index_t, array_size>{N * G1 * O, O, 1, G1 * O}
652 // B1 layout [G0, N, G1, O]
653 : std::array<ck::index_t, array_size>{
654 G1 * N * O, N * O, 1, O}; // B1 layout [G0, G1, N, O]
655
656 std::array<ck::index_t, array_size> c_gs_ms_os_lengths{G0, G1, M, O};
657 std::array<ck::index_t, array_size> c_gs_ms_os_strides =
658 output_permute ? std::array<ck::index_t, array_size>{M * G1 * O, O, G1 * O, 1}
659 // C layout [G0, M, G1, O]
660 : std::array<ck::index_t, array_size>{
661 G1 * M * O, M * O, O, 1}; // C layout [G0, G1, M, O]
662
663 const auto a_grid_desc =
664 DeviceOp::MakeAGridDescriptor(a_gs_ms_ks_lengths, a_gs_ms_ks_strides);
665 const auto b0_grid_desc =
666 DeviceOp::MakeB0GridDescriptor(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides);
667 const auto b1_grid_desc =
668 DeviceOp::MakeB1GridDescriptor(b1_gs_os_ns_lengths, b1_gs_os_ns_strides);
669 const auto c_grid_desc_m_n =
670 DeviceOp::Transform::MakeCGridDescriptor_M_N(c_gs_ms_os_lengths, c_gs_ms_os_strides);
671
672 const auto block_2_ctile_map = GridwiseOp::MakeDefaultBlock2CTileMap(c_grid_desc_m_n, 1, 1);
673
674 const auto c_grid_desc_g_m_n =
675 DeviceOp::Transform::MakeCGridDescriptor_G_M_N(c_gs_ms_os_lengths, c_gs_ms_os_strides);
676 index_t batch_count = c_grid_desc_g_m_n.GetLength(Number<0>{});
677
679 a_grid_desc, b0_grid_desc, b1_grid_desc, c_grid_desc_m_n, block_2_ctile_map))
680 {
681 return false;
682 }
683
684 // Check if C permute dimension matches GEMM + GEMM shape
685 const index_t c_g = c_grid_desc_g_m_n.GetLength(I0); // unpadded
686
687 if(!(c_g == batch_count))
688 {
689 printf("DeviceOp: BatchCount err");
690 return false;
691 }
692
693 // Note: we need raw lengths since threadwise copy can not handle vector load when part of
694 // vector is out of bounds
695 // Note: need lowest dim in Ms/Ns/Ks/Os, not merged M/N/K/O
696 const auto MzRaw = M;
697 const auto LzRaw = N;
698 const auto KzRaw = K;
699 const auto NzRaw = O;
700
701 // Check scalar per vector requirement
702 const auto a_extent_lowest = ABlockTransferSrcVectorDim == 2 ? KzRaw : MzRaw;
703 const auto b0_extent_lowest = B0BlockTransferSrcVectorDim == 2 ? KzRaw : LzRaw;
704 const auto b1_extent_lowest = B1BlockTransferSrcVectorDim == 2 ? LzRaw : NzRaw;
705 const auto c_extent_lowest = NzRaw;
706
707 if(!(a_extent_lowest % ABlockTransferSrcScalarPerVector == 0 &&
708 b0_extent_lowest % B0BlockTransferSrcScalarPerVector == 0 &&
709 b1_extent_lowest % B1BlockTransferSrcScalarPerVector == 0 &&
710 c_extent_lowest % CShuffleBlockTransferScalarPerVector_NPerBlock == 0))
711 {
712 printf("DeviceOp: Data Transfer Vector scalar err");
713 return false;
714 }
715
716 std::array<index_t, NumDimG + NumDimM + NumDimN> a_mz_kz_strides_{
717 a_gs_ms_ks_strides[NumDimG + NumDimM - 1],
718 a_gs_ms_ks_strides[NumDimG + NumDimM + NumDimK - 1]};
719 std::array<index_t, NumDimG + NumDimM + NumDimN> b0_lz_kz_strides_{
720 b0_gs_ns_ks_strides[NumDimG + NumDimL - 1],
721 b0_gs_ns_ks_strides[NumDimG + NumDimL + NumDimK - 1]};
722 std::array<index_t, NumDimG + NumDimM + NumDimN> b1_nz_lz_strides_{
723 b1_gs_os_ns_strides[NumDimG + NumDimN - 1],
724 b1_gs_os_ns_strides[NumDimG + NumDimN + NumDimL - 1]};
725 std::array<index_t, NumDimG + NumDimM + NumDimN> c_mz_nz_strides_{
726 c_gs_ms_os_strides[NumDimG + NumDimM - 1],
727 c_gs_ms_os_strides[NumDimG + NumDimM + NumDimN - 1]};
728
729 // Check vector load/store requirement
730 const auto a_stride_lowest =
731 ABlockTransferSrcVectorDim == 2 ? a_mz_kz_strides_[1] : a_mz_kz_strides_[0];
732 const auto b0_stride_lowest =
733 B0BlockTransferSrcVectorDim == 2 ? b0_lz_kz_strides_[1] : b0_lz_kz_strides_[0];
734 const auto b1_stride_lowest =
735 B1BlockTransferSrcVectorDim == 2 ? b1_nz_lz_strides_[1] : b1_nz_lz_strides_[0];
736 const auto c_stride_lowest = c_mz_nz_strides_[1];
737
738 if(!(a_stride_lowest == 1 || b0_stride_lowest == 1 || b1_stride_lowest == 1 ||
739 c_stride_lowest == 1))
740 {
741 printf("DeviceOp: Data Vectorize transfer err");
742 return false;
743 }
744
745 return true;
746 }
747
748 // polymorphic
749 bool IsSupportedArgument(const BaseArgument* p_arg) override
750 {
751 return IsSupportedArgument(*dynamic_cast<const RawArg*>(p_arg));
752 }
753
754 // Argument
755 struct Argument : public BaseArgument
756 {
758 const ADataType* p_a_grid,
759 const B0DataType* p_b0_grid,
760 const B1DataType* p_b1_grid,
761 CDataType* p_c_grid,
762 const std::array<void*, NumAcc0Bias> p_acc0_biases,
763 const std::array<void*, NumAcc1Bias> p_acc1_biases,
764 const std::array<index_t, NumDimG + NumDimM + NumDimN>& a_gs_ms_ks_lengths,
765 const std::array<index_t, NumDimG + NumDimM + NumDimN>& a_gs_ms_ks_strides,
766 const std::array<index_t, NumDimG + NumDimM + NumDimN>& b0_gs_ls_ks_lengths,
767 const std::array<index_t, NumDimG + NumDimM + NumDimN>& b0_gs_ls_ks_strides,
768 const std::array<index_t, NumDimG + NumDimM + NumDimN>& b1_gs_ns_ls_lengths,
769 const std::array<index_t, NumDimG + NumDimM + NumDimN>& b1_gs_ns_ls_strides,
770 const std::array<index_t, NumDimG + NumDimM + NumDimN>& c_gs_ms_ns_lengths,
771 const std::array<index_t, NumDimG + NumDimM + NumDimN>& c_gs_ms_ns_strides,
772 const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ls_lengths,
773 const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ls_strides,
774 const std::array<std::vector<ck::index_t>, NumAcc1Bias> acc1_biases_gs_ms_ns_lengths,
775 const std::array<std::vector<ck::index_t>, NumAcc1Bias> acc1_biases_gs_ms_ns_strides,
776 const index_t M01,
777 const index_t N01,
778 AElementwiseOperation a_element_op,
779 B0ElementwiseOperation b0_element_op,
780 AccElementwiseOperation acc_element_op,
781 B1ElementwiseOperation b1_element_op,
782 CElementwiseOperation c_element_op)
783 : p_a_grid_{p_a_grid},
784 p_b0_grid_{p_b0_grid},
785 p_b1_grid_{p_b1_grid},
786 p_c_grid_{p_c_grid},
787 a_grid_desc{DeviceOp::MakeAGridDescriptor(a_gs_ms_ks_lengths, a_gs_ms_ks_strides)},
789 DeviceOp::MakeB0GridDescriptor(b0_gs_ls_ks_lengths, b0_gs_ls_ks_strides)},
791 DeviceOp::MakeB1GridDescriptor(b1_gs_ns_ls_lengths, b1_gs_ns_ls_strides)},
793 Transform::MakeCGridDescriptor_M_N(c_gs_ms_ns_lengths, c_gs_ms_ns_strides)},
795 Transform::MakeAGridDescriptor_G_M_K(a_gs_ms_ks_lengths, a_gs_ms_ks_strides)},
797 Transform::MakeB0GridDescriptor_G_N_K(b0_gs_ls_ks_lengths, b0_gs_ls_ks_strides)},
799 Transform::MakeB1GridDescriptor_G_N_K(b1_gs_ns_ls_lengths, b1_gs_ns_ls_strides)},
801 Transform::MakeCGridDescriptor_G_M_N(c_gs_ms_ns_lengths, c_gs_ms_ns_strides)},
803 block_2_ctile_map_{GridwiseOp::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01)},
804 a_element_op_{a_element_op},
805 b0_element_op_{b0_element_op},
806 acc_element_op_{acc_element_op},
807 b1_element_op_{b1_element_op},
808 c_element_op_{c_element_op},
810 raw_lengths_mz_lz_kz_nz_{a_gs_ms_ks_lengths[NumDimG + NumDimM - 1],
811 b0_gs_ls_ks_lengths[NumDimG + NumDimL - 1],
812 b0_gs_ls_ks_lengths[NumDimG + NumDimL + NumDimK - 1],
813 b1_gs_ns_ls_lengths[NumDimG + NumDimN - 1]},
814 a_mz_kz_strides_{a_gs_ms_ks_strides[NumDimG + NumDimM - 1],
815 a_gs_ms_ks_strides[NumDimG + NumDimM + NumDimK - 1]},
816 b0_lz_kz_strides_{b0_gs_ls_ks_strides[NumDimG + NumDimL - 1],
817 b0_gs_ls_ks_strides[NumDimG + NumDimL + NumDimK - 1]},
818 b1_nz_lz_strides_{b1_gs_ns_ls_strides[NumDimG + NumDimN - 1],
819 b1_gs_ns_ls_strides[NumDimG + NumDimN + NumDimL - 1]},
820 c_mz_nz_strides_{c_gs_ms_ns_strides[NumDimG + NumDimM - 1],
821 c_gs_ms_ns_strides[NumDimG + NumDimM + NumDimN - 1]},
825 {
826 // TODO ANT: implement bias addition
827 ignore = p_acc0_biases;
828 ignore = p_acc1_biases;
829 ignore = acc0_biases_gs_ms_ls_lengths;
830 ignore = acc0_biases_gs_ms_ls_strides;
831 ignore = acc1_biases_gs_ms_ns_lengths;
832 ignore = acc1_biases_gs_ms_ns_strides;
833
836 {
840 }
841 }
842
843 // Pointers
844 const ADataType* p_a_grid_;
845 const B0DataType* p_b0_grid_;
846 const B1DataType* p_b1_grid_;
847 CDataType* p_c_grid_;
848
849 // Tensor Descriptors
854
859
862
863 // Block to Tile mapping
865
866 // ElementwiseOp
867 AElementwiseOperation a_element_op_;
868 B0ElementwiseOperation b0_element_op_;
869 AccElementwiseOperation acc_element_op_;
870 B1ElementwiseOperation b1_element_op_;
871 CElementwiseOperation c_element_op_;
872
873 // check C0 masking and padding
875
876 // Strides for the last M/N/K dimensions of A/B0/B1/C
877 // for sanity check of vector load/store
878 std::array<index_t, NumDimG + NumDimM + NumDimN> raw_lengths_mz_lz_kz_nz_;
879 std::array<index_t, NumDimG + NumDimM + NumDimN> a_mz_kz_strides_;
880 std::array<index_t, NumDimG + NumDimM + NumDimN> b0_lz_kz_strides_;
881 std::array<index_t, NumDimG + NumDimM + NumDimN> b1_nz_lz_strides_;
882 std::array<index_t, NumDimG + NumDimM + NumDimN> c_mz_nz_strides_;
883
885 // Batch Offset
887 };
888
889 struct Invoker : public BaseInvoker
890 {
892
893 float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
894 {
895 const auto M0 = math::integer_divide_ceil(arg.M_, MPerBlock);
896 const auto N0 = math::integer_divide_ceil(arg.O_, NPerBlock);
897
898 const index_t grid_size = arg.G0_ * arg.G1_ * M0 * N0;
899 const auto K = arg.K_;
900 // printf("HasKBlockLoop: %d\n", GridwiseOp::CalculateHasMainKBlockLoop(K));
901 auto launch_kernel = [&](auto has_main_k_block_loop) {
904 ADataType,
905 B0DataType,
906 B1DataType,
907 CDataType,
908 AElementwiseOperation,
909 B0ElementwiseOperation,
910 AccElementwiseOperation,
911 B1ElementwiseOperation,
912 CElementwiseOperation,
913 QueryGroupNumber,
914 has_main_k_block_loop>;
915
916 return launch_and_time_kernel(stream_config,
917 kernel,
918 dim3(grid_size),
919 dim3(BlockSize),
920 0,
921 arg.p_a_grid_,
922 arg.p_b0_grid_,
923 arg.p_b1_grid_,
924 arg.p_c_grid_,
925 arg.M_,
926 arg.N_,
927 arg.K_,
928 arg.O_,
929 arg.G0_,
930 arg.G1_,
931 arg.alpha_,
932 arg.input_permute_,
933 arg.output_permute_);
934 };
935
937 {
938 return launch_kernel(integral_constant<bool, true>{});
939 }
940 else
941 {
942 return launch_kernel(integral_constant<bool, false>{});
943 }
944 }
945
946 // polymorphic
947 float Run(const BaseArgument* p_arg,
948 const StreamConfig& stream_config = StreamConfig{}) override
949 {
950 return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
951 }
952 };
953
954 static constexpr bool IsValidCompilationParameter()
955 {
956 // TODO: properly implement this check
957 return true;
958 }
959#if 0
960 static bool IsSupportedArgument(const Argument& arg)
961 {
963 {
965 {
966 printf("DeviceOp: Acc0 Type err");
967 return false;
968 }
969
971 {
972 printf("DeviceOp: Acc1 Type err");
973 return false;
974 }
975 }
976 else
977 {
978 printf("DeviceOp: Arch err");
979 return false;
980 }
981
982 if(!GridwiseOp::CheckValidity(arg.a_grid_desc,
983 arg.b0_grid_desc,
984 arg.b1_grid_desc,
985 arg.c_grid_desc_m_n_,
986 arg.block_2_ctile_map_))
987 {
988 return false;
989 }
990
991 // Check if C permute dimension matches GEMM + GEMM shape
992 const index_t c_g = arg.c_grid_desc_g_m_n_.GetLength(I0); // unpadded
993
994 if(!(c_g == arg.batch_count_))
995 {
996 printf("DeviceOp: BatchCount err");
997 return false;
998 }
999
1000 // Note: we need raw lengths since threadwise copy can not handle vector load when part of
1001 // vector is out of bounds
1002 // Note: need lowest dim in Ms/Ns/Ks/Os, not merged M/N/K/O
1003 const auto MzRaw = arg.raw_lengths_mz_lz_kz_nz_[0];
1004 const auto LzRaw = arg.raw_lengths_mz_lz_kz_nz_[1];
1005 const auto KzRaw = arg.raw_lengths_mz_lz_kz_nz_[2];
1006 const auto NzRaw = arg.raw_lengths_mz_lz_kz_nz_[3];
1007
1008 // Check scalar per vector requirement
1009 const auto a_extent_lowest = ABlockTransferSrcVectorDim == 2 ? KzRaw : MzRaw;
1010 const auto b0_extent_lowest = B0BlockTransferSrcVectorDim == 2 ? KzRaw : LzRaw;
1011 const auto b1_extent_lowest = B1BlockTransferSrcVectorDim == 2 ? LzRaw : NzRaw;
1012 const auto c_extent_lowest = NzRaw;
1013
1014 if(!(a_extent_lowest % ABlockTransferSrcScalarPerVector == 0 &&
1015 b0_extent_lowest % B0BlockTransferSrcScalarPerVector == 0 &&
1016 b1_extent_lowest % B1BlockTransferSrcScalarPerVector == 0 &&
1017 c_extent_lowest % CShuffleBlockTransferScalarPerVector_NPerBlock == 0))
1018 {
1019 printf("DeviceOp: Data Transfer Vector scalar err");
1020 return false;
1021 }
1022
1023 // Check vector load/store requirement
1024 const auto a_stride_lowest =
1025 ABlockTransferSrcVectorDim == 2 ? arg.a_mz_kz_strides_[1] : arg.a_mz_kz_strides_[0];
1026 const auto b0_stride_lowest =
1027 B0BlockTransferSrcVectorDim == 2 ? arg.b0_lz_kz_strides_[1] : arg.b0_lz_kz_strides_[0];
1028 const auto b1_stride_lowest =
1029 B1BlockTransferSrcVectorDim == 2 ? arg.b1_nz_lz_strides_[1] : arg.b1_nz_lz_strides_[0];
1030 const auto c_stride_lowest = arg.c_mz_nz_strides_[1];
1031
1032 if(!(a_stride_lowest == 1 || b0_stride_lowest == 1 || b1_stride_lowest == 1 ||
1033 c_stride_lowest == 1))
1034 {
1035 printf("DeviceOp: Data Vectorize transfer err");
1036 return false;
1037 }
1038
1039 return true;
1040 }
1041
1042 // polymorphic
1043 bool IsSupportedArgument(const BaseArgument* p_arg) override
1044 {
1045 return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
1046 }
1047
1048 static auto MakeArgument(
1049 const ADataType* p_a,
1050 const B0DataType* p_b0,
1051 const B1DataType* p_b1,
1052 CDataType* p_c,
1053 const std::array<void*, NumAcc0Bias> p_acc0_biases,
1054 const std::array<void*, NumAcc1Bias> p_acc1_biases,
1055 const std::array<index_t, NumDimG + NumDimM + NumDimN>& a_gs_ms_ks_lengths,
1056 const std::array<index_t, NumDimG + NumDimM + NumDimN>& a_gs_ms_ks_strides,
1057 const std::array<index_t, NumDimG + NumDimM + NumDimN>& b0_gs_ls_ks_lengths,
1058 const std::array<index_t, NumDimG + NumDimM + NumDimN>& b0_gs_ls_ks_strides,
1059 const std::array<index_t, NumDimG + NumDimM + NumDimN>& b1_gs_ns_ls_lengths,
1060 const std::array<index_t, NumDimG + NumDimM + NumDimN>& b1_gs_ns_ls_strides,
1061 const std::array<index_t, NumDimG + NumDimM + NumDimN>& c_gs_ms_ns_lengths,
1062 const std::array<index_t, NumDimG + NumDimM + NumDimN>& c_gs_ms_ns_strides,
1063 const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ls_lengths,
1064 const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ls_strides,
1065 const std::array<std::vector<ck::index_t>, NumAcc1Bias> acc1_biases_gs_ms_ns_lengths,
1066 const std::array<std::vector<ck::index_t>, NumAcc1Bias> acc1_biases_gs_ms_ns_strides,
1067 AElementwiseOperation a_element_op,
1068 B0ElementwiseOperation b0_element_op,
1069 AccElementwiseOperation acc_element_op,
1070 B1ElementwiseOperation b1_element_op,
1071 CElementwiseOperation c_element_op)
1072 {
1073 return Argument{p_a,
1074 p_b0,
1075 p_b1,
1076 p_c,
1077 p_acc0_biases,
1078 p_acc1_biases,
1079 a_gs_ms_ks_lengths,
1080 a_gs_ms_ks_strides,
1081 b0_gs_ls_ks_lengths,
1082 b0_gs_ls_ks_strides,
1083 b1_gs_ns_ls_lengths,
1084 b1_gs_ns_ls_strides,
1085 c_gs_ms_ns_lengths,
1086 c_gs_ms_ns_strides,
1087 acc0_biases_gs_ms_ls_lengths,
1088 acc0_biases_gs_ms_ls_strides,
1089 acc1_biases_gs_ms_ns_lengths,
1090 acc1_biases_gs_ms_ns_strides,
1091 1,
1092 1,
1093 a_element_op,
1094 b0_element_op,
1095 acc_element_op,
1096 b1_element_op,
1097 c_element_op};
1098 }
1099#endif
1100
1101 // polymorphic
1102 std::unique_ptr<BaseArgument> MakeArgumentPointer(
1103 const void* p_a,
1104 const void* p_b0,
1105 const void* p_b1,
1106 void* p_c,
1107 const std::array<void*, NumAcc0Bias> p_acc0_biases,
1108 const std::array<void*, NumAcc1Bias> p_acc1_biases,
1109 const std::vector<index_t>& a_gs_ms_ks_lengths,
1110 const std::vector<index_t>& a_gs_ms_ks_strides,
1111 const std::vector<index_t>& b0_gs_ls_ks_lengths,
1112 const std::vector<index_t>& b0_gs_ls_ks_strides,
1113 const std::vector<index_t>& b1_gs_ns_ls_lengths,
1114 const std::vector<index_t>& b1_gs_ns_ls_strides,
1115 const std::vector<index_t>& c_gs_ms_ns_lengths,
1116 const std::vector<index_t>& c_gs_ms_ns_strides,
1117 const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ls_lengths,
1118 const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ls_strides,
1119 const std::array<std::vector<ck::index_t>, NumAcc1Bias> acc1_biases_gs_ms_ns_lengths,
1120 const std::array<std::vector<ck::index_t>, NumAcc1Bias> acc1_biases_gs_ms_ns_strides,
1121 AElementwiseOperation a_element_op,
1122 B0ElementwiseOperation b0_element_op,
1123 AccElementwiseOperation acc_element_op,
1124 B1ElementwiseOperation b1_element_op,
1125 CElementwiseOperation c_element_op) override
1126 {
1127 std::array<index_t, NumDimG + NumDimM + NumDimN> a_lengths;
1128 std::array<index_t, NumDimG + NumDimM + NumDimN> a_strides;
1129 std::array<index_t, NumDimG + NumDimM + NumDimN> b0_lengths;
1130 std::array<index_t, NumDimG + NumDimM + NumDimN> b0_strides;
1131 std::array<index_t, NumDimG + NumDimM + NumDimN> b1_lengths;
1132 std::array<index_t, NumDimG + NumDimM + NumDimN> b1_strides;
1133 std::array<index_t, NumDimG + NumDimM + NumDimN> c_lengths;
1134 std::array<index_t, NumDimG + NumDimM + NumDimN> c_strides;
1135 std::transform(a_gs_ms_ks_lengths.begin(),
1136 a_gs_ms_ks_lengths.end(),
1137 a_lengths.begin(),
1138 [](index_t i) { return i; });
1139 std::transform(a_gs_ms_ks_strides.begin(),
1140 a_gs_ms_ks_strides.end(),
1141 a_strides.begin(),
1142 [](index_t i) { return i; });
1143 std::transform(b0_gs_ls_ks_lengths.begin(),
1144 b0_gs_ls_ks_lengths.end(),
1145 b0_lengths.begin(),
1146 [](index_t i) { return i; });
1147 std::transform(b0_gs_ls_ks_strides.begin(),
1148 b0_gs_ls_ks_strides.end(),
1149 b0_strides.begin(),
1150 [](index_t i) { return i; });
1151 std::transform(b1_gs_ns_ls_lengths.begin(),
1152 b1_gs_ns_ls_lengths.end(),
1153 b1_lengths.begin(),
1154 [](index_t i) { return i; });
1155 std::transform(b1_gs_ns_ls_strides.begin(),
1156 b1_gs_ns_ls_strides.end(),
1157 b1_strides.begin(),
1158 [](index_t i) { return i; });
1159 std::transform(c_gs_ms_ns_lengths.begin(),
1160 c_gs_ms_ns_lengths.end(),
1161 c_lengths.begin(),
1162 [](index_t i) { return i; });
1163 std::transform(c_gs_ms_ns_strides.begin(),
1164 c_gs_ms_ns_strides.end(),
1165 c_strides.begin(),
1166 [](index_t i) { return i; });
1167 return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
1168 static_cast<const B0DataType*>(p_b0),
1169 static_cast<const B1DataType*>(p_b1),
1170 static_cast<CDataType*>(p_c),
1171 p_acc0_biases,
1172 p_acc1_biases,
1173 a_lengths,
1174 a_strides,
1175 b0_lengths,
1176 b0_strides,
1177 b1_lengths,
1178 b1_strides,
1179 c_lengths,
1180 c_strides,
1181 acc0_biases_gs_ms_ls_lengths,
1182 acc0_biases_gs_ms_ls_strides,
1183 acc1_biases_gs_ms_ns_lengths,
1184 acc1_biases_gs_ms_ns_strides,
1185 1,
1186 1,
1187 a_element_op,
1188 b0_element_op,
1189 acc_element_op,
1190 b1_element_op,
1191 c_element_op);
1192 }
1193
1194 static auto MakeInvoker() { return Invoker{}; }
1195
1196 // polymorphic
1197 std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
1198 {
1199 return std::make_unique<Invoker>(Invoker{});
1200 }
1201
1202 // polymorphic
1203 std::string GetTypeString() const override
1204 {
1205 auto str = std::stringstream();
1206
1207 std::map<LoopScheduler, std::string> LoopSchedToString{
1208 {LoopScheduler::Default, "Default"}, {LoopScheduler::Interwave, "Interwave"}};
1209
1210 std::map<PipelineVersion, std::string> PipelineVersionToString{{PipelineVersion::v1, "v1"},
1211 {PipelineVersion::v2, "v2"}};
1212
1213 // clang-format off
1214 str << "DeviceGroupedQueryAttentionForward_Wmma, "
1215 << "QueryGroupNumber: "
1216 << QueryGroupNumber << ", "
1217 << "<"
1218 << BlockSize << ", "
1219 << MPerBlock << ", "
1220 << LPerBlock << ", "
1221 << KPerBlock << ", "
1222 << AK1 << ", "
1223 << BK1 << ", "
1224 << MPerBlock << ", "
1225 << NPerBlock << ", "
1226 << LTilePerBlock << ", "
1227 << L1 << ", "
1228 << getGemmSpecializationString(GemmSpec) << ", "
1229 << "ASpec" << getTensorSpecializationString(ASpec) << ", "
1230 << "B0Spec" << getTensorSpecializationString(B0Spec) << ", "
1231 << "B1Spec" << getTensorSpecializationString(B1Spec) << ", "
1232 << "CSpec" << getTensorSpecializationString(CSpec) << ", "
1233 << getMaskingSpecializationString(MaskingSpec)
1234 << ">"
1235 << " AEnableLds: "
1236 << AEnableLds << ", "
1237 << "B0EnableLds: "
1238 << B0EnableLds << ", "
1239 << "B1EnableLds: "
1240 << B1EnableLds << ", "
1241 << "NumPrefetch: "
1242 << NumPrefetch << ", "
1243 << "LoopScheduler: "
1244 << LoopSchedToString[LoopSched] << ", "
1245 << "PipelineVersion: "
1246 << PipelineVersionToString[PipelineVer];
1247 // clang-format on
1248
1249 return str.str();
1250 }
1251};
1252
1253} // namespace device
1254} // namespace tensor_operation
1255} // namespace ck
#define CK_MIN_BLOCK_PER_CU
Definition ck.hpp:31
#define CK_MAX_THREAD_PER_BLOCK
Definition ck.hpp:30
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:14
__host__ __device__ constexpr auto integer_divide_ceil(X x, Y y)
Definition utility/math.hpp:72
Definition convolution_backward_data_specialization.hpp:8
std::string getGemmSpecializationString(const GemmSpecialization &s)
Definition gemm_specialization.hpp:32
std::string getMaskingSpecializationString(const MaskingSpecialization &s)
Definition masking_specialization.hpp:17
MaskingSpecialization
Definition masking_specialization.hpp:11
@ MaskDisabled
Definition masking_specialization.hpp:12
@ MaskOutUpperTriangle
Definition masking_specialization.hpp:13
TensorSpecialization
Definition tensor_specialization.hpp:11
__global__ void kernel_grouped_query_attention_wmma(const ADataType *__restrict__ p_a_grid, const B0DataType *__restrict__ p_b0_grid, const B1DataType *__restrict__ p_b1_grid, CDataType *__restrict__ p_c_grid, index_t M, index_t N, index_t K, index_t O, index_t G0, index_t G1, float alpha, bool input_permute, bool output_permute)
Definition device_grouped_query_attention_forward_wmma.hpp:50
GemmSpecialization
Definition gemm_specialization.hpp:11
std::string getTensorSpecializationString(const TensorSpecialization &s)
Definition tensor_specialization.hpp:16
Definition convolution_backward_data_specialization.hpp:7
Definition ck.hpp:268
__host__ __device__ constexpr auto make_multi_index(Xs &&... xs)
Definition array_multi_index.hpp:15
__device__ index_t get_grid_size()
Definition get_id.hpp:49
int32_t index_t
Definition ck.hpp:299
@ Set
Definition ck.hpp:278
integral_constant< index_t, N > Number
Definition number.hpp:12
constexpr detail::ignore_t ignore
Definition utility/ignore.hpp:20
__device__ index_t get_block_1d_id()
Definition get_id.hpp:47
bool is_gfx12_supported()
Definition host_utility/device_prop.hpp:55
constexpr bool is_same_v
Definition type.hpp:283
LoopScheduler
Definition loop_scheduler.hpp:15
@ Default
Definition loop_scheduler.hpp:16
@ Interwave
Definition loop_scheduler.hpp:17
int64_t long_index_t
Definition ck.hpp:300
PipelineVersion
Definition gridwise_gemm_pipeline_selector.hpp:18
@ v2
Definition gridwise_gemm_pipeline_selector.hpp:20
@ v1
Definition gridwise_gemm_pipeline_selector.hpp:19
bool is_gfx11_supported()
Definition host_utility/device_prop.hpp:60
constexpr LoopScheduler make_default_loop_scheduler()
Definition loop_scheduler.hpp:20
Definition ck/stream_config.hpp:10
Definition gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp:93
ck::GridwiseBatchedGemmSoftmaxGemm_Wmma< ADataType, B0DataType, Acc0DataType, B1DataType, Acc1DataType, CShuffleDataType, CDataType, AElementwiseOperation, B0ElementwiseOperation, AccElementwiseOperation, B1ElementwiseOperation, CElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc, B0GridDesc, B1GridDesc, CGridDesc_M_N, MPerBlock, LPerBlock, KPerBlock, AK1, BK1, NPerBlock, LTilePerBlock, L1, MPerWmma, LPerWmma, NPerWmma, MRepeat, LRepeat, NRepeat, BlockSize, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_K1, true, AEnableLds, ABlockLdsAddExtraM, B0BlockTransferThreadClusterLengths_K0_L_K1, B0BlockTransferThreadClusterArrangeOrder, B0BlockTransferSrcAccessOrder, B0BlockTransferSrcVectorDim, B0BlockTransferSrcScalarPerVector, B0BlockTransferDstScalarPerVector_K1, true, B0EnableLds, B0BlockLdsAddExtraL, B1BlockTransferThreadClusterLengths_L0_N_L1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_L1, false, B1EnableLds, B1BlockLdsAddExtraN, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, Transform::matrix_padder.PadN, MaskingSpec==MaskingSpecialization::MaskOutUpperTriangle, NumPrefetch, LoopSched, PipelineVer >::DefaultBlock2CTileMap
remove_cvref_t< decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}, 1, 1))> DefaultBlock2CTileMap
Definition gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp:682
ck::GridwiseBatchedGemmSoftmaxGemm_Wmma< ADataType, B0DataType, Acc0DataType, B1DataType, Acc1DataType, CShuffleDataType, CDataType, AElementwiseOperation, B0ElementwiseOperation, AccElementwiseOperation, B1ElementwiseOperation, CElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc, B0GridDesc, B1GridDesc, CGridDesc_M_N, MPerBlock, LPerBlock, KPerBlock, AK1, BK1, NPerBlock, LTilePerBlock, L1, MPerWmma, LPerWmma, NPerWmma, MRepeat, LRepeat, NRepeat, BlockSize, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_K1, true, AEnableLds, ABlockLdsAddExtraM, B0BlockTransferThreadClusterLengths_K0_L_K1, B0BlockTransferThreadClusterArrangeOrder, B0BlockTransferSrcAccessOrder, B0BlockTransferSrcVectorDim, B0BlockTransferSrcScalarPerVector, B0BlockTransferDstScalarPerVector_K1, true, B0EnableLds, B0BlockLdsAddExtraL, B1BlockTransferThreadClusterLengths_L0_N_L1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_L1, false, B1EnableLds, B1BlockLdsAddExtraN, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, Transform::matrix_padder.PadN, MaskingSpec==MaskingSpecialization::MaskOutUpperTriangle, NumPrefetch, LoopSched, PipelineVer >::MakeDefaultBlock2CTileMap
__host__ static __device__ constexpr auto MakeDefaultBlock2CTileMap(const CGridDesc_M_N &c_grid_desc_m_n, index_t, index_t)
Definition gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp:672
ck::GridwiseBatchedGemmSoftmaxGemm_Wmma< ADataType, B0DataType, Acc0DataType, B1DataType, Acc1DataType, CShuffleDataType, CDataType, AElementwiseOperation, B0ElementwiseOperation, AccElementwiseOperation, B1ElementwiseOperation, CElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc, B0GridDesc, B1GridDesc, CGridDesc_M_N, MPerBlock, LPerBlock, KPerBlock, AK1, BK1, NPerBlock, LTilePerBlock, L1, MPerWmma, LPerWmma, NPerWmma, MRepeat, LRepeat, NRepeat, BlockSize, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_K1, true, AEnableLds, ABlockLdsAddExtraM, B0BlockTransferThreadClusterLengths_K0_L_K1, B0BlockTransferThreadClusterArrangeOrder, B0BlockTransferSrcAccessOrder, B0BlockTransferSrcVectorDim, B0BlockTransferSrcScalarPerVector, B0BlockTransferDstScalarPerVector_K1, true, B0EnableLds, B0BlockLdsAddExtraL, B1BlockTransferThreadClusterLengths_L0_N_L1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_L1, false, B1EnableLds, B1BlockLdsAddExtraN, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, Transform::matrix_padder.PadN, MaskingSpec==MaskingSpecialization::MaskOutUpperTriangle, NumPrefetch, LoopSched, PipelineVer >::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
remove_cvref_t< decltype(MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(CGridDesc_M_N{}))> CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
Definition gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp:679
ck::GridwiseBatchedGemmSoftmaxGemm_Wmma< ADataType, B0DataType, Acc0DataType, B1DataType, Acc1DataType, CShuffleDataType, CDataType, AElementwiseOperation, B0ElementwiseOperation, AccElementwiseOperation, B1ElementwiseOperation, CElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc, B0GridDesc, B1GridDesc, CGridDesc_M_N, MPerBlock, LPerBlock, KPerBlock, AK1, BK1, NPerBlock, LTilePerBlock, L1, MPerWmma, LPerWmma, NPerWmma, MRepeat, LRepeat, NRepeat, BlockSize, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_K1, true, AEnableLds, ABlockLdsAddExtraM, B0BlockTransferThreadClusterLengths_K0_L_K1, B0BlockTransferThreadClusterArrangeOrder, B0BlockTransferSrcAccessOrder, B0BlockTransferSrcVectorDim, B0BlockTransferSrcScalarPerVector, B0BlockTransferDstScalarPerVector_K1, true, B0EnableLds, B0BlockLdsAddExtraL, B1BlockTransferThreadClusterLengths_L0_N_L1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_L1, false, B1EnableLds, B1BlockLdsAddExtraN, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, Transform::matrix_padder.PadN, MaskingSpec==MaskingSpecialization::MaskOutUpperTriangle, NumPrefetch, LoopSched, PipelineVer >::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
__host__ static __device__ constexpr auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc_M_N &c_grid_desc_m_n)
Definition gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp:653
ck::GridwiseBatchedGemmSoftmaxGemm_Wmma< ADataType, B0DataType, Acc0DataType, B1DataType, Acc1DataType, CShuffleDataType, CDataType, AElementwiseOperation, B0ElementwiseOperation, AccElementwiseOperation, B1ElementwiseOperation, CElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc, B0GridDesc, B1GridDesc, CGridDesc_M_N, MPerBlock, LPerBlock, KPerBlock, AK1, BK1, NPerBlock, LTilePerBlock, L1, MPerWmma, LPerWmma, NPerWmma, MRepeat, LRepeat, NRepeat, BlockSize, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_K1, true, AEnableLds, ABlockLdsAddExtraM, B0BlockTransferThreadClusterLengths_K0_L_K1, B0BlockTransferThreadClusterArrangeOrder, B0BlockTransferSrcAccessOrder, B0BlockTransferSrcVectorDim, B0BlockTransferSrcScalarPerVector, B0BlockTransferDstScalarPerVector_K1, true, B0EnableLds, B0BlockLdsAddExtraL, B1BlockTransferThreadClusterLengths_L0_N_L1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_L1, false, B1EnableLds, B1BlockLdsAddExtraN, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, Transform::matrix_padder.PadN, MaskingSpec==MaskingSpecialization::MaskOutUpperTriangle, NumPrefetch, LoopSched, PipelineVer >::CalculateHasMainKBlockLoop
__host__ static __device__ constexpr bool CalculateHasMainKBlockLoop(index_t K)
Definition gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp:645
ck::GridwiseBatchedGemmSoftmaxGemm_Wmma< ADataType, B0DataType, Acc0DataType, B1DataType, Acc1DataType, CShuffleDataType, CDataType, AElementwiseOperation, B0ElementwiseOperation, AccElementwiseOperation, B1ElementwiseOperation, CElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc, B0GridDesc, B1GridDesc, CGridDesc_M_N, MPerBlock, LPerBlock, KPerBlock, AK1, BK1, NPerBlock, LTilePerBlock, L1, MPerWmma, LPerWmma, NPerWmma, MRepeat, LRepeat, NRepeat, BlockSize, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_K1, true, AEnableLds, ABlockLdsAddExtraM, B0BlockTransferThreadClusterLengths_K0_L_K1, B0BlockTransferThreadClusterArrangeOrder, B0BlockTransferSrcAccessOrder, B0BlockTransferSrcVectorDim, B0BlockTransferSrcScalarPerVector, B0BlockTransferDstScalarPerVector_K1, true, B0EnableLds, B0BlockLdsAddExtraL, B1BlockTransferThreadClusterLengths_L0_N_L1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_L1, false, B1EnableLds, B1BlockLdsAddExtraN, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, Transform::matrix_padder.PadN, MaskingSpec==MaskingSpecialization::MaskOutUpperTriangle, NumPrefetch, LoopSched, PipelineVer >::CheckValidity
__host__ static __device__ constexpr bool CheckValidity(const AGridDesc &a_grid_desc, const B0GridDesc &b0_grid_desc, const B1GridDesc &b1_grid_desc, const CGridDesc_M_N &c_grid_desc_m_n, const Block2CTileMap &block_2_ctile_map)
Definition gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp:511
Definition utility/sequence.hpp:43
Definition utility/integral_constant.hpp:20
Definition transform_contraction_to_gemm_arraybase.hpp:122
__host__ static __device__ constexpr auto MakeB0GridDescriptor_BK0_N_BK1(const BGridDesc_N_K &b_grid_desc_n_k, const Number &BK1)
Definition transform_contraction_to_gemm_arraybase.hpp:245
__host__ static __device__ auto MakeCGridDescriptor_G_M_N(const std::array< index_t, NumDimG+NumDimM+NumDimN > &c_gs_ms_os_lengths_vec, const std::array< index_t, NumDimG+NumDimM+NumDimN > &c_gs_ms_os_strides_vec)
Definition transform_contraction_to_gemm_arraybase.hpp:375
__host__ static __device__ constexpr auto MakeAGridDescriptor_AK0_M_AK1(const AGridDesc_M_K &a_grid_desc_m_k, const Number &AK1)
Definition transform_contraction_to_gemm_arraybase.hpp:172
__host__ static __device__ auto MakeB1GridDescriptor_N_K(const std::array< index_t, NumDimG+NumDimM+NumDimN > &b1_gs_os_ns_lengths_vec, const std::array< index_t, NumDimG+NumDimM+NumDimN > &b1_gs_os_ns_strides_vec)
Definition transform_contraction_to_gemm_arraybase.hpp:307
__host__ static __device__ auto MakeB1GridDescriptor_G_N_K(const std::array< index_t, NumDimG+NumDimM+NumDimN > &b1_gs_os_ns_lengths_vec, const std::array< index_t, NumDimG+NumDimM+NumDimN > &b1_gs_os_ns_strides_vec)
Definition transform_contraction_to_gemm_arraybase.hpp:301
__host__ static __device__ constexpr auto MakeB0GridDescriptor_BKWmma_LBlockRepeat_LWaves_BK0PerWmma_BKRow_LPerWmma_BK1(const BGridDesc_L_K &b_grid_desc_l_k, const WmmaK &, const LRepeat &, const LWaves &, const LPerWmma &, const BK1 &)
Definition transform_contraction_to_gemm_arraybase.hpp:266
__host__ static __device__ auto MakeB0GridDescriptor_G_N_K(const std::array< index_t, NumDimG+NumDimM+NumDimN > &b0_gs_ns_ks_lengths_vec, const std::array< index_t, NumDimG+NumDimM+NumDimN > &b0_gs_ns_ks_strides_vec)
Definition transform_contraction_to_gemm_arraybase.hpp:228
__host__ static __device__ constexpr auto MakeB1GridDescriptor_BK0_N_BK1(const B1GridDesc_N_K &b1_grid_desc_n_k, const Number &B1K1)
Definition transform_contraction_to_gemm_arraybase.hpp:318
__host__ static __device__ auto MakeAGridDescriptor_G_M_K(const std::array< index_t, NumDimG+NumDimM+NumDimN > &a_gs_ms_ks_lengths_vec, const std::array< index_t, NumDimG+NumDimM+NumDimN > &a_gs_ms_ks_strides_vec)
Definition transform_contraction_to_gemm_arraybase.hpp:156
__host__ static __device__ constexpr auto MakeAGridDescriptor_AKWmma_MBlockRepeat_MWaves_AK0PerWmma_AKRow_MPerWmma_AK1(const AGridDesc_M_K &a_grid_desc_m_k, const WmmaK &, const MRepeat &, const MWaves &, const MPerWmma &, const AK1 &)
Definition transform_contraction_to_gemm_arraybase.hpp:193
__host__ static __device__ auto MakeCGridDescriptor_M_N(const std::array< index_t, NumDimG+NumDimM+NumDimN > &c_gs_ms_os_lengths_vec, const std::array< index_t, NumDimG+NumDimM+NumDimN > &c_gs_ms_os_strides_vec)
Definition transform_contraction_to_gemm_arraybase.hpp:381
__host__ static __device__ constexpr auto MakeB1GridDescriptor_BLWmma_NBlockRepeat_NWaves__BL0PerWmma_BLRow_NPerWmma_BL1(const BGridDesc_N_L &b_grid_desc_n_l, const WmmaL &, const NRepeat &, const NWaves &, const NPerWmma &, const BL1 &)
Definition transform_contraction_to_gemm_arraybase.hpp:340
__host__ static __device__ auto MakeAGridDescriptor_M_K(const std::array< index_t, NumDimG+NumDimM+NumDimN > &a_gs_ms_ks_lengths_vec, const std::array< index_t, NumDimG+NumDimM+NumDimN > &a_gs_ms_ks_strides_vec)
Definition transform_contraction_to_gemm_arraybase.hpp:162
__host__ static __device__ auto MakeB0GridDescriptor_N_K(const std::array< index_t, NumDimG+NumDimM+NumDimN > &b0_gs_ns_ks_lengths_vec, const std::array< index_t, NumDimG+NumDimM+NumDimN > &b0_gs_ns_ks_strides_vec)
Definition transform_contraction_to_gemm_arraybase.hpp:234
Definition device_base.hpp:197
Definition masking_specialization.hpp:57
Definition device_batched_gemm_softmax_gemm_permute.hpp:34
Definition device_grouped_query_attention_forward_wmma.hpp:756
CGridDesc_M_N c_grid_desc_m_n_
Definition device_grouped_query_attention_forward_wmma.hpp:853
const B0DataType * p_b0_grid_
Definition device_grouped_query_attention_forward_wmma.hpp:845
AGridDesc_G_M_K a_grid_desc_g_m_k_
Definition device_grouped_query_attention_forward_wmma.hpp:855
AccElementwiseOperation acc_element_op_
Definition device_grouped_query_attention_forward_wmma.hpp:869
B1ElementwiseOperation b1_element_op_
Definition device_grouped_query_attention_forward_wmma.hpp:870
GridwiseOp::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock_
Definition device_grouped_query_attention_forward_wmma.hpp:861
std::array< index_t, NumDimG+NumDimM+NumDimN > a_mz_kz_strides_
Definition device_grouped_query_attention_forward_wmma.hpp:879
B0GridDesc_G_L_K b0_grid_desc_g_l_k_
Definition device_grouped_query_attention_forward_wmma.hpp:856
ComputeBasePtrOfStridedBatch compute_ptr_offset_of_batch_
Definition device_grouped_query_attention_forward_wmma.hpp:886
AGridDesc a_grid_desc
Definition device_grouped_query_attention_forward_wmma.hpp:850
GridwiseOp::DefaultBlock2CTileMap block_2_ctile_map_
Definition device_grouped_query_attention_forward_wmma.hpp:864
std::array< index_t, NumDimG+NumDimM+NumDimN > b1_nz_lz_strides_
Definition device_grouped_query_attention_forward_wmma.hpp:881
CElementwiseOperation c_element_op_
Definition device_grouped_query_attention_forward_wmma.hpp:871
B1GridDesc b1_grid_desc
Definition device_grouped_query_attention_forward_wmma.hpp:852
const ADataType * p_a_grid_
Definition device_grouped_query_attention_forward_wmma.hpp:844
B0ElementwiseOperation b0_element_op_
Definition device_grouped_query_attention_forward_wmma.hpp:868
C0MatrixMask c0_matrix_mask_
Definition device_grouped_query_attention_forward_wmma.hpp:874
std::array< index_t, NumDimG+NumDimM+NumDimN > b0_lz_kz_strides_
Definition device_grouped_query_attention_forward_wmma.hpp:880
CDataType * p_c_grid_
Definition device_grouped_query_attention_forward_wmma.hpp:847
CGridDesc_G_M_N c_grid_desc_g_m_n_
Definition device_grouped_query_attention_forward_wmma.hpp:858
B0GridDesc b0_grid_desc
Definition device_grouped_query_attention_forward_wmma.hpp:851
std::array< index_t, NumDimG+NumDimM+NumDimN > raw_lengths_mz_lz_kz_nz_
Definition device_grouped_query_attention_forward_wmma.hpp:878
Argument(const ADataType *p_a_grid, const B0DataType *p_b0_grid, const B1DataType *p_b1_grid, CDataType *p_c_grid, const std::array< void *, NumAcc0Bias > p_acc0_biases, const std::array< void *, NumAcc1Bias > p_acc1_biases, const std::array< index_t, NumDimG+NumDimM+NumDimN > &a_gs_ms_ks_lengths, const std::array< index_t, NumDimG+NumDimM+NumDimN > &a_gs_ms_ks_strides, const std::array< index_t, NumDimG+NumDimM+NumDimN > &b0_gs_ls_ks_lengths, const std::array< index_t, NumDimG+NumDimM+NumDimN > &b0_gs_ls_ks_strides, const std::array< index_t, NumDimG+NumDimM+NumDimN > &b1_gs_ns_ls_lengths, const std::array< index_t, NumDimG+NumDimM+NumDimN > &b1_gs_ns_ls_strides, const std::array< index_t, NumDimG+NumDimM+NumDimN > &c_gs_ms_ns_lengths, const std::array< index_t, NumDimG+NumDimM+NumDimN > &c_gs_ms_ns_strides, const std::array< std::vector< ck::index_t >, NumAcc0Bias > acc0_biases_gs_ms_ls_lengths, const std::array< std::vector< ck::index_t >, NumAcc0Bias > acc0_biases_gs_ms_ls_strides, const std::array< std::vector< ck::index_t >, NumAcc1Bias > acc1_biases_gs_ms_ns_lengths, const std::array< std::vector< ck::index_t >, NumAcc1Bias > acc1_biases_gs_ms_ns_strides, const index_t M01, const index_t N01, AElementwiseOperation a_element_op, B0ElementwiseOperation b0_element_op, AccElementwiseOperation acc_element_op, B1ElementwiseOperation b1_element_op, CElementwiseOperation c_element_op)
Definition device_grouped_query_attention_forward_wmma.hpp:757
index_t batch_count_
Definition device_grouped_query_attention_forward_wmma.hpp:884
B1GridDesc_G_N_L b1_grid_desc_g_n_l_
Definition device_grouped_query_attention_forward_wmma.hpp:857
AElementwiseOperation a_element_op_
Definition device_grouped_query_attention_forward_wmma.hpp:867
std::array< index_t, NumDimG+NumDimM+NumDimN > c_mz_nz_strides_
Definition device_grouped_query_attention_forward_wmma.hpp:882
const B1DataType * p_b1_grid_
Definition device_grouped_query_attention_forward_wmma.hpp:846
__host__ __device__ constexpr long_index_t GetABasePtr(index_t g_idx) const
Definition device_grouped_query_attention_forward_wmma.hpp:429
__host__ __device__ constexpr long_index_t GetB0BasePtr(index_t g_idx) const
Definition device_grouped_query_attention_forward_wmma.hpp:434
__host__ __device__ ComputeBasePtrOfStridedBatch(const AGridDesc_G_M_K &a_grid_desc_g_m_k, const B0GridDesc_G_L_K &b0_grid_desc_g_l_k, const B1GridDesc_G_N_L &b1_grid_desc_g_n_l, const CGridDesc_G_M_N &c_grid_desc_g_m_n)
Definition device_grouped_query_attention_forward_wmma.hpp:418
__host__ __device__ constexpr long_index_t GetCBasePtr(index_t g_idx) const
Definition device_grouped_query_attention_forward_wmma.hpp:444
__host__ __device__ constexpr long_index_t GetB1BasePtr(index_t g_idx) const
Definition device_grouped_query_attention_forward_wmma.hpp:439
Definition device_grouped_query_attention_forward_wmma.hpp:890
float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition device_grouped_query_attention_forward_wmma.hpp:947
float Run(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition device_grouped_query_attention_forward_wmma.hpp:893
DeviceOp::RawArg Argument
Definition device_grouped_query_attention_forward_wmma.hpp:891
Definition device_grouped_query_attention_forward_wmma.hpp:533
index_t G1_
Definition device_grouped_query_attention_forward_wmma.hpp:574
index_t N_
Definition device_grouped_query_attention_forward_wmma.hpp:570
float alpha_
Definition device_grouped_query_attention_forward_wmma.hpp:575
index_t K_
Definition device_grouped_query_attention_forward_wmma.hpp:571
bool output_permute_
Definition device_grouped_query_attention_forward_wmma.hpp:577
const ADataType * p_a_grid_
Definition device_grouped_query_attention_forward_wmma.hpp:563
index_t M_
Definition device_grouped_query_attention_forward_wmma.hpp:569
const B1DataType * p_b1_grid_
Definition device_grouped_query_attention_forward_wmma.hpp:565
RawArg(const ADataType *p_a_grid, const B0DataType *p_b0_grid, const B1DataType *p_b1_grid, CDataType *p_c_grid, index_t M, index_t N, index_t K, index_t O, index_t G0, index_t G1, float alpha, bool input_permute, bool output_permute)
Definition device_grouped_query_attention_forward_wmma.hpp:534
index_t O_
Definition device_grouped_query_attention_forward_wmma.hpp:572
bool input_permute_
Definition device_grouped_query_attention_forward_wmma.hpp:576
index_t G0_
Definition device_grouped_query_attention_forward_wmma.hpp:573
const B0DataType * p_b0_grid_
Definition device_grouped_query_attention_forward_wmma.hpp:564
CDataType * p_c_grid_
Definition device_grouped_query_attention_forward_wmma.hpp:566
Definition device_grouped_query_attention_forward_wmma.hpp:266
static constexpr auto NWaves
Definition device_grouped_query_attention_forward_wmma.hpp:297
static constexpr bool IsValidCompilationParameter()
Definition device_grouped_query_attention_forward_wmma.hpp:954
static constexpr auto I0
Definition device_grouped_query_attention_forward_wmma.hpp:285
static constexpr auto B0EnableLds_manu
Definition device_grouped_query_attention_forward_wmma.hpp:304
static constexpr index_t NumDimGemm1K
Definition device_grouped_query_attention_forward_wmma.hpp:281
static constexpr auto AEnableLds_auto
Definition device_grouped_query_attention_forward_wmma.hpp:299
decltype(MakeB0GridDescriptor({}, {})) B0GridDesc
Definition device_grouped_query_attention_forward_wmma.hpp:395
decltype(Transform::MakeB1GridDescriptor_G_N_K({}, {})) B1GridDesc_G_N_L
Definition device_grouped_query_attention_forward_wmma.hpp:400
static constexpr auto B1EnableLds_auto
Definition device_grouped_query_attention_forward_wmma.hpp:301
DeviceGroupedQueryAttentionForward_Wmma DeviceOp
Definition device_grouped_query_attention_forward_wmma.hpp:283
static constexpr auto I1
Definition device_grouped_query_attention_forward_wmma.hpp:286
static constexpr auto I5
Definition device_grouped_query_attention_forward_wmma.hpp:290
decltype(Transform::MakeCGridDescriptor_G_M_N({}, {})) CGridDesc_G_M_N
Definition device_grouped_query_attention_forward_wmma.hpp:401
decltype(MakeB1GridDescriptor({}, {})) B1GridDesc
Definition device_grouped_query_attention_forward_wmma.hpp:396
__host__ __device__ static constexpr auto make_MaskOutPredicate()
Definition device_grouped_query_attention_forward_wmma.hpp:403
decltype(MakeAGridDescriptor({}, {})) AGridDesc
Definition device_grouped_query_attention_forward_wmma.hpp:394
__host__ static __device__ auto MakeAGridDescriptor(const std::array< index_t, NumDimG+NumDimM+NumDimN > &a_gs_ms_ks_lengths_vec, const std::array< index_t, NumDimG+NumDimM+NumDimN > &a_gs_ms_ks_strides_vec)
Definition device_grouped_query_attention_forward_wmma.hpp:320
static constexpr auto B0EnableLds
Definition device_grouped_query_attention_forward_wmma.hpp:308
C0MatrixMask_impl< decltype(make_MaskOutPredicate())> C0MatrixMask
Definition device_grouped_query_attention_forward_wmma.hpp:414
static constexpr auto MWaves
Definition device_grouped_query_attention_forward_wmma.hpp:295
static constexpr index_t NumDimGemm0M
Definition device_grouped_query_attention_forward_wmma.hpp:276
static constexpr index_t NumDimGemm0N
Definition device_grouped_query_attention_forward_wmma.hpp:277
static constexpr auto I6
Definition device_grouped_query_attention_forward_wmma.hpp:291
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition device_grouped_query_attention_forward_wmma.hpp:1197
static constexpr index_t NumDimGemm1M
Definition device_grouped_query_attention_forward_wmma.hpp:279
__host__ static __device__ auto MakeB0GridDescriptor(const std::array< index_t, NumDimG+NumDimM+NumDimN > &b0_gs_ls_ks_lengths_vec, const std::array< index_t, NumDimG+NumDimM+NumDimN > &b0_gs_ls_ks_strides_vec)
Definition device_grouped_query_attention_forward_wmma.hpp:344
static bool IsSupportedArgument(const RawArg &arg)
Definition device_grouped_query_attention_forward_wmma.hpp:598
static constexpr index_t NumAcc1Bias
Definition device_grouped_query_attention_forward_wmma.hpp:271
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition device_grouped_query_attention_forward_wmma.hpp:749
TransformBatchedContractionContractionToBatchedGemmGemm_Wmma< Sequence< NumDimG, NumDimM, NumDimL, NumDimK, NumDimN >, Sequence< MPerBlock, LPerBlock, KPerBlock, NPerBlock >, GemmSpec, ASpec, B0Spec, B1Spec, CSpec > Transform
Definition device_grouped_query_attention_forward_wmma.hpp:311
static constexpr auto B1EnableLds_manu
Definition device_grouped_query_attention_forward_wmma.hpp:305
static constexpr index_t NumDimGemm1N
Definition device_grouped_query_attention_forward_wmma.hpp:280
static constexpr auto B0EnableLds_auto
Definition device_grouped_query_attention_forward_wmma.hpp:300
static constexpr auto AEnableLds_manu
Definition device_grouped_query_attention_forward_wmma.hpp:303
static auto MakeInvoker()
Definition device_grouped_query_attention_forward_wmma.hpp:1194
static constexpr auto I4
Definition device_grouped_query_attention_forward_wmma.hpp:289
static auto MakeArgument(const ADataType *p_a, const B0DataType *p_b0, const B1DataType *p_b1, CDataType *p_c, index_t M, index_t N, index_t K, index_t O, index_t G0, index_t G1, float alpha, bool input_permute, bool output_permute)
Definition device_grouped_query_attention_forward_wmma.hpp:580
decltype(Transform::MakeCGridDescriptor_M_N({}, {})) CGridDesc_M_N
Definition device_grouped_query_attention_forward_wmma.hpp:397
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b0, const void *p_b1, void *p_c, const std::array< void *, NumAcc0Bias > p_acc0_biases, const std::array< void *, NumAcc1Bias > p_acc1_biases, const std::vector< index_t > &a_gs_ms_ks_lengths, const std::vector< index_t > &a_gs_ms_ks_strides, const std::vector< index_t > &b0_gs_ls_ks_lengths, const std::vector< index_t > &b0_gs_ls_ks_strides, const std::vector< index_t > &b1_gs_ns_ls_lengths, const std::vector< index_t > &b1_gs_ns_ls_strides, const std::vector< index_t > &c_gs_ms_ns_lengths, const std::vector< index_t > &c_gs_ms_ns_strides, const std::array< std::vector< ck::index_t >, NumAcc0Bias > acc0_biases_gs_ms_ls_lengths, const std::array< std::vector< ck::index_t >, NumAcc0Bias > acc0_biases_gs_ms_ls_strides, const std::array< std::vector< ck::index_t >, NumAcc1Bias > acc1_biases_gs_ms_ns_lengths, const std::array< std::vector< ck::index_t >, NumAcc1Bias > acc1_biases_gs_ms_ns_strides, AElementwiseOperation a_element_op, B0ElementwiseOperation b0_element_op, AccElementwiseOperation acc_element_op, B1ElementwiseOperation b1_element_op, CElementwiseOperation c_element_op) override
Definition device_grouped_query_attention_forward_wmma.hpp:1102
static constexpr index_t NumDimGemm0K
Definition device_grouped_query_attention_forward_wmma.hpp:278
static constexpr auto I2
Definition device_grouped_query_attention_forward_wmma.hpp:287
static constexpr auto B1EnableLds
Definition device_grouped_query_attention_forward_wmma.hpp:309
__host__ static __device__ auto MakeB1GridDescriptor(const std::array< index_t, NumDimG+NumDimM+NumDimN > &b1_gs_ns_ls_lengths_vec, const std::array< index_t, NumDimG+NumDimM+NumDimN > &b1_gs_ns_ls_strides_vec)
Definition device_grouped_query_attention_forward_wmma.hpp:369
static constexpr auto LWaves
Definition device_grouped_query_attention_forward_wmma.hpp:296
static constexpr auto AEnableLds
Definition device_grouped_query_attention_forward_wmma.hpp:307
decltype(Transform::MakeB0GridDescriptor_G_N_K({}, {})) B0GridDesc_G_L_K
Definition device_grouped_query_attention_forward_wmma.hpp:399
std::string GetTypeString() const override
Definition device_grouped_query_attention_forward_wmma.hpp:1203
static constexpr auto WmmaK
Definition device_grouped_query_attention_forward_wmma.hpp:293
decltype(Transform::MakeAGridDescriptor_G_M_K({}, {})) AGridDesc_G_M_K
Definition device_grouped_query_attention_forward_wmma.hpp:398
static constexpr index_t NumAcc0Bias
Definition device_grouped_query_attention_forward_wmma.hpp:270
GridwiseBatchedGemmSoftmaxGemm_Wmma< ADataType, B0DataType, Acc0DataType, B1DataType, Acc1DataType, CShuffleDataType, CDataType, AElementwiseOperation, B0ElementwiseOperation, AccElementwiseOperation, B1ElementwiseOperation, CElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc, B0GridDesc, B1GridDesc, CGridDesc_M_N, MPerBlock, LPerBlock, KPerBlock, AK1, BK1, NPerBlock, LTilePerBlock, L1, MPerWmma, LPerWmma, NPerWmma, MRepeat, LRepeat, NRepeat, BlockSize, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_K1, true, AEnableLds, ABlockLdsAddExtraM, B0BlockTransferThreadClusterLengths_K0_L_K1, B0BlockTransferThreadClusterArrangeOrder, B0BlockTransferSrcAccessOrder, B0BlockTransferSrcVectorDim, B0BlockTransferSrcScalarPerVector, B0BlockTransferDstScalarPerVector_K1, true, B0EnableLds, B0BlockLdsAddExtraL, B1BlockTransferThreadClusterLengths_L0_N_L1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_L1, false, B1EnableLds, B1BlockLdsAddExtraN, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, Transform::matrix_padder.PadN, MaskingSpec==MaskingSpecialization::MaskOutUpperTriangle, NumPrefetch, LoopSched, PipelineVer > GridwiseOp
Definition device_grouped_query_attention_forward_wmma.hpp:457
static constexpr auto I3
Definition device_grouped_query_attention_forward_wmma.hpp:288
Definition masking_specialization.hpp:29
Definition masking_specialization.hpp:43