tensor.hpp Source File

tensor.hpp Source File#

Composable Kernel: tensor.hpp Source File
tensor.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
9
10// Disable from doxygen docs generation
12namespace ck {
13namespace wrapper {
15
16// Disable from doxygen docs generation
18namespace {
19namespace detail {
25template <typename T>
26__host__ __device__ constexpr bool HasSlice(T&&)
27{
28 return is_detected<is_slice, T>::value;
29}
30template <typename... Ts>
31__host__ __device__ constexpr bool HasSlice(Tuple<Ts...>&&)
32{
33 return (HasSlice(Ts{}) || ...);
34}
35
43template <typename... Ts, typename SlicedShape>
44__host__ __device__ constexpr auto GetSlicedShape(const Tuple<Ts...>& idxs,
45 const SlicedShape& shape)
46{
47 // Pack each value in tuple to remove empty tuples after generation
48 auto new_shape = generate_tuple(
49 [&](auto i) {
50 constexpr auto num_i = Number<i>{};
51 if constexpr(is_detected<is_tuple, tuple_element_t<i.value, Tuple<Ts...>>>::value)
52 {
53 if constexpr(!detail::HasSlice(tuple_element_t<i.value, Tuple<Ts...>>{}))
54 {
55 // if tuple does not have any slice then we can remove dimension
56 return Tuple<>{};
57 }
58 else
59 {
60 // if tuple then recurrence
61 return make_tuple(GetSlicedShape(idxs.At(num_i), shape.At(num_i)));
62 }
63 }
64 else if constexpr(is_detected<is_slice, tuple_element_t<i.value, Tuple<Ts...>>>::value)
65 {
66 // calculate new dimension
67 const auto& dim = size(shape.At(num_i));
68 const auto val = idxs.At(num_i).range(dim);
69 return make_tuple(val);
70 }
71 else
72 {
73 // remove dimension for just value
74 return Tuple<>{};
75 }
76 },
77 Number<Tuple<Ts...>::Size()>{});
78 // Remove empty tuples (deleted elements) and return
79 return UnrollNestedTuple<0, 1>(new_shape);
80}
81
89template <typename T, typename Shape>
90__host__ __device__ constexpr auto GenerateMultipleFreeze(T idx, const Shape& shape)
91{
92 const auto unrolled_shape = UnrollNestedTuple(shape);
93 return generate_tuple(
94 [&](auto i) {
95 // dimension offset from idx
96 const auto dim = unrolled_shape.At(Number<i>{});
97 const auto dim_idx = idx % dim;
98 idx /= dim;
99 return make_freeze_transform(dim_idx);
100 },
101 Number<decltype(unrolled_shape)::Size()>{});
102}
103
111template <typename... Ts, typename Shape>
112__host__ __device__ constexpr auto GenerateSliceTransforms(const Tuple<Ts...>& idx,
113 const Shape& shape)
114{
115 // Pack each value in tuple to remove empty tuples after generation
116 auto transforms = generate_tuple(
117 [&](auto i) {
118 constexpr auto num_i = Number<i>{};
119 if constexpr(is_detected<is_tuple, tuple_element_t<i.value, Tuple<Ts...>>>::value)
120 {
121 return GenerateSliceTransforms(idx.At(num_i), shape.At(num_i));
122 }
123 else if constexpr(is_detected<is_slice, tuple_element_t<i.value, Tuple<Ts...>>>::value)
124 {
125
126 const auto from = idx.At(num_i).from_;
127 const auto dim = size<num_i>(shape);
128 const auto range = idx.At(num_i).range(dim);
129 return make_slice_transform(range, from, from + range);
130 }
131 else
132 {
133 // remove dimension for just value
134 return GenerateMultipleFreeze(idx.At(num_i), shape.At(num_i));
135 }
136 },
137 Number<Tuple<Ts...>::Size()>{});
138 // Remove empty tuples (deleted elements) and return
139 return UnrollNestedTuple(transforms);
140}
141
142template <index_t i, typename LowerIndex>
143__host__ __device__ constexpr auto GetSequenceVal(const ck::Freeze<LowerIndex>&)
144{
145 // There is no output for Freeze transform
146 return Sequence<>{};
147}
148
149template <index_t i, typename LowLength, typename SliceBegin, typename SliceEnd>
150__host__ __device__ constexpr auto GetSequenceVal(const ck::Slice<LowLength, SliceBegin, SliceEnd>&)
151{
152 return Sequence<i>{};
153}
154
155template <index_t i>
156__host__ __device__ constexpr auto GenerateUpperDims(const Tuple<>&)
157{
158 return Tuple<>{};
159}
160
161template <index_t i, typename... Transforms>
162__host__ __device__ constexpr auto GenerateUpperDims(const Tuple<Transforms...>& transforms)
163{
164 constexpr auto num_transforms = Tuple<Transforms...>::Size();
165 // Deduce Sequence element for specific transform
166 const auto current_elem = GetSequenceVal<i>(transforms.At(Number<0>{}));
167 if constexpr(is_same_v<decltype(current_elem), const Sequence<>>)
168 {
169 const auto next_tuple = GenerateUpperDims<i>(TupleSlice<1, num_transforms>(transforms));
170 return concat_tuple(make_tuple(current_elem), next_tuple);
171 }
172 else
173 {
174 // Increase i if current_elem is Slice transform
175 const auto next_tuple = GenerateUpperDims<i + 1>(TupleSlice<1, num_transforms>(transforms));
176 return concat_tuple(make_tuple(current_elem), next_tuple);
177 }
178}
179
180template <typename... Ts, typename Shape, typename UnrolledDescriptor>
181__host__ __device__ constexpr auto GenerateSlicedDescriptor(const Tuple<Ts...>& idx,
182 const Shape& shape,
183 const UnrolledDescriptor& flatten_desc)
184{
185 constexpr auto old_shape_dims = decltype(UnrollNestedTuple(shape))::Size();
186
187 const auto transforms = GenerateSliceTransforms(idx, shape);
188 using TransformsTupleType = decltype(transforms);
189
190 const auto lower_dims =
191 generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<old_shape_dims>{});
192 const auto upper_dims = decltype(GenerateUpperDims<0>(TransformsTupleType{})){};
193 return transform_tensor_descriptor(flatten_desc, transforms, lower_dims, upper_dims);
194}
195} // namespace detail
196} // namespace
198
209template <MemoryTypeEnum BufferAddressSpace,
210 typename ElementType,
211 typename Shape,
212 typename UnrolledDescriptorType>
213struct Tensor
214{
215 public:
217 Shape{}, UnrolledDescriptorType{}}.GetElementSpaceSize()); // SpaceSize type for buffer
218 using TensorElementType = std::conditional_t<
219 is_scalar_type<ElementType>::value,
220 ElementType,
221 typename scalar_type<std::remove_const_t<ElementType>>::type>; // DataType
222
223 static constexpr MemoryTypeEnum TensorBufferAddressSpace = BufferAddressSpace;
224 static constexpr bool IsDynamicBuffer = !(BufferAddressSpace == MemoryTypeEnum ::Sgpr ||
225 BufferAddressSpace == MemoryTypeEnum ::Vgpr);
226
227 __host__ __device__ Tensor() = delete;
228 __host__ __device__ constexpr Tensor(ElementType* pointer,
230 : layout_(layout),
231 buffer_(make_dynamic_buffer<BufferAddressSpace>(pointer, layout.GetElementSpaceSize())),
232 multi_idx_offset_(make_zero_multi_index<Shape::Size()>()),
233 base_offset_(0)
234 {
235 static_assert(IsDynamicBuffer, "Wrong BufferAddressSpace for register.");
236 }
237
238 __host__ __device__ constexpr Tensor(const Layout<Shape, UnrolledDescriptorType>& layout)
239 : layout_(layout),
240 multi_idx_offset_(make_zero_multi_index<Shape::Size()>()),
241 base_offset_(0)
242 {
243 static_assert(!IsDynamicBuffer, "Wrong BufferAddressSpace for register.");
244 }
245
246 __host__ __device__ constexpr const Layout<Shape, UnrolledDescriptorType>& GetLayout() const
247 {
248 return layout_;
249 }
250
257 template <typename... Ts, enable_if_t<detail::HasSlice(Tuple<Ts...>{}), bool> = false>
258 __host__ __device__ auto operator[](const Tuple<Ts...>& idx)
259 {
260 static_assert(IsDynamicBuffer, "Register slice is not supported");
261 const auto& shape = layout_.GetShape();
262 auto new_shape = detail::GetSlicedShape(idx, shape);
263
264 const auto& flatten_desc = layout_.GetUnrolledDescriptor();
265 auto new_desc = detail::GenerateSlicedDescriptor(idx, shape, flatten_desc);
266 const auto new_layout =
268 // Update embed offset
269 base_offset_ -= new_layout(make_tuple(Number<0>{}));
270 return make_tensor<BufferAddressSpace>(buffer_.p_data_, new_layout);
271 }
272
273 template <typename... Ts, enable_if_t<detail::HasSlice(Tuple<Ts...>{}), bool> = false>
274 __host__ __device__ auto operator()(const Tuple<Ts...>& idx)
275 {
276 return this->operator[](idx);
277 }
278
279 template <typename... Idxs, enable_if_t<detail::HasSlice(Tuple<Idxs...>{}), bool> = false>
280 __host__ __device__ auto operator()(Idxs... idxs)
281 {
282 return this->operator[](make_tuple(idxs...));
283 }
284
291 template <typename... Ts, enable_if_t<!detail::HasSlice(Tuple<Ts...>{}), bool> = false>
292 __host__ __device__ const TensorElementType& operator[](const Tuple<Ts...>& idx) const
293 {
294 if constexpr(IsDynamicBuffer)
295 {
296 const index_t offset = layout_(idx) + base_offset_;
297 return buffer_[offset];
298 }
299 else
300 {
301 constexpr index_t index_offset = Layout<Shape, UnrolledDescriptorType>{
302 Shape{},
303 UnrolledDescriptorType{}}.template operator()<Tuple<Ts...>>();
304 // Calculate and apply base offset in compile-time
305 constexpr index_t base_offset = Layout<Shape, UnrolledDescriptorType>{
306 Shape{},
307 UnrolledDescriptorType{}}.template operator()<MultiIndex<Shape::Size()>>();
308 return buffer_[Number<index_offset + base_offset>{}];
309 }
310 }
311
312 template <typename... Ts, enable_if_t<!detail::HasSlice(Tuple<Ts...>{}), bool> = false>
313 __host__ __device__ const TensorElementType& operator()(const Tuple<Ts...>& idx) const
314 {
315 return this->operator[](idx);
316 }
317
318 template <typename... Idxs, enable_if_t<!detail::HasSlice(Tuple<Idxs...>{}), bool> = false>
319 __host__ __device__ const TensorElementType& operator()(Idxs... idxs) const
320 {
321 return this->operator[](make_tuple(idxs...));
322 }
323
330 template <typename... Ts, enable_if_t<!detail::HasSlice(Tuple<Ts...>{}), bool> = false>
331 __host__ __device__ TensorElementType& operator[](const Tuple<Ts...>& idx)
332 {
333 if constexpr(IsDynamicBuffer)
334 {
335 const index_t offset = layout_(idx) + base_offset_;
336 return buffer_(offset);
337 }
338 else
339 {
340 constexpr index_t index_offset = Layout<Shape, UnrolledDescriptorType>{
341 Shape{},
342 UnrolledDescriptorType{}}.template operator()<Tuple<Ts...>>();
343 // Apply embed offset (calculate in compiletime)
344 constexpr index_t base_offset = Layout<Shape, UnrolledDescriptorType>{
345 Shape{},
346 UnrolledDescriptorType{}}.template operator()<MultiIndex<Shape::Size()>>();
347 return buffer_(Number<index_offset + base_offset>{});
348 }
349 }
350
351 template <typename... Ts, enable_if_t<!detail::HasSlice(Tuple<Ts...>{}), bool> = false>
352 __host__ __device__ TensorElementType& operator()(const Tuple<Ts...>& idx)
353 {
354 return this->operator[](idx);
355 }
356
357 template <typename... Idxs, enable_if_t<!detail::HasSlice(Tuple<Idxs...>{}), bool> = false>
358 __host__ __device__ TensorElementType& operator()(Idxs... idxs)
359 {
360 return this->operator[](make_tuple(idxs...));
361 }
362
368 __host__ __device__ constexpr auto GetMergedNestingDescriptor()
369 {
370 return layout_.GetMergedNestingDescriptor();
371 }
372
378 __host__ __device__ TensorElementType* GetPointer() const { return buffer_.p_data_; }
379
380 __host__ __device__ constexpr auto& GetBuffer() { return buffer_; }
381 __host__ __device__ constexpr auto& GetBuffer() const { return buffer_; }
382
388 __host__ __device__ constexpr auto& GetMultiIdxOffsets() const { return multi_idx_offset_; }
389
395 template <typename MultiIdxOffsets>
396 __host__ __device__ constexpr void SetMultiIdxOffset(const MultiIdxOffsets multi_idx_offset)
397 {
398 multi_idx_offset_ = multi_idx_offset;
399 base_offset_ += layout_(multi_idx_offset);
400 }
401
402 private:
403 // Disable from doxygen docs generation
405 using DynamicBufferType = DynamicBuffer<BufferAddressSpace,
406 ElementType,
407 ElementSpaceSize,
408 true /*InvalidElementUseNumericalZeroValue*/>;
409 using StaticBufferType = std::conditional_t<
410 is_scalar_type<ElementType>::value,
411 StaticBuffer<BufferAddressSpace,
412 ElementType,
413 size(Shape{}),
414 true /*InvalidElementUseNumericalZeroValue*/>,
415 StaticBufferTupleOfVector<BufferAddressSpace,
416 TensorElementType,
417 size(Shape{}) /
418 scalar_type<std::remove_const_t<ElementType>>::vector_size,
419 scalar_type<std::remove_const_t<ElementType>>::vector_size,
420 true /*InvalidElementUseNumericalZeroValue*/>>;
421 // If register use static buffer, else use dynamic buffer
422 using Buffer = std::conditional_t<IsDynamicBuffer, DynamicBufferType, StaticBufferType>;
423
424 const Layout<Shape, UnrolledDescriptorType> layout_;
425 Buffer buffer_;
426 // We use multi_idx_offset_ to enable the creation of a descriptor in
427 // compile time for partitions or tiles if tile shape and thread layout
428 // is known at compile time (We can use the same descriptor for each
429 // thread). Additionally, the copy between the static and dynamic buffer
430 // requires a descriptor known at compile time, so we can shift data using
431 // such multi_idx_offset_.
432 MultiIndex<Shape::Size()> multi_idx_offset_;
433 // Base offset and multi index offset are corresponding to exactly the
434 // same element in tensor ( and in physical memory ). Multi index offset
435 // is multi dimensional index. However base offset is calculated using
436 // tensor descriptor (thus all it's transforms) and is linear (1D).
437 // We store base_offset_ to avoid multiple recalculations.
438 index_t base_offset_;
440};
441
442} // namespace wrapper
443} // namespace ck
__host__ __device__ constexpr const auto & shape(const LayoutType &layout)
Get Layout shape.
Definition layout_utils.hpp:431
decltype(std::declval< T & >().IsTuple()) is_tuple
Definition device_grouped_conv_fwd_multiple_abd.hpp:23
Definition ck.hpp:268
__host__ __device__ constexpr auto concat_tuple(const Tuple< X... > &tx, const Tuple< Y... > &ty)
Definition tuple_helper.hpp:52
int32_t index_t
Definition ck.hpp:299
__host__ __device__ constexpr auto UnrollNestedTuple(const Tuple<> &element)
Definition tuple_helper.hpp:120
__host__ __device__ constexpr auto make_slice_transform(const LowLength &low_length, const SliceBegin &slice_begin, const SliceEnd &slice_end)
Definition multi_index_transform_helper.hpp:163
__host__ __device__ constexpr auto make_freeze_transform(const LowerIndex &low_idx)
Definition multi_index_transform_helper.hpp:151
typename detail::detector< nonesuch, void, Op, Args... >::value_t is_detected
Definition is_detected.hpp:34
integral_constant< index_t, N > Number
Definition number.hpp:12
typename tuple_element< I, TTuple >::type tuple_element_t
Definition utility/tuple.hpp:208
constexpr bool is_same_v
Definition type.hpp:283
__host__ __device__ constexpr auto generate_tuple(F &&f, Number< N >)
Definition tuple_helper.hpp:21
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
Array< index_t, N > MultiIndex
Definition array_multi_index.hpp:12
const GenericPointer< typename T::ValueType > T2 value
Definition pointer.h:1697
const GenericPointer< typename T::ValueType > & pointer
Definition pointer.h:1514
Layout wrapper that performs the tensor descriptor logic.
Definition layout.hpp:24
__host__ __device__ constexpr auto GetElementSpaceSize() const
Definition layout.hpp:297
__host__ __device__ constexpr const Layout< Shape, UnrolledDescriptorType > & GetLayout() const
Definition tensor.hpp:246
__host__ __device__ constexpr void SetMultiIdxOffset(const MultiIdxOffsets multi_idx_offset)
Apply multi index offset on the tensor.
Definition tensor.hpp:396
__host__ __device__ constexpr auto & GetBuffer() const
Definition tensor.hpp:381
decltype(Layout< Shape, UnrolledDescriptorType >{ Shape{}, UnrolledDescriptorType{}}.GetElementSpaceSize()) ElementSpaceSize
Definition tensor.hpp:216
static constexpr bool IsDynamicBuffer
Definition tensor.hpp:224
__host__ __device__ auto operator[](const Tuple< Ts... > &idx)
Get the new sliced tensor.
Definition tensor.hpp:258
__host__ __device__ constexpr auto & GetBuffer()
Definition tensor.hpp:380
std::conditional_t< is_scalar_type< ElementType >::value, ElementType, typename scalar_type< std::remove_const_t< ElementType > >::type > TensorElementType
Definition tensor.hpp:218
__host__ __device__ constexpr auto & GetMultiIdxOffsets() const
Get multi index offset to the data.
Definition tensor.hpp:388
static constexpr MemoryTypeEnum TensorBufferAddressSpace
Definition tensor.hpp:223
__host__ __device__ Tensor()=delete
__host__ __device__ constexpr Tensor(const Layout< Shape, UnrolledDescriptorType > &layout)
Definition tensor.hpp:238
__host__ __device__ TensorElementType * GetPointer() const
Get pointer to the data.
Definition tensor.hpp:378
std::size_t GetElementSpaceSize() const
Definition library/utility/host_tensor.hpp:810
__host__ __device__ constexpr auto GetMergedNestingDescriptor()
Get descriptor with all nested dimensions merged.
Definition tensor.hpp:368
__host__ __device__ constexpr Tensor(ElementType *pointer, const Layout< Shape, UnrolledDescriptorType > &layout)
Definition tensor.hpp:228
__host__ __device__ constexpr const auto & layout(const Tensor< BufferAddressSpace, ElementType, Shape, UnrolledDescriptorType > &tensor)
Get Tensor Layout.
Definition tensor_utils.hpp:162
AddressSpaceEnum MemoryTypeEnum
Memory type, allowed members:
Definition tensor_utils.hpp:30
constexpr auto make_tensor(ElementType *pointer, const Layout< Shape, UnrolledDescriptorType > &layout)
Make tensor function.
Definition tensor_utils.hpp:112