mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 01:41:17 +08:00
No quant reshape (#957)
* precise option on cpu * remove print * remove reshape in quant matmul * no quant reshape
This commit is contained in:
parent
d88d2124b5
commit
039da779d1
22
mlx/ops.cpp
22
mlx/ops.cpp
@ -3057,7 +3057,6 @@ array quantized_matmul(
|
||||
int bits /* = 4 */,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
array x = in_x;
|
||||
|
||||
if (w.dtype() != uint32) {
|
||||
std::ostringstream msg;
|
||||
msg << "[quantized_matmul] The weight matrix should be uint32 "
|
||||
@ -3074,12 +3073,6 @@ array quantized_matmul(
|
||||
// Keep x's batch dimensions to reshape it back after the matmul
|
||||
auto original_shape = x.shape();
|
||||
int x_inner_dims = original_shape.back();
|
||||
original_shape.pop_back();
|
||||
|
||||
// Reshape x into a matrix if it isn't already one
|
||||
if (x.ndim() != 2) {
|
||||
x = reshape(x, {-1, x_inner_dims}, s);
|
||||
}
|
||||
|
||||
if (scales.ndim() != 2 || scales.shape() != biases.shape()) {
|
||||
std::ostringstream msg;
|
||||
@ -3122,9 +3115,10 @@ array quantized_matmul(
|
||||
<< " and biases.dtype() == " << biases.dtype();
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
auto out = array(
|
||||
{x.shape(0), w_outer_dims},
|
||||
std::vector<array> inputs;
|
||||
original_shape.back() = w_outer_dims;
|
||||
return array(
|
||||
std::move(original_shape),
|
||||
dtype,
|
||||
std::make_shared<QuantizedMatmul>(
|
||||
to_stream(s), group_size, bits, transpose),
|
||||
@ -3132,14 +3126,6 @@ array quantized_matmul(
|
||||
w,
|
||||
astype(scales, dtype, s),
|
||||
astype(biases, dtype, s)});
|
||||
|
||||
// If needed reshape x to the original batch shape
|
||||
if (original_shape.size() != 1) {
|
||||
original_shape.push_back(w_outer_dims);
|
||||
out = reshape(out, std::move(original_shape), s);
|
||||
}
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
std::tuple<array, array, array> quantize(
|
||||
|
@ -47,6 +47,36 @@ class TestQuantized(mlx_tests.MLXTestCase):
|
||||
self.assertEqual(y_q.shape, y_hat.shape)
|
||||
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
|
||||
|
||||
def test_qmm_vjp(self):
|
||||
key = mx.random.key(0)
|
||||
k1, k2 = mx.random.split(key)
|
||||
|
||||
bits = 8
|
||||
group_size = 64
|
||||
M = 64
|
||||
N = 1024
|
||||
K = 512
|
||||
|
||||
x = mx.random.normal(shape=(2, M, K), key=k1)
|
||||
c = mx.ones(shape=(2, M, N))
|
||||
|
||||
transposes = [True, False]
|
||||
for transposed in transposes:
|
||||
w = mx.random.normal(shape=(N, K) if transposed else (K, N), key=k2)
|
||||
w_q, scales, biases = mx.quantize(w, group_size, bits)
|
||||
|
||||
def fn(x):
|
||||
return mx.quantized_matmul(
|
||||
x, w_q, scales, biases, transposed, group_size, bits
|
||||
)
|
||||
|
||||
_, vjp_out = mx.vjp(fn, primals=(x,), cotangents=(c,))
|
||||
|
||||
expected_out = mx.quantized_matmul(
|
||||
c, w_q, scales, biases, not transposed, group_size, bits
|
||||
)
|
||||
self.assertTrue(mx.allclose(vjp_out[0], expected_out))
|
||||
|
||||
def test_qmm_shapes(self):
|
||||
key = mx.random.key(0)
|
||||
k1, k2 = mx.random.split(key)
|
||||
|
Loading…
Reference in New Issue
Block a user