diff --git a/python/src/array.cpp b/python/src/array.cpp index 5ce09dd90..6ca0ab0ca 100644 --- a/python/src/array.cpp +++ b/python/src/array.cpp @@ -208,6 +208,7 @@ using array_init_type = std::variant< std::complex, py::list, py::tuple, + array, py::array, py::buffer, py::object>; @@ -410,8 +411,9 @@ std::optional buffer_format(const array& a) { } case bfloat16: // not supported by python buffer protocol or numpy. - // musst be null according to + // must be null according to // https://docs.python.org/3.10/c-api/buffer.html#c.PyBUF_FORMAT + // which implies 'B'. return {}; case complex64: return pybind11::format_descriptor>::format(); @@ -449,6 +451,8 @@ array create_array(array_init_type v, std::optional t) { return array_from_list(*pv, t); } else if (auto pv = std::get_if(&v); pv) { return array_from_list(*pv, t); + } else if (auto pv = std::get_if(&v); pv) { + return astype(*pv, t.value_or((*pv).dtype())); } else if (auto pv = std::get_if(&v); pv) { return np_array_to_mlx(*pv, t); } else if (auto pv = std::get_if(&v); pv) { @@ -528,7 +532,8 @@ void init_array(py::module_& m) { return pybind11::buffer_info( a.data(), a.itemsize(), - buffer_format(a).value_or(nullptr), + buffer_format(a).value_or("B"), // we use "B" because pybind uses a + // std::string which can't be null a.ndim(), a.shape(), buffer_strides(a)); diff --git a/python/tests/mlx_tests.py b/python/tests/mlx_tests.py index 9e8ed772e..8cccc61ec 100644 --- a/python/tests/mlx_tests.py +++ b/python/tests/mlx_tests.py @@ -52,10 +52,21 @@ class MLXTestCase(unittest.TestCase): atol=1e-2, rtol=1e-2, ): - assert tuple(mx_res.shape) == tuple( - expected.shape - ), f"shape mismatch expected={expected.shape} got={mx_res.shape}" - assert ( - mx_res.dtype == expected.dtype - ), f"dtype mismatch expected={expected.dtype} got={mx_res.dtype}" - np.testing.assert_allclose(mx_res, expected, rtol=rtol, atol=atol) + self.assertEqual( + tuple(mx_res.shape), + tuple(expected.shape), + msg=f"shape mismatch expected={expected.shape} got={mx_res.shape}", + ) + self.assertEqual( + mx_res.dtype, + expected.dtype, + msg=f"dtype mismatch expected={expected.dtype} got={mx_res.dtype}", + ) + if not isinstance(mx_res, mx.array) and not isinstance(expected, mx.array): + np.testing.assert_allclose(mx_res, expected, rtol=rtol, atol=atol) + elif not isinstance(mx_res, mx.array): + mx_res = mx.array(mx_res) + self.assertTrue(mx.allclose(mx_res, expected, rtol=rtol, atol=atol)) + elif not isinstance(expected, mx.array): + expected = mx.array(expected) + self.assertTrue(mx.allclose(mx_res, expected, rtol=rtol, atol=atol)) diff --git a/python/tests/test_array.py b/python/tests/test_array.py index eee570920..44775a11a 100644 --- a/python/tests/test_array.py +++ b/python/tests/test_array.py @@ -1170,7 +1170,6 @@ class TestArray(mlx_tests.MLXTestCase): f(a_np), atol=0, rtol=0, - msg=f"{mlx_dtype}{np_dtype}", ) # extra test for bfloat16, which is not numpy convertible @@ -1178,7 +1177,7 @@ class TestArray(mlx_tests.MLXTestCase): mv_mx = memoryview(a_mx) self.assertEqual(mv_mx.strides, (8, 2)) self.assertEqual(mv_mx.shape, (3, 4)) - self.assertEqual(mv_mx.format, "") + self.assertEqual(mv_mx.format, "B") with self.assertRaises(RuntimeError) as cm: np.array(a_mx) e = cm.exception @@ -1265,7 +1264,6 @@ class TestArray(mlx_tests.MLXTestCase): f(a_tf), atol=0, rtol=0, - msg=f"{mlx_dtype}{tf_dtype}", )