device_batchnorm_forward_impl_obsolete.hpp Source File

device_batchnorm_forward_impl_obsolete.hpp Source File#

Composable Kernel: device_batchnorm_forward_impl_obsolete.hpp Source File
device_batchnorm_forward_impl_obsolete.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include <iostream>
7#include <sstream>
8
19
20namespace ck {
21namespace tensor_operation {
22namespace device {
23
24template <typename XDataType,
25 typename YDataType,
26 typename AccDataType,
27 typename ScaleDataType,
28 typename BiasDataType,
29 typename MeanVarDataType,
30 typename YElementwiseOp,
31 index_t Rank,
32 index_t NumBatchNormReduceDim,
33 bool UseMultiblockInK,
34 index_t BlockSize,
35 index_t MThreadClusterSize,
36 index_t KThreadClusterSize,
37 index_t MThreadSliceSize,
38 index_t KThreadSliceSize,
39 index_t XSrcYDstVectorDim,
40 index_t XSrcVectorSize,
41 index_t YDstVectorSize,
42 index_t ScaleSrcVectorSize,
43 index_t BiasSrcVectorSize,
44 index_t MeanVarSrcDstVectorSize>
45struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType,
46 YDataType,
47 AccDataType,
48 ScaleDataType,
49 BiasDataType,
50 MeanVarDataType,
51 YElementwiseOp,
52 Rank,
53 NumBatchNormReduceDim>
54{
55 static_assert(Rank <= 6, "Bigger Rank size is not supported!");
56 static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize,
57 "Invalid thread cluster size assignments!");
58
59 static_assert((XSrcYDstVectorDim == 0 && MThreadSliceSize % XSrcVectorSize == 0) ||
60 (XSrcYDstVectorDim == 1 && KThreadSliceSize % XSrcVectorSize == 0),
61 "Invalid thread slice sizes and/or vector sizes configuration, please check!");
62
63 static constexpr index_t NumInvariantDim = Rank - NumBatchNormReduceDim;
64
65 static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
66 static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
67
68 static auto MakeXY2dDescriptor(const std::array<index_t, Rank>& xyLengths,
69 const std::array<index_t, Rank>& xyStrides,
70 int blkGroupSize,
71 int numBlockTileIteration)
72 {
73 const auto tupleXYLengths =
74 generate_tuple([&](auto I) { return xyLengths[I]; }, Number<Rank>{});
75 const auto tupleXYStrides =
76 generate_tuple([&](auto I) { return xyStrides[I]; }, Number<Rank>{});
77
78 const auto raw_grid_desc = make_naive_tensor_descriptor(tupleXYLengths, tupleXYStrides);
79
80 const auto grid_desc_m_k = [&]() {
81 using InvariantDims = typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type;
83
84 const auto reduceDimLengths =
85 generate_tuple([&](auto I) { return xyLengths[NumInvariantDim + I]; },
87 const auto invariantDimLengths =
88 generate_tuple([&](auto I) { return xyLengths[I]; }, Number<NumInvariantDim>{});
89
90 return transform_tensor_descriptor(raw_grid_desc,
91 make_tuple(make_merge_transform(invariantDimLengths),
92 make_merge_transform(reduceDimLengths)),
93 make_tuple(InvariantDims{}, ReduceDims{}),
95 }();
96
97 const auto invariantLength = grid_desc_m_k.GetLength(Number<0>{});
98 const auto reduceLength = grid_desc_m_k.GetLength(Number<1>{});
99
100 const int workSizePerBlock = K_BlockTileSize * numBlockTileIteration;
101 const auto mPad =
102 math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
103 const auto kPad = workSizePerBlock * blkGroupSize - reduceLength;
104
105 auto grid_desc_m_k_padded =
106 transform_tensor_descriptor(grid_desc_m_k,
107 make_tuple(make_right_pad_transform(invariantLength, mPad),
108 make_right_pad_transform(reduceLength, kPad)),
111
112 return (grid_desc_m_k_padded);
113 };
114
115 static auto MakeMeanVarCountOutputMG2dDescriptor(int invariantLength, int blkGroupSize)
116 {
117 const auto grid_desc_m_g = make_naive_tensor_descriptor(
118 make_tuple(invariantLength, blkGroupSize), make_tuple(1, invariantLength));
119
120 const auto mPad =
121 math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
122
123 auto grid_desc_m_g_padded =
124 transform_tensor_descriptor(grid_desc_m_g,
125 make_tuple(make_right_pad_transform(invariantLength, mPad),
126 make_pass_through_transform(blkGroupSize)),
129
130 return (grid_desc_m_g_padded);
131 };
132
133 static auto MakeMeanVarCountInputMK2dDescriptor(int invariantLength, int blkGroupSize)
134 {
135 const auto reduceLength = blkGroupSize;
136 const auto grid_desc_m_k = make_naive_tensor_descriptor(
137 make_tuple(invariantLength, reduceLength), make_tuple(1, invariantLength));
138
139 const auto mPad =
140 math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
141 const auto kPad =
142 math::integer_least_multiple(reduceLength, KThreadClusterSize) - reduceLength;
143
144 auto grid_desc_m_k_padded =
145 transform_tensor_descriptor(grid_desc_m_k,
146 make_tuple(make_right_pad_transform(invariantLength, mPad),
147 make_right_pad_transform(reduceLength, kPad)),
150
151 return (grid_desc_m_k_padded);
152 };
153
154 static auto
155 MakeScaleBiasMeanVar1dDescriptor(const std::array<index_t, NumInvariantDim>& lengths,
156 const std::array<index_t, NumInvariantDim>& strides)
157 {
158 const auto tupleLengths =
159 generate_tuple([&](auto I) { return lengths[I]; }, Number<NumInvariantDim>{});
160 const auto tupleStrides =
161 generate_tuple([&](auto I) { return strides[I]; }, Number<NumInvariantDim>{});
162
163 auto raw_grid_desc = make_naive_tensor_descriptor(tupleLengths, tupleStrides);
164
165 auto grid_desc_m = transform_tensor_descriptor(
166 raw_grid_desc,
167 make_tuple(make_merge_transform(tupleLengths)),
170
171 const auto invariantLength = grid_desc_m.GetLength(Number<0>{});
172
173 const auto mPad =
174 math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
175
176 auto grid_desc_m_padded =
177 transform_tensor_descriptor(grid_desc_m,
178 make_tuple(make_right_pad_transform(invariantLength, mPad)),
181 return (grid_desc_m_padded);
182 };
183
184 using XYGridDesc_M_K = decltype(MakeXY2dDescriptor({1}, {1}, 1, 1));
186
187 struct Argument : public BaseArgument
188 {
189 Argument(const std::array<index_t, Rank> xyLengths,
190 const std::array<index_t, Rank> xStrides,
191 const std::array<index_t, Rank> yStrides,
192 const std::array<int, NumBatchNormReduceDim> reduceDims,
193 const std::array<index_t, Rank - NumBatchNormReduceDim> bnScaleBiasMeanVarLengths,
194 const std::array<index_t, Rank - NumBatchNormReduceDim> bnScaleStrides,
195 const std::array<index_t, Rank - NumBatchNormReduceDim> bnBiasStrides,
196 const std::array<index_t, Rank - NumBatchNormReduceDim> bnMeanVarStrides,
197 const XDataType* p_x,
198 const ScaleDataType* p_scale,
199 const BiasDataType* p_bias,
200 const YElementwiseOp y_elementwise_op,
201 double epsilon,
202 YDataType* p_y,
203 MeanVarDataType* resultSaveMean,
204 MeanVarDataType* resultSaveInvVariance,
205 double averageFactor,
206 MeanVarDataType* resultRunningMean,
207 MeanVarDataType* resultRunningVariance)
208 : bnScaleBiasMeanVarLengths_(bnScaleBiasMeanVarLengths),
209 bnScaleStrides_(bnScaleStrides),
210 bnBiasStrides_(bnBiasStrides),
211 bnMeanVarStrides_(bnMeanVarStrides),
212 p_x_(p_x),
213 p_scale_(p_scale),
214 p_bias_(p_bias),
215 y_elementwise_op_(y_elementwise_op),
216 p_y_(p_y),
217 resultSaveMean_(resultSaveMean),
218 resultSaveInvVariance_(resultSaveInvVariance),
219 resultRunningMean_(resultRunningMean),
220 resultRunningVariance_(resultRunningVariance)
221 {
222 xyLengths_ =
224 xStrides_ =
226 yStrides_ =
228
231
234
236 (resultRunningMean != nullptr && resultRunningVariance != nullptr);
237 saveMeanInvVariance_ = (resultSaveMean != nullptr && resultSaveInvVariance_ != nullptr);
238
239 if(UseMultiblockInK)
240 {
241 int iterations = 1;
242 while(true)
243 {
244 int testBlkGroupSize = (reduce_length_ + (K_BlockTileSize * iterations) - 1) /
245 (K_BlockTileSize * iterations);
246
247 // we want the blkGroupSize be not more than 16
248 if(testBlkGroupSize <= 16)
249 break;
250
251 iterations++;
252 };
253
254 blkGroupSize_ = (reduce_length_ + (K_BlockTileSize * iterations) - 1) /
255 (K_BlockTileSize * iterations);
256
257 numBlockTileIteration_ = iterations;
258 }
259 else
260 {
261 blkGroupSize_ = 1;
263 };
264
266
272 MakeScaleBiasMeanVar1dDescriptor(bnScaleBiasMeanVarLengths, bnScaleStrides_);
274 MakeScaleBiasMeanVar1dDescriptor(bnScaleBiasMeanVarLengths, bnBiasStrides_);
276 MakeScaleBiasMeanVar1dDescriptor(bnScaleBiasMeanVarLengths, bnMeanVarStrides_);
277 }
278
279 AccDataType epsilon_;
280 AccDataType averageFactor_;
281
284
285 std::array<index_t, Rank> xyLengths_;
286 std::array<index_t, Rank> xStrides_;
287 std::array<index_t, Rank> yStrides_;
288
289 std::array<index_t, Rank - NumBatchNormReduceDim> bnScaleBiasMeanVarLengths_;
290 std::array<index_t, Rank - NumBatchNormReduceDim> bnScaleStrides_;
291 std::array<index_t, Rank - NumBatchNormReduceDim> bnBiasStrides_;
292 std::array<index_t, Rank - NumBatchNormReduceDim> bnMeanVarStrides_;
293
294 const XDataType* p_x_;
295 const ScaleDataType* p_scale_;
296 const BiasDataType* p_bias_;
297 const YElementwiseOp y_elementwise_op_;
298 YDataType* p_y_;
299
300 MeanVarDataType* resultSaveMean_;
301 MeanVarDataType* resultSaveInvVariance_;
302
303 MeanVarDataType* resultRunningMean_;
304 MeanVarDataType* resultRunningVariance_;
305
308
309 int blkGroupSize_;
311 size_t gridSize_;
312
318
319 void* workspace_mean_;
321 void* workspace_count_;
322 };
323
324 size_t GetWorkSpaceSize(const BaseArgument* pArg) const override
325 {
326 const Argument* pArg_ = dynamic_cast<const Argument*>(pArg);
327
328 size_t workspace_size = 0;
329
330 if(UseMultiblockInK && pArg_->blkGroupSize_ > 1)
331 {
332 // workspace for welford intermediate mean
333 workspace_size +=
334 pArg_->invariant_length_ * pArg_->blkGroupSize_ * sizeof(MeanVarDataType) + 64;
335
336 // workspace for welford intermediate variance
337 workspace_size +=
338 pArg_->invariant_length_ * pArg_->blkGroupSize_ * sizeof(MeanVarDataType) + 64;
339
340 // workspace for welford intermediate count
341 workspace_size +=
342 pArg_->invariant_length_ * pArg_->blkGroupSize_ * sizeof(int32_t) + 64;
343 }
344
345 return (workspace_size);
346 };
347
349 void* p_workspace,
350 const StreamConfig& = StreamConfig{}) const override
351 {
352 Argument* pArg_ = dynamic_cast<Argument*>(pArg);
353
354 pArg_->p_workspace_ = p_workspace;
355
356 if(UseMultiblockInK && pArg_->blkGroupSize_ > 1)
357 {
358
359 // setup buffer used for intermediate welford mean
360 pArg_->workspace_mean_ = static_cast<char*>(pArg_->p_workspace_);
361
362 index_t mean_space_sz =
363 pArg_->invariant_length_ * pArg_->blkGroupSize_ * sizeof(MeanVarDataType);
364
365 mean_space_sz = math::integer_least_multiple(mean_space_sz, 64);
366
367 // setup buffer used for intermediate welford varirance
368 pArg_->workspace_variance_ =
369 reinterpret_cast<char*>(pArg_->workspace_mean_) + mean_space_sz;
370
371 index_t variance_space_sz =
372 pArg_->invariant_length_ * pArg_->blkGroupSize_ * sizeof(MeanVarDataType);
373
374 variance_space_sz = math::integer_least_multiple(variance_space_sz, 64);
375
376 // setup buffer used for intermediate welfor count
377 pArg_->workspace_count_ =
378 reinterpret_cast<char*>(pArg_->workspace_variance_) + variance_space_sz;
379 };
380 };
381
382 struct Invoker : public BaseInvoker
383 {
384 float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
385 {
386 float avg_time = 0;
387
388 if(UseMultiblockInK && arg.blkGroupSize_ > 1)
389 {
390 using GetReduceCountPerThreadFunctor =
392
393 GetReduceCountPerThreadFunctor get_reduce_count_per_thread(
394 arg.blkGroupSize_, arg.numBlockTileIteration_, arg.reduce_length_);
395
396 const auto mean_var_count_grid_desc_m_g =
398 arg.invariant_length_, arg.blkGroupSize_);
399
400 const auto mean_var_count_grid_desc_m_k =
402 arg.invariant_length_, arg.blkGroupSize_);
403
404 using MeanVarCountGridDesc_M_G = decltype(mean_var_count_grid_desc_m_g);
405 using MeanVarCountGridDesc_M_K = decltype(mean_var_count_grid_desc_m_k);
406
407 using GridwiseMultiblockWelfordFirstHalf_ =
409 AccDataType,
410 MeanVarDataType,
412 MeanVarCountGridDesc_M_G,
413 GetReduceCountPerThreadFunctor,
414 BlockSize,
415 MThreadClusterSize,
416 KThreadClusterSize,
417 MThreadSliceSize,
418 KThreadSliceSize,
419 XSrcYDstVectorDim,
420 XSrcVectorSize>;
421
422 using GridwiseWelfordSecondHalfBatchNormForwardFinal_ =
424 YDataType,
425 AccDataType,
426 ScaleDataType,
427 BiasDataType,
428 MeanVarDataType,
429 YElementwiseOp,
431 MeanVarCountGridDesc_M_K,
434 BlockSize,
435 MThreadClusterSize,
436 KThreadClusterSize,
437 MThreadSliceSize,
438 KThreadSliceSize,
439 XSrcYDstVectorDim,
440 XSrcVectorSize,
441 YDstVectorSize,
442 ScaleSrcVectorSize,
443 BiasSrcVectorSize,
444 MeanVarSrcDstVectorSize>;
445
446 const auto kern_multiblock_welford_first_half =
447 kernel_multiblock_welford_first_half<GridwiseMultiblockWelfordFirstHalf_,
448 XDataType,
449 MeanVarDataType,
451 MeanVarCountGridDesc_M_G,
452 GetReduceCountPerThreadFunctor>;
453
454 const auto kern_welford_second_half_batchnorm_forward_final =
456 GridwiseWelfordSecondHalfBatchNormForwardFinal_,
457 XDataType,
458 YDataType,
459 AccDataType,
460 ScaleDataType,
461 BiasDataType,
462 MeanVarDataType,
463 YElementwiseOp,
465 MeanVarCountGridDesc_M_K,
468
469 avg_time +=
470 launch_and_time_kernel(stream_config,
471 kern_multiblock_welford_first_half,
472 dim3(arg.gridSize_),
473 dim3(BlockSize),
474 0,
475 arg.x_grid_desc_m_k_,
476 mean_var_count_grid_desc_m_g,
477 get_reduce_count_per_thread,
478 arg.numBlockTileIteration_,
479 arg.p_x_,
480 static_cast<MeanVarDataType*>(arg.workspace_mean_),
481 static_cast<MeanVarDataType*>(arg.workspace_variance_),
482 static_cast<int32_t*>(arg.workspace_count_));
483
484 avg_time +=
485 launch_and_time_kernel(stream_config,
486 kern_welford_second_half_batchnorm_forward_final,
487 dim3(arg.gridSize_),
488 dim3(BlockSize),
489 0,
490 arg.x_grid_desc_m_k_,
491 arg.y_grid_desc_m_k_,
492 mean_var_count_grid_desc_m_k,
493 arg.scale_grid_desc_m_,
494 arg.bias_grid_desc_m_,
495 arg.mean_var_grid_desc_m_,
496 arg.blkGroupSize_,
497 arg.numBlockTileIteration_,
498 arg.epsilon_,
499 static_cast<MeanVarDataType*>(arg.workspace_mean_),
500 static_cast<MeanVarDataType*>(arg.workspace_variance_),
501 static_cast<int32_t*>(arg.workspace_count_),
502 arg.p_x_,
503 arg.p_scale_,
504 arg.p_bias_,
505 arg.y_elementwise_op_,
506 arg.p_y_,
507 arg.updateMovingAverage_,
508 arg.averageFactor_,
509 arg.resultRunningMean_,
510 arg.resultRunningVariance_,
511 arg.saveMeanInvVariance_,
512 arg.resultSaveMean_,
513 arg.resultSaveInvVariance_);
514 }
515 else
516 {
517 using GetReduceCountPerThreadFunctor =
519
520 GetReduceCountPerThreadFunctor get_reduce_count_per_thread(
521 arg.numBlockTileIteration_, arg.reduce_length_);
522
523 using GridwiseBatchNormForwardWithBlockwiseWelford_ =
525 YDataType,
526 AccDataType,
527 ScaleDataType,
528 BiasDataType,
529 MeanVarDataType,
530 YElementwiseOp,
534 GetReduceCountPerThreadFunctor,
535 BlockSize,
536 MThreadClusterSize,
537 KThreadClusterSize,
538 MThreadSliceSize,
539 KThreadSliceSize,
540 XSrcYDstVectorDim,
541 XSrcVectorSize,
542 YDstVectorSize,
543 ScaleSrcVectorSize,
544 BiasSrcVectorSize,
545 MeanVarSrcDstVectorSize>;
546
547 const auto kern_batchnorm_fwd = kernel_batchnorm_forward_with_blockwise_welford<
548 GridwiseBatchNormForwardWithBlockwiseWelford_,
549 XDataType,
550 YDataType,
551 AccDataType,
552 ScaleDataType,
553 BiasDataType,
554 MeanVarDataType,
555 YElementwiseOp,
559 GetReduceCountPerThreadFunctor>;
560
561 avg_time += launch_and_time_kernel(stream_config,
562 kern_batchnorm_fwd,
563 dim3(arg.gridSize_),
564 dim3(BlockSize),
565 0,
566 arg.x_grid_desc_m_k_,
567 arg.y_grid_desc_m_k_,
568 arg.scale_grid_desc_m_,
569 arg.bias_grid_desc_m_,
570 arg.mean_var_grid_desc_m_,
571 get_reduce_count_per_thread,
572 arg.numBlockTileIteration_,
573 arg.epsilon_,
574 arg.p_x_,
575 arg.p_scale_,
576 arg.p_bias_,
577 arg.y_elementwise_op_,
578 arg.p_y_,
579 arg.updateMovingAverage_, // true or false
580 arg.averageFactor_,
581 arg.resultRunningMean_,
582 arg.resultRunningVariance_,
583 arg.saveMeanInvVariance_, // true or false
584 arg.resultSaveMean_,
585 arg.resultSaveInvVariance_);
586 };
587
588 return (avg_time);
589 };
590
591 float Run(const BaseArgument* pArg,
592 const StreamConfig& stream_config = StreamConfig{}) override
593 {
594 return Run(*dynamic_cast<const Argument*>(pArg), stream_config);
595 };
596 };
597
598 bool IsSupportedArgument(const BaseArgument* pArg) override
599 {
600 const Argument* pArg_ = dynamic_cast<const Argument*>(pArg);
601
602 if constexpr(XSrcYDstVectorDim == 0)
603 {
604 if(pArg_->xStrides_[NumInvariantDim - 1] != 1 ||
605 pArg_->yStrides_[NumInvariantDim - 1] != 1)
606 return false;
607
608 if(pArg_->xyLengths_[NumInvariantDim - 1] % XSrcVectorSize != 0 ||
609 pArg_->xyLengths_[NumInvariantDim - 1] % YDstVectorSize != 0)
610 return false;
611 }
612 else
613 {
614 if(pArg_->xStrides_[Rank - 1] != 1 || pArg_->yStrides_[Rank - 1] != 1)
615 return false;
616
617 if(pArg_->xyLengths_[Rank - 1] % XSrcVectorSize != 0 ||
618 pArg_->xyLengths_[Rank - 1] % YDstVectorSize != 0)
619 return false;
620 };
621
622 if(pArg_->bnScaleStrides_[NumInvariantDim - 1] != 1 && ScaleSrcVectorSize != 1)
623 return false;
624 if(pArg_->bnBiasStrides_[NumInvariantDim - 1] != 1 && BiasSrcVectorSize != 1)
625 return false;
626
627 if(pArg_->bnScaleBiasMeanVarLengths_[NumInvariantDim - 1] % ScaleSrcVectorSize != 0)
628 return false;
629 if(pArg_->bnScaleBiasMeanVarLengths_[NumInvariantDim - 1] % BiasSrcVectorSize != 0)
630 return false;
631
632 if(pArg_->bnMeanVarStrides_[NumInvariantDim - 1] != 1 && MeanVarSrcDstVectorSize != 1)
633 return false;
634
635 if(pArg_->bnScaleBiasMeanVarLengths_[NumInvariantDim - 1] % MeanVarSrcDstVectorSize != 0)
636 return false;
637
638 bool is_valid = true;
639
641 if(pArg_->xyLengths_[I] != pArg_->bnScaleBiasMeanVarLengths_[I])
642 is_valid = false;
643 });
644
645 if(!is_valid)
646 return false;
647
648 return true;
649 };
650
651 std::unique_ptr<BaseArgument> MakeArgumentPointer(
652 const std::array<index_t, Rank> xyLengths,
653 const std::array<index_t, Rank> xStrides,
654 const std::array<index_t, Rank> yStrides,
655 const std::array<int, NumBatchNormReduceDim> reduceDims,
656 const std::array<index_t, Rank - NumBatchNormReduceDim> bnScaleBiasMeanVarLengths,
657 const std::array<index_t, Rank - NumBatchNormReduceDim> bnScaleStrides,
658 const std::array<index_t, Rank - NumBatchNormReduceDim> bnBiasStrides,
659 const std::array<index_t, Rank - NumBatchNormReduceDim> bnMeanVarStrides,
660 const void* p_x,
661 const void* p_scale,
662 const void* p_bias,
663 double epsilon,
664 const YElementwiseOp y_elementwise_op,
665 void* p_y,
666 void* resultSaveMean,
667 void* resultSaveInvVariance,
668 double averageFactor,
669 void* resultRunningMean,
670 void* resultRunningVariance) override
671 {
672 return std::make_unique<Argument>(xyLengths,
673 xStrides,
674 yStrides,
675 reduceDims,
676 bnScaleBiasMeanVarLengths,
677 bnScaleStrides,
678 bnBiasStrides,
679 bnMeanVarStrides,
680 static_cast<const XDataType*>(p_x),
681 static_cast<const ScaleDataType*>(p_scale),
682 static_cast<const BiasDataType*>(p_bias),
683 y_elementwise_op,
684 epsilon,
685 static_cast<YDataType*>(p_y),
686 static_cast<MeanVarDataType*>(resultSaveMean),
687 static_cast<MeanVarDataType*>(resultSaveInvVariance),
688 averageFactor,
689 static_cast<MeanVarDataType*>(resultRunningMean),
690 static_cast<MeanVarDataType*>(resultRunningVariance));
691 };
692
693 std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
694 {
695 return std::make_unique<Invoker>();
696 };
697
698 std::string GetTypeString() const override
699 {
700 auto str = std::stringstream();
701
702 // clang-format off
703 str << "DeviceBatchNormFwdImpl<" << BlockSize << ",";
704 str << "M_C" << MThreadClusterSize << "_S" << MThreadSliceSize << ",";
705 str << "K_C" << KThreadClusterSize << "_S" << KThreadSliceSize << ",";
706 str << "XSrcYDstVectorDim_" << XSrcYDstVectorDim << ",";
707 str << "VectorSize_X" << XSrcVectorSize << "_scale_" << ScaleSrcVectorSize << "_bias_" << BiasSrcVectorSize << "_mean_var_" << MeanVarSrcDstVectorSize << "_Y" << YDstVectorSize << ">";
708 // clang-format on
709
710 return str.str();
711 }
712};
713
714} // namespace device
715} // namespace tensor_operation
716} // namespace ck
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:14
__host__ __device__ constexpr auto integer_least_multiple(X x, Y y)
Definition utility/math.hpp:78
Definition convolution_backward_data_specialization.hpp:8
std::pair< long_index_t, long_index_t > get_2d_lengths(const std::vector< index_t > &inLengths)
Definition device_reduce_common.hpp:20
std::vector< index_t > shuffle_tensor_dimensions(const std::vector< index_t > &origLengthsStrides, const std::vector< int > &reduceDims)
Definition device_reduce_common.hpp:75
Definition convolution_backward_data_specialization.hpp:7
Definition ck.hpp:268
__host__ __device__ constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition multi_index_transform_helper.hpp:12
__global__ void kernel_multiblock_welford_first_half(const XGridDesc_M_K x_grid_desc_m_k, const MeanVarCountGridDesc_M_G mean_var_count_grid_desc_m_g, const GetReduceCountPerThreadFunctor get_reduce_count_per_thread, index_t num_k_block_tile_iteration, const XDataType *const __restrict__ p_x, MeanVarDataType *const p_welford_mean, MeanVarDataType *const p_welford_variance, int32_t *const p_welford_count)
Definition gridwise_multiblock_welford_first_half.hpp:21
int32_t index_t
Definition ck.hpp:299
__global__ void kernel_welford_second_half_batchnorm_forward_final(const XYGridDesc_M_K x_grid_desc_m_k, const XYGridDesc_M_K y_grid_desc_m_k, const MeanVarCountGridDesc_M_K mean_var_count_grid_desc_m_k, const ScaleBiasGridDesc_M scale_grid_desc_m, const ScaleBiasGridDesc_M bias_grid_desc_m, const MeanVarGridDesc_M mean_var_grid_desc_m, index_t blkgroup_size, index_t num_xy_k_block_tile_iteration, AccDataType epsilon, const MeanVarDataType *const __restrict__ p_in_welford_mean, const MeanVarDataType *const __restrict__ p_in_welford_variance, const int32_t *const __restrict__ p_in_welford_count, const XDataType *const __restrict__ p_x, const ScaleDataType *const __restrict__ p_scale, const BiasDataType *const __restrict__ p_bias, const YElementwiseOp y_elementwise_op, YDataType *const __restrict__ p_y, bool updateMovingAverage, AccDataType averageFactor, MeanVarDataType *const __restrict__ resultRunningMean, MeanVarDataType *const __restrict__ resultRunningVariance, bool saveMeanInvVariance, MeanVarDataType *const __restrict__ resultSaveMean, MeanVarDataType *const __restrict__ resultSaveInvVariance)
Definition gridwise_multiblock_welford_second_half_batchnorm_forward_final_obsolete.hpp:27
__host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition tensor_descriptor_helper.hpp:49
__host__ __device__ constexpr auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:37
integral_constant< index_t, N > Number
Definition number.hpp:12
__host__ __device__ constexpr auto make_merge_transform(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:55
__host__ __device__ constexpr Y type_convert(X x)
Definition utility/type_convert.hpp:98
__host__ __device__ constexpr auto generate_tuple(F &&f, Number< N >)
Definition tuple_helper.hpp:21
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
int64_t long_index_t
Definition ck.hpp:300
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
__global__ void kernel_batchnorm_forward_with_blockwise_welford(const XYGridDesc_M_K x_grid_desc_m_k, const XYGridDesc_M_K y_grid_desc_m_k, const ScaleBiasGridDesc_M scale_grid_desc_m, const ScaleBiasGridDesc_M bias_grid_desc_m, const MeanVarGridDesc_M mean_var_grid_desc_m, const GetReduceCountPerThreadFunctor get_reduce_count_per_thread, index_t num_k_block_tile_iteration, AccDataType epsilon, const XDataType *const __restrict__ p_x, const ScaleDataType *const __restrict__ p_scale, const BiasDataType *const __restrict__ p_bias, const YElementwiseOp y_elementwise_op, YDataType *const __restrict__ p_y, bool updateMovingAverage, AccDataType averageFactor, MeanVarDataType *const __restrict__ resultRunningMean, MeanVarDataType *const __restrict__ resultRunningVariance, bool saveMeanInvVariance, MeanVarDataType *const __restrict__ resultSaveMean, MeanVarDataType *const __restrict__ resultSaveInvVariance)
Definition gridwise_batchnorm_forward_blockwise_welford.hpp:27
signed int int32_t
Definition stdint.h:123
Definition ck/stream_config.hpp:10
Definition gridwise_batchnorm_forward_blockwise_welford.hpp:94
Definition gridwise_multiblock_welford_first_half.hpp:55
Definition gridwise_multiblock_welford_second_half_batchnorm_forward_final_obsolete.hpp:102
Definition utility/sequence.hpp:43
typename conditional< kHasContent, type0, type1 >::type type
Definition utility/sequence.hpp:271
Definition functional2.hpp:33
Definition device_base.hpp:197
Definition device_batchnorm_forward.hpp:26
Definition device_batchnorm_forward_impl.hpp:190
MeanVarDataType * resultRunningMean_
Definition device_batchnorm_forward_impl.hpp:305
long_index_t reduce_length_
Definition device_batchnorm_forward_impl.hpp:309
const ScaleDataType * p_scale_
Definition device_batchnorm_forward_impl.hpp:297
bool updateMovingAverage_
Definition device_batchnorm_forward_impl.hpp:284
ScaleBiasMeanVarGridDesc_M scale_grid_desc_m_
Definition device_batchnorm_forward_impl.hpp:317
std::array< index_t, Rank > xStrides_
Definition device_batchnorm_forward_impl.hpp:288
const XDataType * p_x_
Definition device_batchnorm_forward_impl.hpp:296
std::array< index_t, Rank > xyLengths_
Definition device_batchnorm_forward_impl.hpp:287
int blkGroupSize_
Definition device_batchnorm_forward_impl.hpp:311
XYGridDesc_M_K x_grid_desc_m_k_
Definition device_batchnorm_forward_impl.hpp:315
XYGridDesc_M_K y_grid_desc_m_k_
Definition device_batchnorm_forward_impl.hpp:316
ScaleBiasMeanVarGridDesc_M bias_grid_desc_m_
Definition device_batchnorm_forward_impl.hpp:318
bool saveMeanInvVariance_
Definition device_batchnorm_forward_impl.hpp:285
long_index_t invariant_length_
Definition device_batchnorm_forward_impl.hpp:308
MeanVarDataType * resultRunningVariance_
Definition device_batchnorm_forward_impl.hpp:306
Argument(const std::array< index_t, Rank > xyLengths, const std::array< index_t, Rank > xStrides, const std::array< index_t, Rank > yStrides, const std::array< int, NumBatchNormReduceDim > reduceDims, const std::array< index_t, Rank - NumBatchNormReduceDim > bnScaleBiasMeanVarLengths, const std::array< index_t, Rank - NumBatchNormReduceDim > bnScaleStrides, const std::array< index_t, Rank - NumBatchNormReduceDim > bnBiasStrides, const std::array< index_t, Rank - NumBatchNormReduceDim > bnMeanVarStrides, const XDataType *p_x, const ScaleDataType *p_scale, const BiasDataType *p_bias, const YElementwiseOp y_elementwise_op, double epsilon, YDataType *p_y, MeanVarDataType *resultSaveMean, MeanVarDataType *resultSaveInvVariance, double averageFactor, MeanVarDataType *resultRunningMean, MeanVarDataType *resultRunningVariance)
Definition device_batchnorm_forward_impl_obsolete.hpp:189
AccDataType averageFactor_
Definition device_batchnorm_forward_impl.hpp:282
const BiasDataType * p_bias_
Definition device_batchnorm_forward_impl.hpp:298
AccDataType epsilon_
Definition device_batchnorm_forward_impl.hpp:281
ScaleBiasMeanVarGridDesc_M mean_var_grid_desc_m_
Definition device_batchnorm_forward_impl.hpp:319
std::array< index_t, Rank - NumBatchNormReduceDim > bnBiasStrides_
Definition device_batchnorm_forward_impl.hpp:293
int numBlockTileIteration_
Definition device_batchnorm_forward_impl.hpp:312
void * workspace_count_
Definition device_batchnorm_forward_impl.hpp:323
const YElementwiseOp y_elementwise_op_
Definition device_batchnorm_forward_impl.hpp:299
void * workspace_mean_
Definition device_batchnorm_forward_impl.hpp:321
std::array< index_t, Rank - NumBatchNormReduceDim > bnMeanVarStrides_
Definition device_batchnorm_forward_impl.hpp:294
MeanVarDataType * resultSaveMean_
Definition device_batchnorm_forward_impl.hpp:302
YDataType * p_y_
Definition device_batchnorm_forward_impl.hpp:300
size_t gridSize_
Definition device_batchnorm_forward_impl.hpp:313
MeanVarDataType * resultSaveInvVariance_
Definition device_batchnorm_forward_impl.hpp:303
std::array< index_t, Rank - NumBatchNormReduceDim > bnScaleBiasMeanVarLengths_
Definition device_batchnorm_forward_impl.hpp:291
void * workspace_variance_
Definition device_batchnorm_forward_impl.hpp:322
std::array< index_t, Rank > yStrides_
Definition device_batchnorm_forward_impl.hpp:289
std::array< index_t, Rank - NumBatchNormReduceDim > bnScaleStrides_
Definition device_batchnorm_forward_impl.hpp:292
float Run(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition device_batchnorm_forward_impl_obsolete.hpp:384
float Run(const BaseArgument *pArg, const StreamConfig &stream_config=StreamConfig{}) override
Definition device_batchnorm_forward_impl_obsolete.hpp:591
Definition device_batchnorm_forward_impl.hpp:56
static auto MakeXY2dDescriptor(const std::array< index_t, Rank > &xyLengths, const std::array< index_t, Rank > &xyStrides, int blkGroupSize, int numBlockTileIteration)
Definition device_batchnorm_forward_impl_obsolete.hpp:68
bool IsSupportedArgument(const BaseArgument *pArg) override
Definition device_batchnorm_forward_impl_obsolete.hpp:598
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition device_batchnorm_forward_impl_obsolete.hpp:693
static constexpr index_t K_BlockTileSize
Definition device_batchnorm_forward_impl.hpp:68
std::string GetTypeString() const override
Definition device_batchnorm_forward_impl_obsolete.hpp:698
void SetWorkSpacePointer(BaseArgument *pArg, void *p_workspace, const StreamConfig &=StreamConfig{}) const override
Definition device_batchnorm_forward_impl_obsolete.hpp:348
decltype(MakeXY2dDescriptor({1}, {1}, 1, 1)) XYGridDesc_M_K
Definition device_batchnorm_forward_impl.hpp:186
static constexpr index_t M_BlockTileSize
Definition device_batchnorm_forward_impl.hpp:67
decltype(MakeScaleBiasMeanVar1dDescriptor({1}, {1})) ScaleBiasMeanVarGridDesc_M
Definition device_batchnorm_forward_impl.hpp:187
static auto MakeScaleBiasMeanVar1dDescriptor(const std::array< index_t, NumInvariantDim > &lengths, const std::array< index_t, NumInvariantDim > &strides)
Definition device_batchnorm_forward_impl_obsolete.hpp:155
std::unique_ptr< BaseArgument > MakeArgumentPointer(const std::array< index_t, Rank > xyLengths, const std::array< index_t, Rank > xStrides, const std::array< index_t, Rank > yStrides, const std::array< int, NumBatchNormReduceDim > reduceDims, const std::array< index_t, Rank - NumBatchNormReduceDim > bnScaleBiasMeanVarLengths, const std::array< index_t, Rank - NumBatchNormReduceDim > bnScaleStrides, const std::array< index_t, Rank - NumBatchNormReduceDim > bnBiasStrides, const std::array< index_t, Rank - NumBatchNormReduceDim > bnMeanVarStrides, const void *p_x, const void *p_scale, const void *p_bias, double epsilon, const YElementwiseOp y_elementwise_op, void *p_y, void *resultSaveMean, void *resultSaveInvVariance, double averageFactor, void *resultRunningMean, void *resultRunningVariance) override
Definition device_batchnorm_forward_impl_obsolete.hpp:651
static constexpr index_t NumInvariantDim
Definition device_batchnorm_forward_impl.hpp:65
size_t GetWorkSpaceSize(const BaseArgument *pArg) const override
Definition device_batchnorm_forward_impl_obsolete.hpp:324
static auto MakeMeanVarCountInputMK2dDescriptor(int invariantLength, int blkGroupSize)
Definition device_batchnorm_forward_impl_obsolete.hpp:133
static auto MakeMeanVarCountOutputMG2dDescriptor(int invariantLength, int blkGroupSize)
Definition device_batchnorm_forward_impl_obsolete.hpp:115