batched_contraction_kernel.hpp Source File#
batched_contraction_kernel.hpp
Go to the documentation of this file.
Definition tile/core/algorithm/cluster_descriptor.hpp:13
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
Definition batched_contraction_kernel.hpp:100
const std::array< std::vector< ck_tile::index_t >, NumDTensor > Ds_dims
Dimension vectors for D tensors: [G0, G1, ..., M0, M1, ..., N0, N1, ...].
Definition batched_contraction_kernel.hpp:162
const void * b_ptr
Pointer to input tensor B.
Definition batched_contraction_kernel.hpp:153
const std::vector< ck_tile::index_t > E_strides
Stride vector for tensor E: [G0, G1, ..., M0, M1, ..., N0, N1, ...].
Definition batched_contraction_kernel.hpp:172
const std::vector< ck_tile::index_t > E_dims
Dimension vector for tensor E: [G0, G1, ..., M0, M1, ..., N0, N1, ...].
Definition batched_contraction_kernel.hpp:164
std::array< const void *, NumDTensor > ds_ptr
Array of pointers to auxiliary input tensors D.
Definition batched_contraction_kernel.hpp:154
const std::vector< ck_tile::index_t > B_dims
Dimension vector for tensor B: [G0, G1, ..., N0, N1, ..., K0, K1, ...].
Definition batched_contraction_kernel.hpp:160
void * e_ptr
Pointer to output tensor E.
Definition batched_contraction_kernel.hpp:155
const std::vector< ck_tile::index_t > A_dims
Dimension vector for tensor A: [G0, G1, ..., M0, M1, ..., K0, K1, ...].
Definition batched_contraction_kernel.hpp:158
const std::array< std::vector< ck_tile::index_t >, NumDTensor > Ds_strides
Stride vectors for D tensors: [G0, G1, ..., M0, M1, ..., N0, N1, ...].
Definition batched_contraction_kernel.hpp:170
const std::vector< ck_tile::index_t > A_strides
Stride vector for tensor A: [G0, G1, ..., M0, M1, ..., K0, K1, ...].
Definition batched_contraction_kernel.hpp:166
ck_tile::index_t k_batch
Number of k-splits for split-K batching.
Definition batched_contraction_kernel.hpp:156
CK_TILE_HOST BatchedContractionHostArgs(const void *a_ptr_, const void *b_ptr_, const std::array< const void *, NumDTensor > &ds_ptr_, void *e_ptr_, ck_tile::index_t k_batch_, const std::vector< ck_tile::index_t > &A_dims_, const std::vector< ck_tile::index_t > &B_dims_, const std::array< std::vector< ck_tile::index_t >, NumDTensor > &Ds_dims_, const std::vector< ck_tile::index_t > &E_dims_, const std::vector< ck_tile::index_t > &A_strides_, const std::vector< ck_tile::index_t > &B_strides_, const std::array< std::vector< ck_tile::index_t >, NumDTensor > &Ds_strides_, const std::vector< ck_tile::index_t > &E_strides_)
Constructor for batched contraction host arguments.
Definition batched_contraction_kernel.hpp:117
const void * a_ptr
Pointer to input tensor A.
Definition batched_contraction_kernel.hpp:152
const std::vector< ck_tile::index_t > B_strides
Stride vector for tensor B: [G0, G1, ..., N0, N1, ..., K0, K1, ...].
Definition batched_contraction_kernel.hpp:168
Kernel arguments for batched tensor contraction operations.
Definition batched_contraction_kernel.hpp:189
const void * b_ptr
Definition batched_contraction_kernel.hpp:191
std::array< ck_tile::index_t, NumDTensor > stride_Ds
Definition batched_contraction_kernel.hpp:216
const void * a_ptr
Definition batched_contraction_kernel.hpp:190
ck_tile::index_t stride_A
Definition batched_contraction_kernel.hpp:213
ck_tile::index_t M_total
Definition batched_contraction_kernel.hpp:209
ck_tile::index_t G_total
Definition batched_contraction_kernel.hpp:208
ck_tile::index_t stride_E
Definition batched_contraction_kernel.hpp:217
std::array< ck_tile::index_t, NumDTensor > batch_stride_Ds
Definition batched_contraction_kernel.hpp:206
std::array< const void *, NumDTensor > ds_ptr
Definition batched_contraction_kernel.hpp:192
ck_tile::index_t M_dims[NumDimM]
Definition batched_contraction_kernel.hpp:196
ck_tile::index_t K_dims[NumDimK]
Definition batched_contraction_kernel.hpp:198
ck_tile::index_t stride_B
Definition batched_contraction_kernel.hpp:214
ck_tile::index_t G_dims[NumDimG]
Definition batched_contraction_kernel.hpp:200
ck_tile::index_t N_dims[NumDimN]
Definition batched_contraction_kernel.hpp:197
ck_tile::index_t batch_stride_E
Definition batched_contraction_kernel.hpp:205
ck_tile::index_t K_total
Definition batched_contraction_kernel.hpp:211
ck_tile::index_t batch_stride_A
Definition batched_contraction_kernel.hpp:203
ck_tile::index_t k_batch
Definition batched_contraction_kernel.hpp:194
ck_tile::index_t batch_stride_B
Definition batched_contraction_kernel.hpp:204
void * e_ptr
Definition batched_contraction_kernel.hpp:193
ck_tile::index_t N_total
Definition batched_contraction_kernel.hpp:210
GPU kernel for batched tensor contraction operations.
Definition batched_contraction_kernel.hpp:238
static CK_TILE_HOST constexpr auto GetBlockSize()
Returns the GPU block size for kernel launch.
Definition batched_contraction_kernel.hpp:318
static constexpr ck_tile::index_t NumDTensor
Number of auxiliary input D tensors.
Definition batched_contraction_kernel.hpp:259
static CK_TILE_HOST constexpr KernelArgs MakeKernelArgs(const BatchedContractionHostArgs< NumDTensor > &host_args)
Definition batched_contraction_kernel.hpp:330
static constexpr ck_tile::index_t NumDimM
Number of M (output row) dimensions.
Definition batched_contraction_kernel.hpp:253
ck_tile::UniversalGemmKernel< TilePartitioner_, GemmPipeline_, EpiloguePipeline_ > UniversalGemmKernel
Definition batched_contraction_kernel.hpp:271
static CK_TILE_HOST constexpr bool IsSupportedArguments(const KernelArgs &kargs)
Validates whether the given kernel arguments are supported.
Definition batched_contraction_kernel.hpp:290
ck_tile::remove_cvref_t< EpiloguePipeline_ > EpiloguePipeline
Epilogue pipeline for post-GEMM operations.
Definition batched_contraction_kernel.hpp:267
ck_tile::remove_cvref_t< Problem_ > Problem
Tensor contraction problem specification.
Definition batched_contraction_kernel.hpp:240
static CK_TILE_HOST constexpr auto GridSize(const KernelArgs &kargs)
Definition batched_contraction_kernel.hpp:323
static CK_TILE_HOST constexpr auto GetKernelName()
Returns the kernel name for debugging and profiling purposes.
Definition batched_contraction_kernel.hpp:284
static constexpr ck_tile::index_t NumDimG
Number of batch dimensions.
Definition batched_contraction_kernel.hpp:252
ck_tile::remove_cvref_t< GemmPipeline_ > GemmPipeline
GEMM computation pipeline.
Definition batched_contraction_kernel.hpp:266
CK_TILE_DEVICE void operator()(const KernelArgs &kargs) const
Definition batched_contraction_kernel.hpp:461
ck_tile::remove_cvref_t< TilePartitioner_ > TilePartitioner
Definition batched_contraction_kernel.hpp:263
ck_tile::remove_cvref_t< typename Problem::ADataType > ADataType
Data type for input tensor A.
Definition batched_contraction_kernel.hpp:241
static CK_TILE_HOST constexpr ck_tile::index_t GetSmemSize()
Returns the shared memory size required by the kernel.
Definition batched_contraction_kernel.hpp:311
ck_tile::remove_cvref_t< typename Problem::EDataType > EDataType
Data type for output tensor E.
Definition batched_contraction_kernel.hpp:248
static constexpr ck_tile::index_t NumDimN
Number of N (output column) dimensions.
Definition batched_contraction_kernel.hpp:255
BatchedContractionKernelArgs< NumDimG, NumDimM, NumDimN, NumDimK, NumDTensor > KernelArgs
Definition batched_contraction_kernel.hpp:277
static constexpr ck_tile::index_t NumDimK
Number of K (contraction) dimensions.
Definition batched_contraction_kernel.hpp:257
ck_tile::remove_cvref_t< typename Problem::DsDataType > DsDataType
Definition batched_contraction_kernel.hpp:245
ck_tile::remove_cvref_t< typename Problem::BDataType > BDataType
Data type for input tensor B.
Definition batched_contraction_kernel.hpp:243
static constexpr ck_tile::index_t kBlockSize
GPU block size inherited from GEMM kernel.
Definition batched_contraction_kernel.hpp:274
Definition universal_gemm_kernel.hpp:325
std::array< index_t, NumATensor > as_k_split_offset
Definition universal_gemm_kernel.hpp:368
std::array< index_t, NumBTensor > bs_k_split_offset
Definition universal_gemm_kernel.hpp:369
The Universal GEMM kernel template.
Definition universal_gemm_kernel.hpp:154
static CK_TILE_DEVICE void RunGemm(const std::array< const ADataType *, NumATensor > &as_ptr, const std::array< const BDataType *, NumBTensor > &bs_ptr, const std::array< const void *, NumDTensor > &ds_ptr, EDataType *e_ptr, void *smem_ptr_0, const KernelArgs &kargs, const SplitKBatchOffset &splitk_batch_offset, const index_t block_idx_m, const index_t block_idx_n)
Runs single GEMM problem cooperatively by whole workgroup.
Definition universal_gemm_kernel.hpp:955
static CK_TILE_HOST_DEVICE constexpr index_t GetSmemSize()
Definition universal_gemm_kernel.hpp:319
static CK_TILE_HOST bool IsSupportedArgument(const KernelArgs &kargs)
Definition universal_gemm_kernel.hpp:373
static constexpr index_t kBlockSize
Definition universal_gemm_kernel.hpp:202
UniversalGemmKernelArgs< AsLayout::size(), BsLayout::size(), DsLayout::size()> KernelArgs
Definition universal_gemm_kernel.hpp:257