device_batched_gemm_wmma_cshuffle_v3_b_scale.hpp Source File

device_batched_gemm_wmma_cshuffle_v3_b_scale.hpp Source File#

Composable Kernel: device_batched_gemm_wmma_cshuffle_v3_b_scale.hpp Source File
device_batched_gemm_wmma_cshuffle_v3_b_scale.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include <iostream>
7#include <sstream>
8
19
20namespace ck {
21namespace tensor_operation {
22namespace device {
23
24template <typename GridwiseGemm,
25 typename ComputePtrOffsetOfStridedBatch,
26 bool HasMainKBlockLoop,
27 InMemoryDataOperationEnum CGlobalMemoryDataOperation,
28 index_t MinimumOccupancy = 1,
30__global__ void
31#if CK_USE_LAUNCH_BOUNDS
32__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
33#endif
35 typename GridwiseGemm::Argument karg, // This works for now but it actually receives a
36 // DeviceBatchedGemm_Wmma_CShuffleV3::Argument
37 // argument through implicit conversion to base class!
38 const ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch)
39{
40#if(defined(__gfx11__) || defined(__gfx12__))
41#if defined(__gfx11__)
42 // gfx11 does not support *_atomic_pk_add_f16/bf16 instructions
43 using c_data_type = remove_cvref_t<remove_pointer_t<decltype(karg.p_e_grid)>>;
44 if constexpr(!(CGlobalMemoryDataOperation == InMemoryDataOperationEnum::AtomicAdd &&
45 (std::is_same_v<c_data_type, ck::half_t> ||
46 std::is_same_v<c_data_type, ck::bhalf_t>)))
47 {
48#endif
49 constexpr index_t LDS_size = GridwiseGemm::template GetSharedMemoryNumberOfByte<
50 typename GridwiseGemm::EpilogueCShuffle>();
51 // The normal approach to batching would be to increase the grid size by just stretching out
52 // the grid Z dimension (which is the outermost dimension), but this depends on lower level
53 // functions not directly using the Z dimension for other calculations. As it turns out, k
54 // batching does rely directly on blockIdx.Z through SplitKBatchOffset. Therefore, for now
55 // we will use the grid Y dimension for batching. This may be a bit fragile.
56 __shared__ char p_shared[LDS_size];
57
58 const index_t g_idx = amd_wave_read_first_lane(blockIdx.y);
59
60 const long_index_t a_batch_offset =
61 amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx));
62 const long_index_t b_batch_offset =
63 amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx));
64 const long_index_t c_batch_offset =
65 amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetCPtrOffset(g_idx));
66 const long_index_t b_scale_batch_offset =
67 amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetScaleBPtrOffset(g_idx));
68
69 auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
70
71 // shift A matrices pointer for splitk
72 typename GridwiseGemm::AsGridPointer p_as_grid_shift;
74 using ADataType_ =
75 remove_cvref_t<tuple_element_t<i.value, typename GridwiseGemm::AsDataType_>>;
76 p_as_grid_shift(i) = static_cast<const ADataType_*>(karg.p_as_grid[i]) +
77 splitk_batch_offset.a_k_split_offset[i] + a_batch_offset;
78 });
79
80 // shift B matrices pointer for splitk
81 typename GridwiseGemm::BsGridPointer p_bs_grid_shift;
83 using BDataType_ =
84 remove_cvref_t<tuple_element_t<i.value, typename GridwiseGemm::BsDataType_>>;
85 p_bs_grid_shift(i) = static_cast<const BDataType_*>(karg.p_bs_grid[i]) +
86 splitk_batch_offset.b_k_split_offset[i] + b_batch_offset;
87 });
88
89 auto epilogue_args = typename GridwiseGemm::EpilogueCShuffle{};
90
91 GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
92 p_as_grid_shift,
93 p_bs_grid_shift,
94 karg.p_ds_grid,
95 karg.p_e_grid + splitk_batch_offset.c_reduce_offset + c_batch_offset,
96 karg.p_b_scale_grid + b_scale_batch_offset + splitk_batch_offset.scale_k_split_offset,
97 p_shared,
98 karg,
99 karg.a_element_op,
100 karg.b_element_op,
101 karg.cde_element_op,
102 epilogue_args);
103#if defined(__gfx11__)
104 }
105#endif
106#else
107 ignore = karg;
108 ignore = compute_ptr_offset_of_batch;
109#endif
110}
111
208template <typename ALayout,
209 typename BLayout,
210 typename CLayout,
211 typename ADataType,
212 typename BDataType,
213 typename BScaleDataType,
214 typename CDataType,
215 typename AccDataType,
216 typename CShuffleDataType,
217 typename AElementwiseOperation,
218 typename BElementwiseOperation,
219 typename CElementwiseOperation,
220 GemmSpecialization GemmSpec,
221 index_t BlockSize,
222 index_t ScaleBlockN, // scale block for N
223 index_t ScaleBlockK, // scale block for K
224 index_t MPerBlock,
225 index_t NPerBlock,
226 index_t KPerBlock,
227 index_t AK1,
228 index_t BK1,
229 index_t MPerWmma,
230 index_t NPerWmma,
231 index_t MRepeat,
232 index_t NRepeat,
233 typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
234 typename ABlockTransferThreadClusterArrangeOrder,
235 typename ABlockTransferSrcAccessOrder,
236 index_t ABlockTransferSrcVectorDim,
237 index_t ABlockTransferSrcScalarPerVector,
238 index_t ABlockTransferDstScalarPerVector_AK1,
239 bool ABlockLdsExtraM,
240 typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
241 typename BBlockTransferThreadClusterArrangeOrder,
242 typename BBlockTransferSrcAccessOrder,
243 index_t BBlockTransferSrcVectorDim,
244 index_t BBlockTransferSrcScalarPerVector,
245 index_t BBlockTransferDstScalarPerVector_BK1,
246 bool BBlockLdsExtraN,
247 index_t CShuffleMRepeatPerShuffle,
248 index_t CShuffleNRepeatPerShuffle,
249 typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
250 index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
253 typename ComputeTypeA = CDataType,
254 typename ComputeTypeB = ComputeTypeA,
255 bool PermuteA = false,
256 bool PermuteB = false>
258 : public DeviceBatchedGemmV2BScale<ALayout,
259 BLayout,
260 CLayout,
261 ADataType,
262 BDataType,
263 BScaleDataType,
264 CDataType,
265 ScaleBlockN,
266 ScaleBlockK,
267 AElementwiseOperation,
268 BElementwiseOperation,
269 CElementwiseOperation>
270{
271 // We are inheriting from DeviceBatchedGemm and this base class does not support permuteA and
272 // permuteB arguments so for now we are not including this functionality.
273 static_assert(PermuteA == false,
274 "Permute A functionality not supported by DeviceBatchedGemm operations.\n");
275 static_assert(PermuteB == false,
276 "Permute B functionality not supported by DeviceBatchedGemm operations.\n");
277
279 {
281 index_t BatchStrideB,
282 index_t BatchStrideC,
283 index_t BatchStrideScaleB)
284 : BatchStrideA_(BatchStrideA),
285 BatchStrideB_(BatchStrideB),
286 BatchStrideC_(BatchStrideC),
287 BatchStrideScaleB_(BatchStrideScaleB)
288 {
289 }
290
291 __host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const
292 {
293 return g_idx * static_cast<long_index_t>(BatchStrideA_);
294 }
295
296 __host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const
297 {
298 return g_idx * static_cast<long_index_t>(BatchStrideB_) / GridwiseGemm::BPackedSize;
299 }
300
301 __host__ __device__ constexpr long_index_t GetCPtrOffset(index_t g_idx) const
302 {
303 return g_idx * static_cast<long_index_t>(BatchStrideC_);
304 }
305 __host__ __device__ constexpr long_index_t GetScaleBPtrOffset(index_t g_idx) const
306 {
307 return g_idx * static_cast<long_index_t>(BatchStrideScaleB_);
308 }
309
310 private:
311 index_t BatchStrideA_;
312 index_t BatchStrideB_;
313 index_t BatchStrideC_;
314 index_t BatchStrideScaleB_;
315 };
316
317 // GridwiseGemm
319 ALayout,
320 BLayout,
321 Tuple<>, // DsLayout
322 CLayout,
325 BScaleDataType,
326 AccDataType,
327 CShuffleDataType,
328 Tuple<>, // DsDataType
329 CDataType,
330 AElementwiseOperation,
331 BElementwiseOperation,
332 CElementwiseOperation,
333 GemmSpec,
334 BlockSize,
335 ScaleBlockN,
336 ScaleBlockK,
337 MPerBlock,
338 NPerBlock,
339 KPerBlock,
340 AK1,
341 BK1,
342 MPerWmma,
343 NPerWmma,
344 MRepeat,
345 NRepeat,
346 ABlockTransferThreadClusterLengths_AK0_M_AK1,
347 ABlockTransferThreadClusterArrangeOrder,
348 ABlockTransferSrcAccessOrder,
349 ABlockTransferSrcVectorDim,
350 ABlockTransferSrcScalarPerVector,
351 ABlockTransferDstScalarPerVector_AK1,
352 false,
353 ABlockLdsExtraM,
354 BBlockTransferThreadClusterLengths_BK0_N_BK1,
355 BBlockTransferThreadClusterArrangeOrder,
356 BBlockTransferSrcAccessOrder,
357 BBlockTransferSrcVectorDim,
358 BBlockTransferSrcScalarPerVector,
359 BBlockTransferDstScalarPerVector_BK1,
360 false,
361 BBlockLdsExtraN,
362 CShuffleMRepeatPerShuffle,
363 CShuffleNRepeatPerShuffle,
364 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
366 BlkGemmPipeSched,
367 BlkGemmPipelineVer,
368 ComputeTypeA,
369 ComputeTypeB,
370 PermuteA, // PermuteA not supported by DeviceBatchedGemm base class.
371 PermuteB>; // PermuteB not supported by DeviceBatchedGemm base class.
372
373 // Argument
374 struct Argument : public GridwiseGemm::Argument
375 {
376 __host__ Argument(const ADataType* p_a_grid_,
377 const BDataType* p_b_grid_,
378 CDataType* p_c_grid_,
379 index_t M_,
380 index_t N_,
381 index_t K_,
382 index_t StrideA_,
383 index_t StrideB_,
384 index_t StrideC_,
385 index_t StrideScaleB_,
386 index_t BatchStrideA_,
387 index_t BatchStrideB_,
388 index_t BatchStrideC_,
389 index_t BatchStrideScaleB_,
390 const BScaleDataType* p_b_scale_grid_,
391 index_t Batch_,
392 index_t k_batch_,
393 AElementwiseOperation a_element_op_,
394 BElementwiseOperation b_element_op_,
395 CElementwiseOperation c_element_op_,
396 bool is_reduce_ = false)
397 : GridwiseGemm::Argument(std::array<const void*, 1>{p_a_grid_},
398 std::array<const void*, 1>{p_b_grid_},
399 std::array<const void*, 0>{}, // p_ds_grid_
400 p_c_grid_,
401 M_,
402 N_,
403 K_,
404 std::array<index_t, 1>{StrideA_},
405 std::array<index_t, 1>{StrideB_},
406 std::array<index_t, 0>{}, // StrideDs_
407 StrideC_,
408 StrideScaleB_,
409 p_b_scale_grid_,
410 k_batch_,
411 a_element_op_,
412 b_element_op_,
413 c_element_op_,
414 is_reduce_),
415 Batch(Batch_),
417 BatchStrideA_, BatchStrideB_, BatchStrideC_, BatchStrideScaleB_}
418 {
419 }
420
423 };
424
434 struct Invoker : public BaseInvoker
435 {
441 float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
442 {
443 if(stream_config.log_level_ > 0)
444 {
445 arg.Print();
446 GridwiseGemm::BlockwiseGemmPipe::HotLoopInstList::Print();
447 }
448
450 {
451 throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
452 }
453
454 index_t gdx, gdy, gdz;
455 std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N, arg.KBatch);
456
457 // The normal approach to batching would be to increase the grid size by just stretching
458 // out the grid Z dimension (which is the outermost dimension), but this depends on
459 // lower level functions not directly using the Z dimension for other calculations. As
460 // it turns out, k batching does rely directly on blockIdx.Z through SplitKBatchOffset.
461 // Therefore, for now we will use the grid Y dimension for batching. This may be a bit
462 // fragile.
463 gdy *= arg.Batch;
464
465 float ave_time = 0;
466
467 index_t k_grain = arg.KBatch * KPerBlock;
468 index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock;
469
470 const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
471
472 const auto Run = [&](const auto& kernel) {
473 if(stream_config.flush_cache)
474 {
475 Argument arg_ = arg;
476
477 const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAsGridDescriptor_AK0_M_AK1(
478 arg_.M, arg_.MPadded, arg_.K, arg_.KPadded, arg_.StrideAs, arg_.AK0);
479 const auto b_grid_desc_bk0_n_bk1 = GridwiseGemm::MakeBsGridDescriptor_BK0_N_BK1(
480 arg_.K, arg_.KPadded, arg_.N, arg_.NPadded, arg_.StrideBs, arg_.BK0);
481
482 // Packed sizes are 1 for all implemented data types but we include it anyway
483 // for future compatibility.
484 // Note: the grid descriptors and size_a / size_b do *not* take batching into
485 // account, so we have to manually multiply overall buffer sizes for rotating
486 // memory by batch.
487 std::array<std::size_t, 1> size_as_buffers;
488 size_as_buffers[0] = a_grid_desc_ak0_m_ak1[Number<0>{}].GetElementSpaceSize() *
489 sizeof(ADataType) / GridwiseGemm::APackedSize * arg_.Batch;
490
491 std::array<std::size_t, 1> size_bs_buffers;
492 size_bs_buffers[0] = b_grid_desc_bk0_n_bk1[Number<0>{}].GetElementSpaceSize() *
493 sizeof(BDataType) / GridwiseGemm::BPackedSize * arg_.Batch;
494
496 Tuple<ADataType>,
497 Tuple<BDataType>,
498 Tuple<>>
499 rotating_mem(arg_,
500 stream_config.rotating_count,
501 size_as_buffers,
502 size_bs_buffers,
503 std::array<std::size_t, 0>{});
504 rotating_mem.Print();
505
506 auto run_flush_cache = [&]() {
508 rotating_mem.Next();
509 // clear c mem
510 if(arg_.KBatch > 1)
511 // Note: we multiply by batch since we want to clear the C matrix for
512 // the whole batch. Untested since we don't have k batching ATM.
513 // Note: This seems incorrect for non-contiguous memory layouts for C
514 // (padding, gaps).
516 hipMemsetAsync(arg_.p_e_grid,
517 0,
518 arg_.Batch * arg_.M * arg_.N * sizeof(CDataType),
519 stream_config.stream_id_));
520 };
521
523 stream_config,
524 run_flush_cache,
525 kernel,
526 dim3(gdx, gdy, gdz),
527 dim3(BlockSize),
528 0,
529 arg_,
530 arg_.compute_ptr_offset_of_batch);
531 }
532 else
533 {
534 auto clear_workspace = [&]() {
535 // clear c mem
536 if(arg.KBatch > 1)
537 // Note: we multiply by batch since we want to clear the C matrix for
538 // the whole batch. Untested since we don't have k batching ATM.
539 // Note: This seems incorrect for non-contiguous memory layouts for C
540 // (padding, gaps).
542 hipMemsetAsync(arg.p_e_grid,
543 0,
544 arg.Batch * arg.M * arg.N * sizeof(CDataType),
545 stream_config.stream_id_));
546 };
547
549 stream_config,
550 clear_workspace,
551 kernel,
552 dim3(gdx, gdy, gdz),
553 dim3(BlockSize),
554 0,
555 arg,
557 }
558 };
559
560 constexpr index_t minimum_occupancy = []() {
561 if constexpr(BlkGemmPipeSched == BlockGemmPipelineScheduler::Interwave)
562 {
563 return 2;
564 }
565 else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
566 {
567 return (MPerBlock * NPerBlock / BlockSize <= 128) ? 2 : 1;
568 }
569 else
570 {
571 return 1;
572 }
573 }();
574
575 if(has_main_k_block_loop)
576 {
577 // Tail number always full
578 if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 ||
579 BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
580 {
581 if(arg.KBatch > 1)
582 {
585 ComputePtrOffsetOfStridedBatch,
586 true,
588 minimum_occupancy>;
589 Run(kernel);
590 }
591 else
592 {
596 true,
598 minimum_occupancy>;
599 Run(kernel);
600 }
601 }
602 else
603 {
604 throw std::runtime_error("Pipeline not implemented");
605 }
606 }
607 else
608 {
609 // Tail number always 1
610 if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
611 {
612 if(arg.KBatch > 1)
613 {
616 ComputePtrOffsetOfStridedBatch,
617 false,
619 minimum_occupancy>;
620 Run(kernel);
621 }
622 else
623 {
627 false,
629 minimum_occupancy>;
630 Run(kernel);
631 }
632 }
633 }
634
635 return ave_time;
636 }
637
638 // polymorphic
639 float Run(const BaseArgument* p_arg,
640 const StreamConfig& stream_config = StreamConfig{}) override
641 {
642 return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
643 }
644 };
645
646 static constexpr bool IsValidCompilationParameter()
647 {
648 // TODO: properly implement this check
649 return true;
650 }
651
652 static bool IsSupportedArgument(const Argument& arg)
653 {
655 {
656 return false;
657 }
658
659 if constexpr(std::is_same_v<CDataType, ck::half_t> ||
660 std::is_same_v<CDataType, ck::bhalf_t>)
661 {
662 if(arg.KBatch > 1 && ck::is_gfx11_supported())
663 {
664 // gfx11 does not support *_atomic_pk_add_f16/bf16 instructions
665 return false;
666 }
667 }
668
669 if constexpr(std::is_same_v<ComputeTypeA, f8_t> || std::is_same_v<ComputeTypeA, bf8_t> ||
670 std::is_same_v<ComputeTypeB, f8_t> || std::is_same_v<ComputeTypeB, bf8_t>)
671 {
673 {
674 return false;
675 }
676 }
677
678 if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding ||
679 GemmSpec == GemmSpecialization::NKPadding ||
680 GemmSpec == GemmSpecialization::MNKPadding ||
681 GemmSpec == GemmSpecialization::KPadding))
682 {
683 return false;
684 }
685
686 return GridwiseGemm::CheckValidity(arg);
687 }
688
689 // polymorphic
690 bool IsSupportedArgument(const BaseArgument* p_arg) override
691 {
692 return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
693 }
694
695 index_t GetKPerBlock() override { return KPerBlock; }
696 bool GetPermuteB() override { return PermuteB; }
697
698 static auto MakeArgument(const ADataType* p_a,
699 const BDataType* p_b,
700 CDataType* p_c,
701 index_t M,
702 index_t N,
703 index_t K,
704 index_t StrideA,
705 index_t StrideB,
706 index_t StrideC,
707 index_t StrideScaleB,
708 index_t BatchStrideA,
709 index_t BatchStrideB,
710 index_t BatchStrideC,
711 index_t BatchStrideScaleB,
712 const BScaleDataType* p_b_scale,
713 index_t Batch,
714 AElementwiseOperation,
715 BElementwiseOperation,
716 CElementwiseOperation,
717 index_t KBatch = 1)
718 {
719 return Argument{p_a,
720 p_b,
721 p_c,
722 M,
723 N,
724 K,
725 StrideA,
726 StrideB,
727 StrideC,
728 StrideScaleB,
729 BatchStrideA,
730 BatchStrideB,
731 BatchStrideC,
732 BatchStrideScaleB,
733 p_b_scale,
734 Batch,
735 KBatch};
736 }
737
738 static auto MakeInvoker() { return Invoker{}; }
739
740 // polymorphic
741 std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
742 const void* p_b,
743 void* p_c,
744 index_t M,
745 index_t N,
746 index_t K,
747 index_t StrideA,
748 index_t StrideB,
749 index_t StrideC,
750 index_t StrideScaleB,
751 index_t BatchStrideA,
752 index_t BatchStrideB,
753 index_t BatchStrideC,
754 index_t BatchStrideScaleB,
755 const void* p_b_scale,
756 index_t Batch,
757 index_t KBatch,
758 AElementwiseOperation a_element_op,
759 BElementwiseOperation b_element_op,
760 CElementwiseOperation c_element_op) override
761 {
762 return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
763 static_cast<const BDataType*>(p_b),
764 static_cast<CDataType*>(p_c),
765 M,
766 N,
767 K,
768 StrideA,
769 StrideB,
770 StrideC,
771 StrideScaleB,
772 BatchStrideA,
773 BatchStrideB,
774 BatchStrideC,
775 BatchStrideScaleB,
776 static_cast<const BScaleDataType*>(p_b_scale),
777 Batch,
778 KBatch,
779 a_element_op,
780 b_element_op,
781 c_element_op);
782 }
783
784 // polymorphic
785 std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
786 {
787 return std::make_unique<Invoker>(Invoker{});
788 }
789
790 // polymorphic
791 std::string GetTypeString() const override
792 {
793 auto str = std::stringstream();
794
795 std::map<BlockGemmPipelineScheduler, std::string> BlkGemmPipelineSchedulerToString{
798
799 std::map<BlockGemmPipelineVersion, std::string> BlkGemmPipelineVersionToString{
805
806 // clang-format off
807 str << "DeviceBatchedGemm_Wmma_CShuffleV3_BScale"
808 << "<"
809 << getGemmSpecializationString(GemmSpec) << ", "
810 << std::string(ALayout::name)[0]
811 << std::string(BLayout::name)[0]
812 << std::string(CLayout::name)[0]
813 << ">"
814 << " BlkSize: "
815 << BlockSize << ", "
816 << "BlkTile: "
817 << MPerBlock << "x" << NPerBlock << "x" << KPerBlock << ", "
818 << "WaveTile: "
819 << MPerWmma << "x"<<NPerWmma << ", "
820 << "WaveMap: "
821 << MRepeat << "x" << NRepeat << ", "
822 << "VmemReadVec: "
823 << ABlockTransferSrcScalarPerVector << "x" << BBlockTransferSrcScalarPerVector << ", "
824 << "BlkGemmPipelineScheduler: "
825 << BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] << ", "
826 << "BlkGemmPipelineVersion: "
827 << BlkGemmPipelineVersionToString[BlkGemmPipelineVer] << ", "
828 << "BlkGemmPipelinePrefetchStages: "
829 << GridwiseGemm::BlockwiseGemmPipe::PrefetchStages << ", "
830 << "KPack: "
832 // clang-format on
833
834 return str.str();
835 }
837};
838
839} // namespace device
840} // namespace tensor_operation
841} // namespace ck
#define CK_MAX_THREAD_PER_BLOCK
Definition ck.hpp:30
#define REGISTER_EXTRA_PRINTING_METHODS
Definition device_base.hpp:47
#define HIP_CHECK_ERROR(retval_or_funcall)
Definition host_utility/hip_check_error.hpp:21
Definition convolution_backward_data_specialization.hpp:8
std::string getGemmSpecializationString(const GemmSpecialization &s)
Definition gemm_specialization.hpp:32
GemmSpecialization
Definition gemm_specialization.hpp:11
@ MKPadding
Definition gemm_specialization.hpp:18
@ KPadding
Definition gemm_specialization.hpp:16
@ MNKPadding
Definition gemm_specialization.hpp:20
@ NKPadding
Definition gemm_specialization.hpp:19
__global__ void kernel_batched_gemm_b_scale_wmma_cshuffle_v3(typename GridwiseGemm::Argument karg, const ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch)
Definition device_batched_gemm_wmma_cshuffle_v3_b_scale.hpp:34
Definition convolution_backward_data_specialization.hpp:7
void flush_icache()
Definition flush_cache.hpp:383
float launch_and_time_kernel_with_preprocess(const StreamConfig &stream_config, PreProcessFunc preprocess, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, GemmArgs &gemm_args, Args... args)
Definition flush_cache.hpp:398
Definition ck.hpp:268
int32_t index_t
Definition ck.hpp:299
InMemoryDataOperationEnum
Definition ck.hpp:277
@ Set
Definition ck.hpp:278
@ AtomicAdd
Definition ck.hpp:279
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
BlockGemmPipelineVersion
Definition blkgemmpipe_scheduler.hpp:12
@ v2
Definition blkgemmpipe_scheduler.hpp:15
@ v3
Definition blkgemmpipe_scheduler.hpp:16
@ v5
Definition blkgemmpipe_scheduler.hpp:18
@ v4
Definition blkgemmpipe_scheduler.hpp:17
@ v1
Definition blkgemmpipe_scheduler.hpp:14
integral_constant< index_t, N > Number
Definition number.hpp:12
TailNumber
Definition blkgemmpipe_scheduler.hpp:31
@ Full
Definition blkgemmpipe_scheduler.hpp:49
typename tuple_element< I, TTuple >::type tuple_element_t
Definition utility/tuple.hpp:208
constexpr detail::ignore_t ignore
Definition utility/ignore.hpp:20
__device__ uint32_t amd_wave_read_first_lane(uint32_t value)
Definition amd_wave_read_first_lane.hpp:100
bool is_gfx12_supported()
Definition host_utility/device_prop.hpp:55
BlockGemmPipelineScheduler
Definition blkgemmpipe_scheduler.hpp:25
@ Intrawave
Definition blkgemmpipe_scheduler.hpp:26
@ Interwave
Definition blkgemmpipe_scheduler.hpp:27
typename remove_reference< T >::type remove_reference_t
Definition type.hpp:292
int64_t long_index_t
Definition ck.hpp:300
typename remove_pointer< T >::type remove_pointer_t
Definition type.hpp:300
bool is_gfx11_supported()
Definition host_utility/device_prop.hpp:60
STL namespace.
Definition ck/stream_config.hpp:10
Definition gridwise_gemm_wmma_cshuffle_v3_b_scale.hpp:127
static __host__ constexpr bool CheckValidity(const Argument &karg)
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:624
static __host__ auto CalculateGridSize(index_t M, index_t N, index_t KBatch)
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:273
static constexpr index_t KPack
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:154
static __host__ constexpr bool CalculateHasMainKBlockLoop(index_t K)
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:837
Definition utility/sequence.hpp:43
Definition utility/tuple.hpp:186
Definition utility/tuple.hpp:117
Definition functional2.hpp:33
Definition device_base.hpp:197
Definition device_batched_gemm_wmma_cshuffle_v3_b_scale.hpp:375
ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch
Definition device_batched_gemm_wmma_cshuffle_v3_b_scale.hpp:422
index_t Batch
Definition device_batched_gemm_wmma_cshuffle_v3_b_scale.hpp:421
__host__ Argument(const ADataType *p_a_grid_, const BDataType *p_b_grid_, CDataType *p_c_grid_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, index_t StrideC_, index_t StrideScaleB_, index_t BatchStrideA_, index_t BatchStrideB_, index_t BatchStrideC_, index_t BatchStrideScaleB_, const BScaleDataType *p_b_scale_grid_, index_t Batch_, index_t k_batch_, AElementwiseOperation a_element_op_, BElementwiseOperation b_element_op_, CElementwiseOperation c_element_op_, bool is_reduce_=false)
Definition device_batched_gemm_wmma_cshuffle_v3_b_scale.hpp:376
Helper structure responsible for kernel invocation.
Definition device_batched_gemm_wmma_cshuffle_v3_b_scale.hpp:435
float Run(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
This function issues GPU kernel execution.
Definition device_batched_gemm_wmma_cshuffle_v3_b_scale.hpp:441
float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition device_batched_gemm_wmma_cshuffle_v3_b_scale.hpp:639
ComputePtrOffsetOfStridedBatch(index_t BatchStrideA, index_t BatchStrideB, index_t BatchStrideC, index_t BatchStrideScaleB)
Definition device_batched_gemm_wmma_cshuffle_v3_b_scale.hpp:280
__host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const
Definition device_batched_gemm_wmma_cshuffle_v3_b_scale.hpp:291
__host__ __device__ constexpr long_index_t GetScaleBPtrOffset(index_t g_idx) const
Definition device_batched_gemm_wmma_cshuffle_v3_b_scale.hpp:305
__host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const
Definition device_batched_gemm_wmma_cshuffle_v3_b_scale.hpp:296
__host__ __device__ constexpr long_index_t GetCPtrOffset(index_t g_idx) const
Definition device_batched_gemm_wmma_cshuffle_v3_b_scale.hpp:301
"Universal" Batched GEMM operation without SplitK support.
Definition device_batched_gemm_wmma_cshuffle_v3_b_scale.hpp:270
static constexpr bool IsValidCompilationParameter()
Definition device_batched_gemm_wmma_cshuffle_v3_b_scale.hpp:646
static auto MakeArgument(const ADataType *p_a, const BDataType *p_b, CDataType *p_c, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, index_t StrideC, index_t StrideScaleB, index_t BatchStrideA, index_t BatchStrideB, index_t BatchStrideC, index_t BatchStrideScaleB, const BScaleDataType *p_b_scale, index_t Batch, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, index_t KBatch=1)
Definition device_batched_gemm_wmma_cshuffle_v3_b_scale.hpp:698
GridwiseGemm_wmma_cshuffle_v3_b_scale< ALayout, BLayout, Tuple<>, CLayout, Tuple< ADataType >, Tuple< BDataType >, BScaleDataType, AccDataType, CShuffleDataType, Tuple<>, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, ScaleBlockN, ScaleBlockK, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerWmma, NPerWmma, MRepeat, NRepeat, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, Sequence< CShuffleBlockTransferScalarPerVector_NPerBlock >, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB > GridwiseGemm
Definition device_batched_gemm_wmma_cshuffle_v3_b_scale.hpp:318
std::string GetTypeString() const override
Definition device_batched_gemm_wmma_cshuffle_v3_b_scale.hpp:791
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition device_batched_gemm_wmma_cshuffle_v3_b_scale.hpp:785
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, void *p_c, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, index_t StrideC, index_t StrideScaleB, index_t BatchStrideA, index_t BatchStrideB, index_t BatchStrideC, index_t BatchStrideScaleB, const void *p_b_scale, index_t Batch, index_t KBatch, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op) override
Definition device_batched_gemm_wmma_cshuffle_v3_b_scale.hpp:741
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition device_batched_gemm_wmma_cshuffle_v3_b_scale.hpp:690
bool GetPermuteB() override
Definition device_batched_gemm_wmma_cshuffle_v3_b_scale.hpp:696
index_t GetKPerBlock() override
Definition device_batched_gemm_wmma_cshuffle_v3_b_scale.hpp:695
static bool IsSupportedArgument(const Argument &arg)
Definition device_batched_gemm_wmma_cshuffle_v3_b_scale.hpp:652
static auto MakeInvoker()
Definition device_batched_gemm_wmma_cshuffle_v3_b_scale.hpp:738
Definition device_batched_gemm.hpp:60
Definition flush_cache.hpp:21