From d8f41a5c0f858ce0dd0690f9a2e08a3afdabe388 Mon Sep 17 00:00:00 2001 From: mutexuan Date: Fri, 5 Jan 2024 10:53:33 +0800 Subject: [PATCH] support python mlx.array creation from list of mlx.array's (#325) * support python mlx.array creation from list of mlx.array's * include bfloat16 in UT * refactor so that sub array made of all python primitive types gets initialized by fill_vector * address PR comment: arr.shape().size() -> arr.ndim() * address PR comment: get back Dtype constness and let stack to handle type promotions automatically --- mlx/ops.cpp | 1 + python/src/array.cpp | 184 ++++++++++++++++++++++++------------- python/tests/test_array.py | 58 ++++++++++++ 3 files changed, 180 insertions(+), 63 deletions(-) diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 744aff68a..8ec7787f9 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -714,6 +714,7 @@ array stack( } return concatenate(new_arrays, axis, s); } + array stack(const std::vector& arrays, StreamOrDevice s /* = {} */) { return stack(arrays, 0, s); } diff --git a/python/src/array.cpp b/python/src/array.cpp index 1c6f724f4..5e57eb4c2 100644 --- a/python/src/array.cpp +++ b/python/src/array.cpp @@ -120,7 +120,11 @@ void fill_vector(T list, std::vector& vals) { } template -PyScalarT validate_shape(T list, const std::vector& shape, int idx) { +PyScalarT validate_shape( + T list, + const std::vector& shape, + int idx, + bool& all_python_primitive_elements) { if (idx >= shape.size()) { throw std::invalid_argument("Initialization encountered extra dimension."); } @@ -138,9 +142,17 @@ PyScalarT validate_shape(T list, const std::vector& shape, int idx) { for (auto l : list) { PyScalarT t; if (py::isinstance(l)) { - t = validate_shape(l.template cast(), shape, idx + 1); + t = validate_shape( + l.template cast(), + shape, + idx + 1, + all_python_primitive_elements); } else if (py::isinstance(*list.begin())) { - t = validate_shape(l.template cast(), shape, idx + 1); + t = validate_shape( + l.template cast(), + shape, + idx + 1, + all_python_primitive_elements); } else if (py::isinstance(l)) { t = pybool; } else if (py::isinstance(l)) { @@ -149,6 +161,19 @@ PyScalarT validate_shape(T list, const std::vector& shape, int idx) { t = pyfloat; } else if (PyComplex_Check(l.ptr())) { t = pycomplex; + } else if (py::isinstance(l)) { + all_python_primitive_elements = false; + auto arr = py::cast(l); + if (arr.ndim() + idx + 1 == shape.size() && + std::equal( + arr.shape().cbegin(), + arr.shape().cend(), + shape.cbegin() + idx + 1)) { + t = pybool; + } else { + throw std::invalid_argument( + "Initialization encountered non-uniform length."); + } } else { std::ostringstream msg; msg << "Invalid type in array initialization" << l.get_type() << "."; @@ -168,6 +193,64 @@ void get_shape(T list, std::vector& shape) { return get_shape(l.template cast(), shape); } else if (py::isinstance(l)) { return get_shape(l.template cast(), shape); + } else if (py::isinstance(l)) { + auto arr = py::cast(l); + shape.insert(shape.end(), arr.shape().begin(), arr.shape().end()); + return; + } + } +} + +using array_init_type = std::variant< + py::bool_, + py::int_, + py::float_, + std::complex, + py::list, + py::tuple, + py::array, + py::buffer, + py::object>; + +// Forward declaration +array create_array(array_init_type v, std::optional t); + +template +array array_from_list( + T pl, + const PyScalarT& inferred_type, + std::optional specified_type, + const std::vector& shape) { + // Make the array + switch (inferred_type) { + case pybool: { + std::vector vals; + fill_vector(pl, vals); + return array(vals.begin(), shape, specified_type.value_or(bool_)); + } + case pyint: { + std::vector vals; + fill_vector(pl, vals); + return array(vals.begin(), shape, specified_type.value_or(int32)); + } + case pyfloat: { + std::vector vals; + fill_vector(pl, vals); + return array(vals.begin(), shape, specified_type.value_or(float32)); + } + case pycomplex: { + std::vector> vals; + fill_vector(pl, vals); + return array( + reinterpret_cast(vals.data()), + shape, + specified_type.value_or(complex64)); + } + default: { + std::ostringstream msg; + msg << "Should not happen, inferred: " << inferred_type + << " on subarray made of only python primitive types."; + throw std::runtime_error(msg.str()); } } } @@ -179,39 +262,20 @@ array array_from_list(T pl, std::optional dtype) { get_shape(pl, shape); // Validate the shape and type - auto type = validate_shape(pl, shape, 0); + bool all_python_primitive_elements = true; + auto type = validate_shape(pl, shape, 0, all_python_primitive_elements); - size_t size = 1; - for (auto s : shape) { - size *= s; + if (all_python_primitive_elements) { + // `pl` does not contain mlx arrays + return array_from_list(pl, type, dtype, shape); } - // Make the array - switch (type) { - case pybool: { - std::vector vals; - fill_vector(pl, vals); - return array(vals.begin(), shape, dtype.value_or(bool_)); - } - case pyint: { - std::vector vals; - fill_vector(pl, vals); - return array(vals.begin(), shape, dtype.value_or(int32)); - } - case pyfloat: { - std::vector vals; - fill_vector(pl, vals); - return array(vals.begin(), shape, dtype.value_or(float32)); - } - case pycomplex: { - std::vector> vals; - fill_vector(pl, vals); - return array( - reinterpret_cast(vals.data()), - shape, - dtype.value_or(complex64)); - } + // `pl` contains mlx arrays + std::vector arrays; + for (auto l : pl) { + arrays.push_back(create_array(py::cast(l), dtype)); } + return stack(arrays); } /////////////////////////////////////////////////////////////////////////////// @@ -419,6 +483,29 @@ array np_array_to_mlx(py::array np_array, std::optional dtype) { // Module /////////////////////////////////////////////////////////////////////////////// +array create_array(array_init_type v, std::optional t) { + if (auto pv = std::get_if(&v); pv) { + return array(py::cast(*pv), t.value_or(bool_)); + } else if (auto pv = std::get_if(&v); pv) { + return array(py::cast(*pv), t.value_or(int32)); + } else if (auto pv = std::get_if(&v); pv) { + return array(py::cast(*pv), t.value_or(float32)); + } else if (auto pv = std::get_if>(&v); pv) { + return array(static_cast(*pv), t.value_or(complex64)); + } else if (auto pv = std::get_if(&v); pv) { + return array_from_list(*pv, t); + } else if (auto pv = std::get_if(&v); pv) { + return array_from_list(*pv, t); + } else if (auto pv = std::get_if(&v); pv) { + return np_array_to_mlx(*pv, t); + } else if (auto pv = std::get_if(&v); pv) { + return np_array_to_mlx(*pv, t); + } else { + auto arr = to_array_with_accessor(std::get(v)); + return astype(arr, t.value_or(arr.dtype())); + } +} + void init_array(py::module_& m) { // Types py::class_( @@ -466,37 +553,8 @@ void init_array(py::module_& m) { options.disable_function_signatures(); array_class.def( - py::init([](std::variant< - py::bool_, - py::int_, - py::float_, - std::complex, - py::list, - py::tuple, - py::array, - py::buffer, - py::object> v, - std::optional t) { - if (auto pv = std::get_if(&v); pv) { - return array(py::cast(*pv), t.value_or(bool_)); - } else if (auto pv = std::get_if(&v); pv) { - return array(py::cast(*pv), t.value_or(int32)); - } else if (auto pv = std::get_if(&v); pv) { - return array(py::cast(*pv), t.value_or(float32)); - } else if (auto pv = std::get_if>(&v); pv) { - return array(static_cast(*pv), t.value_or(complex64)); - } else if (auto pv = std::get_if(&v); pv) { - return array_from_list(*pv, t); - } else if (auto pv = std::get_if(&v); pv) { - return array_from_list(*pv, t); - } else if (auto pv = std::get_if(&v); pv) { - return np_array_to_mlx(*pv, t); - } else if (auto pv = std::get_if(&v); pv) { - return np_array_to_mlx(*pv, t); - } else { - auto arr = to_array_with_accessor(std::get(v)); - return astype(arr, t.value_or(arr.dtype())); - } + py::init([](array_init_type v, std::optional t) { + return create_array(v, t); }), "val"_a, "dtype"_a = std::nullopt, diff --git a/python/tests/test_array.py b/python/tests/test_array.py index 9016f3ae4..42d41a550 100644 --- a/python/tests/test_array.py +++ b/python/tests/test_array.py @@ -218,6 +218,64 @@ class TestArray(mlx_tests.MLXTestCase): x = mx.array([1 + 0j, 2j, True, 0], mx.complex64) self.assertEqual(x.tolist(), [1 + 0j, 2j, 1 + 0j, 0j]) + def test_construction_from_lists_of_mlx_arrays(self): + dtypes = [ + mx.bool_, + mx.uint8, + mx.uint16, + mx.uint32, + mx.uint64, + mx.int8, + mx.int16, + mx.int32, + mx.int64, + mx.float16, + mx.float32, + mx.bfloat16, + mx.complex64, + ] + for x_t, y_t in permutations(dtypes, 2): + # check type promotion and numeric correctness + x, y = mx.array([1.0], x_t), mx.array([2.0], y_t) + z = mx.array([x, y]) + expected = mx.stack([x, y], axis=0) + self.assertEqualArray(z, expected) + + # check heterogeneous construction with mlx arrays and python primitive types + x, y = mx.array([True], x_t), mx.array([False], y_t) + z = mx.array([[x, [2.0]], [[3.0], y]]) + expected = mx.array([[[x.item()], [2.0]], [[3.0], [y.item()]]], z.dtype) + self.assertEqualArray(z, expected) + + # check when create from an array which does not contain memory to the raw data + x = mx.array([1.0]).astype(mx.bfloat16) # x does not hold raw data + for y_t in dtypes: + y = mx.array([2.0], y_t) + z = mx.array([x, y]) + expected = mx.stack([x, y], axis=0) + self.assertEqualArray(z, expected) + + # shape check from `stack()` + with self.assertRaises(ValueError) as e: + mx.array([x, 1.0]) + self.assertEqual(str(e.exception), "All arrays must have the same shape") + + # shape check from `validate_shape` + with self.assertRaises(ValueError) as e: + mx.array([1.0, x]) + self.assertEqual( + str(e.exception), "Initialization encountered non-uniform length." + ) + + # check that `[mx.array, ...]` retains the `mx.array` in the graph + def f(x): + y = mx.array([x, mx.array([2.0])]) + return (2 * y).sum() + + x = mx.array([1.0]) + dfdx = mx.grad(f) + self.assertEqual(dfdx(x).item(), 2.0) + def test_init_from_array(self): x = mx.array(3.0) y = mx.array(x)