diff --git a/mlx/backend/cpu/quantized.cpp b/mlx/backend/cpu/quantized.cpp index 1c02c4e61..f9cfb347d 100644 --- a/mlx/backend/cpu/quantized.cpp +++ b/mlx/backend/cpu/quantized.cpp @@ -407,6 +407,51 @@ void _qmm_dispatch( } } +// template +// void _qmm_mxfp4_dispatch_typed( +// array& out, +// const array& x, +// const array& w, +// const array& scales, +// bool transposed_w) { +// int K = x.shape(-1); +// int M = x.ndim() > 1 ? x.shape(-2) : 1; +// int N = out.shape(-1); +// int w_els = w.ndim() > 2 ? w.shape(-1) * w.shape(-2) : 0; +// int g_els = w.ndim() > 2 ? scales.shape(-1) * scales.shape(-2) : 0; +// int batch_size = x.size() / (K * M); +// +// auto out_ptr = out.data(); +// auto x_ptr = x.data(); +// auto w_ptr = w.data(); +// auto scales_ptr = scales.data(); +// for (int i = 0; i < batch_size; i++) { +// _qmm_mxfp4_dispatch_typed( +// out_ptr + i * M * N, +// x_ptr + elem_to_loc(i * M * K, x.shape(), x.strides()), +// w_ptr + elem_to_loc(i * w_els, w.shape(), w.strides()), +// scales_ptr + elem_to_loc(i * g_els, scales.shape(), +// scales.strides()), M, N, K, transposed_w); +// } +// } +// +// +// void _qmm_mxfp4_dispatch( +// array& out, +// const array& x, +// const array& w, +// const array& scales, +// bool transposed_w) { +// switch (x.dtype()) { +// case bfloat16: +// _qmm_mxfp4_dispatch_typed(out, x, w, scales, transposed_w); +// break; +// default: +// throw std::invalid_argument( +// "[quantized_matmul] only bfloat is supported for mxfp4"); +// } +// } + template void _bs_qmm_dispatch_typed( array& out, @@ -521,7 +566,6 @@ void QuantizedMatmul::eval_cpu(const std::vector& inputs, array& out) { auto& x_pre = inputs[0]; auto& w_pre = inputs[1]; auto& scales_pre = inputs[2]; - auto& biases_pre = inputs[3]; std::vector temps; auto ensure_row_contiguous = [s = stream(), &temps](const array& arr) { @@ -537,7 +581,6 @@ void QuantizedMatmul::eval_cpu(const std::vector& inputs, array& out) { auto x = ensure_row_contiguous(x_pre); auto w = ensure_row_contiguous(w_pre); auto scales = ensure_row_contiguous(scales_pre); - auto biases = ensure_row_contiguous(biases_pre); out.set_data(allocator::malloc(out.nbytes())); @@ -546,18 +589,31 @@ void QuantizedMatmul::eval_cpu(const std::vector& inputs, array& out) { encoder.set_input_array(x); encoder.set_input_array(w); encoder.set_input_array(scales); - encoder.set_input_array(biases); encoder.set_output_array(out); - encoder.dispatch([out = array::unsafe_weak_copy(out), - x = array::unsafe_weak_copy(x), - w = array::unsafe_weak_copy(w), - scales = array::unsafe_weak_copy(scales), - biases = array::unsafe_weak_copy(biases), - group_size_ = group_size_, - bits_ = bits_, - transpose_ = transpose_]() mutable { - _qmm_dispatch(out, x, w, scales, biases, group_size_, bits_, transpose_); - }); + if (mode_ == "affine") { + auto biases = ensure_row_contiguous(inputs[3]); + encoder.set_input_array(biases); + encoder.dispatch([out = array::unsafe_weak_copy(out), + x = array::unsafe_weak_copy(x), + w = array::unsafe_weak_copy(w), + scales = array::unsafe_weak_copy(scales), + biases = array::unsafe_weak_copy(biases), + group_size_ = group_size_, + bits_ = bits_, + transpose_ = transpose_]() mutable { + _qmm_dispatch(out, x, w, scales, biases, group_size_, bits_, transpose_); + }); + } else { + // encoder.dispatch([out = array::unsafe_weak_copy(out), + // x = array::unsafe_weak_copy(x), + // w = array::unsafe_weak_copy(w), + // scales = array::unsafe_weak_copy(scales), + // group_size_ = group_size_, + // bits_ = bits_, + // transpose_ = transpose_]() mutable { + // _qmm_mxfp4_dispatch(out, x, w, scales, transpose_); + // }); + } } void GatherQMM::eval_cpu(const std::vector& inputs, array& out) { @@ -705,7 +761,7 @@ void dispatch_quantize( w_ptr, out_ptr, scales_ptr, biases_ptr, bits, group_size, w.size()); } -void fast::AffineQuantize::eval_cpu( +void fast::Quantize::eval_cpu( const std::vector& inputs, std::vector& outputs) { auto ensure_row_contiguous = [s = stream()](const array& arr) { @@ -764,7 +820,7 @@ void fast::AffineQuantize::eval_cpu( } } else { throw std::runtime_error( - "[fast::AffineQuantize::eval_cpu] Only supports floating point inputs"); + "[fast::Quantize::eval_cpu] Only supports floating point inputs"); } }); } diff --git a/mlx/backend/metal/kernels/CMakeLists.txt b/mlx/backend/metal/kernels/CMakeLists.txt index 4069d8c21..c3164af0c 100644 --- a/mlx/backend/metal/kernels/CMakeLists.txt +++ b/mlx/backend/metal/kernels/CMakeLists.txt @@ -109,6 +109,7 @@ if(NOT MLX_METAL_JIT) reduction/reduce_col.h reduction/reduce_row.h) build_kernel(quantized quantized.h ${STEEL_HEADERS}) + build_kernel(fp4_quantized fp4_quantized.h ${STEEL_HEADERS}) build_kernel(scan scan.h) build_kernel(softmax softmax.h) build_kernel(logsumexp logsumexp.h) diff --git a/mlx/backend/metal/kernels/fp4_quantized.h b/mlx/backend/metal/kernels/fp4_quantized.h new file mode 100644 index 000000000..0ce9bc35d --- /dev/null +++ b/mlx/backend/metal/kernels/fp4_quantized.h @@ -0,0 +1,1789 @@ +// Copyright © 2023-2024 Apple Inc. + +#include +#include + +constant bool align_M [[function_constant(200)]]; +constant bool align_N [[function_constant(201)]]; +constant bool align_K [[function_constant(202)]]; + +using namespace metal; + +#define MLX_MTL_CONST static constant constexpr const + +MLX_MTL_CONST int SIMD_SIZE = 32; +MLX_MTL_CONST int QUAD_SIZE = 4; + +template +inline constexpr short get_pack_factor() { + return wsize / 4; +} + +template +inline constexpr short get_bytes_per_pack() { + return wsize / 8; +} + +template +inline void load_vector(const device T* x, thread U* x_thread) { + for (int i = 0; i < values_per_thread; i += 4) { + x_thread[i] = x[i]; + x_thread[i + 1] = x[i + 1]; + x_thread[i + 2] = x[i + 2]; + x_thread[i + 3] = x[i + 3]; + } +} + +template +inline void load_vector_safe(const device T* x, thread U* x_thread, int N) { + for (int i = 0; i < N; i += 4) { + x_thread[i] = x[i]; + x_thread[i + 1] = x[i + 1]; + x_thread[i + 2] = x[i + 2]; + x_thread[i + 3] = x[i + 3]; + } + + for (int i = N; i < values_per_thread; i++) { + x_thread[i] = 0; + } +} + +constant float MXFP4_LUT[16] = { + +0.0f, + +0.5f, + +1.0f, + +1.5f, + +2.0f, + +3.0f, + +4.0f, + +6.0f, + -0.0f, + -0.5f, + -1.0f, + -1.5f, + -2.0f, + -3.0f, + -4.0f, + -6.0f}; + +template +inline U qdot(const device uint8_t* w, const thread U* x_thread, U scale) { + U accum = 0; + + const device uint16_t* ws = (const device uint16_t*)w; + for (int i = 0; i < (values_per_thread / 4); i++) { + accum += + (x_thread[4 * i] * MXFP4_LUT[ws[i] & 0x000f] + + x_thread[4 * i + 1] * MXFP4_LUT[(ws[i] & 0x00f0) >> 4] + + x_thread[4 * i + 2] * MXFP4_LUT[(ws[i] & 0x0f00) >> 8] + + x_thread[4 * i + 3] * MXFP4_LUT[(ws[i] & 0xf000) >> 12]); + } + return scale * accum; +} + +template +inline U +qdot_safe(const device uint8_t* w, const thread U* x_thread, S scale, int N) { + U accum = 0; + + const device uint16_t* ws = (const device uint16_t*)w; + for (int i = 0; i < (N / 4); i++) { + accum += + (x_thread[4 * i] * MXFP4_LUT[ws[i] & 0x000f] + + x_thread[4 * i + 1] * MXFP4_LUT[(ws[i] & 0x00f0) >> 4] + + x_thread[4 * i + 2] * MXFP4_LUT[(ws[i] & 0x0f00) >> 8] + + x_thread[4 * i + 3] * MXFP4_LUT[(ws[i] & 0xf000) >> 12]); + } + return scale * accum; +} + +template +inline void qouter(const thread uint8_t* w, U x, U scale, thread U* result) { + for (int i = 0; i < (values_per_thread / 2); i++) { + result[2 * i] += x * scale * MXFP4_LUT[w[i] & 0x0f]; + result[2 * i + 1] += x * scale * MXFP4_LUT[(w[i] & 0xf0) >> 4]; + } +} + +template +inline void +dequantize(const device uint8_t* w, U scale, threadgroup U* w_local) { + for (int i = 0; i < (N / 2); i++) { + w_local[2 * i] = scale * static_cast(MXFP4_LUT[w[i] & 0x0f]); + w_local[2 * i + 1] = scale * static_cast(MXFP4_LUT[(w[i] & 0xf0) >> 4]); + } +} + +template < + typename T, + short BROWS, + short BCOLS, + short dst_ld, + short reduction_dim, + short tgp_size, + short group_size, + typename S> +struct QuantizedBlockLoader { + static_assert( + BCOLS <= group_size, + "The group size should be larger than the columns"); + static_assert( + group_size % BCOLS == 0, + "The group size should be divisible by the columns"); + + MLX_MTL_CONST short pack_factor = get_pack_factor<8>(); + MLX_MTL_CONST short bytes_per_pack = get_bytes_per_pack(); + MLX_MTL_CONST short BCOLS_PACKED = BCOLS / pack_factor; + MLX_MTL_CONST short n_reads = + (BCOLS_PACKED * BROWS < tgp_size) ? 1 : (BCOLS_PACKED * BROWS) / tgp_size; + MLX_MTL_CONST short group_steps = group_size / BCOLS; + + const int src_ld; + const int tile_stride; + short group_step_cnt; + const int group_stride; + + const short thread_idx; + const short bi; + const short bj; + + threadgroup T* dst; + const device uint8_t* src; + const device S* scales; + + QuantizedBlockLoader( + const device uint8_t* src_, + const device S* scales_, + const int src_ld_, + threadgroup T* dst_, + ushort simd_group_id [[simdgroup_index_in_threadgroup]], + ushort simd_lane_id [[thread_index_in_simdgroup]]) + : src_ld(src_ld_), + tile_stride( + reduction_dim ? BCOLS_PACKED * bytes_per_pack + : BROWS * src_ld * bytes_per_pack / pack_factor), + group_step_cnt(0), + group_stride(BROWS * src_ld / group_size), + thread_idx(simd_group_id * 32 + simd_lane_id), + bi(n_reads * thread_idx / BCOLS_PACKED), + bj((n_reads * thread_idx) % BCOLS_PACKED), + dst(dst_ + bi * dst_ld + bj * pack_factor), + src(src_ + bi * src_ld * bytes_per_pack / pack_factor + + bj * bytes_per_pack), + scales(scales_ + bi * src_ld / group_size) {} + + void load_unsafe() const { + if (BCOLS_PACKED * BROWS < tgp_size && bi >= BROWS) { + return; + } + + T scale = metal::pow(T(2.0), static_cast(*scales) - 127); + for (int i = 0; i < n_reads; i++) { + dequantize( + src + i * bytes_per_pack, scale, dst + i * pack_factor); + } + } + + void load_safe(short2 src_tile_dim) const { + if (BCOLS_PACKED * BROWS < tgp_size && bi >= BROWS) { + return; + } + + if (reduction_dim == 1 && bi >= src_tile_dim.x) { + for (int i = 0; i < n_reads * pack_factor; i++) { + dst[i] = T(0); + } + return; + } + + if (reduction_dim == 0 && bi >= src_tile_dim.y) { + for (int i = 0; i < n_reads * pack_factor; i++) { + dst[i] = T(0); + } + return; + } + + T scale = metal::pow(T(2.0), static_cast(*scales) - 127); + for (int i = 0; i < n_reads; i++) { + dequantize( + (device uint8_t*)(src + i * bytes_per_pack), + scale, + dst + i * pack_factor); + } + } + + void next() { + src += tile_stride; + if (reduction_dim == 1) { + if (group_steps > 1) { + group_step_cnt++; + if (group_step_cnt == group_steps) { + group_step_cnt = 0; + scales++; + } + } else { + scales++; + } + } else { + scales += group_stride; + } + } +}; + +template +METAL_FUNC void mxfp4_qmv_quad_impl( + const device uint32_t* w, + const device S* scales, + const device T* x, + device T* y, + constant int& in_vec_size, + const constant int& out_vec_size, + uint3 tid [[threadgroup_position_in_grid]], + uint quad_gid [[quadgroup_index_in_threadgroup]], + uint quad_lid [[thread_index_in_quadgroup]]) { + constexpr int quads_per_simd = SIMD_SIZE / QUAD_SIZE; + constexpr int pack_factor = 8; + constexpr int values_per_thread = D / QUAD_SIZE; + constexpr int packs_per_thread = values_per_thread / pack_factor; + constexpr int scale_step_per_thread = group_size / values_per_thread; + constexpr int results_per_quadgroup = 8; + + typedef float U; + + thread U x_thread[values_per_thread]; + thread U result[results_per_quadgroup] = {0}; + + // Adjust positions + const int in_vec_size_w = in_vec_size / pack_factor; + const int in_vec_size_g = in_vec_size / group_size; + const int out_row = tid.y * quads_per_simd * results_per_quadgroup + quad_gid; + + w += out_row * in_vec_size_w + quad_lid * packs_per_thread; + scales += out_row * in_vec_size_g + quad_lid / scale_step_per_thread; + x += tid.x * in_vec_size + quad_lid * values_per_thread; + y += tid.x * out_vec_size + out_row; + + load_vector(x, x_thread); + + for (int row = 0; row < results_per_quadgroup; row++) { + auto wl = (const device uint8_t*)(w + row * in_vec_size_w * quads_per_simd); + const device S* sl = scales + row * in_vec_size_g * quads_per_simd; + + U s = metal::pow(2.0f, static_cast(sl[0]) - 127); + if (row * quads_per_simd + out_row < out_vec_size) { + result[row] += qdot(wl, x_thread, s); + } + } + + for (int row = 0; row < results_per_quadgroup; row++) { + result[row] = quad_sum(result[row]); + if (quad_lid == 0 && row * quads_per_simd + out_row < out_vec_size) { + y[row * quads_per_simd] = static_cast(result[row]); + } + } +} + +template +METAL_FUNC void mxfp4_qmv_fast_impl( + const device uint32_t* w, + const device S* scales, + const device T* x, + device T* y, + const constant int& in_vec_size, + const constant int& out_vec_size, + uint3 tid [[threadgroup_position_in_grid]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + constexpr int packs_per_thread = 2; + constexpr int num_simdgroups = 2; + constexpr int results_per_simdgroup = 4; + constexpr int pack_factor = get_pack_factor<32>(); + constexpr int bytes_per_pack = get_bytes_per_pack<32>(); + constexpr int values_per_thread = pack_factor * packs_per_thread; + constexpr int block_size = values_per_thread * SIMD_SIZE; + constexpr int scale_step_per_thread = group_size / values_per_thread; + + const device uint8_t* ws = (const device uint8_t*)w; + + typedef float U; + + thread U x_thread[values_per_thread]; + thread U result[results_per_simdgroup] = {0}; + + // Adjust positions + const int in_vec_size_w = in_vec_size * bytes_per_pack / pack_factor; + const int in_vec_size_g = in_vec_size / group_size; + const int out_row = tid.y * (num_simdgroups * results_per_simdgroup) + + simd_gid * results_per_simdgroup; + + ws += out_row * in_vec_size_w + simd_lid * packs_per_thread * bytes_per_pack; + scales += out_row * in_vec_size_g + simd_lid / scale_step_per_thread; + x += tid.x * in_vec_size + simd_lid * values_per_thread; + y += tid.x * out_vec_size + out_row; + + for (int k = 0; k < in_vec_size; k += block_size) { + load_vector(x, x_thread); + + for (int row = 0; row < results_per_simdgroup; row++) { + auto wl = (const device uint8_t*)(ws + row * in_vec_size_w); + const device auto* sl = scales + row * in_vec_size_g; + + U s = metal::pow(2.0f, static_cast(sl[0]) - 127); + result[row] += qdot(wl, x_thread, s); + } + + ws += block_size * bytes_per_pack / pack_factor; + scales += block_size / group_size; + x += block_size; + } + + for (int row = 0; row < results_per_simdgroup; row++) { + result[row] = simd_sum(result[row]); + if (simd_lid == 0) { + y[row] = static_cast(result[row]); + } + } +} + +template +METAL_FUNC void mxfp4_qmv_impl( + const device uint32_t* w, + const device S* scales, + const device T* x, + device T* y, + const constant int& in_vec_size, + const constant int& out_vec_size, + uint3 tid [[threadgroup_position_in_grid]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + constexpr int num_simdgroups = 2; + constexpr int results_per_simdgroup = 4; + constexpr int packs_per_thread = 1; + constexpr int pack_factor = get_pack_factor<32>(); + constexpr int bytes_per_pack = get_bytes_per_pack<32>(); + + constexpr int values_per_thread = pack_factor * packs_per_thread; + constexpr int block_size = values_per_thread * SIMD_SIZE; + constexpr int scale_step_per_thread = group_size / values_per_thread; + + const device uint8_t* ws = (const device uint8_t*)w; + + typedef float U; + + thread U x_thread[values_per_thread]; + thread U result[results_per_simdgroup] = {0}; + + // Adjust positions + const int in_vec_size_w = in_vec_size * bytes_per_pack / pack_factor; + const int in_vec_size_g = in_vec_size / group_size; + const int out_row = tid.y * (num_simdgroups * results_per_simdgroup) + + simd_gid * results_per_simdgroup; + const int used_out_row = min(out_vec_size - results_per_simdgroup, out_row); + + if (out_row >= out_vec_size) { + return; + } + + // In this case we need to properly guard all our reads because there isn't + // even 1 tile in the matrix + if (out_vec_size < (num_simdgroups * results_per_simdgroup)) { + ws += + out_row * in_vec_size_w + simd_lid * packs_per_thread * bytes_per_pack; + scales += out_row * in_vec_size_g + simd_lid / scale_step_per_thread; + x += tid.x * in_vec_size + simd_lid * values_per_thread; + y += tid.x * out_vec_size + out_row; + + int k = 0; + for (; k < in_vec_size - block_size; k += block_size) { + load_vector(x, x_thread); + + for (int row = 0; out_row + row < out_vec_size; row++) { + auto wl = (const device uint8_t*)(ws + row * in_vec_size_w); + const device auto* sl = scales + row * in_vec_size_g; + + S s = sl[0]; + result[row] += qdot(wl, x_thread, s); + } + + ws += block_size * bytes_per_pack / pack_factor; + scales += block_size / group_size; + x += block_size; + } + const int remaining = clamp( + static_cast(in_vec_size - k - simd_lid * values_per_thread), + 0, + values_per_thread); + if (remaining > 0) { + load_vector_safe(x, x_thread, remaining); + + for (int row = 0; out_row + row < out_vec_size; row++) { + auto wl = (const device uint8_t*)(ws + row * in_vec_size_w); + const device auto* sl = scales + row * in_vec_size_g; + + U s = metal::pow(2.0f, static_cast(sl[0]) - 127); + result[row] += qdot(wl, x_thread, s); + } + } + + for (int row = 0; out_row + row < out_vec_size; row++) { + result[row] = simd_sum(result[row]); + if (simd_lid == 0) { + y[row] = static_cast(result[row]); + } + } + } + + // In this case the last tile is moved back to redo some output values + else { + ws += used_out_row * in_vec_size_w + + simd_lid * packs_per_thread * bytes_per_pack; + scales += used_out_row * in_vec_size_g + simd_lid / scale_step_per_thread; + x += tid.x * in_vec_size + simd_lid * values_per_thread; + y += tid.x * out_vec_size + used_out_row; + + int k = 0; + for (; k < in_vec_size - block_size; k += block_size) { + load_vector(x, x_thread); + + for (int row = 0; row < results_per_simdgroup; row++) { + auto wl = (const device uint8_t*)(ws + row * in_vec_size_w); + const device auto* sl = scales + row * in_vec_size_g; + + U s = metal::pow(2.0f, static_cast(sl[0]) - 127); + result[row] += qdot(wl, x_thread, s); + } + + ws += block_size * bytes_per_pack / pack_factor; + scales += block_size / group_size; + x += block_size; + } + const int remaining = clamp( + static_cast(in_vec_size - k - simd_lid * values_per_thread), + 0, + values_per_thread); + if (remaining > 0) { + load_vector_safe(x, x_thread, remaining); + + for (int row = 0; row < results_per_simdgroup; row++) { + auto wl = (const device uint8_t*)(ws + row * in_vec_size_w); + const device auto* sl = scales + row * in_vec_size_g; + + U s = metal::pow(2.0f, static_cast(sl[0]) - 127); + result[row] += + qdot_safe(wl, x_thread, s, remaining); + } + } + for (int row = 0; row < results_per_simdgroup; row++) { + result[row] = simd_sum(result[row]); + if (simd_lid == 0) { + y[row] = static_cast(result[row]); + } + } + } +} + +template +METAL_FUNC void mxfp4_qvm_impl( + const device uint32_t* w, + const device S* scales, + const device T* x, + device T* y, + const int in_vec_size, + const int out_vec_size, + uint3 tid [[threadgroup_position_in_grid]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + constexpr int num_simdgroups = 2; + constexpr int pack_factor = get_pack_factor<32>(); + constexpr int bytes_per_pack = get_bytes_per_pack(); + + constexpr int tn = 32 / pack_factor; + constexpr int block_size = SIMD_SIZE; + + using W_T = uint32_t; + const device W_T* ws = (const device W_T*)w; + + typedef float U; + typedef struct { + W_T wi[tn * bytes_per_pack]; + } vec_w; + + thread vec_w w_local; + thread U result[tn * pack_factor] = {0}; + thread U scale = 0; + thread U x_local = 0; + + // Adjust positions + const int out_vec_size_w = out_vec_size * bytes_per_pack / pack_factor; + const int out_vec_size_g = out_vec_size / group_size; + int out_col = pack_factor * tn * (tid.y * num_simdgroups + simd_gid); + ws += out_col * bytes_per_pack / pack_factor + simd_lid * out_vec_size_w; + scales += out_col / group_size + simd_lid * out_vec_size_g; + x += tid.x * in_vec_size + simd_lid; + y += tid.x * out_vec_size + out_col; + + if (out_col >= out_vec_size) { + return; + } + + // Loop over in_vec in blocks of block_size + int remaining = in_vec_size % block_size; + if (remaining == 0) { + for (int i = 0; i < in_vec_size; i += block_size) { + x_local = *x; + scale = metal::pow(2.0f, static_cast(*scales) - 127); + w_local = *((device vec_w*)ws); + qouter( + (thread uint8_t*)&w_local, x_local, scale, result); + + x += block_size; + scales += block_size * out_vec_size_g; + ws += block_size * out_vec_size_w; + } + } else { + for (int i = block_size; i < in_vec_size; i += block_size) { + x_local = *x; + scale = metal::pow(2.0f, static_cast(*scales) - 127); + w_local = *((device vec_w*)ws); + + qouter( + (thread uint8_t*)&w_local, x_local, scale, result); + + x += block_size; + scales += block_size * out_vec_size_g; + ws += block_size * out_vec_size_w; + } + if (static_cast(simd_lid) < remaining) { + x_local = *x; + scale = metal::pow(2.0f, static_cast(*scales) - 127); + w_local = *((device vec_w*)ws); + } else { + x_local = 0; + scale = 0; + } + qouter( + (thread uint8_t*)&w_local, x_local, scale, result); + } + +// Accumulate in the simdgroup +#pragma clang loop unroll(full) + for (int k = 0; k < tn * pack_factor; k++) { + result[k] = simd_sum(result[k]); + } + + // Store the result + if (simd_lid == 0) { +#pragma clang loop unroll(full) + for (int k = 0; k < tn * pack_factor; k++) { + y[k] = static_cast(result[k]); + } + } +} + +template < + typename T, + const int group_size, + const bool aligned_N, + typename S, + const int BM = 32, + const int BK = 32, + const int BN = 32> +METAL_FUNC void mxfp4_qmm_t_impl( + const device uint32_t* w, + const device S* scales, + const device T* x, + device T* y, + threadgroup T* Xs, + threadgroup T* Ws, + const constant int& K, + const constant int& N, + const constant int& M, + uint3 tid [[threadgroup_position_in_grid]], + uint lid [[thread_index_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + static_assert(BK >= SIMD_SIZE, "BK should be larger than SIMD_SIZE"); + static_assert(BK % SIMD_SIZE == 0, "BK should be divisible by SIMD_SIZE"); + + (void)lid; + + constexpr int WM = 2; + constexpr int WN = 2; + constexpr int pack_factor = get_pack_factor<8>(); + constexpr int bytes_per_pack = get_bytes_per_pack(); + + constexpr int BK_padded = (BK + 16 / sizeof(T)); + + // Instantiate the appropriate BlockMMA and Loader + using mma_t = mlx::steel:: + BlockMMA; + using loader_x_t = + mlx::steel::BlockLoader; + using loader_w_t = QuantizedBlockLoader< + T, + BN, + BK, + BK_padded, + 1, + WM * WN * SIMD_SIZE, + group_size, + S>; + + // Set the block + const int K_w = K * bytes_per_pack / pack_factor; + const int K_g = K / group_size; + const int y_row = tid.y * BM; + const int y_col = tid.x * BN; + + auto wl = (const device uint8_t*)w; + + x += y_row * static_cast(K); + wl += y_col * K_w; + scales += y_col * K_g; + y += y_row * static_cast(N) + y_col; + + // Make the x loader and mma operation + const short num_els = min(BM, M - y_row); + const short num_outs = min(BN, N - y_col); + loader_x_t loader_x(x, K, Xs, simd_gid, simd_lid); + loader_w_t loader_w(wl, scales, K, Ws, simd_gid, simd_lid); + mma_t mma_op(simd_gid, simd_lid); + + if (num_els < BM) { + if (!aligned_N && num_outs < BN) { + for (int k = 0; k < K; k += BK) { + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_x.load_safe(short2(BK, num_els)); + loader_w.load_safe(short2(BK, num_outs)); + threadgroup_barrier(mem_flags::mem_threadgroup); + mma_op.mma(Xs, Ws); + loader_x.next(); + loader_w.next(); + } + } else { + for (int k = 0; k < K; k += BK) { + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_x.load_safe(short2(BK, num_els)); + loader_w.load_unsafe(); + threadgroup_barrier(mem_flags::mem_threadgroup); + mma_op.mma(Xs, Ws); + loader_x.next(); + loader_w.next(); + } + } + } else { + if (!aligned_N && num_outs < BN) { + for (int k = 0; k < K; k += BK) { + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_x.load_unsafe(); + loader_w.load_safe(short2(BK, num_outs)); + threadgroup_barrier(mem_flags::mem_threadgroup); + mma_op.mma(Xs, Ws); + loader_x.next(); + loader_w.next(); + } + } else { + for (int k = 0; k < K; k += BK) { + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_x.load_unsafe(); + loader_w.load_unsafe(); + threadgroup_barrier(mem_flags::mem_threadgroup); + + mma_op.mma(Xs, Ws); + loader_x.next(); + loader_w.next(); + } + } + } + + // Store results to device memory + threadgroup_barrier(mem_flags::mem_threadgroup); + if (num_els < BM || num_outs < BN) { + mma_op.store_result_safe(y, N, short2(num_outs, num_els)); + } else { + mma_op.store_result(y, N); + } +} + +template < + typename T, + const int group_size, + typename S, + const int BM = 32, + const int BK = 32, + const int BN = 32> +METAL_FUNC void mxfp4_qmm_n_impl( + const device uint32_t* w, + const device S* scales, + const device T* x, + device T* y, + threadgroup T* Xs, + threadgroup T* Ws, + const constant int& K, + const constant int& N, + const constant int& M, + uint3 tid [[threadgroup_position_in_grid]], + uint lid [[thread_index_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + static_assert(BK >= SIMD_SIZE, "BK should be larger than SIMD_SIZE"); + static_assert(BK % SIMD_SIZE == 0, "BK should be divisible by SIMD_SIZE"); + + (void)lid; + + constexpr int WM = 2; + constexpr int WN = 2; + constexpr int pack_factor = get_pack_factor<8>(); + constexpr int bytes_per_pack = get_bytes_per_pack(); + + constexpr int BK_padded = (BK + 16 / sizeof(T)); + constexpr int BN_padded = (BN + 16 / sizeof(T)); + + // Instantiate the appropriate BlockMMA and Loader + using mma_t = mlx::steel:: + BlockMMA; + using loader_x_t = mlx::steel:: + BlockLoader; + using loader_w_t = QuantizedBlockLoader< + T, + BK, + BN, + BN_padded, + 0, + WM * WN * SIMD_SIZE, + group_size, + S>; + + auto wl = (const device uint8_t*)w; + + // Set the block + const int y_row = tid.y * BM; + const int y_col = tid.x * BN; + x += y_row * static_cast(K); + wl += y_col * bytes_per_pack / pack_factor; + scales += y_col / group_size; + y += y_row * static_cast(N) + y_col; + + // Make the x loader and mma operation + const short num_els = min(BM, M - y_row); + loader_x_t loader_x(x, K, Xs, simd_gid, simd_lid); + loader_w_t loader_w(wl, scales, N, Ws, simd_gid, simd_lid); + mma_t mma_op(simd_gid, simd_lid); + + if (num_els < BM) { + if ((K % BK) != 0) { + const int k_blocks = K / BK; + for (int k = 0; k < k_blocks; k++) { + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_x.load_safe(short2(BK, num_els)); + loader_w.load_unsafe(); + threadgroup_barrier(mem_flags::mem_threadgroup); + mma_op.mma(Xs, Ws); + loader_x.next(); + loader_w.next(); + } + const short num_k = K - k_blocks * BK; + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_x.load_safe(short2(num_k, num_els)); + loader_w.load_safe(short2(BN, num_k)); + threadgroup_barrier(mem_flags::mem_threadgroup); + mma_op.mma(Xs, Ws); + } else { + for (int k = 0; k < K; k += BK) { + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_x.load_safe(short2(BK, num_els)); + loader_w.load_unsafe(); + threadgroup_barrier(mem_flags::mem_threadgroup); + mma_op.mma(Xs, Ws); + loader_x.next(); + loader_w.next(); + } + } + } else { + if ((K % BK) != 0) { + const int k_blocks = K / BK; + for (int k = 0; k < k_blocks; k++) { + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_x.load_unsafe(); + loader_w.load_unsafe(); + threadgroup_barrier(mem_flags::mem_threadgroup); + mma_op.mma(Xs, Ws); + loader_x.next(); + loader_w.next(); + } + const short num_k = K - k_blocks * BK; + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_x.load_safe(short2(num_k, BM)); + loader_w.load_safe(short2(BN, num_k)); + threadgroup_barrier(mem_flags::mem_threadgroup); + mma_op.mma(Xs, Ws); + } else { + for (int k = 0; k < K; k += BK) { + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_x.load_unsafe(); + loader_w.load_unsafe(); + threadgroup_barrier(mem_flags::mem_threadgroup); + mma_op.mma(Xs, Ws); + loader_x.next(); + loader_w.next(); + } + } + } + + // Store results to device memory + threadgroup_barrier(mem_flags::mem_threadgroup); + if (num_els < BM) { + mma_op.store_result_safe(y, N, short2(BN, num_els)); + } else { + mma_op.store_result(y, N); + } +} + +template +METAL_FUNC void adjust_matrix_offsets( + const device T*& x, + const device uint32_t*& w, + const device S*& scales, + device T*& y, + int output_stride, + const constant int& x_batch_ndims, + const constant int* x_shape, + const constant int64_t* x_strides, + const constant int& w_batch_ndims, + const constant int* w_shape, + const constant int64_t* w_strides, + const constant int64_t* s_strides, + uint3 tid [[threadgroup_position_in_grid]]) { + // Set the input/output matrices + uint32_t x_idx = tid.z; + uint32_t w_idx = tid.z; + if (x_batch_ndims == 1) { + x += x_idx * x_strides[0]; + } else { + x += elem_to_loc(x_idx, x_shape, x_strides, x_batch_ndims); + } + if (w_batch_ndims == 1) { + w += w_idx * w_strides[0]; + scales += w_idx * s_strides[0]; + } else { + ulong2 idx = elem_to_loc_broadcast( + w_idx, w_shape, w_strides, s_strides, w_batch_ndims); + w += idx.x; + scales += idx.y; + } + y += tid.z * output_stride; +} + +template +METAL_FUNC void adjust_matrix_offsets( + const device T*& x, + const device uint32_t*& w, + const device S*& scales, + const device uint32_t* lhs_indices, + const device uint32_t* rhs_indices, + device T*& y, + int output_stride, + const constant int& batch_ndims, + const constant int* batch_shape, + const constant int64_t* lhs_strides, + const constant int64_t* rhs_strides, + const constant int& x_batch_ndims, + const constant int* x_shape, + const constant int64_t* x_strides, + const constant int& w_batch_ndims, + const constant int* w_shape, + const constant int64_t* w_strides, + const constant int64_t* s_strides, + uint3 tid [[threadgroup_position_in_grid]]) { + // Set the input/output matrices + uint32_t x_idx; + uint32_t w_idx; + if (batch_ndims == 1) { + x_idx = lhs_indices[tid.z * lhs_strides[0]]; + w_idx = rhs_indices[tid.z * rhs_strides[0]]; + } else { + ulong2 idx = elem_to_loc_broadcast( + tid.z, batch_shape, lhs_strides, rhs_strides, batch_ndims); + x_idx = lhs_indices[idx.x]; + w_idx = rhs_indices[idx.y]; + } + if (x_batch_ndims == 1) { + x += x_idx * x_strides[0]; + } else { + x += elem_to_loc(x_idx, x_shape, x_strides, x_batch_ndims); + } + if (w_batch_ndims == 1) { + w += w_idx * w_strides[0]; + scales += w_idx * s_strides[0]; + } else { + ulong2 idx = elem_to_loc_broadcast( + w_idx, w_shape, w_strides, s_strides, w_batch_ndims); + w += idx.x; + scales += idx.y; + } + y += tid.z * output_stride; +} + +template +[[kernel]] void mxfp4_qmv_quad( + const device uint32_t* w, + const device S* scales, + const device T* x, + device T* y, + const constant int& in_vec_size, + const constant int& out_vec_size, + const constant int& x_batch_ndims, + const constant int* x_shape, + const constant int64_t* x_strides, + const constant int& w_batch_ndims, + const constant int* w_shape, + const constant int64_t* w_strides, + const constant int64_t* s_strides, + uint3 tid [[threadgroup_position_in_grid]], + uint quad_gid [[quadgroup_index_in_threadgroup]], + uint quad_lid [[thread_index_in_quadgroup]]) { + if (batched) { + int M = x_shape[x_batch_ndims]; + adjust_matrix_offsets( + x, + w, + scales, + y, + out_vec_size * M, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + s_strides, + tid); + } + mxfp4_qmv_quad_impl( + w, scales, x, y, in_vec_size, out_vec_size, tid, quad_gid, quad_lid); +} + +template +[[kernel]] void mxfp4_qmv_fast( + const device uint32_t* w, + const device S* scales, + const device T* x, + device T* y, + const constant int& in_vec_size, + const constant int& out_vec_size, + const constant int& x_batch_ndims, + const constant int* x_shape, + const constant int64_t* x_strides, + const constant int& w_batch_ndims, + const constant int* w_shape, + const constant int64_t* w_strides, + const constant int64_t* s_strides, + uint3 tid [[threadgroup_position_in_grid]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + if (batched) { + int M = x_shape[x_batch_ndims]; + adjust_matrix_offsets( + x, + w, + scales, + y, + out_vec_size * M, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + s_strides, + tid); + } + mxfp4_qmv_fast_impl( + w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid); +} + +template +[[kernel]] void mxfp4_qmv( + const device uint32_t* w, + const device S* scales, + const device T* x, + device T* y, + const constant int& in_vec_size, + const constant int& out_vec_size, + const constant int& x_batch_ndims, + const constant int* x_shape, + const constant int64_t* x_strides, + const constant int& w_batch_ndims, + const constant int* w_shape, + const constant int64_t* w_strides, + const constant int64_t* s_strides, + uint3 tid [[threadgroup_position_in_grid]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + if (batched) { + int M = x_shape[x_batch_ndims]; + adjust_matrix_offsets( + x, + w, + scales, + y, + out_vec_size * M, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + s_strides, + tid); + } + mxfp4_qmv_impl( + w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid); +} + +template +[[kernel]] void mxfp4_qvm( + const device uint32_t* w, + const device S* scales, + const device T* x, + device T* y, + const constant int& in_vec_size, + const constant int& out_vec_size, + const constant int& x_batch_ndims, + const constant int* x_shape, + const constant int64_t* x_strides, + const constant int& w_batch_ndims, + const constant int* w_shape, + const constant int64_t* w_strides, + const constant int64_t* s_strides, + uint3 tid [[threadgroup_position_in_grid]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + if (batched) { + int M = x_shape[x_batch_ndims]; + adjust_matrix_offsets( + x, + w, + scales, + y, + out_vec_size * M, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + s_strides, + tid); + } + mxfp4_qvm_impl( + w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid); +} + +template +[[kernel]] void mxfp4_qvm_split_k( + const device uint32_t* w, + const device S* scales, + const device T* x, + device T* y, + const constant int& in_vec_size, + const constant int& out_vec_size, + const constant int& x_batch_ndims, + const constant int* x_shape, + const constant int64_t* x_strides, + const constant int& w_batch_ndims, + const constant int* w_shape, + const constant int64_t* w_strides, + const constant int64_t* s_strides, + const constant int& final_block_size, + uint3 tid [[threadgroup_position_in_grid]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + int M = x_shape[x_batch_ndims]; + adjust_matrix_offsets( + x, + w, + scales, + y, + out_vec_size * M, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + s_strides, + tid); + + // When (in_vec_size % split_k != 0) the final block needs to be smaller + int in_vec_size_adj = + tid.z % split_k == split_k - 1 ? final_block_size : in_vec_size; + + mxfp4_qvm_impl( + w, scales, x, y, in_vec_size_adj, out_vec_size, tid, simd_gid, simd_lid); +} + +template < + typename T, + const int group_size, + const bool aligned_N, + const bool batched, + typename S, + const int BM = 32, + const int BK = 32, + const int BN = 32> +[[kernel]] void mxfp4_qmm_t( + const device uint32_t* w, + const device S* scales, + const device T* x, + device T* y, + const constant int& K, + const constant int& N, + const constant int& M, + const constant int& x_batch_ndims, + const constant int* x_shape, + const constant int64_t* x_strides, + const constant int& w_batch_ndims, + const constant int* w_shape, + const constant int64_t* w_strides, + const constant int64_t* s_strides, + uint3 tid [[threadgroup_position_in_grid]], + uint lid [[thread_index_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + (void)lid; + + constexpr int BK_padded = (BK + 16 / sizeof(T)); + + threadgroup T Xs[BM * BK_padded]; + threadgroup T Ws[BN * BK_padded]; + + if (batched) { + adjust_matrix_offsets( + x, + w, + scales, + y, + M * N, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + s_strides, + tid); + } + mxfp4_qmm_t_impl( + w, scales, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid); +} + +template < + typename T, + const int group_size, + const bool batched, + typename S, + const int BM = 32, + const int BK = 32, + const int BN = 32> +[[kernel]] void mxfp4_qmm_n( + const device uint32_t* w, + const device S* scales, + const device T* x, + device T* y, + const constant int& K, + const constant int& N, + const constant int& M, + const constant int& x_batch_ndims, + const constant int* x_shape, + const constant int64_t* x_strides, + const constant int& w_batch_ndims, + const constant int* w_shape, + const constant int64_t* w_strides, + const constant int64_t* s_strides, + uint3 tid [[threadgroup_position_in_grid]], + uint lid [[thread_index_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + (void)lid; + + constexpr int BK_padded = (BK + 16 / sizeof(T)); + constexpr int BN_padded = (BN + 16 / sizeof(T)); + + threadgroup T Xs[BM * BK_padded]; + threadgroup T Ws[BK * BN_padded]; + + if (batched) { + adjust_matrix_offsets( + x, + w, + scales, + y, + M * N, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + s_strides, + tid); + } + + mxfp4_qmm_n_impl( + w, scales, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid); +} + +template +[[kernel]] void mxfp4_gather_qmv_fast( + const device uint32_t* w, + const device S* scales, + const device T* x, + const device uint32_t* lhs_indices, + const device uint32_t* rhs_indices, + device T* y, + const constant int& in_vec_size, + const constant int& out_vec_size, + const constant int& x_batch_ndims, + const constant int* x_shape, + const constant int64_t* x_strides, + const constant int& w_batch_ndims, + const constant int* w_shape, + const constant int64_t* w_strides, + const constant int64_t* s_strides, + const constant int& batch_ndims, + const constant int* batch_shape, + const constant int64_t* lhs_strides, + const constant int64_t* rhs_strides, + uint3 tid [[threadgroup_position_in_grid]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + int M = x_shape[x_batch_ndims]; + adjust_matrix_offsets( + x, + w, + scales, + lhs_indices, + rhs_indices, + y, + out_vec_size * M, + batch_ndims, + batch_shape, + lhs_strides, + rhs_strides, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + s_strides, + tid); + mxfp4_qmv_fast_impl( + w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid); +} + +template +[[kernel]] void mxfp4_gather_qmv( + const device uint32_t* w, + const device S* scales, + const device T* x, + const device uint32_t* lhs_indices, + const device uint32_t* rhs_indices, + device T* y, + const constant int& in_vec_size, + const constant int& out_vec_size, + const constant int& x_batch_ndims, + const constant int* x_shape, + const constant int64_t* x_strides, + const constant int& w_batch_ndims, + const constant int* w_shape, + const constant int64_t* w_strides, + const constant int64_t* s_strides, + const constant int& batch_ndims, + const constant int* batch_shape, + const constant int64_t* lhs_strides, + const constant int64_t* rhs_strides, + uint3 tid [[threadgroup_position_in_grid]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + int M = x_shape[x_batch_ndims]; + adjust_matrix_offsets( + x, + w, + scales, + lhs_indices, + rhs_indices, + y, + out_vec_size * M, + batch_ndims, + batch_shape, + lhs_strides, + rhs_strides, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + s_strides, + tid); + mxfp4_qmv_impl( + w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid); +} + +template +[[kernel]] void mxfp4_gather_qvm( + const device uint32_t* w, + const device S* scales, + const device T* x, + const device uint32_t* lhs_indices, + const device uint32_t* rhs_indices, + device T* y, + const constant int& in_vec_size, + const constant int& out_vec_size, + const constant int& x_batch_ndims, + const constant int* x_shape, + const constant int64_t* x_strides, + const constant int& w_batch_ndims, + const constant int* w_shape, + const constant int64_t* w_strides, + const constant int64_t* s_strides, + const constant int& batch_ndims, + const constant int* batch_shape, + const constant int64_t* lhs_strides, + const constant int64_t* rhs_strides, + uint3 tid [[threadgroup_position_in_grid]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + int M = x_shape[x_batch_ndims]; + adjust_matrix_offsets( + x, + w, + scales, + lhs_indices, + rhs_indices, + y, + out_vec_size * M, + batch_ndims, + batch_shape, + lhs_strides, + rhs_strides, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + s_strides, + tid); + mxfp4_qvm_impl( + w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid); +} + +template < + typename T, + const int group_size, + const bool aligned_N, + typename S, + const int BM = 32, + const int BK = 32, + const int BN = 32> +[[kernel]] void mxfp4_gather_qmm_t( + const device uint32_t* w, + const device S* scales, + const device T* x, + const device uint32_t* lhs_indices, + const device uint32_t* rhs_indices, + device T* y, + const constant int& K, + const constant int& N, + const constant int& M, + const constant int& x_batch_ndims, + const constant int* x_shape, + const constant int64_t* x_strides, + const constant int& w_batch_ndims, + const constant int* w_shape, + const constant int64_t* w_strides, + const constant int64_t* s_strides, + const constant int& batch_ndims, + const constant int* batch_shape, + const constant int64_t* lhs_strides, + const constant int64_t* rhs_strides, + uint3 tid [[threadgroup_position_in_grid]], + uint lid [[thread_index_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + (void)lid; + + constexpr int BK_padded = (BK + 16 / sizeof(T)); + + threadgroup T Xs[BM * BK_padded]; + threadgroup T Ws[BN * BK_padded]; + + adjust_matrix_offsets( + x, + w, + scales, + lhs_indices, + rhs_indices, + y, + M * N, + batch_ndims, + batch_shape, + lhs_strides, + rhs_strides, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + s_strides, + tid); + mxfp4_qmm_t_impl( + w, scales, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid); +} + +template < + typename T, + const int group_size, + typename S, + const int BM = 32, + const int BK = 32, + const int BN = 32> +[[kernel]] void mxfp4_gather_qmm_n( + const device uint32_t* w, + const device S* scales, + const device T* x, + const device uint32_t* lhs_indices, + const device uint32_t* rhs_indices, + device T* y, + const constant int& K, + const constant int& N, + const constant int& M, + const constant int& x_batch_ndims, + const constant int* x_shape, + const constant int64_t* x_strides, + const constant int& w_batch_ndims, + const constant int* w_shape, + const constant int64_t* w_strides, + const constant int64_t* s_strides, + const constant int& batch_ndims, + const constant int* batch_shape, + const constant int64_t* lhs_strides, + const constant int64_t* rhs_strides, + uint3 tid [[threadgroup_position_in_grid]], + uint lid [[thread_index_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + (void)lid; + + constexpr int BK_padded = (BK + 16 / sizeof(T)); + constexpr int BN_padded = (BN + 16 / sizeof(T)); + + threadgroup T Xs[BM * BK_padded]; + threadgroup T Ws[BK * BN_padded]; + + adjust_matrix_offsets( + x, + w, + scales, + lhs_indices, + rhs_indices, + y, + M * N, + batch_ndims, + batch_shape, + lhs_strides, + rhs_strides, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + s_strides, + tid); + mxfp4_qmm_n_impl( + w, scales, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid); +} + +template +METAL_FUNC void gemm_loop_aligned( + threadgroup T* As, + threadgroup T* Bs, + thread mma_t& mma_op, + thread loader_a_t& loader_a, + thread loader_b_t& loader_b, + const int k_iterations) { + for (int k = 0; k < k_iterations; k++) { + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Load elements into threadgroup memory + loader_a.load_unsafe(); + loader_b.load_unsafe(); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Multiply and accumulate threadgroup elements + mma_op.mma(As, Bs); + + // Prepare for next iteration + loader_a.next(); + loader_b.next(); + } +} + +template < + bool rows_aligned, + bool cols_aligned, + bool transpose, + typename T, + typename mma_t, + typename loader_a_t, + typename loader_b_t> +METAL_FUNC void gemm_loop_unaligned( + threadgroup T* As, + threadgroup T* Bs, + thread mma_t& mma_op, + thread loader_a_t& loader_a, + thread loader_b_t& loader_b, + const int k_iterations, + const short tgp_bm, + const short tgp_bn, + const short tgp_bk) { + for (int k = 0; k < k_iterations; k++) { + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Load elements into threadgroup memory + if (rows_aligned) { + loader_a.load_unsafe(); + } else { + loader_a.load_safe(short2(tgp_bk, tgp_bm)); + } + if (cols_aligned) { + loader_b.load_unsafe(); + } else { + loader_b.load_safe( + transpose ? short2(tgp_bk, tgp_bn) : short2(tgp_bn, tgp_bk)); + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Multiply and accumulate threadgroup elements + mma_op.mma(As, Bs); + + // Prepare for next iteration + loader_a.next(); + loader_b.next(); + } +} + +template +METAL_FUNC void gemm_loop_finalize( + threadgroup T* As, + threadgroup T* Bs, + thread mma_t& mma_op, + thread loader_a_t& loader_a, + thread loader_b_t& loader_b, + const short2 tile_a, + const short2 tile_b) { + loader_a.load_safe(tile_a); + loader_b.load_safe(tile_b); + threadgroup_barrier(mem_flags::mem_threadgroup); + mma_op.mma(As, Bs); +} + +template < + typename T, + int group_size, + typename S, + int BM, + int BN, + int BK, + int WM, + int WN, + bool transpose> +[[kernel]] void mxfp4_gather_qmm_rhs( + const device T* x, + const device uint32_t* w, + const device S* scales, + const device uint32_t* indices, + device T* y, + const constant int& M, + const constant int& N, + const constant int& K, + uint3 tid [[threadgroup_position_in_grid]], + uint simd_group_id [[simdgroup_index_in_threadgroup]], + uint simd_lane_id [[thread_index_in_simdgroup]]) { + constexpr int pack_factor = get_pack_factor<8>(); + constexpr int bytes_per_pack = get_bytes_per_pack(); + constexpr int BK_padded = (BK + 16 / sizeof(T)); + constexpr int BN_padded = (BN + 16 / sizeof(T)); + + using mma_t = mlx::steel::BlockMMA< + T, + T, + BM, + BN, + BK, + WM, + WN, + false, + transpose, + BK_padded, + transpose ? BK_padded : BN_padded>; + using loader_x_t = + mlx::steel::BlockLoader; + using loader_w_t = QuantizedBlockLoader< + T, + transpose ? BN : BK, + transpose ? BK : BN, + transpose ? BK_padded : BN_padded, + transpose, + WM * WN * SIMD_SIZE, + group_size, + S>; + + threadgroup T Xs[BM * BK_padded]; + threadgroup T Ws[transpose ? BN * BK_padded : BK * BN_padded]; + + // Compute the block + const int K_w = K * bytes_per_pack / pack_factor; + const int K_g = K / group_size; + const int N_w = N * bytes_per_pack / pack_factor; + const int N_g = N / group_size; + const int K_it = K / BK; + const size_t stride_w = transpose ? N * K_w : K * N_w; + const size_t stride_s = transpose ? N * K_g : K * N_g; + const int y_row = tid.y * BM; + const int y_col = tid.x * BN; + const size_t y_row_long = size_t(y_row); + const size_t y_col_long = size_t(y_col); + + // Prepare threadgroup bounds + const short tgp_bm = align_M ? BM : short(min(BM, M - y_row)); + const short tgp_bn = align_N ? BN : short(min(BN, N - y_col)); + + // Calculate the final tiles in the case that K is not aligned + const int k_remain = K - K_it * BK; + const short2 tile_x = short2(k_remain, tgp_bm); + const short2 tile_w = + transpose ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain); + + // Move x and output to the correct block + auto wl = (const device uint8_t*)w; + x += y_row_long * K; + y += y_row_long * N + y_col_long; + wl += transpose ? y_col_long * K_w : y_col * bytes_per_pack / pack_factor; + scales += transpose ? y_col_long * K_g : y_col / group_size; + + // Do as many matmuls as necessary + uint32_t index; + short offset; + uint32_t index_next = indices[y_row]; + short offset_next = 0; + int n = 0; + while (n < tgp_bm) { + n++; + offset = offset_next; + index = index_next; + offset_next = tgp_bm; + for (; n < tgp_bm; n++) { + if (indices[y_row + n] != index) { + offset_next = n; + index_next = indices[y_row + n]; + break; + } + } + threadgroup_barrier(mem_flags::mem_none); + + // Prepare threadgroup mma operation + thread mma_t mma_op(simd_group_id, simd_lane_id); + + // Prepare threadgroup loading operations + thread loader_x_t loader_x(x, K, Xs, simd_group_id, simd_lane_id); + thread loader_w_t loader_w( + wl + index * stride_w, + scales + index * stride_s, + transpose ? K : N, + Ws, + simd_group_id, + simd_lane_id); + + // Matrices are all aligned check nothing + if (align_M && align_N) { + gemm_loop_aligned(Xs, Ws, mma_op, loader_x, loader_w, K_it); + if (!align_K) { + threadgroup_barrier(mem_flags::mem_threadgroup); + gemm_loop_finalize(Xs, Ws, mma_op, loader_x, loader_w, tile_x, tile_w); + } + + // Store results to device memory + if (offset_next - offset == BM) { + mma_op.store_result(y, N); + } else { + mma_op.store_result_slice( + y, N, short2(0, offset), short2(BN, offset_next)); + } + } else { + // Tile aligned so check outside of the hot loop + if ((align_M || tgp_bm == BM) && (align_N || tgp_bn == BN)) { + gemm_loop_aligned(Xs, Ws, mma_op, loader_x, loader_w, K_it); + if (!align_K) { + threadgroup_barrier(mem_flags::mem_threadgroup); + gemm_loop_finalize( + Xs, Ws, mma_op, loader_x, loader_w, tile_x, tile_w); + } + + // Store results to device memory + if (offset_next - offset == BM) { + mma_op.store_result(y, N); + } else { + mma_op.store_result_slice( + y, N, short2(0, offset), short2(BN, offset_next)); + } + } + + // Tile partially aligned check rows + else if (align_N || tgp_bn == BN) { + gemm_loop_unaligned( + Xs, Ws, mma_op, loader_x, loader_w, K_it, tgp_bm, tgp_bn, BK); + if (!align_K) { + threadgroup_barrier(mem_flags::mem_threadgroup); + gemm_loop_finalize( + Xs, Ws, mma_op, loader_x, loader_w, tile_x, tile_w); + } + mma_op.store_result_slice( + y, N, short2(0, offset), short2(BN, offset_next)); + } + + // Tile partially aligned check cols + else if (align_M || tgp_bm == BM) { + gemm_loop_unaligned( + Xs, Ws, mma_op, loader_x, loader_w, K_it, tgp_bm, tgp_bn, BK); + if (!align_K) { + threadgroup_barrier(mem_flags::mem_threadgroup); + gemm_loop_finalize( + Xs, Ws, mma_op, loader_x, loader_w, tile_x, tile_w); + } + mma_op.store_result_slice( + y, N, short2(0, offset), short2(tgp_bn, offset_next)); + } + + // Nothing aligned so check both rows and cols + else { + gemm_loop_unaligned( + Xs, Ws, mma_op, loader_x, loader_w, K_it, tgp_bm, tgp_bn, BK); + if (!align_K) { + threadgroup_barrier(mem_flags::mem_threadgroup); + gemm_loop_finalize( + Xs, Ws, mma_op, loader_x, loader_w, tile_x, tile_w); + } + mma_op.store_result_slice( + y, N, short2(0, offset), short2(tgp_bn, offset_next)); + } + } + } +} diff --git a/mlx/backend/metal/kernels/fp4_quantized.metal b/mlx/backend/metal/kernels/fp4_quantized.metal new file mode 100644 index 000000000..9a817d20d --- /dev/null +++ b/mlx/backend/metal/kernels/fp4_quantized.metal @@ -0,0 +1,126 @@ +// Copyright © 2025 Apple Inc. + +// clang-format off +#include "mlx/backend/metal/kernels/utils.h" +#include "mlx/backend/metal/kernels/steel/gemm/gemm.h" +#include "mlx/backend/metal/kernels/fp4_quantized.h" + +#define instantiate_quantized(name, type) \ + instantiate_kernel( \ + #name "_" #type "_gs_32_b_4", \ + name, \ + type, \ + 32, \ + uint8_t) + +#define instantiate_quantized_batched(name, type, batched) \ + instantiate_kernel( \ + #name "_" #type "_gs_32_b_4_batch_" #batched, \ + name, \ + type, \ + 32, \ + batched, \ + uint8_t) + +#define instantiate_quantized_aligned(name, type, aligned) \ + instantiate_kernel( \ + #name "_" #type "_gs_32_b_4_alN_" #aligned, \ + name, \ + type, \ + 32, \ + aligned, \ + uint8_t) + +#define instantiate_quantized_aligned_batched(name, type, aligned, batched) \ + instantiate_kernel( \ + #name "_" #type "_gs_32_b_4_alN_" #aligned "_batch_" #batched, \ + name, \ + type, \ + 32, \ + aligned, \ + batched, \ + uint8_t) + +#define instantiate_quantized_quad(name, type, D, batched) \ + instantiate_kernel( \ + #name "_" #type "_gs_32_b_4_d_" #D "_batch_" #batched, \ + name, \ + type, \ + 32, \ + D, \ + batched, \ + uint8_t) + +#define instantiate_quantized_split_k(name, type, split_k) \ + instantiate_kernel( \ + #name "_" #type "_gs_32_b_4_spk_" #split_k, \ + name, \ + type, \ + 32, \ + split_k, \ + uint8_t) + +#define instantiate_gather_qmm_rhs(func, name, type, bm, bn, bk, wm, wn, transpose) \ + instantiate_kernel( \ + #name "_" #type "_gs_32_b_4_bm_" #bm "_bn_" #bn "_bk_" #bk "_wm_" #wm "_wn_" #wn, \ + func, \ + type, \ + 32, \ + uint8_t, \ + bm, \ + bn, \ + bk, \ + wm, \ + wn, \ + transpose) + +#define instantiate_quantized_batched_wrap(name, type) \ + instantiate_quantized_batched(name, type, 1) \ + instantiate_quantized_batched(name, type, 0) + +#define instantiate_quantized_all_batched(type) \ + instantiate_quantized_batched_wrap(mxfp4_qmv_fast, type) \ + instantiate_quantized_batched_wrap(mxfp4_qmv, type) \ + instantiate_quantized_batched_wrap(mxfp4_qvm, type) \ + instantiate_quantized_batched_wrap(mxfp4_qmm_n, type) + +#define instantiate_quantized_all_single(type) \ + instantiate_quantized(mxfp4_gather_qmv_fast, type) \ + instantiate_quantized(mxfp4_gather_qmv, type) \ + instantiate_quantized(mxfp4_gather_qvm, type) \ + instantiate_quantized(mxfp4_gather_qmm_n, type) + +#define instantiate_quantized_all_aligned(type) \ + instantiate_quantized_aligned(mxfp4_gather_qmm_t, type, true) \ + instantiate_quantized_aligned(mxfp4_gather_qmm_t, type, false) \ + instantiate_quantized_aligned_batched(mxfp4_qmm_t, type, true, 1) \ + instantiate_quantized_aligned_batched(mxfp4_qmm_t, type, true, 0) \ + instantiate_quantized_aligned_batched(mxfp4_qmm_t, type, false, 1) \ + instantiate_quantized_aligned_batched(mxfp4_qmm_t, type, false, 0) + +#define instantiate_quantized_all_quad(type) \ + instantiate_quantized_quad(mxfp4_qmv_quad, type, 64, 1) \ + instantiate_quantized_quad(mxfp4_qmv_quad, type, 64, 0) \ + instantiate_quantized_quad(mxfp4_qmv_quad, type, 128, 1) \ + instantiate_quantized_quad(mxfp4_qmv_quad, type, 128, 0) + +#define instantiate_quantized_all_splitk(type) \ + instantiate_quantized_split_k(mxfp4_qvm_split_k, type, 8) \ + instantiate_quantized_split_k(mxfp4_qvm_split_k, type, 32) + +#define instantiate_quantized_all_rhs(type) \ + instantiate_gather_qmm_rhs(mxfp4_gather_qmm_rhs, mxfp4_gather_qmm_rhs_nt, type, 16, 32, 32, 1, 2, true) \ + instantiate_gather_qmm_rhs(mxfp4_gather_qmm_rhs, mxfp4_gather_qmm_rhs_nn, type, 16, 32, 32, 1, 2, false) + +#define instantiate_quantized_types(type) \ + instantiate_quantized_all_batched(type) \ + instantiate_quantized_all_quad(type) \ + instantiate_quantized_all_splitk(type) \ + instantiate_quantized_all_single(type) \ + instantiate_quantized_all_aligned(type) \ + instantiate_quantized_all_rhs(type) + +instantiate_quantized_types(float) +instantiate_quantized_types(bfloat16_t) +instantiate_quantized_types(float16_t) + // clang-format on diff --git a/mlx/backend/metal/kernels/quantized.h b/mlx/backend/metal/kernels/quantized.h index 0a40cec00..09a73f887 100644 --- a/mlx/backend/metal/kernels/quantized.h +++ b/mlx/backend/metal/kernels/quantized.h @@ -1434,7 +1434,7 @@ METAL_FUNC void adjust_matrix_offsets( } template -[[kernel]] void qmv_quad( +[[kernel]] void affine_qmv_quad( const device uint32_t* w [[buffer(0)]], const device T* scales [[buffer(1)]], const device T* biases [[buffer(2)]], @@ -1486,7 +1486,7 @@ template } template -[[kernel]] void qmv_fast( +[[kernel]] void affine_qmv_fast( const device uint32_t* w [[buffer(0)]], const device T* scales [[buffer(1)]], const device T* biases [[buffer(2)]], @@ -1538,7 +1538,7 @@ template } template -[[kernel]] void qmv( +[[kernel]] void affine_qmv( const device uint32_t* w [[buffer(0)]], const device T* scales [[buffer(1)]], const device T* biases [[buffer(2)]], @@ -1590,7 +1590,7 @@ template } template -[[kernel]] void qvm( +[[kernel]] void affine_qvm( const device uint32_t* w [[buffer(0)]], const device T* scales [[buffer(1)]], const device T* biases [[buffer(2)]], @@ -1642,7 +1642,7 @@ template } template -[[kernel]] void qvm_split_k( +[[kernel]] void affine_qvm_split_k( const device uint32_t* w [[buffer(0)]], const device T* scales [[buffer(1)]], const device T* biases [[buffer(2)]], @@ -1706,7 +1706,7 @@ template < const int BM = 32, const int BK = 32, const int BN = 32> -[[kernel]] void qmm_t( +[[kernel]] void affine_qmm_t( const device uint32_t* w [[buffer(0)]], const device T* scales [[buffer(1)]], const device T* biases [[buffer(2)]], @@ -1764,7 +1764,7 @@ template < const int BM = 32, const int BK = 32, const int BN = 32> -[[kernel]] void qmm_n( +[[kernel]] void affine_qmm_n( const device uint32_t* w [[buffer(0)]], const device T* scales [[buffer(1)]], const device T* biases [[buffer(2)]], @@ -1817,7 +1817,7 @@ template < } template -[[kernel]] void gather_qmv_fast( +[[kernel]] void affine_gather_qmv_fast( const device uint32_t* w [[buffer(0)]], const device T* scales [[buffer(1)]], const device T* biases [[buffer(2)]], @@ -1879,7 +1879,7 @@ template } template -[[kernel]] void gather_qmv( +[[kernel]] void affine_gather_qmv( const device uint32_t* w [[buffer(0)]], const device T* scales [[buffer(1)]], const device T* biases [[buffer(2)]], @@ -1941,7 +1941,7 @@ template } template -[[kernel]] void gather_qvm( +[[kernel]] void affine_gather_qvm( const device uint32_t* w [[buffer(0)]], const device T* scales [[buffer(1)]], const device T* biases [[buffer(2)]], @@ -2010,7 +2010,7 @@ template < const int BM = 32, const int BK = 32, const int BN = 32> -[[kernel]] void gather_qmm_t( +[[kernel]] void affine_gather_qmm_t( const device uint32_t* w [[buffer(0)]], const device T* scales [[buffer(1)]], const device T* biases [[buffer(2)]], @@ -2077,7 +2077,7 @@ template < const int BM = 32, const int BK = 32, const int BN = 32> -[[kernel]] void gather_qmm_n( +[[kernel]] void affine_gather_qmm_n( const device uint32_t* w [[buffer(0)]], const device T* scales [[buffer(1)]], const device T* biases [[buffer(2)]], @@ -2234,7 +2234,7 @@ template < int WM, int WN, bool transpose> -[[kernel]] void gather_qmm_rhs( +[[kernel]] void affine_gather_qmm_rhs( const device T* x [[buffer(0)]], const device uint32_t* w [[buffer(1)]], const device T* scales [[buffer(2)]], diff --git a/mlx/backend/metal/kernels/quantized.metal b/mlx/backend/metal/kernels/quantized.metal index de83cb657..56e8e4ca5 100644 --- a/mlx/backend/metal/kernels/quantized.metal +++ b/mlx/backend/metal/kernels/quantized.metal @@ -79,40 +79,40 @@ instantiate_quantized_batched(name, type, group_size, bits, 0) #define instantiate_quantized_all_batched(type, group_size, bits) \ - instantiate_quantized_batched_wrap(qmv_fast, type, group_size, bits) \ - instantiate_quantized_batched_wrap(qmv, type, group_size, bits) \ - instantiate_quantized_batched_wrap(qvm, type, group_size, bits) \ - instantiate_quantized_batched_wrap(qmm_n, type, group_size, bits) + instantiate_quantized_batched_wrap(affine_qmv_fast, type, group_size, bits) \ + instantiate_quantized_batched_wrap(affine_qmv, type, group_size, bits) \ + instantiate_quantized_batched_wrap(affine_qvm, type, group_size, bits) \ + instantiate_quantized_batched_wrap(affine_qmm_n, type, group_size, bits) #define instantiate_quantized_all_single(type, group_size, bits) \ instantiate_quantized(affine_quantize, type, group_size, bits) \ instantiate_quantized(affine_dequantize, type, group_size, bits) \ - instantiate_quantized(gather_qmv_fast, type, group_size, bits) \ - instantiate_quantized(gather_qmv, type, group_size, bits) \ - instantiate_quantized(gather_qvm, type, group_size, bits) \ - instantiate_quantized(gather_qmm_n, type, group_size, bits) + instantiate_quantized(affine_gather_qmv_fast, type, group_size, bits) \ + instantiate_quantized(affine_gather_qmv, type, group_size, bits) \ + instantiate_quantized(affine_gather_qvm, type, group_size, bits) \ + instantiate_quantized(affine_gather_qmm_n, type, group_size, bits) #define instantiate_quantized_all_aligned(type, group_size, bits) \ - instantiate_quantized_aligned(gather_qmm_t, type, group_size, bits, true) \ - instantiate_quantized_aligned(gather_qmm_t, type, group_size, bits, false) \ - instantiate_quantized_aligned_batched(qmm_t, type, group_size, bits, true, 1) \ - instantiate_quantized_aligned_batched(qmm_t, type, group_size, bits, true, 0) \ - instantiate_quantized_aligned_batched(qmm_t, type, group_size, bits, false, 1) \ - instantiate_quantized_aligned_batched(qmm_t, type, group_size, bits, false, 0) + instantiate_quantized_aligned(affine_gather_qmm_t, type, group_size, bits, true) \ + instantiate_quantized_aligned(affine_gather_qmm_t, type, group_size, bits, false) \ + instantiate_quantized_aligned_batched(affine_qmm_t, type, group_size, bits, true, 1) \ + instantiate_quantized_aligned_batched(affine_qmm_t, type, group_size, bits, true, 0) \ + instantiate_quantized_aligned_batched(affine_qmm_t, type, group_size, bits, false, 1) \ + instantiate_quantized_aligned_batched(affine_qmm_t, type, group_size, bits, false, 0) #define instantiate_quantized_all_quad(type, group_size, bits) \ - instantiate_quantized_quad(qmv_quad, type, group_size, bits, 64, 1) \ - instantiate_quantized_quad(qmv_quad, type, group_size, bits, 64, 0) \ - instantiate_quantized_quad(qmv_quad, type, group_size, bits, 128, 1) \ - instantiate_quantized_quad(qmv_quad, type, group_size, bits, 128, 0) + instantiate_quantized_quad(affine_qmv_quad, type, group_size, bits, 64, 1) \ + instantiate_quantized_quad(affine_qmv_quad, type, group_size, bits, 64, 0) \ + instantiate_quantized_quad(affine_qmv_quad, type, group_size, bits, 128, 1) \ + instantiate_quantized_quad(affine_qmv_quad, type, group_size, bits, 128, 0) #define instantiate_quantized_all_splitk(type, group_size, bits) \ - instantiate_quantized_split_k(qvm_split_k, type, group_size, bits, 8) \ - instantiate_quantized_split_k(qvm_split_k, type, group_size, bits, 32) + instantiate_quantized_split_k(affine_qvm_split_k, type, group_size, bits, 8) \ + instantiate_quantized_split_k(affine_qvm_split_k, type, group_size, bits, 32) #define instantiate_quantized_all_rhs(type, group_size, bits) \ - instantiate_gather_qmm_rhs(gather_qmm_rhs, gather_qmm_rhs_nt, type, group_size, bits, 16, 32, 32, 1, 2, true) \ - instantiate_gather_qmm_rhs(gather_qmm_rhs, gather_qmm_rhs_nn, type, group_size, bits, 16, 32, 32, 1, 2, false) + instantiate_gather_qmm_rhs(affine_gather_qmm_rhs, affine_gather_qmm_rhs_nt, type, group_size, bits, 16, 32, 32, 1, 2, true) \ + instantiate_gather_qmm_rhs(affine_gather_qmm_rhs, affine_gather_qmm_rhs_nn, type, group_size, bits, 16, 32, 32, 1, 2, false) #define instantiate_quantized_funcs(type, group_size, bits) \ instantiate_quantized_all_single(type, group_size, bits) \ diff --git a/mlx/backend/metal/quantized.cpp b/mlx/backend/metal/quantized.cpp index 999825043..903c650bc 100644 --- a/mlx/backend/metal/quantized.cpp +++ b/mlx/backend/metal/quantized.cpp @@ -99,7 +99,7 @@ inline int add_strides_and_shapes( const array& x, const array& w, const array& scales, - const array& biases, + const std::optional& biases, int offset) { if (skip) { return 0; @@ -109,16 +109,18 @@ inline int add_strides_and_shapes( int x_batch_ndims = x.ndim() - 2; int w_batch_ndims = w.ndim() - 2; - compute_encoder.set_bytes(x_batch_ndims, offset); - compute_encoder.set_vector_bytes(x.shape(), offset + 1); - compute_encoder.set_vector_bytes(x.strides(), offset + 2); - compute_encoder.set_bytes(w_batch_ndims, offset + 3); - compute_encoder.set_vector_bytes(w.shape(), offset + 4); - compute_encoder.set_vector_bytes(w.strides(), offset + 5); - compute_encoder.set_vector_bytes(scales.strides(), offset + 6); - compute_encoder.set_vector_bytes(biases.strides(), offset + 7); + compute_encoder.set_bytes(x_batch_ndims, offset++); + compute_encoder.set_vector_bytes(x.shape(), offset++); + compute_encoder.set_vector_bytes(x.strides(), offset++); + compute_encoder.set_bytes(w_batch_ndims, offset++); + compute_encoder.set_vector_bytes(w.shape(), offset++); + compute_encoder.set_vector_bytes(w.strides(), offset++); + compute_encoder.set_vector_bytes(scales.strides(), offset++); + if (biases) { + compute_encoder.set_vector_bytes(biases->strides(), offset++); + } - return 8; + return offset; } inline int add_gather_strides_and_shapes( @@ -130,12 +132,12 @@ inline int add_gather_strides_and_shapes( lhs_indices.shape(), {lhs_indices.strides(), rhs_indices.strides()}); int ndims = shape.size(); - compute_encoder.set_bytes(ndims, offset); - compute_encoder.set_vector_bytes(shape, offset + 1); - compute_encoder.set_vector_bytes(strides[0], offset + 2); - compute_encoder.set_vector_bytes(strides[1], offset + 3); + compute_encoder.set_bytes(ndims, offset++); + compute_encoder.set_vector_bytes(shape, offset++); + compute_encoder.set_vector_bytes(strides[0], offset++); + compute_encoder.set_vector_bytes(strides[1], offset++); - return 4; + return offset; } } // namespace @@ -144,7 +146,7 @@ void qmv_quad( const array& x, const array& w, const array& scales, - const array& biases, + const std::optional& biases, array& out, int group_size, int bits, @@ -152,7 +154,8 @@ void qmv_quad( int N, int K, metal::Device& d, - const Stream& s) { + const Stream& s, + const std::string& mode) { int B = out.size() / M / N; constexpr int quads_per_simd = 8; @@ -165,9 +168,10 @@ void qmv_quad( std::string kname; kname.reserve(64); std::string type_string = get_type_string(x.dtype()); + concatenate( kname, - "qmv_quad_", + mode + "_qmv_quad_", type_string, "_gs_", group_size, @@ -177,20 +181,23 @@ void qmv_quad( K, B > 1 ? "_batch_1" : "_batch_0"); auto template_def = get_template_definition( - kname, "qmv_quad", type_string, group_size, bits, K, B > 1); + kname, mode + "_qmv_quad", type_string, group_size, bits, K, B > 1); auto kernel = get_quantized_kernel(d, kname, template_def); auto& compute_encoder = d.get_command_encoder(s.index); compute_encoder.set_compute_pipeline_state(kernel); - compute_encoder.set_input_array(w, 0); - compute_encoder.set_input_array(scales, 1); - compute_encoder.set_input_array(biases, 2); - compute_encoder.set_input_array(x, 3); - compute_encoder.set_output_array(out, 4); - compute_encoder.set_bytes(K, 5); - compute_encoder.set_bytes(N, 6); - add_strides_and_shapes(compute_encoder, B <= 1, x, w, scales, biases, 7); + int c = 0; + compute_encoder.set_input_array(w, c++); + compute_encoder.set_input_array(scales, c++); + if (biases) { + compute_encoder.set_input_array(*biases, c++); + } + compute_encoder.set_input_array(x, c++); + compute_encoder.set_output_array(out, c++); + compute_encoder.set_bytes(K, c++); + compute_encoder.set_bytes(N, c++); + add_strides_and_shapes(compute_encoder, B <= 1, x, w, scales, biases, c++); compute_encoder.dispatch_threadgroups(grid_dims, group_dims); } @@ -199,7 +206,7 @@ void qmv( const array& x, const array& w, const array& scales, - const array& biases, + const std::optional& biases, array& out, int group_size, int bits, @@ -207,7 +214,8 @@ void qmv( int N, int K, metal::Device& d, - const Stream& s) { + const Stream& s, + const std::string& mode) { int B = out.size() / M / N; int bn = 8; @@ -219,9 +227,10 @@ void qmv( kname.reserve(64); std::string type_string = get_type_string(x.dtype()); bool fast = N % bn == 0 && K % 512 == 0; + concatenate( kname, - fast ? "qmv_fast_" : "qmv_", + mode + (fast ? "_qmv_fast_" : "_qmv_"), type_string, "_gs_", group_size, @@ -229,20 +238,28 @@ void qmv( bits, B > 1 ? "_batch_1" : "_batch_0"); auto template_def = get_template_definition( - kname, fast ? "qmv_fast" : "qmv", type_string, group_size, bits, B > 1); + kname, + mode + (fast ? "_qmv_fast" : "_qmv"), + type_string, + group_size, + bits, + B > 1); auto kernel = get_quantized_kernel(d, kname, template_def); auto& compute_encoder = d.get_command_encoder(s.index); compute_encoder.set_compute_pipeline_state(kernel); - compute_encoder.set_input_array(w, 0); - compute_encoder.set_input_array(scales, 1); - compute_encoder.set_input_array(biases, 2); - compute_encoder.set_input_array(x, 3); - compute_encoder.set_output_array(out, 4); - compute_encoder.set_bytes(K, 5); - compute_encoder.set_bytes(N, 6); - add_strides_and_shapes(compute_encoder, B <= 1, x, w, scales, biases, 7); + int c = 0; + compute_encoder.set_input_array(w, c++); + compute_encoder.set_input_array(scales, c++); + if (biases) { + compute_encoder.set_input_array(*biases, c++); + } + compute_encoder.set_input_array(x, c++); + compute_encoder.set_output_array(out, c++); + compute_encoder.set_bytes(K, c++); + compute_encoder.set_bytes(N, c++); + add_strides_and_shapes(compute_encoder, B <= 1, x, w, scales, biases, c); compute_encoder.dispatch_threadgroups(grid_dims, group_dims); } @@ -251,7 +268,7 @@ void qvm_split_k( const array& x, const array& w, const array& scales, - const array& biases, + const std::optional& biases, array& out, int group_size, int bits, @@ -259,7 +276,8 @@ void qvm_split_k( int N, int K, metal::Device& d, - const Stream& s) { + const Stream& s, + const std::string& mode) { int split_k = K > 8192 ? 32 : 8; int split_D = (K + split_k - 1) / split_k; int B = out.size() / M / N; @@ -283,7 +301,6 @@ void qvm_split_k( auto w_shape = w.shape(); auto w_strides = w.strides(); auto s_strides = scales.strides(); - auto b_strides = biases.strides(); // Add split_k dim with reshapes x_shape.insert(x_shape.end() - 2, split_k); @@ -297,7 +314,6 @@ void qvm_split_k( w_strides.insert(w_strides.end() - 2, split_D * w.shape(-1)); w_batch_ndims += 1; s_strides.insert(s_strides.end() - 2, split_D * scales.shape(-1)); - b_strides.insert(b_strides.end() - 2, split_D * biases.shape(-1)); int final_block_size = K - (split_k - 1) * split_D; @@ -315,7 +331,7 @@ void qvm_split_k( kname.reserve(64); concatenate( kname, - "qvm_split_k_", + mode + "_qvm_split_k_", type_string, "_gs_", group_size, @@ -324,30 +340,37 @@ void qvm_split_k( "_spk_", split_k); auto template_def = get_template_definition( - kname, "qvm_split_k", type_string, group_size, bits, split_k); + kname, mode + "_qvm_split_k", type_string, group_size, bits, split_k); // Encode and dispatch kernel auto kernel = get_quantized_kernel(d, kname, template_def); auto& compute_encoder = d.get_command_encoder(s.index); compute_encoder.set_compute_pipeline_state(kernel); - compute_encoder.set_input_array(w, 0); - compute_encoder.set_input_array(scales, 1); - compute_encoder.set_input_array(biases, 2); - compute_encoder.set_input_array(x, 3); - compute_encoder.set_output_array(intermediate, 4); - compute_encoder.set_bytes(split_D, 5); - compute_encoder.set_bytes(N, 6); + int c = 0; + compute_encoder.set_input_array(w, c++); + compute_encoder.set_input_array(scales, c++); + if (biases) { + compute_encoder.set_input_array(*biases, c++); + } + compute_encoder.set_input_array(x, c++); + compute_encoder.set_output_array(intermediate, c++); + compute_encoder.set_bytes(split_D, c++); + compute_encoder.set_bytes(N, c++); - compute_encoder.set_bytes(x_batch_ndims, 7); - compute_encoder.set_vector_bytes(x_shape, 8); - compute_encoder.set_vector_bytes(x_strides, 9); - compute_encoder.set_bytes(w_batch_ndims, 10); - compute_encoder.set_vector_bytes(w_shape, 11); - compute_encoder.set_vector_bytes(w_strides, 12); - compute_encoder.set_vector_bytes(s_strides, 13); - compute_encoder.set_vector_bytes(b_strides, 14); - compute_encoder.set_bytes(final_block_size, 15); + compute_encoder.set_bytes(x_batch_ndims, c++); + compute_encoder.set_vector_bytes(x_shape, c++); + compute_encoder.set_vector_bytes(x_strides, c++); + compute_encoder.set_bytes(w_batch_ndims, c++); + compute_encoder.set_vector_bytes(w_shape, c++); + compute_encoder.set_vector_bytes(w_strides, c++); + compute_encoder.set_vector_bytes(s_strides, c++); + if (biases) { + auto b_strides = biases->strides(); + b_strides.insert(b_strides.end() - 2, split_D * biases->shape(-1)); + compute_encoder.set_vector_bytes(b_strides, c++); + } + compute_encoder.set_bytes(final_block_size, c++); compute_encoder.dispatch_threadgroups(grid_dims, group_dims); @@ -364,7 +387,7 @@ void qvm( const array& x, const array& w, const array& scales, - const array& biases, + const std::optional& biases, array& out, int group_size, int bits, @@ -372,7 +395,8 @@ void qvm( int N, int K, metal::Device& d, - const Stream& s) { + const Stream& s, + const std::string& mode) { int B = out.size() / M / N; int bn = 64; @@ -385,7 +409,7 @@ void qvm( std::string type_string = get_type_string(x.dtype()); concatenate( kname, - "qvm_", + mode + "_qvm_", type_string, "_gs_", group_size, @@ -393,20 +417,23 @@ void qvm( bits, B > 1 ? "_batch_1" : "_batch_0"); auto template_def = get_template_definition( - kname, "qvm", type_string, group_size, bits, B > 1); + kname, mode + "_qvm", type_string, group_size, bits, B > 1); auto kernel = get_quantized_kernel(d, kname, template_def); auto& compute_encoder = d.get_command_encoder(s.index); compute_encoder.set_compute_pipeline_state(kernel); - compute_encoder.set_input_array(w, 0); - compute_encoder.set_input_array(scales, 1); - compute_encoder.set_input_array(biases, 2); - compute_encoder.set_input_array(x, 3); - compute_encoder.set_output_array(out, 4); - compute_encoder.set_bytes(K, 5); - compute_encoder.set_bytes(N, 6); - add_strides_and_shapes(compute_encoder, B <= 1, x, w, scales, biases, 7); + int c = 0; + compute_encoder.set_input_array(w, c++); + compute_encoder.set_input_array(scales, c++); + if (biases) { + compute_encoder.set_input_array(*biases, c++); + } + compute_encoder.set_input_array(x, c++); + compute_encoder.set_output_array(out, c++); + compute_encoder.set_bytes(K, c++); + compute_encoder.set_bytes(N, c++); + add_strides_and_shapes(compute_encoder, B <= 1, x, w, scales, biases, c++); compute_encoder.dispatch_threadgroups(grid_dims, group_dims); } @@ -415,7 +442,7 @@ void qmm( const array& x, const array& w, const array& scales, - const array& biases, + const std::optional& biases, array& out, bool transpose, int group_size, @@ -424,7 +451,8 @@ void qmm( int N, int K, metal::Device& d, - const Stream& s) { + const Stream& s, + const std::string& mode) { int B = out.size() / M / N; int wm = 2; @@ -441,7 +469,7 @@ void qmm( std::string type_string = get_type_string(x.dtype()); concatenate( kname, - transpose ? "qmm_t_" : "qmm_n_", + mode + (transpose ? "_qmm_t_" : "_qmm_n_"), type_string, "_gs_", group_size, @@ -452,25 +480,34 @@ void qmm( std::string template_def; if (transpose) { template_def = get_template_definition( - kname, "qmm_t", type_string, group_size, bits, aligned, batched); + kname, + mode + "_qmm_t", + type_string, + group_size, + bits, + aligned, + batched); } else { template_def = get_template_definition( - kname, "qmm_n", type_string, group_size, bits, batched); + kname, mode + "_qmm_n", type_string, group_size, bits, batched); } auto kernel = get_quantized_kernel(d, kname, template_def); auto& compute_encoder = d.get_command_encoder(s.index); compute_encoder.set_compute_pipeline_state(kernel); - compute_encoder.set_input_array(w, 0); - compute_encoder.set_input_array(scales, 1); - compute_encoder.set_input_array(biases, 2); - compute_encoder.set_input_array(x, 3); - compute_encoder.set_output_array(out, 4); - compute_encoder.set_bytes(K, 5); - compute_encoder.set_bytes(N, 6); - compute_encoder.set_bytes(M, 7); - add_strides_and_shapes(compute_encoder, B <= 1, x, w, scales, biases, 8); + int c = 0; + compute_encoder.set_input_array(w, c++); + compute_encoder.set_input_array(scales, c++); + if (biases) { + compute_encoder.set_input_array(*biases, c++); + } + compute_encoder.set_input_array(x, c++); + compute_encoder.set_output_array(out, c++); + compute_encoder.set_bytes(K, c++); + compute_encoder.set_bytes(N, c++); + compute_encoder.set_bytes(M, c++); + add_strides_and_shapes(compute_encoder, B <= 1, x, w, scales, biases, c); compute_encoder.dispatch_threadgroups(grid_dims, group_dims); } @@ -479,7 +516,7 @@ void gather_qmm( const array& x, const array& w, const array& scales, - const array& biases, + const std::optional& biases, const array& lhs_indices, const array& rhs_indices, array& out, @@ -490,7 +527,8 @@ void gather_qmm( int N, int K, metal::Device& d, - const Stream& s) { + const Stream& s, + const std::string& mode) { int B = out.size() / M / N; int wm = 2; @@ -507,7 +545,7 @@ void gather_qmm( std::string type_string = get_type_string(x.dtype()); concatenate( kname, - transpose ? "gather_qmm_t_" : "gather_qmm_n_", + mode + (transpose ? "_gather_qmm_t_" : "_gather_qmm_n_"), type_string, "_gs_", group_size, @@ -517,30 +555,31 @@ void gather_qmm( std::string template_def; if (transpose) { template_def = get_template_definition( - kname, "gather_qmm_t", type_string, group_size, bits, aligned); + kname, mode + "_gather_qmm_t", type_string, group_size, bits, aligned); } else { template_def = get_template_definition( - kname, "gather_qmm_n", type_string, group_size, bits); + kname, mode + "_gather_qmm_n", type_string, group_size, bits); } auto kernel = get_quantized_kernel(d, kname, template_def); auto& compute_encoder = d.get_command_encoder(s.index); compute_encoder.set_compute_pipeline_state(kernel); - compute_encoder.set_input_array(w, 0); - compute_encoder.set_input_array(scales, 1); - compute_encoder.set_input_array(biases, 2); - compute_encoder.set_input_array(x, 3); - compute_encoder.set_input_array(lhs_indices, 4); - compute_encoder.set_input_array(rhs_indices, 5); - compute_encoder.set_output_array(out, 6); - compute_encoder.set_bytes(K, 7); - compute_encoder.set_bytes(N, 8); - compute_encoder.set_bytes(M, 9); - int n = - add_strides_and_shapes(compute_encoder, false, x, w, scales, biases, 10); - add_gather_strides_and_shapes( - compute_encoder, lhs_indices, rhs_indices, 10 + n); + int c = 0; + compute_encoder.set_input_array(w, c++); + compute_encoder.set_input_array(scales, c++); + if (biases) { + compute_encoder.set_input_array(*biases, c++); + } + compute_encoder.set_input_array(x, c++); + compute_encoder.set_input_array(lhs_indices, c++); + compute_encoder.set_input_array(rhs_indices, c++); + compute_encoder.set_output_array(out, c++); + compute_encoder.set_bytes(K, c++); + compute_encoder.set_bytes(N, c++); + compute_encoder.set_bytes(M, c++); + c = add_strides_and_shapes(compute_encoder, false, x, w, scales, biases, c); + add_gather_strides_and_shapes(compute_encoder, lhs_indices, rhs_indices, c); compute_encoder.dispatch_threadgroups(grid_dims, group_dims); } @@ -549,7 +588,7 @@ void gather_qmv( const array& x, const array& w, const array& scales, - const array& biases, + const std::optional& biases, const array& lhs_indices, const array& rhs_indices, array& out, @@ -559,7 +598,8 @@ void gather_qmv( int N, int K, metal::Device& d, - const Stream& s) { + const Stream& s, + const std::string& mode) { int B = out.size() / M / N; int bn = 8; @@ -573,7 +613,7 @@ void gather_qmv( bool fast = N % bn == 0 && K % 512 == 0; concatenate( kname, - fast ? "gather_qmv_fast_" : "gather_qmv_", + mode + (fast ? "_gather_qmv_fast_" : "_gather_qmv_"), type_string, "_gs_", group_size, @@ -581,7 +621,7 @@ void gather_qmv( bits); auto template_def = get_template_definition( kname, - fast ? "gather_qmv_fast" : "gather_qmv", + mode + (fast ? "_gather_qmv_fast" : "_gather_qmv"), type_string, group_size, bits); @@ -590,19 +630,20 @@ void gather_qmv( auto& compute_encoder = d.get_command_encoder(s.index); compute_encoder.set_compute_pipeline_state(kernel); - compute_encoder.set_input_array(w, 0); - compute_encoder.set_input_array(scales, 1); - compute_encoder.set_input_array(biases, 2); - compute_encoder.set_input_array(x, 3); - compute_encoder.set_input_array(lhs_indices, 4); - compute_encoder.set_input_array(rhs_indices, 5); - compute_encoder.set_output_array(out, 6); - compute_encoder.set_bytes(K, 7); - compute_encoder.set_bytes(N, 8); - int n = - add_strides_and_shapes(compute_encoder, false, x, w, scales, biases, 9); - add_gather_strides_and_shapes( - compute_encoder, lhs_indices, rhs_indices, 9 + n); + int c = 0; + compute_encoder.set_input_array(w, c++); + compute_encoder.set_input_array(scales, c++); + if (biases) { + compute_encoder.set_input_array(*biases, c++); + } + compute_encoder.set_input_array(x, c++); + compute_encoder.set_input_array(lhs_indices, c++); + compute_encoder.set_input_array(rhs_indices, c++); + compute_encoder.set_output_array(out, c++); + compute_encoder.set_bytes(K, c++); + compute_encoder.set_bytes(N, c++); + c = add_strides_and_shapes(compute_encoder, false, x, w, scales, biases, c); + add_gather_strides_and_shapes(compute_encoder, lhs_indices, rhs_indices, c); compute_encoder.dispatch_threadgroups(grid_dims, group_dims); } @@ -611,7 +652,7 @@ void gather_qvm( const array& x, const array& w, const array& scales, - const array& biases, + const std::optional& biases, const array& lhs_indices, const array& rhs_indices, array& out, @@ -621,7 +662,8 @@ void gather_qvm( int N, int K, metal::Device& d, - const Stream& s) { + const Stream& s, + const std::string& mode) { int B = out.size() / M / N; int bn = 64; @@ -633,27 +675,34 @@ void gather_qvm( kname.reserve(64); std::string type_string = get_type_string(x.dtype()); concatenate( - kname, "gather_qvm_", type_string, "_gs_", group_size, "_b_", bits); + kname, + mode + "_gather_qvm_", + type_string, + "_gs_", + group_size, + "_b_", + bits); auto template_def = get_template_definition( - kname, "gather_qvm", type_string, group_size, bits); + kname, mode + "_gather_qvm", type_string, group_size, bits); auto kernel = get_quantized_kernel(d, kname, template_def); auto& compute_encoder = d.get_command_encoder(s.index); compute_encoder.set_compute_pipeline_state(kernel); - compute_encoder.set_input_array(w, 0); - compute_encoder.set_input_array(scales, 1); - compute_encoder.set_input_array(biases, 2); - compute_encoder.set_input_array(x, 3); - compute_encoder.set_input_array(lhs_indices, 4); - compute_encoder.set_input_array(rhs_indices, 5); - compute_encoder.set_output_array(out, 6); - compute_encoder.set_bytes(K, 7); - compute_encoder.set_bytes(N, 8); - int n = - add_strides_and_shapes(compute_encoder, false, x, w, scales, biases, 9); - add_gather_strides_and_shapes( - compute_encoder, lhs_indices, rhs_indices, 9 + n); + int c = 0; + compute_encoder.set_input_array(w, c++); + compute_encoder.set_input_array(scales, c++); + if (biases) { + compute_encoder.set_input_array(*biases, c++); + } + compute_encoder.set_input_array(x, c++); + compute_encoder.set_input_array(lhs_indices, c++); + compute_encoder.set_input_array(rhs_indices, c++); + compute_encoder.set_output_array(out, c++); + compute_encoder.set_bytes(K, c++); + compute_encoder.set_bytes(N, c++); + c = add_strides_and_shapes(compute_encoder, false, x, w, scales, biases, c++); + add_gather_strides_and_shapes(compute_encoder, lhs_indices, rhs_indices, c); compute_encoder.dispatch_threadgroups(grid_dims, group_dims); } @@ -662,7 +711,7 @@ void gather_qmm_rhs( const array& x_, const array& w_, const array& scales_, - const array& biases_, + const std::optional& biases_, const array& indices_, array& out, bool transpose, @@ -672,7 +721,8 @@ void gather_qmm_rhs( int N, int K, metal::Device& d, - const Stream& s) { + const Stream& s, + const std::string mode) { // Start by normalizing the indices array indices = ensure_row_contiguous(indices_, d, s); @@ -697,7 +747,6 @@ void gather_qmm_rhs( array x = broadcast_with_indices(x_); array w = ensure_row_contiguous(w_, d, s); array scales = ensure_row_contiguous(scales_, d, s); - array biases = ensure_row_contiguous(biases_, d, s); // TODO: Tune the block sizes int bm = 16, bn = 32, bk = 32; @@ -713,7 +762,7 @@ void gather_qmm_rhs( std::string type_string = get_type_string(x.dtype()); concatenate( kname, - transpose ? "gather_qmm_rhs_nt_" : "gather_qmm_rhs_nn_", + mode + (transpose ? "_gather_qmm_rhs_nt_" : "_gather_qmm_rhs_nn_"), type_string, "_gs_", group_size, @@ -770,15 +819,19 @@ void gather_qmm_rhs( MTL::Size group_dims(32, wn, wm); MTL::Size grid_dims((N + bn - 1) / bn, (M + bm - 1) / bm, 1); - compute_encoder.set_input_array(x, 0); - compute_encoder.set_input_array(w, 1); - compute_encoder.set_input_array(scales, 2); - compute_encoder.set_input_array(biases, 3); - compute_encoder.set_input_array(indices, 4); - compute_encoder.set_output_array(out, 5); - compute_encoder.set_bytes(M, 6); - compute_encoder.set_bytes(N, 7); - compute_encoder.set_bytes(K, 8); + int c = 0; + compute_encoder.set_input_array(x, c++); + compute_encoder.set_input_array(w, c++); + compute_encoder.set_input_array(scales, c++); + if (biases_) { + array biases = ensure_row_contiguous(*biases_, d, s); + compute_encoder.set_input_array(biases, c++); + } + compute_encoder.set_input_array(indices, c++); + compute_encoder.set_output_array(out, c++); + compute_encoder.set_bytes(M, c++); + compute_encoder.set_bytes(N, c++); + compute_encoder.set_bytes(K, c++); compute_encoder.dispatch_threadgroups(grid_dims, group_dims); } @@ -794,7 +847,10 @@ void QuantizedMatmul::eval_gpu(const std::vector& inputs, array& out) { array x = ensure_row_contiguous_matrix(inputs[0], d, s); array w = ensure_row_contiguous_matrix(inputs[1], d, s); array scales = ensure_row_contiguous_matrix(inputs[2], d, s); - array biases = ensure_row_contiguous_matrix(inputs[3], d, s); + std::optional biases = std::nullopt; + if (inputs.size() == 4) { + biases = ensure_row_contiguous_matrix(inputs[3], d, s); + } // Extract the matmul shapes bool non_batched = w.ndim() == 2 && x.flags().row_contiguous; @@ -818,30 +874,33 @@ void QuantizedMatmul::eval_gpu(const std::vector& inputs, array& out) { N, K, d, - s); + s, + mode_); return; } // It is a qmv with a small inner dimension so route to qmv_quad kernel if (transpose_ && (K == 128 || K == 64) && is_power_of_2(bits_)) { - qmv_quad(x, w, scales, biases, out, group_size_, bits_, M, N, K, d, s); + qmv_quad( + x, w, scales, biases, out, group_size_, bits_, M, N, K, d, s, mode_); return; } // Run of the mill qmv if (transpose_) { - qmv(x, w, scales, biases, out, group_size_, bits_, M, N, K, d, s); + qmv(x, w, scales, biases, out, group_size_, bits_, M, N, K, d, s, mode_); return; } // Run of the mill qvm if (K < 1024) { - qvm(x, w, scales, biases, out, group_size_, bits_, M, N, K, d, s); + qvm(x, w, scales, biases, out, group_size_, bits_, M, N, K, d, s, mode_); return; } // Qvm with large dimension so route to a split K kernel for more parallelism - qvm_split_k(x, w, scales, biases, out, group_size_, bits_, M, N, K, d, s); + qvm_split_k( + x, w, scales, biases, out, group_size_, bits_, M, N, K, d, s, mode_); return; } @@ -854,9 +913,12 @@ void GatherQMM::eval_gpu(const std::vector& inputs, array& out) { array x = ensure_row_contiguous_matrix(inputs[0], d, s); array w = ensure_row_contiguous_matrix(inputs[1], d, s); array scales = ensure_row_contiguous_matrix(inputs[2], d, s); - array biases = ensure_row_contiguous_matrix(inputs[3], d, s); - const array& lhs_indices = inputs[4]; - const array& rhs_indices = inputs[5]; + std::optional biases = std::nullopt; + if (inputs.size() == 6) { + biases = ensure_row_contiguous_matrix(inputs[3], d, s); + } + const array& lhs_indices = inputs[inputs.size() - 2]; + const array& rhs_indices = inputs[inputs.size() - 1]; int K = x.shape(-1); int M = x.shape(-2); @@ -884,7 +946,8 @@ void GatherQMM::eval_gpu(const std::vector& inputs, array& out) { N, K, d, - s); + s, + mode_); return; } @@ -905,7 +968,8 @@ void GatherQMM::eval_gpu(const std::vector& inputs, array& out) { N, K, d, - s); + s, + mode_); return; } @@ -924,7 +988,8 @@ void GatherQMM::eval_gpu(const std::vector& inputs, array& out) { N, K, d, - s); + s, + mode_); return; } @@ -942,10 +1007,11 @@ void GatherQMM::eval_gpu(const std::vector& inputs, array& out) { N, K, d, - s); + s, + mode_); } -void fast::AffineQuantize::eval_gpu( +void fast::Quantize::eval_gpu( const std::vector& inputs, std::vector& outputs) { auto& w_pre = inputs[0]; diff --git a/mlx/ops.cpp b/mlx/ops.cpp index bde34d54b..c14184a46 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -4089,7 +4089,7 @@ array quantized_matmul( inputs = { astype(x, dtype), w, astype(scales, dtype), astype(*biases, dtype)}; } else { - throw std::invalid_argument("ERROR!"); + inputs = {x, w, scales}; } if (x.ndim() > 2 && w.ndim() > 2) { @@ -4568,7 +4568,23 @@ array gather_qmm( auto out_shape = lhs_indices.shape(); out_shape.push_back(x.shape(-2)); out_shape.push_back(w_outer_dims); - + std::vector inputs; + if (mode == "affine") { + inputs = { + astype(x, out_type, s), + std::move(w), + astype(scales, out_type, s), + astype(*biases, out_type, s), + std::move(lhs_indices), + std::move(rhs_indices)}; + } else { + inputs = { + astype(x, out_type, s), + std::move(w), + std::move(scales), + std::move(lhs_indices), + std::move(rhs_indices)}; + } return array( std::move(out_shape), out_type, @@ -4580,12 +4596,7 @@ array gather_qmm( transpose, sorted_indices && !rhs_indices_, sorted_indices && !lhs_indices_), - {astype(x, out_type, s), - std::move(w), - astype(scales, out_type, s), - astype(*biases, out_type, s), - std::move(lhs_indices), - std::move(rhs_indices)}); + std::move(inputs)); } array tensordot( diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index 4341235b6..4808018c7 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -3243,6 +3243,10 @@ std::vector QuantizedMatmul::vjp( throw std::runtime_error( "[QuantizedMatmul::vjp] no gradient wrt the quantized weights."); } else { + if (mode_ == "mxfp4") { + throw std::runtime_error( + "[QuantizedMatmul::vjp] no gradient wrt scales with mxfp4 quantization."); + } if (!dsb) { int ndim = primals[1].ndim(); auto fc = flatten(cotangents[0], 0, -ndim, stream()); @@ -3372,14 +3376,19 @@ std::vector GatherQMM::vjp( // gradient wrt to the indices is undefined else if (arg > 3) { throw std::runtime_error( - "GatherQMM::vjp cannot compute the gradient wrt the indices."); + "[GatherQMM::vjp] cannot compute the gradient wrt the indices."); } // gradient wrt to w_q, scales or biases else if (arg == 1) { throw std::runtime_error( - "GatherQMM::vjp no gradient wrt the quantized weights."); + "[GatherQMM::vjp] no gradient wrt the quantized weights."); } else { + if (mode_ == "mxfp4") { + throw std::runtime_error( + "[GatherQMM::vjp] no gradient wrt scales with mxfp4 quantization."); + } + if (!dsb) { auto shape = w.shape(); shape.pop_back(); diff --git a/python/mlx/nn/layers/quantized.py b/python/mlx/nn/layers/quantized.py index c85b55e90..5d9c7ae5c 100644 --- a/python/mlx/nn/layers/quantized.py +++ b/python/mlx/nn/layers/quantized.py @@ -98,11 +98,11 @@ class QuantizedEmbedding(Module): # Initialize the quantized weight scale = math.sqrt(1 / dims) weight = mx.random.normal(shape=(num_embeddings, dims), scale=scale) - self.weight, scales_biases = mx.quantize(weight, group_size, bits, mode=mode) + self.weight, *scales_biases = mx.quantize(weight, group_size, bits, mode=mode) if mode == "affine": self.scales, self.biases = scales_biases else: - self.scales = scales_biases + (self.scales,) = scales_biases self.num_embeddings = num_embeddings self.dims = dims @@ -155,12 +155,16 @@ class QuantizedEmbedding(Module): """Create a :obj:`QuantizedEmbedding` layer from an :obj:`Embedding` layer.""" embedding_dims, dims = embedding_layer.weight.shape ql = cls(embedding_dims, dims, group_size, bits) - ql.weight, ql.scales, ql.biases = mx.quantize( + ql.weight, *scales_biases = mx.quantize( embedding_layer.weight, group_size, bits, mode=mode, ) + if mode == "affine": + ql.scales, ql.biases = scales_biases + else: + (ql.scales,) = scales_biases return ql @@ -210,11 +214,11 @@ class QuantizedLinear(Module): high=scale, shape=(output_dims, input_dims), ) - self.weight, scales_biases = mx.quantize(weight, group_size, bits, mode=mode) + self.weight, *scales_biases = mx.quantize(weight, group_size, bits, mode=mode) if mode == "affine": self.scales, self.biases = scales_biases else: - self.scales = scales_biases + (self.scales,) = scales_biases # And bias if needed if bias: @@ -257,7 +261,7 @@ class QuantizedLinear(Module): """Create a :obj:`QuantizedLinear` layer from a :obj:`Linear` layer.""" output_dims, input_dims = linear_layer.weight.shape ql = cls(input_dims, output_dims, False, group_size, bits) - ql.weight, scales_biases = mx.quantize( + ql.weight, *scales_biases = mx.quantize( linear_layer.weight, group_size, bits, @@ -266,7 +270,7 @@ class QuantizedLinear(Module): if mode == "affine": ql.scales, ql.biases = scales_biases else: - ql.scales = scales_biases + (ql.scales,) = scales_biases if "bias" in linear_layer: ql.bias = linear_layer.bias diff --git a/python/tests/test_nn.py b/python/tests/test_nn.py index 296f6ee8d..6ded37227 100644 --- a/python/tests/test_nn.py +++ b/python/tests/test_nn.py @@ -198,6 +198,12 @@ class TestBase(mlx_tests.MLXTestCase): self.assertTrue(isinstance(m.layers[1], nn.ReLU)) self.assertTrue(isinstance(m.layers[2], nn.QuantizedLinear)) + nn.quantize(m, group_size=32, mode="mxfp4") + self.assertTrue(isinstance(m.layers[0], nn.QuantizedEmbedding)) + self.assertTrue(isinstance(m.layers[1], nn.ReLU)) + self.assertTrue(isinstance(m.layers[2], nn.QuantizedLinear)) + self.assertTrue(isinstance(m.layers[2].scales, mx.array)) + def test_quantize_freeze(self): lin = nn.Linear(512, 512) qlin = lin.to_quantized() diff --git a/python/tests/test_quantized.py b/python/tests/test_quantized.py index eb2826dc8..f792c8c11 100644 --- a/python/tests/test_quantized.py +++ b/python/tests/test_quantized.py @@ -218,6 +218,34 @@ class TestQuantized(mlx_tests.MLXTestCase): self.assertEqual(y_q.shape, y_hat.shape) self.assertLess((y_q - y_hat).abs().max(), 1e-3) + def test_mxfp4_qmv(self): + key = mx.random.key(0) + k1, k2 = mx.random.split(key) + tests = product( + [256, 512, 67], # M + [64, 128], # N + [0, 1, 3, 8], # B + ) + for M, N, B in tests: + with self.subTest(shape=(B, M, N), group_size=32): + x_shape = (3, 1, N) if B == 0 else (B, 1, N) + w_shape = (M, N) if B == 0 else (B, M, N) + x = mx.random.normal(shape=x_shape, key=k1) + w = mx.random.normal(shape=w_shape, key=k2) + w_q, scales = mx.quantize(w, group_size=32, mode="mxfp4") + w_hat = mx.dequantize(w_q, scales, group_size=32, mode="mxfp4") + y_q = mx.quantized_matmul( + x, + w_q, + scales, + transpose=True, + group_size=32, + mode="mxfp4", + ) + y_hat = x @ mx.swapaxes(w_hat, -1, -2) + self.assertEqual(y_q.shape, y_hat.shape) + self.assertLess((y_q - y_hat).abs().max(), 1e-3) + def test_qvm(self): key = mx.random.key(0) k1, k2 = mx.random.split(key) @@ -283,6 +311,38 @@ class TestQuantized(mlx_tests.MLXTestCase): self.assertEqual(y_q.shape, y_hat.shape) self.assertLess((y_q - y_hat).abs().max(), 2e-3) + def test_mxfp4_qvm(self): + key = mx.random.key(0) + k1, k2 = mx.random.split(key) + tests = product( + [32, 128, 256], # M + [128, 256, 67], # N + [0, 1, 3, 8], # B + ) + # Add a splitk + tests = list(tests) + tests.append((128, 16384, 0)) + + for M, N, B in tests: + with self.subTest(shape=(B, M, N)): + x_shape = (1, N) if B == 0 else (B, 1, N) + w_shape = (N, M) if B == 0 else (B, N, M) + x = mx.random.normal(shape=x_shape, key=k1) + w = mx.random.normal(shape=w_shape, key=k2) + w_q, scales = mx.quantize(w, group_size=32, mode="mxfp4") + w_hat = mx.dequantize(w_q, scales, group_size=32, mode="mxfp4") + y_q = mx.quantized_matmul( + x, + w_q, + scales, + transpose=False, + group_size=32, + mode="mxfp4", + ) + y_hat = x @ w_hat + self.assertEqual(y_q.shape, y_hat.shape) + self.assertLess((y_q - y_hat).abs().max(), 1e-3) + def test_mode_error_cases(self): w = mx.random.normal(shape=(256, 256)) x = mx.random.normal(shape=(1, 256)) @@ -475,9 +535,13 @@ class TestQuantized(mlx_tests.MLXTestCase): self.assertLess((y_q - y_hat).abs().max(), 1e-3) def test_gather_qmm(self): - def quantize(w, transpose=True, group_size=64, bits=4): - qw, s, b = mx.quantize(w, group_size=group_size, bits=bits) - w_hat = mx.dequantize(qw, s, b, group_size=group_size, bits=bits) + def quantize(w, transpose=True, group_size=64, bits=4, mode="affine"): + if mode == "affine": + qw, s, b = mx.quantize(w, group_size=group_size, bits=bits, mode=mode) + else: + qw, s = mx.quantize(w, group_size=group_size, bits=bits, mode=mode) + b = None + w_hat = mx.dequantize(qw, s, b, group_size=group_size, bits=bits, mode=mode) if transpose: w_hat = w_hat.swapaxes(-1, -2) return w_hat, qw, s, b @@ -494,6 +558,7 @@ class TestQuantized(mlx_tests.MLXTestCase): transpose=True, group_size=64, bits=4, + mode="affine", ): with self.subTest( M=M, @@ -507,12 +572,13 @@ class TestQuantized(mlx_tests.MLXTestCase): transpose=transpose, group_size=group_size, bits=bits, + mode=mode, ): x = mx.random.normal(shape=batch_A + (M, K)).astype(dtype) w = mx.random.normal( shape=batch_B + ((N, K) if transpose else (K, N)) ).astype(dtype) - w_hat, qw, s, b = quantize(w, transpose, group_size, bits) + w_hat, qw, s, b = quantize(w, transpose, group_size, bits, mode=mode) if lhs_indices is not None: lhs_indices = mx.array(lhs_indices) @@ -530,8 +596,8 @@ class TestQuantized(mlx_tests.MLXTestCase): transpose=transpose, group_size=group_size, bits=bits, + mode=mode, ) - self.assertTrue(mx.allclose(c1, c2, atol=1e-4)) inputs = ( @@ -575,6 +641,14 @@ class TestQuantized(mlx_tests.MLXTestCase): "batch_B": (4, 1), "rhs_indices": ((2,), (0,), (1,)), }, + { + "batch_A": (1,), + "lhs_indices": (0,), + "batch_B": (3,), + "rhs_indices": (2, 1), + "group_size": 32, + "mode": "mxfp4", + }, ) for kwargs in inputs: @@ -618,9 +692,14 @@ class TestQuantized(mlx_tests.MLXTestCase): self.assertTrue(mx.allclose(g1, g2, atol=1e-4)) def test_gather_qmm_sorted(self): - def quantize(w, transpose=True, group_size=64, bits=4): - qw, s, b = mx.quantize(w, group_size=group_size, bits=bits) - w_hat = mx.dequantize(qw, s, b, group_size=group_size, bits=bits) + def quantize(w, transpose=True, bits=4, group_size=64, mode="affine"): + if mode == "affine": + qw, s, b = mx.quantize(w, group_size=group_size, bits=bits, mode=mode) + else: + qw, s = mx.quantize(w, group_size=group_size, bits=bits, mode=mode) + b = None + + w_hat = mx.dequantize(qw, s, b, group_size=group_size, bits=bits, mode=mode) if transpose: w_hat = w_hat.swapaxes(-1, -2) return w_hat, qw, s, b @@ -640,19 +719,21 @@ class TestQuantized(mlx_tests.MLXTestCase): parameters = [ # L, K, D, E, I, transpose - (32, 512, 512, 4, 2, True), - (32, 512, 544, 4, 2, True), - (133, 512, 512, 4, 2, True), - (133, 512, 555, 4, 2, True), - (133, 512, 512, 4, 2, True), - (64, 512, 512, 4, 2, False), - (64, 512, 544, 4, 2, False), - (133, 512, 512, 4, 2, False), - (133, 512, 544, 4, 2, False), - (133, 512, 555, 4, 2, False), - (64, 512, 512, 4, 2, False), + (32, 512, 512, 4, 2, True, "affine"), + (32, 512, 544, 4, 2, True, "mxfp4"), + (133, 512, 512, 4, 2, True, "affine"), + (133, 512, 555, 4, 2, True, "affine"), + (133, 512, 512, 4, 2, True, "affine"), + (64, 512, 512, 4, 2, False, "affine"), + (64, 512, 544, 4, 2, False, "mxfp4"), + (133, 512, 512, 4, 2, False, "affine"), + (133, 512, 544, 4, 2, False, "affine"), + (133, 512, 555, 4, 2, False, "affine"), + (64, 512, 512, 4, 2, False, "affine"), ] - for L, K, D, E, I, transpose in parameters: + for L, K, D, E, I, transpose, mode in parameters: + if mode == "mxfp4": + group_size = 32 K, D = (K, D) if transpose else (D, K) ishape = (L, I) xshape = (L, 1, 1, K) @@ -661,14 +742,28 @@ class TestQuantized(mlx_tests.MLXTestCase): indices = (mx.random.uniform(shape=ishape) * E).astype(mx.uint32) x = mx.random.normal(xshape) / K**0.5 w = mx.random.normal(wshape) / K**0.5 - w, *wq = quantize(w, transpose=transpose) + w, *wq = quantize(w, group_size=group_size, mode=mode, transpose=transpose) y1 = mx.gather_mm(x, w, rhs_indices=indices) - y2 = mx.gather_qmm(x, *wq, transpose=transpose, rhs_indices=indices) + y2 = mx.gather_qmm( + x, + *wq, + group_size=group_size, + mode=mode, + transpose=transpose, + rhs_indices=indices + ) xs, idx, inv_order = gather_sort(x, indices) y3 = mx.gather_mm(xs, w, rhs_indices=idx, sorted_indices=True) + y4 = mx.gather_qmm( - xs, *wq, rhs_indices=idx, transpose=transpose, sorted_indices=True + xs, + *wq, + group_size=group_size, + mode=mode, + rhs_indices=idx, + transpose=transpose, + sorted_indices=True ) y3 = scatter_unsort(y3, inv_order, indices.shape) y4 = scatter_unsort(y4, inv_order, indices.shape)