mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-27 19:31:16 +08:00
Add a quantization type in the ops
This commit is contained in:
parent
50f3535693
commit
bdd68bd893
136
mlx/ops.cpp
136
mlx/ops.cpp
@ -75,10 +75,33 @@ std::pair<int, int> extract_quantized_matmul_dims(
|
||||
const array& x,
|
||||
const array& w,
|
||||
const array& scales,
|
||||
const array& biases,
|
||||
const std::optional<array>& biases,
|
||||
bool transpose,
|
||||
int group_size,
|
||||
int bits) {
|
||||
int bits,
|
||||
QuantizationType type) {
|
||||
// Check if we have biases as expected
|
||||
switch (type) {
|
||||
case QuantizationType::Affine:
|
||||
if (!biases.has_value()) {
|
||||
std::ostringstream msg;
|
||||
msg << "[" << tag
|
||||
<< "] The biases argument is required for quantization "
|
||||
<< "type '" << type << "'";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
break;
|
||||
case QuantizationType::AffinePacked:
|
||||
if (biases.has_value()) {
|
||||
std::ostringstream msg;
|
||||
msg << "[" << tag << "] Quantization type '" << type
|
||||
<< "' does not use "
|
||||
<< "biases but biases were provided";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
if (w.dtype() != uint32) {
|
||||
std::ostringstream msg;
|
||||
msg << "[" << tag << "] The weight matrix should be uint32 "
|
||||
@ -86,11 +109,11 @@ std::pair<int, int> extract_quantized_matmul_dims(
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
if (scales.shape() != biases.shape()) {
|
||||
if (biases.has_value() && scales.shape() != biases.value().shape()) {
|
||||
std::ostringstream msg;
|
||||
msg << "[" << tag << "] Scales and biases should have the same shape. "
|
||||
<< "Received scales with shape " << scales.shape()
|
||||
<< " and biases with " << biases.shape();
|
||||
<< " and biases with " << biases.value().shape();
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
@ -99,25 +122,33 @@ std::pair<int, int> extract_quantized_matmul_dims(
|
||||
std::ostringstream msg;
|
||||
msg << "[" << tag
|
||||
<< "] Weight, scales and biases should have the same batch shape. "
|
||||
<< "Received weight with shape " << w.shape() << ", scales with "
|
||||
<< scales.shape() << " and biases with " << biases.shape();
|
||||
<< "Received weight with shape " << w.shape()
|
||||
<< " and scales/biases with " << scales.shape();
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
if (w.shape(-1) * 32 / bits != scales.shape(-1) * group_size) {
|
||||
int weight_dims = w.shape(-1) * 32 / bits;
|
||||
int scales_dims = scales.shape(-1) * group_size;
|
||||
if (type == QuantizationType::AffinePacked) {
|
||||
scales_dims /= 8;
|
||||
}
|
||||
|
||||
if (weight_dims != scales_dims) {
|
||||
std::ostringstream msg;
|
||||
msg << "[" << tag << "] The shapes of the weight and scales are "
|
||||
<< "incompatible based on bits and group_size. w.shape() == "
|
||||
<< w.shape() << " and scales.shape() == " << scales.shape()
|
||||
<< " with group_size=" << group_size << " and bits=" << bits;
|
||||
<< "incompatible based on bits, group_size and quantization type. "
|
||||
<< "w.shape() == " << w.shape()
|
||||
<< " and scales.shape() == " << scales.shape()
|
||||
<< " with group_size=" << group_size << ", bits=" << bits
|
||||
<< " and type='" << type << "'";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
int x_inner_dims = x.shape(-1);
|
||||
|
||||
// Calculate the expanded w's dims
|
||||
int w_inner_dims = (transpose) ? w.shape(-1) * 32 / bits : w.shape(-2);
|
||||
int w_outer_dims = (transpose) ? w.shape(-2) : w.shape(-1) * 32 / bits;
|
||||
int w_inner_dims = (transpose) ? weight_dims : w.shape(-2);
|
||||
int w_outer_dims = (transpose) ? w.shape(-2) : weight_dims;
|
||||
|
||||
if (w_inner_dims != x_inner_dims) {
|
||||
std::ostringstream msg;
|
||||
@ -3662,14 +3693,23 @@ array quantized_matmul(
|
||||
array x,
|
||||
array w,
|
||||
array scales,
|
||||
array biases,
|
||||
std::optional<array> biases,
|
||||
bool transpose /* = true */,
|
||||
int group_size /* = 64 */,
|
||||
int bits /* = 4 */,
|
||||
QuantizationType type /* = QuantizationType::Affine */,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
// Check and extract the quantized matrix shape against x
|
||||
auto [w_inner_dims, w_outer_dims] = extract_quantized_matmul_dims(
|
||||
"quantized_matmul", x, w, scales, biases, transpose, group_size, bits);
|
||||
"quantized_matmul",
|
||||
x,
|
||||
w,
|
||||
scales,
|
||||
biases,
|
||||
transpose,
|
||||
group_size,
|
||||
bits,
|
||||
type);
|
||||
|
||||
// QuantizedMatmul handles w.ndim == 2 case.
|
||||
if (x.ndim() > 2 && w.ndim() > 2) {
|
||||
@ -3690,37 +3730,53 @@ array quantized_matmul(
|
||||
*(inner_shape.end() - 1) = scales.shape(-1);
|
||||
scales = broadcast_to(scales, inner_shape, s);
|
||||
|
||||
*(inner_shape.end() - 1) = biases.shape(-1);
|
||||
biases = broadcast_to(biases, inner_shape, s);
|
||||
if (biases.has_value()) {
|
||||
*(inner_shape.end() - 1) = biases.value().shape(-1);
|
||||
biases = broadcast_to(biases.value(), inner_shape, s);
|
||||
}
|
||||
}
|
||||
|
||||
auto dtype = result_type(x, scales, biases);
|
||||
auto dtype = result_type(x, scales);
|
||||
if (biases.has_value()) {
|
||||
dtype = promote_types(dtype, biases.value().dtype());
|
||||
}
|
||||
if (!issubdtype(dtype, floating)) {
|
||||
std::ostringstream msg;
|
||||
msg << "[quantized_matmul] Only real floating types are supported but "
|
||||
<< "the passed types where x.dtype() == " << x.dtype()
|
||||
<< ", scales.dtype() == " << scales.dtype()
|
||||
<< " and biases.dtype() == " << biases.dtype();
|
||||
<< ", scales.dtype() == " << scales.dtype();
|
||||
if (biases.has_value()) {
|
||||
msg << " and biases.dtype() == " << biases.value().dtype();
|
||||
}
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
// Prepare the inputs vector
|
||||
std::vector<array> inputs;
|
||||
inputs.reserve(4);
|
||||
inputs.push_back(astype(x, dtype, s));
|
||||
inputs.push_back(w);
|
||||
inputs.push_back(astype(scales, dtype, s));
|
||||
if (biases.has_value()) {
|
||||
inputs.push_back(astype(biases.value(), dtype, s));
|
||||
}
|
||||
|
||||
auto out_shape = x.shape();
|
||||
out_shape.back() = w_outer_dims;
|
||||
|
||||
return array(
|
||||
std::move(out_shape),
|
||||
dtype,
|
||||
std::make_shared<QuantizedMatmul>(
|
||||
to_stream(s), group_size, bits, transpose),
|
||||
{astype(x, dtype, s),
|
||||
w,
|
||||
astype(scales, dtype, s),
|
||||
astype(biases, dtype, s)});
|
||||
to_stream(s), type, group_size, bits, transpose),
|
||||
std::move(inputs));
|
||||
}
|
||||
|
||||
std::tuple<array, array, array> quantize(
|
||||
std::tuple<array, array, std::optional<array>> quantize(
|
||||
const array& w,
|
||||
int group_size /* = 64 */,
|
||||
int bits /* = 4 */,
|
||||
QuantizationType type /* = QuantizationType::Affine */,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
return fast::affine_quantize(w, group_size, bits, s);
|
||||
}
|
||||
@ -3728,31 +3784,40 @@ std::tuple<array, array, array> quantize(
|
||||
array dequantize(
|
||||
const array& w,
|
||||
const array& scales,
|
||||
const array& biases,
|
||||
const std::optional<array>& biases,
|
||||
int group_size /* = 64 */,
|
||||
int bits /* = 4 */,
|
||||
QuantizationType type /* = QuantizationType::Affine */,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
return fast::affine_dequantize(w, scales, biases, group_size, bits, s);
|
||||
return fast::affine_dequantize(
|
||||
w, scales, biases.value(), group_size, bits, s);
|
||||
}
|
||||
|
||||
array gather_qmm(
|
||||
const array& x,
|
||||
const array& w,
|
||||
const array& scales,
|
||||
const array& biases,
|
||||
std::optional<array> lhs_indices_ /* = std::nullopt */,
|
||||
std::optional<array> rhs_indices_ /* = std::nullopt */,
|
||||
const std::optional<array>& biases,
|
||||
const std::optional<array>& lhs_indices_ /* = std::nullopt */,
|
||||
const std::optional<array>& rhs_indices_ /* = std::nullopt */,
|
||||
bool transpose /* = true */,
|
||||
int group_size /* = 64 */,
|
||||
int bits /* = 4 */,
|
||||
QuantizationType type /* = QuantizationType::Affine */,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
if (!lhs_indices_ && !rhs_indices_) {
|
||||
return quantized_matmul(
|
||||
x, w, scales, biases, transpose, group_size, bits, s);
|
||||
x, w, scales, biases, transpose, group_size, bits, type, s);
|
||||
}
|
||||
|
||||
if (type != QuantizationType::Affine) {
|
||||
std::ostringstream msg;
|
||||
msg << "[gather_qmm] Only quantization type '" << type << "' supported";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
auto [w_inner_dims, w_outer_dims] = extract_quantized_matmul_dims(
|
||||
"gather_qmm", x, w, scales, biases, transpose, group_size, bits);
|
||||
"gather_qmm", x, w, scales, biases, transpose, group_size, bits, type);
|
||||
|
||||
// Extract indices and broadcast them
|
||||
array lhs_indices = indices_or_default(lhs_indices_, x, s);
|
||||
@ -3768,16 +3833,17 @@ array gather_qmm(
|
||||
out_shape.push_back(w_outer_dims);
|
||||
|
||||
// and output type
|
||||
auto out_type = result_type(x, scales, biases);
|
||||
auto out_type = result_type(x, scales, biases.value());
|
||||
|
||||
return array(
|
||||
std::move(out_shape),
|
||||
out_type,
|
||||
std::make_shared<GatherQMM>(to_stream(s), group_size, bits, transpose),
|
||||
std::make_shared<GatherQMM>(
|
||||
to_stream(s), type, group_size, bits, transpose),
|
||||
{astype(x, out_type, s),
|
||||
w,
|
||||
astype(scales, out_type, s),
|
||||
astype(biases, out_type, s),
|
||||
astype(biases.value(), out_type, s),
|
||||
lhs_indices,
|
||||
rhs_indices});
|
||||
}
|
||||
|
18
mlx/ops.h
18
mlx/ops.h
@ -1277,31 +1277,34 @@ array conv_transpose3d(
|
||||
int groups = 1,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
/** Quantized matmul multiplies x with a quantized matrix w*/
|
||||
/** Quantized matmul multiplies x with a quantized matrix w */
|
||||
array quantized_matmul(
|
||||
array x,
|
||||
array w,
|
||||
array scales,
|
||||
array biases,
|
||||
std::optional<array> biases,
|
||||
bool transpose = true,
|
||||
int group_size = 64,
|
||||
int bits = 4,
|
||||
QuantizationType type = QuantizationType::Affine,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
/** Quantize a matrix along its last axis */
|
||||
std::tuple<array, array, array> quantize(
|
||||
std::tuple<array, array, std::optional<array>> quantize(
|
||||
const array& w,
|
||||
int group_size = 64,
|
||||
int bits = 4,
|
||||
QuantizationType type = QuantizationType::Affine,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
/** Dequantize a matrix produced by quantize() */
|
||||
array dequantize(
|
||||
const array& w,
|
||||
const array& scales,
|
||||
const array& biases,
|
||||
const std::optional<array>& biases,
|
||||
int group_size = 64,
|
||||
int bits = 4,
|
||||
QuantizationType type = QuantizationType::Affine,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
/** Compute matrix products with matrix-level gather. */
|
||||
@ -1309,12 +1312,13 @@ array gather_qmm(
|
||||
const array& x,
|
||||
const array& w,
|
||||
const array& scales,
|
||||
const array& biases,
|
||||
std::optional<array> lhs_indices = std::nullopt,
|
||||
std::optional<array> rhs_indices = std::nullopt,
|
||||
const std::optional<array>& biases,
|
||||
const std::optional<array>& lhs_indices = std::nullopt,
|
||||
const std::optional<array>& rhs_indices = std::nullopt,
|
||||
bool transpose = true,
|
||||
int group_size = 64,
|
||||
int bits = 4,
|
||||
QuantizationType type = QuantizationType::Affine,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
/** Returns a contraction of a and b over multiple dimensions. */
|
||||
|
@ -2777,10 +2777,11 @@ std::vector<array> QuantizedMatmul::vjp(
|
||||
cotangents[0],
|
||||
primals[1],
|
||||
primals[2],
|
||||
primals[3],
|
||||
(primals.size() > 3) ? std::optional(primals[3]) : std::nullopt,
|
||||
!transpose_,
|
||||
group_size_,
|
||||
bits_,
|
||||
type_,
|
||||
stream()));
|
||||
}
|
||||
|
||||
@ -2855,6 +2856,7 @@ std::vector<array> GatherQMM::vjp(
|
||||
!transpose_,
|
||||
group_size_,
|
||||
bits_,
|
||||
type_,
|
||||
stream()),
|
||||
-3,
|
||||
stream()),
|
||||
|
@ -8,6 +8,7 @@
|
||||
#include "mlx/device.h"
|
||||
#include "mlx/io/load.h"
|
||||
#include "mlx/stream.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
#define DEFINE_VMAP() \
|
||||
virtual std::pair<std::vector<array>, std::vector<int>> vmap( \
|
||||
@ -1568,10 +1569,12 @@ class QuantizedMatmul : public UnaryPrimitive {
|
||||
public:
|
||||
explicit QuantizedMatmul(
|
||||
Stream stream,
|
||||
QuantizationType type,
|
||||
int group_size,
|
||||
int bits,
|
||||
bool transpose)
|
||||
: UnaryPrimitive(stream),
|
||||
type_(type),
|
||||
group_size_(group_size),
|
||||
bits_(bits),
|
||||
transpose_(transpose) {}
|
||||
@ -1586,6 +1589,7 @@ class QuantizedMatmul : public UnaryPrimitive {
|
||||
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
|
||||
|
||||
private:
|
||||
QuantizationType type_;
|
||||
int group_size_;
|
||||
int bits_;
|
||||
bool transpose_;
|
||||
@ -1595,8 +1599,14 @@ class QuantizedMatmul : public UnaryPrimitive {
|
||||
|
||||
class GatherQMM : public UnaryPrimitive {
|
||||
public:
|
||||
explicit GatherQMM(Stream stream, int group_size, int bits, bool transpose)
|
||||
explicit GatherQMM(
|
||||
Stream stream,
|
||||
QuantizationType type,
|
||||
int group_size,
|
||||
int bits,
|
||||
bool transpose)
|
||||
: UnaryPrimitive(stream),
|
||||
type_(type),
|
||||
group_size_(group_size),
|
||||
bits_(bits),
|
||||
transpose_(transpose) {}
|
||||
@ -1610,6 +1620,7 @@ class GatherQMM : public UnaryPrimitive {
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
|
||||
private:
|
||||
QuantizationType type_;
|
||||
int group_size_;
|
||||
int bits_;
|
||||
bool transpose_;
|
||||
|
@ -145,6 +145,19 @@ std::ostream& operator<<(std::ostream& os, uint8_t x) {
|
||||
return os;
|
||||
}
|
||||
|
||||
std::ostream& operator<<(std::ostream& os, QuantizationType type) {
|
||||
std::string_view quantization_type;
|
||||
switch (type) {
|
||||
case QuantizationType::Affine:
|
||||
quantization_type = "affine";
|
||||
break;
|
||||
case QuantizationType::AffinePacked:
|
||||
quantization_type = "affine-packed";
|
||||
break;
|
||||
}
|
||||
return os << quantization_type;
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
inline size_t
|
||||
|
12
mlx/utils.h
12
mlx/utils.h
@ -100,6 +100,18 @@ inline int next_power_of_2(int n) {
|
||||
return pow(2, std::ceil(std::log2(n)));
|
||||
}
|
||||
|
||||
/** Enumerate the different quantization types */
|
||||
enum class QuantizationType {
|
||||
// Traditional affine quantization with separate scales and biases
|
||||
Affine,
|
||||
|
||||
// The same quantization as affine but with the scales and biases packed in a
|
||||
// single array and contiguously every 4 rows
|
||||
AffinePacked,
|
||||
};
|
||||
|
||||
std::ostream& operator<<(std::ostream& os, QuantizationType type);
|
||||
|
||||
namespace env {
|
||||
|
||||
int get_var(const char* name, int default_value);
|
||||
|
Loading…
Reference in New Issue
Block a user