mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 16:21:27 +08:00 
			
		
		
		
	support python mlx.array creation from list of mlx.array's (#325)
* support python mlx.array creation from list of mlx.array's * include bfloat16 in UT * refactor so that sub array made of all python primitive types gets initialized by fill_vector * address PR comment: arr.shape().size() -> arr.ndim() * address PR comment: get back Dtype constness and let stack to handle type promotions automatically
This commit is contained in:
		| @@ -120,7 +120,11 @@ void fill_vector(T list, std::vector<U>& vals) { | ||||
| } | ||||
|  | ||||
| template <typename T> | ||||
| PyScalarT validate_shape(T list, const std::vector<int>& shape, int idx) { | ||||
| PyScalarT validate_shape( | ||||
|     T list, | ||||
|     const std::vector<int>& shape, | ||||
|     int idx, | ||||
|     bool& all_python_primitive_elements) { | ||||
|   if (idx >= shape.size()) { | ||||
|     throw std::invalid_argument("Initialization encountered extra dimension."); | ||||
|   } | ||||
| @@ -138,9 +142,17 @@ PyScalarT validate_shape(T list, const std::vector<int>& shape, int idx) { | ||||
|   for (auto l : list) { | ||||
|     PyScalarT t; | ||||
|     if (py::isinstance<py::list>(l)) { | ||||
|       t = validate_shape(l.template cast<py::list>(), shape, idx + 1); | ||||
|       t = validate_shape( | ||||
|           l.template cast<py::list>(), | ||||
|           shape, | ||||
|           idx + 1, | ||||
|           all_python_primitive_elements); | ||||
|     } else if (py::isinstance<py::tuple>(*list.begin())) { | ||||
|       t = validate_shape(l.template cast<py::tuple>(), shape, idx + 1); | ||||
|       t = validate_shape( | ||||
|           l.template cast<py::tuple>(), | ||||
|           shape, | ||||
|           idx + 1, | ||||
|           all_python_primitive_elements); | ||||
|     } else if (py::isinstance<py::bool_>(l)) { | ||||
|       t = pybool; | ||||
|     } else if (py::isinstance<py::int_>(l)) { | ||||
| @@ -149,6 +161,19 @@ PyScalarT validate_shape(T list, const std::vector<int>& shape, int idx) { | ||||
|       t = pyfloat; | ||||
|     } else if (PyComplex_Check(l.ptr())) { | ||||
|       t = pycomplex; | ||||
|     } else if (py::isinstance<array>(l)) { | ||||
|       all_python_primitive_elements = false; | ||||
|       auto arr = py::cast<array>(l); | ||||
|       if (arr.ndim() + idx + 1 == shape.size() && | ||||
|           std::equal( | ||||
|               arr.shape().cbegin(), | ||||
|               arr.shape().cend(), | ||||
|               shape.cbegin() + idx + 1)) { | ||||
|         t = pybool; | ||||
|       } else { | ||||
|         throw std::invalid_argument( | ||||
|             "Initialization encountered non-uniform length."); | ||||
|       } | ||||
|     } else { | ||||
|       std::ostringstream msg; | ||||
|       msg << "Invalid type in array initialization" << l.get_type() << "."; | ||||
| @@ -168,6 +193,64 @@ void get_shape(T list, std::vector<int>& shape) { | ||||
|       return get_shape(l.template cast<py::list>(), shape); | ||||
|     } else if (py::isinstance<py::tuple>(l)) { | ||||
|       return get_shape(l.template cast<py::tuple>(), shape); | ||||
|     } else if (py::isinstance<array>(l)) { | ||||
|       auto arr = py::cast<array>(l); | ||||
|       shape.insert(shape.end(), arr.shape().begin(), arr.shape().end()); | ||||
|       return; | ||||
|     } | ||||
|   } | ||||
| } | ||||
|  | ||||
| using array_init_type = std::variant< | ||||
|     py::bool_, | ||||
|     py::int_, | ||||
|     py::float_, | ||||
|     std::complex<float>, | ||||
|     py::list, | ||||
|     py::tuple, | ||||
|     py::array, | ||||
|     py::buffer, | ||||
|     py::object>; | ||||
|  | ||||
| // Forward declaration | ||||
| array create_array(array_init_type v, std::optional<Dtype> t); | ||||
|  | ||||
| template <typename T> | ||||
| array array_from_list( | ||||
|     T pl, | ||||
|     const PyScalarT& inferred_type, | ||||
|     std::optional<Dtype> specified_type, | ||||
|     const std::vector<int>& shape) { | ||||
|   // Make the array | ||||
|   switch (inferred_type) { | ||||
|     case pybool: { | ||||
|       std::vector<bool> vals; | ||||
|       fill_vector(pl, vals); | ||||
|       return array(vals.begin(), shape, specified_type.value_or(bool_)); | ||||
|     } | ||||
|     case pyint: { | ||||
|       std::vector<int> vals; | ||||
|       fill_vector(pl, vals); | ||||
|       return array(vals.begin(), shape, specified_type.value_or(int32)); | ||||
|     } | ||||
|     case pyfloat: { | ||||
|       std::vector<float> vals; | ||||
|       fill_vector(pl, vals); | ||||
|       return array(vals.begin(), shape, specified_type.value_or(float32)); | ||||
|     } | ||||
|     case pycomplex: { | ||||
|       std::vector<std::complex<float>> vals; | ||||
|       fill_vector(pl, vals); | ||||
|       return array( | ||||
|           reinterpret_cast<complex64_t*>(vals.data()), | ||||
|           shape, | ||||
|           specified_type.value_or(complex64)); | ||||
|     } | ||||
|     default: { | ||||
|       std::ostringstream msg; | ||||
|       msg << "Should not happen, inferred: " << inferred_type | ||||
|           << " on subarray made of only python primitive types."; | ||||
|       throw std::runtime_error(msg.str()); | ||||
|     } | ||||
|   } | ||||
| } | ||||
| @@ -179,39 +262,20 @@ array array_from_list(T pl, std::optional<Dtype> dtype) { | ||||
|   get_shape(pl, shape); | ||||
|  | ||||
|   // Validate the shape and type | ||||
|   auto type = validate_shape(pl, shape, 0); | ||||
|   bool all_python_primitive_elements = true; | ||||
|   auto type = validate_shape(pl, shape, 0, all_python_primitive_elements); | ||||
|  | ||||
|   size_t size = 1; | ||||
|   for (auto s : shape) { | ||||
|     size *= s; | ||||
|   if (all_python_primitive_elements) { | ||||
|     // `pl` does not contain mlx arrays | ||||
|     return array_from_list(pl, type, dtype, shape); | ||||
|   } | ||||
|  | ||||
|   // Make the array | ||||
|   switch (type) { | ||||
|     case pybool: { | ||||
|       std::vector<bool> vals; | ||||
|       fill_vector(pl, vals); | ||||
|       return array(vals.begin(), shape, dtype.value_or(bool_)); | ||||
|     } | ||||
|     case pyint: { | ||||
|       std::vector<int> vals; | ||||
|       fill_vector(pl, vals); | ||||
|       return array(vals.begin(), shape, dtype.value_or(int32)); | ||||
|     } | ||||
|     case pyfloat: { | ||||
|       std::vector<float> vals; | ||||
|       fill_vector(pl, vals); | ||||
|       return array(vals.begin(), shape, dtype.value_or(float32)); | ||||
|     } | ||||
|     case pycomplex: { | ||||
|       std::vector<std::complex<float>> vals; | ||||
|       fill_vector(pl, vals); | ||||
|       return array( | ||||
|           reinterpret_cast<complex64_t*>(vals.data()), | ||||
|           shape, | ||||
|           dtype.value_or(complex64)); | ||||
|     } | ||||
|   // `pl` contains mlx arrays | ||||
|   std::vector<array> arrays; | ||||
|   for (auto l : pl) { | ||||
|     arrays.push_back(create_array(py::cast<array_init_type>(l), dtype)); | ||||
|   } | ||||
|   return stack(arrays); | ||||
| } | ||||
|  | ||||
| /////////////////////////////////////////////////////////////////////////////// | ||||
| @@ -419,6 +483,29 @@ array np_array_to_mlx(py::array np_array, std::optional<Dtype> dtype) { | ||||
| // Module | ||||
| /////////////////////////////////////////////////////////////////////////////// | ||||
|  | ||||
| array create_array(array_init_type v, std::optional<Dtype> t) { | ||||
|   if (auto pv = std::get_if<py::bool_>(&v); pv) { | ||||
|     return array(py::cast<bool>(*pv), t.value_or(bool_)); | ||||
|   } else if (auto pv = std::get_if<py::int_>(&v); pv) { | ||||
|     return array(py::cast<int>(*pv), t.value_or(int32)); | ||||
|   } else if (auto pv = std::get_if<py::float_>(&v); pv) { | ||||
|     return array(py::cast<float>(*pv), t.value_or(float32)); | ||||
|   } else if (auto pv = std::get_if<std::complex<float>>(&v); pv) { | ||||
|     return array(static_cast<complex64_t>(*pv), t.value_or(complex64)); | ||||
|   } else if (auto pv = std::get_if<py::list>(&v); pv) { | ||||
|     return array_from_list(*pv, t); | ||||
|   } else if (auto pv = std::get_if<py::tuple>(&v); pv) { | ||||
|     return array_from_list(*pv, t); | ||||
|   } else if (auto pv = std::get_if<py::array>(&v); pv) { | ||||
|     return np_array_to_mlx(*pv, t); | ||||
|   } else if (auto pv = std::get_if<py::buffer>(&v); pv) { | ||||
|     return np_array_to_mlx(*pv, t); | ||||
|   } else { | ||||
|     auto arr = to_array_with_accessor(std::get<py::object>(v)); | ||||
|     return astype(arr, t.value_or(arr.dtype())); | ||||
|   } | ||||
| } | ||||
|  | ||||
| void init_array(py::module_& m) { | ||||
|   // Types | ||||
|   py::class_<Dtype>( | ||||
| @@ -466,37 +553,8 @@ void init_array(py::module_& m) { | ||||
|     options.disable_function_signatures(); | ||||
|  | ||||
|     array_class.def( | ||||
|         py::init([](std::variant< | ||||
|                         py::bool_, | ||||
|                         py::int_, | ||||
|                         py::float_, | ||||
|                         std::complex<float>, | ||||
|                         py::list, | ||||
|                         py::tuple, | ||||
|                         py::array, | ||||
|                         py::buffer, | ||||
|                         py::object> v, | ||||
|                     std::optional<Dtype> t) { | ||||
|           if (auto pv = std::get_if<py::bool_>(&v); pv) { | ||||
|             return array(py::cast<bool>(*pv), t.value_or(bool_)); | ||||
|           } else if (auto pv = std::get_if<py::int_>(&v); pv) { | ||||
|             return array(py::cast<int>(*pv), t.value_or(int32)); | ||||
|           } else if (auto pv = std::get_if<py::float_>(&v); pv) { | ||||
|             return array(py::cast<float>(*pv), t.value_or(float32)); | ||||
|           } else if (auto pv = std::get_if<std::complex<float>>(&v); pv) { | ||||
|             return array(static_cast<complex64_t>(*pv), t.value_or(complex64)); | ||||
|           } else if (auto pv = std::get_if<py::list>(&v); pv) { | ||||
|             return array_from_list(*pv, t); | ||||
|           } else if (auto pv = std::get_if<py::tuple>(&v); pv) { | ||||
|             return array_from_list(*pv, t); | ||||
|           } else if (auto pv = std::get_if<py::array>(&v); pv) { | ||||
|             return np_array_to_mlx(*pv, t); | ||||
|           } else if (auto pv = std::get_if<py::buffer>(&v); pv) { | ||||
|             return np_array_to_mlx(*pv, t); | ||||
|           } else { | ||||
|             auto arr = to_array_with_accessor(std::get<py::object>(v)); | ||||
|             return astype(arr, t.value_or(arr.dtype())); | ||||
|           } | ||||
|         py::init([](array_init_type v, std::optional<Dtype> t) { | ||||
|           return create_array(v, t); | ||||
|         }), | ||||
|         "val"_a, | ||||
|         "dtype"_a = std::nullopt, | ||||
|   | ||||
| @@ -218,6 +218,64 @@ class TestArray(mlx_tests.MLXTestCase): | ||||
|         x = mx.array([1 + 0j, 2j, True, 0], mx.complex64) | ||||
|         self.assertEqual(x.tolist(), [1 + 0j, 2j, 1 + 0j, 0j]) | ||||
|  | ||||
|     def test_construction_from_lists_of_mlx_arrays(self): | ||||
|         dtypes = [ | ||||
|             mx.bool_, | ||||
|             mx.uint8, | ||||
|             mx.uint16, | ||||
|             mx.uint32, | ||||
|             mx.uint64, | ||||
|             mx.int8, | ||||
|             mx.int16, | ||||
|             mx.int32, | ||||
|             mx.int64, | ||||
|             mx.float16, | ||||
|             mx.float32, | ||||
|             mx.bfloat16, | ||||
|             mx.complex64, | ||||
|         ] | ||||
|         for x_t, y_t in permutations(dtypes, 2): | ||||
|             # check type promotion and numeric correctness | ||||
|             x, y = mx.array([1.0], x_t), mx.array([2.0], y_t) | ||||
|             z = mx.array([x, y]) | ||||
|             expected = mx.stack([x, y], axis=0) | ||||
|             self.assertEqualArray(z, expected) | ||||
|  | ||||
|             # check heterogeneous construction with mlx arrays and python primitive types | ||||
|             x, y = mx.array([True], x_t), mx.array([False], y_t) | ||||
|             z = mx.array([[x, [2.0]], [[3.0], y]]) | ||||
|             expected = mx.array([[[x.item()], [2.0]], [[3.0], [y.item()]]], z.dtype) | ||||
|             self.assertEqualArray(z, expected) | ||||
|  | ||||
|         # check when create from an array which does not contain memory to the raw data | ||||
|         x = mx.array([1.0]).astype(mx.bfloat16)  # x does not hold raw data | ||||
|         for y_t in dtypes: | ||||
|             y = mx.array([2.0], y_t) | ||||
|             z = mx.array([x, y]) | ||||
|             expected = mx.stack([x, y], axis=0) | ||||
|             self.assertEqualArray(z, expected) | ||||
|  | ||||
|         # shape check from `stack()` | ||||
|         with self.assertRaises(ValueError) as e: | ||||
|             mx.array([x, 1.0]) | ||||
|         self.assertEqual(str(e.exception), "All arrays must have the same shape") | ||||
|  | ||||
|         # shape check from `validate_shape` | ||||
|         with self.assertRaises(ValueError) as e: | ||||
|             mx.array([1.0, x]) | ||||
|         self.assertEqual( | ||||
|             str(e.exception), "Initialization encountered non-uniform length." | ||||
|         ) | ||||
|  | ||||
|         # check that `[mx.array, ...]` retains the `mx.array` in the graph | ||||
|         def f(x): | ||||
|             y = mx.array([x, mx.array([2.0])]) | ||||
|             return (2 * y).sum() | ||||
|  | ||||
|         x = mx.array([1.0]) | ||||
|         dfdx = mx.grad(f) | ||||
|         self.assertEqual(dfdx(x).item(), 2.0) | ||||
|  | ||||
|     def test_init_from_array(self): | ||||
|         x = mx.array(3.0) | ||||
|         y = mx.array(x) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 mutexuan
					mutexuan