diff --git a/mlx/backend/metal/matmul.cpp b/mlx/backend/metal/matmul.cpp index b7596d415..1858715c2 100644 --- a/mlx/backend/metal/matmul.cpp +++ b/mlx/backend/metal/matmul.cpp @@ -710,15 +710,19 @@ void Matmul::eval_gpu(const std::vector& inputs, array& out) { ///////////////////////////////////////////////////////////////////////////// // Init checks and prep + int M = a_pre.shape(-2); + int N = b_pre.shape(-1); + int K = a_pre.shape(-1); + // Keep a vector with copies to be cleared in the completed buffer to release // the arrays std::vector copies; - auto check_transpose = [&copies, &s](const array& arr) { + auto check_transpose = [&copies, &s](const array& arr, bool is_vector) { auto stx = arr.strides()[arr.ndim() - 2]; auto sty = arr.strides()[arr.ndim() - 1]; - if (sty == 1) { + if (sty == 1 && (!is_vector || stx == arr.shape(-1))) { return std::make_tuple(false, stx, arr); - } else if (stx == 1) { + } else if (stx == 1 && (!is_vector || sty == arr.shape(-2))) { return std::make_tuple(true, sty, arr); } else { array arr_copy(arr.shape(), arr.dtype(), nullptr, {}); @@ -729,12 +733,8 @@ void Matmul::eval_gpu(const std::vector& inputs, array& out) { } }; - auto [a_transposed, a_cols, a] = check_transpose(a_pre); - auto [b_transposed, b_cols, b] = check_transpose(b_pre); - - int M = a.shape(-2); - int N = b.shape(-1); - int K = a.shape(-1); + auto [a_transposed, a_cols, a] = check_transpose(a_pre, M == 1); + auto [b_transposed, b_cols, b] = check_transpose(b_pre, N == 1); ///////////////////////////////////////////////////////////////////////////// // Check and collapse batch dimensions diff --git a/python/tests/test_blas.py b/python/tests/test_blas.py index de69d00dd..20df4341e 100644 --- a/python/tests/test_blas.py +++ b/python/tests/test_blas.py @@ -517,6 +517,70 @@ class TestBlas(mlx_tests.MLXTestCase): ) self.assertTrue(np.array_equal(c_mlx, c_npy)) + def test_mismatch_stride_mm(self): + np.random.seed(0) + a_npy = np.random.normal(0.0, 1.0 / 128, (4, 16, 16)).astype(np.float32) + b_npy = np.random.normal(0.0, 1.0 / 128, (4, 16, 16)).astype(np.float32) + + a_mlx = mx.array(a_npy) + b_mlx = mx.array(b_npy) + + # Matmul with batches + c_npy = a_npy[::2, :, :] @ b_npy[1::2, :, :] + c_mlx = a_mlx[::2, :, :] @ b_mlx[1::2, :, :] + + self.assertListEqual(list(c_npy.shape), list(c_mlx.shape)) + self.assertTrue(np.allclose(c_mlx, c_npy, atol=1e-5)) + + # Matvec with batches + c_npy = a_npy[::2, :, :] @ b_npy[1::2, :, 2:3] + c_mlx = a_mlx[::2, :, :] @ b_mlx[1::2, :, 2:3] + + self.assertListEqual(list(c_npy.shape), list(c_mlx.shape)) + self.assertTrue(np.allclose(c_mlx, c_npy, atol=1e-5)) + + # Matmul with slice + c_npy = a_npy[:, :8, :] @ b_npy[:, :, :8] + c_mlx = a_mlx[:, :8, :] @ b_mlx[:, :, :8] + + self.assertListEqual(list(c_npy.shape), list(c_mlx.shape)) + self.assertTrue(np.allclose(c_mlx, c_npy, atol=1e-5)) + + # Matmul with slice + c_npy = a_npy[:, :, :8] @ b_npy[:, :8, :] + c_mlx = a_mlx[:, :, :8] @ b_mlx[:, :8, :] + + self.assertListEqual(list(c_npy.shape), list(c_mlx.shape)) + self.assertTrue(np.allclose(c_mlx, c_npy, atol=1e-5)) + + # Matmul transpose with slice + c_npy = a_npy[:, :8, :] @ b_npy[:, :8, :].swapaxes(-1, -2) + c_mlx = a_mlx[:, :8, :] @ b_mlx[:, :8, :].swapaxes(-1, -2) + + self.assertListEqual(list(c_npy.shape), list(c_mlx.shape)) + self.assertTrue(np.allclose(c_mlx, c_npy, atol=1e-5)) + + # Matmul transpose with slice + c_npy = a_npy[:, :, :8] @ b_npy[:, :, :8].swapaxes(-1, -2) + c_mlx = a_mlx[:, :, :8] @ b_mlx[:, :, :8].swapaxes(-1, -2) + + self.assertListEqual(list(c_npy.shape), list(c_mlx.shape)) + self.assertTrue(np.allclose(c_mlx, c_npy, atol=1e-5)) + + # Matvec with slice + c_npy = a_npy[:, :8, :] @ b_npy[:, :, 6:7] + c_mlx = a_mlx[:, :8, :] @ b_mlx[:, :, 6:7] + + self.assertListEqual(list(c_npy.shape), list(c_mlx.shape)) + self.assertTrue(np.allclose(c_mlx, c_npy, atol=1e-5)) + + # Matvec with slice + c_npy = a_npy[:, :, :8] @ b_npy[:, 3:11, 2:3] + c_mlx = a_mlx[:, :, :8] @ b_mlx[:, 3:11, 2:3] + + self.assertListEqual(list(c_npy.shape), list(c_mlx.shape)) + self.assertTrue(np.allclose(c_mlx, c_npy, atol=1e-5)) + def test_addmm(self): np.random.seed(0) # Batched matmul