support python mlx.array creation from list of mlx.array's (#325)

* support python mlx.array creation from list of mlx.array's

* include bfloat16 in UT

* refactor so that sub array made of all python primitive types gets initialized by fill_vector

* address PR comment: arr.shape().size() -> arr.ndim()

* address PR comment: get back Dtype constness and let stack to handle type promotions automatically
This commit is contained in:
mutexuan 2024-01-05 10:53:33 +08:00 committed by GitHub
parent b9e415d19c
commit d8f41a5c0f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 180 additions and 63 deletions

View File

@ -714,6 +714,7 @@ array stack(
} }
return concatenate(new_arrays, axis, s); return concatenate(new_arrays, axis, s);
} }
array stack(const std::vector<array>& arrays, StreamOrDevice s /* = {} */) { array stack(const std::vector<array>& arrays, StreamOrDevice s /* = {} */) {
return stack(arrays, 0, s); return stack(arrays, 0, s);
} }

View File

@ -120,7 +120,11 @@ void fill_vector(T list, std::vector<U>& vals) {
} }
template <typename T> template <typename T>
PyScalarT validate_shape(T list, const std::vector<int>& shape, int idx) { PyScalarT validate_shape(
T list,
const std::vector<int>& shape,
int idx,
bool& all_python_primitive_elements) {
if (idx >= shape.size()) { if (idx >= shape.size()) {
throw std::invalid_argument("Initialization encountered extra dimension."); throw std::invalid_argument("Initialization encountered extra dimension.");
} }
@ -138,9 +142,17 @@ PyScalarT validate_shape(T list, const std::vector<int>& shape, int idx) {
for (auto l : list) { for (auto l : list) {
PyScalarT t; PyScalarT t;
if (py::isinstance<py::list>(l)) { if (py::isinstance<py::list>(l)) {
t = validate_shape(l.template cast<py::list>(), shape, idx + 1); t = validate_shape(
l.template cast<py::list>(),
shape,
idx + 1,
all_python_primitive_elements);
} else if (py::isinstance<py::tuple>(*list.begin())) { } else if (py::isinstance<py::tuple>(*list.begin())) {
t = validate_shape(l.template cast<py::tuple>(), shape, idx + 1); t = validate_shape(
l.template cast<py::tuple>(),
shape,
idx + 1,
all_python_primitive_elements);
} else if (py::isinstance<py::bool_>(l)) { } else if (py::isinstance<py::bool_>(l)) {
t = pybool; t = pybool;
} else if (py::isinstance<py::int_>(l)) { } else if (py::isinstance<py::int_>(l)) {
@ -149,6 +161,19 @@ PyScalarT validate_shape(T list, const std::vector<int>& shape, int idx) {
t = pyfloat; t = pyfloat;
} else if (PyComplex_Check(l.ptr())) { } else if (PyComplex_Check(l.ptr())) {
t = pycomplex; t = pycomplex;
} else if (py::isinstance<array>(l)) {
all_python_primitive_elements = false;
auto arr = py::cast<array>(l);
if (arr.ndim() + idx + 1 == shape.size() &&
std::equal(
arr.shape().cbegin(),
arr.shape().cend(),
shape.cbegin() + idx + 1)) {
t = pybool;
} else {
throw std::invalid_argument(
"Initialization encountered non-uniform length.");
}
} else { } else {
std::ostringstream msg; std::ostringstream msg;
msg << "Invalid type in array initialization" << l.get_type() << "."; msg << "Invalid type in array initialization" << l.get_type() << ".";
@ -168,6 +193,64 @@ void get_shape(T list, std::vector<int>& shape) {
return get_shape(l.template cast<py::list>(), shape); return get_shape(l.template cast<py::list>(), shape);
} else if (py::isinstance<py::tuple>(l)) { } else if (py::isinstance<py::tuple>(l)) {
return get_shape(l.template cast<py::tuple>(), shape); return get_shape(l.template cast<py::tuple>(), shape);
} else if (py::isinstance<array>(l)) {
auto arr = py::cast<array>(l);
shape.insert(shape.end(), arr.shape().begin(), arr.shape().end());
return;
}
}
}
using array_init_type = std::variant<
py::bool_,
py::int_,
py::float_,
std::complex<float>,
py::list,
py::tuple,
py::array,
py::buffer,
py::object>;
// Forward declaration
array create_array(array_init_type v, std::optional<Dtype> t);
template <typename T>
array array_from_list(
T pl,
const PyScalarT& inferred_type,
std::optional<Dtype> specified_type,
const std::vector<int>& shape) {
// Make the array
switch (inferred_type) {
case pybool: {
std::vector<bool> vals;
fill_vector(pl, vals);
return array(vals.begin(), shape, specified_type.value_or(bool_));
}
case pyint: {
std::vector<int> vals;
fill_vector(pl, vals);
return array(vals.begin(), shape, specified_type.value_or(int32));
}
case pyfloat: {
std::vector<float> vals;
fill_vector(pl, vals);
return array(vals.begin(), shape, specified_type.value_or(float32));
}
case pycomplex: {
std::vector<std::complex<float>> vals;
fill_vector(pl, vals);
return array(
reinterpret_cast<complex64_t*>(vals.data()),
shape,
specified_type.value_or(complex64));
}
default: {
std::ostringstream msg;
msg << "Should not happen, inferred: " << inferred_type
<< " on subarray made of only python primitive types.";
throw std::runtime_error(msg.str());
} }
} }
} }
@ -179,39 +262,20 @@ array array_from_list(T pl, std::optional<Dtype> dtype) {
get_shape(pl, shape); get_shape(pl, shape);
// Validate the shape and type // Validate the shape and type
auto type = validate_shape(pl, shape, 0); bool all_python_primitive_elements = true;
auto type = validate_shape(pl, shape, 0, all_python_primitive_elements);
size_t size = 1; if (all_python_primitive_elements) {
for (auto s : shape) { // `pl` does not contain mlx arrays
size *= s; return array_from_list(pl, type, dtype, shape);
} }
// Make the array // `pl` contains mlx arrays
switch (type) { std::vector<array> arrays;
case pybool: { for (auto l : pl) {
std::vector<bool> vals; arrays.push_back(create_array(py::cast<array_init_type>(l), dtype));
fill_vector(pl, vals);
return array(vals.begin(), shape, dtype.value_or(bool_));
}
case pyint: {
std::vector<int> vals;
fill_vector(pl, vals);
return array(vals.begin(), shape, dtype.value_or(int32));
}
case pyfloat: {
std::vector<float> vals;
fill_vector(pl, vals);
return array(vals.begin(), shape, dtype.value_or(float32));
}
case pycomplex: {
std::vector<std::complex<float>> vals;
fill_vector(pl, vals);
return array(
reinterpret_cast<complex64_t*>(vals.data()),
shape,
dtype.value_or(complex64));
}
} }
return stack(arrays);
} }
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
@ -419,6 +483,29 @@ array np_array_to_mlx(py::array np_array, std::optional<Dtype> dtype) {
// Module // Module
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
array create_array(array_init_type v, std::optional<Dtype> t) {
if (auto pv = std::get_if<py::bool_>(&v); pv) {
return array(py::cast<bool>(*pv), t.value_or(bool_));
} else if (auto pv = std::get_if<py::int_>(&v); pv) {
return array(py::cast<int>(*pv), t.value_or(int32));
} else if (auto pv = std::get_if<py::float_>(&v); pv) {
return array(py::cast<float>(*pv), t.value_or(float32));
} else if (auto pv = std::get_if<std::complex<float>>(&v); pv) {
return array(static_cast<complex64_t>(*pv), t.value_or(complex64));
} else if (auto pv = std::get_if<py::list>(&v); pv) {
return array_from_list(*pv, t);
} else if (auto pv = std::get_if<py::tuple>(&v); pv) {
return array_from_list(*pv, t);
} else if (auto pv = std::get_if<py::array>(&v); pv) {
return np_array_to_mlx(*pv, t);
} else if (auto pv = std::get_if<py::buffer>(&v); pv) {
return np_array_to_mlx(*pv, t);
} else {
auto arr = to_array_with_accessor(std::get<py::object>(v));
return astype(arr, t.value_or(arr.dtype()));
}
}
void init_array(py::module_& m) { void init_array(py::module_& m) {
// Types // Types
py::class_<Dtype>( py::class_<Dtype>(
@ -466,37 +553,8 @@ void init_array(py::module_& m) {
options.disable_function_signatures(); options.disable_function_signatures();
array_class.def( array_class.def(
py::init([](std::variant< py::init([](array_init_type v, std::optional<Dtype> t) {
py::bool_, return create_array(v, t);
py::int_,
py::float_,
std::complex<float>,
py::list,
py::tuple,
py::array,
py::buffer,
py::object> v,
std::optional<Dtype> t) {
if (auto pv = std::get_if<py::bool_>(&v); pv) {
return array(py::cast<bool>(*pv), t.value_or(bool_));
} else if (auto pv = std::get_if<py::int_>(&v); pv) {
return array(py::cast<int>(*pv), t.value_or(int32));
} else if (auto pv = std::get_if<py::float_>(&v); pv) {
return array(py::cast<float>(*pv), t.value_or(float32));
} else if (auto pv = std::get_if<std::complex<float>>(&v); pv) {
return array(static_cast<complex64_t>(*pv), t.value_or(complex64));
} else if (auto pv = std::get_if<py::list>(&v); pv) {
return array_from_list(*pv, t);
} else if (auto pv = std::get_if<py::tuple>(&v); pv) {
return array_from_list(*pv, t);
} else if (auto pv = std::get_if<py::array>(&v); pv) {
return np_array_to_mlx(*pv, t);
} else if (auto pv = std::get_if<py::buffer>(&v); pv) {
return np_array_to_mlx(*pv, t);
} else {
auto arr = to_array_with_accessor(std::get<py::object>(v));
return astype(arr, t.value_or(arr.dtype()));
}
}), }),
"val"_a, "val"_a,
"dtype"_a = std::nullopt, "dtype"_a = std::nullopt,

View File

@ -218,6 +218,64 @@ class TestArray(mlx_tests.MLXTestCase):
x = mx.array([1 + 0j, 2j, True, 0], mx.complex64) x = mx.array([1 + 0j, 2j, True, 0], mx.complex64)
self.assertEqual(x.tolist(), [1 + 0j, 2j, 1 + 0j, 0j]) self.assertEqual(x.tolist(), [1 + 0j, 2j, 1 + 0j, 0j])
def test_construction_from_lists_of_mlx_arrays(self):
dtypes = [
mx.bool_,
mx.uint8,
mx.uint16,
mx.uint32,
mx.uint64,
mx.int8,
mx.int16,
mx.int32,
mx.int64,
mx.float16,
mx.float32,
mx.bfloat16,
mx.complex64,
]
for x_t, y_t in permutations(dtypes, 2):
# check type promotion and numeric correctness
x, y = mx.array([1.0], x_t), mx.array([2.0], y_t)
z = mx.array([x, y])
expected = mx.stack([x, y], axis=0)
self.assertEqualArray(z, expected)
# check heterogeneous construction with mlx arrays and python primitive types
x, y = mx.array([True], x_t), mx.array([False], y_t)
z = mx.array([[x, [2.0]], [[3.0], y]])
expected = mx.array([[[x.item()], [2.0]], [[3.0], [y.item()]]], z.dtype)
self.assertEqualArray(z, expected)
# check when create from an array which does not contain memory to the raw data
x = mx.array([1.0]).astype(mx.bfloat16) # x does not hold raw data
for y_t in dtypes:
y = mx.array([2.0], y_t)
z = mx.array([x, y])
expected = mx.stack([x, y], axis=0)
self.assertEqualArray(z, expected)
# shape check from `stack()`
with self.assertRaises(ValueError) as e:
mx.array([x, 1.0])
self.assertEqual(str(e.exception), "All arrays must have the same shape")
# shape check from `validate_shape`
with self.assertRaises(ValueError) as e:
mx.array([1.0, x])
self.assertEqual(
str(e.exception), "Initialization encountered non-uniform length."
)
# check that `[mx.array, ...]` retains the `mx.array` in the graph
def f(x):
y = mx.array([x, mx.array([2.0])])
return (2 * y).sum()
x = mx.array([1.0])
dfdx = mx.grad(f)
self.assertEqual(dfdx(x).item(), 2.0)
def test_init_from_array(self): def test_init_from_array(self):
x = mx.array(3.0) x = mx.array(3.0)
y = mx.array(x) y = mx.array(x)