mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-29 13:55:29 +08:00
Fix qvm splitk (#2415)
This commit is contained in:
parent
9acec364c2
commit
5597fa089c
@ -265,9 +265,15 @@ void qvm_split_k(
|
|||||||
MTL::Size group_dims = MTL::Size(bk, 2, 1);
|
MTL::Size group_dims = MTL::Size(bk, 2, 1);
|
||||||
MTL::Size grid_dims = MTL::Size(M, N / bn, B);
|
MTL::Size grid_dims = MTL::Size(M, N / bn, B);
|
||||||
|
|
||||||
int x_batch_ndims = x.ndim() - 2;
|
|
||||||
auto x_shape = x.shape();
|
auto x_shape = x.shape();
|
||||||
auto x_strides = x.strides();
|
auto x_strides = x.strides();
|
||||||
|
if (x_shape.size() == 1) {
|
||||||
|
x_shape.insert(x_shape.begin(), 1);
|
||||||
|
x_strides.insert(x_strides.begin(), 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
int x_ndim = x_shape.size();
|
||||||
|
int x_batch_ndims = x_ndim - 2;
|
||||||
int w_batch_ndims = w.ndim() - 2;
|
int w_batch_ndims = w.ndim() - 2;
|
||||||
auto w_shape = w.shape();
|
auto w_shape = w.shape();
|
||||||
auto w_strides = w.strides();
|
auto w_strides = w.strides();
|
||||||
@ -278,7 +284,7 @@ void qvm_split_k(
|
|||||||
x_shape.insert(x_shape.end() - 2, split_k);
|
x_shape.insert(x_shape.end() - 2, split_k);
|
||||||
x_shape.back() /= split_k;
|
x_shape.back() /= split_k;
|
||||||
x_strides.insert(x_strides.end() - 2, split_D);
|
x_strides.insert(x_strides.end() - 2, split_D);
|
||||||
x_strides[x.ndim() - 1] = split_D;
|
x_strides[x_ndim - 1] = split_D;
|
||||||
x_batch_ndims += 1;
|
x_batch_ndims += 1;
|
||||||
|
|
||||||
w_shape.insert(w_shape.end() - 2, split_k);
|
w_shape.insert(w_shape.end() - 2, split_k);
|
||||||
@ -291,6 +297,9 @@ void qvm_split_k(
|
|||||||
int final_block_size = K - (split_k - 1) * split_D;
|
int final_block_size = K - (split_k - 1) * split_D;
|
||||||
|
|
||||||
auto temp_shape = out.shape();
|
auto temp_shape = out.shape();
|
||||||
|
if (temp_shape.size() == 1) {
|
||||||
|
temp_shape.insert(temp_shape.begin(), 1);
|
||||||
|
}
|
||||||
temp_shape.insert(temp_shape.end() - 2, split_k);
|
temp_shape.insert(temp_shape.end() - 2, split_k);
|
||||||
array intermediate(temp_shape, x.dtype(), nullptr, {});
|
array intermediate(temp_shape, x.dtype(), nullptr, {});
|
||||||
intermediate.set_data(allocator::malloc(intermediate.nbytes()));
|
intermediate.set_data(allocator::malloc(intermediate.nbytes()));
|
||||||
|
@ -220,6 +220,19 @@ class TestQuantized(mlx_tests.MLXTestCase):
|
|||||||
self.assertEqual(y_q.shape, y_hat.shape)
|
self.assertEqual(y_q.shape, y_hat.shape)
|
||||||
self.assertLess((y_q - y_hat).abs().max(), 2e-3)
|
self.assertLess((y_q - y_hat).abs().max(), 2e-3)
|
||||||
|
|
||||||
|
# Test with 1D vector
|
||||||
|
group_size = 32
|
||||||
|
bits = 8
|
||||||
|
N = 2048
|
||||||
|
x = 1e-1 * mx.random.normal(shape=(N,), key=k1)
|
||||||
|
w = 1e-1 * mx.random.normal(shape=(N, N), key=k2)
|
||||||
|
w_q, scales, biases = mx.quantize(w, group_size, bits)
|
||||||
|
w_hat = mx.dequantize(w_q, scales, biases, group_size, bits)
|
||||||
|
y_q = mx.quantized_matmul(x, w_q, scales, biases, False, group_size, bits)
|
||||||
|
y_hat = x @ w_hat
|
||||||
|
self.assertEqual(y_q.shape, y_hat.shape)
|
||||||
|
self.assertLess((y_q - y_hat).abs().max(), 2e-3)
|
||||||
|
|
||||||
def test_throw(self):
|
def test_throw(self):
|
||||||
x = mx.random.normal(shape=(10, 512))
|
x = mx.random.normal(shape=(10, 512))
|
||||||
w = mx.random.normal(shape=(32, 512))
|
w = mx.random.normal(shape=(32, 512))
|
||||||
|
Loading…
Reference in New Issue
Block a user