mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-12 07:18:52 +08:00
@@ -889,13 +889,13 @@ void init_array(nb::module_& m) {
|
||||
.def(
|
||||
"reshape",
|
||||
[](const mx::array& a, nb::args shape_, mx::StreamOrDevice s) {
|
||||
std::vector<int> shape;
|
||||
mx::Shape shape;
|
||||
if (!nb::isinstance<int>(shape_[0])) {
|
||||
shape = nb::cast<std::vector<int>>(shape_[0]);
|
||||
shape = nb::cast<mx::Shape>(shape_[0]);
|
||||
} else {
|
||||
shape = nb::cast<std::vector<int>>(shape_);
|
||||
shape = nb::cast<mx::Shape>(shape_);
|
||||
}
|
||||
return mx::reshape(a, shape, s);
|
||||
return mx::reshape(a, std::move(shape), s);
|
||||
},
|
||||
"shape"_a,
|
||||
"stream"_a = nb::none(),
|
||||
@@ -1182,14 +1182,14 @@ void init_array(nb::module_& m) {
|
||||
.def(
|
||||
"split",
|
||||
[](const mx::array& a,
|
||||
const std::variant<int, std::vector<int>>& indices_or_sections,
|
||||
const std::variant<int, mx::Shape>& indices_or_sections,
|
||||
int axis,
|
||||
mx::StreamOrDevice s) {
|
||||
if (auto pv = std::get_if<int>(&indices_or_sections); pv) {
|
||||
return mx::split(a, *pv, axis, s);
|
||||
} else {
|
||||
return mx::split(
|
||||
a, std::get<std::vector<int>>(indices_or_sections), axis, s);
|
||||
a, std::get<mx::Shape>(indices_or_sections), axis, s);
|
||||
}
|
||||
},
|
||||
"indices_or_sections"_a,
|
||||
|
||||
@@ -181,7 +181,7 @@ void init_fast(nb::module_& parent_module) {
|
||||
return nb::cpp_function(
|
||||
[kernel = std::move(kernel)](
|
||||
const std::vector<ScalarOrArray>& inputs_,
|
||||
const std::vector<std::vector<int>>& output_shapes,
|
||||
const std::vector<mx::Shape>& output_shapes,
|
||||
const std::vector<mx::Dtype>& output_dtypes,
|
||||
std::tuple<int, int, int> grid,
|
||||
std::tuple<int, int, int> threadgroup,
|
||||
|
||||
@@ -79,7 +79,7 @@ void init_fft(nb::module_& parent_module) {
|
||||
m.def(
|
||||
"fft2",
|
||||
[](const mx::array& a,
|
||||
const std::optional<std::vector<int>>& n,
|
||||
const std::optional<mx::Shape>& n,
|
||||
const std::optional<std::vector<int>>& axes,
|
||||
mx::StreamOrDevice s) {
|
||||
if (axes.has_value() && n.has_value()) {
|
||||
@@ -115,7 +115,7 @@ void init_fft(nb::module_& parent_module) {
|
||||
m.def(
|
||||
"ifft2",
|
||||
[](const mx::array& a,
|
||||
const std::optional<std::vector<int>>& n,
|
||||
const std::optional<mx::Shape>& n,
|
||||
const std::optional<std::vector<int>>& axes,
|
||||
mx::StreamOrDevice s) {
|
||||
if (axes.has_value() && n.has_value()) {
|
||||
@@ -151,7 +151,7 @@ void init_fft(nb::module_& parent_module) {
|
||||
m.def(
|
||||
"fftn",
|
||||
[](const mx::array& a,
|
||||
const std::optional<std::vector<int>>& n,
|
||||
const std::optional<mx::Shape>& n,
|
||||
const std::optional<std::vector<int>>& axes,
|
||||
mx::StreamOrDevice s) {
|
||||
if (axes.has_value() && n.has_value()) {
|
||||
@@ -188,7 +188,7 @@ void init_fft(nb::module_& parent_module) {
|
||||
m.def(
|
||||
"ifftn",
|
||||
[](const mx::array& a,
|
||||
const std::optional<std::vector<int>>& n,
|
||||
const std::optional<mx::Shape>& n,
|
||||
const std::optional<std::vector<int>>& axes,
|
||||
mx::StreamOrDevice s) {
|
||||
if (axes.has_value() && n.has_value()) {
|
||||
@@ -294,7 +294,7 @@ void init_fft(nb::module_& parent_module) {
|
||||
m.def(
|
||||
"rfft2",
|
||||
[](const mx::array& a,
|
||||
const std::optional<std::vector<int>>& n,
|
||||
const std::optional<mx::Shape>& n,
|
||||
const std::optional<std::vector<int>>& axes,
|
||||
mx::StreamOrDevice s) {
|
||||
if (axes.has_value() && n.has_value()) {
|
||||
@@ -336,7 +336,7 @@ void init_fft(nb::module_& parent_module) {
|
||||
m.def(
|
||||
"irfft2",
|
||||
[](const mx::array& a,
|
||||
const std::optional<std::vector<int>>& n,
|
||||
const std::optional<mx::Shape>& n,
|
||||
const std::optional<std::vector<int>>& axes,
|
||||
mx::StreamOrDevice s) {
|
||||
if (axes.has_value() && n.has_value()) {
|
||||
@@ -378,7 +378,7 @@ void init_fft(nb::module_& parent_module) {
|
||||
m.def(
|
||||
"rfftn",
|
||||
[](const mx::array& a,
|
||||
const std::optional<std::vector<int>>& n,
|
||||
const std::optional<mx::Shape>& n,
|
||||
const std::optional<std::vector<int>>& axes,
|
||||
mx::StreamOrDevice s) {
|
||||
if (axes.has_value() && n.has_value()) {
|
||||
@@ -420,7 +420,7 @@ void init_fft(nb::module_& parent_module) {
|
||||
m.def(
|
||||
"irfftn",
|
||||
[](const mx::array& a,
|
||||
const std::optional<std::vector<int>>& n,
|
||||
const std::optional<mx::Shape>& n,
|
||||
const std::optional<std::vector<int>>& axes,
|
||||
mx::StreamOrDevice s) {
|
||||
if (axes.has_value() && n.has_value()) {
|
||||
|
||||
@@ -25,9 +25,9 @@ int get_slice_int(nb::object obj, int default_val) {
|
||||
}
|
||||
|
||||
void get_slice_params(
|
||||
int& starts,
|
||||
int& ends,
|
||||
int& strides,
|
||||
mx::ShapeElem& starts,
|
||||
mx::ShapeElem& ends,
|
||||
mx::ShapeElem& strides,
|
||||
const nb::slice& in_slice,
|
||||
int axis_size) {
|
||||
// Following numpy's convention
|
||||
@@ -68,9 +68,9 @@ mx::array mlx_get_item_slice(const mx::array& src, const nb::slice& in_slice) {
|
||||
return src;
|
||||
}
|
||||
|
||||
std::vector<int> starts(src.ndim(), 0);
|
||||
std::vector<int> ends = src.shape();
|
||||
std::vector<int> strides(src.ndim(), 1);
|
||||
mx::Shape starts(src.ndim(), 0);
|
||||
auto ends = src.shape();
|
||||
mx::Shape strides(src.ndim(), 1);
|
||||
|
||||
// Check and update slice params
|
||||
get_slice_params(starts[0], ends[0], strides[0], in_slice, ends[0]);
|
||||
@@ -119,7 +119,7 @@ mx::array mlx_gather_nd(
|
||||
auto& idx = indices[i];
|
||||
|
||||
if (nb::isinstance<nb::slice>(idx)) {
|
||||
int start, end, stride;
|
||||
mx::ShapeElem start, end, stride;
|
||||
get_slice_params(
|
||||
start, end, stride, nb::cast<nb::slice>(idx), src.shape(i));
|
||||
|
||||
@@ -168,7 +168,7 @@ mx::array mlx_gather_nd(
|
||||
// Do the gather
|
||||
std::vector<int> axes(indices.size());
|
||||
std::iota(axes.begin(), axes.end(), 0);
|
||||
std::vector<int> slice_sizes = src.shape();
|
||||
auto slice_sizes = src.shape();
|
||||
std::fill(slice_sizes.begin(), slice_sizes.begin() + indices.size(), 1);
|
||||
src = gather(src, gather_indices, axes, slice_sizes);
|
||||
|
||||
@@ -179,9 +179,7 @@ mx::array mlx_gather_nd(
|
||||
return mx::squeeze(src, axes);
|
||||
}
|
||||
|
||||
auto mlx_expand_ellipsis(
|
||||
const std::vector<int>& shape,
|
||||
const nb::tuple& entries) {
|
||||
auto mlx_expand_ellipsis(const mx::Shape& shape, const nb::tuple& entries) {
|
||||
std::vector<nb::object> indices;
|
||||
|
||||
// Go over all entries and note the position of ellipsis
|
||||
@@ -230,7 +228,8 @@ auto mlx_expand_ellipsis(
|
||||
for (int axis = non_none_indices_before;
|
||||
axis < shape.size() - non_none_indices_after;
|
||||
axis++) {
|
||||
indices.push_back(nb::slice(0, shape[axis], 1));
|
||||
indices.push_back(
|
||||
nb::slice(mx::ShapeElem{0}, shape[axis], mx::ShapeElem{1}));
|
||||
non_none_indices++;
|
||||
}
|
||||
}
|
||||
@@ -371,9 +370,9 @@ mx::array mlx_get_item_nd(mx::array src, const nb::tuple& entries) {
|
||||
|
||||
// Slice handling
|
||||
{
|
||||
std::vector<int> starts(src.ndim(), 0);
|
||||
std::vector<int> ends = src.shape();
|
||||
std::vector<int> strides(src.ndim(), 1);
|
||||
mx::Shape starts(src.ndim(), 0);
|
||||
auto ends = src.shape();
|
||||
mx::Shape strides(src.ndim(), 1);
|
||||
int axis = 0;
|
||||
for (auto& idx : remaining_indices) {
|
||||
if (!idx.is_none()) {
|
||||
@@ -461,8 +460,7 @@ mlx_scatter_args_int(
|
||||
int s = 0;
|
||||
for (; s < update.ndim() && update.shape(s) == 1; s++)
|
||||
;
|
||||
auto up_shape =
|
||||
std::vector<int>(update.shape().begin() + s, update.shape().end());
|
||||
auto up_shape = mx::Shape(update.shape().begin() + s, update.shape().end());
|
||||
auto shape = src.shape();
|
||||
shape[0] = 1;
|
||||
|
||||
@@ -521,9 +519,9 @@ mlx_scatter_args_slice(
|
||||
{}, broadcast_to(squeeze_leading_singletons(update), src.shape()), {}};
|
||||
}
|
||||
|
||||
int start = 0;
|
||||
int end = src.shape(0);
|
||||
int stride = 1;
|
||||
mx::ShapeElem start = 0;
|
||||
auto end = src.shape(0);
|
||||
mx::ShapeElem stride = 1;
|
||||
|
||||
// Check and update slice params
|
||||
get_slice_params(start, end, stride, in_slice, end);
|
||||
@@ -645,7 +643,7 @@ mlx_scatter_args_nd(
|
||||
for (int i = 0; i < indices.size(); ++i) {
|
||||
auto& pyidx = indices[i];
|
||||
if (nb::isinstance<nb::slice>(pyidx)) {
|
||||
int start, end, stride;
|
||||
mx::ShapeElem start, end, stride;
|
||||
auto axis_size = src.shape(ax++);
|
||||
get_slice_params(
|
||||
start, end, stride, nb::cast<nb::slice>(pyidx), axis_size);
|
||||
@@ -654,7 +652,7 @@ mlx_scatter_args_nd(
|
||||
start = (start < 0) ? start + axis_size : start;
|
||||
end = (end < 0) ? end + axis_size : end;
|
||||
|
||||
std::vector<int> idx_shape(idx_ndim, 1);
|
||||
mx::Shape idx_shape(idx_ndim, 1);
|
||||
|
||||
// If it's a simple slice, we only need to add the start index
|
||||
if (array_num >= num_arrays && num_strided_slices <= 0 && stride == 1) {
|
||||
|
||||
@@ -1571,15 +1571,14 @@ void init_ops(nb::module_& m) {
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"full",
|
||||
[](const std::variant<int, std::vector<int>>& shape,
|
||||
[](const std::variant<int, mx::Shape>& shape,
|
||||
const ScalarOrArray& vals,
|
||||
std::optional<mx::Dtype> dtype,
|
||||
mx::StreamOrDevice s) {
|
||||
if (auto pv = std::get_if<int>(&shape); pv) {
|
||||
return mx::full({*pv}, to_array(vals, dtype), s);
|
||||
} else {
|
||||
return mx::full(
|
||||
std::get<std::vector<int>>(shape), to_array(vals, dtype), s);
|
||||
return mx::full(std::get<mx::Shape>(shape), to_array(vals, dtype), s);
|
||||
}
|
||||
},
|
||||
"shape"_a,
|
||||
@@ -1606,14 +1605,14 @@ void init_ops(nb::module_& m) {
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"zeros",
|
||||
[](const std::variant<int, std::vector<int>>& shape,
|
||||
[](const std::variant<int, mx::Shape>& shape,
|
||||
std::optional<mx::Dtype> dtype,
|
||||
mx::StreamOrDevice s) {
|
||||
auto t = dtype.value_or(mx::float32);
|
||||
if (auto pv = std::get_if<int>(&shape); pv) {
|
||||
return mx::zeros({*pv}, t, s);
|
||||
} else {
|
||||
return mx::zeros(std::get<std::vector<int>>(shape), t, s);
|
||||
return mx::zeros(std::get<mx::Shape>(shape), t, s);
|
||||
}
|
||||
},
|
||||
"shape"_a,
|
||||
@@ -1652,14 +1651,14 @@ void init_ops(nb::module_& m) {
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"ones",
|
||||
[](const std::variant<int, std::vector<int>>& shape,
|
||||
[](const std::variant<int, mx::Shape>& shape,
|
||||
std::optional<mx::Dtype> dtype,
|
||||
mx::StreamOrDevice s) {
|
||||
auto t = dtype.value_or(mx::float32);
|
||||
if (auto pv = std::get_if<int>(&shape); pv) {
|
||||
return mx::ones({*pv}, t, s);
|
||||
} else {
|
||||
return mx::ones(std::get<std::vector<int>>(shape), t, s);
|
||||
return mx::ones(std::get<mx::Shape>(shape), t, s);
|
||||
}
|
||||
},
|
||||
"shape"_a,
|
||||
@@ -2481,14 +2480,14 @@ void init_ops(nb::module_& m) {
|
||||
m.def(
|
||||
"split",
|
||||
[](const mx::array& a,
|
||||
const std::variant<int, std::vector<int>>& indices_or_sections,
|
||||
const std::variant<int, mx::Shape>& indices_or_sections,
|
||||
int axis,
|
||||
mx::StreamOrDevice s) {
|
||||
if (auto pv = std::get_if<int>(&indices_or_sections); pv) {
|
||||
return mx::split(a, *pv, axis, s);
|
||||
} else {
|
||||
return mx::split(
|
||||
a, std::get<std::vector<int>>(indices_or_sections), axis, s);
|
||||
a, std::get<mx::Shape>(indices_or_sections), axis, s);
|
||||
}
|
||||
},
|
||||
nb::arg(),
|
||||
@@ -2744,9 +2743,7 @@ void init_ops(nb::module_& m) {
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"broadcast_to",
|
||||
[](const ScalarOrArray& a,
|
||||
const std::vector<int>& shape,
|
||||
mx::StreamOrDevice s) {
|
||||
[](const ScalarOrArray& a, const mx::Shape& shape, mx::StreamOrDevice s) {
|
||||
return mx::broadcast_to(to_array(a), shape, s);
|
||||
},
|
||||
nb::arg(),
|
||||
@@ -4895,23 +4892,15 @@ void init_ops(nb::module_& m) {
|
||||
m.def(
|
||||
"roll",
|
||||
[](const mx::array& a,
|
||||
const IntOrVec& shift,
|
||||
const std::variant<int, mx::Shape>& shift,
|
||||
const IntOrVec& axis,
|
||||
mx::StreamOrDevice s) {
|
||||
return std::visit(
|
||||
[&](auto sh, auto ax) -> mx::array {
|
||||
using T = decltype(ax);
|
||||
using V = decltype(sh);
|
||||
|
||||
if constexpr (std::is_same_v<V, std::monostate>) {
|
||||
throw std::invalid_argument(
|
||||
"[roll] Expected two arguments but only one was given.");
|
||||
if constexpr (std::is_same_v<decltype(ax), std::monostate>) {
|
||||
return mx::roll(a, sh, s);
|
||||
} else {
|
||||
if constexpr (std::is_same_v<T, std::monostate>) {
|
||||
return mx::roll(a, sh, s);
|
||||
} else {
|
||||
return mx::roll(a, sh, ax, s);
|
||||
}
|
||||
return mx::roll(a, sh, ax, s);
|
||||
}
|
||||
},
|
||||
shift,
|
||||
|
||||
@@ -108,7 +108,7 @@ void init_random(nb::module_& parent_module) {
|
||||
"uniform",
|
||||
[](const ScalarOrArray& low,
|
||||
const ScalarOrArray& high,
|
||||
const std::vector<int>& shape,
|
||||
const mx::Shape& shape,
|
||||
std::optional<mx::Dtype> type,
|
||||
const std::optional<mx::array>& key_,
|
||||
mx::StreamOrDevice s) {
|
||||
@@ -123,7 +123,7 @@ void init_random(nb::module_& parent_module) {
|
||||
},
|
||||
"low"_a = 0,
|
||||
"high"_a = 1,
|
||||
"shape"_a = std::vector<int>{},
|
||||
"shape"_a = mx::Shape{},
|
||||
"dtype"_a.none() = mx::float32,
|
||||
"key"_a = nb::none(),
|
||||
"stream"_a = nb::none(),
|
||||
@@ -150,7 +150,7 @@ void init_random(nb::module_& parent_module) {
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"normal",
|
||||
[](const std::vector<int>& shape,
|
||||
[](const mx::Shape& shape,
|
||||
std::optional<mx::Dtype> type,
|
||||
float loc,
|
||||
float scale,
|
||||
@@ -160,7 +160,7 @@ void init_random(nb::module_& parent_module) {
|
||||
return mx::random::normal(
|
||||
shape, type.value_or(mx::float32), loc, scale, key, s);
|
||||
},
|
||||
"shape"_a = std::vector<int>{},
|
||||
"shape"_a = mx::Shape{},
|
||||
"dtype"_a.none() = mx::float32,
|
||||
"loc"_a = 0.0,
|
||||
"scale"_a = 1.0,
|
||||
@@ -185,7 +185,7 @@ void init_random(nb::module_& parent_module) {
|
||||
"multivariate_normal",
|
||||
[](const mx::array& mean,
|
||||
const mx::array& cov,
|
||||
const std::vector<int>& shape,
|
||||
const mx::Shape& shape,
|
||||
std::optional<mx::Dtype> type,
|
||||
const std::optional<mx::array>& key_,
|
||||
mx::StreamOrDevice s) {
|
||||
@@ -195,7 +195,7 @@ void init_random(nb::module_& parent_module) {
|
||||
},
|
||||
"mean"_a,
|
||||
"cov"_a,
|
||||
"shape"_a = std::vector<int>{},
|
||||
"shape"_a = mx::Shape{},
|
||||
"dtype"_a.none() = mx::float32,
|
||||
"key"_a = nb::none(),
|
||||
"stream"_a = nb::none(),
|
||||
@@ -227,7 +227,7 @@ void init_random(nb::module_& parent_module) {
|
||||
"randint",
|
||||
[](const ScalarOrArray& low,
|
||||
const ScalarOrArray& high,
|
||||
const std::vector<int>& shape,
|
||||
const mx::Shape& shape,
|
||||
std::optional<mx::Dtype> type,
|
||||
const std::optional<mx::array>& key_,
|
||||
mx::StreamOrDevice s) {
|
||||
@@ -242,7 +242,7 @@ void init_random(nb::module_& parent_module) {
|
||||
},
|
||||
"low"_a,
|
||||
"high"_a,
|
||||
"shape"_a = std::vector<int>{},
|
||||
"shape"_a = mx::Shape{},
|
||||
"dtype"_a.none() = mx::int32,
|
||||
"key"_a = nb::none(),
|
||||
"stream"_a = nb::none(),
|
||||
@@ -268,7 +268,7 @@ void init_random(nb::module_& parent_module) {
|
||||
m.def(
|
||||
"bernoulli",
|
||||
[](const ScalarOrArray& p_,
|
||||
const std::optional<std::vector<int>> shape,
|
||||
const std::optional<mx::Shape> shape,
|
||||
const std::optional<mx::array>& key_,
|
||||
mx::StreamOrDevice s) {
|
||||
auto key = key_ ? key_.value() : default_key().next();
|
||||
@@ -306,7 +306,7 @@ void init_random(nb::module_& parent_module) {
|
||||
"truncated_normal",
|
||||
[](const ScalarOrArray& lower_,
|
||||
const ScalarOrArray& upper_,
|
||||
const std::optional<std::vector<int>> shape_,
|
||||
const std::optional<mx::Shape> shape_,
|
||||
std::optional<mx::Dtype> type,
|
||||
const std::optional<mx::array>& key_,
|
||||
mx::StreamOrDevice s) {
|
||||
@@ -350,14 +350,14 @@ void init_random(nb::module_& parent_module) {
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"gumbel",
|
||||
[](const std::vector<int>& shape,
|
||||
[](const mx::Shape& shape,
|
||||
std::optional<mx::Dtype> type,
|
||||
const std::optional<mx::array>& key_,
|
||||
mx::StreamOrDevice s) {
|
||||
auto key = key_ ? key_.value() : default_key().next();
|
||||
return mx::random::gumbel(shape, type.value_or(mx::float32), key, s);
|
||||
},
|
||||
"shape"_a = std::vector<int>{},
|
||||
"shape"_a = mx::Shape{},
|
||||
"dtype"_a.none() = mx::float32,
|
||||
"key"_a = nb::none(),
|
||||
"stream"_a = nb::none(),
|
||||
@@ -384,7 +384,7 @@ void init_random(nb::module_& parent_module) {
|
||||
"categorical",
|
||||
[](const mx::array& logits,
|
||||
int axis,
|
||||
const std::optional<std::vector<int>> shape,
|
||||
const std::optional<mx::Shape> shape,
|
||||
const std::optional<int> num_samples,
|
||||
const std::optional<mx::array>& key_,
|
||||
mx::StreamOrDevice s) {
|
||||
@@ -434,7 +434,7 @@ void init_random(nb::module_& parent_module) {
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"laplace",
|
||||
[](const std::vector<int>& shape,
|
||||
[](const mx::Shape& shape,
|
||||
std::optional<mx::Dtype> type,
|
||||
float loc,
|
||||
float scale,
|
||||
@@ -444,7 +444,7 @@ void init_random(nb::module_& parent_module) {
|
||||
return mx::random::laplace(
|
||||
shape, type.value_or(mx::float32), loc, scale, key, s);
|
||||
},
|
||||
"shape"_a = std::vector<int>{},
|
||||
"shape"_a = mx::Shape{},
|
||||
"dtype"_a.none() = mx::float32,
|
||||
"loc"_a = 0.0,
|
||||
"scale"_a = 1.0,
|
||||
@@ -479,7 +479,7 @@ void init_random(nb::module_& parent_module) {
|
||||
return mx::random::permutation(std::get<mx::array>(x), axis, key, s);
|
||||
}
|
||||
},
|
||||
"shape"_a = std::vector<int>{},
|
||||
"x"_a,
|
||||
"axis"_a = 0,
|
||||
"key"_a = nb::none(),
|
||||
"stream"_a = nb::none(),
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <nanobind/nanobind.h>
|
||||
#include <nanobind/stl/optional.h>
|
||||
#include <nanobind/stl/pair.h>
|
||||
|
||||
Reference in New Issue
Block a user