From dfa0b9aab427a28139c54ebeb9564ff35c35d758 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Fri, 8 Nov 2024 20:10:39 -0800 Subject: [PATCH] Cpu fast quantize (#1578) * cpu quantize * fix --- mlx/backend/common/quantized.cpp | 103 +++++++++++++++++++++++++++++- mlx/backend/no_cpu/primitives.cpp | 5 ++ mlx/fast.cpp | 42 +++++------- mlx/fast_primitives.h | 4 +- 4 files changed, 125 insertions(+), 29 deletions(-) diff --git a/mlx/backend/common/quantized.cpp b/mlx/backend/common/quantized.cpp index daeb50c6a..d939334c9 100644 --- a/mlx/backend/common/quantized.cpp +++ b/mlx/backend/common/quantized.cpp @@ -2,7 +2,9 @@ #include -#include "mlx/backend/metal/copy.h" +#include "mlx/backend/common/copy.h" +#include "mlx/backend/common/ops.h" +#include "mlx/fast_primitives.h" #include "mlx/primitives.h" namespace mlx::core { @@ -404,4 +406,103 @@ void GatherQMM::eval(const std::vector& inputs, array& out) { transpose_); } +template +void quantize( + const array& w_, + array& out_, + array& scales_, + array& biases_, + int bits, + int group_size, + bool compute_scale_bias) { + const T* w = w_.data(); + T* scales = scales_.data(); + T* biases = biases_.data(); + auto out = out_.data(); + + T n_bins = (1 << bits) - 1; + T eps = 1e-7; + int el_per_int = 32 / bits; + int int_per_group = group_size / el_per_int; + size_t n_groups = w_.size() / group_size; + + for (size_t i = 0; i < n_groups; ++i) { + size_t w_idx = i * group_size; + if (compute_scale_bias) { + T w_min = std::numeric_limits::infinity(); + T w_max = -w_min; + for (int j = 0; j < group_size; ++j) { + w_max = std::max(w_max, w[w_idx + j]); + w_min = std::min(w_min, w[w_idx + j]); + } + bool mask = std::abs(w_min) > std::abs(w_max); + T scale = std::max(T((w_max - w_min) / n_bins), eps); + scale = mask ? scale : -scale; + + auto edge = mask ? w_min : w_max; + auto q0 = std::rint(edge / scale); + if (q0 == 0) { + scales[i] = scale; + biases[i] = 0; + } else { + scales[i] = edge / q0; + biases[i] = edge; + } + } + size_t out_idx = i * int_per_group; + for (int j = 0; j < int_per_group; ++j) { + uint32_t out_el = 0; + for (int k = 0; k < el_per_int; ++k) { + T w_el = w[w_idx + j * el_per_int + k]; + w_el = std::rint((w_el - biases[i]) / scales[i]); + w_el = std::min(std::max(w_el, T(0)), n_bins); + out_el |= static_cast(w_el) << (k * bits); + } + out[out_idx + j] = out_el; + } + } +} + +void fast::AffineQuantize::eval_cpu( + const std::vector& inputs, + std::vector& outputs) { + bool compute_scale_bias = inputs.size() == 1; + + auto ensure_row_contiguous = [](const array& arr) { + if (arr.flags().row_contiguous) { + return arr; + } else { + array arr_copy(arr.shape(), arr.dtype(), nullptr, {}); + copy(arr, arr_copy, CopyType::General); + return arr_copy; + } + }; + auto w = ensure_row_contiguous(inputs[0]); + + auto& out = outputs[0]; + out.set_data(allocator::malloc_or_wait(out.nbytes())); + + auto& scales = + compute_scale_bias ? outputs[1] : const_cast(inputs[1]); + auto& biases = + compute_scale_bias ? outputs[2] : const_cast(inputs[2]); + if (compute_scale_bias) { + scales.set_data(allocator::malloc_or_wait(scales.nbytes())); + biases.set_data(allocator::malloc_or_wait(biases.nbytes())); + } + if (w.dtype() == float16) { + quantize( + w, out, scales, biases, bits_, group_size_, compute_scale_bias); + } else if (w.dtype() == bfloat16) { + quantize( + w, out, scales, biases, bits_, group_size_, compute_scale_bias); + } else if (w.dtype() == float32) { + quantize( + w, out, scales, biases, bits_, group_size_, compute_scale_bias); + } else { + throw std::runtime_error( + "[fast::AffineQuantize::eval_cpu] Only supports floating point inputs"); + } +} + } // namespace mlx::core diff --git a/mlx/backend/no_cpu/primitives.cpp b/mlx/backend/no_cpu/primitives.cpp index c87fcc8bb..9afeaec8b 100644 --- a/mlx/backend/no_cpu/primitives.cpp +++ b/mlx/backend/no_cpu/primitives.cpp @@ -1,6 +1,7 @@ // Copyright © 2024 Apple Inc. #include "mlx/primitives.h" +#include "mlx/fast_primitives.h" #define NO_CPU_MULTI(func) \ void func::eval_cpu( \ @@ -112,4 +113,8 @@ NO_CPU(Transpose) NO_CPU(Inverse) NO_CPU(View) +namespace fast { +NO_CPU_MULTI(AffineQuantize) +} // namespace fast + } // namespace mlx::core diff --git a/mlx/fast.cpp b/mlx/fast.cpp index 895ae2d52..d3eb77d06 100644 --- a/mlx/fast.cpp +++ b/mlx/fast.cpp @@ -773,20 +773,15 @@ affine_quantize(const array& w, int group_size, int bits, StreamOrDevice s_) { }; }; - std::vector outputs; - if (s.device == Device::gpu) { - auto wq_shape = w.shape(); - wq_shape.back() = w.shape(-1) / el_per_int; - auto sshape = w.shape(); - sshape.back() = w.shape(-1) / group_size; - outputs = array::make_arrays( - {wq_shape, sshape, sshape}, - {uint32, w.dtype(), w.dtype()}, - std::make_shared(s, fallback, group_size, bits, false), - {w}); - } else { - outputs = fallback({w}); - } + auto wq_shape = w.shape(); + wq_shape.back() = w.shape(-1) / el_per_int; + auto sshape = w.shape(); + sshape.back() = w.shape(-1) / group_size; + auto outputs = array::make_arrays( + {std::move(wq_shape), sshape, sshape}, + {uint32, w.dtype(), w.dtype()}, + std::make_shared(s, fallback, group_size, bits, false), + {w}); return {outputs[0], outputs[1], outputs[2]}; } @@ -814,16 +809,13 @@ array affine_quantize( return {reshape(packed_w, wshape, s)}; }; - if (s.device == Device::gpu) { - auto out_shape = w.shape(); - out_shape.back() = w.shape(-1) / el_per_int; - return array( - out_shape, - uint32, - std::make_shared(s, fallback, group_size, bits, false), - {w, scales, biases}); - } - return fallback({w, scales, biases})[0]; + auto out_shape = w.shape(); + out_shape.back() = w.shape(-1) / el_per_int; + return array( + std::move(out_shape), + uint32, + std::make_shared(s, fallback, group_size, bits, false), + {w, scales, biases}); } array affine_dequantize( @@ -916,7 +908,7 @@ array affine_dequantize( auto out_shape = w.shape(); out_shape.back() = w.shape(-1) * el_per_int; return array( - out_shape, + std::move(out_shape), scales.dtype(), std::make_shared(s, fallback, group_size, bits, true), {w, scales, biases}); diff --git a/mlx/fast_primitives.h b/mlx/fast_primitives.h index 9233a1628..30db282ff 100644 --- a/mlx/fast_primitives.h +++ b/mlx/fast_primitives.h @@ -228,9 +228,7 @@ class AffineQuantize : public Custom { dequantize_(dequantize) {} void eval_cpu(const std::vector& inputs, std::vector& outputs) - override { - throw std::runtime_error("NYI"); - } + override; void eval_gpu(const std::vector& inputs, std::vector& outputs) override;