mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
WIP (python)
This commit is contained in:
@@ -83,7 +83,7 @@ class ArrayPythonIterator {
|
||||
throw nb::stop_iteration();
|
||||
}
|
||||
|
||||
if (idx_ >= 0 && idx_ < splits_.size()) {
|
||||
if (idx_ >= 0 && idx_ < std::ssize(splits_)) {
|
||||
return mx::squeeze(splits_[idx_++], 0);
|
||||
}
|
||||
|
||||
@@ -390,7 +390,7 @@ void init_array(nb::module_& m) {
|
||||
)pbdoc")
|
||||
.def(
|
||||
"__array_namespace__",
|
||||
[](const mx::array& a,
|
||||
[](const mx::array& /* a */,
|
||||
const std::optional<std::string>& api_version) {
|
||||
if (api_version) {
|
||||
throw std::invalid_argument(
|
||||
@@ -501,7 +501,7 @@ void init_array(nb::module_& m) {
|
||||
.def("__dlpack__", [](const mx::array& a) { return mlx_to_dlpack(a); })
|
||||
.def(
|
||||
"__dlpack_device__",
|
||||
[](const mx::array& a) {
|
||||
[](const mx::array& /* a */) {
|
||||
// See
|
||||
// https://github.com/dmlc/dlpack/blob/5c210da409e7f1e51ddf445134a4376fdbd70d7d/include/dlpack/dlpack.h#L74
|
||||
if (mx::metal::is_available()) {
|
||||
|
||||
@@ -50,7 +50,7 @@ mx::array nd_array_to_mlx(
|
||||
// Compute the shape and size
|
||||
mx::Shape shape;
|
||||
shape.reserve(nd_array.ndim());
|
||||
for (int i = 0; i < nd_array.ndim(); i++) {
|
||||
for (int i = 0; i < static_cast<int>(nd_array.ndim()); i++) {
|
||||
shape.push_back(check_shape_dim(nd_array.shape(i)));
|
||||
}
|
||||
auto type = nd_array.dtype();
|
||||
@@ -289,7 +289,7 @@ PyScalarT validate_shape(
|
||||
throw std::invalid_argument("Initialization encountered extra dimension.");
|
||||
}
|
||||
auto s = shape[idx];
|
||||
if (nb::len(list) != s) {
|
||||
if (nb::len(list) != static_cast<size_t>(s)) {
|
||||
throw std::invalid_argument(
|
||||
"Initialization encountered non-uniform length.");
|
||||
}
|
||||
|
||||
@@ -201,7 +201,6 @@ void init_fast(nb::module_& parent_module) {
|
||||
bool has_mask = !std::holds_alternative<std::monostate>(mask);
|
||||
bool has_str_mask =
|
||||
has_mask && std::holds_alternative<std::string>(mask);
|
||||
bool has_arr_mask = has_mask && std::holds_alternative<mx::array>(mask);
|
||||
|
||||
if (has_mask) {
|
||||
if (has_str_mask) {
|
||||
|
||||
@@ -115,7 +115,7 @@ mx::array mlx_gather_nd(
|
||||
std::vector<bool> is_slice(indices.size(), false);
|
||||
int num_slices = 0;
|
||||
// gather all the arrays
|
||||
for (int i = 0; i < indices.size(); i++) {
|
||||
for (int i = 0; i < std::ssize(indices); i++) {
|
||||
auto& idx = indices[i];
|
||||
|
||||
if (nb::isinstance<nb::slice>(idx)) {
|
||||
@@ -142,7 +142,7 @@ mx::array mlx_gather_nd(
|
||||
// reshape them so that the int/array indices are first
|
||||
if (gather_first) {
|
||||
int slice_index = 0;
|
||||
for (int i = 0; i < gather_indices.size(); i++) {
|
||||
for (int i = 0; i < std::ssize(gather_indices); i++) {
|
||||
if (is_slice[i]) {
|
||||
mx::Shape index_shape(max_dims + num_slices, 1);
|
||||
index_shape[max_dims + slice_index] = gather_indices[i].shape(0);
|
||||
@@ -156,7 +156,7 @@ mx::array mlx_gather_nd(
|
||||
}
|
||||
} else {
|
||||
// reshape them so that the int/array indices are last
|
||||
for (int i = 0; i < gather_indices.size(); i++) {
|
||||
for (int i = 0; i < std::ssize(gather_indices); i++) {
|
||||
if (i < num_slices) {
|
||||
mx::Shape index_shape(max_dims + num_slices, 1);
|
||||
index_shape[i] = gather_indices[i].shape(0);
|
||||
@@ -190,7 +190,7 @@ auto mlx_expand_ellipsis(const mx::Shape& shape, const nb::tuple& entries) {
|
||||
bool has_ellipsis = false;
|
||||
|
||||
// Start from dimension 0 till we hit an ellipsis
|
||||
for (; i < entries.size(); i++) {
|
||||
for (; i < std::ssize(entries); i++) {
|
||||
auto idx = entries[i];
|
||||
if (!is_valid_index_type(idx)) {
|
||||
throw std::invalid_argument(
|
||||
@@ -301,7 +301,8 @@ mx::array mlx_get_item_nd(mx::array src, const nb::tuple& entries) {
|
||||
if (have_array) {
|
||||
int last_array;
|
||||
// Then find the last array
|
||||
for (last_array = indices.size() - 1; last_array >= 0; last_array--) {
|
||||
for (last_array = std::ssize(indices) - 1; last_array >= 0;
|
||||
last_array--) {
|
||||
auto& idx = indices[last_array];
|
||||
if (nb::isinstance<mx::array>(idx) || nb::isinstance<nb::int_>(idx)) {
|
||||
break;
|
||||
@@ -333,11 +334,11 @@ mx::array mlx_get_item_nd(mx::array src, const nb::tuple& entries) {
|
||||
nb::slice(nb::none(), nb::none(), nb::none()));
|
||||
}
|
||||
}
|
||||
for (int i = last_array + 1; i < indices.size(); i++) {
|
||||
for (int i = last_array + 1; i < std::ssize(indices); i++) {
|
||||
remaining_indices.push_back(indices[i]);
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < indices.size(); i++) {
|
||||
for (int i = 0; i < std::ssize(indices); i++) {
|
||||
auto& idx = indices[i];
|
||||
if (nb::isinstance<mx::array>(idx) || nb::isinstance<nb::int_>(idx)) {
|
||||
break;
|
||||
@@ -352,7 +353,7 @@ mx::array mlx_get_item_nd(mx::array src, const nb::tuple& entries) {
|
||||
remaining_indices.push_back(
|
||||
nb::slice(nb::none(), nb::none(), nb::none()));
|
||||
}
|
||||
for (int i = last_array + 1; i < indices.size(); i++) {
|
||||
for (int i = last_array + 1; i < std::ssize(indices); i++) {
|
||||
remaining_indices.push_back(indices[i]);
|
||||
}
|
||||
}
|
||||
@@ -406,7 +407,7 @@ mx::array mlx_get_item_nd(mx::array src, const nb::tuple& entries) {
|
||||
if (unsqueeze_needed || squeeze_needed) {
|
||||
std::vector<int> squeeze_axes;
|
||||
std::vector<int> unsqueeze_axes;
|
||||
for (int axis = 0; axis < remaining_indices.size(); ++axis) {
|
||||
for (int axis = 0; axis < std::ssize(remaining_indices); ++axis) {
|
||||
auto& idx = remaining_indices[axis];
|
||||
if (unsqueeze_needed && idx.is_none()) {
|
||||
unsqueeze_axes.push_back(axis - squeeze_axes.size());
|
||||
@@ -583,7 +584,7 @@ mlx_scatter_args_nd(
|
||||
}
|
||||
|
||||
// Analyse the types of the indices
|
||||
size_t max_dim = 0;
|
||||
int max_dim = 0;
|
||||
bool arrays_first = false;
|
||||
int num_none = 0;
|
||||
int num_slices = 0;
|
||||
@@ -640,7 +641,7 @@ mlx_scatter_args_nd(
|
||||
std::vector<int> update_shape(non_none_indices, 1);
|
||||
std::vector<int> slice_shapes;
|
||||
|
||||
for (int i = 0; i < indices.size(); ++i) {
|
||||
for (int i = 0; i < std::ssize(indices); ++i) {
|
||||
auto& pyidx = indices[i];
|
||||
if (nb::isinstance<nb::slice>(pyidx)) {
|
||||
mx::ShapeElem start, end, stride;
|
||||
@@ -848,7 +849,7 @@ auto mlx_slice_update(
|
||||
int unspecified = src.ndim() - non_none_indices;
|
||||
std::vector<int> squeeze_dims;
|
||||
std::vector<int> expand_dims;
|
||||
for (int i = indices.size() - 1,
|
||||
for (int i = std::ssize(indices) - 1,
|
||||
ax = non_none_indices - 1,
|
||||
upd_ax = upd.ndim() - unspecified - 1;
|
||||
i >= 0;
|
||||
|
||||
@@ -436,7 +436,7 @@ void mlx_savez_helper(
|
||||
nb::cast<std::unordered_map<std::string, mx::array>>(kwargs);
|
||||
auto arrays_list = nb::cast<std::vector<mx::array>>(args);
|
||||
|
||||
for (int i = 0; i < arrays_list.size(); i++) {
|
||||
for (int i = 0; i < std::ssize(arrays_list); i++) {
|
||||
std::string arr_name = "arr_" + std::to_string(i);
|
||||
|
||||
if (arrays_dict.count(arr_name) > 0) {
|
||||
|
||||
@@ -22,7 +22,9 @@ bool DEPRECATE(const char* old_fn, const char* new_fn) {
|
||||
return true;
|
||||
}
|
||||
|
||||
#define DEPRECATE(oldfn, newfn) static bool dep = DEPRECATE(oldfn, newfn)
|
||||
#define DEPRECATE(oldfn, newfn) \
|
||||
static bool dep = DEPRECATE(oldfn, newfn); \
|
||||
(void)dep;
|
||||
|
||||
void init_metal(nb::module_& m) {
|
||||
nb::module_ metal = m.def_submodule("metal", "mlx.metal");
|
||||
|
||||
@@ -107,7 +107,7 @@ nb::callable mlx_func(
|
||||
return nb::steal<nb::callable>((PyObject*)r);
|
||||
}
|
||||
|
||||
void init_mlx_func(nb::module_& m) {
|
||||
void init_mlx_func(nb::module_& /* m */) {
|
||||
gc_func_tp = (PyTypeObject*)PyType_FromSpec(&gc_func_spec);
|
||||
if (!gc_func_tp) {
|
||||
nb::raise("Could not register MLX function type.");
|
||||
|
||||
@@ -100,9 +100,9 @@ void init_stream(nb::module_& m) {
|
||||
.def(
|
||||
"__exit__",
|
||||
[](PyStreamContext& scm,
|
||||
const std::optional<nb::type_object>& exc_type,
|
||||
const std::optional<nb::object>& exc_value,
|
||||
const std::optional<nb::object>& traceback) { scm.exit(); },
|
||||
const std::optional<nb::type_object>& /* exc_type */,
|
||||
const std::optional<nb::object>& /* exc_value */,
|
||||
const std::optional<nb::object>& /* traceback */) { scm.exit(); },
|
||||
"exc_type"_a = nb::none(),
|
||||
"exc_value"_a = nb::none(),
|
||||
"traceback"_a = nb::none());
|
||||
|
||||
@@ -86,7 +86,7 @@ auto py_value_and_grad(
|
||||
<< argnums[0];
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
for (int i = 1; i < argnums.size(); ++i) {
|
||||
for (int i = 1; i < std::ssize(argnums); ++i) {
|
||||
if (argnums[i] == argnums[i - 1]) {
|
||||
std::ostringstream msg;
|
||||
msg << error_msg_tag << " Duplicate argument index " << argnums[0]
|
||||
@@ -99,7 +99,7 @@ auto py_value_and_grad(
|
||||
return [fun, argnums, argnames, error_msg_tag, scalar_func_only](
|
||||
nb::args& args, nb::kwargs& kwargs) {
|
||||
// Sanitize the input
|
||||
if (argnums.size() > 0 && argnums.back() >= args.size()) {
|
||||
if (argnums.size() > 0 && argnums.back() >= std::ssize(args)) {
|
||||
std::ostringstream msg;
|
||||
msg << error_msg_tag << " Can't compute the gradient of argument index "
|
||||
<< argnums.back() << " because the function is called with only "
|
||||
@@ -126,8 +126,8 @@ auto py_value_and_grad(
|
||||
std::vector<mx::array> arrays;
|
||||
std::vector<int> counts(1, 0);
|
||||
std::vector<int> gradient_indices;
|
||||
for (int i = 0, j = 0; i < args.size(); ++i) {
|
||||
bool needs_grad = (j < argnums.size() && argnums[j] == i);
|
||||
for (int i = 0, j = 0; i < std::ssize(args); ++i) {
|
||||
bool needs_grad = (j < std::ssize(argnums) && argnums[j] == i);
|
||||
auto argsi = tree_flatten(args[i], /* strict = */ needs_grad);
|
||||
if (needs_grad) {
|
||||
auto old_size = gradient_indices.size();
|
||||
@@ -257,7 +257,7 @@ auto py_value_and_grad(
|
||||
positional_grads = tree_unflatten(args[argnums[0]], gradients, counts[0]);
|
||||
} else if (argnums.size() > 1) {
|
||||
nb::list grads_;
|
||||
for (int i = 0; i < argnums.size(); i++) {
|
||||
for (int i = 0; i < std::ssize(argnums); i++) {
|
||||
grads_.append(tree_unflatten(args[argnums[i]], gradients, counts[i]));
|
||||
}
|
||||
positional_grads = nb::tuple(grads_);
|
||||
@@ -366,14 +366,13 @@ auto py_vmap(
|
||||
// able to reconstruct the python tree of extra return values
|
||||
nb::object py_outputs;
|
||||
|
||||
auto vmap_fn =
|
||||
[&fun, &args, &inputs, &py_outputs](const std::vector<mx::array>& a) {
|
||||
// Call the python function
|
||||
py_outputs = fun(*tree_unflatten(args, a));
|
||||
auto vmap_fn = [&fun, &args, &py_outputs](const std::vector<mx::array>& a) {
|
||||
// Call the python function
|
||||
py_outputs = fun(*tree_unflatten(args, a));
|
||||
|
||||
// Flatten the outputs
|
||||
return tree_flatten(py_outputs, true);
|
||||
};
|
||||
// Flatten the outputs
|
||||
return tree_flatten(py_outputs, true);
|
||||
};
|
||||
|
||||
auto [trace_inputs, trace_outputs] =
|
||||
mx::detail::vmap_trace(vmap_fn, inputs, flat_in_axes);
|
||||
@@ -451,7 +450,7 @@ struct PyCompiledFun {
|
||||
if (nb::isinstance<nb::list>(obj)) {
|
||||
auto l = nb::cast<nb::list>(obj);
|
||||
constants.push_back(list_identifier);
|
||||
for (int i = 0; i < l.size(); ++i) {
|
||||
for (int i = 0; i < std::ssize(l); ++i) {
|
||||
recurse(l[i]);
|
||||
}
|
||||
} else if (nb::isinstance<nb::tuple>(obj)) {
|
||||
|
||||
@@ -6,7 +6,8 @@ template <typename T, typename U, typename V>
|
||||
void validate_subtrees(const std::vector<nb::object>& subtrees) {
|
||||
int len = nb::cast<T>(subtrees[0]).size();
|
||||
for (auto& subtree : subtrees) {
|
||||
if ((nb::isinstance<T>(subtree) && nb::cast<T>(subtree).size() != len) ||
|
||||
if ((nb::isinstance<T>(subtree) &&
|
||||
std::ssize(nb::cast<T>(subtree)) != len) ||
|
||||
nb::isinstance<U>(subtree) || nb::isinstance<V>(subtree)) {
|
||||
throw std::invalid_argument(
|
||||
"[tree_map] Additional input tree is not a valid prefix of the first tree.");
|
||||
@@ -24,8 +25,8 @@ nb::object tree_map(
|
||||
nb::list l;
|
||||
std::vector<nb::object> items(subtrees.size());
|
||||
validate_subtrees<nb::list, nb::tuple, nb::dict>(subtrees);
|
||||
for (int i = 0; i < nb::cast<nb::list>(subtrees[0]).size(); ++i) {
|
||||
for (int j = 0; j < subtrees.size(); ++j) {
|
||||
for (int i = 0; i < std::ssize(nb::cast<nb::list>(subtrees[0])); ++i) {
|
||||
for (int j = 0; j < std::ssize(subtrees); ++j) {
|
||||
if (nb::isinstance<nb::list>(subtrees[j])) {
|
||||
items[j] = nb::cast<nb::list>(subtrees[j])[i];
|
||||
} else {
|
||||
@@ -42,7 +43,7 @@ nb::object tree_map(
|
||||
nb::list l;
|
||||
validate_subtrees<nb::tuple, nb::list, nb::dict>(subtrees);
|
||||
for (int i = 0; i < len; ++i) {
|
||||
for (int j = 0; j < subtrees.size(); ++j) {
|
||||
for (int j = 0; j < std::ssize(subtrees); ++j) {
|
||||
if (nb::isinstance<nb::tuple>(subtrees[j])) {
|
||||
items[j] = nb::cast<nb::tuple>(subtrees[j])[i];
|
||||
} else {
|
||||
@@ -57,7 +58,7 @@ nb::object tree_map(
|
||||
validate_subtrees<nb::dict, nb::list, nb::tuple>(subtrees);
|
||||
nb::dict d;
|
||||
for (auto item : nb::cast<nb::dict>(subtrees[0])) {
|
||||
for (int j = 0; j < subtrees.size(); ++j) {
|
||||
for (int j = 0; j < std::ssize(subtrees); ++j) {
|
||||
if (nb::isinstance<nb::dict>(subtrees[j])) {
|
||||
auto subdict = nb::cast<nb::dict>(subtrees[j]);
|
||||
if (!subdict.contains(item.first)) {
|
||||
@@ -96,8 +97,8 @@ void tree_visit(
|
||||
if (nb::isinstance<nb::list>(subtrees[0])) {
|
||||
std::vector<nb::object> items(subtrees.size());
|
||||
validate_subtrees<nb::list, nb::tuple, nb::dict>(subtrees);
|
||||
for (int i = 0; i < nb::cast<nb::list>(subtrees[0]).size(); ++i) {
|
||||
for (int j = 0; j < subtrees.size(); ++j) {
|
||||
for (int i = 0; i < std::ssize(nb::cast<nb::list>(subtrees[0])); ++i) {
|
||||
for (int j = 0; j < std::ssize(subtrees); ++j) {
|
||||
if (nb::isinstance<nb::list>(subtrees[j])) {
|
||||
items[j] = nb::cast<nb::list>(subtrees[j])[i];
|
||||
} else {
|
||||
@@ -112,7 +113,7 @@ void tree_visit(
|
||||
int len = nb::cast<nb::tuple>(subtrees[0]).size();
|
||||
validate_subtrees<nb::tuple, nb::list, nb::dict>(subtrees);
|
||||
for (int i = 0; i < len; ++i) {
|
||||
for (int j = 0; j < subtrees.size(); ++j) {
|
||||
for (int j = 0; j < std::ssize(subtrees); ++j) {
|
||||
if (nb::isinstance<nb::tuple>(subtrees[j])) {
|
||||
items[j] = nb::cast<nb::tuple>(subtrees[j])[i];
|
||||
} else {
|
||||
@@ -125,7 +126,7 @@ void tree_visit(
|
||||
std::vector<nb::object> items(subtrees.size());
|
||||
validate_subtrees<nb::dict, nb::list, nb::tuple>(subtrees);
|
||||
for (auto item : nb::cast<nb::dict>(subtrees[0])) {
|
||||
for (int j = 0; j < subtrees.size(); ++j) {
|
||||
for (int j = 0; j < std::ssize(subtrees); ++j) {
|
||||
if (nb::isinstance<nb::dict>(subtrees[j])) {
|
||||
auto subdict = nb::cast<nb::dict>(subtrees[j]);
|
||||
if (!subdict.contains(item.first)) {
|
||||
@@ -173,13 +174,13 @@ void tree_visit_update(
|
||||
recurse = [&](nb::handle subtree) {
|
||||
if (nb::isinstance<nb::list>(subtree)) {
|
||||
auto l = nb::cast<nb::list>(subtree);
|
||||
for (int i = 0; i < l.size(); ++i) {
|
||||
for (int i = 0; i < std::ssize(l); ++i) {
|
||||
l[i] = recurse(l[i]);
|
||||
}
|
||||
return nb::cast<nb::object>(l);
|
||||
} else if (nb::isinstance<nb::tuple>(subtree)) {
|
||||
nb::list l(subtree);
|
||||
for (int i = 0; i < l.size(); ++i) {
|
||||
for (int i = 0; i < std::ssize(l); ++i) {
|
||||
l[i] = recurse(l[i]);
|
||||
}
|
||||
return nb::cast<nb::object>(nb::tuple(l));
|
||||
@@ -204,7 +205,7 @@ void tree_visit_update(
|
||||
void tree_fill(nb::object& tree, const std::vector<mx::array>& values) {
|
||||
size_t index = 0;
|
||||
tree_visit_update(
|
||||
tree, [&](nb::handle node) { return nb::cast(values[index++]); });
|
||||
tree, [&](nb::handle /* node */) { return nb::cast(values[index++]); });
|
||||
}
|
||||
|
||||
// Replace all the arrays from the src values with the dst values in the tree
|
||||
@@ -213,7 +214,7 @@ void tree_replace(
|
||||
const std::vector<mx::array>& src,
|
||||
const std::vector<mx::array>& dst) {
|
||||
std::unordered_map<uintptr_t, mx::array> src_to_dst;
|
||||
for (int i = 0; i < src.size(); ++i) {
|
||||
for (int i = 0; i < std::ssize(src); ++i) {
|
||||
src_to_dst.insert({src[i].id(), dst[i]});
|
||||
}
|
||||
tree_visit_update(tree, [&](nb::handle node) {
|
||||
|
||||
@@ -57,8 +57,8 @@ std::pair<mx::array, mx::array> to_arrays(
|
||||
// - If neither is an array convert to arrays but leave their types alone
|
||||
auto is_mlx_array = [](const ScalarOrArray& x) {
|
||||
return std::holds_alternative<mx::array>(x) ||
|
||||
std::holds_alternative<ArrayLike>(x) &&
|
||||
nb::hasattr(std::get<ArrayLike>(x).obj, "__mlx_array__");
|
||||
(std::holds_alternative<ArrayLike>(x) &&
|
||||
nb::hasattr(std::get<ArrayLike>(x).obj, "__mlx_array__"));
|
||||
};
|
||||
auto get_mlx_array = [](const ScalarOrArray& x) {
|
||||
if (auto px = std::get_if<mx::array>(&x); px) {
|
||||
|
||||
Reference in New Issue
Block a user