mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-21 11:38:06 +08:00
144 lines
8.0 KiB
Metal
144 lines
8.0 KiB
Metal
// Copyright © 2023-2024 Apple Inc.
|
|
|
|
// clang-format off
|
|
#include "mlx/backend/metal/kernels/utils.h"
|
|
#include "mlx/backend/metal/kernels/steel/gemm/gemm.h"
|
|
#include "mlx/backend/metal/kernels/quantized.h"
|
|
|
|
#define instantiate_quantized(name, type, group_size, bits) \
|
|
instantiate_kernel( \
|
|
#name "_" #type "_gs_" #group_size "_b_" #bits, \
|
|
name, \
|
|
type, \
|
|
group_size, \
|
|
bits)
|
|
|
|
#define instantiate_quantized_batched(name, type, group_size, bits, batched) \
|
|
instantiate_kernel( \
|
|
#name "_" #type "_gs_" #group_size "_b_" #bits "_batch_" #batched, \
|
|
name, \
|
|
type, \
|
|
group_size, \
|
|
bits, \
|
|
batched)
|
|
|
|
#define instantiate_quantized_aligned(name, type, group_size, bits, aligned) \
|
|
instantiate_kernel( \
|
|
#name "_" #type "_gs_" #group_size "_b_" #bits "_alN_" #aligned, \
|
|
name, \
|
|
type, \
|
|
group_size, \
|
|
bits, \
|
|
aligned)
|
|
|
|
#define instantiate_quantized_aligned_batched(name, type, group_size, bits, aligned, batched) \
|
|
instantiate_kernel( \
|
|
#name "_" #type "_gs_" #group_size "_b_" #bits "_alN_" #aligned "_batch_" #batched, \
|
|
name, \
|
|
type, \
|
|
group_size, \
|
|
bits, \
|
|
aligned, \
|
|
batched)
|
|
|
|
#define instantiate_quantized_quad(name, type, group_size, bits, D, batched) \
|
|
instantiate_kernel( \
|
|
#name "_" #type "_gs_" #group_size "_b_" #bits "_d_" #D "_batch_" #batched, \
|
|
name, \
|
|
type, \
|
|
group_size, \
|
|
bits, \
|
|
D, \
|
|
batched)
|
|
|
|
#define instantiate_quantized_split_k(name, type, group_size, bits, split_k) \
|
|
instantiate_kernel( \
|
|
#name "_" #type "_gs_" #group_size "_b_" #bits "_spk_" #split_k, \
|
|
name, \
|
|
type, \
|
|
group_size, \
|
|
bits, \
|
|
split_k)
|
|
|
|
#define instantiate_gather_qmm_rhs(func, name, type, group_size, bits, bm, bn, bk, wm, wn, transpose) \
|
|
instantiate_kernel( \
|
|
#name "_" #type "_gs_" #group_size "_b_" #bits "_bm_" #bm "_bn_" #bn "_bk_" #bk "_wm_" #wm "_wn_" #wn, \
|
|
func, \
|
|
type, \
|
|
group_size, \
|
|
bits, \
|
|
bm, \
|
|
bn, \
|
|
bk, \
|
|
wm, \
|
|
wn, \
|
|
transpose)
|
|
|
|
#define instantiate_quantized_batched_wrap(name, type, group_size, bits) \
|
|
instantiate_quantized_batched(name, type, group_size, bits, 1) \
|
|
instantiate_quantized_batched(name, type, group_size, bits, 0)
|
|
|
|
#define instantiate_quantized_all_batched(type, group_size, bits) \
|
|
instantiate_quantized_batched_wrap(qmv_fast, type, group_size, bits) \
|
|
instantiate_quantized_batched_wrap(qmv, type, group_size, bits) \
|
|
instantiate_quantized_batched_wrap(qvm, type, group_size, bits) \
|
|
instantiate_quantized_batched_wrap(qmm_n, type, group_size, bits)
|
|
|
|
#define instantiate_quantized_all_single(type, group_size, bits) \
|
|
instantiate_quantized(affine_quantize, type, group_size, bits) \
|
|
instantiate_quantized(affine_dequantize, type, group_size, bits) \
|
|
instantiate_quantized(gather_qmv_fast, type, group_size, bits) \
|
|
instantiate_quantized(gather_qmv, type, group_size, bits) \
|
|
instantiate_quantized(gather_qvm, type, group_size, bits) \
|
|
instantiate_quantized(gather_qmm_n, type, group_size, bits)
|
|
|
|
#define instantiate_quantized_all_aligned(type, group_size, bits) \
|
|
instantiate_quantized_aligned(gather_qmm_t, type, group_size, bits, true) \
|
|
instantiate_quantized_aligned(gather_qmm_t, type, group_size, bits, false) \
|
|
instantiate_quantized_aligned_batched(qmm_t, type, group_size, bits, true, 1) \
|
|
instantiate_quantized_aligned_batched(qmm_t, type, group_size, bits, true, 0) \
|
|
instantiate_quantized_aligned_batched(qmm_t, type, group_size, bits, false, 1) \
|
|
instantiate_quantized_aligned_batched(qmm_t, type, group_size, bits, false, 0)
|
|
|
|
#define instantiate_quantized_all_quad(type, group_size, bits) \
|
|
instantiate_quantized_quad(qmv_quad, type, group_size, bits, 64, 1) \
|
|
instantiate_quantized_quad(qmv_quad, type, group_size, bits, 64, 0) \
|
|
instantiate_quantized_quad(qmv_quad, type, group_size, bits, 128, 1) \
|
|
instantiate_quantized_quad(qmv_quad, type, group_size, bits, 128, 0)
|
|
|
|
#define instantiate_quantized_all_splitk(type, group_size, bits) \
|
|
instantiate_quantized_split_k(qvm_split_k, type, group_size, bits, 8) \
|
|
instantiate_quantized_split_k(qvm_split_k, type, group_size, bits, 32)
|
|
|
|
#define instantiate_quantized_all_rhs(type, group_size, bits) \
|
|
instantiate_gather_qmm_rhs(gather_qmm_rhs, gather_qmm_rhs_nt, type, group_size, bits, 16, 32, 32, 1, 2, true) \
|
|
instantiate_gather_qmm_rhs(gather_qmm_rhs, gather_qmm_rhs_nn, type, group_size, bits, 16, 32, 32, 1, 2, false)
|
|
|
|
#define instantiate_quantized_funcs(type, group_size, bits) \
|
|
instantiate_quantized_all_single(type, group_size, bits) \
|
|
instantiate_quantized_all_batched(type, group_size, bits) \
|
|
instantiate_quantized_all_aligned(type, group_size, bits) \
|
|
instantiate_quantized_all_quad(type, group_size, bits) \
|
|
instantiate_quantized_all_splitk(type, group_size, bits) \
|
|
instantiate_quantized_all_rhs(type, group_size, bits)
|
|
|
|
#define instantiate_quantized_types(group_size, bits) \
|
|
instantiate_quantized_funcs(float, group_size, bits) \
|
|
instantiate_quantized_funcs(float16_t, group_size, bits) \
|
|
instantiate_quantized_funcs(bfloat16_t, group_size, bits)
|
|
|
|
#define instantiate_quantized_groups(bits) \
|
|
instantiate_quantized_types(128, bits) \
|
|
instantiate_quantized_types(64, bits) \
|
|
instantiate_quantized_types(32, bits)
|
|
|
|
#define instantiate_quantized_all() \
|
|
instantiate_quantized_groups(2) \
|
|
instantiate_quantized_groups(3) \
|
|
instantiate_quantized_groups(4) \
|
|
instantiate_quantized_groups(5) \
|
|
instantiate_quantized_groups(6) \
|
|
instantiate_quantized_groups(8)
|
|
|
|
instantiate_quantized_all() // clang-format on
|