mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-28 21:21:21 +08:00
Fix qvm splitk (#2415)
This commit is contained in:
parent
9acec364c2
commit
5597fa089c
@ -265,9 +265,15 @@ void qvm_split_k(
|
||||
MTL::Size group_dims = MTL::Size(bk, 2, 1);
|
||||
MTL::Size grid_dims = MTL::Size(M, N / bn, B);
|
||||
|
||||
int x_batch_ndims = x.ndim() - 2;
|
||||
auto x_shape = x.shape();
|
||||
auto x_strides = x.strides();
|
||||
if (x_shape.size() == 1) {
|
||||
x_shape.insert(x_shape.begin(), 1);
|
||||
x_strides.insert(x_strides.begin(), 0);
|
||||
}
|
||||
|
||||
int x_ndim = x_shape.size();
|
||||
int x_batch_ndims = x_ndim - 2;
|
||||
int w_batch_ndims = w.ndim() - 2;
|
||||
auto w_shape = w.shape();
|
||||
auto w_strides = w.strides();
|
||||
@ -278,7 +284,7 @@ void qvm_split_k(
|
||||
x_shape.insert(x_shape.end() - 2, split_k);
|
||||
x_shape.back() /= split_k;
|
||||
x_strides.insert(x_strides.end() - 2, split_D);
|
||||
x_strides[x.ndim() - 1] = split_D;
|
||||
x_strides[x_ndim - 1] = split_D;
|
||||
x_batch_ndims += 1;
|
||||
|
||||
w_shape.insert(w_shape.end() - 2, split_k);
|
||||
@ -291,6 +297,9 @@ void qvm_split_k(
|
||||
int final_block_size = K - (split_k - 1) * split_D;
|
||||
|
||||
auto temp_shape = out.shape();
|
||||
if (temp_shape.size() == 1) {
|
||||
temp_shape.insert(temp_shape.begin(), 1);
|
||||
}
|
||||
temp_shape.insert(temp_shape.end() - 2, split_k);
|
||||
array intermediate(temp_shape, x.dtype(), nullptr, {});
|
||||
intermediate.set_data(allocator::malloc(intermediate.nbytes()));
|
||||
|
@ -220,6 +220,19 @@ class TestQuantized(mlx_tests.MLXTestCase):
|
||||
self.assertEqual(y_q.shape, y_hat.shape)
|
||||
self.assertLess((y_q - y_hat).abs().max(), 2e-3)
|
||||
|
||||
# Test with 1D vector
|
||||
group_size = 32
|
||||
bits = 8
|
||||
N = 2048
|
||||
x = 1e-1 * mx.random.normal(shape=(N,), key=k1)
|
||||
w = 1e-1 * mx.random.normal(shape=(N, N), key=k2)
|
||||
w_q, scales, biases = mx.quantize(w, group_size, bits)
|
||||
w_hat = mx.dequantize(w_q, scales, biases, group_size, bits)
|
||||
y_q = mx.quantized_matmul(x, w_q, scales, biases, False, group_size, bits)
|
||||
y_hat = x @ w_hat
|
||||
self.assertEqual(y_q.shape, y_hat.shape)
|
||||
self.assertLess((y_q - y_hat).abs().max(), 2e-3)
|
||||
|
||||
def test_throw(self):
|
||||
x = mx.random.normal(shape=(10, 512))
|
||||
w = mx.random.normal(shape=(32, 512))
|
||||
|
Loading…
Reference in New Issue
Block a user