container_helper.hpp Source File

container_helper.hpp Source File#

Composable Kernel: container_helper.hpp Source File
utility/container_helper.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#ifndef CK_CONTAINER_HELPER_HPP
5#define CK_CONTAINER_HELPER_HPP
6
7#include "sequence.hpp"
8#include "sequence_helper.hpp"
9#include "array.hpp"
10#include "tuple.hpp"
11#include "tuple_helper.hpp"
14
15namespace ck {
16
17template <typename TData, index_t NSize>
18__host__ __device__ constexpr auto container_push_back(const Array<TData, NSize>& a, const TData& x)
19{
21
22 static_for<0, NSize, 1>{}([&r, &a](auto i) constexpr { r(i) = a[i]; });
23
24 r(Number<NSize>{}) = x;
25
26 return r;
27}
28
29template <typename... Ts, typename T>
30__host__ __device__ constexpr auto container_push_front(const Tuple<Ts...>& a, const T& x)
31{
32 return container_concat(make_tuple(x), a);
33}
34
35template <typename... Ts, typename T>
36__host__ __device__ constexpr auto container_push_back(const Tuple<Ts...>& a, const T& x)
37{
38 return container_concat(a, make_tuple(x));
39}
40
41template <typename TData, index_t NSize, index_t... IRs>
42__host__ __device__ constexpr auto
44{
45 static_assert(NSize == sizeof...(IRs), "wrong! size not consistent");
46
47 static_assert(is_valid_sequence_map<Sequence<IRs...>>{}, "wrong! invalid reorder map");
48
49 return make_array(old_array[Number<IRs>{}]...);
50}
51
52template <typename TData, index_t NSize, index_t... IRs>
53__host__ __device__ constexpr auto
55{
57 old_array, typename sequence_map_inverse<decltype(old2new)>::type{});
58}
59
60template <typename... Ts, index_t... IRs>
61__host__ __device__ constexpr auto container_reorder_given_new2old(const Tuple<Ts...>& old_tuple,
62 Sequence<IRs...> /*new2old*/)
63{
64 static_assert(sizeof...(Ts) == sizeof...(IRs), "wrong! size not consistent");
65
66 static_assert(is_valid_sequence_map<Sequence<IRs...>>{}, "wrong! invalid reorder map");
67
68 return make_tuple(old_tuple[Number<IRs>{}]...);
69}
70
71template <typename... Ts, index_t... IRs>
72__host__ __device__ constexpr auto container_reorder_given_old2new(const Tuple<Ts...>& old_tuple,
73 Sequence<IRs...> old2new)
74{
76 old_tuple, typename sequence_map_inverse<decltype(old2new)>::type{});
77}
78
79template <index_t... Is, index_t... IRs>
80__host__ __device__ constexpr auto container_reorder_given_new2old(Sequence<Is...> /* old_seq */,
81 Sequence<IRs...> /*new2old*/)
82{
83 static_assert(sizeof...(Is) == sizeof...(IRs), "wrong! size not consistent");
84
85 static_assert(is_valid_sequence_map<Sequence<IRs...>>{}, "wrong! invalid reorder map");
86
88}
89
90template <index_t... Is, index_t... IRs>
91__host__ __device__ constexpr auto container_reorder_given_old2new(Sequence<Is...> old_seq,
92 Sequence<IRs...> /* old2new */)
93{
94 static_assert(sizeof...(Is) == sizeof...(IRs), "wrong! size not consistent");
95
96 static_assert(is_valid_sequence_map<Sequence<IRs...>>{}, "wrong! invalid reorder map");
97
98 constexpr auto new2old = typename sequence_map_inverse<Sequence<IRs...>>::type{};
99
100 return container_reorder_given_new2old(old_seq, new2old);
101}
102
103#if !CK_WORKAROUND_SWDEV_275126
104// rocm-4.1 compiler would crash for recursive lambda
105template <typename Container,
106 typename Reduce,
107 typename Init,
108 index_t IBegin = 0,
109 index_t IEnd = Container::Size(),
110 index_t IStep = 1>
111__host__ __device__ constexpr auto container_reduce(const Container& x,
112 Reduce reduce,
113 Init init,
115 Number<IEnd> = Number<Container::Size()>{},
117{
118 static_assert((IEnd - IBegin) % IStep == 0, "wrong!");
119
120 // f is recursive function, fs is a dummy of f
121 // i is index, y_old is current scan, r_old is current reduction
122 auto f = [&](auto fs, auto i, auto r_old) {
123 auto r_new = reduce(x[i], r_old);
124
125 if constexpr(i.value < IEnd - IStep)
126 {
127 // recursively call f/fs
128 return fs(fs, i + Number<IStep>{}, r_new);
129 }
130 else
131 {
132 return r_new;
133 }
134 };
135
136 // start recursion
137 return f(f, Number<IBegin>{}, init);
138}
139#else
140// i is index, y_old is current scan, r_old is current reduction
141template <typename Container,
142 typename Reduce,
143 typename ROld,
144 index_t I,
145 index_t IEnd,
146 index_t IStep>
147__host__ __device__ constexpr auto container_reduce_impl(
148 const Container& x, Reduce reduce, ROld r_old, Number<I> i, Number<IEnd>, Number<IStep>)
149{
150 auto r_new = reduce(x[i], r_old);
151
152 if constexpr(i.value < IEnd - IStep)
153 {
155 x, reduce, r_new, i + Number<IStep>{}, Number<IEnd>{}, Number<IStep>{});
156 }
157 else
158 {
159 return r_new;
160 }
161}
162
163// rocm-4.1 compiler would crash for recursive lambda
164// container reduce with initial value
165template <typename Container,
166 typename Reduce,
167 typename Init,
168 index_t IBegin = 0,
169 index_t IEnd = Container::Size(),
170 index_t IStep = 1>
171__host__ __device__ constexpr auto container_reduce(const Container& x,
172 Reduce reduce,
173 Init init,
175 Number<IEnd> = Number<Container::Size()>{},
177{
178 static_assert((IEnd - IBegin) % IStep == 0, "wrong!");
179
180 if constexpr(IEnd > IBegin)
181 {
184 }
185 else
186 {
187 return init;
188 }
189}
190#endif
191
192template <typename TData, index_t NSize, typename Reduce>
193__host__ __device__ constexpr auto
195{
197
198 TData r = init;
199
200 static_for<NSize - 1, 0, -1>{}([&](auto i) {
201 r = f(r, x[i]);
202 y(i) = r;
203 });
204
205 r = f(r, x[Number<0>{}]);
206 y(Number<0>{}) = r;
207
208 return y;
209}
210
211template <typename TData, index_t NSize, typename Reduce>
212__host__ __device__ constexpr auto
214{
216
217 TData r = init;
218
219 static_for<NSize - 1, 0, -1>{}([&](auto i) {
220 y(i) = r;
221 r = f(r, x[i]);
222 });
223
224 y(Number<0>{}) = r;
225
226 return y;
227}
228
229template <index_t... Is, typename Reduce, index_t Init>
230__host__ __device__ constexpr auto
235
236#if !CK_WORKAROUND_SWDEV_275126
237// rocm4.1 compiler would crash with recursive lambda
238template <typename... Xs, typename Reduce, typename Init>
239__host__ __device__ constexpr auto
241{
242 constexpr index_t NSize = sizeof...(Xs);
243
244 // f is recursive function, fs is a dummy of f
245 // i is index, y_old is current scan, r_old is current reduction
246 auto f = [&](auto fs, auto i, auto y_old, auto r_old) {
247 auto r_new = reduce(x[i], r_old);
248
249 auto y_new = container_push_front(y_old, r_new);
250
251 if constexpr(i.value > 1)
252 {
253 // recursively call f/fs
254 return fs(fs, i - Number<1>{}, y_new, r_new);
255 }
256 else
257 {
258 return y_new;
259 }
260 };
261
262 // start recursion
263 return f(f, Number<NSize - 1>{}, make_tuple(init), init);
264}
265#else
266// i is index, y_old is current scan, r_old is current reduction
267template <typename... Xs, typename Reduce, index_t I, typename YOld, typename ROld>
268__host__ __device__ constexpr auto container_reverse_exclusive_scan_impl(
269 const Tuple<Xs...>& x, Reduce reduce, Number<I> i, YOld y_old, ROld r_old)
270{
271 auto r_new = reduce(x[i], r_old);
272
273 auto y_new = container_push_front(y_old, r_new);
274
275 if constexpr(i.value > 1)
276 {
277 // recursively call f/fs
278 return container_reverse_exclusive_scan_impl(x, reduce, i - Number<1>{}, y_new, r_new);
279 }
280 else
281 {
282 return y_new;
283 }
284}
285
286template <typename... Xs, typename Reduce, typename Init>
287__host__ __device__ constexpr auto
288container_reverse_exclusive_scan(const Tuple<Xs...>& x, Reduce reduce, Init init)
289{
290 constexpr index_t NSize = sizeof...(Xs);
291
293 x, reduce, Number<NSize - 1>{}, make_tuple(init), init);
294}
295#endif
296
297// TODO: update to like container_reverse_exclusive_scan to deal with Tuple of Numebr<>
298template <typename... Xs, typename Reduce, typename TData>
299__host__ __device__ constexpr auto
300container_reverse_inclusive_scan(const Tuple<Xs...>& x, Reduce f, TData init)
301{
302 constexpr index_t NSize = sizeof...(Xs);
303
304 Tuple<Xs...> y;
305
306 TData r = init;
307
308 static_for<NSize - 1, 0, -1>{}([&](auto i) {
309 r = f(r, x[i]);
310 y(i) = r;
311 });
312
313 r = f(r, x[Number<0>{}]);
314 y(Number<0>{}) = r;
315
316 return y;
317}
318
319template <typename X, typename... Ys>
320__host__ __device__ constexpr auto container_concat(const X& x, const Ys&... ys)
321{
322 return container_concat(x, container_concat(ys...));
323}
324
325template <typename T, index_t NX, index_t NY>
326__host__ __device__ constexpr auto container_concat(const Array<T, NX>& ax, const Array<T, NY>& ay)
327{
328 return unpack2(
329 [&](auto&&... zs) { return make_array(ck::forward<decltype(zs)>(zs)...); }, ax, ay);
330}
331
332template <typename... X, typename... Y>
333__host__ __device__ constexpr auto container_concat(const Tuple<X...>& tx, const Tuple<Y...>& ty)
334{
335 return unpack2(
336 [&](auto&&... zs) { return make_tuple(ck::forward<decltype(zs)>(zs)...); }, tx, ty);
337}
338
339template <typename Container>
340__host__ __device__ constexpr auto container_concat(const Container& x)
341{
342 return x;
343}
344
345template <typename T, index_t N, index_t... Is>
346__host__ __device__ constexpr auto get_container_subset(const Array<T, N>& arr, Sequence<Is...>)
347{
348 static_assert(N >= sizeof...(Is), "wrong! size");
349
350 return make_array(arr[Number<Is>{}]...);
351}
352
353template <typename... Ts, index_t... Is>
354__host__ __device__ constexpr auto get_container_subset(const Tuple<Ts...>& tup, Sequence<Is...>)
355{
356 static_assert(sizeof...(Ts) >= sizeof...(Is), "wrong! size");
357
358 return make_tuple(tup[Number<Is>{}]...);
359}
360
361template <typename T, index_t N, index_t... Is>
362__host__ __device__ constexpr void
363set_container_subset(Array<T, N>& y, Sequence<Is...> picks, const Array<T, sizeof...(Is)>& x)
364{
365 static_assert(N >= sizeof...(Is), "wrong! size");
366
367 static_for<0, sizeof...(Is), 1>{}([&](auto i) { y(picks[i]) = x[i]; });
368}
369
370template <typename... Ys, index_t... Is, typename... Xs>
371__host__ __device__ constexpr void
373{
374 static_assert(sizeof...(Ys) >= sizeof...(Is) && sizeof...(Is) == sizeof...(Xs), "wrong! size");
375
376 static_for<0, sizeof...(Is), 1>{}([&](auto i) { y(picks[i]) = x[i]; });
377}
378
379template <index_t... Is>
380__host__ __device__ constexpr auto sequence_to_tuple_of_number(Sequence<Is...>)
381{
382 using Seq = Sequence<Is...>;
383
384 return generate_tuple(
385 [&](auto i) {
386 constexpr index_t tmp = Seq::At(i);
387 return Number<tmp>{};
388 },
389 Seq::Size());
390}
391
392} // namespace ck
393#endif
Definition reduction_operator.hpp:13
CK_TILE_HOST_DEVICE constexpr auto container_reverse_exclusive_scan_impl(const tuple< Xs... > &x, Reduce reduce, number< I > i, YOld y_old, ROld r_old)
Definition tile/core/container/container_helper.hpp:311
CK_TILE_HOST_DEVICE constexpr auto container_reduce_impl(const Container &x, Reduce reduce, ROld r_old, number< I > i, number< IEnd >, number< IStep >)
Definition tile/core/container/container_helper.hpp:174
Definition ck.hpp:268
__host__ __device__ constexpr auto container_concat(const X &x, const Ys &... ys)
Definition utility/container_helper.hpp:320
int32_t index_t
Definition ck.hpp:299
__host__ __device__ constexpr auto container_push_back(const Array< TData, NSize > &a, const TData &x)
Definition utility/container_helper.hpp:18
__host__ __device__ constexpr auto container_push_front(const Tuple< Ts... > &a, const T &x)
Definition utility/container_helper.hpp:30
integral_constant< index_t, N > Number
Definition number.hpp:12
__host__ __device__ constexpr auto sequence_to_tuple_of_number(Sequence< Is... >)
Definition utility/container_helper.hpp:380
__host__ __device__ constexpr auto container_reorder_given_old2new(const Array< TData, NSize > &old_array, Sequence< IRs... > old2new)
Definition utility/container_helper.hpp:54
__host__ __device__ constexpr auto get_container_subset(const Array< T, N > &arr, Sequence< Is... >)
Definition utility/container_helper.hpp:346
__host__ __device__ constexpr auto container_reduce(const Container &x, Reduce reduce, Init init, Number< IBegin >=Number< 0 >{}, Number< IEnd >=Number< Container::Size()>{}, Number< IStep >=Number< 1 >{})
Definition utility/container_helper.hpp:111
__host__ __device__ constexpr auto generate_tuple(F &&f, Number< N >)
Definition tuple_helper.hpp:21
__host__ __device__ constexpr auto make_array()
Definition utility/array.hpp:64
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__host__ __device__ constexpr auto container_reverse_exclusive_scan(const Array< TData, NSize > &x, Reduce f, TData init)
Definition utility/container_helper.hpp:213
__host__ __device__ constexpr auto container_reverse_inclusive_scan(const Array< TData, NSize > &x, Reduce f, TData init)
Definition utility/container_helper.hpp:194
__host__ __device__ constexpr auto reverse_exclusive_scan_sequence(Seq, Reduce, Number< Init >)
Definition utility/sequence.hpp:805
__host__ __device__ constexpr auto unpack2(F &&f, X &&x, Y &&y)
Definition functional4.hpp:55
__host__ __device__ constexpr void set_container_subset(Array< T, N > &y, Sequence< Is... > picks, const Array< T, sizeof...(Is)> &x)
Definition utility/container_helper.hpp:363
__host__ __device__ constexpr auto container_reorder_given_new2old(const Array< TData, NSize > &old_array, Sequence< IRs... >)
Definition utility/container_helper.hpp:43
const GenericPointer< typename T::ValueType > T2 T::AllocatorType & a
Definition pointer.h:1517
Definition utility/array.hpp:14
Definition utility/sequence.hpp:43
__host__ static __device__ constexpr index_t At(index_t I)
Definition utility/sequence.hpp:53
Definition utility/tuple.hpp:117
Definition utility/sequence.hpp:618
Definition utility/sequence.hpp:623
Definition functional2.hpp:33