Switch to nanobind (#839)

* mostly builds

* most tests pass

* fix circle build

* add back buffer protocol

* includes

* fix for py38

* limit to cpu device

* include

* fix stubs

* move signatures for docs

* stubgen + docs fix

* doc for compiled function, comments
This commit is contained in:
Awni Hannun
2024-03-18 20:12:25 -07:00
committed by GitHub
parent d39ed54f8e
commit 9a8ee00246
34 changed files with 2343 additions and 2344 deletions

View File

@@ -1,4 +1,4 @@
// Copyright © 2023 Apple Inc.
// Copyright © 2023-2024 Apple Inc.
#include <numeric>
#include <sstream>
@@ -7,19 +7,19 @@
#include "mlx/ops.h"
bool is_none_slice(const py::slice& in_slice) {
bool is_none_slice(const nb::slice& in_slice) {
return (
py::getattr(in_slice, "start").is_none() &&
py::getattr(in_slice, "stop").is_none() &&
py::getattr(in_slice, "step").is_none());
nb::getattr(in_slice, "start").is_none() &&
nb::getattr(in_slice, "stop").is_none() &&
nb::getattr(in_slice, "step").is_none());
}
int get_slice_int(py::object obj, int default_val) {
int get_slice_int(nb::object obj, int default_val) {
if (!obj.is_none()) {
if (!py::isinstance<py::int_>(obj)) {
if (!nb::isinstance<nb::int_>(obj)) {
throw std::invalid_argument("Slice indices must be integers or None.");
}
return py::cast<int>(py::cast<py::int_>(obj));
return nb::cast<int>(nb::cast<nb::int_>(obj));
}
return default_val;
}
@@ -28,7 +28,7 @@ void get_slice_params(
int& starts,
int& ends,
int& strides,
const py::slice& in_slice,
const nb::slice& in_slice,
int axis_size) {
// Following numpy's convention
// Assume n is the number of elements in the dimension being sliced.
@@ -36,26 +36,26 @@ void get_slice_params(
// k < 0 . If j is not given it defaults to n for k > 0 and -n-1 for
// k < 0 . If k is not given it defaults to 1
strides = get_slice_int(py::getattr(in_slice, "step"), 1);
strides = get_slice_int(nb::getattr(in_slice, "step"), 1);
starts = get_slice_int(
py::getattr(in_slice, "start"), strides < 0 ? axis_size - 1 : 0);
nb::getattr(in_slice, "start"), strides < 0 ? axis_size - 1 : 0);
ends = get_slice_int(
py::getattr(in_slice, "stop"), strides < 0 ? -axis_size - 1 : axis_size);
nb::getattr(in_slice, "stop"), strides < 0 ? -axis_size - 1 : axis_size);
}
array get_int_index(py::object idx, int axis_size) {
int idx_ = py::cast<int>(idx);
array get_int_index(nb::object idx, int axis_size) {
int idx_ = nb::cast<int>(idx);
idx_ = (idx_ < 0) ? idx_ + axis_size : idx_;
return array(idx_, uint32);
}
bool is_valid_index_type(const py::object& obj) {
return py::isinstance<py::slice>(obj) || py::isinstance<py::int_>(obj) ||
py::isinstance<array>(obj) || obj.is_none() || py::ellipsis().is(obj);
bool is_valid_index_type(const nb::object& obj) {
return nb::isinstance<nb::slice>(obj) || nb::isinstance<nb::int_>(obj) ||
nb::isinstance<array>(obj) || obj.is_none() || nb::ellipsis().is(obj);
}
array mlx_get_item_slice(const array& src, const py::slice& in_slice) {
array mlx_get_item_slice(const array& src, const nb::slice& in_slice) {
// Check input and raise error if 0 dim for parity with np
if (src.ndim() == 0) {
throw std::invalid_argument(
@@ -92,7 +92,7 @@ array mlx_get_item_array(const array& src, const array& indices) {
return take(src, indices, 0);
}
array mlx_get_item_int(const array& src, const py::int_& idx) {
array mlx_get_item_int(const array& src, const nb::int_& idx) {
// Check input and raise error if 0 dim for parity with np
if (src.ndim() == 0) {
throw std::invalid_argument(
@@ -106,7 +106,7 @@ array mlx_get_item_int(const array& src, const py::int_& idx) {
array mlx_gather_nd(
array src,
const std::vector<py::object>& indices,
const std::vector<nb::object>& indices,
bool gather_first,
int& max_dims) {
max_dims = 0;
@@ -117,9 +117,10 @@ array mlx_gather_nd(
for (int i = 0; i < indices.size(); i++) {
auto& idx = indices[i];
if (py::isinstance<py::slice>(idx)) {
if (nb::isinstance<nb::slice>(idx)) {
int start, end, stride;
get_slice_params(start, end, stride, idx, src.shape(i));
get_slice_params(
start, end, stride, nb::cast<nb::slice>(idx), src.shape(i));
// Handle negative indices
start = (start < 0) ? start + src.shape(i) : start;
@@ -128,10 +129,10 @@ array mlx_gather_nd(
gather_indices.push_back(arange(start, end, stride, uint32));
num_slices++;
is_slice[i] = true;
} else if (py::isinstance<py::int_>(idx)) {
} else if (nb::isinstance<nb::int_>(idx)) {
gather_indices.push_back(get_int_index(idx, src.shape(i)));
} else if (py::isinstance<array>(idx)) {
auto arr = py::cast<array>(idx);
} else if (nb::isinstance<array>(idx)) {
auto arr = nb::cast<array>(idx);
max_dims = std::max(static_cast<int>(arr.ndim()), max_dims);
gather_indices.push_back(arr);
}
@@ -185,7 +186,7 @@ array mlx_gather_nd(
return src;
}
array mlx_get_item_nd(array src, const py::tuple& entries) {
array mlx_get_item_nd(array src, const nb::tuple& entries) {
// No indices make this a noop
if (entries.size() == 0) {
return src;
@@ -197,11 +198,11 @@ array mlx_get_item_nd(array src, const py::tuple& entries) {
// 3. Calculate the remaining slices and reshapes
// Ellipsis handling
std::vector<py::object> indices;
std::vector<nb::object> indices;
{
int non_none_indices_before = 0;
int non_none_indices_after = 0;
std::vector<py::object> r_indices;
std::vector<nb::object> r_indices;
int i = 0;
for (; i < entries.size(); i++) {
auto idx = entries[i];
@@ -209,7 +210,7 @@ array mlx_get_item_nd(array src, const py::tuple& entries) {
throw std::invalid_argument(
"Cannot index mlx array using the given type yet");
}
if (!py::ellipsis().is(idx)) {
if (!nb::ellipsis().is(idx)) {
indices.push_back(idx);
non_none_indices_before += !idx.is_none();
} else {
@@ -222,7 +223,7 @@ array mlx_get_item_nd(array src, const py::tuple& entries) {
throw std::invalid_argument(
"Cannot index mlx array using the given type yet");
}
if (py::ellipsis().is(idx)) {
if (nb::ellipsis().is(idx)) {
throw std::invalid_argument(
"An index can only have a single ellipsis (...)");
}
@@ -232,7 +233,7 @@ array mlx_get_item_nd(array src, const py::tuple& entries) {
for (int axis = non_none_indices_before;
axis < src.ndim() - non_none_indices_after;
axis++) {
indices.push_back(py::slice(0, src.shape(axis), 1));
indices.push_back(nb::slice(0, src.shape(axis), 1));
}
indices.insert(indices.end(), r_indices.rbegin(), r_indices.rend());
}
@@ -256,7 +257,7 @@ array mlx_get_item_nd(array src, const py::tuple& entries) {
//
// Check whether we have arrays or integer indices and delegate to gather_nd
// after removing the slices at the end and all Nones.
std::vector<py::object> remaining_indices;
std::vector<nb::object> remaining_indices;
bool have_array = false;
{
// First check whether the results of gather are going to be 1st or
@@ -264,7 +265,7 @@ array mlx_get_item_nd(array src, const py::tuple& entries) {
bool have_non_array = false;
bool gather_first = false;
for (auto& idx : indices) {
if (py::isinstance<array>(idx) || py::isinstance<py::int_>(idx)) {
if (nb::isinstance<array>(idx) || nb::isinstance<nb::int_>(idx)) {
if (have_array && have_non_array) {
gather_first = true;
break;
@@ -280,12 +281,12 @@ array mlx_get_item_nd(array src, const py::tuple& entries) {
// Then find the last array
for (last_array = indices.size() - 1; last_array >= 0; last_array--) {
auto& idx = indices[last_array];
if (py::isinstance<array>(idx) || py::isinstance<py::int_>(idx)) {
if (nb::isinstance<array>(idx) || nb::isinstance<nb::int_>(idx)) {
break;
}
}
std::vector<py::object> gather_indices;
std::vector<nb::object> gather_indices;
for (int i = 0; i <= last_array; i++) {
auto& idx = indices[i];
if (!idx.is_none()) {
@@ -299,15 +300,15 @@ array mlx_get_item_nd(array src, const py::tuple& entries) {
if (gather_first) {
for (int i = 0; i < max_dims; i++) {
remaining_indices.push_back(
py::slice(py::none(), py::none(), py::none()));
nb::slice(nb::none(), nb::none(), nb::none()));
}
for (int i = 0; i < last_array; i++) {
auto& idx = indices[i];
if (idx.is_none()) {
remaining_indices.push_back(indices[i]);
} else if (py::isinstance<py::slice>(idx)) {
} else if (nb::isinstance<nb::slice>(idx)) {
remaining_indices.push_back(
py::slice(py::none(), py::none(), py::none()));
nb::slice(nb::none(), nb::none(), nb::none()));
}
}
for (int i = last_array + 1; i < indices.size(); i++) {
@@ -316,18 +317,18 @@ array mlx_get_item_nd(array src, const py::tuple& entries) {
} else {
for (int i = 0; i < indices.size(); i++) {
auto& idx = indices[i];
if (py::isinstance<array>(idx) || py::isinstance<py::int_>(idx)) {
if (nb::isinstance<array>(idx) || nb::isinstance<nb::int_>(idx)) {
break;
} else if (idx.is_none()) {
remaining_indices.push_back(idx);
} else {
remaining_indices.push_back(
py::slice(py::none(), py::none(), py::none()));
nb::slice(nb::none(), nb::none(), nb::none()));
}
}
for (int i = 0; i < max_dims; i++) {
remaining_indices.push_back(
py::slice(py::none(), py::none(), py::none()));
nb::slice(nb::none(), nb::none(), nb::none()));
}
for (int i = last_array + 1; i < indices.size(); i++) {
remaining_indices.push_back(indices[i]);
@@ -351,7 +352,11 @@ array mlx_get_item_nd(array src, const py::tuple& entries) {
for (auto& idx : remaining_indices) {
if (!idx.is_none()) {
get_slice_params(
starts[axis], ends[axis], strides[axis], idx, ends[axis]);
starts[axis],
ends[axis],
strides[axis],
nb::cast<nb::slice>(idx),
ends[axis]);
axis++;
}
}
@@ -375,15 +380,17 @@ array mlx_get_item_nd(array src, const py::tuple& entries) {
return src;
}
array mlx_get_item(const array& src, const py::object& obj) {
if (py::isinstance<py::slice>(obj)) {
return mlx_get_item_slice(src, obj);
} else if (py::isinstance<array>(obj)) {
return mlx_get_item_array(src, py::cast<array>(obj));
} else if (py::isinstance<py::int_>(obj)) {
return mlx_get_item_int(src, obj);
} else if (py::isinstance<py::tuple>(obj)) {
return mlx_get_item_nd(src, obj);
array mlx_get_item(const array& src, const nb::object& obj) {
if (nb::isinstance<nb::slice>(obj)) {
return mlx_get_item_slice(src, nb::cast<nb::slice>(obj));
} else if (nb::isinstance<array>(obj)) {
return mlx_get_item_array(src, nb::cast<array>(obj));
} else if (nb::isinstance<nb::int_>(obj)) {
return mlx_get_item_int(src, nb::cast<nb::int_>(obj));
} else if (nb::isinstance<nb::tuple>(obj)) {
return mlx_get_item_nd(src, nb::cast<nb::tuple>(obj));
} else if (nb::isinstance<nb::ellipsis>(obj)) {
return src;
} else if (obj.is_none()) {
std::vector<int> s(1, 1);
s.insert(s.end(), src.shape().begin(), src.shape().end());
@@ -394,7 +401,7 @@ array mlx_get_item(const array& src, const py::object& obj) {
std::tuple<std::vector<array>, array, std::vector<int>> mlx_scatter_args_int(
const array& src,
const py::int_& idx,
const nb::int_& idx,
const array& update) {
if (src.ndim() == 0) {
throw std::invalid_argument(
@@ -446,7 +453,7 @@ std::tuple<std::vector<array>, array, std::vector<int>> mlx_scatter_args_array(
std::tuple<std::vector<array>, array, std::vector<int>> mlx_scatter_args_slice(
const array& src,
const py::slice& in_slice,
const nb::slice& in_slice,
const array& update) {
// Check input and raise error if 0 dim for parity with np
if (src.ndim() == 0) {
@@ -478,9 +485,9 @@ std::tuple<std::vector<array>, array, std::vector<int>> mlx_scatter_args_slice(
std::tuple<std::vector<array>, array, std::vector<int>> mlx_scatter_args_nd(
const array& src,
const py::tuple& entries,
const nb::tuple& entries,
const array& update) {
std::vector<py::object> indices;
std::vector<nb::object> indices;
int non_none_indices = 0;
// Expand ellipses into a series of ':' slices
@@ -494,7 +501,7 @@ std::tuple<std::vector<array>, array, std::vector<int>> mlx_scatter_args_nd(
if (!is_valid_index_type(idx)) {
throw std::invalid_argument(
"Cannot index mlx array using the given type yet");
} else if (!py::ellipsis().is(idx)) {
} else if (!nb::ellipsis().is(idx)) {
if (!has_ellipsis) {
indices_before++;
non_none_indices_before += !idx.is_none();
@@ -514,7 +521,7 @@ std::tuple<std::vector<array>, array, std::vector<int>> mlx_scatter_args_nd(
axis < src.ndim() - non_none_indices_after;
axis++) {
indices.insert(
indices.begin() + indices_before, py::slice(0, src.shape(axis), 1));
indices.begin() + indices_before, nb::slice(0, src.shape(axis), 1));
}
non_none_indices = src.ndim();
} else {
@@ -549,15 +556,15 @@ std::tuple<std::vector<array>, array, std::vector<int>> mlx_scatter_args_nd(
bool have_array = false;
bool have_non_array = false;
for (auto& idx : indices) {
if (py::isinstance<py::slice>(idx) || idx.is_none()) {
if (nb::isinstance<nb::slice>(idx) || idx.is_none()) {
have_non_array = have_array;
num_slices++;
} else if (py::isinstance<array>(idx)) {
} else if (nb::isinstance<array>(idx)) {
have_array = true;
if (have_array && have_non_array) {
arrays_first = true;
}
max_dim = std::max(py::cast<array>(idx).ndim(), max_dim);
max_dim = std::max(nb::cast<array>(idx).ndim(), max_dim);
num_arrays++;
}
}
@@ -569,10 +576,11 @@ std::tuple<std::vector<array>, array, std::vector<int>> mlx_scatter_args_nd(
int ax = 0;
for (int i = 0; i < indices.size(); ++i) {
auto& pyidx = indices[i];
if (py::isinstance<py::slice>(pyidx)) {
if (nb::isinstance<nb::slice>(pyidx)) {
int start, end, stride;
auto axis_size = src.shape(ax++);
get_slice_params(start, end, stride, pyidx, axis_size);
get_slice_params(
start, end, stride, nb::cast<nb::slice>(pyidx), axis_size);
// Handle negative indices
start = (start < 0) ? start + axis_size : start;
@@ -584,13 +592,13 @@ std::tuple<std::vector<array>, array, std::vector<int>> mlx_scatter_args_nd(
slice_num++;
idx_shape[loc] = idx.size();
arr_indices.push_back(reshape(idx, idx_shape));
} else if (py::isinstance<py::int_>(pyidx)) {
} else if (nb::isinstance<nb::int_>(pyidx)) {
arr_indices.push_back(get_int_index(pyidx, src.shape(ax++)));
} else if (pyidx.is_none()) {
slice_num++;
} else if (py::isinstance<array>(pyidx)) {
} else if (nb::isinstance<array>(pyidx)) {
ax++;
auto idx = py::cast<array>(pyidx);
auto idx = nb::cast<array>(pyidx);
std::vector<int> idx_shape;
if (!arrays_first) {
idx_shape.insert(idx_shape.end(), slice_num, 1);
@@ -629,24 +637,24 @@ std::tuple<std::vector<array>, array, std::vector<int>> mlx_scatter_args_nd(
std::tuple<std::vector<array>, array, std::vector<int>>
mlx_compute_scatter_args(
const array& src,
const py::object& obj,
const nb::object& obj,
const ScalarOrArray& v) {
auto vals = to_array(v, src.dtype());
if (py::isinstance<py::slice>(obj)) {
return mlx_scatter_args_slice(src, obj, vals);
} else if (py::isinstance<array>(obj)) {
return mlx_scatter_args_array(src, py::cast<array>(obj), vals);
} else if (py::isinstance<py::int_>(obj)) {
return mlx_scatter_args_int(src, obj, vals);
} else if (py::isinstance<py::tuple>(obj)) {
return mlx_scatter_args_nd(src, obj, vals);
if (nb::isinstance<nb::slice>(obj)) {
return mlx_scatter_args_slice(src, nb::cast<nb::slice>(obj), vals);
} else if (nb::isinstance<array>(obj)) {
return mlx_scatter_args_array(src, nb::cast<array>(obj), vals);
} else if (nb::isinstance<nb::int_>(obj)) {
return mlx_scatter_args_int(src, nb::cast<nb::int_>(obj), vals);
} else if (nb::isinstance<nb::tuple>(obj)) {
return mlx_scatter_args_nd(src, nb::cast<nb::tuple>(obj), vals);
} else if (obj.is_none()) {
return {{}, broadcast_to(vals, src.shape()), {}};
}
throw std::invalid_argument("Cannot index mlx array using the given type.");
}
void mlx_set_item(array& src, const py::object& obj, const ScalarOrArray& v) {
void mlx_set_item(array& src, const nb::object& obj, const ScalarOrArray& v) {
auto [indices, updates, axes] = mlx_compute_scatter_args(src, obj, v);
if (indices.size() > 0) {
auto out = scatter(src, indices, updates, axes);
@@ -658,7 +666,7 @@ void mlx_set_item(array& src, const py::object& obj, const ScalarOrArray& v) {
array mlx_add_item(
const array& src,
const py::object& obj,
const nb::object& obj,
const ScalarOrArray& v) {
auto [indices, updates, axes] = mlx_compute_scatter_args(src, obj, v);
if (indices.size() > 0) {
@@ -670,7 +678,7 @@ array mlx_add_item(
array mlx_subtract_item(
const array& src,
const py::object& obj,
const nb::object& obj,
const ScalarOrArray& v) {
auto [indices, updates, axes] = mlx_compute_scatter_args(src, obj, v);
if (indices.size() > 0) {
@@ -682,7 +690,7 @@ array mlx_subtract_item(
array mlx_multiply_item(
const array& src,
const py::object& obj,
const nb::object& obj,
const ScalarOrArray& v) {
auto [indices, updates, axes] = mlx_compute_scatter_args(src, obj, v);
if (indices.size() > 0) {
@@ -694,7 +702,7 @@ array mlx_multiply_item(
array mlx_divide_item(
const array& src,
const py::object& obj,
const nb::object& obj,
const ScalarOrArray& v) {
auto [indices, updates, axes] = mlx_compute_scatter_args(src, obj, v);
if (indices.size() > 0) {
@@ -706,7 +714,7 @@ array mlx_divide_item(
array mlx_maximum_item(
const array& src,
const py::object& obj,
const nb::object& obj,
const ScalarOrArray& v) {
auto [indices, updates, axes] = mlx_compute_scatter_args(src, obj, v);
if (indices.size() > 0) {
@@ -718,7 +726,7 @@ array mlx_maximum_item(
array mlx_minimum_item(
const array& src,
const py::object& obj,
const nb::object& obj,
const ScalarOrArray& v) {
auto [indices, updates, axes] = mlx_compute_scatter_args(src, obj, v);
if (indices.size() > 0) {