mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-20 08:01:12 +08:00
parent
a4c47b0276
commit
dfa0b9aab4
@ -2,7 +2,9 @@
|
||||
|
||||
#include <cassert>
|
||||
|
||||
#include "mlx/backend/metal/copy.h"
|
||||
#include "mlx/backend/common/copy.h"
|
||||
#include "mlx/backend/common/ops.h"
|
||||
#include "mlx/fast_primitives.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core {
|
||||
@ -404,4 +406,103 @@ void GatherQMM::eval(const std::vector<array>& inputs, array& out) {
|
||||
transpose_);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void quantize(
|
||||
const array& w_,
|
||||
array& out_,
|
||||
array& scales_,
|
||||
array& biases_,
|
||||
int bits,
|
||||
int group_size,
|
||||
bool compute_scale_bias) {
|
||||
const T* w = w_.data<T>();
|
||||
T* scales = scales_.data<T>();
|
||||
T* biases = biases_.data<T>();
|
||||
auto out = out_.data<uint32_t>();
|
||||
|
||||
T n_bins = (1 << bits) - 1;
|
||||
T eps = 1e-7;
|
||||
int el_per_int = 32 / bits;
|
||||
int int_per_group = group_size / el_per_int;
|
||||
size_t n_groups = w_.size() / group_size;
|
||||
|
||||
for (size_t i = 0; i < n_groups; ++i) {
|
||||
size_t w_idx = i * group_size;
|
||||
if (compute_scale_bias) {
|
||||
T w_min = std::numeric_limits<float>::infinity();
|
||||
T w_max = -w_min;
|
||||
for (int j = 0; j < group_size; ++j) {
|
||||
w_max = std::max(w_max, w[w_idx + j]);
|
||||
w_min = std::min(w_min, w[w_idx + j]);
|
||||
}
|
||||
bool mask = std::abs(w_min) > std::abs(w_max);
|
||||
T scale = std::max(T((w_max - w_min) / n_bins), eps);
|
||||
scale = mask ? scale : -scale;
|
||||
|
||||
auto edge = mask ? w_min : w_max;
|
||||
auto q0 = std::rint(edge / scale);
|
||||
if (q0 == 0) {
|
||||
scales[i] = scale;
|
||||
biases[i] = 0;
|
||||
} else {
|
||||
scales[i] = edge / q0;
|
||||
biases[i] = edge;
|
||||
}
|
||||
}
|
||||
size_t out_idx = i * int_per_group;
|
||||
for (int j = 0; j < int_per_group; ++j) {
|
||||
uint32_t out_el = 0;
|
||||
for (int k = 0; k < el_per_int; ++k) {
|
||||
T w_el = w[w_idx + j * el_per_int + k];
|
||||
w_el = std::rint((w_el - biases[i]) / scales[i]);
|
||||
w_el = std::min(std::max(w_el, T(0)), n_bins);
|
||||
out_el |= static_cast<uint32_t>(w_el) << (k * bits);
|
||||
}
|
||||
out[out_idx + j] = out_el;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void fast::AffineQuantize::eval_cpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
bool compute_scale_bias = inputs.size() == 1;
|
||||
|
||||
auto ensure_row_contiguous = [](const array& arr) {
|
||||
if (arr.flags().row_contiguous) {
|
||||
return arr;
|
||||
} else {
|
||||
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
||||
copy(arr, arr_copy, CopyType::General);
|
||||
return arr_copy;
|
||||
}
|
||||
};
|
||||
auto w = ensure_row_contiguous(inputs[0]);
|
||||
|
||||
auto& out = outputs[0];
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
|
||||
auto& scales =
|
||||
compute_scale_bias ? outputs[1] : const_cast<array&>(inputs[1]);
|
||||
auto& biases =
|
||||
compute_scale_bias ? outputs[2] : const_cast<array&>(inputs[2]);
|
||||
if (compute_scale_bias) {
|
||||
scales.set_data(allocator::malloc_or_wait(scales.nbytes()));
|
||||
biases.set_data(allocator::malloc_or_wait(biases.nbytes()));
|
||||
}
|
||||
if (w.dtype() == float16) {
|
||||
quantize<float16_t>(
|
||||
w, out, scales, biases, bits_, group_size_, compute_scale_bias);
|
||||
} else if (w.dtype() == bfloat16) {
|
||||
quantize<bfloat16_t>(
|
||||
w, out, scales, biases, bits_, group_size_, compute_scale_bias);
|
||||
} else if (w.dtype() == float32) {
|
||||
quantize<float>(
|
||||
w, out, scales, biases, bits_, group_size_, compute_scale_bias);
|
||||
} else {
|
||||
throw std::runtime_error(
|
||||
"[fast::AffineQuantize::eval_cpu] Only supports floating point inputs");
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@ -1,6 +1,7 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/fast_primitives.h"
|
||||
|
||||
#define NO_CPU_MULTI(func) \
|
||||
void func::eval_cpu( \
|
||||
@ -112,4 +113,8 @@ NO_CPU(Transpose)
|
||||
NO_CPU(Inverse)
|
||||
NO_CPU(View)
|
||||
|
||||
namespace fast {
|
||||
NO_CPU_MULTI(AffineQuantize)
|
||||
} // namespace fast
|
||||
|
||||
} // namespace mlx::core
|
||||
|
16
mlx/fast.cpp
16
mlx/fast.cpp
@ -773,20 +773,15 @@ affine_quantize(const array& w, int group_size, int bits, StreamOrDevice s_) {
|
||||
};
|
||||
};
|
||||
|
||||
std::vector<array> outputs;
|
||||
if (s.device == Device::gpu) {
|
||||
auto wq_shape = w.shape();
|
||||
wq_shape.back() = w.shape(-1) / el_per_int;
|
||||
auto sshape = w.shape();
|
||||
sshape.back() = w.shape(-1) / group_size;
|
||||
outputs = array::make_arrays(
|
||||
{wq_shape, sshape, sshape},
|
||||
auto outputs = array::make_arrays(
|
||||
{std::move(wq_shape), sshape, sshape},
|
||||
{uint32, w.dtype(), w.dtype()},
|
||||
std::make_shared<AffineQuantize>(s, fallback, group_size, bits, false),
|
||||
{w});
|
||||
} else {
|
||||
outputs = fallback({w});
|
||||
}
|
||||
return {outputs[0], outputs[1], outputs[2]};
|
||||
}
|
||||
|
||||
@ -814,17 +809,14 @@ array affine_quantize(
|
||||
return {reshape(packed_w, wshape, s)};
|
||||
};
|
||||
|
||||
if (s.device == Device::gpu) {
|
||||
auto out_shape = w.shape();
|
||||
out_shape.back() = w.shape(-1) / el_per_int;
|
||||
return array(
|
||||
out_shape,
|
||||
std::move(out_shape),
|
||||
uint32,
|
||||
std::make_shared<AffineQuantize>(s, fallback, group_size, bits, false),
|
||||
{w, scales, biases});
|
||||
}
|
||||
return fallback({w, scales, biases})[0];
|
||||
}
|
||||
|
||||
array affine_dequantize(
|
||||
const array& w,
|
||||
@ -916,7 +908,7 @@ array affine_dequantize(
|
||||
auto out_shape = w.shape();
|
||||
out_shape.back() = w.shape(-1) * el_per_int;
|
||||
return array(
|
||||
out_shape,
|
||||
std::move(out_shape),
|
||||
scales.dtype(),
|
||||
std::make_shared<AffineQuantize>(s, fallback, group_size, bits, true),
|
||||
{w, scales, biases});
|
||||
|
@ -228,9 +228,7 @@ class AffineQuantize : public Custom {
|
||||
dequantize_(dequantize) {}
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
override {
|
||||
throw std::runtime_error("NYI");
|
||||
}
|
||||
override;
|
||||
|
||||
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
override;
|
||||
|
Loading…
Reference in New Issue
Block a user