amd_smfmac.hpp Source File

amd_smfmac.hpp Source File#

Composable Kernel: amd_smfmac.hpp Source File
amd_smfmac.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
3
4#include "ck/ck.hpp"
5#pragma once
6
7namespace ck {
8
9template <index_t MPerWave, index_t NPerWave>
11
12// for every smfmac instruction if CBSZ[1:0]=0, ABID[1:0] selects one of four 8-bit sets of sparse
13// indices from reg_idx
14template <>
16{
17 template <class FloatC, index_t abid = 0>
18 __device__ static void
19 Run(const half4_t& reg_a, const half8_t& reg_b, const index_t& reg_idx, FloatC& reg_c)
20 {
21#if defined(__gfx94__)
22 reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_smfmac_f32_16x16x32_f16(
23 reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], reg_idx, 0, abid);
24#else
25 ignore = reg_a;
26 ignore = reg_b;
27 ignore = reg_c;
28 ignore = reg_idx;
29#endif
30 }
31};
32
33template <index_t MPerWave, index_t NPerWave>
35
36template <>
38{
39 template <class FloatC, index_t abid = 0>
40 __device__ static void
41 Run(const bhalf4_t& reg_a, const bhalf8_t& reg_b, const index_t& reg_idx, FloatC& reg_c)
42 {
43#if defined(__gfx94__)
44 reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_smfmac_f32_16x16x32_bf16(
45 reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], reg_idx, 0, abid);
46#else
47 ignore = reg_a;
48 ignore = reg_b;
49 ignore = reg_c;
50 ignore = reg_idx;
51#endif
52 }
53};
54
55template <index_t MPerWave, index_t NPerWave>
57
58template <>
60{
61 template <class FloatC, index_t abid = 0>
62 __device__ static void
63 Run(const half4_t& reg_a, const half8_t& reg_b, const index_t& reg_idx, FloatC& reg_c)
64 {
65#if defined(__gfx94__)
66 reg_c.template AsType<float16_t>()(Number<0>{}) = __builtin_amdgcn_smfmac_f32_32x32x16_f16(
67 reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], reg_idx, 0, abid);
68#else
69 ignore = reg_a;
70 ignore = reg_b;
71 ignore = reg_c;
72 ignore = reg_idx;
73#endif
74 }
75};
76
77template <index_t MPerWave, index_t NPerWave>
79
80template <>
82{
83 template <class FloatC, index_t abid = 0>
84 __device__ static void
85 Run(const bhalf4_t& reg_a, const bhalf8_t& reg_b, const index_t& reg_idx, FloatC& reg_c)
86 {
87#if defined(__gfx94__)
88 reg_c.template AsType<float16_t>()(Number<0>{}) = __builtin_amdgcn_smfmac_f32_32x32x16_bf16(
89 reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], reg_idx, 0, abid);
90#else
91 ignore = reg_a;
92 ignore = reg_b;
93 ignore = reg_c;
94 ignore = reg_idx;
95#endif
96 }
97};
98
99} // namespace ck
Definition ck.hpp:268
int32_t index_t
Definition ck.hpp:299
typename vector_type< bhalf_t, 8 >::type bhalf8_t
Definition dtype_vector.hpp:2162
integral_constant< index_t, N > Number
Definition number.hpp:12
typename vector_type< half_t, 8 >::type half8_t
Definition dtype_vector.hpp:2155
constexpr detail::ignore_t ignore
Definition utility/ignore.hpp:20
typename vector_type< bhalf_t, 4 >::type bhalf4_t
Definition dtype_vector.hpp:2161
typename vector_type< half_t, 4 >::type half4_t
Definition dtype_vector.hpp:2154
static __device__ void Run(const bhalf4_t &reg_a, const bhalf8_t &reg_b, const index_t &reg_idx, FloatC &reg_c)
Definition amd_smfmac.hpp:41
Definition amd_smfmac.hpp:34
static __device__ void Run(const half4_t &reg_a, const half8_t &reg_b, const index_t &reg_idx, FloatC &reg_c)
Definition amd_smfmac.hpp:19
Definition amd_smfmac.hpp:10
static __device__ void Run(const bhalf4_t &reg_a, const bhalf8_t &reg_b, const index_t &reg_idx, FloatC &reg_c)
Definition amd_smfmac.hpp:85
Definition amd_smfmac.hpp:78
static __device__ void Run(const half4_t &reg_a, const half8_t &reg_b, const index_t &reg_idx, FloatC &reg_c)
Definition amd_smfmac.hpp:63
Definition amd_smfmac.hpp:56