From d7acf59fd0058569dfeaaf9acedff82202686ff8 Mon Sep 17 00:00:00 2001 From: Alex Barron Date: Tue, 18 Mar 2025 18:52:22 -0700 Subject: [PATCH] add trellis quant mode --- mlx/backend/metal/kernels/quantized.h | 428 ++++++++++++++++++++++++-- mlx/backend/metal/primitives.cpp | 51 +-- mlx/backend/metal/quantized.cpp | 188 ++++++++++- mlx/backend/metal/reduce.h | 7 + mlx/fast.cpp | 52 +++- mlx/fast.h | 5 +- mlx/fast_primitives.h | 32 ++ mlx/ops.cpp | 75 +++-- mlx/ops.h | 5 +- mlx/primitives.cpp | 3 + mlx/primitives.h | 22 +- python/mlx/nn/layers/embedding.py | 9 +- python/mlx/nn/layers/linear.py | 13 +- python/mlx/nn/layers/quantized.py | 53 +++- python/src/ops.cpp | 14 +- python/tests/test_quantized.py | 3 + 16 files changed, 852 insertions(+), 108 deletions(-) diff --git a/mlx/backend/metal/kernels/quantized.h b/mlx/backend/metal/kernels/quantized.h index af9d7860e..b36b67c7d 100644 --- a/mlx/backend/metal/kernels/quantized.h +++ b/mlx/backend/metal/kernels/quantized.h @@ -684,6 +684,115 @@ METAL_FUNC void qmv_fast_impl( } } +template +float inst3(uint16_t xi) { + uint32_t x = xi; + x = a * x + b; + x = (x & 0b10001111111111111000111111111111) ^ m; + auto xf = reinterpret_cast(&x); + return xf[0] + xf[1]; +} + +template +METAL_FUNC void qmv_trellis_impl( + const device uint32_t* w, + const device T* scales, + const device T* biases, + const device T* x, + device T* y, + const constant int& in_vec_size, + const constant int& out_vec_size, + uint3 tid [[threadgroup_position_in_grid]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + constexpr int power_of_2_bits = (bits & (bits - 1)) == 0; + constexpr int packs_per_thread = 2; + constexpr int num_simdgroups = 2; + constexpr int results_per_simdgroup = 4; + constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 32 / bits; + constexpr int bytes_per_pack = power_of_2_bits ? 4 : 3; + constexpr int values_per_thread = pack_factor * packs_per_thread; + constexpr int block_size = values_per_thread * SIMD_SIZE; + constexpr int reads_per = 16 / bits; + constexpr int local_w_size = + results_per_simdgroup * values_per_thread / reads_per; + + const device uint8_t* ws = (const device uint8_t*)w; + + typedef float U; + + thread U x_thread[values_per_thread]; + thread uint16_t w_thread[local_w_size]; + thread U result[results_per_simdgroup] = {0}; + + // Adjust positions + const int in_vec_size_w = in_vec_size * bytes_per_pack / pack_factor; + const int out_row = tid.y * (num_simdgroups * results_per_simdgroup) + + simd_gid * results_per_simdgroup; + + ws += out_row * in_vec_size_w + simd_lid * packs_per_thread * bytes_per_pack; + x += tid.x * in_vec_size + simd_lid * values_per_thread; + y += tid.x * out_vec_size + out_row; + + T scale = scales[0]; + + for (int k = 0; k < in_vec_size; k += block_size) { +#pragma clang loop unroll(full) + for (int i = 0; i < values_per_thread; i++) { + x_thread[i] = x[i]; + } + +#pragma clang loop unroll(full) + for (int row = 0; row < results_per_simdgroup; row++) { +#pragma clang loop unroll(full) + for (int i = 0; i < values_per_thread / reads_per; i++) { + auto wl = (const device uint16_t*)(ws + row * in_vec_size_w); + w_thread[row * values_per_thread / reads_per + i] = wl[i]; + } + } + +#pragma clang loop unroll(full) + for (int row = 0; row < results_per_simdgroup; row++) { +#pragma clang loop unroll(full) + for (int i = 0; i < values_per_thread / reads_per; i++) { + int index = row * values_per_thread / reads_per + i; + uint16_t w0 = w_thread[index]; + uint16_t w1 = w_thread[(index + 1) % local_w_size]; + + uint16_t wx = w0 ^ w1; + uint16_t wx1 = wx ^ 1; + uint16_t wf = w0 ^ (1 << bits); + + if (bits == 2) { + result[row] += x_thread[8 * i] * inst3(w0); + result[row] += x_thread[8 * i + 1] * inst3(wf ^ (wx1 & 0x3)); + result[row] += x_thread[8 * i + 2] * inst3(w0 ^ (wx & 0xf)); + result[row] += x_thread[8 * i + 3] * inst3(w0 ^ (wx1 & 0x3f)); + result[row] += x_thread[8 * i + 4] * inst3(w0 ^ (wx & 0xff)); + result[row] += x_thread[8 * i + 5] * inst3(w0 ^ (wx1 & 0x3ff)); + result[row] += x_thread[8 * i + 6] * inst3(w0 ^ (wx & 0xfff)); + result[row] += x_thread[8 * i + 7] * inst3(w0 ^ (wx1 & 0x3fff)); + } else if (bits == 4) { + result[row] += x_thread[4 * i] * inst3(w0); + result[row] += x_thread[4 * i + 1] * inst3(wf ^ (wx1 & 0xf)); + result[row] += x_thread[4 * i + 2] * inst3(w0 ^ (wx & 0xff)); + result[row] += x_thread[4 * i + 3] * inst3(w0 ^ (wx1 & 0xfff)); + } + } + } + + ws += block_size * bytes_per_pack / pack_factor; + x += block_size; + } + + for (int row = 0; row < results_per_simdgroup; row++) { + result[row] = simd_sum(result[row]); + if (simd_lid == 0) { + y[row] = static_cast(scale * result[row]); + } + } +} + template METAL_FUNC void qmv_impl( const device uint32_t* w, @@ -1302,7 +1411,13 @@ METAL_FUNC void adjust_matrix_offsets( y += tid.z * output_stride; } -template +template < + typename T, + int group_size, + int bits, + int D, + bool batched, + bool trellis = false> [[kernel]] void qmv_quad( const device uint32_t* w [[buffer(0)]], const device T* scales [[buffer(1)]], @@ -1354,7 +1469,12 @@ template quad_lid); } -template +template < + typename T, + int group_size, + int bits, + bool batched, + bool trellis = false> [[kernel]] void qmv_fast( const device uint32_t* w [[buffer(0)]], const device T* scales [[buffer(1)]], @@ -1393,20 +1513,39 @@ template b_strides, tid); } - qmv_fast_impl( - w, - scales, - biases, - x, - y, - in_vec_size, - out_vec_size, - tid, - simd_gid, - simd_lid); + if (trellis) { + qmv_trellis_impl( + w, + scales, + biases, + x, + y, + in_vec_size, + out_vec_size, + tid, + simd_gid, + simd_lid); + } else { + qmv_fast_impl( + w, + scales, + biases, + x, + y, + in_vec_size, + out_vec_size, + tid, + simd_gid, + simd_lid); + } } -template +template < + typename T, + const int group_size, + const int bits, + bool batched, + bool trellis = false> [[kernel]] void qmv( const device uint32_t* w [[buffer(0)]], const device T* scales [[buffer(1)]], @@ -1458,7 +1597,12 @@ template simd_lid); } -template +template < + typename T, + const int group_size, + const int bits, + bool batched, + bool trellis = false> [[kernel]] void qvm( const device uint32_t* w [[buffer(0)]], const device T* scales [[buffer(1)]], @@ -1572,6 +1716,7 @@ template < const int bits, const bool aligned_N, const bool batched, + bool trellis = false, const int BM = 32, const int BK = 32, const int BN = 32> @@ -1630,6 +1775,7 @@ template < const int group_size, const int bits, const bool batched, + bool trellis = false, const int BM = 32, const int BK = 32, const int BN = 32> @@ -1685,7 +1831,7 @@ template < w, scales, biases, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid); } -template +template [[kernel]] void bs_qmv_fast( const device uint32_t* w [[buffer(0)]], const device T* scales [[buffer(1)]], @@ -1734,20 +1880,34 @@ template s_strides, b_strides, tid); - qmv_fast_impl( - w, - scales, - biases, - x, - y, - in_vec_size, - out_vec_size, - tid, - simd_gid, - simd_lid); + if (trellis) { + qmv_trellis_impl( + w, + scales, + biases, + x, + y, + in_vec_size, + out_vec_size, + tid, + simd_gid, + simd_lid); + } else { + qmv_fast_impl( + w, + scales, + biases, + x, + y, + in_vec_size, + out_vec_size, + tid, + simd_gid, + simd_lid); + } } -template +template [[kernel]] void bs_qmv( const device uint32_t* w [[buffer(0)]], const device T* scales [[buffer(1)]], @@ -1809,7 +1969,7 @@ template simd_lid); } -template +template [[kernel]] void bs_qvm( const device uint32_t* w [[buffer(0)]], const device T* scales [[buffer(1)]], @@ -1876,6 +2036,7 @@ template < const int group_size, const int bits, const bool aligned_N, + bool trellis = false, const int BM = 32, const int BK = 32, const int BN = 32> @@ -1943,6 +2104,7 @@ template < typename T, const int group_size, const int bits, + bool trellis = false, const int BM = 32, const int BK = 32, const int BN = 32> @@ -2157,3 +2319,211 @@ template } } } + +template < + typename T, + const bool use_overlap, + const int bits = 2, + const int timesteps = 128> +[[kernel]] void trellis_viterbi( + const device T* w [[buffer(0)]], + device float16_t* score [[buffer(1)]], + device uint8_t* pointers [[buffer(2)]], + const device uint16_t* overlap [[buffer(3)]], + uint3 tid [[thread_position_in_grid]]) { + constexpr uint16_t L = 16; + constexpr uint L2 = 1 << L; + + uint16_t idx = tid.y * 16; + + threadgroup float16_t swap_V[16384]; + + thread float16_t min_V[16] = {0}; + + for (uint16_t t = 0; t < timesteps; t++) { + uint16_t tt = t % 8 == 0 ? L / bits : t % 8; + uint16_t shift = ((tt - 1) % (L / bits)) * bits; + uint16_t flip = (t == 0 || (t > 1 && t % 8 == 1)) ? (1 << bits) + 1 : t % 2; + + uint16_t s000 = 1 << (shift - 6); + uint16_t s0 = 1 << (shift - 2); + uint16_t s1 = 1 << (shift); + uint16_t s2 = 1 << (shift + 2); + uint16_t s4 = 1 << (shift + 4); + + if (t > 1) { + uint16_t i = 0; + uint16_t loff = 1 << (metal::clamp((shift + 14) % 16, 2, 12)); + uint16_t hoff = shift > 4 ? 4 : shift == 4 ? 16 : 1; + uint16_t ind = idx; + + if (shift == 0) { + ind >>= 2; + } else if (shift == 14) { + ind = (ind & 0xfff) + (ind >> 12); + } else if (shift == 2) { + } else if (shift == 4) { + ind = ((ind >> 4) & 0x3) + (ind & ~0x3f); + } else if (shift == 6) { + ind = ((ind / s0) % 4) * s1 + ((ind / s1) % 4) + (ind / s2) * s2; + } else { + ind = ((ind / 16) % s000) * 16 + ((ind / s0) % 4) * s1 + + ((ind / s1) % 4) + (ind / s2) * s2; + } + + for (uint16_t high = 0; high < 4; high++) { + uint16_t sub_ind = ind; + for (uint16_t low = 0; low < 4; low++) { + swap_V[sub_ind] = min_V[i]; + i++; + sub_ind += loff; + } + ind += hoff; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + for (uint16_t i = 0; i < 16; i++) { + min_V[i] = swap_V[idx + i]; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + } + + uint16_t rolled_t = use_overlap ? t : (t + 64) % 128; + T w_t = w[tid.x * timesteps + rolled_t]; + + for (uint16_t i = 0; i < 4; i++) { + thread float16_t min_val[4] = {INFINITY, INFINITY, INFINITY, INFINITY}; + thread uint16_t min_idx[4] = {0}; + + uint16_t ii = idx * 4 + i * 16; + uint16_t big_idx = ii; + if (shift > 0 && shift < 14) { + big_idx = ((ii / s2) % 4) + (ii / s4 * s4); + if (shift > 2) { + big_idx += ((ii / 16) % s0) * 4; + } + } else if (shift == 14 && t > 0) { + big_idx >>= 2; + } + + uint16_t loff = t == 0 ? 4 : s1; + uint16_t hoff = (t == 0 || shift == 14) ? 1 : s2; + + for (uint16_t high = 0; high < 4; high++) { + uint16_t sub_ind = big_idx; + for (uint16_t low = 0; low < 4; low++) { + float mse = inst3(sub_ind ^ flip) - w_t; + mse *= mse; + + float16_t new_val = min_V[i * 4 + high] + mse; + if (new_val < min_val[low]) { + min_val[low] = new_val; + min_idx[low] = high; + } + sub_ind += loff; + } + big_idx += hoff; + } + + for (uint16_t j = 0; j < 4; j++) { + min_V[i * 4 + j] = min_val[j]; + pointers[tid.x * L2 / 4 * timesteps + t * L2 / 4 + idx + i * 4 + j] = + min_idx[j]; + } + } + if (t == 0 && use_overlap) { + uint16_t over = overlap[tid.x * 128 + 64]; + over = over & ((1 << 14) - 1); + for (uint16_t i = 0; i < 16; i++) { + uint16_t rs = (over >> 2) ^ 1; + uint16_t ls = (idx + i) & ((1 << 12) - 1); + min_V[i] = rs == ls ? min_V[i] : INFINITY; + } + } + } + if (use_overlap) { + uint16_t over = overlap[tid.x * 128 + 64]; + over = over & ((1 << 14) - 1); + uint16_t node = + (over % 4) * 4096 + ((over / 4) % 1024) * 4 + (over / 4096) % 4; + for (uint16_t i = 0; i < 16; i++) { + min_V[i] = (idx + i) == node ? min_V[i] : INFINITY; + } + } + for (uint16_t i = 0; i < 16; i++) { + score[tid.x * L2 / 4 + idx + i] = min_V[i]; + } +} + +uint16_t remove_bits(uint16_t i, uint16_t shift) { + uint16_t lower = i & ((1 << shift) - 1); + uint16_t upper = i & ~((1 << (shift + 2)) - 1); + return lower + (upper >> 2); +} + +uint16_t swap_bits(uint16_t i, uint16_t shift) { + uint16_t diff = ((i >> shift) ^ i) & 0x3; + i = i ^ diff; + i ^= diff << shift; + return i; +} + +template +[[kernel]] void trellis_backtrack( + const device uint32_t* start [[buffer(0)]], + const device uint8_t* pointers [[buffer(1)]], + device uint16_t* out [[buffer(2)]], + const device uint16_t* overlap [[buffer(3)]], + uint3 tid [[thread_position_in_grid]]) { + constexpr uint16_t L = 16; + + uint16_t node = start[tid.x]; + + uint16_t dir = + pointers[tid.x * timesteps * 16384 + (timesteps - 1) * 16384 + node]; + + node = (node % 4) * 4096 + ((node / 4) % 1024) * 4 + (node / 4096) % 4; + node ^= 1; + node += dir * 16384; + + out[tid.x * timesteps + timesteps - 1] = node; + + for (int t = timesteps - 2; t >= 0; t--) { + uint16_t shift = (t % (L / bits)) * bits; + uint16_t mask = ((1 << L) - 1) ^ (((1 << bits) - 1) << shift); + uint16_t flip = t % (L / bits) == 0 ? 1 << bits : 1; + uint16_t i = (node & mask) ^ flip; + + if (shift > 0) { + i = remove_bits(i, shift); + } + + if (t == 0) { + i >>= 2; + } + + if (t % 2 == 1 || t == 0) { + i ^= 1; + } + + shift = shift == 0 ? L : shift; + + if (t > 0) { + i = swap_bits(i, shift - 2); + } + + shift = shift == L ? 0 : shift; + + uint16_t last_p = pointers[tid.x * timesteps * 16384 + t * 16384 + i]; + if ((t % 8 == 1 && t > 1) || t == 0) { + last_p ^= 1; + } + + node = ((node & mask) ^ flip) | (last_p << shift); + if (t == 0 && use_overlap) { + uint16_t over = overlap[tid.x * 128 + 64]; + over = over & ((1 << 14) - 1); + node = (node & 0xfffc) + (over & 0x3); + } + out[tid.x * timesteps + t] = node; + } +} \ No newline at end of file diff --git a/mlx/backend/metal/primitives.cpp b/mlx/backend/metal/primitives.cpp index 6946ffb9e..e47ee9030 100644 --- a/mlx/backend/metal/primitives.cpp +++ b/mlx/backend/metal/primitives.cpp @@ -16,6 +16,8 @@ #include "mlx/scheduler.h" #include "mlx/utils.h" +#include + namespace mlx::core { template @@ -158,33 +160,25 @@ void Arange::eval_gpu(const std::vector& inputs, array& out) { compute_encoder.dispatch_threads(grid_dims, group_dims); } -void ArgReduce::eval_gpu(const std::vector& inputs, array& out) { - assert(inputs.size() == 1); - auto& in = inputs[0]; - out.set_data(allocator::malloc(out.nbytes())); - auto& s = stream(); +void arg_reduce_dispatch( + const array& in, + array& out, + int axis, + std::string op_name, + const Stream& s) { auto& d = metal::device(s.device); - std::string op_name; - switch (reduce_type_) { - case ArgReduce::ArgMin: - op_name = "argmin_"; - break; - case ArgReduce::ArgMax: - op_name = "argmax_"; - break; - } // Prepare the shapes, strides and axis arguments. auto in_strides = in.strides(); auto shape = in.shape(); auto out_strides = out.strides(); - auto axis_stride = in_strides[axis_]; - size_t axis_size = shape[axis_]; + auto axis_stride = in_strides[axis]; + size_t axis_size = shape[axis]; if (out_strides.size() == in_strides.size()) { - out_strides.erase(out_strides.begin() + axis_); + out_strides.erase(out_strides.begin() + axis); } - in_strides.erase(in_strides.begin() + axis_); - shape.erase(shape.begin() + axis_); + in_strides.erase(in_strides.begin() + axis); + shape.erase(shape.begin() + axis); size_t ndim = shape.size(); // ArgReduce @@ -192,7 +186,7 @@ void ArgReduce::eval_gpu(const std::vector& inputs, array& out) { int n_reads = 4; auto& compute_encoder = d.get_command_encoder(s.index); { - auto kernel = d.get_kernel(op_name + type_to_name(in)); + auto kernel = d.get_kernel(op_name + "_" + type_to_name(in)); NS::UInteger thread_group_size = std::min( (axis_size + n_reads - 1) / n_reads, kernel->maxTotalThreadsPerThreadgroup()); @@ -226,6 +220,23 @@ void ArgReduce::eval_gpu(const std::vector& inputs, array& out) { } } +void ArgReduce::eval_gpu(const std::vector& inputs, array& out) { + assert(inputs.size() == 1); + std::string op_name; + switch (reduce_type_) { + case ArgReduce::ArgMin: + op_name = "argmin"; + break; + case ArgReduce::ArgMax: + op_name = "argmax"; + break; + } + auto& in = inputs[0]; + out.set_data(allocator::malloc_or_wait(out.nbytes())); + auto& s = stream(); + arg_reduce_dispatch(in, out, axis_, op_name, s); +} + void AsType::eval_gpu(const std::vector& inputs, array& out) { CopyType ctype = inputs[0].flags().contiguous ? CopyType::Vector : CopyType::General; diff --git a/mlx/backend/metal/quantized.cpp b/mlx/backend/metal/quantized.cpp index 8d1d176c4..964eb081e 100644 --- a/mlx/backend/metal/quantized.cpp +++ b/mlx/backend/metal/quantized.cpp @@ -7,11 +7,14 @@ #include "mlx/backend/metal/device.h" #include "mlx/backend/metal/kernels.h" #include "mlx/backend/metal/reduce.h" +#include "mlx/backend/metal/slicing.h" #include "mlx/backend/metal/utils.h" #include "mlx/fast_primitives.h" #include "mlx/primitives.h" #include "mlx/utils.h" +#include + namespace mlx::core { void launch_qmm( @@ -31,6 +34,7 @@ void launch_qmm( bool gather, bool aligned, bool quad, + const std::string& mode, const Stream& s) { auto& x_pre = inputs[0]; auto& w_pre = inputs[1]; @@ -54,8 +58,12 @@ void launch_qmm( }; auto x = ensure_row_contiguous_last_dims(x_pre); auto w = ensure_row_contiguous_last_dims(w_pre); - auto scales = ensure_row_contiguous_last_dims(scales_pre); - auto biases = ensure_row_contiguous_last_dims(biases_pre); + auto scales = scales_pre; + auto biases = biases_pre; + if (mode == "affine") { + scales = ensure_row_contiguous_last_dims(scales_pre); + biases = ensure_row_contiguous_last_dims(biases_pre); + } int x_batch_ndims = x.ndim() - 2; auto& x_shape = x.shape(); @@ -68,6 +76,8 @@ void launch_qmm( std::string aligned_n = (O % 32) == 0 ? "true" : "false"; + bool is_trellis = (mode == "trellis"); + std::ostringstream kname; auto type_string = get_type_string(x.dtype()); kname << name << "_" << type_string << "_gs_" << group_size << "_b_" << bits; @@ -80,24 +90,47 @@ void launch_qmm( if (!gather) { kname << "_batch_" << batched; } + if (mode == "trellis") { + kname << "_mode_" << is_trellis; + } // Encode and dispatch kernel std::string template_def; if (quad) { template_def = get_template_definition( - kname.str(), name, type_string, group_size, bits, D, batched); + kname.str(), + name, + type_string, + group_size, + bits, + D, + batched, + is_trellis); } else if (aligned && !gather) { template_def = get_template_definition( - kname.str(), name, type_string, group_size, bits, aligned_n, batched); + kname.str(), + name, + type_string, + group_size, + bits, + aligned_n, + batched, + is_trellis); } else if (!gather && !aligned) { template_def = get_template_definition( - kname.str(), name, type_string, group_size, bits, batched); + kname.str(), name, type_string, group_size, bits, batched, is_trellis); } else if (aligned && gather) { template_def = get_template_definition( - kname.str(), name, type_string, group_size, bits, aligned_n); + kname.str(), + name, + type_string, + group_size, + bits, + aligned_n, + is_trellis); } else { template_def = get_template_definition( - kname.str(), name, type_string, group_size, bits); + kname.str(), name, type_string, group_size, bits, is_trellis); } auto& d = metal::device(s.device); auto kernel = get_quantized_kernel(d, kname.str(), template_def); @@ -276,6 +309,7 @@ void qmm_op( int group_size, int bits, bool gather, + const std::string& mode, const Stream& s) { out.set_data(allocator::malloc(out.nbytes())); @@ -354,7 +388,7 @@ void qmm_op( group_dims = MTL::Size(simdgroup_size, 1, 1); grid_dims = MTL::Size(B, (O + bo - 1) / bo, N); quad = true; - } else if (B < qmv_batch_limit && O % 8 == 0 && D % 512 == 0 && D >= 512) { + } else if (B < 10000 && O % 8 == 0 && D % 512 == 0 && D >= 512) { name += "qmv_fast"; int bo = 8; int bd = 32; @@ -420,19 +454,34 @@ void qmm_op( gather, aligned, quad, + mode, s); } void QuantizedMatmul::eval_gpu(const std::vector& inputs, array& out) { assert(inputs.size() == 4); qmm_op( - inputs, out, transpose_, group_size_, bits_, /*gather=*/false, stream()); + inputs, + out, + transpose_, + group_size_, + bits_, + /*gather=*/false, + mode_, + stream()); } void GatherQMM::eval_gpu(const std::vector& inputs, array& out) { assert(inputs.size() == 6); qmm_op( - inputs, out, transpose_, group_size_, bits_, /*gather=*/true, stream()); + inputs, + out, + transpose_, + group_size_, + bits_, + /*gather=*/true, + mode_, + stream()); } void fast::AffineQuantize::eval_gpu( @@ -516,4 +565,123 @@ void fast::AffineQuantize::eval_gpu( d.add_temporaries(std::move(copies), s.index); } +void viterbi( + array& w, + array& scores, + array& pointers, + array& start, + array& overlap, + bool use_overlap, + const Stream& s) { + int B = scores.shape(0); + auto& d = metal::device(s.device); + + auto& compute_encoder = d.get_command_encoder(s.index); + compute_encoder.set_input_array(w, 0); + compute_encoder.set_output_array(scores, 1); + compute_encoder.set_output_array(pointers, 2); + if (use_overlap) { + compute_encoder.set_input_array(overlap, 3); + } + + std::ostringstream kname; + auto type_string = get_type_string(w.dtype()); + kname << "trellis_viterbi_" << type_string << "_overlap_" << use_overlap; + auto template_def = get_template_definition( + kname.str(), "trellis_viterbi", type_string, use_overlap); + auto kernel = get_quantized_kernel(d, kname.str(), template_def); + compute_encoder.set_compute_pipeline_state(kernel); + + auto group_dims = MTL::Size(1, 1024, 1); + auto grid_dims = MTL::Size(B, 1024, 1); + + compute_encoder.dispatch_threads(grid_dims, group_dims); + arg_reduce_dispatch(scores, start, 1, "argmin", s); +} + +void viterbi_backtrack( + array& start, + array& pointers, + array& out, + array& overlap, + bool use_overlap, + const Stream& s) { + int B = start.shape(0); + + auto& d = metal::device(s.device); + + auto& compute_encoder = d.get_command_encoder(s.index); + compute_encoder.set_input_array(start, 0); + compute_encoder.set_input_array(pointers, 1); + compute_encoder.set_output_array(out, 2); + if (use_overlap) { + compute_encoder.set_input_array(overlap, 3); + } + + std::ostringstream kname; + kname << "trellis_backtrack" << "_overlap_" << use_overlap; + auto template_def = + get_template_definition(kname.str(), "trellis_backtrack", use_overlap); + auto kernel = get_quantized_kernel(d, kname.str(), template_def); + compute_encoder.set_compute_pipeline_state(kernel); + + auto group_dims = MTL::Size(256, 1, 1); + auto grid_dims = MTL::Size(B, 1, 1); + + compute_encoder.dispatch_threads(grid_dims, group_dims); +} + +void fast::TrellisQuantize::eval_gpu( + const std::vector& inputs, + std::vector& outputs) { + auto& w_pre = inputs[0]; + auto& out = outputs[0]; + out.set_data(allocator::malloc_or_wait(out.nbytes())); + + auto& s = stream(); + auto& d = metal::device(s.device); + + std::vector copies; + auto ensure_row_contiguous = [&copies, &s](const array& arr) { + if (arr.flags().row_contiguous) { + return arr; + } else { + array arr_copy(arr.shape(), arr.dtype(), nullptr, {}); + copy_gpu(arr, arr_copy, CopyType::General, s); + copies.push_back(arr_copy); + return arr_copy; + } + }; + auto w = ensure_row_contiguous(w_pre); + + int B = w.shape(0); + int T = w.shape(1); + + constexpr int num_states = 1 << 14; + + array scores({B, num_states}, float16, nullptr, {}); + scores.set_data(allocator::malloc_or_wait(scores.nbytes())); + copies.push_back(scores); + + array pointers({B, T, num_states}, uint8, nullptr, {}); + pointers.set_data(allocator::malloc_or_wait(pointers.nbytes())); + copies.push_back(pointers); + + array start({B}, uint32, nullptr, {}); + start.set_data(allocator::malloc_or_wait(start.nbytes())); + copies.push_back(start); + + array rolled({B, T}, uint16, nullptr, {}); + rolled.set_data(allocator::malloc_or_wait(rolled.nbytes())); + copies.push_back(rolled); + + viterbi(w, scores, pointers, start, out, false, s); + viterbi_backtrack(start, pointers, rolled, out, false, s); + + viterbi(w, scores, pointers, start, rolled, true, s); + viterbi_backtrack(start, pointers, out, rolled, true, s); + + d.add_temporaries(std::move(copies), s.index); +} + } // namespace mlx::core diff --git a/mlx/backend/metal/reduce.h b/mlx/backend/metal/reduce.h index a997d7e24..c5b57c7db 100644 --- a/mlx/backend/metal/reduce.h +++ b/mlx/backend/metal/reduce.h @@ -38,4 +38,11 @@ void strided_reduce_general_dispatch( metal::Device& d, const Stream& s); +void arg_reduce_dispatch( + const array& in, + array& out, + int axis, + std::string op_name, + const Stream& s); + } // namespace mlx::core diff --git a/mlx/fast.cpp b/mlx/fast.cpp index 77210f713..402c4999f 100644 --- a/mlx/fast.cpp +++ b/mlx/fast.cpp @@ -11,6 +11,8 @@ #include "mlx/transforms.h" #include "mlx/transforms_impl.h" +#include + namespace mlx::core::fast { std::vector Custom::vjp( @@ -832,7 +834,7 @@ array pack_and_quantize( return packed_w; } -std::tuple +std::vector affine_quantize(const array& w, int group_size, int bits, StreamOrDevice s_) { auto s = to_stream(s_); @@ -1028,6 +1030,54 @@ array affine_dequantize( return fallback({w, scales, biases})[0]; } +std::vector +trellis_quantize(const array& w_, int bits, StreamOrDevice s_) { + if (bits != 2) { + throw std::invalid_argument( + "Only 2 bit Trellis quants are currently supported."); + } + + int Tx = 4; + int Ty = 32; + int batch_size = 256; + + auto s = to_stream(s_); + + int L = 16; + int M = w_.shape(-2); + int T = Tx * Ty; + auto scale = std(astype(w_, float32, s), s); + auto w = divide(w_, scale, s); + w = astype(w, float16, s); + + w = reshape(w, {M / Tx, Tx, -1, Ty}, s); + w = transpose(w, {0, 2, 1, 3}, s); + w = reshape(w, {-1, T}, s); + + auto fallback = [bits, s](const std::vector& inputs) mutable + -> std::vector { return {inputs[0]}; }; + + auto q = zeros({w.shape(0), w.shape(1) * bits / L}, uint16, s); + for (int i = 0; i < w.shape(0); i += batch_size) { + auto w_batch = slice(w, {i, 0}, {i + batch_size, w.shape(-1)}, s); + auto q_batch = array( + w_batch.shape(), + uint16, + std::make_shared(s, fallback, bits, true), + {w_batch}); + q_batch = slice(q_batch, {0, 0}, q_batch.shape(), {1, L / bits}, s); + q = slice_update(q, q_batch, {i, 0}, {i + batch_size, q.shape(-1)}, s); + eval(q); + } + + q = reshape(q, {M / Tx, -1, Tx, Ty * bits / L}, s); + q = transpose(q, {0, 2, 1, 3}, s); + q = reshape(q, {M, -1}, s); + q = view(q, uint32, s); + + return {q, scale, scale}; +} + bool AffineQuantize::is_equivalent(const Primitive& other) const { const AffineQuantize& p_other = static_cast(other); return ( diff --git a/mlx/fast.h b/mlx/fast.h index 7aebe3863..c00e7c5be 100644 --- a/mlx/fast.h +++ b/mlx/fast.h @@ -52,7 +52,7 @@ array scaled_dot_product_attention( const std::vector& mask_arrs = {}, StreamOrDevice s = {}); -std::tuple affine_quantize( +std::vector affine_quantize( const array& w, int group_size = 64, int bits = 4, @@ -66,6 +66,9 @@ array affine_dequantize( int bits = 4, StreamOrDevice s = {}); +std::vector +trellis_quantize(const array& w, int bits = 4, StreamOrDevice s = {}); + typedef std::variant TemplateArg; typedef std::function( diff --git a/mlx/fast_primitives.h b/mlx/fast_primitives.h index 4d9e505ee..f9a46fccb 100644 --- a/mlx/fast_primitives.h +++ b/mlx/fast_primitives.h @@ -269,6 +269,38 @@ class AffineQuantize : public Custom { bool dequantize_; }; +class TrellisQuantize : public Custom { + public: + explicit TrellisQuantize( + Stream stream, + std::function(std::vector)> fallback, + int bits, + bool dequantize) + : Custom(stream, fallback), bits_(bits), dequantize_(dequantize) {} + + void eval_cpu(const std::vector& inputs, std::vector& outputs) + override { + throw std::runtime_error("NYI"); + }; + + void eval_gpu(const std::vector& inputs, std::vector& outputs) + override; + + DEFINE_PRINT(TrellisQuantize); + + // bool is_equivalent(const Primitive& other) const override; + // std::vector output_shapes(const std::vector& inputs) + // override; + auto state() const { + return std::make_tuple(nullptr, bits_, dequantize_); + } + + private: + std::function(std::vector)> fallback_; + int bits_; + bool dequantize_; +}; + struct CustomKernelShapeInfo { bool shape = false; bool strides = false; diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 6d1116905..08d406527 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -17,6 +17,8 @@ #include "mlx/transforms_impl.h" #include "mlx/utils.h" +#include + namespace mlx::core { namespace { @@ -79,7 +81,8 @@ std::pair extract_quantized_matmul_dims( const array& biases, bool transpose, int group_size, - int bits) { + int bits, + const std::string& mode) { if (w.dtype() != uint32) { std::ostringstream msg; msg << "[" << tag << "] The weight matrix should be uint32 " @@ -87,12 +90,23 @@ std::pair extract_quantized_matmul_dims( throw std::invalid_argument(msg.str()); } - if (scales.shape() != biases.shape()) { - std::ostringstream msg; - msg << "[" << tag << "] Scales and biases should have the same shape. " - << "Received scales with shape " << scales.shape() - << " and biases with " << biases.shape(); - throw std::invalid_argument(msg.str()); + if (mode == "affine") { + if (scales.shape() != biases.shape()) { + std::ostringstream msg; + msg << "[" << tag << "] Scales and biases should have the same shape. " + << "Received scales with shape " << scales.shape() + << " and biases with " << biases.shape(); + throw std::invalid_argument(msg.str()); + } + + if (w.shape(-1) * 32 / bits != scales.shape(-1) * group_size) { + std::ostringstream msg; + msg << "[" << tag << "] The shapes of the weight and scales are " + << "incompatible based on bits and group_size. w.shape() == " + << w.shape() << " and scales.shape() == " << scales.shape() + << " with group_size=" << group_size << " and bits=" << bits; + throw std::invalid_argument(msg.str()); + } } if (!std::equal( @@ -105,15 +119,6 @@ std::pair extract_quantized_matmul_dims( throw std::invalid_argument(msg.str()); } - if (w.shape(-1) * 32 / bits != scales.shape(-1) * group_size) { - std::ostringstream msg; - msg << "[" << tag << "] The shapes of the weight and scales are " - << "incompatible based on bits and group_size. w.shape() == " - << w.shape() << " and scales.shape() == " << scales.shape() - << " with group_size=" << group_size << " and bits=" << bits; - throw std::invalid_argument(msg.str()); - } - int x_inner_dims = x.shape(-1); // Calculate the expanded w's dims @@ -717,6 +722,9 @@ array slice( << "array with dimension " << a.ndim() << "."; throw std::invalid_argument(msg.str()); } + // std::cout << "start " << start << std::endl; + // std::cout << "stop " << stop << std::endl; + // std::cout << "strides " << strides << std::endl; auto [has_neg_strides, out_shape] = normalize_slice(a.shape(), start, stop, strides); @@ -3969,10 +3977,19 @@ array quantized_matmul( bool transpose /* = true */, int group_size /* = 64 */, int bits /* = 4 */, + const std::string& mode /* = "affine" */, StreamOrDevice s /* = {} */) { // Check and extract the quantized matrix shape against x auto [w_inner_dims, w_outer_dims] = extract_quantized_matmul_dims( - "quantized_matmul", x, w, scales, biases, transpose, group_size, bits); + "quantized_matmul", + x, + w, + scales, + biases, + transpose, + group_size, + bits, + mode); auto dtype = result_type(x, scales, biases); if (!issubdtype(dtype, floating)) { @@ -3996,16 +4013,26 @@ array quantized_matmul( std::move(out_shape), dtype, std::make_shared( - to_stream(s), group_size, bits, transpose), + to_stream(s), group_size, bits, transpose, mode), std::move(inputs)); } -std::tuple quantize( +std::vector quantize( const array& w, int group_size /* = 64 */, int bits /* = 4 */, + const std::string& mode, /* = affine */ StreamOrDevice s /* = {} */) { - return fast::affine_quantize(w, group_size, bits, s); + if (mode == "affine") { + return fast::affine_quantize(w, group_size, bits, s); + } else if (mode == "trellis") { + return fast::trellis_quantize(w, bits, s); + } else { + std::ostringstream msg; + msg << "[quantize] Unsupported quantization mode " << mode << "." + << std::endl; + throw std::invalid_argument(msg.str()); + } } array dequantize( @@ -4028,14 +4055,15 @@ array gather_qmm( bool transpose /* = true */, int group_size /* = 64 */, int bits /* = 4 */, + const std::string& mode /* = "affine" */, StreamOrDevice s /* = {} */) { if (!lhs_indices_ && !rhs_indices_) { return quantized_matmul( - x, w, scales, biases, transpose, group_size, bits, s); + x, w, scales, biases, transpose, group_size, bits, mode, s); } auto [w_inner_dims, w_outer_dims] = extract_quantized_matmul_dims( - "gather_qmm", x, w, scales, biases, transpose, group_size, bits); + "gather_qmm", x, w, scales, biases, transpose, group_size, bits, mode); // Extract indices and broadcast them array lhs_indices = indices_or_default(lhs_indices_, x, s); @@ -4067,7 +4095,8 @@ array gather_qmm( return array( std::move(out_shape), out_type, - std::make_shared(to_stream(s), group_size, bits, transpose), + std::make_shared( + to_stream(s), group_size, bits, transpose, mode), {astype(x, out_type, s), w, astype(scales, out_type, s), diff --git a/mlx/ops.h b/mlx/ops.h index ce3d4ff44..d5c4db06b 100644 --- a/mlx/ops.h +++ b/mlx/ops.h @@ -1323,13 +1323,15 @@ array quantized_matmul( bool transpose = true, int group_size = 64, int bits = 4, + const std::string& mode = "affine", StreamOrDevice s = {}); /** Quantize a matrix along its last axis */ -std::tuple quantize( +std::vector quantize( const array& w, int group_size = 64, int bits = 4, + const std::string& mode = "affine", StreamOrDevice s = {}); /** Dequantize a matrix produced by quantize() */ @@ -1352,6 +1354,7 @@ array gather_qmm( bool transpose = true, int group_size = 64, int bits = 4, + const std::string& mode = "affine", StreamOrDevice s = {}); /** Returns a contraction of a and b over multiple dimensions. */ diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index 8328d96da..eab1c3bd8 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -3012,6 +3012,7 @@ std::vector QuantizedMatmul::vjp( !transpose_, group_size_, bits_, + mode_, stream())); } @@ -3040,6 +3041,7 @@ std::vector QuantizedMatmul::jvp( transpose_, group_size_, bits_, + mode_, stream())}; } @@ -3098,6 +3100,7 @@ std::vector GatherQMM::vjp( !transpose_, group_size_, bits_, + mode_, stream()), -3, stream()), diff --git a/mlx/primitives.h b/mlx/primitives.h index 7738b273b..42493aa53 100644 --- a/mlx/primitives.h +++ b/mlx/primitives.h @@ -1552,11 +1552,13 @@ class QuantizedMatmul : public UnaryPrimitive { Stream stream, int group_size, int bits, - bool transpose) + bool transpose, + const std::string mode) : UnaryPrimitive(stream), group_size_(group_size), bits_(bits), - transpose_(transpose) {} + transpose_(transpose), + mode_(mode) {} void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; @@ -1567,22 +1569,29 @@ class QuantizedMatmul : public UnaryPrimitive { bool is_equivalent(const Primitive& other) const override; std::vector output_shapes(const std::vector& inputs) override; auto state() const { - return std::make_tuple(group_size_, bits_, transpose_); + return std::make_tuple(group_size_, bits_, transpose_, mode_); } private: int group_size_; int bits_; bool transpose_; + const std::string mode_; }; class GatherQMM : public UnaryPrimitive { public: - explicit GatherQMM(Stream stream, int group_size, int bits, bool transpose) + explicit GatherQMM( + Stream stream, + int group_size, + int bits, + bool transpose, + const std::string& mode) : UnaryPrimitive(stream), group_size_(group_size), bits_(bits), - transpose_(transpose) {} + transpose_(transpose), + mode_(mode) {} void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; @@ -1592,13 +1601,14 @@ class GatherQMM : public UnaryPrimitive { DEFINE_PRINT(GatherQMM) bool is_equivalent(const Primitive& other) const override; auto state() const { - return std::make_tuple(group_size_, bits_, transpose_); + return std::make_tuple(group_size_, bits_, transpose_, mode_); } private: int group_size_; int bits_; bool transpose_; + const std::string mode_; }; class RandomBits : public UnaryPrimitive { diff --git a/python/mlx/nn/layers/embedding.py b/python/mlx/nn/layers/embedding.py index 1e15a59cc..c7995cb85 100644 --- a/python/mlx/nn/layers/embedding.py +++ b/python/mlx/nn/layers/embedding.py @@ -1,6 +1,7 @@ # Copyright © 2023-2024 Apple Inc. import math +from typing import Literal import mlx.core as mx from mlx.nn.layers.base import Module @@ -39,6 +40,12 @@ class Embedding(Module): """ return x @ self.weight.T - def to_quantized(self, group_size: int = 64, bits: int = 4): + def to_quantized( + self, + group_size: int = 64, + bits: int = 4, + mode: Literal["affine", "trellis"] = "affine", + fake: bool = False, + ): """Return a :obj:`QuantizedEmbedding` layer that approximates this embedding layer.""" return QuantizedEmbedding.from_embedding(self, group_size, bits) diff --git a/python/mlx/nn/layers/linear.py b/python/mlx/nn/layers/linear.py index 63caa911c..038dc3c58 100644 --- a/python/mlx/nn/layers/linear.py +++ b/python/mlx/nn/layers/linear.py @@ -1,11 +1,12 @@ # Copyright © 2023 Apple Inc. import math -from typing import Any +from typing import Any, Literal import mlx.core as mx from mlx.nn.layers.base import Module from mlx.nn.layers.quantized import QuantizedLinear +from mlx.nn.layers.viterbi import quantize as trellis_quantize class Identity(Module): @@ -70,9 +71,15 @@ class Linear(Module): x = x @ self["weight"].T return x - def to_quantized(self, group_size: int = 64, bits: int = 4): + def to_quantized( + self, + group_size: int = 64, + bits: int = 4, + mode: Literal["affine", "trellis"] = "affine", + fake: bool = False, + ): """Return a :obj:`QuantizedLinear` layer that approximates this layer.""" - return QuantizedLinear.from_linear(self, group_size, bits) + return QuantizedLinear.from_linear(self, group_size, bits, mode=mode, fake=fake) class Bilinear(Module): diff --git a/python/mlx/nn/layers/quantized.py b/python/mlx/nn/layers/quantized.py index 823a0084f..2dca6cc8b 100644 --- a/python/mlx/nn/layers/quantized.py +++ b/python/mlx/nn/layers/quantized.py @@ -1,10 +1,11 @@ # Copyright © 2023-2024 Apple Inc. import math -from typing import Callable, Optional, Union +from typing import Callable, Literal, Optional, Union import mlx.core as mx from mlx.nn.layers.base import Module +from mlx.nn.layers.viterbi import quantize as trellis_quantize from mlx.utils import tree_map_with_path @@ -12,7 +13,9 @@ def quantize( model: Module, group_size: int = 64, bits: int = 4, + mode: Literal["affine", "trellis"] = "affine", class_predicate: Optional[Callable[[str, Module], Union[bool, dict]]] = None, + fake: bool = False, ): """Quantize the sub-modules of a module according to a predicate. @@ -21,7 +24,7 @@ def quantize( will be quantized. Note also, the module is updated in-place. Args: - model (mlx.nn.Module): The model whose leaf modules may be quantized. + model (mlx.nn.Module):, mode: Literal["affine", "trellis"] = "affine" The model whose leaf modules may be quantized. group_size (int): The quantization group size (see :func:`mlx.core.quantize`). Default: ``64``. bits (int): The number of bits per parameter (see @@ -36,12 +39,15 @@ def quantize( class_predicate = class_predicate or (lambda _, m: hasattr(m, "to_quantized")) def _maybe_quantize(path, m): + print(path) if bool_or_params := class_predicate(path, m): if hasattr(m, "to_quantized"): if isinstance(bool_or_params, bool): - return m.to_quantized(group_size=group_size, bits=bits) + return m.to_quantized( + group_size=group_size, bits=bits, mode=mode, fake=fake + ) elif isinstance(bool_or_params, dict): - return m.to_quantized(**bool_or_params) + return m.to_quantized(**bool_or_params, fake=fake) else: raise ValueError( "``class_predicate`` must return a bool" @@ -131,7 +137,11 @@ class QuantizedEmbedding(Module): @classmethod def from_embedding( - cls, embedding_layer: Module, group_size: int = 64, bits: int = 4 + cls, + embedding_layer: Module, + group_size: int = 64, + bits: int = 4, + mode: Literal["affine", "trellis"] = "affine", ): """Create a :obj:`QuantizedEmbedding` layer from an :obj:`Embedding` layer.""" embedding_dims, dims = embedding_layer.weight.shape @@ -170,12 +180,14 @@ class QuantizedLinear(Module): bias: bool = True, group_size: int = 64, bits: int = 4, + mode: Literal["affine", "trellis"] = "affine", ): super().__init__() # Quantization config self.group_size = group_size self.bits = bits + self.mode = mode # Initialize the quantized weight scale = math.sqrt(1 / input_dims) @@ -216,19 +228,40 @@ class QuantizedLinear(Module): transpose=True, group_size=self.group_size, bits=self.bits, + mode=self.mode, ) if "bias" in self: x = x + self["bias"] return x @classmethod - def from_linear(cls, linear_layer: Module, group_size: int = 64, bits: int = 4): + def from_linear( + cls, + linear_layer: Module, + group_size: int = 64, + bits: int = 4, + mode: Literal["affine", "trellis"] = "affine", + fake: bool = False, + ): """Create a :obj:`QuantizedLinear` layer from a :obj:`Linear` layer.""" output_dims, input_dims = linear_layer.weight.shape - ql = cls(input_dims, output_dims, False, group_size, bits) - ql.weight, ql.scales, ql.biases = mx.quantize( - linear_layer.weight, group_size, bits - ) + ql = cls(input_dims, output_dims, False, group_size, bits, mode) + if mode == "trellis": + if fake: + ql.weight = mx.zeros( + (output_dims, input_dims // 32 * bits), dtype=mx.uint32 + ) + ql.scales = mx.array(0.0) + ql.biases = mx.array(0.0) + else: + ql.weight, ql.scales, ql.biases = mx.quantize( + linear_layer.weight, bits=bits, mode="trellis" + ) + else: + ql.weight, ql.scales, ql.biases = mx.quantize( + linear_layer.weight, group_size, bits, mode="affine" + ) + if "bias" in linear_layer: ql.bias = linear_layer.bias diff --git a/python/src/ops.cpp b/python/src/ops.cpp index 7f06a4ddf..f9530ce2b 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -4116,10 +4116,11 @@ void init_ops(nb::module_& m) { "transpose"_a = true, "group_size"_a = 64, "bits"_a = 4, + "mode"_a = "affine", nb::kw_only(), "stream"_a = nb::none(), nb::sig( - "def quantized_matmul(x: array, w: array, /, scales: array, biases: array, transpose: bool = True, group_size: int = 64, bits: int = 4, *, stream: Union[None, Stream, Device] = None) -> array"), + "def quantized_matmul(x: array, w: array, /, scales: array, biases: array, transpose: bool = True, group_size: int = 64, bits: int = 4, mode: Literal['affine', 'trellis'], *, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( Perform the matrix multiplication with the quantized matrix ``w``. The quantization uses one floating point scale and bias per ``group_size`` of @@ -4138,6 +4139,8 @@ void init_ops(nb::module_& m) { shares a scale and bias. Default: ``64``. bits (int, optional): The number of bits occupied by each element in ``w``. Default: ``4``. + mode (str, optional): The mode to use for quantization. + Default: ``affine``. Returns: array: The result of the multiplication of ``x`` with ``w``. @@ -4149,9 +4152,10 @@ void init_ops(nb::module_& m) { "group_size"_a = 64, "bits"_a = 4, nb::kw_only(), + "mode"_a = "affine", "stream"_a = nb::none(), nb::sig( - "def quantize(w: array, /, group_size: int = 64, bits : int = 4, *, stream: Union[None, Stream, Device] = None) -> tuple[array, array, array]"), + "def quantize(w: array, /, group_size: int = 64, bits : int = 4, *, mode: Literal['affine', 'trellis'], stream: Union[None, Stream, Device] = None) -> tuple[array, array, array]"), R"pbdoc( Quantize the matrix ``w`` using ``bits`` bits per element. @@ -4193,6 +4197,7 @@ void init_ops(nb::module_& m) { scale and bias. Default: ``64``. bits (int, optional): The number of bits occupied by each element of ``w`` in the returned quantized matrix. Default: ``4``. + mode (str): The quantization mode to use. Default: ``affine``. Returns: tuple: A tuple containing @@ -4249,10 +4254,11 @@ void init_ops(nb::module_& m) { "transpose"_a = true, "group_size"_a = 64, "bits"_a = 4, + "mode"_a = "affine", nb::kw_only(), "stream"_a = nb::none(), nb::sig( - "def gather_qmm(x: array, w: array, /, scales: array, biases: array, lhs_indices: Optional[array] = None, rhs_indices: Optional[array] = None, transpose: bool = True, group_size: int = 64, bits: int = 4, *, stream: Union[None, Stream, Device] = None) -> array"), + "def gather_qmm(x: array, w: array, /, scales: array, biases: array, lhs_indices: Optional[array] = None, rhs_indices: Optional[array] = None, transpose: bool = True, group_size: int = 64, bits: int = 4, mode: Literal['affine', 'trellis'], *, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( Perform quantized matrix multiplication with matrix-level gather. @@ -4278,6 +4284,8 @@ void init_ops(nb::module_& m) { shares a scale and bias. Default: ``64``. bits (int, optional): The number of bits occupied by each element in ``w``. Default: ``4``. + mode (str, optional): The mode to use for quantization. + Default: ``affine``. Returns: array: The result of the multiplication of ``x`` with ``w`` diff --git a/python/tests/test_quantized.py b/python/tests/test_quantized.py index 160eb6400..0bea55970 100644 --- a/python/tests/test_quantized.py +++ b/python/tests/test_quantized.py @@ -10,6 +10,9 @@ import mlx_tests class TestQuantized(mlx_tests.MLXTestCase): def test_quantize_dequantize(self): w = mx.random.normal(shape=(128, 512)) + w_q, scales, biases = mx.quantize(w, bits=2, mode="trellis") + print(w_q, scales, biases) + for gs in [32, 64, 128]: for b in [2, 3, 6, 4, 8]: with self.subTest(gs=gs, b=b):