More shape type (#1705)

* more shape type

* fix
This commit is contained in:
Awni Hannun
2024-12-19 08:08:20 -08:00
committed by GitHub
parent f17536af9c
commit e03f0372b1
38 changed files with 260 additions and 258 deletions

View File

@@ -889,13 +889,13 @@ void init_array(nb::module_& m) {
.def(
"reshape",
[](const mx::array& a, nb::args shape_, mx::StreamOrDevice s) {
std::vector<int> shape;
mx::Shape shape;
if (!nb::isinstance<int>(shape_[0])) {
shape = nb::cast<std::vector<int>>(shape_[0]);
shape = nb::cast<mx::Shape>(shape_[0]);
} else {
shape = nb::cast<std::vector<int>>(shape_);
shape = nb::cast<mx::Shape>(shape_);
}
return mx::reshape(a, shape, s);
return mx::reshape(a, std::move(shape), s);
},
"shape"_a,
"stream"_a = nb::none(),
@@ -1182,14 +1182,14 @@ void init_array(nb::module_& m) {
.def(
"split",
[](const mx::array& a,
const std::variant<int, std::vector<int>>& indices_or_sections,
const std::variant<int, mx::Shape>& indices_or_sections,
int axis,
mx::StreamOrDevice s) {
if (auto pv = std::get_if<int>(&indices_or_sections); pv) {
return mx::split(a, *pv, axis, s);
} else {
return mx::split(
a, std::get<std::vector<int>>(indices_or_sections), axis, s);
a, std::get<mx::Shape>(indices_or_sections), axis, s);
}
},
"indices_or_sections"_a,

View File

@@ -181,7 +181,7 @@ void init_fast(nb::module_& parent_module) {
return nb::cpp_function(
[kernel = std::move(kernel)](
const std::vector<ScalarOrArray>& inputs_,
const std::vector<std::vector<int>>& output_shapes,
const std::vector<mx::Shape>& output_shapes,
const std::vector<mx::Dtype>& output_dtypes,
std::tuple<int, int, int> grid,
std::tuple<int, int, int> threadgroup,

View File

@@ -79,7 +79,7 @@ void init_fft(nb::module_& parent_module) {
m.def(
"fft2",
[](const mx::array& a,
const std::optional<std::vector<int>>& n,
const std::optional<mx::Shape>& n,
const std::optional<std::vector<int>>& axes,
mx::StreamOrDevice s) {
if (axes.has_value() && n.has_value()) {
@@ -115,7 +115,7 @@ void init_fft(nb::module_& parent_module) {
m.def(
"ifft2",
[](const mx::array& a,
const std::optional<std::vector<int>>& n,
const std::optional<mx::Shape>& n,
const std::optional<std::vector<int>>& axes,
mx::StreamOrDevice s) {
if (axes.has_value() && n.has_value()) {
@@ -151,7 +151,7 @@ void init_fft(nb::module_& parent_module) {
m.def(
"fftn",
[](const mx::array& a,
const std::optional<std::vector<int>>& n,
const std::optional<mx::Shape>& n,
const std::optional<std::vector<int>>& axes,
mx::StreamOrDevice s) {
if (axes.has_value() && n.has_value()) {
@@ -188,7 +188,7 @@ void init_fft(nb::module_& parent_module) {
m.def(
"ifftn",
[](const mx::array& a,
const std::optional<std::vector<int>>& n,
const std::optional<mx::Shape>& n,
const std::optional<std::vector<int>>& axes,
mx::StreamOrDevice s) {
if (axes.has_value() && n.has_value()) {
@@ -294,7 +294,7 @@ void init_fft(nb::module_& parent_module) {
m.def(
"rfft2",
[](const mx::array& a,
const std::optional<std::vector<int>>& n,
const std::optional<mx::Shape>& n,
const std::optional<std::vector<int>>& axes,
mx::StreamOrDevice s) {
if (axes.has_value() && n.has_value()) {
@@ -336,7 +336,7 @@ void init_fft(nb::module_& parent_module) {
m.def(
"irfft2",
[](const mx::array& a,
const std::optional<std::vector<int>>& n,
const std::optional<mx::Shape>& n,
const std::optional<std::vector<int>>& axes,
mx::StreamOrDevice s) {
if (axes.has_value() && n.has_value()) {
@@ -378,7 +378,7 @@ void init_fft(nb::module_& parent_module) {
m.def(
"rfftn",
[](const mx::array& a,
const std::optional<std::vector<int>>& n,
const std::optional<mx::Shape>& n,
const std::optional<std::vector<int>>& axes,
mx::StreamOrDevice s) {
if (axes.has_value() && n.has_value()) {
@@ -420,7 +420,7 @@ void init_fft(nb::module_& parent_module) {
m.def(
"irfftn",
[](const mx::array& a,
const std::optional<std::vector<int>>& n,
const std::optional<mx::Shape>& n,
const std::optional<std::vector<int>>& axes,
mx::StreamOrDevice s) {
if (axes.has_value() && n.has_value()) {

View File

@@ -25,9 +25,9 @@ int get_slice_int(nb::object obj, int default_val) {
}
void get_slice_params(
int& starts,
int& ends,
int& strides,
mx::ShapeElem& starts,
mx::ShapeElem& ends,
mx::ShapeElem& strides,
const nb::slice& in_slice,
int axis_size) {
// Following numpy's convention
@@ -68,9 +68,9 @@ mx::array mlx_get_item_slice(const mx::array& src, const nb::slice& in_slice) {
return src;
}
std::vector<int> starts(src.ndim(), 0);
std::vector<int> ends = src.shape();
std::vector<int> strides(src.ndim(), 1);
mx::Shape starts(src.ndim(), 0);
auto ends = src.shape();
mx::Shape strides(src.ndim(), 1);
// Check and update slice params
get_slice_params(starts[0], ends[0], strides[0], in_slice, ends[0]);
@@ -119,7 +119,7 @@ mx::array mlx_gather_nd(
auto& idx = indices[i];
if (nb::isinstance<nb::slice>(idx)) {
int start, end, stride;
mx::ShapeElem start, end, stride;
get_slice_params(
start, end, stride, nb::cast<nb::slice>(idx), src.shape(i));
@@ -168,7 +168,7 @@ mx::array mlx_gather_nd(
// Do the gather
std::vector<int> axes(indices.size());
std::iota(axes.begin(), axes.end(), 0);
std::vector<int> slice_sizes = src.shape();
auto slice_sizes = src.shape();
std::fill(slice_sizes.begin(), slice_sizes.begin() + indices.size(), 1);
src = gather(src, gather_indices, axes, slice_sizes);
@@ -179,9 +179,7 @@ mx::array mlx_gather_nd(
return mx::squeeze(src, axes);
}
auto mlx_expand_ellipsis(
const std::vector<int>& shape,
const nb::tuple& entries) {
auto mlx_expand_ellipsis(const mx::Shape& shape, const nb::tuple& entries) {
std::vector<nb::object> indices;
// Go over all entries and note the position of ellipsis
@@ -230,7 +228,8 @@ auto mlx_expand_ellipsis(
for (int axis = non_none_indices_before;
axis < shape.size() - non_none_indices_after;
axis++) {
indices.push_back(nb::slice(0, shape[axis], 1));
indices.push_back(
nb::slice(mx::ShapeElem{0}, shape[axis], mx::ShapeElem{1}));
non_none_indices++;
}
}
@@ -371,9 +370,9 @@ mx::array mlx_get_item_nd(mx::array src, const nb::tuple& entries) {
// Slice handling
{
std::vector<int> starts(src.ndim(), 0);
std::vector<int> ends = src.shape();
std::vector<int> strides(src.ndim(), 1);
mx::Shape starts(src.ndim(), 0);
auto ends = src.shape();
mx::Shape strides(src.ndim(), 1);
int axis = 0;
for (auto& idx : remaining_indices) {
if (!idx.is_none()) {
@@ -461,8 +460,7 @@ mlx_scatter_args_int(
int s = 0;
for (; s < update.ndim() && update.shape(s) == 1; s++)
;
auto up_shape =
std::vector<int>(update.shape().begin() + s, update.shape().end());
auto up_shape = mx::Shape(update.shape().begin() + s, update.shape().end());
auto shape = src.shape();
shape[0] = 1;
@@ -521,9 +519,9 @@ mlx_scatter_args_slice(
{}, broadcast_to(squeeze_leading_singletons(update), src.shape()), {}};
}
int start = 0;
int end = src.shape(0);
int stride = 1;
mx::ShapeElem start = 0;
auto end = src.shape(0);
mx::ShapeElem stride = 1;
// Check and update slice params
get_slice_params(start, end, stride, in_slice, end);
@@ -645,7 +643,7 @@ mlx_scatter_args_nd(
for (int i = 0; i < indices.size(); ++i) {
auto& pyidx = indices[i];
if (nb::isinstance<nb::slice>(pyidx)) {
int start, end, stride;
mx::ShapeElem start, end, stride;
auto axis_size = src.shape(ax++);
get_slice_params(
start, end, stride, nb::cast<nb::slice>(pyidx), axis_size);
@@ -654,7 +652,7 @@ mlx_scatter_args_nd(
start = (start < 0) ? start + axis_size : start;
end = (end < 0) ? end + axis_size : end;
std::vector<int> idx_shape(idx_ndim, 1);
mx::Shape idx_shape(idx_ndim, 1);
// If it's a simple slice, we only need to add the start index
if (array_num >= num_arrays && num_strided_slices <= 0 && stride == 1) {

View File

@@ -1571,15 +1571,14 @@ void init_ops(nb::module_& m) {
)pbdoc");
m.def(
"full",
[](const std::variant<int, std::vector<int>>& shape,
[](const std::variant<int, mx::Shape>& shape,
const ScalarOrArray& vals,
std::optional<mx::Dtype> dtype,
mx::StreamOrDevice s) {
if (auto pv = std::get_if<int>(&shape); pv) {
return mx::full({*pv}, to_array(vals, dtype), s);
} else {
return mx::full(
std::get<std::vector<int>>(shape), to_array(vals, dtype), s);
return mx::full(std::get<mx::Shape>(shape), to_array(vals, dtype), s);
}
},
"shape"_a,
@@ -1606,14 +1605,14 @@ void init_ops(nb::module_& m) {
)pbdoc");
m.def(
"zeros",
[](const std::variant<int, std::vector<int>>& shape,
[](const std::variant<int, mx::Shape>& shape,
std::optional<mx::Dtype> dtype,
mx::StreamOrDevice s) {
auto t = dtype.value_or(mx::float32);
if (auto pv = std::get_if<int>(&shape); pv) {
return mx::zeros({*pv}, t, s);
} else {
return mx::zeros(std::get<std::vector<int>>(shape), t, s);
return mx::zeros(std::get<mx::Shape>(shape), t, s);
}
},
"shape"_a,
@@ -1652,14 +1651,14 @@ void init_ops(nb::module_& m) {
)pbdoc");
m.def(
"ones",
[](const std::variant<int, std::vector<int>>& shape,
[](const std::variant<int, mx::Shape>& shape,
std::optional<mx::Dtype> dtype,
mx::StreamOrDevice s) {
auto t = dtype.value_or(mx::float32);
if (auto pv = std::get_if<int>(&shape); pv) {
return mx::ones({*pv}, t, s);
} else {
return mx::ones(std::get<std::vector<int>>(shape), t, s);
return mx::ones(std::get<mx::Shape>(shape), t, s);
}
},
"shape"_a,
@@ -2481,14 +2480,14 @@ void init_ops(nb::module_& m) {
m.def(
"split",
[](const mx::array& a,
const std::variant<int, std::vector<int>>& indices_or_sections,
const std::variant<int, mx::Shape>& indices_or_sections,
int axis,
mx::StreamOrDevice s) {
if (auto pv = std::get_if<int>(&indices_or_sections); pv) {
return mx::split(a, *pv, axis, s);
} else {
return mx::split(
a, std::get<std::vector<int>>(indices_or_sections), axis, s);
a, std::get<mx::Shape>(indices_or_sections), axis, s);
}
},
nb::arg(),
@@ -2744,9 +2743,7 @@ void init_ops(nb::module_& m) {
)pbdoc");
m.def(
"broadcast_to",
[](const ScalarOrArray& a,
const std::vector<int>& shape,
mx::StreamOrDevice s) {
[](const ScalarOrArray& a, const mx::Shape& shape, mx::StreamOrDevice s) {
return mx::broadcast_to(to_array(a), shape, s);
},
nb::arg(),
@@ -4895,23 +4892,15 @@ void init_ops(nb::module_& m) {
m.def(
"roll",
[](const mx::array& a,
const IntOrVec& shift,
const std::variant<int, mx::Shape>& shift,
const IntOrVec& axis,
mx::StreamOrDevice s) {
return std::visit(
[&](auto sh, auto ax) -> mx::array {
using T = decltype(ax);
using V = decltype(sh);
if constexpr (std::is_same_v<V, std::monostate>) {
throw std::invalid_argument(
"[roll] Expected two arguments but only one was given.");
if constexpr (std::is_same_v<decltype(ax), std::monostate>) {
return mx::roll(a, sh, s);
} else {
if constexpr (std::is_same_v<T, std::monostate>) {
return mx::roll(a, sh, s);
} else {
return mx::roll(a, sh, ax, s);
}
return mx::roll(a, sh, ax, s);
}
},
shift,

View File

@@ -108,7 +108,7 @@ void init_random(nb::module_& parent_module) {
"uniform",
[](const ScalarOrArray& low,
const ScalarOrArray& high,
const std::vector<int>& shape,
const mx::Shape& shape,
std::optional<mx::Dtype> type,
const std::optional<mx::array>& key_,
mx::StreamOrDevice s) {
@@ -123,7 +123,7 @@ void init_random(nb::module_& parent_module) {
},
"low"_a = 0,
"high"_a = 1,
"shape"_a = std::vector<int>{},
"shape"_a = mx::Shape{},
"dtype"_a.none() = mx::float32,
"key"_a = nb::none(),
"stream"_a = nb::none(),
@@ -150,7 +150,7 @@ void init_random(nb::module_& parent_module) {
)pbdoc");
m.def(
"normal",
[](const std::vector<int>& shape,
[](const mx::Shape& shape,
std::optional<mx::Dtype> type,
float loc,
float scale,
@@ -160,7 +160,7 @@ void init_random(nb::module_& parent_module) {
return mx::random::normal(
shape, type.value_or(mx::float32), loc, scale, key, s);
},
"shape"_a = std::vector<int>{},
"shape"_a = mx::Shape{},
"dtype"_a.none() = mx::float32,
"loc"_a = 0.0,
"scale"_a = 1.0,
@@ -185,7 +185,7 @@ void init_random(nb::module_& parent_module) {
"multivariate_normal",
[](const mx::array& mean,
const mx::array& cov,
const std::vector<int>& shape,
const mx::Shape& shape,
std::optional<mx::Dtype> type,
const std::optional<mx::array>& key_,
mx::StreamOrDevice s) {
@@ -195,7 +195,7 @@ void init_random(nb::module_& parent_module) {
},
"mean"_a,
"cov"_a,
"shape"_a = std::vector<int>{},
"shape"_a = mx::Shape{},
"dtype"_a.none() = mx::float32,
"key"_a = nb::none(),
"stream"_a = nb::none(),
@@ -227,7 +227,7 @@ void init_random(nb::module_& parent_module) {
"randint",
[](const ScalarOrArray& low,
const ScalarOrArray& high,
const std::vector<int>& shape,
const mx::Shape& shape,
std::optional<mx::Dtype> type,
const std::optional<mx::array>& key_,
mx::StreamOrDevice s) {
@@ -242,7 +242,7 @@ void init_random(nb::module_& parent_module) {
},
"low"_a,
"high"_a,
"shape"_a = std::vector<int>{},
"shape"_a = mx::Shape{},
"dtype"_a.none() = mx::int32,
"key"_a = nb::none(),
"stream"_a = nb::none(),
@@ -268,7 +268,7 @@ void init_random(nb::module_& parent_module) {
m.def(
"bernoulli",
[](const ScalarOrArray& p_,
const std::optional<std::vector<int>> shape,
const std::optional<mx::Shape> shape,
const std::optional<mx::array>& key_,
mx::StreamOrDevice s) {
auto key = key_ ? key_.value() : default_key().next();
@@ -306,7 +306,7 @@ void init_random(nb::module_& parent_module) {
"truncated_normal",
[](const ScalarOrArray& lower_,
const ScalarOrArray& upper_,
const std::optional<std::vector<int>> shape_,
const std::optional<mx::Shape> shape_,
std::optional<mx::Dtype> type,
const std::optional<mx::array>& key_,
mx::StreamOrDevice s) {
@@ -350,14 +350,14 @@ void init_random(nb::module_& parent_module) {
)pbdoc");
m.def(
"gumbel",
[](const std::vector<int>& shape,
[](const mx::Shape& shape,
std::optional<mx::Dtype> type,
const std::optional<mx::array>& key_,
mx::StreamOrDevice s) {
auto key = key_ ? key_.value() : default_key().next();
return mx::random::gumbel(shape, type.value_or(mx::float32), key, s);
},
"shape"_a = std::vector<int>{},
"shape"_a = mx::Shape{},
"dtype"_a.none() = mx::float32,
"key"_a = nb::none(),
"stream"_a = nb::none(),
@@ -384,7 +384,7 @@ void init_random(nb::module_& parent_module) {
"categorical",
[](const mx::array& logits,
int axis,
const std::optional<std::vector<int>> shape,
const std::optional<mx::Shape> shape,
const std::optional<int> num_samples,
const std::optional<mx::array>& key_,
mx::StreamOrDevice s) {
@@ -434,7 +434,7 @@ void init_random(nb::module_& parent_module) {
)pbdoc");
m.def(
"laplace",
[](const std::vector<int>& shape,
[](const mx::Shape& shape,
std::optional<mx::Dtype> type,
float loc,
float scale,
@@ -444,7 +444,7 @@ void init_random(nb::module_& parent_module) {
return mx::random::laplace(
shape, type.value_or(mx::float32), loc, scale, key, s);
},
"shape"_a = std::vector<int>{},
"shape"_a = mx::Shape{},
"dtype"_a.none() = mx::float32,
"loc"_a = 0.0,
"scale"_a = 1.0,
@@ -479,7 +479,7 @@ void init_random(nb::module_& parent_module) {
return mx::random::permutation(std::get<mx::array>(x), axis, key, s);
}
},
"shape"_a = std::vector<int>{},
"x"_a,
"axis"_a = 0,
"key"_a = nb::none(),
"stream"_a = nb::none(),

View File

@@ -1,4 +1,5 @@
// Copyright © 2023-2024 Apple Inc.
#include <nanobind/nanobind.h>
#include <nanobind/stl/optional.h>
#include <nanobind/stl/pair.h>