device_contraction_multiple_abd_xdl_cshuffle.hpp Source File

device_contraction_multiple_abd_xdl_cshuffle.hpp Source File#

Composable Kernel: device_contraction_multiple_abd_xdl_cshuffle.hpp Source File
device_contraction_multiple_abd_xdl_cshuffle.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2023-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include <iostream>
7#include <sstream>
8#include <vector>
9
21
22namespace ck {
23
24template <typename GridwiseGemm,
25 typename AsPointer,
26 typename BsPointer,
27 typename DsPointer,
28 typename EDataType,
29 typename AElementwiseOperation,
30 typename BElementwiseOperation,
31 typename CDEElementwiseOperation,
32 typename AsGridDesc_AK0_M_AK1,
33 typename BsGridDesc_BK0_N_BK1,
34 typename DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
35 typename EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
36 typename Block2ETileMap,
37 bool HasMainKBlockLoop>
38__global__ void
39#if CK_USE_LAUNCH_BOUNDS
41#endif
43 AsPointer p_as_grid,
44 BsPointer p_bs_grid,
45 DsPointer p_ds_grid,
46 EDataType* __restrict__ p_e_grid,
47 const AElementwiseOperation a_element_op,
48 const BElementwiseOperation b_element_op,
49 const CDEElementwiseOperation cde_element_op,
50 const AsGridDesc_AK0_M_AK1 as_grid_desc_ak0_m_ak1,
51 const BsGridDesc_BK0_N_BK1 bs_grid_desc_bk0_n_bk1,
52 const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
53 ds_grid_desc_mblock_mperblock_nblock_nperblock,
54 const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
55 e_grid_desc_mblock_mperblock_nblock_nperblock,
56 const Block2ETileMap block_2_etile_map)
57{
58#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__)
59 if constexpr(GridwiseGemm::template IsValidCompilationParameter<>())
60 {
61 __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
62
63 GridwiseGemm::template Run<HasMainKBlockLoop>(
64 p_as_grid,
65 p_bs_grid,
66 p_ds_grid,
67 p_e_grid,
68 p_shared,
69 a_element_op,
70 b_element_op,
71 cde_element_op,
72 as_grid_desc_ak0_m_ak1,
73 bs_grid_desc_bk0_n_bk1,
74 ds_grid_desc_mblock_mperblock_nblock_nperblock,
75 e_grid_desc_mblock_mperblock_nblock_nperblock,
76 block_2_etile_map);
77 }
78#else
79 ignore = p_as_grid;
80 ignore = p_bs_grid;
81 ignore = p_ds_grid;
82 ignore = p_e_grid;
83 ignore = a_element_op;
84 ignore = b_element_op;
85 ignore = cde_element_op;
86 ignore = as_grid_desc_ak0_m_ak1;
87 ignore = bs_grid_desc_bk0_n_bk1;
88 ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock;
89 ignore = e_grid_desc_mblock_mperblock_nblock_nperblock;
90 ignore = block_2_etile_map;
91#endif
92}
93
94} // namespace ck
95
96namespace ck {
97namespace tensor_operation {
98namespace device {
99
100// GEMM:
101// input : A[M, K]
102// input : B[N, K]
103// input : D0[M, N], D1[M, N], ...
104// output : E[M, N]
105// C = a_op(A) * b_op(B)
106// E = cde_op(C, D0, D1, ...)
107// Assume:
108// D0, D1, ... and E have the same layout
109template <index_t NumDimM,
110 index_t NumDimN,
111 index_t NumDimK,
112 typename AsDataType,
113 typename BsDataType,
114 typename AccDataType,
115 typename CShuffleDataType,
116 typename DsDataType,
117 typename EDataType,
118 typename AElementwiseOperation,
119 typename BElementwiseOperation,
120 typename CDEElementwiseOperation,
121 GemmSpecialization GemmSpec,
122 index_t NumGemmKPrefetchStage,
123 index_t BlockSize,
124 index_t MPerBlock,
125 index_t NPerBlock,
126 index_t KPerBlock,
127 index_t AK1,
128 index_t BK1,
129 index_t MPerXDL,
130 index_t NPerXDL,
131 index_t MXdlPerWave,
132 index_t NXdlPerWave,
133 typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
134 typename ABlockTransferThreadClusterArrangeOrder,
135 typename ABlockTransferSrcAccessOrder,
136 index_t ABlockTransferSrcVectorDim,
137 index_t ABlockTransferSrcScalarPerVector,
138 index_t ABlockTransferDstScalarPerVector_AK1,
139 index_t ABlockLdsExtraM,
140 typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
141 typename BBlockTransferThreadClusterArrangeOrder,
142 typename BBlockTransferSrcAccessOrder,
143 index_t BBlockTransferSrcVectorDim,
144 index_t BBlockTransferSrcScalarPerVector,
145 index_t BBlockTransferDstScalarPerVector_BK1,
146 index_t BBlockLdsExtraN,
147 index_t CShuffleMXdlPerWavePerShuffle,
148 index_t CShuffleNXdlPerWavePerShuffle,
149 typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
150 index_t CDEBlockTransferScalarPerVector_NPerBlock,
154 : public DeviceContractionMultipleABD<NumDimM,
155 NumDimN,
156 NumDimK,
157 AsDataType,
158 BsDataType,
159 DsDataType,
160 EDataType,
161 AElementwiseOperation,
162 BElementwiseOperation,
163 CDEElementwiseOperation>
164{
166
168 static constexpr auto NXdlPerWave64 = GetNXdlPerWave<true>();
169 static constexpr auto NXdlPerWave32 = GetNXdlPerWave<false>();
170
171 static constexpr index_t NumATensor = AsDataType::Size();
172 static constexpr index_t NumBTensor = BsDataType::Size();
173 static constexpr index_t NumDTensor = DsDataType::Size();
174
175 static constexpr auto I0 = Number<0>{};
176 static constexpr auto I1 = Number<1>{};
177 static constexpr auto I2 = Number<2>{};
178 static constexpr auto I3 = Number<3>{};
179
180 using ComputeDataType = EDataType;
181
182 // GridwiseGemm
183 template <index_t NXdlPerWave_>
185 AsDataType,
186 BsDataType,
188 AccDataType,
189 CShuffleDataType,
190 DsDataType,
191 EDataType,
192 AElementwiseOperation,
193 BElementwiseOperation,
194 CDEElementwiseOperation,
196 NumGemmKPrefetchStage,
197 BlockSize,
198 MPerBlock,
199 NPerBlock,
200 KPerBlock,
201 AK1,
202 BK1,
203 MPerXDL,
204 NPerXDL,
205 MXdlPerWave,
206 NXdlPerWave_,
207 ABlockTransferThreadClusterLengths_AK0_M_AK1,
208 ABlockTransferThreadClusterArrangeOrder,
209 ABlockTransferSrcAccessOrder,
210 ABlockTransferSrcVectorDim,
211 ABlockTransferSrcScalarPerVector,
212 ABlockTransferDstScalarPerVector_AK1,
213 false,
214 ABlockLdsExtraM,
215 BBlockTransferThreadClusterLengths_BK0_N_BK1,
216 BBlockTransferThreadClusterArrangeOrder,
217 BBlockTransferSrcAccessOrder,
218 BBlockTransferSrcVectorDim,
219 BBlockTransferSrcScalarPerVector,
220 BBlockTransferDstScalarPerVector_BK1,
221 false,
222 BBlockLdsExtraN,
223 CShuffleMXdlPerWavePerShuffle,
224 CShuffleNXdlPerWavePerShuffle,
225 CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
226 CDEBlockTransferScalarPerVector_NPerBlock,
227 LoopSched,
228 PipelineVer>;
231
232 static constexpr auto matrix_padder =
234 MPerBlock, NPerBlock, KPerBlock};
235
236 static auto MakeAGridDescriptor_M_K(const std::vector<index_t>& a_ms_ks_lengths_,
237 const std::vector<index_t>& a_ms_ks_strides_)
238 {
239 assert(a_ms_ks_lengths_.size() == NumDimM + NumDimK &&
240 a_ms_ks_strides_.size() == NumDimM + NumDimK);
241
242 const auto to_tuple = [&](auto& vec, auto num) {
243 return generate_tuple([&](auto i) { return vec[i]; }, num);
244 };
245
246 const auto a_ms_ks_lengths = to_tuple(a_ms_ks_lengths_, Number<NumDimM + NumDimK>{});
247 const auto a_ms_ks_strides = to_tuple(a_ms_ks_strides_, Number<NumDimM + NumDimK>{});
248
249 // dimension Ids for M0, M1, ...
250 constexpr auto mDimIds = typename arithmetic_sequence_gen<0, NumDimM, 1>::type{};
251
252 // dimension Ids for K0, K1, ...
253 constexpr auto kDimIds =
255
256 // lengths for M0, M1, ...
257 const auto mLengths = get_container_subset(a_ms_ks_lengths, mDimIds);
258
259 // lengths for K0, K1, ...
260 const auto kLengths = get_container_subset(a_ms_ks_lengths, kDimIds);
261
262 // naive tensor A[M0, M1, M2, ..., K0, K1, K2...]
263 const auto a_grid_desc_ms_ks =
264 make_naive_tensor_descriptor(a_ms_ks_lengths, a_ms_ks_strides);
265
266 // transformed tensor A[MRaw = M0 * M1 * M2 * ... , KRaw = K0 * K1 * K2 * ...]
267 const auto a_grid_desc_mraw_kraw = transform_tensor_descriptor(
268 a_grid_desc_ms_ks,
270 make_tuple(mDimIds, kDimIds),
272
273 return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw);
274 }
275
276 __host__ __device__ static auto
277 MakeAsGridDescriptor_M_K(const std::array<std::vector<index_t>, NumATensor>& as_ms_ks_lengths,
278 const std::array<std::vector<index_t>, NumATensor>& as_ms_ks_strides)
279 {
280 return generate_tuple(
281 [&](auto i) {
282 return MakeAGridDescriptor_M_K(as_ms_ks_lengths[i], as_ms_ks_strides[i]);
283 },
285 }
286
287 // Assume: B[N0, N1, N2, ..., K0, K1, K2, ...]
288 static auto MakeBGridDescriptor_N_K(const std::vector<index_t>& b_ns_ks_lengths_,
289 const std::vector<index_t>& b_ns_ks_strides_)
290 {
291 assert(b_ns_ks_lengths_.size() == NumDimN + NumDimK &&
292 b_ns_ks_strides_.size() == NumDimN + NumDimK);
293
294 const auto to_tuple = [&](auto& vec, auto num) {
295 return generate_tuple([&](auto i) { return vec[i]; }, num);
296 };
297
298 const auto b_ns_ks_lengths = to_tuple(b_ns_ks_lengths_, Number<NumDimN + NumDimK>{});
299 const auto b_ns_ks_strides = to_tuple(b_ns_ks_strides_, Number<NumDimN + NumDimK>{});
300
301 // dimension Ids for N0, N1, ...
302 constexpr auto nDimIds = typename arithmetic_sequence_gen<0, NumDimN, 1>::type{};
303
304 // dimension Ids for K0, K1, ...
305 constexpr auto kDimIds =
307
308 // lengths for K0, K1, ...
309 const auto kLengths = get_container_subset(b_ns_ks_lengths, kDimIds);
310
311 // lengths for N0, N1, ...
312 const auto nLengths = get_container_subset(b_ns_ks_lengths, nDimIds);
313
314 // naive tensor B[N0, N1, N2, ..., K0, K1, K2, ...]
315 const auto b_grid_desc_ns_ks =
316 make_naive_tensor_descriptor(b_ns_ks_lengths, b_ns_ks_strides);
317
318 // transformed tensor B[NRaw = N0 * N1 * N2 * ..., KRaw = K0 * K1 * K2 * ...]
319 const auto b_grid_desc_nraw_kraw = transform_tensor_descriptor(
320 b_grid_desc_ns_ks,
322 make_tuple(nDimIds, kDimIds),
324
325 return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw);
326 }
327
328 __host__ __device__ static auto
329 MakeBsGridDescriptor_N_K(const std::array<std::vector<index_t>, NumBTensor>& bs_ns_ks_lengths,
330 const std::array<std::vector<index_t>, NumBTensor>& bs_ns_ks_strides)
331 {
332 return generate_tuple(
333 [&](auto i) {
334 return MakeBGridDescriptor_N_K(bs_ns_ks_lengths[i], bs_ns_ks_strides[i]);
335 },
337 }
338
339 // assume E[M0, M1, M2, ..., N0, N1, N2...]
340 static auto MakeEGridDescriptor_M_N(const std::vector<index_t>& e_ms_ns_lengths_,
341 const std::vector<index_t>& e_ms_ns_strides_)
342 {
343 assert(e_ms_ns_lengths_.size() == NumDimM + NumDimN &&
344 e_ms_ns_strides_.size() == NumDimM + NumDimN);
345
346 const auto to_tuple = [&](auto& vec, auto num) {
347 return generate_tuple([&](auto i) { return vec[i]; }, num);
348 };
349
350 const auto e_ms_ns_lengths = to_tuple(e_ms_ns_lengths_, Number<NumDimM + NumDimN>{});
351 const auto e_ms_ns_strides = to_tuple(e_ms_ns_strides_, Number<NumDimM + NumDimN>{});
352
353 // dimension Ids for M0, M1, ...
354 constexpr auto mDimIds = typename arithmetic_sequence_gen<0, NumDimM, 1>::type{};
355
356 // dimension Ids for N0, N1, ...
357 constexpr auto nDimIds =
359
360 // lengths for M0, M1, ...
361 const auto mLengths = get_container_subset(e_ms_ns_lengths, mDimIds);
362
363 // lengths for K0, K1, ...
364 const auto nLengths = get_container_subset(e_ms_ns_lengths, nDimIds);
365
366 // naive tensor E[M0, M1, M2, ..., N0, N1, N2...]
367 const auto e_grid_desc_ms_ns =
368 make_naive_tensor_descriptor(e_ms_ns_lengths, e_ms_ns_strides);
369
370 // transformed tensor E[MRaw = M0 * M1 * M2 * ... , NRaw = N0 * N1 * N2 * ...]
371 const auto e_grid_desc_mraw_nraw = transform_tensor_descriptor(
372 e_grid_desc_ms_ns,
374 make_tuple(mDimIds, nDimIds),
376
377 return matrix_padder.PadCDescriptor_M_N(e_grid_desc_mraw_nraw);
378 }
379
380 static auto
381 MakeDsGridDescriptor_M_N(const std::array<std::vector<index_t>, NumDTensor>& ds_ms_ns_lengths,
382 const std::array<std::vector<index_t>, NumDTensor>& ds_ms_ns_strides)
383 {
384 return generate_tuple(
385 [&](auto i) {
386 return MakeEGridDescriptor_M_N(ds_ms_ns_lengths[i], ds_ms_ns_strides[i]);
387 },
389 }
390
391 // desc for problem definition
396
397 // desc for blockwise copy
400 AsGridDesc_M_K{}))>;
403 BsGridDesc_N_K{}))>;
406 DsGridDesc_M_N{}))>;
409 EGridDesc_M_N{}))>;
410
411 // block-to-e-tile map
414
415 // Argument
416 struct Argument : public BaseArgument
417 {
418 Argument(std::array<const void*, NumATensor> p_as_grid,
419 std::array<const void*, NumBTensor> p_bs_grid,
420 std::array<const void*, NumDTensor> p_ds_grid,
421 void* p_e_grid,
422 const std::array<std::vector<index_t>, NumATensor>& a_ms_ks_lengths,
423 const std::array<std::vector<index_t>, NumATensor>& a_ms_ks_strides,
424 const std::array<std::vector<index_t>, NumBTensor>& b_ns_ks_lengths,
425 const std::array<std::vector<index_t>, NumBTensor>& b_ns_ks_strides,
426 const std::array<std::vector<index_t>, NumDTensor>& d_ms_ns_lengths,
427 const std::array<std::vector<index_t>, NumDTensor>& d_ms_ns_strides,
428 const std::vector<index_t>& e_ms_ns_length,
429 const std::vector<index_t>& e_ms_ns_stride,
430 AElementwiseOperation a_element_op,
431 BElementwiseOperation b_element_op,
432 CDEElementwiseOperation cde_element_op)
433 : p_as_grid_{},
434 p_bs_grid_{},
435 p_ds_grid_{},
436 p_e_grid_{static_cast<EDataType*>(p_e_grid)},
440 e_grid_desc_m_n_{MakeEGridDescriptor_M_N(e_ms_ns_length, e_ms_ns_stride)},
441 block_2_etile_map_{GridwiseGemm64::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)},
442 a_element_op_{a_element_op},
443 b_element_op_{b_element_op},
444 cde_element_op_{cde_element_op}
445 {
446 // populate pointer, desc for As
447 static_for<0, NumATensor, 1>{}([&](auto i) {
448 // using ALayout = remove_cvref_t<tuple_element_t<i.value, AsLayout>>;
449 using ADataType = remove_cvref_t<tuple_element_t<i.value, AsDataType>>;
450
451 // A pointer
452 p_as_grid_(i) = static_cast<const ADataType*>(p_as_grid[i]);
453
454 // A desc
456 MakeAGridDescriptor_M_K(a_ms_ks_lengths[i], a_ms_ks_strides[i]);
457 });
458
459 // populate pointer, desc for Bs
460 static_for<0, NumBTensor, 1>{}([&](auto i) {
461 // using BLayout = remove_cvref_t<tuple_element_t<i.value, BsLayout>>;
462 using BDataType = remove_cvref_t<tuple_element_t<i.value, BsDataType>>;
463
464 // B pointer
465 p_bs_grid_(i) = static_cast<const BDataType*>(p_bs_grid[i]);
466
467 // B desc
469 MakeBGridDescriptor_N_K(b_ns_ks_lengths[i], b_ns_ks_strides[i]);
470 });
471
472 // populate pointer, desc for Ds
473 static_for<0, NumDTensor, 1>{}([&](auto i) {
474 // using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
475 using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
476
477 // D pointer
478 p_ds_grid_(i) = static_cast<const DDataType*>(p_ds_grid[i]);
479
480 // D desc
482 MakeEGridDescriptor_M_N(d_ms_ns_lengths[i], d_ms_ns_strides[i]);
483 });
484
485 // for sanity check of vector memory access
486 for(index_t i = 0; i < NumATensor; ++i)
487 {
489 CalculateMaxRead<NumDimM, NumDimK>(a_ms_ks_lengths[i], a_ms_ks_strides[i]);
490 }
491
492 for(index_t i = 0; i < NumBTensor; ++i)
493 {
495 CalculateMaxRead<NumDimN, NumDimK>(b_ns_ks_lengths[i], b_ns_ks_strides[i]);
496 }
497
498 for(index_t i = 0; i < NumDTensor; ++i)
499 {
501 CalculateMaxRead<NumDimM, NumDimN>(d_ms_ns_lengths[i], d_ms_ns_strides[i]);
502 }
503
505 CalculateMaxRead<NumDimM, NumDimN>(e_ms_ns_length, e_ms_ns_stride);
506 }
507
508 // pointers
512 EDataType* p_e_grid_;
513
514 // tensor descriptors for problem definiton
519
520 // block-to-e-tile map
522
523 // element-wise op
524 AElementwiseOperation a_element_op_;
525 BElementwiseOperation b_element_op_;
526 CDEElementwiseOperation cde_element_op_;
527
528 // Describe whether the last part of a given dimension of A/B/D/E is continues dim.
529 std::array<index_t, NumATensor> as_continous_dim_;
530 std::array<index_t, NumATensor> bs_continous_dim_;
531 std::array<index_t, NumBTensor> ds_continous_dim_;
533
534 std::array<index_t, NumATensor> as_max_read_elems_;
535 std::array<index_t, NumBTensor> bs_max_read_elems_;
536 std::array<index_t, NumDTensor> ds_max_read_elems_;
538 };
539
540 // Invoker
541 struct Invoker : public BaseInvoker
542 {
544
545 template <typename GridwiseGemm>
546 float RunImp(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
547 {
548 if(!GridwiseGemm::CheckValidity(arg.as_grid_desc_m_k_,
553 {
554 throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
555 }
556 auto as_grid_desc_ak0_m_ak1 =
557 GridwiseGemm::MakeDefaultAsGridDescriptor_AK0_M_AK1(arg.as_grid_desc_m_k_);
558
559 auto bs_grid_desc_bk0_n_bk1 =
560 GridwiseGemm::MakeDefaultBsGridDescriptor_BK0_N_BK1(arg.bs_grid_desc_n_k_);
561
562 auto ds_grid_desc_mblock_mperblock_nblock_nperblock =
563 GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
565
566 auto e_grid_desc_mblock_mperblock_nblock_nperblock =
567 GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
568 arg.e_grid_desc_m_n_);
569 const index_t grid_size =
570 arg.block_2_etile_map_.CalculateGridSize(arg.e_grid_desc_m_n_);
571
572 auto launch_kernel = [&](auto has_main_k_block_loop) {
573 constexpr bool has_main_loop = has_main_k_block_loop.value;
574
576 GridwiseGemm,
577 typename GridwiseGemm::AsGridPointer,
578 typename GridwiseGemm::BsGridPointer,
579 typename GridwiseGemm::DsGridPointer,
580 EDataType,
581 AElementwiseOperation,
582 BElementwiseOperation,
583 CDEElementwiseOperation,
589 has_main_loop>;
590
591 return launch_and_time_kernel(stream_config,
592 kernel,
593 dim3(grid_size),
594 dim3(BlockSize),
595 0,
596 arg.p_as_grid_,
597 arg.p_bs_grid_,
598 arg.p_ds_grid_,
599 arg.p_e_grid_,
600 arg.a_element_op_,
601 arg.b_element_op_,
602 arg.cde_element_op_,
603 as_grid_desc_ak0_m_ak1,
604 bs_grid_desc_bk0_n_bk1,
605 ds_grid_desc_mblock_mperblock_nblock_nperblock,
606 e_grid_desc_mblock_mperblock_nblock_nperblock,
608 };
609
610 const auto K = arg.as_grid_desc_m_k_[I0].GetLength(I1);
611
612 if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
613 {
614 return launch_kernel(integral_constant<bool, true>{});
615 }
616 else
617 {
618 return launch_kernel(integral_constant<bool, false>{});
619 }
620 }
621
623
624 // polymorphic
625 float Run(const BaseArgument* p_arg,
626 const StreamConfig& stream_config = StreamConfig{}) override
627 {
628 return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
629 }
630 };
631
632 static bool IsSupportedArgument(const Argument& arg)
633 {
637 {
638 return false;
639 }
640 // check vector load/store
641 {
642 bool valid_as_access = true;
643 static_for<0, NumATensor, 1>{}([&](auto i) {
644 const bool valid_a_vector_size =
645 arg.as_max_read_elems_[i] % ABlockTransferSrcScalarPerVector == 0;
646 const bool valid_a_access_dim_m =
647 ABlockTransferSrcVectorDim == 1 && arg.as_continous_dim_[i] == 0;
648 const bool valid_a_access_dim_k =
649 ABlockTransferSrcVectorDim == 2 && arg.as_continous_dim_[i] == 1;
650 const bool valid_a_access_dim = valid_a_access_dim_m || valid_a_access_dim_k;
651 if(!((valid_a_vector_size && valid_a_access_dim) ||
652 ABlockTransferSrcScalarPerVector == 1))
653 {
654 valid_as_access = false;
655 }
656 });
657 if(!valid_as_access)
658 {
659 return false;
660 }
661
662 bool valid_bs_access = true;
663 static_for<0, NumBTensor, 1>{}([&](auto i) {
664 const bool valid_b_vector_size =
665 arg.bs_max_read_elems_[i] % BBlockTransferSrcScalarPerVector == 0;
666 const bool valid_b_access_dim_n =
667 BBlockTransferSrcVectorDim == 1 && arg.bs_continous_dim_[i] == 0;
668 const bool valid_b_access_dim_k =
669 BBlockTransferSrcVectorDim == 2 && arg.bs_continous_dim_[i] == 1;
670 const bool valid_b_access_dim = valid_b_access_dim_n || valid_b_access_dim_k;
671 if(!((valid_b_vector_size && valid_b_access_dim) ||
672 BBlockTransferSrcScalarPerVector == 1))
673 {
674 valid_bs_access = false;
675 }
676 });
677 if(!valid_bs_access)
678 {
679 return false;
680 }
681
682 bool valid_ds_access = true;
683 static_for<0, NumDTensor, 1>{}([&](auto i) {
684 const bool valid_d_vector_size =
685 arg.ds_max_read_elems_[i] % CDEBlockTransferScalarPerVector_NPerBlock == 0;
686 // Vector read of Ds is always on N dimension.
687 const bool valid_d_access_dim = arg.ds_continous_dim_[i] == 1;
688 if(!((valid_d_vector_size && valid_d_access_dim) ||
689 CDEBlockTransferScalarPerVector_NPerBlock == 1))
690 {
691 valid_ds_access = false;
692 }
693 });
694 if(!valid_ds_access)
695 {
696 return false;
697 }
698
699 const bool valid_e_vector_size =
700 arg.e_max_write_elems_ % CDEBlockTransferScalarPerVector_NPerBlock == 0;
701 // Vector write of E is always on N dimension.
702 const bool valid_e_access_dim = arg.e_continous_dim_ == 1;
703 if(!((valid_e_vector_size && valid_e_access_dim) ||
704 CDEBlockTransferScalarPerVector_NPerBlock == 1))
705 {
706 return false;
707 }
708 }
709
710 if(get_warp_size() > 0)
711 {
712 if constexpr(NXdlPerWave64 > 0)
713 {
719 }
720 }
721 else
722 {
723 if constexpr(NXdlPerWave32 > 0)
724 {
730 }
731 }
732
733 return false;
734 }
735
736 // polymorphic
737 bool IsSupportedArgument(const BaseArgument* p_arg) override
738 {
739 return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
740 }
741
742 static auto MakeArgument(std::array<const void*, NumATensor> p_as,
743 std::array<const void*, NumBTensor> p_bs,
744 std::array<const void*, NumDTensor> p_ds,
745 void* p_e,
746 const std::array<std::vector<index_t>, NumATensor>& a_ms_ks_lengths,
747 const std::array<std::vector<index_t>, NumATensor>& a_ms_ks_strides,
748 const std::array<std::vector<index_t>, NumBTensor>& b_ns_ks_lengths,
749 const std::array<std::vector<index_t>, NumBTensor>& b_ns_ks_strides,
750 const std::array<std::vector<index_t>, NumDTensor>& d_ms_ns_lengths,
751 const std::array<std::vector<index_t>, NumDTensor>& d_ms_ns_strides,
752 const std::vector<index_t>& e_ms_ns_length,
753 const std::vector<index_t>& e_ms_ns_stride,
754 AElementwiseOperation a_element_op,
755 BElementwiseOperation b_element_op,
756 CDEElementwiseOperation cde_element_op)
757 {
758 return Argument{p_as,
759 p_bs,
760 p_ds,
761 p_e,
762 a_ms_ks_lengths,
763 a_ms_ks_strides,
764 b_ns_ks_lengths,
765 b_ns_ks_strides,
766 d_ms_ns_lengths,
767 d_ms_ns_strides,
768 e_ms_ns_length,
769 e_ms_ns_stride,
770 a_element_op,
771 b_element_op,
772 cde_element_op};
773 }
774
775 static auto MakeInvoker() { return Invoker{}; }
776
777 // polymorphic
778 std::unique_ptr<BaseArgument>
779 MakeArgumentPointer(std::array<const void*, NumATensor> p_as,
780 std::array<const void*, NumBTensor> p_bs,
781 std::array<const void*, NumDTensor> p_ds,
782 void* p_e,
783 const std::array<std::vector<index_t>, NumATensor>& as_ms_ks_lengths,
784 const std::array<std::vector<index_t>, NumATensor>& as_ms_ks_strides,
785 const std::array<std::vector<index_t>, NumBTensor>& bs_ns_ks_lengths,
786 const std::array<std::vector<index_t>, NumBTensor>& bs_ns_ks_strides,
787 const std::array<std::vector<index_t>, NumDTensor>& ds_ms_ns_lengths,
788 const std::array<std::vector<index_t>, NumDTensor>& ds_ms_ns_strides,
789 const std::vector<index_t>& e_ms_ns_length,
790 const std::vector<index_t>& e_ms_ns_stride,
791 AElementwiseOperation a_element_op,
792 BElementwiseOperation b_element_op,
793 CDEElementwiseOperation cde_element_op) override
794 {
795 return std::make_unique<Argument>(p_as,
796 p_bs,
797 p_ds,
798 p_e,
799 as_ms_ks_lengths,
800 as_ms_ks_strides,
801 bs_ns_ks_lengths,
802 bs_ns_ks_strides,
803 ds_ms_ns_lengths,
804 ds_ms_ns_strides,
805 e_ms_ns_length,
806 e_ms_ns_stride,
807 a_element_op,
808 b_element_op,
809 cde_element_op);
810 }
811
812 // polymorphic
813 std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
814 {
815 return std::make_unique<Invoker>(Invoker{});
816 }
817
818 // polymorphic
819 std::string GetTypeString() const override
820 {
821 auto str = std::stringstream();
822
823 std::map<LoopScheduler, std::string> LoopSchedToString{
824 {LoopScheduler::Default, "Default"}, {LoopScheduler::Interwave, "Interwave"}};
825
826 std::map<PipelineVersion, std::string> PipelineVersionToString{{PipelineVersion::v1, "v1"},
827 {PipelineVersion::v2, "v2"}};
828
829 // clang-format off
830 str << "DeviceContractionMultipleABD_Xdl_CShuffle"
831 << "<"
832 << BlockSize << ", "
833 << MPerBlock << ", "
834 << NPerBlock << ", "
835 << KPerBlock << ", "
836 << AK1 << ", "
837 << BK1 << ", "
838 << MPerXDL << ", "
839 << NPerXDL << ", "
840 << MXdlPerWave << ", "
841 << NXdlPerWave << ", "
842 << ABlockTransferSrcScalarPerVector << ", "
843 << BBlockTransferSrcScalarPerVector << ", "
844 << CShuffleMXdlPerWavePerShuffle << ", "
845 << CShuffleNXdlPerWavePerShuffle << ", "
846 << getGemmSpecializationString(GemmSpec)
847 << ">"
848 << " LoopScheduler: "
849 << LoopSchedToString[LoopSched] << ", "
850 << "PipelineVersion: "
851 << PipelineVersionToString[PipelineVer];
852 // clang-format on
853
854 return str.str();
855 }
856};
857
858} // namespace device
859} // namespace tensor_operation
860} // namespace ck
#define CK_MIN_BLOCK_PER_CU
Definition ck.hpp:31
#define CK_MAX_THREAD_PER_BLOCK
Definition ck.hpp:30
#define GET_NXDL_PER_WAVE_IMPL
Definition device_base.hpp:81
#define INVOKER_RUN_IMPL
Definition device_base.hpp:94
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:14
__host__ __device__ constexpr T max(T x)
Definition utility/math.hpp:84
Definition convolution_backward_data_specialization.hpp:8
std::string getGemmSpecializationString(const GemmSpecialization &s)
Definition gemm_specialization.hpp:32
auto CalculateMaxRead(const std::vector< index_t > &lengths, const std::vector< index_t > &strides)
Definition device_contraction_utils.hpp:33
GemmSpecialization
Definition gemm_specialization.hpp:11
Definition convolution_backward_data_specialization.hpp:7
Definition ck.hpp:268
int32_t index_t
Definition ck.hpp:299
__host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition tensor_descriptor_helper.hpp:49
@ Set
Definition ck.hpp:278
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
constexpr Tuple< Args &... > tie(Args &... args) noexcept
Definition utility/tuple.hpp:218
integral_constant< index_t, N > Number
Definition number.hpp:12
typename tuple_element< I, TTuple >::type tuple_element_t
Definition utility/tuple.hpp:208
__host__ __device__ constexpr auto make_merge_transform(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:55
constexpr detail::ignore_t ignore
Definition utility/ignore.hpp:20
bool is_xdl_wmma_supported()
Definition host_utility/device_prop.hpp:76
__device__ constexpr index_t get_warp_size()
Definition get_id.hpp:10
__host__ __device__ constexpr auto get_container_subset(const Array< T, N > &arr, Sequence< Is... >)
Definition utility/container_helper.hpp:346
__host__ __device__ constexpr auto generate_tuple(F &&f, Number< N >)
Definition tuple_helper.hpp:21
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
LoopScheduler
Definition loop_scheduler.hpp:15
@ Default
Definition loop_scheduler.hpp:16
@ Interwave
Definition loop_scheduler.hpp:17
__global__ void kernel_contraction_multiple_abd_xdl_cshuffle(AsPointer p_as_grid, BsPointer p_bs_grid, DsPointer p_ds_grid, EDataType *__restrict__ p_e_grid, const AElementwiseOperation a_element_op, const BElementwiseOperation b_element_op, const CDEElementwiseOperation cde_element_op, const AsGridDesc_AK0_M_AK1 as_grid_desc_ak0_m_ak1, const BsGridDesc_BK0_N_BK1 bs_grid_desc_bk0_n_bk1, const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock ds_grid_desc_mblock_mperblock_nblock_nperblock, const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock, const Block2ETileMap block_2_etile_map)
Definition device_contraction_multiple_abd_xdl_cshuffle.hpp:42
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
PipelineVersion
Definition gridwise_gemm_pipeline_selector.hpp:18
@ v2
Definition gridwise_gemm_pipeline_selector.hpp:20
@ v1
Definition gridwise_gemm_pipeline_selector.hpp:19
constexpr LoopScheduler make_default_loop_scheduler()
Definition loop_scheduler.hpp:20
Definition ck/stream_config.hpp:10
Definition gridwise_gemm_multiple_abd_xdl_cshuffle.hpp:77
Definition utility/sequence.hpp:43
typename conditional< kHasContent, type0, type1 >::type type
Definition utility/sequence.hpp:271
Definition utility/integral_constant.hpp:20
Definition functional2.hpp:33
Definition device_base.hpp:197
Definition device_contraction_multiple_abd_xdl_cshuffle.hpp:417
EDataType * p_e_grid_
Definition device_contraction_multiple_abd_xdl_cshuffle.hpp:512
std::array< index_t, NumATensor > bs_continous_dim_
Definition device_contraction_multiple_abd_xdl_cshuffle.hpp:530
GridwiseGemm64::BsGridPointer p_bs_grid_
Definition device_contraction_multiple_abd_xdl_cshuffle.hpp:510
DsGridDesc_M_N ds_grid_desc_m_n_
Definition device_contraction_multiple_abd_xdl_cshuffle.hpp:517
std::array< index_t, NumATensor > as_continous_dim_
Definition device_contraction_multiple_abd_xdl_cshuffle.hpp:529
BsGridDesc_N_K bs_grid_desc_n_k_
Definition device_contraction_multiple_abd_xdl_cshuffle.hpp:516
AElementwiseOperation a_element_op_
Definition device_contraction_multiple_abd_xdl_cshuffle.hpp:524
BElementwiseOperation b_element_op_
Definition device_contraction_multiple_abd_xdl_cshuffle.hpp:525
EGridDesc_M_N e_grid_desc_m_n_
Definition device_contraction_multiple_abd_xdl_cshuffle.hpp:518
std::array< index_t, NumBTensor > ds_continous_dim_
Definition device_contraction_multiple_abd_xdl_cshuffle.hpp:531
Argument(std::array< const void *, NumATensor > p_as_grid, std::array< const void *, NumBTensor > p_bs_grid, std::array< const void *, NumDTensor > p_ds_grid, void *p_e_grid, const std::array< std::vector< index_t >, NumATensor > &a_ms_ks_lengths, const std::array< std::vector< index_t >, NumATensor > &a_ms_ks_strides, const std::array< std::vector< index_t >, NumBTensor > &b_ns_ks_lengths, const std::array< std::vector< index_t >, NumBTensor > &b_ns_ks_strides, const std::array< std::vector< index_t >, NumDTensor > &d_ms_ns_lengths, const std::array< std::vector< index_t >, NumDTensor > &d_ms_ns_strides, const std::vector< index_t > &e_ms_ns_length, const std::vector< index_t > &e_ms_ns_stride, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op)
Definition device_contraction_multiple_abd_xdl_cshuffle.hpp:418
index_t e_max_write_elems_
Definition device_contraction_multiple_abd_xdl_cshuffle.hpp:537
std::array< index_t, NumDTensor > ds_max_read_elems_
Definition device_contraction_multiple_abd_xdl_cshuffle.hpp:536
GridwiseGemm64::DsGridPointer p_ds_grid_
Definition device_contraction_multiple_abd_xdl_cshuffle.hpp:511
AsGridDesc_M_K as_grid_desc_m_k_
Definition device_contraction_multiple_abd_xdl_cshuffle.hpp:515
index_t e_continous_dim_
Definition device_contraction_multiple_abd_xdl_cshuffle.hpp:532
std::array< index_t, NumATensor > as_max_read_elems_
Definition device_contraction_multiple_abd_xdl_cshuffle.hpp:534
GridwiseGemm64::AsGridPointer p_as_grid_
Definition device_contraction_multiple_abd_xdl_cshuffle.hpp:509
std::array< index_t, NumBTensor > bs_max_read_elems_
Definition device_contraction_multiple_abd_xdl_cshuffle.hpp:535
Block2ETileMap block_2_etile_map_
Definition device_contraction_multiple_abd_xdl_cshuffle.hpp:521
CDEElementwiseOperation cde_element_op_
Definition device_contraction_multiple_abd_xdl_cshuffle.hpp:526
Definition device_contraction_multiple_abd_xdl_cshuffle.hpp:542
float RunImp(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition device_contraction_multiple_abd_xdl_cshuffle.hpp:546
INVOKER_RUN_IMPL float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition device_contraction_multiple_abd_xdl_cshuffle.hpp:625
DeviceOp::Argument Argument
Definition device_contraction_multiple_abd_xdl_cshuffle.hpp:543
Definition device_contraction_multiple_abd_xdl_cshuffle.hpp:164
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition device_contraction_multiple_abd_xdl_cshuffle.hpp:813
GridwiseGemmBase< NXdlPerWave32 > GridwiseGemm32
Definition device_contraction_multiple_abd_xdl_cshuffle.hpp:230
remove_cvref_t< decltype(MakeEGridDescriptor_M_N({}, {}))> EGridDesc_M_N
Definition device_contraction_multiple_abd_xdl_cshuffle.hpp:395
remove_cvref_t< decltype(GridwiseGemm64::MakeDefaultAsGridDescriptor_AK0_M_AK1( AsGridDesc_M_K{}))> AsGridDesc_AK0_M_AK1
Definition device_contraction_multiple_abd_xdl_cshuffle.hpp:398
static constexpr index_t NumATensor
Definition device_contraction_multiple_abd_xdl_cshuffle.hpp:171
static auto MakeDsGridDescriptor_M_N(const std::array< std::vector< index_t >, NumDTensor > &ds_ms_ns_lengths, const std::array< std::vector< index_t >, NumDTensor > &ds_ms_ns_strides)
Definition device_contraction_multiple_abd_xdl_cshuffle.hpp:381
remove_cvref_t< decltype(MakeBsGridDescriptor_N_K({}, {}))> BsGridDesc_N_K
Definition device_contraction_multiple_abd_xdl_cshuffle.hpp:393
DeviceContractionMultipleABD_Xdl_CShuffle DeviceOp
Definition device_contraction_multiple_abd_xdl_cshuffle.hpp:165
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition device_contraction_multiple_abd_xdl_cshuffle.hpp:737
__host__ static __device__ auto MakeAsGridDescriptor_M_K(const std::array< std::vector< index_t >, NumATensor > &as_ms_ks_lengths, const std::array< std::vector< index_t >, NumATensor > &as_ms_ks_strides)
Definition device_contraction_multiple_abd_xdl_cshuffle.hpp:277
static constexpr auto I2
Definition device_contraction_multiple_abd_xdl_cshuffle.hpp:177
std::unique_ptr< BaseArgument > MakeArgumentPointer(std::array< const void *, NumATensor > p_as, std::array< const void *, NumBTensor > p_bs, std::array< const void *, NumDTensor > p_ds, void *p_e, const std::array< std::vector< index_t >, NumATensor > &as_ms_ks_lengths, const std::array< std::vector< index_t >, NumATensor > &as_ms_ks_strides, const std::array< std::vector< index_t >, NumBTensor > &bs_ns_ks_lengths, const std::array< std::vector< index_t >, NumBTensor > &bs_ns_ks_strides, const std::array< std::vector< index_t >, NumDTensor > &ds_ms_ns_lengths, const std::array< std::vector< index_t >, NumDTensor > &ds_ms_ns_strides, const std::vector< index_t > &e_ms_ns_length, const std::vector< index_t > &e_ms_ns_stride, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op) override
Definition device_contraction_multiple_abd_xdl_cshuffle.hpp:779
static auto MakeBGridDescriptor_N_K(const std::vector< index_t > &b_ns_ks_lengths_, const std::vector< index_t > &b_ns_ks_strides_)
Definition device_contraction_multiple_abd_xdl_cshuffle.hpp:288
std::string GetTypeString() const override
Definition device_contraction_multiple_abd_xdl_cshuffle.hpp:819
remove_cvref_t< decltype(GridwiseGemm64::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( DsGridDesc_M_N{}))> DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
Definition device_contraction_multiple_abd_xdl_cshuffle.hpp:404
static constexpr auto I1
Definition device_contraction_multiple_abd_xdl_cshuffle.hpp:176
static bool IsSupportedArgument(const Argument &arg)
Definition device_contraction_multiple_abd_xdl_cshuffle.hpp:632
__host__ static __device__ auto MakeBsGridDescriptor_N_K(const std::array< std::vector< index_t >, NumBTensor > &bs_ns_ks_lengths, const std::array< std::vector< index_t >, NumBTensor > &bs_ns_ks_strides)
Definition device_contraction_multiple_abd_xdl_cshuffle.hpp:329
GridwiseGemmMultipleABD_xdl_cshuffle< AsDataType, BsDataType, ComputeDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, InMemoryDataOperationEnum::Set, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEBlockTransferScalarPerVector_NPerBlock, LoopSched, PipelineVer > GridwiseGemmBase
Definition device_contraction_multiple_abd_xdl_cshuffle.hpp:184
remove_cvref_t< decltype(MakeAsGridDescriptor_M_K({}, {}))> AsGridDesc_M_K
Definition device_contraction_multiple_abd_xdl_cshuffle.hpp:392
static constexpr auto matrix_padder
Definition device_contraction_multiple_abd_xdl_cshuffle.hpp:232
GridwiseGemmBase< math::max(NXdlPerWave64, 1)> GridwiseGemm64
Definition device_contraction_multiple_abd_xdl_cshuffle.hpp:229
remove_cvref_t< decltype(GridwiseGemm64::MakeDefaultBsGridDescriptor_BK0_N_BK1( BsGridDesc_N_K{}))> BsGridDesc_BK0_N_BK1
Definition device_contraction_multiple_abd_xdl_cshuffle.hpp:401
remove_cvref_t< decltype(GridwiseGemm64::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( EGridDesc_M_N{}))> EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
Definition device_contraction_multiple_abd_xdl_cshuffle.hpp:407
remove_cvref_t< decltype(MakeDsGridDescriptor_M_N({}, {}))> DsGridDesc_M_N
Definition device_contraction_multiple_abd_xdl_cshuffle.hpp:394
static constexpr auto NXdlPerWave32
Definition device_contraction_multiple_abd_xdl_cshuffle.hpp:169
static GET_NXDL_PER_WAVE_IMPL constexpr auto NXdlPerWave64
Definition device_contraction_multiple_abd_xdl_cshuffle.hpp:168
EDataType ComputeDataType
Definition device_contraction_multiple_abd_xdl_cshuffle.hpp:180
static constexpr auto I3
Definition device_contraction_multiple_abd_xdl_cshuffle.hpp:178
static auto MakeArgument(std::array< const void *, NumATensor > p_as, std::array< const void *, NumBTensor > p_bs, std::array< const void *, NumDTensor > p_ds, void *p_e, const std::array< std::vector< index_t >, NumATensor > &a_ms_ks_lengths, const std::array< std::vector< index_t >, NumATensor > &a_ms_ks_strides, const std::array< std::vector< index_t >, NumBTensor > &b_ns_ks_lengths, const std::array< std::vector< index_t >, NumBTensor > &b_ns_ks_strides, const std::array< std::vector< index_t >, NumDTensor > &d_ms_ns_lengths, const std::array< std::vector< index_t >, NumDTensor > &d_ms_ns_strides, const std::vector< index_t > &e_ms_ns_length, const std::vector< index_t > &e_ms_ns_stride, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op)
Definition device_contraction_multiple_abd_xdl_cshuffle.hpp:742
static constexpr index_t NumDTensor
Definition device_contraction_multiple_abd_xdl_cshuffle.hpp:173
static auto MakeAGridDescriptor_M_K(const std::vector< index_t > &a_ms_ks_lengths_, const std::vector< index_t > &a_ms_ks_strides_)
Definition device_contraction_multiple_abd_xdl_cshuffle.hpp:236
remove_cvref_t< decltype(GridwiseGemm64::MakeDefaultBlock2ETileMap(EGridDesc_M_N{}))> Block2ETileMap
Definition device_contraction_multiple_abd_xdl_cshuffle.hpp:412
static constexpr index_t NumBTensor
Definition device_contraction_multiple_abd_xdl_cshuffle.hpp:172
static auto MakeInvoker()
Definition device_contraction_multiple_abd_xdl_cshuffle.hpp:775
static constexpr auto I0
Definition device_contraction_multiple_abd_xdl_cshuffle.hpp:175
static auto MakeEGridDescriptor_M_N(const std::vector< index_t > &e_ms_ns_lengths_, const std::vector< index_t > &e_ms_ns_strides_)
Definition device_contraction_multiple_abd_xdl_cshuffle.hpp:340
Definition device_contraction_multiple_abd.hpp:34
Definition matrix_padder.hpp:180