mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Batched Quantized Matmul + Fast Small QMV (#1503)
* add fast qmv for small dims * fix test * batched cpu * add batched template param * refactor metal quantized.cpp
This commit is contained in:
34
mlx/ops.cpp
34
mlx/ops.cpp
@@ -3592,10 +3592,10 @@ array conv_general(
|
||||
}
|
||||
|
||||
array quantized_matmul(
|
||||
const array& x,
|
||||
const array& w,
|
||||
const array& scales,
|
||||
const array& biases,
|
||||
array x,
|
||||
array w,
|
||||
array scales,
|
||||
array biases,
|
||||
bool transpose /* = true */,
|
||||
int group_size /* = 64 */,
|
||||
int bits /* = 4 */,
|
||||
@@ -3604,11 +3604,27 @@ array quantized_matmul(
|
||||
auto [w_inner_dims, w_outer_dims] = extract_quantized_matmul_dims(
|
||||
"quantized_matmul", x, w, scales, biases, transpose, group_size, bits);
|
||||
|
||||
if (w.ndim() != 2) {
|
||||
std::ostringstream msg;
|
||||
msg << "[quantized_matmul] Batched quantized matmul is not supported for now "
|
||||
<< "received w with shape " << w.shape();
|
||||
throw std::invalid_argument(msg.str());
|
||||
// QuantizedMatmul handles w.ndim == 2 case.
|
||||
if (x.ndim() > 2 && w.ndim() > 2) {
|
||||
std::vector<int> bsx_x(x.shape().begin(), x.shape().end() - 2);
|
||||
std::vector<int> bsx_w(w.shape().begin(), w.shape().end() - 2);
|
||||
auto inner_shape = broadcast_shapes(bsx_x, bsx_w);
|
||||
|
||||
// Broadcast x
|
||||
inner_shape.push_back(x.shape(-2));
|
||||
inner_shape.push_back(x.shape(-1));
|
||||
x = broadcast_to(x, inner_shape, s);
|
||||
|
||||
// Broadcast w
|
||||
*(inner_shape.end() - 2) = w.shape(-2);
|
||||
*(inner_shape.end() - 1) = w.shape(-1);
|
||||
w = broadcast_to(w, inner_shape, s);
|
||||
|
||||
*(inner_shape.end() - 1) = scales.shape(-1);
|
||||
scales = broadcast_to(scales, inner_shape, s);
|
||||
|
||||
*(inner_shape.end() - 1) = biases.shape(-1);
|
||||
biases = broadcast_to(biases, inner_shape, s);
|
||||
}
|
||||
|
||||
auto dtype = result_type(x, scales, biases);
|
||||
|
||||
Reference in New Issue
Block a user