list based indexing (#1150)

This commit is contained in:
Awni Hannun 2024-05-22 15:52:05 -07:00 committed by GitHub
parent 79ef49b2c2
commit eb8321d863
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 425 additions and 328 deletions

View File

@ -23,330 +23,6 @@ namespace nb = nanobind;
using namespace nb::literals; using namespace nb::literals;
using namespace mlx::core; using namespace mlx::core;
enum PyScalarT {
pybool = 0,
pyint = 1,
pyfloat = 2,
pycomplex = 3,
};
template <typename T, typename U = T>
nb::list to_list(array& a, size_t index, int dim) {
nb::list pl;
auto stride = a.strides()[dim];
for (int i = 0; i < a.shape(dim); ++i) {
if (dim == a.ndim() - 1) {
pl.append(static_cast<U>(a.data<T>()[index]));
} else {
pl.append(to_list<T, U>(a, index, dim + 1));
}
index += stride;
}
return pl;
}
auto to_scalar(array& a) {
{
nb::gil_scoped_release nogil;
a.eval();
}
switch (a.dtype()) {
case bool_:
return nb::cast(a.item<bool>());
case uint8:
return nb::cast(a.item<uint8_t>());
case uint16:
return nb::cast(a.item<uint16_t>());
case uint32:
return nb::cast(a.item<uint32_t>());
case uint64:
return nb::cast(a.item<uint64_t>());
case int8:
return nb::cast(a.item<int8_t>());
case int16:
return nb::cast(a.item<int16_t>());
case int32:
return nb::cast(a.item<int32_t>());
case int64:
return nb::cast(a.item<int64_t>());
case float16:
return nb::cast(static_cast<float>(a.item<float16_t>()));
case float32:
return nb::cast(a.item<float>());
case bfloat16:
return nb::cast(static_cast<float>(a.item<bfloat16_t>()));
case complex64:
return nb::cast(a.item<std::complex<float>>());
}
}
nb::object tolist(array& a) {
if (a.ndim() == 0) {
return to_scalar(a);
}
{
nb::gil_scoped_release nogil;
a.eval();
}
switch (a.dtype()) {
case bool_:
return to_list<bool>(a, 0, 0);
case uint8:
return to_list<uint8_t>(a, 0, 0);
case uint16:
return to_list<uint16_t>(a, 0, 0);
case uint32:
return to_list<uint32_t>(a, 0, 0);
case uint64:
return to_list<uint64_t>(a, 0, 0);
case int8:
return to_list<int8_t>(a, 0, 0);
case int16:
return to_list<int16_t>(a, 0, 0);
case int32:
return to_list<int32_t>(a, 0, 0);
case int64:
return to_list<int64_t>(a, 0, 0);
case float16:
return to_list<float16_t, float>(a, 0, 0);
case float32:
return to_list<float>(a, 0, 0);
case bfloat16:
return to_list<bfloat16_t, float>(a, 0, 0);
case complex64:
return to_list<std::complex<float>>(a, 0, 0);
}
}
template <typename T, typename U>
void fill_vector(T list, std::vector<U>& vals) {
for (auto l : list) {
if (nb::isinstance<nb::list>(l)) {
fill_vector(nb::cast<nb::list>(l), vals);
} else if (nb::isinstance<nb::tuple>(*list.begin())) {
fill_vector(nb::cast<nb::tuple>(l), vals);
} else {
vals.push_back(nb::cast<U>(l));
}
}
}
template <typename T>
PyScalarT validate_shape(
T list,
const std::vector<int>& shape,
int idx,
bool& all_python_primitive_elements) {
if (idx >= shape.size()) {
throw std::invalid_argument("Initialization encountered extra dimension.");
}
auto s = shape[idx];
if (nb::len(list) != s) {
throw std::invalid_argument(
"Initialization encountered non-uniform length.");
}
if (s == 0) {
return pyfloat;
}
PyScalarT type = pybool;
for (auto l : list) {
PyScalarT t;
if (nb::isinstance<nb::list>(l)) {
t = validate_shape(
nb::cast<nb::list>(l), shape, idx + 1, all_python_primitive_elements);
} else if (nb::isinstance<nb::tuple>(*list.begin())) {
t = validate_shape(
nb::cast<nb::tuple>(l),
shape,
idx + 1,
all_python_primitive_elements);
} else if (nb::isinstance<array>(l)) {
all_python_primitive_elements = false;
auto arr = nb::cast<array>(l);
if (arr.ndim() + idx + 1 == shape.size() &&
std::equal(
arr.shape().cbegin(),
arr.shape().cend(),
shape.cbegin() + idx + 1)) {
t = pybool;
} else {
throw std::invalid_argument(
"Initialization encountered non-uniform length.");
}
} else {
if (nb::isinstance<nb::bool_>(l)) {
t = pybool;
} else if (nb::isinstance<nb::int_>(l)) {
t = pyint;
} else if (nb::isinstance<nb::float_>(l)) {
t = pyfloat;
} else if (PyComplex_Check(l.ptr())) {
t = pycomplex;
} else {
std::ostringstream msg;
msg << "Invalid type " << nb::type_name(l.type()).c_str()
<< " received in array initialization.";
throw std::invalid_argument(msg.str());
}
if (idx + 1 != shape.size()) {
throw std::invalid_argument(
"Initialization encountered non-uniform length.");
}
}
type = std::max(type, t);
}
return type;
}
template <typename T>
void get_shape(T list, std::vector<int>& shape) {
shape.push_back(check_shape_dim(nb::len(list)));
if (shape.back() > 0) {
auto l = list.begin();
if (nb::isinstance<nb::list>(*l)) {
return get_shape(nb::cast<nb::list>(*l), shape);
} else if (nb::isinstance<nb::tuple>(*l)) {
return get_shape(nb::cast<nb::tuple>(*l), shape);
} else if (nb::isinstance<array>(*l)) {
auto arr = nb::cast<array>(*l);
for (int i = 0; i < arr.ndim(); i++) {
shape.push_back(check_shape_dim(arr.shape(i)));
}
return;
}
}
}
using ArrayInitType = std::variant<
nb::bool_,
nb::int_,
nb::float_,
// Must be above ndarray
array,
// Must be above complex
nb::ndarray<nb::ro, nb::c_contig, nb::device::cpu>,
std::complex<float>,
nb::list,
nb::tuple,
nb::object>;
// Forward declaration
array create_array(ArrayInitType v, std::optional<Dtype> t);
template <typename T>
array array_from_list(
T pl,
const PyScalarT& inferred_type,
std::optional<Dtype> specified_type,
const std::vector<int>& shape) {
// Make the array
switch (inferred_type) {
case pybool: {
std::vector<bool> vals;
fill_vector(pl, vals);
return array(vals.begin(), shape, specified_type.value_or(bool_));
}
case pyint: {
auto dtype = specified_type.value_or(int32);
if (dtype == int64) {
std::vector<int64_t> vals;
fill_vector(pl, vals);
return array(vals.begin(), shape, dtype);
} else if (dtype == uint64) {
std::vector<uint64_t> vals;
fill_vector(pl, vals);
return array(vals.begin(), shape, dtype);
} else if (dtype == uint32) {
std::vector<uint32_t> vals;
fill_vector(pl, vals);
return array(vals.begin(), shape, dtype);
} else if (issubdtype(dtype, inexact)) {
std::vector<float> vals;
fill_vector(pl, vals);
return array(vals.begin(), shape, dtype);
} else {
std::vector<int> vals;
fill_vector(pl, vals);
return array(vals.begin(), shape, dtype);
}
}
case pyfloat: {
std::vector<float> vals;
fill_vector(pl, vals);
return array(vals.begin(), shape, specified_type.value_or(float32));
}
case pycomplex: {
std::vector<std::complex<float>> vals;
fill_vector(pl, vals);
return array(
reinterpret_cast<complex64_t*>(vals.data()),
shape,
specified_type.value_or(complex64));
}
default: {
std::ostringstream msg;
msg << "Should not happen, inferred: " << inferred_type
<< " on subarray made of only python primitive types.";
throw std::runtime_error(msg.str());
}
}
}
template <typename T>
array array_from_list(T pl, std::optional<Dtype> dtype) {
// Compute the shape
std::vector<int> shape;
get_shape(pl, shape);
// Validate the shape and type
bool all_python_primitive_elements = true;
auto type = validate_shape(pl, shape, 0, all_python_primitive_elements);
if (all_python_primitive_elements) {
// `pl` does not contain mlx arrays
return array_from_list(pl, type, dtype, shape);
}
// `pl` contains mlx arrays
std::vector<array> arrays;
for (auto l : pl) {
arrays.push_back(create_array(nb::cast<ArrayInitType>(l), dtype));
}
return stack(arrays);
}
///////////////////////////////////////////////////////////////////////////////
// Module
///////////////////////////////////////////////////////////////////////////////
array create_array(ArrayInitType v, std::optional<Dtype> t) {
if (auto pv = std::get_if<nb::bool_>(&v); pv) {
return array(nb::cast<bool>(*pv), t.value_or(bool_));
} else if (auto pv = std::get_if<nb::int_>(&v); pv) {
return array(nb::cast<int>(*pv), t.value_or(int32));
} else if (auto pv = std::get_if<nb::float_>(&v); pv) {
return array(nb::cast<float>(*pv), t.value_or(float32));
} else if (auto pv = std::get_if<std::complex<float>>(&v); pv) {
return array(static_cast<complex64_t>(*pv), t.value_or(complex64));
} else if (auto pv = std::get_if<nb::list>(&v); pv) {
return array_from_list(*pv, t);
} else if (auto pv = std::get_if<nb::tuple>(&v); pv) {
return array_from_list(*pv, t);
} else if (auto pv = std::get_if<
nb::ndarray<nb::ro, nb::c_contig, nb::device::cpu>>(&v);
pv) {
return nd_array_to_mlx(*pv, t);
} else if (auto pv = std::get_if<array>(&v); pv) {
return astype(*pv, t.value_or((*pv).dtype()));
} else {
auto arr = to_array_with_accessor(std::get<nb::object>(v));
return astype(arr, t.value_or(arr.dtype()));
}
}
class ArrayAt { class ArrayAt {
public: public:
ArrayAt(array x) : x_(std::move(x)) {} ArrayAt(array x) : x_(std::move(x)) {}

View File

@ -3,9 +3,17 @@
#include <nanobind/stl/complex.h> #include <nanobind/stl/complex.h>
#include "python/src/convert.h" #include "python/src/convert.h"
#include "python/src/utils.h"
#include "mlx/utils.h" #include "mlx/utils.h"
enum PyScalarT {
pybool = 0,
pyint = 1,
pyfloat = 2,
pycomplex = 3,
};
namespace nanobind { namespace nanobind {
template <> template <>
struct ndarray_traits<float16_t> { struct ndarray_traits<float16_t> {
@ -158,3 +166,308 @@ nb::ndarray<nb::numpy> mlx_to_np_array(const array& a) {
nb::ndarray<> mlx_to_dlpack(const array& a) { nb::ndarray<> mlx_to_dlpack(const array& a) {
return mlx_to_nd_array<>(a); return mlx_to_nd_array<>(a);
} }
nb::object to_scalar(array& a) {
{
nb::gil_scoped_release nogil;
a.eval();
}
switch (a.dtype()) {
case bool_:
return nb::cast(a.item<bool>());
case uint8:
return nb::cast(a.item<uint8_t>());
case uint16:
return nb::cast(a.item<uint16_t>());
case uint32:
return nb::cast(a.item<uint32_t>());
case uint64:
return nb::cast(a.item<uint64_t>());
case int8:
return nb::cast(a.item<int8_t>());
case int16:
return nb::cast(a.item<int16_t>());
case int32:
return nb::cast(a.item<int32_t>());
case int64:
return nb::cast(a.item<int64_t>());
case float16:
return nb::cast(static_cast<float>(a.item<float16_t>()));
case float32:
return nb::cast(a.item<float>());
case bfloat16:
return nb::cast(static_cast<float>(a.item<bfloat16_t>()));
case complex64:
return nb::cast(a.item<std::complex<float>>());
}
}
template <typename T, typename U = T>
nb::list to_list(array& a, size_t index, int dim) {
nb::list pl;
auto stride = a.strides()[dim];
for (int i = 0; i < a.shape(dim); ++i) {
if (dim == a.ndim() - 1) {
pl.append(static_cast<U>(a.data<T>()[index]));
} else {
pl.append(to_list<T, U>(a, index, dim + 1));
}
index += stride;
}
return pl;
}
nb::object tolist(array& a) {
if (a.ndim() == 0) {
return to_scalar(a);
}
{
nb::gil_scoped_release nogil;
a.eval();
}
switch (a.dtype()) {
case bool_:
return to_list<bool>(a, 0, 0);
case uint8:
return to_list<uint8_t>(a, 0, 0);
case uint16:
return to_list<uint16_t>(a, 0, 0);
case uint32:
return to_list<uint32_t>(a, 0, 0);
case uint64:
return to_list<uint64_t>(a, 0, 0);
case int8:
return to_list<int8_t>(a, 0, 0);
case int16:
return to_list<int16_t>(a, 0, 0);
case int32:
return to_list<int32_t>(a, 0, 0);
case int64:
return to_list<int64_t>(a, 0, 0);
case float16:
return to_list<float16_t, float>(a, 0, 0);
case float32:
return to_list<float>(a, 0, 0);
case bfloat16:
return to_list<bfloat16_t, float>(a, 0, 0);
case complex64:
return to_list<std::complex<float>>(a, 0, 0);
}
}
template <typename T, typename U>
void fill_vector(T list, std::vector<U>& vals) {
for (auto l : list) {
if (nb::isinstance<nb::list>(l)) {
fill_vector(nb::cast<nb::list>(l), vals);
} else if (nb::isinstance<nb::tuple>(*list.begin())) {
fill_vector(nb::cast<nb::tuple>(l), vals);
} else {
vals.push_back(nb::cast<U>(l));
}
}
}
template <typename T>
PyScalarT validate_shape(
T list,
const std::vector<int>& shape,
int idx,
bool& all_python_primitive_elements) {
if (idx >= shape.size()) {
throw std::invalid_argument("Initialization encountered extra dimension.");
}
auto s = shape[idx];
if (nb::len(list) != s) {
throw std::invalid_argument(
"Initialization encountered non-uniform length.");
}
if (s == 0) {
return pyfloat;
}
PyScalarT type = pybool;
for (auto l : list) {
PyScalarT t;
if (nb::isinstance<nb::list>(l)) {
t = validate_shape(
nb::cast<nb::list>(l), shape, idx + 1, all_python_primitive_elements);
} else if (nb::isinstance<nb::tuple>(*list.begin())) {
t = validate_shape(
nb::cast<nb::tuple>(l),
shape,
idx + 1,
all_python_primitive_elements);
} else if (nb::isinstance<array>(l)) {
all_python_primitive_elements = false;
auto arr = nb::cast<array>(l);
if (arr.ndim() + idx + 1 == shape.size() &&
std::equal(
arr.shape().cbegin(),
arr.shape().cend(),
shape.cbegin() + idx + 1)) {
t = pybool;
} else {
throw std::invalid_argument(
"Initialization encountered non-uniform length.");
}
} else {
if (nb::isinstance<nb::bool_>(l)) {
t = pybool;
} else if (nb::isinstance<nb::int_>(l)) {
t = pyint;
} else if (nb::isinstance<nb::float_>(l)) {
t = pyfloat;
} else if (PyComplex_Check(l.ptr())) {
t = pycomplex;
} else {
std::ostringstream msg;
msg << "Invalid type " << nb::type_name(l.type()).c_str()
<< " received in array initialization.";
throw std::invalid_argument(msg.str());
}
if (idx + 1 != shape.size()) {
throw std::invalid_argument(
"Initialization encountered non-uniform length.");
}
}
type = std::max(type, t);
}
return type;
}
template <typename T>
void get_shape(T list, std::vector<int>& shape) {
shape.push_back(check_shape_dim(nb::len(list)));
if (shape.back() > 0) {
auto l = list.begin();
if (nb::isinstance<nb::list>(*l)) {
return get_shape(nb::cast<nb::list>(*l), shape);
} else if (nb::isinstance<nb::tuple>(*l)) {
return get_shape(nb::cast<nb::tuple>(*l), shape);
} else if (nb::isinstance<array>(*l)) {
auto arr = nb::cast<array>(*l);
for (int i = 0; i < arr.ndim(); i++) {
shape.push_back(check_shape_dim(arr.shape(i)));
}
return;
}
}
}
template <typename T>
array array_from_list_impl(
T pl,
const PyScalarT& inferred_type,
std::optional<Dtype> specified_type,
const std::vector<int>& shape) {
// Make the array
switch (inferred_type) {
case pybool: {
std::vector<bool> vals;
fill_vector(pl, vals);
return array(vals.begin(), shape, specified_type.value_or(bool_));
}
case pyint: {
auto dtype = specified_type.value_or(int32);
if (dtype == int64) {
std::vector<int64_t> vals;
fill_vector(pl, vals);
return array(vals.begin(), shape, dtype);
} else if (dtype == uint64) {
std::vector<uint64_t> vals;
fill_vector(pl, vals);
return array(vals.begin(), shape, dtype);
} else if (dtype == uint32) {
std::vector<uint32_t> vals;
fill_vector(pl, vals);
return array(vals.begin(), shape, dtype);
} else if (issubdtype(dtype, inexact)) {
std::vector<float> vals;
fill_vector(pl, vals);
return array(vals.begin(), shape, dtype);
} else {
std::vector<int> vals;
fill_vector(pl, vals);
return array(vals.begin(), shape, dtype);
}
}
case pyfloat: {
std::vector<float> vals;
fill_vector(pl, vals);
return array(vals.begin(), shape, specified_type.value_or(float32));
}
case pycomplex: {
std::vector<std::complex<float>> vals;
fill_vector(pl, vals);
return array(
reinterpret_cast<complex64_t*>(vals.data()),
shape,
specified_type.value_or(complex64));
}
default: {
std::ostringstream msg;
msg << "Should not happen, inferred: " << inferred_type
<< " on subarray made of only python primitive types.";
throw std::runtime_error(msg.str());
}
}
}
template <typename T>
array array_from_list_impl(T pl, std::optional<Dtype> dtype) {
// Compute the shape
std::vector<int> shape;
get_shape(pl, shape);
// Validate the shape and type
bool all_python_primitive_elements = true;
auto type = validate_shape(pl, shape, 0, all_python_primitive_elements);
if (all_python_primitive_elements) {
// `pl` does not contain mlx arrays
return array_from_list_impl(pl, type, dtype, shape);
}
// `pl` contains mlx arrays
std::vector<array> arrays;
for (auto l : pl) {
arrays.push_back(create_array(nb::cast<ArrayInitType>(l), dtype));
}
return stack(arrays);
}
array array_from_list(nb::list pl, std::optional<Dtype> dtype) {
return array_from_list_impl(pl, dtype);
}
array array_from_list(nb::tuple pl, std::optional<Dtype> dtype) {
return array_from_list_impl(pl, dtype);
}
array create_array(ArrayInitType v, std::optional<Dtype> t) {
if (auto pv = std::get_if<nb::bool_>(&v); pv) {
return array(nb::cast<bool>(*pv), t.value_or(bool_));
} else if (auto pv = std::get_if<nb::int_>(&v); pv) {
return array(nb::cast<int>(*pv), t.value_or(int32));
} else if (auto pv = std::get_if<nb::float_>(&v); pv) {
return array(nb::cast<float>(*pv), t.value_or(float32));
} else if (auto pv = std::get_if<std::complex<float>>(&v); pv) {
return array(static_cast<complex64_t>(*pv), t.value_or(complex64));
} else if (auto pv = std::get_if<nb::list>(&v); pv) {
return array_from_list(*pv, t);
} else if (auto pv = std::get_if<nb::tuple>(&v); pv) {
return array_from_list(*pv, t);
} else if (auto pv = std::get_if<
nb::ndarray<nb::ro, nb::c_contig, nb::device::cpu>>(&v);
pv) {
return nd_array_to_mlx(*pv, t);
} else if (auto pv = std::get_if<array>(&v); pv) {
return astype(*pv, t.value_or((*pv).dtype()));
} else {
auto arr = to_array_with_accessor(std::get<nb::object>(v));
return astype(arr, t.value_or(arr.dtype()));
}
}

View File

@ -1,4 +1,5 @@
// Copyright © 2024 Apple Inc. // Copyright © 2024 Apple Inc.
#pragma once
#include <optional> #include <optional>
@ -6,13 +7,35 @@
#include <nanobind/ndarray.h> #include <nanobind/ndarray.h>
#include "mlx/array.h" #include "mlx/array.h"
#include "mlx/ops.h"
namespace nb = nanobind; namespace nb = nanobind;
using namespace mlx::core; using namespace mlx::core;
using ArrayInitType = std::variant<
nb::bool_,
nb::int_,
nb::float_,
// Must be above ndarray
array,
// Must be above complex
nb::ndarray<nb::ro, nb::c_contig, nb::device::cpu>,
std::complex<float>,
nb::list,
nb::tuple,
nb::object>;
array nd_array_to_mlx( array nd_array_to_mlx(
nb::ndarray<nb::ro, nb::c_contig, nb::device::cpu> nd_array, nb::ndarray<nb::ro, nb::c_contig, nb::device::cpu> nd_array,
std::optional<Dtype> dtype); std::optional<Dtype> dtype);
nb::ndarray<nb::numpy> mlx_to_np_array(const array& a); nb::ndarray<nb::numpy> mlx_to_np_array(const array& a);
nb::ndarray<> mlx_to_dlpack(const array& a); nb::ndarray<> mlx_to_dlpack(const array& a);
nb::object to_scalar(array& a);
nb::object tolist(array& a);
array create_array(ArrayInitType v, std::optional<Dtype> t);
array array_from_list(nb::list pl, std::optional<Dtype> dtype);
array array_from_list(nb::tuple pl, std::optional<Dtype> dtype);

View File

@ -2,6 +2,7 @@
#include <numeric> #include <numeric>
#include <sstream> #include <sstream>
#include "python/src/convert.h"
#include "python/src/indexing.h" #include "python/src/indexing.h"
#include "mlx/ops.h" #include "mlx/ops.h"
@ -51,7 +52,8 @@ array get_int_index(nb::object idx, int axis_size) {
bool is_valid_index_type(const nb::object& obj) { bool is_valid_index_type(const nb::object& obj) {
return nb::isinstance<nb::slice>(obj) || nb::isinstance<nb::int_>(obj) || return nb::isinstance<nb::slice>(obj) || nb::isinstance<nb::int_>(obj) ||
nb::isinstance<array>(obj) || obj.is_none() || nb::ellipsis().is(obj); nb::isinstance<array>(obj) || obj.is_none() || nb::ellipsis().is(obj) ||
nb::isinstance<nb::list>(obj);
} }
array mlx_get_item_slice(const array& src, const nb::slice& in_slice) { array mlx_get_item_slice(const array& src, const nb::slice& in_slice) {
@ -255,11 +257,18 @@ array mlx_get_item_nd(array src, const nb::tuple& entries) {
// The plan is as follows: // The plan is as follows:
// 1. Replace the ellipsis with a series of slice(None) // 1. Replace the ellipsis with a series of slice(None)
// 2. Loop over the indices and calculate the gather indices // 2. Convert list to array
// 3. Calculate the remaining slices and reshapes // 3. Loop over the indices and calculate the gather indices
// 4. Calculate the remaining slices and reshapes
// Ellipsis handling // Ellipsis handling
auto [non_none_indices, indices] = mlx_expand_ellipsis(src.shape(), entries); auto [non_none_indices, indices] = mlx_expand_ellipsis(src.shape(), entries);
// List handling
for (auto& idx : indices) {
if (nb::isinstance<nb::list>(idx)) {
idx = nb::cast(array_from_list(nb::cast<nb::list>(idx), {}));
}
}
// Check for the number of indices passed // Check for the number of indices passed
if (non_none_indices > src.ndim()) { if (non_none_indices > src.ndim()) {
@ -440,6 +449,9 @@ array mlx_get_item(const array& src, const nb::object& obj) {
std::vector<int> s(1, 1); std::vector<int> s(1, 1);
s.insert(s.end(), src.shape().begin(), src.shape().end()); s.insert(s.end(), src.shape().begin(), src.shape().end());
return reshape(src, s); return reshape(src, s);
} else if (nb::isinstance<nb::list>(obj)) {
return mlx_get_item_array(
src, array_from_list(nb::cast<nb::list>(obj), {}));
} }
throw std::invalid_argument("Cannot index mlx array using the given type."); throw std::invalid_argument("Cannot index mlx array using the given type.");
} }
@ -564,6 +576,13 @@ std::tuple<std::vector<array>, array, std::vector<int>> mlx_scatter_args_nd(
// Expand ellipses into a series of ':' slices // Expand ellipses into a series of ':' slices
auto [non_none_indices, indices] = mlx_expand_ellipsis(src.shape(), entries); auto [non_none_indices, indices] = mlx_expand_ellipsis(src.shape(), entries);
// Convert List to array
for (auto& idx : indices) {
if (nb::isinstance<nb::list>(idx)) {
idx = nb::cast(array_from_list(nb::cast<nb::list>(idx), {}));
}
}
if (non_none_indices > src.ndim()) { if (non_none_indices > src.ndim()) {
std::ostringstream msg; std::ostringstream msg;
msg << "Too many indices for array with " << src.ndim() << "dimensions."; msg << "Too many indices for array with " << src.ndim() << "dimensions.";
@ -753,7 +772,11 @@ mlx_compute_scatter_args(
return mlx_scatter_args_nd(src, nb::cast<nb::tuple>(obj), vals); return mlx_scatter_args_nd(src, nb::cast<nb::tuple>(obj), vals);
} else if (obj.is_none()) { } else if (obj.is_none()) {
return {{}, broadcast_to(vals, src.shape()), {}}; return {{}, broadcast_to(vals, src.shape()), {}};
} else if (nb::isinstance<nb::list>(obj)) {
return mlx_scatter_args_array(
src, array_from_list(nb::cast<nb::list>(obj), {}), vals);
} }
throw std::invalid_argument("Cannot index mlx array using the given type."); throw std::invalid_argument("Cannot index mlx array using the given type.");
} }
@ -769,7 +792,7 @@ auto mlx_slice_update(
if (nb::isinstance<nb::tuple>(obj)) { if (nb::isinstance<nb::tuple>(obj)) {
// Can't route to slice update if any arrays are present // Can't route to slice update if any arrays are present
for (auto idx : nb::cast<nb::tuple>(obj)) { for (auto idx : nb::cast<nb::tuple>(obj)) {
if (nb::isinstance<array>(idx)) { if (nb::isinstance<array>(idx) || nb::isinstance<nb::list>(idx)) {
return std::make_pair(false, src); return std::make_pair(false, src);
} }
} }

View File

@ -1740,6 +1740,68 @@ class TestArray(mlx_tests.MLXTestCase):
y = np.from_dlpack(x) y = np.from_dlpack(x)
self.assertTrue(mx.array_equal(y, x)) self.assertTrue(mx.array_equal(y, x))
def test_getitem_with_list(self):
a = mx.array([1, 2, 3, 4, 5])
idx = [0, 2, 4]
self.assertTrue(np.array_equal(a[idx], np.array(a)[idx]))
a = mx.array([[1, 2], [3, 4], [5, 6]])
idx = [0, 2]
self.assertTrue(np.array_equal(a[idx], np.array(a)[idx]))
a = mx.arange(10).reshape(5, 2)
idx = [0, 2, 4]
self.assertTrue(np.array_equal(a[idx], np.array(a)[idx]))
idx = [0, 2]
a = mx.arange(16).reshape(4, 4)
anp = np.array(a)
self.assertTrue(np.array_equal(a[idx, 0], anp[idx, 0]))
self.assertTrue(np.array_equal(a[idx, :], anp[idx, :]))
self.assertTrue(np.array_equal(a[0, idx], anp[0, idx]))
self.assertTrue(np.array_equal(a[:, idx], anp[:, idx]))
def test_setitem_with_list(self):
a = mx.array([1, 2, 3, 4, 5])
anp = np.array(a)
idx = [0, 2, 4]
a[idx] = 3
anp[idx] = 3
self.assertTrue(np.array_equal(a, anp))
a = mx.array([[1, 2], [3, 4], [5, 6]])
idx = [0, 2]
anp = np.array(a)
a[idx] = 3
anp[idx] = 3
self.assertTrue(np.array_equal(a, anp))
a = mx.arange(10).reshape(5, 2)
idx = [0, 2, 4]
anp = np.array(a)
a[idx] = 3
anp[idx] = 3
self.assertTrue(np.array_equal(a, anp))
idx = [0, 2]
a = mx.arange(16).reshape(4, 4)
anp = np.array(a)
a[idx, 0] = 1
anp[idx, 0] = 1
self.assertTrue(np.array_equal(a, anp))
a[idx, :] = 2
anp[idx, :] = 2
self.assertTrue(np.array_equal(a, anp))
a[0, idx] = 3
anp[0, idx] = 3
self.assertTrue(np.array_equal(a, anp))
a[:, idx] = 4
anp[:, idx] = 4
self.assertTrue(np.array_equal(a, anp))
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()