mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-18 15:28:16 +08:00
More helpful error message in vjp transform + concate bug (#543)
* more helpful message in vjp transform * fix concatenate on mismatch dims * typo * typo
This commit is contained in:
@@ -1345,6 +1345,11 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
self.assertEqual(list(c_npy.shape), list(c_mlx.shape))
|
||||
self.assertTrue(np.allclose(c_npy, c_mlx, atol=1e-6))
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
a = mx.array([[1, 2], [1, 2], [1, 2]])
|
||||
b = mx.array([1, 2])
|
||||
mx.concatenate([a, b], axis=0)
|
||||
|
||||
def test_pad(self):
|
||||
pad_width_and_values = [
|
||||
([(1, 1), (1, 1), (1, 1)], 0),
|
||||
|
Reference in New Issue
Block a user