gemm_quant_kernel.hpp Source File

gemm_quant_kernel.hpp Source File#

Composable Kernel: gemm_quant_kernel.hpp Source File
gemm_quant_kernel.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include <string>
7
8#include "ck_tile/core.hpp"
16
17namespace ck_tile {
18
19namespace detail {
20// Helper templates for safe type extraction
21template <typename, typename Default, typename = void>
23{
24 using type = Default;
25};
26
27template <typename T, typename Default>
28struct get_aq_layout_or<T, Default, std::void_t<typename T::AQLayout>>
29{
30 using type = typename T::AQLayout;
31};
32
33template <typename, typename Default, typename = void>
35{
36 using type = Default;
37};
38
39template <typename T, typename Default>
40struct get_bq_layout_or<T, Default, std::void_t<typename T::BQLayout>>
41{
42 using type = typename T::BQLayout;
43};
44
45template <typename, typename Default, typename = void>
47{
48 using type = Default;
49};
50
51template <typename T, typename Default>
52struct get_aq_data_type_or<T, Default, std::void_t<typename T::AQDataType>>
53{
54 using type = typename T::AQDataType;
55};
56
57template <typename, typename Default, typename = void>
59{
60 using type = Default;
61};
62
63template <typename T, typename Default>
64struct get_bq_data_type_or<T, Default, std::void_t<typename T::BQDataType>>
65{
66 using type = typename T::BQDataType;
67};
68
69template <typename, typename = void>
71{
72 static constexpr bool value = false;
73};
74
75template <typename T>
76struct is_quantpreshuffle_enabled<T, std::void_t<decltype(T::PreshuffleQuant)>>
77{
78 static constexpr bool value = T::PreshuffleQuant;
79};
80
81template <typename, typename = void>
83{
84 static constexpr bool value = false;
85};
86
87template <typename T>
88struct is_preshuffleB_enabled<T, std::void_t<decltype(T::PreshuffleB)>>
89{
90 static constexpr bool value = T::PreshuffleB;
91};
92} // namespace detail
93
95{
98 index_t N_,
99 index_t K_,
100 index_t QK_A_,
101 index_t QK_B_,
102 index_t stride_A_,
103 index_t stride_B_,
104 index_t stride_C_,
105 index_t stride_AQ_,
106 index_t stride_BQ_)
107 : M(M_),
108 N(N_),
109 K(K_),
110 QK_A(QK_A_),
111 QK_B(QK_B_),
112 stride_A(stride_A_),
113 stride_B(stride_B_),
114 stride_C(stride_C_),
115 stride_AQ(stride_AQ_),
116 stride_BQ(stride_BQ_)
117 {
118 }
119
130};
131
133{
135 CK_TILE_HOST QuantGemmHostArgs(const void* a_ptr_,
136 const void* b_ptr_,
137 void* c_ptr_,
138 const void* aq_ptr_,
139 const void* bq_ptr_,
140 index_t k_batch_,
141 index_t M_,
142 index_t N_,
143 index_t K_,
144 index_t QK_A_,
145 index_t QK_B_,
146 index_t stride_A_,
147 index_t stride_B_,
148 index_t stride_C_,
149 index_t stride_AQ_,
150 index_t stride_BQ_)
152 M_, N_, K_, QK_A_, QK_B_, stride_A_, stride_B_, stride_C_, stride_AQ_, stride_BQ_),
153 a_ptr(a_ptr_),
154 b_ptr(b_ptr_),
155 aq_ptr(aq_ptr_),
156 bq_ptr(bq_ptr_),
157 c_ptr(c_ptr_),
158 k_batch(k_batch_)
159 {
160 }
161
162 const void* a_ptr = nullptr;
163 const void* b_ptr = nullptr;
164 const void* aq_ptr = nullptr;
165 const void* bq_ptr = nullptr;
166 void* c_ptr = nullptr;
168};
169
189
190template <typename TilePartitioner_,
191 typename GemmPipeline_,
192 typename EpiloguePipeline_,
193 QuantType QuantType_>
195{
202
207
208 static constexpr index_t kBlockSize = GemmPipeline::BlockSize;
209 static constexpr bool PreshuffleQuant =
212
217
222
223 static constexpr auto I0 = number<0>(); // A Tensor
224 static constexpr auto I1 = number<1>(); // AQ Tensor
225 static constexpr auto I2 = number<2>(); // B Tensor
226 static constexpr auto I3 = number<3>(); // BQ Tensor
227 static constexpr auto I4 = number<4>(); // C Tensor
228
229 static constexpr auto kQuantType = QuantType_;
230
231 [[nodiscard]] CK_TILE_HOST static const std::string GetName()
232 {
233 // clang-format off
234 return concat('_', "gemm_quant", gemm_prec_str<ADataType, BDataType>, GemmPipeline::GetName());
235 // clang-format on
236 }
237
238 CK_TILE_HOST static constexpr auto GridSize(index_t M, index_t N, index_t KBatch)
239 {
240 return dim3(TilePartitioner::GridSize(M, N), 1, KBatch);
241 }
242
243 CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); }
244
245 CK_TILE_HOST static constexpr QuantGemmKernelArgs
247 {
248 return QuantGemmKernelArgs{hostArgs.a_ptr,
249 hostArgs.b_ptr,
250 hostArgs.aq_ptr,
251 hostArgs.bq_ptr,
252 hostArgs.c_ptr,
253 hostArgs.M,
254 hostArgs.N,
255 hostArgs.K,
256 hostArgs.QK_A,
257 hostArgs.QK_B,
258 hostArgs.stride_A,
259 hostArgs.stride_B,
260 hostArgs.stride_C,
261 hostArgs.stride_AQ,
262 hostArgs.stride_BQ,
263 hostArgs.k_batch};
264 }
265
267 {
268 return max(GemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
269 }
270
272 {
273 __device__ SplitKBatchOffset(const QuantGemmKernelArgs& kargs,
274 const std::size_t k_id = blockIdx.z)
275 {
276 constexpr auto K1 = TilePartitioner::BlockGemmShape::WarpTile::at(I2);
277 const index_t K_t = amd_wave_read_first_lane(kargs.k_batch * K1);
278 const index_t KRead = amd_wave_read_first_lane((kargs.K + K_t - 1) / K_t * K1);
279
280 if constexpr(std::is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
281 {
283 }
284 else if constexpr(std::is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
285 {
286 a_k_split_offset = amd_wave_read_first_lane(k_id * KRead * kargs.stride_A);
287 }
288
289 if constexpr(std::is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
290 {
291 b_k_split_offset = amd_wave_read_first_lane(k_id * KRead * kargs.stride_B);
292 }
293 else if constexpr(std::is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
294 {
296 }
297
298 if(k_id < static_cast<uint32_t>(kargs.k_batch - 1))
299 {
301 }
302 else
303 {
304 splitted_k = amd_wave_read_first_lane(kargs.K - KRead * (kargs.k_batch - 1));
305 }
306 }
307
311 };
312
314 {
315 if(kargs.k_batch != 1)
316 {
317 if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
318 {
319 CK_TILE_ERROR("Conditions not met for Kbatch >1 !");
320 }
321 return false;
322 }
323
324 if constexpr(kQuantType == QuantType::AQuantGrouped)
325 {
326 static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>);
327 if(kargs.QK_A % GemmPipeline::GetVectorSizeAQ() != 0)
328 {
329 if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
330 {
331 CK_TILE_ERROR("K_A is not a multiple of vector load size for A tensor!");
332 }
333 return false;
334 }
335 }
336
337 if constexpr(kQuantType == QuantType::BQuantGrouped)
338 {
339 static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
340 if(kargs.QK_B % GemmPipeline::GetVectorSizeBQ() != 0)
341 {
342 if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
343 {
344 CK_TILE_ERROR("K_B is not a multiple of vector load size for B tensor!");
345 }
346 return false;
347 }
348 }
349
350 if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
351 {
352 if(kargs.K % (TilePartitioner::KPerBlock * kargs.k_batch) != 0 &&
353 GemmPipeline::kPadK == false)
354 {
355 if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
356 {
357 CK_TILE_ERROR("Can't support K that is not a multiple of k_batch * KPerBlock "
358 "without padding!");
359 }
360 return false;
361 }
362 if(kargs.K % GemmPipeline::GetVectorSizeA() != 0)
363 {
364 if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
365 {
366 CK_TILE_ERROR("K is not a multiple of vector load size for A tensor!");
367 }
368 return false;
369 }
370 }
371 else
372 {
373 if(kargs.M % TilePartitioner::MPerBlock != 0 && GemmPipeline::kPadM == false)
374 {
375 if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
376 {
378 "Can't support M that is not a multiple of MPerBlock without padding!");
379 }
380 return false;
381 }
382 if(kargs.M % GemmPipeline::GetVectorSizeA() != 0)
383 {
384 if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
385 {
386 CK_TILE_ERROR("M is not a multiple of vector load size for A tensor!");
387 }
388 return false;
389 }
390 }
391
392 if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>)
393 {
394 if(kargs.N % TilePartitioner::NPerBlock != 0 && GemmPipeline::kPadN == false)
395 {
396 if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
397 {
399 "Can't support N that is not a multiple of NPerBlock without padding!");
400 }
401 return false;
402 }
403 if(kargs.N % GemmPipeline::GetVectorSizeB() != 0)
404 {
405 if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
406 {
407 CK_TILE_ERROR("N is not a multiple of vector load size for B tensor!");
408 }
409 return false;
410 }
411 }
412 else
413 {
414 if(kargs.K % (TilePartitioner::KPerBlock * kargs.k_batch) != 0 &&
415 GemmPipeline::kPadK == false)
416 {
417 if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
418 {
419 CK_TILE_ERROR("Can't support K that is not a multiple of k_batch * KPerBlock "
420 "without padding!");
421 }
422 return false;
423 }
424 if(kargs.K % GemmPipeline::GetVectorSizeB() != 0)
425 {
426 if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
427 {
428 CK_TILE_ERROR("K is not a multiple of vector load size for B tensor!");
429 }
430 return false;
431 }
432 }
433
434 if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
435 {
436 if(kargs.N % TilePartitioner::NPerBlock != 0 && GemmPipeline::kPadN == false)
437 {
438 if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
439 {
441 "Can't support N that is not a multiple of NPerBlock without padding!");
442 }
443 return false;
444 }
445 if(kargs.N % EpiloguePipeline::GetVectorSizeC() != 0)
446 {
447 if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
448 {
449 CK_TILE_ERROR("N is not a multiple of vector load size for C tensor!");
450 }
451 return false;
452 }
453 }
454 else
455 {
456 if(kargs.M % TilePartitioner::MPerBlock != 0 && GemmPipeline::kPadM == false)
457 {
458 if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
459 {
461 "Can't support M that is not a multiple of MPerBlock without padding!");
462 }
463 return false;
464 }
465 if(kargs.M % EpiloguePipeline::GetVectorSizeC() != 0)
466 {
467 if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
468 {
469 CK_TILE_ERROR("M is not a multiple of vector load size for C tensor!");
470 }
471 return false;
472 }
473 }
474 return true;
475 }
476
477 template <memory_operation_enum DstInMemOp = memory_operation_enum::set>
479 const BDataType* b_ptr,
480 const AQDataType* aq_ptr,
481 const BQDataType* bq_ptr,
482 CDataType* c_ptr,
483 const QuantGemmKernelArgs& kargs,
484 const SplitKBatchOffset& splitk_batch_offset)
485 {
486
487 static_assert(!TilePartitioner::BlockGemmShape::PermuteA, "Not implemented!");
488 const auto& a_tensor_view = [&]() {
489 if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
490 {
492 a_ptr,
493 make_tuple(kargs.M, splitk_batch_offset.splitted_k),
494 make_tuple(kargs.stride_A, 1),
495 number<GemmPipeline::GetVectorSizeA()>{},
496 number<1>{});
497 }
498 else
499 {
501 a_ptr,
502 make_tuple(splitk_batch_offset.splitted_k, kargs.M),
503 make_tuple(kargs.stride_A, 1),
504 number<GemmPipeline::GetVectorSizeA()>{},
505 number<1>{});
506 }
507 }();
508
509 const auto get_padding_size = [](index_t length, index_t alignment) {
510 return ck_tile::integer_least_multiple(length, alignment) - length;
511 };
512
513 const auto& aq_tensor_view = [&]() {
515 {
516 static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>);
517 const auto aq_x = kargs.M * GemmPipeline::KPerBlockAQ;
518 const auto aq_y = kargs.QK_A / GemmPipeline::KPerBlockAQ;
519
520 const auto aq_desc =
522 make_tuple(aq_x, 1),
523 number<GemmPipeline::GetVectorSizeAQ()>{},
524 number<1>{});
525
526 const auto block_tile_size = GemmPipeline::MPerBlock * GemmPipeline::KPerBlockAQ;
527 const auto aq_pad0_desc = transform_tensor_descriptor(
528 aq_desc,
531 make_right_pad_transform(aq_x, get_padding_size(aq_x, block_tile_size))),
534
535 const auto pad_aq_x = aq_pad0_desc.get_lengths()[I1];
536 const auto wave_tile_size =
537 TilePartitioner::BlockGemmShape::WarpTile::at(I0) * GemmPipeline::KPerBlockAQ;
538 const auto wave_tile_count_x =
539 ck_tile::integer_divide_ceil(pad_aq_x, wave_tile_size);
540 const auto aq_unmerge_pad0_desc = transform_tensor_descriptor(
541 aq_pad0_desc,
544 make_unmerge_transform(make_tuple(wave_tile_count_x, wave_tile_size))),
547
548 const auto aq_pad1_desc = transform_tensor_descriptor(
549 aq_unmerge_pad0_desc,
552 make_pass_through_transform(wave_tile_count_x),
554 wave_tile_size, get_padding_size(wave_tile_size, get_warp_size()))),
557
558 const auto pad_wave_size =
560 const auto aq_merge_pad1_desc = transform_tensor_descriptor(
561 aq_pad1_desc,
562 make_tuple(make_merge_transform(make_tuple(aq_y, wave_tile_count_x)),
563 make_pass_through_transform(pad_wave_size)),
566
567 return make_tensor_view<address_space_enum::global>(aq_ptr, aq_merge_pad1_desc);
568 }
569 else if constexpr(kQuantType == QuantType::AQuantGrouped && !PreshuffleQuant)
570 {
571 static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>);
573 aq_ptr,
574 make_tuple(kargs.M, kargs.QK_A),
575 make_tuple(kargs.stride_AQ, 1),
576 number<GemmPipeline::GetVectorSizeAQ()>{},
577 number<1>{});
578 }
579 else if constexpr(kQuantType == QuantType::RowColQuant)
580 {
582 aq_ptr,
583 make_tuple(kargs.M, kargs.N),
584 make_tuple(1, 0), // broadcasting over n
585 number<1>{},
586 number<1>{});
587 }
588 else
589 {
590 return nullptr; // TODO: use some other "empty" type for this
591 }
592 }();
593
594 const auto& b_tensor_view = [&]() {
595 if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>)
596 {
597 if constexpr(TilePartitioner::BlockGemmShape::PermuteB)
598 {
599 constexpr index_t K1 = GemmPipeline::GetSmemPackB();
600 const index_t K0 = splitk_batch_offset.splitted_k / K1;
601 constexpr index_t VectorSizeB = std::min(K1, GemmPipeline::GetVectorSizeB());
602 const auto b_k0_n_k1_desc =
604 make_tuple(kargs.N * K1, K1, I1),
606 number<1>{});
607 const auto b_n_k_desc = transform_tensor_descriptor(
608 b_k0_n_k1_desc,
613 return make_tensor_view<address_space_enum::global>(b_ptr, b_n_k_desc);
614 }
615 else
616 {
618 b_ptr,
619 make_tuple(splitk_batch_offset.splitted_k, kargs.N),
620 make_tuple(kargs.stride_B, 1),
621 number<GemmPipeline::GetVectorSizeB()>{},
622 number<1>{});
623 }
624 }
625 else
626 {
627 if constexpr(TilePartitioner::BlockGemmShape::PermuteB)
628 {
629 constexpr index_t K1 = GemmPipeline::GetSmemPackB();
630 const index_t K0 = splitk_batch_offset.splitted_k / K1;
631 constexpr index_t VectorSizeB = std::min(K1, GemmPipeline::GetVectorSizeB());
632 const auto b_k0_n_k1_desc =
634 make_tuple(kargs.N * K1, K1, I1),
636 number<1>{});
637 const auto b_n_k_desc = transform_tensor_descriptor(
638 b_k0_n_k1_desc,
643 return make_tensor_view<address_space_enum::global>(b_ptr, b_n_k_desc);
644 }
645 else
646 {
647 if constexpr(PreshuffleB)
648 {
649 index_t kFlatK =
650 GemmPipeline::flatKPerWarp *
651 (splitk_batch_offset.splitted_k /
652 TilePartitioner::BlockGemmShape::WarpTile::at(number<2>{}));
653 index_t kFlatN = kargs.N * kargs.K / kFlatK;
654
656 b_ptr,
657 make_tuple(kFlatN, kFlatK),
658 make_tuple(kFlatK, 1),
659 number<GemmPipeline::GetVectorSizeB()>{},
660 number<1>{});
661 }
662 else
663 {
665 b_ptr,
666 make_tuple(kargs.N, splitk_batch_offset.splitted_k),
667 make_tuple(kargs.stride_B, 1),
668 number<GemmPipeline::GetVectorSizeB()>{},
669 number<1>{});
670 }
671 }
672 }
673 }();
674
675 const auto& bq_tensor_view = [&]() {
676 if constexpr(kQuantType == QuantType::RowColQuant)
677 {
679 bq_ptr,
680 make_tuple(kargs.M, kargs.N),
681 make_tuple(0, 1), // broadcasting over m
682 number<1>{},
683 number<1>{});
684 }
685 else if constexpr(kQuantType == QuantType::BQuantGrouped)
686 {
687 static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
690 bq_ptr,
691 make_tuple(kargs.QK_B, integer_divide_ceil(kargs.N, QuantGroupSize::kN)),
692 make_tuple(1, kargs.stride_BQ),
693 number<GemmPipeline::GetVectorSizeBQ()>{},
694 number<1>{});
695 }
696 else
697 {
698 return nullptr; // TODO: use some other "empty" type for this
699 }
700 }();
701
702 // TODO: enable vector write for C in ColMajor
703 const auto& c_tensor_view = [&]() {
704 if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
705 {
707 c_ptr,
708 make_tuple(kargs.M, kargs.N),
709 make_tuple(kargs.stride_C, 1),
710 number<EpiloguePipeline::GetVectorSizeC()>{},
711 number<1>{});
712 }
713 else
714 {
716 c_ptr,
717 make_tuple(kargs.M, kargs.N),
718 make_tuple(1, kargs.stride_C),
719 number<1>{},
720 number<1>{});
721 }
722 }();
723
724 return make_tuple(
725 a_tensor_view, aq_tensor_view, b_tensor_view, bq_tensor_view, c_tensor_view);
726 }
727
728 template <typename TensorView>
729 CK_TILE_DEVICE static auto MakeGemmPadViews(const TensorView& views)
730 {
731 const auto& a_pad_view = [&]() {
732 const auto& a_tensor_view = views.at(I0);
733 if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
734 {
735 return pad_tensor_view(a_tensor_view,
739 }
740 else
741 {
742 return pad_tensor_view(a_tensor_view,
746 }
747 }();
748
749 // no padding
750 const auto& aq_pad_view = [&]() { return views.at(I1); }();
751
752 const auto& b_flat_view = views.at(I2); // not applying any padding to flat B view
753
754 const auto& b_pad_view = [&]() {
755 const auto& b_tensor_view = views.at(I2);
756 if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::ColumnMajor>)
757 {
758 return pad_tensor_view(b_tensor_view,
762 }
763 else
764 {
765 return pad_tensor_view(b_tensor_view,
769 }
770 }();
771
772 // no padding
773 const auto& bq_pad_view = [&]() { return views.at(I3); }();
774
775 // TODO vector write in for C in ColMajor
776 const auto& c_pad_view = [&]() {
777 const auto& c_tensor_view = views.at(I4);
778 if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
779 {
780 return pad_tensor_view(c_tensor_view,
784 }
785 else
786 {
787 return pad_tensor_view(c_tensor_view,
791 }
792 }();
793 if constexpr(PreshuffleB)
794 {
795
796 return make_tuple(a_pad_view, aq_pad_view, b_flat_view, bq_pad_view, c_pad_view);
797 }
798 else
799 {
800 return make_tuple(a_pad_view, aq_pad_view, b_pad_view, bq_pad_view, c_pad_view);
801 }
802 }
803
804 template <typename PadView>
805 CK_TILE_DEVICE static auto
806 MakeGemmTileWindows(const PadView& views, const index_t i_m, const index_t i_n)
807 {
808
809 const auto& a_pad_view = views.at(I0);
810 const auto& aq_pad_view = views.at(I1);
811 const auto& b_pad_view = views.at(I2);
812 const auto& bq_pad_view = views.at(I3);
813 const auto& c_pad_view = views.at(I4);
814 const auto& a_block_window = [&]() {
815 if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
816 {
817 return make_tile_window(a_pad_view,
820 {i_m, 0});
821 }
822 else
823 {
824 return make_tile_window(a_pad_view,
827 {0, i_m});
828 }
829 }();
830
831 const auto& aq_block_window = [&]() {
833 {
834 static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>);
836 constexpr auto block_m = TilePartitioner::MPerBlock;
837 constexpr auto warp_m = TilePartitioner::BlockGemmShape::WarpTile::at(I0);
838 constexpr auto aqk_per_block = TilePartitioner::KPerBlock / QuantGroupSize::kK;
839 constexpr auto tile_window_width =
840 ck_tile::integer_least_multiple(warp_m * aqk_per_block, get_warp_size());
841 constexpr auto tile_window_height = block_m / warp_m;
842 auto block_m_idx = i_m / block_m;
843 return make_tile_window(
844 aq_pad_view,
846 {block_m_idx * tile_window_height, 0});
847 }
848 else if constexpr(kQuantType == QuantType::AQuantGrouped && !PreshuffleQuant)
849 {
850 static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>);
852 constexpr auto block_m = TilePartitioner::MPerBlock;
853 constexpr auto block_k = TilePartitioner::KPerBlock;
854 return make_tile_window(
855 aq_pad_view,
856 make_tuple(number<block_m>{}, number<block_k / QuantGroupSize::kK>{}),
857 {i_m, 0});
858 }
859 else if constexpr(kQuantType == QuantType::RowColQuant)
860 {
861 return make_tile_window(aq_pad_view,
864 {i_m, i_n});
865 }
866 else
867 {
868 return nullptr; // TODO: use some other "empty" type?
869 }
870 }();
871
872 const auto& b_block_window = [&]() {
873 if constexpr(PreshuffleB)
874 {
875
876 return make_tile_window(
877 b_pad_view,
880 {static_cast<int>(i_n / TilePartitioner::BlockGemmShape::WarpTile::at(I1)), 0});
881 }
882 else
883 {
884 if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::ColumnMajor>)
885 {
886 return make_tile_window(b_pad_view,
889 {i_n, 0});
890 }
891 else
892 {
893 return make_tile_window(b_pad_view,
896 {0, i_n});
897 }
898 }
899 }();
900
901 const auto& bq_block_window = [&]() {
902 if constexpr(kQuantType == QuantType::RowColQuant)
903 {
904 return make_tile_window(bq_pad_view,
907 {i_m, i_n});
908 }
909 else if constexpr(kQuantType == QuantType::BQuantGrouped)
910 {
911 static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
913 return make_tile_window(
914 bq_pad_view,
916 number<TilePartitioner::NPerBlock / QuantGroupSize::kN>{}),
917 {0, i_n / QuantGroupSize::kN});
918 }
919 else
920 {
921 return nullptr; // TODO: use some other "empty" type here
922 }
923 }();
924
925 auto c_block_window = make_tile_window(
926 c_pad_view,
928 {i_m, i_n});
929
930 return make_tuple(
931 a_block_window, aq_block_window, b_block_window, bq_block_window, c_block_window);
932 }
933
950 template <memory_operation_enum DstInMemOp = memory_operation_enum::set>
951 CK_TILE_DEVICE static void RunGemm(const ADataType* a_ptr,
952 const BDataType* b_ptr,
953 const AQDataType* aq_ptr,
954 const BQDataType* bq_ptr,
955 CDataType* c_ptr,
956 void* smem_ptr_0,
957 const QuantGemmKernelArgs& kargs,
958 const SplitKBatchOffset& splitk_batch_offset,
959 const index_t block_idx_m,
960 const index_t block_idx_n)
961 {
962 // Create Gemm tensor views, pad views and tile windows
963 const auto& gemm_tensor_views_tuple = MakeGemmTensorViews<DstInMemOp>(
964 a_ptr, b_ptr, aq_ptr, bq_ptr, c_ptr, kargs, splitk_batch_offset);
965
966 const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
967 auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
968
969 const index_t num_loop =
970 amd_wave_read_first_lane(TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k));
971
972 // Run GEMM cooperatively by whole workgroup.
973 const auto& a_block_window = gemm_tile_windows.at(I0);
974 const auto& b_block_window = gemm_tile_windows.at(I2);
975
976 const auto& c_block_tile = [&]() {
977 if constexpr(kQuantType == QuantType::AQuantGrouped)
978 {
979 const auto& aq_block_window = gemm_tile_windows.at(I1);
980 return GemmPipeline{}.template operator()(
981 a_block_window, b_block_window, aq_block_window, kargs.M, num_loop, smem_ptr_0);
982 }
983 else if constexpr(kQuantType == QuantType::BQuantGrouped)
984 {
985 const auto& bq_block_window = gemm_tile_windows.at(I3);
986 return GemmPipeline{}.template operator()(
987 a_block_window, b_block_window, bq_block_window, num_loop, smem_ptr_0);
988 }
989 else if constexpr(kQuantType == QuantType::RowColQuant ||
991 {
992 return GemmPipeline{}.template operator()(
993 a_block_window, b_block_window, num_loop, smem_ptr_0);
994 }
995 }();
996
997 // Run Epilogue Pipeline
998 auto& c_block_window = gemm_tile_windows.at(I4);
999
1000 if constexpr(kQuantType == QuantType::AQuantGrouped ||
1002 {
1003 EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr_0);
1004 }
1005 else if constexpr(kQuantType == QuantType::RowColQuant)
1006 {
1007 const auto& aq_block_window = gemm_tile_windows.at(I1);
1008 const auto& bq_block_window = gemm_tile_windows.at(I3);
1009 EpiloguePipeline{}(c_block_window,
1010 c_block_tile,
1011 c_block_window,
1012 smem_ptr_0,
1013 aq_block_window,
1014 bq_block_window);
1015 }
1016 else if constexpr(kQuantType == QuantType::TensorQuant)
1017 {
1018 // TODO: why doesn't readfirstlane work here?
1019 // const AccDataType aq_scale =
1020 // __builtin_amdgcn_readfirstlane(type_convert<AccDataType>(*aq_ptr));
1021 // const AccDataType bq_scale =
1022 // __builtin_amdgcn_readfirstlane(type_convert<AccDataType>(*bq_ptr));
1023 const AccDataType aq_scale = type_convert<AccDataType>(*aq_ptr);
1024 const AccDataType bq_scale = type_convert<AccDataType>(*bq_ptr);
1026 c_block_window, c_block_tile, c_block_window, smem_ptr_0, aq_scale, bq_scale);
1027 }
1028 }
1029
1044 template <memory_operation_enum DstInMemOp = memory_operation_enum::set>
1045 CK_TILE_DEVICE static void RunGemm2LDS(const ADataType* a_ptr,
1046 const BDataType* b_ptr,
1047 const AQDataType* aq_ptr,
1048 const BQDataType* bq_ptr,
1049 CDataType* c_ptr,
1050 void* smem_ptr_0,
1051 void* smem_ptr_1,
1052 const QuantGemmKernelArgs& kargs,
1053 const SplitKBatchOffset& splitk_batch_offset,
1054 const index_t block_idx_m,
1055 const index_t block_idx_n)
1056 {
1057 // Create Gemm tensor views, pad views and tile windows
1058 const auto& gemm_tensor_views_tuple = MakeGemmTensorViews<DstInMemOp>(
1059 a_ptr, b_ptr, aq_ptr, bq_ptr, c_ptr, kargs, splitk_batch_offset);
1060
1061 const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
1062 auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
1063
1064 const index_t num_loop = __builtin_amdgcn_readfirstlane(
1065 TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k));
1066
1067 // Run GEMM cooperatively by whole workgroup.
1068 const auto& a_block_window = gemm_tile_windows.at(I0);
1069 const auto& b_block_window = gemm_tile_windows.at(I2);
1070
1071 const auto& c_block_tile = [&]() {
1072 if constexpr(kQuantType == QuantType::BQuantGrouped)
1073 {
1074 const auto& bq_block_window = gemm_tile_windows.at(I3);
1075 return GemmPipeline{}.template operator()(a_block_window,
1076 b_block_window,
1077 bq_block_window,
1078 num_loop,
1079 smem_ptr_0,
1080 smem_ptr_1);
1081 }
1082 else
1083 {
1084 return nullptr;
1085 }
1086 }();
1087
1088 // Run Epilogue Pipeline
1089 auto& c_block_window = gemm_tile_windows.at(I4);
1090
1091 if constexpr(kQuantType == QuantType::BQuantGrouped)
1092 {
1093 EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr_0);
1094 }
1095 else
1096 {
1097 return;
1098 // throw std::runtime_error("DoubleSmemBuffer Not implemented for AQuantGrouped or
1099 // RowColQuant"); static_assert(kQuantType == QuantType::BQuantGrouped,
1100 // "DoubleSmemBuffer Not implemented");
1101 }
1102 }
1103
1105 {
1106 const auto blockId = amd_wave_read_first_lane(blockIdx.x);
1107 const auto [iM, iN] = TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(blockId);
1108 const index_t i_m = amd_wave_read_first_lane(iM * TilePartitioner::MPerBlock);
1109 const index_t i_n = amd_wave_read_first_lane(iN * TilePartitioner::NPerBlock);
1110
1111 const SplitKBatchOffset splitk_batch_offset(kargs);
1112 // options
1113 const ADataType* a_ptr = static_cast<const ADataType*>(kargs.a_ptr);
1114 const BDataType* b_ptr = static_cast<const BDataType*>(kargs.b_ptr);
1115 const AQDataType* aq_ptr = static_cast<const AQDataType*>(kargs.aq_ptr);
1116 const BQDataType* bq_ptr = static_cast<const BQDataType*>(kargs.bq_ptr);
1117 CDataType* c_ptr = static_cast<CDataType*>(kargs.c_ptr);
1118
1119 // allocate LDS
1120 __shared__ char smem_ptr_0[GetSmemSize()];
1121 assert(kargs.k_batch == 1);
1122 if constexpr(GemmPipeline::DoubleSmemBuffer == true)
1123 {
1124 __shared__ char smem_ptr_1[GetSmemSize()];
1125
1126 RunGemm2LDS(a_ptr,
1127 b_ptr,
1128 aq_ptr,
1129 bq_ptr,
1130 c_ptr,
1131 smem_ptr_0,
1132 smem_ptr_1,
1133 kargs,
1134 splitk_batch_offset,
1135 i_m,
1136 i_n);
1137 }
1138 else
1139 {
1140 RunGemm(a_ptr,
1141 b_ptr,
1142 aq_ptr,
1143 bq_ptr,
1144 c_ptr,
1145 smem_ptr_0,
1146 kargs,
1147 splitk_batch_offset,
1148 i_m,
1149 i_n);
1150 }
1151 }
1152};
1153
1154} // namespace ck_tile
#define CK_TILE_DEVICE
Definition config.hpp:41
#define CK_TILE_HOST
Definition config.hpp:40
#define CK_TILE_HOST_DEVICE
Definition config.hpp:42
Definition arch.hpp:385
Definition tile/core/algorithm/cluster_descriptor.hpp:13
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
CK_TILE_HOST_DEVICE constexpr auto make_naive_tensor_view(DataType *__restrict__ p, const tuple< Lengths... > &lengths, const tuple< Strides... > &strides, number< GuaranteedLastDimensionVectorLength >=number<-1 >{}, number< GuaranteedLastDimensionVectorStride >=number<-1 >{})
Definition tensor_view.hpp:471
bool EnvIsEnabled(EnvVar)
Definition tile/core/utility/env.hpp:156
CK_TILE_HOST_DEVICE constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition coordinate_transform.hpp:1558
CK_TILE_HOST_DEVICE constexpr index_t get_warp_size()
Definition arch.hpp:63
void CK_TILE_ERROR(Args &&... args) noexcept
Definition tile/core/utility/env.hpp:12
__device__ uint32_t amd_wave_read_first_lane(uint16_t v)
Definition tile/core/arch/amd_buffer_addressing.hpp:35
CK_TILE_HOST_DEVICE constexpr auto make_tensor_view(DataType *__restrict__ p, const tensor_descriptor< Ts... > &desc)
Definition tensor_view.hpp:452
CK_TILE_HOST_DEVICE constexpr auto make_naive_tensor_descriptor(const tuple< Lengths... > &lengths, const tuple< Strides... > &strides, number< GuaranteedLastDimensionVectorLength >=number<-1 >{}, number< GuaranteedLastDimensionVectorStride >=number<-1 >{})
Definition tile/core/tensor/tensor_descriptor.hpp:274
CK_TILE_HOST_DEVICE constexpr auto make_merge_transform(const LowLengths &low_lengths)
Definition coordinate_transform.hpp:1615
CK_TILE_HOST_DEVICE constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldTopIdss, NewUpperDimensionNewTopIdss)
Definition tile/core/tensor/tensor_descriptor.hpp:203
std::string gemm_prec_str()
Definition utils.hpp:31
auto concat(const Ts &... xs) -> std::enable_if_t<!AllConvertibleToStringView< Ts... >, std::string >
Definition concat.hpp:43
CK_TILE_HOST_DEVICE constexpr auto make_unmerge_transform(const UpLengths &up_lengths, bool_constant< Use24BitIntegerCalculation >=bool_constant< false >{})
Definition coordinate_transform.hpp:1622
constant< v > number
Definition tile/core/numeric/integral_constant.hpp:37
CK_TILE_DEVICE constexpr auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition null_tile_window.hpp:75
QuantType
Definition tile_gemm_quant_traits.hpp:12
@ BQuantGrouped
Definition tile_gemm_quant_traits.hpp:14
@ RowColQuant
Definition tile_gemm_quant_traits.hpp:15
@ TensorQuant
Definition tile_gemm_quant_traits.hpp:16
@ AQuantGrouped
Definition tile_gemm_quant_traits.hpp:13
CK_TILE_HOST_DEVICE constexpr auto integer_divide_ceil(X x, Y y)
Definition tile/core/numeric/math.hpp:149
CK_TILE_HOST_DEVICE constexpr auto integer_least_multiple(X x, Y y)
Definition tile/core/numeric/math.hpp:155
CK_TILE_HOST_DEVICE constexpr auto pad_tensor_view(const TensorView &tensor_view, const TileLengths &tile_lengths, DoPads)
Definition tensor_view.hpp:530
CK_TILE_HOST_DEVICE constexpr T max(T x)
Definition tile/core/numeric/math.hpp:161
int32_t index_t
Definition integer.hpp:9
CK_TILE_HOST_DEVICE constexpr Y type_convert(X x)
Definition tile/core/numeric/type_convert.hpp:29
CK_TILE_HOST_DEVICE constexpr auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad_, bool_constant< SkipIsValidCheck >=bool_constant< false >{})
Definition coordinate_transform.hpp:1584
@ Default
Definition gemm_pipeline_ag_bg_cr_scheduler.hpp:15
CK_TILE_HOST_DEVICE constexpr auto make_tuple(Xs &&... xs)
Definition tile/core/container/tuple.hpp:360
STL namespace.
unsigned int uint32_t
Definition stdint.h:126
Definition gemm_quant_kernel.hpp:133
void * c_ptr
Definition gemm_quant_kernel.hpp:166
const void * aq_ptr
Definition gemm_quant_kernel.hpp:164
const void * bq_ptr
Definition gemm_quant_kernel.hpp:165
const void * b_ptr
Definition gemm_quant_kernel.hpp:163
CK_TILE_HOST QuantGemmHostArgs()=default
index_t k_batch
Definition gemm_quant_kernel.hpp:167
const void * a_ptr
Definition gemm_quant_kernel.hpp:162
CK_TILE_HOST QuantGemmHostArgs(const void *a_ptr_, const void *b_ptr_, void *c_ptr_, const void *aq_ptr_, const void *bq_ptr_, index_t k_batch_, index_t M_, index_t N_, index_t K_, index_t QK_A_, index_t QK_B_, index_t stride_A_, index_t stride_B_, index_t stride_C_, index_t stride_AQ_, index_t stride_BQ_)
Definition gemm_quant_kernel.hpp:135
Definition gemm_quant_kernel.hpp:272
__device__ SplitKBatchOffset(const QuantGemmKernelArgs &kargs, const std::size_t k_id=blockIdx.z)
Definition gemm_quant_kernel.hpp:273
index_t a_k_split_offset
Definition gemm_quant_kernel.hpp:308
index_t b_k_split_offset
Definition gemm_quant_kernel.hpp:309
index_t splitted_k
Definition gemm_quant_kernel.hpp:310
Definition gemm_quant_kernel.hpp:171
index_t k_batch
Definition gemm_quant_kernel.hpp:187
index_t stride_BQ
Definition gemm_quant_kernel.hpp:186
const void * b_ptr
Definition gemm_quant_kernel.hpp:173
void * c_ptr
Definition gemm_quant_kernel.hpp:176
const void * aq_ptr
Definition gemm_quant_kernel.hpp:174
index_t stride_A
Definition gemm_quant_kernel.hpp:182
index_t M
Definition gemm_quant_kernel.hpp:177
const void * a_ptr
Definition gemm_quant_kernel.hpp:172
const void * bq_ptr
Definition gemm_quant_kernel.hpp:175
index_t QK_B
Definition gemm_quant_kernel.hpp:181
index_t K
Definition gemm_quant_kernel.hpp:179
index_t QK_A
Definition gemm_quant_kernel.hpp:180
index_t stride_AQ
Definition gemm_quant_kernel.hpp:185
index_t N
Definition gemm_quant_kernel.hpp:178
index_t stride_C
Definition gemm_quant_kernel.hpp:184
index_t stride_B
Definition gemm_quant_kernel.hpp:183
Definition gemm_quant_kernel.hpp:195
static constexpr auto I4
Definition gemm_quant_kernel.hpp:227
static constexpr auto I3
Definition gemm_quant_kernel.hpp:226
static constexpr bool PreshuffleB
Definition gemm_quant_kernel.hpp:211
remove_cvref_t< typename detail::get_bq_data_type_or< GemmPipeline, AccDataType >::type > BQDataType
Definition gemm_quant_kernel.hpp:220
remove_cvref_t< GemmPipeline_ > GemmPipeline
Definition gemm_quant_kernel.hpp:197
static CK_TILE_HOST_DEVICE constexpr index_t GetSmemSize()
Definition gemm_quant_kernel.hpp:266
remove_cvref_t< EpiloguePipeline_ > EpiloguePipeline
Definition gemm_quant_kernel.hpp:198
static CK_TILE_DEVICE auto MakeGemmPadViews(const TensorView &views)
Definition gemm_quant_kernel.hpp:729
remove_cvref_t< TilePartitioner_ > TilePartitioner
Definition gemm_quant_kernel.hpp:196
remove_cvref_t< typename EpiloguePipeline::AccDataType > AccDataType
Definition gemm_quant_kernel.hpp:216
static constexpr auto I0
Definition gemm_quant_kernel.hpp:223
CK_TILE_DEVICE void operator()(QuantGemmKernelArgs kargs) const
Definition gemm_quant_kernel.hpp:1104
remove_cvref_t< typename EpiloguePipeline::ODataType > CDataType
Definition gemm_quant_kernel.hpp:215
static constexpr index_t kBlockSize
Definition gemm_quant_kernel.hpp:208
remove_cvref_t< typename GemmPipeline::BLayout > BLayout
Definition gemm_quant_kernel.hpp:200
static CK_TILE_HOST constexpr auto GridSize(index_t M, index_t N, index_t KBatch)
Definition gemm_quant_kernel.hpp:238
remove_cvref_t< typename GemmPipeline::CLayout > CLayout
Definition gemm_quant_kernel.hpp:201
static CK_TILE_DEVICE auto MakeGemmTensorViews(const ADataType *a_ptr, const BDataType *b_ptr, const AQDataType *aq_ptr, const BQDataType *bq_ptr, CDataType *c_ptr, const QuantGemmKernelArgs &kargs, const SplitKBatchOffset &splitk_batch_offset)
Definition gemm_quant_kernel.hpp:478
remove_cvref_t< typename detail::get_aq_layout_or< GemmPipeline, typename GemmPipeline::ALayout >::type > AQLayout
Definition gemm_quant_kernel.hpp:203
static constexpr auto I1
Definition gemm_quant_kernel.hpp:224
remove_cvref_t< typename GemmPipeline::ALayout > ALayout
Definition gemm_quant_kernel.hpp:199
remove_cvref_t< typename detail::get_aq_data_type_or< GemmPipeline, AccDataType >::type > AQDataType
Definition gemm_quant_kernel.hpp:218
static constexpr bool PreshuffleQuant
Definition gemm_quant_kernel.hpp:209
static CK_TILE_HOST bool IsSupportedArgument(const QuantGemmKernelArgs &kargs)
Definition gemm_quant_kernel.hpp:313
remove_cvref_t< typename GemmPipeline::BDataType > BDataType
Definition gemm_quant_kernel.hpp:214
static constexpr auto I2
Definition gemm_quant_kernel.hpp:225
remove_cvref_t< typename detail::get_bq_layout_or< GemmPipeline, typename GemmPipeline::BLayout >::type > BQLayout
Definition gemm_quant_kernel.hpp:205
static CK_TILE_DEVICE void RunGemm(const ADataType *a_ptr, const BDataType *b_ptr, const AQDataType *aq_ptr, const BQDataType *bq_ptr, CDataType *c_ptr, void *smem_ptr_0, const QuantGemmKernelArgs &kargs, const SplitKBatchOffset &splitk_batch_offset, const index_t block_idx_m, const index_t block_idx_n)
Runs single GEMM problem cooperatively by whole workgroup.
Definition gemm_quant_kernel.hpp:951
static CK_TILE_HOST const std::string GetName()
Definition gemm_quant_kernel.hpp:231
static CK_TILE_HOST constexpr auto BlockSize()
Definition gemm_quant_kernel.hpp:243
static CK_TILE_DEVICE auto MakeGemmTileWindows(const PadView &views, const index_t i_m, const index_t i_n)
Definition gemm_quant_kernel.hpp:806
static CK_TILE_HOST constexpr QuantGemmKernelArgs MakeKernelArgs(const QuantGemmHostArgs &hostArgs)
Definition gemm_quant_kernel.hpp:246
static CK_TILE_DEVICE void RunGemm2LDS(const ADataType *a_ptr, const BDataType *b_ptr, const AQDataType *aq_ptr, const BQDataType *bq_ptr, CDataType *c_ptr, void *smem_ptr_0, void *smem_ptr_1, const QuantGemmKernelArgs &kargs, const SplitKBatchOffset &splitk_batch_offset, const index_t block_idx_m, const index_t block_idx_n)
Runs single GEMM problem cooperatively by whole workgroup.
Definition gemm_quant_kernel.hpp:1045
remove_cvref_t< typename GemmPipeline::ADataType > ADataType
Definition gemm_quant_kernel.hpp:213
static constexpr auto kQuantType
Definition gemm_quant_kernel.hpp:229
index_t stride_AQ
Definition gemm_quant_kernel.hpp:128
index_t N
Definition gemm_quant_kernel.hpp:121
index_t K
Definition gemm_quant_kernel.hpp:122
index_t stride_BQ
Definition gemm_quant_kernel.hpp:129
index_t stride_C
Definition gemm_quant_kernel.hpp:127
index_t stride_B
Definition gemm_quant_kernel.hpp:126
index_t stride_A
Definition gemm_quant_kernel.hpp:125
CK_TILE_HOST QuantGemmProblem(index_t M_, index_t N_, index_t K_, index_t QK_A_, index_t QK_B_, index_t stride_A_, index_t stride_B_, index_t stride_C_, index_t stride_AQ_, index_t stride_BQ_)
Definition gemm_quant_kernel.hpp:97
index_t QK_A
Definition gemm_quant_kernel.hpp:123
index_t QK_B
Definition gemm_quant_kernel.hpp:124
CK_TILE_HOST QuantGemmProblem()=default
index_t M
Definition gemm_quant_kernel.hpp:120
Definition gemm_quant_kernel.hpp:47
Default type
Definition gemm_quant_kernel.hpp:48
typename T::AQLayout type
Definition gemm_quant_kernel.hpp:30
Definition gemm_quant_kernel.hpp:23
Default type
Definition gemm_quant_kernel.hpp:24
Definition gemm_quant_kernel.hpp:59
Default type
Definition gemm_quant_kernel.hpp:60
typename T::BQLayout type
Definition gemm_quant_kernel.hpp:42
Definition gemm_quant_kernel.hpp:35
Default type
Definition gemm_quant_kernel.hpp:36
static constexpr bool value
Definition gemm_quant_kernel.hpp:90
Definition gemm_quant_kernel.hpp:83
static constexpr bool value
Definition gemm_quant_kernel.hpp:84
Definition gemm_quant_kernel.hpp:71
static constexpr bool value
Definition gemm_quant_kernel.hpp:72
Definition tile/core/container/sequence.hpp:49
#define CK_TILE_ENV(name)
Definition tile/core/utility/env.hpp:145