device_batched_gemm_wmma_cshuffle_v3.hpp Source File

device_batched_gemm_wmma_cshuffle_v3.hpp Source File#

Composable Kernel: device_batched_gemm_wmma_cshuffle_v3.hpp Source File
device_batched_gemm_wmma_cshuffle_v3.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include <iostream>
7#include <sstream>
8
19
20namespace ck {
21namespace tensor_operation {
22namespace device {
23
24template <typename GridwiseGemm,
25 typename ComputePtrOffsetOfStridedBatch,
26 bool HasMainKBlockLoop,
27 InMemoryDataOperationEnum CGlobalMemoryDataOperation,
28 index_t MinimumOccupancy = 1,
30__global__ void
31#if CK_USE_LAUNCH_BOUNDS
32__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
33#endif
35 typename GridwiseGemm::Argument karg, // This works for now but it actually receives a
36 // DeviceBatchedGemm_Wmma_CShuffleV3::Argument
37 // argument through implicit conversion to base class!
38 const ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch)
39{
40#if(defined(__gfx11__) || defined(__gfx12__))
41#if defined(__gfx11__)
42 // gfx11 does not support *_atomic_pk_add_f16/bf16 instructions
43 using c_data_type = remove_cvref_t<remove_pointer_t<decltype(karg.p_e_grid)>>;
44 if constexpr(!(CGlobalMemoryDataOperation == InMemoryDataOperationEnum::AtomicAdd &&
45 (std::is_same_v<c_data_type, ck::half_t> ||
46 std::is_same_v<c_data_type, ck::bhalf_t>)))
47 {
48#endif
49 // The normal approach to batching would be to increase the grid size by just stretching out
50 // the grid Z dimension (which is the outermost dimension), but this depends on lower level
51 // functions not directly using the Z dimension for other calculations. As it turns out, k
52 // batching does rely directly on blockIdx.Z through SplitKBatchOffset. Therefore, for now
53 // we will use the grid Y dimension for batching. This may be a bit fragile.
54 const index_t g_idx = amd_wave_read_first_lane(blockIdx.y);
55
56 const long_index_t a_batch_offset =
57 amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx));
58 const long_index_t b_batch_offset =
59 amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx));
60 const long_index_t c_batch_offset =
61 amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetCPtrOffset(g_idx));
62
63 constexpr index_t LDS_size = GridwiseGemm::template GetSharedMemoryNumberOfByte<
64 typename GridwiseGemm::EpilogueCShuffle>();
65 __shared__ char p_shared[LDS_size];
66
67 auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
68
69 // shift A matrices pointer for splitk
70 typename GridwiseGemm::AsGridPointer p_as_grid_shift;
72 using ADataType_ =
73 remove_cvref_t<tuple_element_t<i.value, typename GridwiseGemm::AsDataType_>>;
74 p_as_grid_shift(i) = static_cast<const ADataType_*>(karg.p_as_grid[i]) +
75 splitk_batch_offset.a_k_split_offset[i] + a_batch_offset;
76 });
77
78 // shift B matrices pointer for splitk
79 typename GridwiseGemm::BsGridPointer p_bs_grid_shift;
81 using BDataType_ =
82 remove_cvref_t<tuple_element_t<i.value, typename GridwiseGemm::BsDataType_>>;
83 p_bs_grid_shift(i) = static_cast<const BDataType_*>(karg.p_bs_grid[i]) +
84 splitk_batch_offset.b_k_split_offset[i] + b_batch_offset;
85 });
86
87 auto epilogue_args = typename GridwiseGemm::EpilogueCShuffle{};
88
89 GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
90 p_as_grid_shift,
91 p_bs_grid_shift,
92 karg.p_ds_grid,
93 karg.p_e_grid + splitk_batch_offset.c_reduce_offset + c_batch_offset,
94 p_shared,
95 karg,
96 karg.a_element_op,
97 karg.b_element_op,
98 karg.cde_element_op,
99 epilogue_args);
100#if defined(__gfx11__)
101 }
102#endif
103#else
104 ignore = karg;
105 ignore = compute_ptr_offset_of_batch;
106#endif
107}
108
205template <typename ALayout,
206 typename BLayout,
207 typename CLayout,
208 typename ADataType,
209 typename BDataType,
210 typename CDataType,
211 typename AccDataType,
212 typename CShuffleDataType,
213 typename AElementwiseOperation,
214 typename BElementwiseOperation,
215 typename CElementwiseOperation,
216 GemmSpecialization GemmSpec,
217 index_t BlockSize,
218 index_t MPerBlock,
219 index_t NPerBlock,
220 index_t KPerBlock,
221 index_t AK1,
222 index_t BK1,
223 index_t MPerWmma,
224 index_t NPerWmma,
225 index_t MRepeat,
226 index_t NRepeat,
227 typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
228 typename ABlockTransferThreadClusterArrangeOrder,
229 typename ABlockTransferSrcAccessOrder,
230 index_t ABlockTransferSrcVectorDim,
231 index_t ABlockTransferSrcScalarPerVector,
232 index_t ABlockTransferDstScalarPerVector_AK1,
233 bool ABlockLdsExtraM,
234 typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
235 typename BBlockTransferThreadClusterArrangeOrder,
236 typename BBlockTransferSrcAccessOrder,
237 index_t BBlockTransferSrcVectorDim,
238 index_t BBlockTransferSrcScalarPerVector,
239 index_t BBlockTransferDstScalarPerVector_BK1,
240 bool BBlockLdsExtraN,
241 index_t CShuffleMRepeatPerShuffle,
242 index_t CShuffleNRepeatPerShuffle,
243 typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
244 index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
247 typename ComputeTypeA = CDataType,
248 typename ComputeTypeB = ComputeTypeA,
249 bool PermuteA = false,
250 bool PermuteB = false>
252 BLayout,
253 CLayout,
254 ADataType,
255 BDataType,
256 CDataType,
257 AElementwiseOperation,
258 BElementwiseOperation,
259 CElementwiseOperation>
260{
261 // We are inheriting from DeviceBatchedGemm and this base class does not support permuteA and
262 // permuteB arguments so for now we are not including this functionality.
263 static_assert(PermuteA == false,
264 "Permute A functionality not supported by DeviceBatchedGemm operations.\n");
265 static_assert(PermuteB == false,
266 "Permute B functionality not supported by DeviceBatchedGemm operations.\n");
267
269 {
271 index_t BatchStrideB,
272 index_t BatchStrideC)
273 : BatchStrideA_(BatchStrideA), BatchStrideB_(BatchStrideB), BatchStrideC_(BatchStrideC)
274 {
275 }
276
277 __host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const
278 {
279 return g_idx * static_cast<long_index_t>(BatchStrideA_);
280 }
281
282 __host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const
283 {
284 return g_idx * static_cast<long_index_t>(BatchStrideB_);
285 }
286
287 __host__ __device__ constexpr long_index_t GetCPtrOffset(index_t g_idx) const
288 {
289 return g_idx * static_cast<long_index_t>(BatchStrideC_);
290 }
291
292 private:
293 index_t BatchStrideA_;
294 index_t BatchStrideB_;
295 index_t BatchStrideC_;
296 };
297
298 // GridwiseGemm
300 ALayout,
301 BLayout,
302 Tuple<>, // DsLayout
303 CLayout,
306 AccDataType,
307 CShuffleDataType,
308 Tuple<>, // DsDataType
309 CDataType,
310 AElementwiseOperation,
311 BElementwiseOperation,
312 CElementwiseOperation,
313 GemmSpec,
314 BlockSize,
315 MPerBlock,
316 NPerBlock,
317 KPerBlock,
318 AK1,
319 BK1,
320 MPerWmma,
321 NPerWmma,
322 MRepeat,
323 NRepeat,
324 ABlockTransferThreadClusterLengths_AK0_M_AK1,
325 ABlockTransferThreadClusterArrangeOrder,
326 ABlockTransferSrcAccessOrder,
327 ABlockTransferSrcVectorDim,
328 ABlockTransferSrcScalarPerVector,
329 ABlockTransferDstScalarPerVector_AK1,
330 false,
331 ABlockLdsExtraM,
332 BBlockTransferThreadClusterLengths_BK0_N_BK1,
333 BBlockTransferThreadClusterArrangeOrder,
334 BBlockTransferSrcAccessOrder,
335 BBlockTransferSrcVectorDim,
336 BBlockTransferSrcScalarPerVector,
337 BBlockTransferDstScalarPerVector_BK1,
338 false,
339 BBlockLdsExtraN,
340 CShuffleMRepeatPerShuffle,
341 CShuffleNRepeatPerShuffle,
342 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
344 BlkGemmPipeSched,
345 BlkGemmPipelineVer,
346 ComputeTypeA,
347 ComputeTypeB,
348 false, // PermuteA not supported by DeviceBatchedGemm base class.
349 false>; // PermuteB not supported by DeviceBatchedGemm base class.
350
351 // Argument
352 struct Argument : public GridwiseGemm::Argument
353 {
354 __host__ Argument(const ADataType* p_a_grid_,
355 const BDataType* p_b_grid_,
356 CDataType* p_c_grid_,
357 index_t M_,
358 index_t N_,
359 index_t K_,
360 index_t StrideA_,
361 index_t StrideB_,
362 index_t StrideC_,
363 index_t BatchStrideA_,
364 index_t BatchStrideB_,
365 index_t BatchStrideC_,
366 index_t Batch_,
367 index_t k_batch_,
368 AElementwiseOperation a_element_op_,
369 BElementwiseOperation b_element_op_,
370 CElementwiseOperation cde_element_op_,
371 bool is_reduce_ = false)
372 : GridwiseGemm::Argument(std::array<const void*, 1>{p_a_grid_},
373 std::array<const void*, 1>{p_b_grid_},
374 std::array<const void*, 0>{}, // p_ds_grid_
375 p_c_grid_,
376 M_,
377 N_,
378 K_,
379 std::array<index_t, 1>{StrideA_},
380 std::array<index_t, 1>{StrideB_},
381 std::array<index_t, 0>{}, // StrideDs_
382 StrideC_,
383 k_batch_,
384 a_element_op_,
385 b_element_op_,
386 cde_element_op_,
387 is_reduce_),
388 Batch(Batch_),
389 compute_ptr_offset_of_batch{BatchStrideA_, BatchStrideB_, BatchStrideC_}
390 {
391 }
392
395 };
396
406 struct Invoker : public BaseInvoker
407 {
413 float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
414 {
415 if(stream_config.log_level_ > 0)
416 {
417 arg.Print();
418 GridwiseGemm::BlockwiseGemmPipe::HotLoopInstList::Print();
419 }
420
422 {
423 throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
424 }
425
426 index_t gdx, gdy, gdz;
427 std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N, arg.KBatch);
428
429 // The normal approach to batching would be to increase the grid size by just stretching
430 // out the grid Z dimension (which is the outermost dimension), but this depends on
431 // lower level functions not directly using the Z dimension for other calculations. As
432 // it turns out, k batching does rely directly on blockIdx.Z through SplitKBatchOffset.
433 // Therefore, for now we will use the grid Y dimension for batching. This may be a bit
434 // fragile.
435 gdy *= arg.Batch;
436
437 float ave_time = 0;
438
439 index_t k_grain = arg.KBatch * KPerBlock;
440 index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock;
441
442 const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
443
444 const auto Run = [&](const auto& kernel) {
445 if(stream_config.flush_cache)
446 {
447 Argument arg_ = arg;
448
449 const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAsGridDescriptor_AK0_M_AK1(
450 arg_.M, arg_.MPadded, arg_.K, arg_.KPadded, arg_.StrideAs, arg_.AK0);
451 const auto b_grid_desc_bk0_n_bk1 = GridwiseGemm::MakeBsGridDescriptor_BK0_N_BK1(
452 arg_.K, arg_.KPadded, arg_.N, arg_.NPadded, arg_.StrideBs, arg_.BK0);
453
454 // Packed sizes are 1 for all implemented data types but we include it anyway
455 // for future compatibility.
456 // Note: the grid descriptors and size_a / size_b do *not* take batching into
457 // account, so we have to manually multiply overall buffer sizes for rotating
458 // memory by batch.
459 std::array<std::size_t, 1> size_as_buffers;
460 size_as_buffers[0] = a_grid_desc_ak0_m_ak1[Number<0>{}].GetElementSpaceSize() *
461 sizeof(ADataType) / GridwiseGemm::APackedSize * arg_.Batch;
462
463 std::array<std::size_t, 1> size_bs_buffers;
464 size_bs_buffers[0] = b_grid_desc_bk0_n_bk1[Number<0>{}].GetElementSpaceSize() *
465 sizeof(BDataType) / GridwiseGemm::BPackedSize * arg_.Batch;
466
468 Tuple<ADataType>,
469 Tuple<BDataType>,
470 Tuple<>>
471 rotating_mem(arg_,
472 stream_config.rotating_count,
473 size_as_buffers,
474 size_bs_buffers,
475 std::array<std::size_t, 0>{});
476 rotating_mem.Print();
477
478 auto run_flush_cache = [&]() {
479 // flush icache
481 // rotating mem
482 rotating_mem.Next();
483 // clear c mem
484 if(arg_.KBatch > 1)
485 // Note: we multiply by batch since we want to clear the C matrix for
486 // the whole batch. Untested since we don't have k batching ATM.
487 // Note: This seems incorrect for non-contiguous memory layouts for C
488 // (padding, gaps).
490 hipMemsetAsync(arg_.p_e_grid,
491 0,
492 arg_.Batch * arg_.M * arg_.N * sizeof(CDataType),
493 stream_config.stream_id_));
494 };
495
497 stream_config,
498 run_flush_cache,
499 kernel,
500 dim3(gdx, gdy, gdz),
501 dim3(BlockSize),
502 0,
503 arg_,
504 arg_.compute_ptr_offset_of_batch);
505 }
506 else
507 {
508 auto clear_workspace = [&]() {
509 // clear c mem
510 if(arg.KBatch > 1)
511 // Note: we multiply by batch since we want to clear the C matrix for
512 // the whole batch. Untested since we don't have k batching ATM.
513 // Note: This seems incorrect for non-contiguous memory layouts for C
514 // (padding, gaps).
516 hipMemsetAsync(arg.p_e_grid,
517 0,
518 arg.Batch * arg.M * arg.N * sizeof(CDataType),
519 stream_config.stream_id_));
520 };
521
523 stream_config,
524 clear_workspace,
525 kernel,
526 dim3(gdx, gdy, gdz),
527 dim3(BlockSize),
528 0,
529 arg,
531 }
532 };
533
534 constexpr index_t minimum_occupancy = []() {
535 if constexpr(BlkGemmPipeSched == BlockGemmPipelineScheduler::Interwave)
536 {
537 return 2;
538 }
539 else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
540 {
541 return (MPerBlock * NPerBlock / BlockSize <= 128) ? 2 : 1;
542 }
543 else
544 {
545 return 1;
546 }
547 }();
548
549 if(has_main_k_block_loop)
550 {
551 // Tail number always full
552 if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 ||
553 BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
554 {
555 if(arg.KBatch > 1)
556 {
557 const auto kernel = kernel_batched_gemm_wmma_cshuffle_v3<
559 ComputePtrOffsetOfStridedBatch,
560 true,
562 minimum_occupancy>;
563 Run(kernel);
564 }
565 else
566 {
567 const auto kernel = kernel_batched_gemm_wmma_cshuffle_v3<
570 true,
572 minimum_occupancy>;
573 Run(kernel);
574 }
575 }
576 else
577 {
578 // TODO: Implement
579 }
580 }
581 else
582 {
583 // Tail number always 1
584 if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
585 {
586 if(arg.KBatch > 1)
587 {
588 const auto kernel = kernel_batched_gemm_wmma_cshuffle_v3<
590 ComputePtrOffsetOfStridedBatch,
591 false,
593 minimum_occupancy>;
594 Run(kernel);
595 }
596 else
597 {
598 const auto kernel = kernel_batched_gemm_wmma_cshuffle_v3<
601 false,
603 minimum_occupancy>;
604 Run(kernel);
605 }
606 }
607 }
608
609 return ave_time;
610 }
611
612 // polymorphic
613 float Run(const BaseArgument* p_arg,
614 const StreamConfig& stream_config = StreamConfig{}) override
615 {
616 return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
617 }
618 };
619
620 static constexpr bool IsValidCompilationParameter()
621 {
622 // TODO: properly implement this check
623 return true;
624 }
625
626 static bool IsSupportedArgument(const Argument& arg)
627 {
629 {
630 return false;
631 }
632
633 if constexpr(std::is_same_v<CDataType, ck::half_t> ||
634 std::is_same_v<CDataType, ck::bhalf_t>)
635 {
636 if(arg.KBatch > 1 && ck::is_gfx11_supported())
637 {
638 // gfx11 does not support *_atomic_pk_add_f16/bf16 instructions
639 return false;
640 }
641 }
642
643 if constexpr(std::is_same_v<ComputeTypeA, f8_t> || std::is_same_v<ComputeTypeA, bf8_t> ||
644 std::is_same_v<ComputeTypeB, f8_t> || std::is_same_v<ComputeTypeB, bf8_t>)
645 {
647 {
648 return false;
649 }
650 }
651
652 if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding ||
653 GemmSpec == GemmSpecialization::NKPadding ||
654 GemmSpec == GemmSpecialization::MNKPadding ||
655 GemmSpec == GemmSpecialization::KPadding))
656 {
657 return false;
658 }
659
660 return GridwiseGemm::CheckValidity(arg);
661 }
662
663 // polymorphic
664 bool IsSupportedArgument(const BaseArgument* p_arg) override
665 {
666 return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
667 }
668
669 // TODO: This is not part of the DeviceBatchedGemm base class but it was part of
670 // DeviceBatchedGemmV2. Remove?
671 // index_t GetKPerBlock() override { return KPerBlock; }
672 // bool GetPermuteA() override { return PermuteA; }
673 // bool GetPermuteB() override { return PermuteB; }
674
675 static auto MakeArgument(const ADataType* p_a,
676 const BDataType* p_b,
677 CDataType* p_c,
678 index_t M,
679 index_t N,
680 index_t K,
681 index_t StrideA,
682 index_t StrideB,
683 index_t StrideC,
684 index_t BatchStrideA,
685 index_t BatchStrideB,
686 index_t BatchStrideC,
687 index_t Batch,
688 AElementwiseOperation,
689 BElementwiseOperation,
690 CElementwiseOperation)
691 {
692 return Argument{p_a,
693 p_b,
694 p_c,
695 M,
696 N,
697 K,
698 StrideA,
699 StrideB,
700 StrideC,
701 BatchStrideA,
702 BatchStrideB,
703 BatchStrideC,
704 Batch,
705 1, /* KBatch */
706 AElementwiseOperation{},
707 BElementwiseOperation{},
708 CElementwiseOperation{}};
709 }
710
711 static auto MakeInvoker() { return Invoker{}; }
712
713 // polymorphic
714 std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
715 const void* p_b,
716 void* p_c,
717 index_t M,
718 index_t N,
719 index_t K,
720 index_t StrideA,
721 index_t StrideB,
722 index_t StrideC,
723 index_t BatchStrideA,
724 index_t BatchStrideB,
725 index_t BatchStrideC,
726 index_t Batch,
727 AElementwiseOperation,
728 BElementwiseOperation,
729 CElementwiseOperation) override
730 {
731 return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
732 static_cast<const BDataType*>(p_b),
733 static_cast<CDataType*>(p_c),
734 M,
735 N,
736 K,
737 StrideA,
738 StrideB,
739 StrideC,
740 BatchStrideA,
741 BatchStrideB,
742 BatchStrideC,
743 Batch,
744 1,
745 AElementwiseOperation{},
746 BElementwiseOperation{},
747 CElementwiseOperation{}); // KBatch
748 }
749
750 // polymorphic
751 std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
752 {
753 return std::make_unique<Invoker>(Invoker{});
754 }
755
756 // polymorphic
757 std::string GetTypeString() const override
758 {
759 auto str = std::stringstream();
760
761 std::map<BlockGemmPipelineScheduler, std::string> BlkGemmPipelineSchedulerToString{
764
765 std::map<BlockGemmPipelineVersion, std::string> BlkGemmPipelineVersionToString{
771
772 // clang-format off
773 str << "DeviceBatchedGemm_Wmma_CShuffleV3"
774 << "<"
775 << getGemmSpecializationString(GemmSpec) << ", "
776 << std::string(ALayout::name)[0]
777 << std::string(BLayout::name)[0]
778 << std::string(CLayout::name)[0]
779 << ">"
780 << " BlkSize: "
781 << BlockSize << ", "
782 << "BlkTile: "
783 << MPerBlock << "x" << NPerBlock << "x" << KPerBlock << ", "
784 << "WaveTile: "
785 << MPerWmma << "x"<<NPerWmma << ", "
786 << "WaveMap: "
787 << MRepeat << "x" << NRepeat << ", "
788 << "VmemReadVec: "
789 << ABlockTransferSrcScalarPerVector << "x" << BBlockTransferSrcScalarPerVector << ", "
790 << "BlkGemmPipelineScheduler: "
791 << BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] << ", "
792 << "BlkGemmPipelineVersion: "
793 << BlkGemmPipelineVersionToString[BlkGemmPipelineVer] << ", "
794 << "BlkGemmPipelinePrefetchStages: "
795 << GridwiseGemm::BlockwiseGemmPipe::PrefetchStages << ", "
796 << "KPack: "
798 // clang-format on
799
800 return str.str();
801 }
803};
804
805} // namespace device
806} // namespace tensor_operation
807} // namespace ck
#define CK_MAX_THREAD_PER_BLOCK
Definition ck.hpp:30
#define REGISTER_EXTRA_PRINTING_METHODS
Definition device_base.hpp:47
#define HIP_CHECK_ERROR(retval_or_funcall)
Definition host_utility/hip_check_error.hpp:21
Definition convolution_backward_data_specialization.hpp:8
std::string getGemmSpecializationString(const GemmSpecialization &s)
Definition gemm_specialization.hpp:32
GemmSpecialization
Definition gemm_specialization.hpp:11
@ MKPadding
Definition gemm_specialization.hpp:18
@ KPadding
Definition gemm_specialization.hpp:16
@ MNKPadding
Definition gemm_specialization.hpp:20
@ NKPadding
Definition gemm_specialization.hpp:19
__global__ void kernel_batched_gemm_wmma_cshuffle_v3(typename GridwiseGemm::Argument karg, const ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch)
Definition device_batched_gemm_wmma_cshuffle_v3.hpp:34
Definition convolution_backward_data_specialization.hpp:7
void flush_icache()
Definition flush_cache.hpp:383
float launch_and_time_kernel_with_preprocess(const StreamConfig &stream_config, PreProcessFunc preprocess, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, GemmArgs &gemm_args, Args... args)
Definition flush_cache.hpp:398
Definition ck.hpp:268
int32_t index_t
Definition ck.hpp:299
InMemoryDataOperationEnum
Definition ck.hpp:277
@ Set
Definition ck.hpp:278
@ AtomicAdd
Definition ck.hpp:279
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
BlockGemmPipelineVersion
Definition blkgemmpipe_scheduler.hpp:12
@ v2
Definition blkgemmpipe_scheduler.hpp:15
@ v3
Definition blkgemmpipe_scheduler.hpp:16
@ v5
Definition blkgemmpipe_scheduler.hpp:18
@ v4
Definition blkgemmpipe_scheduler.hpp:17
@ v1
Definition blkgemmpipe_scheduler.hpp:14
integral_constant< index_t, N > Number
Definition number.hpp:12
TailNumber
Definition blkgemmpipe_scheduler.hpp:31
@ Full
Definition blkgemmpipe_scheduler.hpp:49
typename tuple_element< I, TTuple >::type tuple_element_t
Definition utility/tuple.hpp:208
constexpr detail::ignore_t ignore
Definition utility/ignore.hpp:20
__device__ uint32_t amd_wave_read_first_lane(uint32_t value)
Definition amd_wave_read_first_lane.hpp:100
bool is_gfx12_supported()
Definition host_utility/device_prop.hpp:55
BlockGemmPipelineScheduler
Definition blkgemmpipe_scheduler.hpp:25
@ Intrawave
Definition blkgemmpipe_scheduler.hpp:26
@ Interwave
Definition blkgemmpipe_scheduler.hpp:27
typename remove_reference< T >::type remove_reference_t
Definition type.hpp:292
int64_t long_index_t
Definition ck.hpp:300
typename remove_pointer< T >::type remove_pointer_t
Definition type.hpp:300
bool is_gfx11_supported()
Definition host_utility/device_prop.hpp:60
STL namespace.
Definition ck/stream_config.hpp:10
static __host__ constexpr bool CheckValidity(const Argument &karg)
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:624
static __host__ auto CalculateGridSize(index_t M, index_t N, index_t KBatch)
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:273
static constexpr index_t KPack
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:154
static __host__ constexpr bool CalculateHasMainKBlockLoop(index_t K)
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:837
"Universal" GEMM kernel with SplitK support.
Definition gridwise_gemm_wmma_cshuffle_v3.hpp:233
Definition utility/sequence.hpp:43
Definition utility/tuple.hpp:186
Definition utility/tuple.hpp:117
Definition functional2.hpp:33
Definition device_base.hpp:197
Definition device_batched_gemm_wmma_cshuffle_v3.hpp:353
ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch
Definition device_batched_gemm_wmma_cshuffle_v3.hpp:394
index_t Batch
Definition device_batched_gemm_wmma_cshuffle_v3.hpp:393
__host__ Argument(const ADataType *p_a_grid_, const BDataType *p_b_grid_, CDataType *p_c_grid_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, index_t StrideC_, index_t BatchStrideA_, index_t BatchStrideB_, index_t BatchStrideC_, index_t Batch_, index_t k_batch_, AElementwiseOperation a_element_op_, BElementwiseOperation b_element_op_, CElementwiseOperation cde_element_op_, bool is_reduce_=false)
Definition device_batched_gemm_wmma_cshuffle_v3.hpp:354
ComputePtrOffsetOfStridedBatch(index_t BatchStrideA, index_t BatchStrideB, index_t BatchStrideC)
Definition device_batched_gemm_wmma_cshuffle_v3.hpp:270
__host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const
Definition device_batched_gemm_wmma_cshuffle_v3.hpp:277
__host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const
Definition device_batched_gemm_wmma_cshuffle_v3.hpp:282
__host__ __device__ constexpr long_index_t GetCPtrOffset(index_t g_idx) const
Definition device_batched_gemm_wmma_cshuffle_v3.hpp:287
Helper structure responsible for kernel invocation.
Definition device_batched_gemm_wmma_cshuffle_v3.hpp:407
float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition device_batched_gemm_wmma_cshuffle_v3.hpp:613
float Run(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
This function issues GPU kernel execution.
Definition device_batched_gemm_wmma_cshuffle_v3.hpp:413
"Universal" Batched GEMM operation without SplitK support.
Definition device_batched_gemm_wmma_cshuffle_v3.hpp:260
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition device_batched_gemm_wmma_cshuffle_v3.hpp:751
static constexpr bool IsValidCompilationParameter()
Definition device_batched_gemm_wmma_cshuffle_v3.hpp:620
static auto MakeInvoker()
Definition device_batched_gemm_wmma_cshuffle_v3.hpp:711
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition device_batched_gemm_wmma_cshuffle_v3.hpp:664
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, void *p_c, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, index_t StrideC, index_t BatchStrideA, index_t BatchStrideB, index_t BatchStrideC, index_t Batch, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation) override
Definition device_batched_gemm_wmma_cshuffle_v3.hpp:714
std::string GetTypeString() const override
Definition device_batched_gemm_wmma_cshuffle_v3.hpp:757
static auto MakeArgument(const ADataType *p_a, const BDataType *p_b, CDataType *p_c, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, index_t StrideC, index_t BatchStrideA, index_t BatchStrideB, index_t BatchStrideC, index_t Batch, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation)
Definition device_batched_gemm_wmma_cshuffle_v3.hpp:675
static bool IsSupportedArgument(const Argument &arg)
Definition device_batched_gemm_wmma_cshuffle_v3.hpp:626
GridwiseGemm_wmma_cshuffle_v3< ALayout, BLayout, Tuple<>, CLayout, Tuple< ADataType >, Tuple< BDataType >, AccDataType, CShuffleDataType, Tuple<>, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerWmma, NPerWmma, MRepeat, NRepeat, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, Sequence< CShuffleBlockTransferScalarPerVector_NPerBlock >, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, false, false > GridwiseGemm
Definition device_batched_gemm_wmma_cshuffle_v3.hpp:299
Definition device_batched_gemm.hpp:25
Definition flush_cache.hpp:21