block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp Source File

block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp Source File#

Composable Kernel: block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp Source File
block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include "ck_tile/core.hpp"
11
12namespace ck_tile {
13
14template <typename Problem, typename Policy = BlockFmhaBwdPipelineTrLoadDefaultPolicy>
16{
34 // using HotLoopScheduler = typename Policy::template HotLoopScheduler<Problem>;
35
37
38 static constexpr index_t kBlockPerCu = Problem::kBlockPerCu;
39 static constexpr index_t kBlockSize = Problem::kBlockSize;
40
41 static constexpr index_t kM0 = BlockFmhaShape::kM0;
42 static constexpr index_t kN0 = BlockFmhaShape::kN0;
43 static constexpr index_t kK0 = BlockFmhaShape::kK0;
44 static constexpr index_t kK1 = BlockFmhaShape::kK1;
45 static constexpr index_t kK2 = BlockFmhaShape::kK2;
46 static constexpr index_t kK3 = BlockFmhaShape::kK3;
47 static constexpr index_t kK4 = BlockFmhaShape::kK4;
48 static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim;
49 static constexpr index_t kVHeaddim = BlockFmhaShape::kVHeaddim;
50
51 static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
52 static constexpr index_t kPadHeadDimQ = Problem::kPadHeadDimQ;
53 static constexpr index_t kPadHeadDimV = Problem::kPadHeadDimV;
54 static constexpr auto BiasEnum = Problem::BiasEnum;
55 static constexpr bool kHasBiasGrad = Problem::kHasBiasGrad;
56 static constexpr bool kIsDeterministic = Problem::kIsDeterministic;
57 static constexpr bool kUseTrLoad = Problem::kUseTrLoad;
58 static_assert(kUseTrLoad, "This pipeline uses trload!");
59
60 // last dimension vector length used to create tensor view(and decide buffer_load vector length)
61 // ... together with tensor distribution. tensor dist should able to overwrite this
62 static constexpr index_t kAlignmentQ =
63 kPadHeadDimQ ? kPadHeadDimQ : Policy::template GetAlignmentQ<Problem>();
64 static constexpr index_t kAlignmentK =
65 kPadHeadDimQ ? kPadHeadDimQ : Policy::template GetAlignmentK<Problem>();
66 static constexpr index_t kAlignmentV =
67 kPadHeadDimV ? kPadHeadDimV : Policy::template GetAlignmentV<Problem>();
68 static constexpr index_t kAlignmentOGrad =
69 kPadHeadDimV ? kPadHeadDimV : Policy::template GetAlignmentOGrad<Problem>();
70 static constexpr index_t kAlignmentQGrad = 1;
71 static constexpr index_t kAlignmentKGrad =
72 kPadHeadDimQ ? kPadHeadDimQ : Policy::template GetAlignmentKGrad<Problem>();
73 static constexpr index_t kAlignmentVGrad =
74 kPadHeadDimV ? kPadHeadDimV : Policy::template GetAlignmentVGrad<Problem>();
75 static constexpr index_t kAlignmentBias = 1;
76
77 static constexpr const char* name = "trload_kr_ktr_vr";
78
80 {
81 return Policy::template GetSmemSize<Problem>();
82 }
83
85 {
86 if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || FmhaMask::IsMasking)
87 return (raw_lse == -numeric<LSEDataType>::infinity()) //
89 : raw_lse;
90 else
91 return raw_lse;
92 };
93 template <typename... Ts>
94 CK_TILE_DEVICE auto operator()(void* smem_ptr, Ts&&... args) const
95 {
96 // LDS allocation
97 // cast to char* to do pointer arithmetic
98 const auto smem_ptr_ = reinterpret_cast<char*>(smem_ptr);
99 const auto k_lds_ptr = reinterpret_cast<KDataType*>(smem_ptr_);
100 const auto v_lds_ptr =
101 reinterpret_cast<VDataType*>(smem_ptr_ + Policy::template GetSmemSizeK<Problem>());
102
103 const auto do_lds_ptr0 = reinterpret_cast<OGradDataType*>(smem_ptr_);
104 const auto do_lds_ptr1 = reinterpret_cast<OGradDataType*>(
105 smem_ptr_ + Policy::template GetSmemSizeOGrad<Problem>());
106 const auto q_lds_ptr0 = reinterpret_cast<QDataType*>( //
107 smem_ptr_ + Policy::template GetSmemSizeOGrad<Problem>() +
108 Policy::template GetSmemSizeOGrad<Problem>());
109 const auto q_lds_ptr1 = reinterpret_cast<QDataType*>( //
110 smem_ptr_ + Policy::template GetSmemSizeOGrad<Problem>() +
111 Policy::template GetSmemSizeOGrad<Problem>() +
112 Policy::template GetSmemSizeQ<Problem>());
113 const auto lse_lds_ptr0 = reinterpret_cast<LSEDataType*>(
114 smem_ptr_ + Policy::template GetSmemSizeOGrad<Problem>() +
115 Policy::template GetSmemSizeOGrad<Problem>() +
116 Policy::template GetSmemSizeQ<Problem>() + Policy::template GetSmemSizeQ<Problem>());
117 const auto lse_lds_ptr1 = reinterpret_cast<LSEDataType*>(
118 smem_ptr_ + Policy::template GetSmemSizeOGrad<Problem>() +
119 Policy::template GetSmemSizeOGrad<Problem>() +
120 Policy::template GetSmemSizeQ<Problem>() + Policy::template GetSmemSizeQ<Problem>() +
121 Policy::template GetSmemSizeLSE<Problem>());
122 const auto d_lds_ptr0 = reinterpret_cast<DDataType*>(
123 smem_ptr_ + Policy::template GetSmemSizeOGrad<Problem>() +
124 Policy::template GetSmemSizeOGrad<Problem>() +
125 Policy::template GetSmemSizeQ<Problem>() + Policy::template GetSmemSizeQ<Problem>() +
126 Policy::template GetSmemSizeLSE<Problem>() +
127 Policy::template GetSmemSizeLSE<Problem>());
128 const auto d_lds_ptr1 = reinterpret_cast<DDataType*>(
129 smem_ptr_ + Policy::template GetSmemSizeOGrad<Problem>() +
130 Policy::template GetSmemSizeOGrad<Problem>() +
131 Policy::template GetSmemSizeQ<Problem>() + Policy::template GetSmemSizeQ<Problem>() +
132 Policy::template GetSmemSizeLSE<Problem>() +
133 Policy::template GetSmemSizeLSE<Problem>() + Policy::template GetSmemSizeD<Problem>());
134 const auto ds_lds_ptr = reinterpret_cast<GemmDataType*>(
135 smem_ptr_ + Policy::template GetSmemSizeOGrad<Problem>() +
136 Policy::template GetSmemSizeOGrad<Problem>() +
137 Policy::template GetSmemSizeQ<Problem>() + Policy::template GetSmemSizeQ<Problem>() +
138 Policy::template GetSmemSizeLSE<Problem>() +
139 Policy::template GetSmemSizeLSE<Problem>() + Policy::template GetSmemSizeD<Problem>() +
140 Policy::template GetSmemSizeD<Problem>());
141 const auto bias_lds_ptr = reinterpret_cast<BiasDataType*>(ds_lds_ptr);
142 return run(k_lds_ptr,
143 v_lds_ptr,
144 do_lds_ptr0,
145 do_lds_ptr1,
146 q_lds_ptr0,
147 q_lds_ptr1,
148 lse_lds_ptr0,
149 lse_lds_ptr1,
150 d_lds_ptr0,
151 d_lds_ptr1,
152 ds_lds_ptr,
153 bias_lds_ptr,
154 std::forward<Ts>(args)...);
155 }
156
157 template <typename QDramBlockWindowTmp,
158 typename KDramBlockWindowTmp,
159 typename VDramBlockWindowTmp,
160 typename BiasDramBlockWindowTmp,
161 typename RandValDramBlockWindowTmp,
162 typename OGradDramBlockWindowTmp,
163 typename LSEDramBlockWindowTmp,
164 typename DDramBlockWindowTmp,
165 typename QGradDramBlockWindowTmp,
166 typename BiasGradDramBlockWindowTmp,
167 typename PositionEncoding>
169 KDataType* __restrict__ k_lds_ptr,
170 VDataType* __restrict__ v_lds_ptr,
171 OGradDataType* __restrict__ do_lds_ptr0,
172 OGradDataType* __restrict__ do_lds_ptr1,
173 QDataType* __restrict__ q_lds_ptr0,
174 QDataType* __restrict__ q_lds_ptr1,
175 LSEDataType* __restrict__ lse_lds_ptr0,
176 LSEDataType* __restrict__ lse_lds_ptr1,
177 DDataType* __restrict__ d_lds_ptr0,
178 DDataType* __restrict__ d_lds_ptr1,
179 GemmDataType* __restrict__ ds_lds_ptr,
180 BiasDataType* __restrict__ bias_lds_ptr,
181 const QDramBlockWindowTmp& q_dram_block_window_tmp,
182 const KDramBlockWindowTmp& k_dram_block_window_tmp,
183 const VDramBlockWindowTmp& v_dram_block_window_tmp,
184 const BiasDramBlockWindowTmp& bias_dram_block_window_tmp,
185 const RandValDramBlockWindowTmp& randval_dram_block_window_tmp,
186 const OGradDramBlockWindowTmp& do_dram_block_window_tmp,
187 const LSEDramBlockWindowTmp& lse_dram_block_window_tmp,
188 const DDramBlockWindowTmp& d_dram_block_window_tmp,
189 const QGradDramBlockWindowTmp& dq_dram_block_window_tmp,
190 const BiasGradDramBlockWindowTmp& dbias_dram_block_window_tmp,
191 FmhaMask mask,
192 PositionEncoding position_encoding,
193 float raw_scale,
194 float scale,
195 float rp_undrop,
196 float scale_rp_undrop,
197 FmhaDropout& dropout) const
198 {
199 static_assert(
200 std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
201 std::is_same_v<KDataType, remove_cvref_t<typename KDramBlockWindowTmp::DataType>> &&
202 std::is_same_v<VDataType, remove_cvref_t<typename VDramBlockWindowTmp::DataType>> &&
203 std::is_same_v<OGradDataType,
205 std::is_same_v<LSEDataType,
207 std::is_same_v<DDataType, remove_cvref_t<typename DDramBlockWindowTmp::DataType>>,
208 "wrong!");
209
210 static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
211 kN0 == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
212 kN0 == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
213 kM0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
214 kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] &&
215 kM0 == OGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
216 kM0 == LSEDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
217 kM0 == DDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
218 kM0 == QGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
219 kM0 == BiasGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
220 kN0 == BiasGradDramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
221 "wrong!");
222
223 // Block GEMM
224 constexpr auto gemm_0 = Policy::template GetQKBlockGemm<Problem>();
225 constexpr auto gemm_1 = Policy::template GetPTOGradTBlockGemm<Problem>();
226 constexpr auto gemm_2 = Policy::template GetOGradVBlockGemm<Problem>();
227 constexpr auto gemm_3 = Policy::template GetSGradTQTBlockGemm<Problem>();
228 constexpr auto gemm_4 = Policy::template GetSGradKTBlockGemm<Problem>();
229
230 // init VGrad & KGrad
231 auto dv_acc = decltype(gemm_1.MakeCBlockTile()){};
232 auto dk_acc = decltype(gemm_3.MakeCBlockTile()){};
233
234 // K, HBM ->LDS ->Reg
235 auto k_dram_window =
236 make_tile_window(Policy::template TransformXDramTensorView<KDataType>(
237 k_dram_block_window_tmp.get_bottom_tensor_view()),
238 k_dram_block_window_tmp.get_window_lengths(),
239 k_dram_block_window_tmp.get_window_origin(),
240 Policy::template MakeKDramTileDistribution<Problem>());
241
242 const auto k_origin = k_dram_window.get_window_origin();
243
244 // Early termination
245 const auto [seqlen_q_start, seqlen_q_end] =
246 mask.GetTileRangeAlongY(k_origin.at(number<0>{}), number<kM0>{}, number<kN0>{});
247
248 const auto num_total_loop = integer_divide_ceil(seqlen_q_end - seqlen_q_start, kM0);
249
250 // check early exit if masked and no work to do.
251 if constexpr(FmhaMask::IsMasking)
252 {
253 if(num_total_loop <= 0)
254 {
255 // Note: here dk_acc&dv_acc are all cleard, return it
256 // Note: v loaded but no fence, ignore it.
257 return make_tuple(dk_acc, dv_acc);
258 }
259 }
260
262 k_lds_ptr, Policy::template MakeKLdsWriteBlockDescriptor<Problem>());
263 auto k_lds_write_window =
265
266 //------------------------------------------------------------------
267 // V, HBM ->LDS ->Reg
268 auto v_dram_window =
269 make_tile_window(Policy::template TransformXDramTensorView<VDataType>(
270 v_dram_block_window_tmp.get_bottom_tensor_view()),
271 v_dram_block_window_tmp.get_window_lengths(),
272 v_dram_block_window_tmp.get_window_origin(),
273 Policy::template MakeVDramTileDistribution<Problem>());
275 v_lds_ptr, Policy::template MakeVLdsWriteBlockDescriptor<Problem>());
276 auto v_lds_write_window =
278
279 //------------------------------------------------------------------
280 // KT, HBM -> LDS --trload-->Reg
281 async_load_tile(k_lds_write_window, k_dram_window);
282 async_load_tile(v_lds_write_window, v_dram_window);
283 __builtin_amdgcn_s_waitcnt(3952);
285
286 //------------------------------------------------------------------
287 // Pre-Load KV into Registers
289 k_lds_ptr, Policy::template MakeKLdsReadBlockDescriptor<Problem>());
290 auto k_lds_read_window =
291 make_tile_window(k_lds_read,
293 k_lds_write_window.get_window_origin(),
294 Policy::template MakeKRegBlockDescriptor<Problem>());
295 auto k_reg_tensor = load_tile(k_lds_read_window);
296
297 auto kt_lds_read_window =
298 make_tile_window(k_lds_read,
300 {0, 0},
301 Policy::template MakeKTRegBlockDescriptor<Problem>());
302
303 auto kt_reg_tensor = load_tile_transpose(kt_lds_read_window);
304
306 v_lds_ptr, Policy::template MakeVLdsReadBlockDescriptor<Problem>());
307 auto v_lds_read_window =
308 make_tile_window(v_lds_read,
310 v_lds_write_window.get_window_origin(),
311 Policy::template MakeVRegBlockDescriptor<Problem>());
312 auto v_reg_tensor = load_tile(v_lds_read_window);
313
314 __builtin_amdgcn_s_waitcnt(3952);
316 //---------------------------- Loop Load in ----------------------------//
317 // Q: HBM -->LDS
318 auto q_dram_window =
319 make_tile_window(Policy::template TransformXDramTensorView<QDataType>(
320 q_dram_block_window_tmp.get_bottom_tensor_view()),
321 q_dram_block_window_tmp.get_window_lengths(),
322 {seqlen_q_start, 0},
323 Policy::template MakeQDramTileDistribution<Problem>());
324
326 q_lds_ptr0, Policy::template MakeQLdsWriteBlockDescriptor<Problem>());
327 auto q_lds_write_window =
329
331 q_lds_ptr0, Policy::template MakeQLdsReadBlockDescriptor<Problem>());
332 auto q_lds_read_window =
333 make_tile_window(q_lds_read,
335 q_lds_write_window.get_window_origin(),
336 Policy::template MakeQRegSliceBlockDescriptor<Problem>());
337 auto qt_lds_read_window =
338 make_tile_window(q_lds_read,
340 {0, 0},
341 Policy::template MakeQTRegSliceBlockDescriptor<Problem>());
342
343 // dO: HBM ->LDS ---load--> Reg
344 // dOT: \-loadtr-> Reg
345 auto do_dram_window =
346 make_tile_window(Policy::template TransformXDramTensorView<OGradDataType>(
347 do_dram_block_window_tmp.get_bottom_tensor_view()),
348 do_dram_block_window_tmp.get_window_lengths(),
349 {seqlen_q_start, 0},
350 Policy::template MakeOGradDramTileDistribution<Problem>());
351
353 do_lds_ptr0, Policy::template MakeOGradLdsWriteBlockDescriptor<Problem>());
354 auto do_lds_write_window =
356
358 do_lds_ptr0, Policy::template MakeOGradLdsReadBlockDescriptor<Problem>());
359 auto do_lds_read_window =
360 make_tile_window(do_lds_read,
362 do_lds_write_window.get_window_origin(),
363 Policy::template MakeOGradRegSliceBlockDescriptor<Problem>());
364 auto dot_lds_read_window =
365 make_tile_window(do_lds_read,
367 {0, 0},
368 Policy::template MakeOGradTRegSliceBlockDescriptor<Problem>());
369
370 // dS: Reg -> Reg -> LDS
372 ds_lds_ptr, Policy::template MakeSGradLdsBlockDescriptor<Problem>());
373
374 auto ds_lds_window =
375 make_tile_window(ds_lds, make_tuple(number<kM0>{}, number<kN0>{}), {0, 0});
376
377 // transform it to make it from col-major to row-major; prepared for load_tile_transpose
379 ds_lds_ptr, Policy::template MakeSGradLdsBlockDescriptor<Problem, true>());
380 auto ds_lds_read_window =
381 make_tile_window(ds_lds_t,
383 {0, 0},
384 Policy::template MakeSGradRegSliceBlockDescriptor<Problem>());
385
386 // Bias: HBM ->Reg ->Reg ->LDS
387 const auto bias_origin = bias_dram_block_window_tmp.get_window_origin();
388
389 auto bias_dram_window =
390 make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(),
391 bias_dram_block_window_tmp.get_window_lengths(),
392 {seqlen_q_start, bias_origin.at(number<1>{})},
393 Policy::template MakeBiasTileDistribution<Problem>());
394
396 bias_lds_ptr, Policy::template MakeBiasLdsBlockDescriptor<Problem>());
397 auto bias_lds_write_window =
398 make_tile_window(bias_lds, make_tuple(number<kM0>{}, number<kN0>{}), {0, 0});
399
400 auto bias_s_lds_read_window =
401 make_tile_window(bias_lds_write_window.get_bottom_tensor_view(),
402 bias_lds_write_window.get_window_lengths(),
403 bias_lds_write_window.get_window_origin(),
404 Policy::template MakeBiasSTileDistribution<decltype(gemm_0)>());
405
406 static_assert(std::is_same_v<BiasDataType, BiasGradDataType>,
407 "BiasDataType and BiasGradDataType should be the same!");
408
409 // LSE: HBM -> LDS ->Reg
410 auto lse_dram_window =
411 make_tile_window(lse_dram_block_window_tmp.get_bottom_tensor_view(),
412 lse_dram_block_window_tmp.get_window_lengths(),
413 {seqlen_q_start},
414 Policy::template MakeLSEDDramTileDistribution<Problem>());
415
417 lse_lds_ptr0, Policy::template MakeLSEDLdsWriteBlockDescriptor<Problem>());
418
419 auto lse_lds_write_window = make_tile_window(lse_lds, make_tuple(number<kM0>{}), {0});
420
421 auto lse_lds_read_window =
422 make_tile_window(lse_lds,
424 {0},
425 Policy::template MakeLSEDLdsReadBlockDescriptor<Problem>());
426
427 // D: HBM ->Reg
428 auto d_dram_window =
429 make_tile_window(d_dram_block_window_tmp.get_bottom_tensor_view(),
430 d_dram_block_window_tmp.get_window_lengths(),
431 {seqlen_q_start},
432 Policy::template MakeLSEDDramTileDistribution<Problem>());
433
435 d_lds_ptr0, Policy::template MakeLSEDLdsWriteBlockDescriptor<Problem>());
436 auto d_lds_write_window = make_tile_window(d_lds, make_tuple(number<kM0>{}), {0});
437 auto d_lds_read_window =
438 make_tile_window(d_lds,
440 {0},
441 Policy::template MakeLSEDLdsReadBlockDescriptor<Problem>());
442
443 // RandVal: HBM ->Reg
444 auto randval_dram_window = dropout.template MakeRandvalDramWindow<decltype(gemm_0), false>(
445 randval_dram_block_window_tmp, seqlen_q_start);
446
447 // BiasGrad
448 // Reg ->LDS ->Reg ->HBM
449 const auto dbias_origin = dbias_dram_block_window_tmp.get_window_origin();
450
451 auto dbias_dram_window =
452 make_tile_window(dbias_dram_block_window_tmp.get_bottom_tensor_view(),
453 dbias_dram_block_window_tmp.get_window_lengths(),
454 {seqlen_q_start, dbias_origin.at(number<1>{})}); // M/N
455
456 auto dbias_lds_read_window =
457 make_tile_window(bias_lds,
459 {0, 0},
460 Policy::template MakeShuffledBiasTileDistribution<Problem>());
461
462 // ----------------------------Loop write out------------------------------//
463 auto dq_dram_window = make_tile_window(dq_dram_block_window_tmp.get_bottom_tensor_view(),
464 dq_dram_block_window_tmp.get_window_lengths(),
465 {seqlen_q_start, 0});
466
467 index_t i_total_loops = 0;
468 index_t seqlen_q_step = seqlen_q_start;
469 static_assert(kQKHeaddim >= kK0, "kQKHeaddim should be equal or greater than kK0");
470 static_assert(kM0 == kK1, "kM0 should equal to kK1");
471 static_assert(kVHeaddim >= kK2, "kVHeaddim should be equal or greater than kK2");
472 static_assert(kM0 == kK3, "kM0 should equal to kK3");
473 constexpr index_t k4_loops = kN0 / kK4;
474
475 clear_tile(dv_acc);
476 clear_tile(dk_acc);
477
478 __builtin_amdgcn_sched_barrier(0);
479
480 decltype(load_tile(q_lds_read_window)) q_reg_tensor;
481 decltype(load_tile(lse_lds_read_window)) lse;
482 decltype(load_tile_transpose(ds_lds_read_window)) ds_reg_tensor;
483 decltype(load_tile_transpose(ds_lds_read_window)) ds_reg_tensor_next;
484 decltype(load_tile(do_lds_read_window)) do_reg_tensor;
485 decltype(load_tile_transpose(dot_lds_read_window)) dot_reg_tensor;
486 decltype(load_tile(d_lds_read_window)) d;
487 decltype(load_tile_transpose(qt_lds_read_window)) qt_reg_tensor;
488 decltype(gemm_0.MakeCBlockTile()) s_acc, p;
489 decltype(gemm_2.MakeCBlockTile()) dp_acc, ds;
490 decltype(gemm_4.MakeCBlockTile()) dq_acc;
491
492 index_t i_total_bodys = 0;
493 auto main_body_impl = [&](auto is_prologue_,
494 auto is_epilogue_,
495 QDataType* const __restrict__ q_lds_ptr_curr,
496 QDataType* const __restrict__ q_lds_ptr_next,
497 OGradDataType* const __restrict__ do_lds_ptr_curr,
498 OGradDataType* const __restrict__ do_lds_ptr_next,
499 LSEDataType* const __restrict__ lse_lds_ptr_curr,
500 LSEDataType* const __restrict__ lse_lds_ptr_next,
501 DDataType* const __restrict__ d_lds_ptr_curr,
502 DDataType* const __restrict__ d_lds_ptr_next
503
504 ) mutable {
505 constexpr bool is_prologue = is_prologue_.value;
506 constexpr bool is_epilogue = is_epilogue_.value;
507 static_assert(is_prologue || is_epilogue, "is_prologue or is_epilogue should be true");
508 constexpr bool is_main_body = is_prologue && is_epilogue;
509 if constexpr(is_prologue)
510 {
511 lse_lds_write_window.set_bottom_tensor_view_data_ptr(lse_lds_ptr_next);
512 async_load_tile(lse_lds_write_window, lse_dram_window);
513 move_tile_window(lse_dram_window, {kM0});
514
515 d_lds_write_window.set_bottom_tensor_view_data_ptr(d_lds_ptr_next);
516 async_load_tile(d_lds_write_window, d_dram_window);
517 move_tile_window(d_dram_window, {kM0});
518
519 q_lds_write_window.set_bottom_tensor_view_data_ptr(q_lds_ptr_next);
520 async_load_tile(q_lds_write_window, q_dram_window);
521 move_tile_window(q_dram_window, {kM0, 0});
522
523 do_lds_write_window.set_bottom_tensor_view_data_ptr(do_lds_ptr_next);
524 async_load_tile(do_lds_write_window, do_dram_window);
525 move_tile_window(do_dram_window, {kM0, 0});
526 }
527 if constexpr(is_epilogue)
528 {
529 // STAGE 1, Q@K Gemm0
530 s_acc = gemm_0(q_reg_tensor, k_reg_tensor);
531
532 dot_lds_read_window.set_bottom_tensor_view_data_ptr(do_lds_ptr_curr);
533 dot_reg_tensor = load_tile_transpose(dot_lds_read_window);
534 }
535 if constexpr(is_epilogue)
536 {
537 lse_lds_read_window.set_bottom_tensor_view_data_ptr(lse_lds_ptr_curr);
538 lse = load_tile(lse_lds_read_window);
539 d_lds_read_window.set_bottom_tensor_view_data_ptr(d_lds_ptr_curr);
540 d = load_tile(d_lds_read_window);
541 }
542 if constexpr(is_main_body)
543 Policy::template HotLoopScheduler<Problem>::SchedulerGemm0();
544 __builtin_amdgcn_sched_barrier(0);
545 if constexpr(is_epilogue)
546 {
547 // STAGE 2, Scale, Add bias, Mask, Softmax, Dropout
548 if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
549 {
550 const auto bias_tile = load_tile(bias_dram_window);
551 auto shuffled_bias_tile = make_static_distributed_tensor<BiasDataType>(
552 Policy::template MakeShuffledBiasTileDistribution<Problem>());
553 shuffle_tile(shuffled_bias_tile, bias_tile);
554 store_tile(bias_lds_write_window, shuffled_bias_tile);
556 auto bias_s_tile = load_tile(bias_s_lds_read_window);
558 [&](auto& x, const auto& y) {
560 },
561 s_acc,
562 bias_s_tile);
563 move_tile_window(bias_dram_window, {kM0, 0});
564 __builtin_amdgcn_sched_barrier(0);
565 }
566 else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
567 {
568 constexpr auto s_spans = decltype(s_acc)::get_distributed_spans();
569 sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) {
570 sweep_tile_span(s_spans[number<1>{}], [&](auto idx1) {
571 const auto tile_idx = get_x_indices_from_distributed_indices(
572 s_acc.get_tile_distribution(), make_tuple(idx0, idx1));
573
574 const auto row = seqlen_q_step + tile_idx.at(number<0>{});
575 const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
576 constexpr auto i_j_idx = make_tuple(idx0, idx1);
577
578 s_acc(i_j_idx) *= scale;
579 position_encoding.update(s_acc(i_j_idx), row, col);
580 });
581 });
582 }
583
584 {
585 bool need_perpixel_check = mask.IsEdgeTile(
586 seqlen_q_step, k_origin.at(number<0>{}), number<kM0>{}, number<kN0>{});
587 if(need_perpixel_check)
588 {
589 set_tile_if(s_acc, -numeric<AccDataType>::infinity(), [&](auto tile_idx) {
590 const auto row = seqlen_q_step + tile_idx.at(number<0>{});
591 const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
592 return mask.IsOutOfBound(row, col);
593 });
594 }
595 }
596
597 constexpr auto p_spans = decltype(p)::get_distributed_spans();
598 sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) {
599 constexpr auto i_idx = make_tuple(idx0);
600 auto row_lse = log2e_v<LSEDataType> * get_validated_lse(lse[i_idx]);
601
602 sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {
603 constexpr auto i_j_idx = make_tuple(idx0, idx1);
604
605 if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
607 p(i_j_idx) = exp2(s_acc[i_j_idx] - row_lse);
608 else
609 p(i_j_idx) = exp2(scale * s_acc[i_j_idx] - row_lse);
610 });
611 });
612
613 if constexpr(FmhaDropout::IsDropout)
614 {
615 dropout.template Run<decltype(gemm_0), RandValOutputDataType>(
616 seqlen_q_step, k_origin.at(number<0>{}), p, randval_dram_window);
617 }
618 const auto p_gemm = [&]() { // dropout / type conversion
619 if constexpr(FmhaDropout::IsDropout)
620 {
621 return tile_elementwise_in(
622 [](const auto& x) {
623 return type_convert<GemmDataType>(x > 0.f ? x : 0.f);
624 },
625 p);
626 }
627 else
628 {
629 return cast_tile<GemmDataType>(p);
630 }
631 }();
632
633 // STAGE 4, OGrad@V Gemm2
634 dp_acc = gemm_2(do_reg_tensor, v_reg_tensor);
635
636 qt_lds_read_window.set_bottom_tensor_view_data_ptr(q_lds_ptr_curr);
637 qt_reg_tensor = load_tile_transpose(qt_lds_read_window);
638
639 // STAGE 3, P^T@OGrad^T Gemm1
641 Policy::template MakePTRegSliceBlockDescriptor<Problem>());
642 pt_reg_tensor.get_thread_buffer() = p_gemm.get_thread_buffer();
643 gemm_1(dv_acc, pt_reg_tensor, dot_reg_tensor);
644 }
646 if constexpr(is_main_body)
647 Policy::template HotLoopScheduler<Problem>::SchedulerGemm12();
648 __builtin_amdgcn_sched_barrier(0);
649 if constexpr(is_epilogue)
650 {
651 // STAGE 5, P^T(PGrad^T - D)
652 constexpr auto ds_spans = decltype(ds)::get_distributed_spans();
653 sweep_tile_span(ds_spans[number<0>{}], [&](auto idx0) {
654 constexpr auto i_idx = make_tuple(idx0);
655 sweep_tile_span(ds_spans[number<1>{}], [&](auto idx1) {
656 constexpr auto i_j_idx = make_tuple(idx0, idx1);
657 bool undrop_flag = p[i_j_idx] >= 0;
658 ds(i_j_idx) = p[i_j_idx] * (!FmhaDropout::IsDropout || undrop_flag
659 ? (dp_acc[i_j_idx] - d[i_idx])
660 : d[i_idx]);
661 });
662 });
663
664 if constexpr(kHasBiasGrad)
665 {
666 const auto dbias = [&]() {
667 if constexpr(FmhaDropout::IsDropout)
668 {
669 return tile_elementwise_in(
670 [&rp_undrop](const auto& x) {
671 return type_convert<BiasGradDataType>(x * rp_undrop);
672 },
673 ds);
674 }
675 else
676 {
678 }
679 }();
680 store_tile(bias_lds_write_window, dbias);
681 __builtin_amdgcn_s_waitcnt(3952);
683 auto shuffled_dbias_tile = load_tile(dbias_lds_read_window);
685 Policy::template MakeBiasTileDistribution<Problem>());
686 shuffle_tile(dbias_tile, shuffled_dbias_tile);
687 store_tile(dbias_dram_window, dbias_tile);
688 move_tile_window(dbias_dram_window, {kM0, 0});
689 __builtin_amdgcn_sched_barrier(0);
690 }
691 }
692 if constexpr(is_epilogue)
693 {
694 // STAGE 6, SGrad^T@Q^T Gemm3
695 const auto ds_gemm = cast_tile<GemmDataType>(ds);
697 Policy::template MakeSGradTRegSliceBlockDescriptor<Problem>());
698 dst_reg_tensor.get_thread_buffer() = ds_gemm.get_thread_buffer();
699 gemm_3(dk_acc, dst_reg_tensor, qt_reg_tensor);
700
701 if constexpr(kHasBiasGrad)
702 {
703 // SGrad and BiasGrad use the same address in LDS, finish loading dbias to reuse
704 // LDS.
706 }
707 store_tile(ds_lds_window, ds_gemm);
708 }
709 s_waitcnt</*vmcnt=*/0>();
711 if constexpr(is_prologue)
712 {
713 q_lds_read_window.set_bottom_tensor_view_data_ptr(q_lds_ptr_next);
714 q_reg_tensor = load_tile(q_lds_read_window);
715 }
716 if constexpr(is_epilogue)
717 {
718 ds_reg_tensor = load_tile_transpose(ds_lds_read_window);
719 move_tile_window(ds_lds_read_window, {kK4, 0});
720 }
721 if constexpr(is_main_body)
722 Policy::template HotLoopScheduler<Problem>::SchedulerGemm3();
723 __builtin_amdgcn_sched_barrier(0);
724 if constexpr(is_epilogue)
725 {
726 // STAGE7 SGrad@K^T Gemm4
727 clear_tile(dq_acc);
728 static_for<0, k4_loops, 1>{}([&](auto i_k4) {
729 if constexpr(i_k4 < k4_loops - 1)
730 {
731 ds_reg_tensor_next = load_tile_transpose(ds_lds_read_window);
732 move_tile_window(ds_lds_read_window, {kK4, 0});
733 }
734 auto kt_reg_tensor_slice = get_slice_tile( //
735 kt_reg_tensor,
737 sequence<kQKHeaddim, (i_k4 + 1) * kK4>{});
738 gemm_4(dq_acc, ds_reg_tensor, kt_reg_tensor_slice);
739
740 if constexpr(i_k4 < k4_loops - 1)
741 {
742 ds_reg_tensor.get_thread_buffer() = ds_reg_tensor_next.get_thread_buffer();
743 }
744 });
745 move_tile_window(ds_lds_read_window, {-kN0, 0});
746 }
748 if constexpr(is_prologue)
749 {
750 do_lds_read_window.set_bottom_tensor_view_data_ptr(do_lds_ptr_next);
751 do_reg_tensor = load_tile(do_lds_read_window);
752 }
753 if constexpr(is_main_body)
754 Policy::template HotLoopScheduler<Problem>::SchedulerGemm4();
755 if constexpr(is_epilogue)
756 {
757 // QGrad Scale
758 if constexpr(FmhaDropout::IsDropout)
759 {
760 tile_elementwise_inout([&scale_rp_undrop](auto& x) { x = x * scale_rp_undrop; },
761 dq_acc);
762 }
763 else
764 {
765 tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, dq_acc);
766 }
767 if constexpr(kIsDeterministic)
768 {
769 store_tile(dq_dram_window, dq_acc);
770 }
771 else
772 {
773 update_tile(dq_dram_window, dq_acc);
774 }
775 move_tile_window(dq_dram_window, {kM0, 0});
776 }
777 };
778
779 auto main_body = [&](auto is_prologue_, auto is_epilogue_) mutable {
780 const bool is_even = (i_total_bodys % 2 == 0);
781 const auto q_lds_ptr_curr = is_even ? q_lds_ptr1 : q_lds_ptr0;
782 const auto q_lds_ptr_next = is_even ? q_lds_ptr0 : q_lds_ptr1;
783 const auto do_lds_ptr_curr = is_even ? do_lds_ptr1 : do_lds_ptr0;
784 const auto do_lds_ptr_next = is_even ? do_lds_ptr0 : do_lds_ptr1;
785 const auto lse_lds_ptr_curr = is_even ? lse_lds_ptr1 : lse_lds_ptr0;
786 const auto lse_lds_ptr_next = is_even ? lse_lds_ptr0 : lse_lds_ptr1;
787 const auto d_lds_ptr_curr = is_even ? d_lds_ptr1 : d_lds_ptr0;
788 const auto d_lds_ptr_next = is_even ? d_lds_ptr0 : d_lds_ptr1;
789 main_body_impl(is_prologue_,
790 is_epilogue_,
791 q_lds_ptr_curr,
792 q_lds_ptr_next,
793 do_lds_ptr_curr,
794 do_lds_ptr_next,
795 lse_lds_ptr_curr,
796 lse_lds_ptr_next,
797 d_lds_ptr_curr,
798 d_lds_ptr_next);
799 i_total_bodys += 1;
800 };
801
802 main_body(std::true_type{}, std::false_type{});
803 // Hot loop
804 if(num_total_loop > 1)
805 {
806 do
807 {
808 main_body(std::true_type{}, std::true_type{});
809 i_total_loops += 1;
810 seqlen_q_step += kM0;
811 } while(i_total_loops < num_total_loop - 1);
812 }
813 main_body(std::false_type{}, std::true_type{});
814
815 // Results Scale
816 if constexpr(FmhaDropout::IsDropout)
817 {
818 tile_elementwise_inout([&scale_rp_undrop](auto& x) { x = x * scale_rp_undrop; },
819 dk_acc);
820 tile_elementwise_inout([&rp_undrop](auto& x) { x = x * rp_undrop; }, dv_acc);
821 }
822 else
823 {
824 tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, dk_acc);
825 }
826
827 return make_tuple(dk_acc, dv_acc);
828 }
829};
830
831} // namespace ck_tile
#define CK_TILE_DEVICE
Definition config.hpp:41
#define CK_TILE_HOST_DEVICE
Definition config.hpp:42
Definition tile/core/algorithm/cluster_descriptor.hpp:13
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
CK_TILE_DEVICE auto async_load_tile(LdsTileWindow_ &&lds_tile, const TileWindow_ &tile_window, number< i_access >={}, bool_constant< oob_conditional_check >={})
Definition load_tile.hpp:119
CK_TILE_DEVICE auto tile_elementwise_in(const InElementFunc &in_element_func, const InTensor &... in_dstr_tensors)
Definition tile_elementwise.hpp:40
CK_TILE_HOST_DEVICE constexpr auto make_tensor_view(DataType *__restrict__ p, const tensor_descriptor< Ts... > &desc)
Definition tensor_view.hpp:452
CK_TILE_HOST_DEVICE constexpr auto get_x_indices_from_distributed_indices(StaticTileDistribution tile_distribution, DistributedIndices distributed_indices)
Definition static_distributed_tensor.hpp:159
CK_TILE_DEVICE constexpr auto get_slice_tile(const tile_window_with_static_lengths< BottomTensorView_, WindowLengths_ > &tile, sequence< SliceBegins... > slice_begins, sequence< SliceEnds... > slice_ends)
Definition slice_tile.hpp:23
@ ALIBI
Definition block_attention_bias_enum.hpp:15
@ ELEMENTWISE_BIAS
Definition block_attention_bias_enum.hpp:14
CK_TILE_DEVICE void tile_elementwise_inout(const InOutElementFunc &inout_element_func, InOutDstrTensors &... inout_dstr_tensors)
Definition tile_elementwise.hpp:23
constexpr T log2e_v
Definition tile/core/numeric/math.hpp:488
CK_TILE_DEVICE void block_sync_lds()
Definition arch.hpp:282
CK_TILE_HOST_DEVICE constexpr auto make_static_distributed_tensor(const StaticTileDistribution &)
Definition static_distributed_tensor.hpp:142
CK_TILE_DEVICE void shuffle_tile(OutTensor &out, const InTensor &in)
Definition shuffle_tile.hpp:154
CK_TILE_DEVICE auto load_tile_transpose(const tile_window_with_static_distribution< BottomTensorView_, WindowLengths_, TileDistribution_, NumCoord > &tile_window)
transpose loads tile from a tensor and returns the resulting tensor with a new (transposed) tile dist...
Definition load_tile_transpose.hpp:403
constant< v > number
Definition tile/core/numeric/integral_constant.hpp:37
CK_TILE_DEVICE constexpr auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition null_tile_window.hpp:75
CK_TILE_DEVICE auto cast_tile(const SrcTensor &src_tensor)
Definition tile_elementwise.hpp:327
CK_TILE_HOST_DEVICE constexpr auto integer_divide_ceil(X x, Y y)
Definition tile/core/numeric/math.hpp:149
CK_TILE_DEVICE void s_waitcnt()
Definition arch.hpp:241
CK_TILE_DEVICE void sweep_tile_span(TileDistributedSpan_, const F &f)
Definition sweep_tile.hpp:20
CK_TILE_DEVICE void move_tile_window(null_tile_window< WindowLengths > &, const typename null_tile_window< WindowLengths >::BottomTensorIndex &)
Definition null_tile_window.hpp:95
CK_TILE_HOST_DEVICE void set_tile_if(static_distributed_tensor< DataType, StaticTileDistribution > &out_tensor, DataType value, XIndicesPredicate predicate)
Definition static_distributed_tensor.hpp:175
CK_TILE_DEVICE void update_tile(tile_window_with_static_lengths< BottomTensorView_, WindowLengths_ > &tile_window_tmp, const static_distributed_tensor< DataType_, TileDistribution_ > &dstr_tensor)
Definition update_tile.hpp:22
CK_TILE_DEVICE void store_tile(tile_window_with_static_lengths< BottomTensorView_, WindowLengths_ > &tile_window_tmp, const static_distributed_tensor< DataType_, TileDistribution_ > &dstr_tensor)
Definition store_tile.hpp:23
int32_t index_t
Definition integer.hpp:9
CK_TILE_DEVICE void clear_tile(DstrTensors &dstr_tensor)
Definition tile_elementwise.hpp:177
CK_TILE_HOST_DEVICE constexpr Y type_convert(X x)
Definition tile/core/numeric/type_convert.hpp:29
CK_TILE_DEVICE auto load_tile(const TileWindow_ &tile_window, number< i_access >={}, bool_constant< oob_conditional_check >={})
Definition load_tile.hpp:22
CK_TILE_DEVICE bfloat16_t exp2(bfloat16_t x)
Definition bfloat16.hpp:425
CK_TILE_HOST_DEVICE constexpr auto make_tuple(Xs &&... xs)
Definition tile/core/container/tuple.hpp:360
Definition block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp:16
static constexpr index_t kM0
Definition block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp:41
remove_cvref_t< typename Problem::KDataType > KDataType
Definition block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp:18
remove_cvref_t< typename Problem::ODataType > ODataType
Definition block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp:26
static constexpr bool kHasBiasGrad
Definition block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp:55
static constexpr index_t kPadHeadDimV
Definition block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp:53
remove_cvref_t< typename Problem::OGradDataType > OGradDataType
Definition block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp:27
static constexpr index_t kVHeaddim
Definition block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp:49
remove_cvref_t< typename Problem::BiasDataType > BiasDataType
Definition block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp:21
static CK_TILE_HOST_DEVICE constexpr ck_tile::index_t GetSmemSize()
Definition block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp:79
static constexpr const char * name
Definition block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp:77
remove_cvref_t< typename Problem::KGradDataType > KGradDataType
Definition block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp:29
static constexpr index_t kBlockSize
Definition block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp:39
CK_TILE_DEVICE auto operator()(void *smem_ptr, Ts &&... args) const
Definition block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp:94
static constexpr index_t kK1
Definition block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp:44
static constexpr index_t kAlignmentQ
Definition block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp:62
static constexpr auto BiasEnum
Definition block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp:54
remove_cvref_t< typename Problem::BiasGradDataType > BiasGradDataType
Definition block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp:31
static constexpr index_t kQKHeaddim
Definition block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp:48
static constexpr index_t kAlignmentOGrad
Definition block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp:68
remove_cvref_t< typename Problem::GemmDataType > GemmDataType
Definition block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp:20
static constexpr index_t kK0
Definition block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp:43
static constexpr index_t kAlignmentVGrad
Definition block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp:73
remove_cvref_t< typename Problem::DDataType > DDataType
Definition block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp:24
remove_cvref_t< typename Problem::LSEDataType > LSEDataType
Definition block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp:22
static constexpr index_t kAlignmentBias
Definition block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp:75
CK_TILE_DEVICE auto run(KDataType *__restrict__ k_lds_ptr, VDataType *__restrict__ v_lds_ptr, OGradDataType *__restrict__ do_lds_ptr0, OGradDataType *__restrict__ do_lds_ptr1, QDataType *__restrict__ q_lds_ptr0, QDataType *__restrict__ q_lds_ptr1, LSEDataType *__restrict__ lse_lds_ptr0, LSEDataType *__restrict__ lse_lds_ptr1, DDataType *__restrict__ d_lds_ptr0, DDataType *__restrict__ d_lds_ptr1, GemmDataType *__restrict__ ds_lds_ptr, BiasDataType *__restrict__ bias_lds_ptr, const QDramBlockWindowTmp &q_dram_block_window_tmp, const KDramBlockWindowTmp &k_dram_block_window_tmp, const VDramBlockWindowTmp &v_dram_block_window_tmp, const BiasDramBlockWindowTmp &bias_dram_block_window_tmp, const RandValDramBlockWindowTmp &randval_dram_block_window_tmp, const OGradDramBlockWindowTmp &do_dram_block_window_tmp, const LSEDramBlockWindowTmp &lse_dram_block_window_tmp, const DDramBlockWindowTmp &d_dram_block_window_tmp, const QGradDramBlockWindowTmp &dq_dram_block_window_tmp, const BiasGradDramBlockWindowTmp &dbias_dram_block_window_tmp, FmhaMask mask, PositionEncoding position_encoding, float raw_scale, float scale, float rp_undrop, float scale_rp_undrop, FmhaDropout &dropout) const
Definition block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp:168
static constexpr bool kIsDeterministic
Definition block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp:56
static constexpr index_t kPadHeadDimQ
Definition block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp:52
static constexpr index_t kAlignmentK
Definition block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp:64
static constexpr index_t kAlignmentQGrad
Definition block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp:70
static constexpr index_t kBlockPerCu
Definition block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp:38
remove_cvref_t< typename Problem::RandValOutputDataType > RandValOutputDataType
Definition block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp:25
static constexpr bool kUseTrLoad
Definition block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp:57
static constexpr index_t kK4
Definition block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp:47
remove_cvref_t< typename Problem::FmhaDropout > FmhaDropout
Definition block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp:33
static constexpr index_t kAlignmentV
Definition block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp:66
remove_cvref_t< typename Problem::QGradDataType > QGradDataType
Definition block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp:28
remove_cvref_t< typename Problem::AccDataType > AccDataType
Definition block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp:23
static constexpr index_t kAlignmentKGrad
Definition block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp:71
remove_cvref_t< typename Problem::VDataType > VDataType
Definition block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp:19
static constexpr bool kIsGroupMode
Definition block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp:51
remove_cvref_t< typename Problem::BlockFmhaShape > BlockFmhaShape
Definition block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp:36
remove_cvref_t< typename Problem::VGradDataType > VGradDataType
Definition block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp:30
static CK_TILE_HOST_DEVICE LSEDataType get_validated_lse(const LSEDataType raw_lse)
Definition block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp:84
static constexpr index_t kK3
Definition block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp:46
static constexpr index_t kK2
Definition block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp:45
remove_cvref_t< typename Problem::FmhaMask > FmhaMask
Definition block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp:32
remove_cvref_t< typename Problem::QDataType > QDataType
Definition block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp:17
static constexpr index_t kN0
Definition block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp:42
static CK_TILE_HOST_DEVICE constexpr T infinity()
Definition tile/core/numeric/numeric.hpp:38
Definition tile/core/container/sequence.hpp:49
Definition tile/core/utility/functional.hpp:43