list based indexing (#1150)

This commit is contained in:
Awni Hannun
2024-05-22 15:52:05 -07:00
committed by GitHub
parent 79ef49b2c2
commit eb8321d863
5 changed files with 425 additions and 328 deletions

View File

@@ -2,6 +2,7 @@
#include <numeric>
#include <sstream>
#include "python/src/convert.h"
#include "python/src/indexing.h"
#include "mlx/ops.h"
@@ -51,7 +52,8 @@ array get_int_index(nb::object idx, int axis_size) {
bool is_valid_index_type(const nb::object& obj) {
return nb::isinstance<nb::slice>(obj) || nb::isinstance<nb::int_>(obj) ||
nb::isinstance<array>(obj) || obj.is_none() || nb::ellipsis().is(obj);
nb::isinstance<array>(obj) || obj.is_none() || nb::ellipsis().is(obj) ||
nb::isinstance<nb::list>(obj);
}
array mlx_get_item_slice(const array& src, const nb::slice& in_slice) {
@@ -255,11 +257,18 @@ array mlx_get_item_nd(array src, const nb::tuple& entries) {
// The plan is as follows:
// 1. Replace the ellipsis with a series of slice(None)
// 2. Loop over the indices and calculate the gather indices
// 3. Calculate the remaining slices and reshapes
// 2. Convert list to array
// 3. Loop over the indices and calculate the gather indices
// 4. Calculate the remaining slices and reshapes
// Ellipsis handling
auto [non_none_indices, indices] = mlx_expand_ellipsis(src.shape(), entries);
// List handling
for (auto& idx : indices) {
if (nb::isinstance<nb::list>(idx)) {
idx = nb::cast(array_from_list(nb::cast<nb::list>(idx), {}));
}
}
// Check for the number of indices passed
if (non_none_indices > src.ndim()) {
@@ -440,6 +449,9 @@ array mlx_get_item(const array& src, const nb::object& obj) {
std::vector<int> s(1, 1);
s.insert(s.end(), src.shape().begin(), src.shape().end());
return reshape(src, s);
} else if (nb::isinstance<nb::list>(obj)) {
return mlx_get_item_array(
src, array_from_list(nb::cast<nb::list>(obj), {}));
}
throw std::invalid_argument("Cannot index mlx array using the given type.");
}
@@ -564,6 +576,13 @@ std::tuple<std::vector<array>, array, std::vector<int>> mlx_scatter_args_nd(
// Expand ellipses into a series of ':' slices
auto [non_none_indices, indices] = mlx_expand_ellipsis(src.shape(), entries);
// Convert List to array
for (auto& idx : indices) {
if (nb::isinstance<nb::list>(idx)) {
idx = nb::cast(array_from_list(nb::cast<nb::list>(idx), {}));
}
}
if (non_none_indices > src.ndim()) {
std::ostringstream msg;
msg << "Too many indices for array with " << src.ndim() << "dimensions.";
@@ -753,7 +772,11 @@ mlx_compute_scatter_args(
return mlx_scatter_args_nd(src, nb::cast<nb::tuple>(obj), vals);
} else if (obj.is_none()) {
return {{}, broadcast_to(vals, src.shape()), {}};
} else if (nb::isinstance<nb::list>(obj)) {
return mlx_scatter_args_array(
src, array_from_list(nb::cast<nb::list>(obj), {}), vals);
}
throw std::invalid_argument("Cannot index mlx array using the given type.");
}
@@ -769,7 +792,7 @@ auto mlx_slice_update(
if (nb::isinstance<nb::tuple>(obj)) {
// Can't route to slice update if any arrays are present
for (auto idx : nb::cast<nb::tuple>(obj)) {
if (nb::isinstance<array>(idx)) {
if (nb::isinstance<array>(idx) || nb::isinstance<nb::list>(idx)) {
return std::make_pair(false, src);
}
}