mirror of
https://github.com/ml-explore/mlx.git
synced 2025-11-24 09:08:09 +08:00
list based indexing (#1150)
This commit is contained in:
@@ -2,6 +2,7 @@
|
||||
#include <numeric>
|
||||
#include <sstream>
|
||||
|
||||
#include "python/src/convert.h"
|
||||
#include "python/src/indexing.h"
|
||||
|
||||
#include "mlx/ops.h"
|
||||
@@ -51,7 +52,8 @@ array get_int_index(nb::object idx, int axis_size) {
|
||||
|
||||
bool is_valid_index_type(const nb::object& obj) {
|
||||
return nb::isinstance<nb::slice>(obj) || nb::isinstance<nb::int_>(obj) ||
|
||||
nb::isinstance<array>(obj) || obj.is_none() || nb::ellipsis().is(obj);
|
||||
nb::isinstance<array>(obj) || obj.is_none() || nb::ellipsis().is(obj) ||
|
||||
nb::isinstance<nb::list>(obj);
|
||||
}
|
||||
|
||||
array mlx_get_item_slice(const array& src, const nb::slice& in_slice) {
|
||||
@@ -255,11 +257,18 @@ array mlx_get_item_nd(array src, const nb::tuple& entries) {
|
||||
|
||||
// The plan is as follows:
|
||||
// 1. Replace the ellipsis with a series of slice(None)
|
||||
// 2. Loop over the indices and calculate the gather indices
|
||||
// 3. Calculate the remaining slices and reshapes
|
||||
// 2. Convert list to array
|
||||
// 3. Loop over the indices and calculate the gather indices
|
||||
// 4. Calculate the remaining slices and reshapes
|
||||
|
||||
// Ellipsis handling
|
||||
auto [non_none_indices, indices] = mlx_expand_ellipsis(src.shape(), entries);
|
||||
// List handling
|
||||
for (auto& idx : indices) {
|
||||
if (nb::isinstance<nb::list>(idx)) {
|
||||
idx = nb::cast(array_from_list(nb::cast<nb::list>(idx), {}));
|
||||
}
|
||||
}
|
||||
|
||||
// Check for the number of indices passed
|
||||
if (non_none_indices > src.ndim()) {
|
||||
@@ -440,6 +449,9 @@ array mlx_get_item(const array& src, const nb::object& obj) {
|
||||
std::vector<int> s(1, 1);
|
||||
s.insert(s.end(), src.shape().begin(), src.shape().end());
|
||||
return reshape(src, s);
|
||||
} else if (nb::isinstance<nb::list>(obj)) {
|
||||
return mlx_get_item_array(
|
||||
src, array_from_list(nb::cast<nb::list>(obj), {}));
|
||||
}
|
||||
throw std::invalid_argument("Cannot index mlx array using the given type.");
|
||||
}
|
||||
@@ -564,6 +576,13 @@ std::tuple<std::vector<array>, array, std::vector<int>> mlx_scatter_args_nd(
|
||||
// Expand ellipses into a series of ':' slices
|
||||
auto [non_none_indices, indices] = mlx_expand_ellipsis(src.shape(), entries);
|
||||
|
||||
// Convert List to array
|
||||
for (auto& idx : indices) {
|
||||
if (nb::isinstance<nb::list>(idx)) {
|
||||
idx = nb::cast(array_from_list(nb::cast<nb::list>(idx), {}));
|
||||
}
|
||||
}
|
||||
|
||||
if (non_none_indices > src.ndim()) {
|
||||
std::ostringstream msg;
|
||||
msg << "Too many indices for array with " << src.ndim() << "dimensions.";
|
||||
@@ -753,7 +772,11 @@ mlx_compute_scatter_args(
|
||||
return mlx_scatter_args_nd(src, nb::cast<nb::tuple>(obj), vals);
|
||||
} else if (obj.is_none()) {
|
||||
return {{}, broadcast_to(vals, src.shape()), {}};
|
||||
} else if (nb::isinstance<nb::list>(obj)) {
|
||||
return mlx_scatter_args_array(
|
||||
src, array_from_list(nb::cast<nb::list>(obj), {}), vals);
|
||||
}
|
||||
|
||||
throw std::invalid_argument("Cannot index mlx array using the given type.");
|
||||
}
|
||||
|
||||
@@ -769,7 +792,7 @@ auto mlx_slice_update(
|
||||
if (nb::isinstance<nb::tuple>(obj)) {
|
||||
// Can't route to slice update if any arrays are present
|
||||
for (auto idx : nb::cast<nb::tuple>(obj)) {
|
||||
if (nb::isinstance<array>(idx)) {
|
||||
if (nb::isinstance<array>(idx) || nb::isinstance<nb::list>(idx)) {
|
||||
return std::make_pair(false, src);
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user