philox_rand.hpp Source File

philox_rand.hpp Source File#

Composable Kernel: philox_rand.hpp Source File
philox_rand.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
7
8namespace ck_tile {
9
10// Reference: https://github.com/Dao-AILab/flash-attention/blob/main/csrc/flash_attn/src/philox.cuh
11class philox
12{
13 public:
14 CK_TILE_HOST_DEVICE philox(unsigned long long seed_, unsigned long long offset_)
15 : seed(reinterpret_cast<const uint2&>(seed_))
16 {
17
18 ull2* tmp = reinterpret_cast<ull2*>(&counter);
19 tmp->x = offset_;
20 }
21
22 CK_TILE_HOST_DEVICE uint4 get_philox_4x32(const unsigned long long subsequence) const
23 {
24
25 uint4 counter_ = counter;
26 ull2* tmp = reinterpret_cast<ull2*>(&counter_);
27 tmp->y = subsequence;
28
29 uint2 key_ = seed;
30// 7-round philox
31#pragma unroll
32 for(int i = 0; i < 6; i++)
33 {
34 counter_ = philox_single_round(counter_, key_);
35 key_.x += kPhilox10A;
36 key_.y += kPhilox10B;
37 }
38 uint4 output = philox_single_round(counter_, key_);
39 return output;
40 }
41
43 const unsigned long long subsequence) const
44 {
45 uint4 tmp_ph;
46 tmp_ph = get_philox_4x32(subsequence);
47
48 uint32_t* out_tmp = reinterpret_cast<uint32_t*>(&out[0]);
49
50 out_tmp[0] = tmp_ph.x;
51 out_tmp[1] = tmp_ph.y;
52 out_tmp[2] = tmp_ph.z;
53 out_tmp[3] = tmp_ph.w;
54 }
55
57 const unsigned long long subsequence,
58 const index_t idx0,
59 const index_t idx1) const
60 {
61 uint4 tmp_ph;
62 tmp_ph = get_philox_4x32(subsequence);
63
64 uint32x4_t tmp;
65 tmp[0] = tmp_ph.x;
66 tmp[1] = tmp_ph.y;
67 tmp[2] = tmp_ph.z;
68 tmp[3] = tmp_ph.w;
69 uint32_t* out_tmp = reinterpret_cast<uint32_t*>(&out[0]);
70 out_tmp[0] = tmp[idx0];
71 out_tmp[1] = tmp[idx1];
72 }
73
75 get_random_4x8(uint8_t* out, const unsigned long long subsequence, const index_t idx) const
76 {
77 uint4 tmp_ph;
78 tmp_ph = get_philox_4x32(subsequence);
79
80 uint32x4_t tmp;
81 tmp[0] = tmp_ph.x;
82 tmp[1] = tmp_ph.y;
83 tmp[2] = tmp_ph.z;
84 tmp[3] = tmp_ph.w;
85 uint32_t* out_tmp = reinterpret_cast<uint32_t*>(&out[0]);
86 out_tmp[0] = tmp[idx];
87 }
88
89 private:
90 struct ull2
91 {
92 uint64_t x;
93 uint64_t y;
94 };
95 uint4 counter;
96 const uint2 seed;
97
98 CK_TILE_HOST_DEVICE uint2 mulhilo32(const unsigned int a, const unsigned int b) const
99 {
100 uint2* res;
101 unsigned long long tmp;
102 tmp = static_cast<unsigned long long>(a) * b;
103 res = reinterpret_cast<uint2*>(&tmp);
104 return *res;
105 }
106
107 CK_TILE_HOST_DEVICE uint4 philox_single_round(const uint4 ctr, const uint2 key) const
108 {
109
110 uint2 res0 = mulhilo32(kPhiloxSA, ctr.x);
111 uint2 res1 = mulhilo32(kPhiloxSB, ctr.z);
112 uint4 ret = {res1.y ^ ctr.y ^ key.x, res1.x, res0.y ^ ctr.w ^ key.y, res0.x};
113 return ret;
114 }
115
116 static const unsigned long kPhilox10A = 0x9E3779B9;
117 static const unsigned long kPhilox10B = 0xBB67AE85;
118 static const unsigned long kPhiloxSA = 0xD2511F53;
119 static const unsigned long kPhiloxSB = 0xCD9E8D57;
120};
121
122} // namespace ck_tile
CK_TILE_HOST_DEVICE uint4 get_philox_4x32(const unsigned long long subsequence) const
Definition philox_rand.hpp:22
CK_TILE_HOST_DEVICE void get_random_4x8(uint8_t *out, const unsigned long long subsequence, const index_t idx) const
Definition philox_rand.hpp:75
CK_TILE_HOST_DEVICE philox(unsigned long long seed_, unsigned long long offset_)
Definition philox_rand.hpp:14
CK_TILE_HOST_DEVICE void get_random_8x8(uint8_t *out, const unsigned long long subsequence, const index_t idx0, const index_t idx1) const
Definition philox_rand.hpp:56
CK_TILE_HOST_DEVICE void get_random_16x8(uint8_t *out, const unsigned long long subsequence) const
Definition philox_rand.hpp:42
#define CK_TILE_HOST_DEVICE
Definition config.hpp:42
Definition tile/core/algorithm/cluster_descriptor.hpp:13
uint32_t uint32x4_t
Definition vector_type.hpp:164
int32_t index_t
Definition integer.hpp:9
const GenericPointer< typename T::ValueType > T2 T::AllocatorType & a
Definition pointer.h:1517
unsigned int uint32_t
Definition stdint.h:126
unsigned char uint8_t
Definition stdint.h:124
unsigned __int64 uint64_t
Definition stdint.h:136