diff --git a/python/src/array.cpp b/python/src/array.cpp index b405faf84..0ef27880b 100644 --- a/python/src/array.cpp +++ b/python/src/array.cpp @@ -162,14 +162,6 @@ PyScalarT validate_shape( shape, idx + 1, all_python_primitive_elements); - } else if (nb::isinstance(l)) { - t = pybool; - } else if (nb::isinstance(l)) { - t = pyint; - } else if (nb::isinstance(l)) { - t = pyfloat; - } else if (PyComplex_Check(l.ptr())) { - t = pycomplex; } else if (nb::isinstance(l)) { all_python_primitive_elements = false; auto arr = nb::cast(l); @@ -184,10 +176,25 @@ PyScalarT validate_shape( "Initialization encountered non-uniform length."); } } else { - std::ostringstream msg; - msg << "Invalid type " << nb::type_name(l.type()).c_str() - << " received in array initialization."; - throw std::invalid_argument(msg.str()); + if (nb::isinstance(l)) { + t = pybool; + } else if (nb::isinstance(l)) { + t = pyint; + } else if (nb::isinstance(l)) { + t = pyfloat; + } else if (PyComplex_Check(l.ptr())) { + t = pycomplex; + } else { + std::ostringstream msg; + msg << "Invalid type " << nb::type_name(l.type()).c_str() + << " received in array initialization."; + throw std::invalid_argument(msg.str()); + } + + if (idx + 1 != shape.size()) { + throw std::invalid_argument( + "Initialization encountered non-uniform length."); + } } type = std::max(type, t); } @@ -1440,4 +1447,4 @@ void init_array(nb::module_& m) { R"pbdoc( Extract a diagonal or construct a diagonal matrix. )pbdoc"); -} \ No newline at end of file +} diff --git a/python/tests/test_array.py b/python/tests/test_array.py index 4e8c00134..587a98e2e 100644 --- a/python/tests/test_array.py +++ b/python/tests/test_array.py @@ -391,7 +391,9 @@ class TestArray(mlx_tests.MLXTestCase): # shape check from `stack()` with self.assertRaises(ValueError) as e: mx.array([x, 1.0]) - self.assertEqual(str(e.exception), "All arrays must have the same shape") + self.assertEqual( + str(e.exception), "Initialization encountered non-uniform length." + ) # shape check from `validate_shape` with self.assertRaises(ValueError) as e: