More helpful error message in vjp transform + concate bug (#543)

* more helpful message in vjp transform

* fix concatenate on mismatch dims

* typo

* typo
This commit is contained in:
Awni Hannun 2024-01-24 09:58:33 -08:00 committed by GitHub
parent f30e63353a
commit f27ec5e097
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 31 additions and 9 deletions

View File

@ -669,26 +669,27 @@ array concatenate(
int axis, int axis,
StreamOrDevice s /* = {} */) { StreamOrDevice s /* = {} */) {
if (arrays.size() == 0) { if (arrays.size() == 0) {
throw std::invalid_argument("No arrays provided for concatenation"); throw std::invalid_argument(
"[concatenate] No arrays provided for concatenation");
} }
// Normalize the given axis // Normalize the given axis
auto ax = axis < 0 ? axis + arrays[0].ndim() : axis; auto ax = axis < 0 ? axis + arrays[0].ndim() : axis;
if (ax < 0 || ax >= arrays[0].ndim()) { if (ax < 0 || ax >= arrays[0].ndim()) {
std::ostringstream msg; std::ostringstream msg;
msg << "Invalid axis (" << axis << ") passed to concatenate" msg << "[concatenate] Invalid axis (" << axis << ") passed to concatenate"
<< " for array with shape " << arrays[0].shape() << "."; << " for array with shape " << arrays[0].shape() << ".";
throw std::invalid_argument(msg.str()); throw std::invalid_argument(msg.str());
} }
auto throw_invalid_shapes = [&]() { auto throw_invalid_shapes = [&]() {
std::ostringstream msg; std::ostringstream msg;
msg << "All the input array dimensions must match exactly except" msg << "[concatenate] All the input array dimensions must match exactly "
<< " for the concatenation axis. However, the provided shapes are "; << "except for the concatenation axis. However, the provided shapes are ";
for (auto& a : arrays) { for (auto& a : arrays) {
msg << a.shape() << ", "; msg << a.shape() << ", ";
} }
msg << "and the concatenation axis is " << axis; msg << "and the concatenation axis is " << axis << ".";
throw std::invalid_argument(msg.str()); throw std::invalid_argument(msg.str());
}; };
@ -697,6 +698,13 @@ array concatenate(
// Make the output shape and validate that all arrays have the same shape // Make the output shape and validate that all arrays have the same shape
// except for the concatenation axis. // except for the concatenation axis.
for (auto& a : arrays) { for (auto& a : arrays) {
if (a.ndim() != shape.size()) {
std::ostringstream msg;
msg << "[concatenate] All the input arrays must have the same number of "
<< "dimensions. However, got arrays with dimensions " << shape.size()
<< " and " << a.ndim() << ".";
throw std::invalid_argument(msg.str());
}
for (int i = 0; i < a.ndim(); i++) { for (int i = 0; i < a.ndim(); i++) {
if (i == ax) { if (i == ax) {
continue; continue;

View File

@ -337,12 +337,21 @@ std::pair<std::vector<array>, std::vector<array>> vjp(
} }
} }
if (cotan_index >= cotans.size()) { if (cotan_index >= cotans.size()) {
throw std::invalid_argument( std::ostringstream msg;
"[vjp] Number of outputs with gradient does not match number of cotangents."); msg << "[vjp] Number of outputs to compute gradients for ("
<< outputs.size() << ") does not match number of cotangents ("
<< cotans.size() << ").";
throw std::invalid_argument(msg.str());
} }
if (out.shape() != cotans[cotan_index].shape()) { if (out.shape() != cotans[cotan_index].shape()) {
throw std::invalid_argument( std::ostringstream msg;
"[vjp] Output shape does not match shape of cotangent."); msg << "[vjp] Output shape " << out.shape()
<< " does not match cotangent shape " << cotans[cotan_index].shape()
<< ".";
if (outputs.size() == 1 && out.size() == 1) {
msg << " If you are using grad your function must return a scalar.";
}
throw std::invalid_argument(msg.str());
} }
output_cotan_pairs.emplace_back(i, cotan_index++); output_cotan_pairs.emplace_back(i, cotan_index++);
} }

View File

@ -1345,6 +1345,11 @@ class TestOps(mlx_tests.MLXTestCase):
self.assertEqual(list(c_npy.shape), list(c_mlx.shape)) self.assertEqual(list(c_npy.shape), list(c_mlx.shape))
self.assertTrue(np.allclose(c_npy, c_mlx, atol=1e-6)) self.assertTrue(np.allclose(c_npy, c_mlx, atol=1e-6))
with self.assertRaises(ValueError):
a = mx.array([[1, 2], [1, 2], [1, 2]])
b = mx.array([1, 2])
mx.concatenate([a, b], axis=0)
def test_pad(self): def test_pad(self):
pad_width_and_values = [ pad_width_and_values = [
([(1, 1), (1, 1), (1, 1)], 0), ([(1, 1), (1, 1), (1, 1)], 0),