reference_grouped_conv_fwd.hpp Source File

reference_grouped_conv_fwd.hpp Source File#

Composable Kernel: reference_grouped_conv_fwd.hpp Source File
reference_grouped_conv_fwd.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include <cstdlib>
7#include <thread>
8
9#include "ck_tile/core.hpp"
12
13namespace ck_tile {
14
15template <ck_tile::index_t NDimSpatial,
16 typename InDataType,
17 typename WeiDataType,
18 typename OutDataType,
19 typename Elfunc = ck_tile::element_wise::PassThrough,
20 typename Tuple = ck_tile::tuple<>>
22 const HostTensor<WeiDataType>& weight,
24 std::vector<ck_tile::long_index_t> conv_strides,
25 std::vector<ck_tile::long_index_t> conv_dilations,
26 std::vector<ck_tile::long_index_t> in_left_pads,
27 std::vector<ck_tile::long_index_t>,
28 Elfunc elfunc = Elfunc{},
29 Tuple ds = {})
30{
31 if(!(input.get_num_of_dimension() == NDimSpatial + 3 &&
32 weight.get_num_of_dimension() == NDimSpatial + 3 &&
33 output.get_num_of_dimension() == NDimSpatial + 3))
34 {
35 throw std::runtime_error("wrong! inconsistent dimension");
36 }
37
38 if constexpr(NDimSpatial == 1)
39 {
40 auto func = [&](auto g, auto n, auto k, auto wo) {
41 float v_acc = 0;
42
43 for(std::size_t c = 0; c < weight.get_lengths()[2]; ++c)
44 {
45 for(std::size_t x = 0; x < weight.get_lengths()[3]; ++x)
46 {
47 auto wi = static_cast<ck_tile::long_index_t>(wo * conv_strides[0]) +
48 static_cast<ck_tile::long_index_t>(x * conv_dilations[0]) -
49 static_cast<ck_tile::long_index_t>(in_left_pads[0]);
50
51 if(wi >= 0 && ck_tile::type_convert<std::size_t>(wi) < input.get_lengths()[3])
52 {
53 InDataType v_in = input(g, n, c, wi);
54 WeiDataType v_wei = weight(g, k, c, x);
55 v_acc += ck_tile::type_convert<float>(v_in) *
57 }
58 }
59 }
60 if constexpr(Tuple::size() > 0)
61 elfunc(v_acc, v_acc, ds.at(ck_tile::number<0>{})(g, n, k, wo));
62 else
63 elfunc(v_acc, v_acc);
64 OutDataType v_acc_out = ck_tile::type_convert<OutDataType>(v_acc);
65 output(g, n, k, wo) = v_acc_out;
66 };
67
69 output.get_lengths()[0],
70 output.get_lengths()[1],
71 output.get_lengths()[2],
72 output.get_lengths()[3])(std::thread::hardware_concurrency());
73 }
74 else if constexpr(NDimSpatial == 2)
75 {
76 auto func = [&](auto g, auto n, auto k, auto ho, auto wo) {
77 float v_acc = 0;
78
79 for(std::size_t c = 0; c < weight.get_lengths()[2]; ++c)
80 {
81 for(std::size_t y = 0; y < weight.get_lengths()[3]; ++y)
82 {
83 auto hi = static_cast<ck_tile::long_index_t>(ho * conv_strides[0]) +
84 static_cast<ck_tile::long_index_t>(y * conv_dilations[0]) -
85 static_cast<ck_tile::long_index_t>(in_left_pads[0]);
86
87 for(std::size_t x = 0; x < weight.get_lengths()[4]; ++x)
88 {
89 auto wi = static_cast<ck_tile::long_index_t>(wo * conv_strides[1]) +
90 static_cast<ck_tile::long_index_t>(x * conv_dilations[1]) -
91 static_cast<ck_tile::long_index_t>(in_left_pads[1]);
92
93 if(hi >= 0 &&
95 wi >= 0 &&
97 {
98 InDataType v_in = input(g, n, c, hi, wi);
99 WeiDataType v_wei = weight(g, k, c, y, x);
100
101 v_acc += ck_tile::type_convert<float>(v_in) *
103 }
104 }
105 }
106 }
107 if constexpr(Tuple::size() > 0)
108 elfunc(v_acc, v_acc, ds.at(ck_tile::number<0>{})(g, n, k, ho, wo));
109 else
110 elfunc(v_acc, v_acc);
111 OutDataType v_acc_out = ck_tile::type_convert<OutDataType>(v_acc);
112 output(g, n, k, ho, wo) = v_acc_out;
113 };
114
116 output.get_lengths()[0],
117 output.get_lengths()[1],
118 output.get_lengths()[2],
119 output.get_lengths()[3],
120 output.get_lengths()[4])(std::thread::hardware_concurrency());
121 }
122 else if constexpr(NDimSpatial == 3)
123 {
124 auto func = [&](auto g, auto n, auto k, auto d_o, auto ho, auto wo) {
125 float v_acc = 0;
126
127 for(std::size_t c = 0; c < weight.get_lengths()[2]; ++c)
128 {
129 for(std::size_t z = 0; z < weight.get_lengths()[3]; ++z)
130 {
131 auto di = static_cast<ck_tile::long_index_t>(d_o * conv_strides[0]) +
132 static_cast<ck_tile::long_index_t>(z * conv_dilations[0]) -
133 static_cast<ck_tile::long_index_t>(in_left_pads[0]);
134 for(std::size_t y = 0; y < weight.get_lengths()[4]; ++y)
135 {
136 auto hi = static_cast<ck_tile::long_index_t>(ho * conv_strides[1]) +
137 static_cast<ck_tile::long_index_t>(y * conv_dilations[1]) -
138 static_cast<ck_tile::long_index_t>(in_left_pads[1]);
139 for(std::size_t x = 0; x < weight.get_lengths()[5]; ++x)
140 {
141 auto wi = static_cast<ck_tile::long_index_t>(wo * conv_strides[2]) +
142 static_cast<ck_tile::long_index_t>(x * conv_dilations[2]) -
143 static_cast<ck_tile::long_index_t>(in_left_pads[2]);
144 if(di >= 0 &&
146 hi >= 0 &&
148 wi >= 0 &&
150 {
151 InDataType v_in = input(g, n, c, di, hi, wi);
152 WeiDataType v_wei = weight(g, k, c, z, y, x);
153
154 v_acc += ck_tile::type_convert<float>(v_in) *
156 }
157 }
158 }
159 }
160 }
161 if constexpr(Tuple::size() > 0)
162 elfunc(v_acc, v_acc, ds.at(ck_tile::number<0>{})(g, n, k, d_o, ho, wo));
163 else
164 elfunc(v_acc, v_acc);
165 OutDataType v_acc_out = ck_tile::type_convert<OutDataType>(v_acc);
166 output(g, n, k, d_o, ho, wo) = v_acc_out;
167 };
168
170 output.get_lengths()[0],
171 output.get_lengths()[1],
172 output.get_lengths()[2],
173 output.get_lengths()[3],
174 output.get_lengths()[4],
175 output.get_lengths()[5])(std::thread::hardware_concurrency());
176 }
177 else
178 {
179 throw std::runtime_error("Ref_Conv_fwd: number of dimensions must be between 1 and 3.");
180 }
181}
182} // namespace ck_tile
#define CK_TILE_HOST
Definition config.hpp:40
Definition tile/core/algorithm/cluster_descriptor.hpp:13
CK_TILE_HOST auto make_ParallelTensorFunctor(F f, Xs... xs)
Definition tile/host/host_tensor.hpp:329
int64_t long_index_t
Definition integer.hpp:11
constant< v > number
Definition tile/core/numeric/integral_constant.hpp:37
int32_t index_t
Definition integer.hpp:9
CK_TILE_HOST_DEVICE constexpr Y type_convert(X x)
Definition tile/core/numeric/type_convert.hpp:29
CK_TILE_HOST void reference_grouped_conv_fwd(const HostTensor< InDataType > &input, const HostTensor< WeiDataType > &weight, HostTensor< OutDataType > &output, std::vector< ck_tile::long_index_t > conv_strides, std::vector< ck_tile::long_index_t > conv_dilations, std::vector< ck_tile::long_index_t > in_left_pads, std::vector< ck_tile::long_index_t >, Elfunc elfunc=Elfunc{}, Tuple ds={})
Definition reference_grouped_conv_fwd.hpp:21
Definition tile/host/host_tensor.hpp:336
decltype(auto) get_lengths() const
Definition tile/host/host_tensor.hpp:390
std::size_t get_num_of_dimension() const
Definition tile/host/host_tensor.hpp:396