Batched Quantized Matmul + Fast Small QMV (#1503)

* add fast qmv for small dims

* fix test

* batched cpu

* add batched template param

* refactor metal quantized.cpp
This commit is contained in:
Alex Barron
2024-10-21 16:23:17 -07:00
committed by GitHub
parent 58a855682c
commit d15fa13daf
9 changed files with 866 additions and 761 deletions

View File

@@ -3592,10 +3592,10 @@ array conv_general(
}
array quantized_matmul(
const array& x,
const array& w,
const array& scales,
const array& biases,
array x,
array w,
array scales,
array biases,
bool transpose /* = true */,
int group_size /* = 64 */,
int bits /* = 4 */,
@@ -3604,11 +3604,27 @@ array quantized_matmul(
auto [w_inner_dims, w_outer_dims] = extract_quantized_matmul_dims(
"quantized_matmul", x, w, scales, biases, transpose, group_size, bits);
if (w.ndim() != 2) {
std::ostringstream msg;
msg << "[quantized_matmul] Batched quantized matmul is not supported for now "
<< "received w with shape " << w.shape();
throw std::invalid_argument(msg.str());
// QuantizedMatmul handles w.ndim == 2 case.
if (x.ndim() > 2 && w.ndim() > 2) {
std::vector<int> bsx_x(x.shape().begin(), x.shape().end() - 2);
std::vector<int> bsx_w(w.shape().begin(), w.shape().end() - 2);
auto inner_shape = broadcast_shapes(bsx_x, bsx_w);
// Broadcast x
inner_shape.push_back(x.shape(-2));
inner_shape.push_back(x.shape(-1));
x = broadcast_to(x, inner_shape, s);
// Broadcast w
*(inner_shape.end() - 2) = w.shape(-2);
*(inner_shape.end() - 1) = w.shape(-1);
w = broadcast_to(w, inner_shape, s);
*(inner_shape.end() - 1) = scales.shape(-1);
scales = broadcast_to(scales, inner_shape, s);
*(inner_shape.end() - 1) = biases.shape(-1);
biases = broadcast_to(biases, inner_shape, s);
}
auto dtype = result_type(x, scales, biases);