batched_transpose_kernel.hpp Source File

batched_transpose_kernel.hpp Source File#

Composable Kernel: batched_transpose_kernel.hpp Source File
batched_transpose_kernel.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include "ck_tile/core.hpp"
10#include <string>
11#include <type_traits>
12
13namespace ck_tile {
14
26
27template <typename Pipeline_>
29{
30
34
35 using Type = typename Problem::DataType;
36
37 static constexpr index_t kBlockSize = Problem::kBlockSize;
38
48
51
52 CK_TILE_HOST static constexpr auto GridSize(const Hargs& host_args)
53 {
54 const size_t grid_size_x =
56 const size_t grid_size_y =
58 const size_t grid_size_z = host_args.batch;
59 return dim3(grid_size_x, grid_size_y, grid_size_z);
60 }
61
62 CK_TILE_HOST static constexpr auto MakeKargs(const Hargs& h)
63 {
64 Kargs k;
65 k.p_input = h.p_input;
66 k.p_output = h.p_output;
67 k.batch = h.batch;
68 k.height = h.height;
69 k.width = h.width;
71 return k;
72 }
73
74 CK_TILE_HOST static constexpr auto BlockSize() { return Problem::kBlockSize; }
75
77 {
78 static constexpr ck_tile::index_t kMPerBlock = Problem::kMPerBlock;
79 static constexpr ck_tile::index_t kNPerBlock = Problem::kNPerBlock;
80 static constexpr bool kPadM = Problem::kPadM;
81 static constexpr bool kPadN = Problem::kPadN;
82 static constexpr ck_tile::index_t VectorSizeInput = Problem::VectorSizeInput;
83 static constexpr ck_tile::index_t VectorStrideInput = 1;
84 static constexpr ck_tile::index_t VectorSizeOutput = Problem::VectorSizeOutput;
85 static constexpr ck_tile::index_t VectorStrideOutput = 1;
86
87 const auto iM = amd_wave_read_first_lane(blockIdx.x * kMPerBlock);
88 const auto iN = amd_wave_read_first_lane(blockIdx.y * kNPerBlock);
89 const auto offset = amd_wave_read_first_lane(blockIdx.z * kargs.height * kargs.width);
90
91 const auto x_m_n = [&]() {
93 static_cast<const Type*>(kargs.p_input) + offset,
94 make_tuple(kargs.height, kargs.width),
95 make_tuple(kargs.width, 1),
98
99 return pad_tensor_view(x_dram_naive,
102 }();
103
104 const auto y_n_m = [&]() {
106 static_cast<Type*>(kargs.p_output) + offset,
107 make_tuple(kargs.width, kargs.height),
108 make_tuple(kargs.height, 1),
111
112 return pad_tensor_view(y_dram_naive,
115 }();
116
117 auto x_block_window = make_tile_window(
118 x_m_n,
120 {static_cast<ck_tile::index_t>(iM), static_cast<ck_tile::index_t>(iN)});
121
122 auto y_block_window = make_tile_window(
123 y_n_m,
125 {static_cast<ck_tile::index_t>(iN), static_cast<ck_tile::index_t>(iM)});
126
127 Pipeline{}(x_block_window, y_block_window);
128 }
129};
130} // namespace ck_tile
#define CK_TILE_DEVICE
Definition config.hpp:41
#define CK_TILE_HOST
Definition config.hpp:40
Definition tile/core/algorithm/cluster_descriptor.hpp:13
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
CK_TILE_HOST_DEVICE constexpr auto make_naive_tensor_view(DataType *__restrict__ p, const tuple< Lengths... > &lengths, const tuple< Strides... > &strides, number< GuaranteedLastDimensionVectorLength >=number<-1 >{}, number< GuaranteedLastDimensionVectorStride >=number<-1 >{})
Definition tensor_view.hpp:471
__device__ uint32_t amd_wave_read_first_lane(uint16_t v)
Definition tile/core/arch/amd_buffer_addressing.hpp:35
constant< v > number
Definition tile/core/numeric/integral_constant.hpp:37
CK_TILE_DEVICE constexpr auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition null_tile_window.hpp:75
CK_TILE_HOST_DEVICE constexpr auto integer_divide_ceil(X x, Y y)
Definition tile/core/numeric/math.hpp:149
CK_TILE_HOST_DEVICE constexpr auto pad_tensor_view(const TensorView &tensor_view, const TileLengths &tile_lengths, DoPads)
Definition tensor_view.hpp:530
int32_t index_t
Definition integer.hpp:9
CK_TILE_HOST_DEVICE constexpr auto make_tuple(Xs &&... xs)
Definition tile/core/container/tuple.hpp:360
Definition batched_transpose_kernel.hpp:16
index_t height
Definition batched_transpose_kernel.hpp:20
index_t batch
Definition batched_transpose_kernel.hpp:19
void * p_output
Definition batched_transpose_kernel.hpp:18
index_t dim_block_w
Definition batched_transpose_kernel.hpp:24
index_t dim_stride
Definition batched_transpose_kernel.hpp:22
const void * p_input
Definition batched_transpose_kernel.hpp:17
index_t width
Definition batched_transpose_kernel.hpp:21
index_t dim_block_h
Definition batched_transpose_kernel.hpp:23
Definition batched_transpose_kernel.hpp:40
index_t width
Definition batched_transpose_kernel.hpp:45
index_t height
Definition batched_transpose_kernel.hpp:44
index_t dim_stride
Definition batched_transpose_kernel.hpp:46
index_t batch
Definition batched_transpose_kernel.hpp:43
const void * p_input
Definition batched_transpose_kernel.hpp:41
void * p_output
Definition batched_transpose_kernel.hpp:42
Definition batched_transpose_kernel.hpp:29
static CK_TILE_HOST constexpr auto GridSize(const Hargs &host_args)
Definition batched_transpose_kernel.hpp:52
remove_cvref_t< typename Pipeline::Problem > Problem
Definition batched_transpose_kernel.hpp:33
CK_TILE_DEVICE void operator()(Kargs kargs) const
Definition batched_transpose_kernel.hpp:76
static CK_TILE_HOST constexpr auto MakeKargs(const Hargs &h)
Definition batched_transpose_kernel.hpp:62
static CK_TILE_HOST constexpr auto BlockSize()
Definition batched_transpose_kernel.hpp:74
typename Problem::DataType Type
Definition batched_transpose_kernel.hpp:35
static CK_TILE_DEVICE index_t counter
Definition batched_transpose_kernel.hpp:31
BatchedTransposeHostArgs Hargs
Definition batched_transpose_kernel.hpp:50
remove_cvref_t< Pipeline_ > Pipeline
Definition batched_transpose_kernel.hpp:32
BatchedTransposeKargs Kargs
Definition batched_transpose_kernel.hpp:49
static constexpr index_t kBlockSize
Definition batched_transpose_kernel.hpp:37
Definition coordinate_transform.hpp:1392
Definition tile/core/container/sequence.hpp:49