mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 16:21:27 +08:00 
			
		
		
		
	Fix matvec vector stride bug (#1168)
This commit is contained in:
		| @@ -517,6 +517,70 @@ class TestBlas(mlx_tests.MLXTestCase): | ||||
|                                     ) | ||||
|                                     self.assertTrue(np.array_equal(c_mlx, c_npy)) | ||||
|  | ||||
|     def test_mismatch_stride_mm(self): | ||||
|         np.random.seed(0) | ||||
|         a_npy = np.random.normal(0.0, 1.0 / 128, (4, 16, 16)).astype(np.float32) | ||||
|         b_npy = np.random.normal(0.0, 1.0 / 128, (4, 16, 16)).astype(np.float32) | ||||
|  | ||||
|         a_mlx = mx.array(a_npy) | ||||
|         b_mlx = mx.array(b_npy) | ||||
|  | ||||
|         # Matmul with batches | ||||
|         c_npy = a_npy[::2, :, :] @ b_npy[1::2, :, :] | ||||
|         c_mlx = a_mlx[::2, :, :] @ b_mlx[1::2, :, :] | ||||
|  | ||||
|         self.assertListEqual(list(c_npy.shape), list(c_mlx.shape)) | ||||
|         self.assertTrue(np.allclose(c_mlx, c_npy, atol=1e-5)) | ||||
|  | ||||
|         # Matvec with batches | ||||
|         c_npy = a_npy[::2, :, :] @ b_npy[1::2, :, 2:3] | ||||
|         c_mlx = a_mlx[::2, :, :] @ b_mlx[1::2, :, 2:3] | ||||
|  | ||||
|         self.assertListEqual(list(c_npy.shape), list(c_mlx.shape)) | ||||
|         self.assertTrue(np.allclose(c_mlx, c_npy, atol=1e-5)) | ||||
|  | ||||
|         # Matmul with slice | ||||
|         c_npy = a_npy[:, :8, :] @ b_npy[:, :, :8] | ||||
|         c_mlx = a_mlx[:, :8, :] @ b_mlx[:, :, :8] | ||||
|  | ||||
|         self.assertListEqual(list(c_npy.shape), list(c_mlx.shape)) | ||||
|         self.assertTrue(np.allclose(c_mlx, c_npy, atol=1e-5)) | ||||
|  | ||||
|         # Matmul with slice | ||||
|         c_npy = a_npy[:, :, :8] @ b_npy[:, :8, :] | ||||
|         c_mlx = a_mlx[:, :, :8] @ b_mlx[:, :8, :] | ||||
|  | ||||
|         self.assertListEqual(list(c_npy.shape), list(c_mlx.shape)) | ||||
|         self.assertTrue(np.allclose(c_mlx, c_npy, atol=1e-5)) | ||||
|  | ||||
|         # Matmul transpose with slice | ||||
|         c_npy = a_npy[:, :8, :] @ b_npy[:, :8, :].swapaxes(-1, -2) | ||||
|         c_mlx = a_mlx[:, :8, :] @ b_mlx[:, :8, :].swapaxes(-1, -2) | ||||
|  | ||||
|         self.assertListEqual(list(c_npy.shape), list(c_mlx.shape)) | ||||
|         self.assertTrue(np.allclose(c_mlx, c_npy, atol=1e-5)) | ||||
|  | ||||
|         # Matmul transpose with slice | ||||
|         c_npy = a_npy[:, :, :8] @ b_npy[:, :, :8].swapaxes(-1, -2) | ||||
|         c_mlx = a_mlx[:, :, :8] @ b_mlx[:, :, :8].swapaxes(-1, -2) | ||||
|  | ||||
|         self.assertListEqual(list(c_npy.shape), list(c_mlx.shape)) | ||||
|         self.assertTrue(np.allclose(c_mlx, c_npy, atol=1e-5)) | ||||
|  | ||||
|         # Matvec with slice | ||||
|         c_npy = a_npy[:, :8, :] @ b_npy[:, :, 6:7] | ||||
|         c_mlx = a_mlx[:, :8, :] @ b_mlx[:, :, 6:7] | ||||
|  | ||||
|         self.assertListEqual(list(c_npy.shape), list(c_mlx.shape)) | ||||
|         self.assertTrue(np.allclose(c_mlx, c_npy, atol=1e-5)) | ||||
|  | ||||
|         # Matvec with slice | ||||
|         c_npy = a_npy[:, :, :8] @ b_npy[:, 3:11, 2:3] | ||||
|         c_mlx = a_mlx[:, :, :8] @ b_mlx[:, 3:11, 2:3] | ||||
|  | ||||
|         self.assertListEqual(list(c_npy.shape), list(c_mlx.shape)) | ||||
|         self.assertTrue(np.allclose(c_mlx, c_npy, atol=1e-5)) | ||||
|  | ||||
|     def test_addmm(self): | ||||
|         np.random.seed(0) | ||||
|         # Batched matmul | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Jagrit Digani
					Jagrit Digani