WIP (python)

This commit is contained in:
Ronan Collobert
2025-10-31 16:24:51 -07:00
parent 18aa921388
commit 9f649b5658
11 changed files with 54 additions and 52 deletions

View File

@@ -83,7 +83,7 @@ class ArrayPythonIterator {
throw nb::stop_iteration();
}
if (idx_ >= 0 && idx_ < splits_.size()) {
if (idx_ >= 0 && idx_ < std::ssize(splits_)) {
return mx::squeeze(splits_[idx_++], 0);
}
@@ -390,7 +390,7 @@ void init_array(nb::module_& m) {
)pbdoc")
.def(
"__array_namespace__",
[](const mx::array& a,
[](const mx::array& /* a */,
const std::optional<std::string>& api_version) {
if (api_version) {
throw std::invalid_argument(
@@ -501,7 +501,7 @@ void init_array(nb::module_& m) {
.def("__dlpack__", [](const mx::array& a) { return mlx_to_dlpack(a); })
.def(
"__dlpack_device__",
[](const mx::array& a) {
[](const mx::array& /* a */) {
// See
// https://github.com/dmlc/dlpack/blob/5c210da409e7f1e51ddf445134a4376fdbd70d7d/include/dlpack/dlpack.h#L74
if (mx::metal::is_available()) {

View File

@@ -50,7 +50,7 @@ mx::array nd_array_to_mlx(
// Compute the shape and size
mx::Shape shape;
shape.reserve(nd_array.ndim());
for (int i = 0; i < nd_array.ndim(); i++) {
for (int i = 0; i < static_cast<int>(nd_array.ndim()); i++) {
shape.push_back(check_shape_dim(nd_array.shape(i)));
}
auto type = nd_array.dtype();
@@ -289,7 +289,7 @@ PyScalarT validate_shape(
throw std::invalid_argument("Initialization encountered extra dimension.");
}
auto s = shape[idx];
if (nb::len(list) != s) {
if (nb::len(list) != static_cast<size_t>(s)) {
throw std::invalid_argument(
"Initialization encountered non-uniform length.");
}

View File

@@ -201,7 +201,6 @@ void init_fast(nb::module_& parent_module) {
bool has_mask = !std::holds_alternative<std::monostate>(mask);
bool has_str_mask =
has_mask && std::holds_alternative<std::string>(mask);
bool has_arr_mask = has_mask && std::holds_alternative<mx::array>(mask);
if (has_mask) {
if (has_str_mask) {

View File

@@ -115,7 +115,7 @@ mx::array mlx_gather_nd(
std::vector<bool> is_slice(indices.size(), false);
int num_slices = 0;
// gather all the arrays
for (int i = 0; i < indices.size(); i++) {
for (int i = 0; i < std::ssize(indices); i++) {
auto& idx = indices[i];
if (nb::isinstance<nb::slice>(idx)) {
@@ -142,7 +142,7 @@ mx::array mlx_gather_nd(
// reshape them so that the int/array indices are first
if (gather_first) {
int slice_index = 0;
for (int i = 0; i < gather_indices.size(); i++) {
for (int i = 0; i < std::ssize(gather_indices); i++) {
if (is_slice[i]) {
mx::Shape index_shape(max_dims + num_slices, 1);
index_shape[max_dims + slice_index] = gather_indices[i].shape(0);
@@ -156,7 +156,7 @@ mx::array mlx_gather_nd(
}
} else {
// reshape them so that the int/array indices are last
for (int i = 0; i < gather_indices.size(); i++) {
for (int i = 0; i < std::ssize(gather_indices); i++) {
if (i < num_slices) {
mx::Shape index_shape(max_dims + num_slices, 1);
index_shape[i] = gather_indices[i].shape(0);
@@ -190,7 +190,7 @@ auto mlx_expand_ellipsis(const mx::Shape& shape, const nb::tuple& entries) {
bool has_ellipsis = false;
// Start from dimension 0 till we hit an ellipsis
for (; i < entries.size(); i++) {
for (; i < std::ssize(entries); i++) {
auto idx = entries[i];
if (!is_valid_index_type(idx)) {
throw std::invalid_argument(
@@ -301,7 +301,8 @@ mx::array mlx_get_item_nd(mx::array src, const nb::tuple& entries) {
if (have_array) {
int last_array;
// Then find the last array
for (last_array = indices.size() - 1; last_array >= 0; last_array--) {
for (last_array = std::ssize(indices) - 1; last_array >= 0;
last_array--) {
auto& idx = indices[last_array];
if (nb::isinstance<mx::array>(idx) || nb::isinstance<nb::int_>(idx)) {
break;
@@ -333,11 +334,11 @@ mx::array mlx_get_item_nd(mx::array src, const nb::tuple& entries) {
nb::slice(nb::none(), nb::none(), nb::none()));
}
}
for (int i = last_array + 1; i < indices.size(); i++) {
for (int i = last_array + 1; i < std::ssize(indices); i++) {
remaining_indices.push_back(indices[i]);
}
} else {
for (int i = 0; i < indices.size(); i++) {
for (int i = 0; i < std::ssize(indices); i++) {
auto& idx = indices[i];
if (nb::isinstance<mx::array>(idx) || nb::isinstance<nb::int_>(idx)) {
break;
@@ -352,7 +353,7 @@ mx::array mlx_get_item_nd(mx::array src, const nb::tuple& entries) {
remaining_indices.push_back(
nb::slice(nb::none(), nb::none(), nb::none()));
}
for (int i = last_array + 1; i < indices.size(); i++) {
for (int i = last_array + 1; i < std::ssize(indices); i++) {
remaining_indices.push_back(indices[i]);
}
}
@@ -406,7 +407,7 @@ mx::array mlx_get_item_nd(mx::array src, const nb::tuple& entries) {
if (unsqueeze_needed || squeeze_needed) {
std::vector<int> squeeze_axes;
std::vector<int> unsqueeze_axes;
for (int axis = 0; axis < remaining_indices.size(); ++axis) {
for (int axis = 0; axis < std::ssize(remaining_indices); ++axis) {
auto& idx = remaining_indices[axis];
if (unsqueeze_needed && idx.is_none()) {
unsqueeze_axes.push_back(axis - squeeze_axes.size());
@@ -583,7 +584,7 @@ mlx_scatter_args_nd(
}
// Analyse the types of the indices
size_t max_dim = 0;
int max_dim = 0;
bool arrays_first = false;
int num_none = 0;
int num_slices = 0;
@@ -640,7 +641,7 @@ mlx_scatter_args_nd(
std::vector<int> update_shape(non_none_indices, 1);
std::vector<int> slice_shapes;
for (int i = 0; i < indices.size(); ++i) {
for (int i = 0; i < std::ssize(indices); ++i) {
auto& pyidx = indices[i];
if (nb::isinstance<nb::slice>(pyidx)) {
mx::ShapeElem start, end, stride;
@@ -848,7 +849,7 @@ auto mlx_slice_update(
int unspecified = src.ndim() - non_none_indices;
std::vector<int> squeeze_dims;
std::vector<int> expand_dims;
for (int i = indices.size() - 1,
for (int i = std::ssize(indices) - 1,
ax = non_none_indices - 1,
upd_ax = upd.ndim() - unspecified - 1;
i >= 0;

View File

@@ -436,7 +436,7 @@ void mlx_savez_helper(
nb::cast<std::unordered_map<std::string, mx::array>>(kwargs);
auto arrays_list = nb::cast<std::vector<mx::array>>(args);
for (int i = 0; i < arrays_list.size(); i++) {
for (int i = 0; i < std::ssize(arrays_list); i++) {
std::string arr_name = "arr_" + std::to_string(i);
if (arrays_dict.count(arr_name) > 0) {

View File

@@ -22,7 +22,9 @@ bool DEPRECATE(const char* old_fn, const char* new_fn) {
return true;
}
#define DEPRECATE(oldfn, newfn) static bool dep = DEPRECATE(oldfn, newfn)
#define DEPRECATE(oldfn, newfn) \
static bool dep = DEPRECATE(oldfn, newfn); \
(void)dep;
void init_metal(nb::module_& m) {
nb::module_ metal = m.def_submodule("metal", "mlx.metal");

View File

@@ -107,7 +107,7 @@ nb::callable mlx_func(
return nb::steal<nb::callable>((PyObject*)r);
}
void init_mlx_func(nb::module_& m) {
void init_mlx_func(nb::module_& /* m */) {
gc_func_tp = (PyTypeObject*)PyType_FromSpec(&gc_func_spec);
if (!gc_func_tp) {
nb::raise("Could not register MLX function type.");

View File

@@ -100,9 +100,9 @@ void init_stream(nb::module_& m) {
.def(
"__exit__",
[](PyStreamContext& scm,
const std::optional<nb::type_object>& exc_type,
const std::optional<nb::object>& exc_value,
const std::optional<nb::object>& traceback) { scm.exit(); },
const std::optional<nb::type_object>& /* exc_type */,
const std::optional<nb::object>& /* exc_value */,
const std::optional<nb::object>& /* traceback */) { scm.exit(); },
"exc_type"_a = nb::none(),
"exc_value"_a = nb::none(),
"traceback"_a = nb::none());

View File

@@ -86,7 +86,7 @@ auto py_value_and_grad(
<< argnums[0];
throw std::invalid_argument(msg.str());
}
for (int i = 1; i < argnums.size(); ++i) {
for (int i = 1; i < std::ssize(argnums); ++i) {
if (argnums[i] == argnums[i - 1]) {
std::ostringstream msg;
msg << error_msg_tag << " Duplicate argument index " << argnums[0]
@@ -99,7 +99,7 @@ auto py_value_and_grad(
return [fun, argnums, argnames, error_msg_tag, scalar_func_only](
nb::args& args, nb::kwargs& kwargs) {
// Sanitize the input
if (argnums.size() > 0 && argnums.back() >= args.size()) {
if (argnums.size() > 0 && argnums.back() >= std::ssize(args)) {
std::ostringstream msg;
msg << error_msg_tag << " Can't compute the gradient of argument index "
<< argnums.back() << " because the function is called with only "
@@ -126,8 +126,8 @@ auto py_value_and_grad(
std::vector<mx::array> arrays;
std::vector<int> counts(1, 0);
std::vector<int> gradient_indices;
for (int i = 0, j = 0; i < args.size(); ++i) {
bool needs_grad = (j < argnums.size() && argnums[j] == i);
for (int i = 0, j = 0; i < std::ssize(args); ++i) {
bool needs_grad = (j < std::ssize(argnums) && argnums[j] == i);
auto argsi = tree_flatten(args[i], /* strict = */ needs_grad);
if (needs_grad) {
auto old_size = gradient_indices.size();
@@ -257,7 +257,7 @@ auto py_value_and_grad(
positional_grads = tree_unflatten(args[argnums[0]], gradients, counts[0]);
} else if (argnums.size() > 1) {
nb::list grads_;
for (int i = 0; i < argnums.size(); i++) {
for (int i = 0; i < std::ssize(argnums); i++) {
grads_.append(tree_unflatten(args[argnums[i]], gradients, counts[i]));
}
positional_grads = nb::tuple(grads_);
@@ -366,8 +366,7 @@ auto py_vmap(
// able to reconstruct the python tree of extra return values
nb::object py_outputs;
auto vmap_fn =
[&fun, &args, &inputs, &py_outputs](const std::vector<mx::array>& a) {
auto vmap_fn = [&fun, &args, &py_outputs](const std::vector<mx::array>& a) {
// Call the python function
py_outputs = fun(*tree_unflatten(args, a));
@@ -451,7 +450,7 @@ struct PyCompiledFun {
if (nb::isinstance<nb::list>(obj)) {
auto l = nb::cast<nb::list>(obj);
constants.push_back(list_identifier);
for (int i = 0; i < l.size(); ++i) {
for (int i = 0; i < std::ssize(l); ++i) {
recurse(l[i]);
}
} else if (nb::isinstance<nb::tuple>(obj)) {

View File

@@ -6,7 +6,8 @@ template <typename T, typename U, typename V>
void validate_subtrees(const std::vector<nb::object>& subtrees) {
int len = nb::cast<T>(subtrees[0]).size();
for (auto& subtree : subtrees) {
if ((nb::isinstance<T>(subtree) && nb::cast<T>(subtree).size() != len) ||
if ((nb::isinstance<T>(subtree) &&
std::ssize(nb::cast<T>(subtree)) != len) ||
nb::isinstance<U>(subtree) || nb::isinstance<V>(subtree)) {
throw std::invalid_argument(
"[tree_map] Additional input tree is not a valid prefix of the first tree.");
@@ -24,8 +25,8 @@ nb::object tree_map(
nb::list l;
std::vector<nb::object> items(subtrees.size());
validate_subtrees<nb::list, nb::tuple, nb::dict>(subtrees);
for (int i = 0; i < nb::cast<nb::list>(subtrees[0]).size(); ++i) {
for (int j = 0; j < subtrees.size(); ++j) {
for (int i = 0; i < std::ssize(nb::cast<nb::list>(subtrees[0])); ++i) {
for (int j = 0; j < std::ssize(subtrees); ++j) {
if (nb::isinstance<nb::list>(subtrees[j])) {
items[j] = nb::cast<nb::list>(subtrees[j])[i];
} else {
@@ -42,7 +43,7 @@ nb::object tree_map(
nb::list l;
validate_subtrees<nb::tuple, nb::list, nb::dict>(subtrees);
for (int i = 0; i < len; ++i) {
for (int j = 0; j < subtrees.size(); ++j) {
for (int j = 0; j < std::ssize(subtrees); ++j) {
if (nb::isinstance<nb::tuple>(subtrees[j])) {
items[j] = nb::cast<nb::tuple>(subtrees[j])[i];
} else {
@@ -57,7 +58,7 @@ nb::object tree_map(
validate_subtrees<nb::dict, nb::list, nb::tuple>(subtrees);
nb::dict d;
for (auto item : nb::cast<nb::dict>(subtrees[0])) {
for (int j = 0; j < subtrees.size(); ++j) {
for (int j = 0; j < std::ssize(subtrees); ++j) {
if (nb::isinstance<nb::dict>(subtrees[j])) {
auto subdict = nb::cast<nb::dict>(subtrees[j]);
if (!subdict.contains(item.first)) {
@@ -96,8 +97,8 @@ void tree_visit(
if (nb::isinstance<nb::list>(subtrees[0])) {
std::vector<nb::object> items(subtrees.size());
validate_subtrees<nb::list, nb::tuple, nb::dict>(subtrees);
for (int i = 0; i < nb::cast<nb::list>(subtrees[0]).size(); ++i) {
for (int j = 0; j < subtrees.size(); ++j) {
for (int i = 0; i < std::ssize(nb::cast<nb::list>(subtrees[0])); ++i) {
for (int j = 0; j < std::ssize(subtrees); ++j) {
if (nb::isinstance<nb::list>(subtrees[j])) {
items[j] = nb::cast<nb::list>(subtrees[j])[i];
} else {
@@ -112,7 +113,7 @@ void tree_visit(
int len = nb::cast<nb::tuple>(subtrees[0]).size();
validate_subtrees<nb::tuple, nb::list, nb::dict>(subtrees);
for (int i = 0; i < len; ++i) {
for (int j = 0; j < subtrees.size(); ++j) {
for (int j = 0; j < std::ssize(subtrees); ++j) {
if (nb::isinstance<nb::tuple>(subtrees[j])) {
items[j] = nb::cast<nb::tuple>(subtrees[j])[i];
} else {
@@ -125,7 +126,7 @@ void tree_visit(
std::vector<nb::object> items(subtrees.size());
validate_subtrees<nb::dict, nb::list, nb::tuple>(subtrees);
for (auto item : nb::cast<nb::dict>(subtrees[0])) {
for (int j = 0; j < subtrees.size(); ++j) {
for (int j = 0; j < std::ssize(subtrees); ++j) {
if (nb::isinstance<nb::dict>(subtrees[j])) {
auto subdict = nb::cast<nb::dict>(subtrees[j]);
if (!subdict.contains(item.first)) {
@@ -173,13 +174,13 @@ void tree_visit_update(
recurse = [&](nb::handle subtree) {
if (nb::isinstance<nb::list>(subtree)) {
auto l = nb::cast<nb::list>(subtree);
for (int i = 0; i < l.size(); ++i) {
for (int i = 0; i < std::ssize(l); ++i) {
l[i] = recurse(l[i]);
}
return nb::cast<nb::object>(l);
} else if (nb::isinstance<nb::tuple>(subtree)) {
nb::list l(subtree);
for (int i = 0; i < l.size(); ++i) {
for (int i = 0; i < std::ssize(l); ++i) {
l[i] = recurse(l[i]);
}
return nb::cast<nb::object>(nb::tuple(l));
@@ -204,7 +205,7 @@ void tree_visit_update(
void tree_fill(nb::object& tree, const std::vector<mx::array>& values) {
size_t index = 0;
tree_visit_update(
tree, [&](nb::handle node) { return nb::cast(values[index++]); });
tree, [&](nb::handle /* node */) { return nb::cast(values[index++]); });
}
// Replace all the arrays from the src values with the dst values in the tree
@@ -213,7 +214,7 @@ void tree_replace(
const std::vector<mx::array>& src,
const std::vector<mx::array>& dst) {
std::unordered_map<uintptr_t, mx::array> src_to_dst;
for (int i = 0; i < src.size(); ++i) {
for (int i = 0; i < std::ssize(src); ++i) {
src_to_dst.insert({src[i].id(), dst[i]});
}
tree_visit_update(tree, [&](nb::handle node) {

View File

@@ -57,8 +57,8 @@ std::pair<mx::array, mx::array> to_arrays(
// - If neither is an array convert to arrays but leave their types alone
auto is_mlx_array = [](const ScalarOrArray& x) {
return std::holds_alternative<mx::array>(x) ||
std::holds_alternative<ArrayLike>(x) &&
nb::hasattr(std::get<ArrayLike>(x).obj, "__mlx_array__");
(std::holds_alternative<ArrayLike>(x) &&
nb::hasattr(std::get<ArrayLike>(x).obj, "__mlx_array__"));
};
auto get_mlx_array = [](const ScalarOrArray& x) {
if (auto px = std::get_if<mx::array>(&x); px) {