No quant reshape (#957)

* precise option on cpu

* remove print

* remove reshape in quant matmul

* no quant reshape
This commit is contained in:
Awni Hannun 2024-04-04 11:52:12 -07:00 committed by GitHub
parent d88d2124b5
commit 039da779d1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 34 additions and 18 deletions

View File

@ -3057,7 +3057,6 @@ array quantized_matmul(
int bits /* = 4 */, int bits /* = 4 */,
StreamOrDevice s /* = {} */) { StreamOrDevice s /* = {} */) {
array x = in_x; array x = in_x;
if (w.dtype() != uint32) { if (w.dtype() != uint32) {
std::ostringstream msg; std::ostringstream msg;
msg << "[quantized_matmul] The weight matrix should be uint32 " msg << "[quantized_matmul] The weight matrix should be uint32 "
@ -3074,12 +3073,6 @@ array quantized_matmul(
// Keep x's batch dimensions to reshape it back after the matmul // Keep x's batch dimensions to reshape it back after the matmul
auto original_shape = x.shape(); auto original_shape = x.shape();
int x_inner_dims = original_shape.back(); int x_inner_dims = original_shape.back();
original_shape.pop_back();
// Reshape x into a matrix if it isn't already one
if (x.ndim() != 2) {
x = reshape(x, {-1, x_inner_dims}, s);
}
if (scales.ndim() != 2 || scales.shape() != biases.shape()) { if (scales.ndim() != 2 || scales.shape() != biases.shape()) {
std::ostringstream msg; std::ostringstream msg;
@ -3122,9 +3115,10 @@ array quantized_matmul(
<< " and biases.dtype() == " << biases.dtype(); << " and biases.dtype() == " << biases.dtype();
throw std::invalid_argument(msg.str()); throw std::invalid_argument(msg.str());
} }
std::vector<array> inputs;
auto out = array( original_shape.back() = w_outer_dims;
{x.shape(0), w_outer_dims}, return array(
std::move(original_shape),
dtype, dtype,
std::make_shared<QuantizedMatmul>( std::make_shared<QuantizedMatmul>(
to_stream(s), group_size, bits, transpose), to_stream(s), group_size, bits, transpose),
@ -3132,14 +3126,6 @@ array quantized_matmul(
w, w,
astype(scales, dtype, s), astype(scales, dtype, s),
astype(biases, dtype, s)}); astype(biases, dtype, s)});
// If needed reshape x to the original batch shape
if (original_shape.size() != 1) {
original_shape.push_back(w_outer_dims);
out = reshape(out, std::move(original_shape), s);
}
return out;
} }
std::tuple<array, array, array> quantize( std::tuple<array, array, array> quantize(

View File

@ -47,6 +47,36 @@ class TestQuantized(mlx_tests.MLXTestCase):
self.assertEqual(y_q.shape, y_hat.shape) self.assertEqual(y_q.shape, y_hat.shape)
self.assertLess((y_q - y_hat).abs().max(), 1e-3) self.assertLess((y_q - y_hat).abs().max(), 1e-3)
def test_qmm_vjp(self):
key = mx.random.key(0)
k1, k2 = mx.random.split(key)
bits = 8
group_size = 64
M = 64
N = 1024
K = 512
x = mx.random.normal(shape=(2, M, K), key=k1)
c = mx.ones(shape=(2, M, N))
transposes = [True, False]
for transposed in transposes:
w = mx.random.normal(shape=(N, K) if transposed else (K, N), key=k2)
w_q, scales, biases = mx.quantize(w, group_size, bits)
def fn(x):
return mx.quantized_matmul(
x, w_q, scales, biases, transposed, group_size, bits
)
_, vjp_out = mx.vjp(fn, primals=(x,), cotangents=(c,))
expected_out = mx.quantized_matmul(
c, w_q, scales, biases, not transposed, group_size, bits
)
self.assertTrue(mx.allclose(vjp_out[0], expected_out))
def test_qmm_shapes(self): def test_qmm_shapes(self):
key = mx.random.key(0) key = mx.random.key(0)
k1, k2 = mx.random.split(key) k1, k2 = mx.random.split(key)