Add a quantization type in the ops

This commit is contained in:
Angelos Katharopoulos 2024-12-12 01:30:38 -08:00
parent 50f3535693
commit bdd68bd893
6 changed files with 152 additions and 44 deletions

View File

@ -75,10 +75,33 @@ std::pair<int, int> extract_quantized_matmul_dims(
const array& x, const array& x,
const array& w, const array& w,
const array& scales, const array& scales,
const array& biases, const std::optional<array>& biases,
bool transpose, bool transpose,
int group_size, int group_size,
int bits) { int bits,
QuantizationType type) {
// Check if we have biases as expected
switch (type) {
case QuantizationType::Affine:
if (!biases.has_value()) {
std::ostringstream msg;
msg << "[" << tag
<< "] The biases argument is required for quantization "
<< "type '" << type << "'";
throw std::invalid_argument(msg.str());
}
break;
case QuantizationType::AffinePacked:
if (biases.has_value()) {
std::ostringstream msg;
msg << "[" << tag << "] Quantization type '" << type
<< "' does not use "
<< "biases but biases were provided";
throw std::invalid_argument(msg.str());
}
break;
}
if (w.dtype() != uint32) { if (w.dtype() != uint32) {
std::ostringstream msg; std::ostringstream msg;
msg << "[" << tag << "] The weight matrix should be uint32 " msg << "[" << tag << "] The weight matrix should be uint32 "
@ -86,11 +109,11 @@ std::pair<int, int> extract_quantized_matmul_dims(
throw std::invalid_argument(msg.str()); throw std::invalid_argument(msg.str());
} }
if (scales.shape() != biases.shape()) { if (biases.has_value() && scales.shape() != biases.value().shape()) {
std::ostringstream msg; std::ostringstream msg;
msg << "[" << tag << "] Scales and biases should have the same shape. " msg << "[" << tag << "] Scales and biases should have the same shape. "
<< "Received scales with shape " << scales.shape() << "Received scales with shape " << scales.shape()
<< " and biases with " << biases.shape(); << " and biases with " << biases.value().shape();
throw std::invalid_argument(msg.str()); throw std::invalid_argument(msg.str());
} }
@ -99,25 +122,33 @@ std::pair<int, int> extract_quantized_matmul_dims(
std::ostringstream msg; std::ostringstream msg;
msg << "[" << tag msg << "[" << tag
<< "] Weight, scales and biases should have the same batch shape. " << "] Weight, scales and biases should have the same batch shape. "
<< "Received weight with shape " << w.shape() << ", scales with " << "Received weight with shape " << w.shape()
<< scales.shape() << " and biases with " << biases.shape(); << " and scales/biases with " << scales.shape();
throw std::invalid_argument(msg.str()); throw std::invalid_argument(msg.str());
} }
if (w.shape(-1) * 32 / bits != scales.shape(-1) * group_size) { int weight_dims = w.shape(-1) * 32 / bits;
int scales_dims = scales.shape(-1) * group_size;
if (type == QuantizationType::AffinePacked) {
scales_dims /= 8;
}
if (weight_dims != scales_dims) {
std::ostringstream msg; std::ostringstream msg;
msg << "[" << tag << "] The shapes of the weight and scales are " msg << "[" << tag << "] The shapes of the weight and scales are "
<< "incompatible based on bits and group_size. w.shape() == " << "incompatible based on bits, group_size and quantization type. "
<< w.shape() << " and scales.shape() == " << scales.shape() << "w.shape() == " << w.shape()
<< " with group_size=" << group_size << " and bits=" << bits; << " and scales.shape() == " << scales.shape()
<< " with group_size=" << group_size << ", bits=" << bits
<< " and type='" << type << "'";
throw std::invalid_argument(msg.str()); throw std::invalid_argument(msg.str());
} }
int x_inner_dims = x.shape(-1); int x_inner_dims = x.shape(-1);
// Calculate the expanded w's dims // Calculate the expanded w's dims
int w_inner_dims = (transpose) ? w.shape(-1) * 32 / bits : w.shape(-2); int w_inner_dims = (transpose) ? weight_dims : w.shape(-2);
int w_outer_dims = (transpose) ? w.shape(-2) : w.shape(-1) * 32 / bits; int w_outer_dims = (transpose) ? w.shape(-2) : weight_dims;
if (w_inner_dims != x_inner_dims) { if (w_inner_dims != x_inner_dims) {
std::ostringstream msg; std::ostringstream msg;
@ -3662,14 +3693,23 @@ array quantized_matmul(
array x, array x,
array w, array w,
array scales, array scales,
array biases, std::optional<array> biases,
bool transpose /* = true */, bool transpose /* = true */,
int group_size /* = 64 */, int group_size /* = 64 */,
int bits /* = 4 */, int bits /* = 4 */,
QuantizationType type /* = QuantizationType::Affine */,
StreamOrDevice s /* = {} */) { StreamOrDevice s /* = {} */) {
// Check and extract the quantized matrix shape against x // Check and extract the quantized matrix shape against x
auto [w_inner_dims, w_outer_dims] = extract_quantized_matmul_dims( auto [w_inner_dims, w_outer_dims] = extract_quantized_matmul_dims(
"quantized_matmul", x, w, scales, biases, transpose, group_size, bits); "quantized_matmul",
x,
w,
scales,
biases,
transpose,
group_size,
bits,
type);
// QuantizedMatmul handles w.ndim == 2 case. // QuantizedMatmul handles w.ndim == 2 case.
if (x.ndim() > 2 && w.ndim() > 2) { if (x.ndim() > 2 && w.ndim() > 2) {
@ -3690,37 +3730,53 @@ array quantized_matmul(
*(inner_shape.end() - 1) = scales.shape(-1); *(inner_shape.end() - 1) = scales.shape(-1);
scales = broadcast_to(scales, inner_shape, s); scales = broadcast_to(scales, inner_shape, s);
*(inner_shape.end() - 1) = biases.shape(-1); if (biases.has_value()) {
biases = broadcast_to(biases, inner_shape, s); *(inner_shape.end() - 1) = biases.value().shape(-1);
biases = broadcast_to(biases.value(), inner_shape, s);
}
} }
auto dtype = result_type(x, scales, biases); auto dtype = result_type(x, scales);
if (biases.has_value()) {
dtype = promote_types(dtype, biases.value().dtype());
}
if (!issubdtype(dtype, floating)) { if (!issubdtype(dtype, floating)) {
std::ostringstream msg; std::ostringstream msg;
msg << "[quantized_matmul] Only real floating types are supported but " msg << "[quantized_matmul] Only real floating types are supported but "
<< "the passed types where x.dtype() == " << x.dtype() << "the passed types where x.dtype() == " << x.dtype()
<< ", scales.dtype() == " << scales.dtype() << ", scales.dtype() == " << scales.dtype();
<< " and biases.dtype() == " << biases.dtype(); if (biases.has_value()) {
msg << " and biases.dtype() == " << biases.value().dtype();
}
throw std::invalid_argument(msg.str()); throw std::invalid_argument(msg.str());
} }
// Prepare the inputs vector
std::vector<array> inputs;
inputs.reserve(4);
inputs.push_back(astype(x, dtype, s));
inputs.push_back(w);
inputs.push_back(astype(scales, dtype, s));
if (biases.has_value()) {
inputs.push_back(astype(biases.value(), dtype, s));
}
auto out_shape = x.shape(); auto out_shape = x.shape();
out_shape.back() = w_outer_dims; out_shape.back() = w_outer_dims;
return array( return array(
std::move(out_shape), std::move(out_shape),
dtype, dtype,
std::make_shared<QuantizedMatmul>( std::make_shared<QuantizedMatmul>(
to_stream(s), group_size, bits, transpose), to_stream(s), type, group_size, bits, transpose),
{astype(x, dtype, s), std::move(inputs));
w,
astype(scales, dtype, s),
astype(biases, dtype, s)});
} }
std::tuple<array, array, array> quantize( std::tuple<array, array, std::optional<array>> quantize(
const array& w, const array& w,
int group_size /* = 64 */, int group_size /* = 64 */,
int bits /* = 4 */, int bits /* = 4 */,
QuantizationType type /* = QuantizationType::Affine */,
StreamOrDevice s /* = {} */) { StreamOrDevice s /* = {} */) {
return fast::affine_quantize(w, group_size, bits, s); return fast::affine_quantize(w, group_size, bits, s);
} }
@ -3728,31 +3784,40 @@ std::tuple<array, array, array> quantize(
array dequantize( array dequantize(
const array& w, const array& w,
const array& scales, const array& scales,
const array& biases, const std::optional<array>& biases,
int group_size /* = 64 */, int group_size /* = 64 */,
int bits /* = 4 */, int bits /* = 4 */,
QuantizationType type /* = QuantizationType::Affine */,
StreamOrDevice s /* = {} */) { StreamOrDevice s /* = {} */) {
return fast::affine_dequantize(w, scales, biases, group_size, bits, s); return fast::affine_dequantize(
w, scales, biases.value(), group_size, bits, s);
} }
array gather_qmm( array gather_qmm(
const array& x, const array& x,
const array& w, const array& w,
const array& scales, const array& scales,
const array& biases, const std::optional<array>& biases,
std::optional<array> lhs_indices_ /* = std::nullopt */, const std::optional<array>& lhs_indices_ /* = std::nullopt */,
std::optional<array> rhs_indices_ /* = std::nullopt */, const std::optional<array>& rhs_indices_ /* = std::nullopt */,
bool transpose /* = true */, bool transpose /* = true */,
int group_size /* = 64 */, int group_size /* = 64 */,
int bits /* = 4 */, int bits /* = 4 */,
QuantizationType type /* = QuantizationType::Affine */,
StreamOrDevice s /* = {} */) { StreamOrDevice s /* = {} */) {
if (!lhs_indices_ && !rhs_indices_) { if (!lhs_indices_ && !rhs_indices_) {
return quantized_matmul( return quantized_matmul(
x, w, scales, biases, transpose, group_size, bits, s); x, w, scales, biases, transpose, group_size, bits, type, s);
}
if (type != QuantizationType::Affine) {
std::ostringstream msg;
msg << "[gather_qmm] Only quantization type '" << type << "' supported";
throw std::invalid_argument(msg.str());
} }
auto [w_inner_dims, w_outer_dims] = extract_quantized_matmul_dims( auto [w_inner_dims, w_outer_dims] = extract_quantized_matmul_dims(
"gather_qmm", x, w, scales, biases, transpose, group_size, bits); "gather_qmm", x, w, scales, biases, transpose, group_size, bits, type);
// Extract indices and broadcast them // Extract indices and broadcast them
array lhs_indices = indices_or_default(lhs_indices_, x, s); array lhs_indices = indices_or_default(lhs_indices_, x, s);
@ -3768,16 +3833,17 @@ array gather_qmm(
out_shape.push_back(w_outer_dims); out_shape.push_back(w_outer_dims);
// and output type // and output type
auto out_type = result_type(x, scales, biases); auto out_type = result_type(x, scales, biases.value());
return array( return array(
std::move(out_shape), std::move(out_shape),
out_type, out_type,
std::make_shared<GatherQMM>(to_stream(s), group_size, bits, transpose), std::make_shared<GatherQMM>(
to_stream(s), type, group_size, bits, transpose),
{astype(x, out_type, s), {astype(x, out_type, s),
w, w,
astype(scales, out_type, s), astype(scales, out_type, s),
astype(biases, out_type, s), astype(biases.value(), out_type, s),
lhs_indices, lhs_indices,
rhs_indices}); rhs_indices});
} }

View File

@ -1277,31 +1277,34 @@ array conv_transpose3d(
int groups = 1, int groups = 1,
StreamOrDevice s = {}); StreamOrDevice s = {});
/** Quantized matmul multiplies x with a quantized matrix w*/ /** Quantized matmul multiplies x with a quantized matrix w */
array quantized_matmul( array quantized_matmul(
array x, array x,
array w, array w,
array scales, array scales,
array biases, std::optional<array> biases,
bool transpose = true, bool transpose = true,
int group_size = 64, int group_size = 64,
int bits = 4, int bits = 4,
QuantizationType type = QuantizationType::Affine,
StreamOrDevice s = {}); StreamOrDevice s = {});
/** Quantize a matrix along its last axis */ /** Quantize a matrix along its last axis */
std::tuple<array, array, array> quantize( std::tuple<array, array, std::optional<array>> quantize(
const array& w, const array& w,
int group_size = 64, int group_size = 64,
int bits = 4, int bits = 4,
QuantizationType type = QuantizationType::Affine,
StreamOrDevice s = {}); StreamOrDevice s = {});
/** Dequantize a matrix produced by quantize() */ /** Dequantize a matrix produced by quantize() */
array dequantize( array dequantize(
const array& w, const array& w,
const array& scales, const array& scales,
const array& biases, const std::optional<array>& biases,
int group_size = 64, int group_size = 64,
int bits = 4, int bits = 4,
QuantizationType type = QuantizationType::Affine,
StreamOrDevice s = {}); StreamOrDevice s = {});
/** Compute matrix products with matrix-level gather. */ /** Compute matrix products with matrix-level gather. */
@ -1309,12 +1312,13 @@ array gather_qmm(
const array& x, const array& x,
const array& w, const array& w,
const array& scales, const array& scales,
const array& biases, const std::optional<array>& biases,
std::optional<array> lhs_indices = std::nullopt, const std::optional<array>& lhs_indices = std::nullopt,
std::optional<array> rhs_indices = std::nullopt, const std::optional<array>& rhs_indices = std::nullopt,
bool transpose = true, bool transpose = true,
int group_size = 64, int group_size = 64,
int bits = 4, int bits = 4,
QuantizationType type = QuantizationType::Affine,
StreamOrDevice s = {}); StreamOrDevice s = {});
/** Returns a contraction of a and b over multiple dimensions. */ /** Returns a contraction of a and b over multiple dimensions. */

View File

@ -2777,10 +2777,11 @@ std::vector<array> QuantizedMatmul::vjp(
cotangents[0], cotangents[0],
primals[1], primals[1],
primals[2], primals[2],
primals[3], (primals.size() > 3) ? std::optional(primals[3]) : std::nullopt,
!transpose_, !transpose_,
group_size_, group_size_,
bits_, bits_,
type_,
stream())); stream()));
} }
@ -2855,6 +2856,7 @@ std::vector<array> GatherQMM::vjp(
!transpose_, !transpose_,
group_size_, group_size_,
bits_, bits_,
type_,
stream()), stream()),
-3, -3,
stream()), stream()),

View File

@ -8,6 +8,7 @@
#include "mlx/device.h" #include "mlx/device.h"
#include "mlx/io/load.h" #include "mlx/io/load.h"
#include "mlx/stream.h" #include "mlx/stream.h"
#include "mlx/utils.h"
#define DEFINE_VMAP() \ #define DEFINE_VMAP() \
virtual std::pair<std::vector<array>, std::vector<int>> vmap( \ virtual std::pair<std::vector<array>, std::vector<int>> vmap( \
@ -1568,10 +1569,12 @@ class QuantizedMatmul : public UnaryPrimitive {
public: public:
explicit QuantizedMatmul( explicit QuantizedMatmul(
Stream stream, Stream stream,
QuantizationType type,
int group_size, int group_size,
int bits, int bits,
bool transpose) bool transpose)
: UnaryPrimitive(stream), : UnaryPrimitive(stream),
type_(type),
group_size_(group_size), group_size_(group_size),
bits_(bits), bits_(bits),
transpose_(transpose) {} transpose_(transpose) {}
@ -1586,6 +1589,7 @@ class QuantizedMatmul : public UnaryPrimitive {
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override; std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
private: private:
QuantizationType type_;
int group_size_; int group_size_;
int bits_; int bits_;
bool transpose_; bool transpose_;
@ -1595,8 +1599,14 @@ class QuantizedMatmul : public UnaryPrimitive {
class GatherQMM : public UnaryPrimitive { class GatherQMM : public UnaryPrimitive {
public: public:
explicit GatherQMM(Stream stream, int group_size, int bits, bool transpose) explicit GatherQMM(
Stream stream,
QuantizationType type,
int group_size,
int bits,
bool transpose)
: UnaryPrimitive(stream), : UnaryPrimitive(stream),
type_(type),
group_size_(group_size), group_size_(group_size),
bits_(bits), bits_(bits),
transpose_(transpose) {} transpose_(transpose) {}
@ -1610,6 +1620,7 @@ class GatherQMM : public UnaryPrimitive {
bool is_equivalent(const Primitive& other) const override; bool is_equivalent(const Primitive& other) const override;
private: private:
QuantizationType type_;
int group_size_; int group_size_;
int bits_; int bits_;
bool transpose_; bool transpose_;

View File

@ -145,6 +145,19 @@ std::ostream& operator<<(std::ostream& os, uint8_t x) {
return os; return os;
} }
std::ostream& operator<<(std::ostream& os, QuantizationType type) {
std::string_view quantization_type;
switch (type) {
case QuantizationType::Affine:
quantization_type = "affine";
break;
case QuantizationType::AffinePacked:
quantization_type = "affine-packed";
break;
}
return os << quantization_type;
}
namespace { namespace {
inline size_t inline size_t

View File

@ -100,6 +100,18 @@ inline int next_power_of_2(int n) {
return pow(2, std::ceil(std::log2(n))); return pow(2, std::ceil(std::log2(n)));
} }
/** Enumerate the different quantization types */
enum class QuantizationType {
// Traditional affine quantization with separate scales and biases
Affine,
// The same quantization as affine but with the scales and biases packed in a
// single array and contiguously every 4 rows
AffinePacked,
};
std::ostream& operator<<(std::ostream& os, QuantizationType type);
namespace env { namespace env {
int get_var(const char* name, int default_value); int get_var(const char* name, int default_value);