More helpful error message in vjp transform + concate bug (#543)

* more helpful message in vjp transform

* fix concatenate on mismatch dims

* typo

* typo
This commit is contained in:
Awni Hannun
2024-01-24 09:58:33 -08:00
committed by GitHub
parent f30e63353a
commit f27ec5e097
3 changed files with 31 additions and 9 deletions

View File

@@ -1345,6 +1345,11 @@ class TestOps(mlx_tests.MLXTestCase):
self.assertEqual(list(c_npy.shape), list(c_mlx.shape))
self.assertTrue(np.allclose(c_npy, c_mlx, atol=1e-6))
with self.assertRaises(ValueError):
a = mx.array([[1, 2], [1, 2], [1, 2]])
b = mx.array([1, 2])
mx.concatenate([a, b], axis=0)
def test_pad(self):
pad_width_and_values = [
([(1, 1), (1, 1), (1, 1)], 0),