mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 01:41:17 +08:00
No quant reshape (#957)
* precise option on cpu * remove print * remove reshape in quant matmul * no quant reshape
This commit is contained in:
parent
d88d2124b5
commit
039da779d1
22
mlx/ops.cpp
22
mlx/ops.cpp
@ -3057,7 +3057,6 @@ array quantized_matmul(
|
|||||||
int bits /* = 4 */,
|
int bits /* = 4 */,
|
||||||
StreamOrDevice s /* = {} */) {
|
StreamOrDevice s /* = {} */) {
|
||||||
array x = in_x;
|
array x = in_x;
|
||||||
|
|
||||||
if (w.dtype() != uint32) {
|
if (w.dtype() != uint32) {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
msg << "[quantized_matmul] The weight matrix should be uint32 "
|
msg << "[quantized_matmul] The weight matrix should be uint32 "
|
||||||
@ -3074,12 +3073,6 @@ array quantized_matmul(
|
|||||||
// Keep x's batch dimensions to reshape it back after the matmul
|
// Keep x's batch dimensions to reshape it back after the matmul
|
||||||
auto original_shape = x.shape();
|
auto original_shape = x.shape();
|
||||||
int x_inner_dims = original_shape.back();
|
int x_inner_dims = original_shape.back();
|
||||||
original_shape.pop_back();
|
|
||||||
|
|
||||||
// Reshape x into a matrix if it isn't already one
|
|
||||||
if (x.ndim() != 2) {
|
|
||||||
x = reshape(x, {-1, x_inner_dims}, s);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (scales.ndim() != 2 || scales.shape() != biases.shape()) {
|
if (scales.ndim() != 2 || scales.shape() != biases.shape()) {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
@ -3122,9 +3115,10 @@ array quantized_matmul(
|
|||||||
<< " and biases.dtype() == " << biases.dtype();
|
<< " and biases.dtype() == " << biases.dtype();
|
||||||
throw std::invalid_argument(msg.str());
|
throw std::invalid_argument(msg.str());
|
||||||
}
|
}
|
||||||
|
std::vector<array> inputs;
|
||||||
auto out = array(
|
original_shape.back() = w_outer_dims;
|
||||||
{x.shape(0), w_outer_dims},
|
return array(
|
||||||
|
std::move(original_shape),
|
||||||
dtype,
|
dtype,
|
||||||
std::make_shared<QuantizedMatmul>(
|
std::make_shared<QuantizedMatmul>(
|
||||||
to_stream(s), group_size, bits, transpose),
|
to_stream(s), group_size, bits, transpose),
|
||||||
@ -3132,14 +3126,6 @@ array quantized_matmul(
|
|||||||
w,
|
w,
|
||||||
astype(scales, dtype, s),
|
astype(scales, dtype, s),
|
||||||
astype(biases, dtype, s)});
|
astype(biases, dtype, s)});
|
||||||
|
|
||||||
// If needed reshape x to the original batch shape
|
|
||||||
if (original_shape.size() != 1) {
|
|
||||||
original_shape.push_back(w_outer_dims);
|
|
||||||
out = reshape(out, std::move(original_shape), s);
|
|
||||||
}
|
|
||||||
|
|
||||||
return out;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
std::tuple<array, array, array> quantize(
|
std::tuple<array, array, array> quantize(
|
||||||
|
@ -47,6 +47,36 @@ class TestQuantized(mlx_tests.MLXTestCase):
|
|||||||
self.assertEqual(y_q.shape, y_hat.shape)
|
self.assertEqual(y_q.shape, y_hat.shape)
|
||||||
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
|
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
|
||||||
|
|
||||||
|
def test_qmm_vjp(self):
|
||||||
|
key = mx.random.key(0)
|
||||||
|
k1, k2 = mx.random.split(key)
|
||||||
|
|
||||||
|
bits = 8
|
||||||
|
group_size = 64
|
||||||
|
M = 64
|
||||||
|
N = 1024
|
||||||
|
K = 512
|
||||||
|
|
||||||
|
x = mx.random.normal(shape=(2, M, K), key=k1)
|
||||||
|
c = mx.ones(shape=(2, M, N))
|
||||||
|
|
||||||
|
transposes = [True, False]
|
||||||
|
for transposed in transposes:
|
||||||
|
w = mx.random.normal(shape=(N, K) if transposed else (K, N), key=k2)
|
||||||
|
w_q, scales, biases = mx.quantize(w, group_size, bits)
|
||||||
|
|
||||||
|
def fn(x):
|
||||||
|
return mx.quantized_matmul(
|
||||||
|
x, w_q, scales, biases, transposed, group_size, bits
|
||||||
|
)
|
||||||
|
|
||||||
|
_, vjp_out = mx.vjp(fn, primals=(x,), cotangents=(c,))
|
||||||
|
|
||||||
|
expected_out = mx.quantized_matmul(
|
||||||
|
c, w_q, scales, biases, not transposed, group_size, bits
|
||||||
|
)
|
||||||
|
self.assertTrue(mx.allclose(vjp_out[0], expected_out))
|
||||||
|
|
||||||
def test_qmm_shapes(self):
|
def test_qmm_shapes(self):
|
||||||
key = mx.random.key(0)
|
key = mx.random.key(0)
|
||||||
k1, k2 = mx.random.split(key)
|
k1, k2 = mx.random.split(key)
|
||||||
|
Loading…
Reference in New Issue
Block a user