BlockFmhaBwdDQDKDVPipelineTrLoadQRQTRDOR< Problem, Policy > Struct Template Reference#
Public Types |
Public Member Functions |
Static Public Member Functions |
Static Public Attributes |
List of all members
ck_tile::BlockFmhaBwdDQDKDVPipelineTrLoadQRQTRDOR< Problem, Policy > Struct Template Reference
#include <block_fmha_bwd_dq_dk_dv_pipeline_trload_qr_qtr_dor.hpp>
Public Types | |
| using | QDataType = remove_cvref_t<typename Problem::QDataType> |
| using | KDataType = remove_cvref_t<typename Problem::KDataType> |
| using | VDataType = remove_cvref_t<typename Problem::VDataType> |
| using | GemmDataType = remove_cvref_t<typename Problem::GemmDataType> |
| using | BiasDataType = remove_cvref_t<typename Problem::BiasDataType> |
| using | LSEDataType = remove_cvref_t<typename Problem::LSEDataType> |
| using | AccDataType = remove_cvref_t<typename Problem::AccDataType> |
| using | DDataType = remove_cvref_t<typename Problem::DDataType> |
| using | RandValOutputDataType = remove_cvref_t<typename Problem::RandValOutputDataType> |
| using | ODataType = remove_cvref_t<typename Problem::ODataType> |
| using | OGradDataType = remove_cvref_t<typename Problem::OGradDataType> |
| using | QGradDataType = remove_cvref_t<typename Problem::QGradDataType> |
| using | KGradDataType = remove_cvref_t<typename Problem::KGradDataType> |
| using | VGradDataType = remove_cvref_t<typename Problem::VGradDataType> |
| using | BiasGradDataType = remove_cvref_t<typename Problem::BiasGradDataType> |
| using | FmhaMask = remove_cvref_t<typename Problem::FmhaMask> |
| using | FmhaDropout = remove_cvref_t<typename Problem::FmhaDropout> |
| using | BlockFmhaShape = remove_cvref_t<typename Problem::BlockFmhaShape> |
Public Member Functions | |
| template<typename... Ts> | |
| CK_TILE_DEVICE auto | operator() (void *smem_ptr, Ts &&... args) const |
| template<typename QDramBlockWindowTmp, typename KDramBlockWindowTmp, typename VDramBlockWindowTmp, typename BiasDramBlockWindowTmp, typename RandValDramBlockWindowTmp, typename OGradDramBlockWindowTmp, typename LSEDramBlockWindowTmp, typename DDramBlockWindowTmp, typename QGradDramBlockWindowTmp, typename KGradDramBlockWindowTmp, typename VGradDramBlockWindowTmp, typename BiasGradDramBlockWindowTmp, typename QGradEpilogue, typename KGradEpilogue, typename VGradEpilogue, typename PositionEncoding> | |
| CK_TILE_DEVICE auto | run (KDataType *__restrict__ k_lds_ptr, VDataType *__restrict__ v_lds_ptr, OGradDataType *__restrict__ do_lds_ptr, QDataType *__restrict__ q_lds_ptr, LSEDataType *__restrict__ lse_lds_ptr, DDataType *__restrict__ d_lds_ptr, GemmDataType *__restrict__ ds_lds_ptr, BiasDataType *__restrict__ bias_lds_ptr, const QDramBlockWindowTmp &q_dram_block_window_tmp, const KDramBlockWindowTmp &k_dram_block_window_tmp, const VDramBlockWindowTmp &v_dram_block_window_tmp, const BiasDramBlockWindowTmp &bias_dram_block_window_tmp, const RandValDramBlockWindowTmp &randval_dram_block_window_tmp, const OGradDramBlockWindowTmp &do_dram_block_window_tmp, const LSEDramBlockWindowTmp &lse_dram_block_window_tmp, const DDramBlockWindowTmp &d_dram_block_window_tmp, const QGradDramBlockWindowTmp &dq_dram_block_window_tmp, const KGradDramBlockWindowTmp &dk_dram_block_window_tmp, const VGradDramBlockWindowTmp &dv_dram_block_window_tmp, const BiasGradDramBlockWindowTmp &dbias_dram_block_window_tmp, const QGradEpilogue &dq_epilogue, const KGradEpilogue &dk_epilogue, const VGradEpilogue &dv_epilogue, FmhaMask mask, PositionEncoding position_encoding, float raw_scale, float scale, float rp_undrop, float scale_rp_undrop, FmhaDropout &dropout) const |
Static Public Member Functions | |
| static CK_TILE_HOST_DEVICE constexpr ck_tile::index_t | GetSmemSize () |
| static CK_TILE_HOST_DEVICE LSEDataType | get_validated_lse (const LSEDataType raw_lse) |
Static Public Attributes | |
| static constexpr auto | is_qr_qtr_dor_pipeline = true |
| static constexpr index_t | kBlockPerCu = Problem::kBlockPerCu |
| static constexpr index_t | kBlockSize = Problem::kBlockSize |
| static constexpr index_t | kM0 = BlockFmhaShape::kM0 |
| static constexpr index_t | kN0 = BlockFmhaShape::kN0 |
| static constexpr index_t | kK0 = BlockFmhaShape::kK0 |
| static constexpr index_t | kK1 = BlockFmhaShape::kK1 |
| static constexpr index_t | kK2 = BlockFmhaShape::kK2 |
| static constexpr index_t | kK3 = BlockFmhaShape::kK3 |
| static constexpr index_t | kK4 = BlockFmhaShape::kK4 |
| static constexpr index_t | kQKHeaddim = BlockFmhaShape::kQKHeaddim |
| static constexpr index_t | kVHeaddim = BlockFmhaShape::kVHeaddim |
| static constexpr bool | kIsGroupMode = Problem::kIsGroupMode |
| static constexpr index_t | kPadHeadDimQ = Problem::kPadHeadDimQ |
| static constexpr index_t | kPadHeadDimV = Problem::kPadHeadDimV |
| static constexpr auto | BiasEnum = Problem::BiasEnum |
| static constexpr bool | kHasBiasGrad = Problem::kHasBiasGrad |
| static constexpr bool | kIsDeterministic = Problem::kIsDeterministic |
| static constexpr bool | kUseTrLoad = Problem::kUseTrLoad |
| static constexpr index_t | kAlignmentQ |
| static constexpr index_t | kAlignmentK |
| static constexpr index_t | kAlignmentV |
| static constexpr index_t | kAlignmentOGrad |
| static constexpr index_t | kAlignmentQGrad = 1 |
| static constexpr index_t | kAlignmentKGrad |
| static constexpr index_t | kAlignmentVGrad |
| static constexpr index_t | kAlignmentBias = 1 |
| static constexpr const char * | name = "trload_kr_ktr_vr" |
Member Typedef Documentation
◆ AccDataType
template<typename Problem, typename Policy = BlockFmhaBwdPipelineTrLoadDefaultPolicy>
| using ck_tile::BlockFmhaBwdDQDKDVPipelineTrLoadQRQTRDOR< Problem, Policy >::AccDataType = remove_cvref_t<typename Problem::AccDataType> |
◆ BiasDataType
template<typename Problem, typename Policy = BlockFmhaBwdPipelineTrLoadDefaultPolicy>
| using ck_tile::BlockFmhaBwdDQDKDVPipelineTrLoadQRQTRDOR< Problem, Policy >::BiasDataType = remove_cvref_t<typename Problem::BiasDataType> |
◆ BiasGradDataType
template<typename Problem, typename Policy = BlockFmhaBwdPipelineTrLoadDefaultPolicy>
| using ck_tile::BlockFmhaBwdDQDKDVPipelineTrLoadQRQTRDOR< Problem, Policy >::BiasGradDataType = remove_cvref_t<typename Problem::BiasGradDataType> |
◆ BlockFmhaShape
template<typename Problem, typename Policy = BlockFmhaBwdPipelineTrLoadDefaultPolicy>
| using ck_tile::BlockFmhaBwdDQDKDVPipelineTrLoadQRQTRDOR< Problem, Policy >::BlockFmhaShape = remove_cvref_t<typename Problem::BlockFmhaShape> |
◆ DDataType
template<typename Problem, typename Policy = BlockFmhaBwdPipelineTrLoadDefaultPolicy>
| using ck_tile::BlockFmhaBwdDQDKDVPipelineTrLoadQRQTRDOR< Problem, Policy >::DDataType = remove_cvref_t<typename Problem::DDataType> |
◆ FmhaDropout
template<typename Problem, typename Policy = BlockFmhaBwdPipelineTrLoadDefaultPolicy>
| using ck_tile::BlockFmhaBwdDQDKDVPipelineTrLoadQRQTRDOR< Problem, Policy >::FmhaDropout = remove_cvref_t<typename Problem::FmhaDropout> |
◆ FmhaMask
template<typename Problem, typename Policy = BlockFmhaBwdPipelineTrLoadDefaultPolicy>
| using ck_tile::BlockFmhaBwdDQDKDVPipelineTrLoadQRQTRDOR< Problem, Policy >::FmhaMask = remove_cvref_t<typename Problem::FmhaMask> |
◆ GemmDataType
template<typename Problem, typename Policy = BlockFmhaBwdPipelineTrLoadDefaultPolicy>
| using ck_tile::BlockFmhaBwdDQDKDVPipelineTrLoadQRQTRDOR< Problem, Policy >::GemmDataType = remove_cvref_t<typename Problem::GemmDataType> |
◆ KDataType
template<typename Problem, typename Policy = BlockFmhaBwdPipelineTrLoadDefaultPolicy>
| using ck_tile::BlockFmhaBwdDQDKDVPipelineTrLoadQRQTRDOR< Problem, Policy >::KDataType = remove_cvref_t<typename Problem::KDataType> |
◆ KGradDataType
template<typename Problem, typename Policy = BlockFmhaBwdPipelineTrLoadDefaultPolicy>
| using ck_tile::BlockFmhaBwdDQDKDVPipelineTrLoadQRQTRDOR< Problem, Policy >::KGradDataType = remove_cvref_t<typename Problem::KGradDataType> |
◆ LSEDataType
template<typename Problem, typename Policy = BlockFmhaBwdPipelineTrLoadDefaultPolicy>
| using ck_tile::BlockFmhaBwdDQDKDVPipelineTrLoadQRQTRDOR< Problem, Policy >::LSEDataType = remove_cvref_t<typename Problem::LSEDataType> |
◆ ODataType
template<typename Problem, typename Policy = BlockFmhaBwdPipelineTrLoadDefaultPolicy>
| using ck_tile::BlockFmhaBwdDQDKDVPipelineTrLoadQRQTRDOR< Problem, Policy >::ODataType = remove_cvref_t<typename Problem::ODataType> |
◆ OGradDataType
template<typename Problem, typename Policy = BlockFmhaBwdPipelineTrLoadDefaultPolicy>
| using ck_tile::BlockFmhaBwdDQDKDVPipelineTrLoadQRQTRDOR< Problem, Policy >::OGradDataType = remove_cvref_t<typename Problem::OGradDataType> |
◆ QDataType
template<typename Problem, typename Policy = BlockFmhaBwdPipelineTrLoadDefaultPolicy>
| using ck_tile::BlockFmhaBwdDQDKDVPipelineTrLoadQRQTRDOR< Problem, Policy >::QDataType = remove_cvref_t<typename Problem::QDataType> |
◆ QGradDataType
template<typename Problem, typename Policy = BlockFmhaBwdPipelineTrLoadDefaultPolicy>
| using ck_tile::BlockFmhaBwdDQDKDVPipelineTrLoadQRQTRDOR< Problem, Policy >::QGradDataType = remove_cvref_t<typename Problem::QGradDataType> |
◆ RandValOutputDataType
template<typename Problem, typename Policy = BlockFmhaBwdPipelineTrLoadDefaultPolicy>
| using ck_tile::BlockFmhaBwdDQDKDVPipelineTrLoadQRQTRDOR< Problem, Policy >::RandValOutputDataType = remove_cvref_t<typename Problem::RandValOutputDataType> |
◆ VDataType
template<typename Problem, typename Policy = BlockFmhaBwdPipelineTrLoadDefaultPolicy>
| using ck_tile::BlockFmhaBwdDQDKDVPipelineTrLoadQRQTRDOR< Problem, Policy >::VDataType = remove_cvref_t<typename Problem::VDataType> |
◆ VGradDataType
template<typename Problem, typename Policy = BlockFmhaBwdPipelineTrLoadDefaultPolicy>
| using ck_tile::BlockFmhaBwdDQDKDVPipelineTrLoadQRQTRDOR< Problem, Policy >::VGradDataType = remove_cvref_t<typename Problem::VGradDataType> |
Member Function Documentation
◆ get_validated_lse()
template<typename Problem, typename Policy = BlockFmhaBwdPipelineTrLoadDefaultPolicy>
|
inlinestatic |
◆ GetSmemSize()
template<typename Problem, typename Policy = BlockFmhaBwdPipelineTrLoadDefaultPolicy>
|
inlinestaticconstexpr |
◆ operator()()
template<typename Problem, typename Policy = BlockFmhaBwdPipelineTrLoadDefaultPolicy>
template<typename... Ts>
|
inline |
◆ run()
template<typename Problem, typename Policy = BlockFmhaBwdPipelineTrLoadDefaultPolicy>
template<typename QDramBlockWindowTmp, typename KDramBlockWindowTmp, typename VDramBlockWindowTmp, typename BiasDramBlockWindowTmp, typename RandValDramBlockWindowTmp, typename OGradDramBlockWindowTmp, typename LSEDramBlockWindowTmp, typename DDramBlockWindowTmp, typename QGradDramBlockWindowTmp, typename KGradDramBlockWindowTmp, typename VGradDramBlockWindowTmp, typename BiasGradDramBlockWindowTmp, typename QGradEpilogue, typename KGradEpilogue, typename VGradEpilogue, typename PositionEncoding>
|
inline |
Member Data Documentation
◆ BiasEnum
template<typename Problem, typename Policy = BlockFmhaBwdPipelineTrLoadDefaultPolicy>
|
staticconstexpr |
◆ is_qr_qtr_dor_pipeline
template<typename Problem, typename Policy = BlockFmhaBwdPipelineTrLoadDefaultPolicy>
|
staticconstexpr |
◆ kAlignmentBias
template<typename Problem, typename Policy = BlockFmhaBwdPipelineTrLoadDefaultPolicy>
|
staticconstexpr |
◆ kAlignmentK
template<typename Problem, typename Policy = BlockFmhaBwdPipelineTrLoadDefaultPolicy>
|
staticconstexpr |
Initial value:
=
kPadHeadDimQ ? kPadHeadDimQ : Policy::template GetAlignmentK<Problem>()
static constexpr index_t kPadHeadDimQ
Definition block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp:52
◆ kAlignmentKGrad
template<typename Problem, typename Policy = BlockFmhaBwdPipelineTrLoadDefaultPolicy>
|
staticconstexpr |
Initial value:
=
kPadHeadDimQ ? kPadHeadDimQ : Policy::template GetAlignmentKGrad<Problem>()
◆ kAlignmentOGrad
template<typename Problem, typename Policy = BlockFmhaBwdPipelineTrLoadDefaultPolicy>
|
staticconstexpr |
Initial value:
=
kPadHeadDimV ? kPadHeadDimV : Policy::template GetAlignmentOGrad<Problem>()
static constexpr bool kPadHeadDimV
Definition block_fmha_bwd_dot_do_o.hpp:24
◆ kAlignmentQ
template<typename Problem, typename Policy = BlockFmhaBwdPipelineTrLoadDefaultPolicy>
|
staticconstexpr |
Initial value:
=
kPadHeadDimQ ? kPadHeadDimQ : Policy::template GetAlignmentQ<Problem>()
◆ kAlignmentQGrad
template<typename Problem, typename Policy = BlockFmhaBwdPipelineTrLoadDefaultPolicy>
|
staticconstexpr |
◆ kAlignmentV
template<typename Problem, typename Policy = BlockFmhaBwdPipelineTrLoadDefaultPolicy>
|
staticconstexpr |
Initial value:
=
kPadHeadDimV ? kPadHeadDimV : Policy::template GetAlignmentV<Problem>()
static constexpr index_t kPadHeadDimV
Definition block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp:53
◆ kAlignmentVGrad
template<typename Problem, typename Policy = BlockFmhaBwdPipelineTrLoadDefaultPolicy>
|
staticconstexpr |
Initial value:
=
kPadHeadDimV ? kPadHeadDimV : Policy::template GetAlignmentVGrad<Problem>()
◆ kBlockPerCu
template<typename Problem, typename Policy = BlockFmhaBwdPipelineTrLoadDefaultPolicy>
|
staticconstexpr |
◆ kBlockSize
template<typename Problem, typename Policy = BlockFmhaBwdPipelineTrLoadDefaultPolicy>
|
staticconstexpr |
◆ kHasBiasGrad
template<typename Problem, typename Policy = BlockFmhaBwdPipelineTrLoadDefaultPolicy>
|
staticconstexpr |
◆ kIsDeterministic
template<typename Problem, typename Policy = BlockFmhaBwdPipelineTrLoadDefaultPolicy>
|
staticconstexpr |
◆ kIsGroupMode
template<typename Problem, typename Policy = BlockFmhaBwdPipelineTrLoadDefaultPolicy>
|
staticconstexpr |
◆ kK0
template<typename Problem, typename Policy = BlockFmhaBwdPipelineTrLoadDefaultPolicy>
|
staticconstexpr |
◆ kK1
template<typename Problem, typename Policy = BlockFmhaBwdPipelineTrLoadDefaultPolicy>
|
staticconstexpr |
◆ kK2
template<typename Problem, typename Policy = BlockFmhaBwdPipelineTrLoadDefaultPolicy>
|
staticconstexpr |
◆ kK3
template<typename Problem, typename Policy = BlockFmhaBwdPipelineTrLoadDefaultPolicy>
|
staticconstexpr |
◆ kK4
template<typename Problem, typename Policy = BlockFmhaBwdPipelineTrLoadDefaultPolicy>
|
staticconstexpr |
◆ kM0
template<typename Problem, typename Policy = BlockFmhaBwdPipelineTrLoadDefaultPolicy>
|
staticconstexpr |
◆ kN0
template<typename Problem, typename Policy = BlockFmhaBwdPipelineTrLoadDefaultPolicy>
|
staticconstexpr |
◆ kPadHeadDimQ
template<typename Problem, typename Policy = BlockFmhaBwdPipelineTrLoadDefaultPolicy>
|
staticconstexpr |
◆ kPadHeadDimV
template<typename Problem, typename Policy = BlockFmhaBwdPipelineTrLoadDefaultPolicy>
|
staticconstexpr |
◆ kQKHeaddim
template<typename Problem, typename Policy = BlockFmhaBwdPipelineTrLoadDefaultPolicy>
|
staticconstexpr |
◆ kUseTrLoad
template<typename Problem, typename Policy = BlockFmhaBwdPipelineTrLoadDefaultPolicy>
|
staticconstexpr |
◆ kVHeaddim
template<typename Problem, typename Policy = BlockFmhaBwdPipelineTrLoadDefaultPolicy>
|
staticconstexpr |
◆ name
template<typename Problem, typename Policy = BlockFmhaBwdPipelineTrLoadDefaultPolicy>
|
staticconstexpr |
The documentation for this struct was generated from the following file: