diff --git a/docs/src/python/fast.rst b/docs/src/python/fast.rst index 30ade264e..f78f40563 100644 --- a/docs/src/python/fast.rst +++ b/docs/src/python/fast.rst @@ -12,5 +12,4 @@ Fast layer_norm rope scaled_dot_product_attention - affine_quantize metal_kernel diff --git a/mlx/backend/common/quantized.cpp b/mlx/backend/common/quantized.cpp index d939334c9..4b8bbdb89 100644 --- a/mlx/backend/common/quantized.cpp +++ b/mlx/backend/common/quantized.cpp @@ -6,11 +6,34 @@ #include "mlx/backend/common/ops.h" #include "mlx/fast_primitives.h" #include "mlx/primitives.h" +#include "mlx/utils.h" namespace mlx::core { namespace { +template +void extract_bits(const uint8_t* w_in, T* w_out) { + assert(bits == 3 || bits == 6); + if (bits == 3) { + w_out[0] = static_cast(w_in[0] & 0x7); + w_out[1] = static_cast((w_in[0] & 0x38) >> 3); + w_out[2] = static_cast(((w_in[0] & 0xc0) >> 6) + ((w_in[1] & 0x1) << 2)); + w_out[3] = static_cast((w_in[1] & 0xe) >> 1); + w_out[4] = static_cast((w_in[1] & 0x70) >> 4); + w_out[5] = static_cast(((w_in[1] & 0x80) >> 7) + ((w_in[2] & 0x3) << 1)); + w_out[6] = static_cast((w_in[2] & 0x1c) >> 2); + w_out[7] = static_cast((w_in[2] & 0xe0) >> 5); + } else if (bits == 6) { + w_out[0] = static_cast(w_in[0] & 0x3f); + w_out[1] = + static_cast(((w_in[0] >> 6) & 0x03) + ((w_in[1] & 0x0f) << 2)); + w_out[2] = + static_cast(((w_in[1] >> 4) & 0x0f) + ((w_in[2] & 0x03) << 4)); + w_out[3] = static_cast((w_in[2] >> 2) & 0x3f); + } +} + template void _qmm( T* result, @@ -22,13 +45,12 @@ void _qmm( int N, int K) { constexpr int bitmask = (1 << bits) - 1; - constexpr int pack_factor = 32 / bits; + constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits; + constexpr int bytes_per_pack = (bits == 3 || bits == 6) ? 3 : 1; constexpr int packs_in_group = group_size / pack_factor; - const int Ng = N / group_size; - const int Nw = N / pack_factor; for (int m = 0; m < M; m++) { - const uint32_t* w_local = w; + const uint8_t* w_local = (const uint8_t*)w; const T* scales_local = scales; const T* biases_local = biases; @@ -42,13 +64,25 @@ void _qmm( T scale = *scales_local++; T bias = *biases_local++; for (int ng = 0; ng < packs_in_group; ng++) { - uint32_t wi = *w_local++; - + if (bits == 3 || bits == 6) { + T wl[pack_factor]; + extract_bits(w_local, wl); #pragma clang loop unroll(full) - for (int p = 0; p < pack_factor; p++) { - (*result_local++) += - xi * (scale * static_cast(wi & bitmask) + bias); - wi >>= bits; + for (int p = 0; p < pack_factor; p++) { + (*result_local++) += xi * (scale * wl[p] + bias); + } + w_local += bytes_per_pack; + + } else { + uint8_t wi = *w_local++; +#pragma clang loop unroll(full) + for (int p = 0; p < pack_factor; p++) { + (*result_local++) += + xi * (scale * static_cast(wi & bitmask) + bias); + if (bits != 8) { + wi >>= bits; + } + } } } } @@ -69,13 +103,12 @@ void _qmm_t( int N, int K) { constexpr int bitmask = (1 << bits) - 1; - constexpr int pack_factor = 32 / bits; + constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits; + constexpr int bytes_per_pack = (bits == 3 || bits == 6) ? 3 : 1; constexpr int packs_in_group = group_size / pack_factor; - const int Kg = K / group_size; - const int Kw = K / pack_factor; for (int m = 0; m < M; m++) { - const uint32_t* w_local = w; + const uint8_t* w_local = (const uint8_t*)w; const T* scales_local = scales; const T* biases_local = biases; @@ -87,12 +120,26 @@ void _qmm_t( T bias = *biases_local++; for (int kw = 0; kw < packs_in_group; kw++) { - uint32_t wi = *w_local++; - + if (bits == 3 || bits == 6) { + T wl[pack_factor]; + extract_bits(w_local, wl); #pragma clang loop unroll(full) - for (int p = 0; p < pack_factor; p++) { - sum += (*x_local++) * (scale * static_cast(wi & bitmask) + bias); - wi >>= bits; + for (int p = 0; p < pack_factor; p++) { + sum += x_local[p] * (scale * wl[p] + bias); + } + w_local += bytes_per_pack; + x_local += pack_factor; + + } else { + uint8_t wi = *w_local++; +#pragma clang loop unroll(full) + for (int p = 0; p < pack_factor; p++) { + sum += + (*x_local++) * (scale * static_cast(wi & bitmask) + bias); + if (bits != 8) { + wi >>= bits; + } + } } } } @@ -104,6 +151,55 @@ void _qmm_t( } } +template +void _qmm_dispatch_transpose( + T* result, + const T* x, + const uint32_t* w, + const T* scales, + const T* biases, + int M, + int N, + int K, + bool transposed_w) { + if (transposed_w) { + return _qmm_t(result, x, w, scales, biases, M, N, K); + } else { + return _qmm(result, x, w, scales, biases, M, N, K); + } +} + +template +void _qmm_dispatch_group( + T* result, + const T* x, + const uint32_t* w, + const T* scales, + const T* biases, + int M, + int N, + int K, + int group_size, + bool transposed_w) { + switch (group_size) { + case 32: + _qmm_dispatch_transpose( + result, x, w, scales, biases, M, N, K, transposed_w); + break; + case 64: + _qmm_dispatch_transpose( + result, x, w, scales, biases, M, N, K, transposed_w); + break; + case 128: + _qmm_dispatch_transpose( + result, x, w, scales, biases, M, N, K, transposed_w); + break; + default: + throw std::invalid_argument( + "Quantization group size must be 32, 64 or 128."); + } +} + template void _qmm_dispatch_typed( T* result, @@ -118,79 +214,29 @@ void _qmm_dispatch_typed( int bits, bool transposed_w) { switch (bits) { - case 2: { - switch (group_size) { - case 32: - if (transposed_w) { - return _qmm_t(result, x, w, scales, biases, M, N, K); - } else { - return _qmm(result, x, w, scales, biases, M, N, K); - } - case 64: - if (transposed_w) { - return _qmm_t(result, x, w, scales, biases, M, N, K); - } else { - return _qmm(result, x, w, scales, biases, M, N, K); - } - case 128: - if (transposed_w) { - return _qmm_t(result, x, w, scales, biases, M, N, K); - } else { - return _qmm(result, x, w, scales, biases, M, N, K); - } - } - } - case 4: { - switch (group_size) { - case 32: - if (transposed_w) { - return _qmm_t(result, x, w, scales, biases, M, N, K); - } else { - return _qmm(result, x, w, scales, biases, M, N, K); - } - case 64: - if (transposed_w) { - return _qmm_t(result, x, w, scales, biases, M, N, K); - } else { - return _qmm(result, x, w, scales, biases, M, N, K); - } - case 128: - if (transposed_w) { - return _qmm_t(result, x, w, scales, biases, M, N, K); - } else { - return _qmm(result, x, w, scales, biases, M, N, K); - } - } - } - case 8: { - switch (group_size) { - case 32: - if (transposed_w) { - return _qmm_t(result, x, w, scales, biases, M, N, K); - } else { - return _qmm(result, x, w, scales, biases, M, N, K); - } - case 64: - if (transposed_w) { - return _qmm_t(result, x, w, scales, biases, M, N, K); - } else { - return _qmm(result, x, w, scales, biases, M, N, K); - } - case 128: - if (transposed_w) { - return _qmm_t(result, x, w, scales, biases, M, N, K); - } else { - return _qmm(result, x, w, scales, biases, M, N, K); - } - } - } + case 2: + _qmm_dispatch_group( + result, x, w, scales, biases, M, N, K, group_size, transposed_w); + break; + case 3: + _qmm_dispatch_group( + result, x, w, scales, biases, M, N, K, group_size, transposed_w); + break; + case 4: + _qmm_dispatch_group( + result, x, w, scales, biases, M, N, K, group_size, transposed_w); + break; + case 6: + _qmm_dispatch_group( + result, x, w, scales, biases, M, N, K, group_size, transposed_w); + break; + case 8: + _qmm_dispatch_group( + result, x, w, scales, biases, M, N, K, group_size, transposed_w); + break; + default: + throw std::invalid_argument("Quantization bits must be 2, 3, 4, 6 or 8."); } - std::ostringstream msg; - msg << "Quantization type not supported. Provided bits=" << bits - << " and group_size=" << group_size - << ". The supported options are bits in " - << "{2, 4, 8} and group_size in {64, 128}."; - throw std::invalid_argument(msg.str()); } void _qmm_dispatch( @@ -406,51 +452,52 @@ void GatherQMM::eval(const std::vector& inputs, array& out) { transpose_); } -template +template void quantize( const array& w_, array& out_, array& scales_, array& biases_, int bits, - int group_size, - bool compute_scale_bias) { + int group_size) { const T* w = w_.data(); + + auto out = out_.data(); T* scales = scales_.data(); T* biases = biases_.data(); - auto out = out_.data(); T n_bins = (1 << bits) - 1; T eps = 1e-7; - int el_per_int = 32 / bits; - int int_per_group = group_size / el_per_int; + bool power_of_2_bits = is_power_of_2(bits); + int el_per_int = bits == 3 ? 8 : bits == 6 ? 4 : 32 / bits; + // For 3/6 bits we read 3 uint8s at a time instead of 1 uint32 + int bytes_per_pack = power_of_2_bits ? 1 : 3; + int int_per_group = group_size * bytes_per_pack / el_per_int; size_t n_groups = w_.size() / group_size; for (size_t i = 0; i < n_groups; ++i) { size_t w_idx = i * group_size; - if (compute_scale_bias) { - T w_min = std::numeric_limits::infinity(); - T w_max = -w_min; - for (int j = 0; j < group_size; ++j) { - w_max = std::max(w_max, w[w_idx + j]); - w_min = std::min(w_min, w[w_idx + j]); - } - bool mask = std::abs(w_min) > std::abs(w_max); - T scale = std::max(T((w_max - w_min) / n_bins), eps); - scale = mask ? scale : -scale; + T w_min = std::numeric_limits::infinity(); + T w_max = -w_min; + for (int j = 0; j < group_size; ++j) { + w_max = std::max(w_max, w[w_idx + j]); + w_min = std::min(w_min, w[w_idx + j]); + } + bool mask = std::abs(w_min) > std::abs(w_max); + T scale = std::max(T((w_max - w_min) / n_bins), eps); + scale = mask ? scale : -scale; - auto edge = mask ? w_min : w_max; - auto q0 = std::rint(edge / scale); - if (q0 == 0) { - scales[i] = scale; - biases[i] = 0; - } else { - scales[i] = edge / q0; - biases[i] = edge; - } + auto edge = mask ? w_min : w_max; + auto q0 = std::rint(edge / scale); + if (q0 == 0) { + scales[i] = scale; + biases[i] = 0; + } else { + scales[i] = edge / q0; + biases[i] = edge; } size_t out_idx = i * int_per_group; - for (int j = 0; j < int_per_group; ++j) { + for (int j = 0; j < int_per_group / bytes_per_pack; ++j) { uint32_t out_el = 0; for (int k = 0; k < el_per_int; ++k) { T w_el = w[w_idx + j * el_per_int + k]; @@ -458,7 +505,13 @@ void quantize( w_el = std::min(std::max(w_el, T(0)), n_bins); out_el |= static_cast(w_el) << (k * bits); } - out[out_idx + j] = out_el; + if (power_of_2_bits) { + out[out_idx + j] = out_el; + } else { + out[out_idx + bytes_per_pack * j] = out_el & 0xff; + out[out_idx + bytes_per_pack * j + 1] = (out_el & 0xff00) >> 8; + out[out_idx + bytes_per_pack * j + 2] = (out_el & 0xff0000) >> 16; + } } } } @@ -466,8 +519,6 @@ void quantize( void fast::AffineQuantize::eval_cpu( const std::vector& inputs, std::vector& outputs) { - bool compute_scale_bias = inputs.size() == 1; - auto ensure_row_contiguous = [](const array& arr) { if (arr.flags().row_contiguous) { return arr; @@ -482,23 +533,29 @@ void fast::AffineQuantize::eval_cpu( auto& out = outputs[0]; out.set_data(allocator::malloc_or_wait(out.nbytes())); - auto& scales = - compute_scale_bias ? outputs[1] : const_cast(inputs[1]); - auto& biases = - compute_scale_bias ? outputs[2] : const_cast(inputs[2]); - if (compute_scale_bias) { - scales.set_data(allocator::malloc_or_wait(scales.nbytes())); - biases.set_data(allocator::malloc_or_wait(biases.nbytes())); - } + auto& scales = outputs[1]; + auto& biases = outputs[2]; + scales.set_data(allocator::malloc_or_wait(scales.nbytes())); + biases.set_data(allocator::malloc_or_wait(biases.nbytes())); if (w.dtype() == float16) { - quantize( - w, out, scales, biases, bits_, group_size_, compute_scale_bias); + if (is_power_of_2(bits_)) { + quantize(w, out, scales, biases, bits_, group_size_); + } else { + quantize(w, out, scales, biases, bits_, group_size_); + } } else if (w.dtype() == bfloat16) { - quantize( - w, out, scales, biases, bits_, group_size_, compute_scale_bias); + if (is_power_of_2(bits_)) { + quantize( + w, out, scales, biases, bits_, group_size_); + } else { + quantize(w, out, scales, biases, bits_, group_size_); + } } else if (w.dtype() == float32) { - quantize( - w, out, scales, biases, bits_, group_size_, compute_scale_bias); + if (is_power_of_2(bits_)) { + quantize(w, out, scales, biases, bits_, group_size_); + } else { + quantize(w, out, scales, biases, bits_, group_size_); + } } else { throw std::runtime_error( "[fast::AffineQuantize::eval_cpu] Only supports floating point inputs"); diff --git a/mlx/backend/metal/kernels/quantized.h b/mlx/backend/metal/kernels/quantized.h index 90e2dfaf1..b23154a8d 100644 --- a/mlx/backend/metal/kernels/quantized.h +++ b/mlx/backend/metal/kernels/quantized.h @@ -13,8 +13,8 @@ MLX_MTL_CONST int QUAD_SIZE = 4; template inline U load_vector(const device T* x, thread U* x_thread) { static_assert( - bits == 2 || bits == 4 || bits == 8, - "Template undefined for bits not in {2, 4, 8}"); + bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8, + "Template undefined for bits not in {2, 3, 4, 6, 8}"); U sum = 0; @@ -28,6 +28,21 @@ inline U load_vector(const device T* x, thread U* x_thread) { } } + else if (bits == 3) { + for (int i = 0; i < values_per_thread; i += 8) { + sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3] + x[i + 4] + x[i + 5] + + x[i + 6] + x[i + 7]; + x_thread[i] = x[i]; + x_thread[i + 1] = x[i + 1] / 8.0f; + x_thread[i + 2] = x[i + 2] / 64.0f; + x_thread[i + 3] = x[i + 3] / 2.0f; + x_thread[i + 4] = x[i + 4] / 16.0f; + x_thread[i + 5] = x[i + 5] / 128.0f; + x_thread[i + 6] = x[i + 6] / 4.0f; + x_thread[i + 7] = x[i + 7] / 32.0f; + } + } + else if (bits == 4) { for (int i = 0; i < values_per_thread; i += 4) { sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3]; @@ -38,6 +53,16 @@ inline U load_vector(const device T* x, thread U* x_thread) { } } + else if (bits == 6) { + for (int i = 0; i < values_per_thread; i += 4) { + sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3]; + x_thread[i] = x[i]; + x_thread[i + 1] = x[i + 1] / 64.0f; + x_thread[i + 2] = x[i + 2] / 16.0f; + x_thread[i + 3] = x[i + 3] / 4.0f; + } + } + else if (bits == 8) { for (int i = 0; i < values_per_thread; i++) { sum += x[i]; @@ -51,8 +76,8 @@ inline U load_vector(const device T* x, thread U* x_thread) { template inline U load_vector_safe(const device T* x, thread U* x_thread, int N) { static_assert( - bits == 2 || bits == 4 || bits == 8, - "Template undefined for bits not in {2, 4, 8}"); + bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8, + "Template undefined for bits not in {2, 3, 4, 6, 8}"); U sum = 0; @@ -64,8 +89,21 @@ inline U load_vector_safe(const device T* x, thread U* x_thread, int N) { x_thread[i + 2] = x[i + 2] / 16.0f; x_thread[i + 3] = x[i + 3] / 64.0f; } - for (int i = N; i < values_per_thread; i++) { - x_thread[i] = 0; + } + + else if (bits == 3) { + for (int i = 0; i < N; i += 8) { + sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3] + x[i + 4] + x[i + 5] + + x[i + 6] + x[i + 7]; + + x_thread[i] = x[i]; + x_thread[i + 1] = x[i + 1] / 8.0f; + x_thread[i + 2] = x[i + 2] / 64.0f; + x_thread[i + 3] = x[i + 3] / 2.0f; + x_thread[i + 4] = x[i + 4] / 16.0f; + x_thread[i + 5] = x[i + 5] / 128.0f; + x_thread[i + 6] = x[i + 6] / 4.0f; + x_thread[i + 7] = x[i + 7] / 32.0f; } } @@ -77,8 +115,15 @@ inline U load_vector_safe(const device T* x, thread U* x_thread, int N) { x_thread[i + 2] = x[i + 2] / 256.0f; x_thread[i + 3] = x[i + 3] / 4096.0f; } - for (int i = N; i < values_per_thread; i++) { - x_thread[i] = 0; + } + + else if (bits == 6) { + for (int i = 0; i < N; i += 4) { + sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3]; + x_thread[i] = x[i]; + x_thread[i + 1] = x[i + 1] / 64.0f; + x_thread[i + 2] = x[i + 2] / 16.0f; + x_thread[i + 3] = x[i + 3] / 4.0f; } } @@ -87,9 +132,10 @@ inline U load_vector_safe(const device T* x, thread U* x_thread, int N) { sum += x[i]; x_thread[i] = x[i]; } - for (int i = N; i < values_per_thread; i++) { - x_thread[i] = 0; - } + } + + for (int i = N; i < values_per_thread; i++) { + x_thread[i] = 0; } return sum; @@ -103,8 +149,8 @@ inline U qdot( U bias, U sum) { static_assert( - bits == 2 || bits == 4 || bits == 8, - "Template undefined for bits not in {2, 4, 8}"); + bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8, + "Template undefined for bits not in {2, 3, 4, 6, 8}"); U accum = 0; @@ -118,6 +164,26 @@ inline U qdot( } } + else if (bits == 3) { + for (int i = 0; i < (values_per_thread / 8); i++) { + x_thread += 8 * i; + w += 3 * i; + + accum += (w[0] & 0x07) * x_thread[0]; + accum += (w[0] & 0x38) * x_thread[1]; + accum += (w[0] & 0xc0) * x_thread[2]; + accum += (w[1] & 0x01) * (x_thread[2] * 256.0f); + + accum += (w[1] & 0x0e) * x_thread[3]; + accum += (w[1] & 0x70) * x_thread[4]; + accum += (w[1] & 0x80) * x_thread[5]; + accum += (w[2] & 0x03) * (x_thread[5] * 256.0f); + + accum += (w[2] & 0x1c) * x_thread[6]; + accum += (w[2] & 0xe0) * x_thread[7]; + } + } + else if (bits == 4) { const device uint16_t* ws = (const device uint16_t*)w; for (int i = 0; i < (values_per_thread / 4); i++) { @@ -129,6 +195,23 @@ inline U qdot( } } + else if (bits == 6) { + for (int i = 0; i < (values_per_thread / 4); i++) { + x_thread += 4 * i; + w += 3 * i; + + accum += (w[0] & 0x3f) * x_thread[0]; + + accum += (w[0] & 0xc0) * x_thread[1]; + accum += (w[1] & 0x0f) * (x_thread[1] * 256.0f); + + accum += (w[1] & 0xf0) * x_thread[2]; + accum += (w[2] & 0x03) * (x_thread[2] * 256.0f); + + accum += (w[2] & 0xfc) * x_thread[3]; + } + } + else if (bits == 8) { for (int i = 0; i < values_per_thread; i++) { accum += x_thread[i] * w[i]; @@ -147,8 +230,8 @@ inline U qdot_safe( U sum, int N) { static_assert( - bits == 2 || bits == 4 || bits == 8, - "Template undefined for bits not in {2, 4, 8}"); + bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8, + "Template undefined for bits not in {2, 3, 4, 6, 8}"); U accum = 0; @@ -162,6 +245,26 @@ inline U qdot_safe( } } + else if (bits == 3) { + for (int i = 0; i < (N / 8); i++) { + x_thread += 8 * i; + w += 3 * i; + + accum += (w[0] & 0x07) * x_thread[0]; + accum += (w[0] & 0x38) * x_thread[1]; + accum += (w[0] & 0xc0) * x_thread[2]; + accum += (w[1] & 0x01) * (x_thread[2] * 256.0f); + + accum += (w[1] & 0x0e) * x_thread[3]; + accum += (w[1] & 0x70) * x_thread[4]; + accum += (w[1] & 0x80) * x_thread[5]; + accum += (w[2] & 0x03) * (x_thread[5] * 256.0f); + + accum += (w[2] & 0x1c) * x_thread[6]; + accum += (w[2] & 0xe0) * x_thread[7]; + } + } + else if (bits == 4) { const device uint16_t* ws = (const device uint16_t*)w; for (int i = 0; i < (N / 4); i++) { @@ -173,6 +276,23 @@ inline U qdot_safe( } } + else if (bits == 6) { + for (int i = 0; i < (N / 4); i++) { + x_thread += 4 * i; + w += 3 * i; + + accum += (w[0] & 0x3f) * x_thread[0]; + + accum += (w[0] & 0xc0) * x_thread[1]; + accum += (w[1] & 0x0f) * (x_thread[1] * 256.0f); + + accum += (w[1] & 0xf0) * x_thread[2]; + accum += (w[2] & 0x03) * (x_thread[2] * 256.0f); + + accum += (w[2] & 0xfc) * x_thread[3]; + } + } + else if (bits == 8) { for (int i = 0; i < N; i++) { accum += x_thread[i] * w[i]; @@ -186,8 +306,8 @@ template inline void qouter(const thread uint8_t* w, U x, U scale, U bias, thread U* result) { static_assert( - bits == 2 || bits == 4 || bits == 8, - "Template undefined for bits not in {2, 4, 8}"); + bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8, + "Template undefined for bits not in {2, 3, 4, 6, 8}"); if (bits == 2) { U s[4] = {scale, scale / 4.0f, scale / 16.0f, scale / 64.0f}; @@ -199,12 +319,45 @@ qouter(const thread uint8_t* w, U x, U scale, U bias, thread U* result) { } } + else if (bits == 3) { + for (int i = 0; i < (values_per_thread / 8); i++) { + uint8_t w0 = w[3 * i]; + uint8_t w1 = w[3 * i + 1]; + uint8_t w2 = w[3 * i + 2]; + + result[8 * i] += x * ((w0 & 0x7) * scale + bias); + result[8 * i + 1] += x * (((w0 & 0x38) >> 3) * scale + bias); + result[8 * i + 2] += + x * ((((w0 & 0xc0) >> 6) + ((w1 & 0x1) << 2)) * scale + bias); + result[8 * i + 3] += x * (((w1 & 0xe) >> 1) * scale + bias); + result[8 * i + 4] += x * (((w1 & 0x70) >> 4) * scale + bias); + result[8 * i + 5] += + x * ((((w1 & 0x80) >> 7) + ((w2 & 0x3) << 1)) * scale + bias); + result[8 * i + 6] += x * (((w2 & 0x1c) >> 2) * scale + bias); + result[8 * i + 7] += x * (((w2 & 0xe0) >> 5) * scale + bias); + } + } + else if (bits == 4) { U s[2] = {scale, scale / 16.0f}; for (int i = 0; i < (values_per_thread / 2); i++) { result[2 * i] += x * (s[0] * (w[i] & 0x0f) + bias); result[2 * i + 1] += x * (s[1] * (w[i] & 0xf0) + bias); } + + } else if (bits == 6) { + for (int i = 0; i < (values_per_thread / 4); i++) { + uint8_t w0 = w[3 * i]; + uint8_t w1 = w[3 * i + 1]; + uint8_t w2 = w[3 * i + 2]; + + result[4 * i] += x * ((w0 & 0x3f) * scale + bias); + result[4 * i + 1] += + x * ((((w0 >> 6) & 0x03) + ((w1 & 0x0f) << 2)) * scale + bias); + result[4 * i + 2] += + x * ((((w1 >> 4) & 0x0f) + ((w2 & 0x03) << 4)) * scale + bias); + result[4 * i + 3] += x * (((w2 >> 2) & 0x3f) * scale + bias); + } } else if (bits == 8) { @@ -218,8 +371,8 @@ template inline void dequantize(const device uint8_t* w, U scale, U bias, threadgroup U* w_local) { static_assert( - bits == 2 || bits == 4 || bits == 8, - "Template undefined for bits not in {2, 4, 8}"); + bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8, + "Template undefined for bits not in {2, 3, 4, 6, 8}"); if (bits == 2) { U s[4] = { @@ -235,6 +388,22 @@ dequantize(const device uint8_t* w, U scale, U bias, threadgroup U* w_local) { } } + else if (bits == 3) { + for (int i = 0; i < (N / 8); i++) { + w_local += 8 * i; + w += 3 * i; + + w_local[0] = (w[0] & 0x7) * scale + bias; + w_local[1] = ((w[0] & 0x38) >> 3) * scale + bias; + w_local[2] = (((w[0] & 0xc0) >> 6) + ((w[1] & 0x1) << 2)) * scale + bias; + w_local[3] = ((w[1] & 0xe) >> 1) * scale + bias; + w_local[4] = ((w[1] & 0x70) >> 4) * scale + bias; + w_local[5] = (((w[1] & 0x80) >> 7) + ((w[2] & 0x3) << 1)) * scale + bias; + w_local[6] = ((w[2] & 0x1c) >> 2) * scale + bias; + w_local[7] = ((w[2] & 0xe0) >> 5) * scale + bias; + } + } + else if (bits == 4) { U s[2] = {scale, scale / static_cast(16.0f)}; for (int i = 0; i < (N / 2); i++) { @@ -243,6 +412,18 @@ dequantize(const device uint8_t* w, U scale, U bias, threadgroup U* w_local) { } } + else if (bits == 6) { + for (int i = 0; i < (N / 4); i++) { + w_local += 4 * i; + w += 3 * i; + + w_local[0] = (w[0] & 0x3f) * scale + bias; + w_local[1] = (((w[0] >> 6) & 0x03) + ((w[1] & 0x0f) << 2)) * scale + bias; + w_local[2] = (((w[1] >> 4) & 0x0f) + ((w[2] & 0x03) << 4)) * scale + bias; + w_local[3] = ((w[2] >> 2) & 0x3f) * scale + bias; + } + } + else if (bits == 8) { for (int i = 0; i < N; i++) { w_local[i] = scale * w[i] + bias; @@ -267,10 +448,11 @@ struct QuantizedBlockLoader { group_size % BCOLS == 0, "The group size should be divisible by the columns"); static_assert( - bits == 2 || bits == 4 || bits == 8, - "Template undefined for bits not in {2, 4, 8}"); + bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8, + "Template undefined for bits not in {2, 3, 4, 6, 8}"); - MLX_MTL_CONST short pack_factor = 32 / bits; + MLX_MTL_CONST short pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits; + MLX_MTL_CONST short bytes_per_pack = (bits == 3 || bits == 6) ? 3 : 1; MLX_MTL_CONST short BCOLS_PACKED = BCOLS / pack_factor; MLX_MTL_CONST short n_reads = (BCOLS_PACKED * BROWS < tgp_size) ? 1 : (BCOLS_PACKED * BROWS) / tgp_size; @@ -286,12 +468,12 @@ struct QuantizedBlockLoader { const short bj; threadgroup T* dst; - const device uint32_t* src; + const device uint8_t* src; const device T* scales; const device T* biases; QuantizedBlockLoader( - const device uint32_t* src_, + const device uint8_t* src_, const device T* scales_, const device T* biases_, const int src_ld_, @@ -300,14 +482,16 @@ struct QuantizedBlockLoader { ushort simd_lane_id [[thread_index_in_simdgroup]]) : src_ld(src_ld_), tile_stride( - reduction_dim ? BCOLS_PACKED : BROWS * src_ld / pack_factor), + reduction_dim ? BCOLS_PACKED * bytes_per_pack + : BROWS * src_ld * bytes_per_pack / pack_factor), group_step_cnt(0), group_stride(BROWS * src_ld / group_size), thread_idx(simd_group_id * 32 + simd_lane_id), bi(n_reads * thread_idx / BCOLS_PACKED), bj((n_reads * thread_idx) % BCOLS_PACKED), dst(dst_ + bi * dst_ld + bj * pack_factor), - src(src_ + bi * src_ld / pack_factor + bj), + src(src_ + bi * src_ld * bytes_per_pack / pack_factor + + bj * bytes_per_pack), scales(scales_ + bi * src_ld / group_size), biases(biases_ + bi * src_ld / group_size) {} @@ -320,7 +504,7 @@ struct QuantizedBlockLoader { T bias = *biases; for (int i = 0; i < n_reads; i++) { dequantize( - (device uint8_t*)(src + i), scale, bias, dst + i * pack_factor); + src + i * bytes_per_pack, scale, bias, dst + i * pack_factor); } } @@ -347,7 +531,10 @@ struct QuantizedBlockLoader { T bias = *biases; for (int i = 0; i < n_reads; i++) { dequantize( - (device uint8_t*)(src + i), scale, bias, dst + i * pack_factor); + (device uint8_t*)(src + i * bytes_per_pack), + scale, + bias, + dst + i * pack_factor); } } @@ -410,8 +597,7 @@ METAL_FUNC void qmv_quad_impl( U sum = load_vector(x, x_thread); for (int row = 0; row < results_per_quadgroup; row++) { - const device uint8_t* wl = - (const device uint8_t*)(w + row * in_vec_size_w * quads_per_simd); + auto wl = (const device uint8_t*)(w + row * in_vec_size_w * quads_per_simd); const device T* sl = scales + row * in_vec_size_g * quads_per_simd; const device T* bl = biases + row * in_vec_size_g * quads_per_simd; @@ -442,25 +628,34 @@ METAL_FUNC void qmv_fast_impl( uint3 tid [[threadgroup_position_in_grid]], uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { - constexpr int packs_per_thread = bits > 2 ? 2 : 1; + constexpr int power_of_2_bits = (bits & (bits - 1)) == 0; + constexpr int packs_per_thread = bits == 2 ? 1 : 2; constexpr int num_simdgroups = 2; constexpr int results_per_simdgroup = 4; - constexpr int pack_factor = 32 / bits; + constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 32 / bits; + constexpr int bytes_per_pack = power_of_2_bits ? 1 : 3; constexpr int values_per_thread = pack_factor * packs_per_thread; constexpr int block_size = values_per_thread * SIMD_SIZE; constexpr int scale_step_per_thread = group_size / values_per_thread; + // When bits is a power of two, read 1 uint32_t at a time + // When bits is 3 or 6, read 3 uint8_ts at a time + using W_T = + typename ConditionalType::type; + const device W_T* ws = (const device W_T*)w; + typedef float U; thread U x_thread[values_per_thread]; thread U result[results_per_simdgroup] = {0}; // Adjust positions - const int in_vec_size_w = in_vec_size / pack_factor; + const int in_vec_size_w = in_vec_size * bytes_per_pack / pack_factor; const int in_vec_size_g = in_vec_size / group_size; const int out_row = tid.x * (num_simdgroups * results_per_simdgroup) + simd_gid * results_per_simdgroup; - w += out_row * in_vec_size_w + simd_lid * packs_per_thread; + + ws += out_row * in_vec_size_w + simd_lid * packs_per_thread * bytes_per_pack; scales += out_row * in_vec_size_g + simd_lid / scale_step_per_thread; biases += out_row * in_vec_size_g + simd_lid / scale_step_per_thread; x += tid.y * in_vec_size + simd_lid * values_per_thread; @@ -470,8 +665,7 @@ METAL_FUNC void qmv_fast_impl( U sum = load_vector(x, x_thread); for (int row = 0; row < results_per_simdgroup; row++) { - const device uint8_t* wl = - (const device uint8_t*)(w + row * in_vec_size_w); + auto wl = (const device uint8_t*)(ws + row * in_vec_size_w); const device T* sl = scales + row * in_vec_size_g; const device T* bl = biases + row * in_vec_size_g; @@ -480,7 +674,7 @@ METAL_FUNC void qmv_fast_impl( result[row] += qdot(wl, x_thread, s, b, sum); } - w += block_size / pack_factor; + ws += block_size * bytes_per_pack / pack_factor; scales += block_size / group_size; biases += block_size / group_size; x += block_size; @@ -506,21 +700,29 @@ METAL_FUNC void qmv_impl( uint3 tid [[threadgroup_position_in_grid]], uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { + constexpr int power_of_2_bits = (bits & (bits - 1)) == 0; constexpr int num_simdgroups = 2; constexpr int results_per_simdgroup = 4; constexpr int packs_per_thread = 1; - constexpr int pack_factor = 32 / bits; + constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 32 / bits; + constexpr int bytes_per_pack = power_of_2_bits ? 1 : 3; constexpr int values_per_thread = pack_factor * packs_per_thread; constexpr int block_size = values_per_thread * SIMD_SIZE; constexpr int scale_step_per_thread = group_size / values_per_thread; + // When bits is a power of two, read 1 uint32_t at a time + // When bits is 3 or 6, read 3 uint8_ts at a time + using W_T = + typename ConditionalType::type; + const device W_T* ws = (const device W_T*)w; + typedef float U; thread U x_thread[values_per_thread]; thread U result[results_per_simdgroup] = {0}; // Adjust positions - const int in_vec_size_w = in_vec_size / pack_factor; + const int in_vec_size_w = in_vec_size * bytes_per_pack / pack_factor; const int in_vec_size_g = in_vec_size / group_size; const int out_row = tid.x * (num_simdgroups * results_per_simdgroup) + simd_gid * results_per_simdgroup; @@ -533,7 +735,8 @@ METAL_FUNC void qmv_impl( // In this case we need to properly guard all our reads because there isn't // even 1 tile in the matrix if (out_vec_size < (num_simdgroups * results_per_simdgroup)) { - w += out_row * in_vec_size_w + simd_lid * packs_per_thread; + ws += + out_row * in_vec_size_w + simd_lid * packs_per_thread * bytes_per_pack; scales += out_row * in_vec_size_g + simd_lid / scale_step_per_thread; biases += out_row * in_vec_size_g + simd_lid / scale_step_per_thread; x += tid.y * in_vec_size + simd_lid * values_per_thread; @@ -544,8 +747,7 @@ METAL_FUNC void qmv_impl( U sum = load_vector(x, x_thread); for (int row = 0; out_row + row < out_vec_size; row++) { - const device uint8_t* wl = - (const device uint8_t*)(w + row * in_vec_size_w); + auto wl = (const device uint8_t*)(ws + row * in_vec_size_w); const device T* sl = scales + row * in_vec_size_g; const device T* bl = biases + row * in_vec_size_g; @@ -555,7 +757,7 @@ METAL_FUNC void qmv_impl( qdot(wl, x_thread, s, b, sum); } - w += block_size / pack_factor; + ws += block_size * bytes_per_pack / pack_factor; scales += block_size / group_size; biases += block_size / group_size; x += block_size; @@ -569,8 +771,7 @@ METAL_FUNC void qmv_impl( x, x_thread, remaining); for (int row = 0; out_row + row < out_vec_size; row++) { - const device uint8_t* wl = - (const device uint8_t*)(w + row * in_vec_size_w); + auto wl = (const device uint8_t*)(ws + row * in_vec_size_w); const device T* sl = scales + row * in_vec_size_g; const device T* bl = biases + row * in_vec_size_g; @@ -591,7 +792,8 @@ METAL_FUNC void qmv_impl( // In this case the last tile is moved back to redo some output values else { - w += used_out_row * in_vec_size_w + simd_lid * packs_per_thread; + ws += used_out_row * in_vec_size_w + + simd_lid * packs_per_thread * bytes_per_pack; scales += used_out_row * in_vec_size_g + simd_lid / scale_step_per_thread; biases += used_out_row * in_vec_size_g + simd_lid / scale_step_per_thread; x += tid.y * in_vec_size + simd_lid * values_per_thread; @@ -602,8 +804,7 @@ METAL_FUNC void qmv_impl( U sum = load_vector(x, x_thread); for (int row = 0; row < results_per_simdgroup; row++) { - const device uint8_t* wl = - (const device uint8_t*)(w + row * in_vec_size_w); + auto wl = (const device uint8_t*)(ws + row * in_vec_size_w); const device T* sl = scales + row * in_vec_size_g; const device T* bl = biases + row * in_vec_size_g; @@ -613,7 +814,7 @@ METAL_FUNC void qmv_impl( qdot(wl, x_thread, s, b, sum); } - w += block_size / pack_factor; + ws += block_size * bytes_per_pack / pack_factor; scales += block_size / group_size; biases += block_size / group_size; x += block_size; @@ -627,8 +828,7 @@ METAL_FUNC void qmv_impl( x, x_thread, remaining); for (int row = 0; row < results_per_simdgroup; row++) { - const device uint8_t* wl = - (const device uint8_t*)(w + row * in_vec_size_w); + auto wl = (const device uint8_t*)(ws + row * in_vec_size_w); const device T* sl = scales + row * in_vec_size_g; const device T* bl = biases + row * in_vec_size_g; @@ -659,14 +859,22 @@ METAL_FUNC void qvm_impl( uint3 tid [[threadgroup_position_in_grid]], uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { + constexpr int power_of_2_bits = (bits & (bits - 1)) == 0; constexpr int num_simdgroups = 2; - constexpr int pack_factor = 32 / bits; + constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 32 / bits; + constexpr int bytes_per_pack = power_of_2_bits ? 1 : 3; constexpr int tn = 32 / pack_factor; - constexpr int blocksize = SIMD_SIZE; + constexpr int block_size = SIMD_SIZE; + + // When bits is a power of two, read 1 uint32_t at a time + // When bits is 3 or 6, read 3 uint8_ts at a time + using W_T = + typename ConditionalType::type; + const device W_T* ws = (const device W_T*)w; typedef float U; typedef struct { - uint32_t wi[tn]; + W_T wi[tn * bytes_per_pack]; } vec_w; thread vec_w w_local; @@ -676,11 +884,10 @@ METAL_FUNC void qvm_impl( thread U x_local = 0; // Adjust positions - const int out_vec_size_w = out_vec_size / pack_factor; + const int out_vec_size_w = out_vec_size * bytes_per_pack / pack_factor; const int out_vec_size_g = out_vec_size / group_size; - int out_col = - tid.x * (num_simdgroups * pack_factor * tn) + simd_gid * pack_factor * tn; - w += out_col / pack_factor + simd_lid * out_vec_size_w; + int out_col = pack_factor * tn * (tid.x * num_simdgroups + simd_gid); + ws += out_col * bytes_per_pack / pack_factor + simd_lid * out_vec_size_w; scales += out_col / group_size + simd_lid * out_vec_size_g; biases += out_col / group_size + simd_lid * out_vec_size_g; x += tid.y * in_vec_size + simd_lid; @@ -690,43 +897,42 @@ METAL_FUNC void qvm_impl( return; } - // Loop over in_vec in blocks of blocksize - int remaining = in_vec_size % blocksize; + // Loop over in_vec in blocks of block_size + int remaining = in_vec_size % block_size; if (remaining == 0) { - for (int i = 0; i < in_vec_size; i += blocksize) { + for (int i = 0; i < in_vec_size; i += block_size) { x_local = *x; scale = *scales; bias = *biases; - w_local = *((device vec_w*)w); - + w_local = *((device vec_w*)ws); qouter( (thread uint8_t*)&w_local, x_local, scale, bias, result); - x += blocksize; - scales += blocksize * out_vec_size_g; - biases += blocksize * out_vec_size_g; - w += blocksize * out_vec_size_w; + x += block_size; + scales += block_size * out_vec_size_g; + biases += block_size * out_vec_size_g; + ws += block_size * out_vec_size_w; } } else { - for (int i = blocksize; i < in_vec_size; i += blocksize) { + for (int i = block_size; i < in_vec_size; i += block_size) { x_local = *x; scale = *scales; bias = *biases; - w_local = *((device vec_w*)w); + w_local = *((device vec_w*)ws); qouter( (thread uint8_t*)&w_local, x_local, scale, bias, result); - x += blocksize; - scales += blocksize * out_vec_size_g; - biases += blocksize * out_vec_size_g; - w += blocksize * out_vec_size_w; + x += block_size; + scales += block_size * out_vec_size_g; + biases += block_size * out_vec_size_g; + ws += block_size * out_vec_size_w; } if (static_cast(simd_lid) < remaining) { x_local = *x; scale = *scales; bias = *biases; - w_local = *((device vec_w*)w); + w_local = *((device vec_w*)ws); } else { x_local = 0; scale = 0; @@ -781,8 +987,9 @@ METAL_FUNC void qmm_t_impl( constexpr int WM = 2; constexpr int WN = 2; - constexpr int pack_factor = 32 / bits; + constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits; constexpr int BK_padded = (BK + 16 / sizeof(T)); + constexpr int bytes_per_pack = (bits == 3 || bits == 6) ? 3 : 1; // Instantiate the appropriate BlockMMA and Loader using mma_t = mlx::steel:: @@ -800,13 +1007,15 @@ METAL_FUNC void qmm_t_impl( bits>; // Set the block - const int K_w = K / pack_factor; + const int K_w = K * bytes_per_pack / pack_factor; const int K_g = K / group_size; const int y_row = tid.y * BM; const int y_col = tid.x * BN; + auto wl = (const device uint8_t*)w; + x += y_row * K; - w += y_col * K_w; + wl += y_col * K_w; scales += y_col * K_g; biases += y_col * K_g; y += y_row * N + y_col; @@ -815,7 +1024,7 @@ METAL_FUNC void qmm_t_impl( const short num_els = min(BM, M - y_row); const short num_outs = min(BN, N - y_col); loader_x_t loader_x(x, K, Xs, simd_gid, simd_lid); - loader_w_t loader_w(w, scales, biases, K, Ws, simd_gid, simd_lid); + loader_w_t loader_w(wl, scales, biases, K, Ws, simd_gid, simd_lid); mma_t mma_op(simd_gid, simd_lid); if (num_els < BM) { @@ -857,6 +1066,7 @@ METAL_FUNC void qmm_t_impl( loader_x.load_unsafe(); loader_w.load_unsafe(); threadgroup_barrier(mem_flags::mem_threadgroup); + mma_op.mma(Xs, Ws); loader_x.next(); loader_w.next(); @@ -902,9 +1112,11 @@ METAL_FUNC void qmm_n_impl( constexpr int WM = 2; constexpr int WN = 2; - constexpr int pack_factor = 32 / bits; + constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits; constexpr int BK_padded = (BK + 16 / sizeof(T)); constexpr int BN_padded = (BN + 16 / sizeof(T)); + constexpr int power_of_2_bits = (bits & (bits - 1)) == 0; + constexpr int bytes_per_pack = power_of_2_bits ? 1 : 3; // Instantiate the appropriate BlockMMA and Loader using mma_t = mlx::steel:: @@ -921,11 +1133,13 @@ METAL_FUNC void qmm_n_impl( group_size, bits>; + auto wl = (const device uint8_t*)w; + // Set the block const int y_row = tid.y * BM; const int y_col = tid.x * BN; x += y_row * K; - w += y_col / pack_factor; + wl += y_col * bytes_per_pack / pack_factor; scales += y_col / group_size; biases += y_col / group_size; y += y_row * N + y_col; @@ -933,7 +1147,7 @@ METAL_FUNC void qmm_n_impl( // Make the x loader and mma operation const short num_els = min(BM, M - y_row); loader_x_t loader_x(x, K, Xs, simd_gid, simd_lid); - loader_w_t loader_w(w, scales, biases, N, Ws, simd_gid, simd_lid); + loader_w_t loader_w(wl, scales, biases, N, Ws, simd_gid, simd_lid); mma_t mma_op(simd_gid, simd_lid); if (num_els < BM) { @@ -1805,13 +2019,14 @@ template uint2 grid_dim [[threads_per_grid]]) { constexpr T eps = T(1e-7); constexpr int simd_size = 32; - constexpr int uint8_bits = 8; constexpr T n_bins = (1 << bits) - 1; - constexpr int packs_per_int = uint8_bits / bits; + constexpr int packs_per_int = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits; constexpr int values_per_reduce = group_size / simd_size; constexpr int writes_per_reduce = packs_per_int / values_per_reduce; constexpr int writes_per_pack = writes_per_reduce > 1 ? 1 : values_per_reduce / packs_per_int; + constexpr int power_of_2_bits = (bits & (bits - 1)) == 0; + constexpr int bytes_per_pack = power_of_2_bits ? 1 : 3; static_assert( group_size % simd_size == 0, @@ -1819,7 +2034,9 @@ template size_t offset = index.x + grid_dim.x * size_t(index.y); size_t in_index = offset * values_per_reduce; - size_t out_index = offset * writes_per_pack; + size_t out_index = power_of_2_bits + ? offset * writes_per_pack + : offset * bytes_per_pack / writes_per_reduce; T w_thread[values_per_reduce]; T w_min = Limits::max; @@ -1852,7 +2069,11 @@ template biases[gindex] = bias; } - uint8_t output = 0; + // We accumulate 3 bytes worth for 3/6 bit so we need a uint32_t + using OutT = + typename ConditionalType::type; + OutT output = 0; + #pragma clang loop unroll(full) for (int i = 0; i < values_per_reduce; i++) { uint8_t val = min(round((w_thread[i] - bias) / scale), n_bins); @@ -1868,47 +2089,23 @@ template output = 0; } else { #pragma clang loop unroll(full) - for (int j = 0; j < writes_per_reduce - 1; j++) { - uint8_t sval = simd_shuffle_down(val, j + 1); - output += sval << (bits * (values_per_reduce + j + i)); + for (int j = 1; j < writes_per_reduce; j++) { + uint8_t sval = simd_shuffle_down(val, j); + output += sval << (bits * (j * values_per_reduce + i)); } } } - if (writes_per_reduce > 0 && out_index % writes_per_reduce == 0) { - out[out_index / writes_per_reduce] = output; - } -} - -template -[[kernel]] void affine_quantize_scales_biases( - const device T* w [[buffer(0)]], - const device T* scales [[buffer(1)]], - const device T* biases [[buffer(2)]], - device uint8_t* out [[buffer(3)]], - uint2 index [[thread_position_in_grid]], - uint2 grid_dim [[threads_per_grid]]) { - constexpr int uint8_bits = 8; - constexpr int packs_per_int = uint8_bits / bits; - constexpr T n_bins = (1 << bits) - 1; - - size_t offset = index.x + grid_dim.x * size_t(index.y); - size_t in_index = offset * packs_per_int; - size_t gindex = in_index / group_size; - - T scale = scales[gindex]; - T bias = biases[gindex]; - - uint8_t output = 0; -#pragma clang loop unroll(full) - for (int i = 0; i < packs_per_int; i++) { - uint8_t val = min(round((w[in_index + i] - bias) / scale), n_bins); - if (bits == 8) { - output = val; - } else { - output += val << (bits * i); + if (bits == 3 || bits == 6) { + if (in_index % packs_per_int == 0 && out_index % bytes_per_pack == 0) { + out[out_index] = output & 0xff; + out[out_index + 1] = (output & 0xff00) >> 8; + out[out_index + 2] = (output & 0xff0000) >> 16; + } + } else { + if (writes_per_reduce > 0 && out_index % writes_per_reduce == 0) { + out[out_index / writes_per_reduce] = output; } } - out[offset] = output; } template @@ -1919,26 +2116,48 @@ template device T* out [[buffer(3)]], uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { - constexpr int uint8_bits = 8; - constexpr int packs_per_int = uint8_bits / bits; + constexpr int packs_per_int = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits; + constexpr int power_of_2_bits = (bits & (bits - 1)) == 0; + constexpr int bytes_per_pack = power_of_2_bits ? 1 : 3; size_t offset = index.x + grid_dim.x * size_t(index.y); size_t oindex = offset * packs_per_int; size_t gindex = oindex / group_size; T scale = scales[gindex]; T bias = biases[gindex]; - uint val = w[offset]; + out += oindex; + + if (bits == 3) { + w += offset * bytes_per_pack; + out[0] = (w[0] & 0x7) * scale + bias; + out[1] = ((w[0] & 0x38) >> 3) * scale + bias; + out[2] = (((w[0] & 0xc0) >> 6) + ((w[1] & 0x1) << 2)) * scale + bias; + out[3] = ((w[1] & 0xe) >> 1) * scale + bias; + out[4] = ((w[1] & 0x70) >> 4) * scale + bias; + out[5] = (((w[1] & 0x80) >> 7) + ((w[2] & 0x3) << 1)) * scale + bias; + out[6] = ((w[2] & 0x1c) >> 2) * scale + bias; + out[7] = ((w[2] & 0xe0) >> 5) * scale + bias; + + } else if (bits == 6) { + w += offset * bytes_per_pack; + out[0] = (w[0] & 0x3f) * scale + bias; + out[1] = (((w[0] >> 6) & 0x03) + ((w[1] & 0x0f) << 2)) * scale + bias; + out[2] = (((w[1] >> 4) & 0x0f) + ((w[2] & 0x03) << 4)) * scale + bias; + out[3] = ((w[2] >> 2) & 0x3f) * scale + bias; + } else { + uint val = w[offset]; #pragma clang loop unroll(full) - for (int i = 0; i < packs_per_int; i++) { - uint8_t d; - if (bits == 2) { - d = (val >> (bits * i)) & 0x03; - } else if (bits == 4) { - d = (val >> (bits * i)) & 0x0f; - } else if (bits == 8) { - d = val; + for (int i = 0; i < packs_per_int; i++) { + uint8_t d; + if (bits == 2) { + d = (val >> (bits * i)) & 0x03; + } else if (bits == 4) { + d = (val >> (bits * i)) & 0x0f; + } else if (bits == 8) { + d = val; + } + out[i] = scale * d + bias; } - out[oindex + i] = scale * d + bias; } } diff --git a/mlx/backend/metal/kernels/quantized.metal b/mlx/backend/metal/kernels/quantized.metal index 5751d953f..7af554437 100644 --- a/mlx/backend/metal/kernels/quantized.metal +++ b/mlx/backend/metal/kernels/quantized.metal @@ -72,7 +72,6 @@ #define instantiate_quantized_all_single(type, group_size, bits) \ instantiate_quantized(affine_quantize, type, group_size, bits) \ - instantiate_quantized(affine_quantize_scales_biases, type, group_size, bits) \ instantiate_quantized(affine_dequantize, type, group_size, bits) \ instantiate_quantized(bs_qmv_fast, type, group_size, bits) \ instantiate_quantized(bs_qmv, type, group_size, bits) \ @@ -116,7 +115,9 @@ #define instantiate_quantized_all() \ instantiate_quantized_groups(2) \ + instantiate_quantized_groups(3) \ instantiate_quantized_groups(4) \ + instantiate_quantized_groups(6) \ instantiate_quantized_groups(8) instantiate_quantized_all() // clang-format on diff --git a/mlx/backend/metal/kernels/utils.h b/mlx/backend/metal/kernels/utils.h index b894426d4..c6add37f9 100644 --- a/mlx/backend/metal/kernels/utils.h +++ b/mlx/backend/metal/kernels/utils.h @@ -421,3 +421,14 @@ inline complex64_t simd_shuffle(complex64_t data, uint16_t lane) { return complex64_t( simd_shuffle(data.real, lane), simd_shuffle(data.imag, lane)); } + +// std::conditional is not included with Metal +template +struct ConditionalType { + using type = U; +}; + +template +struct ConditionalType { + using type = T; +}; diff --git a/mlx/backend/metal/quantized.cpp b/mlx/backend/metal/quantized.cpp index d41502617..4454476c9 100644 --- a/mlx/backend/metal/quantized.cpp +++ b/mlx/backend/metal/quantized.cpp @@ -10,6 +10,7 @@ #include "mlx/backend/metal/utils.h" #include "mlx/fast_primitives.h" #include "mlx/primitives.h" +#include "mlx/utils.h" namespace mlx::core { @@ -298,7 +299,7 @@ void qmm_op( bool quad = false; if (transpose) { - if (B < 6 && (D == 128 || D == 64)) { + if (B < 6 && (D == 128 || D == 64) && is_power_of_2(bits)) { name += "qmv_quad"; constexpr int quads_per_simd = 8; constexpr int results_per_quadgroup = 8; @@ -391,8 +392,6 @@ void GatherQMM::eval_gpu(const std::vector& inputs, array& out) { void fast::AffineQuantize::eval_gpu( const std::vector& inputs, std::vector& outputs) { - bool compute_scale_bias = inputs.size() == 1; - auto& w_pre = inputs[0]; auto& out = outputs[0]; out.set_data(allocator::malloc_or_wait(out.nbytes())); @@ -415,7 +414,7 @@ void fast::AffineQuantize::eval_gpu( auto& compute_encoder = d.get_command_encoder(s.index); compute_encoder.set_input_array(w, 0); - if (!compute_scale_bias) { + if (dequantize_) { auto& scales_pre = inputs[1]; auto& biases_pre = inputs[2]; auto scales = ensure_row_contiguous(scales_pre); @@ -436,12 +435,7 @@ void fast::AffineQuantize::eval_gpu( std::ostringstream kname; auto type_string = dequantize_ ? get_type_string(out.dtype()) : get_type_string(w_pre.dtype()); - auto kernel_func = "affine_quantize_scales_biases"; - if (dequantize_) { - kernel_func = "affine_dequantize"; - } else if (compute_scale_bias) { - kernel_func = "affine_quantize"; - } + auto kernel_func = dequantize_ ? "affine_dequantize" : "affine_quantize"; kname << kernel_func << "_" << type_string << "_gs_" << group_size_ << "_b_" << bits_; auto template_def = get_template_definition( @@ -452,10 +446,10 @@ void fast::AffineQuantize::eval_gpu( // Treat uint32 as uint8 in kernel constexpr int uint8_per_uint32 = 4; constexpr int simd_size = 32; - int packs_per_int = 8 / bits_; - int per_thread = compute_scale_bias ? group_size_ / simd_size : packs_per_int; + int packs_per_int = bits_ == 3 ? 8 : bits_ == 6 ? 4 : 8 / bits_; + int per_thread = dequantize_ ? packs_per_int : group_size_ / simd_size; size_t nthreads = - dequantize_ ? w.size() * uint8_per_uint32 : w.size() / per_thread; + dequantize_ ? out.size() / packs_per_int : w.size() / per_thread; NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); if (thread_group_size > nthreads) { diff --git a/mlx/fast.cpp b/mlx/fast.cpp index d3eb77d06..ccce1d6b1 100644 --- a/mlx/fast.cpp +++ b/mlx/fast.cpp @@ -686,13 +686,11 @@ array pack_and_quantize( array& packed_w, const array& scales, const array& biases, - int group_size, int bits, const Stream& s) { int el_per_int = 32 / bits; array zero(0, packed_w.dtype()); array n_bins((1 << bits) - 1, packed_w.dtype()); // 2**bits - 1 - array shifts = power(array(2, uint32), arange(0, 32, bits, uint32, s), s); packed_w = astype( clip( round(divide(subtract(packed_w, biases, s), scales, s), s), @@ -701,9 +699,30 @@ array pack_and_quantize( s), uint32, s); - packed_w = reshape(packed_w, {packed_w.shape(0), -1, el_per_int}, s); - packed_w = sum( - multiply(packed_w, shifts, s), /* axis= */ 2, /* keepdims= */ false, s); + if (is_power_of_2(bits)) { + array shifts = power(array(2, uint32), arange(0, 32, bits, uint32, s), s); + packed_w = reshape(packed_w, {packed_w.shape(0), -1, el_per_int}, s); + packed_w = sum( + multiply(packed_w, shifts, s), /* axis= */ 2, /* keepdims= */ false, s); + } else { + // This is slow but we have fast GPU/CPU versions of this function so we + // shouldn't be here often. + packed_w = expand_dims(packed_w, /* axis= */ -1, s); + packed_w = bitwise_and( + right_shift(packed_w, arange(bits, uint32, s), s), + array({1}, uint32), + s); + auto new_shape = packed_w.shape(); + new_shape[new_shape.size() - 2] = -1; + new_shape.back() = 32; + packed_w = reshape(packed_w, new_shape, s); + array shifts = arange(32, uint32, s); + packed_w = + sum(left_shift(packed_w, shifts, s), + /* axis= */ -1, + /* keepdims= */ false, + s); + } return packed_w; } @@ -718,10 +737,10 @@ affine_quantize(const array& w, int group_size, int bits, StreamOrDevice s_) { throw std::invalid_argument(msg.str()); } - if (bits != 2 && bits != 4 && bits != 8) { + if (bits != 2 && bits != 3 && bits != 4 && bits != 6 && bits != 8) { std::ostringstream msg; msg << "[quantize] The requested number of bits " << bits - << " is not supported. The supported bits are 2, 4 and 8."; + << " is not supported. The supported bits are 2, 3, 4, 6 and 8."; throw std::invalid_argument(msg.str()); } @@ -740,9 +759,7 @@ affine_quantize(const array& w, int group_size, int bits, StreamOrDevice s_) { throw std::invalid_argument(msg.str()); } - int el_per_int = 32 / bits; - - auto fallback = [group_size, bits, el_per_int, s]( + auto fallback = [group_size, bits, s]( const std::vector& inputs) -> std::vector { auto& w = inputs[0]; auto wshape = w.shape(); @@ -765,7 +782,7 @@ affine_quantize(const array& w, int group_size, int bits, StreamOrDevice s_) { scales = where(not_equal(q0, zero, s), divide(edge, q0, s), scales); array biases = where(equal(q0, zero, s), zero, edge, s); - packed_w = pack_and_quantize(packed_w, scales, biases, group_size, bits, s); + packed_w = pack_and_quantize(packed_w, scales, biases, bits, s); return { reshape(packed_w, wshape, s), reshape(scales, wshape, s), @@ -774,7 +791,7 @@ affine_quantize(const array& w, int group_size, int bits, StreamOrDevice s_) { }; auto wq_shape = w.shape(); - wq_shape.back() = w.shape(-1) / el_per_int; + wq_shape.back() = w.shape(-1) * bits / 32; auto sshape = w.shape(); sshape.back() = w.shape(-1) / group_size; auto outputs = array::make_arrays( @@ -785,39 +802,6 @@ affine_quantize(const array& w, int group_size, int bits, StreamOrDevice s_) { return {outputs[0], outputs[1], outputs[2]}; } -array affine_quantize( - const array& w, - const array& scales, - const array& biases, - int group_size, - int bits, - StreamOrDevice s_) { - auto s = to_stream(s_); - - int el_per_int = 32 / bits; - auto fallback = [group_size, bits, el_per_int, s]( - const std::vector& inputs) -> std::vector { - auto& w = inputs[0]; - auto scales = expand_dims(inputs[1], -1, s); - auto biases = expand_dims(inputs[2], -1, s); - - auto wshape = w.shape(); - wshape.back() = -1; - - array packed_w = reshape(w, {-1, w.shape(-1) / group_size, group_size}, s); - packed_w = pack_and_quantize(packed_w, scales, biases, group_size, bits, s); - return {reshape(packed_w, wshape, s)}; - }; - - auto out_shape = w.shape(); - out_shape.back() = w.shape(-1) / el_per_int; - return array( - std::move(out_shape), - uint32, - std::make_shared(s, fallback, group_size, bits, false), - {w, scales, biases}); -} - array affine_dequantize( const array& w, const array& scales, @@ -860,9 +844,9 @@ array affine_dequantize( } // Packing into uint32 - int el_per_int = 32 / bits; + int out_size = w.shape(-1) * 32 / bits; - if (w.shape(-1) * el_per_int != scales.shape(-1) * group_size) { + if (out_size != scales.shape(-1) * group_size) { std::ostringstream msg; msg << "[dequantize] Shape of scales and biases does not match the matrix " << "given the quantization parameters. Provided matrix of shape " @@ -873,40 +857,52 @@ array affine_dequantize( auto s = to_stream(s_); - auto fallback = - [&wshape, &sshape, &scales, &biases, group_size, bits, el_per_int, s]( - const std::vector& inputs) -> std::vector { - auto& w = inputs[0]; + auto fallback = [&wshape, &sshape, &scales, &biases, group_size, bits, s]( + const std::vector& inputs) -> std::vector { + auto w = inputs[0]; auto& scales = inputs[1]; auto& biases = inputs[2]; - std::vector parts; - for (int start = 0; start < 32; start += bits) { - int shift_left = 32 - (start + bits); - int shift_right = shift_left + start; + if (is_power_of_2(bits)) { + std::vector parts; + for (int start = 0; start < 32; start += bits) { + int shift_left = 32 - (start + bits); + int shift_right = shift_left + start; - parts.push_back(expand_dims( - right_shift( - left_shift(w, array(32 - (start + bits), uint32), s), - array(32 - bits, uint32), - s), - -1, - s)); + parts.push_back(expand_dims( + right_shift( + left_shift(w, array(32 - (start + bits), uint32), s), + array(32 - bits, uint32), + s), + -1, + s)); + } + w = concatenate(parts, -1, s); + } else { + w = expand_dims(w, /* axis= */ -1, s); + w = bitwise_and( + right_shift(w, arange(32, uint32, s), s), array({1}, uint32), s); + auto new_shape = w.shape(); + new_shape[new_shape.size() - 2] = -1; + new_shape.back() = bits; + w = reshape(w, new_shape, s); + array shifts = arange(bits, uint32, s); + w = sum( + left_shift(w, shifts, s), /* axis= */ -1, /* keepdims= */ false, s); } - array w_full = concatenate(parts, -1, s); // Dequantize wshape.push_back(group_size); - w_full = reshape(w_full, wshape, s); - w_full = multiply(w_full, expand_dims(scales, -1, s), s); - w_full = add(w_full, expand_dims(biases, -1, s), s); - w_full = reshape(w_full, sshape, s); + w = reshape(w, wshape, s); + w = multiply(w, expand_dims(scales, -1, s), s); + w = add(w, expand_dims(biases, -1, s), s); + w = reshape(w, sshape, s); - return {w_full}; + return {w}; }; if (s.device == Device::gpu) { auto out_shape = w.shape(); - out_shape.back() = w.shape(-1) * el_per_int; + out_shape.back() = out_size; return array( std::move(out_shape), scales.dtype(), diff --git a/mlx/fast.h b/mlx/fast.h index e1a876882..ddc3512b5 100644 --- a/mlx/fast.h +++ b/mlx/fast.h @@ -47,14 +47,6 @@ std::tuple affine_quantize( int bits = 4, StreamOrDevice s = {}); -array affine_quantize( - const array& w, - const array& scales, - const array& biases, - int group_size = 64, - int bits = 4, - StreamOrDevice s = {}); - array affine_dequantize( const array& w, const array& scales, diff --git a/mlx/ops.cpp b/mlx/ops.cpp index d02ff4b48..cbe8a6861 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -3683,7 +3683,7 @@ std::tuple quantize( int group_size /* = 64 */, int bits /* = 4 */, StreamOrDevice s /* = {} */) { - return fast::affine_quantize(w, group_size, bits); + return fast::affine_quantize(w, group_size, bits, s); } array dequantize( diff --git a/python/src/fast.cpp b/python/src/fast.cpp index b71baa183..cbc8b934d 100644 --- a/python/src/fast.cpp +++ b/python/src/fast.cpp @@ -161,49 +161,6 @@ void init_fast(nb::module_& parent_module) { array: The output array. )pbdoc"); - m.def( - "affine_quantize", - nb::overload_cast< - const array&, - const array&, - const array&, - int, - int, - StreamOrDevice>(&fast::affine_quantize), - "w"_a, - "scales"_a, - "biases"_a, - "group_size"_a = 64, - "bits"_a = 4, - nb::kw_only(), - "stream"_a = nb::none(), - nb::sig( - "def affine_quantize(w: array, /, scales: array, biases: array, group_size: int = 64, bits: int = 4, *, stream: Union[None, Stream, Device] = None) -> array"), - R"pbdoc( - Quantize the matrix ``w`` using the provided ``scales`` and - ``biases`` and the ``group_size`` and ``bits`` configuration. - - Formally, given the notation in :func:`quantize`, we compute - :math:`w_i` from :math:`\hat{w_i}` and corresponding :math:`s` and - :math:`\beta` as follows - - .. math:: - - w_i = s (\hat{w_i} + \beta) - - Args: - w (array): Matrix to be quantize - scales (array): The scales to use per ``group_size`` elements of ``w`` - biases (array): The biases to use per ``group_size`` elements of ``w`` - group_size (int, optional): The size of the group in ``w`` that shares a - scale and bias. (default: ``64``) - bits (int, optional): The number of bits occupied by each element in - ``w``. (default: ``4``) - - Returns: - array: The quantized version of ``w`` - )pbdoc"); - m.def( "metal_kernel", [](const std::string& name, diff --git a/python/tests/test_fast.py b/python/tests/test_fast.py index d27bdddfb..56006e59f 100644 --- a/python/tests/test_fast.py +++ b/python/tests/test_fast.py @@ -549,18 +549,6 @@ class TestFast(mlx_tests.MLXTestCase): )(x) self.assertTrue(mx.allclose(vmap_out, vmap_fast_out)) - def test_affine_quantize(self): - mx.random.seed(7) - x = mx.random.uniform(shape=(4, 1024)) - for bits in (2, 4, 8): - for group_size in (32, 64, 128): - with self.subTest(bits=bits, group_size=group_size): - w, scales, biases = mx.quantize(x, bits=bits, group_size=group_size) - w_p = mx.fast.affine_quantize( - x, scales, biases, bits=bits, group_size=group_size - ) - self.assertTrue(mx.allclose(w, w_p)) - @unittest.skipIf(not mx.metal.is_available(), "Metal is not available") def test_custom_kernel_basic(self): mx.random.seed(7) diff --git a/python/tests/test_quantized.py b/python/tests/test_quantized.py index 607f7ef24..7d4ba9949 100644 --- a/python/tests/test_quantized.py +++ b/python/tests/test_quantized.py @@ -11,7 +11,7 @@ class TestQuantized(mlx_tests.MLXTestCase): def test_quantize_dequantize(self): w = mx.random.normal(shape=(128, 512)) for gs in [32, 64, 128]: - for b in [2, 4, 8]: + for b in [2, 3, 6, 4, 8]: with self.subTest(gs=gs, b=b): w_q, scales, biases = mx.quantize(w, group_size=gs, bits=b) w_hat = mx.dequantize(w_q, scales, biases, gs, b) @@ -22,7 +22,7 @@ class TestQuantized(mlx_tests.MLXTestCase): # test quantize/dequantize 0s a = mx.zeros((256, 512)) for gs in [32, 64, 128]: - for b in [2, 4, 8]: + for b in [2, 3, 4, 6, 8]: w_q, scales, biases = mx.quantize(a, gs, b) a_hat = mx.dequantize(w_q, scales, biases, gs, b) self.assertTrue(mx.all(a_hat == 0)) @@ -116,7 +116,7 @@ class TestQuantized(mlx_tests.MLXTestCase): k1, k2 = mx.random.split(key) tests = product( [128, 64, 32], # group_size - [2, 4, 8], # bits + [2, 3, 4, 6, 8], # bits [512, 1024, 67], # M [64, 128, 512, 1024], # N [0, 1, 3, 8], # B @@ -143,7 +143,7 @@ class TestQuantized(mlx_tests.MLXTestCase): k1, k2 = mx.random.split(key) tests = product( [128, 64, 32], # group_size - [2, 4, 8], # bits + [2, 3, 4, 6, 8], # bits [512, 1024], # M [512, 1024, 67], # N [0, 1, 3, 8], # B