split_k_utils.hpp Source File

split_k_utils.hpp Source File#

Composable Kernel: split_k_utils.hpp Source File
split_k_utils.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5#include <numeric>
6#include <hip/hip_runtime.h>
7#include "ck/utility/env.hpp"
10#include "ck/ck.hpp"
11
12namespace ck {
13namespace tensor_operation {
14namespace device {
15
17{
19 {
20 hipDeviceProp_t dev_prop;
21 hipDevice_t dev;
22 hip_check_error(hipGetDevice(&dev));
23 hip_check_error(hipGetDeviceProperties(&dev_prop, dev));
24
25 num_cu_ = dev_prop.multiProcessorCount;
26 };
28};
29
30inline ck::index_t get_best_occupancy_k_batch_value(int max_occupancy, ck::index_t grid_size)
31{
32 static DeviceProperties device_properties;
33 const int max_capacity = max_occupancy * device_properties.num_cu_;
34
35 ck::index_t k_batch = 1;
36 const auto optimal_split =
37 static_cast<ck::index_t>(std::floor((1.0 * max_capacity) / grid_size));
38 if(optimal_split > 1)
39 {
40 k_batch = optimal_split;
41 }
42
43 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
44 {
45 std::cout << "[SPLIT-K AUTODEDUCE] Max active thread blocks per CU for GEMM kernel: "
46 << max_occupancy << std::endl;
47 std::cout << "[SPLIT-K AUTODEDUCE] Output grid size: " << grid_size << std::endl;
48 std::cout << "[SPLIT-K AUTODEDUCE] Optimal split-k value " << k_batch << std::endl;
49 }
50 return k_batch;
51}
52
53template <ck::index_t NDimSpatial>
54inline auto
55get_bwd_weight_gemm_sizes(const std::array<index_t, NDimSpatial + 3>& a_g_n_k_wos_lengths,
56 const std::array<index_t, NDimSpatial + 3>& e_g_k_c_xs_lengths)
57{
58 static constexpr auto I1 = Number<1>{};
59 static constexpr auto I2 = Number<2>{};
60
61 // The input array has elements in the order: G, N, K, Do, Ho, Wo
62 // GemmK = N * Do * Ho * Wo for the BWD weight pass.
63 constexpr index_t spatial_offset = 3;
64 const index_t DoHoWo = std::accumulate(begin(a_g_n_k_wos_lengths) + spatial_offset,
65 end(a_g_n_k_wos_lengths),
66 index_t{1},
67 std::multiplies<>{});
68 const auto gemmK = a_g_n_k_wos_lengths[I1] * DoHoWo;
69
70 // The GEMM M dimension is the number of output channels.
71 const auto gemmM = e_g_k_c_xs_lengths[I1];
72
73 // The output array has elements in the order: G, K, C, X, Y, Z
74 // GemmN = C * X * Y * Z for the BWD weight pass.
75 const index_t XYZ = std::accumulate(begin(e_g_k_c_xs_lengths) + spatial_offset,
76 end(e_g_k_c_xs_lengths),
77 index_t{1},
78 std::multiplies<>{});
79 const auto gemmN = e_g_k_c_xs_lengths[I2] * XYZ;
80 return std::make_tuple(gemmM, gemmN, gemmK);
81}
82
83template <ck::index_t MPerBlock, ck::index_t NPerBlock>
85{
86 const auto M0 = math::integer_divide_ceil(gemmM, MPerBlock);
87 const auto N0 = math::integer_divide_ceil(gemmN, NPerBlock);
88 return M0 * N0;
89}
90
91} // namespace device
92} // namespace tensor_operation
93} // namespace ck
void hip_check_error(hipError_t x)
Definition host_utility/hip_check_error.hpp:10
__host__ __device__ constexpr auto integer_divide_ceil(X x, Y y)
Definition utility/math.hpp:72
Definition convolution_backward_data_specialization.hpp:8
auto get_bwd_weight_gemm_sizes(const std::array< index_t, NDimSpatial+3 > &a_g_n_k_wos_lengths, const std::array< index_t, NDimSpatial+3 > &e_g_k_c_xs_lengths)
Definition split_k_utils.hpp:55
ck::index_t get_best_occupancy_k_batch_value(int max_occupancy, ck::index_t grid_size)
Definition split_k_utils.hpp:30
ck::index_t calculate_mn_grid_size(ck::index_t gemmM, ck::index_t gemmN)
Definition split_k_utils.hpp:84
Definition convolution_backward_data_specialization.hpp:7
Definition ck.hpp:268
int32_t index_t
Definition ck.hpp:299
integral_constant< index_t, N > Number
Definition number.hpp:12
bool EnvIsEnabled(EnvVar)
Definition utility/env.hpp:140
Definition split_k_utils.hpp:17
DeviceProperties()
Definition split_k_utils.hpp:18
int num_cu_
Definition split_k_utils.hpp:27
#define CK_ENV(name)
Definition utility/env.hpp:129