shuffle_tile.hpp Source File

shuffle_tile.hpp Source File#

Composable Kernel: shuffle_tile.hpp Source File
shuffle_tile.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
19
20namespace ck_tile {
21namespace detail {
22
23template <typename OutTensor, typename InTensor>
24CK_TILE_DEVICE void shuffle_tile_impl_in_thread(OutTensor& out_tensor, const InTensor& in_tensor)
25{
26 constexpr auto I0 = number<0>{};
27
28 using DataType = typename InTensor::DataType;
29
30 constexpr auto y_in_desc = InTensor::get_tile_distribution().get_ys_to_d_descriptor();
31 constexpr auto y_out_desc = OutTensor::get_tile_distribution().get_ys_to_d_descriptor();
32
33 // y_dim_out_to_in
34 constexpr auto get_rh_major_minor_to_y = [](auto dstr_tensor) {
35 using DstrEncode = typename decltype(dstr_tensor.get_tile_distribution())::DstrEncode;
36
37 map<array<index_t, 2>, index_t> rh_major_minor_to_y_;
38
40 constexpr index_t rh_major = DstrEncode::ys_to_rhs_major_[i];
41 constexpr index_t rh_minor = DstrEncode::ys_to_rhs_minor_[i];
42
43 rh_major_minor_to_y_({rh_major, rh_minor}) = i;
44 });
45
46 return rh_major_minor_to_y_;
47 };
48
49 constexpr auto rh_major_minor_to_y_in = get_rh_major_minor_to_y(InTensor{});
50 constexpr auto rh_major_minor_to_y_out = get_rh_major_minor_to_y(OutTensor{});
51
52 constexpr auto y_dim_out_to_in = [&] {
53 map<index_t, index_t> y_dim_out_to_in_;
54
55 for(const auto& [rh_major_minor, y_out] : rh_major_minor_to_y_out)
56 {
57 y_dim_out_to_in_(y_out) = rh_major_minor_to_y_in[rh_major_minor];
58 }
59
60 return y_dim_out_to_in_;
61 }();
62
63 //
64 constexpr index_t NDimY = InTensor::get_tile_distribution().get_num_of_dimension_y();
65
66 constexpr auto y_lengths = to_sequence(y_in_desc.get_lengths());
67
68 // input and output vector dim in the order of input Y dims
69 constexpr index_t y_dim_vec_in = NDimY - 1;
70 constexpr index_t y_dim_vec_out = y_dim_out_to_in[NDimY - 1];
71
72 // vector lengths
73 constexpr index_t vec_length_in = y_lengths[y_dim_vec_in];
74 constexpr index_t vec_length_out = y_lengths[y_dim_vec_out];
75
76 // # of vectors
77 constexpr index_t num_vec_in = vec_length_out;
78 constexpr index_t num_vec_out = vec_length_in;
79
82
83 // using InVec = typename InVec::type;
84 // using OutVec = typename OutVec::type;
85
86 // SFC
87 constexpr auto scalars_per_access_arr = generate_array(
88 [&](auto i) { return (i == y_dim_vec_in or i == y_dim_vec_out) ? y_lengths[i] : 1; },
90
91 constexpr auto scalars_per_access = TO_SEQUENCE(scalars_per_access_arr, NDimY);
92
93 using SFC_Y = space_filling_curve<decltype(y_lengths),
95 decltype(scalars_per_access)>;
96
97 constexpr index_t num_access = SFC_Y::get_num_of_access();
98
99 static_assert(num_access > 0, "wrong! num_access should be larger than 0");
100
101 // in/out vectors to be transposed
104
105 // loop over SFC and do transpose
106 static_for<0, num_access, 1>{}([&](auto iAccess) {
107 // data index [y0, y1, ...] in the order of input tensor
108 constexpr auto idx_y_start = SFC_Y::get_index(iAccess);
109
110 // get input vectors
111 static_for<0, num_vec_in, 1>{}([&](auto i) {
112 constexpr auto idx_y_in = generate_tuple(
113 [&](auto ii) {
114 return ii == y_dim_vec_out ? idx_y_start[ii] + i : idx_y_start[ii];
115 },
116 number<NDimY>{});
117
118 constexpr index_t in_offset = y_in_desc.calculate_offset(idx_y_in);
119 static_assert(in_offset % vec_length_in == 0);
120
121 in_vectors(i).template get_as<InVec>()(I0) =
122 in_tensor.get_thread_buffer()
123 .template get_as<InVec>()[number<in_offset / vec_length_in>{}];
124 });
125
126 // transpose
128
129 // set output vectors
130 static_for<0, num_vec_out, 1>{}([&](auto i) {
131 constexpr auto idx_y_out_tmp = generate_array(
132 [&](auto ii) {
133 return ii == y_dim_vec_in ? static_cast<index_t>(idx_y_start[ii]) + i
134 : static_cast<index_t>(idx_y_start[ii]);
135 },
136 number<NDimY>{});
137
138 constexpr auto idx_y_out =
139 container_reorder_given_new2old(idx_y_out_tmp, y_dim_out_to_in);
140
141 constexpr index_t out_offset = y_out_desc.calculate_offset(idx_y_out);
142 static_assert(out_offset % vec_length_out == 0);
143
144 out_tensor.get_thread_buffer().template set_as<OutVec>(
146 out_vectors[i].template get_as<OutVec>()[I0]);
147 });
148 });
149}
150
151} // namespace detail
152
153template <typename OutTensor, typename InTensor>
154CK_TILE_DEVICE void shuffle_tile(OutTensor& out, const InTensor& in)
155{
156 using InDataType = typename InTensor::DataType;
157 using OutDataType = typename OutTensor::DataType;
158
159 using InDstrEncode = typename InTensor::StaticTileDistribution::DstrEncode;
160 using OutDstrEncode = typename OutTensor::StaticTileDistribution::DstrEncode;
161
162 // type convert
164
165 // shuffle
166 if constexpr(InDstrEncode::rs_lengths_ == OutDstrEncode::rs_lengths_ &&
167 InDstrEncode::hs_lengthss_ == OutDstrEncode::hs_lengthss_ &&
168 InDstrEncode::ps_to_rhss_major_ == OutDstrEncode::ps_to_rhss_major_ &&
169 InDstrEncode::ps_to_rhss_minor_ == OutDstrEncode::ps_to_rhss_minor_ &&
170 InDstrEncode::NDimY == OutDstrEncode::NDimY)
171 {
173 }
174 else
175 {
176 static_assert(false, "The shuffle should always happen!");
177 }
178}
179
180} // namespace ck_tile
#define CK_TILE_DEVICE
Definition config.hpp:41
Definition arch.hpp:385
CK_TILE_DEVICE void shuffle_tile_impl_in_thread(OutTensor &out_tensor, const InTensor &in_tensor)
Definition shuffle_tile.hpp:24
Definition tile/core/algorithm/cluster_descriptor.hpp:13
CK_TILE_HOST_DEVICE constexpr auto container_reorder_given_new2old(const array< TData, NSize > &old_array, sequence< IRs... >)
Definition tile/core/container/container_helper.hpp:39
CK_TILE_DEVICE auto tile_elementwise_in(const InElementFunc &in_element_func, const InTensor &... in_dstr_tensors)
Definition tile_elementwise.hpp:40
CK_TILE_HOST_DEVICE constexpr auto generate_array(F &&f, number< N >)
Definition tile/core/container/sequence.hpp:1115
CK_TILE_DEVICE void shuffle_tile(OutTensor &out, const InTensor &in)
Definition shuffle_tile.hpp:154
constant< v > number
Definition tile/core/numeric/integral_constant.hpp:37
CK_TILE_HOST_DEVICE constexpr auto generate_tuple(F &&f, number< N >)
Definition tile/core/container/tuple.hpp:429
CK_TILE_HOST_DEVICE constexpr auto to_sequence(tuple< number< Is >... >)
Definition tile/core/container/sequence.hpp:1055
int32_t index_t
Definition integer.hpp:9
CK_TILE_HOST_DEVICE constexpr Y type_convert(X x)
Definition tile/core/numeric/type_convert.hpp:29
typename std::conditional< kHasContent, type0, type1 >::type type
Definition tile/core/container/sequence.hpp:302
A fixed-size array container similar to std::array with additional utilities.
Definition tile/core/container/array.hpp:43
Definition map.hpp:16
Definition space_filling_curve.hpp:20
Definition tile/core/utility/functional.hpp:43
Definition tile/core/utility/debug.hpp:67
Definition tile/core/utility/transpose_vectors.hpp:20
#define TO_SEQUENCE(a, n)
Definition to_sequence.hpp:10