moe_flatmm_kernel.hpp Source File

moe_flatmm_kernel.hpp Source File#

Composable Kernel: moe_flatmm_kernel.hpp Source File
moe_flatmm_kernel.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
11#include "ck_tile/host.hpp"
12
13// #define disable_tile_gs
14
15namespace ck_tile {
16
17template <class ScaleM = FlatmmScalePointer<-1>,
18 class ScaleN = FlatmmScalePointer<-1>,
19 class ExpertBias = FlatmmScalePointer<-1>>
20struct MoeFlatmmHostArgs : ScaleFlatmmHostArgs<ScaleM, ScaleN, 0>
21{
31 ExpertBias exp_bias;
32
33 CK_TILE_HOST MoeFlatmmHostArgs() noexcept = default;
34
35 CK_TILE_HOST MoeFlatmmHostArgs(const ck_tile::index_t* p_sorted_token_ids_,
36 const void* p_sorted_expert_weights_,
37 const ck_tile::index_t* p_sorted_expert_ids_,
38 const ck_tile::index_t* p_max_token_id_,
39 const void* a_ptr_,
40 const void* b_ptr_,
41 void* c_ptr_,
42 ck_tile::index_t NumTokens_,
43 ck_tile::index_t NumExperts_,
44 ck_tile::index_t TopK_,
45 ck_tile::index_t k_batch_,
46 ck_tile::index_t M_,
47 ck_tile::index_t N_,
48 ck_tile::index_t K_,
49 ck_tile::index_t stride_A_,
50 ck_tile::index_t stride_B_,
51 ck_tile::index_t stride_C_,
52 ScaleM scale_m_ = {},
53 ScaleN scale_n_ = {},
54 ExpertBias exp_bias_ = {})
55 : MoeFlatmmHostArgs(p_sorted_token_ids_,
56 p_sorted_expert_weights_,
57 p_sorted_expert_ids_,
58 p_max_token_id_,
59 a_ptr_,
60 b_ptr_,
61 c_ptr_,
62 NumTokens_,
63 NumExperts_,
64 TopK_,
65 k_batch_,
66 M_,
67 N_,
68 K_,
69 stride_A_,
70 stride_B_,
71 stride_C_,
72 0, // n_padded_zeros_
73 0, // k_padded_zeros_
74 scale_m_,
75 scale_n_,
76 exp_bias_)
77 {
78 }
79
81 const void* p_sorted_expert_weights_,
82 const ck_tile::index_t* p_sorted_expert_ids_,
83 const ck_tile::index_t* p_max_token_id_,
84 const void* a_ptr_,
85 const void* b_ptr_,
86 void* c_ptr_,
87 ck_tile::index_t NumTokens_,
88 ck_tile::index_t NumExperts_,
89 ck_tile::index_t TopK_,
90 ck_tile::index_t k_batch_,
94 ck_tile::index_t stride_A_,
95 ck_tile::index_t stride_B_,
96 ck_tile::index_t stride_C_,
97 ck_tile::index_t n_padded_zeros_ = 0,
98 ck_tile::index_t k_padded_zeros_ = 0,
99 ScaleM scale_m_ = {},
100 ScaleN scale_n_ = {},
101 ExpertBias exp_bias_ = {})
102 : ScaleFlatmmHostArgs<ScaleM, ScaleN, 0>(a_ptr_,
103 b_ptr_,
104 {}, // d_ptr_array
105 c_ptr_,
106 k_batch_,
107 M_,
108 N_,
109 K_,
110 stride_A_,
111 stride_B_,
112 {}, // d_stride_array
113 stride_C_,
114 scale_m_,
115 scale_n_),
116 NumTokens(NumTokens_),
117 NumExperts(NumExperts_),
118 TopK(TopK_),
119 p_sorted_token_ids(p_sorted_token_ids_),
120 p_sorted_expert_ids(p_sorted_expert_ids_),
121 p_max_token_id(p_max_token_id_),
122 p_sorted_expert_weights(p_sorted_expert_weights_),
123 n_padded_zeros(n_padded_zeros_),
124 k_padded_zeros(k_padded_zeros_),
125 exp_bias(exp_bias_)
126 {
127 }
128};
129
136
137namespace moe {
138
140{
141 template <typename T>
142 CK_TILE_HOST_DEVICE T operator()(T gate, T linear = 1) const
143 {
144 ck_tile::element_wise::Silu{}(gate, gate);
145 return gate * linear;
146 };
147};
148
149struct Swiglu
150{
151 const float alpha;
152 const float limit;
153
155 Swiglu(float alpha_ = 1.702f, float limit_ = 7.0f) // use value in gpt-oss as default
156 : alpha(alpha_), limit(limit_)
157 {
158 }
159
160 template <typename T>
161 CK_TILE_HOST_DEVICE T operator()(T gate, T linear) const
162 {
163 static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
164 std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int8_t> ||
165 std::is_same_v<T, int32_t>,
166 "Data type is not supported by this operation!");
167
168 constexpr T one = type_convert<T>(1);
169
170 gate = gate < limit ? gate : limit;
171 linear = linear < limit ? (linear > -limit ? linear : -limit) : limit;
172
173 if constexpr(std::is_same_v<T, float>)
174 {
175 return gate * __builtin_amdgcn_rcpf(one + ck_tile::exp(alpha * -gate)) * (linear + 1);
176 }
177 else
178 {
179 return gate * (one / (one + ck_tile::exp(alpha * -gate))) * (linear + 1);
180 }
181 }
182};
183
184} // namespace moe
185
186template <typename TilePartitioner_,
187 typename FlatmmPipeline_,
188 typename EpiloguePipeline_,
189 MoeFlatmmKind kind,
190 typename FusedActivation = moe::MoeSilu>
192{
203 static constexpr index_t kBlockSize = FlatmmPipeline::BlockSize;
204 static constexpr bool UsePersistentKernel = FlatmmPipeline::UsePersistentKernel;
205
208 // Below type is actually accumulation data type - the output of block GEMM.
210
211 using AccDataType = float;
212 using ActivationOp = FusedActivation;
213
214 static constexpr index_t NumDTensor = DsDataType::size();
215
216 static constexpr auto I0 = number<0>();
217 static constexpr auto I1 = number<1>();
218 static constexpr auto I2 = number<2>();
219 static constexpr auto I3 = number<3>();
220
221 static_assert(DsLayout::size() == DsDataType::size(),
222 "The size of DsLayout and DsDataType should be the same");
223
224 static constexpr bool IsInputGemm = kind != MoeFlatmmKind::kFFN_gemm2;
225 static constexpr bool IsGateUp = kind == MoeFlatmmKind::kFFN_gemm1_gate_up;
226
227 // static constexpr index_t kBlockSize = EpiloguePipeline::kBlockSize;
228 static constexpr index_t kMPerBlock = EpiloguePipeline::kMPerBlock;
229 static constexpr index_t kNPerBlock = EpiloguePipeline::kNPerBlock;
230 static constexpr index_t MWave = EpiloguePipeline::MWave;
231 static constexpr index_t NWave = EpiloguePipeline::NWave;
232 static constexpr index_t MPerXdl = EpiloguePipeline::MPerXdl;
233 static constexpr index_t NPerXdl = EpiloguePipeline::NPerXdl;
234 static constexpr index_t KPerXdl = EpiloguePipeline::KPerXdl;
235 static constexpr index_t isCTransposed = EpiloguePipeline::isCTransposed;
236 static constexpr index_t kMPerIteration = MPerXdl * MWave;
237 static constexpr index_t kNPerIteration = NPerXdl * NWave;
239
240 static constexpr int OutputNPerBlock =
241 IsGateUp ? TilePartitioner::NPerBlock / 2 : TilePartitioner::NPerBlock;
242
243 // MXF4_Pipeline only has the of scale B and granularityK is 32
244 static constexpr bool MXFP4_Pipeline = std::is_same_v<BDataType, pk_fp4_t>;
245 static constexpr int MXFP4N_Pack = 2;
246 static constexpr int MXFP4K_Pack = 2;
247
248 static constexpr int N_Pack = MXFP4_Pipeline ? MXFP4N_Pack : 1;
249 static constexpr int K_Pack = MXFP4_Pipeline ? MXFP4K_Pack : 1;
250
252
253 template <class ScaleM = FlatmmScalePointer<-1>,
254 class ScaleN = FlatmmScalePointer<-1>,
255 class ExpertBias = FlatmmScalePointer<-1>>
280
281 template <class ScaleM = FlatmmScalePointer<-1>,
282 class ScaleN = FlatmmScalePointer<-1>,
283 class ExpertBias = FlatmmScalePointer<-1>>
284 CK_TILE_HOST static constexpr auto
286 {
288 hostArgs.p_sorted_expert_ids,
289 hostArgs.p_max_token_id,
291 hostArgs.a_ptr,
292 hostArgs.b_ptr,
293 hostArgs.e_ptr,
294 hostArgs.NumTokens,
295 hostArgs.TopK,
296 hostArgs.M,
297 hostArgs.N,
298 hostArgs.K,
299 hostArgs.stride_A,
300 hostArgs.stride_B,
301 hostArgs.stride_C,
302 hostArgs.k_batch,
303 hostArgs.n_padded_zeros,
304 hostArgs.k_padded_zeros,
305 hostArgs.scale_m,
306 hostArgs.scale_n,
307 hostArgs.exp_bias};
308 }
309
310 [[nodiscard]] CK_TILE_HOST static const std::string GetName()
311 {
312 return concat(
313 '_', "moe_flatmm", gemm_prec_str<ADataType, BDataType>, FlatmmPipeline::GetName());
314 }
315
316 static constexpr auto BlockSize() -> dim3 { return dim3(kBlockSize); }
317
318 static constexpr auto GridSize(index_t M, index_t N, index_t KBatch)
319 {
320 return dim3(TilePartitioner::GridSize(M, N), 1, KBatch);
321 }
322 template <class MoeFlatmmKernelArgs>
323 static constexpr auto GridSize(const MoeFlatmmKernelArgs& kargs)
324 {
325 if constexpr(UsePersistentKernel)
326 {
327 hipDeviceProp_t prop;
328 int deviceId = 0; // default device
329
330 constexpr int block_size = MoeFlatmmKernel::BlockSize().x;
331 int dync_smem_size = 0;
332 int maxActiveBlocksPerCU = 0;
333
334 [[maybe_unused]] auto e = hipGetDeviceProperties(&prop, deviceId);
335
336 e = hipOccupancyMaxActiveBlocksPerMultiprocessor(
337 &maxActiveBlocksPerCU,
339 block_size,
340 dync_smem_size);
341
342 const int persistent_block_size = prop.multiProcessorCount * maxActiveBlocksPerCU;
343 const int total_work_tile_cnt = TilePartitioner::GridSize(kargs.M, kargs.N);
344
345 // std::cout << "maxActiveBlocksPerCU: " << maxActiveBlocksPerCU
346 // << ", persistent_block_size: " << persistent_block_size
347 // << ", total_work_tile_cnt: " << total_work_tile_cnt << std::endl;
348
349 assert(kargs.k_batch == 1);
350 return dim3(min(persistent_block_size, total_work_tile_cnt), 1, kargs.k_batch);
351 }
352 else
353 {
354 return dim3(TilePartitioner::GridSize(kargs.M, kargs.N), 1, kargs.k_batch);
355 }
356 }
357
359 {
360 return max(FlatmmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
361 }
363 {
364 return FlatmmPipeline::GetSmemSize();
365 }
366
368 {
369 template <class KernelArgs>
370 __device__ SplitKBatchOffset(const KernelArgs& kargs, const std::size_t k_id = blockIdx.z)
371 {
372 constexpr auto K1 = TilePartitioner::BlockGemmShape::WarpTile::at(number<2>{});
373 const index_t K_t = kargs.k_batch * K1;
374 const index_t KRead = (kargs.K + K_t - 1) / K_t * K1;
375
376 if constexpr(std::is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
377 {
378 a_k_split_offset = k_id * KRead;
379 }
380 else if constexpr(std::is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
381 {
382 a_k_split_offset = k_id * KRead * kargs.stride_A;
383 }
384
385 if constexpr(std::is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
386 {
387 b_k_split_offset = k_id * KRead * kargs.stride_B;
388 }
389 else if constexpr(std::is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
390 {
391 b_k_split_offset = k_id * KRead;
392 }
393
394 if(k_id < static_cast<uint32_t>(kargs.k_batch - 1))
395 {
396 splitted_k = KRead;
397 }
398 else
399 {
400 splitted_k = kargs.K - KRead * (kargs.k_batch - 1);
401 }
402 }
403
407 };
408
409 template <typename KernelArgs>
410 CK_TILE_HOST static bool IsSupportedArgument(const KernelArgs& kargs)
411 {
412 if constexpr(EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
414 {
415 if(kargs.k_batch != 1)
416 {
417 std::cerr << "Conditions not met for Kbatch >1 !" << std::endl;
418 return false;
419 }
420 }
421 if constexpr(UsePersistentKernel)
422 {
423 if(kargs.k_batch != 1)
424 {
425 std::cerr << "Persistent mode doesn't support Kbatch >1 !" << std::endl;
426 return false;
427 }
428 }
429
430 if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
431 {
432 if(kargs.K % TilePartitioner::KPerBlock != 0 && FlatmmPipeline::kPadK == false)
433 {
434 std::cerr << "Can't support K that is not a multiple of KPerBlock"
435 " without padding!"
436 << std::endl;
437 return false;
438 }
439 if(kargs.K % FlatmmPipeline::GetVectorSizeA() != 0)
440 {
441 std::cerr << "K is not a multiple of vector load size for A tensor!" << std::endl;
442 return false;
443 }
444 }
445 else
446 {
447 if(kargs.M % TilePartitioner::MPerBlock != 0 && FlatmmPipeline::kPadM == false)
448 {
449 std::cerr << "Can't support M that is not a multiple of MPerBlock"
450 " without padding!"
451 << std::endl;
452 return false;
453 }
454 if(kargs.M % FlatmmPipeline::GetVectorSizeA() != 0)
455 {
456 std::cerr << "M is not a multiple of vector load size for A tensor!" << std::endl;
457 return false;
458 }
459 }
460
461 if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>)
462 {
463 // if(kargs.N % TilePartitioner::NPerBlock != 0 && FlatmmPipeline::kPadN == false)
464 // {
465 // std::cerr << "Can't support N that is not a multiple of NPerBlock"
466 // " without padding!"
467 // << std::endl;
468 // return false;
469 // }
470 if(kargs.N % FlatmmPipeline::GetVectorSizeB() != 0)
471 {
472 std::cerr << "N is not a multiple of vector load size for B tensor!" << std::endl;
473 return false;
474 }
475 }
476 else
477 {
478 if(kargs.K % TilePartitioner::KPerBlock != 0 && FlatmmPipeline::kPadK == false)
479 {
480 std::cerr << "Can't support K that is not a multiple of KPerBlock"
481 " without padding!"
482 << std::endl;
483 return false;
484 }
485 if(kargs.K % FlatmmPipeline::GetVectorSizeB() != 0)
486 {
487 std::cerr << "K is not a multiple of vector load size for B tensor!" << std::endl;
488 return false;
489 }
490 }
491
492 bool DTesnorIsValid = {true};
493 static_for<0, NumDTensor, 1>{}([&](auto index) {
494 using DiLayout = remove_cvref_t<std::tuple_element_t<index.value, DsLayout>>;
495 if(std::is_same_v<DiLayout, ELayout> == false)
496 {
497 DTesnorIsValid = false;
498 }
499 if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
500 {
501 if(kargs.N % TilePartitioner::NPerBlock != 0 && FlatmmPipeline::kPadN == false)
502 {
503 CK_TILE_ERROR("Can't support N for tensor D that is not a multiple of "
504 "NPerBlock without padding!");
505 DTesnorIsValid = false;
506 }
507 if(kargs.N % EpiloguePipeline::GetVectorSizeD(index) != 0)
508 {
509 CK_TILE_ERROR("N is not a multiple of vector load size for D tensor!");
510 DTesnorIsValid = false;
511 }
512 }
513 else
514 {
515 if(kargs.M % TilePartitioner::MPerBlock != 0 && FlatmmPipeline::kPadM == false)
516 {
517 CK_TILE_ERROR("Can't support M for tensor D that is not a multiple of "
518 "MPerBlock without padding!");
519
520 DTesnorIsValid = false;
521 }
522 if(kargs.M % EpiloguePipeline::GetVectorSizeD(index) != 0)
523 {
524 CK_TILE_ERROR("M is not a multiple of vector load size for D tensor!");
525 DTesnorIsValid = false;
526 }
527 }
528 });
529
530 if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
531 {
532 if(kargs.stride_C % TilePartitioner::NPerBlock != 0 && FlatmmPipeline::kPadN == false)
533 {
534 std::cerr << "Can't support N that is not a multiple of NPerBlock"
535 " without padding!"
536 << std::endl;
537 return false;
538 }
539 if(kargs.N % EpiloguePipeline::GetVectorSizeC() != 0)
540 {
541 std::cerr << "N is not a multiple of vector load size for C tensor!" << std::endl;
542 return false;
543 }
544 }
545 else
546 {
547 if(kargs.M % TilePartitioner::MPerBlock != 0 && FlatmmPipeline::kPadM == false)
548 {
549 std::cerr << "Can't support M that is not a multiple of MPerBlock"
550 " without padding!"
551 << std::endl;
552 return false;
553 }
554 if(kargs.M % EpiloguePipeline::GetVectorSizeC() != 0)
555 {
556 std::cerr << "M is not a multiple of vector load size for C tensor!" << std::endl;
557 return false;
558 }
559 }
560 return DTesnorIsValid;
561 }
562
565 typename KernelArgs>
566 CK_TILE_DEVICE static auto
568 const BDataType* b_flat_ptr,
569 EDataType* e_ptr,
570 [[maybe_unused]] const AccDataType* exp_weight_ptr,
571 const int expert_id,
572 const KernelArgs& kargs,
573 const SplitKBatchOffset& splitk_batch_offset)
574 {
575 const auto& a_tensor_view = [&]() {
576 if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
577 {
579 a_ptr,
580 make_tuple(IsInputGemm ? kargs.NumTokens : kargs.NumTokens * kargs.TopK,
581 splitk_batch_offset.splitted_k),
582 make_tuple(kargs.stride_A, 1),
583 number<FlatmmPipeline::GetVectorSizeA()>{},
584 number<1>{});
585 }
586 else
587 {
589 a_ptr,
590 make_tuple(splitk_batch_offset.splitted_k,
591 IsInputGemm ? kargs.NumTokens : kargs.NumTokens * kargs.TopK),
592 make_tuple(kargs.stride_A, 1),
593 number<FlatmmPipeline::GetVectorSizeA()>{},
594 number<1>{});
595 }
596 }();
597
598 index_t kFlatK = kargs.K * BlockGemmShape::WarpTile::at(I1); // TODO (support splitK)
599 index_t kFlatN = kargs.N * kargs.K / kFlatK;
600
601 const auto& b_flat_tensor_view = [&]() {
603 b_flat_ptr,
604 make_tuple(kFlatN - kargs.n_padded_zeros / NPerXdl, kFlatK),
605 make_tuple(kFlatK, 1),
606 number<FlatmmPipeline::GetVectorSizeB()>{},
607 number<1>{});
608 }();
609
610 // TODO: enable vector write for C in ColMajor
611 const auto& c_tensor_view = [&]() {
612 if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
613 {
615 e_ptr,
616 make_tuple(IsInputGemm ? kargs.NumTokens * kargs.TopK : kargs.NumTokens,
617 IsGateUp ? kargs.N / 2 : kargs.N),
618 make_tuple(kargs.stride_C, 1),
619 number<EpiloguePipeline::GetVectorSizeC()>{},
620 number<1>{});
621 }
622 else
623 {
625 e_ptr,
626 make_tuple(IsInputGemm ? kargs.NumTokens * kargs.TopK : kargs.NumToken,
627 IsGateUp ? kargs.N / 2 : kargs.N),
628 make_tuple(1, kargs.stride_C),
629 number<1>{},
630 number<1>{});
631 }
632 }();
633
634 auto scale_n = kargs.scale_n;
635 constexpr int GranularityK = decltype(scale_n)::GranularityK;
636
637 index_t scale_k = GranularityK == 0 ? 1 : (kargs.K + GranularityK - 1) / GranularityK;
638 index_t FlatScaleK = scale_k * N_Pack * BlockGemmShape::WarpTile::at(I1);
639 index_t FlatScaleN = kargs.N / N_Pack / BlockGemmShape::WarpTile::at(I1);
640
641 using ScaleType = std::conditional_t<MXFP4_Pipeline, e8m0_t, float>;
642
643 const auto scale_b_flat_view = make_naive_tensor_view<address_space_enum::global>(
644 reinterpret_cast<const ScaleType*>(scale_n.ptr) + expert_id * kargs.N * scale_k,
645 make_tuple(FlatScaleN - kargs.n_padded_zeros / NPerXdl / N_Pack, FlatScaleK),
646 make_tuple(FlatScaleK, 1),
647 number<8>{},
648 number<1>{});
649
650 return make_tuple(a_tensor_view, b_flat_tensor_view, c_tensor_view, scale_b_flat_view);
651 }
652
653 template <typename TensorView>
654 CK_TILE_DEVICE static auto MakeGemmPadViews(const TensorView& views)
655 {
656 const auto& a_pad_view = [&]() {
657 const auto& a_tensor_view = views.at(I0);
658 if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
659 {
660 return pad_tensor_view(a_tensor_view,
664 }
665 else
666 {
667 return pad_tensor_view(a_tensor_view,
671 }
672 }();
673
674 // TODO vector write in for C in ColMajor
675 const auto& c_pad_view = [&]() {
676 const auto& c_tensor_view = views.at(I2);
677 if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
678 {
679 return pad_tensor_view(
680 c_tensor_view,
683 }
684 else
685 {
686 return pad_tensor_view(
687 c_tensor_view,
690 }
691 }();
692
693 return make_tuple(a_pad_view, views.at(I1), c_pad_view, views.at(I3));
694 }
695
696 template <typename PadView>
697 CK_TILE_DEVICE static auto MakeGemmTileWindows(const PadView& views,
698 [[maybe_unused]] const index_t coord_m,
699 const index_t coord_n)
700 {
701 const auto& a_pad_view = views.at(number<0>{});
702 const auto& b_flat_pad_view = views.at(number<1>{});
703 const auto& c_pad_view = views.at(number<2>{});
704
705 const auto& a_block_window = [&]() {
706 if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
707 {
708 return make_tile_window(a_pad_view,
711 {coord_m, 0}); // NOTE!
712 }
713 else
714 {
715 return make_tile_window(a_pad_view,
718 {0, 0}); // NOTE!
719 }
720 }();
721
722 constexpr bool isNonInterleaveGateUp = !IsGateUp || MXFP4_Pipeline;
723
724 const auto& b_flat_block_window =
725 make_tile_window(b_flat_pad_view,
728 {static_cast<int>(coord_n / BlockGemmShape::WarpTile::at(I1) /
729 (isNonInterleaveGateUp ? 1 : 2)),
730 0});
731
732 const int output_N_offset = IsGateUp ? coord_n / 2 : coord_n;
733
734 auto c_block_window = make_tile_window(
735 c_pad_view,
737 {0, // offset_m is included when construct C-scatter-window offsets
738 output_N_offset});
739
740 constexpr int GranularityK = 32; // fixed config for MXF4_Pipeline
741 constexpr int XDLPerLoadScaleB =
742 MXFP4_Pipeline ? 4 : 1; // GranularityK32 / XDL16x16x32_K8 = 4
743
744 auto scale_block_window =
745 make_tile_window(views.at(I3),
747 number<FlatmmPipeline::flatKPerWarp * N_Pack * K_Pack *
748 XDLPerLoadScaleB / GranularityK>{}),
749 {coord_n / BlockGemmShape::WarpTile::at(I1) / N_Pack, 0});
750
751 return make_tuple(a_block_window, b_flat_block_window, c_block_window, scale_block_window);
752 }
753
754 template <class MoeFlatmmKernelArgs>
756 {
757 int partition_idx = blockIdx.x;
758 int total_work_tile_cnt = TilePartitioner::GridSize(kargs.M, kargs.N);
759 do
760 {
761 const auto [block_offset_m, block_offset_n] =
762 TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(partition_idx);
763
764 this->operator()(kargs, block_offset_m, block_offset_n);
765 partition_idx += gridDim.x;
766 } while(UsePersistentKernel && partition_idx < total_work_tile_cnt);
767 }
768
769 template <class MoeFlatmmKernelArgs>
771 {
772
773 // const auto [iM, iN] = TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(blockIdx.x);
774 const index_t coord_m = __builtin_amdgcn_readfirstlane(iM * TilePartitioner::MPerBlock);
775 const index_t coord_n = __builtin_amdgcn_readfirstlane(iN * TilePartitioner::NPerBlock);
776 const index_t max_token_id = kargs.p_max_token_id[0];
777 // allocate LDS
778 __shared__ char smem_ptr_ping[GetSmemPingSize()];
779 __shared__ char smem_ptr_pong[GetSmemPongSize()];
780
781 const index_t expert_id = kargs.p_sorted_expert_ids[iM];
782
783 constexpr auto a_dram_dist = FlatmmPipeline::GetADramTileDistribution();
784 const auto a_coord = a_dram_dist.calculate_index(); // 2d thread offset, [i_row, i_col]
785
786 constexpr ck_tile::index_t DramMRepeat =
787 decltype(a_dram_dist)::DstrEncode::hs_lengthss_[number<0>{}][number<0>{}];
789
790 constexpr index_t token_id_offset = 24;
791 constexpr index_t token_id_mask = (1 << token_id_offset) - 1;
792
793 auto row_to_token_idx = [&](auto row_idx) {
794 const index_t fused_token =
795 kargs.p_sorted_token_ids[row_idx]; // topk-idx[31:24] + token_idx[23:0]
796 index_t gather_token_id = fused_token & token_id_mask;
797 if constexpr(!IsInputGemm)
798 {
799 gather_token_id = gather_token_id * kargs.TopK + (fused_token >> token_id_offset);
800 }
801 return gather_token_id;
802 };
803
804 if(coord_m >= max_token_id)
805 return;
806
807 static_for<0, DramMRepeat, 1>{}([&](auto m0) {
808 const auto row_idx =
809 coord_m + m0 * (TilePartitioner::MPerBlock / DramMRepeat) + a_coord[I0];
810 index_t gather_token_id = row_to_token_idx(row_idx);
811 a_offsets[m0] = std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>
812 ? gather_token_id * kargs.stride_A
813 : gather_token_id;
814 });
815
816 const SplitKBatchOffset splitk_batch_offset(kargs);
817 const long_index_t expert_stride =
818 __builtin_amdgcn_readfirstlane(long_index_t(kargs.N) * kargs.K);
819
820 const ADataType* a_ptr =
821 static_cast<const ADataType*>(kargs.a_ptr) + splitk_batch_offset.a_k_split_offset;
822 const BDataType* b_flat_ptr =
823 static_cast<const BDataType*>(kargs.b_ptr) +
824 (splitk_batch_offset.b_k_split_offset + expert_stride * expert_id) / WeightPackedSize;
825 EDataType* e_ptr = static_cast<EDataType*>(kargs.e_ptr);
826
827 const AccDataType* exp_weight_ptr =
828 static_cast<const AccDataType*>(kargs.p_sorted_expert_weights);
829
830 const auto& gemm_tensor_views_tuple = MakeGemmTensorViews(
831 a_ptr, b_flat_ptr, e_ptr, exp_weight_ptr, expert_id, kargs, splitk_batch_offset);
832 const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
833
834 auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, coord_m, coord_n);
835
836 const index_t num_loop = TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k);
837
838 // Run GEMM cooperatively by whole workgroup.
839 const auto& a_block_window = gemm_tile_windows.at(I0);
840 const auto& b_block_window = gemm_tile_windows.at(I1);
841 const auto& scale_block_window = gemm_tile_windows.at(I3);
842
843 auto a_gather_block_tile =
844 ck_tile::make_tile_scatter_gather(a_block_window.get_bottom_tensor_view(),
845 a_block_window.get_window_lengths(),
846 a_block_window.get_window_origin(),
847 a_dram_dist,
848 a_offsets); // K DRAM tile window for
849
850 auto c_block_tile = [&] {
851 if constexpr(MXFP4_Pipeline)
852 {
853 // MXFP4_Pipeline uses gate-up interleave 16 layout for weight
854 // so don't need extra processing
855 return FlatmmPipeline{}(a_gather_block_tile,
856 b_block_window,
857 scale_block_window, // weight scale with granularityK = 32
858 num_loop,
859 kargs.k_padded_zeros,
860 smem_ptr_ping,
861 smem_ptr_pong);
862 }
863 else
864 {
865 return FlatmmPipeline{}(a_gather_block_tile,
866 b_block_window,
868 num_loop,
869 smem_ptr_ping,
870 smem_ptr_pong);
871 }
872 }();
873
874 auto& c_block_window = gemm_tile_windows.at(number<2>{});
875
876 // Run EpiloguePipeline
877 {
878 using EpiProblem = typename EpiloguePipeline::Problem;
879 using ODataType = typename EpiloguePipeline::ODataType;
880 using CWarpDstr = typename EpiloguePipeline::CWarpDstr;
881
882 constexpr index_t NumMXdlPerWavePerShuffle = EpiloguePipeline::NumMXdlPerWavePerShuffle;
883 constexpr index_t NumNXdlPerWavePerShuffle = EpiloguePipeline::NumNXdlPerWavePerShuffle;
884 constexpr index_t MPerIterationShuffle = EpiloguePipeline::MPerIterationShuffle;
885 constexpr index_t NPerIterationShuffle = EpiloguePipeline::NPerIterationShuffle;
886
887 constexpr index_t MRepeat = EpiloguePipeline::MRepeat;
888 constexpr index_t NRepeat = EpiloguePipeline::NRepeat;
889 constexpr index_t OutputNRepeat = IsGateUp ? NRepeat / 2 : NRepeat;
890
891 [[maybe_unused]] constexpr index_t EpiVectorSizeC = EpiloguePipeline::GetVectorSizeC();
892 [[maybe_unused]] constexpr index_t BlockedXDLN_PerWarp =
893 EpiloguePipeline::BlockedXDLN_PerWarp;
894
895 static_assert(!IsGateUp || NumNXdlPerWavePerShuffle % 2 == 0);
896
897 constexpr index_t OutputNumNXdlPerWavePerShuffle =
898 IsGateUp ? NumNXdlPerWavePerShuffle / 2 : NumNXdlPerWavePerShuffle;
899 constexpr index_t LDS_NPerIterationShuffle =
900 IsGateUp ? NPerIterationShuffle / 2 : NPerIterationShuffle;
901
902 constexpr auto lds_block_desc = make_naive_tensor_descriptor(
905
906 // EpiloguePipeline::template MakeLdsBlockDescriptor<EpiProblem>();
908 reinterpret_cast<ODataType*>(smem_ptr_ping), lds_block_desc);
909
910 constexpr int ScaleGranularityM = decltype(kargs.scale_m)::GranularityMN;
911 constexpr int ScaleGranularityN = decltype(kargs.scale_n)::GranularityMN;
912
913 constexpr index_t scale_stride_m = ScaleGranularityM == 0 ? 0 // per-tensor scale
914 : 1; // per-token scale
915 constexpr index_t scale_stride_n = ScaleGranularityN == 0 ? 0 // per-tensor scale
916 : 1; // per-channel scale
917
918 auto output_acc_tile_distr =
927 typename CWarpDstr::DstrEncode{}));
928
929 const auto scale_m_coord =
930 output_acc_tile_distr.calculate_index(); // 2d thread offset, [i_row, i_col]
931
932 constexpr index_t kM2 = 4; // Val-dim
933 constexpr index_t kM1 = get_warp_size() / NPerXdl; // Thr-dim
934 constexpr index_t kM0 = MPerXdl / kM1 / kM2; // Var-dim
935
936 constexpr index_t ScaleMRepeat = MRepeat * kM0 * kM2;
938
939 if constexpr(!MXFP4_Pipeline)
940 static_for<0, MRepeat, 1>{}([&](auto mIter) {
941 static_for<0, kM0, 1>{}([&](auto m0) {
942 static_for<0, kM2, 1>{}([&](auto m2) {
943 const auto row_idx =
944 coord_m + mIter * MPerXdl + m0 * kM1 * kM2 + m2 + scale_m_coord[I0];
945 scale_m_offsets[mIter * number<kM0 * kM2>{} + m0 * number<kM2>{} + m2] =
946 row_to_token_idx(row_idx);
947 });
948 });
949 });
950
951 constexpr int DynamicTileOffsetFlag = 0;
952
953 constexpr bool EnableBias = decltype(kargs.exp_bias)::GranularityMN != -1;
954
955 auto permute_tensor_view = [&](auto naive_view, auto is_needed_to_permute_N_PACK) {
956 if constexpr(!is_needed_to_permute_N_PACK)
957 {
958 return naive_view;
959 }
960 else
961 {
962 auto view1 = transform_tensor_view(
963 naive_view,
967 number<NRepeat / N_Pack>{},
970 number<NPerXdl>{}))),
974 view1,
978 number<NRepeat / N_Pack>{},
981 number<NPerXdl>{}))),
984 }
985 };
986
987 auto scale_m_window =
989 kargs.scale_m.ptr,
990 make_tuple(kargs.M, 1),
991 make_tuple(scale_stride_m, 0),
992 number<1>{}, // gather load can't vectorize
993 number<1>{}),
996 {0, 0}, // offset m is included in gather offsets
997 output_acc_tile_distr,
998 scale_m_offsets);
999
1000 auto scale_n_window = make_tile_window(
1002 kargs.scale_n.ptr + expert_id * kargs.N,
1003 make_tuple(1, kargs.N),
1004 make_tuple(0, scale_stride_n),
1005 number < ScaleGranularityN == 1 ? FlatmmPipeline::GetVectorSizeB() : 1 > {},
1006 number<1>{}), // MXF4_Pipeline does't use scale_n, so there is no need to
1007 // permute as n_pack
1009 number < IsGateUp ? TilePartitioner::NPerBlock / 2
1010 : TilePartitioner::NPerBlock > {}),
1011 {0, IsGateUp ? coord_n / 2 : coord_n},
1012 output_acc_tile_distr);
1013
1014 auto scale_n_up_window = make_tile_window(
1016 kargs.scale_n.ptr + expert_id * kargs.N + kargs.N / 2,
1017 make_tuple(1, kargs.N),
1018 make_tuple(0, scale_stride_n),
1019 number < ScaleGranularityN == 1 ? FlatmmPipeline::GetVectorSizeB() : 1 > {},
1020 number<1>{}),
1022 number<TilePartitioner::NPerBlock / 2>{}),
1023 {0, coord_n / 2},
1024 output_acc_tile_distr);
1025
1027 kargs.exp_bias.ptr + expert_id * kargs.N,
1028 make_tuple(1, kargs.N),
1029 make_tuple(0, scale_stride_n),
1030 number<FlatmmPipeline::GetVectorSizeB()>{},
1031 number<1>{});
1032
1033 auto exp_bias_window = make_tile_window(
1034 permute_tensor_view(exp_bias_view, number<(MXFP4_Pipeline && !IsInputGemm)>{}),
1036 number < IsGateUp ? TilePartitioner::NPerBlock / 2
1037 : TilePartitioner::NPerBlock > {}),
1038 {0, IsGateUp ? coord_n / 2 : coord_n},
1039 output_acc_tile_distr);
1040
1041 auto exp_bias_up_window =
1043 kargs.exp_bias.ptr + expert_id * kargs.N + kargs.N / 2,
1044 make_tuple(1, kargs.N),
1045 make_tuple(0, scale_stride_n),
1046 number<FlatmmPipeline::GetVectorSizeB()>{},
1047 number<1>{}),
1049 number<TilePartitioner::NPerBlock / 2>{}),
1050 {0, coord_n / 2},
1051 output_acc_tile_distr);
1052
1053 auto exp_weight_window =
1055 static_cast<const float*>(kargs.p_sorted_expert_weights),
1056 make_tuple(kargs.M, 1),
1057 make_tuple(1, 0),
1058 number<FlatmmPipeline::GetVectorSizeA()>{},
1059 number<1>{}),
1062 {coord_m, 0},
1063 output_acc_tile_distr);
1064
1065 using ScaleMBuffer = decltype(load_tile(scale_m_window));
1066 using ScaleNBuffer = decltype(load_tile(scale_n_window));
1067 using ExpBiasBuffer = decltype(load_tile(exp_bias_window));
1068 using ExpWeightBuffer = decltype(load_tile(exp_weight_window));
1069
1070 ScaleMBuffer scale_m_buffer;
1071 ScaleNBuffer scale_n_buffer, scale_n_up_buffer;
1072
1073 ExpBiasBuffer exp_bias_buffer, exp_bias_up_buffer;
1074 ExpWeightBuffer exp_weight_buffer;
1075
1076 if constexpr(!MXFP4_Pipeline)
1077 {
1078 scale_m_window.load(scale_m_buffer);
1079 scale_n_buffer = load_tile(scale_n_window);
1080 if constexpr(IsGateUp)
1081 scale_n_up_buffer = load_tile(scale_n_up_window);
1082 }
1083
1084 if constexpr(EnableBias)
1085 {
1086 exp_bias_buffer = load_tile(exp_bias_window);
1087 if constexpr(IsGateUp)
1088 exp_bias_up_buffer = load_tile(exp_bias_up_window);
1089 }
1090 if constexpr(!IsInputGemm)
1091 exp_weight_buffer = load_tile(exp_weight_window);
1092
1093 auto in_lds_window = make_tile_window(
1094 o_lds_block,
1096 {0, 0});
1097
1098 auto out_lds_window = make_tile_window(
1099 o_lds_block,
1101 {0, 0});
1102
1106
1107 constexpr index_t num_access = SFC::get_num_of_access();
1108
1109 static_assert(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>,
1110 "Currently, the CShuffle EpiloguePipeline only supports the Row Major "
1111 "Output layout");
1112
1113 using TileEncodingPattern = tile_distribution_encoding_pattern_2d<
1114 kBlockSize,
1115 MPerIterationShuffle,
1116 LDS_NPerIterationShuffle,
1117 kind == MoeFlatmmKind::kFFN_gemm2 ? 2 : EpiloguePipeline::GetVectorSizeC(),
1119 EpiProblem::kNumWaveGroups>;
1120
1121 constexpr auto dram_tile_distribution =
1122 TileEncodingPattern::make_2d_static_tile_distribution();
1123
1124 constexpr auto LdsTileDistr = [&] {
1125 if constexpr(IsGateUp)
1129 sequence<>,
1131 // merge two contiguous N
1136 sequence<0, 0>>{},
1137 typename CWarpDstr::DstrEncode{}));
1138 else
1140 EpiloguePipeline::MakeLdsDistributionEncode());
1141 }();
1142
1143 using LDSTileTensor =
1144 decltype(make_static_distributed_tensor<AccDataType>(LdsTileDistr));
1145 LDSTileTensor lds_tile[2];
1146
1147 constexpr auto c_warp_y_lengths =
1148 to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
1149 constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
1150 constexpr int ActVectorSize = c_warp_y_lengths.product() * NumMXdlPerWavePerShuffle *
1151 OutputNumNXdlPerWavePerShuffle;
1152
1153 auto epi_tile_idx_slice =
1154 [&](const auto& acc_tile_like_tensor, auto epi_m_idx, auto epi_n_idx) {
1155 return acc_tile_like_tensor.get_y_sliced_thread_data(
1156 merge_sequences(sequence<epi_m_idx * NumMXdlPerWavePerShuffle,
1157 epi_n_idx * OutputNumNXdlPerWavePerShuffle>{},
1158 c_warp_y_index_zeros),
1161 c_warp_y_lengths));
1162 };
1163
1164 auto gate_up_epi_tile_idx_interleave_slice = [&](auto& dest_gate_tensor,
1165 auto& dest_up_tensor,
1166 const auto& acc_tile_like_tensor,
1167 auto epi_m_idx,
1168 auto epi_n_idx) {
1170 dest_gate_tensor.set_y_sliced_thread_data(
1171 merge_sequences(sequence<0, n_xdl>{}, c_warp_y_index_zeros),
1173 acc_tile_like_tensor.get_y_sliced_thread_data(
1175 sequence<epi_m_idx * NumMXdlPerWavePerShuffle,
1176 epi_n_idx * NumNXdlPerWavePerShuffle + 2 * n_xdl>{},
1177 c_warp_y_index_zeros),
1179 c_warp_y_lengths)));
1180 dest_up_tensor.set_y_sliced_thread_data(
1181 merge_sequences(sequence<0, n_xdl>{}, c_warp_y_index_zeros),
1183 acc_tile_like_tensor.get_y_sliced_thread_data(
1185 sequence<epi_m_idx * NumMXdlPerWavePerShuffle,
1186 epi_n_idx * NumNXdlPerWavePerShuffle + 2 * n_xdl + 1>{},
1187 c_warp_y_index_zeros),
1189 c_warp_y_lengths)));
1190 });
1191 };
1192
1193 auto process_epi_tile = [&](auto lds_stage, auto epi_m, auto epi_n) {
1194 if constexpr(IsGateUp)
1195 {
1196 LDSTileTensor gate_tensor, up_tensor;
1197
1198 gate_up_epi_tile_idx_interleave_slice(
1199 gate_tensor, up_tensor, c_block_tile, epi_m, epi_n);
1200 auto epi_scale_m = epi_tile_idx_slice(scale_m_buffer, epi_m, epi_n);
1201 auto epi_scale_n = epi_tile_idx_slice(scale_n_buffer, epi_m, epi_n);
1202 auto epi_scale_n_up = epi_tile_idx_slice(scale_n_up_buffer, epi_m, epi_n);
1203
1204 auto epi_exp_bias = epi_tile_idx_slice(exp_bias_buffer, epi_m, epi_n);
1205 auto epi_exp_bias_up = epi_tile_idx_slice(exp_bias_up_buffer, epi_m, epi_n);
1206
1207 static_for<0, ActVectorSize, 1>{}([&](auto idx) {
1208 if constexpr(!MXFP4_Pipeline)
1209 {
1210 gate_tensor.get_thread_buffer()[idx] *=
1211 epi_scale_m[idx] * epi_scale_n[idx];
1212 up_tensor.get_thread_buffer()[idx] *=
1213 epi_scale_m[idx] * epi_scale_n_up[idx];
1214 }
1215 if constexpr(EnableBias)
1216 {
1217 gate_tensor.get_thread_buffer()[idx] += epi_exp_bias[idx];
1218 up_tensor.get_thread_buffer()[idx] += epi_exp_bias_up[idx];
1219 }
1220 lds_tile[lds_stage].get_thread_buffer().at(idx) =
1221 ActivationOp{}(gate_tensor.get_thread_buffer().at(idx),
1222 up_tensor.get_thread_buffer().at(idx));
1223 });
1224 }
1225 else
1226 {
1227 lds_tile[lds_stage].get_thread_buffer() =
1228 epi_tile_idx_slice(c_block_tile, epi_m, epi_n);
1229 auto epi_scale_m = epi_tile_idx_slice(scale_m_buffer, epi_m, epi_n);
1230 auto epi_scale_n = epi_tile_idx_slice(scale_n_buffer, epi_m, epi_n);
1231 auto epi_exp_weight = epi_tile_idx_slice(exp_weight_buffer, epi_m, epi_n);
1232 auto epi_exp_bias = epi_tile_idx_slice(exp_bias_buffer, epi_m, epi_n);
1233
1234 static_for<0, ActVectorSize, 1>{}([&](auto idx) {
1235 if constexpr(!MXFP4_Pipeline)
1236 lds_tile[lds_stage].get_thread_buffer()[idx] *=
1237 epi_scale_m[idx] * epi_scale_n[idx];
1238 if constexpr(EnableBias)
1239 lds_tile[lds_stage].get_thread_buffer()[idx] += epi_exp_bias[idx];
1240 if constexpr(!IsInputGemm)
1241 lds_tile[lds_stage].get_thread_buffer()[idx] *= epi_exp_weight[idx];
1242 else // for mlp1 gate-only
1243 lds_tile[lds_stage].get_thread_buffer()[idx] =
1244 ActivationOp{}(lds_tile[lds_stage].get_thread_buffer()[idx]);
1245 });
1246 }
1247 };
1248
1249 constexpr int NumMEpiTile = MRepeat / NumMXdlPerWavePerShuffle;
1250 constexpr int MPerThread = TileEncodingPattern::Y2;
1252 c_scatter_offsets;
1253 auto c_coord = dram_tile_distribution.calculate_index();
1254 static_for<0, NumMEpiTile, 1>{}([&](auto mIter) {
1255 static_for<0, MPerThread, 1>{}([&](auto m0) {
1256 auto row_idx = coord_m + mIter * MPerIterationShuffle + c_coord[0] + m0;
1257 auto fused_token =
1258 kargs.p_sorted_token_ids[row_idx]; // topk-idx[31:24] + token_idx[23:0]
1259
1260 index_t scatter_token_id = fused_token & token_id_mask;
1261 if constexpr(IsInputGemm)
1262 scatter_token_id =
1263 scatter_token_id * kargs.TopK + (fused_token >> token_id_offset);
1264 c_scatter_offsets[mIter][m0] = scatter_token_id * kargs.stride_C;
1265 });
1266 });
1267
1268 //===----------------------------------------------------------------------===//
1269 // Pingpong process start
1270 //===----------------------------------------------------------------------===//
1271 process_epi_tile(number<0>{}, number<0>{}, number<0>{});
1272
1273 static_for<0, num_access, 1>{}([&](auto iAccess) {
1274 constexpr int read_stage = iAccess % 2;
1275 constexpr int write_stage = read_stage ^ 1;
1276
1278 constexpr auto idx_y_start = SFC::get_index(number<iAccess.value>{});
1279 constexpr auto mIter = number<idx_y_start.at(number<0>{}) / MPerIterationShuffle>{};
1280
1281 const auto c_warptile_in_tensor_casted = cast_tile<ODataType>(lds_tile[read_stage]);
1282
1283 store_tile(in_lds_window, c_warptile_in_tensor_casted);
1284
1285 if constexpr(iAccess < num_access - 1)
1286 {
1287 constexpr auto idx_y_start_next = SFC::get_index(number<iAccess.value + 1>{});
1288 constexpr auto mIter_next =
1289 number<idx_y_start_next.at(number<0>{}) / MPerIterationShuffle>{};
1290 constexpr auto nIter_next =
1291 number<idx_y_start_next.at(number<1>{}) / NPerIterationShuffle>{};
1292
1293 process_epi_tile(number<write_stage>{}, mIter_next, nIter_next);
1294 }
1295
1297
1298 auto c_out_tensor =
1299 load_tile(make_tile_window(out_lds_window, dram_tile_distribution));
1300 auto c_scatter_tile_window =
1301 make_tile_scatter_gather(c_block_window.get_bottom_tensor_view(),
1302 c_block_window.get_window_lengths(),
1303 c_block_window.get_window_origin(),
1304 dram_tile_distribution,
1305 c_scatter_offsets[mIter]);
1306
1307 if constexpr(!IsInputGemm ||
1308 EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add)
1309 c_scatter_tile_window.update(c_out_tensor);
1310 else
1311 c_scatter_tile_window.store(c_out_tensor);
1312
1313 if constexpr(iAccess != num_access - 1)
1314 {
1315 constexpr auto step = SFC::get_forward_step(iAccess);
1316 // row_offset of out windows has been included in scatter offset
1317 move_tile_window(c_block_window,
1318 {0, step.at(number<1>{}) / number < IsGateUp ? 2 : 1 > {}});
1319 }
1320 });
1321 }
1322 }
1323};
1324
1325} // namespace ck_tile
#define CK_TILE_DEVICE
Definition config.hpp:41
#define CK_TILE_HOST
Definition config.hpp:40
#define CK_TILE_HOST_DEVICE
Definition config.hpp:42
CK_TILE_HOST_DEVICE constexpr auto make_embed_tile_distribution_encoding(OuterDstr, InnerDstr)
Definition tile_distribution_encoding.hpp:457
Definition moe_flatmm_kernel.hpp:137
Definition tile/core/algorithm/cluster_descriptor.hpp:13
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
CK_TILE_HOST_DEVICE constexpr auto make_naive_tensor_view(DataType *__restrict__ p, const tuple< Lengths... > &lengths, const tuple< Strides... > &strides, number< GuaranteedLastDimensionVectorLength >=number<-1 >{}, number< GuaranteedLastDimensionVectorStride >=number<-1 >{})
Definition tensor_view.hpp:471
CK_TILE_HOST_DEVICE constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition coordinate_transform.hpp:1558
CK_TILE_HOST_DEVICE constexpr index_t get_warp_size()
Definition arch.hpp:63
void CK_TILE_ERROR(Args &&... args) noexcept
Definition tile/core/utility/env.hpp:12
CK_TILE_HOST_DEVICE constexpr auto make_tensor_view(DataType *__restrict__ p, const tensor_descriptor< Ts... > &desc)
Definition tensor_view.hpp:452
__global__ void kentry(Args... args)
Definition tile/host/kernel_launch.hpp:22
memory_operation_enum
Definition arch.hpp:56
@ atomic_add
Definition arch.hpp:58
@ set
Definition arch.hpp:57
CK_TILE_HOST_DEVICE constexpr auto make_naive_tensor_descriptor(const tuple< Lengths... > &lengths, const tuple< Strides... > &strides, number< GuaranteedLastDimensionVectorLength >=number<-1 >{}, number< GuaranteedLastDimensionVectorStride >=number<-1 >{})
Definition tile/core/tensor/tensor_descriptor.hpp:274
CK_TILE_DEVICE void block_sync_lds()
Definition arch.hpp:282
std::string gemm_prec_str()
Definition utils.hpp:31
auto concat(const Ts &... xs) -> std::enable_if_t<!AllConvertibleToStringView< Ts... >, std::string >
Definition concat.hpp:43
CK_TILE_HOST_DEVICE constexpr auto make_unmerge_transform(const UpLengths &up_lengths, bool_constant< Use24BitIntegerCalculation >=bool_constant< false >{})
Definition coordinate_transform.hpp:1622
int64_t long_index_t
Definition integer.hpp:11
CK_TILE_HOST_DEVICE constexpr auto make_static_distributed_tensor(const StaticTileDistribution &)
Definition static_distributed_tensor.hpp:142
CK_TILE_HOST_DEVICE constexpr auto merge_sequences(Seqs...)
Definition tile/core/container/sequence.hpp:826
constant< v > number
Definition tile/core/numeric/integral_constant.hpp:37
@ thread_raked
Thread raked pattern.
Definition static_encoding_pattern.hpp:94
CK_TILE_DEVICE constexpr auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition null_tile_window.hpp:75
CK_TILE_DEVICE auto cast_tile(const SrcTensor &src_tensor)
Definition tile_elementwise.hpp:327
CK_TILE_HOST_DEVICE constexpr auto pad_tensor_view(const TensorView &tensor_view, const TileLengths &tile_lengths, DoPads)
Definition tensor_view.hpp:530
CK_TILE_DEVICE bfloat16_t exp(bfloat16_t x)
Definition bfloat16.hpp:419
CK_TILE_HOST_DEVICE constexpr auto to_sequence(tuple< number< Is >... >)
Definition tile/core/container/sequence.hpp:1055
CK_TILE_DEVICE void move_tile_window(null_tile_window< WindowLengths > &, const typename null_tile_window< WindowLengths >::BottomTensorIndex &)
Definition null_tile_window.hpp:95
MoeFlatmmKind
Definition moe_flatmm_kernel.hpp:131
@ kFFN_gemm1_gate_up
Definition moe_flatmm_kernel.hpp:133
@ kFFN_gemm2
Definition moe_flatmm_kernel.hpp:134
@ kFFN_gemm1_gate_only
Definition moe_flatmm_kernel.hpp:132
CK_TILE_DEVICE constexpr auto make_tile_scatter_gather(const TensorView_ &tensor_view, const WindowLengths_ &window_lengths, const multi_index< TensorView_::get_num_of_dimension()> &origin, const StaticTileDistribution_ &tile_distribution, const StaticPageIndexArray_ &page_idx, number< HsGatherDim >={}, number< NumCoord >={})
Definition tile_scatter_gather.hpp:906
CK_TILE_HOST_DEVICE constexpr auto make_merge_transform_v3_division_mod(const LowLengths &low_lengths)
Definition coordinate_transform.hpp:1609
CK_TILE_HOST_DEVICE constexpr T max(T x)
Definition tile/core/numeric/math.hpp:161
CK_TILE_HOST_DEVICE constexpr T min(T x)
Definition tile/core/numeric/math.hpp:210
typename uniform_sequence_gen< NSize, I >::type uniform_sequence_gen_t
Definition tile/core/container/sequence.hpp:1026
CK_TILE_DEVICE void store_tile(tile_window_with_static_lengths< BottomTensorView_, WindowLengths_ > &tile_window_tmp, const static_distributed_tensor< DataType_, TileDistribution_ > &dstr_tensor)
Definition store_tile.hpp:23
int32_t index_t
Definition integer.hpp:9
CK_TILE_HOST_DEVICE constexpr auto transform_tensor_view(const OldTensorView &old_tensor_view, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_view.hpp:511
CK_TILE_HOST_DEVICE constexpr auto make_static_tile_distribution(StaticTileDistributionEncoding_)
Definition tile_distribution.hpp:480
CK_TILE_HOST_DEVICE constexpr Y type_convert(X x)
Definition tile/core/numeric/type_convert.hpp:29
CK_TILE_DEVICE auto load_tile(const TileWindow_ &tile_window, number< i_access >={}, bool_constant< oob_conditional_check >={})
Definition load_tile.hpp:22
CK_TILE_HOST_DEVICE constexpr auto make_tuple(Xs &&... xs)
Definition tile/core/container/tuple.hpp:360
tuple_array< T, N > statically_indexed_array
Definition tile/core/container/statically_indexed_array.hpp:16
unsigned int uint32_t
Definition stdint.h:126
index_t N
Definition flatmm_kernel.hpp:170
const void * a_ptr
Definition flatmm_kernel.hpp:161
index_t stride_B
Definition flatmm_kernel.hpp:173
index_t stride_C
Definition flatmm_kernel.hpp:178
index_t K
Definition flatmm_kernel.hpp:171
const void * b_ptr
Definition flatmm_kernel.hpp:162
index_t k_batch
Definition flatmm_kernel.hpp:181
index_t stride_A
Definition flatmm_kernel.hpp:172
void * e_ptr
Definition flatmm_kernel.hpp:166
index_t M
Definition flatmm_kernel.hpp:169
Definition flatmm_kernel.hpp:33
Definition moe_flatmm_kernel.hpp:21
ck_tile::index_t NumExperts
Definition moe_flatmm_kernel.hpp:23
const void * p_sorted_expert_weights
Definition moe_flatmm_kernel.hpp:28
const ck_tile::index_t * p_max_token_id
Definition moe_flatmm_kernel.hpp:27
ck_tile::index_t NumTokens
Definition moe_flatmm_kernel.hpp:22
const ck_tile::index_t * p_sorted_expert_ids
Definition moe_flatmm_kernel.hpp:26
ExpertBias exp_bias
Definition moe_flatmm_kernel.hpp:31
const ck_tile::index_t n_padded_zeros
Definition moe_flatmm_kernel.hpp:29
const ck_tile::index_t * p_sorted_token_ids
Definition moe_flatmm_kernel.hpp:25
const ck_tile::index_t k_padded_zeros
Definition moe_flatmm_kernel.hpp:30
CK_TILE_HOST MoeFlatmmHostArgs(const ck_tile::index_t *p_sorted_token_ids_, const void *p_sorted_expert_weights_, const ck_tile::index_t *p_sorted_expert_ids_, const ck_tile::index_t *p_max_token_id_, const void *a_ptr_, const void *b_ptr_, void *c_ptr_, ck_tile::index_t NumTokens_, ck_tile::index_t NumExperts_, ck_tile::index_t TopK_, ck_tile::index_t k_batch_, ck_tile::index_t M_, ck_tile::index_t N_, ck_tile::index_t K_, ck_tile::index_t stride_A_, ck_tile::index_t stride_B_, ck_tile::index_t stride_C_, ck_tile::index_t n_padded_zeros_=0, ck_tile::index_t k_padded_zeros_=0, ScaleM scale_m_={}, ScaleN scale_n_={}, ExpertBias exp_bias_={})
Definition moe_flatmm_kernel.hpp:80
CK_TILE_HOST MoeFlatmmHostArgs() noexcept=default
ck_tile::index_t TopK
Definition moe_flatmm_kernel.hpp:24
Definition moe_flatmm_kernel.hpp:257
ck_tile::index_t K
Definition moe_flatmm_kernel.hpp:269
ExpertBias exp_bias
Definition moe_flatmm_kernel.hpp:278
ck_tile::index_t stride_B
Definition moe_flatmm_kernel.hpp:271
ScaleM scale_m
Definition moe_flatmm_kernel.hpp:276
ck_tile::index_t k_padded_zeros
Definition moe_flatmm_kernel.hpp:275
const void * b_ptr
Definition moe_flatmm_kernel.hpp:263
ck_tile::index_t stride_A
Definition moe_flatmm_kernel.hpp:270
ck_tile::index_t k_batch
Definition moe_flatmm_kernel.hpp:273
ck_tile::index_t stride_C
Definition moe_flatmm_kernel.hpp:272
void * e_ptr
Definition moe_flatmm_kernel.hpp:264
const ck_tile::index_t * p_max_token_id
Definition moe_flatmm_kernel.hpp:260
ScaleN scale_n
Definition moe_flatmm_kernel.hpp:277
ck_tile::index_t NumTokens
Definition moe_flatmm_kernel.hpp:265
ck_tile::index_t M
Definition moe_flatmm_kernel.hpp:267
ck_tile::index_t n_padded_zeros
Definition moe_flatmm_kernel.hpp:274
ck_tile::index_t TopK
Definition moe_flatmm_kernel.hpp:266
const ck_tile::index_t * p_sorted_token_ids
Definition moe_flatmm_kernel.hpp:258
const ck_tile::index_t * p_sorted_expert_ids
Definition moe_flatmm_kernel.hpp:259
const void * a_ptr
Definition moe_flatmm_kernel.hpp:262
ck_tile::index_t N
Definition moe_flatmm_kernel.hpp:268
const void * p_sorted_expert_weights
Definition moe_flatmm_kernel.hpp:261
Definition moe_flatmm_kernel.hpp:368
index_t splitted_k
Definition moe_flatmm_kernel.hpp:406
index_t b_k_split_offset
Definition moe_flatmm_kernel.hpp:405
__device__ SplitKBatchOffset(const KernelArgs &kargs, const std::size_t k_id=blockIdx.z)
Definition moe_flatmm_kernel.hpp:370
index_t a_k_split_offset
Definition moe_flatmm_kernel.hpp:404
Definition moe_flatmm_kernel.hpp:192
static constexpr int OutputNPerBlock
Definition moe_flatmm_kernel.hpp:240
remove_cvref_t< typename EpiloguePipeline::ODataType > EDataType
Definition moe_flatmm_kernel.hpp:209
remove_cvref_t< typename EpiloguePipeline::DsDataType > DsDataType
Definition moe_flatmm_kernel.hpp:202
remove_cvref_t< typename FlatmmPipeline::BlockGemmShape > BlockGemmShape
Definition moe_flatmm_kernel.hpp:195
static constexpr index_t NumDTensor
Definition moe_flatmm_kernel.hpp:214
float AccDataType
Definition moe_flatmm_kernel.hpp:211
remove_cvref_t< typename EpiloguePipeline::DsLayout > DsLayout
Definition moe_flatmm_kernel.hpp:201
static constexpr auto GridSize(const MoeFlatmmKernelArgs &kargs)
Definition moe_flatmm_kernel.hpp:323
static constexpr auto I1
Definition moe_flatmm_kernel.hpp:217
remove_cvref_t< TilePartitioner_ > TilePartitioner
Definition moe_flatmm_kernel.hpp:193
static constexpr auto I3
Definition moe_flatmm_kernel.hpp:219
static CK_TILE_HOST_DEVICE constexpr index_t GetSmemPongSize()
Definition moe_flatmm_kernel.hpp:362
static constexpr index_t kBlockSize
Definition moe_flatmm_kernel.hpp:203
static constexpr bool IsInputGemm
Definition moe_flatmm_kernel.hpp:224
static CK_TILE_DEVICE auto MakeGemmPadViews(const TensorView &views)
Definition moe_flatmm_kernel.hpp:654
remove_cvref_t< typename FlatmmPipeline::ALayout > ALayout
Definition moe_flatmm_kernel.hpp:198
remove_cvref_t< EpiloguePipeline_ > EpiloguePipeline
Definition moe_flatmm_kernel.hpp:197
static constexpr int MXFP4N_Pack
Definition moe_flatmm_kernel.hpp:245
static constexpr auto GridSize(index_t M, index_t N, index_t KBatch)
Definition moe_flatmm_kernel.hpp:318
static constexpr bool UsePersistentKernel
Definition moe_flatmm_kernel.hpp:204
FusedActivation ActivationOp
Definition moe_flatmm_kernel.hpp:212
remove_cvref_t< typename FlatmmPipeline::BDataType > BDataType
Definition moe_flatmm_kernel.hpp:207
remove_cvref_t< typename FlatmmPipeline::BLayout > BLayout
Definition moe_flatmm_kernel.hpp:199
static constexpr bool MXFP4_Pipeline
Definition moe_flatmm_kernel.hpp:244
remove_cvref_t< typename FlatmmPipeline::ADataType > ADataType
Definition moe_flatmm_kernel.hpp:206
remove_cvref_t< typename FlatmmPipeline::CLayout > ELayout
Definition moe_flatmm_kernel.hpp:200
static constexpr index_t kMPerBlock
Definition moe_flatmm_kernel.hpp:228
static constexpr index_t MWave
Definition moe_flatmm_kernel.hpp:230
static constexpr index_t KPerXdl
Definition moe_flatmm_kernel.hpp:234
static constexpr auto BlockSize() -> dim3
Definition moe_flatmm_kernel.hpp:316
static constexpr bool IsGateUp
Definition moe_flatmm_kernel.hpp:225
static constexpr index_t kNPerBlock
Definition moe_flatmm_kernel.hpp:229
static CK_TILE_HOST_DEVICE constexpr index_t GetSmemPingSize()
Definition moe_flatmm_kernel.hpp:358
static CK_TILE_HOST const std::string GetName()
Definition moe_flatmm_kernel.hpp:310
static constexpr index_t NPerXdl
Definition moe_flatmm_kernel.hpp:233
static constexpr index_t kNPerIteration
Definition moe_flatmm_kernel.hpp:237
static constexpr index_t kMPerIteration
Definition moe_flatmm_kernel.hpp:236
static constexpr int WeightPackedSize
Definition moe_flatmm_kernel.hpp:251
static constexpr auto I0
Definition moe_flatmm_kernel.hpp:216
static CK_TILE_HOST bool IsSupportedArgument(const KernelArgs &kargs)
Definition moe_flatmm_kernel.hpp:410
static constexpr index_t isCTransposed
Definition moe_flatmm_kernel.hpp:235
static constexpr int K_Pack
Definition moe_flatmm_kernel.hpp:249
CK_TILE_DEVICE void operator()(MoeFlatmmKernelArgs kargs) const
Definition moe_flatmm_kernel.hpp:755
static constexpr int N_Pack
Definition moe_flatmm_kernel.hpp:248
static CK_TILE_DEVICE auto MakeGemmTileWindows(const PadView &views, const index_t coord_m, const index_t coord_n)
Definition moe_flatmm_kernel.hpp:697
static constexpr int MXFP4K_Pack
Definition moe_flatmm_kernel.hpp:246
static constexpr index_t kNRepeat
Definition moe_flatmm_kernel.hpp:238
static CK_TILE_HOST constexpr auto MakeKernelArgs(const MoeFlatmmHostArgs< ScaleM, ScaleN, ExpertBias > &hostArgs)
Definition moe_flatmm_kernel.hpp:285
static constexpr index_t MPerXdl
Definition moe_flatmm_kernel.hpp:232
static CK_TILE_DEVICE auto MakeGemmTensorViews(const ADataType *a_ptr, const BDataType *b_flat_ptr, EDataType *e_ptr, const AccDataType *exp_weight_ptr, const int expert_id, const KernelArgs &kargs, const SplitKBatchOffset &splitk_batch_offset)
Definition moe_flatmm_kernel.hpp:567
CK_TILE_DEVICE void operator()(MoeFlatmmKernelArgs kargs, index_t iM, index_t iN) const
Definition moe_flatmm_kernel.hpp:770
static constexpr index_t NWave
Definition moe_flatmm_kernel.hpp:231
static constexpr auto I2
Definition moe_flatmm_kernel.hpp:218
remove_cvref_t< FlatmmPipeline_ > FlatmmPipeline
Definition moe_flatmm_kernel.hpp:194
CK_TILE_HOST ScaleFlatmmHostArgs()=default
ScaleM scale_m
Definition flatmm_kernel.hpp:219
ScaleN scale_n
Definition flatmm_kernel.hpp:220
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:1014
Definition type_traits.hpp:115
Definition moe_flatmm_kernel.hpp:140
CK_TILE_HOST_DEVICE T operator()(T gate, T linear=1) const
Definition moe_flatmm_kernel.hpp:142
const float alpha
Definition moe_flatmm_kernel.hpp:151
const float limit
Definition moe_flatmm_kernel.hpp:152
CK_TILE_HOST_DEVICE Swiglu(float alpha_=1.702f, float limit_=7.0f)
Definition moe_flatmm_kernel.hpp:155
CK_TILE_HOST_DEVICE T operator()(T gate, T linear) const
Definition moe_flatmm_kernel.hpp:161
static constexpr int PackedSize
Definition tile/core/numeric/numeric.hpp:82
Definition tile/core/container/sequence.hpp:49
Definition space_filling_curve.hpp:20
Definition tile/core/utility/functional.hpp:43
Class creating 2D static tile distribution with different load/store patterns.
Definition static_encoding_pattern.hpp:130
Definition tile_distribution_encoding.hpp:26
Definition tile/core/container/tuple.hpp:192