device_gemm_dpp.hpp Source File

device_gemm_dpp.hpp Source File#

Composable Kernel: device_gemm_dpp.hpp Source File
device_gemm_dpp.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include <sstream>
7
17
18namespace ck {
19namespace tensor_operation {
20namespace device {
21
22template <typename ADataType,
23 typename BDataType,
24 typename CDataType,
25 typename AccDataType,
26 typename ALayout,
27 typename BLayout,
28 typename CLayout,
29 typename AElementwiseOperation,
30 typename BElementwiseOperation,
31 typename CElementwiseOperation,
32 GemmSpecialization GemmSpec,
33 ck::index_t BlockSize,
34 ck::index_t MPerBlock,
35 ck::index_t NPerBlock,
36 ck::index_t KPerBlock,
37 ck::index_t AK1,
38 ck::index_t BK1,
39 ck::index_t MPerDpp,
40 ck::index_t NPerDpp,
41 ck::index_t MDppPerWave,
42 ck::index_t NDppPerWave,
43 typename ABlockTransferThreadClusterLengths_K0_M_K1,
44 typename ABlockTransferThreadClusterArrangeOrder,
45 typename ABlockTransferSrcAccessOrder,
46 ck::index_t ABlockTransferSrcVectorDim,
47 ck::index_t ABlockTransferSrcScalarPerVector,
48 ck::index_t ABlockTransferDstScalarPerVector_K1,
49 bool ABlockLdsAddExtraM,
50 typename BBlockTransferThreadClusterLengths_K0_N_K1,
51 typename BBlockTransferThreadClusterArrangeOrder,
52 typename BBlockTransferSrcAccessOrder,
53 ck::index_t BBlockTransferSrcVectorDim,
54 ck::index_t BBlockTransferSrcScalarPerVector,
55 ck::index_t BBlockTransferDstScalarPerVector_K1,
56 bool BBlockLdsAddExtraN,
57 ck::index_t CThreadTransferSrcDstVectorDim,
58 ck::index_t CThreadTransferDstScalarPerVector,
59 ck::index_t NumPrefetch = 1,
61struct DeviceGemmDpp : public DeviceGemm<ALayout,
62 BLayout,
63 CLayout,
64 ADataType,
65 BDataType,
66 CDataType,
67 AElementwiseOperation,
68 BElementwiseOperation,
69 CElementwiseOperation>
70{
72 BlockSize,
73 ADataType,
74 AccDataType,
75 CDataType,
77 ALayout,
78 BLayout,
79 CLayout,
80 AElementwiseOperation,
81 BElementwiseOperation,
82 CElementwiseOperation,
83 GemmSpec,
84 MPerBlock,
85 NPerBlock,
86 KPerBlock,
87 MPerDpp,
88 NPerDpp,
89 AK1,
90 BK1,
91 MDppPerWave,
92 NDppPerWave,
93 ABlockTransferThreadClusterLengths_K0_M_K1,
94 ABlockTransferThreadClusterArrangeOrder,
95 ABlockTransferSrcAccessOrder,
96 ABlockTransferSrcVectorDim,
97 ABlockTransferSrcScalarPerVector,
98 ABlockTransferDstScalarPerVector_K1,
99 false, // AThreadTransferSrcResetCoordinateAfterRun,
100 ABlockLdsAddExtraM,
101 BBlockTransferThreadClusterLengths_K0_N_K1,
102 BBlockTransferThreadClusterArrangeOrder,
103 BBlockTransferSrcAccessOrder,
104 BBlockTransferSrcVectorDim,
105 BBlockTransferSrcScalarPerVector,
106 BBlockTransferDstScalarPerVector_K1,
107 false, // BThreadTransferSrcResetCoordinateAfterRun,
108 BBlockLdsAddExtraN,
109 Sequence<0, 2, 4, 1, 3, 5>, // CThreadTransferSrcDstAccessOrder,
110 CThreadTransferSrcDstVectorDim,
111 CThreadTransferDstScalarPerVector,
112 NumPrefetch,
113 PipelineVer>;
114
115 using Argument = typename GridwiseGemm::Argument;
116
117 // Invoker
118 struct Invoker : public BaseInvoker
119 {
120 float Run(const Argument& karg, const StreamConfig& stream_config = StreamConfig{})
121 {
122 if(stream_config.log_level_ > 0)
123 {
124 karg.Print();
125 }
126
128 {
129 throw std::runtime_error(
130 "wrong! GridwiseGemm_k0mk1_k0nk1_mn_dpp has invalid setting");
131 }
132
133 const auto [gdx, gdy, gdz] = GridwiseGemm::CalculateGridSize(karg.M, karg.N);
134
135 float ave_time = 0;
136
138 {
139 const auto kernel = kernel_gemm_dpp<GridwiseGemm, true>;
140
141 ave_time = launch_and_time_kernel(
142 stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, karg);
143 }
144 else
145 {
146 const auto kernel = kernel_gemm_dpp<GridwiseGemm, false>;
147
148 ave_time = launch_and_time_kernel(
149 stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, karg);
150 }
151
152 return ave_time;
153 }
154
155 // polymorphic
156 float Run(const BaseArgument* p_arg,
157 const StreamConfig& stream_config = StreamConfig{}) override
158 {
159 return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
160 }
161 };
162
163 static constexpr bool IsValidCompilationParameter()
164 {
165 // TODO: properly implement this check
166 return true;
167 }
168
169 static bool IsSupportedArgument(const Argument& karg)
170 {
172 {
173 return GridwiseGemm::CheckValidity(karg);
174 }
175 return false;
176 }
177
178 // polymorphic
179 bool IsSupportedArgument(const BaseArgument* p_arg) override
180 {
181 return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
182 }
183
184 static auto MakeArgument(const ADataType* p_a,
185 const BDataType* p_b,
186 CDataType* p_c,
187 index_t M,
188 index_t N,
189 index_t K,
190 index_t StrideA,
191 index_t StrideB,
192 index_t StrideC,
193 AElementwiseOperation,
194 BElementwiseOperation,
195 CElementwiseOperation)
196 {
197 return Argument{p_a, p_b, p_c, M, N, K, StrideA, StrideB, StrideC};
198 }
199
200 static auto MakeInvoker() { return Invoker{}; }
201
202 // polymorphic
203 std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
204 const void* p_b,
205 void* p_c,
206 index_t M,
207 index_t N,
208 index_t K,
209 index_t StrideA,
210 index_t StrideB,
211 index_t StrideC,
212 AElementwiseOperation,
213 BElementwiseOperation,
214 CElementwiseOperation) override
215 {
216 return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
217 static_cast<const BDataType*>(p_b),
218 static_cast<CDataType*>(p_c),
219 M,
220 N,
221 K,
222 StrideA,
223 StrideB,
224 StrideC);
225 }
226
227 // polymorphic
228 std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
229 {
230 return std::make_unique<Invoker>(Invoker{});
231 }
232
233 // polymorphic
234 std::string GetTypeString() const override
235 {
236 auto str = std::stringstream();
237
238 std::map<PipelineVersion, std::string> PipelineVersionToString{{PipelineVersion::v1, "v1"},
239 {PipelineVersion::v2, "v2"}};
240
241 // clang-format off
242 str << "DeviceGemmDpp"
243 << "<"
244 << BlockSize << ", "
245 << MPerBlock << ", "
246 << NPerBlock << ", "
247 << KPerBlock << ", "
248 << AK1 << ", "
249 << BK1 << ", "
250 << MPerDpp << ", "
251 << NPerDpp << ", "
252 << MDppPerWave << ", "
253 << MDppPerWave << ", "
254 << ABlockTransferSrcScalarPerVector << ", "
255 << ABlockTransferDstScalarPerVector_K1 << ", "
256 << BBlockTransferSrcScalarPerVector << ", "
257 << BBlockTransferDstScalarPerVector_K1
258 << ">"
259 << " NumPrefetch: "
260 << NumPrefetch << ", "
261 << "PipelineVersion: "
262 << PipelineVersionToString[PipelineVer];
263 // clang-format on
264
265 return str.str();
266 }
267};
268
269} // namespace device
270} // namespace tensor_operation
271} // namespace ck
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:14
Definition convolution_backward_data_specialization.hpp:8
GemmSpecialization
Definition gemm_specialization.hpp:11
Definition convolution_backward_data_specialization.hpp:7
Definition ck.hpp:268
int32_t index_t
Definition ck.hpp:299
@ Set
Definition ck.hpp:278
__global__ void kernel_gemm_dpp(const typename GridwiseGemm::Argument karg)
Definition gridwise_gemm_dpp.hpp:29
bool is_gfx103_supported()
Definition host_utility/device_prop.hpp:120
PipelineVersion
Definition gridwise_gemm_pipeline_selector.hpp:18
@ v2
Definition gridwise_gemm_pipeline_selector.hpp:20
@ v1
Definition gridwise_gemm_pipeline_selector.hpp:19
bool is_gfx11_supported()
Definition host_utility/device_prop.hpp:60
Definition ck/stream_config.hpp:10
Definition gridwise_gemm_dpp.hpp:96
Definition utility/sequence.hpp:43
Definition device_base.hpp:197
float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition device_gemm_dpp.hpp:156
float Run(const Argument &karg, const StreamConfig &stream_config=StreamConfig{})
Definition device_gemm_dpp.hpp:120
Definition device_gemm_dpp.hpp:70
GridwiseGemm_ak0mak1_bk0nbk1_mn_dpp< BlockSize, ADataType, AccDataType, CDataType, InMemoryDataOperationEnum::Set, ALayout, BLayout, CLayout, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, MPerBlock, NPerBlock, KPerBlock, MPerDpp, NPerDpp, AK1, BK1, MDppPerWave, NDppPerWave, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_K1, false, ABlockLdsAddExtraM, BBlockTransferThreadClusterLengths_K0_N_K1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_K1, false, BBlockLdsAddExtraN, Sequence< 0, 2, 4, 1, 3, 5 >, CThreadTransferSrcDstVectorDim, CThreadTransferDstScalarPerVector, NumPrefetch, PipelineVer > GridwiseGemm
Definition device_gemm_dpp.hpp:71
typename GridwiseGemm::Argument Argument
Definition device_gemm_dpp.hpp:115
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition device_gemm_dpp.hpp:179
static constexpr bool IsValidCompilationParameter()
Definition device_gemm_dpp.hpp:163
static auto MakeArgument(const ADataType *p_a, const BDataType *p_b, CDataType *p_c, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, index_t StrideC, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation)
Definition device_gemm_dpp.hpp:184
static bool IsSupportedArgument(const Argument &karg)
Definition device_gemm_dpp.hpp:169
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, void *p_c, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, index_t StrideC, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation) override
Definition device_gemm_dpp.hpp:203
static auto MakeInvoker()
Definition device_gemm_dpp.hpp:200
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition device_gemm_dpp.hpp:228
std::string GetTypeString() const override
Definition device_gemm_dpp.hpp:234
Definition device_gemm.hpp:22