device_moe_gemm_blockscale.hpp Source File

device_moe_gemm_blockscale.hpp Source File#

Composable Kernel: device_moe_gemm_blockscale.hpp Source File
device_moe_gemm_blockscale.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include <iostream>
7#include <sstream>
8#include <hip/hip_runtime.h>
9
20
21namespace ck {
22namespace tensor_operation {
23namespace device {
24
25template <typename ALayout,
26 typename BLayout,
27 typename DsLayout,
28 typename CLayout,
29 typename ADataType,
30 typename AScaleDataType,
31 typename BDataType,
32 typename BScaleDataType,
33 typename DsDataType,
34 typename CDataType,
35 typename GemmAccDataType,
36 typename CShuffleDataType,
37 typename AElementwiseOperation,
38 typename BElementwiseOperation,
39 typename CElementwiseOperation,
40 GemmSpecialization GemmSpec,
41 index_t BlockSize,
42 index_t ScaleBlockM,
43 index_t ScaleBlockN,
44 index_t ScaleBlockK,
45 index_t MPerBlock,
46 index_t NPerBlock,
47 index_t KPerBlock,
48 index_t AK1,
49 index_t BK1,
50 index_t MPerXDL,
51 index_t NPerXDL,
52 index_t MXdlPerWave,
53 index_t NXdlPerWave,
54 typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
55 typename ABlockTransferThreadClusterArrangeOrder,
56 typename ABlockTransferSrcAccessOrder,
57 index_t ABlockTransferSrcVectorDim,
58 index_t ABlockTransferSrcScalarPerVector,
59 index_t ABlockTransferDstScalarPerVector_AK1,
60 bool ABlockLdsExtraM,
61 typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
62 typename BBlockTransferThreadClusterArrangeOrder,
63 typename BBlockTransferSrcAccessOrder,
64 index_t BBlockTransferSrcVectorDim,
65 index_t BBlockTransferSrcScalarPerVector,
66 index_t BBlockTransferDstScalarPerVector_BK1,
67 bool BBlockLdsExtraN,
68 index_t CShuffleMXdlPerWavePerShuffle,
69 index_t CShuffleNXdlPerWavePerShuffle,
70 typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
71 typename CDEShuffleBlockTransferScalarPerVectors,
74 index_t ActivationOP = 0,
75 bool NSwizzle = false,
76 bool IsInputGemm = true,
77 bool MulRoutedWeight = false,
78 typename IndexType = index_t,
79 typename ComputeTypeA = CDataType,
80 typename ComputeTypeB = ComputeTypeA,
81 typename LDSTypeA = ComputeTypeA,
82 typename LDSTypeB = ComputeTypeB>
85 BLayout,
86 DsLayout,
87 CLayout,
88 ADataType,
89 AScaleDataType,
90 BDataType,
91 BScaleDataType,
92 DsDataType,
93 CDataType,
94 ScaleBlockM,
95 ScaleBlockN,
96 ScaleBlockK,
97 AElementwiseOperation,
98 BElementwiseOperation,
99 CElementwiseOperation>
100{
102 static constexpr auto NXdlPerWave64 = GetNXdlPerWave<true>();
103 static constexpr auto NXdlPerWave32 = GetNXdlPerWave<false>();
104 static constexpr index_t NumDTensor = DsDataType::Size();
105 template <index_t NXdlPerWave_>
107 ALayout,
108 BLayout,
109 DsLayout,
110 CLayout,
111 ADataType,
112 BDataType,
113 GemmAccDataType,
114 CShuffleDataType,
115 DsDataType,
116 CDataType,
117 AElementwiseOperation,
118 BElementwiseOperation,
119 CElementwiseOperation,
120 GemmSpec,
121 BlockSize,
122 ScaleBlockM,
123 ScaleBlockN,
124 ScaleBlockK,
125 MPerBlock,
126 NPerBlock,
127 KPerBlock,
128 AK1,
129 BK1,
130 MPerXDL,
131 NPerXDL,
132 MXdlPerWave,
133 NXdlPerWave_,
134 ABlockTransferThreadClusterLengths_AK0_M_AK1,
135 ABlockTransferThreadClusterArrangeOrder,
136 ABlockTransferSrcAccessOrder,
137 ABlockTransferSrcVectorDim,
138 ABlockTransferSrcScalarPerVector,
139 ABlockTransferDstScalarPerVector_AK1,
140 false,
141 ABlockLdsExtraM,
142 BBlockTransferThreadClusterLengths_BK0_N_BK1,
143 BBlockTransferThreadClusterArrangeOrder,
144 BBlockTransferSrcAccessOrder,
145 BBlockTransferSrcVectorDim,
146 BBlockTransferSrcScalarPerVector,
147 BBlockTransferDstScalarPerVector_BK1,
148 false,
149 BBlockLdsExtraN,
150 CShuffleMXdlPerWavePerShuffle,
151 math::min(CShuffleNXdlPerWavePerShuffle, NXdlPerWave_),
152 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
153 CDEShuffleBlockTransferScalarPerVectors,
154 BlkGemmPipeSched,
155 BlkGemmPipelineVer,
156 ActivationOP,
157 NSwizzle,
158 IsInputGemm,
159 MulRoutedWeight,
160 IndexType,
161 ComputeTypeA,
162 ComputeTypeB,
163 LDSTypeA,
164 LDSTypeB>;
167
168 using Argument = typename GridwiseGemm64::Argument;
169
170 static constexpr index_t APackedSize = []() {
172 return 2;
173 else
174 return 1;
175 }();
176
177 static constexpr index_t BPackedSize = []() {
179 return 2;
180 else
181 return 1;
182 }();
183
184 int GetPreShuffleParameters() override { return NPerXDL; }
185
186 // Invoker
187 struct Invoker : public BaseInvoker
188 {
189 template <typename GridwiseGemm>
190 float RunImp(const typename GridwiseGemm::Argument& arg,
191 const StreamConfig& stream_config = StreamConfig{})
192 {
193 if(stream_config.log_level_ > 0)
194 {
195 arg.Print();
196 }
197
198 if(!GridwiseGemm::CheckValidity(arg))
199 {
200 throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
201 }
202
203 index_t gdx, gdy, gdz;
204 std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N);
205
206 float ave_time = 0;
207
208 index_t k_grain = arg.KBatch * KPerBlock;
209 index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock;
210
211 const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
212 const auto RunKernel = [&](const auto& kernel) {
213 if(stream_config.flush_cache)
214 {
215
216 std::array<std::size_t, NumDTensor> DsSize;
217
218 auto arg_ = arg;
219
220 const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1(
221 arg_.M, arg_.MPadded, arg_.K, arg_.KPadded, arg_.StrideA, arg_.AK0);
222 const auto b_grid_desc_bk0_n_bk1 = GridwiseGemm::MakeBGridDescriptor_BK0_N_BK1(
223 arg_.K, arg_.KPadded, arg_.N, arg_.NPadded, arg_.StrideB, arg_.BK0);
224
225 auto size_a_buffer = a_grid_desc_ak0_m_ak1.GetElementSpaceSize() *
226 sizeof(ADataType) / APackedSize;
227 auto size_b_buffer = b_grid_desc_bk0_n_bk1.GetElementSpaceSize() *
228 sizeof(BDataType) / BPackedSize;
229
230 const auto ds_grid_desc_m_n = GridwiseGemm::MakeDsGridDescriptor_M_N(
231 arg_.M, arg_.MPadded, arg_.N, arg_.NPadded, arg_.StrideDs);
232
233 static_for<0, NumDTensor, 1>{}([&](auto i) {
234 using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
235 DsSize[i] = ds_grid_desc_m_n[i].GetElementSpaceSize() * sizeof(DDataType);
236 });
237 ck::utility::RotatingMemWrapperMultiD<typename GridwiseGemm::Argument,
238 DsDataType>
239 rotating_mem(arg_,
240 stream_config.rotating_count,
241 size_a_buffer,
242 size_b_buffer,
243 DsSize);
244 rotating_mem.Print();
245
246 auto run_flush_cache = [&]() {
247 // flush icache
249 // rotating mem
250 rotating_mem.Next();
251 // clear c mem
252 if(arg_.KBatch > 1)
253 hipGetErrorString(hipMemsetAsync(arg_.p_c_grid,
254 0,
255 arg_.M * arg_.N * sizeof(CDataType),
256 stream_config.stream_id_));
257 };
258
260 stream_config,
261 run_flush_cache,
262 kernel,
263 dim3(gdx, gdy, gdz),
264 dim3(BlockSize),
265 0,
266 arg_);
267 }
268 else
269 {
270 if(arg.KBatch > 1)
271 hipGetErrorString(hipMemsetAsync(arg.p_c_grid,
272 0,
273 arg.M * arg.N * sizeof(CDataType),
274 stream_config.stream_id_));
275
276 ave_time = launch_and_time_kernel(
277 stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg);
278 }
279 };
280
281 constexpr auto estimated_reg_a = MPerBlock * KPerBlock * sizeof(ADataType) / BlockSize /
282 4 * (1 + GridwiseGemm::NWave);
283 constexpr auto estimated_reg_b = NPerBlock * KPerBlock * sizeof(BDataType) / BlockSize /
284 4 * (2) * (IsInputGemm ? 2 : 1);
285 constexpr auto estimated_reg_c = MPerBlock * NPerBlock * sizeof(GemmAccDataType) /
286 BlockSize / 4 * (IsInputGemm ? 2 : 1);
287 constexpr auto estimated_reg_total =
288 estimated_reg_a + estimated_reg_b + estimated_reg_c;
289
290 constexpr index_t minimum_occupancy = (estimated_reg_total >= 256) ? 1 : 2;
291
292 constexpr auto MemoryDataOp =
294
295 if(has_main_k_block_loop)
296 {
297 // Tail number always full
298 if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
299 {
300 {
301 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
302 {
303 const auto kernel = kernel_moe_gemm<GridwiseGemm,
304 true,
305 MemoryDataOp,
306 minimum_occupancy,
308 RunKernel(kernel);
309 }
310 else
311 {
312 const auto kernel = kernel_moe_gemm<GridwiseGemm,
313 true,
314 MemoryDataOp,
315 minimum_occupancy,
317 RunKernel(kernel);
318 }
319 }
320 }
321 else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2 ||
322 BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
323 {
324 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
325 {
326 const auto kernel = kernel_moe_gemm_2lds<GridwiseGemm,
327 true,
328 MemoryDataOp,
329 minimum_occupancy,
331 RunKernel(kernel);
332 }
333 else
334 {
335 const auto kernel = kernel_moe_gemm_2lds<GridwiseGemm,
336 true,
337 MemoryDataOp,
338 minimum_occupancy,
340 RunKernel(kernel);
341 }
342 }
343 else
344 {
345 throw std::runtime_error("todo: only v1 & v2 support now");
346 }
347 }
348#if 1
349 else
350 {
351 // Tail number always 1
352 if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
353 {
354 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
355 {
356 const auto kernel = kernel_moe_gemm<GridwiseGemm,
357 false,
358 MemoryDataOp,
359 minimum_occupancy,
361 RunKernel(kernel);
362 }
363 else
364 {
365 const auto kernel = kernel_moe_gemm<GridwiseGemm,
366 false,
367 MemoryDataOp,
368 minimum_occupancy,
370 RunKernel(kernel);
371 }
372 }
373 else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2 ||
374 BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
375 {
376 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
377 {
378 const auto kernel = kernel_moe_gemm_2lds<GridwiseGemm,
379 false,
380 MemoryDataOp,
381 minimum_occupancy,
383 RunKernel(kernel);
384 }
385 else
386 {
387 const auto kernel = kernel_moe_gemm_2lds<GridwiseGemm,
388 false,
389 MemoryDataOp,
390 minimum_occupancy,
392 RunKernel(kernel);
393 }
394 }
395 }
396#endif
397
398 return ave_time;
399 }
400
402
403 // polymorphic
404 float Run(const BaseArgument* p_arg,
405 const StreamConfig& stream_config = StreamConfig{}) override
406 {
407 return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
408 }
409 };
410
411 static constexpr bool IsValidCompilationParameter()
412 {
413 // TODO: properly implement this check
414 return true;
415 }
416
417 static bool IsSupportedArgument(const Argument& arg)
418 {
419 // only impl kbatch 1 now
420 if(arg.KBatch > 1)
421 {
422 return false;
423 }
425 {
426 return false;
427 }
428 if(!is_bf16_atomic_supported() && std::is_same_v<CDataType, ck::bhalf_t> && arg.KBatch > 1)
429 {
430 return false;
431 }
432
433 if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding ||
434 GemmSpec == GemmSpecialization::NKPadding ||
435 GemmSpec == GemmSpecialization::MNKPadding ||
436 GemmSpec == GemmSpecialization::KPadding))
437 {
438 return false;
439 }
440 if(arg.N % NPerBlock != 0 || arg.K % KPerBlock != 0)
441 {
442 return false;
443 }
444
445 if(get_warp_size() == 64)
446 {
447 if constexpr(NXdlPerWave64 > 0)
448 {
450 }
451 }
452 else
453 {
454 if constexpr(NXdlPerWave32 > 0)
455 {
457 reinterpret_cast<const typename GridwiseGemm32::Argument&>(arg));
458 }
459 }
460 return false;
461 }
462
463 // polymorphic
464 bool IsSupportedArgument(const BaseArgument* p_arg) override
465 {
466 return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
467 }
468
469 static auto MakeArgument(const void* p_sorted_token_ids,
470 const void* p_sorted_expert_ids,
471 const void* p_max_token_id,
472 const void* p_a,
473 const void* p_b,
474 std::array<const void*, NumDTensor> p_ds,
475 void* p_c,
476 index_t NumTokens,
477 index_t TopK,
478 index_t M,
479 index_t N,
480 index_t K,
481 index_t StrideA,
482 index_t StrideB,
483 std::array<index_t, NumDTensor> StrideDs,
484 index_t StrideC,
485 const void* p_a_scale,
486 const void* p_b_scale,
487 index_t KBatch,
488 AElementwiseOperation a_element_op,
489 BElementwiseOperation b_element_op,
490 CElementwiseOperation c_element_op)
491 {
492 return Argument{static_cast<const index_t*>(p_sorted_token_ids),
493 static_cast<const index_t*>(p_sorted_expert_ids),
494 static_cast<const index_t*>(p_max_token_id),
495 static_cast<const ADataType*>(p_a),
496 static_cast<const BDataType*>(p_b),
497 p_ds,
498 static_cast<CDataType*>(p_c),
499 NumTokens,
500 TopK,
501 M,
502 N,
503 K,
504 StrideA,
505 StrideB,
506 StrideDs,
507 StrideC,
508 static_cast<const AScaleDataType*>(p_a_scale),
509 static_cast<const BScaleDataType*>(p_b_scale),
510 KBatch,
511 a_element_op,
512 b_element_op,
513 c_element_op};
514 }
515
516 static auto MakeInvoker() { return Invoker{}; }
517
518 // polymorphic
519 std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
520 const void* p_b,
521 std::array<const void*, NumDTensor> p_ds,
522 void* p_c,
523 index_t M,
524 index_t N,
525 index_t K,
526 index_t StrideA,
527 index_t StrideB,
528 std::array<ck::index_t, NumDTensor> StrideDs,
529 index_t StrideC,
530 const void* p_a_scale,
531 const void* p_b_scale,
532 // index_t KBatch,
533 AElementwiseOperation a_element_op,
534 BElementwiseOperation b_element_op,
535 CElementwiseOperation c_element_op) override
536 {
537 return std::make_unique<Argument>(nullptr,
538 nullptr,
539 nullptr,
540 static_cast<const ADataType*>(p_a),
541 static_cast<const BDataType*>(p_b),
542 p_ds,
543 static_cast<CDataType*>(p_c),
544 M, // randoms set, no use
545 0,
546 M,
547 N,
548 K,
549 StrideA,
550 StrideB,
551 StrideDs,
552 StrideC,
553 static_cast<const AScaleDataType*>(p_a_scale),
554 static_cast<const BScaleDataType*>(p_b_scale),
555 1, // KBatch,
556 a_element_op,
557 b_element_op,
558 c_element_op);
559 }
560
561 // polymorphic
562 std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
563 {
564 return std::make_unique<Invoker>(Invoker{});
565 }
566
567 // polymorphic
568 std::string GetTypeString() const override
569 {
570 auto str = std::stringstream();
571
572 std::map<BlockGemmPipelineScheduler, std::string> BlkGemmPipelineSchedulerToString{
575
576 std::map<BlockGemmPipelineVersion, std::string> BlkGemmPipelineVersionToString{
580
581 // clang-format off
582 str << "DeviceMoeGEmm"
583 << "<"
584 << getGemmSpecializationString(GemmSpec) << ", "
585 << std::string(ALayout::name)[0]
586 << std::string(BLayout::name)[0]
587 << std::string(CLayout::name)[0]
588 << ">"
589 << " BlkSize: "
590 << BlockSize << ", "
591 << "BlkTile: "
592 << MPerBlock<<"x"<<NPerBlock<<"x"<<KPerBlock << ", "
593 << "WaveTile: "
594 << MPerXDL<<"x"<<NPerXDL << ", "
595 << "WaveMap: "
596 << MXdlPerWave<<"x" << NXdlPerWave<<", "
597 << "VmemReadVec: "
598 << ABlockTransferSrcScalarPerVector<<"x"<<BBlockTransferSrcScalarPerVector<<", "
599 << "BlkGemmPipelineScheduler: "
600 << BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] << ", "
601 << "BlkGemmPipelineVersion: "
602 << BlkGemmPipelineVersionToString[BlkGemmPipelineVer] << ", "
603 << "BlkGemmPipelinePrefetchStages: "
604 << GridwiseGemm64::BlockwiseGemmPipe::PrefetchStages;
605 // clang-format on
606
607 return str.str();
608 }
609};
610
611} // namespace device
612} // namespace tensor_operation
613} // namespace ck
#define INVOKER_RUN3_IMPL
Definition device_base.hpp:114
#define GET_NXDL_PER_WAVE_IMPL
Definition device_base.hpp:81
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:14
__host__ __device__ constexpr T max(T x)
Definition utility/math.hpp:84
__host__ __device__ constexpr T min(T x)
Definition utility/math.hpp:116
Definition convolution_backward_data_specialization.hpp:8
std::string getGemmSpecializationString(const GemmSpecialization &s)
Definition gemm_specialization.hpp:32
GemmSpecialization
Definition gemm_specialization.hpp:11
@ MKPadding
Definition gemm_specialization.hpp:18
@ KPadding
Definition gemm_specialization.hpp:16
@ MNKPadding
Definition gemm_specialization.hpp:20
@ NKPadding
Definition gemm_specialization.hpp:19
Definition convolution_backward_data_specialization.hpp:7
void flush_icache()
Definition flush_cache.hpp:383
float launch_and_time_kernel_with_preprocess(const StreamConfig &stream_config, PreProcessFunc preprocess, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, GemmArgs &gemm_args, Args... args)
Definition flush_cache.hpp:398
Definition ck.hpp:268
int32_t index_t
Definition ck.hpp:299
@ Set
Definition ck.hpp:278
@ AtomicAdd
Definition ck.hpp:279
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
BlockGemmPipelineVersion
Definition blkgemmpipe_scheduler.hpp:12
@ v2
Definition blkgemmpipe_scheduler.hpp:15
@ v3
Definition blkgemmpipe_scheduler.hpp:16
@ v1
Definition blkgemmpipe_scheduler.hpp:14
__global__ void kernel_moe_gemm(typename GridwiseGemm::Argument karg)
Definition gridwise_moe_gemm.hpp:46
@ Even
Definition blkgemmpipe_scheduler.hpp:34
@ Odd
Definition blkgemmpipe_scheduler.hpp:33
typename tuple_element< I, TTuple >::type tuple_element_t
Definition utility/tuple.hpp:208
bool is_xdl_wmma_supported()
Definition host_utility/device_prop.hpp:76
__device__ constexpr index_t get_warp_size()
Definition get_id.hpp:10
constexpr bool is_same_v
Definition type.hpp:283
BlockGemmPipelineScheduler
Definition blkgemmpipe_scheduler.hpp:25
@ Intrawave
Definition blkgemmpipe_scheduler.hpp:26
@ Interwave
Definition blkgemmpipe_scheduler.hpp:27
__global__ void kernel_moe_gemm_2lds(typename GridwiseGemm::Argument karg)
Definition gridwise_moe_gemm.hpp:84
bool is_bf16_atomic_supported()
Definition host_utility/device_prop.hpp:108
Definition ck/stream_config.hpp:10
Definition gridwise_moe_gemm_blockscale.hpp:177
Definition data_type.hpp:187
Definition functional2.hpp:33
Definition device_base.hpp:197
Definition device_gemm_multiple_d_ab_scale.hpp:82
Definition device_moe_gemm_blockscale.hpp:188
float RunImp(const typename GridwiseGemm::Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition device_moe_gemm_blockscale.hpp:190
INVOKER_RUN3_IMPL float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition device_moe_gemm_blockscale.hpp:404
Definition device_moe_gemm_blockscale.hpp:100
static constexpr index_t BPackedSize
Definition device_moe_gemm_blockscale.hpp:177
GridwiseGemmBase< math::max(NXdlPerWave64, 1)> GridwiseGemm64
Definition device_moe_gemm_blockscale.hpp:165
static constexpr auto NXdlPerWave32
Definition device_moe_gemm_blockscale.hpp:103
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition device_moe_gemm_blockscale.hpp:562
static bool IsSupportedArgument(const Argument &arg)
Definition device_moe_gemm_blockscale.hpp:417
static constexpr index_t APackedSize
Definition device_moe_gemm_blockscale.hpp:170
static GET_NXDL_PER_WAVE_IMPL constexpr auto NXdlPerWave64
Definition device_moe_gemm_blockscale.hpp:102
GridwiseMoeGemmBlockScale< ALayout, BLayout, DsLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, DsDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, ScaleBlockM, ScaleBlockN, ScaleBlockK, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, math::min(CShuffleNXdlPerWavePerShuffle, NXdlPerWave_), CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEShuffleBlockTransferScalarPerVectors, BlkGemmPipeSched, BlkGemmPipelineVer, ActivationOP, NSwizzle, IsInputGemm, MulRoutedWeight, IndexType, ComputeTypeA, ComputeTypeB, LDSTypeA, LDSTypeB > GridwiseGemmBase
Definition device_moe_gemm_blockscale.hpp:106
int GetPreShuffleParameters() override
Definition device_moe_gemm_blockscale.hpp:184
std::string GetTypeString() const override
Definition device_moe_gemm_blockscale.hpp:568
GridwiseGemmBase< NXdlPerWave32 > GridwiseGemm32
Definition device_moe_gemm_blockscale.hpp:166
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition device_moe_gemm_blockscale.hpp:464
typename GridwiseGemm64::Argument Argument
Definition device_moe_gemm_blockscale.hpp:168
static auto MakeArgument(const void *p_sorted_token_ids, const void *p_sorted_expert_ids, const void *p_max_token_id, const void *p_a, const void *p_b, std::array< const void *, NumDTensor > p_ds, void *p_c, index_t NumTokens, index_t TopK, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, std::array< index_t, NumDTensor > StrideDs, index_t StrideC, const void *p_a_scale, const void *p_b_scale, index_t KBatch, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op)
Definition device_moe_gemm_blockscale.hpp:469
static constexpr bool IsValidCompilationParameter()
Definition device_moe_gemm_blockscale.hpp:411
static constexpr index_t NumDTensor
Definition device_moe_gemm_blockscale.hpp:104
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, std::array< const void *, NumDTensor > p_ds, void *p_c, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, std::array< ck::index_t, NumDTensor > StrideDs, index_t StrideC, const void *p_a_scale, const void *p_b_scale, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op) override
Definition device_moe_gemm_blockscale.hpp:519
static auto MakeInvoker()
Definition device_moe_gemm_blockscale.hpp:516
Definition flush_cache.hpp:174