2024-03-06 09:32:19 +08:00
|
|
|
// Copyright © 2023-2024 Apple Inc.
|
2023-12-19 15:18:57 +08:00
|
|
|
|
2024-06-13 00:47:12 +08:00
|
|
|
// clang-format off
|
2023-12-19 15:18:57 +08:00
|
|
|
#include "mlx/backend/metal/kernels/utils.h"
|
2024-01-18 04:42:39 +08:00
|
|
|
#include "mlx/backend/metal/kernels/steel/gemm/gemm.h"
|
2024-06-13 00:47:12 +08:00
|
|
|
#include "mlx/backend/metal/kernels/quantized.h"
|
|
|
|
|
|
|
|
|
|
|
|
#define instantiate_qmv_fast(itype, group_size, bits) \
|
|
|
|
instantiate_kernel( \
|
|
|
|
"qmv_" #itype "_gs_" #group_size "_b_" #bits "_fast", \
|
|
|
|
qmv_fast, \
|
|
|
|
itype, \
|
|
|
|
group_size, \
|
|
|
|
bits)
|
|
|
|
|
|
|
|
#define instantiate_qmv_fast_types(group_size, bits) \
|
|
|
|
instantiate_qmv_fast(float, group_size, bits) \
|
|
|
|
instantiate_qmv_fast(float16_t, group_size, bits) \
|
|
|
|
instantiate_qmv_fast(bfloat16_t, group_size, bits)
|
|
|
|
|
|
|
|
instantiate_qmv_fast_types(128, 2)
|
|
|
|
instantiate_qmv_fast_types(128, 4)
|
|
|
|
instantiate_qmv_fast_types(128, 8)
|
|
|
|
instantiate_qmv_fast_types( 64, 2)
|
|
|
|
instantiate_qmv_fast_types( 64, 4)
|
|
|
|
instantiate_qmv_fast_types( 64, 8)
|
|
|
|
instantiate_qmv_fast_types( 32, 2)
|
|
|
|
instantiate_qmv_fast_types( 32, 4)
|
|
|
|
instantiate_qmv_fast_types( 32, 8)
|
|
|
|
|
|
|
|
#define instantiate_qmv(itype, group_size, bits) \
|
|
|
|
instantiate_kernel( \
|
|
|
|
"qmv_" #itype "_gs_" #group_size "_b_" #bits, \
|
|
|
|
qmv, \
|
|
|
|
itype, \
|
|
|
|
group_size, \
|
|
|
|
bits)
|
2024-01-18 04:42:39 +08:00
|
|
|
|
2024-04-30 22:18:09 +08:00
|
|
|
#define instantiate_qmv_types(group_size, bits) \
|
2024-06-13 00:47:12 +08:00
|
|
|
instantiate_qmv(float, group_size, bits) \
|
|
|
|
instantiate_qmv(float16_t, group_size, bits) \
|
|
|
|
instantiate_qmv(bfloat16_t, group_size, bits)
|
2023-12-19 15:18:57 +08:00
|
|
|
|
|
|
|
instantiate_qmv_types(128, 2)
|
|
|
|
instantiate_qmv_types(128, 4)
|
|
|
|
instantiate_qmv_types(128, 8)
|
|
|
|
instantiate_qmv_types( 64, 2)
|
|
|
|
instantiate_qmv_types( 64, 4)
|
|
|
|
instantiate_qmv_types( 64, 8)
|
2024-01-21 22:19:05 +08:00
|
|
|
instantiate_qmv_types( 32, 2)
|
|
|
|
instantiate_qmv_types( 32, 4)
|
2024-06-13 00:47:12 +08:00
|
|
|
instantiate_qmv_types( 32, 8)
|
2024-04-30 22:18:09 +08:00
|
|
|
|
2024-06-13 00:47:12 +08:00
|
|
|
#define instantiate_qvm(itype, group_size, bits) \
|
|
|
|
instantiate_kernel( \
|
|
|
|
"qvm_" #itype "_gs_" #group_size "_b_" #bits, \
|
|
|
|
qvm, \
|
|
|
|
itype, \
|
|
|
|
group_size, \
|
|
|
|
bits)
|
2024-04-30 22:18:09 +08:00
|
|
|
|
|
|
|
#define instantiate_qvm_types(group_size, bits) \
|
2024-06-13 00:47:12 +08:00
|
|
|
instantiate_qvm(float, group_size, bits) \
|
|
|
|
instantiate_qvm(float16_t, group_size, bits) \
|
|
|
|
instantiate_qvm(bfloat16_t, group_size, bits)
|
2024-01-04 06:22:36 +08:00
|
|
|
|
|
|
|
instantiate_qvm_types(128, 2)
|
|
|
|
instantiate_qvm_types(128, 4)
|
|
|
|
instantiate_qvm_types(128, 8)
|
|
|
|
instantiate_qvm_types( 64, 2)
|
|
|
|
instantiate_qvm_types( 64, 4)
|
|
|
|
instantiate_qvm_types( 64, 8)
|
2024-01-21 22:19:05 +08:00
|
|
|
instantiate_qvm_types( 32, 2)
|
|
|
|
instantiate_qvm_types( 32, 4)
|
2024-06-13 00:47:12 +08:00
|
|
|
instantiate_qvm_types( 32, 8)
|
2024-04-30 22:18:09 +08:00
|
|
|
|
2024-06-13 00:47:12 +08:00
|
|
|
#define instantiate_qmm_t(itype, group_size, bits, aligned_N) \
|
|
|
|
instantiate_kernel( \
|
|
|
|
"qmm_t_" #itype "_gs_" #group_size "_b_" #bits "_alN_" #aligned_N, \
|
|
|
|
qmm_t, \
|
|
|
|
itype, \
|
|
|
|
group_size, \
|
|
|
|
bits, \
|
|
|
|
aligned_N)
|
2023-12-19 15:18:57 +08:00
|
|
|
|
2024-04-30 22:18:09 +08:00
|
|
|
#define instantiate_qmm_t_types(group_size, bits) \
|
2024-06-13 00:47:12 +08:00
|
|
|
instantiate_qmm_t(float, group_size, bits, false) \
|
|
|
|
instantiate_qmm_t(float16_t, group_size, bits, false) \
|
|
|
|
instantiate_qmm_t(bfloat16_t, group_size, bits, false) \
|
|
|
|
instantiate_qmm_t(float, group_size, bits, true) \
|
|
|
|
instantiate_qmm_t(float16_t, group_size, bits, true) \
|
|
|
|
instantiate_qmm_t(bfloat16_t, group_size, bits, true)
|
2023-12-19 15:18:57 +08:00
|
|
|
|
|
|
|
instantiate_qmm_t_types(128, 2)
|
|
|
|
instantiate_qmm_t_types(128, 4)
|
|
|
|
instantiate_qmm_t_types(128, 8)
|
|
|
|
instantiate_qmm_t_types( 64, 2)
|
|
|
|
instantiate_qmm_t_types( 64, 4)
|
|
|
|
instantiate_qmm_t_types( 64, 8)
|
2024-01-21 22:19:05 +08:00
|
|
|
instantiate_qmm_t_types( 32, 2)
|
|
|
|
instantiate_qmm_t_types( 32, 4)
|
2024-06-13 00:47:12 +08:00
|
|
|
instantiate_qmm_t_types( 32, 8)
|
2024-01-04 06:22:36 +08:00
|
|
|
|
2024-06-13 00:47:12 +08:00
|
|
|
#define instantiate_qmm_n(itype, group_size, bits) \
|
|
|
|
instantiate_kernel( \
|
|
|
|
"qmm_n_" #itype "_gs_" #group_size "_b_" #bits, \
|
|
|
|
qmm_n, \
|
|
|
|
itype, \
|
|
|
|
group_size, \
|
|
|
|
bits)
|
2024-01-04 06:22:36 +08:00
|
|
|
|
2024-04-30 22:18:09 +08:00
|
|
|
#define instantiate_qmm_n_types(group_size, bits) \
|
2024-06-13 00:47:12 +08:00
|
|
|
instantiate_qmm_n(float, group_size, bits) \
|
|
|
|
instantiate_qmm_n(float16_t, group_size, bits) \
|
|
|
|
instantiate_qmm_n(bfloat16_t, group_size, bits)
|
2024-01-04 06:22:36 +08:00
|
|
|
|
|
|
|
instantiate_qmm_n_types(128, 2)
|
|
|
|
instantiate_qmm_n_types(128, 4)
|
|
|
|
instantiate_qmm_n_types(128, 8)
|
|
|
|
instantiate_qmm_n_types( 64, 2)
|
|
|
|
instantiate_qmm_n_types( 64, 4)
|
|
|
|
instantiate_qmm_n_types( 64, 8)
|
2024-01-21 22:19:05 +08:00
|
|
|
instantiate_qmm_n_types( 32, 2)
|
|
|
|
instantiate_qmm_n_types( 32, 4)
|
2024-06-13 00:47:12 +08:00
|
|
|
instantiate_qmm_n_types( 32, 8)
|
|
|
|
|
|
|
|
#define instantiate_bs_qmv_fast(itype, group_size, bits) \
|
|
|
|
instantiate_kernel( \
|
|
|
|
"bs_qmv_" #itype "_gs_" #group_size "_b_" #bits "_fast", \
|
|
|
|
bs_qmv_fast, \
|
|
|
|
itype, \
|
|
|
|
group_size, \
|
|
|
|
bits)
|
|
|
|
|
|
|
|
#define instantiate_bs_qmv_fast_types(group_size, bits) \
|
|
|
|
instantiate_bs_qmv_fast(float, group_size, bits) \
|
|
|
|
instantiate_bs_qmv_fast(float16_t, group_size, bits) \
|
|
|
|
instantiate_bs_qmv_fast(bfloat16_t, group_size, bits)
|
|
|
|
|
|
|
|
instantiate_bs_qmv_fast_types(128, 2)
|
|
|
|
instantiate_bs_qmv_fast_types(128, 4)
|
|
|
|
instantiate_bs_qmv_fast_types(128, 8)
|
|
|
|
instantiate_bs_qmv_fast_types( 64, 2)
|
|
|
|
instantiate_bs_qmv_fast_types( 64, 4)
|
|
|
|
instantiate_bs_qmv_fast_types( 64, 8)
|
|
|
|
instantiate_bs_qmv_fast_types( 32, 2)
|
|
|
|
instantiate_bs_qmv_fast_types( 32, 4)
|
|
|
|
instantiate_bs_qmv_fast_types( 32, 8)
|
|
|
|
|
|
|
|
#define instantiate_bs_qmv(itype, group_size, bits) \
|
|
|
|
instantiate_kernel( \
|
|
|
|
"bs_qmv_" #itype "_gs_" #group_size "_b_" #bits, \
|
|
|
|
bs_qmv, \
|
|
|
|
itype, \
|
|
|
|
group_size, \
|
|
|
|
bits)
|
2024-05-17 06:24:14 +08:00
|
|
|
|
|
|
|
#define instantiate_bs_qmv_types(group_size, bits) \
|
2024-06-13 00:47:12 +08:00
|
|
|
instantiate_bs_qmv(float, group_size, bits) \
|
|
|
|
instantiate_bs_qmv(float16_t, group_size, bits) \
|
|
|
|
instantiate_bs_qmv(bfloat16_t, group_size, bits)
|
2024-05-17 06:24:14 +08:00
|
|
|
|
|
|
|
instantiate_bs_qmv_types(128, 2)
|
|
|
|
instantiate_bs_qmv_types(128, 4)
|
|
|
|
instantiate_bs_qmv_types(128, 8)
|
|
|
|
instantiate_bs_qmv_types( 64, 2)
|
|
|
|
instantiate_bs_qmv_types( 64, 4)
|
|
|
|
instantiate_bs_qmv_types( 64, 8)
|
|
|
|
instantiate_bs_qmv_types( 32, 2)
|
|
|
|
instantiate_bs_qmv_types( 32, 4)
|
2024-06-13 00:47:12 +08:00
|
|
|
instantiate_bs_qmv_types( 32, 8)
|
2024-05-17 06:24:14 +08:00
|
|
|
|
2024-06-13 00:47:12 +08:00
|
|
|
#define instantiate_bs_qvm(itype, group_size, bits) \
|
|
|
|
instantiate_kernel( \
|
|
|
|
"bs_qvm_" #itype "_gs_" #group_size "_b_" #bits, \
|
|
|
|
bs_qvm, \
|
|
|
|
itype, \
|
|
|
|
group_size, \
|
|
|
|
bits)
|
2024-05-17 06:24:14 +08:00
|
|
|
|
|
|
|
#define instantiate_bs_qvm_types(group_size, bits) \
|
2024-06-13 00:47:12 +08:00
|
|
|
instantiate_bs_qvm(float, group_size, bits) \
|
|
|
|
instantiate_bs_qvm(float16_t, group_size, bits) \
|
|
|
|
instantiate_bs_qvm(bfloat16_t, group_size, bits)
|
2024-05-17 06:24:14 +08:00
|
|
|
|
|
|
|
instantiate_bs_qvm_types(128, 2)
|
|
|
|
instantiate_bs_qvm_types(128, 4)
|
|
|
|
instantiate_bs_qvm_types(128, 8)
|
|
|
|
instantiate_bs_qvm_types( 64, 2)
|
|
|
|
instantiate_bs_qvm_types( 64, 4)
|
|
|
|
instantiate_bs_qvm_types( 64, 8)
|
|
|
|
instantiate_bs_qvm_types( 32, 2)
|
|
|
|
instantiate_bs_qvm_types( 32, 4)
|
2024-06-13 00:47:12 +08:00
|
|
|
instantiate_bs_qvm_types( 32, 8)
|
2024-05-17 06:24:14 +08:00
|
|
|
|
2024-06-13 00:47:12 +08:00
|
|
|
#define instantiate_bs_qmm_t(itype, group_size, bits, aligned_N) \
|
|
|
|
instantiate_kernel( \
|
|
|
|
"bs_qmm_t_" #itype "_gs_" #group_size "_b_" #bits "_alN_" #aligned_N, \
|
|
|
|
bs_qmm_t, \
|
|
|
|
itype, \
|
|
|
|
group_size, \
|
|
|
|
bits, \
|
|
|
|
aligned_N)
|
2024-05-17 06:24:14 +08:00
|
|
|
|
|
|
|
#define instantiate_bs_qmm_t_types(group_size, bits) \
|
2024-06-13 00:47:12 +08:00
|
|
|
instantiate_bs_qmm_t(float, group_size, bits, false) \
|
|
|
|
instantiate_bs_qmm_t(float16_t, group_size, bits, false) \
|
|
|
|
instantiate_bs_qmm_t(bfloat16_t, group_size, bits, false) \
|
|
|
|
instantiate_bs_qmm_t(float, group_size, bits, true) \
|
|
|
|
instantiate_bs_qmm_t(float16_t, group_size, bits, true) \
|
|
|
|
instantiate_bs_qmm_t(bfloat16_t, group_size, bits, true)
|
2024-05-17 06:24:14 +08:00
|
|
|
|
|
|
|
instantiate_bs_qmm_t_types(128, 2)
|
|
|
|
instantiate_bs_qmm_t_types(128, 4)
|
|
|
|
instantiate_bs_qmm_t_types(128, 8)
|
|
|
|
instantiate_bs_qmm_t_types( 64, 2)
|
|
|
|
instantiate_bs_qmm_t_types( 64, 4)
|
|
|
|
instantiate_bs_qmm_t_types( 64, 8)
|
|
|
|
instantiate_bs_qmm_t_types( 32, 2)
|
|
|
|
instantiate_bs_qmm_t_types( 32, 4)
|
2024-06-13 00:47:12 +08:00
|
|
|
instantiate_bs_qmm_t_types( 32, 8)
|
2024-05-17 06:24:14 +08:00
|
|
|
|
2024-06-13 00:47:12 +08:00
|
|
|
#define instantiate_bs_qmm_n(itype, group_size, bits) \
|
|
|
|
instantiate_kernel( \
|
|
|
|
"bs_qmm_n_" #itype "_gs_" #group_size "_b_" #bits, \
|
|
|
|
bs_qmm_n, \
|
|
|
|
itype, \
|
|
|
|
group_size, \
|
|
|
|
bits)
|
2024-05-17 06:24:14 +08:00
|
|
|
|
|
|
|
#define instantiate_bs_qmm_n_types(group_size, bits) \
|
2024-06-13 00:47:12 +08:00
|
|
|
instantiate_bs_qmm_n(float, group_size, bits) \
|
|
|
|
instantiate_bs_qmm_n(float16_t, group_size, bits) \
|
|
|
|
instantiate_bs_qmm_n(bfloat16_t, group_size, bits)
|
2024-05-17 06:24:14 +08:00
|
|
|
|
|
|
|
instantiate_bs_qmm_n_types(128, 2)
|
|
|
|
instantiate_bs_qmm_n_types(128, 4)
|
|
|
|
instantiate_bs_qmm_n_types(128, 8)
|
|
|
|
instantiate_bs_qmm_n_types( 64, 2)
|
|
|
|
instantiate_bs_qmm_n_types( 64, 4)
|
|
|
|
instantiate_bs_qmm_n_types( 64, 8)
|
|
|
|
instantiate_bs_qmm_n_types( 32, 2)
|
|
|
|
instantiate_bs_qmm_n_types( 32, 4)
|
|
|
|
instantiate_bs_qmm_n_types( 32, 8) // clang-format on
|