14template <
typename Problem_,
typename Policy_ = BlockGemmARegBSmemCRegV1DefaultPolicy>
28 template <
typename CBlockTensor,
typename ABlockTensorTmp,
typename BBlockWindowTmp>
30 const ABlockTensorTmp& a_block_tensor_tmp,
31 const BBlockWindowTmp& b_block_window_tmp)
const
34 std::is_same_v<ADataType, remove_cv_t<typename ABlockTensorTmp::DataType>> &&
35 std::is_same_v<BDataType, remove_cv_t<typename BBlockWindowTmp::DataType>> &&
36 std::is_same_v<CDataType, remove_cv_t<typename CBlockTensor::DataType>>,
42 constexpr index_t MPerBlock = BlockGemmShape::kM;
43 constexpr index_t NPerBlock = BlockGemmShape::kN;
44 constexpr index_t KPerBlock = BlockGemmShape::kK;
50 constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
54 constexpr index_t MWarp = config.template at<1>();
55 constexpr index_t NWarp = config.template at<2>();
57 static_assert(MWarp == 1 && NWarp == 1,
"Check failed!");
59 constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
60 constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN);
61 constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
63 constexpr index_t NPerBlockPerIter = NPerBlock / NIterPerWarp;
64 constexpr index_t KPerBlockPerIter = KPerBlock / KIterPerWarp;
68 constexpr auto c_block_outer_dstr_encoding =
77 c_block_outer_dstr_encoding,
typename WG::CWarpDstrEncoding{});
85 a_block_tensor.get_thread_buffer() = a_block_tensor_tmp.get_thread_buffer();
89 b_block_window_tmp.get_bottom_tensor_view(),
91 b_block_window_tmp.get_window_origin() +
multi_index<2>{iNWarp * WG::kN, 0},
95 array<
array<
decltype(b_warp_window_tmp), KIterPerWarp>, NIterPerWarp> b_warp_windows{
98 for(
index_t nIter = 0; nIter < NIterPerWarp; nIter++)
100 for(
index_t kIter = 0; kIter < KIterPerWarp; kIter++)
103 {nIter * NPerBlockPerIter, kIter * KPerBlockPerIter});
114 b_warp_windows(nIter)(kIter) = b_warp_window_tmp;
117 {nIter * NPerBlockPerIter, kIter * KPerBlockPerIter});
126 .get_static_tile_distribution_encoding())>>,
129 using AWarpDstr =
typename WG::AWarpDstr;
130 using CWarpDstr =
typename WG::CWarpDstr;
132 using AWarpTensor =
typename WG::AWarpTensor;
133 using CWarpTensor =
typename WG::CWarpTensor;
135 constexpr auto a_warp_y_lengths =
136 to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
137 constexpr auto c_warp_y_lengths =
138 to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
147 AWarpTensor a_warp_tensor;
149 a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data(
155 const auto b_warp_tensor =
load_tile(b_warp_windows(nIter)(kIter));
158 CWarpTensor c_warp_tensor;
160 c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
165 WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
168 c_block_tensor.set_y_sliced_thread_data(
171 c_warp_tensor.get_thread_buffer());
177 template <index_t MPerBlock = BlockGemmShape::kM, index_t KPerBlock = BlockGemmShape::kK>
180 constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
184 constexpr index_t MWarp = config.template at<1>();
185 constexpr index_t NWarp = config.template at<2>();
187 constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
188 constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
190 constexpr auto a_block_outer_dstr_encoding =
199 a_block_outer_dstr_encoding,
typename WG::AWarpDstrEncoding{});
206 constexpr index_t MPerBlock = BlockGemmShape::kM;
207 constexpr index_t NPerBlock = BlockGemmShape::kN;
209 constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
213 constexpr index_t MWarp = config.template at<1>();
214 constexpr index_t NWarp = config.template at<2>();
216 static_assert(MWarp == 1 && NWarp == 1,
"Check failed!");
218 constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
219 constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN);
222 constexpr auto c_block_outer_dstr_encoding =
231 c_block_outer_dstr_encoding,
typename WG::CWarpDstrEncoding{});
233 static_assert(
decltype(c_block_dstr_encode)::NDimP == 1,
"Check failed!");
237 return c_block_tensor;
241 template <
typename ABlockTensorTmp,
typename BBlockWindowTmp>
243 const BBlockWindowTmp& b_block_window_tmp)
const
246 operator()(c_block_tensor, a_block_tensor_tmp, b_block_window_tmp);
247 return c_block_tensor;
#define CK_TILE_DEVICE
Definition config.hpp:41
CK_TILE_HOST_DEVICE constexpr auto make_embed_tile_distribution_encoding(OuterDstr, InnerDstr)
Definition tile_distribution_encoding.hpp:457
Definition tile/core/algorithm/cluster_descriptor.hpp:13
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
CK_TILE_HOST_DEVICE constexpr index_t get_warp_size()
Definition arch.hpp:63
CK_TILE_HOST_DEVICE constexpr auto make_static_distributed_tensor(const StaticTileDistribution &)
Definition static_distributed_tensor.hpp:142
CK_TILE_HOST_DEVICE constexpr auto merge_sequences(Seqs...)
Definition tile/core/container/sequence.hpp:826
constant< v > number
Definition tile/core/numeric/integral_constant.hpp:37
CK_TILE_DEVICE constexpr auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition null_tile_window.hpp:75
array< index_t, N > multi_index
Definition tile/core/container/multi_index.hpp:17
CK_TILE_HOST_DEVICE constexpr auto to_sequence(tuple< number< Is >... >)
Definition tile/core/container/sequence.hpp:1055
CK_TILE_DEVICE void move_tile_window(null_tile_window< WindowLengths > &, const typename null_tile_window< WindowLengths >::BottomTensorIndex &)
Definition null_tile_window.hpp:95
typename uniform_sequence_gen< NSize, I >::type uniform_sequence_gen_t
Definition tile/core/container/sequence.hpp:1026
int32_t index_t
Definition integer.hpp:9
CK_TILE_HOST_DEVICE constexpr auto make_static_tile_distribution(StaticTileDistributionEncoding_)
Definition tile_distribution.hpp:480
CK_TILE_DEVICE auto load_tile(const TileWindow_ &tile_window, number< i_access >={}, bool_constant< oob_conditional_check >={})
Definition load_tile.hpp:22
CK_TILE_HOST_DEVICE constexpr auto make_tuple(Xs &&... xs)
Definition tile/core/container/tuple.hpp:360
tuple_array< T, N > statically_indexed_array
Definition tile/core/container/statically_indexed_array.hpp:16
Definition block_gemm_areg_bsmem_creg_one_warp_v1.hpp:16
remove_cvref_t< Problem_ > Problem
Definition block_gemm_areg_bsmem_creg_one_warp_v1.hpp:17
static constexpr index_t kBlockSize
Definition block_gemm_areg_bsmem_creg_one_warp_v1.hpp:24
remove_cvref_t< typename Problem::ADataType > ADataType
Definition block_gemm_areg_bsmem_creg_one_warp_v1.hpp:19
static CK_TILE_DEVICE constexpr auto MakeCBlockTile()
Definition block_gemm_areg_bsmem_creg_one_warp_v1.hpp:204
remove_cvref_t< typename Problem::BDataType > BDataType
Definition block_gemm_areg_bsmem_creg_one_warp_v1.hpp:20
CK_TILE_DEVICE void operator()(CBlockTensor &c_block_tensor, const ABlockTensorTmp &a_block_tensor_tmp, const BBlockWindowTmp &b_block_window_tmp) const
Definition block_gemm_areg_bsmem_creg_one_warp_v1.hpp:29
remove_cvref_t< Policy_ > Policy
Definition block_gemm_areg_bsmem_creg_one_warp_v1.hpp:18
remove_cvref_t< typename Problem::BlockGemmShape > BlockGemmShape
Definition block_gemm_areg_bsmem_creg_one_warp_v1.hpp:22
remove_cvref_t< typename Problem::CDataType > CDataType
Definition block_gemm_areg_bsmem_creg_one_warp_v1.hpp:21
static CK_TILE_DEVICE constexpr auto MakeABlockTileDistribution()
Definition block_gemm_areg_bsmem_creg_one_warp_v1.hpp:178
CK_TILE_DEVICE auto operator()(const ABlockTensorTmp &a_block_tensor_tmp, const BBlockWindowTmp &b_block_window_tmp) const
Definition block_gemm_areg_bsmem_creg_one_warp_v1.hpp:242
A fixed-size array container similar to std::array with additional utilities.
Definition tile/core/container/array.hpp:43
Definition tile/core/container/sequence.hpp:49
Definition tile/core/utility/functional.hpp:43
Definition tile_distribution_encoding.hpp:26
Definition tile/core/container/tuple.hpp:192