moe_sorting_problem.hpp Source File

moe_sorting_problem.hpp Source File#

Composable Kernel: moe_sorting_problem.hpp Source File
moe_sorting_problem.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include "ck_tile/core.hpp"
7#include <string>
8#include <type_traits>
9
10namespace ck_tile {
11
12template <typename IndexType_,
13 typename WeightType_,
14 index_t InternalLoadUnroll_,
15 index_t ExpertTile_ = 0>
17{
18 // TODO: this kernel only support warp per row
21
22 static constexpr index_t WarpSize = get_warp_size();
23 static constexpr index_t WarpsPerBlock = 1;
24 static constexpr index_t InternalLoadUnroll =
25 InternalLoadUnroll_; // TODO: need better design(like tile size)
26 static constexpr index_t ExpertTile = ExpertTile_; // TODO: only used in store out
27};
28
29template <typename IndexType_,
30 typename WeightType_,
31 index_t SubTokenTile_, // 1,2,4,8, or 0 in the future
32 bool SubTokenOneShot_, // if we only loop over once or not
33 bool LocalExpertMasking_, // used in EP case
34 bool LocalToken_, // used in EP case
35 bool SkipExpertsWithZeroTokens_ = true,
36 index_t ExpertTile_ = 0>
38{
39 // TODO: this kernel only support warp per row
42
43 static constexpr index_t WarpSize = get_warp_size();
44 static constexpr index_t WarpsPerBlock = 1;
45 static constexpr index_t SubTokenTile = SubTokenTile_;
46 static constexpr bool SubTokenOneShot = SubTokenOneShot_;
47 static constexpr bool LocalExpertMasking = LocalExpertMasking_;
48 static constexpr bool LocalToken = LocalToken_;
49 static constexpr bool SkipExpertsWithZeroTokens = SkipExpertsWithZeroTokens_;
50 static_assert(SubTokenTile == 1 || SubTokenTile == 2 || SubTokenTile == 4 || SubTokenTile == 8);
51 static constexpr index_t ExpertTile = ExpertTile_; // TODO: only used in store out
52};
53
54template <typename IndexType_,
55 typename WeightType_, // used for expert mesh in ws
56 typename MeshType_,
57 index_t SubTokenTile_, // 1,2,4,8
58 bool LocalExpertMasking_, // used in EP case
59 bool LocalToken_, // used in EP case
60 bool SkipExpertsWithZeroTokens_ = true>
62{
63 // TODO: this kernel only support warp per row
67
68 static constexpr index_t SubTokenTile = SubTokenTile_;
69 static constexpr bool LocalExpertMasking = LocalExpertMasking_;
70 static constexpr bool LocalToken = LocalToken_;
71 static constexpr bool SkipExpertsWithZeroTokens = SkipExpertsWithZeroTokens_;
72 static_assert(SubTokenTile == 1 || SubTokenTile == 2 || SubTokenTile == 4 ||
73 SubTokenTile == 8 || SubTokenTile == 16);
74};
75
76template <bool LocalToken_, index_t BlockSize_ = 1024, index_t Occu_ = 1>
78{
79 static constexpr bool LocalToken = LocalToken_;
80 static constexpr index_t BlockSize = BlockSize_;
81 static constexpr index_t Occu = Occu_;
82};
83
84} // namespace ck_tile
Definition tile/core/algorithm/cluster_descriptor.hpp:13
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
CK_TILE_HOST_DEVICE constexpr index_t get_warp_size()
Definition arch.hpp:63
int32_t index_t
Definition integer.hpp:9
Definition moe_sorting_problem.hpp:78
static constexpr bool LocalToken
Definition moe_sorting_problem.hpp:79
static constexpr index_t BlockSize
Definition moe_sorting_problem.hpp:80
static constexpr index_t Occu
Definition moe_sorting_problem.hpp:81
Definition moe_sorting_problem.hpp:38
static constexpr bool LocalToken
Definition moe_sorting_problem.hpp:48
remove_cvref_t< IndexType_ > IndexType
Definition moe_sorting_problem.hpp:41
static constexpr index_t WarpSize
Definition moe_sorting_problem.hpp:43
static constexpr bool SkipExpertsWithZeroTokens
Definition moe_sorting_problem.hpp:49
remove_cvref_t< WeightType_ > WeightType
Definition moe_sorting_problem.hpp:40
static constexpr bool LocalExpertMasking
Definition moe_sorting_problem.hpp:47
static constexpr index_t WarpsPerBlock
Definition moe_sorting_problem.hpp:44
static constexpr index_t ExpertTile
Definition moe_sorting_problem.hpp:51
static constexpr bool SubTokenOneShot
Definition moe_sorting_problem.hpp:46
static constexpr index_t SubTokenTile
Definition moe_sorting_problem.hpp:45
Definition moe_sorting_problem.hpp:17
remove_cvref_t< WeightType_ > WeightType
Definition moe_sorting_problem.hpp:19
static constexpr index_t WarpsPerBlock
Definition moe_sorting_problem.hpp:23
static constexpr index_t ExpertTile
Definition moe_sorting_problem.hpp:26
remove_cvref_t< IndexType_ > IndexType
Definition moe_sorting_problem.hpp:20
static constexpr index_t WarpSize
Definition moe_sorting_problem.hpp:22
static constexpr index_t InternalLoadUnroll
Definition moe_sorting_problem.hpp:24
Definition moe_sorting_problem.hpp:62
remove_cvref_t< IndexType_ > IndexType
Definition moe_sorting_problem.hpp:66
remove_cvref_t< WeightType_ > WeightType
Definition moe_sorting_problem.hpp:64
static constexpr index_t SubTokenTile
Definition moe_sorting_problem.hpp:68
static constexpr bool LocalToken
Definition moe_sorting_problem.hpp:70
static constexpr bool SkipExpertsWithZeroTokens
Definition moe_sorting_problem.hpp:71
static constexpr bool LocalExpertMasking
Definition moe_sorting_problem.hpp:69
remove_cvref_t< MeshType_ > MeshType
Definition moe_sorting_problem.hpp:65