mxfp4 works

This commit is contained in:
Awni Hannun 2025-08-19 07:49:56 -07:00
parent 4cf90c9762
commit 6295e53216
12 changed files with 2420 additions and 257 deletions

View File

@ -407,6 +407,51 @@ void _qmm_dispatch(
}
}
// template <typename T>
// void _qmm_mxfp4_dispatch_typed(
// array& out,
// const array& x,
// const array& w,
// const array& scales,
// bool transposed_w) {
// int K = x.shape(-1);
// int M = x.ndim() > 1 ? x.shape(-2) : 1;
// int N = out.shape(-1);
// int w_els = w.ndim() > 2 ? w.shape(-1) * w.shape(-2) : 0;
// int g_els = w.ndim() > 2 ? scales.shape(-1) * scales.shape(-2) : 0;
// int batch_size = x.size() / (K * M);
//
// auto out_ptr = out.data<T>();
// auto x_ptr = x.data<T>();
// auto w_ptr = w.data<uint32_t>();
// auto scales_ptr = scales.data<T>();
// for (int i = 0; i < batch_size; i++) {
// _qmm_mxfp4_dispatch_typed<T>(
// out_ptr + i * M * N,
// x_ptr + elem_to_loc(i * M * K, x.shape(), x.strides()),
// w_ptr + elem_to_loc(i * w_els, w.shape(), w.strides()),
// scales_ptr + elem_to_loc(i * g_els, scales.shape(),
// scales.strides()), M, N, K, transposed_w);
// }
// }
//
//
// void _qmm_mxfp4_dispatch(
// array& out,
// const array& x,
// const array& w,
// const array& scales,
// bool transposed_w) {
// switch (x.dtype()) {
// case bfloat16:
// _qmm_mxfp4_dispatch_typed<bfloat16>(out, x, w, scales, transposed_w);
// break;
// default:
// throw std::invalid_argument(
// "[quantized_matmul] only bfloat is supported for mxfp4");
// }
// }
template <typename T>
void _bs_qmm_dispatch_typed(
array& out,
@ -521,7 +566,6 @@ void QuantizedMatmul::eval_cpu(const std::vector<array>& inputs, array& out) {
auto& x_pre = inputs[0];
auto& w_pre = inputs[1];
auto& scales_pre = inputs[2];
auto& biases_pre = inputs[3];
std::vector<array> temps;
auto ensure_row_contiguous = [s = stream(), &temps](const array& arr) {
@ -537,7 +581,6 @@ void QuantizedMatmul::eval_cpu(const std::vector<array>& inputs, array& out) {
auto x = ensure_row_contiguous(x_pre);
auto w = ensure_row_contiguous(w_pre);
auto scales = ensure_row_contiguous(scales_pre);
auto biases = ensure_row_contiguous(biases_pre);
out.set_data(allocator::malloc(out.nbytes()));
@ -546,18 +589,31 @@ void QuantizedMatmul::eval_cpu(const std::vector<array>& inputs, array& out) {
encoder.set_input_array(x);
encoder.set_input_array(w);
encoder.set_input_array(scales);
encoder.set_input_array(biases);
encoder.set_output_array(out);
encoder.dispatch([out = array::unsafe_weak_copy(out),
x = array::unsafe_weak_copy(x),
w = array::unsafe_weak_copy(w),
scales = array::unsafe_weak_copy(scales),
biases = array::unsafe_weak_copy(biases),
group_size_ = group_size_,
bits_ = bits_,
transpose_ = transpose_]() mutable {
_qmm_dispatch(out, x, w, scales, biases, group_size_, bits_, transpose_);
});
if (mode_ == "affine") {
auto biases = ensure_row_contiguous(inputs[3]);
encoder.set_input_array(biases);
encoder.dispatch([out = array::unsafe_weak_copy(out),
x = array::unsafe_weak_copy(x),
w = array::unsafe_weak_copy(w),
scales = array::unsafe_weak_copy(scales),
biases = array::unsafe_weak_copy(biases),
group_size_ = group_size_,
bits_ = bits_,
transpose_ = transpose_]() mutable {
_qmm_dispatch(out, x, w, scales, biases, group_size_, bits_, transpose_);
});
} else {
// encoder.dispatch([out = array::unsafe_weak_copy(out),
// x = array::unsafe_weak_copy(x),
// w = array::unsafe_weak_copy(w),
// scales = array::unsafe_weak_copy(scales),
// group_size_ = group_size_,
// bits_ = bits_,
// transpose_ = transpose_]() mutable {
// _qmm_mxfp4_dispatch(out, x, w, scales, transpose_);
// });
}
}
void GatherQMM::eval_cpu(const std::vector<array>& inputs, array& out) {
@ -705,7 +761,7 @@ void dispatch_quantize(
w_ptr, out_ptr, scales_ptr, biases_ptr, bits, group_size, w.size());
}
void fast::AffineQuantize::eval_cpu(
void fast::Quantize::eval_cpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
auto ensure_row_contiguous = [s = stream()](const array& arr) {
@ -764,7 +820,7 @@ void fast::AffineQuantize::eval_cpu(
}
} else {
throw std::runtime_error(
"[fast::AffineQuantize::eval_cpu] Only supports floating point inputs");
"[fast::Quantize::eval_cpu] Only supports floating point inputs");
}
});
}

View File

@ -109,6 +109,7 @@ if(NOT MLX_METAL_JIT)
reduction/reduce_col.h
reduction/reduce_row.h)
build_kernel(quantized quantized.h ${STEEL_HEADERS})
build_kernel(fp4_quantized fp4_quantized.h ${STEEL_HEADERS})
build_kernel(scan scan.h)
build_kernel(softmax softmax.h)
build_kernel(logsumexp logsumexp.h)

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,126 @@
// Copyright © 2025 Apple Inc.
// clang-format off
#include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/steel/gemm/gemm.h"
#include "mlx/backend/metal/kernels/fp4_quantized.h"
#define instantiate_quantized(name, type) \
instantiate_kernel( \
#name "_" #type "_gs_32_b_4", \
name, \
type, \
32, \
uint8_t)
#define instantiate_quantized_batched(name, type, batched) \
instantiate_kernel( \
#name "_" #type "_gs_32_b_4_batch_" #batched, \
name, \
type, \
32, \
batched, \
uint8_t)
#define instantiate_quantized_aligned(name, type, aligned) \
instantiate_kernel( \
#name "_" #type "_gs_32_b_4_alN_" #aligned, \
name, \
type, \
32, \
aligned, \
uint8_t)
#define instantiate_quantized_aligned_batched(name, type, aligned, batched) \
instantiate_kernel( \
#name "_" #type "_gs_32_b_4_alN_" #aligned "_batch_" #batched, \
name, \
type, \
32, \
aligned, \
batched, \
uint8_t)
#define instantiate_quantized_quad(name, type, D, batched) \
instantiate_kernel( \
#name "_" #type "_gs_32_b_4_d_" #D "_batch_" #batched, \
name, \
type, \
32, \
D, \
batched, \
uint8_t)
#define instantiate_quantized_split_k(name, type, split_k) \
instantiate_kernel( \
#name "_" #type "_gs_32_b_4_spk_" #split_k, \
name, \
type, \
32, \
split_k, \
uint8_t)
#define instantiate_gather_qmm_rhs(func, name, type, bm, bn, bk, wm, wn, transpose) \
instantiate_kernel( \
#name "_" #type "_gs_32_b_4_bm_" #bm "_bn_" #bn "_bk_" #bk "_wm_" #wm "_wn_" #wn, \
func, \
type, \
32, \
uint8_t, \
bm, \
bn, \
bk, \
wm, \
wn, \
transpose)
#define instantiate_quantized_batched_wrap(name, type) \
instantiate_quantized_batched(name, type, 1) \
instantiate_quantized_batched(name, type, 0)
#define instantiate_quantized_all_batched(type) \
instantiate_quantized_batched_wrap(mxfp4_qmv_fast, type) \
instantiate_quantized_batched_wrap(mxfp4_qmv, type) \
instantiate_quantized_batched_wrap(mxfp4_qvm, type) \
instantiate_quantized_batched_wrap(mxfp4_qmm_n, type)
#define instantiate_quantized_all_single(type) \
instantiate_quantized(mxfp4_gather_qmv_fast, type) \
instantiate_quantized(mxfp4_gather_qmv, type) \
instantiate_quantized(mxfp4_gather_qvm, type) \
instantiate_quantized(mxfp4_gather_qmm_n, type)
#define instantiate_quantized_all_aligned(type) \
instantiate_quantized_aligned(mxfp4_gather_qmm_t, type, true) \
instantiate_quantized_aligned(mxfp4_gather_qmm_t, type, false) \
instantiate_quantized_aligned_batched(mxfp4_qmm_t, type, true, 1) \
instantiate_quantized_aligned_batched(mxfp4_qmm_t, type, true, 0) \
instantiate_quantized_aligned_batched(mxfp4_qmm_t, type, false, 1) \
instantiate_quantized_aligned_batched(mxfp4_qmm_t, type, false, 0)
#define instantiate_quantized_all_quad(type) \
instantiate_quantized_quad(mxfp4_qmv_quad, type, 64, 1) \
instantiate_quantized_quad(mxfp4_qmv_quad, type, 64, 0) \
instantiate_quantized_quad(mxfp4_qmv_quad, type, 128, 1) \
instantiate_quantized_quad(mxfp4_qmv_quad, type, 128, 0)
#define instantiate_quantized_all_splitk(type) \
instantiate_quantized_split_k(mxfp4_qvm_split_k, type, 8) \
instantiate_quantized_split_k(mxfp4_qvm_split_k, type, 32)
#define instantiate_quantized_all_rhs(type) \
instantiate_gather_qmm_rhs(mxfp4_gather_qmm_rhs, mxfp4_gather_qmm_rhs_nt, type, 16, 32, 32, 1, 2, true) \
instantiate_gather_qmm_rhs(mxfp4_gather_qmm_rhs, mxfp4_gather_qmm_rhs_nn, type, 16, 32, 32, 1, 2, false)
#define instantiate_quantized_types(type) \
instantiate_quantized_all_batched(type) \
instantiate_quantized_all_quad(type) \
instantiate_quantized_all_splitk(type) \
instantiate_quantized_all_single(type) \
instantiate_quantized_all_aligned(type) \
instantiate_quantized_all_rhs(type)
instantiate_quantized_types(float)
instantiate_quantized_types(bfloat16_t)
instantiate_quantized_types(float16_t)
// clang-format on

View File

@ -1434,7 +1434,7 @@ METAL_FUNC void adjust_matrix_offsets(
}
template <typename T, int group_size, int bits, int D, bool batched>
[[kernel]] void qmv_quad(
[[kernel]] void affine_qmv_quad(
const device uint32_t* w [[buffer(0)]],
const device T* scales [[buffer(1)]],
const device T* biases [[buffer(2)]],
@ -1486,7 +1486,7 @@ template <typename T, int group_size, int bits, int D, bool batched>
}
template <typename T, int group_size, int bits, bool batched>
[[kernel]] void qmv_fast(
[[kernel]] void affine_qmv_fast(
const device uint32_t* w [[buffer(0)]],
const device T* scales [[buffer(1)]],
const device T* biases [[buffer(2)]],
@ -1538,7 +1538,7 @@ template <typename T, int group_size, int bits, bool batched>
}
template <typename T, const int group_size, const int bits, bool batched>
[[kernel]] void qmv(
[[kernel]] void affine_qmv(
const device uint32_t* w [[buffer(0)]],
const device T* scales [[buffer(1)]],
const device T* biases [[buffer(2)]],
@ -1590,7 +1590,7 @@ template <typename T, const int group_size, const int bits, bool batched>
}
template <typename T, const int group_size, const int bits, bool batched>
[[kernel]] void qvm(
[[kernel]] void affine_qvm(
const device uint32_t* w [[buffer(0)]],
const device T* scales [[buffer(1)]],
const device T* biases [[buffer(2)]],
@ -1642,7 +1642,7 @@ template <typename T, const int group_size, const int bits, bool batched>
}
template <typename T, const int group_size, const int bits, int split_k = 32>
[[kernel]] void qvm_split_k(
[[kernel]] void affine_qvm_split_k(
const device uint32_t* w [[buffer(0)]],
const device T* scales [[buffer(1)]],
const device T* biases [[buffer(2)]],
@ -1706,7 +1706,7 @@ template <
const int BM = 32,
const int BK = 32,
const int BN = 32>
[[kernel]] void qmm_t(
[[kernel]] void affine_qmm_t(
const device uint32_t* w [[buffer(0)]],
const device T* scales [[buffer(1)]],
const device T* biases [[buffer(2)]],
@ -1764,7 +1764,7 @@ template <
const int BM = 32,
const int BK = 32,
const int BN = 32>
[[kernel]] void qmm_n(
[[kernel]] void affine_qmm_n(
const device uint32_t* w [[buffer(0)]],
const device T* scales [[buffer(1)]],
const device T* biases [[buffer(2)]],
@ -1817,7 +1817,7 @@ template <
}
template <typename T, int group_size, int bits>
[[kernel]] void gather_qmv_fast(
[[kernel]] void affine_gather_qmv_fast(
const device uint32_t* w [[buffer(0)]],
const device T* scales [[buffer(1)]],
const device T* biases [[buffer(2)]],
@ -1879,7 +1879,7 @@ template <typename T, int group_size, int bits>
}
template <typename T, int group_size, int bits>
[[kernel]] void gather_qmv(
[[kernel]] void affine_gather_qmv(
const device uint32_t* w [[buffer(0)]],
const device T* scales [[buffer(1)]],
const device T* biases [[buffer(2)]],
@ -1941,7 +1941,7 @@ template <typename T, int group_size, int bits>
}
template <typename T, int group_size, int bits>
[[kernel]] void gather_qvm(
[[kernel]] void affine_gather_qvm(
const device uint32_t* w [[buffer(0)]],
const device T* scales [[buffer(1)]],
const device T* biases [[buffer(2)]],
@ -2010,7 +2010,7 @@ template <
const int BM = 32,
const int BK = 32,
const int BN = 32>
[[kernel]] void gather_qmm_t(
[[kernel]] void affine_gather_qmm_t(
const device uint32_t* w [[buffer(0)]],
const device T* scales [[buffer(1)]],
const device T* biases [[buffer(2)]],
@ -2077,7 +2077,7 @@ template <
const int BM = 32,
const int BK = 32,
const int BN = 32>
[[kernel]] void gather_qmm_n(
[[kernel]] void affine_gather_qmm_n(
const device uint32_t* w [[buffer(0)]],
const device T* scales [[buffer(1)]],
const device T* biases [[buffer(2)]],
@ -2234,7 +2234,7 @@ template <
int WM,
int WN,
bool transpose>
[[kernel]] void gather_qmm_rhs(
[[kernel]] void affine_gather_qmm_rhs(
const device T* x [[buffer(0)]],
const device uint32_t* w [[buffer(1)]],
const device T* scales [[buffer(2)]],

View File

@ -79,40 +79,40 @@
instantiate_quantized_batched(name, type, group_size, bits, 0)
#define instantiate_quantized_all_batched(type, group_size, bits) \
instantiate_quantized_batched_wrap(qmv_fast, type, group_size, bits) \
instantiate_quantized_batched_wrap(qmv, type, group_size, bits) \
instantiate_quantized_batched_wrap(qvm, type, group_size, bits) \
instantiate_quantized_batched_wrap(qmm_n, type, group_size, bits)
instantiate_quantized_batched_wrap(affine_qmv_fast, type, group_size, bits) \
instantiate_quantized_batched_wrap(affine_qmv, type, group_size, bits) \
instantiate_quantized_batched_wrap(affine_qvm, type, group_size, bits) \
instantiate_quantized_batched_wrap(affine_qmm_n, type, group_size, bits)
#define instantiate_quantized_all_single(type, group_size, bits) \
instantiate_quantized(affine_quantize, type, group_size, bits) \
instantiate_quantized(affine_dequantize, type, group_size, bits) \
instantiate_quantized(gather_qmv_fast, type, group_size, bits) \
instantiate_quantized(gather_qmv, type, group_size, bits) \
instantiate_quantized(gather_qvm, type, group_size, bits) \
instantiate_quantized(gather_qmm_n, type, group_size, bits)
instantiate_quantized(affine_gather_qmv_fast, type, group_size, bits) \
instantiate_quantized(affine_gather_qmv, type, group_size, bits) \
instantiate_quantized(affine_gather_qvm, type, group_size, bits) \
instantiate_quantized(affine_gather_qmm_n, type, group_size, bits)
#define instantiate_quantized_all_aligned(type, group_size, bits) \
instantiate_quantized_aligned(gather_qmm_t, type, group_size, bits, true) \
instantiate_quantized_aligned(gather_qmm_t, type, group_size, bits, false) \
instantiate_quantized_aligned_batched(qmm_t, type, group_size, bits, true, 1) \
instantiate_quantized_aligned_batched(qmm_t, type, group_size, bits, true, 0) \
instantiate_quantized_aligned_batched(qmm_t, type, group_size, bits, false, 1) \
instantiate_quantized_aligned_batched(qmm_t, type, group_size, bits, false, 0)
instantiate_quantized_aligned(affine_gather_qmm_t, type, group_size, bits, true) \
instantiate_quantized_aligned(affine_gather_qmm_t, type, group_size, bits, false) \
instantiate_quantized_aligned_batched(affine_qmm_t, type, group_size, bits, true, 1) \
instantiate_quantized_aligned_batched(affine_qmm_t, type, group_size, bits, true, 0) \
instantiate_quantized_aligned_batched(affine_qmm_t, type, group_size, bits, false, 1) \
instantiate_quantized_aligned_batched(affine_qmm_t, type, group_size, bits, false, 0)
#define instantiate_quantized_all_quad(type, group_size, bits) \
instantiate_quantized_quad(qmv_quad, type, group_size, bits, 64, 1) \
instantiate_quantized_quad(qmv_quad, type, group_size, bits, 64, 0) \
instantiate_quantized_quad(qmv_quad, type, group_size, bits, 128, 1) \
instantiate_quantized_quad(qmv_quad, type, group_size, bits, 128, 0)
instantiate_quantized_quad(affine_qmv_quad, type, group_size, bits, 64, 1) \
instantiate_quantized_quad(affine_qmv_quad, type, group_size, bits, 64, 0) \
instantiate_quantized_quad(affine_qmv_quad, type, group_size, bits, 128, 1) \
instantiate_quantized_quad(affine_qmv_quad, type, group_size, bits, 128, 0)
#define instantiate_quantized_all_splitk(type, group_size, bits) \
instantiate_quantized_split_k(qvm_split_k, type, group_size, bits, 8) \
instantiate_quantized_split_k(qvm_split_k, type, group_size, bits, 32)
instantiate_quantized_split_k(affine_qvm_split_k, type, group_size, bits, 8) \
instantiate_quantized_split_k(affine_qvm_split_k, type, group_size, bits, 32)
#define instantiate_quantized_all_rhs(type, group_size, bits) \
instantiate_gather_qmm_rhs(gather_qmm_rhs, gather_qmm_rhs_nt, type, group_size, bits, 16, 32, 32, 1, 2, true) \
instantiate_gather_qmm_rhs(gather_qmm_rhs, gather_qmm_rhs_nn, type, group_size, bits, 16, 32, 32, 1, 2, false)
instantiate_gather_qmm_rhs(affine_gather_qmm_rhs, affine_gather_qmm_rhs_nt, type, group_size, bits, 16, 32, 32, 1, 2, true) \
instantiate_gather_qmm_rhs(affine_gather_qmm_rhs, affine_gather_qmm_rhs_nn, type, group_size, bits, 16, 32, 32, 1, 2, false)
#define instantiate_quantized_funcs(type, group_size, bits) \
instantiate_quantized_all_single(type, group_size, bits) \

View File

@ -99,7 +99,7 @@ inline int add_strides_and_shapes(
const array& x,
const array& w,
const array& scales,
const array& biases,
const std::optional<array>& biases,
int offset) {
if (skip) {
return 0;
@ -109,16 +109,18 @@ inline int add_strides_and_shapes(
int x_batch_ndims = x.ndim() - 2;
int w_batch_ndims = w.ndim() - 2;
compute_encoder.set_bytes(x_batch_ndims, offset);
compute_encoder.set_vector_bytes(x.shape(), offset + 1);
compute_encoder.set_vector_bytes(x.strides(), offset + 2);
compute_encoder.set_bytes(w_batch_ndims, offset + 3);
compute_encoder.set_vector_bytes(w.shape(), offset + 4);
compute_encoder.set_vector_bytes(w.strides(), offset + 5);
compute_encoder.set_vector_bytes(scales.strides(), offset + 6);
compute_encoder.set_vector_bytes(biases.strides(), offset + 7);
compute_encoder.set_bytes(x_batch_ndims, offset++);
compute_encoder.set_vector_bytes(x.shape(), offset++);
compute_encoder.set_vector_bytes(x.strides(), offset++);
compute_encoder.set_bytes(w_batch_ndims, offset++);
compute_encoder.set_vector_bytes(w.shape(), offset++);
compute_encoder.set_vector_bytes(w.strides(), offset++);
compute_encoder.set_vector_bytes(scales.strides(), offset++);
if (biases) {
compute_encoder.set_vector_bytes(biases->strides(), offset++);
}
return 8;
return offset;
}
inline int add_gather_strides_and_shapes(
@ -130,12 +132,12 @@ inline int add_gather_strides_and_shapes(
lhs_indices.shape(), {lhs_indices.strides(), rhs_indices.strides()});
int ndims = shape.size();
compute_encoder.set_bytes(ndims, offset);
compute_encoder.set_vector_bytes(shape, offset + 1);
compute_encoder.set_vector_bytes(strides[0], offset + 2);
compute_encoder.set_vector_bytes(strides[1], offset + 3);
compute_encoder.set_bytes(ndims, offset++);
compute_encoder.set_vector_bytes(shape, offset++);
compute_encoder.set_vector_bytes(strides[0], offset++);
compute_encoder.set_vector_bytes(strides[1], offset++);
return 4;
return offset;
}
} // namespace
@ -144,7 +146,7 @@ void qmv_quad(
const array& x,
const array& w,
const array& scales,
const array& biases,
const std::optional<array>& biases,
array& out,
int group_size,
int bits,
@ -152,7 +154,8 @@ void qmv_quad(
int N,
int K,
metal::Device& d,
const Stream& s) {
const Stream& s,
const std::string& mode) {
int B = out.size() / M / N;
constexpr int quads_per_simd = 8;
@ -165,9 +168,10 @@ void qmv_quad(
std::string kname;
kname.reserve(64);
std::string type_string = get_type_string(x.dtype());
concatenate(
kname,
"qmv_quad_",
mode + "_qmv_quad_",
type_string,
"_gs_",
group_size,
@ -177,20 +181,23 @@ void qmv_quad(
K,
B > 1 ? "_batch_1" : "_batch_0");
auto template_def = get_template_definition(
kname, "qmv_quad", type_string, group_size, bits, K, B > 1);
kname, mode + "_qmv_quad", type_string, group_size, bits, K, B > 1);
auto kernel = get_quantized_kernel(d, kname, template_def);
auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_input_array(w, 0);
compute_encoder.set_input_array(scales, 1);
compute_encoder.set_input_array(biases, 2);
compute_encoder.set_input_array(x, 3);
compute_encoder.set_output_array(out, 4);
compute_encoder.set_bytes(K, 5);
compute_encoder.set_bytes(N, 6);
add_strides_and_shapes(compute_encoder, B <= 1, x, w, scales, biases, 7);
int c = 0;
compute_encoder.set_input_array(w, c++);
compute_encoder.set_input_array(scales, c++);
if (biases) {
compute_encoder.set_input_array(*biases, c++);
}
compute_encoder.set_input_array(x, c++);
compute_encoder.set_output_array(out, c++);
compute_encoder.set_bytes(K, c++);
compute_encoder.set_bytes(N, c++);
add_strides_and_shapes(compute_encoder, B <= 1, x, w, scales, biases, c++);
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
}
@ -199,7 +206,7 @@ void qmv(
const array& x,
const array& w,
const array& scales,
const array& biases,
const std::optional<array>& biases,
array& out,
int group_size,
int bits,
@ -207,7 +214,8 @@ void qmv(
int N,
int K,
metal::Device& d,
const Stream& s) {
const Stream& s,
const std::string& mode) {
int B = out.size() / M / N;
int bn = 8;
@ -219,9 +227,10 @@ void qmv(
kname.reserve(64);
std::string type_string = get_type_string(x.dtype());
bool fast = N % bn == 0 && K % 512 == 0;
concatenate(
kname,
fast ? "qmv_fast_" : "qmv_",
mode + (fast ? "_qmv_fast_" : "_qmv_"),
type_string,
"_gs_",
group_size,
@ -229,20 +238,28 @@ void qmv(
bits,
B > 1 ? "_batch_1" : "_batch_0");
auto template_def = get_template_definition(
kname, fast ? "qmv_fast" : "qmv", type_string, group_size, bits, B > 1);
kname,
mode + (fast ? "_qmv_fast" : "_qmv"),
type_string,
group_size,
bits,
B > 1);
auto kernel = get_quantized_kernel(d, kname, template_def);
auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_input_array(w, 0);
compute_encoder.set_input_array(scales, 1);
compute_encoder.set_input_array(biases, 2);
compute_encoder.set_input_array(x, 3);
compute_encoder.set_output_array(out, 4);
compute_encoder.set_bytes(K, 5);
compute_encoder.set_bytes(N, 6);
add_strides_and_shapes(compute_encoder, B <= 1, x, w, scales, biases, 7);
int c = 0;
compute_encoder.set_input_array(w, c++);
compute_encoder.set_input_array(scales, c++);
if (biases) {
compute_encoder.set_input_array(*biases, c++);
}
compute_encoder.set_input_array(x, c++);
compute_encoder.set_output_array(out, c++);
compute_encoder.set_bytes(K, c++);
compute_encoder.set_bytes(N, c++);
add_strides_and_shapes(compute_encoder, B <= 1, x, w, scales, biases, c);
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
}
@ -251,7 +268,7 @@ void qvm_split_k(
const array& x,
const array& w,
const array& scales,
const array& biases,
const std::optional<array>& biases,
array& out,
int group_size,
int bits,
@ -259,7 +276,8 @@ void qvm_split_k(
int N,
int K,
metal::Device& d,
const Stream& s) {
const Stream& s,
const std::string& mode) {
int split_k = K > 8192 ? 32 : 8;
int split_D = (K + split_k - 1) / split_k;
int B = out.size() / M / N;
@ -283,7 +301,6 @@ void qvm_split_k(
auto w_shape = w.shape();
auto w_strides = w.strides();
auto s_strides = scales.strides();
auto b_strides = biases.strides();
// Add split_k dim with reshapes
x_shape.insert(x_shape.end() - 2, split_k);
@ -297,7 +314,6 @@ void qvm_split_k(
w_strides.insert(w_strides.end() - 2, split_D * w.shape(-1));
w_batch_ndims += 1;
s_strides.insert(s_strides.end() - 2, split_D * scales.shape(-1));
b_strides.insert(b_strides.end() - 2, split_D * biases.shape(-1));
int final_block_size = K - (split_k - 1) * split_D;
@ -315,7 +331,7 @@ void qvm_split_k(
kname.reserve(64);
concatenate(
kname,
"qvm_split_k_",
mode + "_qvm_split_k_",
type_string,
"_gs_",
group_size,
@ -324,30 +340,37 @@ void qvm_split_k(
"_spk_",
split_k);
auto template_def = get_template_definition(
kname, "qvm_split_k", type_string, group_size, bits, split_k);
kname, mode + "_qvm_split_k", type_string, group_size, bits, split_k);
// Encode and dispatch kernel
auto kernel = get_quantized_kernel(d, kname, template_def);
auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_input_array(w, 0);
compute_encoder.set_input_array(scales, 1);
compute_encoder.set_input_array(biases, 2);
compute_encoder.set_input_array(x, 3);
compute_encoder.set_output_array(intermediate, 4);
compute_encoder.set_bytes(split_D, 5);
compute_encoder.set_bytes(N, 6);
int c = 0;
compute_encoder.set_input_array(w, c++);
compute_encoder.set_input_array(scales, c++);
if (biases) {
compute_encoder.set_input_array(*biases, c++);
}
compute_encoder.set_input_array(x, c++);
compute_encoder.set_output_array(intermediate, c++);
compute_encoder.set_bytes(split_D, c++);
compute_encoder.set_bytes(N, c++);
compute_encoder.set_bytes(x_batch_ndims, 7);
compute_encoder.set_vector_bytes(x_shape, 8);
compute_encoder.set_vector_bytes(x_strides, 9);
compute_encoder.set_bytes(w_batch_ndims, 10);
compute_encoder.set_vector_bytes(w_shape, 11);
compute_encoder.set_vector_bytes(w_strides, 12);
compute_encoder.set_vector_bytes(s_strides, 13);
compute_encoder.set_vector_bytes(b_strides, 14);
compute_encoder.set_bytes(final_block_size, 15);
compute_encoder.set_bytes(x_batch_ndims, c++);
compute_encoder.set_vector_bytes(x_shape, c++);
compute_encoder.set_vector_bytes(x_strides, c++);
compute_encoder.set_bytes(w_batch_ndims, c++);
compute_encoder.set_vector_bytes(w_shape, c++);
compute_encoder.set_vector_bytes(w_strides, c++);
compute_encoder.set_vector_bytes(s_strides, c++);
if (biases) {
auto b_strides = biases->strides();
b_strides.insert(b_strides.end() - 2, split_D * biases->shape(-1));
compute_encoder.set_vector_bytes(b_strides, c++);
}
compute_encoder.set_bytes(final_block_size, c++);
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
@ -364,7 +387,7 @@ void qvm(
const array& x,
const array& w,
const array& scales,
const array& biases,
const std::optional<array>& biases,
array& out,
int group_size,
int bits,
@ -372,7 +395,8 @@ void qvm(
int N,
int K,
metal::Device& d,
const Stream& s) {
const Stream& s,
const std::string& mode) {
int B = out.size() / M / N;
int bn = 64;
@ -385,7 +409,7 @@ void qvm(
std::string type_string = get_type_string(x.dtype());
concatenate(
kname,
"qvm_",
mode + "_qvm_",
type_string,
"_gs_",
group_size,
@ -393,20 +417,23 @@ void qvm(
bits,
B > 1 ? "_batch_1" : "_batch_0");
auto template_def = get_template_definition(
kname, "qvm", type_string, group_size, bits, B > 1);
kname, mode + "_qvm", type_string, group_size, bits, B > 1);
auto kernel = get_quantized_kernel(d, kname, template_def);
auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_input_array(w, 0);
compute_encoder.set_input_array(scales, 1);
compute_encoder.set_input_array(biases, 2);
compute_encoder.set_input_array(x, 3);
compute_encoder.set_output_array(out, 4);
compute_encoder.set_bytes(K, 5);
compute_encoder.set_bytes(N, 6);
add_strides_and_shapes(compute_encoder, B <= 1, x, w, scales, biases, 7);
int c = 0;
compute_encoder.set_input_array(w, c++);
compute_encoder.set_input_array(scales, c++);
if (biases) {
compute_encoder.set_input_array(*biases, c++);
}
compute_encoder.set_input_array(x, c++);
compute_encoder.set_output_array(out, c++);
compute_encoder.set_bytes(K, c++);
compute_encoder.set_bytes(N, c++);
add_strides_and_shapes(compute_encoder, B <= 1, x, w, scales, biases, c++);
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
}
@ -415,7 +442,7 @@ void qmm(
const array& x,
const array& w,
const array& scales,
const array& biases,
const std::optional<array>& biases,
array& out,
bool transpose,
int group_size,
@ -424,7 +451,8 @@ void qmm(
int N,
int K,
metal::Device& d,
const Stream& s) {
const Stream& s,
const std::string& mode) {
int B = out.size() / M / N;
int wm = 2;
@ -441,7 +469,7 @@ void qmm(
std::string type_string = get_type_string(x.dtype());
concatenate(
kname,
transpose ? "qmm_t_" : "qmm_n_",
mode + (transpose ? "_qmm_t_" : "_qmm_n_"),
type_string,
"_gs_",
group_size,
@ -452,25 +480,34 @@ void qmm(
std::string template_def;
if (transpose) {
template_def = get_template_definition(
kname, "qmm_t", type_string, group_size, bits, aligned, batched);
kname,
mode + "_qmm_t",
type_string,
group_size,
bits,
aligned,
batched);
} else {
template_def = get_template_definition(
kname, "qmm_n", type_string, group_size, bits, batched);
kname, mode + "_qmm_n", type_string, group_size, bits, batched);
}
auto kernel = get_quantized_kernel(d, kname, template_def);
auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_input_array(w, 0);
compute_encoder.set_input_array(scales, 1);
compute_encoder.set_input_array(biases, 2);
compute_encoder.set_input_array(x, 3);
compute_encoder.set_output_array(out, 4);
compute_encoder.set_bytes(K, 5);
compute_encoder.set_bytes(N, 6);
compute_encoder.set_bytes(M, 7);
add_strides_and_shapes(compute_encoder, B <= 1, x, w, scales, biases, 8);
int c = 0;
compute_encoder.set_input_array(w, c++);
compute_encoder.set_input_array(scales, c++);
if (biases) {
compute_encoder.set_input_array(*biases, c++);
}
compute_encoder.set_input_array(x, c++);
compute_encoder.set_output_array(out, c++);
compute_encoder.set_bytes(K, c++);
compute_encoder.set_bytes(N, c++);
compute_encoder.set_bytes(M, c++);
add_strides_and_shapes(compute_encoder, B <= 1, x, w, scales, biases, c);
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
}
@ -479,7 +516,7 @@ void gather_qmm(
const array& x,
const array& w,
const array& scales,
const array& biases,
const std::optional<array>& biases,
const array& lhs_indices,
const array& rhs_indices,
array& out,
@ -490,7 +527,8 @@ void gather_qmm(
int N,
int K,
metal::Device& d,
const Stream& s) {
const Stream& s,
const std::string& mode) {
int B = out.size() / M / N;
int wm = 2;
@ -507,7 +545,7 @@ void gather_qmm(
std::string type_string = get_type_string(x.dtype());
concatenate(
kname,
transpose ? "gather_qmm_t_" : "gather_qmm_n_",
mode + (transpose ? "_gather_qmm_t_" : "_gather_qmm_n_"),
type_string,
"_gs_",
group_size,
@ -517,30 +555,31 @@ void gather_qmm(
std::string template_def;
if (transpose) {
template_def = get_template_definition(
kname, "gather_qmm_t", type_string, group_size, bits, aligned);
kname, mode + "_gather_qmm_t", type_string, group_size, bits, aligned);
} else {
template_def = get_template_definition(
kname, "gather_qmm_n", type_string, group_size, bits);
kname, mode + "_gather_qmm_n", type_string, group_size, bits);
}
auto kernel = get_quantized_kernel(d, kname, template_def);
auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_input_array(w, 0);
compute_encoder.set_input_array(scales, 1);
compute_encoder.set_input_array(biases, 2);
compute_encoder.set_input_array(x, 3);
compute_encoder.set_input_array(lhs_indices, 4);
compute_encoder.set_input_array(rhs_indices, 5);
compute_encoder.set_output_array(out, 6);
compute_encoder.set_bytes(K, 7);
compute_encoder.set_bytes(N, 8);
compute_encoder.set_bytes(M, 9);
int n =
add_strides_and_shapes(compute_encoder, false, x, w, scales, biases, 10);
add_gather_strides_and_shapes(
compute_encoder, lhs_indices, rhs_indices, 10 + n);
int c = 0;
compute_encoder.set_input_array(w, c++);
compute_encoder.set_input_array(scales, c++);
if (biases) {
compute_encoder.set_input_array(*biases, c++);
}
compute_encoder.set_input_array(x, c++);
compute_encoder.set_input_array(lhs_indices, c++);
compute_encoder.set_input_array(rhs_indices, c++);
compute_encoder.set_output_array(out, c++);
compute_encoder.set_bytes(K, c++);
compute_encoder.set_bytes(N, c++);
compute_encoder.set_bytes(M, c++);
c = add_strides_and_shapes(compute_encoder, false, x, w, scales, biases, c);
add_gather_strides_and_shapes(compute_encoder, lhs_indices, rhs_indices, c);
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
}
@ -549,7 +588,7 @@ void gather_qmv(
const array& x,
const array& w,
const array& scales,
const array& biases,
const std::optional<array>& biases,
const array& lhs_indices,
const array& rhs_indices,
array& out,
@ -559,7 +598,8 @@ void gather_qmv(
int N,
int K,
metal::Device& d,
const Stream& s) {
const Stream& s,
const std::string& mode) {
int B = out.size() / M / N;
int bn = 8;
@ -573,7 +613,7 @@ void gather_qmv(
bool fast = N % bn == 0 && K % 512 == 0;
concatenate(
kname,
fast ? "gather_qmv_fast_" : "gather_qmv_",
mode + (fast ? "_gather_qmv_fast_" : "_gather_qmv_"),
type_string,
"_gs_",
group_size,
@ -581,7 +621,7 @@ void gather_qmv(
bits);
auto template_def = get_template_definition(
kname,
fast ? "gather_qmv_fast" : "gather_qmv",
mode + (fast ? "_gather_qmv_fast" : "_gather_qmv"),
type_string,
group_size,
bits);
@ -590,19 +630,20 @@ void gather_qmv(
auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_input_array(w, 0);
compute_encoder.set_input_array(scales, 1);
compute_encoder.set_input_array(biases, 2);
compute_encoder.set_input_array(x, 3);
compute_encoder.set_input_array(lhs_indices, 4);
compute_encoder.set_input_array(rhs_indices, 5);
compute_encoder.set_output_array(out, 6);
compute_encoder.set_bytes(K, 7);
compute_encoder.set_bytes(N, 8);
int n =
add_strides_and_shapes(compute_encoder, false, x, w, scales, biases, 9);
add_gather_strides_and_shapes(
compute_encoder, lhs_indices, rhs_indices, 9 + n);
int c = 0;
compute_encoder.set_input_array(w, c++);
compute_encoder.set_input_array(scales, c++);
if (biases) {
compute_encoder.set_input_array(*biases, c++);
}
compute_encoder.set_input_array(x, c++);
compute_encoder.set_input_array(lhs_indices, c++);
compute_encoder.set_input_array(rhs_indices, c++);
compute_encoder.set_output_array(out, c++);
compute_encoder.set_bytes(K, c++);
compute_encoder.set_bytes(N, c++);
c = add_strides_and_shapes(compute_encoder, false, x, w, scales, biases, c);
add_gather_strides_and_shapes(compute_encoder, lhs_indices, rhs_indices, c);
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
}
@ -611,7 +652,7 @@ void gather_qvm(
const array& x,
const array& w,
const array& scales,
const array& biases,
const std::optional<array>& biases,
const array& lhs_indices,
const array& rhs_indices,
array& out,
@ -621,7 +662,8 @@ void gather_qvm(
int N,
int K,
metal::Device& d,
const Stream& s) {
const Stream& s,
const std::string& mode) {
int B = out.size() / M / N;
int bn = 64;
@ -633,27 +675,34 @@ void gather_qvm(
kname.reserve(64);
std::string type_string = get_type_string(x.dtype());
concatenate(
kname, "gather_qvm_", type_string, "_gs_", group_size, "_b_", bits);
kname,
mode + "_gather_qvm_",
type_string,
"_gs_",
group_size,
"_b_",
bits);
auto template_def = get_template_definition(
kname, "gather_qvm", type_string, group_size, bits);
kname, mode + "_gather_qvm", type_string, group_size, bits);
auto kernel = get_quantized_kernel(d, kname, template_def);
auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_input_array(w, 0);
compute_encoder.set_input_array(scales, 1);
compute_encoder.set_input_array(biases, 2);
compute_encoder.set_input_array(x, 3);
compute_encoder.set_input_array(lhs_indices, 4);
compute_encoder.set_input_array(rhs_indices, 5);
compute_encoder.set_output_array(out, 6);
compute_encoder.set_bytes(K, 7);
compute_encoder.set_bytes(N, 8);
int n =
add_strides_and_shapes(compute_encoder, false, x, w, scales, biases, 9);
add_gather_strides_and_shapes(
compute_encoder, lhs_indices, rhs_indices, 9 + n);
int c = 0;
compute_encoder.set_input_array(w, c++);
compute_encoder.set_input_array(scales, c++);
if (biases) {
compute_encoder.set_input_array(*biases, c++);
}
compute_encoder.set_input_array(x, c++);
compute_encoder.set_input_array(lhs_indices, c++);
compute_encoder.set_input_array(rhs_indices, c++);
compute_encoder.set_output_array(out, c++);
compute_encoder.set_bytes(K, c++);
compute_encoder.set_bytes(N, c++);
c = add_strides_and_shapes(compute_encoder, false, x, w, scales, biases, c++);
add_gather_strides_and_shapes(compute_encoder, lhs_indices, rhs_indices, c);
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
}
@ -662,7 +711,7 @@ void gather_qmm_rhs(
const array& x_,
const array& w_,
const array& scales_,
const array& biases_,
const std::optional<array>& biases_,
const array& indices_,
array& out,
bool transpose,
@ -672,7 +721,8 @@ void gather_qmm_rhs(
int N,
int K,
metal::Device& d,
const Stream& s) {
const Stream& s,
const std::string mode) {
// Start by normalizing the indices
array indices = ensure_row_contiguous(indices_, d, s);
@ -697,7 +747,6 @@ void gather_qmm_rhs(
array x = broadcast_with_indices(x_);
array w = ensure_row_contiguous(w_, d, s);
array scales = ensure_row_contiguous(scales_, d, s);
array biases = ensure_row_contiguous(biases_, d, s);
// TODO: Tune the block sizes
int bm = 16, bn = 32, bk = 32;
@ -713,7 +762,7 @@ void gather_qmm_rhs(
std::string type_string = get_type_string(x.dtype());
concatenate(
kname,
transpose ? "gather_qmm_rhs_nt_" : "gather_qmm_rhs_nn_",
mode + (transpose ? "_gather_qmm_rhs_nt_" : "_gather_qmm_rhs_nn_"),
type_string,
"_gs_",
group_size,
@ -770,15 +819,19 @@ void gather_qmm_rhs(
MTL::Size group_dims(32, wn, wm);
MTL::Size grid_dims((N + bn - 1) / bn, (M + bm - 1) / bm, 1);
compute_encoder.set_input_array(x, 0);
compute_encoder.set_input_array(w, 1);
compute_encoder.set_input_array(scales, 2);
compute_encoder.set_input_array(biases, 3);
compute_encoder.set_input_array(indices, 4);
compute_encoder.set_output_array(out, 5);
compute_encoder.set_bytes(M, 6);
compute_encoder.set_bytes(N, 7);
compute_encoder.set_bytes(K, 8);
int c = 0;
compute_encoder.set_input_array(x, c++);
compute_encoder.set_input_array(w, c++);
compute_encoder.set_input_array(scales, c++);
if (biases_) {
array biases = ensure_row_contiguous(*biases_, d, s);
compute_encoder.set_input_array(biases, c++);
}
compute_encoder.set_input_array(indices, c++);
compute_encoder.set_output_array(out, c++);
compute_encoder.set_bytes(M, c++);
compute_encoder.set_bytes(N, c++);
compute_encoder.set_bytes(K, c++);
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
}
@ -794,7 +847,10 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
array x = ensure_row_contiguous_matrix(inputs[0], d, s);
array w = ensure_row_contiguous_matrix(inputs[1], d, s);
array scales = ensure_row_contiguous_matrix(inputs[2], d, s);
array biases = ensure_row_contiguous_matrix(inputs[3], d, s);
std::optional<array> biases = std::nullopt;
if (inputs.size() == 4) {
biases = ensure_row_contiguous_matrix(inputs[3], d, s);
}
// Extract the matmul shapes
bool non_batched = w.ndim() == 2 && x.flags().row_contiguous;
@ -818,30 +874,33 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
N,
K,
d,
s);
s,
mode_);
return;
}
// It is a qmv with a small inner dimension so route to qmv_quad kernel
if (transpose_ && (K == 128 || K == 64) && is_power_of_2(bits_)) {
qmv_quad(x, w, scales, biases, out, group_size_, bits_, M, N, K, d, s);
qmv_quad(
x, w, scales, biases, out, group_size_, bits_, M, N, K, d, s, mode_);
return;
}
// Run of the mill qmv
if (transpose_) {
qmv(x, w, scales, biases, out, group_size_, bits_, M, N, K, d, s);
qmv(x, w, scales, biases, out, group_size_, bits_, M, N, K, d, s, mode_);
return;
}
// Run of the mill qvm
if (K < 1024) {
qvm(x, w, scales, biases, out, group_size_, bits_, M, N, K, d, s);
qvm(x, w, scales, biases, out, group_size_, bits_, M, N, K, d, s, mode_);
return;
}
// Qvm with large dimension so route to a split K kernel for more parallelism
qvm_split_k(x, w, scales, biases, out, group_size_, bits_, M, N, K, d, s);
qvm_split_k(
x, w, scales, biases, out, group_size_, bits_, M, N, K, d, s, mode_);
return;
}
@ -854,9 +913,12 @@ void GatherQMM::eval_gpu(const std::vector<array>& inputs, array& out) {
array x = ensure_row_contiguous_matrix(inputs[0], d, s);
array w = ensure_row_contiguous_matrix(inputs[1], d, s);
array scales = ensure_row_contiguous_matrix(inputs[2], d, s);
array biases = ensure_row_contiguous_matrix(inputs[3], d, s);
const array& lhs_indices = inputs[4];
const array& rhs_indices = inputs[5];
std::optional<array> biases = std::nullopt;
if (inputs.size() == 6) {
biases = ensure_row_contiguous_matrix(inputs[3], d, s);
}
const array& lhs_indices = inputs[inputs.size() - 2];
const array& rhs_indices = inputs[inputs.size() - 1];
int K = x.shape(-1);
int M = x.shape(-2);
@ -884,7 +946,8 @@ void GatherQMM::eval_gpu(const std::vector<array>& inputs, array& out) {
N,
K,
d,
s);
s,
mode_);
return;
}
@ -905,7 +968,8 @@ void GatherQMM::eval_gpu(const std::vector<array>& inputs, array& out) {
N,
K,
d,
s);
s,
mode_);
return;
}
@ -924,7 +988,8 @@ void GatherQMM::eval_gpu(const std::vector<array>& inputs, array& out) {
N,
K,
d,
s);
s,
mode_);
return;
}
@ -942,10 +1007,11 @@ void GatherQMM::eval_gpu(const std::vector<array>& inputs, array& out) {
N,
K,
d,
s);
s,
mode_);
}
void fast::AffineQuantize::eval_gpu(
void fast::Quantize::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
auto& w_pre = inputs[0];

View File

@ -4089,7 +4089,7 @@ array quantized_matmul(
inputs = {
astype(x, dtype), w, astype(scales, dtype), astype(*biases, dtype)};
} else {
throw std::invalid_argument("ERROR!");
inputs = {x, w, scales};
}
if (x.ndim() > 2 && w.ndim() > 2) {
@ -4568,7 +4568,23 @@ array gather_qmm(
auto out_shape = lhs_indices.shape();
out_shape.push_back(x.shape(-2));
out_shape.push_back(w_outer_dims);
std::vector<array> inputs;
if (mode == "affine") {
inputs = {
astype(x, out_type, s),
std::move(w),
astype(scales, out_type, s),
astype(*biases, out_type, s),
std::move(lhs_indices),
std::move(rhs_indices)};
} else {
inputs = {
astype(x, out_type, s),
std::move(w),
std::move(scales),
std::move(lhs_indices),
std::move(rhs_indices)};
}
return array(
std::move(out_shape),
out_type,
@ -4580,12 +4596,7 @@ array gather_qmm(
transpose,
sorted_indices && !rhs_indices_,
sorted_indices && !lhs_indices_),
{astype(x, out_type, s),
std::move(w),
astype(scales, out_type, s),
astype(*biases, out_type, s),
std::move(lhs_indices),
std::move(rhs_indices)});
std::move(inputs));
}
array tensordot(

View File

@ -3243,6 +3243,10 @@ std::vector<array> QuantizedMatmul::vjp(
throw std::runtime_error(
"[QuantizedMatmul::vjp] no gradient wrt the quantized weights.");
} else {
if (mode_ == "mxfp4") {
throw std::runtime_error(
"[QuantizedMatmul::vjp] no gradient wrt scales with mxfp4 quantization.");
}
if (!dsb) {
int ndim = primals[1].ndim();
auto fc = flatten(cotangents[0], 0, -ndim, stream());
@ -3372,14 +3376,19 @@ std::vector<array> GatherQMM::vjp(
// gradient wrt to the indices is undefined
else if (arg > 3) {
throw std::runtime_error(
"GatherQMM::vjp cannot compute the gradient wrt the indices.");
"[GatherQMM::vjp] cannot compute the gradient wrt the indices.");
}
// gradient wrt to w_q, scales or biases
else if (arg == 1) {
throw std::runtime_error(
"GatherQMM::vjp no gradient wrt the quantized weights.");
"[GatherQMM::vjp] no gradient wrt the quantized weights.");
} else {
if (mode_ == "mxfp4") {
throw std::runtime_error(
"[GatherQMM::vjp] no gradient wrt scales with mxfp4 quantization.");
}
if (!dsb) {
auto shape = w.shape();
shape.pop_back();

View File

@ -98,11 +98,11 @@ class QuantizedEmbedding(Module):
# Initialize the quantized weight
scale = math.sqrt(1 / dims)
weight = mx.random.normal(shape=(num_embeddings, dims), scale=scale)
self.weight, scales_biases = mx.quantize(weight, group_size, bits, mode=mode)
self.weight, *scales_biases = mx.quantize(weight, group_size, bits, mode=mode)
if mode == "affine":
self.scales, self.biases = scales_biases
else:
self.scales = scales_biases
(self.scales,) = scales_biases
self.num_embeddings = num_embeddings
self.dims = dims
@ -155,12 +155,16 @@ class QuantizedEmbedding(Module):
"""Create a :obj:`QuantizedEmbedding` layer from an :obj:`Embedding` layer."""
embedding_dims, dims = embedding_layer.weight.shape
ql = cls(embedding_dims, dims, group_size, bits)
ql.weight, ql.scales, ql.biases = mx.quantize(
ql.weight, *scales_biases = mx.quantize(
embedding_layer.weight,
group_size,
bits,
mode=mode,
)
if mode == "affine":
ql.scales, ql.biases = scales_biases
else:
(ql.scales,) = scales_biases
return ql
@ -210,11 +214,11 @@ class QuantizedLinear(Module):
high=scale,
shape=(output_dims, input_dims),
)
self.weight, scales_biases = mx.quantize(weight, group_size, bits, mode=mode)
self.weight, *scales_biases = mx.quantize(weight, group_size, bits, mode=mode)
if mode == "affine":
self.scales, self.biases = scales_biases
else:
self.scales = scales_biases
(self.scales,) = scales_biases
# And bias if needed
if bias:
@ -257,7 +261,7 @@ class QuantizedLinear(Module):
"""Create a :obj:`QuantizedLinear` layer from a :obj:`Linear` layer."""
output_dims, input_dims = linear_layer.weight.shape
ql = cls(input_dims, output_dims, False, group_size, bits)
ql.weight, scales_biases = mx.quantize(
ql.weight, *scales_biases = mx.quantize(
linear_layer.weight,
group_size,
bits,
@ -266,7 +270,7 @@ class QuantizedLinear(Module):
if mode == "affine":
ql.scales, ql.biases = scales_biases
else:
ql.scales = scales_biases
(ql.scales,) = scales_biases
if "bias" in linear_layer:
ql.bias = linear_layer.bias

View File

@ -198,6 +198,12 @@ class TestBase(mlx_tests.MLXTestCase):
self.assertTrue(isinstance(m.layers[1], nn.ReLU))
self.assertTrue(isinstance(m.layers[2], nn.QuantizedLinear))
nn.quantize(m, group_size=32, mode="mxfp4")
self.assertTrue(isinstance(m.layers[0], nn.QuantizedEmbedding))
self.assertTrue(isinstance(m.layers[1], nn.ReLU))
self.assertTrue(isinstance(m.layers[2], nn.QuantizedLinear))
self.assertTrue(isinstance(m.layers[2].scales, mx.array))
def test_quantize_freeze(self):
lin = nn.Linear(512, 512)
qlin = lin.to_quantized()

View File

@ -218,6 +218,34 @@ class TestQuantized(mlx_tests.MLXTestCase):
self.assertEqual(y_q.shape, y_hat.shape)
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
def test_mxfp4_qmv(self):
key = mx.random.key(0)
k1, k2 = mx.random.split(key)
tests = product(
[256, 512, 67], # M
[64, 128], # N
[0, 1, 3, 8], # B
)
for M, N, B in tests:
with self.subTest(shape=(B, M, N), group_size=32):
x_shape = (3, 1, N) if B == 0 else (B, 1, N)
w_shape = (M, N) if B == 0 else (B, M, N)
x = mx.random.normal(shape=x_shape, key=k1)
w = mx.random.normal(shape=w_shape, key=k2)
w_q, scales = mx.quantize(w, group_size=32, mode="mxfp4")
w_hat = mx.dequantize(w_q, scales, group_size=32, mode="mxfp4")
y_q = mx.quantized_matmul(
x,
w_q,
scales,
transpose=True,
group_size=32,
mode="mxfp4",
)
y_hat = x @ mx.swapaxes(w_hat, -1, -2)
self.assertEqual(y_q.shape, y_hat.shape)
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
def test_qvm(self):
key = mx.random.key(0)
k1, k2 = mx.random.split(key)
@ -283,6 +311,38 @@ class TestQuantized(mlx_tests.MLXTestCase):
self.assertEqual(y_q.shape, y_hat.shape)
self.assertLess((y_q - y_hat).abs().max(), 2e-3)
def test_mxfp4_qvm(self):
key = mx.random.key(0)
k1, k2 = mx.random.split(key)
tests = product(
[32, 128, 256], # M
[128, 256, 67], # N
[0, 1, 3, 8], # B
)
# Add a splitk
tests = list(tests)
tests.append((128, 16384, 0))
for M, N, B in tests:
with self.subTest(shape=(B, M, N)):
x_shape = (1, N) if B == 0 else (B, 1, N)
w_shape = (N, M) if B == 0 else (B, N, M)
x = mx.random.normal(shape=x_shape, key=k1)
w = mx.random.normal(shape=w_shape, key=k2)
w_q, scales = mx.quantize(w, group_size=32, mode="mxfp4")
w_hat = mx.dequantize(w_q, scales, group_size=32, mode="mxfp4")
y_q = mx.quantized_matmul(
x,
w_q,
scales,
transpose=False,
group_size=32,
mode="mxfp4",
)
y_hat = x @ w_hat
self.assertEqual(y_q.shape, y_hat.shape)
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
def test_mode_error_cases(self):
w = mx.random.normal(shape=(256, 256))
x = mx.random.normal(shape=(1, 256))
@ -475,9 +535,13 @@ class TestQuantized(mlx_tests.MLXTestCase):
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
def test_gather_qmm(self):
def quantize(w, transpose=True, group_size=64, bits=4):
qw, s, b = mx.quantize(w, group_size=group_size, bits=bits)
w_hat = mx.dequantize(qw, s, b, group_size=group_size, bits=bits)
def quantize(w, transpose=True, group_size=64, bits=4, mode="affine"):
if mode == "affine":
qw, s, b = mx.quantize(w, group_size=group_size, bits=bits, mode=mode)
else:
qw, s = mx.quantize(w, group_size=group_size, bits=bits, mode=mode)
b = None
w_hat = mx.dequantize(qw, s, b, group_size=group_size, bits=bits, mode=mode)
if transpose:
w_hat = w_hat.swapaxes(-1, -2)
return w_hat, qw, s, b
@ -494,6 +558,7 @@ class TestQuantized(mlx_tests.MLXTestCase):
transpose=True,
group_size=64,
bits=4,
mode="affine",
):
with self.subTest(
M=M,
@ -507,12 +572,13 @@ class TestQuantized(mlx_tests.MLXTestCase):
transpose=transpose,
group_size=group_size,
bits=bits,
mode=mode,
):
x = mx.random.normal(shape=batch_A + (M, K)).astype(dtype)
w = mx.random.normal(
shape=batch_B + ((N, K) if transpose else (K, N))
).astype(dtype)
w_hat, qw, s, b = quantize(w, transpose, group_size, bits)
w_hat, qw, s, b = quantize(w, transpose, group_size, bits, mode=mode)
if lhs_indices is not None:
lhs_indices = mx.array(lhs_indices)
@ -530,8 +596,8 @@ class TestQuantized(mlx_tests.MLXTestCase):
transpose=transpose,
group_size=group_size,
bits=bits,
mode=mode,
)
self.assertTrue(mx.allclose(c1, c2, atol=1e-4))
inputs = (
@ -575,6 +641,14 @@ class TestQuantized(mlx_tests.MLXTestCase):
"batch_B": (4, 1),
"rhs_indices": ((2,), (0,), (1,)),
},
{
"batch_A": (1,),
"lhs_indices": (0,),
"batch_B": (3,),
"rhs_indices": (2, 1),
"group_size": 32,
"mode": "mxfp4",
},
)
for kwargs in inputs:
@ -618,9 +692,14 @@ class TestQuantized(mlx_tests.MLXTestCase):
self.assertTrue(mx.allclose(g1, g2, atol=1e-4))
def test_gather_qmm_sorted(self):
def quantize(w, transpose=True, group_size=64, bits=4):
qw, s, b = mx.quantize(w, group_size=group_size, bits=bits)
w_hat = mx.dequantize(qw, s, b, group_size=group_size, bits=bits)
def quantize(w, transpose=True, bits=4, group_size=64, mode="affine"):
if mode == "affine":
qw, s, b = mx.quantize(w, group_size=group_size, bits=bits, mode=mode)
else:
qw, s = mx.quantize(w, group_size=group_size, bits=bits, mode=mode)
b = None
w_hat = mx.dequantize(qw, s, b, group_size=group_size, bits=bits, mode=mode)
if transpose:
w_hat = w_hat.swapaxes(-1, -2)
return w_hat, qw, s, b
@ -640,19 +719,21 @@ class TestQuantized(mlx_tests.MLXTestCase):
parameters = [
# L, K, D, E, I, transpose
(32, 512, 512, 4, 2, True),
(32, 512, 544, 4, 2, True),
(133, 512, 512, 4, 2, True),
(133, 512, 555, 4, 2, True),
(133, 512, 512, 4, 2, True),
(64, 512, 512, 4, 2, False),
(64, 512, 544, 4, 2, False),
(133, 512, 512, 4, 2, False),
(133, 512, 544, 4, 2, False),
(133, 512, 555, 4, 2, False),
(64, 512, 512, 4, 2, False),
(32, 512, 512, 4, 2, True, "affine"),
(32, 512, 544, 4, 2, True, "mxfp4"),
(133, 512, 512, 4, 2, True, "affine"),
(133, 512, 555, 4, 2, True, "affine"),
(133, 512, 512, 4, 2, True, "affine"),
(64, 512, 512, 4, 2, False, "affine"),
(64, 512, 544, 4, 2, False, "mxfp4"),
(133, 512, 512, 4, 2, False, "affine"),
(133, 512, 544, 4, 2, False, "affine"),
(133, 512, 555, 4, 2, False, "affine"),
(64, 512, 512, 4, 2, False, "affine"),
]
for L, K, D, E, I, transpose in parameters:
for L, K, D, E, I, transpose, mode in parameters:
if mode == "mxfp4":
group_size = 32
K, D = (K, D) if transpose else (D, K)
ishape = (L, I)
xshape = (L, 1, 1, K)
@ -661,14 +742,28 @@ class TestQuantized(mlx_tests.MLXTestCase):
indices = (mx.random.uniform(shape=ishape) * E).astype(mx.uint32)
x = mx.random.normal(xshape) / K**0.5
w = mx.random.normal(wshape) / K**0.5
w, *wq = quantize(w, transpose=transpose)
w, *wq = quantize(w, group_size=group_size, mode=mode, transpose=transpose)
y1 = mx.gather_mm(x, w, rhs_indices=indices)
y2 = mx.gather_qmm(x, *wq, transpose=transpose, rhs_indices=indices)
y2 = mx.gather_qmm(
x,
*wq,
group_size=group_size,
mode=mode,
transpose=transpose,
rhs_indices=indices
)
xs, idx, inv_order = gather_sort(x, indices)
y3 = mx.gather_mm(xs, w, rhs_indices=idx, sorted_indices=True)
y4 = mx.gather_qmm(
xs, *wq, rhs_indices=idx, transpose=transpose, sorted_indices=True
xs,
*wq,
group_size=group_size,
mode=mode,
rhs_indices=idx,
transpose=transpose,
sorted_indices=True
)
y3 = scatter_unsort(y3, inv_order, indices.shape)
y4 = scatter_unsort(y4, inv_order, indices.shape)