generic_memory_space_atomic.hpp Source File

generic_memory_space_atomic.hpp Source File#

Composable Kernel: generic_memory_space_atomic.hpp Source File
tile/core/arch/generic_memory_space_atomic.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
8
9#define HAS_GLOBAL_ATOMIC_PK_ADD_BUILTIN \
10 __has_builtin(__builtin_amdgcn_global_atomic_fadd_v2f16) && \
11 __has_builtin(__builtin_amdgcn_global_atomic_fadd_v2bf16)
12
13namespace ck_tile {
14
15template <typename T, typename ComputeType>
20
22{
23 bf16x2_t rtn;
24 rtn[0] = add<bf16_t, float>(a[0], b[0]);
25 rtn[1] = add<bf16_t, float>(a[1], b[1]);
26 return rtn;
27}
28
30{
31 bf16x4_t rtn;
32 rtn[0] = add<bf16_t, float>(a[0], b[0]);
33 rtn[1] = add<bf16_t, float>(a[1], b[1]);
34 rtn[2] = add<bf16_t, float>(a[2], b[2]);
35 rtn[3] = add<bf16_t, float>(a[3], b[3]);
36 return rtn;
37}
38
40{
41 fp16x2_t rtn;
42 rtn[0] = add<fp16_t, float>(a[0], b[0]);
43 rtn[1] = add<fp16_t, float>(a[1], b[1]);
44 return rtn;
45}
46
48{
49 fp8x4_t rtn;
50 rtn[0] = add<fp8_t, float>(a[0], b[0]);
51 rtn[1] = add<fp8_t, float>(a[1], b[1]);
52 rtn[2] = add<fp8_t, float>(a[2], b[2]);
53 rtn[3] = add<fp8_t, float>(a[3], b[3]);
54 return rtn;
55}
56
58{
59 fp8x8_t rtn;
60 rtn[0] = add<fp8_t, float>(a[0], b[0]);
61 rtn[1] = add<fp8_t, float>(a[1], b[1]);
62 rtn[2] = add<fp8_t, float>(a[2], b[2]);
63 rtn[3] = add<fp8_t, float>(a[3], b[3]);
64 rtn[4] = add<fp8_t, float>(a[4], b[4]);
65 rtn[5] = add<fp8_t, float>(a[5], b[5]);
66 rtn[6] = add<fp8_t, float>(a[6], b[6]);
67 rtn[7] = add<fp8_t, float>(a[7], b[7]);
68 return rtn;
69}
70
72{
73 bf8x4_t rtn;
74 rtn[0] = add<bf8_t, float>(a[0], b[0]);
75 rtn[1] = add<bf8_t, float>(a[1], b[1]);
76 rtn[2] = add<bf8_t, float>(a[2], b[2]);
77 rtn[3] = add<bf8_t, float>(a[3], b[3]);
78 return rtn;
79}
80
82{
83 bf8x8_t rtn;
84 rtn[0] = add<bf8_t, float>(a[0], b[0]);
85 rtn[1] = add<bf8_t, float>(a[1], b[1]);
86 rtn[2] = add<bf8_t, float>(a[2], b[2]);
87 rtn[3] = add<bf8_t, float>(a[3], b[3]);
88 rtn[4] = add<bf8_t, float>(a[4], b[4]);
89 rtn[5] = add<bf8_t, float>(a[5], b[5]);
90 rtn[6] = add<bf8_t, float>(a[6], b[6]);
91 rtn[7] = add<bf8_t, float>(a[7], b[7]);
92 return rtn;
93}
94
95// Caution: DO NOT REMOVE
96// intentionally have only declaration but no definition to cause compilation failure when trying to
97// instantiate this template. The purpose is to make the implementation of atomic_add explicit for
98// each datatype.
99template <typename X>
100CK_TILE_DEVICE void atomic_add(X* p_dst, const X& x);
101
102template <>
104{
105 union U32BF162_ADDR
106 {
107 uint32_t* u32_a;
108 bf16x2_t* bf162_a;
109 };
110
111 union U32BF162
112 {
113 uint32_t u32;
114 bf16x2_t bf162;
115 };
116
117 U32BF162_ADDR dword_addr;
118 U32BF162 cur_v;
119 U32BF162 new_;
120 uint32_t old_v, new_v;
121 dword_addr.bf162_a = p_dst;
122 cur_v.u32 = *dword_addr.u32_a;
123
124 do
125 {
126 old_v = cur_v.u32;
127 new_.bf162 = add_bf16x2_t(cur_v.bf162, x);
128 new_v = new_.u32;
129 cur_v.u32 = atomicCAS(dword_addr.u32_a, old_v, new_v);
130 } while(cur_v.u32 != old_v);
131}
132
133template <>
135{
136 // Union to treat the pointer as either bf16x4_t* or uint64_t*:
137 union U64BF164_ADDR
138 {
139 uint64_t* u64_a;
140 bf16x4_t* bf164_a;
141 };
142
143 // Union to treat the data as either bf16x4_t or 64-bit integer
144 union U64BF164
145 {
146 uint64_t u64;
147 bf16x4_t bf164;
148 };
149
150 U64BF164_ADDR addr;
151 addr.bf164_a = p_dst; // interpret p_dst as a 64-bit location
152
153 // First read (non-atomic) of the old value
154 U64BF164 cur_v;
155 cur_v.u64 = *addr.u64_a;
156
157 U64BF164 new_v_union;
158 uint64_t old_v, new_v;
159
160 do
161 {
162 // old 64 bits
163 old_v = cur_v.u64;
164
165 // Add elementwise in bf16
166 new_v_union.bf164 = add_bf16x4_t(cur_v.bf164, x);
167 new_v = new_v_union.u64;
168
169 // Attempt the 64-bit CAS
170 cur_v.u64 = atomicCAS(addr.u64_a, old_v, new_v);
171
172 } while(cur_v.u64 != old_v);
173}
174
175template <>
177{
178 union U32FP84_ADDR
179 {
180 uint32_t* u32_a;
181 fp8x4_t* fp84_a;
182 };
183
184 union U32FP84
185 {
186 uint32_t u32;
187 fp8x4_t fp84;
188 };
189
190 U32FP84_ADDR dword_addr;
191 U32FP84 cur_v;
192 U32FP84 new_;
193 uint32_t old_v, new_v;
194
195 dword_addr.fp84_a = p_dst;
196 cur_v.u32 = *dword_addr.u32_a;
197
198 do
199 {
200 old_v = cur_v.u32;
201 new_.fp84 = add_fp8x4_t(cur_v.fp84, x);
202 new_v = new_.u32;
203 cur_v.u32 = atomicCAS(dword_addr.u32_a, old_v, new_v);
204 } while(cur_v.u32 != old_v);
205}
206
207template <>
209{
210 union U32BF84_ADDR
211 {
212 uint32_t* u32_a;
213 bf8x4_t* bf84_a;
214 };
215
216 union U32BF84
217 {
218 uint32_t u32;
219 bf8x4_t bf84;
220 };
221
222 U32BF84_ADDR dword_addr;
223 U32BF84 cur_v;
224 U32BF84 new_;
225 uint32_t old_v, new_v;
226
227 dword_addr.bf84_a = p_dst;
228 cur_v.u32 = *dword_addr.u32_a;
229
230 do
231 {
232 old_v = cur_v.u32;
233 new_.bf84 = add_bf8x4_t(cur_v.bf84, x);
234 new_v = new_.u32;
235 cur_v.u32 = atomicCAS(dword_addr.u32_a, old_v, new_v);
236 } while(cur_v.u32 != old_v);
237}
238
239//
240// Atomic add for fp8x8_t
241//
242template <>
244{
245 // Union for addressing 64 bits as either "fp8x8_t" or a 64-bit integer.
246 union U64FP88_ADDR
247 {
248 uint64_t* u64_a; // pointer to 64-bit integer
249 fp8x8_t* fp88_a; // pointer to fp8x8_t
250 };
251
252 union U64FP88
253 {
254 uint64_t u64;
255 fp8x8_t fp88;
256 };
257
258 U64FP88_ADDR dword_addr;
259 U64FP88 cur_v;
260 U64FP88 new_v_union;
261 uint64_t old_v, new_v;
262
263 // Point to the destination as both fp8x8_t* and uint64_t*.
264 dword_addr.fp88_a = p_dst;
265 // Initial read of 64 bits from memory
266 cur_v.u64 = *dword_addr.u64_a;
267
268 do
269 {
270 old_v = cur_v.u64;
271 // Add each fp8 element using your add_fp8x8_t(...) routine
272 new_v_union.fp88 = add_fp8x8_t(cur_v.fp88, x);
273 new_v = new_v_union.u64;
274
275 // Attempt 64-bit CAS
276 cur_v.u64 = atomicCAS(dword_addr.u64_a, old_v, new_v);
277 } while(cur_v.u64 != old_v);
278}
279
280//
281// Atomic add for bf8x8_t
282//
283template <>
285{
286 union U64BF88_ADDR
287 {
288 uint64_t* u64_a;
289 bf8x8_t* bf88_a;
290 };
291
292 union U64BF88
293 {
294 uint64_t u64;
295 bf8x8_t bf88;
296 };
297
298 U64BF88_ADDR dword_addr;
299 U64BF88 cur_v;
300 U64BF88 new_v_union;
301 uint64_t old_v, new_v;
302
303 dword_addr.bf88_a = p_dst;
304 // Read the original 64 bits
305 cur_v.u64 = *dword_addr.u64_a;
306
307 do
308 {
309 old_v = cur_v.u64;
310 // Add each bf8 element using your add_bf8x8_t(...) routine
311 new_v_union.bf88 = add_bf8x8_t(cur_v.bf88, x);
312 new_v = new_v_union.u64;
313
314 // 64-bit CAS loop
315 cur_v.u64 = atomicCAS(dword_addr.u64_a, old_v, new_v);
316 } while(cur_v.u64 != old_v);
317}
318
319//
320// Atomic add for fp16x2_t
321//
322template <>
324{
325#if HAS_GLOBAL_ATOMIC_PK_ADD_BUILTIN
326 __builtin_amdgcn_global_atomic_fadd_v2f16(c_style_pointer_cast<fp16x2_t*>(p_dst), x);
327#else
328 union U32F162_ADDR
329 {
330 uint32_t* u32_a;
331 fp16x2_t* f162_a;
332 };
333
334 union U32F162
335 {
336 uint32_t u32;
337 fp16x2_t f162;
338 };
339
340 U32F162_ADDR dword_addr;
341 U32F162 cur_v;
342 U32F162 new_;
343 uint32_t old_v, new_v;
344 dword_addr.f162_a = p_dst;
345 cur_v.u32 = *dword_addr.u32_a;
346
347 do
348 {
349 old_v = cur_v.u32;
350 new_.f162 = add_f16x2_t(cur_v.f162, x);
351 new_v = new_.u32;
352 cur_v.u32 = atomicCAS(dword_addr.u32_a, old_v, new_v);
353 } while(cur_v.u32 != old_v);
354#endif
355}
356
357template <typename T, index_t N>
359{
360 static_assert((std::is_same<T, int32_t>::value && (N == 1)) ||
361 (std::is_same<T, uint32_t>::value && (N == 1)) ||
362 (std::is_same<T, float>::value && (N == 1 || N == 2 || N == 4)) ||
363 (std::is_same<T, double>::value && (N == 1 || N == 2)) ||
364 (std::is_same<T, fp16_t>::value && (N == 2 || N == 4 || N == 8)) ||
365 (std::is_same<T, bf16_t>::value && (N == 2 || N == 4 || N == 8)) ||
366 (std::is_same<T, fp8_t>::value && (N == 4 || N == 8 || N == 16)) ||
367 (std::is_same<T, bf8_t>::value && (N == 4 || N == 8 || N == 16)),
368 "The granularity of the thread buffer is unsupported on the hardware!");
369
370 constexpr auto I0 = number<0>{};
371 constexpr auto I1 = number<1>{};
372 constexpr auto I2 = number<2>{};
373 constexpr auto I3 = number<3>{};
374
375 if constexpr(std::is_same<T, float>::value)
376 {
377 if constexpr(N == 1)
378 {
379 atomicAdd(p_dst, bit_cast<float>(x));
380 }
381 else if constexpr(N == 2)
382 {
383 atomicAdd(c_style_pointer_cast<float*>(p_dst), x.template get_as<float>()[I0]);
384 atomicAdd(c_style_pointer_cast<float*>(p_dst) + 1, x.template get_as<float>()[I1]);
385 }
386 else if constexpr(N == 4)
387 {
388 atomicAdd(c_style_pointer_cast<float*>(p_dst), x.template get_as<float>()[I0]);
389 atomicAdd(c_style_pointer_cast<float*>(p_dst) + 1, x.template get_as<float>()[I1]);
390 atomicAdd(c_style_pointer_cast<float*>(p_dst) + 2, x.template get_as<float>()[I2]);
391 atomicAdd(c_style_pointer_cast<float*>(p_dst) + 3, x.template get_as<float>()[I3]);
392 }
393 }
394 else if constexpr(std::is_same<T, double>::value)
395 {
396 if constexpr(N == 1)
397 {
398 return atomicAdd(p_dst, bit_cast<double>(x));
399 }
400 else if constexpr(N == 2)
401 {
402 atomicAdd(c_style_pointer_cast<double*>(p_dst), x.template get_as<double>()[I0]);
403 atomicAdd(c_style_pointer_cast<double*>(p_dst) + 1, x.template get_as<double>()[I1]);
404 }
405 }
406 else if constexpr(std::is_same<T, int32_t>::value)
407 {
408 if constexpr(N == 1)
409 {
410 atomicAdd(p_dst, bit_cast<int32_t>(x));
411 }
412 }
413 else if constexpr(std::is_same<T, uint32_t>::value)
414 {
415 if constexpr(N == 1)
416 {
417 atomicAdd(p_dst, bit_cast<uint32_t>(x));
418 }
419 }
420 else if constexpr(std::is_same<T, bf16_t>::value)
421 {
422 if constexpr(N == 2)
423 {
424 atomic_add(c_style_pointer_cast<bf16x2_t*>(p_dst), x.template get_as<bf16x2_t>()[I0]);
425 }
426 else if constexpr(N == 4)
427 {
428 atomic_add(c_style_pointer_cast<bf16x4_t*>(p_dst), x.template get_as<bf16x4_t>()[I0]);
429 }
430 else if constexpr(N == 8)
431 {
432 atomic_add(c_style_pointer_cast<bf16x4_t*>(p_dst), x.template get_as<bf16x4_t>()[I0]);
434 x.template get_as<bf16x4_t>()[I1]);
435 }
436 }
437 else if constexpr(std::is_same<T, fp8_t>::value)
438 {
439 if constexpr(N == 4)
440 {
441 atomic_add(c_style_pointer_cast<fp8x4_t*>(p_dst), x.template get_as<fp8x4_t>()[I0]);
442 }
443 if constexpr(N == 8)
444 {
445 atomic_add(c_style_pointer_cast<fp8x8_t*>(p_dst), x.template get_as<fp8x8_t>()[I0]);
446 }
447 if constexpr(N == 16)
448 {
449 atomic_add(c_style_pointer_cast<fp8x8_t*>(p_dst), x.template get_as<fp8x8_t>()[I0]);
450 atomic_add(c_style_pointer_cast<fp8x8_t*>(p_dst) + 1, x.template get_as<fp8x8_t>()[I1]);
451 }
452 }
453 else if constexpr(std::is_same<T, bf8_t>::value)
454 {
455 if constexpr(N == 4)
456 {
457 atomic_add(c_style_pointer_cast<bf8x4_t*>(p_dst), x.template get_as<bf8x4_t>()[I0]);
458 }
459 if constexpr(N == 8)
460 {
461 atomic_add(c_style_pointer_cast<bf8x8_t*>(p_dst), x.template get_as<bf8x8_t>()[I0]);
462 }
463 if constexpr(N == 16)
464 {
465 atomic_add(c_style_pointer_cast<bf8x8_t*>(p_dst), x.template get_as<bf8x8_t>()[I0]);
466 atomic_add(c_style_pointer_cast<bf8x8_t*>(p_dst) + 1, x.template get_as<bf8x8_t>()[I1]);
467 }
468 }
469 else if constexpr(std::is_same<T, fp16_t>::value)
470 {
471 static_for<0, N / 2, 1>{}([&](auto i) {
473 x.template get_as<fp16x2_t>()[i]);
474 });
475 }
476}
477
478template <typename T, index_t N>
480{
481 static_assert((std::is_same<T, int32_t>::value && (N == 1)) ||
482 (std::is_same<T, uint32_t>::value && (N == 1)) ||
483 (std::is_same<T, float>::value && (N == 1 || N == 2)) ||
484 (std::is_same<T, double>::value && (N == 1)),
485 "wrong! not implemented");
486
487 constexpr auto I0 = number<0>{};
488 constexpr auto I1 = number<1>{};
489
490 if constexpr(std::is_same<T, float>::value)
491 {
492 if constexpr(N == 1)
493 {
494 atomicMax(p_dst, bit_cast<float>(x));
495 }
496 else if constexpr(N == 2)
497 {
498 atomicMax(c_style_pointer_cast<float*>(p_dst), x.template get_as<float>()[I0]);
499 atomicMax(c_style_pointer_cast<float*>(p_dst) + 1, x.template get_as<float>()[I1]);
500 }
501 }
502 else if constexpr(std::is_same<T, double>::value)
503 {
504 if constexpr(N == 1)
505 {
506 atomicMax(p_dst, bit_cast<double>(x));
507 }
508 }
509 else if constexpr(std::is_same<T, int32_t>::value)
510 {
511 if constexpr(N == 1)
512 {
513 atomicMax(p_dst, bit_cast<int32_t>(x));
514 }
515 }
516 else if constexpr(std::is_same<T, uint32_t>::value)
517 {
518 if constexpr(N == 1)
519 {
520 atomicMax(p_dst, bit_cast<uint32_t>(x));
521 }
522 }
523}
524
525} // namespace ck_tile
#define CK_TILE_DEVICE
Definition config.hpp:41
#define CK_TILE_HOST_DEVICE
Definition config.hpp:42
Definition tile/core/algorithm/cluster_descriptor.hpp:13
_Float16 fp16x2_t
Definition half.hpp:385
CK_TILE_DEVICE void atomic_add< bf16x2_t >(bf16x2_t *p_dst, const bf16x2_t &x)
Definition tile/core/arch/generic_memory_space_atomic.hpp:103
CK_TILE_DEVICE void atomic_add(X *p_dst, const X &x)
fp8_t fp8x4_t
Definition vector_type.hpp:228
CK_TILE_DEVICE void atomic_add_g(T *p_dst, const thread_buffer< T, N > &x)
Definition tile/core/arch/generic_memory_space_atomic.hpp:358
CK_TILE_DEVICE void atomic_add< fp8x8_t >(fp8x8_t *p_dst, fp8x8_t const &x)
Definition tile/core/arch/generic_memory_space_atomic.hpp:243
CK_TILE_HOST_DEVICE fp8x8_t add_fp8x8_t(const fp8x8_t &a, const fp8x8_t &b)
Definition tile/core/arch/generic_memory_space_atomic.hpp:57
CK_TILE_HOST_DEVICE T add(const T &a, const T &b)
Definition tile/core/arch/generic_memory_space_atomic.hpp:16
bfloat16_t bf16x2_t
Definition pk_fp4.hpp:24
CK_TILE_DEVICE void atomic_add< bf16x4_t >(bf16x4_t *p_dst, bf16x4_t const &x)
Definition tile/core/arch/generic_memory_space_atomic.hpp:134
CK_TILE_HOST_DEVICE constexpr Y bit_cast(const X &x)
Definition bit_cast.hpp:11
CK_TILE_HOST_DEVICE bf8x8_t add_bf8x8_t(const bf8x8_t &a, const bf8x8_t &b)
Definition tile/core/arch/generic_memory_space_atomic.hpp:81
CK_TILE_HOST_DEVICE fp8x4_t add_fp8x4_t(const fp8x4_t &a, const fp8x4_t &b)
Definition tile/core/arch/generic_memory_space_atomic.hpp:47
CK_TILE_DEVICE void atomic_add< fp8x4_t >(fp8x4_t *p_dst, const fp8x4_t &x)
Definition tile/core/arch/generic_memory_space_atomic.hpp:176
constant< v > number
Definition tile/core/numeric/integral_constant.hpp:37
CK_TILE_HOST_DEVICE PY c_style_pointer_cast(PX p_x)
Definition type_traits.hpp:104
bf8_t bf8x8_t
Definition vector_type.hpp:238
bfloat16_t bf16x4_t
Definition vector_type.hpp:146
CK_TILE_HOST_DEVICE bf16x2_t add_bf16x2_t(const bf16x2_t &a, const bf16x2_t &b)
Definition tile/core/arch/generic_memory_space_atomic.hpp:21
CK_TILE_DEVICE void atomic_add< fp16x2_t >(fp16x2_t *p_dst, fp16x2_t const &x)
Definition tile/core/arch/generic_memory_space_atomic.hpp:323
CK_TILE_HOST_DEVICE fp16x2_t add_f16x2_t(const fp16x2_t &a, const fp16x2_t &b)
Definition tile/core/arch/generic_memory_space_atomic.hpp:39
bf8_t bf8x4_t
Definition vector_type.hpp:237
fp8_t fp8x8_t
Definition vector_type.hpp:229
CK_TILE_HOST_DEVICE constexpr Y type_convert(X x)
Definition tile/core/numeric/type_convert.hpp:29
CK_TILE_DEVICE void atomic_max_g(T *p_dst, const thread_buffer< T, N > &x)
Definition tile/core/arch/generic_memory_space_atomic.hpp:479
CK_TILE_HOST_DEVICE bf16x4_t add_bf16x4_t(const bf16x4_t &a, const bf16x4_t &b)
Definition tile/core/arch/generic_memory_space_atomic.hpp:29
CK_TILE_DEVICE void atomic_add< bf8x4_t >(bf8x4_t *p_dst, const bf8x4_t &x)
Definition tile/core/arch/generic_memory_space_atomic.hpp:208
CK_TILE_HOST_DEVICE bf8x4_t add_bf8x4_t(const bf8x4_t &a, const bf8x4_t &b)
Definition tile/core/arch/generic_memory_space_atomic.hpp:71
CK_TILE_DEVICE void atomic_add< bf8x8_t >(bf8x8_t *p_dst, bf8x8_t const &x)
Definition tile/core/arch/generic_memory_space_atomic.hpp:284
const GenericPointer< typename T::ValueType > T2 T::AllocatorType & a
Definition pointer.h:1517
unsigned int uint32_t
Definition stdint.h:126
unsigned __int64 uint64_t
Definition stdint.h:136
Definition tile/core/utility/functional.hpp:43
Definition tile/core/utility/debug.hpp:67