diff --git a/mlx/ops.cpp b/mlx/ops.cpp index df4b0495e..13e74d292 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -669,26 +669,27 @@ array concatenate( int axis, StreamOrDevice s /* = {} */) { if (arrays.size() == 0) { - throw std::invalid_argument("No arrays provided for concatenation"); + throw std::invalid_argument( + "[concatenate] No arrays provided for concatenation"); } // Normalize the given axis auto ax = axis < 0 ? axis + arrays[0].ndim() : axis; if (ax < 0 || ax >= arrays[0].ndim()) { std::ostringstream msg; - msg << "Invalid axis (" << axis << ") passed to concatenate" + msg << "[concatenate] Invalid axis (" << axis << ") passed to concatenate" << " for array with shape " << arrays[0].shape() << "."; throw std::invalid_argument(msg.str()); } auto throw_invalid_shapes = [&]() { std::ostringstream msg; - msg << "All the input array dimensions must match exactly except" - << " for the concatenation axis. However, the provided shapes are "; + msg << "[concatenate] All the input array dimensions must match exactly " + << "except for the concatenation axis. However, the provided shapes are "; for (auto& a : arrays) { msg << a.shape() << ", "; } - msg << "and the concatenation axis is " << axis; + msg << "and the concatenation axis is " << axis << "."; throw std::invalid_argument(msg.str()); }; @@ -697,6 +698,13 @@ array concatenate( // Make the output shape and validate that all arrays have the same shape // except for the concatenation axis. for (auto& a : arrays) { + if (a.ndim() != shape.size()) { + std::ostringstream msg; + msg << "[concatenate] All the input arrays must have the same number of " + << "dimensions. However, got arrays with dimensions " << shape.size() + << " and " << a.ndim() << "."; + throw std::invalid_argument(msg.str()); + } for (int i = 0; i < a.ndim(); i++) { if (i == ax) { continue; diff --git a/mlx/transforms.cpp b/mlx/transforms.cpp index 08dd92e0a..03d8d993b 100644 --- a/mlx/transforms.cpp +++ b/mlx/transforms.cpp @@ -337,12 +337,21 @@ std::pair, std::vector> vjp( } } if (cotan_index >= cotans.size()) { - throw std::invalid_argument( - "[vjp] Number of outputs with gradient does not match number of cotangents."); + std::ostringstream msg; + msg << "[vjp] Number of outputs to compute gradients for (" + << outputs.size() << ") does not match number of cotangents (" + << cotans.size() << ")."; + throw std::invalid_argument(msg.str()); } if (out.shape() != cotans[cotan_index].shape()) { - throw std::invalid_argument( - "[vjp] Output shape does not match shape of cotangent."); + std::ostringstream msg; + msg << "[vjp] Output shape " << out.shape() + << " does not match cotangent shape " << cotans[cotan_index].shape() + << "."; + if (outputs.size() == 1 && out.size() == 1) { + msg << " If you are using grad your function must return a scalar."; + } + throw std::invalid_argument(msg.str()); } output_cotan_pairs.emplace_back(i, cotan_index++); } diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index 7284085d0..9f3226bdd 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -1345,6 +1345,11 @@ class TestOps(mlx_tests.MLXTestCase): self.assertEqual(list(c_npy.shape), list(c_mlx.shape)) self.assertTrue(np.allclose(c_npy, c_mlx, atol=1e-6)) + with self.assertRaises(ValueError): + a = mx.array([[1, 2], [1, 2], [1, 2]]) + b = mx.array([1, 2]) + mx.concatenate([a, b], axis=0) + def test_pad(self): pad_width_and_values = [ ([(1, 1), (1, 1), (1, 1)], 0),