mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
WIP (python)
This commit is contained in:
@@ -83,7 +83,7 @@ class ArrayPythonIterator {
|
|||||||
throw nb::stop_iteration();
|
throw nb::stop_iteration();
|
||||||
}
|
}
|
||||||
|
|
||||||
if (idx_ >= 0 && idx_ < splits_.size()) {
|
if (idx_ >= 0 && idx_ < std::ssize(splits_)) {
|
||||||
return mx::squeeze(splits_[idx_++], 0);
|
return mx::squeeze(splits_[idx_++], 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -390,7 +390,7 @@ void init_array(nb::module_& m) {
|
|||||||
)pbdoc")
|
)pbdoc")
|
||||||
.def(
|
.def(
|
||||||
"__array_namespace__",
|
"__array_namespace__",
|
||||||
[](const mx::array& a,
|
[](const mx::array& /* a */,
|
||||||
const std::optional<std::string>& api_version) {
|
const std::optional<std::string>& api_version) {
|
||||||
if (api_version) {
|
if (api_version) {
|
||||||
throw std::invalid_argument(
|
throw std::invalid_argument(
|
||||||
@@ -501,7 +501,7 @@ void init_array(nb::module_& m) {
|
|||||||
.def("__dlpack__", [](const mx::array& a) { return mlx_to_dlpack(a); })
|
.def("__dlpack__", [](const mx::array& a) { return mlx_to_dlpack(a); })
|
||||||
.def(
|
.def(
|
||||||
"__dlpack_device__",
|
"__dlpack_device__",
|
||||||
[](const mx::array& a) {
|
[](const mx::array& /* a */) {
|
||||||
// See
|
// See
|
||||||
// https://github.com/dmlc/dlpack/blob/5c210da409e7f1e51ddf445134a4376fdbd70d7d/include/dlpack/dlpack.h#L74
|
// https://github.com/dmlc/dlpack/blob/5c210da409e7f1e51ddf445134a4376fdbd70d7d/include/dlpack/dlpack.h#L74
|
||||||
if (mx::metal::is_available()) {
|
if (mx::metal::is_available()) {
|
||||||
|
|||||||
@@ -50,7 +50,7 @@ mx::array nd_array_to_mlx(
|
|||||||
// Compute the shape and size
|
// Compute the shape and size
|
||||||
mx::Shape shape;
|
mx::Shape shape;
|
||||||
shape.reserve(nd_array.ndim());
|
shape.reserve(nd_array.ndim());
|
||||||
for (int i = 0; i < nd_array.ndim(); i++) {
|
for (int i = 0; i < static_cast<int>(nd_array.ndim()); i++) {
|
||||||
shape.push_back(check_shape_dim(nd_array.shape(i)));
|
shape.push_back(check_shape_dim(nd_array.shape(i)));
|
||||||
}
|
}
|
||||||
auto type = nd_array.dtype();
|
auto type = nd_array.dtype();
|
||||||
@@ -289,7 +289,7 @@ PyScalarT validate_shape(
|
|||||||
throw std::invalid_argument("Initialization encountered extra dimension.");
|
throw std::invalid_argument("Initialization encountered extra dimension.");
|
||||||
}
|
}
|
||||||
auto s = shape[idx];
|
auto s = shape[idx];
|
||||||
if (nb::len(list) != s) {
|
if (nb::len(list) != static_cast<size_t>(s)) {
|
||||||
throw std::invalid_argument(
|
throw std::invalid_argument(
|
||||||
"Initialization encountered non-uniform length.");
|
"Initialization encountered non-uniform length.");
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -201,7 +201,6 @@ void init_fast(nb::module_& parent_module) {
|
|||||||
bool has_mask = !std::holds_alternative<std::monostate>(mask);
|
bool has_mask = !std::holds_alternative<std::monostate>(mask);
|
||||||
bool has_str_mask =
|
bool has_str_mask =
|
||||||
has_mask && std::holds_alternative<std::string>(mask);
|
has_mask && std::holds_alternative<std::string>(mask);
|
||||||
bool has_arr_mask = has_mask && std::holds_alternative<mx::array>(mask);
|
|
||||||
|
|
||||||
if (has_mask) {
|
if (has_mask) {
|
||||||
if (has_str_mask) {
|
if (has_str_mask) {
|
||||||
|
|||||||
@@ -115,7 +115,7 @@ mx::array mlx_gather_nd(
|
|||||||
std::vector<bool> is_slice(indices.size(), false);
|
std::vector<bool> is_slice(indices.size(), false);
|
||||||
int num_slices = 0;
|
int num_slices = 0;
|
||||||
// gather all the arrays
|
// gather all the arrays
|
||||||
for (int i = 0; i < indices.size(); i++) {
|
for (int i = 0; i < std::ssize(indices); i++) {
|
||||||
auto& idx = indices[i];
|
auto& idx = indices[i];
|
||||||
|
|
||||||
if (nb::isinstance<nb::slice>(idx)) {
|
if (nb::isinstance<nb::slice>(idx)) {
|
||||||
@@ -142,7 +142,7 @@ mx::array mlx_gather_nd(
|
|||||||
// reshape them so that the int/array indices are first
|
// reshape them so that the int/array indices are first
|
||||||
if (gather_first) {
|
if (gather_first) {
|
||||||
int slice_index = 0;
|
int slice_index = 0;
|
||||||
for (int i = 0; i < gather_indices.size(); i++) {
|
for (int i = 0; i < std::ssize(gather_indices); i++) {
|
||||||
if (is_slice[i]) {
|
if (is_slice[i]) {
|
||||||
mx::Shape index_shape(max_dims + num_slices, 1);
|
mx::Shape index_shape(max_dims + num_slices, 1);
|
||||||
index_shape[max_dims + slice_index] = gather_indices[i].shape(0);
|
index_shape[max_dims + slice_index] = gather_indices[i].shape(0);
|
||||||
@@ -156,7 +156,7 @@ mx::array mlx_gather_nd(
|
|||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// reshape them so that the int/array indices are last
|
// reshape them so that the int/array indices are last
|
||||||
for (int i = 0; i < gather_indices.size(); i++) {
|
for (int i = 0; i < std::ssize(gather_indices); i++) {
|
||||||
if (i < num_slices) {
|
if (i < num_slices) {
|
||||||
mx::Shape index_shape(max_dims + num_slices, 1);
|
mx::Shape index_shape(max_dims + num_slices, 1);
|
||||||
index_shape[i] = gather_indices[i].shape(0);
|
index_shape[i] = gather_indices[i].shape(0);
|
||||||
@@ -190,7 +190,7 @@ auto mlx_expand_ellipsis(const mx::Shape& shape, const nb::tuple& entries) {
|
|||||||
bool has_ellipsis = false;
|
bool has_ellipsis = false;
|
||||||
|
|
||||||
// Start from dimension 0 till we hit an ellipsis
|
// Start from dimension 0 till we hit an ellipsis
|
||||||
for (; i < entries.size(); i++) {
|
for (; i < std::ssize(entries); i++) {
|
||||||
auto idx = entries[i];
|
auto idx = entries[i];
|
||||||
if (!is_valid_index_type(idx)) {
|
if (!is_valid_index_type(idx)) {
|
||||||
throw std::invalid_argument(
|
throw std::invalid_argument(
|
||||||
@@ -301,7 +301,8 @@ mx::array mlx_get_item_nd(mx::array src, const nb::tuple& entries) {
|
|||||||
if (have_array) {
|
if (have_array) {
|
||||||
int last_array;
|
int last_array;
|
||||||
// Then find the last array
|
// Then find the last array
|
||||||
for (last_array = indices.size() - 1; last_array >= 0; last_array--) {
|
for (last_array = std::ssize(indices) - 1; last_array >= 0;
|
||||||
|
last_array--) {
|
||||||
auto& idx = indices[last_array];
|
auto& idx = indices[last_array];
|
||||||
if (nb::isinstance<mx::array>(idx) || nb::isinstance<nb::int_>(idx)) {
|
if (nb::isinstance<mx::array>(idx) || nb::isinstance<nb::int_>(idx)) {
|
||||||
break;
|
break;
|
||||||
@@ -333,11 +334,11 @@ mx::array mlx_get_item_nd(mx::array src, const nb::tuple& entries) {
|
|||||||
nb::slice(nb::none(), nb::none(), nb::none()));
|
nb::slice(nb::none(), nb::none(), nb::none()));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
for (int i = last_array + 1; i < indices.size(); i++) {
|
for (int i = last_array + 1; i < std::ssize(indices); i++) {
|
||||||
remaining_indices.push_back(indices[i]);
|
remaining_indices.push_back(indices[i]);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
for (int i = 0; i < indices.size(); i++) {
|
for (int i = 0; i < std::ssize(indices); i++) {
|
||||||
auto& idx = indices[i];
|
auto& idx = indices[i];
|
||||||
if (nb::isinstance<mx::array>(idx) || nb::isinstance<nb::int_>(idx)) {
|
if (nb::isinstance<mx::array>(idx) || nb::isinstance<nb::int_>(idx)) {
|
||||||
break;
|
break;
|
||||||
@@ -352,7 +353,7 @@ mx::array mlx_get_item_nd(mx::array src, const nb::tuple& entries) {
|
|||||||
remaining_indices.push_back(
|
remaining_indices.push_back(
|
||||||
nb::slice(nb::none(), nb::none(), nb::none()));
|
nb::slice(nb::none(), nb::none(), nb::none()));
|
||||||
}
|
}
|
||||||
for (int i = last_array + 1; i < indices.size(); i++) {
|
for (int i = last_array + 1; i < std::ssize(indices); i++) {
|
||||||
remaining_indices.push_back(indices[i]);
|
remaining_indices.push_back(indices[i]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -406,7 +407,7 @@ mx::array mlx_get_item_nd(mx::array src, const nb::tuple& entries) {
|
|||||||
if (unsqueeze_needed || squeeze_needed) {
|
if (unsqueeze_needed || squeeze_needed) {
|
||||||
std::vector<int> squeeze_axes;
|
std::vector<int> squeeze_axes;
|
||||||
std::vector<int> unsqueeze_axes;
|
std::vector<int> unsqueeze_axes;
|
||||||
for (int axis = 0; axis < remaining_indices.size(); ++axis) {
|
for (int axis = 0; axis < std::ssize(remaining_indices); ++axis) {
|
||||||
auto& idx = remaining_indices[axis];
|
auto& idx = remaining_indices[axis];
|
||||||
if (unsqueeze_needed && idx.is_none()) {
|
if (unsqueeze_needed && idx.is_none()) {
|
||||||
unsqueeze_axes.push_back(axis - squeeze_axes.size());
|
unsqueeze_axes.push_back(axis - squeeze_axes.size());
|
||||||
@@ -583,7 +584,7 @@ mlx_scatter_args_nd(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Analyse the types of the indices
|
// Analyse the types of the indices
|
||||||
size_t max_dim = 0;
|
int max_dim = 0;
|
||||||
bool arrays_first = false;
|
bool arrays_first = false;
|
||||||
int num_none = 0;
|
int num_none = 0;
|
||||||
int num_slices = 0;
|
int num_slices = 0;
|
||||||
@@ -640,7 +641,7 @@ mlx_scatter_args_nd(
|
|||||||
std::vector<int> update_shape(non_none_indices, 1);
|
std::vector<int> update_shape(non_none_indices, 1);
|
||||||
std::vector<int> slice_shapes;
|
std::vector<int> slice_shapes;
|
||||||
|
|
||||||
for (int i = 0; i < indices.size(); ++i) {
|
for (int i = 0; i < std::ssize(indices); ++i) {
|
||||||
auto& pyidx = indices[i];
|
auto& pyidx = indices[i];
|
||||||
if (nb::isinstance<nb::slice>(pyidx)) {
|
if (nb::isinstance<nb::slice>(pyidx)) {
|
||||||
mx::ShapeElem start, end, stride;
|
mx::ShapeElem start, end, stride;
|
||||||
@@ -848,7 +849,7 @@ auto mlx_slice_update(
|
|||||||
int unspecified = src.ndim() - non_none_indices;
|
int unspecified = src.ndim() - non_none_indices;
|
||||||
std::vector<int> squeeze_dims;
|
std::vector<int> squeeze_dims;
|
||||||
std::vector<int> expand_dims;
|
std::vector<int> expand_dims;
|
||||||
for (int i = indices.size() - 1,
|
for (int i = std::ssize(indices) - 1,
|
||||||
ax = non_none_indices - 1,
|
ax = non_none_indices - 1,
|
||||||
upd_ax = upd.ndim() - unspecified - 1;
|
upd_ax = upd.ndim() - unspecified - 1;
|
||||||
i >= 0;
|
i >= 0;
|
||||||
|
|||||||
@@ -436,7 +436,7 @@ void mlx_savez_helper(
|
|||||||
nb::cast<std::unordered_map<std::string, mx::array>>(kwargs);
|
nb::cast<std::unordered_map<std::string, mx::array>>(kwargs);
|
||||||
auto arrays_list = nb::cast<std::vector<mx::array>>(args);
|
auto arrays_list = nb::cast<std::vector<mx::array>>(args);
|
||||||
|
|
||||||
for (int i = 0; i < arrays_list.size(); i++) {
|
for (int i = 0; i < std::ssize(arrays_list); i++) {
|
||||||
std::string arr_name = "arr_" + std::to_string(i);
|
std::string arr_name = "arr_" + std::to_string(i);
|
||||||
|
|
||||||
if (arrays_dict.count(arr_name) > 0) {
|
if (arrays_dict.count(arr_name) > 0) {
|
||||||
|
|||||||
@@ -22,7 +22,9 @@ bool DEPRECATE(const char* old_fn, const char* new_fn) {
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
#define DEPRECATE(oldfn, newfn) static bool dep = DEPRECATE(oldfn, newfn)
|
#define DEPRECATE(oldfn, newfn) \
|
||||||
|
static bool dep = DEPRECATE(oldfn, newfn); \
|
||||||
|
(void)dep;
|
||||||
|
|
||||||
void init_metal(nb::module_& m) {
|
void init_metal(nb::module_& m) {
|
||||||
nb::module_ metal = m.def_submodule("metal", "mlx.metal");
|
nb::module_ metal = m.def_submodule("metal", "mlx.metal");
|
||||||
|
|||||||
@@ -107,7 +107,7 @@ nb::callable mlx_func(
|
|||||||
return nb::steal<nb::callable>((PyObject*)r);
|
return nb::steal<nb::callable>((PyObject*)r);
|
||||||
}
|
}
|
||||||
|
|
||||||
void init_mlx_func(nb::module_& m) {
|
void init_mlx_func(nb::module_& /* m */) {
|
||||||
gc_func_tp = (PyTypeObject*)PyType_FromSpec(&gc_func_spec);
|
gc_func_tp = (PyTypeObject*)PyType_FromSpec(&gc_func_spec);
|
||||||
if (!gc_func_tp) {
|
if (!gc_func_tp) {
|
||||||
nb::raise("Could not register MLX function type.");
|
nb::raise("Could not register MLX function type.");
|
||||||
|
|||||||
@@ -100,9 +100,9 @@ void init_stream(nb::module_& m) {
|
|||||||
.def(
|
.def(
|
||||||
"__exit__",
|
"__exit__",
|
||||||
[](PyStreamContext& scm,
|
[](PyStreamContext& scm,
|
||||||
const std::optional<nb::type_object>& exc_type,
|
const std::optional<nb::type_object>& /* exc_type */,
|
||||||
const std::optional<nb::object>& exc_value,
|
const std::optional<nb::object>& /* exc_value */,
|
||||||
const std::optional<nb::object>& traceback) { scm.exit(); },
|
const std::optional<nb::object>& /* traceback */) { scm.exit(); },
|
||||||
"exc_type"_a = nb::none(),
|
"exc_type"_a = nb::none(),
|
||||||
"exc_value"_a = nb::none(),
|
"exc_value"_a = nb::none(),
|
||||||
"traceback"_a = nb::none());
|
"traceback"_a = nb::none());
|
||||||
|
|||||||
@@ -86,7 +86,7 @@ auto py_value_and_grad(
|
|||||||
<< argnums[0];
|
<< argnums[0];
|
||||||
throw std::invalid_argument(msg.str());
|
throw std::invalid_argument(msg.str());
|
||||||
}
|
}
|
||||||
for (int i = 1; i < argnums.size(); ++i) {
|
for (int i = 1; i < std::ssize(argnums); ++i) {
|
||||||
if (argnums[i] == argnums[i - 1]) {
|
if (argnums[i] == argnums[i - 1]) {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
msg << error_msg_tag << " Duplicate argument index " << argnums[0]
|
msg << error_msg_tag << " Duplicate argument index " << argnums[0]
|
||||||
@@ -99,7 +99,7 @@ auto py_value_and_grad(
|
|||||||
return [fun, argnums, argnames, error_msg_tag, scalar_func_only](
|
return [fun, argnums, argnames, error_msg_tag, scalar_func_only](
|
||||||
nb::args& args, nb::kwargs& kwargs) {
|
nb::args& args, nb::kwargs& kwargs) {
|
||||||
// Sanitize the input
|
// Sanitize the input
|
||||||
if (argnums.size() > 0 && argnums.back() >= args.size()) {
|
if (argnums.size() > 0 && argnums.back() >= std::ssize(args)) {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
msg << error_msg_tag << " Can't compute the gradient of argument index "
|
msg << error_msg_tag << " Can't compute the gradient of argument index "
|
||||||
<< argnums.back() << " because the function is called with only "
|
<< argnums.back() << " because the function is called with only "
|
||||||
@@ -126,8 +126,8 @@ auto py_value_and_grad(
|
|||||||
std::vector<mx::array> arrays;
|
std::vector<mx::array> arrays;
|
||||||
std::vector<int> counts(1, 0);
|
std::vector<int> counts(1, 0);
|
||||||
std::vector<int> gradient_indices;
|
std::vector<int> gradient_indices;
|
||||||
for (int i = 0, j = 0; i < args.size(); ++i) {
|
for (int i = 0, j = 0; i < std::ssize(args); ++i) {
|
||||||
bool needs_grad = (j < argnums.size() && argnums[j] == i);
|
bool needs_grad = (j < std::ssize(argnums) && argnums[j] == i);
|
||||||
auto argsi = tree_flatten(args[i], /* strict = */ needs_grad);
|
auto argsi = tree_flatten(args[i], /* strict = */ needs_grad);
|
||||||
if (needs_grad) {
|
if (needs_grad) {
|
||||||
auto old_size = gradient_indices.size();
|
auto old_size = gradient_indices.size();
|
||||||
@@ -257,7 +257,7 @@ auto py_value_and_grad(
|
|||||||
positional_grads = tree_unflatten(args[argnums[0]], gradients, counts[0]);
|
positional_grads = tree_unflatten(args[argnums[0]], gradients, counts[0]);
|
||||||
} else if (argnums.size() > 1) {
|
} else if (argnums.size() > 1) {
|
||||||
nb::list grads_;
|
nb::list grads_;
|
||||||
for (int i = 0; i < argnums.size(); i++) {
|
for (int i = 0; i < std::ssize(argnums); i++) {
|
||||||
grads_.append(tree_unflatten(args[argnums[i]], gradients, counts[i]));
|
grads_.append(tree_unflatten(args[argnums[i]], gradients, counts[i]));
|
||||||
}
|
}
|
||||||
positional_grads = nb::tuple(grads_);
|
positional_grads = nb::tuple(grads_);
|
||||||
@@ -366,14 +366,13 @@ auto py_vmap(
|
|||||||
// able to reconstruct the python tree of extra return values
|
// able to reconstruct the python tree of extra return values
|
||||||
nb::object py_outputs;
|
nb::object py_outputs;
|
||||||
|
|
||||||
auto vmap_fn =
|
auto vmap_fn = [&fun, &args, &py_outputs](const std::vector<mx::array>& a) {
|
||||||
[&fun, &args, &inputs, &py_outputs](const std::vector<mx::array>& a) {
|
// Call the python function
|
||||||
// Call the python function
|
py_outputs = fun(*tree_unflatten(args, a));
|
||||||
py_outputs = fun(*tree_unflatten(args, a));
|
|
||||||
|
|
||||||
// Flatten the outputs
|
// Flatten the outputs
|
||||||
return tree_flatten(py_outputs, true);
|
return tree_flatten(py_outputs, true);
|
||||||
};
|
};
|
||||||
|
|
||||||
auto [trace_inputs, trace_outputs] =
|
auto [trace_inputs, trace_outputs] =
|
||||||
mx::detail::vmap_trace(vmap_fn, inputs, flat_in_axes);
|
mx::detail::vmap_trace(vmap_fn, inputs, flat_in_axes);
|
||||||
@@ -451,7 +450,7 @@ struct PyCompiledFun {
|
|||||||
if (nb::isinstance<nb::list>(obj)) {
|
if (nb::isinstance<nb::list>(obj)) {
|
||||||
auto l = nb::cast<nb::list>(obj);
|
auto l = nb::cast<nb::list>(obj);
|
||||||
constants.push_back(list_identifier);
|
constants.push_back(list_identifier);
|
||||||
for (int i = 0; i < l.size(); ++i) {
|
for (int i = 0; i < std::ssize(l); ++i) {
|
||||||
recurse(l[i]);
|
recurse(l[i]);
|
||||||
}
|
}
|
||||||
} else if (nb::isinstance<nb::tuple>(obj)) {
|
} else if (nb::isinstance<nb::tuple>(obj)) {
|
||||||
|
|||||||
@@ -6,7 +6,8 @@ template <typename T, typename U, typename V>
|
|||||||
void validate_subtrees(const std::vector<nb::object>& subtrees) {
|
void validate_subtrees(const std::vector<nb::object>& subtrees) {
|
||||||
int len = nb::cast<T>(subtrees[0]).size();
|
int len = nb::cast<T>(subtrees[0]).size();
|
||||||
for (auto& subtree : subtrees) {
|
for (auto& subtree : subtrees) {
|
||||||
if ((nb::isinstance<T>(subtree) && nb::cast<T>(subtree).size() != len) ||
|
if ((nb::isinstance<T>(subtree) &&
|
||||||
|
std::ssize(nb::cast<T>(subtree)) != len) ||
|
||||||
nb::isinstance<U>(subtree) || nb::isinstance<V>(subtree)) {
|
nb::isinstance<U>(subtree) || nb::isinstance<V>(subtree)) {
|
||||||
throw std::invalid_argument(
|
throw std::invalid_argument(
|
||||||
"[tree_map] Additional input tree is not a valid prefix of the first tree.");
|
"[tree_map] Additional input tree is not a valid prefix of the first tree.");
|
||||||
@@ -24,8 +25,8 @@ nb::object tree_map(
|
|||||||
nb::list l;
|
nb::list l;
|
||||||
std::vector<nb::object> items(subtrees.size());
|
std::vector<nb::object> items(subtrees.size());
|
||||||
validate_subtrees<nb::list, nb::tuple, nb::dict>(subtrees);
|
validate_subtrees<nb::list, nb::tuple, nb::dict>(subtrees);
|
||||||
for (int i = 0; i < nb::cast<nb::list>(subtrees[0]).size(); ++i) {
|
for (int i = 0; i < std::ssize(nb::cast<nb::list>(subtrees[0])); ++i) {
|
||||||
for (int j = 0; j < subtrees.size(); ++j) {
|
for (int j = 0; j < std::ssize(subtrees); ++j) {
|
||||||
if (nb::isinstance<nb::list>(subtrees[j])) {
|
if (nb::isinstance<nb::list>(subtrees[j])) {
|
||||||
items[j] = nb::cast<nb::list>(subtrees[j])[i];
|
items[j] = nb::cast<nb::list>(subtrees[j])[i];
|
||||||
} else {
|
} else {
|
||||||
@@ -42,7 +43,7 @@ nb::object tree_map(
|
|||||||
nb::list l;
|
nb::list l;
|
||||||
validate_subtrees<nb::tuple, nb::list, nb::dict>(subtrees);
|
validate_subtrees<nb::tuple, nb::list, nb::dict>(subtrees);
|
||||||
for (int i = 0; i < len; ++i) {
|
for (int i = 0; i < len; ++i) {
|
||||||
for (int j = 0; j < subtrees.size(); ++j) {
|
for (int j = 0; j < std::ssize(subtrees); ++j) {
|
||||||
if (nb::isinstance<nb::tuple>(subtrees[j])) {
|
if (nb::isinstance<nb::tuple>(subtrees[j])) {
|
||||||
items[j] = nb::cast<nb::tuple>(subtrees[j])[i];
|
items[j] = nb::cast<nb::tuple>(subtrees[j])[i];
|
||||||
} else {
|
} else {
|
||||||
@@ -57,7 +58,7 @@ nb::object tree_map(
|
|||||||
validate_subtrees<nb::dict, nb::list, nb::tuple>(subtrees);
|
validate_subtrees<nb::dict, nb::list, nb::tuple>(subtrees);
|
||||||
nb::dict d;
|
nb::dict d;
|
||||||
for (auto item : nb::cast<nb::dict>(subtrees[0])) {
|
for (auto item : nb::cast<nb::dict>(subtrees[0])) {
|
||||||
for (int j = 0; j < subtrees.size(); ++j) {
|
for (int j = 0; j < std::ssize(subtrees); ++j) {
|
||||||
if (nb::isinstance<nb::dict>(subtrees[j])) {
|
if (nb::isinstance<nb::dict>(subtrees[j])) {
|
||||||
auto subdict = nb::cast<nb::dict>(subtrees[j]);
|
auto subdict = nb::cast<nb::dict>(subtrees[j]);
|
||||||
if (!subdict.contains(item.first)) {
|
if (!subdict.contains(item.first)) {
|
||||||
@@ -96,8 +97,8 @@ void tree_visit(
|
|||||||
if (nb::isinstance<nb::list>(subtrees[0])) {
|
if (nb::isinstance<nb::list>(subtrees[0])) {
|
||||||
std::vector<nb::object> items(subtrees.size());
|
std::vector<nb::object> items(subtrees.size());
|
||||||
validate_subtrees<nb::list, nb::tuple, nb::dict>(subtrees);
|
validate_subtrees<nb::list, nb::tuple, nb::dict>(subtrees);
|
||||||
for (int i = 0; i < nb::cast<nb::list>(subtrees[0]).size(); ++i) {
|
for (int i = 0; i < std::ssize(nb::cast<nb::list>(subtrees[0])); ++i) {
|
||||||
for (int j = 0; j < subtrees.size(); ++j) {
|
for (int j = 0; j < std::ssize(subtrees); ++j) {
|
||||||
if (nb::isinstance<nb::list>(subtrees[j])) {
|
if (nb::isinstance<nb::list>(subtrees[j])) {
|
||||||
items[j] = nb::cast<nb::list>(subtrees[j])[i];
|
items[j] = nb::cast<nb::list>(subtrees[j])[i];
|
||||||
} else {
|
} else {
|
||||||
@@ -112,7 +113,7 @@ void tree_visit(
|
|||||||
int len = nb::cast<nb::tuple>(subtrees[0]).size();
|
int len = nb::cast<nb::tuple>(subtrees[0]).size();
|
||||||
validate_subtrees<nb::tuple, nb::list, nb::dict>(subtrees);
|
validate_subtrees<nb::tuple, nb::list, nb::dict>(subtrees);
|
||||||
for (int i = 0; i < len; ++i) {
|
for (int i = 0; i < len; ++i) {
|
||||||
for (int j = 0; j < subtrees.size(); ++j) {
|
for (int j = 0; j < std::ssize(subtrees); ++j) {
|
||||||
if (nb::isinstance<nb::tuple>(subtrees[j])) {
|
if (nb::isinstance<nb::tuple>(subtrees[j])) {
|
||||||
items[j] = nb::cast<nb::tuple>(subtrees[j])[i];
|
items[j] = nb::cast<nb::tuple>(subtrees[j])[i];
|
||||||
} else {
|
} else {
|
||||||
@@ -125,7 +126,7 @@ void tree_visit(
|
|||||||
std::vector<nb::object> items(subtrees.size());
|
std::vector<nb::object> items(subtrees.size());
|
||||||
validate_subtrees<nb::dict, nb::list, nb::tuple>(subtrees);
|
validate_subtrees<nb::dict, nb::list, nb::tuple>(subtrees);
|
||||||
for (auto item : nb::cast<nb::dict>(subtrees[0])) {
|
for (auto item : nb::cast<nb::dict>(subtrees[0])) {
|
||||||
for (int j = 0; j < subtrees.size(); ++j) {
|
for (int j = 0; j < std::ssize(subtrees); ++j) {
|
||||||
if (nb::isinstance<nb::dict>(subtrees[j])) {
|
if (nb::isinstance<nb::dict>(subtrees[j])) {
|
||||||
auto subdict = nb::cast<nb::dict>(subtrees[j]);
|
auto subdict = nb::cast<nb::dict>(subtrees[j]);
|
||||||
if (!subdict.contains(item.first)) {
|
if (!subdict.contains(item.first)) {
|
||||||
@@ -173,13 +174,13 @@ void tree_visit_update(
|
|||||||
recurse = [&](nb::handle subtree) {
|
recurse = [&](nb::handle subtree) {
|
||||||
if (nb::isinstance<nb::list>(subtree)) {
|
if (nb::isinstance<nb::list>(subtree)) {
|
||||||
auto l = nb::cast<nb::list>(subtree);
|
auto l = nb::cast<nb::list>(subtree);
|
||||||
for (int i = 0; i < l.size(); ++i) {
|
for (int i = 0; i < std::ssize(l); ++i) {
|
||||||
l[i] = recurse(l[i]);
|
l[i] = recurse(l[i]);
|
||||||
}
|
}
|
||||||
return nb::cast<nb::object>(l);
|
return nb::cast<nb::object>(l);
|
||||||
} else if (nb::isinstance<nb::tuple>(subtree)) {
|
} else if (nb::isinstance<nb::tuple>(subtree)) {
|
||||||
nb::list l(subtree);
|
nb::list l(subtree);
|
||||||
for (int i = 0; i < l.size(); ++i) {
|
for (int i = 0; i < std::ssize(l); ++i) {
|
||||||
l[i] = recurse(l[i]);
|
l[i] = recurse(l[i]);
|
||||||
}
|
}
|
||||||
return nb::cast<nb::object>(nb::tuple(l));
|
return nb::cast<nb::object>(nb::tuple(l));
|
||||||
@@ -204,7 +205,7 @@ void tree_visit_update(
|
|||||||
void tree_fill(nb::object& tree, const std::vector<mx::array>& values) {
|
void tree_fill(nb::object& tree, const std::vector<mx::array>& values) {
|
||||||
size_t index = 0;
|
size_t index = 0;
|
||||||
tree_visit_update(
|
tree_visit_update(
|
||||||
tree, [&](nb::handle node) { return nb::cast(values[index++]); });
|
tree, [&](nb::handle /* node */) { return nb::cast(values[index++]); });
|
||||||
}
|
}
|
||||||
|
|
||||||
// Replace all the arrays from the src values with the dst values in the tree
|
// Replace all the arrays from the src values with the dst values in the tree
|
||||||
@@ -213,7 +214,7 @@ void tree_replace(
|
|||||||
const std::vector<mx::array>& src,
|
const std::vector<mx::array>& src,
|
||||||
const std::vector<mx::array>& dst) {
|
const std::vector<mx::array>& dst) {
|
||||||
std::unordered_map<uintptr_t, mx::array> src_to_dst;
|
std::unordered_map<uintptr_t, mx::array> src_to_dst;
|
||||||
for (int i = 0; i < src.size(); ++i) {
|
for (int i = 0; i < std::ssize(src); ++i) {
|
||||||
src_to_dst.insert({src[i].id(), dst[i]});
|
src_to_dst.insert({src[i].id(), dst[i]});
|
||||||
}
|
}
|
||||||
tree_visit_update(tree, [&](nb::handle node) {
|
tree_visit_update(tree, [&](nb::handle node) {
|
||||||
|
|||||||
@@ -57,8 +57,8 @@ std::pair<mx::array, mx::array> to_arrays(
|
|||||||
// - If neither is an array convert to arrays but leave their types alone
|
// - If neither is an array convert to arrays but leave their types alone
|
||||||
auto is_mlx_array = [](const ScalarOrArray& x) {
|
auto is_mlx_array = [](const ScalarOrArray& x) {
|
||||||
return std::holds_alternative<mx::array>(x) ||
|
return std::holds_alternative<mx::array>(x) ||
|
||||||
std::holds_alternative<ArrayLike>(x) &&
|
(std::holds_alternative<ArrayLike>(x) &&
|
||||||
nb::hasattr(std::get<ArrayLike>(x).obj, "__mlx_array__");
|
nb::hasattr(std::get<ArrayLike>(x).obj, "__mlx_array__"));
|
||||||
};
|
};
|
||||||
auto get_mlx_array = [](const ScalarOrArray& x) {
|
auto get_mlx_array = [](const ScalarOrArray& x) {
|
||||||
if (auto px = std::get_if<mx::array>(&x); px) {
|
if (auto px = std::get_if<mx::array>(&x); px) {
|
||||||
|
|||||||
Reference in New Issue
Block a user