Fix matvec vector stride bug (#1168)

This commit is contained in:
Jagrit Digani
2024-05-29 12:18:28 -07:00
committed by GitHub
parent e7a2a3dcd1
commit 9f0df51f8d
2 changed files with 73 additions and 9 deletions

View File

@@ -517,6 +517,70 @@ class TestBlas(mlx_tests.MLXTestCase):
)
self.assertTrue(np.array_equal(c_mlx, c_npy))
def test_mismatch_stride_mm(self):
np.random.seed(0)
a_npy = np.random.normal(0.0, 1.0 / 128, (4, 16, 16)).astype(np.float32)
b_npy = np.random.normal(0.0, 1.0 / 128, (4, 16, 16)).astype(np.float32)
a_mlx = mx.array(a_npy)
b_mlx = mx.array(b_npy)
# Matmul with batches
c_npy = a_npy[::2, :, :] @ b_npy[1::2, :, :]
c_mlx = a_mlx[::2, :, :] @ b_mlx[1::2, :, :]
self.assertListEqual(list(c_npy.shape), list(c_mlx.shape))
self.assertTrue(np.allclose(c_mlx, c_npy, atol=1e-5))
# Matvec with batches
c_npy = a_npy[::2, :, :] @ b_npy[1::2, :, 2:3]
c_mlx = a_mlx[::2, :, :] @ b_mlx[1::2, :, 2:3]
self.assertListEqual(list(c_npy.shape), list(c_mlx.shape))
self.assertTrue(np.allclose(c_mlx, c_npy, atol=1e-5))
# Matmul with slice
c_npy = a_npy[:, :8, :] @ b_npy[:, :, :8]
c_mlx = a_mlx[:, :8, :] @ b_mlx[:, :, :8]
self.assertListEqual(list(c_npy.shape), list(c_mlx.shape))
self.assertTrue(np.allclose(c_mlx, c_npy, atol=1e-5))
# Matmul with slice
c_npy = a_npy[:, :, :8] @ b_npy[:, :8, :]
c_mlx = a_mlx[:, :, :8] @ b_mlx[:, :8, :]
self.assertListEqual(list(c_npy.shape), list(c_mlx.shape))
self.assertTrue(np.allclose(c_mlx, c_npy, atol=1e-5))
# Matmul transpose with slice
c_npy = a_npy[:, :8, :] @ b_npy[:, :8, :].swapaxes(-1, -2)
c_mlx = a_mlx[:, :8, :] @ b_mlx[:, :8, :].swapaxes(-1, -2)
self.assertListEqual(list(c_npy.shape), list(c_mlx.shape))
self.assertTrue(np.allclose(c_mlx, c_npy, atol=1e-5))
# Matmul transpose with slice
c_npy = a_npy[:, :, :8] @ b_npy[:, :, :8].swapaxes(-1, -2)
c_mlx = a_mlx[:, :, :8] @ b_mlx[:, :, :8].swapaxes(-1, -2)
self.assertListEqual(list(c_npy.shape), list(c_mlx.shape))
self.assertTrue(np.allclose(c_mlx, c_npy, atol=1e-5))
# Matvec with slice
c_npy = a_npy[:, :8, :] @ b_npy[:, :, 6:7]
c_mlx = a_mlx[:, :8, :] @ b_mlx[:, :, 6:7]
self.assertListEqual(list(c_npy.shape), list(c_mlx.shape))
self.assertTrue(np.allclose(c_mlx, c_npy, atol=1e-5))
# Matvec with slice
c_npy = a_npy[:, :, :8] @ b_npy[:, 3:11, 2:3]
c_mlx = a_mlx[:, :, :8] @ b_mlx[:, 3:11, 2:3]
self.assertListEqual(list(c_npy.shape), list(c_mlx.shape))
self.assertTrue(np.allclose(c_mlx, c_npy, atol=1e-5))
def test_addmm(self):
np.random.seed(0)
# Batched matmul