load_interleaved_pk_type.hpp Source File

load_interleaved_pk_type.hpp Source File#

Composable Kernel: load_interleaved_pk_type.hpp Source File
load_interleaved_pk_type.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
8
9namespace ck_tile {
10
11template <class T>
12struct is_pk_int4 : std::false_type
13{
14};
15template <>
16struct is_pk_int4<pk_int4_t> : std::true_type
17{
18};
19
20template <typename ComputeDataType, index_t UnaryOpSize>
22{
23 template <typename WarpWindow, typename WarpTile>
24 CK_TILE_DEVICE static void load_interleaved_pk_type(WarpTile& warp_tile,
25 const WarpWindow& warp_window)
26 {
27 const element_wise::PassThroughPack8 elementwise_op{};
28
29 static_assert(WarpTile::get_thread_buffer_size() % UnaryOpSize == 0);
30 constexpr index_t thread_buffer_size = WarpTile::get_thread_buffer_size() / UnaryOpSize;
31 const auto in_dstr_tensors = load_tile(warp_window);
32
33 using ComputeVectorType = ComputeDataType __attribute__((ext_vector_type(UnaryOpSize)));
35 elementwise_op(warp_tile.get_thread_buffer().template get_as<ComputeVectorType>()(i),
36 in_dstr_tensors.get_thread_buffer().template get_as<pk_int4x4_t>()[i]);
37 });
38 }
39};
40
41template <typename BDataType,
42 typename ComputeDataType,
43 index_t UnaryOpSize,
44 typename WarpTile,
45 typename WarpWindow>
46CK_TILE_DEVICE void load_int4_tile(WarpTile& dst, const WarpWindow& src)
47{
49 {
51 }
52 else
53 {
54 dst = load_tile(src);
55 }
56}
57
58} // namespace ck_tile
#define CK_TILE_DEVICE
Definition config.hpp:41
Definition tile/core/algorithm/cluster_descriptor.hpp:13
CK_TILE_DEVICE void load_int4_tile(WarpTile &dst, const WarpWindow &src)
Definition load_interleaved_pk_type.hpp:46
int32_t index_t
Definition integer.hpp:9
CK_TILE_DEVICE auto load_tile(const TileWindow_ &tile_window, number< i_access >={}, bool_constant< oob_conditional_check >={})
Definition load_tile.hpp:22
const GenericPointer< typename T::ValueType > T2 value
Definition pointer.h:1697
Definition load_interleaved_pk_type.hpp:22
static CK_TILE_DEVICE void load_interleaved_pk_type(WarpTile &warp_tile, const WarpWindow &warp_window)
Definition load_interleaved_pk_type.hpp:24
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:351
Definition load_interleaved_pk_type.hpp:13
Definition pk_int4.hpp:21
Definition tile/core/utility/functional.hpp:43