WIP (python)

This commit is contained in:
Ronan Collobert
2025-10-31 16:24:51 -07:00
parent 18aa921388
commit 9f649b5658
11 changed files with 54 additions and 52 deletions

View File

@@ -83,7 +83,7 @@ class ArrayPythonIterator {
throw nb::stop_iteration(); throw nb::stop_iteration();
} }
if (idx_ >= 0 && idx_ < splits_.size()) { if (idx_ >= 0 && idx_ < std::ssize(splits_)) {
return mx::squeeze(splits_[idx_++], 0); return mx::squeeze(splits_[idx_++], 0);
} }
@@ -390,7 +390,7 @@ void init_array(nb::module_& m) {
)pbdoc") )pbdoc")
.def( .def(
"__array_namespace__", "__array_namespace__",
[](const mx::array& a, [](const mx::array& /* a */,
const std::optional<std::string>& api_version) { const std::optional<std::string>& api_version) {
if (api_version) { if (api_version) {
throw std::invalid_argument( throw std::invalid_argument(
@@ -501,7 +501,7 @@ void init_array(nb::module_& m) {
.def("__dlpack__", [](const mx::array& a) { return mlx_to_dlpack(a); }) .def("__dlpack__", [](const mx::array& a) { return mlx_to_dlpack(a); })
.def( .def(
"__dlpack_device__", "__dlpack_device__",
[](const mx::array& a) { [](const mx::array& /* a */) {
// See // See
// https://github.com/dmlc/dlpack/blob/5c210da409e7f1e51ddf445134a4376fdbd70d7d/include/dlpack/dlpack.h#L74 // https://github.com/dmlc/dlpack/blob/5c210da409e7f1e51ddf445134a4376fdbd70d7d/include/dlpack/dlpack.h#L74
if (mx::metal::is_available()) { if (mx::metal::is_available()) {

View File

@@ -50,7 +50,7 @@ mx::array nd_array_to_mlx(
// Compute the shape and size // Compute the shape and size
mx::Shape shape; mx::Shape shape;
shape.reserve(nd_array.ndim()); shape.reserve(nd_array.ndim());
for (int i = 0; i < nd_array.ndim(); i++) { for (int i = 0; i < static_cast<int>(nd_array.ndim()); i++) {
shape.push_back(check_shape_dim(nd_array.shape(i))); shape.push_back(check_shape_dim(nd_array.shape(i)));
} }
auto type = nd_array.dtype(); auto type = nd_array.dtype();
@@ -289,7 +289,7 @@ PyScalarT validate_shape(
throw std::invalid_argument("Initialization encountered extra dimension."); throw std::invalid_argument("Initialization encountered extra dimension.");
} }
auto s = shape[idx]; auto s = shape[idx];
if (nb::len(list) != s) { if (nb::len(list) != static_cast<size_t>(s)) {
throw std::invalid_argument( throw std::invalid_argument(
"Initialization encountered non-uniform length."); "Initialization encountered non-uniform length.");
} }

View File

@@ -201,7 +201,6 @@ void init_fast(nb::module_& parent_module) {
bool has_mask = !std::holds_alternative<std::monostate>(mask); bool has_mask = !std::holds_alternative<std::monostate>(mask);
bool has_str_mask = bool has_str_mask =
has_mask && std::holds_alternative<std::string>(mask); has_mask && std::holds_alternative<std::string>(mask);
bool has_arr_mask = has_mask && std::holds_alternative<mx::array>(mask);
if (has_mask) { if (has_mask) {
if (has_str_mask) { if (has_str_mask) {

View File

@@ -115,7 +115,7 @@ mx::array mlx_gather_nd(
std::vector<bool> is_slice(indices.size(), false); std::vector<bool> is_slice(indices.size(), false);
int num_slices = 0; int num_slices = 0;
// gather all the arrays // gather all the arrays
for (int i = 0; i < indices.size(); i++) { for (int i = 0; i < std::ssize(indices); i++) {
auto& idx = indices[i]; auto& idx = indices[i];
if (nb::isinstance<nb::slice>(idx)) { if (nb::isinstance<nb::slice>(idx)) {
@@ -142,7 +142,7 @@ mx::array mlx_gather_nd(
// reshape them so that the int/array indices are first // reshape them so that the int/array indices are first
if (gather_first) { if (gather_first) {
int slice_index = 0; int slice_index = 0;
for (int i = 0; i < gather_indices.size(); i++) { for (int i = 0; i < std::ssize(gather_indices); i++) {
if (is_slice[i]) { if (is_slice[i]) {
mx::Shape index_shape(max_dims + num_slices, 1); mx::Shape index_shape(max_dims + num_slices, 1);
index_shape[max_dims + slice_index] = gather_indices[i].shape(0); index_shape[max_dims + slice_index] = gather_indices[i].shape(0);
@@ -156,7 +156,7 @@ mx::array mlx_gather_nd(
} }
} else { } else {
// reshape them so that the int/array indices are last // reshape them so that the int/array indices are last
for (int i = 0; i < gather_indices.size(); i++) { for (int i = 0; i < std::ssize(gather_indices); i++) {
if (i < num_slices) { if (i < num_slices) {
mx::Shape index_shape(max_dims + num_slices, 1); mx::Shape index_shape(max_dims + num_slices, 1);
index_shape[i] = gather_indices[i].shape(0); index_shape[i] = gather_indices[i].shape(0);
@@ -190,7 +190,7 @@ auto mlx_expand_ellipsis(const mx::Shape& shape, const nb::tuple& entries) {
bool has_ellipsis = false; bool has_ellipsis = false;
// Start from dimension 0 till we hit an ellipsis // Start from dimension 0 till we hit an ellipsis
for (; i < entries.size(); i++) { for (; i < std::ssize(entries); i++) {
auto idx = entries[i]; auto idx = entries[i];
if (!is_valid_index_type(idx)) { if (!is_valid_index_type(idx)) {
throw std::invalid_argument( throw std::invalid_argument(
@@ -301,7 +301,8 @@ mx::array mlx_get_item_nd(mx::array src, const nb::tuple& entries) {
if (have_array) { if (have_array) {
int last_array; int last_array;
// Then find the last array // Then find the last array
for (last_array = indices.size() - 1; last_array >= 0; last_array--) { for (last_array = std::ssize(indices) - 1; last_array >= 0;
last_array--) {
auto& idx = indices[last_array]; auto& idx = indices[last_array];
if (nb::isinstance<mx::array>(idx) || nb::isinstance<nb::int_>(idx)) { if (nb::isinstance<mx::array>(idx) || nb::isinstance<nb::int_>(idx)) {
break; break;
@@ -333,11 +334,11 @@ mx::array mlx_get_item_nd(mx::array src, const nb::tuple& entries) {
nb::slice(nb::none(), nb::none(), nb::none())); nb::slice(nb::none(), nb::none(), nb::none()));
} }
} }
for (int i = last_array + 1; i < indices.size(); i++) { for (int i = last_array + 1; i < std::ssize(indices); i++) {
remaining_indices.push_back(indices[i]); remaining_indices.push_back(indices[i]);
} }
} else { } else {
for (int i = 0; i < indices.size(); i++) { for (int i = 0; i < std::ssize(indices); i++) {
auto& idx = indices[i]; auto& idx = indices[i];
if (nb::isinstance<mx::array>(idx) || nb::isinstance<nb::int_>(idx)) { if (nb::isinstance<mx::array>(idx) || nb::isinstance<nb::int_>(idx)) {
break; break;
@@ -352,7 +353,7 @@ mx::array mlx_get_item_nd(mx::array src, const nb::tuple& entries) {
remaining_indices.push_back( remaining_indices.push_back(
nb::slice(nb::none(), nb::none(), nb::none())); nb::slice(nb::none(), nb::none(), nb::none()));
} }
for (int i = last_array + 1; i < indices.size(); i++) { for (int i = last_array + 1; i < std::ssize(indices); i++) {
remaining_indices.push_back(indices[i]); remaining_indices.push_back(indices[i]);
} }
} }
@@ -406,7 +407,7 @@ mx::array mlx_get_item_nd(mx::array src, const nb::tuple& entries) {
if (unsqueeze_needed || squeeze_needed) { if (unsqueeze_needed || squeeze_needed) {
std::vector<int> squeeze_axes; std::vector<int> squeeze_axes;
std::vector<int> unsqueeze_axes; std::vector<int> unsqueeze_axes;
for (int axis = 0; axis < remaining_indices.size(); ++axis) { for (int axis = 0; axis < std::ssize(remaining_indices); ++axis) {
auto& idx = remaining_indices[axis]; auto& idx = remaining_indices[axis];
if (unsqueeze_needed && idx.is_none()) { if (unsqueeze_needed && idx.is_none()) {
unsqueeze_axes.push_back(axis - squeeze_axes.size()); unsqueeze_axes.push_back(axis - squeeze_axes.size());
@@ -583,7 +584,7 @@ mlx_scatter_args_nd(
} }
// Analyse the types of the indices // Analyse the types of the indices
size_t max_dim = 0; int max_dim = 0;
bool arrays_first = false; bool arrays_first = false;
int num_none = 0; int num_none = 0;
int num_slices = 0; int num_slices = 0;
@@ -640,7 +641,7 @@ mlx_scatter_args_nd(
std::vector<int> update_shape(non_none_indices, 1); std::vector<int> update_shape(non_none_indices, 1);
std::vector<int> slice_shapes; std::vector<int> slice_shapes;
for (int i = 0; i < indices.size(); ++i) { for (int i = 0; i < std::ssize(indices); ++i) {
auto& pyidx = indices[i]; auto& pyidx = indices[i];
if (nb::isinstance<nb::slice>(pyidx)) { if (nb::isinstance<nb::slice>(pyidx)) {
mx::ShapeElem start, end, stride; mx::ShapeElem start, end, stride;
@@ -848,7 +849,7 @@ auto mlx_slice_update(
int unspecified = src.ndim() - non_none_indices; int unspecified = src.ndim() - non_none_indices;
std::vector<int> squeeze_dims; std::vector<int> squeeze_dims;
std::vector<int> expand_dims; std::vector<int> expand_dims;
for (int i = indices.size() - 1, for (int i = std::ssize(indices) - 1,
ax = non_none_indices - 1, ax = non_none_indices - 1,
upd_ax = upd.ndim() - unspecified - 1; upd_ax = upd.ndim() - unspecified - 1;
i >= 0; i >= 0;

View File

@@ -436,7 +436,7 @@ void mlx_savez_helper(
nb::cast<std::unordered_map<std::string, mx::array>>(kwargs); nb::cast<std::unordered_map<std::string, mx::array>>(kwargs);
auto arrays_list = nb::cast<std::vector<mx::array>>(args); auto arrays_list = nb::cast<std::vector<mx::array>>(args);
for (int i = 0; i < arrays_list.size(); i++) { for (int i = 0; i < std::ssize(arrays_list); i++) {
std::string arr_name = "arr_" + std::to_string(i); std::string arr_name = "arr_" + std::to_string(i);
if (arrays_dict.count(arr_name) > 0) { if (arrays_dict.count(arr_name) > 0) {

View File

@@ -22,7 +22,9 @@ bool DEPRECATE(const char* old_fn, const char* new_fn) {
return true; return true;
} }
#define DEPRECATE(oldfn, newfn) static bool dep = DEPRECATE(oldfn, newfn) #define DEPRECATE(oldfn, newfn) \
static bool dep = DEPRECATE(oldfn, newfn); \
(void)dep;
void init_metal(nb::module_& m) { void init_metal(nb::module_& m) {
nb::module_ metal = m.def_submodule("metal", "mlx.metal"); nb::module_ metal = m.def_submodule("metal", "mlx.metal");

View File

@@ -107,7 +107,7 @@ nb::callable mlx_func(
return nb::steal<nb::callable>((PyObject*)r); return nb::steal<nb::callable>((PyObject*)r);
} }
void init_mlx_func(nb::module_& m) { void init_mlx_func(nb::module_& /* m */) {
gc_func_tp = (PyTypeObject*)PyType_FromSpec(&gc_func_spec); gc_func_tp = (PyTypeObject*)PyType_FromSpec(&gc_func_spec);
if (!gc_func_tp) { if (!gc_func_tp) {
nb::raise("Could not register MLX function type."); nb::raise("Could not register MLX function type.");

View File

@@ -100,9 +100,9 @@ void init_stream(nb::module_& m) {
.def( .def(
"__exit__", "__exit__",
[](PyStreamContext& scm, [](PyStreamContext& scm,
const std::optional<nb::type_object>& exc_type, const std::optional<nb::type_object>& /* exc_type */,
const std::optional<nb::object>& exc_value, const std::optional<nb::object>& /* exc_value */,
const std::optional<nb::object>& traceback) { scm.exit(); }, const std::optional<nb::object>& /* traceback */) { scm.exit(); },
"exc_type"_a = nb::none(), "exc_type"_a = nb::none(),
"exc_value"_a = nb::none(), "exc_value"_a = nb::none(),
"traceback"_a = nb::none()); "traceback"_a = nb::none());

View File

@@ -86,7 +86,7 @@ auto py_value_and_grad(
<< argnums[0]; << argnums[0];
throw std::invalid_argument(msg.str()); throw std::invalid_argument(msg.str());
} }
for (int i = 1; i < argnums.size(); ++i) { for (int i = 1; i < std::ssize(argnums); ++i) {
if (argnums[i] == argnums[i - 1]) { if (argnums[i] == argnums[i - 1]) {
std::ostringstream msg; std::ostringstream msg;
msg << error_msg_tag << " Duplicate argument index " << argnums[0] msg << error_msg_tag << " Duplicate argument index " << argnums[0]
@@ -99,7 +99,7 @@ auto py_value_and_grad(
return [fun, argnums, argnames, error_msg_tag, scalar_func_only]( return [fun, argnums, argnames, error_msg_tag, scalar_func_only](
nb::args& args, nb::kwargs& kwargs) { nb::args& args, nb::kwargs& kwargs) {
// Sanitize the input // Sanitize the input
if (argnums.size() > 0 && argnums.back() >= args.size()) { if (argnums.size() > 0 && argnums.back() >= std::ssize(args)) {
std::ostringstream msg; std::ostringstream msg;
msg << error_msg_tag << " Can't compute the gradient of argument index " msg << error_msg_tag << " Can't compute the gradient of argument index "
<< argnums.back() << " because the function is called with only " << argnums.back() << " because the function is called with only "
@@ -126,8 +126,8 @@ auto py_value_and_grad(
std::vector<mx::array> arrays; std::vector<mx::array> arrays;
std::vector<int> counts(1, 0); std::vector<int> counts(1, 0);
std::vector<int> gradient_indices; std::vector<int> gradient_indices;
for (int i = 0, j = 0; i < args.size(); ++i) { for (int i = 0, j = 0; i < std::ssize(args); ++i) {
bool needs_grad = (j < argnums.size() && argnums[j] == i); bool needs_grad = (j < std::ssize(argnums) && argnums[j] == i);
auto argsi = tree_flatten(args[i], /* strict = */ needs_grad); auto argsi = tree_flatten(args[i], /* strict = */ needs_grad);
if (needs_grad) { if (needs_grad) {
auto old_size = gradient_indices.size(); auto old_size = gradient_indices.size();
@@ -257,7 +257,7 @@ auto py_value_and_grad(
positional_grads = tree_unflatten(args[argnums[0]], gradients, counts[0]); positional_grads = tree_unflatten(args[argnums[0]], gradients, counts[0]);
} else if (argnums.size() > 1) { } else if (argnums.size() > 1) {
nb::list grads_; nb::list grads_;
for (int i = 0; i < argnums.size(); i++) { for (int i = 0; i < std::ssize(argnums); i++) {
grads_.append(tree_unflatten(args[argnums[i]], gradients, counts[i])); grads_.append(tree_unflatten(args[argnums[i]], gradients, counts[i]));
} }
positional_grads = nb::tuple(grads_); positional_grads = nb::tuple(grads_);
@@ -366,14 +366,13 @@ auto py_vmap(
// able to reconstruct the python tree of extra return values // able to reconstruct the python tree of extra return values
nb::object py_outputs; nb::object py_outputs;
auto vmap_fn = auto vmap_fn = [&fun, &args, &py_outputs](const std::vector<mx::array>& a) {
[&fun, &args, &inputs, &py_outputs](const std::vector<mx::array>& a) { // Call the python function
// Call the python function py_outputs = fun(*tree_unflatten(args, a));
py_outputs = fun(*tree_unflatten(args, a));
// Flatten the outputs // Flatten the outputs
return tree_flatten(py_outputs, true); return tree_flatten(py_outputs, true);
}; };
auto [trace_inputs, trace_outputs] = auto [trace_inputs, trace_outputs] =
mx::detail::vmap_trace(vmap_fn, inputs, flat_in_axes); mx::detail::vmap_trace(vmap_fn, inputs, flat_in_axes);
@@ -451,7 +450,7 @@ struct PyCompiledFun {
if (nb::isinstance<nb::list>(obj)) { if (nb::isinstance<nb::list>(obj)) {
auto l = nb::cast<nb::list>(obj); auto l = nb::cast<nb::list>(obj);
constants.push_back(list_identifier); constants.push_back(list_identifier);
for (int i = 0; i < l.size(); ++i) { for (int i = 0; i < std::ssize(l); ++i) {
recurse(l[i]); recurse(l[i]);
} }
} else if (nb::isinstance<nb::tuple>(obj)) { } else if (nb::isinstance<nb::tuple>(obj)) {

View File

@@ -6,7 +6,8 @@ template <typename T, typename U, typename V>
void validate_subtrees(const std::vector<nb::object>& subtrees) { void validate_subtrees(const std::vector<nb::object>& subtrees) {
int len = nb::cast<T>(subtrees[0]).size(); int len = nb::cast<T>(subtrees[0]).size();
for (auto& subtree : subtrees) { for (auto& subtree : subtrees) {
if ((nb::isinstance<T>(subtree) && nb::cast<T>(subtree).size() != len) || if ((nb::isinstance<T>(subtree) &&
std::ssize(nb::cast<T>(subtree)) != len) ||
nb::isinstance<U>(subtree) || nb::isinstance<V>(subtree)) { nb::isinstance<U>(subtree) || nb::isinstance<V>(subtree)) {
throw std::invalid_argument( throw std::invalid_argument(
"[tree_map] Additional input tree is not a valid prefix of the first tree."); "[tree_map] Additional input tree is not a valid prefix of the first tree.");
@@ -24,8 +25,8 @@ nb::object tree_map(
nb::list l; nb::list l;
std::vector<nb::object> items(subtrees.size()); std::vector<nb::object> items(subtrees.size());
validate_subtrees<nb::list, nb::tuple, nb::dict>(subtrees); validate_subtrees<nb::list, nb::tuple, nb::dict>(subtrees);
for (int i = 0; i < nb::cast<nb::list>(subtrees[0]).size(); ++i) { for (int i = 0; i < std::ssize(nb::cast<nb::list>(subtrees[0])); ++i) {
for (int j = 0; j < subtrees.size(); ++j) { for (int j = 0; j < std::ssize(subtrees); ++j) {
if (nb::isinstance<nb::list>(subtrees[j])) { if (nb::isinstance<nb::list>(subtrees[j])) {
items[j] = nb::cast<nb::list>(subtrees[j])[i]; items[j] = nb::cast<nb::list>(subtrees[j])[i];
} else { } else {
@@ -42,7 +43,7 @@ nb::object tree_map(
nb::list l; nb::list l;
validate_subtrees<nb::tuple, nb::list, nb::dict>(subtrees); validate_subtrees<nb::tuple, nb::list, nb::dict>(subtrees);
for (int i = 0; i < len; ++i) { for (int i = 0; i < len; ++i) {
for (int j = 0; j < subtrees.size(); ++j) { for (int j = 0; j < std::ssize(subtrees); ++j) {
if (nb::isinstance<nb::tuple>(subtrees[j])) { if (nb::isinstance<nb::tuple>(subtrees[j])) {
items[j] = nb::cast<nb::tuple>(subtrees[j])[i]; items[j] = nb::cast<nb::tuple>(subtrees[j])[i];
} else { } else {
@@ -57,7 +58,7 @@ nb::object tree_map(
validate_subtrees<nb::dict, nb::list, nb::tuple>(subtrees); validate_subtrees<nb::dict, nb::list, nb::tuple>(subtrees);
nb::dict d; nb::dict d;
for (auto item : nb::cast<nb::dict>(subtrees[0])) { for (auto item : nb::cast<nb::dict>(subtrees[0])) {
for (int j = 0; j < subtrees.size(); ++j) { for (int j = 0; j < std::ssize(subtrees); ++j) {
if (nb::isinstance<nb::dict>(subtrees[j])) { if (nb::isinstance<nb::dict>(subtrees[j])) {
auto subdict = nb::cast<nb::dict>(subtrees[j]); auto subdict = nb::cast<nb::dict>(subtrees[j]);
if (!subdict.contains(item.first)) { if (!subdict.contains(item.first)) {
@@ -96,8 +97,8 @@ void tree_visit(
if (nb::isinstance<nb::list>(subtrees[0])) { if (nb::isinstance<nb::list>(subtrees[0])) {
std::vector<nb::object> items(subtrees.size()); std::vector<nb::object> items(subtrees.size());
validate_subtrees<nb::list, nb::tuple, nb::dict>(subtrees); validate_subtrees<nb::list, nb::tuple, nb::dict>(subtrees);
for (int i = 0; i < nb::cast<nb::list>(subtrees[0]).size(); ++i) { for (int i = 0; i < std::ssize(nb::cast<nb::list>(subtrees[0])); ++i) {
for (int j = 0; j < subtrees.size(); ++j) { for (int j = 0; j < std::ssize(subtrees); ++j) {
if (nb::isinstance<nb::list>(subtrees[j])) { if (nb::isinstance<nb::list>(subtrees[j])) {
items[j] = nb::cast<nb::list>(subtrees[j])[i]; items[j] = nb::cast<nb::list>(subtrees[j])[i];
} else { } else {
@@ -112,7 +113,7 @@ void tree_visit(
int len = nb::cast<nb::tuple>(subtrees[0]).size(); int len = nb::cast<nb::tuple>(subtrees[0]).size();
validate_subtrees<nb::tuple, nb::list, nb::dict>(subtrees); validate_subtrees<nb::tuple, nb::list, nb::dict>(subtrees);
for (int i = 0; i < len; ++i) { for (int i = 0; i < len; ++i) {
for (int j = 0; j < subtrees.size(); ++j) { for (int j = 0; j < std::ssize(subtrees); ++j) {
if (nb::isinstance<nb::tuple>(subtrees[j])) { if (nb::isinstance<nb::tuple>(subtrees[j])) {
items[j] = nb::cast<nb::tuple>(subtrees[j])[i]; items[j] = nb::cast<nb::tuple>(subtrees[j])[i];
} else { } else {
@@ -125,7 +126,7 @@ void tree_visit(
std::vector<nb::object> items(subtrees.size()); std::vector<nb::object> items(subtrees.size());
validate_subtrees<nb::dict, nb::list, nb::tuple>(subtrees); validate_subtrees<nb::dict, nb::list, nb::tuple>(subtrees);
for (auto item : nb::cast<nb::dict>(subtrees[0])) { for (auto item : nb::cast<nb::dict>(subtrees[0])) {
for (int j = 0; j < subtrees.size(); ++j) { for (int j = 0; j < std::ssize(subtrees); ++j) {
if (nb::isinstance<nb::dict>(subtrees[j])) { if (nb::isinstance<nb::dict>(subtrees[j])) {
auto subdict = nb::cast<nb::dict>(subtrees[j]); auto subdict = nb::cast<nb::dict>(subtrees[j]);
if (!subdict.contains(item.first)) { if (!subdict.contains(item.first)) {
@@ -173,13 +174,13 @@ void tree_visit_update(
recurse = [&](nb::handle subtree) { recurse = [&](nb::handle subtree) {
if (nb::isinstance<nb::list>(subtree)) { if (nb::isinstance<nb::list>(subtree)) {
auto l = nb::cast<nb::list>(subtree); auto l = nb::cast<nb::list>(subtree);
for (int i = 0; i < l.size(); ++i) { for (int i = 0; i < std::ssize(l); ++i) {
l[i] = recurse(l[i]); l[i] = recurse(l[i]);
} }
return nb::cast<nb::object>(l); return nb::cast<nb::object>(l);
} else if (nb::isinstance<nb::tuple>(subtree)) { } else if (nb::isinstance<nb::tuple>(subtree)) {
nb::list l(subtree); nb::list l(subtree);
for (int i = 0; i < l.size(); ++i) { for (int i = 0; i < std::ssize(l); ++i) {
l[i] = recurse(l[i]); l[i] = recurse(l[i]);
} }
return nb::cast<nb::object>(nb::tuple(l)); return nb::cast<nb::object>(nb::tuple(l));
@@ -204,7 +205,7 @@ void tree_visit_update(
void tree_fill(nb::object& tree, const std::vector<mx::array>& values) { void tree_fill(nb::object& tree, const std::vector<mx::array>& values) {
size_t index = 0; size_t index = 0;
tree_visit_update( tree_visit_update(
tree, [&](nb::handle node) { return nb::cast(values[index++]); }); tree, [&](nb::handle /* node */) { return nb::cast(values[index++]); });
} }
// Replace all the arrays from the src values with the dst values in the tree // Replace all the arrays from the src values with the dst values in the tree
@@ -213,7 +214,7 @@ void tree_replace(
const std::vector<mx::array>& src, const std::vector<mx::array>& src,
const std::vector<mx::array>& dst) { const std::vector<mx::array>& dst) {
std::unordered_map<uintptr_t, mx::array> src_to_dst; std::unordered_map<uintptr_t, mx::array> src_to_dst;
for (int i = 0; i < src.size(); ++i) { for (int i = 0; i < std::ssize(src); ++i) {
src_to_dst.insert({src[i].id(), dst[i]}); src_to_dst.insert({src[i].id(), dst[i]});
} }
tree_visit_update(tree, [&](nb::handle node) { tree_visit_update(tree, [&](nb::handle node) {

View File

@@ -57,8 +57,8 @@ std::pair<mx::array, mx::array> to_arrays(
// - If neither is an array convert to arrays but leave their types alone // - If neither is an array convert to arrays but leave their types alone
auto is_mlx_array = [](const ScalarOrArray& x) { auto is_mlx_array = [](const ScalarOrArray& x) {
return std::holds_alternative<mx::array>(x) || return std::holds_alternative<mx::array>(x) ||
std::holds_alternative<ArrayLike>(x) && (std::holds_alternative<ArrayLike>(x) &&
nb::hasattr(std::get<ArrayLike>(x).obj, "__mlx_array__"); nb::hasattr(std::get<ArrayLike>(x).obj, "__mlx_array__"));
}; };
auto get_mlx_array = [](const ScalarOrArray& x) { auto get_mlx_array = [](const ScalarOrArray& x) {
if (auto px = std::get_if<mx::array>(&x); px) { if (auto px = std::get_if<mx::array>(&x); px) {