mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-22 19:28:14 +08:00
Fused Affine Quantize/Dequantize ops (#1282)
* Add fast affine dequantize * add full quantize kernel * fused kernel with scale/bias computation * fix docstring * fix no jit error * fix test * test fix * reduce fast api to only affine_quantize
This commit is contained in:
@@ -690,12 +690,12 @@ METAL_FUNC void qvm_impl(
|
||||
|
||||
template <
|
||||
typename T,
|
||||
const int BM,
|
||||
const int BK,
|
||||
const int BN,
|
||||
const int group_size,
|
||||
const int bits,
|
||||
const bool aligned_N>
|
||||
const bool aligned_N,
|
||||
const int BM = 32,
|
||||
const int BK = 32,
|
||||
const int BN = 32>
|
||||
METAL_FUNC void qmm_t_impl(
|
||||
const device T* x,
|
||||
const device uint32_t* w,
|
||||
@@ -812,11 +812,11 @@ METAL_FUNC void qmm_t_impl(
|
||||
|
||||
template <
|
||||
typename T,
|
||||
const int BM,
|
||||
const int BK,
|
||||
const int BN,
|
||||
const int group_size,
|
||||
const int bits>
|
||||
const int bits,
|
||||
const int BM = 32,
|
||||
const int BK = 32,
|
||||
const int BN = 32>
|
||||
METAL_FUNC void qmm_n_impl(
|
||||
const device T* x,
|
||||
const device uint32_t* w,
|
||||
@@ -1099,7 +1099,7 @@ template <
|
||||
threadgroup T Xs[BM * BK_padded];
|
||||
threadgroup T Ws[BN * BK_padded];
|
||||
|
||||
qmm_t_impl<T, BM, BK, BN, group_size, bits, aligned_N>(
|
||||
qmm_t_impl<T, group_size, bits, aligned_N, BM, BK, BN>(
|
||||
x, w, scales, biases, y, Xs, Ws, M, N, K, tid, lid, simd_gid, simd_lid);
|
||||
}
|
||||
|
||||
@@ -1131,7 +1131,7 @@ template <
|
||||
threadgroup T Xs[BM * BK_padded];
|
||||
threadgroup T Ws[BK * BN_padded];
|
||||
|
||||
qmm_n_impl<T, BM, BK, BN, group_size, bits>(
|
||||
qmm_n_impl<T, group_size, bits, BM, BK, BN>(
|
||||
x, w, scales, biases, y, Xs, Ws, M, N, K, tid, lid, simd_gid, simd_lid);
|
||||
}
|
||||
|
||||
@@ -1382,7 +1382,7 @@ template <
|
||||
s_strides,
|
||||
b_strides,
|
||||
tid);
|
||||
qmm_t_impl<T, BM, BK, BN, group_size, bits, aligned_N>(
|
||||
qmm_t_impl<T, group_size, bits, aligned_N, BM, BK, BN>(
|
||||
x, w, scales, biases, y, Xs, Ws, M, N, K, tid, lid, simd_gid, simd_lid);
|
||||
}
|
||||
|
||||
@@ -1450,6 +1450,147 @@ template <
|
||||
s_strides,
|
||||
b_strides,
|
||||
tid);
|
||||
qmm_n_impl<T, BM, BK, BN, group_size, bits>(
|
||||
qmm_n_impl<T, group_size, bits, BM, BK, BN>(
|
||||
x, w, scales, biases, y, Xs, Ws, M, N, K, tid, lid, simd_gid, simd_lid);
|
||||
}
|
||||
|
||||
template <typename T, const int group_size, const int bits>
|
||||
[[kernel]] void affine_quantize(
|
||||
const device T* w [[buffer(0)]],
|
||||
device uint8_t* out [[buffer(1)]],
|
||||
device T* scales [[buffer(2)]],
|
||||
device T* biases [[buffer(3)]],
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
constexpr T eps = T(1e-7);
|
||||
constexpr int simd_size = 32;
|
||||
constexpr int uint8_bits = 8;
|
||||
constexpr T n_bins = (1 << bits) - 1;
|
||||
constexpr int packs_per_int = uint8_bits / bits;
|
||||
constexpr int values_per_reduce = group_size / simd_size;
|
||||
constexpr int writes_per_reduce = packs_per_int / values_per_reduce;
|
||||
constexpr int writes_per_pack =
|
||||
writes_per_reduce > 1 ? 1 : values_per_reduce / packs_per_int;
|
||||
|
||||
static_assert(
|
||||
group_size % simd_size == 0,
|
||||
"Group size must be divisible by simd size.");
|
||||
|
||||
int in_index = index * values_per_reduce;
|
||||
int out_index = index * writes_per_pack;
|
||||
|
||||
T w_thread[values_per_reduce];
|
||||
T w_min = Limits<T>::max;
|
||||
T w_max = 0;
|
||||
|
||||
#pragma clang loop unroll(full)
|
||||
for (int i = 0; i < values_per_reduce; i++) {
|
||||
T val = w[in_index + i];
|
||||
w_thread[i] = val;
|
||||
w_min = min(w_min, val);
|
||||
w_max = max(w_max, val);
|
||||
}
|
||||
|
||||
w_min = simd_min(w_min);
|
||||
w_max = simd_max(w_max);
|
||||
|
||||
T scale = max((w_max - w_min) / n_bins, eps);
|
||||
bool side = abs(w_min) > abs(w_max);
|
||||
scale = side ? scale : -scale;
|
||||
T edge = side ? w_min : w_max;
|
||||
T q0 = round(edge / scale);
|
||||
bool at_zero = q0 == 0.0f;
|
||||
scale = at_zero ? scale : edge / q0;
|
||||
T bias = at_zero ? T(0) : edge;
|
||||
|
||||
// Write out the scales and biases
|
||||
int gindex = in_index / group_size;
|
||||
if (in_index % group_size == 0) {
|
||||
scales[gindex] = scale;
|
||||
biases[gindex] = bias;
|
||||
}
|
||||
|
||||
uint8_t output = 0;
|
||||
#pragma clang loop unroll(full)
|
||||
for (int i = 0; i < values_per_reduce; i++) {
|
||||
uint8_t val = min(round((w_thread[i] - bias) / scale), n_bins);
|
||||
if (bits == 8) {
|
||||
output = val;
|
||||
} else {
|
||||
output += val << (bits * (i % packs_per_int));
|
||||
}
|
||||
|
||||
if (packs_per_int < values_per_reduce &&
|
||||
i % packs_per_int == packs_per_int - 1) {
|
||||
out[out_index + i / packs_per_int] = output;
|
||||
output = 0;
|
||||
} else {
|
||||
#pragma clang loop unroll(full)
|
||||
for (int j = 0; j < writes_per_reduce - 1; j++) {
|
||||
uint8_t sval = simd_shuffle_down(val, j + 1);
|
||||
output += sval << (bits * (values_per_reduce + j + i));
|
||||
}
|
||||
}
|
||||
}
|
||||
if (writes_per_reduce > 0 && out_index % writes_per_reduce == 0) {
|
||||
out[out_index / writes_per_reduce] = output;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, const int group_size, const int bits>
|
||||
[[kernel]] void affine_quantize_scales_biases(
|
||||
const device T* w [[buffer(0)]],
|
||||
const device T* scales [[buffer(1)]],
|
||||
const device T* biases [[buffer(2)]],
|
||||
device uint8_t* out [[buffer(3)]],
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
constexpr int uint8_bits = 8;
|
||||
constexpr int packs_per_int = uint8_bits / bits;
|
||||
constexpr T n_bins = (1 << bits) - 1;
|
||||
|
||||
int in_index = index * packs_per_int;
|
||||
int gindex = in_index / group_size;
|
||||
T scale = scales[gindex];
|
||||
T bias = biases[gindex];
|
||||
|
||||
uint8_t output = 0;
|
||||
#pragma clang loop unroll(full)
|
||||
for (int i = 0; i < packs_per_int; i++) {
|
||||
uint8_t val = min(round((w[in_index + i] - bias) / scale), n_bins);
|
||||
if (bits == 8) {
|
||||
output = val;
|
||||
} else {
|
||||
output += val << (bits * i);
|
||||
}
|
||||
}
|
||||
out[index] = output;
|
||||
}
|
||||
|
||||
template <typename T, const int group_size, const int bits>
|
||||
[[kernel]] void affine_dequantize(
|
||||
const device uint8_t* w [[buffer(0)]],
|
||||
const device T* scales [[buffer(1)]],
|
||||
const device T* biases [[buffer(2)]],
|
||||
device T* out [[buffer(3)]],
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
constexpr int uint8_bits = 8;
|
||||
constexpr int packs_per_int = uint8_bits / bits;
|
||||
|
||||
int oindex = index * packs_per_int;
|
||||
int gindex = oindex / group_size;
|
||||
T scale = scales[gindex];
|
||||
T bias = biases[gindex];
|
||||
uint val = w[index];
|
||||
|
||||
#pragma clang loop unroll(full)
|
||||
for (int i = 0; i < packs_per_int; i++) {
|
||||
uint8_t d;
|
||||
if (bits == 2) {
|
||||
d = (val >> (bits * i)) & 0x03;
|
||||
} else if (bits == 4) {
|
||||
d = (val >> (bits * i)) & 0x0f;
|
||||
} else if (bits == 8) {
|
||||
d = val;
|
||||
}
|
||||
out[oindex + i] = scale * d + bias;
|
||||
}
|
||||
}
|
||||
|
@@ -5,241 +5,67 @@
|
||||
#include "mlx/backend/metal/kernels/steel/gemm/gemm.h"
|
||||
#include "mlx/backend/metal/kernels/quantized.h"
|
||||
|
||||
|
||||
#define instantiate_qmv_fast(itype, group_size, bits) \
|
||||
#define instantiate_quantized(name, type, group_size, bits) \
|
||||
instantiate_kernel( \
|
||||
"qmv_" #itype "_gs_" #group_size "_b_" #bits "_fast", \
|
||||
qmv_fast, \
|
||||
itype, \
|
||||
#name "_" #type "_gs_" #group_size "_b_" #bits, \
|
||||
name, \
|
||||
type, \
|
||||
group_size, \
|
||||
bits)
|
||||
|
||||
#define instantiate_qmv_fast_types(group_size, bits) \
|
||||
instantiate_qmv_fast(float, group_size, bits) \
|
||||
instantiate_qmv_fast(float16_t, group_size, bits) \
|
||||
instantiate_qmv_fast(bfloat16_t, group_size, bits)
|
||||
#define instantiate_quantized_types(name, group_size, bits) \
|
||||
instantiate_quantized(name, float, group_size, bits) \
|
||||
instantiate_quantized(name, float16_t, group_size, bits) \
|
||||
instantiate_quantized(name, bfloat16_t, group_size, bits)
|
||||
|
||||
instantiate_qmv_fast_types(128, 2)
|
||||
instantiate_qmv_fast_types(128, 4)
|
||||
instantiate_qmv_fast_types(128, 8)
|
||||
instantiate_qmv_fast_types( 64, 2)
|
||||
instantiate_qmv_fast_types( 64, 4)
|
||||
instantiate_qmv_fast_types( 64, 8)
|
||||
instantiate_qmv_fast_types( 32, 2)
|
||||
instantiate_qmv_fast_types( 32, 4)
|
||||
instantiate_qmv_fast_types( 32, 8)
|
||||
#define instantiate_quantized_groups(name, bits) \
|
||||
instantiate_quantized_types(name, 128, bits) \
|
||||
instantiate_quantized_types(name, 64, bits) \
|
||||
instantiate_quantized_types(name, 32, bits)
|
||||
|
||||
#define instantiate_qmv(itype, group_size, bits) \
|
||||
instantiate_kernel( \
|
||||
"qmv_" #itype "_gs_" #group_size "_b_" #bits, \
|
||||
qmv, \
|
||||
itype, \
|
||||
group_size, \
|
||||
bits)
|
||||
#define instantiate_quantized_all(name) \
|
||||
instantiate_quantized_groups(name, 2) \
|
||||
instantiate_quantized_groups(name, 4) \
|
||||
instantiate_quantized_groups(name, 8)
|
||||
|
||||
#define instantiate_qmv_types(group_size, bits) \
|
||||
instantiate_qmv(float, group_size, bits) \
|
||||
instantiate_qmv(float16_t, group_size, bits) \
|
||||
instantiate_qmv(bfloat16_t, group_size, bits)
|
||||
instantiate_quantized_all(qmv_fast)
|
||||
instantiate_quantized_all(qmv)
|
||||
instantiate_quantized_all(qvm)
|
||||
instantiate_quantized_all(qmm_n)
|
||||
instantiate_quantized_all(bs_qmv_fast)
|
||||
instantiate_quantized_all(bs_qmv)
|
||||
instantiate_quantized_all(bs_qvm)
|
||||
instantiate_quantized_all(bs_qmm_n)
|
||||
instantiate_quantized_all(affine_quantize)
|
||||
instantiate_quantized_all(affine_quantize_scales_biases)
|
||||
instantiate_quantized_all(affine_dequantize)
|
||||
|
||||
instantiate_qmv_types(128, 2)
|
||||
instantiate_qmv_types(128, 4)
|
||||
instantiate_qmv_types(128, 8)
|
||||
instantiate_qmv_types( 64, 2)
|
||||
instantiate_qmv_types( 64, 4)
|
||||
instantiate_qmv_types( 64, 8)
|
||||
instantiate_qmv_types( 32, 2)
|
||||
instantiate_qmv_types( 32, 4)
|
||||
instantiate_qmv_types( 32, 8)
|
||||
#define instantiate_quantized_aligned(name, type, group_size, bits, aligned) \
|
||||
instantiate_kernel( \
|
||||
#name "_" #type "_gs_" #group_size "_b_" #bits "_alN_" #aligned, \
|
||||
name, \
|
||||
type, \
|
||||
group_size, \
|
||||
bits, \
|
||||
aligned)
|
||||
|
||||
#define instantiate_qvm(itype, group_size, bits) \
|
||||
instantiate_kernel( \
|
||||
"qvm_" #itype "_gs_" #group_size "_b_" #bits, \
|
||||
qvm, \
|
||||
itype, \
|
||||
group_size, \
|
||||
bits)
|
||||
#define instantiate_quantized_types_aligned(name, group_size, bits) \
|
||||
instantiate_quantized_aligned(name, float, group_size, bits, true) \
|
||||
instantiate_quantized_aligned(name, float16_t, group_size, bits, true) \
|
||||
instantiate_quantized_aligned(name, bfloat16_t, group_size, bits, true) \
|
||||
instantiate_quantized_aligned(name, float, group_size, bits, false) \
|
||||
instantiate_quantized_aligned(name, float16_t, group_size, bits, false) \
|
||||
instantiate_quantized_aligned(name, bfloat16_t, group_size, bits, false)
|
||||
|
||||
#define instantiate_qvm_types(group_size, bits) \
|
||||
instantiate_qvm(float, group_size, bits) \
|
||||
instantiate_qvm(float16_t, group_size, bits) \
|
||||
instantiate_qvm(bfloat16_t, group_size, bits)
|
||||
#define instantiate_quantized_groups_aligned(name, bits) \
|
||||
instantiate_quantized_types_aligned(name, 128, bits) \
|
||||
instantiate_quantized_types_aligned(name, 64, bits) \
|
||||
instantiate_quantized_types_aligned(name, 32, bits)
|
||||
|
||||
instantiate_qvm_types(128, 2)
|
||||
instantiate_qvm_types(128, 4)
|
||||
instantiate_qvm_types(128, 8)
|
||||
instantiate_qvm_types( 64, 2)
|
||||
instantiate_qvm_types( 64, 4)
|
||||
instantiate_qvm_types( 64, 8)
|
||||
instantiate_qvm_types( 32, 2)
|
||||
instantiate_qvm_types( 32, 4)
|
||||
instantiate_qvm_types( 32, 8)
|
||||
#define instantiate_quantized_all_aligned(name) \
|
||||
instantiate_quantized_groups_aligned(name, 2) \
|
||||
instantiate_quantized_groups_aligned(name, 4) \
|
||||
instantiate_quantized_groups_aligned(name, 8) \
|
||||
|
||||
#define instantiate_qmm_t(itype, group_size, bits, aligned_N) \
|
||||
instantiate_kernel( \
|
||||
"qmm_t_" #itype "_gs_" #group_size "_b_" #bits "_alN_" #aligned_N, \
|
||||
qmm_t, \
|
||||
itype, \
|
||||
group_size, \
|
||||
bits, \
|
||||
aligned_N)
|
||||
|
||||
#define instantiate_qmm_t_types(group_size, bits) \
|
||||
instantiate_qmm_t(float, group_size, bits, false) \
|
||||
instantiate_qmm_t(float16_t, group_size, bits, false) \
|
||||
instantiate_qmm_t(bfloat16_t, group_size, bits, false) \
|
||||
instantiate_qmm_t(float, group_size, bits, true) \
|
||||
instantiate_qmm_t(float16_t, group_size, bits, true) \
|
||||
instantiate_qmm_t(bfloat16_t, group_size, bits, true)
|
||||
|
||||
instantiate_qmm_t_types(128, 2)
|
||||
instantiate_qmm_t_types(128, 4)
|
||||
instantiate_qmm_t_types(128, 8)
|
||||
instantiate_qmm_t_types( 64, 2)
|
||||
instantiate_qmm_t_types( 64, 4)
|
||||
instantiate_qmm_t_types( 64, 8)
|
||||
instantiate_qmm_t_types( 32, 2)
|
||||
instantiate_qmm_t_types( 32, 4)
|
||||
instantiate_qmm_t_types( 32, 8)
|
||||
|
||||
#define instantiate_qmm_n(itype, group_size, bits) \
|
||||
instantiate_kernel( \
|
||||
"qmm_n_" #itype "_gs_" #group_size "_b_" #bits, \
|
||||
qmm_n, \
|
||||
itype, \
|
||||
group_size, \
|
||||
bits)
|
||||
|
||||
#define instantiate_qmm_n_types(group_size, bits) \
|
||||
instantiate_qmm_n(float, group_size, bits) \
|
||||
instantiate_qmm_n(float16_t, group_size, bits) \
|
||||
instantiate_qmm_n(bfloat16_t, group_size, bits)
|
||||
|
||||
instantiate_qmm_n_types(128, 2)
|
||||
instantiate_qmm_n_types(128, 4)
|
||||
instantiate_qmm_n_types(128, 8)
|
||||
instantiate_qmm_n_types( 64, 2)
|
||||
instantiate_qmm_n_types( 64, 4)
|
||||
instantiate_qmm_n_types( 64, 8)
|
||||
instantiate_qmm_n_types( 32, 2)
|
||||
instantiate_qmm_n_types( 32, 4)
|
||||
instantiate_qmm_n_types( 32, 8)
|
||||
|
||||
#define instantiate_bs_qmv_fast(itype, group_size, bits) \
|
||||
instantiate_kernel( \
|
||||
"bs_qmv_" #itype "_gs_" #group_size "_b_" #bits "_fast", \
|
||||
bs_qmv_fast, \
|
||||
itype, \
|
||||
group_size, \
|
||||
bits)
|
||||
|
||||
#define instantiate_bs_qmv_fast_types(group_size, bits) \
|
||||
instantiate_bs_qmv_fast(float, group_size, bits) \
|
||||
instantiate_bs_qmv_fast(float16_t, group_size, bits) \
|
||||
instantiate_bs_qmv_fast(bfloat16_t, group_size, bits)
|
||||
|
||||
instantiate_bs_qmv_fast_types(128, 2)
|
||||
instantiate_bs_qmv_fast_types(128, 4)
|
||||
instantiate_bs_qmv_fast_types(128, 8)
|
||||
instantiate_bs_qmv_fast_types( 64, 2)
|
||||
instantiate_bs_qmv_fast_types( 64, 4)
|
||||
instantiate_bs_qmv_fast_types( 64, 8)
|
||||
instantiate_bs_qmv_fast_types( 32, 2)
|
||||
instantiate_bs_qmv_fast_types( 32, 4)
|
||||
instantiate_bs_qmv_fast_types( 32, 8)
|
||||
|
||||
#define instantiate_bs_qmv(itype, group_size, bits) \
|
||||
instantiate_kernel( \
|
||||
"bs_qmv_" #itype "_gs_" #group_size "_b_" #bits, \
|
||||
bs_qmv, \
|
||||
itype, \
|
||||
group_size, \
|
||||
bits)
|
||||
|
||||
#define instantiate_bs_qmv_types(group_size, bits) \
|
||||
instantiate_bs_qmv(float, group_size, bits) \
|
||||
instantiate_bs_qmv(float16_t, group_size, bits) \
|
||||
instantiate_bs_qmv(bfloat16_t, group_size, bits)
|
||||
|
||||
instantiate_bs_qmv_types(128, 2)
|
||||
instantiate_bs_qmv_types(128, 4)
|
||||
instantiate_bs_qmv_types(128, 8)
|
||||
instantiate_bs_qmv_types( 64, 2)
|
||||
instantiate_bs_qmv_types( 64, 4)
|
||||
instantiate_bs_qmv_types( 64, 8)
|
||||
instantiate_bs_qmv_types( 32, 2)
|
||||
instantiate_bs_qmv_types( 32, 4)
|
||||
instantiate_bs_qmv_types( 32, 8)
|
||||
|
||||
#define instantiate_bs_qvm(itype, group_size, bits) \
|
||||
instantiate_kernel( \
|
||||
"bs_qvm_" #itype "_gs_" #group_size "_b_" #bits, \
|
||||
bs_qvm, \
|
||||
itype, \
|
||||
group_size, \
|
||||
bits)
|
||||
|
||||
#define instantiate_bs_qvm_types(group_size, bits) \
|
||||
instantiate_bs_qvm(float, group_size, bits) \
|
||||
instantiate_bs_qvm(float16_t, group_size, bits) \
|
||||
instantiate_bs_qvm(bfloat16_t, group_size, bits)
|
||||
|
||||
instantiate_bs_qvm_types(128, 2)
|
||||
instantiate_bs_qvm_types(128, 4)
|
||||
instantiate_bs_qvm_types(128, 8)
|
||||
instantiate_bs_qvm_types( 64, 2)
|
||||
instantiate_bs_qvm_types( 64, 4)
|
||||
instantiate_bs_qvm_types( 64, 8)
|
||||
instantiate_bs_qvm_types( 32, 2)
|
||||
instantiate_bs_qvm_types( 32, 4)
|
||||
instantiate_bs_qvm_types( 32, 8)
|
||||
|
||||
#define instantiate_bs_qmm_t(itype, group_size, bits, aligned_N) \
|
||||
instantiate_kernel( \
|
||||
"bs_qmm_t_" #itype "_gs_" #group_size "_b_" #bits "_alN_" #aligned_N, \
|
||||
bs_qmm_t, \
|
||||
itype, \
|
||||
group_size, \
|
||||
bits, \
|
||||
aligned_N)
|
||||
|
||||
#define instantiate_bs_qmm_t_types(group_size, bits) \
|
||||
instantiate_bs_qmm_t(float, group_size, bits, false) \
|
||||
instantiate_bs_qmm_t(float16_t, group_size, bits, false) \
|
||||
instantiate_bs_qmm_t(bfloat16_t, group_size, bits, false) \
|
||||
instantiate_bs_qmm_t(float, group_size, bits, true) \
|
||||
instantiate_bs_qmm_t(float16_t, group_size, bits, true) \
|
||||
instantiate_bs_qmm_t(bfloat16_t, group_size, bits, true)
|
||||
|
||||
instantiate_bs_qmm_t_types(128, 2)
|
||||
instantiate_bs_qmm_t_types(128, 4)
|
||||
instantiate_bs_qmm_t_types(128, 8)
|
||||
instantiate_bs_qmm_t_types( 64, 2)
|
||||
instantiate_bs_qmm_t_types( 64, 4)
|
||||
instantiate_bs_qmm_t_types( 64, 8)
|
||||
instantiate_bs_qmm_t_types( 32, 2)
|
||||
instantiate_bs_qmm_t_types( 32, 4)
|
||||
instantiate_bs_qmm_t_types( 32, 8)
|
||||
|
||||
#define instantiate_bs_qmm_n(itype, group_size, bits) \
|
||||
instantiate_kernel( \
|
||||
"bs_qmm_n_" #itype "_gs_" #group_size "_b_" #bits, \
|
||||
bs_qmm_n, \
|
||||
itype, \
|
||||
group_size, \
|
||||
bits)
|
||||
|
||||
#define instantiate_bs_qmm_n_types(group_size, bits) \
|
||||
instantiate_bs_qmm_n(float, group_size, bits) \
|
||||
instantiate_bs_qmm_n(float16_t, group_size, bits) \
|
||||
instantiate_bs_qmm_n(bfloat16_t, group_size, bits)
|
||||
|
||||
instantiate_bs_qmm_n_types(128, 2)
|
||||
instantiate_bs_qmm_n_types(128, 4)
|
||||
instantiate_bs_qmm_n_types(128, 8)
|
||||
instantiate_bs_qmm_n_types( 64, 2)
|
||||
instantiate_bs_qmm_n_types( 64, 4)
|
||||
instantiate_bs_qmm_n_types( 64, 8)
|
||||
instantiate_bs_qmm_n_types( 32, 2)
|
||||
instantiate_bs_qmm_n_types( 32, 4)
|
||||
instantiate_bs_qmm_n_types( 32, 8) // clang-format on
|
||||
instantiate_quantized_all_aligned(qmm_t)
|
||||
instantiate_quantized_all_aligned(bs_qmm_t) // clang-format on
|
||||
|
@@ -7,6 +7,7 @@
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/kernels.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
#include "mlx/fast_primitives.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core {
|
||||
@@ -47,8 +48,8 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
if (B < 6 && O % 8 == 0 && D % 512 == 0 && D >= 512) {
|
||||
std::ostringstream kname;
|
||||
auto type_string = get_type_string(x.dtype());
|
||||
kname << "qmv_" << type_string << "_gs_" << group_size_ << "_b_" << bits_
|
||||
<< "_fast";
|
||||
kname << "qmv_fast_" << type_string << "_gs_" << group_size_ << "_b_"
|
||||
<< bits_;
|
||||
|
||||
// Encode and dispatch kernel
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
@@ -270,8 +271,8 @@ void GatherQMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
if (B < 6 && O % 8 == 0 && D % 512 == 0 && D >= 512) {
|
||||
std::ostringstream kname;
|
||||
auto type_string = get_type_string(x.dtype());
|
||||
kname << "bs_qmv_" << type_string << "_gs_" << group_size_ << "_b_"
|
||||
<< bits_ << "_fast";
|
||||
kname << "bs_qmv_fast_" << type_string << "_gs_" << group_size_ << "_b_"
|
||||
<< bits_;
|
||||
|
||||
// Encode and dispatch kernel
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
@@ -513,4 +514,82 @@ void GatherQMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
}
|
||||
|
||||
void fast::AffineQuantize::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
bool compute_scale_bias = inputs.size() == 1;
|
||||
|
||||
auto& w_pre = inputs[0];
|
||||
auto& out = outputs[0];
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
|
||||
auto& s = stream();
|
||||
auto& d = metal::device(s.device);
|
||||
|
||||
std::vector<array> copies;
|
||||
auto ensure_row_contiguous = [&copies, &s](const array& arr) {
|
||||
if (arr.flags().row_contiguous) {
|
||||
return arr;
|
||||
} else {
|
||||
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
||||
copy_gpu(arr, arr_copy, CopyType::General, s);
|
||||
copies.push_back(arr_copy);
|
||||
return arr_copy;
|
||||
}
|
||||
};
|
||||
auto w = ensure_row_contiguous(w_pre);
|
||||
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder.set_input_array(w, 0);
|
||||
if (!compute_scale_bias) {
|
||||
auto& scales_pre = inputs[1];
|
||||
auto& biases_pre = inputs[2];
|
||||
auto scales = ensure_row_contiguous(scales_pre);
|
||||
auto biases = ensure_row_contiguous(biases_pre);
|
||||
compute_encoder.set_input_array(scales, 1);
|
||||
compute_encoder.set_input_array(biases, 2);
|
||||
compute_encoder.set_output_array(out, 3);
|
||||
} else {
|
||||
auto& scales = outputs[1];
|
||||
auto& biases = outputs[2];
|
||||
scales.set_data(allocator::malloc_or_wait(scales.nbytes()));
|
||||
biases.set_data(allocator::malloc_or_wait(biases.nbytes()));
|
||||
compute_encoder.set_output_array(out, 1);
|
||||
compute_encoder.set_output_array(scales, 2);
|
||||
compute_encoder.set_output_array(biases, 3);
|
||||
}
|
||||
|
||||
std::ostringstream kname;
|
||||
auto type_string = dequantize_ ? get_type_string(out.dtype())
|
||||
: get_type_string(w_pre.dtype());
|
||||
auto kernel_func = "affine_quantize_scales_biases";
|
||||
if (dequantize_) {
|
||||
kernel_func = "affine_dequantize";
|
||||
} else if (compute_scale_bias) {
|
||||
kernel_func = "affine_quantize";
|
||||
}
|
||||
kname << kernel_func << "_" << type_string << "_gs_" << group_size_ << "_b_"
|
||||
<< bits_;
|
||||
auto template_def = get_template_definition(
|
||||
kname.str(), kernel_func, type_string, group_size_, bits_);
|
||||
auto kernel = get_quantized_kernel(d, kname.str(), template_def);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
// Treat uint32 as uint8 in kernel
|
||||
constexpr int uint8_per_uint32 = 4;
|
||||
constexpr int simd_size = 32;
|
||||
int packs_per_int = 8 / bits_;
|
||||
int per_thread = compute_scale_bias ? group_size_ / simd_size : packs_per_int;
|
||||
size_t nthreads =
|
||||
dequantize_ ? w.size() * uint8_per_uint32 : w.size() / per_thread;
|
||||
|
||||
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||
auto group_dims = MTL::Size(thread_group_size, 1, 1);
|
||||
auto grid_dims = MTL::Size(nthreads, 1, 1);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
|
||||
d.get_command_buffer(s.index)->addCompletedHandler(
|
||||
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
Reference in New Issue
Block a user