mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-22 01:15:19 +08:00
support python mlx.array creation from list of mlx.array's (#325)
* support python mlx.array creation from list of mlx.array's * include bfloat16 in UT * refactor so that sub array made of all python primitive types gets initialized by fill_vector * address PR comment: arr.shape().size() -> arr.ndim() * address PR comment: get back Dtype constness and let stack to handle type promotions automatically
This commit is contained in:
parent
b9e415d19c
commit
d8f41a5c0f
@ -714,6 +714,7 @@ array stack(
|
||||
}
|
||||
return concatenate(new_arrays, axis, s);
|
||||
}
|
||||
|
||||
array stack(const std::vector<array>& arrays, StreamOrDevice s /* = {} */) {
|
||||
return stack(arrays, 0, s);
|
||||
}
|
||||
|
@ -120,7 +120,11 @@ void fill_vector(T list, std::vector<U>& vals) {
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
PyScalarT validate_shape(T list, const std::vector<int>& shape, int idx) {
|
||||
PyScalarT validate_shape(
|
||||
T list,
|
||||
const std::vector<int>& shape,
|
||||
int idx,
|
||||
bool& all_python_primitive_elements) {
|
||||
if (idx >= shape.size()) {
|
||||
throw std::invalid_argument("Initialization encountered extra dimension.");
|
||||
}
|
||||
@ -138,9 +142,17 @@ PyScalarT validate_shape(T list, const std::vector<int>& shape, int idx) {
|
||||
for (auto l : list) {
|
||||
PyScalarT t;
|
||||
if (py::isinstance<py::list>(l)) {
|
||||
t = validate_shape(l.template cast<py::list>(), shape, idx + 1);
|
||||
t = validate_shape(
|
||||
l.template cast<py::list>(),
|
||||
shape,
|
||||
idx + 1,
|
||||
all_python_primitive_elements);
|
||||
} else if (py::isinstance<py::tuple>(*list.begin())) {
|
||||
t = validate_shape(l.template cast<py::tuple>(), shape, idx + 1);
|
||||
t = validate_shape(
|
||||
l.template cast<py::tuple>(),
|
||||
shape,
|
||||
idx + 1,
|
||||
all_python_primitive_elements);
|
||||
} else if (py::isinstance<py::bool_>(l)) {
|
||||
t = pybool;
|
||||
} else if (py::isinstance<py::int_>(l)) {
|
||||
@ -149,6 +161,19 @@ PyScalarT validate_shape(T list, const std::vector<int>& shape, int idx) {
|
||||
t = pyfloat;
|
||||
} else if (PyComplex_Check(l.ptr())) {
|
||||
t = pycomplex;
|
||||
} else if (py::isinstance<array>(l)) {
|
||||
all_python_primitive_elements = false;
|
||||
auto arr = py::cast<array>(l);
|
||||
if (arr.ndim() + idx + 1 == shape.size() &&
|
||||
std::equal(
|
||||
arr.shape().cbegin(),
|
||||
arr.shape().cend(),
|
||||
shape.cbegin() + idx + 1)) {
|
||||
t = pybool;
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
"Initialization encountered non-uniform length.");
|
||||
}
|
||||
} else {
|
||||
std::ostringstream msg;
|
||||
msg << "Invalid type in array initialization" << l.get_type() << ".";
|
||||
@ -168,6 +193,64 @@ void get_shape(T list, std::vector<int>& shape) {
|
||||
return get_shape(l.template cast<py::list>(), shape);
|
||||
} else if (py::isinstance<py::tuple>(l)) {
|
||||
return get_shape(l.template cast<py::tuple>(), shape);
|
||||
} else if (py::isinstance<array>(l)) {
|
||||
auto arr = py::cast<array>(l);
|
||||
shape.insert(shape.end(), arr.shape().begin(), arr.shape().end());
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
using array_init_type = std::variant<
|
||||
py::bool_,
|
||||
py::int_,
|
||||
py::float_,
|
||||
std::complex<float>,
|
||||
py::list,
|
||||
py::tuple,
|
||||
py::array,
|
||||
py::buffer,
|
||||
py::object>;
|
||||
|
||||
// Forward declaration
|
||||
array create_array(array_init_type v, std::optional<Dtype> t);
|
||||
|
||||
template <typename T>
|
||||
array array_from_list(
|
||||
T pl,
|
||||
const PyScalarT& inferred_type,
|
||||
std::optional<Dtype> specified_type,
|
||||
const std::vector<int>& shape) {
|
||||
// Make the array
|
||||
switch (inferred_type) {
|
||||
case pybool: {
|
||||
std::vector<bool> vals;
|
||||
fill_vector(pl, vals);
|
||||
return array(vals.begin(), shape, specified_type.value_or(bool_));
|
||||
}
|
||||
case pyint: {
|
||||
std::vector<int> vals;
|
||||
fill_vector(pl, vals);
|
||||
return array(vals.begin(), shape, specified_type.value_or(int32));
|
||||
}
|
||||
case pyfloat: {
|
||||
std::vector<float> vals;
|
||||
fill_vector(pl, vals);
|
||||
return array(vals.begin(), shape, specified_type.value_or(float32));
|
||||
}
|
||||
case pycomplex: {
|
||||
std::vector<std::complex<float>> vals;
|
||||
fill_vector(pl, vals);
|
||||
return array(
|
||||
reinterpret_cast<complex64_t*>(vals.data()),
|
||||
shape,
|
||||
specified_type.value_or(complex64));
|
||||
}
|
||||
default: {
|
||||
std::ostringstream msg;
|
||||
msg << "Should not happen, inferred: " << inferred_type
|
||||
<< " on subarray made of only python primitive types.";
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -179,39 +262,20 @@ array array_from_list(T pl, std::optional<Dtype> dtype) {
|
||||
get_shape(pl, shape);
|
||||
|
||||
// Validate the shape and type
|
||||
auto type = validate_shape(pl, shape, 0);
|
||||
bool all_python_primitive_elements = true;
|
||||
auto type = validate_shape(pl, shape, 0, all_python_primitive_elements);
|
||||
|
||||
size_t size = 1;
|
||||
for (auto s : shape) {
|
||||
size *= s;
|
||||
if (all_python_primitive_elements) {
|
||||
// `pl` does not contain mlx arrays
|
||||
return array_from_list(pl, type, dtype, shape);
|
||||
}
|
||||
|
||||
// Make the array
|
||||
switch (type) {
|
||||
case pybool: {
|
||||
std::vector<bool> vals;
|
||||
fill_vector(pl, vals);
|
||||
return array(vals.begin(), shape, dtype.value_or(bool_));
|
||||
}
|
||||
case pyint: {
|
||||
std::vector<int> vals;
|
||||
fill_vector(pl, vals);
|
||||
return array(vals.begin(), shape, dtype.value_or(int32));
|
||||
}
|
||||
case pyfloat: {
|
||||
std::vector<float> vals;
|
||||
fill_vector(pl, vals);
|
||||
return array(vals.begin(), shape, dtype.value_or(float32));
|
||||
}
|
||||
case pycomplex: {
|
||||
std::vector<std::complex<float>> vals;
|
||||
fill_vector(pl, vals);
|
||||
return array(
|
||||
reinterpret_cast<complex64_t*>(vals.data()),
|
||||
shape,
|
||||
dtype.value_or(complex64));
|
||||
}
|
||||
// `pl` contains mlx arrays
|
||||
std::vector<array> arrays;
|
||||
for (auto l : pl) {
|
||||
arrays.push_back(create_array(py::cast<array_init_type>(l), dtype));
|
||||
}
|
||||
return stack(arrays);
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
@ -419,6 +483,29 @@ array np_array_to_mlx(py::array np_array, std::optional<Dtype> dtype) {
|
||||
// Module
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
array create_array(array_init_type v, std::optional<Dtype> t) {
|
||||
if (auto pv = std::get_if<py::bool_>(&v); pv) {
|
||||
return array(py::cast<bool>(*pv), t.value_or(bool_));
|
||||
} else if (auto pv = std::get_if<py::int_>(&v); pv) {
|
||||
return array(py::cast<int>(*pv), t.value_or(int32));
|
||||
} else if (auto pv = std::get_if<py::float_>(&v); pv) {
|
||||
return array(py::cast<float>(*pv), t.value_or(float32));
|
||||
} else if (auto pv = std::get_if<std::complex<float>>(&v); pv) {
|
||||
return array(static_cast<complex64_t>(*pv), t.value_or(complex64));
|
||||
} else if (auto pv = std::get_if<py::list>(&v); pv) {
|
||||
return array_from_list(*pv, t);
|
||||
} else if (auto pv = std::get_if<py::tuple>(&v); pv) {
|
||||
return array_from_list(*pv, t);
|
||||
} else if (auto pv = std::get_if<py::array>(&v); pv) {
|
||||
return np_array_to_mlx(*pv, t);
|
||||
} else if (auto pv = std::get_if<py::buffer>(&v); pv) {
|
||||
return np_array_to_mlx(*pv, t);
|
||||
} else {
|
||||
auto arr = to_array_with_accessor(std::get<py::object>(v));
|
||||
return astype(arr, t.value_or(arr.dtype()));
|
||||
}
|
||||
}
|
||||
|
||||
void init_array(py::module_& m) {
|
||||
// Types
|
||||
py::class_<Dtype>(
|
||||
@ -466,37 +553,8 @@ void init_array(py::module_& m) {
|
||||
options.disable_function_signatures();
|
||||
|
||||
array_class.def(
|
||||
py::init([](std::variant<
|
||||
py::bool_,
|
||||
py::int_,
|
||||
py::float_,
|
||||
std::complex<float>,
|
||||
py::list,
|
||||
py::tuple,
|
||||
py::array,
|
||||
py::buffer,
|
||||
py::object> v,
|
||||
std::optional<Dtype> t) {
|
||||
if (auto pv = std::get_if<py::bool_>(&v); pv) {
|
||||
return array(py::cast<bool>(*pv), t.value_or(bool_));
|
||||
} else if (auto pv = std::get_if<py::int_>(&v); pv) {
|
||||
return array(py::cast<int>(*pv), t.value_or(int32));
|
||||
} else if (auto pv = std::get_if<py::float_>(&v); pv) {
|
||||
return array(py::cast<float>(*pv), t.value_or(float32));
|
||||
} else if (auto pv = std::get_if<std::complex<float>>(&v); pv) {
|
||||
return array(static_cast<complex64_t>(*pv), t.value_or(complex64));
|
||||
} else if (auto pv = std::get_if<py::list>(&v); pv) {
|
||||
return array_from_list(*pv, t);
|
||||
} else if (auto pv = std::get_if<py::tuple>(&v); pv) {
|
||||
return array_from_list(*pv, t);
|
||||
} else if (auto pv = std::get_if<py::array>(&v); pv) {
|
||||
return np_array_to_mlx(*pv, t);
|
||||
} else if (auto pv = std::get_if<py::buffer>(&v); pv) {
|
||||
return np_array_to_mlx(*pv, t);
|
||||
} else {
|
||||
auto arr = to_array_with_accessor(std::get<py::object>(v));
|
||||
return astype(arr, t.value_or(arr.dtype()));
|
||||
}
|
||||
py::init([](array_init_type v, std::optional<Dtype> t) {
|
||||
return create_array(v, t);
|
||||
}),
|
||||
"val"_a,
|
||||
"dtype"_a = std::nullopt,
|
||||
|
@ -218,6 +218,64 @@ class TestArray(mlx_tests.MLXTestCase):
|
||||
x = mx.array([1 + 0j, 2j, True, 0], mx.complex64)
|
||||
self.assertEqual(x.tolist(), [1 + 0j, 2j, 1 + 0j, 0j])
|
||||
|
||||
def test_construction_from_lists_of_mlx_arrays(self):
|
||||
dtypes = [
|
||||
mx.bool_,
|
||||
mx.uint8,
|
||||
mx.uint16,
|
||||
mx.uint32,
|
||||
mx.uint64,
|
||||
mx.int8,
|
||||
mx.int16,
|
||||
mx.int32,
|
||||
mx.int64,
|
||||
mx.float16,
|
||||
mx.float32,
|
||||
mx.bfloat16,
|
||||
mx.complex64,
|
||||
]
|
||||
for x_t, y_t in permutations(dtypes, 2):
|
||||
# check type promotion and numeric correctness
|
||||
x, y = mx.array([1.0], x_t), mx.array([2.0], y_t)
|
||||
z = mx.array([x, y])
|
||||
expected = mx.stack([x, y], axis=0)
|
||||
self.assertEqualArray(z, expected)
|
||||
|
||||
# check heterogeneous construction with mlx arrays and python primitive types
|
||||
x, y = mx.array([True], x_t), mx.array([False], y_t)
|
||||
z = mx.array([[x, [2.0]], [[3.0], y]])
|
||||
expected = mx.array([[[x.item()], [2.0]], [[3.0], [y.item()]]], z.dtype)
|
||||
self.assertEqualArray(z, expected)
|
||||
|
||||
# check when create from an array which does not contain memory to the raw data
|
||||
x = mx.array([1.0]).astype(mx.bfloat16) # x does not hold raw data
|
||||
for y_t in dtypes:
|
||||
y = mx.array([2.0], y_t)
|
||||
z = mx.array([x, y])
|
||||
expected = mx.stack([x, y], axis=0)
|
||||
self.assertEqualArray(z, expected)
|
||||
|
||||
# shape check from `stack()`
|
||||
with self.assertRaises(ValueError) as e:
|
||||
mx.array([x, 1.0])
|
||||
self.assertEqual(str(e.exception), "All arrays must have the same shape")
|
||||
|
||||
# shape check from `validate_shape`
|
||||
with self.assertRaises(ValueError) as e:
|
||||
mx.array([1.0, x])
|
||||
self.assertEqual(
|
||||
str(e.exception), "Initialization encountered non-uniform length."
|
||||
)
|
||||
|
||||
# check that `[mx.array, ...]` retains the `mx.array` in the graph
|
||||
def f(x):
|
||||
y = mx.array([x, mx.array([2.0])])
|
||||
return (2 * y).sum()
|
||||
|
||||
x = mx.array([1.0])
|
||||
dfdx = mx.grad(f)
|
||||
self.assertEqual(dfdx(x).item(), 2.0)
|
||||
|
||||
def test_init_from_array(self):
|
||||
x = mx.array(3.0)
|
||||
y = mx.array(x)
|
||||
|
Loading…
Reference in New Issue
Block a user