mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
list based indexing (#1150)
This commit is contained in:
parent
79ef49b2c2
commit
eb8321d863
@ -23,330 +23,6 @@ namespace nb = nanobind;
|
|||||||
using namespace nb::literals;
|
using namespace nb::literals;
|
||||||
using namespace mlx::core;
|
using namespace mlx::core;
|
||||||
|
|
||||||
enum PyScalarT {
|
|
||||||
pybool = 0,
|
|
||||||
pyint = 1,
|
|
||||||
pyfloat = 2,
|
|
||||||
pycomplex = 3,
|
|
||||||
};
|
|
||||||
|
|
||||||
template <typename T, typename U = T>
|
|
||||||
nb::list to_list(array& a, size_t index, int dim) {
|
|
||||||
nb::list pl;
|
|
||||||
auto stride = a.strides()[dim];
|
|
||||||
for (int i = 0; i < a.shape(dim); ++i) {
|
|
||||||
if (dim == a.ndim() - 1) {
|
|
||||||
pl.append(static_cast<U>(a.data<T>()[index]));
|
|
||||||
} else {
|
|
||||||
pl.append(to_list<T, U>(a, index, dim + 1));
|
|
||||||
}
|
|
||||||
index += stride;
|
|
||||||
}
|
|
||||||
return pl;
|
|
||||||
}
|
|
||||||
|
|
||||||
auto to_scalar(array& a) {
|
|
||||||
{
|
|
||||||
nb::gil_scoped_release nogil;
|
|
||||||
a.eval();
|
|
||||||
}
|
|
||||||
switch (a.dtype()) {
|
|
||||||
case bool_:
|
|
||||||
return nb::cast(a.item<bool>());
|
|
||||||
case uint8:
|
|
||||||
return nb::cast(a.item<uint8_t>());
|
|
||||||
case uint16:
|
|
||||||
return nb::cast(a.item<uint16_t>());
|
|
||||||
case uint32:
|
|
||||||
return nb::cast(a.item<uint32_t>());
|
|
||||||
case uint64:
|
|
||||||
return nb::cast(a.item<uint64_t>());
|
|
||||||
case int8:
|
|
||||||
return nb::cast(a.item<int8_t>());
|
|
||||||
case int16:
|
|
||||||
return nb::cast(a.item<int16_t>());
|
|
||||||
case int32:
|
|
||||||
return nb::cast(a.item<int32_t>());
|
|
||||||
case int64:
|
|
||||||
return nb::cast(a.item<int64_t>());
|
|
||||||
case float16:
|
|
||||||
return nb::cast(static_cast<float>(a.item<float16_t>()));
|
|
||||||
case float32:
|
|
||||||
return nb::cast(a.item<float>());
|
|
||||||
case bfloat16:
|
|
||||||
return nb::cast(static_cast<float>(a.item<bfloat16_t>()));
|
|
||||||
case complex64:
|
|
||||||
return nb::cast(a.item<std::complex<float>>());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
nb::object tolist(array& a) {
|
|
||||||
if (a.ndim() == 0) {
|
|
||||||
return to_scalar(a);
|
|
||||||
}
|
|
||||||
{
|
|
||||||
nb::gil_scoped_release nogil;
|
|
||||||
a.eval();
|
|
||||||
}
|
|
||||||
switch (a.dtype()) {
|
|
||||||
case bool_:
|
|
||||||
return to_list<bool>(a, 0, 0);
|
|
||||||
case uint8:
|
|
||||||
return to_list<uint8_t>(a, 0, 0);
|
|
||||||
case uint16:
|
|
||||||
return to_list<uint16_t>(a, 0, 0);
|
|
||||||
case uint32:
|
|
||||||
return to_list<uint32_t>(a, 0, 0);
|
|
||||||
case uint64:
|
|
||||||
return to_list<uint64_t>(a, 0, 0);
|
|
||||||
case int8:
|
|
||||||
return to_list<int8_t>(a, 0, 0);
|
|
||||||
case int16:
|
|
||||||
return to_list<int16_t>(a, 0, 0);
|
|
||||||
case int32:
|
|
||||||
return to_list<int32_t>(a, 0, 0);
|
|
||||||
case int64:
|
|
||||||
return to_list<int64_t>(a, 0, 0);
|
|
||||||
case float16:
|
|
||||||
return to_list<float16_t, float>(a, 0, 0);
|
|
||||||
case float32:
|
|
||||||
return to_list<float>(a, 0, 0);
|
|
||||||
case bfloat16:
|
|
||||||
return to_list<bfloat16_t, float>(a, 0, 0);
|
|
||||||
case complex64:
|
|
||||||
return to_list<std::complex<float>>(a, 0, 0);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T, typename U>
|
|
||||||
void fill_vector(T list, std::vector<U>& vals) {
|
|
||||||
for (auto l : list) {
|
|
||||||
if (nb::isinstance<nb::list>(l)) {
|
|
||||||
fill_vector(nb::cast<nb::list>(l), vals);
|
|
||||||
} else if (nb::isinstance<nb::tuple>(*list.begin())) {
|
|
||||||
fill_vector(nb::cast<nb::tuple>(l), vals);
|
|
||||||
} else {
|
|
||||||
vals.push_back(nb::cast<U>(l));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
PyScalarT validate_shape(
|
|
||||||
T list,
|
|
||||||
const std::vector<int>& shape,
|
|
||||||
int idx,
|
|
||||||
bool& all_python_primitive_elements) {
|
|
||||||
if (idx >= shape.size()) {
|
|
||||||
throw std::invalid_argument("Initialization encountered extra dimension.");
|
|
||||||
}
|
|
||||||
auto s = shape[idx];
|
|
||||||
if (nb::len(list) != s) {
|
|
||||||
throw std::invalid_argument(
|
|
||||||
"Initialization encountered non-uniform length.");
|
|
||||||
}
|
|
||||||
|
|
||||||
if (s == 0) {
|
|
||||||
return pyfloat;
|
|
||||||
}
|
|
||||||
|
|
||||||
PyScalarT type = pybool;
|
|
||||||
for (auto l : list) {
|
|
||||||
PyScalarT t;
|
|
||||||
if (nb::isinstance<nb::list>(l)) {
|
|
||||||
t = validate_shape(
|
|
||||||
nb::cast<nb::list>(l), shape, idx + 1, all_python_primitive_elements);
|
|
||||||
} else if (nb::isinstance<nb::tuple>(*list.begin())) {
|
|
||||||
t = validate_shape(
|
|
||||||
nb::cast<nb::tuple>(l),
|
|
||||||
shape,
|
|
||||||
idx + 1,
|
|
||||||
all_python_primitive_elements);
|
|
||||||
} else if (nb::isinstance<array>(l)) {
|
|
||||||
all_python_primitive_elements = false;
|
|
||||||
auto arr = nb::cast<array>(l);
|
|
||||||
if (arr.ndim() + idx + 1 == shape.size() &&
|
|
||||||
std::equal(
|
|
||||||
arr.shape().cbegin(),
|
|
||||||
arr.shape().cend(),
|
|
||||||
shape.cbegin() + idx + 1)) {
|
|
||||||
t = pybool;
|
|
||||||
} else {
|
|
||||||
throw std::invalid_argument(
|
|
||||||
"Initialization encountered non-uniform length.");
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if (nb::isinstance<nb::bool_>(l)) {
|
|
||||||
t = pybool;
|
|
||||||
} else if (nb::isinstance<nb::int_>(l)) {
|
|
||||||
t = pyint;
|
|
||||||
} else if (nb::isinstance<nb::float_>(l)) {
|
|
||||||
t = pyfloat;
|
|
||||||
} else if (PyComplex_Check(l.ptr())) {
|
|
||||||
t = pycomplex;
|
|
||||||
} else {
|
|
||||||
std::ostringstream msg;
|
|
||||||
msg << "Invalid type " << nb::type_name(l.type()).c_str()
|
|
||||||
<< " received in array initialization.";
|
|
||||||
throw std::invalid_argument(msg.str());
|
|
||||||
}
|
|
||||||
|
|
||||||
if (idx + 1 != shape.size()) {
|
|
||||||
throw std::invalid_argument(
|
|
||||||
"Initialization encountered non-uniform length.");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
type = std::max(type, t);
|
|
||||||
}
|
|
||||||
return type;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
void get_shape(T list, std::vector<int>& shape) {
|
|
||||||
shape.push_back(check_shape_dim(nb::len(list)));
|
|
||||||
if (shape.back() > 0) {
|
|
||||||
auto l = list.begin();
|
|
||||||
if (nb::isinstance<nb::list>(*l)) {
|
|
||||||
return get_shape(nb::cast<nb::list>(*l), shape);
|
|
||||||
} else if (nb::isinstance<nb::tuple>(*l)) {
|
|
||||||
return get_shape(nb::cast<nb::tuple>(*l), shape);
|
|
||||||
} else if (nb::isinstance<array>(*l)) {
|
|
||||||
auto arr = nb::cast<array>(*l);
|
|
||||||
for (int i = 0; i < arr.ndim(); i++) {
|
|
||||||
shape.push_back(check_shape_dim(arr.shape(i)));
|
|
||||||
}
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
using ArrayInitType = std::variant<
|
|
||||||
nb::bool_,
|
|
||||||
nb::int_,
|
|
||||||
nb::float_,
|
|
||||||
// Must be above ndarray
|
|
||||||
array,
|
|
||||||
// Must be above complex
|
|
||||||
nb::ndarray<nb::ro, nb::c_contig, nb::device::cpu>,
|
|
||||||
std::complex<float>,
|
|
||||||
nb::list,
|
|
||||||
nb::tuple,
|
|
||||||
nb::object>;
|
|
||||||
|
|
||||||
// Forward declaration
|
|
||||||
array create_array(ArrayInitType v, std::optional<Dtype> t);
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
array array_from_list(
|
|
||||||
T pl,
|
|
||||||
const PyScalarT& inferred_type,
|
|
||||||
std::optional<Dtype> specified_type,
|
|
||||||
const std::vector<int>& shape) {
|
|
||||||
// Make the array
|
|
||||||
switch (inferred_type) {
|
|
||||||
case pybool: {
|
|
||||||
std::vector<bool> vals;
|
|
||||||
fill_vector(pl, vals);
|
|
||||||
return array(vals.begin(), shape, specified_type.value_or(bool_));
|
|
||||||
}
|
|
||||||
case pyint: {
|
|
||||||
auto dtype = specified_type.value_or(int32);
|
|
||||||
if (dtype == int64) {
|
|
||||||
std::vector<int64_t> vals;
|
|
||||||
fill_vector(pl, vals);
|
|
||||||
return array(vals.begin(), shape, dtype);
|
|
||||||
} else if (dtype == uint64) {
|
|
||||||
std::vector<uint64_t> vals;
|
|
||||||
fill_vector(pl, vals);
|
|
||||||
return array(vals.begin(), shape, dtype);
|
|
||||||
} else if (dtype == uint32) {
|
|
||||||
std::vector<uint32_t> vals;
|
|
||||||
fill_vector(pl, vals);
|
|
||||||
return array(vals.begin(), shape, dtype);
|
|
||||||
} else if (issubdtype(dtype, inexact)) {
|
|
||||||
std::vector<float> vals;
|
|
||||||
fill_vector(pl, vals);
|
|
||||||
return array(vals.begin(), shape, dtype);
|
|
||||||
} else {
|
|
||||||
std::vector<int> vals;
|
|
||||||
fill_vector(pl, vals);
|
|
||||||
return array(vals.begin(), shape, dtype);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
case pyfloat: {
|
|
||||||
std::vector<float> vals;
|
|
||||||
fill_vector(pl, vals);
|
|
||||||
return array(vals.begin(), shape, specified_type.value_or(float32));
|
|
||||||
}
|
|
||||||
case pycomplex: {
|
|
||||||
std::vector<std::complex<float>> vals;
|
|
||||||
fill_vector(pl, vals);
|
|
||||||
return array(
|
|
||||||
reinterpret_cast<complex64_t*>(vals.data()),
|
|
||||||
shape,
|
|
||||||
specified_type.value_or(complex64));
|
|
||||||
}
|
|
||||||
default: {
|
|
||||||
std::ostringstream msg;
|
|
||||||
msg << "Should not happen, inferred: " << inferred_type
|
|
||||||
<< " on subarray made of only python primitive types.";
|
|
||||||
throw std::runtime_error(msg.str());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
array array_from_list(T pl, std::optional<Dtype> dtype) {
|
|
||||||
// Compute the shape
|
|
||||||
std::vector<int> shape;
|
|
||||||
get_shape(pl, shape);
|
|
||||||
|
|
||||||
// Validate the shape and type
|
|
||||||
bool all_python_primitive_elements = true;
|
|
||||||
auto type = validate_shape(pl, shape, 0, all_python_primitive_elements);
|
|
||||||
|
|
||||||
if (all_python_primitive_elements) {
|
|
||||||
// `pl` does not contain mlx arrays
|
|
||||||
return array_from_list(pl, type, dtype, shape);
|
|
||||||
}
|
|
||||||
|
|
||||||
// `pl` contains mlx arrays
|
|
||||||
std::vector<array> arrays;
|
|
||||||
for (auto l : pl) {
|
|
||||||
arrays.push_back(create_array(nb::cast<ArrayInitType>(l), dtype));
|
|
||||||
}
|
|
||||||
return stack(arrays);
|
|
||||||
}
|
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
|
||||||
// Module
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
|
||||||
|
|
||||||
array create_array(ArrayInitType v, std::optional<Dtype> t) {
|
|
||||||
if (auto pv = std::get_if<nb::bool_>(&v); pv) {
|
|
||||||
return array(nb::cast<bool>(*pv), t.value_or(bool_));
|
|
||||||
} else if (auto pv = std::get_if<nb::int_>(&v); pv) {
|
|
||||||
return array(nb::cast<int>(*pv), t.value_or(int32));
|
|
||||||
} else if (auto pv = std::get_if<nb::float_>(&v); pv) {
|
|
||||||
return array(nb::cast<float>(*pv), t.value_or(float32));
|
|
||||||
} else if (auto pv = std::get_if<std::complex<float>>(&v); pv) {
|
|
||||||
return array(static_cast<complex64_t>(*pv), t.value_or(complex64));
|
|
||||||
} else if (auto pv = std::get_if<nb::list>(&v); pv) {
|
|
||||||
return array_from_list(*pv, t);
|
|
||||||
} else if (auto pv = std::get_if<nb::tuple>(&v); pv) {
|
|
||||||
return array_from_list(*pv, t);
|
|
||||||
} else if (auto pv = std::get_if<
|
|
||||||
nb::ndarray<nb::ro, nb::c_contig, nb::device::cpu>>(&v);
|
|
||||||
pv) {
|
|
||||||
return nd_array_to_mlx(*pv, t);
|
|
||||||
} else if (auto pv = std::get_if<array>(&v); pv) {
|
|
||||||
return astype(*pv, t.value_or((*pv).dtype()));
|
|
||||||
} else {
|
|
||||||
auto arr = to_array_with_accessor(std::get<nb::object>(v));
|
|
||||||
return astype(arr, t.value_or(arr.dtype()));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
class ArrayAt {
|
class ArrayAt {
|
||||||
public:
|
public:
|
||||||
ArrayAt(array x) : x_(std::move(x)) {}
|
ArrayAt(array x) : x_(std::move(x)) {}
|
||||||
|
@ -3,9 +3,17 @@
|
|||||||
#include <nanobind/stl/complex.h>
|
#include <nanobind/stl/complex.h>
|
||||||
|
|
||||||
#include "python/src/convert.h"
|
#include "python/src/convert.h"
|
||||||
|
#include "python/src/utils.h"
|
||||||
|
|
||||||
#include "mlx/utils.h"
|
#include "mlx/utils.h"
|
||||||
|
|
||||||
|
enum PyScalarT {
|
||||||
|
pybool = 0,
|
||||||
|
pyint = 1,
|
||||||
|
pyfloat = 2,
|
||||||
|
pycomplex = 3,
|
||||||
|
};
|
||||||
|
|
||||||
namespace nanobind {
|
namespace nanobind {
|
||||||
template <>
|
template <>
|
||||||
struct ndarray_traits<float16_t> {
|
struct ndarray_traits<float16_t> {
|
||||||
@ -158,3 +166,308 @@ nb::ndarray<nb::numpy> mlx_to_np_array(const array& a) {
|
|||||||
nb::ndarray<> mlx_to_dlpack(const array& a) {
|
nb::ndarray<> mlx_to_dlpack(const array& a) {
|
||||||
return mlx_to_nd_array<>(a);
|
return mlx_to_nd_array<>(a);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
nb::object to_scalar(array& a) {
|
||||||
|
{
|
||||||
|
nb::gil_scoped_release nogil;
|
||||||
|
a.eval();
|
||||||
|
}
|
||||||
|
switch (a.dtype()) {
|
||||||
|
case bool_:
|
||||||
|
return nb::cast(a.item<bool>());
|
||||||
|
case uint8:
|
||||||
|
return nb::cast(a.item<uint8_t>());
|
||||||
|
case uint16:
|
||||||
|
return nb::cast(a.item<uint16_t>());
|
||||||
|
case uint32:
|
||||||
|
return nb::cast(a.item<uint32_t>());
|
||||||
|
case uint64:
|
||||||
|
return nb::cast(a.item<uint64_t>());
|
||||||
|
case int8:
|
||||||
|
return nb::cast(a.item<int8_t>());
|
||||||
|
case int16:
|
||||||
|
return nb::cast(a.item<int16_t>());
|
||||||
|
case int32:
|
||||||
|
return nb::cast(a.item<int32_t>());
|
||||||
|
case int64:
|
||||||
|
return nb::cast(a.item<int64_t>());
|
||||||
|
case float16:
|
||||||
|
return nb::cast(static_cast<float>(a.item<float16_t>()));
|
||||||
|
case float32:
|
||||||
|
return nb::cast(a.item<float>());
|
||||||
|
case bfloat16:
|
||||||
|
return nb::cast(static_cast<float>(a.item<bfloat16_t>()));
|
||||||
|
case complex64:
|
||||||
|
return nb::cast(a.item<std::complex<float>>());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename U = T>
|
||||||
|
nb::list to_list(array& a, size_t index, int dim) {
|
||||||
|
nb::list pl;
|
||||||
|
auto stride = a.strides()[dim];
|
||||||
|
for (int i = 0; i < a.shape(dim); ++i) {
|
||||||
|
if (dim == a.ndim() - 1) {
|
||||||
|
pl.append(static_cast<U>(a.data<T>()[index]));
|
||||||
|
} else {
|
||||||
|
pl.append(to_list<T, U>(a, index, dim + 1));
|
||||||
|
}
|
||||||
|
index += stride;
|
||||||
|
}
|
||||||
|
return pl;
|
||||||
|
}
|
||||||
|
|
||||||
|
nb::object tolist(array& a) {
|
||||||
|
if (a.ndim() == 0) {
|
||||||
|
return to_scalar(a);
|
||||||
|
}
|
||||||
|
{
|
||||||
|
nb::gil_scoped_release nogil;
|
||||||
|
a.eval();
|
||||||
|
}
|
||||||
|
switch (a.dtype()) {
|
||||||
|
case bool_:
|
||||||
|
return to_list<bool>(a, 0, 0);
|
||||||
|
case uint8:
|
||||||
|
return to_list<uint8_t>(a, 0, 0);
|
||||||
|
case uint16:
|
||||||
|
return to_list<uint16_t>(a, 0, 0);
|
||||||
|
case uint32:
|
||||||
|
return to_list<uint32_t>(a, 0, 0);
|
||||||
|
case uint64:
|
||||||
|
return to_list<uint64_t>(a, 0, 0);
|
||||||
|
case int8:
|
||||||
|
return to_list<int8_t>(a, 0, 0);
|
||||||
|
case int16:
|
||||||
|
return to_list<int16_t>(a, 0, 0);
|
||||||
|
case int32:
|
||||||
|
return to_list<int32_t>(a, 0, 0);
|
||||||
|
case int64:
|
||||||
|
return to_list<int64_t>(a, 0, 0);
|
||||||
|
case float16:
|
||||||
|
return to_list<float16_t, float>(a, 0, 0);
|
||||||
|
case float32:
|
||||||
|
return to_list<float>(a, 0, 0);
|
||||||
|
case bfloat16:
|
||||||
|
return to_list<bfloat16_t, float>(a, 0, 0);
|
||||||
|
case complex64:
|
||||||
|
return to_list<std::complex<float>>(a, 0, 0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename U>
|
||||||
|
void fill_vector(T list, std::vector<U>& vals) {
|
||||||
|
for (auto l : list) {
|
||||||
|
if (nb::isinstance<nb::list>(l)) {
|
||||||
|
fill_vector(nb::cast<nb::list>(l), vals);
|
||||||
|
} else if (nb::isinstance<nb::tuple>(*list.begin())) {
|
||||||
|
fill_vector(nb::cast<nb::tuple>(l), vals);
|
||||||
|
} else {
|
||||||
|
vals.push_back(nb::cast<U>(l));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
PyScalarT validate_shape(
|
||||||
|
T list,
|
||||||
|
const std::vector<int>& shape,
|
||||||
|
int idx,
|
||||||
|
bool& all_python_primitive_elements) {
|
||||||
|
if (idx >= shape.size()) {
|
||||||
|
throw std::invalid_argument("Initialization encountered extra dimension.");
|
||||||
|
}
|
||||||
|
auto s = shape[idx];
|
||||||
|
if (nb::len(list) != s) {
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"Initialization encountered non-uniform length.");
|
||||||
|
}
|
||||||
|
|
||||||
|
if (s == 0) {
|
||||||
|
return pyfloat;
|
||||||
|
}
|
||||||
|
|
||||||
|
PyScalarT type = pybool;
|
||||||
|
for (auto l : list) {
|
||||||
|
PyScalarT t;
|
||||||
|
if (nb::isinstance<nb::list>(l)) {
|
||||||
|
t = validate_shape(
|
||||||
|
nb::cast<nb::list>(l), shape, idx + 1, all_python_primitive_elements);
|
||||||
|
} else if (nb::isinstance<nb::tuple>(*list.begin())) {
|
||||||
|
t = validate_shape(
|
||||||
|
nb::cast<nb::tuple>(l),
|
||||||
|
shape,
|
||||||
|
idx + 1,
|
||||||
|
all_python_primitive_elements);
|
||||||
|
} else if (nb::isinstance<array>(l)) {
|
||||||
|
all_python_primitive_elements = false;
|
||||||
|
auto arr = nb::cast<array>(l);
|
||||||
|
if (arr.ndim() + idx + 1 == shape.size() &&
|
||||||
|
std::equal(
|
||||||
|
arr.shape().cbegin(),
|
||||||
|
arr.shape().cend(),
|
||||||
|
shape.cbegin() + idx + 1)) {
|
||||||
|
t = pybool;
|
||||||
|
} else {
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"Initialization encountered non-uniform length.");
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if (nb::isinstance<nb::bool_>(l)) {
|
||||||
|
t = pybool;
|
||||||
|
} else if (nb::isinstance<nb::int_>(l)) {
|
||||||
|
t = pyint;
|
||||||
|
} else if (nb::isinstance<nb::float_>(l)) {
|
||||||
|
t = pyfloat;
|
||||||
|
} else if (PyComplex_Check(l.ptr())) {
|
||||||
|
t = pycomplex;
|
||||||
|
} else {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "Invalid type " << nb::type_name(l.type()).c_str()
|
||||||
|
<< " received in array initialization.";
|
||||||
|
throw std::invalid_argument(msg.str());
|
||||||
|
}
|
||||||
|
|
||||||
|
if (idx + 1 != shape.size()) {
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"Initialization encountered non-uniform length.");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
type = std::max(type, t);
|
||||||
|
}
|
||||||
|
return type;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void get_shape(T list, std::vector<int>& shape) {
|
||||||
|
shape.push_back(check_shape_dim(nb::len(list)));
|
||||||
|
if (shape.back() > 0) {
|
||||||
|
auto l = list.begin();
|
||||||
|
if (nb::isinstance<nb::list>(*l)) {
|
||||||
|
return get_shape(nb::cast<nb::list>(*l), shape);
|
||||||
|
} else if (nb::isinstance<nb::tuple>(*l)) {
|
||||||
|
return get_shape(nb::cast<nb::tuple>(*l), shape);
|
||||||
|
} else if (nb::isinstance<array>(*l)) {
|
||||||
|
auto arr = nb::cast<array>(*l);
|
||||||
|
for (int i = 0; i < arr.ndim(); i++) {
|
||||||
|
shape.push_back(check_shape_dim(arr.shape(i)));
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
array array_from_list_impl(
|
||||||
|
T pl,
|
||||||
|
const PyScalarT& inferred_type,
|
||||||
|
std::optional<Dtype> specified_type,
|
||||||
|
const std::vector<int>& shape) {
|
||||||
|
// Make the array
|
||||||
|
switch (inferred_type) {
|
||||||
|
case pybool: {
|
||||||
|
std::vector<bool> vals;
|
||||||
|
fill_vector(pl, vals);
|
||||||
|
return array(vals.begin(), shape, specified_type.value_or(bool_));
|
||||||
|
}
|
||||||
|
case pyint: {
|
||||||
|
auto dtype = specified_type.value_or(int32);
|
||||||
|
if (dtype == int64) {
|
||||||
|
std::vector<int64_t> vals;
|
||||||
|
fill_vector(pl, vals);
|
||||||
|
return array(vals.begin(), shape, dtype);
|
||||||
|
} else if (dtype == uint64) {
|
||||||
|
std::vector<uint64_t> vals;
|
||||||
|
fill_vector(pl, vals);
|
||||||
|
return array(vals.begin(), shape, dtype);
|
||||||
|
} else if (dtype == uint32) {
|
||||||
|
std::vector<uint32_t> vals;
|
||||||
|
fill_vector(pl, vals);
|
||||||
|
return array(vals.begin(), shape, dtype);
|
||||||
|
} else if (issubdtype(dtype, inexact)) {
|
||||||
|
std::vector<float> vals;
|
||||||
|
fill_vector(pl, vals);
|
||||||
|
return array(vals.begin(), shape, dtype);
|
||||||
|
} else {
|
||||||
|
std::vector<int> vals;
|
||||||
|
fill_vector(pl, vals);
|
||||||
|
return array(vals.begin(), shape, dtype);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case pyfloat: {
|
||||||
|
std::vector<float> vals;
|
||||||
|
fill_vector(pl, vals);
|
||||||
|
return array(vals.begin(), shape, specified_type.value_or(float32));
|
||||||
|
}
|
||||||
|
case pycomplex: {
|
||||||
|
std::vector<std::complex<float>> vals;
|
||||||
|
fill_vector(pl, vals);
|
||||||
|
return array(
|
||||||
|
reinterpret_cast<complex64_t*>(vals.data()),
|
||||||
|
shape,
|
||||||
|
specified_type.value_or(complex64));
|
||||||
|
}
|
||||||
|
default: {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "Should not happen, inferred: " << inferred_type
|
||||||
|
<< " on subarray made of only python primitive types.";
|
||||||
|
throw std::runtime_error(msg.str());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
array array_from_list_impl(T pl, std::optional<Dtype> dtype) {
|
||||||
|
// Compute the shape
|
||||||
|
std::vector<int> shape;
|
||||||
|
get_shape(pl, shape);
|
||||||
|
|
||||||
|
// Validate the shape and type
|
||||||
|
bool all_python_primitive_elements = true;
|
||||||
|
auto type = validate_shape(pl, shape, 0, all_python_primitive_elements);
|
||||||
|
|
||||||
|
if (all_python_primitive_elements) {
|
||||||
|
// `pl` does not contain mlx arrays
|
||||||
|
return array_from_list_impl(pl, type, dtype, shape);
|
||||||
|
}
|
||||||
|
|
||||||
|
// `pl` contains mlx arrays
|
||||||
|
std::vector<array> arrays;
|
||||||
|
for (auto l : pl) {
|
||||||
|
arrays.push_back(create_array(nb::cast<ArrayInitType>(l), dtype));
|
||||||
|
}
|
||||||
|
return stack(arrays);
|
||||||
|
}
|
||||||
|
|
||||||
|
array array_from_list(nb::list pl, std::optional<Dtype> dtype) {
|
||||||
|
return array_from_list_impl(pl, dtype);
|
||||||
|
}
|
||||||
|
|
||||||
|
array array_from_list(nb::tuple pl, std::optional<Dtype> dtype) {
|
||||||
|
return array_from_list_impl(pl, dtype);
|
||||||
|
}
|
||||||
|
|
||||||
|
array create_array(ArrayInitType v, std::optional<Dtype> t) {
|
||||||
|
if (auto pv = std::get_if<nb::bool_>(&v); pv) {
|
||||||
|
return array(nb::cast<bool>(*pv), t.value_or(bool_));
|
||||||
|
} else if (auto pv = std::get_if<nb::int_>(&v); pv) {
|
||||||
|
return array(nb::cast<int>(*pv), t.value_or(int32));
|
||||||
|
} else if (auto pv = std::get_if<nb::float_>(&v); pv) {
|
||||||
|
return array(nb::cast<float>(*pv), t.value_or(float32));
|
||||||
|
} else if (auto pv = std::get_if<std::complex<float>>(&v); pv) {
|
||||||
|
return array(static_cast<complex64_t>(*pv), t.value_or(complex64));
|
||||||
|
} else if (auto pv = std::get_if<nb::list>(&v); pv) {
|
||||||
|
return array_from_list(*pv, t);
|
||||||
|
} else if (auto pv = std::get_if<nb::tuple>(&v); pv) {
|
||||||
|
return array_from_list(*pv, t);
|
||||||
|
} else if (auto pv = std::get_if<
|
||||||
|
nb::ndarray<nb::ro, nb::c_contig, nb::device::cpu>>(&v);
|
||||||
|
pv) {
|
||||||
|
return nd_array_to_mlx(*pv, t);
|
||||||
|
} else if (auto pv = std::get_if<array>(&v); pv) {
|
||||||
|
return astype(*pv, t.value_or((*pv).dtype()));
|
||||||
|
} else {
|
||||||
|
auto arr = to_array_with_accessor(std::get<nb::object>(v));
|
||||||
|
return astype(arr, t.value_or(arr.dtype()));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
// Copyright © 2024 Apple Inc.
|
// Copyright © 2024 Apple Inc.
|
||||||
|
#pragma once
|
||||||
|
|
||||||
#include <optional>
|
#include <optional>
|
||||||
|
|
||||||
@ -6,13 +7,35 @@
|
|||||||
#include <nanobind/ndarray.h>
|
#include <nanobind/ndarray.h>
|
||||||
|
|
||||||
#include "mlx/array.h"
|
#include "mlx/array.h"
|
||||||
|
#include "mlx/ops.h"
|
||||||
|
|
||||||
namespace nb = nanobind;
|
namespace nb = nanobind;
|
||||||
using namespace mlx::core;
|
using namespace mlx::core;
|
||||||
|
|
||||||
|
using ArrayInitType = std::variant<
|
||||||
|
nb::bool_,
|
||||||
|
nb::int_,
|
||||||
|
nb::float_,
|
||||||
|
// Must be above ndarray
|
||||||
|
array,
|
||||||
|
// Must be above complex
|
||||||
|
nb::ndarray<nb::ro, nb::c_contig, nb::device::cpu>,
|
||||||
|
std::complex<float>,
|
||||||
|
nb::list,
|
||||||
|
nb::tuple,
|
||||||
|
nb::object>;
|
||||||
|
|
||||||
array nd_array_to_mlx(
|
array nd_array_to_mlx(
|
||||||
nb::ndarray<nb::ro, nb::c_contig, nb::device::cpu> nd_array,
|
nb::ndarray<nb::ro, nb::c_contig, nb::device::cpu> nd_array,
|
||||||
std::optional<Dtype> dtype);
|
std::optional<Dtype> dtype);
|
||||||
|
|
||||||
nb::ndarray<nb::numpy> mlx_to_np_array(const array& a);
|
nb::ndarray<nb::numpy> mlx_to_np_array(const array& a);
|
||||||
nb::ndarray<> mlx_to_dlpack(const array& a);
|
nb::ndarray<> mlx_to_dlpack(const array& a);
|
||||||
|
|
||||||
|
nb::object to_scalar(array& a);
|
||||||
|
|
||||||
|
nb::object tolist(array& a);
|
||||||
|
|
||||||
|
array create_array(ArrayInitType v, std::optional<Dtype> t);
|
||||||
|
array array_from_list(nb::list pl, std::optional<Dtype> dtype);
|
||||||
|
array array_from_list(nb::tuple pl, std::optional<Dtype> dtype);
|
||||||
|
@ -2,6 +2,7 @@
|
|||||||
#include <numeric>
|
#include <numeric>
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
|
|
||||||
|
#include "python/src/convert.h"
|
||||||
#include "python/src/indexing.h"
|
#include "python/src/indexing.h"
|
||||||
|
|
||||||
#include "mlx/ops.h"
|
#include "mlx/ops.h"
|
||||||
@ -51,7 +52,8 @@ array get_int_index(nb::object idx, int axis_size) {
|
|||||||
|
|
||||||
bool is_valid_index_type(const nb::object& obj) {
|
bool is_valid_index_type(const nb::object& obj) {
|
||||||
return nb::isinstance<nb::slice>(obj) || nb::isinstance<nb::int_>(obj) ||
|
return nb::isinstance<nb::slice>(obj) || nb::isinstance<nb::int_>(obj) ||
|
||||||
nb::isinstance<array>(obj) || obj.is_none() || nb::ellipsis().is(obj);
|
nb::isinstance<array>(obj) || obj.is_none() || nb::ellipsis().is(obj) ||
|
||||||
|
nb::isinstance<nb::list>(obj);
|
||||||
}
|
}
|
||||||
|
|
||||||
array mlx_get_item_slice(const array& src, const nb::slice& in_slice) {
|
array mlx_get_item_slice(const array& src, const nb::slice& in_slice) {
|
||||||
@ -255,11 +257,18 @@ array mlx_get_item_nd(array src, const nb::tuple& entries) {
|
|||||||
|
|
||||||
// The plan is as follows:
|
// The plan is as follows:
|
||||||
// 1. Replace the ellipsis with a series of slice(None)
|
// 1. Replace the ellipsis with a series of slice(None)
|
||||||
// 2. Loop over the indices and calculate the gather indices
|
// 2. Convert list to array
|
||||||
// 3. Calculate the remaining slices and reshapes
|
// 3. Loop over the indices and calculate the gather indices
|
||||||
|
// 4. Calculate the remaining slices and reshapes
|
||||||
|
|
||||||
// Ellipsis handling
|
// Ellipsis handling
|
||||||
auto [non_none_indices, indices] = mlx_expand_ellipsis(src.shape(), entries);
|
auto [non_none_indices, indices] = mlx_expand_ellipsis(src.shape(), entries);
|
||||||
|
// List handling
|
||||||
|
for (auto& idx : indices) {
|
||||||
|
if (nb::isinstance<nb::list>(idx)) {
|
||||||
|
idx = nb::cast(array_from_list(nb::cast<nb::list>(idx), {}));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Check for the number of indices passed
|
// Check for the number of indices passed
|
||||||
if (non_none_indices > src.ndim()) {
|
if (non_none_indices > src.ndim()) {
|
||||||
@ -440,6 +449,9 @@ array mlx_get_item(const array& src, const nb::object& obj) {
|
|||||||
std::vector<int> s(1, 1);
|
std::vector<int> s(1, 1);
|
||||||
s.insert(s.end(), src.shape().begin(), src.shape().end());
|
s.insert(s.end(), src.shape().begin(), src.shape().end());
|
||||||
return reshape(src, s);
|
return reshape(src, s);
|
||||||
|
} else if (nb::isinstance<nb::list>(obj)) {
|
||||||
|
return mlx_get_item_array(
|
||||||
|
src, array_from_list(nb::cast<nb::list>(obj), {}));
|
||||||
}
|
}
|
||||||
throw std::invalid_argument("Cannot index mlx array using the given type.");
|
throw std::invalid_argument("Cannot index mlx array using the given type.");
|
||||||
}
|
}
|
||||||
@ -564,6 +576,13 @@ std::tuple<std::vector<array>, array, std::vector<int>> mlx_scatter_args_nd(
|
|||||||
// Expand ellipses into a series of ':' slices
|
// Expand ellipses into a series of ':' slices
|
||||||
auto [non_none_indices, indices] = mlx_expand_ellipsis(src.shape(), entries);
|
auto [non_none_indices, indices] = mlx_expand_ellipsis(src.shape(), entries);
|
||||||
|
|
||||||
|
// Convert List to array
|
||||||
|
for (auto& idx : indices) {
|
||||||
|
if (nb::isinstance<nb::list>(idx)) {
|
||||||
|
idx = nb::cast(array_from_list(nb::cast<nb::list>(idx), {}));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if (non_none_indices > src.ndim()) {
|
if (non_none_indices > src.ndim()) {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
msg << "Too many indices for array with " << src.ndim() << "dimensions.";
|
msg << "Too many indices for array with " << src.ndim() << "dimensions.";
|
||||||
@ -753,7 +772,11 @@ mlx_compute_scatter_args(
|
|||||||
return mlx_scatter_args_nd(src, nb::cast<nb::tuple>(obj), vals);
|
return mlx_scatter_args_nd(src, nb::cast<nb::tuple>(obj), vals);
|
||||||
} else if (obj.is_none()) {
|
} else if (obj.is_none()) {
|
||||||
return {{}, broadcast_to(vals, src.shape()), {}};
|
return {{}, broadcast_to(vals, src.shape()), {}};
|
||||||
|
} else if (nb::isinstance<nb::list>(obj)) {
|
||||||
|
return mlx_scatter_args_array(
|
||||||
|
src, array_from_list(nb::cast<nb::list>(obj), {}), vals);
|
||||||
}
|
}
|
||||||
|
|
||||||
throw std::invalid_argument("Cannot index mlx array using the given type.");
|
throw std::invalid_argument("Cannot index mlx array using the given type.");
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -769,7 +792,7 @@ auto mlx_slice_update(
|
|||||||
if (nb::isinstance<nb::tuple>(obj)) {
|
if (nb::isinstance<nb::tuple>(obj)) {
|
||||||
// Can't route to slice update if any arrays are present
|
// Can't route to slice update if any arrays are present
|
||||||
for (auto idx : nb::cast<nb::tuple>(obj)) {
|
for (auto idx : nb::cast<nb::tuple>(obj)) {
|
||||||
if (nb::isinstance<array>(idx)) {
|
if (nb::isinstance<array>(idx) || nb::isinstance<nb::list>(idx)) {
|
||||||
return std::make_pair(false, src);
|
return std::make_pair(false, src);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1740,6 +1740,68 @@ class TestArray(mlx_tests.MLXTestCase):
|
|||||||
y = np.from_dlpack(x)
|
y = np.from_dlpack(x)
|
||||||
self.assertTrue(mx.array_equal(y, x))
|
self.assertTrue(mx.array_equal(y, x))
|
||||||
|
|
||||||
|
def test_getitem_with_list(self):
|
||||||
|
a = mx.array([1, 2, 3, 4, 5])
|
||||||
|
idx = [0, 2, 4]
|
||||||
|
self.assertTrue(np.array_equal(a[idx], np.array(a)[idx]))
|
||||||
|
|
||||||
|
a = mx.array([[1, 2], [3, 4], [5, 6]])
|
||||||
|
idx = [0, 2]
|
||||||
|
self.assertTrue(np.array_equal(a[idx], np.array(a)[idx]))
|
||||||
|
|
||||||
|
a = mx.arange(10).reshape(5, 2)
|
||||||
|
idx = [0, 2, 4]
|
||||||
|
self.assertTrue(np.array_equal(a[idx], np.array(a)[idx]))
|
||||||
|
|
||||||
|
idx = [0, 2]
|
||||||
|
a = mx.arange(16).reshape(4, 4)
|
||||||
|
anp = np.array(a)
|
||||||
|
self.assertTrue(np.array_equal(a[idx, 0], anp[idx, 0]))
|
||||||
|
self.assertTrue(np.array_equal(a[idx, :], anp[idx, :]))
|
||||||
|
self.assertTrue(np.array_equal(a[0, idx], anp[0, idx]))
|
||||||
|
self.assertTrue(np.array_equal(a[:, idx], anp[:, idx]))
|
||||||
|
|
||||||
|
def test_setitem_with_list(self):
|
||||||
|
a = mx.array([1, 2, 3, 4, 5])
|
||||||
|
anp = np.array(a)
|
||||||
|
idx = [0, 2, 4]
|
||||||
|
a[idx] = 3
|
||||||
|
anp[idx] = 3
|
||||||
|
self.assertTrue(np.array_equal(a, anp))
|
||||||
|
|
||||||
|
a = mx.array([[1, 2], [3, 4], [5, 6]])
|
||||||
|
idx = [0, 2]
|
||||||
|
anp = np.array(a)
|
||||||
|
a[idx] = 3
|
||||||
|
anp[idx] = 3
|
||||||
|
self.assertTrue(np.array_equal(a, anp))
|
||||||
|
|
||||||
|
a = mx.arange(10).reshape(5, 2)
|
||||||
|
idx = [0, 2, 4]
|
||||||
|
anp = np.array(a)
|
||||||
|
a[idx] = 3
|
||||||
|
anp[idx] = 3
|
||||||
|
self.assertTrue(np.array_equal(a, anp))
|
||||||
|
|
||||||
|
idx = [0, 2]
|
||||||
|
a = mx.arange(16).reshape(4, 4)
|
||||||
|
anp = np.array(a)
|
||||||
|
a[idx, 0] = 1
|
||||||
|
anp[idx, 0] = 1
|
||||||
|
self.assertTrue(np.array_equal(a, anp))
|
||||||
|
|
||||||
|
a[idx, :] = 2
|
||||||
|
anp[idx, :] = 2
|
||||||
|
self.assertTrue(np.array_equal(a, anp))
|
||||||
|
|
||||||
|
a[0, idx] = 3
|
||||||
|
anp[0, idx] = 3
|
||||||
|
self.assertTrue(np.array_equal(a, anp))
|
||||||
|
|
||||||
|
a[:, idx] = 4
|
||||||
|
anp[:, idx] = 4
|
||||||
|
self.assertTrue(np.array_equal(a, anp))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
Loading…
Reference in New Issue
Block a user