mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Switch to nanobind (#839)
* mostly builds * most tests pass * fix circle build * add back buffer protocol * includes * fix for py38 * limit to cpu device * include * fix stubs * move signatures for docs * stubgen + docs fix * doc for compiled function, comments
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <numeric>
|
||||
#include <sstream>
|
||||
@@ -7,19 +7,19 @@
|
||||
|
||||
#include "mlx/ops.h"
|
||||
|
||||
bool is_none_slice(const py::slice& in_slice) {
|
||||
bool is_none_slice(const nb::slice& in_slice) {
|
||||
return (
|
||||
py::getattr(in_slice, "start").is_none() &&
|
||||
py::getattr(in_slice, "stop").is_none() &&
|
||||
py::getattr(in_slice, "step").is_none());
|
||||
nb::getattr(in_slice, "start").is_none() &&
|
||||
nb::getattr(in_slice, "stop").is_none() &&
|
||||
nb::getattr(in_slice, "step").is_none());
|
||||
}
|
||||
|
||||
int get_slice_int(py::object obj, int default_val) {
|
||||
int get_slice_int(nb::object obj, int default_val) {
|
||||
if (!obj.is_none()) {
|
||||
if (!py::isinstance<py::int_>(obj)) {
|
||||
if (!nb::isinstance<nb::int_>(obj)) {
|
||||
throw std::invalid_argument("Slice indices must be integers or None.");
|
||||
}
|
||||
return py::cast<int>(py::cast<py::int_>(obj));
|
||||
return nb::cast<int>(nb::cast<nb::int_>(obj));
|
||||
}
|
||||
return default_val;
|
||||
}
|
||||
@@ -28,7 +28,7 @@ void get_slice_params(
|
||||
int& starts,
|
||||
int& ends,
|
||||
int& strides,
|
||||
const py::slice& in_slice,
|
||||
const nb::slice& in_slice,
|
||||
int axis_size) {
|
||||
// Following numpy's convention
|
||||
// Assume n is the number of elements in the dimension being sliced.
|
||||
@@ -36,26 +36,26 @@ void get_slice_params(
|
||||
// k < 0 . If j is not given it defaults to n for k > 0 and -n-1 for
|
||||
// k < 0 . If k is not given it defaults to 1
|
||||
|
||||
strides = get_slice_int(py::getattr(in_slice, "step"), 1);
|
||||
strides = get_slice_int(nb::getattr(in_slice, "step"), 1);
|
||||
starts = get_slice_int(
|
||||
py::getattr(in_slice, "start"), strides < 0 ? axis_size - 1 : 0);
|
||||
nb::getattr(in_slice, "start"), strides < 0 ? axis_size - 1 : 0);
|
||||
ends = get_slice_int(
|
||||
py::getattr(in_slice, "stop"), strides < 0 ? -axis_size - 1 : axis_size);
|
||||
nb::getattr(in_slice, "stop"), strides < 0 ? -axis_size - 1 : axis_size);
|
||||
}
|
||||
|
||||
array get_int_index(py::object idx, int axis_size) {
|
||||
int idx_ = py::cast<int>(idx);
|
||||
array get_int_index(nb::object idx, int axis_size) {
|
||||
int idx_ = nb::cast<int>(idx);
|
||||
idx_ = (idx_ < 0) ? idx_ + axis_size : idx_;
|
||||
|
||||
return array(idx_, uint32);
|
||||
}
|
||||
|
||||
bool is_valid_index_type(const py::object& obj) {
|
||||
return py::isinstance<py::slice>(obj) || py::isinstance<py::int_>(obj) ||
|
||||
py::isinstance<array>(obj) || obj.is_none() || py::ellipsis().is(obj);
|
||||
bool is_valid_index_type(const nb::object& obj) {
|
||||
return nb::isinstance<nb::slice>(obj) || nb::isinstance<nb::int_>(obj) ||
|
||||
nb::isinstance<array>(obj) || obj.is_none() || nb::ellipsis().is(obj);
|
||||
}
|
||||
|
||||
array mlx_get_item_slice(const array& src, const py::slice& in_slice) {
|
||||
array mlx_get_item_slice(const array& src, const nb::slice& in_slice) {
|
||||
// Check input and raise error if 0 dim for parity with np
|
||||
if (src.ndim() == 0) {
|
||||
throw std::invalid_argument(
|
||||
@@ -92,7 +92,7 @@ array mlx_get_item_array(const array& src, const array& indices) {
|
||||
return take(src, indices, 0);
|
||||
}
|
||||
|
||||
array mlx_get_item_int(const array& src, const py::int_& idx) {
|
||||
array mlx_get_item_int(const array& src, const nb::int_& idx) {
|
||||
// Check input and raise error if 0 dim for parity with np
|
||||
if (src.ndim() == 0) {
|
||||
throw std::invalid_argument(
|
||||
@@ -106,7 +106,7 @@ array mlx_get_item_int(const array& src, const py::int_& idx) {
|
||||
|
||||
array mlx_gather_nd(
|
||||
array src,
|
||||
const std::vector<py::object>& indices,
|
||||
const std::vector<nb::object>& indices,
|
||||
bool gather_first,
|
||||
int& max_dims) {
|
||||
max_dims = 0;
|
||||
@@ -117,9 +117,10 @@ array mlx_gather_nd(
|
||||
for (int i = 0; i < indices.size(); i++) {
|
||||
auto& idx = indices[i];
|
||||
|
||||
if (py::isinstance<py::slice>(idx)) {
|
||||
if (nb::isinstance<nb::slice>(idx)) {
|
||||
int start, end, stride;
|
||||
get_slice_params(start, end, stride, idx, src.shape(i));
|
||||
get_slice_params(
|
||||
start, end, stride, nb::cast<nb::slice>(idx), src.shape(i));
|
||||
|
||||
// Handle negative indices
|
||||
start = (start < 0) ? start + src.shape(i) : start;
|
||||
@@ -128,10 +129,10 @@ array mlx_gather_nd(
|
||||
gather_indices.push_back(arange(start, end, stride, uint32));
|
||||
num_slices++;
|
||||
is_slice[i] = true;
|
||||
} else if (py::isinstance<py::int_>(idx)) {
|
||||
} else if (nb::isinstance<nb::int_>(idx)) {
|
||||
gather_indices.push_back(get_int_index(idx, src.shape(i)));
|
||||
} else if (py::isinstance<array>(idx)) {
|
||||
auto arr = py::cast<array>(idx);
|
||||
} else if (nb::isinstance<array>(idx)) {
|
||||
auto arr = nb::cast<array>(idx);
|
||||
max_dims = std::max(static_cast<int>(arr.ndim()), max_dims);
|
||||
gather_indices.push_back(arr);
|
||||
}
|
||||
@@ -185,7 +186,7 @@ array mlx_gather_nd(
|
||||
return src;
|
||||
}
|
||||
|
||||
array mlx_get_item_nd(array src, const py::tuple& entries) {
|
||||
array mlx_get_item_nd(array src, const nb::tuple& entries) {
|
||||
// No indices make this a noop
|
||||
if (entries.size() == 0) {
|
||||
return src;
|
||||
@@ -197,11 +198,11 @@ array mlx_get_item_nd(array src, const py::tuple& entries) {
|
||||
// 3. Calculate the remaining slices and reshapes
|
||||
|
||||
// Ellipsis handling
|
||||
std::vector<py::object> indices;
|
||||
std::vector<nb::object> indices;
|
||||
{
|
||||
int non_none_indices_before = 0;
|
||||
int non_none_indices_after = 0;
|
||||
std::vector<py::object> r_indices;
|
||||
std::vector<nb::object> r_indices;
|
||||
int i = 0;
|
||||
for (; i < entries.size(); i++) {
|
||||
auto idx = entries[i];
|
||||
@@ -209,7 +210,7 @@ array mlx_get_item_nd(array src, const py::tuple& entries) {
|
||||
throw std::invalid_argument(
|
||||
"Cannot index mlx array using the given type yet");
|
||||
}
|
||||
if (!py::ellipsis().is(idx)) {
|
||||
if (!nb::ellipsis().is(idx)) {
|
||||
indices.push_back(idx);
|
||||
non_none_indices_before += !idx.is_none();
|
||||
} else {
|
||||
@@ -222,7 +223,7 @@ array mlx_get_item_nd(array src, const py::tuple& entries) {
|
||||
throw std::invalid_argument(
|
||||
"Cannot index mlx array using the given type yet");
|
||||
}
|
||||
if (py::ellipsis().is(idx)) {
|
||||
if (nb::ellipsis().is(idx)) {
|
||||
throw std::invalid_argument(
|
||||
"An index can only have a single ellipsis (...)");
|
||||
}
|
||||
@@ -232,7 +233,7 @@ array mlx_get_item_nd(array src, const py::tuple& entries) {
|
||||
for (int axis = non_none_indices_before;
|
||||
axis < src.ndim() - non_none_indices_after;
|
||||
axis++) {
|
||||
indices.push_back(py::slice(0, src.shape(axis), 1));
|
||||
indices.push_back(nb::slice(0, src.shape(axis), 1));
|
||||
}
|
||||
indices.insert(indices.end(), r_indices.rbegin(), r_indices.rend());
|
||||
}
|
||||
@@ -256,7 +257,7 @@ array mlx_get_item_nd(array src, const py::tuple& entries) {
|
||||
//
|
||||
// Check whether we have arrays or integer indices and delegate to gather_nd
|
||||
// after removing the slices at the end and all Nones.
|
||||
std::vector<py::object> remaining_indices;
|
||||
std::vector<nb::object> remaining_indices;
|
||||
bool have_array = false;
|
||||
{
|
||||
// First check whether the results of gather are going to be 1st or
|
||||
@@ -264,7 +265,7 @@ array mlx_get_item_nd(array src, const py::tuple& entries) {
|
||||
bool have_non_array = false;
|
||||
bool gather_first = false;
|
||||
for (auto& idx : indices) {
|
||||
if (py::isinstance<array>(idx) || py::isinstance<py::int_>(idx)) {
|
||||
if (nb::isinstance<array>(idx) || nb::isinstance<nb::int_>(idx)) {
|
||||
if (have_array && have_non_array) {
|
||||
gather_first = true;
|
||||
break;
|
||||
@@ -280,12 +281,12 @@ array mlx_get_item_nd(array src, const py::tuple& entries) {
|
||||
// Then find the last array
|
||||
for (last_array = indices.size() - 1; last_array >= 0; last_array--) {
|
||||
auto& idx = indices[last_array];
|
||||
if (py::isinstance<array>(idx) || py::isinstance<py::int_>(idx)) {
|
||||
if (nb::isinstance<array>(idx) || nb::isinstance<nb::int_>(idx)) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<py::object> gather_indices;
|
||||
std::vector<nb::object> gather_indices;
|
||||
for (int i = 0; i <= last_array; i++) {
|
||||
auto& idx = indices[i];
|
||||
if (!idx.is_none()) {
|
||||
@@ -299,15 +300,15 @@ array mlx_get_item_nd(array src, const py::tuple& entries) {
|
||||
if (gather_first) {
|
||||
for (int i = 0; i < max_dims; i++) {
|
||||
remaining_indices.push_back(
|
||||
py::slice(py::none(), py::none(), py::none()));
|
||||
nb::slice(nb::none(), nb::none(), nb::none()));
|
||||
}
|
||||
for (int i = 0; i < last_array; i++) {
|
||||
auto& idx = indices[i];
|
||||
if (idx.is_none()) {
|
||||
remaining_indices.push_back(indices[i]);
|
||||
} else if (py::isinstance<py::slice>(idx)) {
|
||||
} else if (nb::isinstance<nb::slice>(idx)) {
|
||||
remaining_indices.push_back(
|
||||
py::slice(py::none(), py::none(), py::none()));
|
||||
nb::slice(nb::none(), nb::none(), nb::none()));
|
||||
}
|
||||
}
|
||||
for (int i = last_array + 1; i < indices.size(); i++) {
|
||||
@@ -316,18 +317,18 @@ array mlx_get_item_nd(array src, const py::tuple& entries) {
|
||||
} else {
|
||||
for (int i = 0; i < indices.size(); i++) {
|
||||
auto& idx = indices[i];
|
||||
if (py::isinstance<array>(idx) || py::isinstance<py::int_>(idx)) {
|
||||
if (nb::isinstance<array>(idx) || nb::isinstance<nb::int_>(idx)) {
|
||||
break;
|
||||
} else if (idx.is_none()) {
|
||||
remaining_indices.push_back(idx);
|
||||
} else {
|
||||
remaining_indices.push_back(
|
||||
py::slice(py::none(), py::none(), py::none()));
|
||||
nb::slice(nb::none(), nb::none(), nb::none()));
|
||||
}
|
||||
}
|
||||
for (int i = 0; i < max_dims; i++) {
|
||||
remaining_indices.push_back(
|
||||
py::slice(py::none(), py::none(), py::none()));
|
||||
nb::slice(nb::none(), nb::none(), nb::none()));
|
||||
}
|
||||
for (int i = last_array + 1; i < indices.size(); i++) {
|
||||
remaining_indices.push_back(indices[i]);
|
||||
@@ -351,7 +352,11 @@ array mlx_get_item_nd(array src, const py::tuple& entries) {
|
||||
for (auto& idx : remaining_indices) {
|
||||
if (!idx.is_none()) {
|
||||
get_slice_params(
|
||||
starts[axis], ends[axis], strides[axis], idx, ends[axis]);
|
||||
starts[axis],
|
||||
ends[axis],
|
||||
strides[axis],
|
||||
nb::cast<nb::slice>(idx),
|
||||
ends[axis]);
|
||||
axis++;
|
||||
}
|
||||
}
|
||||
@@ -375,15 +380,17 @@ array mlx_get_item_nd(array src, const py::tuple& entries) {
|
||||
return src;
|
||||
}
|
||||
|
||||
array mlx_get_item(const array& src, const py::object& obj) {
|
||||
if (py::isinstance<py::slice>(obj)) {
|
||||
return mlx_get_item_slice(src, obj);
|
||||
} else if (py::isinstance<array>(obj)) {
|
||||
return mlx_get_item_array(src, py::cast<array>(obj));
|
||||
} else if (py::isinstance<py::int_>(obj)) {
|
||||
return mlx_get_item_int(src, obj);
|
||||
} else if (py::isinstance<py::tuple>(obj)) {
|
||||
return mlx_get_item_nd(src, obj);
|
||||
array mlx_get_item(const array& src, const nb::object& obj) {
|
||||
if (nb::isinstance<nb::slice>(obj)) {
|
||||
return mlx_get_item_slice(src, nb::cast<nb::slice>(obj));
|
||||
} else if (nb::isinstance<array>(obj)) {
|
||||
return mlx_get_item_array(src, nb::cast<array>(obj));
|
||||
} else if (nb::isinstance<nb::int_>(obj)) {
|
||||
return mlx_get_item_int(src, nb::cast<nb::int_>(obj));
|
||||
} else if (nb::isinstance<nb::tuple>(obj)) {
|
||||
return mlx_get_item_nd(src, nb::cast<nb::tuple>(obj));
|
||||
} else if (nb::isinstance<nb::ellipsis>(obj)) {
|
||||
return src;
|
||||
} else if (obj.is_none()) {
|
||||
std::vector<int> s(1, 1);
|
||||
s.insert(s.end(), src.shape().begin(), src.shape().end());
|
||||
@@ -394,7 +401,7 @@ array mlx_get_item(const array& src, const py::object& obj) {
|
||||
|
||||
std::tuple<std::vector<array>, array, std::vector<int>> mlx_scatter_args_int(
|
||||
const array& src,
|
||||
const py::int_& idx,
|
||||
const nb::int_& idx,
|
||||
const array& update) {
|
||||
if (src.ndim() == 0) {
|
||||
throw std::invalid_argument(
|
||||
@@ -446,7 +453,7 @@ std::tuple<std::vector<array>, array, std::vector<int>> mlx_scatter_args_array(
|
||||
|
||||
std::tuple<std::vector<array>, array, std::vector<int>> mlx_scatter_args_slice(
|
||||
const array& src,
|
||||
const py::slice& in_slice,
|
||||
const nb::slice& in_slice,
|
||||
const array& update) {
|
||||
// Check input and raise error if 0 dim for parity with np
|
||||
if (src.ndim() == 0) {
|
||||
@@ -478,9 +485,9 @@ std::tuple<std::vector<array>, array, std::vector<int>> mlx_scatter_args_slice(
|
||||
|
||||
std::tuple<std::vector<array>, array, std::vector<int>> mlx_scatter_args_nd(
|
||||
const array& src,
|
||||
const py::tuple& entries,
|
||||
const nb::tuple& entries,
|
||||
const array& update) {
|
||||
std::vector<py::object> indices;
|
||||
std::vector<nb::object> indices;
|
||||
int non_none_indices = 0;
|
||||
|
||||
// Expand ellipses into a series of ':' slices
|
||||
@@ -494,7 +501,7 @@ std::tuple<std::vector<array>, array, std::vector<int>> mlx_scatter_args_nd(
|
||||
if (!is_valid_index_type(idx)) {
|
||||
throw std::invalid_argument(
|
||||
"Cannot index mlx array using the given type yet");
|
||||
} else if (!py::ellipsis().is(idx)) {
|
||||
} else if (!nb::ellipsis().is(idx)) {
|
||||
if (!has_ellipsis) {
|
||||
indices_before++;
|
||||
non_none_indices_before += !idx.is_none();
|
||||
@@ -514,7 +521,7 @@ std::tuple<std::vector<array>, array, std::vector<int>> mlx_scatter_args_nd(
|
||||
axis < src.ndim() - non_none_indices_after;
|
||||
axis++) {
|
||||
indices.insert(
|
||||
indices.begin() + indices_before, py::slice(0, src.shape(axis), 1));
|
||||
indices.begin() + indices_before, nb::slice(0, src.shape(axis), 1));
|
||||
}
|
||||
non_none_indices = src.ndim();
|
||||
} else {
|
||||
@@ -549,15 +556,15 @@ std::tuple<std::vector<array>, array, std::vector<int>> mlx_scatter_args_nd(
|
||||
bool have_array = false;
|
||||
bool have_non_array = false;
|
||||
for (auto& idx : indices) {
|
||||
if (py::isinstance<py::slice>(idx) || idx.is_none()) {
|
||||
if (nb::isinstance<nb::slice>(idx) || idx.is_none()) {
|
||||
have_non_array = have_array;
|
||||
num_slices++;
|
||||
} else if (py::isinstance<array>(idx)) {
|
||||
} else if (nb::isinstance<array>(idx)) {
|
||||
have_array = true;
|
||||
if (have_array && have_non_array) {
|
||||
arrays_first = true;
|
||||
}
|
||||
max_dim = std::max(py::cast<array>(idx).ndim(), max_dim);
|
||||
max_dim = std::max(nb::cast<array>(idx).ndim(), max_dim);
|
||||
num_arrays++;
|
||||
}
|
||||
}
|
||||
@@ -569,10 +576,11 @@ std::tuple<std::vector<array>, array, std::vector<int>> mlx_scatter_args_nd(
|
||||
int ax = 0;
|
||||
for (int i = 0; i < indices.size(); ++i) {
|
||||
auto& pyidx = indices[i];
|
||||
if (py::isinstance<py::slice>(pyidx)) {
|
||||
if (nb::isinstance<nb::slice>(pyidx)) {
|
||||
int start, end, stride;
|
||||
auto axis_size = src.shape(ax++);
|
||||
get_slice_params(start, end, stride, pyidx, axis_size);
|
||||
get_slice_params(
|
||||
start, end, stride, nb::cast<nb::slice>(pyidx), axis_size);
|
||||
|
||||
// Handle negative indices
|
||||
start = (start < 0) ? start + axis_size : start;
|
||||
@@ -584,13 +592,13 @@ std::tuple<std::vector<array>, array, std::vector<int>> mlx_scatter_args_nd(
|
||||
slice_num++;
|
||||
idx_shape[loc] = idx.size();
|
||||
arr_indices.push_back(reshape(idx, idx_shape));
|
||||
} else if (py::isinstance<py::int_>(pyidx)) {
|
||||
} else if (nb::isinstance<nb::int_>(pyidx)) {
|
||||
arr_indices.push_back(get_int_index(pyidx, src.shape(ax++)));
|
||||
} else if (pyidx.is_none()) {
|
||||
slice_num++;
|
||||
} else if (py::isinstance<array>(pyidx)) {
|
||||
} else if (nb::isinstance<array>(pyidx)) {
|
||||
ax++;
|
||||
auto idx = py::cast<array>(pyidx);
|
||||
auto idx = nb::cast<array>(pyidx);
|
||||
std::vector<int> idx_shape;
|
||||
if (!arrays_first) {
|
||||
idx_shape.insert(idx_shape.end(), slice_num, 1);
|
||||
@@ -629,24 +637,24 @@ std::tuple<std::vector<array>, array, std::vector<int>> mlx_scatter_args_nd(
|
||||
std::tuple<std::vector<array>, array, std::vector<int>>
|
||||
mlx_compute_scatter_args(
|
||||
const array& src,
|
||||
const py::object& obj,
|
||||
const nb::object& obj,
|
||||
const ScalarOrArray& v) {
|
||||
auto vals = to_array(v, src.dtype());
|
||||
if (py::isinstance<py::slice>(obj)) {
|
||||
return mlx_scatter_args_slice(src, obj, vals);
|
||||
} else if (py::isinstance<array>(obj)) {
|
||||
return mlx_scatter_args_array(src, py::cast<array>(obj), vals);
|
||||
} else if (py::isinstance<py::int_>(obj)) {
|
||||
return mlx_scatter_args_int(src, obj, vals);
|
||||
} else if (py::isinstance<py::tuple>(obj)) {
|
||||
return mlx_scatter_args_nd(src, obj, vals);
|
||||
if (nb::isinstance<nb::slice>(obj)) {
|
||||
return mlx_scatter_args_slice(src, nb::cast<nb::slice>(obj), vals);
|
||||
} else if (nb::isinstance<array>(obj)) {
|
||||
return mlx_scatter_args_array(src, nb::cast<array>(obj), vals);
|
||||
} else if (nb::isinstance<nb::int_>(obj)) {
|
||||
return mlx_scatter_args_int(src, nb::cast<nb::int_>(obj), vals);
|
||||
} else if (nb::isinstance<nb::tuple>(obj)) {
|
||||
return mlx_scatter_args_nd(src, nb::cast<nb::tuple>(obj), vals);
|
||||
} else if (obj.is_none()) {
|
||||
return {{}, broadcast_to(vals, src.shape()), {}};
|
||||
}
|
||||
throw std::invalid_argument("Cannot index mlx array using the given type.");
|
||||
}
|
||||
|
||||
void mlx_set_item(array& src, const py::object& obj, const ScalarOrArray& v) {
|
||||
void mlx_set_item(array& src, const nb::object& obj, const ScalarOrArray& v) {
|
||||
auto [indices, updates, axes] = mlx_compute_scatter_args(src, obj, v);
|
||||
if (indices.size() > 0) {
|
||||
auto out = scatter(src, indices, updates, axes);
|
||||
@@ -658,7 +666,7 @@ void mlx_set_item(array& src, const py::object& obj, const ScalarOrArray& v) {
|
||||
|
||||
array mlx_add_item(
|
||||
const array& src,
|
||||
const py::object& obj,
|
||||
const nb::object& obj,
|
||||
const ScalarOrArray& v) {
|
||||
auto [indices, updates, axes] = mlx_compute_scatter_args(src, obj, v);
|
||||
if (indices.size() > 0) {
|
||||
@@ -670,7 +678,7 @@ array mlx_add_item(
|
||||
|
||||
array mlx_subtract_item(
|
||||
const array& src,
|
||||
const py::object& obj,
|
||||
const nb::object& obj,
|
||||
const ScalarOrArray& v) {
|
||||
auto [indices, updates, axes] = mlx_compute_scatter_args(src, obj, v);
|
||||
if (indices.size() > 0) {
|
||||
@@ -682,7 +690,7 @@ array mlx_subtract_item(
|
||||
|
||||
array mlx_multiply_item(
|
||||
const array& src,
|
||||
const py::object& obj,
|
||||
const nb::object& obj,
|
||||
const ScalarOrArray& v) {
|
||||
auto [indices, updates, axes] = mlx_compute_scatter_args(src, obj, v);
|
||||
if (indices.size() > 0) {
|
||||
@@ -694,7 +702,7 @@ array mlx_multiply_item(
|
||||
|
||||
array mlx_divide_item(
|
||||
const array& src,
|
||||
const py::object& obj,
|
||||
const nb::object& obj,
|
||||
const ScalarOrArray& v) {
|
||||
auto [indices, updates, axes] = mlx_compute_scatter_args(src, obj, v);
|
||||
if (indices.size() > 0) {
|
||||
@@ -706,7 +714,7 @@ array mlx_divide_item(
|
||||
|
||||
array mlx_maximum_item(
|
||||
const array& src,
|
||||
const py::object& obj,
|
||||
const nb::object& obj,
|
||||
const ScalarOrArray& v) {
|
||||
auto [indices, updates, axes] = mlx_compute_scatter_args(src, obj, v);
|
||||
if (indices.size() > 0) {
|
||||
@@ -718,7 +726,7 @@ array mlx_maximum_item(
|
||||
|
||||
array mlx_minimum_item(
|
||||
const array& src,
|
||||
const py::object& obj,
|
||||
const nb::object& obj,
|
||||
const ScalarOrArray& v) {
|
||||
auto [indices, updates, axes] = mlx_compute_scatter_args(src, obj, v);
|
||||
if (indices.size() > 0) {
|
||||
|
||||
Reference in New Issue
Block a user