diff --git a/benchmarks/python/comparative/bench_mlx.py b/benchmarks/python/comparative/bench_mlx.py index 877fa4522..e51f6f1c3 100644 --- a/benchmarks/python/comparative/bench_mlx.py +++ b/benchmarks/python/comparative/bench_mlx.py @@ -62,10 +62,17 @@ def matmul(x, y): def _quant_matmul(x, w, s, b, transpose, group_size, bits): ys = [] - for i in range(10): + for i in range(100): ys.append( mx.quantized_matmul( - x, w, s, b, transpose=transpose, group_size=group_size, bits=bits + x, + w, + s, + b, + transpose=transpose, + group_size=group_size, + bits=bits, + mode=mx.QuantizationMode.DEFAULT, ) ) mx.eval(ys) diff --git a/mlx/array.h b/mlx/array.h index 3f000e9b2..1ee172de8 100644 --- a/mlx/array.h +++ b/mlx/array.h @@ -5,6 +5,7 @@ #include #include #include +#include #include #include "mlx/allocator.h" @@ -565,4 +566,6 @@ inline constexpr bool is_arrays_v = (is_array_v && ...); template using enable_for_arrays_t = typename std::enable_if_t>; +enum QuantizationMode { DEFAULT, NF4 }; + } // namespace mlx::core diff --git a/mlx/backend/metal/kernels/quantized.h b/mlx/backend/metal/kernels/quantized.h index 28a055576..2f3c52775 100644 --- a/mlx/backend/metal/kernels/quantized.h +++ b/mlx/backend/metal/kernels/quantized.h @@ -7,8 +7,28 @@ using namespace metal; #define MLX_MTL_CONST static constant constexpr const +enum QuantizationMode { DEFAULT, NF4 }; + MLX_MTL_CONST int SIMD_SIZE = 32; +constexpr constant static float nf4_values[16] = { + -1.0, + -0.6961928009986877, + -0.5250730514526367, + -0.39491748809814453, + -0.28444138169288635, + -0.18477343022823334, + -0.09105003625154495, + 0.0, + 0.07958029955625534, + 0.16093020141124725, + 0.24611230194568634, + 0.33791524171829224, + 0.44070982933044434, + 0.5626170039176941, + 0.7229568362236023, + 1.0}; + template inline U load_vector(const device T* x, thread U* x_thread) { static_assert( @@ -21,9 +41,9 @@ inline U load_vector(const device T* x, thread U* x_thread) { for (int i = 0; i < values_per_thread; i += 4) { sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3]; x_thread[i] = x[i]; - x_thread[i + 1] = x[i + 1] / 4.0f; - x_thread[i + 2] = x[i + 2] / 16.0f; - x_thread[i + 3] = x[i + 3] / 64.0f; + x_thread[i + 1] = x[i + 1]; + x_thread[i + 2] = x[i + 2]; + x_thread[i + 3] = x[i + 3]; } } @@ -31,9 +51,9 @@ inline U load_vector(const device T* x, thread U* x_thread) { for (int i = 0; i < values_per_thread; i += 4) { sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3]; x_thread[i] = x[i]; - x_thread[i + 1] = x[i + 1] / 16.0f; - x_thread[i + 2] = x[i + 2] / 256.0f; - x_thread[i + 3] = x[i + 3] / 4096.0f; + x_thread[i + 1] = x[i + 1]; + x_thread[i + 2] = x[i + 2]; + x_thread[i + 3] = x[i + 3]; } } @@ -59,9 +79,9 @@ inline U load_vector_safe(const device T* x, thread U* x_thread, int N) { for (int i = 0; i < N; i += 4) { sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3]; x_thread[i] = x[i]; - x_thread[i + 1] = x[i + 1] / 4.0f; - x_thread[i + 2] = x[i + 2] / 16.0f; - x_thread[i + 3] = x[i + 3] / 64.0f; + x_thread[i + 1] = x[i + 1]; + x_thread[i + 2] = x[i + 2]; + x_thread[i + 3] = x[i + 3]; } for (int i = N; i < values_per_thread; i++) { x_thread[i] = 0; @@ -72,9 +92,9 @@ inline U load_vector_safe(const device T* x, thread U* x_thread, int N) { for (int i = 0; i < N; i += 4) { sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3]; x_thread[i] = x[i]; - x_thread[i + 1] = x[i + 1] / 16.0f; - x_thread[i + 2] = x[i + 2] / 256.0f; - x_thread[i + 3] = x[i + 3] / 4096.0f; + x_thread[i + 1] = x[i + 1]; + x_thread[i + 2] = x[i + 2]; + x_thread[i + 3] = x[i + 3]; } for (int i = N; i < values_per_thread; i++) { x_thread[i] = 0; @@ -95,25 +115,21 @@ inline U load_vector_safe(const device T* x, thread U* x_thread, int N) { } template -inline U qdot( +METAL_FUNC U qdot_affine( const device uint8_t* w, const thread U* x_thread, U scale, U bias, U sum) { - static_assert( - bits == 2 || bits == 4 || bits == 8, - "Template undefined for bits not in {2, 4, 8}"); - U accum = 0; if (bits == 2) { for (int i = 0; i < (values_per_thread / 4); i++) { accum += (x_thread[4 * i] * (w[i] & 0x03) + - x_thread[4 * i + 1] * (w[i] & 0x0c) + - x_thread[4 * i + 2] * (w[i] & 0x30) + - x_thread[4 * i + 3] * (w[i] & 0xc0)); + x_thread[4 * i + 1] * ((w[i] & 0x0c) >> 2) + + x_thread[4 * i + 2] * ((w[i] & 0x30) >> 4) + + x_thread[4 * i + 3] * ((w[i] & 0xc0) >> 6)); } } @@ -122,9 +138,9 @@ inline U qdot( for (int i = 0; i < (values_per_thread / 4); i++) { accum += (x_thread[4 * i] * (ws[i] & 0x000f) + - x_thread[4 * i + 1] * (ws[i] & 0x00f0) + - x_thread[4 * i + 2] * (ws[i] & 0x0f00) + - x_thread[4 * i + 3] * (ws[i] & 0xf000)); + x_thread[4 * i + 1] / 16.0f * (ws[i] & 0x00f0) + + x_thread[4 * i + 2] / 256.0f * (ws[i] & 0x0f00) + + x_thread[4 * i + 3] / 4096.0f * (ws[i] & 0xf000)); } } @@ -134,10 +150,98 @@ inline U qdot( } } - return scale * accum + sum * bias; + U out = scale * accum + sum * bias; + return out; } template +inline U qdot_nf4(const device uint8_t* w, const thread U* x_thread, U scale) { + U accum = 0; + for (int i = 0; i < (values_per_thread / 2); i++) { + accum += + (x_thread[2 * i] * nf4_values[w[i] & 0x000f] + + x_thread[2 * i + 1] * nf4_values[(w[i] & 0x00f0) >> 4]); + } + U out = scale * accum; + return out; +} + +template +inline U qdot( + const device uint8_t* w, + const thread U* x_thread, + U scale, + U bias, + U sum) { + static_assert( + bits == 2 || bits == 4 || bits == 8, + "Template undefined for bits not in {2, 4, 8}"); + if (mode == QuantizationMode::DEFAULT) { + return qdot_affine( + w, x_thread, scale, bias, sum); + } else { + return qdot_nf4(w, x_thread, scale); + } +} + +template +inline U qdot_safe_affine( + const device uint8_t* w, + const thread U* x_thread, + U scale, + U bias, + U sum, + int N) { + U accum = 0; + + if (bits == 2) { + for (int i = 0; i < (N / 4); i++) { + accum += + (x_thread[4 * i] * (w[i] & 0x03) + + x_thread[4 * i + 1] * ((w[i] & 0x0c) >> 2) + + x_thread[4 * i + 2] * ((w[i] & 0x30) >> 4) + + x_thread[4 * i + 3] * ((w[i] & 0xc0) >> 6)); + } + } + + else if (bits == 4) { + const device uint16_t* ws = (const device uint16_t*)w; + for (int i = 0; i < (N / 4); i++) { + accum += + (x_thread[4 * i] * (ws[i] & 0x000f) + + x_thread[4 * i + 1] / 16.0f * (ws[i] & 0x00f0) + + x_thread[4 * i + 2] / 256.0f * (ws[i] & 0x0f00) + + x_thread[4 * i + 3] / 4096.0f * (ws[i] & 0xf000)); + } + } + + else if (bits == 8) { + for (int i = 0; i < N; i++) { + accum += x_thread[i] * w[i]; + } + } + + U out = scale * accum + sum * bias; + return out; +} + +template +inline U qdot_safe_nf4( + const device uint8_t* w, + const thread U* x_thread, + U scale, + int N) { + U accum = 0; + for (int i = 0; i < (N / 2); i++) { + accum += + (x_thread[2 * i] * nf4_values[w[i] & 0x000f] + + x_thread[2 * i + 1] * nf4_values[(w[i] & 0x00f0) >> 4]); + } + U out = scale * accum; + return out; +} + +template inline U qdot_safe( const device uint8_t* w, const thread U* x_thread, @@ -148,46 +252,17 @@ inline U qdot_safe( static_assert( bits == 2 || bits == 4 || bits == 8, "Template undefined for bits not in {2, 4, 8}"); - - U accum = 0; - - if (bits == 2) { - for (int i = 0; i < (N / 4); i++) { - accum += - (x_thread[4 * i] * (w[i] & 0x03) + - x_thread[4 * i + 1] * (w[i] & 0x0c) + - x_thread[4 * i + 2] * (w[i] & 0x30) + - x_thread[4 * i + 3] * (w[i] & 0xc0)); - } + if (mode == QuantizationMode::DEFAULT) { + return qdot_safe_affine( + w, x_thread, scale, bias, sum, N); + } else { + return qdot_safe_nf4(w, x_thread, scale, N); } - - else if (bits == 4) { - const device uint16_t* ws = (const device uint16_t*)w; - for (int i = 0; i < (N / 4); i++) { - accum += - (x_thread[4 * i] * (ws[i] & 0x000f) + - x_thread[4 * i + 1] * (ws[i] & 0x00f0) + - x_thread[4 * i + 2] * (ws[i] & 0x0f00) + - x_thread[4 * i + 3] * (ws[i] & 0xf000)); - } - } - - else if (bits == 8) { - for (int i = 0; i < N; i++) { - accum += x_thread[i] * w[i]; - } - } - - return scale * accum + sum * bias; } template inline void -qouter(const thread uint8_t* w, U x, U scale, U bias, thread U* result) { - static_assert( - bits == 2 || bits == 4 || bits == 8, - "Template undefined for bits not in {2, 4, 8}"); - +qouter_affine(const thread uint8_t* w, U x, U scale, U bias, thread U* result) { if (bits == 2) { U s[4] = {scale, scale / 4.0f, scale / 16.0f, scale / 64.0f}; for (int i = 0; i < (values_per_thread / 4); i++) { @@ -213,13 +288,34 @@ qouter(const thread uint8_t* w, U x, U scale, U bias, thread U* result) { } } -template +template inline void -dequantize(const device uint8_t* w, U scale, U bias, threadgroup U* w_local) { +qouter_nf4(const thread uint8_t* w, U x, U scale, thread U* result) { + for (int i = 0; i < (values_per_thread / 2); i++) { + result[2 * i] += x * (scale * nf4_values[w[i] & 0x0f]); + result[2 * i + 1] += x * (scale * nf4_values[(w[i] & 0xf0) >> 4]); + } +} + +template +inline void +qouter(const thread uint8_t* w, U x, U scale, U bias, thread U* result) { static_assert( bits == 2 || bits == 4 || bits == 8, "Template undefined for bits not in {2, 4, 8}"); + if (mode == QuantizationMode::DEFAULT) { + return qouter_affine(w, x, scale, bias, result); + } else { + return qouter_nf4(w, x, scale, result); + } +} +template +inline void dequantize_affine( + const device uint8_t* w, + U scale, + U bias, + threadgroup U* w_local) { if (bits == 2) { U s[4] = { scale, @@ -249,6 +345,28 @@ dequantize(const device uint8_t* w, U scale, U bias, threadgroup U* w_local) { } } +template +inline void +dequantize_nf4(const device uint8_t* w, U scale, threadgroup U* w_local) { + for (int i = 0; i < (N / 2); i++) { + w_local[2 * i] = scale * static_cast(nf4_values[w[i] & 0x0f]); + w_local[2 * i + 1] = scale * static_cast(nf4_values[(w[i] & 0xf0) >> 4]); + } +} + +template +inline void +dequantize(const device uint8_t* w, U scale, U bias, threadgroup U* w_local) { + static_assert( + bits == 2 || bits == 4 || bits == 8, + "Template undefined for bits not in {2, 4, 8}"); + if (mode == QuantizationMode::DEFAULT) { + return dequantize_affine(w, scale, bias, w_local); + } else { + return dequantize_nf4(w, scale, w_local); + } +} + template < typename T, short BROWS, @@ -257,7 +375,8 @@ template < short reduction_dim, short tgp_size, short group_size, - short bits> + short bits, + QuantizationMode mode> struct QuantizedBlockLoader { static_assert( BCOLS <= group_size, @@ -318,7 +437,7 @@ struct QuantizedBlockLoader { T scale = *scales; T bias = *biases; for (int i = 0; i < n_reads; i++) { - dequantize( + dequantize( (device uint8_t*)(src + i), scale, bias, dst + i * pack_factor); } } @@ -345,7 +464,7 @@ struct QuantizedBlockLoader { T scale = *scales; T bias = *biases; for (int i = 0; i < n_reads; i++) { - dequantize( + dequantize( (device uint8_t*)(src + i), scale, bias, dst + i * pack_factor); } } @@ -371,7 +490,11 @@ struct QuantizedBlockLoader { } }; -template +template < + typename T, + int group_size, + int bits, + QuantizationMode mode = QuantizationMode::DEFAULT> METAL_FUNC void qmv_fast_impl( const device uint32_t* w, const device T* scales, @@ -417,13 +540,14 @@ METAL_FUNC void qmv_fast_impl( const device T* bl = biases + row * in_vec_size_g; U s = sl[0]; - U b = bl[0]; - result[row] += qdot(wl, x_thread, s, b, sum); + U b = mode == QuantizationMode::DEFAULT ? bl[0] : 0; + result[row] += + qdot(wl, x_thread, s, b, sum); } w += block_size / pack_factor; - scales += block_size / group_size; biases += block_size / group_size; + scales += block_size / group_size; x += block_size; } @@ -435,7 +559,11 @@ METAL_FUNC void qmv_fast_impl( } } -template +template < + typename T, + int group_size, + int bits, + QuantizationMode mode = QuantizationMode::DEFAULT> METAL_FUNC void qmv_impl( const device uint32_t* w, const device T* scales, @@ -493,7 +621,7 @@ METAL_FUNC void qmv_impl( U s = sl[0]; U b = bl[0]; result[row] += - qdot(wl, x_thread, s, b, sum); + qdot(wl, x_thread, s, b, sum); } w += block_size / pack_factor; @@ -516,7 +644,8 @@ METAL_FUNC void qmv_impl( U s = sl[0]; U b = bl[0]; - result[row] += qdot(wl, x_thread, s, b, sum); + result[row] += + qdot(wl, x_thread, s, b, sum); } for (int row = 0; out_row + row < out_vec_size; row++) { @@ -548,7 +677,7 @@ METAL_FUNC void qmv_impl( U s = sl[0]; U b = bl[0]; result[row] += - qdot(wl, x_thread, s, b, sum); + qdot(wl, x_thread, s, b, sum); } w += block_size / pack_factor; @@ -571,7 +700,7 @@ METAL_FUNC void qmv_impl( U s = sl[0]; U b = bl[0]; - result[row] += qdot_safe( + result[row] += qdot_safe( wl, x_thread, s, b, sum, remaining); } @@ -584,7 +713,11 @@ METAL_FUNC void qmv_impl( } } -template +template < + typename T, + const int group_size, + const int bits, + const QuantizationMode mode = QuantizationMode::DEFAULT> METAL_FUNC void qvm_impl( const device T* x, const device uint32_t* w, @@ -636,7 +769,7 @@ METAL_FUNC void qvm_impl( bias = *biases; w_local = *((device vec_w*)w); - qouter( + qouter( (thread uint8_t*)&w_local, x_local, scale, bias, result); x += blocksize; @@ -651,7 +784,7 @@ METAL_FUNC void qvm_impl( bias = *biases; w_local = *((device vec_w*)w); - qouter( + qouter( (thread uint8_t*)&w_local, x_local, scale, bias, result); x += blocksize; @@ -669,7 +802,7 @@ METAL_FUNC void qvm_impl( scale = 0; bias = 0; } - qouter( + qouter( (thread uint8_t*)&w_local, x_local, scale, bias, result); } @@ -695,7 +828,8 @@ template < const int BN, const int group_size, const int bits, - const bool aligned_N> + const bool aligned_N, + const QuantizationMode mode = QuantizationMode::DEFAULT> METAL_FUNC void qmm_t_impl( const device T* x, const device uint32_t* w, @@ -734,7 +868,8 @@ METAL_FUNC void qmm_t_impl( 1, WM * WN * SIMD_SIZE, group_size, - bits>; + bits, + mode>; // Set the block const int K_w = K / pack_factor; @@ -816,7 +951,8 @@ template < const int BK, const int BN, const int group_size, - const int bits> + const int bits, + const QuantizationMode mode = QuantizationMode::DEFAULT> METAL_FUNC void qmm_n_impl( const device T* x, const device uint32_t* w, @@ -856,7 +992,8 @@ METAL_FUNC void qmm_n_impl( 0, WM * WN * SIMD_SIZE, group_size, - bits>; + bits, + mode>; // Set the block const int y_row = tid.y * BM; @@ -996,7 +1133,11 @@ METAL_FUNC void adjust_matrix_offsets( y += tid.z * output_stride; } -template +template < + typename T, + int group_size, + int bits, + QuantizationMode mode = QuantizationMode::DEFAULT> [[kernel]] void qmv_fast( const device uint32_t* w [[buffer(0)]], const device T* scales [[buffer(1)]], @@ -1008,7 +1149,7 @@ template uint3 tid [[threadgroup_position_in_grid]], uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { - qmv_fast_impl( + qmv_fast_impl( w, scales, biases, @@ -1021,7 +1162,11 @@ template simd_lid); } -template +template < + typename T, + const int group_size, + const int bits, + const QuantizationMode mode = QuantizationMode::DEFAULT> [[kernel]] void qmv( const device uint32_t* w [[buffer(0)]], const device T* scales [[buffer(1)]], @@ -1033,7 +1178,7 @@ template uint3 tid [[threadgroup_position_in_grid]], uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { - qmv_impl( + qmv_impl( w, scales, biases, @@ -1046,7 +1191,11 @@ template simd_lid); } -template +template < + typename T, + const int group_size, + const int bits, + const QuantizationMode mode = QuantizationMode::DEFAULT> [[kernel]] void qvm( const device T* x [[buffer(0)]], const device uint32_t* w [[buffer(1)]], @@ -1058,7 +1207,7 @@ template uint3 tid [[threadgroup_position_in_grid]], uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { - qvm_impl( + qvm_impl( x, w, scales, @@ -1076,6 +1225,7 @@ template < const int group_size, const int bits, const bool aligned_N, + const QuantizationMode mode = QuantizationMode::DEFAULT, const int BM = 32, const int BK = 32, const int BN = 32> @@ -1099,7 +1249,7 @@ template < threadgroup T Xs[BM * BK_padded]; threadgroup T Ws[BN * BK_padded]; - qmm_t_impl( + qmm_t_impl( x, w, scales, biases, y, Xs, Ws, M, N, K, tid, lid, simd_gid, simd_lid); } @@ -1107,6 +1257,7 @@ template < typename T, const int group_size, const int bits, + const QuantizationMode mode = QuantizationMode::DEFAULT, const int BM = 32, const int BK = 32, const int BN = 32> @@ -1131,7 +1282,7 @@ template < threadgroup T Xs[BM * BK_padded]; threadgroup T Ws[BK * BN_padded]; - qmm_n_impl( + qmm_n_impl( x, w, scales, biases, y, Xs, Ws, M, N, K, tid, lid, simd_gid, simd_lid); } diff --git a/mlx/backend/metal/kernels/quantized.metal b/mlx/backend/metal/kernels/quantized.metal index 0651db872..b3a110359 100644 --- a/mlx/backend/metal/kernels/quantized.metal +++ b/mlx/backend/metal/kernels/quantized.metal @@ -6,28 +6,33 @@ #include "mlx/backend/metal/kernels/quantized.h" -#define instantiate_qmv_fast(itype, group_size, bits) \ - instantiate_kernel( \ - "qmv_" #itype "_gs_" #group_size "_b_" #bits "_fast", \ - qmv_fast, \ - itype, \ - group_size, \ - bits) +#define instantiate_qmv_fast(itype, group_size, bits, mode) \ + instantiate_kernel( \ + "qmv_" #itype "_gs_" #group_size "_b_" #bits "_" #mode "_fast", \ + qmv_fast, \ + itype, \ + group_size, \ + bits, \ + mode) -#define instantiate_qmv_fast_types(group_size, bits) \ - instantiate_qmv_fast(float, group_size, bits) \ - instantiate_qmv_fast(float16_t, group_size, bits) \ - instantiate_qmv_fast(bfloat16_t, group_size, bits) +#define instantiate_qmv_fast_types(group_size, bits, mode) \ + instantiate_qmv_fast(float, group_size, bits, mode) \ + instantiate_qmv_fast(float16_t, group_size, bits, mode) \ + instantiate_qmv_fast(bfloat16_t, group_size, bits, mode) -instantiate_qmv_fast_types(128, 2) -instantiate_qmv_fast_types(128, 4) -instantiate_qmv_fast_types(128, 8) -instantiate_qmv_fast_types( 64, 2) -instantiate_qmv_fast_types( 64, 4) -instantiate_qmv_fast_types( 64, 8) -instantiate_qmv_fast_types( 32, 2) -instantiate_qmv_fast_types( 32, 4) -instantiate_qmv_fast_types( 32, 8) +instantiate_qmv_fast_types(128, 2, QuantizationMode::DEFAULT) +instantiate_qmv_fast_types(128, 4, QuantizationMode::DEFAULT) +instantiate_qmv_fast_types(128, 8, QuantizationMode::DEFAULT) +instantiate_qmv_fast_types( 64, 2, QuantizationMode::DEFAULT) +instantiate_qmv_fast_types( 64, 4, QuantizationMode::DEFAULT) +instantiate_qmv_fast_types( 64, 8, QuantizationMode::DEFAULT) +instantiate_qmv_fast_types( 32, 2, QuantizationMode::DEFAULT) +instantiate_qmv_fast_types( 32, 4, QuantizationMode::DEFAULT) +instantiate_qmv_fast_types( 32, 8, QuantizationMode::DEFAULT) + +instantiate_qmv_fast_types(128, 4, QuantizationMode::NF4) +instantiate_qmv_fast_types(64, 4, QuantizationMode::NF4) +instantiate_qmv_fast_types(32, 4, QuantizationMode::NF4) #define instantiate_qmv(itype, group_size, bits) \ instantiate_kernel( \ @@ -37,8 +42,8 @@ instantiate_qmv_fast_types( 32, 8) group_size, \ bits) -#define instantiate_qmv_types(group_size, bits) \ - instantiate_qmv(float, group_size, bits) \ +#define instantiate_qmv_types(group_size, bits) \ + instantiate_qmv(float, group_size, bits) \ instantiate_qmv(float16_t, group_size, bits) \ instantiate_qmv(bfloat16_t, group_size, bits) @@ -60,8 +65,8 @@ instantiate_qmv_types( 32, 8) group_size, \ bits) -#define instantiate_qvm_types(group_size, bits) \ - instantiate_qvm(float, group_size, bits) \ +#define instantiate_qvm_types(group_size, bits) \ + instantiate_qvm(float, group_size, bits) \ instantiate_qvm(float16_t, group_size, bits) \ instantiate_qvm(bfloat16_t, group_size, bits) @@ -84,12 +89,12 @@ instantiate_qvm_types( 32, 8) bits, \ aligned_N) -#define instantiate_qmm_t_types(group_size, bits) \ - instantiate_qmm_t(float, group_size, bits, false) \ - instantiate_qmm_t(float16_t, group_size, bits, false) \ +#define instantiate_qmm_t_types(group_size, bits) \ + instantiate_qmm_t(float, group_size, bits, false) \ + instantiate_qmm_t(float16_t, group_size, bits, false) \ instantiate_qmm_t(bfloat16_t, group_size, bits, false) \ - instantiate_qmm_t(float, group_size, bits, true) \ - instantiate_qmm_t(float16_t, group_size, bits, true) \ + instantiate_qmm_t(float, group_size, bits, true) \ + instantiate_qmm_t(float16_t, group_size, bits, true) \ instantiate_qmm_t(bfloat16_t, group_size, bits, true) instantiate_qmm_t_types(128, 2) diff --git a/mlx/backend/metal/quantized.cpp b/mlx/backend/metal/quantized.cpp index f0a64d5e6..109d5c65d 100644 --- a/mlx/backend/metal/quantized.cpp +++ b/mlx/backend/metal/quantized.cpp @@ -9,6 +9,8 @@ #include "mlx/backend/metal/utils.h" #include "mlx/primitives.h" +#include + namespace mlx::core { void QuantizedMatmul::eval_gpu(const std::vector& inputs, array& out) { @@ -42,18 +44,27 @@ void QuantizedMatmul::eval_gpu(const std::vector& inputs, array& out) { int D = x.shape(-1); int B = x.size() / D; int O = out.shape(-1); + auto mode_string = mode_ == QuantizationMode::DEFAULT + ? "QuantizationMode::DEFAULT" + : "QuantizationMode::NF4"; + // auto mode_string = "default"; if (transpose_) { // Route to the fast qmv kernel that has no bounds checking if (B < 6 && O % 8 == 0 && D % 512 == 0 && D >= 512) { std::ostringstream kname; auto type_string = get_type_string(x.dtype()); kname << "qmv_" << type_string << "_gs_" << group_size_ << "_b_" << bits_ - << "_fast"; + << "_" << mode_string << "_fast"; // Encode and dispatch kernel auto& compute_encoder = d.get_command_encoder(s.index); auto template_def = get_template_definition( - kname.str(), "qmv_fast", type_string, group_size_, bits_); + kname.str(), + "qmv_fast", + type_string, + group_size_, + bits_, + mode_string); auto kernel = get_quantized_kernel(d, kname.str(), template_def); compute_encoder->setComputePipelineState(kernel); @@ -77,12 +88,13 @@ void QuantizedMatmul::eval_gpu(const std::vector& inputs, array& out) { else if (B < 6) { std::ostringstream kname; auto type_string = get_type_string(x.dtype()); - kname << "qmv_" << type_string << "_gs_" << group_size_ << "_b_" << bits_; + kname << "qmv_" << type_string << "_gs_" << group_size_ << "_b_" << bits_ + << "_" << mode_string; // Encode and dispatch kernel auto& compute_encoder = d.get_command_encoder(s.index); auto template_def = get_template_definition( - kname.str(), "qmv", type_string, group_size_, bits_); + kname.str(), "qmv", type_string, group_size_, bits_, mode_string); auto kernel = get_quantized_kernel(d, kname.str(), template_def); compute_encoder->setComputePipelineState(kernel); @@ -108,12 +120,18 @@ void QuantizedMatmul::eval_gpu(const std::vector& inputs, array& out) { std::string aligned_n = (O % 32) == 0 ? "true" : "false"; auto type_string = get_type_string(x.dtype()); kname << "qmm_t_" << type_string << "_gs_" << group_size_ << "_b_" - << bits_ << "_alN_" << aligned_n; + << bits_ << "_alN_" << aligned_n << "_" << mode_string; // Encode and dispatch kernel auto& compute_encoder = d.get_command_encoder(s.index); auto template_def = get_template_definition( - kname.str(), "qmm_t", type_string, group_size_, bits_, aligned_n); + kname.str(), + "qmm_t", + type_string, + group_size_, + bits_, + aligned_n, + mode_string); auto kernel = get_quantized_kernel(d, kname.str(), template_def); compute_encoder->setComputePipelineState(kernel); @@ -141,12 +159,13 @@ void QuantizedMatmul::eval_gpu(const std::vector& inputs, array& out) { if (B < 4) { std::ostringstream kname; auto type_string = get_type_string(x.dtype()); - kname << "qvm_" << type_string << "_gs_" << group_size_ << "_b_" << bits_; + kname << "qvm_" << type_string << "_gs_" << group_size_ << "_b_" << bits_ + << "_" << mode_string; // Encode and dispatch kernel auto& compute_encoder = d.get_command_encoder(s.index); auto template_def = get_template_definition( - kname.str(), "qvm", type_string, group_size_, bits_); + kname.str(), "qvm", type_string, group_size_, bits_, mode_string); auto kernel = get_quantized_kernel(d, kname.str(), template_def); compute_encoder->setComputePipelineState(kernel); @@ -171,12 +190,12 @@ void QuantizedMatmul::eval_gpu(const std::vector& inputs, array& out) { std::ostringstream kname; auto type_string = get_type_string(x.dtype()); kname << "qmm_n_" << type_string << "_gs_" << group_size_ << "_b_" - << bits_; + << bits_ << "_" << mode_string; // Encode and dispatch kernel auto& compute_encoder = d.get_command_encoder(s.index); auto template_def = get_template_definition( - kname.str(), "qmm_n", type_string, group_size_, bits_); + kname.str(), "qmm_n", type_string, group_size_, bits_, mode_string); auto kernel = get_quantized_kernel(d, kname.str(), template_def); compute_encoder->setComputePipelineState(kernel); diff --git a/mlx/ops.cpp b/mlx/ops.cpp index d38b8e6f7..271bf6fab 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -12,6 +12,8 @@ #include "mlx/transforms.h" #include "mlx/utils.h" +#include + namespace mlx::core { namespace { @@ -72,7 +74,8 @@ std::pair extract_quantized_matmul_dims( const array& biases, bool transpose, int group_size, - int bits) { + int bits, + QuantizationMode mode) { if (w.dtype() != uint32) { std::ostringstream msg; msg << "[" << tag << "] The weight matrix should be uint32 " @@ -80,7 +83,7 @@ std::pair extract_quantized_matmul_dims( throw std::invalid_argument(msg.str()); } - if (scales.shape() != biases.shape()) { + if (mode == QuantizationMode::DEFAULT && scales.shape() != biases.shape()) { std::ostringstream msg; msg << "[" << tag << "] Scales and biases should have the same shape. " << "Received scales with shape " << scales.shape() @@ -3287,10 +3290,19 @@ array quantized_matmul( bool transpose /* = true */, int group_size /* = 64 */, int bits /* = 4 */, + QuantizationMode mode /* = DEFAULT */, StreamOrDevice s /* = {} */) { // Check and extract the quantized matrix shape against x auto [w_inner_dims, w_outer_dims] = extract_quantized_matmul_dims( - "quantized_matmul", x, w, scales, biases, transpose, group_size, bits); + "quantized_matmul", + x, + w, + scales, + biases, + transpose, + group_size, + bits, + mode); if (w.ndim() != 2) { std::ostringstream msg; @@ -3315,7 +3327,7 @@ array quantized_matmul( std::move(out_shape), dtype, std::make_shared( - to_stream(s), group_size, bits, transpose), + to_stream(s), group_size, bits, transpose, mode), {astype(x, dtype, s), w, astype(scales, dtype, s), @@ -3326,6 +3338,7 @@ std::tuple quantize( const array& w, int group_size /* = 64 */, int bits /* = 4 */, + QuantizationMode mode /* = DEFAULT */, StreamOrDevice s /* = {} */) { if (group_size != 32 && group_size != 64 && group_size != 128) { std::ostringstream msg; @@ -3341,6 +3354,13 @@ std::tuple quantize( throw std::invalid_argument(msg.str()); } + if (mode == QuantizationMode::NF4 && bits != 4) { + std::ostringstream msg; + msg << "[quantize] The requested number of bits " << bits + << " is not supported. NF4 only supports 4 bits"; + throw std::invalid_argument(msg.str()); + } + if (w.ndim() < 2) { std::ostringstream msg; msg << "[quantize] The matrix to be quantized must have at least 2 dimension " @@ -3382,34 +3402,62 @@ std::tuple quantize( auto wshape = w.shape(); wshape.back() = -1; - // Compute scales and biases - array packed_w = reshape(w, {-1, w.shape(-1) / group_size, group_size}, s); - array w_max = max(packed_w, /* axis= */ -1, /* keepdims= */ true, s); - array w_min = min(packed_w, /* axis= */ -1, /* keepdims= */ true, s); + auto packed_w = reshape(w, {-1, w.shape(-1) / group_size, group_size}, s); + auto biases = array({0.0}); + auto scales = array({0.0}); + if (mode == QuantizationMode::NF4) { + // For NF4, we quantize using a fixed array of quantiles generated for the + // normal distribution + auto quantiles = array( + {-1.0, + -0.6961928009986877, + -0.5250730514526367, + -0.39491748809814453, + -0.28444138169288635, + -0.18477343022823334, + -0.09105003625154495, + 0.0, + 0.07958029955625534, + 0.16093020141124725, + 0.24611230194568634, + 0.33791524171829224, + 0.44070982933044434, + 0.5626170039176941, + 0.7229568362236023, + 1.0}); + scales = max(abs(packed_w, s), /* axis= */ -1, /* keepdims= */ true, s); + array scaled_w = expand_dims(divide(packed_w, scales, s), -1, s); + packed_w = + argmin(square(subtract(scaled_w, quantiles, s), s), -1, false, s); + // TODO figure out zero scale case + } else { + array w_max = max(packed_w, /* axis= */ -1, /* keepdims= */ true, s); + array w_min = min(packed_w, /* axis= */ -1, /* keepdims= */ true, s); - array mask = greater(abs(w_min, s), abs(w_max, s), s); - array scales = maximum(divide(subtract(w_max, w_min, s), n_bins, s), eps, s); - scales = where(mask, scales, negative(scales), s); - array edge = where(mask, w_min, w_max, s); - array q0 = round(divide(edge, scales, s), s); - scales = where(not_equal(q0, zero, s), divide(edge, q0, s), scales); - array biases = where(equal(q0, zero, s), zero, edge); + array mask = greater(abs(w_min, s), abs(w_max, s), s); + scales = maximum(divide(subtract(w_max, w_min, s), n_bins, s), eps, s); + scales = where(mask, scales, negative(scales), s); + array edge = where(mask, w_min, w_max, s); + array q0 = round(divide(edge, scales, s), s); + scales = where(not_equal(q0, zero, s), divide(edge, q0, s), scales); + biases = where(equal(q0, zero, s), zero, edge); - // Quantize and pack w - packed_w = astype( - clip( - round(divide(subtract(packed_w, biases, s), scales, s), s), - zero, - n_bins), - uint32); + packed_w = astype( + clip( + round(divide(subtract(packed_w, biases, s), scales, s), s), + zero, + n_bins), + uint32); + biases = reshape(biases, wshape, s); + } + + // Pack bits into uint32s packed_w = reshape(packed_w, {packed_w.shape(0), -1, el_per_int}, s); packed_w = sum( multiply(packed_w, shifts, s), /* axis= */ 2, /* keepdims= */ false, s); return std::make_tuple( - reshape(packed_w, wshape, s), - reshape(scales, wshape, s), - reshape(biases, wshape, s)); + reshape(packed_w, wshape, s), reshape(scales, wshape, s), biases); } array dequantize( @@ -3418,6 +3466,7 @@ array dequantize( const array& biases, int group_size /* = 64 */, int bits /* = 4 */, + QuantizationMode mode /* = DEFAULT */, StreamOrDevice s /* = {} */) { if (bits <= 0) { std::ostringstream msg; @@ -3429,7 +3478,8 @@ array dequantize( msg << "[dequantize] Invalid value for group_size: " << group_size; throw std::invalid_argument(msg.str()); } - if (w.ndim() < 2 || scales.ndim() < 2 || biases.ndim() < 2) { + if (w.ndim() < 2 || scales.ndim() < 2 || + (biases.ndim() < 2 && mode == QuantizationMode::DEFAULT)) { std::ostringstream msg; msg << "[quantize] The matrix to be quantized must have at least 2 dimension " << "but it has only " << w.ndim() << "."; @@ -3443,7 +3493,8 @@ array dequantize( sshape.back() = -1; bshape.back() = -1; - if (wshape != sshape || wshape != bshape) { + if (wshape != sshape || + (wshape != bshape && mode == QuantizationMode::DEFAULT)) { throw std::invalid_argument( "[dequantize] Shape of scales and biases does not match the matrix"); } @@ -3484,10 +3535,31 @@ array dequantize( // Dequantize wshape.push_back(group_size); w_full = reshape(w_full, wshape, s); - w_full = multiply(w_full, expand_dims(scales, -1, s), s); - w_full = add(w_full, expand_dims(biases, -1, s), s); + if (mode == QuantizationMode::NF4) { + auto quantiles = array( + {-1.0, + -0.6961928009986877, + -0.5250730514526367, + -0.39491748809814453, + -0.28444138169288635, + -0.18477343022823334, + -0.09105003625154495, + 0.0, + 0.07958029955625534, + 0.16093020141124725, + 0.24611230194568634, + 0.33791524171829224, + 0.44070982933044434, + 0.5626170039176941, + 0.7229568362236023, + 1.0}); + w_full = take(quantiles, w_full, 0, s); + w_full = multiply(w_full, expand_dims(scales, -1, s), s); + } else { + w_full = multiply(w_full, expand_dims(scales, -1, s), s); + w_full = add(w_full, expand_dims(biases, -1, s), s); + } w_full = reshape(w_full, sshape, s); - return w_full; } @@ -3501,14 +3573,15 @@ array gather_qmm( bool transpose /* = true */, int group_size /* = 64 */, int bits /* = 4 */, + QuantizationMode mode /* = DEFAULT */, StreamOrDevice s /* = {} */) { if (!lhs_indices_ && !rhs_indices_) { return quantized_matmul( - x, w, scales, biases, transpose, group_size, bits, s); + x, w, scales, biases, transpose, group_size, bits, mode, s); } auto [w_inner_dims, w_outer_dims] = extract_quantized_matmul_dims( - "gather_qmm", x, w, scales, biases, transpose, group_size, bits); + "gather_qmm", x, w, scales, biases, transpose, group_size, bits, mode); // Extract indices and broadcast them array lhs_indices = indices_or_default(lhs_indices_, x, s); @@ -3529,7 +3602,8 @@ array gather_qmm( auto out = array( std::move(out_shape), out_type, - std::make_shared(to_stream(s), group_size, bits, transpose), + std::make_shared( + to_stream(s), group_size, bits, transpose, mode), {astype(x, out_type, s), w, astype(scales, out_type, s), diff --git a/mlx/ops.h b/mlx/ops.h index 069400ba8..a92172326 100644 --- a/mlx/ops.h +++ b/mlx/ops.h @@ -1236,6 +1236,7 @@ array quantized_matmul( bool transpose = true, int group_size = 64, int bits = 4, + QuantizationMode mode = QuantizationMode::DEFAULT, StreamOrDevice s = {}); /** Quantize a matrix along its last axis */ @@ -1243,6 +1244,7 @@ std::tuple quantize( const array& w, int group_size = 64, int bits = 4, + QuantizationMode mode = QuantizationMode::DEFAULT, StreamOrDevice s = {}); /** Dequantize a matrix produced by quantize() */ @@ -1252,6 +1254,7 @@ array dequantize( const array& biases, int group_size = 64, int bits = 4, + QuantizationMode mode = QuantizationMode::DEFAULT, StreamOrDevice s = {}); /** Compute matrix products with matrix-level gather. */ @@ -1265,6 +1268,7 @@ array gather_qmm( bool transpose = true, int group_size = 64, int bits = 4, + QuantizationMode mode = QuantizationMode::DEFAULT, StreamOrDevice s = {}); /** Returns a contraction of a and b over multiple dimensions. */ diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index b5cd9ac25..2abc8fc60 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -2359,6 +2359,7 @@ std::vector QuantizedMatmul::vjp( !transpose_, group_size_, bits_, + mode_, stream())); } @@ -2424,6 +2425,7 @@ std::vector GatherQMM::vjp( !transpose_, group_size_, bits_, + mode_, stream()), -3, stream()), diff --git a/mlx/primitives.h b/mlx/primitives.h index 8f49a4c1d..228687a0d 100644 --- a/mlx/primitives.h +++ b/mlx/primitives.h @@ -1445,11 +1445,13 @@ class QuantizedMatmul : public UnaryPrimitive { Stream stream, int group_size, int bits, - bool transpose) + bool transpose, + QuantizationMode mode) : UnaryPrimitive(stream), group_size_(group_size), bits_(bits), - transpose_(transpose) {} + transpose_(transpose), + mode_(mode) {} void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; @@ -1463,17 +1465,24 @@ class QuantizedMatmul : public UnaryPrimitive { int group_size_; int bits_; bool transpose_; + QuantizationMode mode_; void eval(const std::vector& inputs, array& out); }; class GatherQMM : public UnaryPrimitive { public: - explicit GatherQMM(Stream stream, int group_size, int bits, bool transpose) + explicit GatherQMM( + Stream stream, + int group_size, + int bits, + bool transpose, + QuantizationMode mode) : UnaryPrimitive(stream), group_size_(group_size), bits_(bits), - transpose_(transpose) {} + transpose_(transpose), + mode_(mode) {} void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; @@ -1487,6 +1496,7 @@ class GatherQMM : public UnaryPrimitive { int group_size_; int bits_; bool transpose_; + QuantizationMode mode_; void eval(const std::vector& inputs, array& out); }; diff --git a/python/mlx/nn/layers/quantized.py b/python/mlx/nn/layers/quantized.py index b8d727d88..45e3f7796 100644 --- a/python/mlx/nn/layers/quantized.py +++ b/python/mlx/nn/layers/quantized.py @@ -164,12 +164,14 @@ class QuantizedLinear(Module): bias: bool = True, group_size: int = 64, bits: int = 4, + mode: mx.QuantizationMode = mx.QuantizationMode.NF4, ): super().__init__() # Quantization config self.group_size = group_size self.bits = bits + self.mode = mode # Initialize the quantized weight scale = math.sqrt(1 / input_dims) @@ -210,18 +212,26 @@ class QuantizedLinear(Module): transpose=True, group_size=self.group_size, bits=self.bits, + mode=self.mode, ) if "bias" in self: x = x + self["bias"] return x + # if we pass mode to both then we can propagate it to the thing @classmethod - def from_linear(cls, linear_layer: Module, group_size: int = 64, bits: int = 4): + def from_linear( + cls, + linear_layer: Module, + group_size: int = 64, + bits: int = 4, + mode: mx.QuantizationMode = mx.QuantizationMode.NF4, + ): """Create a :obj:`QuantizedLinear` layer from a :obj:`Linear` layer.""" output_dims, input_dims = linear_layer.weight.shape - ql = cls(input_dims, output_dims, False, group_size, bits) + ql = cls(input_dims, output_dims, False, group_size, bits, mode) ql.weight, ql.scales, ql.biases = mx.quantize( - linear_layer.weight, group_size, bits + linear_layer.weight, group_size=group_size, bits=bits, mode=mode ) if "bias" in linear_layer: ql.bias = linear_layer.bias diff --git a/python/src/ops.cpp b/python/src/ops.cpp index 8bfe4e124..8c3ac3c9a 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -3616,6 +3616,10 @@ void init_ops(nb::module_& m) { array: An array of the same type as ``a`` rounded to the given number of decimals. )pbdoc"); + nb::enum_(m, "QuantizationMode") + .value("DEFAULT", QuantizationMode::DEFAULT) + .value("NF4", QuantizationMode::NF4) + .export_values(); m.def( "quantized_matmul", &quantized_matmul, @@ -3626,10 +3630,11 @@ void init_ops(nb::module_& m) { "transpose"_a = true, "group_size"_a = 64, "bits"_a = 4, + "mode"_a = QuantizationMode::DEFAULT, nb::kw_only(), "stream"_a = nb::none(), nb::sig( - "def quantized_matmul(x: array, w: array, /, scales: array, biases: array, transpose: bool = True, group_size: int = 64, bits: int = 4, *, stream: Union[None, Stream, Device] = None) -> array"), + "def quantized_matmul(x: array, w: quantized_array, transpose: bool = True, group_size: int = 64, bits: int = 4, mode: QuantizationMode = QuantizationMode.DEFAULT, *, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( Perform the matrix multiplication with the quantized matrix ``w``. The quantization uses one floating point scale and bias per ``group_size`` of @@ -3648,6 +3653,7 @@ void init_ops(nb::module_& m) { shares a scale and bias. (default: ``64``) bits (int, optional): The number of bits occupied by each element in ``w``. (default: ``4``) + mode (QuantizationMode, optional): The mode of quantization: see QuantizationMode (default: ``QuantizationMode.DEFAULT``) Returns: array: The result of the multiplication of ``x`` with ``w``. @@ -3658,10 +3664,11 @@ void init_ops(nb::module_& m) { nb::arg(), "group_size"_a = 64, "bits"_a = 4, + "mode"_a = QuantizationMode::DEFAULT, nb::kw_only(), "stream"_a = nb::none(), nb::sig( - "def quantize(w: array, /, group_size: int = 64, bits : int = 4, *, stream: Union[None, Stream, Device] = None) -> Tuple[array, array, array]"), + "def quantize(w: array, /, group_size: int = 64, bits : int = 4, mode: QuantizationMode = QuantizationMode.DEFAULT, *, stream: Union[None, Stream, Device] = None) -> quantized_array"), R"pbdoc( Quantize the matrix ``w`` using ``bits`` bits per element. @@ -3703,6 +3710,8 @@ void init_ops(nb::module_& m) { scale and bias. (default: ``64``) bits (int, optional): The number of bits occupied by each element of ``w`` in the returned quantized matrix. (default: ``4``) + mode (QuantizationMode, optional): The number of bits occupied by each element of + ``w`` in the returned quantized matrix. (default: ``QuantizationMode.DEFAULT``) Returns: tuple: A tuple containing @@ -3719,6 +3728,7 @@ void init_ops(nb::module_& m) { "biases"_a, "group_size"_a = 64, "bits"_a = 4, + "mode"_a = QuantizationMode::DEFAULT, nb::kw_only(), "stream"_a = nb::none(), nb::sig( @@ -3743,12 +3753,14 @@ void init_ops(nb::module_& m) { scale and bias. (default: ``64``) bits (int, optional): The number of bits occupied by each element in ``w``. (default: ``4``) + mode (QuantizationMode, optional): The number of bits occupied by each element of + ``w`` in the returned quantized matrix. (default: ``QuantizationMode.DEFAULT``) Returns: array: The dequantized version of ``w`` )pbdoc"); m.def( - "gater_qmm", + "gather_qmm", &gather_qmm, nb::arg(), nb::arg(), @@ -3759,6 +3771,7 @@ void init_ops(nb::module_& m) { "transpose"_a = true, "group_size"_a = 64, "bits"_a = 4, + "mode"_a = QuantizationMode::DEFAULT, nb::kw_only(), "stream"_a = nb::none(), nb::sig( @@ -3788,6 +3801,8 @@ void init_ops(nb::module_& m) { shares a scale and bias. (default: ``64``) bits (int, optional): The number of bits occupied by each element in ``w``. (default: ``4``) + mode (QuantizationMode, optional): The number of bits occupied by each element of + ``w`` in the returned quantized matrix. (default: ``QuantizationMode.DEFAULT``) Returns: array: The result of the multiplication of ``x`` with ``w`` diff --git a/python/tests/test_quantized.py b/python/tests/test_quantized.py index 92ad3d3e7..0797ba11b 100644 --- a/python/tests/test_quantized.py +++ b/python/tests/test_quantized.py @@ -115,18 +115,23 @@ class TestQuantized(mlx_tests.MLXTestCase): k1, k2 = mx.random.split(key) tests = product( [128, 64, 32], # group_size - [2, 4, 8], # bits + # [2, 4, 8], # bits + [4], # bits [512, 1024], # M [512, 1024], # N + [mx.QuantizationMode.DEFAULT, mx.QuantizationMode.DEFAULT], ) - for group_size, bits, M, N in tests: - with self.subTest(shape=(M, N), group_size=group_size, bits=bits): + for group_size, bits, M, N, mode in tests: + with self.subTest( + shape=(M, N), group_size=group_size, bits=bits, mode=mode + ): x = mx.random.normal(shape=(1, N), key=k1) w = mx.random.normal(shape=(M, N), key=k2) - w_q, scales, biases = mx.quantize(w, group_size, bits) - w_hat = mx.dequantize(w_q, scales, biases, group_size, bits) + w_q = mx.quantize(w, group_size, bits) + w_q, scales, biases = mx.quantize(w, group_size, bits, mode=mode) + w_hat = mx.dequantize(w_q, scales, biases, group_size, bits, mode=mode) y_q = mx.quantized_matmul( - x, w_q, scales, biases, True, group_size, bits + x, w_q, scales, biases, True, group_size, bits, mode=mode ) y_hat = x @ w_hat.T self.assertEqual(y_q.shape, y_hat.shape) @@ -137,18 +142,21 @@ class TestQuantized(mlx_tests.MLXTestCase): k1, k2 = mx.random.split(key) tests = product( [128, 64, 32], # group_size - [2, 4, 8], # bits + [4], # bits [512, 1024], # M [512, 1024], # N + [mx.QuantizationMode.NF4, mx.QuantizationMode.DEFAULT], ) - for group_size, bits, M, N in tests: - with self.subTest(shape=(M, N), group_size=group_size, bits=bits): + for group_size, bits, M, N, mode in tests: + with self.subTest( + shape=(M, N), group_size=group_size, bits=bits, mode=mode + ): x = mx.random.normal(shape=(1, N), key=k1) w = mx.random.normal(shape=(N, M), key=k2) - w_q, scales, biases = mx.quantize(w, group_size, bits) - w_hat = mx.dequantize(w_q, scales, biases, group_size, bits) + w_q, scales, biases = mx.quantize(w, group_size, bits, mode=mode) + w_hat = mx.dequantize(w_q, scales, biases, group_size, bits, mode=mode) y_q = mx.quantized_matmul( - x, w_q, scales, biases, False, group_size, bits + x, w_q, scales, biases, False, group_size, bits, mode ) y_hat = x @ w_hat self.assertEqual(y_q.shape, y_hat.shape) @@ -171,37 +179,47 @@ class TestQuantized(mlx_tests.MLXTestCase): mx.eval(y) def test_small_matrix(self): - w = mx.random.normal(shape=(8, 256)) - w_q, scales, biases = mx.quantize(w) - w_hat = mx.dequantize(w_q, scales, biases) + for mode in [mx.QuantizationMode.NF4, mx.QuantizationMode.DEFAULT]: + with self.subTest(mode=mode): + w = mx.random.normal(shape=(8, 256)) + w_q, scales, biases = mx.quantize(w, mode=mode) + w_hat = mx.dequantize(w_q, scales, biases, mode=mode) - # Test qmv - x = mx.random.normal(shape=(1, 256)) - y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=True) - y_hat = x @ w_hat.T - self.assertEqual(y_q.shape, y_hat.shape) - self.assertLess((y_q - y_hat).abs().max(), 1e-3) + # Test qmv + x = mx.random.normal(shape=(1, 256)) + y_q = mx.quantized_matmul( + x, w_q, scales, biases, transpose=True, mode=mode + ) + y_hat = x @ w_hat.T + self.assertEqual(y_q.shape, y_hat.shape) + self.assertLess((y_q - y_hat).abs().max(), 1e-3) - # Test qmm_t - x = mx.random.normal(shape=(10, 256)) - y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=True) - y_hat = x @ w_hat.T - self.assertEqual(y_q.shape, y_hat.shape) - self.assertLess((y_q - y_hat).abs().max(), 1e-3) + # Test qmm_t + x = mx.random.normal(shape=(10, 256)) + y_q = mx.quantized_matmul( + x, w_q, scales, biases, transpose=True, mode=mode + ) + y_hat = x @ w_hat.T + self.assertEqual(y_q.shape, y_hat.shape) + self.assertLess((y_q - y_hat).abs().max(), 1e-3) - # Test qmv - x = mx.random.normal(shape=(1, 8)) - y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=False) - y_hat = x @ w_hat - self.assertEqual(y_q.shape, y_hat.shape) - self.assertLess((y_q - y_hat).abs().max(), 1e-3) + # Test qmv + x = mx.random.normal(shape=(1, 8)) + y_q = mx.quantized_matmul( + x, w_q, scales, biases, transpose=False, mode=mode + ) + y_hat = x @ w_hat + self.assertEqual(y_q.shape, y_hat.shape) + self.assertLess((y_q - y_hat).abs().max(), 1e-3) - # Test qmm - x = mx.random.normal(shape=(10, 8)) - y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=False) - y_hat = x @ w_hat - self.assertEqual(y_q.shape, y_hat.shape) - self.assertLess((y_q - y_hat).abs().max(), 1e-3) + # Test qmm + x = mx.random.normal(shape=(10, 8)) + y_q = mx.quantized_matmul( + x, w_q, scales, biases, transpose=False, mode=mode + ) + y_hat = x @ w_hat + self.assertEqual(y_q.shape, y_hat.shape) + self.assertLess((y_q - y_hat).abs().max(), 1e-3) def test_non_multiples(self): w = mx.random.normal(shape=(33, 256))