device_grouped_conv_bwd_weight_multiple_d.hpp Source File

device_grouped_conv_bwd_weight_multiple_d.hpp Source File#

Composable Kernel: device_grouped_conv_bwd_weight_multiple_d.hpp Source File
device_grouped_conv_bwd_weight_multiple_d.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include <array>
7
9
10namespace ck {
11namespace tensor_operation {
12namespace device {
13
14#define DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS 1
15
16template <ck::index_t NDimSpatial,
17 typename InLayout,
18 typename WeiLayout,
19 typename OutLayout,
20 typename DsLayout,
21 typename InDataType,
22 typename WeiDataType,
23 typename OutDataType,
24 typename DsDataType,
25 typename InElementwiseOperation,
26 typename WeiElementwiseOperation,
27 typename OutElementwiseOperation,
28 typename ComputeTypeA = InDataType,
29 typename ComputeTypeB = ComputeTypeA>
31{
32 static constexpr index_t NumDTensor = DsLayout::Size();
33
34 virtual std::unique_ptr<BaseArgument> MakeArgumentPointer(
35 const void* p_in_grid,
36 void* p_wei_grid,
37 const void* p_out_grid,
38 const std::array<const void*, NumDTensor>& p_ds,
39 const std::array<index_t, NDimSpatial + 3>& b_g_n_c_wis_lengths, // input
40 const std::array<index_t, NDimSpatial + 3>& b_g_n_c_wis_strides,
41 const std::array<index_t, NDimSpatial + 3>& e_g_k_c_xs_lengths, // weight
42 const std::array<index_t, NDimSpatial + 3>& e_g_k_c_xs_strides,
43 const std::array<index_t, NDimSpatial + 3>& a_g_n_k_wos_lengths, // output
44 const std::array<index_t, NDimSpatial + 3>& a_g_n_k_wos_strides,
45 const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>& ds_g_k_c_xs_lengths,
46 const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>& ds_g_k_c_xs_strides,
47 const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
48 const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
49 const std::array<ck::index_t, NDimSpatial>& input_left_pads,
50 const std::array<ck::index_t, NDimSpatial>& input_right_pads,
51 InElementwiseOperation in_element_op,
52 WeiElementwiseOperation wei_element_op,
53 OutElementwiseOperation out_element_op,
54 const ck::index_t split_k) = 0;
55
56 virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
57};
58
59} // namespace device
60} // namespace tensor_operation
61} // namespace ck
Definition convolution_backward_data_specialization.hpp:8
Definition convolution_backward_data_specialization.hpp:7
Definition ck.hpp:268
int32_t index_t
Definition ck.hpp:299
Definition device_grouped_conv_bwd_weight_multiple_d.hpp:31
virtual std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_in_grid, void *p_wei_grid, const void *p_out_grid, const std::array< const void *, NumDTensor > &p_ds, const std::array< index_t, NDimSpatial+3 > &b_g_n_c_wis_lengths, const std::array< index_t, NDimSpatial+3 > &b_g_n_c_wis_strides, const std::array< index_t, NDimSpatial+3 > &e_g_k_c_xs_lengths, const std::array< index_t, NDimSpatial+3 > &e_g_k_c_xs_strides, const std::array< index_t, NDimSpatial+3 > &a_g_n_k_wos_lengths, const std::array< index_t, NDimSpatial+3 > &a_g_n_k_wos_strides, const std::array< std::array< index_t, NDimSpatial+3 >, NumDTensor > &ds_g_k_c_xs_lengths, const std::array< std::array< index_t, NDimSpatial+3 >, NumDTensor > &ds_g_k_c_xs_strides, const std::array< ck::index_t, NDimSpatial > &conv_filter_strides, const std::array< ck::index_t, NDimSpatial > &conv_filter_dilations, const std::array< ck::index_t, NDimSpatial > &input_left_pads, const std::array< ck::index_t, NDimSpatial > &input_right_pads, InElementwiseOperation in_element_op, WeiElementwiseOperation wei_element_op, OutElementwiseOperation out_element_op, const ck::index_t split_k)=0
static constexpr index_t NumDTensor
Definition device_grouped_conv_bwd_weight_multiple_d.hpp:32
virtual std::unique_ptr< BaseInvoker > MakeInvokerPointer()=0