mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-19 15:41:13 +08:00
parent
a4c47b0276
commit
dfa0b9aab4
@ -2,7 +2,9 @@
|
|||||||
|
|
||||||
#include <cassert>
|
#include <cassert>
|
||||||
|
|
||||||
#include "mlx/backend/metal/copy.h"
|
#include "mlx/backend/common/copy.h"
|
||||||
|
#include "mlx/backend/common/ops.h"
|
||||||
|
#include "mlx/fast_primitives.h"
|
||||||
#include "mlx/primitives.h"
|
#include "mlx/primitives.h"
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
@ -404,4 +406,103 @@ void GatherQMM::eval(const std::vector<array>& inputs, array& out) {
|
|||||||
transpose_);
|
transpose_);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void quantize(
|
||||||
|
const array& w_,
|
||||||
|
array& out_,
|
||||||
|
array& scales_,
|
||||||
|
array& biases_,
|
||||||
|
int bits,
|
||||||
|
int group_size,
|
||||||
|
bool compute_scale_bias) {
|
||||||
|
const T* w = w_.data<T>();
|
||||||
|
T* scales = scales_.data<T>();
|
||||||
|
T* biases = biases_.data<T>();
|
||||||
|
auto out = out_.data<uint32_t>();
|
||||||
|
|
||||||
|
T n_bins = (1 << bits) - 1;
|
||||||
|
T eps = 1e-7;
|
||||||
|
int el_per_int = 32 / bits;
|
||||||
|
int int_per_group = group_size / el_per_int;
|
||||||
|
size_t n_groups = w_.size() / group_size;
|
||||||
|
|
||||||
|
for (size_t i = 0; i < n_groups; ++i) {
|
||||||
|
size_t w_idx = i * group_size;
|
||||||
|
if (compute_scale_bias) {
|
||||||
|
T w_min = std::numeric_limits<float>::infinity();
|
||||||
|
T w_max = -w_min;
|
||||||
|
for (int j = 0; j < group_size; ++j) {
|
||||||
|
w_max = std::max(w_max, w[w_idx + j]);
|
||||||
|
w_min = std::min(w_min, w[w_idx + j]);
|
||||||
|
}
|
||||||
|
bool mask = std::abs(w_min) > std::abs(w_max);
|
||||||
|
T scale = std::max(T((w_max - w_min) / n_bins), eps);
|
||||||
|
scale = mask ? scale : -scale;
|
||||||
|
|
||||||
|
auto edge = mask ? w_min : w_max;
|
||||||
|
auto q0 = std::rint(edge / scale);
|
||||||
|
if (q0 == 0) {
|
||||||
|
scales[i] = scale;
|
||||||
|
biases[i] = 0;
|
||||||
|
} else {
|
||||||
|
scales[i] = edge / q0;
|
||||||
|
biases[i] = edge;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
size_t out_idx = i * int_per_group;
|
||||||
|
for (int j = 0; j < int_per_group; ++j) {
|
||||||
|
uint32_t out_el = 0;
|
||||||
|
for (int k = 0; k < el_per_int; ++k) {
|
||||||
|
T w_el = w[w_idx + j * el_per_int + k];
|
||||||
|
w_el = std::rint((w_el - biases[i]) / scales[i]);
|
||||||
|
w_el = std::min(std::max(w_el, T(0)), n_bins);
|
||||||
|
out_el |= static_cast<uint32_t>(w_el) << (k * bits);
|
||||||
|
}
|
||||||
|
out[out_idx + j] = out_el;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void fast::AffineQuantize::eval_cpu(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
std::vector<array>& outputs) {
|
||||||
|
bool compute_scale_bias = inputs.size() == 1;
|
||||||
|
|
||||||
|
auto ensure_row_contiguous = [](const array& arr) {
|
||||||
|
if (arr.flags().row_contiguous) {
|
||||||
|
return arr;
|
||||||
|
} else {
|
||||||
|
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
||||||
|
copy(arr, arr_copy, CopyType::General);
|
||||||
|
return arr_copy;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
auto w = ensure_row_contiguous(inputs[0]);
|
||||||
|
|
||||||
|
auto& out = outputs[0];
|
||||||
|
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||||
|
|
||||||
|
auto& scales =
|
||||||
|
compute_scale_bias ? outputs[1] : const_cast<array&>(inputs[1]);
|
||||||
|
auto& biases =
|
||||||
|
compute_scale_bias ? outputs[2] : const_cast<array&>(inputs[2]);
|
||||||
|
if (compute_scale_bias) {
|
||||||
|
scales.set_data(allocator::malloc_or_wait(scales.nbytes()));
|
||||||
|
biases.set_data(allocator::malloc_or_wait(biases.nbytes()));
|
||||||
|
}
|
||||||
|
if (w.dtype() == float16) {
|
||||||
|
quantize<float16_t>(
|
||||||
|
w, out, scales, biases, bits_, group_size_, compute_scale_bias);
|
||||||
|
} else if (w.dtype() == bfloat16) {
|
||||||
|
quantize<bfloat16_t>(
|
||||||
|
w, out, scales, biases, bits_, group_size_, compute_scale_bias);
|
||||||
|
} else if (w.dtype() == float32) {
|
||||||
|
quantize<float>(
|
||||||
|
w, out, scales, biases, bits_, group_size_, compute_scale_bias);
|
||||||
|
} else {
|
||||||
|
throw std::runtime_error(
|
||||||
|
"[fast::AffineQuantize::eval_cpu] Only supports floating point inputs");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
// Copyright © 2024 Apple Inc.
|
// Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
#include "mlx/primitives.h"
|
#include "mlx/primitives.h"
|
||||||
|
#include "mlx/fast_primitives.h"
|
||||||
|
|
||||||
#define NO_CPU_MULTI(func) \
|
#define NO_CPU_MULTI(func) \
|
||||||
void func::eval_cpu( \
|
void func::eval_cpu( \
|
||||||
@ -112,4 +113,8 @@ NO_CPU(Transpose)
|
|||||||
NO_CPU(Inverse)
|
NO_CPU(Inverse)
|
||||||
NO_CPU(View)
|
NO_CPU(View)
|
||||||
|
|
||||||
|
namespace fast {
|
||||||
|
NO_CPU_MULTI(AffineQuantize)
|
||||||
|
} // namespace fast
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
16
mlx/fast.cpp
16
mlx/fast.cpp
@ -773,20 +773,15 @@ affine_quantize(const array& w, int group_size, int bits, StreamOrDevice s_) {
|
|||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
|
||||||
std::vector<array> outputs;
|
|
||||||
if (s.device == Device::gpu) {
|
|
||||||
auto wq_shape = w.shape();
|
auto wq_shape = w.shape();
|
||||||
wq_shape.back() = w.shape(-1) / el_per_int;
|
wq_shape.back() = w.shape(-1) / el_per_int;
|
||||||
auto sshape = w.shape();
|
auto sshape = w.shape();
|
||||||
sshape.back() = w.shape(-1) / group_size;
|
sshape.back() = w.shape(-1) / group_size;
|
||||||
outputs = array::make_arrays(
|
auto outputs = array::make_arrays(
|
||||||
{wq_shape, sshape, sshape},
|
{std::move(wq_shape), sshape, sshape},
|
||||||
{uint32, w.dtype(), w.dtype()},
|
{uint32, w.dtype(), w.dtype()},
|
||||||
std::make_shared<AffineQuantize>(s, fallback, group_size, bits, false),
|
std::make_shared<AffineQuantize>(s, fallback, group_size, bits, false),
|
||||||
{w});
|
{w});
|
||||||
} else {
|
|
||||||
outputs = fallback({w});
|
|
||||||
}
|
|
||||||
return {outputs[0], outputs[1], outputs[2]};
|
return {outputs[0], outputs[1], outputs[2]};
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -814,16 +809,13 @@ array affine_quantize(
|
|||||||
return {reshape(packed_w, wshape, s)};
|
return {reshape(packed_w, wshape, s)};
|
||||||
};
|
};
|
||||||
|
|
||||||
if (s.device == Device::gpu) {
|
|
||||||
auto out_shape = w.shape();
|
auto out_shape = w.shape();
|
||||||
out_shape.back() = w.shape(-1) / el_per_int;
|
out_shape.back() = w.shape(-1) / el_per_int;
|
||||||
return array(
|
return array(
|
||||||
out_shape,
|
std::move(out_shape),
|
||||||
uint32,
|
uint32,
|
||||||
std::make_shared<AffineQuantize>(s, fallback, group_size, bits, false),
|
std::make_shared<AffineQuantize>(s, fallback, group_size, bits, false),
|
||||||
{w, scales, biases});
|
{w, scales, biases});
|
||||||
}
|
|
||||||
return fallback({w, scales, biases})[0];
|
|
||||||
}
|
}
|
||||||
|
|
||||||
array affine_dequantize(
|
array affine_dequantize(
|
||||||
@ -916,7 +908,7 @@ array affine_dequantize(
|
|||||||
auto out_shape = w.shape();
|
auto out_shape = w.shape();
|
||||||
out_shape.back() = w.shape(-1) * el_per_int;
|
out_shape.back() = w.shape(-1) * el_per_int;
|
||||||
return array(
|
return array(
|
||||||
out_shape,
|
std::move(out_shape),
|
||||||
scales.dtype(),
|
scales.dtype(),
|
||||||
std::make_shared<AffineQuantize>(s, fallback, group_size, bits, true),
|
std::make_shared<AffineQuantize>(s, fallback, group_size, bits, true),
|
||||||
{w, scales, biases});
|
{w, scales, biases});
|
||||||
|
@ -228,9 +228,7 @@ class AffineQuantize : public Custom {
|
|||||||
dequantize_(dequantize) {}
|
dequantize_(dequantize) {}
|
||||||
|
|
||||||
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||||
override {
|
override;
|
||||||
throw std::runtime_error("NYI");
|
|
||||||
}
|
|
||||||
|
|
||||||
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||||
override;
|
override;
|
||||||
|
Loading…
Reference in New Issue
Block a user