Cpu fast quantize (#1578)

* cpu quantize

* fix
This commit is contained in:
Awni Hannun 2024-11-08 20:10:39 -08:00 committed by GitHub
parent a4c47b0276
commit dfa0b9aab4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 125 additions and 29 deletions

View File

@ -2,7 +2,9 @@
#include <cassert>
#include "mlx/backend/metal/copy.h"
#include "mlx/backend/common/copy.h"
#include "mlx/backend/common/ops.h"
#include "mlx/fast_primitives.h"
#include "mlx/primitives.h"
namespace mlx::core {
@ -404,4 +406,103 @@ void GatherQMM::eval(const std::vector<array>& inputs, array& out) {
transpose_);
}
template <typename T>
void quantize(
const array& w_,
array& out_,
array& scales_,
array& biases_,
int bits,
int group_size,
bool compute_scale_bias) {
const T* w = w_.data<T>();
T* scales = scales_.data<T>();
T* biases = biases_.data<T>();
auto out = out_.data<uint32_t>();
T n_bins = (1 << bits) - 1;
T eps = 1e-7;
int el_per_int = 32 / bits;
int int_per_group = group_size / el_per_int;
size_t n_groups = w_.size() / group_size;
for (size_t i = 0; i < n_groups; ++i) {
size_t w_idx = i * group_size;
if (compute_scale_bias) {
T w_min = std::numeric_limits<float>::infinity();
T w_max = -w_min;
for (int j = 0; j < group_size; ++j) {
w_max = std::max(w_max, w[w_idx + j]);
w_min = std::min(w_min, w[w_idx + j]);
}
bool mask = std::abs(w_min) > std::abs(w_max);
T scale = std::max(T((w_max - w_min) / n_bins), eps);
scale = mask ? scale : -scale;
auto edge = mask ? w_min : w_max;
auto q0 = std::rint(edge / scale);
if (q0 == 0) {
scales[i] = scale;
biases[i] = 0;
} else {
scales[i] = edge / q0;
biases[i] = edge;
}
}
size_t out_idx = i * int_per_group;
for (int j = 0; j < int_per_group; ++j) {
uint32_t out_el = 0;
for (int k = 0; k < el_per_int; ++k) {
T w_el = w[w_idx + j * el_per_int + k];
w_el = std::rint((w_el - biases[i]) / scales[i]);
w_el = std::min(std::max(w_el, T(0)), n_bins);
out_el |= static_cast<uint32_t>(w_el) << (k * bits);
}
out[out_idx + j] = out_el;
}
}
}
void fast::AffineQuantize::eval_cpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
bool compute_scale_bias = inputs.size() == 1;
auto ensure_row_contiguous = [](const array& arr) {
if (arr.flags().row_contiguous) {
return arr;
} else {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy(arr, arr_copy, CopyType::General);
return arr_copy;
}
};
auto w = ensure_row_contiguous(inputs[0]);
auto& out = outputs[0];
out.set_data(allocator::malloc_or_wait(out.nbytes()));
auto& scales =
compute_scale_bias ? outputs[1] : const_cast<array&>(inputs[1]);
auto& biases =
compute_scale_bias ? outputs[2] : const_cast<array&>(inputs[2]);
if (compute_scale_bias) {
scales.set_data(allocator::malloc_or_wait(scales.nbytes()));
biases.set_data(allocator::malloc_or_wait(biases.nbytes()));
}
if (w.dtype() == float16) {
quantize<float16_t>(
w, out, scales, biases, bits_, group_size_, compute_scale_bias);
} else if (w.dtype() == bfloat16) {
quantize<bfloat16_t>(
w, out, scales, biases, bits_, group_size_, compute_scale_bias);
} else if (w.dtype() == float32) {
quantize<float>(
w, out, scales, biases, bits_, group_size_, compute_scale_bias);
} else {
throw std::runtime_error(
"[fast::AffineQuantize::eval_cpu] Only supports floating point inputs");
}
}
} // namespace mlx::core

View File

@ -1,6 +1,7 @@
// Copyright © 2024 Apple Inc.
#include "mlx/primitives.h"
#include "mlx/fast_primitives.h"
#define NO_CPU_MULTI(func) \
void func::eval_cpu( \
@ -112,4 +113,8 @@ NO_CPU(Transpose)
NO_CPU(Inverse)
NO_CPU(View)
namespace fast {
NO_CPU_MULTI(AffineQuantize)
} // namespace fast
} // namespace mlx::core

View File

@ -773,20 +773,15 @@ affine_quantize(const array& w, int group_size, int bits, StreamOrDevice s_) {
};
};
std::vector<array> outputs;
if (s.device == Device::gpu) {
auto wq_shape = w.shape();
wq_shape.back() = w.shape(-1) / el_per_int;
auto sshape = w.shape();
sshape.back() = w.shape(-1) / group_size;
outputs = array::make_arrays(
{wq_shape, sshape, sshape},
auto outputs = array::make_arrays(
{std::move(wq_shape), sshape, sshape},
{uint32, w.dtype(), w.dtype()},
std::make_shared<AffineQuantize>(s, fallback, group_size, bits, false),
{w});
} else {
outputs = fallback({w});
}
return {outputs[0], outputs[1], outputs[2]};
}
@ -814,17 +809,14 @@ array affine_quantize(
return {reshape(packed_w, wshape, s)};
};
if (s.device == Device::gpu) {
auto out_shape = w.shape();
out_shape.back() = w.shape(-1) / el_per_int;
return array(
out_shape,
std::move(out_shape),
uint32,
std::make_shared<AffineQuantize>(s, fallback, group_size, bits, false),
{w, scales, biases});
}
return fallback({w, scales, biases})[0];
}
array affine_dequantize(
const array& w,
@ -916,7 +908,7 @@ array affine_dequantize(
auto out_shape = w.shape();
out_shape.back() = w.shape(-1) * el_per_int;
return array(
out_shape,
std::move(out_shape),
scales.dtype(),
std::make_shared<AffineQuantize>(s, fallback, group_size, bits, true),
{w, scales, biases});

View File

@ -228,9 +228,7 @@ class AffineQuantize : public Custom {
dequantize_(dequantize) {}
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override {
throw std::runtime_error("NYI");
}
override;
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override;