mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-28 06:07:46 +08:00
mxfp4 works
This commit is contained in:
parent
4cf90c9762
commit
6295e53216
@ -407,6 +407,51 @@ void _qmm_dispatch(
|
||||
}
|
||||
}
|
||||
|
||||
// template <typename T>
|
||||
// void _qmm_mxfp4_dispatch_typed(
|
||||
// array& out,
|
||||
// const array& x,
|
||||
// const array& w,
|
||||
// const array& scales,
|
||||
// bool transposed_w) {
|
||||
// int K = x.shape(-1);
|
||||
// int M = x.ndim() > 1 ? x.shape(-2) : 1;
|
||||
// int N = out.shape(-1);
|
||||
// int w_els = w.ndim() > 2 ? w.shape(-1) * w.shape(-2) : 0;
|
||||
// int g_els = w.ndim() > 2 ? scales.shape(-1) * scales.shape(-2) : 0;
|
||||
// int batch_size = x.size() / (K * M);
|
||||
//
|
||||
// auto out_ptr = out.data<T>();
|
||||
// auto x_ptr = x.data<T>();
|
||||
// auto w_ptr = w.data<uint32_t>();
|
||||
// auto scales_ptr = scales.data<T>();
|
||||
// for (int i = 0; i < batch_size; i++) {
|
||||
// _qmm_mxfp4_dispatch_typed<T>(
|
||||
// out_ptr + i * M * N,
|
||||
// x_ptr + elem_to_loc(i * M * K, x.shape(), x.strides()),
|
||||
// w_ptr + elem_to_loc(i * w_els, w.shape(), w.strides()),
|
||||
// scales_ptr + elem_to_loc(i * g_els, scales.shape(),
|
||||
// scales.strides()), M, N, K, transposed_w);
|
||||
// }
|
||||
// }
|
||||
//
|
||||
//
|
||||
// void _qmm_mxfp4_dispatch(
|
||||
// array& out,
|
||||
// const array& x,
|
||||
// const array& w,
|
||||
// const array& scales,
|
||||
// bool transposed_w) {
|
||||
// switch (x.dtype()) {
|
||||
// case bfloat16:
|
||||
// _qmm_mxfp4_dispatch_typed<bfloat16>(out, x, w, scales, transposed_w);
|
||||
// break;
|
||||
// default:
|
||||
// throw std::invalid_argument(
|
||||
// "[quantized_matmul] only bfloat is supported for mxfp4");
|
||||
// }
|
||||
// }
|
||||
|
||||
template <typename T>
|
||||
void _bs_qmm_dispatch_typed(
|
||||
array& out,
|
||||
@ -521,7 +566,6 @@ void QuantizedMatmul::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
auto& x_pre = inputs[0];
|
||||
auto& w_pre = inputs[1];
|
||||
auto& scales_pre = inputs[2];
|
||||
auto& biases_pre = inputs[3];
|
||||
|
||||
std::vector<array> temps;
|
||||
auto ensure_row_contiguous = [s = stream(), &temps](const array& arr) {
|
||||
@ -537,7 +581,6 @@ void QuantizedMatmul::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
auto x = ensure_row_contiguous(x_pre);
|
||||
auto w = ensure_row_contiguous(w_pre);
|
||||
auto scales = ensure_row_contiguous(scales_pre);
|
||||
auto biases = ensure_row_contiguous(biases_pre);
|
||||
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
|
||||
@ -546,18 +589,31 @@ void QuantizedMatmul::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
encoder.set_input_array(x);
|
||||
encoder.set_input_array(w);
|
||||
encoder.set_input_array(scales);
|
||||
encoder.set_input_array(biases);
|
||||
encoder.set_output_array(out);
|
||||
encoder.dispatch([out = array::unsafe_weak_copy(out),
|
||||
x = array::unsafe_weak_copy(x),
|
||||
w = array::unsafe_weak_copy(w),
|
||||
scales = array::unsafe_weak_copy(scales),
|
||||
biases = array::unsafe_weak_copy(biases),
|
||||
group_size_ = group_size_,
|
||||
bits_ = bits_,
|
||||
transpose_ = transpose_]() mutable {
|
||||
_qmm_dispatch(out, x, w, scales, biases, group_size_, bits_, transpose_);
|
||||
});
|
||||
if (mode_ == "affine") {
|
||||
auto biases = ensure_row_contiguous(inputs[3]);
|
||||
encoder.set_input_array(biases);
|
||||
encoder.dispatch([out = array::unsafe_weak_copy(out),
|
||||
x = array::unsafe_weak_copy(x),
|
||||
w = array::unsafe_weak_copy(w),
|
||||
scales = array::unsafe_weak_copy(scales),
|
||||
biases = array::unsafe_weak_copy(biases),
|
||||
group_size_ = group_size_,
|
||||
bits_ = bits_,
|
||||
transpose_ = transpose_]() mutable {
|
||||
_qmm_dispatch(out, x, w, scales, biases, group_size_, bits_, transpose_);
|
||||
});
|
||||
} else {
|
||||
// encoder.dispatch([out = array::unsafe_weak_copy(out),
|
||||
// x = array::unsafe_weak_copy(x),
|
||||
// w = array::unsafe_weak_copy(w),
|
||||
// scales = array::unsafe_weak_copy(scales),
|
||||
// group_size_ = group_size_,
|
||||
// bits_ = bits_,
|
||||
// transpose_ = transpose_]() mutable {
|
||||
// _qmm_mxfp4_dispatch(out, x, w, scales, transpose_);
|
||||
// });
|
||||
}
|
||||
}
|
||||
|
||||
void GatherQMM::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
@ -705,7 +761,7 @@ void dispatch_quantize(
|
||||
w_ptr, out_ptr, scales_ptr, biases_ptr, bits, group_size, w.size());
|
||||
}
|
||||
|
||||
void fast::AffineQuantize::eval_cpu(
|
||||
void fast::Quantize::eval_cpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
auto ensure_row_contiguous = [s = stream()](const array& arr) {
|
||||
@ -764,7 +820,7 @@ void fast::AffineQuantize::eval_cpu(
|
||||
}
|
||||
} else {
|
||||
throw std::runtime_error(
|
||||
"[fast::AffineQuantize::eval_cpu] Only supports floating point inputs");
|
||||
"[fast::Quantize::eval_cpu] Only supports floating point inputs");
|
||||
}
|
||||
});
|
||||
}
|
||||
|
@ -109,6 +109,7 @@ if(NOT MLX_METAL_JIT)
|
||||
reduction/reduce_col.h
|
||||
reduction/reduce_row.h)
|
||||
build_kernel(quantized quantized.h ${STEEL_HEADERS})
|
||||
build_kernel(fp4_quantized fp4_quantized.h ${STEEL_HEADERS})
|
||||
build_kernel(scan scan.h)
|
||||
build_kernel(softmax softmax.h)
|
||||
build_kernel(logsumexp logsumexp.h)
|
||||
|
1789
mlx/backend/metal/kernels/fp4_quantized.h
Normal file
1789
mlx/backend/metal/kernels/fp4_quantized.h
Normal file
File diff suppressed because it is too large
Load Diff
126
mlx/backend/metal/kernels/fp4_quantized.metal
Normal file
126
mlx/backend/metal/kernels/fp4_quantized.metal
Normal file
@ -0,0 +1,126 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
// clang-format off
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
#include "mlx/backend/metal/kernels/steel/gemm/gemm.h"
|
||||
#include "mlx/backend/metal/kernels/fp4_quantized.h"
|
||||
|
||||
#define instantiate_quantized(name, type) \
|
||||
instantiate_kernel( \
|
||||
#name "_" #type "_gs_32_b_4", \
|
||||
name, \
|
||||
type, \
|
||||
32, \
|
||||
uint8_t)
|
||||
|
||||
#define instantiate_quantized_batched(name, type, batched) \
|
||||
instantiate_kernel( \
|
||||
#name "_" #type "_gs_32_b_4_batch_" #batched, \
|
||||
name, \
|
||||
type, \
|
||||
32, \
|
||||
batched, \
|
||||
uint8_t)
|
||||
|
||||
#define instantiate_quantized_aligned(name, type, aligned) \
|
||||
instantiate_kernel( \
|
||||
#name "_" #type "_gs_32_b_4_alN_" #aligned, \
|
||||
name, \
|
||||
type, \
|
||||
32, \
|
||||
aligned, \
|
||||
uint8_t)
|
||||
|
||||
#define instantiate_quantized_aligned_batched(name, type, aligned, batched) \
|
||||
instantiate_kernel( \
|
||||
#name "_" #type "_gs_32_b_4_alN_" #aligned "_batch_" #batched, \
|
||||
name, \
|
||||
type, \
|
||||
32, \
|
||||
aligned, \
|
||||
batched, \
|
||||
uint8_t)
|
||||
|
||||
#define instantiate_quantized_quad(name, type, D, batched) \
|
||||
instantiate_kernel( \
|
||||
#name "_" #type "_gs_32_b_4_d_" #D "_batch_" #batched, \
|
||||
name, \
|
||||
type, \
|
||||
32, \
|
||||
D, \
|
||||
batched, \
|
||||
uint8_t)
|
||||
|
||||
#define instantiate_quantized_split_k(name, type, split_k) \
|
||||
instantiate_kernel( \
|
||||
#name "_" #type "_gs_32_b_4_spk_" #split_k, \
|
||||
name, \
|
||||
type, \
|
||||
32, \
|
||||
split_k, \
|
||||
uint8_t)
|
||||
|
||||
#define instantiate_gather_qmm_rhs(func, name, type, bm, bn, bk, wm, wn, transpose) \
|
||||
instantiate_kernel( \
|
||||
#name "_" #type "_gs_32_b_4_bm_" #bm "_bn_" #bn "_bk_" #bk "_wm_" #wm "_wn_" #wn, \
|
||||
func, \
|
||||
type, \
|
||||
32, \
|
||||
uint8_t, \
|
||||
bm, \
|
||||
bn, \
|
||||
bk, \
|
||||
wm, \
|
||||
wn, \
|
||||
transpose)
|
||||
|
||||
#define instantiate_quantized_batched_wrap(name, type) \
|
||||
instantiate_quantized_batched(name, type, 1) \
|
||||
instantiate_quantized_batched(name, type, 0)
|
||||
|
||||
#define instantiate_quantized_all_batched(type) \
|
||||
instantiate_quantized_batched_wrap(mxfp4_qmv_fast, type) \
|
||||
instantiate_quantized_batched_wrap(mxfp4_qmv, type) \
|
||||
instantiate_quantized_batched_wrap(mxfp4_qvm, type) \
|
||||
instantiate_quantized_batched_wrap(mxfp4_qmm_n, type)
|
||||
|
||||
#define instantiate_quantized_all_single(type) \
|
||||
instantiate_quantized(mxfp4_gather_qmv_fast, type) \
|
||||
instantiate_quantized(mxfp4_gather_qmv, type) \
|
||||
instantiate_quantized(mxfp4_gather_qvm, type) \
|
||||
instantiate_quantized(mxfp4_gather_qmm_n, type)
|
||||
|
||||
#define instantiate_quantized_all_aligned(type) \
|
||||
instantiate_quantized_aligned(mxfp4_gather_qmm_t, type, true) \
|
||||
instantiate_quantized_aligned(mxfp4_gather_qmm_t, type, false) \
|
||||
instantiate_quantized_aligned_batched(mxfp4_qmm_t, type, true, 1) \
|
||||
instantiate_quantized_aligned_batched(mxfp4_qmm_t, type, true, 0) \
|
||||
instantiate_quantized_aligned_batched(mxfp4_qmm_t, type, false, 1) \
|
||||
instantiate_quantized_aligned_batched(mxfp4_qmm_t, type, false, 0)
|
||||
|
||||
#define instantiate_quantized_all_quad(type) \
|
||||
instantiate_quantized_quad(mxfp4_qmv_quad, type, 64, 1) \
|
||||
instantiate_quantized_quad(mxfp4_qmv_quad, type, 64, 0) \
|
||||
instantiate_quantized_quad(mxfp4_qmv_quad, type, 128, 1) \
|
||||
instantiate_quantized_quad(mxfp4_qmv_quad, type, 128, 0)
|
||||
|
||||
#define instantiate_quantized_all_splitk(type) \
|
||||
instantiate_quantized_split_k(mxfp4_qvm_split_k, type, 8) \
|
||||
instantiate_quantized_split_k(mxfp4_qvm_split_k, type, 32)
|
||||
|
||||
#define instantiate_quantized_all_rhs(type) \
|
||||
instantiate_gather_qmm_rhs(mxfp4_gather_qmm_rhs, mxfp4_gather_qmm_rhs_nt, type, 16, 32, 32, 1, 2, true) \
|
||||
instantiate_gather_qmm_rhs(mxfp4_gather_qmm_rhs, mxfp4_gather_qmm_rhs_nn, type, 16, 32, 32, 1, 2, false)
|
||||
|
||||
#define instantiate_quantized_types(type) \
|
||||
instantiate_quantized_all_batched(type) \
|
||||
instantiate_quantized_all_quad(type) \
|
||||
instantiate_quantized_all_splitk(type) \
|
||||
instantiate_quantized_all_single(type) \
|
||||
instantiate_quantized_all_aligned(type) \
|
||||
instantiate_quantized_all_rhs(type)
|
||||
|
||||
instantiate_quantized_types(float)
|
||||
instantiate_quantized_types(bfloat16_t)
|
||||
instantiate_quantized_types(float16_t)
|
||||
// clang-format on
|
@ -1434,7 +1434,7 @@ METAL_FUNC void adjust_matrix_offsets(
|
||||
}
|
||||
|
||||
template <typename T, int group_size, int bits, int D, bool batched>
|
||||
[[kernel]] void qmv_quad(
|
||||
[[kernel]] void affine_qmv_quad(
|
||||
const device uint32_t* w [[buffer(0)]],
|
||||
const device T* scales [[buffer(1)]],
|
||||
const device T* biases [[buffer(2)]],
|
||||
@ -1486,7 +1486,7 @@ template <typename T, int group_size, int bits, int D, bool batched>
|
||||
}
|
||||
|
||||
template <typename T, int group_size, int bits, bool batched>
|
||||
[[kernel]] void qmv_fast(
|
||||
[[kernel]] void affine_qmv_fast(
|
||||
const device uint32_t* w [[buffer(0)]],
|
||||
const device T* scales [[buffer(1)]],
|
||||
const device T* biases [[buffer(2)]],
|
||||
@ -1538,7 +1538,7 @@ template <typename T, int group_size, int bits, bool batched>
|
||||
}
|
||||
|
||||
template <typename T, const int group_size, const int bits, bool batched>
|
||||
[[kernel]] void qmv(
|
||||
[[kernel]] void affine_qmv(
|
||||
const device uint32_t* w [[buffer(0)]],
|
||||
const device T* scales [[buffer(1)]],
|
||||
const device T* biases [[buffer(2)]],
|
||||
@ -1590,7 +1590,7 @@ template <typename T, const int group_size, const int bits, bool batched>
|
||||
}
|
||||
|
||||
template <typename T, const int group_size, const int bits, bool batched>
|
||||
[[kernel]] void qvm(
|
||||
[[kernel]] void affine_qvm(
|
||||
const device uint32_t* w [[buffer(0)]],
|
||||
const device T* scales [[buffer(1)]],
|
||||
const device T* biases [[buffer(2)]],
|
||||
@ -1642,7 +1642,7 @@ template <typename T, const int group_size, const int bits, bool batched>
|
||||
}
|
||||
|
||||
template <typename T, const int group_size, const int bits, int split_k = 32>
|
||||
[[kernel]] void qvm_split_k(
|
||||
[[kernel]] void affine_qvm_split_k(
|
||||
const device uint32_t* w [[buffer(0)]],
|
||||
const device T* scales [[buffer(1)]],
|
||||
const device T* biases [[buffer(2)]],
|
||||
@ -1706,7 +1706,7 @@ template <
|
||||
const int BM = 32,
|
||||
const int BK = 32,
|
||||
const int BN = 32>
|
||||
[[kernel]] void qmm_t(
|
||||
[[kernel]] void affine_qmm_t(
|
||||
const device uint32_t* w [[buffer(0)]],
|
||||
const device T* scales [[buffer(1)]],
|
||||
const device T* biases [[buffer(2)]],
|
||||
@ -1764,7 +1764,7 @@ template <
|
||||
const int BM = 32,
|
||||
const int BK = 32,
|
||||
const int BN = 32>
|
||||
[[kernel]] void qmm_n(
|
||||
[[kernel]] void affine_qmm_n(
|
||||
const device uint32_t* w [[buffer(0)]],
|
||||
const device T* scales [[buffer(1)]],
|
||||
const device T* biases [[buffer(2)]],
|
||||
@ -1817,7 +1817,7 @@ template <
|
||||
}
|
||||
|
||||
template <typename T, int group_size, int bits>
|
||||
[[kernel]] void gather_qmv_fast(
|
||||
[[kernel]] void affine_gather_qmv_fast(
|
||||
const device uint32_t* w [[buffer(0)]],
|
||||
const device T* scales [[buffer(1)]],
|
||||
const device T* biases [[buffer(2)]],
|
||||
@ -1879,7 +1879,7 @@ template <typename T, int group_size, int bits>
|
||||
}
|
||||
|
||||
template <typename T, int group_size, int bits>
|
||||
[[kernel]] void gather_qmv(
|
||||
[[kernel]] void affine_gather_qmv(
|
||||
const device uint32_t* w [[buffer(0)]],
|
||||
const device T* scales [[buffer(1)]],
|
||||
const device T* biases [[buffer(2)]],
|
||||
@ -1941,7 +1941,7 @@ template <typename T, int group_size, int bits>
|
||||
}
|
||||
|
||||
template <typename T, int group_size, int bits>
|
||||
[[kernel]] void gather_qvm(
|
||||
[[kernel]] void affine_gather_qvm(
|
||||
const device uint32_t* w [[buffer(0)]],
|
||||
const device T* scales [[buffer(1)]],
|
||||
const device T* biases [[buffer(2)]],
|
||||
@ -2010,7 +2010,7 @@ template <
|
||||
const int BM = 32,
|
||||
const int BK = 32,
|
||||
const int BN = 32>
|
||||
[[kernel]] void gather_qmm_t(
|
||||
[[kernel]] void affine_gather_qmm_t(
|
||||
const device uint32_t* w [[buffer(0)]],
|
||||
const device T* scales [[buffer(1)]],
|
||||
const device T* biases [[buffer(2)]],
|
||||
@ -2077,7 +2077,7 @@ template <
|
||||
const int BM = 32,
|
||||
const int BK = 32,
|
||||
const int BN = 32>
|
||||
[[kernel]] void gather_qmm_n(
|
||||
[[kernel]] void affine_gather_qmm_n(
|
||||
const device uint32_t* w [[buffer(0)]],
|
||||
const device T* scales [[buffer(1)]],
|
||||
const device T* biases [[buffer(2)]],
|
||||
@ -2234,7 +2234,7 @@ template <
|
||||
int WM,
|
||||
int WN,
|
||||
bool transpose>
|
||||
[[kernel]] void gather_qmm_rhs(
|
||||
[[kernel]] void affine_gather_qmm_rhs(
|
||||
const device T* x [[buffer(0)]],
|
||||
const device uint32_t* w [[buffer(1)]],
|
||||
const device T* scales [[buffer(2)]],
|
||||
|
@ -79,40 +79,40 @@
|
||||
instantiate_quantized_batched(name, type, group_size, bits, 0)
|
||||
|
||||
#define instantiate_quantized_all_batched(type, group_size, bits) \
|
||||
instantiate_quantized_batched_wrap(qmv_fast, type, group_size, bits) \
|
||||
instantiate_quantized_batched_wrap(qmv, type, group_size, bits) \
|
||||
instantiate_quantized_batched_wrap(qvm, type, group_size, bits) \
|
||||
instantiate_quantized_batched_wrap(qmm_n, type, group_size, bits)
|
||||
instantiate_quantized_batched_wrap(affine_qmv_fast, type, group_size, bits) \
|
||||
instantiate_quantized_batched_wrap(affine_qmv, type, group_size, bits) \
|
||||
instantiate_quantized_batched_wrap(affine_qvm, type, group_size, bits) \
|
||||
instantiate_quantized_batched_wrap(affine_qmm_n, type, group_size, bits)
|
||||
|
||||
#define instantiate_quantized_all_single(type, group_size, bits) \
|
||||
instantiate_quantized(affine_quantize, type, group_size, bits) \
|
||||
instantiate_quantized(affine_dequantize, type, group_size, bits) \
|
||||
instantiate_quantized(gather_qmv_fast, type, group_size, bits) \
|
||||
instantiate_quantized(gather_qmv, type, group_size, bits) \
|
||||
instantiate_quantized(gather_qvm, type, group_size, bits) \
|
||||
instantiate_quantized(gather_qmm_n, type, group_size, bits)
|
||||
instantiate_quantized(affine_gather_qmv_fast, type, group_size, bits) \
|
||||
instantiate_quantized(affine_gather_qmv, type, group_size, bits) \
|
||||
instantiate_quantized(affine_gather_qvm, type, group_size, bits) \
|
||||
instantiate_quantized(affine_gather_qmm_n, type, group_size, bits)
|
||||
|
||||
#define instantiate_quantized_all_aligned(type, group_size, bits) \
|
||||
instantiate_quantized_aligned(gather_qmm_t, type, group_size, bits, true) \
|
||||
instantiate_quantized_aligned(gather_qmm_t, type, group_size, bits, false) \
|
||||
instantiate_quantized_aligned_batched(qmm_t, type, group_size, bits, true, 1) \
|
||||
instantiate_quantized_aligned_batched(qmm_t, type, group_size, bits, true, 0) \
|
||||
instantiate_quantized_aligned_batched(qmm_t, type, group_size, bits, false, 1) \
|
||||
instantiate_quantized_aligned_batched(qmm_t, type, group_size, bits, false, 0)
|
||||
instantiate_quantized_aligned(affine_gather_qmm_t, type, group_size, bits, true) \
|
||||
instantiate_quantized_aligned(affine_gather_qmm_t, type, group_size, bits, false) \
|
||||
instantiate_quantized_aligned_batched(affine_qmm_t, type, group_size, bits, true, 1) \
|
||||
instantiate_quantized_aligned_batched(affine_qmm_t, type, group_size, bits, true, 0) \
|
||||
instantiate_quantized_aligned_batched(affine_qmm_t, type, group_size, bits, false, 1) \
|
||||
instantiate_quantized_aligned_batched(affine_qmm_t, type, group_size, bits, false, 0)
|
||||
|
||||
#define instantiate_quantized_all_quad(type, group_size, bits) \
|
||||
instantiate_quantized_quad(qmv_quad, type, group_size, bits, 64, 1) \
|
||||
instantiate_quantized_quad(qmv_quad, type, group_size, bits, 64, 0) \
|
||||
instantiate_quantized_quad(qmv_quad, type, group_size, bits, 128, 1) \
|
||||
instantiate_quantized_quad(qmv_quad, type, group_size, bits, 128, 0)
|
||||
instantiate_quantized_quad(affine_qmv_quad, type, group_size, bits, 64, 1) \
|
||||
instantiate_quantized_quad(affine_qmv_quad, type, group_size, bits, 64, 0) \
|
||||
instantiate_quantized_quad(affine_qmv_quad, type, group_size, bits, 128, 1) \
|
||||
instantiate_quantized_quad(affine_qmv_quad, type, group_size, bits, 128, 0)
|
||||
|
||||
#define instantiate_quantized_all_splitk(type, group_size, bits) \
|
||||
instantiate_quantized_split_k(qvm_split_k, type, group_size, bits, 8) \
|
||||
instantiate_quantized_split_k(qvm_split_k, type, group_size, bits, 32)
|
||||
instantiate_quantized_split_k(affine_qvm_split_k, type, group_size, bits, 8) \
|
||||
instantiate_quantized_split_k(affine_qvm_split_k, type, group_size, bits, 32)
|
||||
|
||||
#define instantiate_quantized_all_rhs(type, group_size, bits) \
|
||||
instantiate_gather_qmm_rhs(gather_qmm_rhs, gather_qmm_rhs_nt, type, group_size, bits, 16, 32, 32, 1, 2, true) \
|
||||
instantiate_gather_qmm_rhs(gather_qmm_rhs, gather_qmm_rhs_nn, type, group_size, bits, 16, 32, 32, 1, 2, false)
|
||||
instantiate_gather_qmm_rhs(affine_gather_qmm_rhs, affine_gather_qmm_rhs_nt, type, group_size, bits, 16, 32, 32, 1, 2, true) \
|
||||
instantiate_gather_qmm_rhs(affine_gather_qmm_rhs, affine_gather_qmm_rhs_nn, type, group_size, bits, 16, 32, 32, 1, 2, false)
|
||||
|
||||
#define instantiate_quantized_funcs(type, group_size, bits) \
|
||||
instantiate_quantized_all_single(type, group_size, bits) \
|
||||
|
@ -99,7 +99,7 @@ inline int add_strides_and_shapes(
|
||||
const array& x,
|
||||
const array& w,
|
||||
const array& scales,
|
||||
const array& biases,
|
||||
const std::optional<array>& biases,
|
||||
int offset) {
|
||||
if (skip) {
|
||||
return 0;
|
||||
@ -109,16 +109,18 @@ inline int add_strides_and_shapes(
|
||||
|
||||
int x_batch_ndims = x.ndim() - 2;
|
||||
int w_batch_ndims = w.ndim() - 2;
|
||||
compute_encoder.set_bytes(x_batch_ndims, offset);
|
||||
compute_encoder.set_vector_bytes(x.shape(), offset + 1);
|
||||
compute_encoder.set_vector_bytes(x.strides(), offset + 2);
|
||||
compute_encoder.set_bytes(w_batch_ndims, offset + 3);
|
||||
compute_encoder.set_vector_bytes(w.shape(), offset + 4);
|
||||
compute_encoder.set_vector_bytes(w.strides(), offset + 5);
|
||||
compute_encoder.set_vector_bytes(scales.strides(), offset + 6);
|
||||
compute_encoder.set_vector_bytes(biases.strides(), offset + 7);
|
||||
compute_encoder.set_bytes(x_batch_ndims, offset++);
|
||||
compute_encoder.set_vector_bytes(x.shape(), offset++);
|
||||
compute_encoder.set_vector_bytes(x.strides(), offset++);
|
||||
compute_encoder.set_bytes(w_batch_ndims, offset++);
|
||||
compute_encoder.set_vector_bytes(w.shape(), offset++);
|
||||
compute_encoder.set_vector_bytes(w.strides(), offset++);
|
||||
compute_encoder.set_vector_bytes(scales.strides(), offset++);
|
||||
if (biases) {
|
||||
compute_encoder.set_vector_bytes(biases->strides(), offset++);
|
||||
}
|
||||
|
||||
return 8;
|
||||
return offset;
|
||||
}
|
||||
|
||||
inline int add_gather_strides_and_shapes(
|
||||
@ -130,12 +132,12 @@ inline int add_gather_strides_and_shapes(
|
||||
lhs_indices.shape(), {lhs_indices.strides(), rhs_indices.strides()});
|
||||
int ndims = shape.size();
|
||||
|
||||
compute_encoder.set_bytes(ndims, offset);
|
||||
compute_encoder.set_vector_bytes(shape, offset + 1);
|
||||
compute_encoder.set_vector_bytes(strides[0], offset + 2);
|
||||
compute_encoder.set_vector_bytes(strides[1], offset + 3);
|
||||
compute_encoder.set_bytes(ndims, offset++);
|
||||
compute_encoder.set_vector_bytes(shape, offset++);
|
||||
compute_encoder.set_vector_bytes(strides[0], offset++);
|
||||
compute_encoder.set_vector_bytes(strides[1], offset++);
|
||||
|
||||
return 4;
|
||||
return offset;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
@ -144,7 +146,7 @@ void qmv_quad(
|
||||
const array& x,
|
||||
const array& w,
|
||||
const array& scales,
|
||||
const array& biases,
|
||||
const std::optional<array>& biases,
|
||||
array& out,
|
||||
int group_size,
|
||||
int bits,
|
||||
@ -152,7 +154,8 @@ void qmv_quad(
|
||||
int N,
|
||||
int K,
|
||||
metal::Device& d,
|
||||
const Stream& s) {
|
||||
const Stream& s,
|
||||
const std::string& mode) {
|
||||
int B = out.size() / M / N;
|
||||
|
||||
constexpr int quads_per_simd = 8;
|
||||
@ -165,9 +168,10 @@ void qmv_quad(
|
||||
std::string kname;
|
||||
kname.reserve(64);
|
||||
std::string type_string = get_type_string(x.dtype());
|
||||
|
||||
concatenate(
|
||||
kname,
|
||||
"qmv_quad_",
|
||||
mode + "_qmv_quad_",
|
||||
type_string,
|
||||
"_gs_",
|
||||
group_size,
|
||||
@ -177,20 +181,23 @@ void qmv_quad(
|
||||
K,
|
||||
B > 1 ? "_batch_1" : "_batch_0");
|
||||
auto template_def = get_template_definition(
|
||||
kname, "qmv_quad", type_string, group_size, bits, K, B > 1);
|
||||
kname, mode + "_qmv_quad", type_string, group_size, bits, K, B > 1);
|
||||
|
||||
auto kernel = get_quantized_kernel(d, kname, template_def);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
compute_encoder.set_input_array(w, 0);
|
||||
compute_encoder.set_input_array(scales, 1);
|
||||
compute_encoder.set_input_array(biases, 2);
|
||||
compute_encoder.set_input_array(x, 3);
|
||||
compute_encoder.set_output_array(out, 4);
|
||||
compute_encoder.set_bytes(K, 5);
|
||||
compute_encoder.set_bytes(N, 6);
|
||||
add_strides_and_shapes(compute_encoder, B <= 1, x, w, scales, biases, 7);
|
||||
int c = 0;
|
||||
compute_encoder.set_input_array(w, c++);
|
||||
compute_encoder.set_input_array(scales, c++);
|
||||
if (biases) {
|
||||
compute_encoder.set_input_array(*biases, c++);
|
||||
}
|
||||
compute_encoder.set_input_array(x, c++);
|
||||
compute_encoder.set_output_array(out, c++);
|
||||
compute_encoder.set_bytes(K, c++);
|
||||
compute_encoder.set_bytes(N, c++);
|
||||
add_strides_and_shapes(compute_encoder, B <= 1, x, w, scales, biases, c++);
|
||||
|
||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||
}
|
||||
@ -199,7 +206,7 @@ void qmv(
|
||||
const array& x,
|
||||
const array& w,
|
||||
const array& scales,
|
||||
const array& biases,
|
||||
const std::optional<array>& biases,
|
||||
array& out,
|
||||
int group_size,
|
||||
int bits,
|
||||
@ -207,7 +214,8 @@ void qmv(
|
||||
int N,
|
||||
int K,
|
||||
metal::Device& d,
|
||||
const Stream& s) {
|
||||
const Stream& s,
|
||||
const std::string& mode) {
|
||||
int B = out.size() / M / N;
|
||||
|
||||
int bn = 8;
|
||||
@ -219,9 +227,10 @@ void qmv(
|
||||
kname.reserve(64);
|
||||
std::string type_string = get_type_string(x.dtype());
|
||||
bool fast = N % bn == 0 && K % 512 == 0;
|
||||
|
||||
concatenate(
|
||||
kname,
|
||||
fast ? "qmv_fast_" : "qmv_",
|
||||
mode + (fast ? "_qmv_fast_" : "_qmv_"),
|
||||
type_string,
|
||||
"_gs_",
|
||||
group_size,
|
||||
@ -229,20 +238,28 @@ void qmv(
|
||||
bits,
|
||||
B > 1 ? "_batch_1" : "_batch_0");
|
||||
auto template_def = get_template_definition(
|
||||
kname, fast ? "qmv_fast" : "qmv", type_string, group_size, bits, B > 1);
|
||||
kname,
|
||||
mode + (fast ? "_qmv_fast" : "_qmv"),
|
||||
type_string,
|
||||
group_size,
|
||||
bits,
|
||||
B > 1);
|
||||
|
||||
auto kernel = get_quantized_kernel(d, kname, template_def);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
compute_encoder.set_input_array(w, 0);
|
||||
compute_encoder.set_input_array(scales, 1);
|
||||
compute_encoder.set_input_array(biases, 2);
|
||||
compute_encoder.set_input_array(x, 3);
|
||||
compute_encoder.set_output_array(out, 4);
|
||||
compute_encoder.set_bytes(K, 5);
|
||||
compute_encoder.set_bytes(N, 6);
|
||||
add_strides_and_shapes(compute_encoder, B <= 1, x, w, scales, biases, 7);
|
||||
int c = 0;
|
||||
compute_encoder.set_input_array(w, c++);
|
||||
compute_encoder.set_input_array(scales, c++);
|
||||
if (biases) {
|
||||
compute_encoder.set_input_array(*biases, c++);
|
||||
}
|
||||
compute_encoder.set_input_array(x, c++);
|
||||
compute_encoder.set_output_array(out, c++);
|
||||
compute_encoder.set_bytes(K, c++);
|
||||
compute_encoder.set_bytes(N, c++);
|
||||
add_strides_and_shapes(compute_encoder, B <= 1, x, w, scales, biases, c);
|
||||
|
||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||
}
|
||||
@ -251,7 +268,7 @@ void qvm_split_k(
|
||||
const array& x,
|
||||
const array& w,
|
||||
const array& scales,
|
||||
const array& biases,
|
||||
const std::optional<array>& biases,
|
||||
array& out,
|
||||
int group_size,
|
||||
int bits,
|
||||
@ -259,7 +276,8 @@ void qvm_split_k(
|
||||
int N,
|
||||
int K,
|
||||
metal::Device& d,
|
||||
const Stream& s) {
|
||||
const Stream& s,
|
||||
const std::string& mode) {
|
||||
int split_k = K > 8192 ? 32 : 8;
|
||||
int split_D = (K + split_k - 1) / split_k;
|
||||
int B = out.size() / M / N;
|
||||
@ -283,7 +301,6 @@ void qvm_split_k(
|
||||
auto w_shape = w.shape();
|
||||
auto w_strides = w.strides();
|
||||
auto s_strides = scales.strides();
|
||||
auto b_strides = biases.strides();
|
||||
|
||||
// Add split_k dim with reshapes
|
||||
x_shape.insert(x_shape.end() - 2, split_k);
|
||||
@ -297,7 +314,6 @@ void qvm_split_k(
|
||||
w_strides.insert(w_strides.end() - 2, split_D * w.shape(-1));
|
||||
w_batch_ndims += 1;
|
||||
s_strides.insert(s_strides.end() - 2, split_D * scales.shape(-1));
|
||||
b_strides.insert(b_strides.end() - 2, split_D * biases.shape(-1));
|
||||
|
||||
int final_block_size = K - (split_k - 1) * split_D;
|
||||
|
||||
@ -315,7 +331,7 @@ void qvm_split_k(
|
||||
kname.reserve(64);
|
||||
concatenate(
|
||||
kname,
|
||||
"qvm_split_k_",
|
||||
mode + "_qvm_split_k_",
|
||||
type_string,
|
||||
"_gs_",
|
||||
group_size,
|
||||
@ -324,30 +340,37 @@ void qvm_split_k(
|
||||
"_spk_",
|
||||
split_k);
|
||||
auto template_def = get_template_definition(
|
||||
kname, "qvm_split_k", type_string, group_size, bits, split_k);
|
||||
kname, mode + "_qvm_split_k", type_string, group_size, bits, split_k);
|
||||
|
||||
// Encode and dispatch kernel
|
||||
auto kernel = get_quantized_kernel(d, kname, template_def);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
compute_encoder.set_input_array(w, 0);
|
||||
compute_encoder.set_input_array(scales, 1);
|
||||
compute_encoder.set_input_array(biases, 2);
|
||||
compute_encoder.set_input_array(x, 3);
|
||||
compute_encoder.set_output_array(intermediate, 4);
|
||||
compute_encoder.set_bytes(split_D, 5);
|
||||
compute_encoder.set_bytes(N, 6);
|
||||
int c = 0;
|
||||
compute_encoder.set_input_array(w, c++);
|
||||
compute_encoder.set_input_array(scales, c++);
|
||||
if (biases) {
|
||||
compute_encoder.set_input_array(*biases, c++);
|
||||
}
|
||||
compute_encoder.set_input_array(x, c++);
|
||||
compute_encoder.set_output_array(intermediate, c++);
|
||||
compute_encoder.set_bytes(split_D, c++);
|
||||
compute_encoder.set_bytes(N, c++);
|
||||
|
||||
compute_encoder.set_bytes(x_batch_ndims, 7);
|
||||
compute_encoder.set_vector_bytes(x_shape, 8);
|
||||
compute_encoder.set_vector_bytes(x_strides, 9);
|
||||
compute_encoder.set_bytes(w_batch_ndims, 10);
|
||||
compute_encoder.set_vector_bytes(w_shape, 11);
|
||||
compute_encoder.set_vector_bytes(w_strides, 12);
|
||||
compute_encoder.set_vector_bytes(s_strides, 13);
|
||||
compute_encoder.set_vector_bytes(b_strides, 14);
|
||||
compute_encoder.set_bytes(final_block_size, 15);
|
||||
compute_encoder.set_bytes(x_batch_ndims, c++);
|
||||
compute_encoder.set_vector_bytes(x_shape, c++);
|
||||
compute_encoder.set_vector_bytes(x_strides, c++);
|
||||
compute_encoder.set_bytes(w_batch_ndims, c++);
|
||||
compute_encoder.set_vector_bytes(w_shape, c++);
|
||||
compute_encoder.set_vector_bytes(w_strides, c++);
|
||||
compute_encoder.set_vector_bytes(s_strides, c++);
|
||||
if (biases) {
|
||||
auto b_strides = biases->strides();
|
||||
b_strides.insert(b_strides.end() - 2, split_D * biases->shape(-1));
|
||||
compute_encoder.set_vector_bytes(b_strides, c++);
|
||||
}
|
||||
compute_encoder.set_bytes(final_block_size, c++);
|
||||
|
||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||
|
||||
@ -364,7 +387,7 @@ void qvm(
|
||||
const array& x,
|
||||
const array& w,
|
||||
const array& scales,
|
||||
const array& biases,
|
||||
const std::optional<array>& biases,
|
||||
array& out,
|
||||
int group_size,
|
||||
int bits,
|
||||
@ -372,7 +395,8 @@ void qvm(
|
||||
int N,
|
||||
int K,
|
||||
metal::Device& d,
|
||||
const Stream& s) {
|
||||
const Stream& s,
|
||||
const std::string& mode) {
|
||||
int B = out.size() / M / N;
|
||||
|
||||
int bn = 64;
|
||||
@ -385,7 +409,7 @@ void qvm(
|
||||
std::string type_string = get_type_string(x.dtype());
|
||||
concatenate(
|
||||
kname,
|
||||
"qvm_",
|
||||
mode + "_qvm_",
|
||||
type_string,
|
||||
"_gs_",
|
||||
group_size,
|
||||
@ -393,20 +417,23 @@ void qvm(
|
||||
bits,
|
||||
B > 1 ? "_batch_1" : "_batch_0");
|
||||
auto template_def = get_template_definition(
|
||||
kname, "qvm", type_string, group_size, bits, B > 1);
|
||||
kname, mode + "_qvm", type_string, group_size, bits, B > 1);
|
||||
|
||||
auto kernel = get_quantized_kernel(d, kname, template_def);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
compute_encoder.set_input_array(w, 0);
|
||||
compute_encoder.set_input_array(scales, 1);
|
||||
compute_encoder.set_input_array(biases, 2);
|
||||
compute_encoder.set_input_array(x, 3);
|
||||
compute_encoder.set_output_array(out, 4);
|
||||
compute_encoder.set_bytes(K, 5);
|
||||
compute_encoder.set_bytes(N, 6);
|
||||
add_strides_and_shapes(compute_encoder, B <= 1, x, w, scales, biases, 7);
|
||||
int c = 0;
|
||||
compute_encoder.set_input_array(w, c++);
|
||||
compute_encoder.set_input_array(scales, c++);
|
||||
if (biases) {
|
||||
compute_encoder.set_input_array(*biases, c++);
|
||||
}
|
||||
compute_encoder.set_input_array(x, c++);
|
||||
compute_encoder.set_output_array(out, c++);
|
||||
compute_encoder.set_bytes(K, c++);
|
||||
compute_encoder.set_bytes(N, c++);
|
||||
add_strides_and_shapes(compute_encoder, B <= 1, x, w, scales, biases, c++);
|
||||
|
||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||
}
|
||||
@ -415,7 +442,7 @@ void qmm(
|
||||
const array& x,
|
||||
const array& w,
|
||||
const array& scales,
|
||||
const array& biases,
|
||||
const std::optional<array>& biases,
|
||||
array& out,
|
||||
bool transpose,
|
||||
int group_size,
|
||||
@ -424,7 +451,8 @@ void qmm(
|
||||
int N,
|
||||
int K,
|
||||
metal::Device& d,
|
||||
const Stream& s) {
|
||||
const Stream& s,
|
||||
const std::string& mode) {
|
||||
int B = out.size() / M / N;
|
||||
|
||||
int wm = 2;
|
||||
@ -441,7 +469,7 @@ void qmm(
|
||||
std::string type_string = get_type_string(x.dtype());
|
||||
concatenate(
|
||||
kname,
|
||||
transpose ? "qmm_t_" : "qmm_n_",
|
||||
mode + (transpose ? "_qmm_t_" : "_qmm_n_"),
|
||||
type_string,
|
||||
"_gs_",
|
||||
group_size,
|
||||
@ -452,25 +480,34 @@ void qmm(
|
||||
std::string template_def;
|
||||
if (transpose) {
|
||||
template_def = get_template_definition(
|
||||
kname, "qmm_t", type_string, group_size, bits, aligned, batched);
|
||||
kname,
|
||||
mode + "_qmm_t",
|
||||
type_string,
|
||||
group_size,
|
||||
bits,
|
||||
aligned,
|
||||
batched);
|
||||
} else {
|
||||
template_def = get_template_definition(
|
||||
kname, "qmm_n", type_string, group_size, bits, batched);
|
||||
kname, mode + "_qmm_n", type_string, group_size, bits, batched);
|
||||
}
|
||||
|
||||
auto kernel = get_quantized_kernel(d, kname, template_def);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
compute_encoder.set_input_array(w, 0);
|
||||
compute_encoder.set_input_array(scales, 1);
|
||||
compute_encoder.set_input_array(biases, 2);
|
||||
compute_encoder.set_input_array(x, 3);
|
||||
compute_encoder.set_output_array(out, 4);
|
||||
compute_encoder.set_bytes(K, 5);
|
||||
compute_encoder.set_bytes(N, 6);
|
||||
compute_encoder.set_bytes(M, 7);
|
||||
add_strides_and_shapes(compute_encoder, B <= 1, x, w, scales, biases, 8);
|
||||
int c = 0;
|
||||
compute_encoder.set_input_array(w, c++);
|
||||
compute_encoder.set_input_array(scales, c++);
|
||||
if (biases) {
|
||||
compute_encoder.set_input_array(*biases, c++);
|
||||
}
|
||||
compute_encoder.set_input_array(x, c++);
|
||||
compute_encoder.set_output_array(out, c++);
|
||||
compute_encoder.set_bytes(K, c++);
|
||||
compute_encoder.set_bytes(N, c++);
|
||||
compute_encoder.set_bytes(M, c++);
|
||||
add_strides_and_shapes(compute_encoder, B <= 1, x, w, scales, biases, c);
|
||||
|
||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||
}
|
||||
@ -479,7 +516,7 @@ void gather_qmm(
|
||||
const array& x,
|
||||
const array& w,
|
||||
const array& scales,
|
||||
const array& biases,
|
||||
const std::optional<array>& biases,
|
||||
const array& lhs_indices,
|
||||
const array& rhs_indices,
|
||||
array& out,
|
||||
@ -490,7 +527,8 @@ void gather_qmm(
|
||||
int N,
|
||||
int K,
|
||||
metal::Device& d,
|
||||
const Stream& s) {
|
||||
const Stream& s,
|
||||
const std::string& mode) {
|
||||
int B = out.size() / M / N;
|
||||
|
||||
int wm = 2;
|
||||
@ -507,7 +545,7 @@ void gather_qmm(
|
||||
std::string type_string = get_type_string(x.dtype());
|
||||
concatenate(
|
||||
kname,
|
||||
transpose ? "gather_qmm_t_" : "gather_qmm_n_",
|
||||
mode + (transpose ? "_gather_qmm_t_" : "_gather_qmm_n_"),
|
||||
type_string,
|
||||
"_gs_",
|
||||
group_size,
|
||||
@ -517,30 +555,31 @@ void gather_qmm(
|
||||
std::string template_def;
|
||||
if (transpose) {
|
||||
template_def = get_template_definition(
|
||||
kname, "gather_qmm_t", type_string, group_size, bits, aligned);
|
||||
kname, mode + "_gather_qmm_t", type_string, group_size, bits, aligned);
|
||||
} else {
|
||||
template_def = get_template_definition(
|
||||
kname, "gather_qmm_n", type_string, group_size, bits);
|
||||
kname, mode + "_gather_qmm_n", type_string, group_size, bits);
|
||||
}
|
||||
|
||||
auto kernel = get_quantized_kernel(d, kname, template_def);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
compute_encoder.set_input_array(w, 0);
|
||||
compute_encoder.set_input_array(scales, 1);
|
||||
compute_encoder.set_input_array(biases, 2);
|
||||
compute_encoder.set_input_array(x, 3);
|
||||
compute_encoder.set_input_array(lhs_indices, 4);
|
||||
compute_encoder.set_input_array(rhs_indices, 5);
|
||||
compute_encoder.set_output_array(out, 6);
|
||||
compute_encoder.set_bytes(K, 7);
|
||||
compute_encoder.set_bytes(N, 8);
|
||||
compute_encoder.set_bytes(M, 9);
|
||||
int n =
|
||||
add_strides_and_shapes(compute_encoder, false, x, w, scales, biases, 10);
|
||||
add_gather_strides_and_shapes(
|
||||
compute_encoder, lhs_indices, rhs_indices, 10 + n);
|
||||
int c = 0;
|
||||
compute_encoder.set_input_array(w, c++);
|
||||
compute_encoder.set_input_array(scales, c++);
|
||||
if (biases) {
|
||||
compute_encoder.set_input_array(*biases, c++);
|
||||
}
|
||||
compute_encoder.set_input_array(x, c++);
|
||||
compute_encoder.set_input_array(lhs_indices, c++);
|
||||
compute_encoder.set_input_array(rhs_indices, c++);
|
||||
compute_encoder.set_output_array(out, c++);
|
||||
compute_encoder.set_bytes(K, c++);
|
||||
compute_encoder.set_bytes(N, c++);
|
||||
compute_encoder.set_bytes(M, c++);
|
||||
c = add_strides_and_shapes(compute_encoder, false, x, w, scales, biases, c);
|
||||
add_gather_strides_and_shapes(compute_encoder, lhs_indices, rhs_indices, c);
|
||||
|
||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||
}
|
||||
@ -549,7 +588,7 @@ void gather_qmv(
|
||||
const array& x,
|
||||
const array& w,
|
||||
const array& scales,
|
||||
const array& biases,
|
||||
const std::optional<array>& biases,
|
||||
const array& lhs_indices,
|
||||
const array& rhs_indices,
|
||||
array& out,
|
||||
@ -559,7 +598,8 @@ void gather_qmv(
|
||||
int N,
|
||||
int K,
|
||||
metal::Device& d,
|
||||
const Stream& s) {
|
||||
const Stream& s,
|
||||
const std::string& mode) {
|
||||
int B = out.size() / M / N;
|
||||
|
||||
int bn = 8;
|
||||
@ -573,7 +613,7 @@ void gather_qmv(
|
||||
bool fast = N % bn == 0 && K % 512 == 0;
|
||||
concatenate(
|
||||
kname,
|
||||
fast ? "gather_qmv_fast_" : "gather_qmv_",
|
||||
mode + (fast ? "_gather_qmv_fast_" : "_gather_qmv_"),
|
||||
type_string,
|
||||
"_gs_",
|
||||
group_size,
|
||||
@ -581,7 +621,7 @@ void gather_qmv(
|
||||
bits);
|
||||
auto template_def = get_template_definition(
|
||||
kname,
|
||||
fast ? "gather_qmv_fast" : "gather_qmv",
|
||||
mode + (fast ? "_gather_qmv_fast" : "_gather_qmv"),
|
||||
type_string,
|
||||
group_size,
|
||||
bits);
|
||||
@ -590,19 +630,20 @@ void gather_qmv(
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
compute_encoder.set_input_array(w, 0);
|
||||
compute_encoder.set_input_array(scales, 1);
|
||||
compute_encoder.set_input_array(biases, 2);
|
||||
compute_encoder.set_input_array(x, 3);
|
||||
compute_encoder.set_input_array(lhs_indices, 4);
|
||||
compute_encoder.set_input_array(rhs_indices, 5);
|
||||
compute_encoder.set_output_array(out, 6);
|
||||
compute_encoder.set_bytes(K, 7);
|
||||
compute_encoder.set_bytes(N, 8);
|
||||
int n =
|
||||
add_strides_and_shapes(compute_encoder, false, x, w, scales, biases, 9);
|
||||
add_gather_strides_and_shapes(
|
||||
compute_encoder, lhs_indices, rhs_indices, 9 + n);
|
||||
int c = 0;
|
||||
compute_encoder.set_input_array(w, c++);
|
||||
compute_encoder.set_input_array(scales, c++);
|
||||
if (biases) {
|
||||
compute_encoder.set_input_array(*biases, c++);
|
||||
}
|
||||
compute_encoder.set_input_array(x, c++);
|
||||
compute_encoder.set_input_array(lhs_indices, c++);
|
||||
compute_encoder.set_input_array(rhs_indices, c++);
|
||||
compute_encoder.set_output_array(out, c++);
|
||||
compute_encoder.set_bytes(K, c++);
|
||||
compute_encoder.set_bytes(N, c++);
|
||||
c = add_strides_and_shapes(compute_encoder, false, x, w, scales, biases, c);
|
||||
add_gather_strides_and_shapes(compute_encoder, lhs_indices, rhs_indices, c);
|
||||
|
||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||
}
|
||||
@ -611,7 +652,7 @@ void gather_qvm(
|
||||
const array& x,
|
||||
const array& w,
|
||||
const array& scales,
|
||||
const array& biases,
|
||||
const std::optional<array>& biases,
|
||||
const array& lhs_indices,
|
||||
const array& rhs_indices,
|
||||
array& out,
|
||||
@ -621,7 +662,8 @@ void gather_qvm(
|
||||
int N,
|
||||
int K,
|
||||
metal::Device& d,
|
||||
const Stream& s) {
|
||||
const Stream& s,
|
||||
const std::string& mode) {
|
||||
int B = out.size() / M / N;
|
||||
|
||||
int bn = 64;
|
||||
@ -633,27 +675,34 @@ void gather_qvm(
|
||||
kname.reserve(64);
|
||||
std::string type_string = get_type_string(x.dtype());
|
||||
concatenate(
|
||||
kname, "gather_qvm_", type_string, "_gs_", group_size, "_b_", bits);
|
||||
kname,
|
||||
mode + "_gather_qvm_",
|
||||
type_string,
|
||||
"_gs_",
|
||||
group_size,
|
||||
"_b_",
|
||||
bits);
|
||||
auto template_def = get_template_definition(
|
||||
kname, "gather_qvm", type_string, group_size, bits);
|
||||
kname, mode + "_gather_qvm", type_string, group_size, bits);
|
||||
|
||||
auto kernel = get_quantized_kernel(d, kname, template_def);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
compute_encoder.set_input_array(w, 0);
|
||||
compute_encoder.set_input_array(scales, 1);
|
||||
compute_encoder.set_input_array(biases, 2);
|
||||
compute_encoder.set_input_array(x, 3);
|
||||
compute_encoder.set_input_array(lhs_indices, 4);
|
||||
compute_encoder.set_input_array(rhs_indices, 5);
|
||||
compute_encoder.set_output_array(out, 6);
|
||||
compute_encoder.set_bytes(K, 7);
|
||||
compute_encoder.set_bytes(N, 8);
|
||||
int n =
|
||||
add_strides_and_shapes(compute_encoder, false, x, w, scales, biases, 9);
|
||||
add_gather_strides_and_shapes(
|
||||
compute_encoder, lhs_indices, rhs_indices, 9 + n);
|
||||
int c = 0;
|
||||
compute_encoder.set_input_array(w, c++);
|
||||
compute_encoder.set_input_array(scales, c++);
|
||||
if (biases) {
|
||||
compute_encoder.set_input_array(*biases, c++);
|
||||
}
|
||||
compute_encoder.set_input_array(x, c++);
|
||||
compute_encoder.set_input_array(lhs_indices, c++);
|
||||
compute_encoder.set_input_array(rhs_indices, c++);
|
||||
compute_encoder.set_output_array(out, c++);
|
||||
compute_encoder.set_bytes(K, c++);
|
||||
compute_encoder.set_bytes(N, c++);
|
||||
c = add_strides_and_shapes(compute_encoder, false, x, w, scales, biases, c++);
|
||||
add_gather_strides_and_shapes(compute_encoder, lhs_indices, rhs_indices, c);
|
||||
|
||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||
}
|
||||
@ -662,7 +711,7 @@ void gather_qmm_rhs(
|
||||
const array& x_,
|
||||
const array& w_,
|
||||
const array& scales_,
|
||||
const array& biases_,
|
||||
const std::optional<array>& biases_,
|
||||
const array& indices_,
|
||||
array& out,
|
||||
bool transpose,
|
||||
@ -672,7 +721,8 @@ void gather_qmm_rhs(
|
||||
int N,
|
||||
int K,
|
||||
metal::Device& d,
|
||||
const Stream& s) {
|
||||
const Stream& s,
|
||||
const std::string mode) {
|
||||
// Start by normalizing the indices
|
||||
array indices = ensure_row_contiguous(indices_, d, s);
|
||||
|
||||
@ -697,7 +747,6 @@ void gather_qmm_rhs(
|
||||
array x = broadcast_with_indices(x_);
|
||||
array w = ensure_row_contiguous(w_, d, s);
|
||||
array scales = ensure_row_contiguous(scales_, d, s);
|
||||
array biases = ensure_row_contiguous(biases_, d, s);
|
||||
|
||||
// TODO: Tune the block sizes
|
||||
int bm = 16, bn = 32, bk = 32;
|
||||
@ -713,7 +762,7 @@ void gather_qmm_rhs(
|
||||
std::string type_string = get_type_string(x.dtype());
|
||||
concatenate(
|
||||
kname,
|
||||
transpose ? "gather_qmm_rhs_nt_" : "gather_qmm_rhs_nn_",
|
||||
mode + (transpose ? "_gather_qmm_rhs_nt_" : "_gather_qmm_rhs_nn_"),
|
||||
type_string,
|
||||
"_gs_",
|
||||
group_size,
|
||||
@ -770,15 +819,19 @@ void gather_qmm_rhs(
|
||||
MTL::Size group_dims(32, wn, wm);
|
||||
MTL::Size grid_dims((N + bn - 1) / bn, (M + bm - 1) / bm, 1);
|
||||
|
||||
compute_encoder.set_input_array(x, 0);
|
||||
compute_encoder.set_input_array(w, 1);
|
||||
compute_encoder.set_input_array(scales, 2);
|
||||
compute_encoder.set_input_array(biases, 3);
|
||||
compute_encoder.set_input_array(indices, 4);
|
||||
compute_encoder.set_output_array(out, 5);
|
||||
compute_encoder.set_bytes(M, 6);
|
||||
compute_encoder.set_bytes(N, 7);
|
||||
compute_encoder.set_bytes(K, 8);
|
||||
int c = 0;
|
||||
compute_encoder.set_input_array(x, c++);
|
||||
compute_encoder.set_input_array(w, c++);
|
||||
compute_encoder.set_input_array(scales, c++);
|
||||
if (biases_) {
|
||||
array biases = ensure_row_contiguous(*biases_, d, s);
|
||||
compute_encoder.set_input_array(biases, c++);
|
||||
}
|
||||
compute_encoder.set_input_array(indices, c++);
|
||||
compute_encoder.set_output_array(out, c++);
|
||||
compute_encoder.set_bytes(M, c++);
|
||||
compute_encoder.set_bytes(N, c++);
|
||||
compute_encoder.set_bytes(K, c++);
|
||||
|
||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||
}
|
||||
@ -794,7 +847,10 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
array x = ensure_row_contiguous_matrix(inputs[0], d, s);
|
||||
array w = ensure_row_contiguous_matrix(inputs[1], d, s);
|
||||
array scales = ensure_row_contiguous_matrix(inputs[2], d, s);
|
||||
array biases = ensure_row_contiguous_matrix(inputs[3], d, s);
|
||||
std::optional<array> biases = std::nullopt;
|
||||
if (inputs.size() == 4) {
|
||||
biases = ensure_row_contiguous_matrix(inputs[3], d, s);
|
||||
}
|
||||
|
||||
// Extract the matmul shapes
|
||||
bool non_batched = w.ndim() == 2 && x.flags().row_contiguous;
|
||||
@ -818,30 +874,33 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
N,
|
||||
K,
|
||||
d,
|
||||
s);
|
||||
s,
|
||||
mode_);
|
||||
return;
|
||||
}
|
||||
|
||||
// It is a qmv with a small inner dimension so route to qmv_quad kernel
|
||||
if (transpose_ && (K == 128 || K == 64) && is_power_of_2(bits_)) {
|
||||
qmv_quad(x, w, scales, biases, out, group_size_, bits_, M, N, K, d, s);
|
||||
qmv_quad(
|
||||
x, w, scales, biases, out, group_size_, bits_, M, N, K, d, s, mode_);
|
||||
return;
|
||||
}
|
||||
|
||||
// Run of the mill qmv
|
||||
if (transpose_) {
|
||||
qmv(x, w, scales, biases, out, group_size_, bits_, M, N, K, d, s);
|
||||
qmv(x, w, scales, biases, out, group_size_, bits_, M, N, K, d, s, mode_);
|
||||
return;
|
||||
}
|
||||
|
||||
// Run of the mill qvm
|
||||
if (K < 1024) {
|
||||
qvm(x, w, scales, biases, out, group_size_, bits_, M, N, K, d, s);
|
||||
qvm(x, w, scales, biases, out, group_size_, bits_, M, N, K, d, s, mode_);
|
||||
return;
|
||||
}
|
||||
|
||||
// Qvm with large dimension so route to a split K kernel for more parallelism
|
||||
qvm_split_k(x, w, scales, biases, out, group_size_, bits_, M, N, K, d, s);
|
||||
qvm_split_k(
|
||||
x, w, scales, biases, out, group_size_, bits_, M, N, K, d, s, mode_);
|
||||
return;
|
||||
}
|
||||
|
||||
@ -854,9 +913,12 @@ void GatherQMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
array x = ensure_row_contiguous_matrix(inputs[0], d, s);
|
||||
array w = ensure_row_contiguous_matrix(inputs[1], d, s);
|
||||
array scales = ensure_row_contiguous_matrix(inputs[2], d, s);
|
||||
array biases = ensure_row_contiguous_matrix(inputs[3], d, s);
|
||||
const array& lhs_indices = inputs[4];
|
||||
const array& rhs_indices = inputs[5];
|
||||
std::optional<array> biases = std::nullopt;
|
||||
if (inputs.size() == 6) {
|
||||
biases = ensure_row_contiguous_matrix(inputs[3], d, s);
|
||||
}
|
||||
const array& lhs_indices = inputs[inputs.size() - 2];
|
||||
const array& rhs_indices = inputs[inputs.size() - 1];
|
||||
|
||||
int K = x.shape(-1);
|
||||
int M = x.shape(-2);
|
||||
@ -884,7 +946,8 @@ void GatherQMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
N,
|
||||
K,
|
||||
d,
|
||||
s);
|
||||
s,
|
||||
mode_);
|
||||
return;
|
||||
}
|
||||
|
||||
@ -905,7 +968,8 @@ void GatherQMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
N,
|
||||
K,
|
||||
d,
|
||||
s);
|
||||
s,
|
||||
mode_);
|
||||
return;
|
||||
}
|
||||
|
||||
@ -924,7 +988,8 @@ void GatherQMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
N,
|
||||
K,
|
||||
d,
|
||||
s);
|
||||
s,
|
||||
mode_);
|
||||
return;
|
||||
}
|
||||
|
||||
@ -942,10 +1007,11 @@ void GatherQMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
N,
|
||||
K,
|
||||
d,
|
||||
s);
|
||||
s,
|
||||
mode_);
|
||||
}
|
||||
|
||||
void fast::AffineQuantize::eval_gpu(
|
||||
void fast::Quantize::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
auto& w_pre = inputs[0];
|
||||
|
27
mlx/ops.cpp
27
mlx/ops.cpp
@ -4089,7 +4089,7 @@ array quantized_matmul(
|
||||
inputs = {
|
||||
astype(x, dtype), w, astype(scales, dtype), astype(*biases, dtype)};
|
||||
} else {
|
||||
throw std::invalid_argument("ERROR!");
|
||||
inputs = {x, w, scales};
|
||||
}
|
||||
|
||||
if (x.ndim() > 2 && w.ndim() > 2) {
|
||||
@ -4568,7 +4568,23 @@ array gather_qmm(
|
||||
auto out_shape = lhs_indices.shape();
|
||||
out_shape.push_back(x.shape(-2));
|
||||
out_shape.push_back(w_outer_dims);
|
||||
|
||||
std::vector<array> inputs;
|
||||
if (mode == "affine") {
|
||||
inputs = {
|
||||
astype(x, out_type, s),
|
||||
std::move(w),
|
||||
astype(scales, out_type, s),
|
||||
astype(*biases, out_type, s),
|
||||
std::move(lhs_indices),
|
||||
std::move(rhs_indices)};
|
||||
} else {
|
||||
inputs = {
|
||||
astype(x, out_type, s),
|
||||
std::move(w),
|
||||
std::move(scales),
|
||||
std::move(lhs_indices),
|
||||
std::move(rhs_indices)};
|
||||
}
|
||||
return array(
|
||||
std::move(out_shape),
|
||||
out_type,
|
||||
@ -4580,12 +4596,7 @@ array gather_qmm(
|
||||
transpose,
|
||||
sorted_indices && !rhs_indices_,
|
||||
sorted_indices && !lhs_indices_),
|
||||
{astype(x, out_type, s),
|
||||
std::move(w),
|
||||
astype(scales, out_type, s),
|
||||
astype(*biases, out_type, s),
|
||||
std::move(lhs_indices),
|
||||
std::move(rhs_indices)});
|
||||
std::move(inputs));
|
||||
}
|
||||
|
||||
array tensordot(
|
||||
|
@ -3243,6 +3243,10 @@ std::vector<array> QuantizedMatmul::vjp(
|
||||
throw std::runtime_error(
|
||||
"[QuantizedMatmul::vjp] no gradient wrt the quantized weights.");
|
||||
} else {
|
||||
if (mode_ == "mxfp4") {
|
||||
throw std::runtime_error(
|
||||
"[QuantizedMatmul::vjp] no gradient wrt scales with mxfp4 quantization.");
|
||||
}
|
||||
if (!dsb) {
|
||||
int ndim = primals[1].ndim();
|
||||
auto fc = flatten(cotangents[0], 0, -ndim, stream());
|
||||
@ -3372,14 +3376,19 @@ std::vector<array> GatherQMM::vjp(
|
||||
// gradient wrt to the indices is undefined
|
||||
else if (arg > 3) {
|
||||
throw std::runtime_error(
|
||||
"GatherQMM::vjp cannot compute the gradient wrt the indices.");
|
||||
"[GatherQMM::vjp] cannot compute the gradient wrt the indices.");
|
||||
}
|
||||
|
||||
// gradient wrt to w_q, scales or biases
|
||||
else if (arg == 1) {
|
||||
throw std::runtime_error(
|
||||
"GatherQMM::vjp no gradient wrt the quantized weights.");
|
||||
"[GatherQMM::vjp] no gradient wrt the quantized weights.");
|
||||
} else {
|
||||
if (mode_ == "mxfp4") {
|
||||
throw std::runtime_error(
|
||||
"[GatherQMM::vjp] no gradient wrt scales with mxfp4 quantization.");
|
||||
}
|
||||
|
||||
if (!dsb) {
|
||||
auto shape = w.shape();
|
||||
shape.pop_back();
|
||||
|
@ -98,11 +98,11 @@ class QuantizedEmbedding(Module):
|
||||
# Initialize the quantized weight
|
||||
scale = math.sqrt(1 / dims)
|
||||
weight = mx.random.normal(shape=(num_embeddings, dims), scale=scale)
|
||||
self.weight, scales_biases = mx.quantize(weight, group_size, bits, mode=mode)
|
||||
self.weight, *scales_biases = mx.quantize(weight, group_size, bits, mode=mode)
|
||||
if mode == "affine":
|
||||
self.scales, self.biases = scales_biases
|
||||
else:
|
||||
self.scales = scales_biases
|
||||
(self.scales,) = scales_biases
|
||||
self.num_embeddings = num_embeddings
|
||||
self.dims = dims
|
||||
|
||||
@ -155,12 +155,16 @@ class QuantizedEmbedding(Module):
|
||||
"""Create a :obj:`QuantizedEmbedding` layer from an :obj:`Embedding` layer."""
|
||||
embedding_dims, dims = embedding_layer.weight.shape
|
||||
ql = cls(embedding_dims, dims, group_size, bits)
|
||||
ql.weight, ql.scales, ql.biases = mx.quantize(
|
||||
ql.weight, *scales_biases = mx.quantize(
|
||||
embedding_layer.weight,
|
||||
group_size,
|
||||
bits,
|
||||
mode=mode,
|
||||
)
|
||||
if mode == "affine":
|
||||
ql.scales, ql.biases = scales_biases
|
||||
else:
|
||||
(ql.scales,) = scales_biases
|
||||
return ql
|
||||
|
||||
|
||||
@ -210,11 +214,11 @@ class QuantizedLinear(Module):
|
||||
high=scale,
|
||||
shape=(output_dims, input_dims),
|
||||
)
|
||||
self.weight, scales_biases = mx.quantize(weight, group_size, bits, mode=mode)
|
||||
self.weight, *scales_biases = mx.quantize(weight, group_size, bits, mode=mode)
|
||||
if mode == "affine":
|
||||
self.scales, self.biases = scales_biases
|
||||
else:
|
||||
self.scales = scales_biases
|
||||
(self.scales,) = scales_biases
|
||||
|
||||
# And bias if needed
|
||||
if bias:
|
||||
@ -257,7 +261,7 @@ class QuantizedLinear(Module):
|
||||
"""Create a :obj:`QuantizedLinear` layer from a :obj:`Linear` layer."""
|
||||
output_dims, input_dims = linear_layer.weight.shape
|
||||
ql = cls(input_dims, output_dims, False, group_size, bits)
|
||||
ql.weight, scales_biases = mx.quantize(
|
||||
ql.weight, *scales_biases = mx.quantize(
|
||||
linear_layer.weight,
|
||||
group_size,
|
||||
bits,
|
||||
@ -266,7 +270,7 @@ class QuantizedLinear(Module):
|
||||
if mode == "affine":
|
||||
ql.scales, ql.biases = scales_biases
|
||||
else:
|
||||
ql.scales = scales_biases
|
||||
(ql.scales,) = scales_biases
|
||||
|
||||
if "bias" in linear_layer:
|
||||
ql.bias = linear_layer.bias
|
||||
|
@ -198,6 +198,12 @@ class TestBase(mlx_tests.MLXTestCase):
|
||||
self.assertTrue(isinstance(m.layers[1], nn.ReLU))
|
||||
self.assertTrue(isinstance(m.layers[2], nn.QuantizedLinear))
|
||||
|
||||
nn.quantize(m, group_size=32, mode="mxfp4")
|
||||
self.assertTrue(isinstance(m.layers[0], nn.QuantizedEmbedding))
|
||||
self.assertTrue(isinstance(m.layers[1], nn.ReLU))
|
||||
self.assertTrue(isinstance(m.layers[2], nn.QuantizedLinear))
|
||||
self.assertTrue(isinstance(m.layers[2].scales, mx.array))
|
||||
|
||||
def test_quantize_freeze(self):
|
||||
lin = nn.Linear(512, 512)
|
||||
qlin = lin.to_quantized()
|
||||
|
@ -218,6 +218,34 @@ class TestQuantized(mlx_tests.MLXTestCase):
|
||||
self.assertEqual(y_q.shape, y_hat.shape)
|
||||
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
|
||||
|
||||
def test_mxfp4_qmv(self):
|
||||
key = mx.random.key(0)
|
||||
k1, k2 = mx.random.split(key)
|
||||
tests = product(
|
||||
[256, 512, 67], # M
|
||||
[64, 128], # N
|
||||
[0, 1, 3, 8], # B
|
||||
)
|
||||
for M, N, B in tests:
|
||||
with self.subTest(shape=(B, M, N), group_size=32):
|
||||
x_shape = (3, 1, N) if B == 0 else (B, 1, N)
|
||||
w_shape = (M, N) if B == 0 else (B, M, N)
|
||||
x = mx.random.normal(shape=x_shape, key=k1)
|
||||
w = mx.random.normal(shape=w_shape, key=k2)
|
||||
w_q, scales = mx.quantize(w, group_size=32, mode="mxfp4")
|
||||
w_hat = mx.dequantize(w_q, scales, group_size=32, mode="mxfp4")
|
||||
y_q = mx.quantized_matmul(
|
||||
x,
|
||||
w_q,
|
||||
scales,
|
||||
transpose=True,
|
||||
group_size=32,
|
||||
mode="mxfp4",
|
||||
)
|
||||
y_hat = x @ mx.swapaxes(w_hat, -1, -2)
|
||||
self.assertEqual(y_q.shape, y_hat.shape)
|
||||
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
|
||||
|
||||
def test_qvm(self):
|
||||
key = mx.random.key(0)
|
||||
k1, k2 = mx.random.split(key)
|
||||
@ -283,6 +311,38 @@ class TestQuantized(mlx_tests.MLXTestCase):
|
||||
self.assertEqual(y_q.shape, y_hat.shape)
|
||||
self.assertLess((y_q - y_hat).abs().max(), 2e-3)
|
||||
|
||||
def test_mxfp4_qvm(self):
|
||||
key = mx.random.key(0)
|
||||
k1, k2 = mx.random.split(key)
|
||||
tests = product(
|
||||
[32, 128, 256], # M
|
||||
[128, 256, 67], # N
|
||||
[0, 1, 3, 8], # B
|
||||
)
|
||||
# Add a splitk
|
||||
tests = list(tests)
|
||||
tests.append((128, 16384, 0))
|
||||
|
||||
for M, N, B in tests:
|
||||
with self.subTest(shape=(B, M, N)):
|
||||
x_shape = (1, N) if B == 0 else (B, 1, N)
|
||||
w_shape = (N, M) if B == 0 else (B, N, M)
|
||||
x = mx.random.normal(shape=x_shape, key=k1)
|
||||
w = mx.random.normal(shape=w_shape, key=k2)
|
||||
w_q, scales = mx.quantize(w, group_size=32, mode="mxfp4")
|
||||
w_hat = mx.dequantize(w_q, scales, group_size=32, mode="mxfp4")
|
||||
y_q = mx.quantized_matmul(
|
||||
x,
|
||||
w_q,
|
||||
scales,
|
||||
transpose=False,
|
||||
group_size=32,
|
||||
mode="mxfp4",
|
||||
)
|
||||
y_hat = x @ w_hat
|
||||
self.assertEqual(y_q.shape, y_hat.shape)
|
||||
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
|
||||
|
||||
def test_mode_error_cases(self):
|
||||
w = mx.random.normal(shape=(256, 256))
|
||||
x = mx.random.normal(shape=(1, 256))
|
||||
@ -475,9 +535,13 @@ class TestQuantized(mlx_tests.MLXTestCase):
|
||||
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
|
||||
|
||||
def test_gather_qmm(self):
|
||||
def quantize(w, transpose=True, group_size=64, bits=4):
|
||||
qw, s, b = mx.quantize(w, group_size=group_size, bits=bits)
|
||||
w_hat = mx.dequantize(qw, s, b, group_size=group_size, bits=bits)
|
||||
def quantize(w, transpose=True, group_size=64, bits=4, mode="affine"):
|
||||
if mode == "affine":
|
||||
qw, s, b = mx.quantize(w, group_size=group_size, bits=bits, mode=mode)
|
||||
else:
|
||||
qw, s = mx.quantize(w, group_size=group_size, bits=bits, mode=mode)
|
||||
b = None
|
||||
w_hat = mx.dequantize(qw, s, b, group_size=group_size, bits=bits, mode=mode)
|
||||
if transpose:
|
||||
w_hat = w_hat.swapaxes(-1, -2)
|
||||
return w_hat, qw, s, b
|
||||
@ -494,6 +558,7 @@ class TestQuantized(mlx_tests.MLXTestCase):
|
||||
transpose=True,
|
||||
group_size=64,
|
||||
bits=4,
|
||||
mode="affine",
|
||||
):
|
||||
with self.subTest(
|
||||
M=M,
|
||||
@ -507,12 +572,13 @@ class TestQuantized(mlx_tests.MLXTestCase):
|
||||
transpose=transpose,
|
||||
group_size=group_size,
|
||||
bits=bits,
|
||||
mode=mode,
|
||||
):
|
||||
x = mx.random.normal(shape=batch_A + (M, K)).astype(dtype)
|
||||
w = mx.random.normal(
|
||||
shape=batch_B + ((N, K) if transpose else (K, N))
|
||||
).astype(dtype)
|
||||
w_hat, qw, s, b = quantize(w, transpose, group_size, bits)
|
||||
w_hat, qw, s, b = quantize(w, transpose, group_size, bits, mode=mode)
|
||||
|
||||
if lhs_indices is not None:
|
||||
lhs_indices = mx.array(lhs_indices)
|
||||
@ -530,8 +596,8 @@ class TestQuantized(mlx_tests.MLXTestCase):
|
||||
transpose=transpose,
|
||||
group_size=group_size,
|
||||
bits=bits,
|
||||
mode=mode,
|
||||
)
|
||||
|
||||
self.assertTrue(mx.allclose(c1, c2, atol=1e-4))
|
||||
|
||||
inputs = (
|
||||
@ -575,6 +641,14 @@ class TestQuantized(mlx_tests.MLXTestCase):
|
||||
"batch_B": (4, 1),
|
||||
"rhs_indices": ((2,), (0,), (1,)),
|
||||
},
|
||||
{
|
||||
"batch_A": (1,),
|
||||
"lhs_indices": (0,),
|
||||
"batch_B": (3,),
|
||||
"rhs_indices": (2, 1),
|
||||
"group_size": 32,
|
||||
"mode": "mxfp4",
|
||||
},
|
||||
)
|
||||
|
||||
for kwargs in inputs:
|
||||
@ -618,9 +692,14 @@ class TestQuantized(mlx_tests.MLXTestCase):
|
||||
self.assertTrue(mx.allclose(g1, g2, atol=1e-4))
|
||||
|
||||
def test_gather_qmm_sorted(self):
|
||||
def quantize(w, transpose=True, group_size=64, bits=4):
|
||||
qw, s, b = mx.quantize(w, group_size=group_size, bits=bits)
|
||||
w_hat = mx.dequantize(qw, s, b, group_size=group_size, bits=bits)
|
||||
def quantize(w, transpose=True, bits=4, group_size=64, mode="affine"):
|
||||
if mode == "affine":
|
||||
qw, s, b = mx.quantize(w, group_size=group_size, bits=bits, mode=mode)
|
||||
else:
|
||||
qw, s = mx.quantize(w, group_size=group_size, bits=bits, mode=mode)
|
||||
b = None
|
||||
|
||||
w_hat = mx.dequantize(qw, s, b, group_size=group_size, bits=bits, mode=mode)
|
||||
if transpose:
|
||||
w_hat = w_hat.swapaxes(-1, -2)
|
||||
return w_hat, qw, s, b
|
||||
@ -640,19 +719,21 @@ class TestQuantized(mlx_tests.MLXTestCase):
|
||||
|
||||
parameters = [
|
||||
# L, K, D, E, I, transpose
|
||||
(32, 512, 512, 4, 2, True),
|
||||
(32, 512, 544, 4, 2, True),
|
||||
(133, 512, 512, 4, 2, True),
|
||||
(133, 512, 555, 4, 2, True),
|
||||
(133, 512, 512, 4, 2, True),
|
||||
(64, 512, 512, 4, 2, False),
|
||||
(64, 512, 544, 4, 2, False),
|
||||
(133, 512, 512, 4, 2, False),
|
||||
(133, 512, 544, 4, 2, False),
|
||||
(133, 512, 555, 4, 2, False),
|
||||
(64, 512, 512, 4, 2, False),
|
||||
(32, 512, 512, 4, 2, True, "affine"),
|
||||
(32, 512, 544, 4, 2, True, "mxfp4"),
|
||||
(133, 512, 512, 4, 2, True, "affine"),
|
||||
(133, 512, 555, 4, 2, True, "affine"),
|
||||
(133, 512, 512, 4, 2, True, "affine"),
|
||||
(64, 512, 512, 4, 2, False, "affine"),
|
||||
(64, 512, 544, 4, 2, False, "mxfp4"),
|
||||
(133, 512, 512, 4, 2, False, "affine"),
|
||||
(133, 512, 544, 4, 2, False, "affine"),
|
||||
(133, 512, 555, 4, 2, False, "affine"),
|
||||
(64, 512, 512, 4, 2, False, "affine"),
|
||||
]
|
||||
for L, K, D, E, I, transpose in parameters:
|
||||
for L, K, D, E, I, transpose, mode in parameters:
|
||||
if mode == "mxfp4":
|
||||
group_size = 32
|
||||
K, D = (K, D) if transpose else (D, K)
|
||||
ishape = (L, I)
|
||||
xshape = (L, 1, 1, K)
|
||||
@ -661,14 +742,28 @@ class TestQuantized(mlx_tests.MLXTestCase):
|
||||
indices = (mx.random.uniform(shape=ishape) * E).astype(mx.uint32)
|
||||
x = mx.random.normal(xshape) / K**0.5
|
||||
w = mx.random.normal(wshape) / K**0.5
|
||||
w, *wq = quantize(w, transpose=transpose)
|
||||
w, *wq = quantize(w, group_size=group_size, mode=mode, transpose=transpose)
|
||||
|
||||
y1 = mx.gather_mm(x, w, rhs_indices=indices)
|
||||
y2 = mx.gather_qmm(x, *wq, transpose=transpose, rhs_indices=indices)
|
||||
y2 = mx.gather_qmm(
|
||||
x,
|
||||
*wq,
|
||||
group_size=group_size,
|
||||
mode=mode,
|
||||
transpose=transpose,
|
||||
rhs_indices=indices
|
||||
)
|
||||
xs, idx, inv_order = gather_sort(x, indices)
|
||||
y3 = mx.gather_mm(xs, w, rhs_indices=idx, sorted_indices=True)
|
||||
|
||||
y4 = mx.gather_qmm(
|
||||
xs, *wq, rhs_indices=idx, transpose=transpose, sorted_indices=True
|
||||
xs,
|
||||
*wq,
|
||||
group_size=group_size,
|
||||
mode=mode,
|
||||
rhs_indices=idx,
|
||||
transpose=transpose,
|
||||
sorted_indices=True
|
||||
)
|
||||
y3 = scatter_unsort(y3, inv_order, indices.shape)
|
||||
y4 = scatter_unsort(y4, inv_order, indices.shape)
|
||||
|
Loading…
Reference in New Issue
Block a user