mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Block sparse qmm (#1124)
This commit is contained in:
committed by
GitHub
parent
1873ffda01
commit
e78a6518fa
263
mlx/ops.cpp
263
mlx/ops.cpp
@@ -50,6 +50,83 @@ Dtype at_least_float(const Dtype& d) {
|
||||
return issubdtype(d, inexact) ? d : promote_types(d, float32);
|
||||
}
|
||||
|
||||
array indices_or_default(
|
||||
std::optional<array> indices,
|
||||
const array& x,
|
||||
StreamOrDevice s) {
|
||||
if (indices.has_value()) {
|
||||
return indices.value();
|
||||
}
|
||||
|
||||
std::vector<int> shape(x.shape().begin(), x.shape().end() - 2);
|
||||
int total =
|
||||
std::reduce(shape.begin(), shape.end(), 1, std::multiplies<int>());
|
||||
return reshape(arange(total, uint32, s), shape, s);
|
||||
}
|
||||
|
||||
std::pair<int, int> extract_quantized_matmul_dims(
|
||||
std::string_view tag,
|
||||
const array& x,
|
||||
const array& w,
|
||||
const array& scales,
|
||||
const array& biases,
|
||||
bool transpose,
|
||||
int group_size,
|
||||
int bits) {
|
||||
if (w.dtype() != uint32) {
|
||||
std::ostringstream msg;
|
||||
msg << "[" << tag << "] The weight matrix should be uint32 "
|
||||
<< "but received" << w.dtype();
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
if (scales.shape() != biases.shape()) {
|
||||
std::ostringstream msg;
|
||||
msg << "[" << tag << "] Scales and biases should have the same shape. "
|
||||
<< "Received scales with shape " << scales.shape()
|
||||
<< " and biases with " << biases.shape();
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
if (!std::equal(
|
||||
w.shape().begin(), w.shape().end() - 2, scales.shape().begin())) {
|
||||
std::ostringstream msg;
|
||||
msg << "[" << tag
|
||||
<< "] Weight, scales and biases should have the same batch shape. "
|
||||
<< "Received weight with shape " << w.shape() << ", scales with "
|
||||
<< scales.shape() << " and biases with " << biases.shape();
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
if (w.shape(-1) * 32 / bits != scales.shape(-1) * group_size) {
|
||||
std::ostringstream msg;
|
||||
msg << "[" << tag << "] The shapes of the weight and scales are "
|
||||
<< "incompatible based on bits and group_size. w.shape() == "
|
||||
<< w.shape() << " and scales.shape() == " << scales.shape()
|
||||
<< " with group_size=" << group_size << " and bits=" << bits;
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
int x_inner_dims = x.shape(-1);
|
||||
|
||||
// Calculate the expanded w's dims
|
||||
int w_inner_dims = (transpose) ? w.shape(-1) * 32 / bits : w.shape(-2);
|
||||
int w_outer_dims = (transpose) ? w.shape(-2) : w.shape(-1) * 32 / bits;
|
||||
|
||||
if (w_inner_dims != x_inner_dims) {
|
||||
std::ostringstream msg;
|
||||
msg << "[" << tag << "] Last dimension of first input with "
|
||||
<< "shape (..., " << x_inner_dims << ") does not match "
|
||||
<< "the expanded quantized matrix (" << w_inner_dims << ", "
|
||||
<< w_outer_dims << ") computed from shape " << w.shape()
|
||||
<< " with group_size=" << group_size << ", bits=" << bits
|
||||
<< " and transpose=" << std::boolalpha << transpose;
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
return {w_inner_dims, w_outer_dims};
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
array arange(
|
||||
@@ -3203,7 +3280,7 @@ array conv_general(
|
||||
}
|
||||
|
||||
array quantized_matmul(
|
||||
const array& in_x,
|
||||
const array& x,
|
||||
const array& w,
|
||||
const array& scales,
|
||||
const array& biases,
|
||||
@@ -3211,13 +3288,10 @@ array quantized_matmul(
|
||||
int group_size /* = 64 */,
|
||||
int bits /* = 4 */,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
array x = in_x;
|
||||
if (w.dtype() != uint32) {
|
||||
std::ostringstream msg;
|
||||
msg << "[quantized_matmul] The weight matrix should be uint32 "
|
||||
<< "but received" << w.dtype();
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
// Check and extract the quantized matrix shape against x
|
||||
auto [w_inner_dims, w_outer_dims] = extract_quantized_matmul_dims(
|
||||
"quantized_matmul", x, w, scales, biases, transpose, group_size, bits);
|
||||
|
||||
if (w.ndim() != 2) {
|
||||
std::ostringstream msg;
|
||||
msg << "[quantized_matmul] Batched quantized matmul is not supported for now "
|
||||
@@ -3225,42 +3299,6 @@ array quantized_matmul(
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
// Keep x's batch dimensions to reshape it back after the matmul
|
||||
auto original_shape = x.shape();
|
||||
int x_inner_dims = original_shape.back();
|
||||
|
||||
if (scales.ndim() != 2 || scales.shape() != biases.shape()) {
|
||||
std::ostringstream msg;
|
||||
msg << "[quantized_matmul] Scales and biases should have the same 2D shape. "
|
||||
<< "Received scales with shape " << scales.shape()
|
||||
<< " and biases with " << biases.shape();
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
if (w.shape(1) * 32 / bits != scales.shape(1) * group_size) {
|
||||
std::ostringstream msg;
|
||||
msg << "[quantized_matmul] The shapes of the weight and scales are "
|
||||
<< "incompatible based on bits and group_size. w.shape() == "
|
||||
<< w.shape() << " and scales.shape() == " << scales.shape()
|
||||
<< " with group_size=" << group_size << " and bits=" << bits;
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
// Calculate the expanded w's dims
|
||||
int w_inner_dims = (transpose) ? w.shape(1) * 32 / bits : w.shape(0);
|
||||
int w_outer_dims = (transpose) ? w.shape(0) : w.shape(1) * 32 / bits;
|
||||
|
||||
if (w_inner_dims != x_inner_dims) {
|
||||
std::ostringstream msg;
|
||||
msg << "[quantized_matmul] Last dimension of first input with "
|
||||
<< "shape (..., " << x_inner_dims << ") does not match "
|
||||
<< "the expanded quantized matrix (" << w_inner_dims << ", "
|
||||
<< w_outer_dims << ") computed from shape " << w.shape()
|
||||
<< " with group_size=" << group_size << ", bits=" << bits
|
||||
<< " and transpose=" << std::boolalpha << transpose;
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
auto dtype = result_type(x, scales, biases);
|
||||
if (!issubdtype(dtype, floating)) {
|
||||
std::ostringstream msg;
|
||||
@@ -3270,10 +3308,11 @@ array quantized_matmul(
|
||||
<< " and biases.dtype() == " << biases.dtype();
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
std::vector<array> inputs;
|
||||
original_shape.back() = w_outer_dims;
|
||||
|
||||
auto out_shape = x.shape();
|
||||
out_shape.back() = w_outer_dims;
|
||||
return array(
|
||||
std::move(original_shape),
|
||||
std::move(out_shape),
|
||||
dtype,
|
||||
std::make_shared<QuantizedMatmul>(
|
||||
to_stream(s), group_size, bits, transpose),
|
||||
@@ -3302,11 +3341,14 @@ std::tuple<array, array, array> quantize(
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
if (w.ndim() != 2) {
|
||||
throw std::invalid_argument("[quantize] Only matrices supported for now");
|
||||
if (w.ndim() < 2) {
|
||||
std::ostringstream msg;
|
||||
msg << "[quantize] The matrix to be quantized must have at least 2 dimension "
|
||||
<< "but it has only " << w.ndim() << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
if ((w.shape(1) % group_size) != 0) {
|
||||
if ((w.shape(-1) % group_size) != 0) {
|
||||
std::ostringstream msg;
|
||||
msg << "[quantize] The last dimension of the matrix needs to be divisible by "
|
||||
<< "the quantization group size " << group_size
|
||||
@@ -3327,7 +3369,7 @@ std::tuple<array, array, array> quantize(
|
||||
// at least we bail out early which will result in a nice readable error.
|
||||
//
|
||||
// Hopefully nobody is quantizing matrices that small anyway.
|
||||
if (w.shape(1) < 32 * el_per_int) {
|
||||
if (w.shape(-1) < 32 * el_per_int) {
|
||||
std::ostringstream msg;
|
||||
msg << "[quantize] The feature dimension (2nd dimension of the matrix) is "
|
||||
<< "too small for quantization. We support >=512 for 2 bits, "
|
||||
@@ -3336,9 +3378,12 @@ std::tuple<array, array, array> quantize(
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
// Prepare the shape for the outputs.
|
||||
auto wshape = w.shape();
|
||||
wshape.back() = -1;
|
||||
|
||||
// Compute scales and biases
|
||||
array packed_w =
|
||||
reshape(w, {w.shape(0), w.shape(1) / group_size, group_size}, s);
|
||||
array packed_w = reshape(w, {-1, w.shape(-1) / group_size, group_size}, s);
|
||||
array w_max = max(packed_w, /* axis= */ -1, /* keepdims= */ true, s);
|
||||
array w_min = min(packed_w, /* axis= */ -1, /* keepdims= */ true, s);
|
||||
|
||||
@@ -3357,12 +3402,14 @@ std::tuple<array, array, array> quantize(
|
||||
zero,
|
||||
n_bins),
|
||||
uint32);
|
||||
packed_w = reshape(packed_w, {w.shape(0), -1, el_per_int}, s);
|
||||
packed_w = reshape(packed_w, {packed_w.shape(0), -1, el_per_int}, s);
|
||||
packed_w = sum(
|
||||
multiply(packed_w, shifts, s), /* axis= */ 2, /* keepdims= */ false, s);
|
||||
|
||||
return std::make_tuple(
|
||||
packed_w, squeeze(scales, -1, s), squeeze(biases, -1, s));
|
||||
reshape(packed_w, wshape, s),
|
||||
reshape(scales, wshape, s),
|
||||
reshape(biases, wshape, s));
|
||||
}
|
||||
|
||||
array dequantize(
|
||||
@@ -3382,11 +3429,21 @@ array dequantize(
|
||||
msg << "[dequantize] Invalid value for group_size: " << group_size;
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
if (w.ndim() != 2 || scales.ndim() != 2 || biases.ndim() != 2) {
|
||||
throw std::invalid_argument("[dequantize] Only matrices supported for now");
|
||||
if (w.ndim() < 2 || scales.ndim() < 2 || biases.ndim() < 2) {
|
||||
std::ostringstream msg;
|
||||
msg << "[quantize] The matrix to be quantized must have at least 2 dimension "
|
||||
<< "but it has only " << w.ndim() << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
if (w.shape(0) != scales.shape(0) || w.shape(0) != biases.shape(0)) {
|
||||
auto wshape = w.shape();
|
||||
auto sshape = scales.shape();
|
||||
auto bshape = biases.shape();
|
||||
wshape.back() = -1;
|
||||
sshape.back() = -1;
|
||||
bshape.back() = -1;
|
||||
|
||||
if (wshape != sshape || wshape != bshape) {
|
||||
throw std::invalid_argument(
|
||||
"[dequantize] Shape of scales and biases does not match the matrix");
|
||||
}
|
||||
@@ -3399,7 +3456,7 @@ array dequantize(
|
||||
// Compute some constants for the dequantization
|
||||
int el_per_int = 32 / bits;
|
||||
|
||||
if (w.shape(1) * el_per_int != scales.shape(1) * group_size) {
|
||||
if (w.shape(-1) * el_per_int != scales.shape(-1) * group_size) {
|
||||
std::ostringstream msg;
|
||||
msg << "[dequantize] Shape of scales and biases does not match the matrix "
|
||||
<< "given the quantization parameters. Provided matrix of shape "
|
||||
@@ -3411,25 +3468,79 @@ array dequantize(
|
||||
// Extract the pieces from the passed quantized matrix
|
||||
std::vector<array> parts;
|
||||
for (int start = 0; start < 32; start += bits) {
|
||||
// TODO: Implement bitwise operators for integral types
|
||||
int shift_left = 32 - (start + bits);
|
||||
int shift_right = shift_left + start;
|
||||
array p = multiply(w, array(1 << shift_left, uint32), s);
|
||||
p = floor_divide(p, array(1 << shift_right, uint32), s);
|
||||
p = expand_dims(p, -1, s);
|
||||
parts.push_back(p);
|
||||
|
||||
parts.push_back(expand_dims(
|
||||
right_shift(
|
||||
left_shift(w, array(32 - (start + bits), uint32), s),
|
||||
array(32 - bits, uint32),
|
||||
s),
|
||||
-1,
|
||||
s));
|
||||
}
|
||||
array w_full = concatenate(parts, -1, s);
|
||||
|
||||
// Dequantize
|
||||
w_full = reshape(w_full, {w.shape(0), -1, group_size}, s);
|
||||
wshape.push_back(group_size);
|
||||
w_full = reshape(w_full, wshape, s);
|
||||
w_full = multiply(w_full, expand_dims(scales, -1, s), s);
|
||||
w_full = add(w_full, expand_dims(biases, -1, s), s);
|
||||
w_full = reshape(w_full, {w.shape(0), -1}, s);
|
||||
w_full = reshape(w_full, sshape, s);
|
||||
|
||||
return w_full;
|
||||
}
|
||||
|
||||
array block_sparse_qmm(
|
||||
const array& x,
|
||||
const array& w,
|
||||
const array& scales,
|
||||
const array& biases,
|
||||
std::optional<array> lhs_indices_ /* = std::nullopt */,
|
||||
std::optional<array> rhs_indices_ /* = std::nullopt */,
|
||||
bool transpose /* = true */,
|
||||
int group_size /* = 64 */,
|
||||
int bits /* = 4 */,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
if (!lhs_indices_ && !rhs_indices_) {
|
||||
return quantized_matmul(
|
||||
x, w, scales, biases, transpose, group_size, bits, s);
|
||||
}
|
||||
|
||||
auto [w_inner_dims, w_outer_dims] = extract_quantized_matmul_dims(
|
||||
"block_sparse_qmm", x, w, scales, biases, transpose, group_size, bits);
|
||||
|
||||
// Extract indices and broadcast them
|
||||
array lhs_indices = indices_or_default(lhs_indices_, x, s);
|
||||
array rhs_indices = indices_or_default(rhs_indices_, w, s);
|
||||
auto out_bsx_shape =
|
||||
broadcast_shapes(lhs_indices.shape(), rhs_indices.shape());
|
||||
lhs_indices = broadcast_to(lhs_indices, out_bsx_shape, s);
|
||||
rhs_indices = broadcast_to(rhs_indices, out_bsx_shape, s);
|
||||
|
||||
// Compute the full output shape
|
||||
auto out_shape = out_bsx_shape;
|
||||
out_shape.push_back(x.shape(-2));
|
||||
out_shape.push_back(w_outer_dims);
|
||||
|
||||
// and output type
|
||||
auto out_type = result_type(x, scales, biases);
|
||||
|
||||
auto out = array(
|
||||
std::move(out_shape),
|
||||
out_type,
|
||||
std::make_shared<BlockSparseQMM>(
|
||||
to_stream(s), group_size, bits, transpose),
|
||||
{astype(x, out_type, s),
|
||||
w,
|
||||
astype(scales, out_type, s),
|
||||
astype(biases, out_type, s),
|
||||
lhs_indices,
|
||||
rhs_indices});
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
array tensordot(
|
||||
const array& a,
|
||||
const array& b,
|
||||
@@ -3879,24 +3990,8 @@ array block_sparse_mm(
|
||||
b = astype(b, out_type, s);
|
||||
|
||||
// Handle broadcasting
|
||||
std::vector<int> bsx_a(a.shape().begin(), a.shape().end() - 2);
|
||||
std::vector<int> bsx_b(b.shape().begin(), b.shape().end() - 2);
|
||||
|
||||
auto indices_or_default = [&](const std::optional<array>& indices,
|
||||
const std::vector<int>& bsx_shape) {
|
||||
if (indices.has_value()) {
|
||||
return indices.value();
|
||||
} else {
|
||||
int n_batch = 1;
|
||||
for (auto& i : bsx_shape)
|
||||
n_batch *= i;
|
||||
return reshape(arange(n_batch, uint32, s), bsx_shape, s);
|
||||
}
|
||||
};
|
||||
|
||||
// Pull and broadcast indices
|
||||
array lhs_indices = indices_or_default(lhs_indices_, bsx_a);
|
||||
array rhs_indices = indices_or_default(rhs_indices_, bsx_b);
|
||||
array lhs_indices = indices_or_default(lhs_indices_, a, s);
|
||||
array rhs_indices = indices_or_default(rhs_indices_, b, s);
|
||||
|
||||
if (!issubdtype(lhs_indices.dtype(), integer)) {
|
||||
throw std::invalid_argument(
|
||||
|
||||
Reference in New Issue
Block a user