mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-22 01:21:14 +08:00
support python mlx.array creation from list of mlx.array's (#325)
* support python mlx.array creation from list of mlx.array's * include bfloat16 in UT * refactor so that sub array made of all python primitive types gets initialized by fill_vector * address PR comment: arr.shape().size() -> arr.ndim() * address PR comment: get back Dtype constness and let stack to handle type promotions automatically
This commit is contained in:
parent
b9e415d19c
commit
d8f41a5c0f
@ -714,6 +714,7 @@ array stack(
|
|||||||
}
|
}
|
||||||
return concatenate(new_arrays, axis, s);
|
return concatenate(new_arrays, axis, s);
|
||||||
}
|
}
|
||||||
|
|
||||||
array stack(const std::vector<array>& arrays, StreamOrDevice s /* = {} */) {
|
array stack(const std::vector<array>& arrays, StreamOrDevice s /* = {} */) {
|
||||||
return stack(arrays, 0, s);
|
return stack(arrays, 0, s);
|
||||||
}
|
}
|
||||||
|
@ -120,7 +120,11 @@ void fill_vector(T list, std::vector<U>& vals) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
PyScalarT validate_shape(T list, const std::vector<int>& shape, int idx) {
|
PyScalarT validate_shape(
|
||||||
|
T list,
|
||||||
|
const std::vector<int>& shape,
|
||||||
|
int idx,
|
||||||
|
bool& all_python_primitive_elements) {
|
||||||
if (idx >= shape.size()) {
|
if (idx >= shape.size()) {
|
||||||
throw std::invalid_argument("Initialization encountered extra dimension.");
|
throw std::invalid_argument("Initialization encountered extra dimension.");
|
||||||
}
|
}
|
||||||
@ -138,9 +142,17 @@ PyScalarT validate_shape(T list, const std::vector<int>& shape, int idx) {
|
|||||||
for (auto l : list) {
|
for (auto l : list) {
|
||||||
PyScalarT t;
|
PyScalarT t;
|
||||||
if (py::isinstance<py::list>(l)) {
|
if (py::isinstance<py::list>(l)) {
|
||||||
t = validate_shape(l.template cast<py::list>(), shape, idx + 1);
|
t = validate_shape(
|
||||||
|
l.template cast<py::list>(),
|
||||||
|
shape,
|
||||||
|
idx + 1,
|
||||||
|
all_python_primitive_elements);
|
||||||
} else if (py::isinstance<py::tuple>(*list.begin())) {
|
} else if (py::isinstance<py::tuple>(*list.begin())) {
|
||||||
t = validate_shape(l.template cast<py::tuple>(), shape, idx + 1);
|
t = validate_shape(
|
||||||
|
l.template cast<py::tuple>(),
|
||||||
|
shape,
|
||||||
|
idx + 1,
|
||||||
|
all_python_primitive_elements);
|
||||||
} else if (py::isinstance<py::bool_>(l)) {
|
} else if (py::isinstance<py::bool_>(l)) {
|
||||||
t = pybool;
|
t = pybool;
|
||||||
} else if (py::isinstance<py::int_>(l)) {
|
} else if (py::isinstance<py::int_>(l)) {
|
||||||
@ -149,6 +161,19 @@ PyScalarT validate_shape(T list, const std::vector<int>& shape, int idx) {
|
|||||||
t = pyfloat;
|
t = pyfloat;
|
||||||
} else if (PyComplex_Check(l.ptr())) {
|
} else if (PyComplex_Check(l.ptr())) {
|
||||||
t = pycomplex;
|
t = pycomplex;
|
||||||
|
} else if (py::isinstance<array>(l)) {
|
||||||
|
all_python_primitive_elements = false;
|
||||||
|
auto arr = py::cast<array>(l);
|
||||||
|
if (arr.ndim() + idx + 1 == shape.size() &&
|
||||||
|
std::equal(
|
||||||
|
arr.shape().cbegin(),
|
||||||
|
arr.shape().cend(),
|
||||||
|
shape.cbegin() + idx + 1)) {
|
||||||
|
t = pybool;
|
||||||
|
} else {
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"Initialization encountered non-uniform length.");
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
msg << "Invalid type in array initialization" << l.get_type() << ".";
|
msg << "Invalid type in array initialization" << l.get_type() << ".";
|
||||||
@ -168,6 +193,64 @@ void get_shape(T list, std::vector<int>& shape) {
|
|||||||
return get_shape(l.template cast<py::list>(), shape);
|
return get_shape(l.template cast<py::list>(), shape);
|
||||||
} else if (py::isinstance<py::tuple>(l)) {
|
} else if (py::isinstance<py::tuple>(l)) {
|
||||||
return get_shape(l.template cast<py::tuple>(), shape);
|
return get_shape(l.template cast<py::tuple>(), shape);
|
||||||
|
} else if (py::isinstance<array>(l)) {
|
||||||
|
auto arr = py::cast<array>(l);
|
||||||
|
shape.insert(shape.end(), arr.shape().begin(), arr.shape().end());
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
using array_init_type = std::variant<
|
||||||
|
py::bool_,
|
||||||
|
py::int_,
|
||||||
|
py::float_,
|
||||||
|
std::complex<float>,
|
||||||
|
py::list,
|
||||||
|
py::tuple,
|
||||||
|
py::array,
|
||||||
|
py::buffer,
|
||||||
|
py::object>;
|
||||||
|
|
||||||
|
// Forward declaration
|
||||||
|
array create_array(array_init_type v, std::optional<Dtype> t);
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
array array_from_list(
|
||||||
|
T pl,
|
||||||
|
const PyScalarT& inferred_type,
|
||||||
|
std::optional<Dtype> specified_type,
|
||||||
|
const std::vector<int>& shape) {
|
||||||
|
// Make the array
|
||||||
|
switch (inferred_type) {
|
||||||
|
case pybool: {
|
||||||
|
std::vector<bool> vals;
|
||||||
|
fill_vector(pl, vals);
|
||||||
|
return array(vals.begin(), shape, specified_type.value_or(bool_));
|
||||||
|
}
|
||||||
|
case pyint: {
|
||||||
|
std::vector<int> vals;
|
||||||
|
fill_vector(pl, vals);
|
||||||
|
return array(vals.begin(), shape, specified_type.value_or(int32));
|
||||||
|
}
|
||||||
|
case pyfloat: {
|
||||||
|
std::vector<float> vals;
|
||||||
|
fill_vector(pl, vals);
|
||||||
|
return array(vals.begin(), shape, specified_type.value_or(float32));
|
||||||
|
}
|
||||||
|
case pycomplex: {
|
||||||
|
std::vector<std::complex<float>> vals;
|
||||||
|
fill_vector(pl, vals);
|
||||||
|
return array(
|
||||||
|
reinterpret_cast<complex64_t*>(vals.data()),
|
||||||
|
shape,
|
||||||
|
specified_type.value_or(complex64));
|
||||||
|
}
|
||||||
|
default: {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "Should not happen, inferred: " << inferred_type
|
||||||
|
<< " on subarray made of only python primitive types.";
|
||||||
|
throw std::runtime_error(msg.str());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -179,39 +262,20 @@ array array_from_list(T pl, std::optional<Dtype> dtype) {
|
|||||||
get_shape(pl, shape);
|
get_shape(pl, shape);
|
||||||
|
|
||||||
// Validate the shape and type
|
// Validate the shape and type
|
||||||
auto type = validate_shape(pl, shape, 0);
|
bool all_python_primitive_elements = true;
|
||||||
|
auto type = validate_shape(pl, shape, 0, all_python_primitive_elements);
|
||||||
|
|
||||||
size_t size = 1;
|
if (all_python_primitive_elements) {
|
||||||
for (auto s : shape) {
|
// `pl` does not contain mlx arrays
|
||||||
size *= s;
|
return array_from_list(pl, type, dtype, shape);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Make the array
|
// `pl` contains mlx arrays
|
||||||
switch (type) {
|
std::vector<array> arrays;
|
||||||
case pybool: {
|
for (auto l : pl) {
|
||||||
std::vector<bool> vals;
|
arrays.push_back(create_array(py::cast<array_init_type>(l), dtype));
|
||||||
fill_vector(pl, vals);
|
|
||||||
return array(vals.begin(), shape, dtype.value_or(bool_));
|
|
||||||
}
|
|
||||||
case pyint: {
|
|
||||||
std::vector<int> vals;
|
|
||||||
fill_vector(pl, vals);
|
|
||||||
return array(vals.begin(), shape, dtype.value_or(int32));
|
|
||||||
}
|
|
||||||
case pyfloat: {
|
|
||||||
std::vector<float> vals;
|
|
||||||
fill_vector(pl, vals);
|
|
||||||
return array(vals.begin(), shape, dtype.value_or(float32));
|
|
||||||
}
|
|
||||||
case pycomplex: {
|
|
||||||
std::vector<std::complex<float>> vals;
|
|
||||||
fill_vector(pl, vals);
|
|
||||||
return array(
|
|
||||||
reinterpret_cast<complex64_t*>(vals.data()),
|
|
||||||
shape,
|
|
||||||
dtype.value_or(complex64));
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
return stack(arrays);
|
||||||
}
|
}
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
@ -419,6 +483,29 @@ array np_array_to_mlx(py::array np_array, std::optional<Dtype> dtype) {
|
|||||||
// Module
|
// Module
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
array create_array(array_init_type v, std::optional<Dtype> t) {
|
||||||
|
if (auto pv = std::get_if<py::bool_>(&v); pv) {
|
||||||
|
return array(py::cast<bool>(*pv), t.value_or(bool_));
|
||||||
|
} else if (auto pv = std::get_if<py::int_>(&v); pv) {
|
||||||
|
return array(py::cast<int>(*pv), t.value_or(int32));
|
||||||
|
} else if (auto pv = std::get_if<py::float_>(&v); pv) {
|
||||||
|
return array(py::cast<float>(*pv), t.value_or(float32));
|
||||||
|
} else if (auto pv = std::get_if<std::complex<float>>(&v); pv) {
|
||||||
|
return array(static_cast<complex64_t>(*pv), t.value_or(complex64));
|
||||||
|
} else if (auto pv = std::get_if<py::list>(&v); pv) {
|
||||||
|
return array_from_list(*pv, t);
|
||||||
|
} else if (auto pv = std::get_if<py::tuple>(&v); pv) {
|
||||||
|
return array_from_list(*pv, t);
|
||||||
|
} else if (auto pv = std::get_if<py::array>(&v); pv) {
|
||||||
|
return np_array_to_mlx(*pv, t);
|
||||||
|
} else if (auto pv = std::get_if<py::buffer>(&v); pv) {
|
||||||
|
return np_array_to_mlx(*pv, t);
|
||||||
|
} else {
|
||||||
|
auto arr = to_array_with_accessor(std::get<py::object>(v));
|
||||||
|
return astype(arr, t.value_or(arr.dtype()));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void init_array(py::module_& m) {
|
void init_array(py::module_& m) {
|
||||||
// Types
|
// Types
|
||||||
py::class_<Dtype>(
|
py::class_<Dtype>(
|
||||||
@ -466,37 +553,8 @@ void init_array(py::module_& m) {
|
|||||||
options.disable_function_signatures();
|
options.disable_function_signatures();
|
||||||
|
|
||||||
array_class.def(
|
array_class.def(
|
||||||
py::init([](std::variant<
|
py::init([](array_init_type v, std::optional<Dtype> t) {
|
||||||
py::bool_,
|
return create_array(v, t);
|
||||||
py::int_,
|
|
||||||
py::float_,
|
|
||||||
std::complex<float>,
|
|
||||||
py::list,
|
|
||||||
py::tuple,
|
|
||||||
py::array,
|
|
||||||
py::buffer,
|
|
||||||
py::object> v,
|
|
||||||
std::optional<Dtype> t) {
|
|
||||||
if (auto pv = std::get_if<py::bool_>(&v); pv) {
|
|
||||||
return array(py::cast<bool>(*pv), t.value_or(bool_));
|
|
||||||
} else if (auto pv = std::get_if<py::int_>(&v); pv) {
|
|
||||||
return array(py::cast<int>(*pv), t.value_or(int32));
|
|
||||||
} else if (auto pv = std::get_if<py::float_>(&v); pv) {
|
|
||||||
return array(py::cast<float>(*pv), t.value_or(float32));
|
|
||||||
} else if (auto pv = std::get_if<std::complex<float>>(&v); pv) {
|
|
||||||
return array(static_cast<complex64_t>(*pv), t.value_or(complex64));
|
|
||||||
} else if (auto pv = std::get_if<py::list>(&v); pv) {
|
|
||||||
return array_from_list(*pv, t);
|
|
||||||
} else if (auto pv = std::get_if<py::tuple>(&v); pv) {
|
|
||||||
return array_from_list(*pv, t);
|
|
||||||
} else if (auto pv = std::get_if<py::array>(&v); pv) {
|
|
||||||
return np_array_to_mlx(*pv, t);
|
|
||||||
} else if (auto pv = std::get_if<py::buffer>(&v); pv) {
|
|
||||||
return np_array_to_mlx(*pv, t);
|
|
||||||
} else {
|
|
||||||
auto arr = to_array_with_accessor(std::get<py::object>(v));
|
|
||||||
return astype(arr, t.value_or(arr.dtype()));
|
|
||||||
}
|
|
||||||
}),
|
}),
|
||||||
"val"_a,
|
"val"_a,
|
||||||
"dtype"_a = std::nullopt,
|
"dtype"_a = std::nullopt,
|
||||||
|
@ -218,6 +218,64 @@ class TestArray(mlx_tests.MLXTestCase):
|
|||||||
x = mx.array([1 + 0j, 2j, True, 0], mx.complex64)
|
x = mx.array([1 + 0j, 2j, True, 0], mx.complex64)
|
||||||
self.assertEqual(x.tolist(), [1 + 0j, 2j, 1 + 0j, 0j])
|
self.assertEqual(x.tolist(), [1 + 0j, 2j, 1 + 0j, 0j])
|
||||||
|
|
||||||
|
def test_construction_from_lists_of_mlx_arrays(self):
|
||||||
|
dtypes = [
|
||||||
|
mx.bool_,
|
||||||
|
mx.uint8,
|
||||||
|
mx.uint16,
|
||||||
|
mx.uint32,
|
||||||
|
mx.uint64,
|
||||||
|
mx.int8,
|
||||||
|
mx.int16,
|
||||||
|
mx.int32,
|
||||||
|
mx.int64,
|
||||||
|
mx.float16,
|
||||||
|
mx.float32,
|
||||||
|
mx.bfloat16,
|
||||||
|
mx.complex64,
|
||||||
|
]
|
||||||
|
for x_t, y_t in permutations(dtypes, 2):
|
||||||
|
# check type promotion and numeric correctness
|
||||||
|
x, y = mx.array([1.0], x_t), mx.array([2.0], y_t)
|
||||||
|
z = mx.array([x, y])
|
||||||
|
expected = mx.stack([x, y], axis=0)
|
||||||
|
self.assertEqualArray(z, expected)
|
||||||
|
|
||||||
|
# check heterogeneous construction with mlx arrays and python primitive types
|
||||||
|
x, y = mx.array([True], x_t), mx.array([False], y_t)
|
||||||
|
z = mx.array([[x, [2.0]], [[3.0], y]])
|
||||||
|
expected = mx.array([[[x.item()], [2.0]], [[3.0], [y.item()]]], z.dtype)
|
||||||
|
self.assertEqualArray(z, expected)
|
||||||
|
|
||||||
|
# check when create from an array which does not contain memory to the raw data
|
||||||
|
x = mx.array([1.0]).astype(mx.bfloat16) # x does not hold raw data
|
||||||
|
for y_t in dtypes:
|
||||||
|
y = mx.array([2.0], y_t)
|
||||||
|
z = mx.array([x, y])
|
||||||
|
expected = mx.stack([x, y], axis=0)
|
||||||
|
self.assertEqualArray(z, expected)
|
||||||
|
|
||||||
|
# shape check from `stack()`
|
||||||
|
with self.assertRaises(ValueError) as e:
|
||||||
|
mx.array([x, 1.0])
|
||||||
|
self.assertEqual(str(e.exception), "All arrays must have the same shape")
|
||||||
|
|
||||||
|
# shape check from `validate_shape`
|
||||||
|
with self.assertRaises(ValueError) as e:
|
||||||
|
mx.array([1.0, x])
|
||||||
|
self.assertEqual(
|
||||||
|
str(e.exception), "Initialization encountered non-uniform length."
|
||||||
|
)
|
||||||
|
|
||||||
|
# check that `[mx.array, ...]` retains the `mx.array` in the graph
|
||||||
|
def f(x):
|
||||||
|
y = mx.array([x, mx.array([2.0])])
|
||||||
|
return (2 * y).sum()
|
||||||
|
|
||||||
|
x = mx.array([1.0])
|
||||||
|
dfdx = mx.grad(f)
|
||||||
|
self.assertEqual(dfdx(x).item(), 2.0)
|
||||||
|
|
||||||
def test_init_from_array(self):
|
def test_init_from_array(self):
|
||||||
x = mx.array(3.0)
|
x = mx.array(3.0)
|
||||||
y = mx.array(x)
|
y = mx.array(x)
|
||||||
|
Loading…
Reference in New Issue
Block a user