blockwise_gemm_xdlops_skip_b_lds.hpp Source File

blockwise_gemm_xdlops_skip_b_lds.hpp Source File#

Composable Kernel: blockwise_gemm_xdlops_skip_b_lds.hpp Source File
blockwise_gemm_xdlops_skip_b_lds.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
10
11namespace ck {
12
13template <index_t BlockSize,
14 typename FloatAB,
15 typename FloatAcc,
16 typename AK0MK1BlockDesc,
17 typename BK0K0BN0N1N2N3K1BlockDesc,
18 index_t MPerBlock,
19 index_t NPerBlock,
20 index_t K0PerBlock,
21 index_t MPerXDL,
22 index_t NPerXDL,
23 index_t MRepeat,
24 index_t NRepeat,
25 index_t KPack>
27{
28 static constexpr auto I0 = Number<0>{};
29 static constexpr auto I1 = Number<1>{};
30 static constexpr auto I2 = Number<2>{};
31 static constexpr auto I3 = Number<3>{};
32
33 static constexpr index_t KPerBlock = K0PerBlock * KPack;
34
35 static constexpr index_t A_K0 = AK0MK1BlockDesc{}.GetLength(I0);
36 static constexpr index_t A_K1 = AK0MK1BlockDesc{}.GetLength(I2);
37
39
40 static constexpr index_t KPerThread = KPerBlock / xdlops_gemm.K0PerXdlops;
41 static constexpr index_t K0PerThread = K0PerBlock / xdlops_gemm.K0PerXdlops;
42
43 static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL);
44 static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL);
45 static constexpr index_t WaveSize = BlockSize / MWaves / NWaves;
46
48 FloatAcc,
49 MRepeat * NRepeat,
50 xdlops_gemm.GetRegSizePerXdlops(),
51 true>
53
54 __host__ __device__ constexpr auto& GetCThreadBuffer() { return c_thread_buf_; }
55
56 __device__ static auto GetWaveIdx()
57 {
58 const index_t thread_id = get_thread_local_1d_id();
59
60 constexpr auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor(
64
65 return threadid_to_wave_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id));
66 }
67
68 __device__ static auto CalculateAThreadOriginDataIndex()
69 {
70 const auto wave_idx = GetWaveIdx();
71
72 const auto waveId_m = wave_idx[I0];
73
74 const auto xdlops_a_idx = xdlops_gemm.CalculateAThreadOriginDataIndex();
75
76 return make_tuple(0, waveId_m, xdlops_a_idx[I1], KPerThread * xdlops_a_idx[I0]);
77 }
78
79 __device__ static auto CalculateBThreadOriginDataIndex()
80 {
81 const auto wave_idx = GetWaveIdx();
82
83 const auto waveId_n = wave_idx[I1];
84
85 const auto xdlops_b_idx = xdlops_gemm.CalculateBThreadOriginDataIndex();
86
87 return make_tuple(0, waveId_n, xdlops_b_idx[I1], KPerThread * xdlops_b_idx[I0]);
88 }
89
90 template <index_t m0, index_t n0, index_t xdlops_i, index_t blk_i>
91 __device__ static auto
93 {
94 const auto wave_idx = GetWaveIdx();
95
96 const auto waveId_m = wave_idx[I0];
97 const auto waveId_n = wave_idx[I1];
98
99 const auto blk_idx = xdlops_gemm.GetBeginOfThreadBlk(xdlops_i, blk_i);
100
101 constexpr auto mrepeat_mwave_mperxdl_to_m_adaptor = make_single_stage_tensor_adaptor(
105
106 constexpr auto nrepeat_nwave_nperxdl_to_n_adaptor = make_single_stage_tensor_adaptor(
110
111 const index_t c_thread_m = mrepeat_mwave_mperxdl_to_m_adaptor.CalculateBottomIndex(
112 make_tuple(m0, waveId_m, blk_idx[I0]))[I0];
113 const index_t c_thread_n = nrepeat_nwave_nperxdl_to_n_adaptor.CalculateBottomIndex(
114 make_tuple(n0, waveId_n, blk_idx[I1]))[I0];
115
116 return make_tuple(c_thread_m, c_thread_n);
117 }
118
120 {
121 static_assert(AK0MK1BlockDesc::IsKnownAtCompileTime() &&
122 BK0K0BN0N1N2N3K1BlockDesc::IsKnownAtCompileTime(),
123 "wrong! Desc should be known at compile-time");
124
125 static_assert(BlockSize == MWaves * NWaves * WaveSize,
126 "BlockSize != MWaves * NWaves * WaveSize\n");
127
128 static_assert(MPerBlock % (MPerXDL * MRepeat) == 0 && NPerBlock % (NPerXDL * NRepeat) == 0,
129 "wrong!");
130 }
131
132 __host__ __device__ static constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()
133 {
134 constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths();
135
136 constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0];
137 constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1];
138 constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2];
139 constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3];
140
142 make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, I1, I1, M0, M1, M2, N));
143 }
144
145 __host__ __device__ static constexpr auto GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2()
146 {
147 constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths();
148
149 constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0];
150 constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1];
151 constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2];
152 constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3];
153
155 make_tuple(I1, Number<MRepeat>{}, Number<NRepeat>{}, I1, I1, M0, M1, M2, N));
156 }
157
158 __host__ __device__ static constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()
159 {
160 constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2 =
166 Number<NPerXDL>{}));
167
168 return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_block_desc_m0_n0_m1_n1_m2_n2);
169 }
170
171 __host__ __device__ static constexpr auto GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2()
172 {
173 constexpr auto c_block_desc_g_m0_n0_m1_n1_m2_n2 =
180 Number<NPerXDL>{}));
181
182 return xdlops_gemm.MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(
183 c_block_desc_g_m0_n0_m1_n1_m2_n2);
184 }
185
186 template <typename CGridDesc_M_N>
187 __host__ __device__ static constexpr auto
188 MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_M_N& c_grid_desc_m_n)
189 {
190 const auto M = c_grid_desc_m_n.GetLength(I0);
191 const auto N = c_grid_desc_m_n.GetLength(I1);
192
193 const auto c_grid_desc_m0_n0_m1_n1_m2_n2 = transform_tensor_descriptor(
194 c_grid_desc_m_n,
195 make_tuple(make_unmerge_transform(make_tuple(M / (MWaves * MPerXDL), MWaves, MPerXDL)),
196 make_unmerge_transform(make_tuple(N / (NWaves * NPerXDL), NWaves, NPerXDL))),
199
200 return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m0_n0_m1_n1_m2_n2);
201 }
202
203 template <typename CGridDesc_G_M_N>
204 __host__ __device__ static constexpr auto
205 MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_G_M_N& c_grid_desc_g_m_n)
206 {
207 const auto G = c_grid_desc_g_m_n.GetLength(I0);
208 const auto M = c_grid_desc_g_m_n.GetLength(I1);
209 const auto N = c_grid_desc_g_m_n.GetLength(I2);
210
211 const auto c_grid_desc_g_m0_n0_m1_n1_m2_n2 = transform_tensor_descriptor(
212 c_grid_desc_g_m_n,
214 make_unmerge_transform(make_tuple(M / (MWaves * MPerXDL), MWaves, MPerXDL)),
215 make_unmerge_transform(make_tuple(N / (NWaves * NPerXDL), NWaves, NPerXDL))),
218
219 return xdlops_gemm.MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(
220 c_grid_desc_g_m0_n0_m1_n1_m2_n2);
221 }
222
234
235 __device__ void MoveABlockSliceWindow()
236 {
237 a_thread_copy_.MoveSrcSliceWindow(a_block_desc_m0_m1_m2_k,
238 make_multi_index(0, 0, 0, K0PerBlock * KPack));
239 }
240 __device__ void ResetABlockStartWindow()
241 {
242 a_thread_copy_.SetSrcCoord(CalculateAThreadOriginDataIndex());
243 }
244
246
247 template <typename ABlockBuffer, typename BBlockBuffer, typename CThreadBuffer>
248 __device__ void Run(const ABlockBuffer& a_block_buf,
249 const BBlockBuffer& b_thread_buf,
250 CThreadBuffer& c_thread_buf) const
251 {
253 a_thread_desc_.GetElementSpaceSize());
254
255 static_for<0, MRepeat, 1>{}([&](auto m0) {
256 // read A
257 a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
258 make_tuple(m0, I0, I0, I0),
259 a_block_buf,
260 a_thread_desc_,
261 make_tuple(I0, I0, I0, I0),
262 a_thread_buf);
263
264 static_for<0, NRepeat, 1>{}([&](auto n0) {
265 // read B
267 vector_type<FloatAB, KPack> a_thread_vec;
268 vector_type<FloatAB, KPack> b_thread_vec;
269 constexpr index_t k0 = k / KPack;
270 static_for<0, KPack, 1>{}([&](auto i) {
271 a_thread_vec.template AsType<FloatAB>()(i) = a_thread_buf
272 [Number<a_thread_desc_.CalculateOffset(make_tuple(0, 0, 0, k + i))>{}];
273 b_thread_vec.template AsType<FloatAB>()(i) = b_thread_buf
274 [Number<b_thread_desc_.CalculateOffset(make_tuple(k0, n0, i))>{}];
275 });
276
277 using mfma_input_type =
278 typename vector_type<FloatAB, xdlops_gemm.K1PerXdlops>::type;
279
280 constexpr index_t c_offset =
281 c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
282
283 xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
284 b_thread_vec.template AsType<mfma_input_type>(),
285 c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
286 });
287 });
288 });
289 }
290
291 private:
292 // A[M0, M1, M2, KPerThread]
293 static constexpr auto a_thread_desc_ =
295
296 // B[N0, N1, N2, KPerThread]
297 static constexpr auto b_thread_desc_ =
299 Number<NRepeat>{}, // repeat
300 Number<KPack>{}));
301
302 // C[M, N, NumRegXdlops]
303 static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed(
304 make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, xdlops_gemm.GetRegSizePerXdlops()));
305
306 using AThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatAB,
307 FloatAB,
308 decltype(a_block_desc_m0_m1_m2_k),
309 decltype(a_thread_desc_),
310 Sequence<1, 1, 1, KPerThread>,
311 Sequence<0, 1, 2, 3>,
312 3,
313 A_K1,
314 A_K1>;
315
316 AThreadCopy a_thread_copy_{CalculateAThreadOriginDataIndex()};
317};
318
319} // namespace ck
Definition ck.hpp:268
__host__ __device__ constexpr auto make_multi_index(Xs &&... xs)
Definition array_multi_index.hpp:15
__host__ __device__ constexpr auto make_static_buffer(Number< N >)
Definition static_buffer.hpp:186
__host__ __device__ constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition multi_index_transform_helper.hpp:12
int32_t index_t
Definition ck.hpp:299
__host__ __device__ constexpr auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition tensor_description/tensor_adaptor.hpp:425
integral_constant< index_t, N > Number
Definition number.hpp:12
@ Vgpr
Definition amd_address_space.hpp:20
__host__ __device__ constexpr auto make_merge_transform(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:55
__host__ __device__ constexpr auto make_merge_transform_v3_division_mod(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:84
__host__ __device__ constexpr auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition tensor_descriptor_helper.hpp:101
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__device__ index_t get_thread_local_1d_id()
Definition get_id.hpp:41
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
__host__ __device__ constexpr auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:90
__host__ __device__ constexpr auto & GetCThreadBuffer()
Definition blockwise_gemm_xdlops_skip_b_lds.hpp:54
static constexpr auto I2
Definition blockwise_gemm_xdlops_skip_b_lds.hpp:30
__host__ static __device__ constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()
Definition blockwise_gemm_xdlops_skip_b_lds.hpp:158
__host__ static __device__ constexpr auto GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2()
Definition blockwise_gemm_xdlops_skip_b_lds.hpp:171
static constexpr index_t NWaves
Definition blockwise_gemm_xdlops_skip_b_lds.hpp:44
__host__ __device__ BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1r1()
Definition blockwise_gemm_xdlops_skip_b_lds.hpp:119
static constexpr auto a_block_desc_m0_m1_m2_k
Definition blockwise_gemm_xdlops_skip_b_lds.hpp:245
static constexpr auto I1
Definition blockwise_gemm_xdlops_skip_b_lds.hpp:29
static constexpr index_t KPerBlock
Definition blockwise_gemm_xdlops_skip_b_lds.hpp:33
static constexpr index_t MWaves
Definition blockwise_gemm_xdlops_skip_b_lds.hpp:43
static constexpr index_t KPerThread
Definition blockwise_gemm_xdlops_skip_b_lds.hpp:40
static __device__ auto CalculateCThreadOriginDataIndex(Number< m0 >, Number< n0 >, Number< xdlops_i >, Number< blk_i >)
Definition blockwise_gemm_xdlops_skip_b_lds.hpp:92
static constexpr auto I3
Definition blockwise_gemm_xdlops_skip_b_lds.hpp:31
static __device__ auto CalculateAThreadOriginDataIndex()
Definition blockwise_gemm_xdlops_skip_b_lds.hpp:68
static __device__ auto CalculateBThreadOriginDataIndex()
Definition blockwise_gemm_xdlops_skip_b_lds.hpp:79
static constexpr auto xdlops_gemm
Definition blockwise_gemm_xdlops_skip_b_lds.hpp:38
__host__ static __device__ constexpr auto GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2()
Definition blockwise_gemm_xdlops_skip_b_lds.hpp:145
__host__ static __device__ constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()
Definition blockwise_gemm_xdlops_skip_b_lds.hpp:132
__host__ static __device__ constexpr auto MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_M_N &c_grid_desc_m_n)
Definition blockwise_gemm_xdlops_skip_b_lds.hpp:188
static constexpr auto I0
Definition blockwise_gemm_xdlops_skip_b_lds.hpp:28
static constexpr index_t K0PerThread
Definition blockwise_gemm_xdlops_skip_b_lds.hpp:41
__device__ void Run(const ABlockBuffer &a_block_buf, const BBlockBuffer &b_thread_buf, CThreadBuffer &c_thread_buf) const
Definition blockwise_gemm_xdlops_skip_b_lds.hpp:248
static constexpr index_t WaveSize
Definition blockwise_gemm_xdlops_skip_b_lds.hpp:45
__device__ void ResetABlockStartWindow()
Definition blockwise_gemm_xdlops_skip_b_lds.hpp:240
__host__ static __device__ constexpr auto MakeABlockDescriptor_M0_M1_M2_K()
Definition blockwise_gemm_xdlops_skip_b_lds.hpp:223
__host__ static __device__ constexpr auto MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_G_M_N &c_grid_desc_g_m_n)
Definition blockwise_gemm_xdlops_skip_b_lds.hpp:205
__device__ void MoveABlockSliceWindow()
Definition blockwise_gemm_xdlops_skip_b_lds.hpp:235
static constexpr index_t A_K1
Definition blockwise_gemm_xdlops_skip_b_lds.hpp:36
StaticBufferTupleOfVector< AddressSpaceEnum::Vgpr, FloatAcc, MRepeat *NRepeat, xdlops_gemm.GetRegSizePerXdlops(), true > c_thread_buf_
Definition blockwise_gemm_xdlops_skip_b_lds.hpp:52
static constexpr index_t A_K0
Definition blockwise_gemm_xdlops_skip_b_lds.hpp:35
static __device__ auto GetWaveIdx()
Definition blockwise_gemm_xdlops_skip_b_lds.hpp:56
Definition utility/sequence.hpp:43
Definition static_buffer.hpp:75
Definition xdlops_gemm.hpp:1821
Definition functional2.hpp:33
Definition dtype_vector.hpp:10