More shape type (#1705)

* more shape type

* fix
This commit is contained in:
Awni Hannun 2024-12-19 08:08:20 -08:00 committed by GitHub
parent f17536af9c
commit e03f0372b1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
38 changed files with 260 additions and 258 deletions

View File

@ -25,7 +25,7 @@ bool retain_graph() {
} // namespace } // namespace
array::array(const std::complex<float>& val, Dtype dtype /* = complex64 */) array::array(const std::complex<float>& val, Dtype dtype /* = complex64 */)
: array_desc_(std::make_shared<ArrayDesc>(std::vector<int>{}, dtype)) { : array_desc_(std::make_shared<ArrayDesc>(Shape{}, dtype)) {
auto cval = static_cast<complex64_t>(val); auto cval = static_cast<complex64_t>(val);
init(&cval); init(&cval);
} }
@ -61,14 +61,14 @@ std::vector<array> array::make_arrays(
array::array(std::initializer_list<float> data) array::array(std::initializer_list<float> data)
: array_desc_(std::make_shared<ArrayDesc>( : array_desc_(std::make_shared<ArrayDesc>(
std::vector<int>{static_cast<int>(data.size())}, Shape{static_cast<ShapeElem>(data.size())},
float32)) { float32)) {
init(data.begin()); init(data.begin());
} }
array::array(std::initializer_list<int> data, Dtype dtype) array::array(std::initializer_list<int> data, Dtype dtype)
: array_desc_(std::make_shared<ArrayDesc>( : array_desc_(std::make_shared<ArrayDesc>(
std::vector<int>{static_cast<int>(data.size())}, Shape{static_cast<ShapeElem>(data.size())},
dtype)) { dtype)) {
init(data.begin()); init(data.begin());
} }
@ -322,7 +322,7 @@ array::ArrayIterator::ArrayIterator(const array& arr, int idx)
} }
array::ArrayIterator::reference array::ArrayIterator::operator*() const { array::ArrayIterator::reference array::ArrayIterator::operator*() const {
auto start = std::vector<int>(arr.ndim(), 0); auto start = Shape(arr.ndim(), 0);
auto end = arr.shape(); auto end = arr.shape();
auto shape = arr.shape(); auto shape = arr.shape();
shape.erase(shape.begin()); shape.erase(shape.begin());

View File

@ -17,7 +17,8 @@ namespace mlx::core {
class Primitive; class Primitive;
using Deleter = std::function<void(allocator::Buffer)>; using Deleter = std::function<void(allocator::Buffer)>;
using Shape = std::vector<int32_t>; using ShapeElem = int32_t;
using Shape = std::vector<ShapeElem>;
using Strides = std::vector<int64_t>; using Strides = std::vector<int64_t>;
class array { class array {
@ -498,7 +499,7 @@ class array {
template <typename T> template <typename T>
array::array(T val, Dtype dtype /* = TypeToDtype<T>() */) array::array(T val, Dtype dtype /* = TypeToDtype<T>() */)
: array_desc_(std::make_shared<ArrayDesc>(std::vector<int>{}, dtype)) { : array_desc_(std::make_shared<ArrayDesc>(Shape{}, dtype)) {
init(&val); init(&val);
} }
@ -516,7 +517,7 @@ array::array(
std::initializer_list<T> data, std::initializer_list<T> data,
Dtype dtype /* = TypeToDtype<T>() */) Dtype dtype /* = TypeToDtype<T>() */)
: array_desc_(std::make_shared<ArrayDesc>( : array_desc_(std::make_shared<ArrayDesc>(
std::vector<int>{static_cast<int>(data.size())}, Shape{static_cast<ShapeElem>(data.size())},
dtype)) { dtype)) {
init(data.begin()); init(data.begin());
} }

View File

@ -130,7 +130,7 @@ std::string build_lib_name(
bool compiled_check_contiguity( bool compiled_check_contiguity(
const std::vector<array>& inputs, const std::vector<array>& inputs,
const std::vector<int>& shape) { const Shape& shape) {
bool contiguous = true; bool contiguous = true;
bool all_contig = true; bool all_contig = true;
bool all_row_contig = true; bool all_row_contig = true;

View File

@ -56,7 +56,7 @@ inline bool is_scalar(const array& x) {
// Check if we can use a contiguous operation given inputs and the output shape // Check if we can use a contiguous operation given inputs and the output shape
bool compiled_check_contiguity( bool compiled_check_contiguity(
const std::vector<array>& inputs, const std::vector<array>& inputs,
const std::vector<int>& shape); const Shape& shape);
// Allocate space for the outputs possibly with input donation // Allocate space for the outputs possibly with input donation
void compiled_allocate_outputs( void compiled_allocate_outputs(

View File

@ -726,7 +726,7 @@ void explicit_gemm_conv_1D_cpu(
auto conv_dtype = float32; auto conv_dtype = float32;
// Pad input // Pad input
std::vector<int> padded_shape = {N, iH + 2 * padding[0], C}; Shape padded_shape = {N, iH + 2 * padding[0], C};
array in_padded(padded_shape, conv_dtype, nullptr, {}); array in_padded(padded_shape, conv_dtype, nullptr, {});
// Fill with zeros // Fill with zeros
@ -765,7 +765,7 @@ void explicit_gemm_conv_1D_cpu(
in_padded, strided_strides, flags, in_strided_view.size(), 0); in_padded, strided_strides, flags, in_strided_view.size(), 0);
// Materialize strided view // Materialize strided view
std::vector<int> strided_reshape = {N * oH, wH * C}; Shape strided_reshape = {N * oH, wH * C};
array in_strided(strided_reshape, in_strided_view.dtype(), nullptr, {}); array in_strided(strided_reshape, in_strided_view.dtype(), nullptr, {});
copy(in_strided_view, in_strided, CopyType::General); copy(in_strided_view, in_strided, CopyType::General);
@ -843,8 +843,7 @@ void explicit_gemm_conv_2D_cpu(
auto conv_dtype = out.dtype(); auto conv_dtype = out.dtype();
// Pad input // Pad input
std::vector<int> padded_shape = { Shape padded_shape = {N, iH + 2 * padding[0], iW + 2 * padding[1], C};
N, iH + 2 * padding[0], iW + 2 * padding[1], C};
array in_padded(padded_shape, conv_dtype, nullptr, {}); array in_padded(padded_shape, conv_dtype, nullptr, {});
// Fill with zeros // Fill with zeros
@ -881,7 +880,7 @@ void explicit_gemm_conv_2D_cpu(
in_padded, strided_strides, flags, in_strided_view.size(), 0); in_padded, strided_strides, flags, in_strided_view.size(), 0);
// Materialize strided view // Materialize strided view
std::vector<int> strided_reshape = {N * oH * oW, wH * wW * C}; Shape strided_reshape = {N * oH * oW, wH * wW * C};
array in_strided(strided_reshape, in_strided_view.dtype(), nullptr, {}); array in_strided(strided_reshape, in_strided_view.dtype(), nullptr, {});
copy(in_strided_view, in_strided, CopyType::General); copy(in_strided_view, in_strided, CopyType::General);
@ -934,19 +933,19 @@ void explicit_gemm_conv_ND_cpu(
const std::vector<int>& wt_dilation, const std::vector<int>& wt_dilation,
const bool flip) { const bool flip) {
const int N = in.shape(0); // Batch size, should be the same as out.shape(0) const int N = in.shape(0); // Batch size, should be the same as out.shape(0)
const auto iDim = std::vector<int>( const auto iDim =
in.shape().begin() + 1, in.shape().end() - 1); // Input spatial dim Shape(in.shape().begin() + 1, in.shape().end() - 1); // Input spatial dim
const auto oDim = std::vector<int>( const auto oDim = Shape(
out.shape().begin() + 1, out.shape().end() - 1); // Output spatial dim out.shape().begin() + 1, out.shape().end() - 1); // Output spatial dim
const int O = wt.shape(0); // Out channels const int O = wt.shape(0); // Out channels
const int C = wt.shape(-1); // In channels const int C = wt.shape(-1); // In channels
const auto wDim = std::vector<int>( const auto wDim =
wt.shape().begin() + 1, wt.shape().end() - 1); // Weight spatial dim Shape(wt.shape().begin() + 1, wt.shape().end() - 1); // Weight spatial dim
auto conv_dtype = float32; auto conv_dtype = float32;
// Pad input // Pad input
std::vector<int> padded_shape(in.shape().size()); Shape padded_shape(in.shape().size());
padded_shape.front() = N; padded_shape.front() = N;
for (size_t i = 0; i < iDim.size(); i++) { for (size_t i = 0; i < iDim.size(); i++) {
padded_shape[i + 1] = iDim[i] + 2 * padding[i]; padded_shape[i + 1] = iDim[i] + 2 * padding[i];

View File

@ -14,10 +14,10 @@ namespace mlx::core {
namespace { namespace {
template <typename T, typename IdxT = int32_t> template <typename T>
struct StridedIterator { struct StridedIterator {
using iterator_category = std::random_access_iterator_tag; using iterator_category = std::random_access_iterator_tag;
using difference_type = IdxT; using difference_type = int32_t;
using value_type = T; using value_type = T;
using reference = value_type&; using reference = value_type&;
using pointer = value_type*; using pointer = value_type*;

View File

@ -107,7 +107,7 @@ struct ContiguousIterator {
: shape_(a.shape()), strides_(a.strides()) { : shape_(a.shape()), strides_(a.strides()) {
if (!shape_.empty()) { if (!shape_.empty()) {
std::tie(shape_, strides_) = collapse_contiguous_dims(shape_, strides_); std::tie(shape_, strides_) = collapse_contiguous_dims(shape_, strides_);
pos_ = std::vector<int>(shape_.size(), 0); pos_ = Shape(shape_.size(), 0);
} }
} }

View File

@ -34,7 +34,7 @@ void explicit_gemm_conv_ND_gpu(
int implicit_K = wt.size() / conv_params.O; int implicit_K = wt.size() / conv_params.O;
int implicit_N = conv_params.O; int implicit_N = conv_params.O;
// Prepare unfolding array // Prepare unfolding array
std::vector<int> unfolded_shape{implicit_M, implicit_K}; Shape unfolded_shape{implicit_M, implicit_K};
array in_unfolded(unfolded_shape, in.dtype(), nullptr, {}); array in_unfolded(unfolded_shape, in.dtype(), nullptr, {});
in_unfolded.set_data(allocator::malloc_or_wait(in_unfolded.nbytes())); in_unfolded.set_data(allocator::malloc_or_wait(in_unfolded.nbytes()));
@ -113,7 +113,7 @@ void explicit_gemm_conv_group_ND_gpu(
} }
// Prepare unfolding array // Prepare unfolding array
std::vector<int> unfolded_shape{implicit_M, implicit_K * groups}; Shape unfolded_shape{implicit_M, implicit_K * groups};
array in_unfolded(unfolded_shape, in.dtype(), nullptr, {}); array in_unfolded(unfolded_shape, in.dtype(), nullptr, {});
in_unfolded.set_data(allocator::malloc_or_wait(in_unfolded.nbytes())); in_unfolded.set_data(allocator::malloc_or_wait(in_unfolded.nbytes()));
@ -192,12 +192,12 @@ void conv_1D_gpu(
bool flip) { bool flip) {
// Make conv params // Make conv params
MLXConvParams<1> conv_params{ MLXConvParams<1> conv_params{
/* const int N = */ in.shape(0), /* const int N = */ static_cast<int>(in.shape(0)),
/* const int C = */ in.shape(2), /* const int C = */ static_cast<int>(in.shape(2)),
/* const int O = */ wt.shape(0), /* const int O = */ static_cast<int>(wt.shape(0)),
/* const int iS[NDIM] = */ {in.shape(1)}, /* const int iS[NDIM] = */ {static_cast<int>(in.shape(1))},
/* const int wS[NDIM] = */ {wt.shape(1)}, /* const int wS[NDIM] = */ {static_cast<int>(wt.shape(1))},
/* const int oS[NDIM] = */ {out.shape(1)}, /* const int oS[NDIM] = */ {static_cast<int>(out.shape(1))},
/* const int str[NDIM] = */ {wt_strides[0]}, /* const int str[NDIM] = */ {wt_strides[0]},
/* const int pad[NDIM] = */ {padding[0]}, /* const int pad[NDIM] = */ {padding[0]},
/* const int kdil[NDIM] = */ {wt_dilation[0]}, /* const int kdil[NDIM] = */ {wt_dilation[0]},
@ -541,7 +541,7 @@ void winograd_conv_2D_gpu(
array out, array out,
const MLXConvParams<2>& conv_params, const MLXConvParams<2>& conv_params,
std::vector<array>& copies_w) { std::vector<array>& copies_w) {
std::vector<int> padded_shape = { Shape padded_shape = {
conv_params.N, conv_params.N,
conv_params.iS[0] + 2 * conv_params.pad[0], conv_params.iS[0] + 2 * conv_params.pad[0],
conv_params.iS[1] + 2 * conv_params.pad[1], conv_params.iS[1] + 2 * conv_params.pad[1],
@ -550,7 +550,7 @@ void winograd_conv_2D_gpu(
padded_shape[1] = 6 * ((padded_shape[1] - 2 + 5) / 6) + 2; padded_shape[1] = 6 * ((padded_shape[1] - 2 + 5) / 6) + 2;
padded_shape[2] = 6 * ((padded_shape[2] - 2 + 5) / 6) + 2; padded_shape[2] = 6 * ((padded_shape[2] - 2 + 5) / 6) + 2;
array in_padded(padded_shape, in.dtype(), nullptr, {}); array in_padded(std::move(padded_shape), in.dtype(), nullptr, {});
// Fill with zeros // Fill with zeros
array zero_arr = array(0, in.dtype()); array zero_arr = array(0, in.dtype());
@ -575,12 +575,16 @@ void winograd_conv_2D_gpu(
copies_w.push_back(in_padded); copies_w.push_back(in_padded);
MLXConvParams<2> conv_params_updated{ MLXConvParams<2> conv_params_updated{
/* const int N = */ in_padded.shape(0), /* const int N = */ static_cast<int>(in_padded.shape(0)),
/* const int C = */ in_padded.shape(3), /* const int C = */ static_cast<int>(in_padded.shape(3)),
/* const int O = */ wt.shape(0), /* const int O = */ static_cast<int>(wt.shape(0)),
/* const int iS[NDIM] = */ {in_padded.shape(1), in_padded.shape(2)}, /* const int iS[NDIM] = */
/* const int wS[NDIM] = */ {wt.shape(1), wt.shape(2)}, {static_cast<int>(in_padded.shape(1)),
/* const int oS[NDIM] = */ {out.shape(1), out.shape(2)}, static_cast<int>(in_padded.shape(2))},
/* const int wS[NDIM] = */
{static_cast<int>(wt.shape(1)), static_cast<int>(wt.shape(2))},
/* const int oS[NDIM] = */
{static_cast<int>(out.shape(1)), static_cast<int>(out.shape(2))},
/* const int str[NDIM] = */ {1, 1}, /* const int str[NDIM] = */ {1, 1},
/* const int pad[NDIM] = */ {0, 0}, /* const int pad[NDIM] = */ {0, 0},
/* const int kdil[NDIM] = */ {1, 1}, /* const int kdil[NDIM] = */ {1, 1},
@ -607,8 +611,8 @@ void winograd_conv_2D_gpu(
int N_tiles = N_tiles_n * N_tiles_h * N_tiles_w; int N_tiles = N_tiles_n * N_tiles_h * N_tiles_w;
// Do filter transform // Do filter transform
std::vector<int> filt_wg_shape = {8 * 8, conv_params.C, conv_params.O}; Shape filt_wg_shape = {8 * 8, conv_params.C, conv_params.O};
array filt_wg(filt_wg_shape, wt.dtype(), nullptr, {}); array filt_wg(std::move(filt_wg_shape), wt.dtype(), nullptr, {});
filt_wg.set_data(allocator::malloc_or_wait(filt_wg.nbytes())); filt_wg.set_data(allocator::malloc_or_wait(filt_wg.nbytes()));
copies_w.push_back(filt_wg); copies_w.push_back(filt_wg);
{ {
@ -634,8 +638,8 @@ void winograd_conv_2D_gpu(
} }
// Do input transform // Do input transform
std::vector<int> inp_wg_shape = {8 * 8, N_tiles, conv_params.C}; Shape inp_wg_shape = {8 * 8, N_tiles, conv_params.C};
array inp_wg(inp_wg_shape, in.dtype(), nullptr, {}); array inp_wg(std::move(inp_wg_shape), in.dtype(), nullptr, {});
inp_wg.set_data(allocator::malloc_or_wait(inp_wg.nbytes())); inp_wg.set_data(allocator::malloc_or_wait(inp_wg.nbytes()));
copies_w.push_back(inp_wg); copies_w.push_back(inp_wg);
{ {
@ -661,8 +665,8 @@ void winograd_conv_2D_gpu(
} }
// Do batched gemm // Do batched gemm
std::vector<int> out_wg_shape = {8 * 8, N_tiles, conv_params.O}; Shape out_wg_shape = {8 * 8, N_tiles, conv_params.O};
array out_wg(out_wg_shape, in.dtype(), nullptr, {}); array out_wg(std::move(out_wg_shape), in.dtype(), nullptr, {});
out_wg.set_data(allocator::malloc_or_wait(out_wg.nbytes())); out_wg.set_data(allocator::malloc_or_wait(out_wg.nbytes()));
copies_w.push_back(out_wg); copies_w.push_back(out_wg);
{ {
@ -723,12 +727,15 @@ void conv_2D_gpu(
std::vector<array>& copies) { std::vector<array>& copies) {
// Make conv params // Make conv params
MLXConvParams<2> conv_params{ MLXConvParams<2> conv_params{
/* const int N = */ in.shape(0), /* const int N = */ static_cast<int>(in.shape(0)),
/* const int C = */ in.shape(3), /* const int C = */ static_cast<int>(in.shape(3)),
/* const int O = */ wt.shape(0), /* const int O = */ static_cast<int>(wt.shape(0)),
/* const int iS[NDIM] = */ {in.shape(1), in.shape(2)}, /* const int iS[NDIM] = */
/* const int wS[NDIM] = */ {wt.shape(1), wt.shape(2)}, {static_cast<int>(in.shape(1)), static_cast<int>(in.shape(2))},
/* const int oS[NDIM] = */ {out.shape(1), out.shape(2)}, /* const int wS[NDIM] = */
{static_cast<int>(wt.shape(1)), static_cast<int>(wt.shape(2))},
/* const int oS[NDIM] = */
{static_cast<int>(out.shape(1)), static_cast<int>(out.shape(2))},
/* const int str[NDIM] = */ {wt_strides[0], wt_strides[1]}, /* const int str[NDIM] = */ {wt_strides[0], wt_strides[1]},
/* const int pad[NDIM] = */ {padding[0], padding[1]}, /* const int pad[NDIM] = */ {padding[0], padding[1]},
/* const int kdil[NDIM] = */ {wt_dilation[0], wt_dilation[1]}, /* const int kdil[NDIM] = */ {wt_dilation[0], wt_dilation[1]},
@ -800,12 +807,21 @@ void conv_3D_gpu(
std::vector<array>& copies) { std::vector<array>& copies) {
// Make conv params // Make conv params
MLXConvParams<3> conv_params{ MLXConvParams<3> conv_params{
/* const int N = */ in.shape(0), /* const int N = */ static_cast<int>(in.shape(0)),
/* const int C = */ in.shape(4), /* const int C = */ static_cast<int>(in.shape(4)),
/* const int O = */ wt.shape(0), /* const int O = */ static_cast<int>(wt.shape(0)),
/* const int iS[NDIM] = */ {in.shape(1), in.shape(2), in.shape(3)}, /* const int iS[NDIM] = */
/* const int wS[NDIM] = */ {wt.shape(1), wt.shape(2), wt.shape(3)}, {static_cast<int>(in.shape(1)),
/* const int oS[NDIM] = */ {out.shape(1), out.shape(2), out.shape(3)}, static_cast<int>(in.shape(2)),
static_cast<int>(in.shape(3))},
/* const int wS[NDIM] = */
{static_cast<int>(wt.shape(1)),
static_cast<int>(wt.shape(2)),
static_cast<int>(wt.shape(3))},
/* const int oS[NDIM] = */
{static_cast<int>(out.shape(1)),
static_cast<int>(out.shape(2)),
static_cast<int>(out.shape(3))},
/* const int str[NDIM] = */ {wt_strides[0], wt_strides[1], wt_strides[2]}, /* const int str[NDIM] = */ {wt_strides[0], wt_strides[1], wt_strides[2]},
/* const int pad[NDIM] = */ {padding[0], padding[1], padding[2]}, /* const int pad[NDIM] = */ {padding[0], padding[1], padding[2]},
/* const int kdil[NDIM] = */ /* const int kdil[NDIM] = */

View File

@ -635,7 +635,7 @@ void strided_reduce_longcolumn(
} }
// Prepare the temporary accumulator // Prepare the temporary accumulator
std::vector<int> intermediate_shape; Shape intermediate_shape;
intermediate_shape.reserve(out.ndim() + 1); intermediate_shape.reserve(out.ndim() + 1);
intermediate_shape.push_back(outer_blocks); intermediate_shape.push_back(outer_blocks);
intermediate_shape.insert( intermediate_shape.insert(
@ -806,7 +806,7 @@ void strided_reduce_2pass(
auto [in_type, out_type] = remap_reduce_types(in, op_name); auto [in_type, out_type] = remap_reduce_types(in, op_name);
// Prepare the temporary accumulator // Prepare the temporary accumulator
std::vector<int> intermediate_shape; Shape intermediate_shape;
intermediate_shape.reserve(out.ndim() + 1); intermediate_shape.reserve(out.ndim() + 1);
intermediate_shape.push_back(32); intermediate_shape.push_back(32);
intermediate_shape.insert( intermediate_shape.insert(

View File

@ -63,8 +63,8 @@ void pad_gpu(
const array& in, const array& in,
const array& val, const array& val,
array& out, array& out,
std::vector<int> axes, const std::vector<int>& axes,
std::vector<int> low_pad_size, const Shape& low_pad_size,
const Stream& s) { const Stream& s) {
// Fill output with val // Fill output with val
fill_gpu(val, out, s); fill_gpu(val, out, s);

View File

@ -23,8 +23,8 @@ void pad_gpu(
const array& in, const array& in,
const array& val, const array& val,
array& out, array& out,
std::vector<int> axes, const std::vector<int>& axes,
std::vector<int> low_pad_size, const Shape& low_pad_size,
const Stream& s); const Stream& s);
} // namespace mlx::core } // namespace mlx::core

View File

@ -82,7 +82,7 @@ array send(
} }
array recv( array recv(
std::vector<int> shape, Shape shape,
Dtype dtype, Dtype dtype,
int src, int src,
std::optional<Group> group_ /* = std::nullopt */, std::optional<Group> group_ /* = std::nullopt */,

View File

@ -26,7 +26,7 @@ array send(
StreamOrDevice s = {}); StreamOrDevice s = {});
array recv( array recv(
std::vector<int> shape, Shape shape,
Dtype dtype, Dtype dtype,
int src, int src,
std::optional<Group> group = std::nullopt, std::optional<Group> group = std::nullopt,

View File

@ -91,7 +91,7 @@ std::vector<array> AllGather::vjp(
const std::vector<int>& argnums, const std::vector<int>& argnums,
const std::vector<array>& outputs) { const std::vector<array>& outputs) {
auto g = group(); auto g = group();
std::vector<int> starts(primals[0].ndim(), 0); Shape starts(primals[0].ndim(), 0);
auto stops = primals[0].shape(); auto stops = primals[0].shape();
starts[0] = g.rank() * stops[0]; starts[0] = g.rank() * stops[0];
stops[0] += starts[0]; stops[0] += starts[0];

View File

@ -108,7 +108,7 @@ bool disjoint(const CharSet& x, const CharSet& y) {
} }
template <typename T> template <typename T>
size_t term_size(const T& term, std::unordered_map<char, int> dict) { size_t term_size(const T& term, std::unordered_map<char, ShapeElem> dict) {
size_t size = 1; size_t size = 1;
for (auto c : term) { for (auto c : term) {
size *= dict[c]; size *= dict[c];
@ -120,7 +120,7 @@ size_t flop_count(
const CharSet& term, const CharSet& term,
bool inner, bool inner,
int num_terms, int num_terms,
std::unordered_map<char, int> dict) { std::unordered_map<char, ShapeElem> dict) {
size_t size = term_size(term, dict); size_t size = term_size(term, dict);
auto op_factor = 1; auto op_factor = 1;
if ((num_terms - 1) > op_factor) { if ((num_terms - 1) > op_factor) {
@ -135,7 +135,7 @@ size_t flop_count(
std::pair<size_t, int> compute_cost_and_scaling( std::pair<size_t, int> compute_cost_and_scaling(
const std::vector<Subscript>& inputs, const std::vector<Subscript>& inputs,
const Subscript& output, const Subscript& output,
std::unordered_map<char, int> dim_map) { std::unordered_map<char, ShapeElem> dim_map) {
CharSet contractions; CharSet contractions;
for (auto& in : inputs) { for (auto& in : inputs) {
contractions.insert(in.set.begin(), in.set.end()); contractions.insert(in.set.begin(), in.set.end());
@ -155,7 +155,7 @@ std::pair<size_t, int> compute_cost_and_scaling(
std::tuple<std::vector<PathNode>, size_t, int> greedy_path( std::tuple<std::vector<PathNode>, size_t, int> greedy_path(
std::vector<Subscript> inputs, std::vector<Subscript> inputs,
const Subscript& output, const Subscript& output,
std::unordered_map<char, int> dim_map, std::unordered_map<char, ShapeElem> dim_map,
size_t cost_limit, size_t cost_limit,
size_t memory_limit) { size_t memory_limit) {
// Helper struct for building the greedy path // Helper struct for building the greedy path
@ -457,7 +457,8 @@ array collapse_repeats(array in, Subscript& subscript, StreamOrDevice s) {
} }
Shape idx_shape(n_expand--, 1); Shape idx_shape(n_expand--, 1);
idx_shape[0] = in.shape(axes.back()); idx_shape[0] = in.shape(axes.back());
auto idx = reshape(arange(in.shape(axes.back()), s), idx_shape, s); auto idx = reshape(
arange(static_cast<ShapeElem>(in.shape(axes.back())), s), idx_shape, s);
for (int i = 0; i < v; ++i) { for (int i = 0; i < v; ++i) {
indices.push_back(idx); indices.push_back(idx);
} }
@ -663,7 +664,7 @@ std::pair<std::vector<PathNode>, PathInfo> einsum_path_helper(
} }
Subscript output(out_subscript, std::move(out_set)); Subscript output(out_subscript, std::move(out_set));
std::unordered_map<char, int> dim_map; std::unordered_map<char, ShapeElem> dim_map;
std::vector<Subscript> inputs; std::vector<Subscript> inputs;
for (int i = 0; i < in_subscripts.size(); ++i) { for (int i = 0; i < in_subscripts.size(); ++i) {
auto& in = in_subscripts[i]; auto& in = in_subscripts[i];
@ -680,7 +681,7 @@ std::pair<std::vector<PathNode>, PathInfo> einsum_path_helper(
// Check repeat subscripts are valid // Check repeat subscripts are valid
if (in_set.size() < in.size()) { if (in_set.size() < in.size()) {
std::unordered_map<char, int> local_dims; std::unordered_map<char, ShapeElem> local_dims;
for (int j = 0; j < in.size(); ++j) { for (int j = 0; j < in.size(); ++j) {
auto dim = operands[i].shape(j); auto dim = operands[i].shape(j);
auto inserted = local_dims.insert({in[j], dim}); auto inserted = local_dims.insert({in[j], dim});

View File

@ -670,8 +670,7 @@ array scaled_dot_product_attention(
supports_sdpa_full || supports_sdpa_vector; supports_sdpa_full || supports_sdpa_vector;
if (implementation_supports_use_case) { if (implementation_supports_use_case) {
auto out_shape = auto out_shape = Shape{q.shape(0), q.shape(1), q.shape(2), v.shape(-1)};
std::vector<int>({q.shape(0), q.shape(1), q.shape(2), v.shape(-1)});
return array( return array(
std::move(out_shape), std::move(out_shape),
final_type, final_type,

View File

@ -59,7 +59,7 @@ typedef std::variant<int, bool, Dtype> TemplateArg;
typedef std::function<std::vector<array>( typedef std::function<std::vector<array>(
const std::vector<array>&, const std::vector<array>&,
const std::vector<std::vector<int>>&, const std::vector<Shape>&,
const std::vector<Dtype>&, const std::vector<Dtype>&,
std::tuple<int, int, int>, std::tuple<int, int, int>,
std::tuple<int, int, int>, std::tuple<int, int, int>,

View File

@ -47,8 +47,8 @@ std::optional<Dtype> gguf_type_to_dtype(const uint32_t& gguf_type) {
} }
} }
std::vector<int> get_shape(const gguf_tensor& tensor) { Shape get_shape(const gguf_tensor& tensor) {
std::vector<int> shape; Shape shape;
// The dimension order in GGML is the reverse of the order used in MLX. // The dimension order in GGML is the reverse of the order used in MLX.
for (int i = tensor.ndim - 1; i >= 0; i--) { for (int i = tensor.ndim - 1; i >= 0; i--) {
shape.push_back(tensor.dim[i]); shape.push_back(tensor.dim[i]);

View File

@ -12,7 +12,7 @@ extern "C" {
namespace mlx::core { namespace mlx::core {
std::vector<int> get_shape(const gguf_tensor& tensor); Shape get_shape(const gguf_tensor& tensor);
void gguf_load_quantized( void gguf_load_quantized(
std::unordered_map<std::string, array>& a, std::unordered_map<std::string, array>& a,
const gguf_tensor& tensor); const gguf_tensor& tensor);

View File

@ -109,7 +109,7 @@ void gguf_load_quantized(
std::string name(tensor.name, tensor.namelen); std::string name(tensor.name, tensor.namelen);
std::vector<int> shape = get_shape(tensor); auto shape = get_shape(tensor);
const uint64_t weights_per_block = 32; const uint64_t weights_per_block = 32;
if (shape[shape.size() - 1] % weights_per_block != 0) { if (shape[shape.size() - 1] % weights_per_block != 0) {
std::ostringstream msg; std::ostringstream msg;
@ -118,7 +118,7 @@ void gguf_load_quantized(
throw std::runtime_error(msg.str()); throw std::runtime_error(msg.str());
} }
std::vector<int> weights_shape = shape; auto weights_shape = shape;
weights_shape.back() /= (weights_per_byte * 4); weights_shape.back() /= (weights_per_byte * 4);
auto w_nbytes = uint32.size() * auto w_nbytes = uint32.size() *
std::accumulate(weights_shape.begin(), std::accumulate(weights_shape.begin(),

View File

@ -271,7 +271,7 @@ array load(std::shared_ptr<io::Reader> in_stream, StreamOrDevice s) {
bool col_contiguous = header[34] == 'T'; bool col_contiguous = header[34] == 'T';
// Read array shape from header // Read array shape from header
std::vector<int> shape; Shape shape;
size_t st = header.find_last_of('(') + 1; size_t st = header.find_last_of('(') + 1;
size_t ed = header.find_last_of(')'); size_t ed = header.find_last_of(')');

View File

@ -219,15 +219,15 @@ std::vector<array> svd(const array& a, StreamOrDevice s /* = {} */) {
const auto n = a.shape(-1); const auto n = a.shape(-1);
const auto rank = a.ndim(); const auto rank = a.ndim();
std::vector<int> u_shape = a.shape(); auto u_shape = a.shape();
u_shape[rank - 2] = m; u_shape[rank - 2] = m;
u_shape[rank - 1] = m; u_shape[rank - 1] = m;
std::vector<int> s_shape = a.shape(); auto s_shape = a.shape();
s_shape.pop_back(); s_shape.pop_back();
s_shape[rank - 2] = std::min(m, n); s_shape[rank - 2] = std::min(m, n);
std::vector<int> vt_shape = a.shape(); auto vt_shape = a.shape();
vt_shape[rank - 2] = n; vt_shape[rank - 2] = n;
vt_shape[rank - 1] = n; vt_shape[rank - 1] = n;
@ -328,8 +328,8 @@ array pinv(const array& a, StreamOrDevice s /* = {} */) {
array S = outs[1]; array S = outs[1];
array V = outs[2]; array V = outs[2];
std::vector<int> starts(a.ndim(), 0); Shape starts(a.ndim(), 0);
std::vector<int> ends = a.shape(); auto ends = a.shape();
int i = a.ndim() - 2; int i = a.ndim() - 2;
int j = a.ndim() - 1; int j = a.ndim() - 1;
@ -479,7 +479,7 @@ array eigvalsh(
std::string UPLO /* = "L" */, std::string UPLO /* = "L" */,
StreamOrDevice s /* = {} */) { StreamOrDevice s /* = {} */) {
validate_eigh(a, "[linalg::eigvalsh]"); validate_eigh(a, "[linalg::eigvalsh]");
std::vector<int> out_shape(a.shape().begin(), a.shape().end() - 1); Shape out_shape(a.shape().begin(), a.shape().end() - 1);
return array( return array(
std::move(out_shape), std::move(out_shape),
a.dtype(), a.dtype(),
@ -493,7 +493,7 @@ std::pair<array, array> eigh(
StreamOrDevice s /* = {} */) { StreamOrDevice s /* = {} */) {
validate_eigh(a, "[linalg::eigh]"); validate_eigh(a, "[linalg::eigh]");
auto out = array::make_arrays( auto out = array::make_arrays(
{std::vector<int>(a.shape().begin(), a.shape().end() - 1), a.shape()}, {Shape(a.shape().begin(), a.shape().end() - 1), a.shape()},
{a.dtype(), a.dtype()}, {a.dtype(), a.dtype()},
std::make_shared<Eigh>(to_stream(s), UPLO, true), std::make_shared<Eigh>(to_stream(s), UPLO, true),
{a}); {a});

View File

@ -649,7 +649,7 @@ normalize_slice(const Shape& shape, Shape& start, Shape& stop, Shape& strides) {
// Clamp to bounds // Clamp to bounds
auto st = std::min(s, n - 1); auto st = std::min(s, n - 1);
auto ed = std::max(-1, e); auto ed = e > -1 ? e : -1;
start[i] = st; start[i] = st;
stop[i] = ed > st ? st : ed; stop[i] = ed > st ? st : ed;
@ -659,8 +659,8 @@ normalize_slice(const Shape& shape, Shape& start, Shape& stop, Shape& strides) {
} else { } else {
// Clamp to bounds // Clamp to bounds
auto st = std::max(0, std::min(s, n)); auto st = std::max(static_cast<ShapeElem>(0), std::min(s, n));
auto ed = std::max(0, std::min(e, n)); auto ed = std::max(static_cast<ShapeElem>(0), std::min(e, n));
start[i] = st; start[i] = st;
stop[i] = ed < st ? st : ed; stop[i] = ed < st ? st : ed;
@ -765,7 +765,7 @@ array slice_update(
std::vector<array> split( std::vector<array> split(
const array& a, const array& a,
const std::vector<int>& indices, const Shape& indices,
int axis, int axis,
StreamOrDevice s /* = {} */) { StreamOrDevice s /* = {} */) {
auto ax = axis < 0 ? axis + a.ndim() : axis; auto ax = axis < 0 ? axis + a.ndim() : axis;
@ -809,10 +809,8 @@ std::vector<array> split(
return res; return res;
} }
std::vector<array> split( std::vector<array>
const array& a, split(const array& a, const Shape& indices, StreamOrDevice s /* = {} */) {
const std::vector<int>& indices,
StreamOrDevice s /* = {} */) {
return split(a, indices, 0, s); return split(a, indices, 0, s);
} }
@ -834,7 +832,7 @@ split(const array& a, int num_splits, int axis, StreamOrDevice s /* = {} */) {
throw std::invalid_argument(msg.str()); throw std::invalid_argument(msg.str());
} }
auto split_size = q_and_r.quot; auto split_size = q_and_r.quot;
std::vector<int> indices(num_splits - 1); Shape indices(num_splits - 1);
for (int i = 0; i < indices.size(); ++i) { for (int i = 0; i < indices.size(); ++i) {
indices[i] = (i + 1) * split_size; indices[i] = (i + 1) * split_size;
} }
@ -1104,7 +1102,7 @@ array edge_pad(
/** Pad an array with a constant value */ /** Pad an array with a constant value */
array pad( array pad(
const array& a, const array& a,
const Shape& axes, const std::vector<int>& axes,
const Shape& low_pad_size, const Shape& low_pad_size,
const Shape& high_pad_size, const Shape& high_pad_size,
const array& pad_value /*= array(0)*/, const array& pad_value /*= array(0)*/,
@ -1904,9 +1902,11 @@ array min(
array argmin(const array& a, bool keepdims, StreamOrDevice s /* = {} */) { array argmin(const array& a, bool keepdims, StreamOrDevice s /* = {} */) {
int size = a.size(); int size = a.size();
auto result = argmin(reshape(a, {size}, s), 0, true, s); auto result = argmin(flatten(a, s), 0, true, s);
if (keepdims) { if (keepdims) {
result = reshape(result, std::vector<int>(a.shape().size(), 1), s); std::vector<int> axes(a.ndim() - 1);
std::iota(axes.begin(), axes.end(), 0);
result = expand_dims(result, axes, s);
} else { } else {
result = squeeze(result, s); result = squeeze(result, s);
} }
@ -1940,9 +1940,11 @@ array argmin(
array argmax(const array& a, bool keepdims, StreamOrDevice s /* = {} */) { array argmax(const array& a, bool keepdims, StreamOrDevice s /* = {} */) {
int size = a.size(); int size = a.size();
auto result = argmax(reshape(a, {size}, s), 0, true, s); auto result = argmax(flatten(a, s), 0, true, s);
if (keepdims) { if (keepdims) {
result = reshape(result, Shape(a.shape().size(), 1), s); std::vector<int> axes(a.ndim() - 1);
std::iota(axes.begin(), axes.end(), 0);
result = expand_dims(result, axes, s);
} else { } else {
result = squeeze(result, s); result = squeeze(result, s);
} }
@ -3238,8 +3240,8 @@ inline int dilate_size(int dim, int dil) {
} }
Shape conv_out_shape( Shape conv_out_shape(
const std::vector<int>& in_shape, const Shape& in_shape,
const std::vector<int>& wt_shape, const Shape& wt_shape,
const std::vector<int>& strides, const std::vector<int>& strides,
const std::vector<int>& pads_lo, const std::vector<int>& pads_lo,
const std::vector<int>& pads_hi, const std::vector<int>& pads_hi,
@ -4329,16 +4331,16 @@ array diagonal(
"[diagonal] axis1 and axis2 cannot be the same axis"); "[diagonal] axis1 and axis2 cannot be the same axis");
} }
auto off1 = std::max(-offset, 0); ShapeElem off1 = std::max(-offset, 0);
auto off2 = std::max(offset, 0); ShapeElem off2 = std::max(offset, 0);
auto diag_size = std::min(a.shape(ax1) - off1, a.shape(ax2) - off2); auto diag_size = std::min(a.shape(ax1) - off1, a.shape(ax2) - off2);
diag_size = std::max(diag_size, 0); diag_size = diag_size < 0 ? 0 : diag_size;
std::vector<array> indices = { std::vector<array> indices = {
arange(off1, off1 + diag_size, s), arange(off2, off2 + diag_size, s)}; arange(off1, off1 + diag_size, s), arange(off2, off2 + diag_size, s)};
std::vector<int> slice_sizes = a.shape(); Shape slice_sizes = a.shape();
slice_sizes[ax1] = 1; slice_sizes[ax1] = 1;
slice_sizes[ax2] = 1; slice_sizes[ax2] = 1;

View File

@ -189,13 +189,10 @@ array slice_update(
std::vector<array> std::vector<array>
split(const array& a, int num_splits, int axis, StreamOrDevice s = {}); split(const array& a, int num_splits, int axis, StreamOrDevice s = {});
std::vector<array> split(const array& a, int num_splits, StreamOrDevice s = {}); std::vector<array> split(const array& a, int num_splits, StreamOrDevice s = {});
std::vector<array> split(
const array& a,
const std::vector<int>& indices,
int axis,
StreamOrDevice s = {});
std::vector<array> std::vector<array>
split(const array& a, const std::vector<int>& indices, StreamOrDevice s = {}); split(const array& a, const Shape& indices, int axis, StreamOrDevice s = {});
std::vector<array>
split(const array& a, const Shape& indices, StreamOrDevice s = {});
/** A vector of coordinate arrays from coordinate vectors. */ /** A vector of coordinate arrays from coordinate vectors. */
std::vector<array> meshgrid( std::vector<array> meshgrid(
@ -253,8 +250,8 @@ array moveaxis(
array pad( array pad(
const array& a, const array& a,
const std::vector<int>& axes, const std::vector<int>& axes,
const std::vector<int>& low_pad_size, const Shape& low_pad_size,
const std::vector<int>& high_pad_size, const Shape& high_pad_size,
const array& pad_value = array(0), const array& pad_value = array(0),
const std::string mode = "constant", const std::string mode = "constant",
StreamOrDevice s = {}); StreamOrDevice s = {});
@ -1453,7 +1450,11 @@ array view(const array& a, const Dtype& dtype, StreamOrDevice s = {});
array roll(const array& a, int shift, StreamOrDevice s = {}); array roll(const array& a, int shift, StreamOrDevice s = {});
array roll(const array& a, const Shape& shift, StreamOrDevice s = {}); array roll(const array& a, const Shape& shift, StreamOrDevice s = {});
array roll(const array& a, int shift, int axis, StreamOrDevice s = {}); array roll(const array& a, int shift, int axis, StreamOrDevice s = {});
array roll(const array& a, int shift, const Shape& axes, StreamOrDevice s = {}); array roll(
const array& a,
int shift,
const std::vector<int>& axes,
StreamOrDevice s = {});
array roll(const array& a, const Shape& shift, int axis, StreamOrDevice s = {}); array roll(const array& a, const Shape& shift, int axis, StreamOrDevice s = {});
array roll( array roll(
const array& a, const array& a,

View File

@ -817,10 +817,10 @@ std::vector<array> Concatenate::vjp(
const std::vector<int>& argnums, const std::vector<int>& argnums,
const std::vector<array>&) { const std::vector<array>&) {
auto& cotan = cotangents[0]; auto& cotan = cotangents[0];
std::vector<int> start(cotan.ndim(), 0); Shape start(cotan.ndim(), 0);
std::vector<int> stop = cotan.shape(); Shape stop = cotan.shape();
std::vector<int> sizes; Shape sizes;
sizes.push_back(0); sizes.push_back(0);
for (auto& p : primals) { for (auto& p : primals) {
sizes.push_back(p.shape(axis_)); sizes.push_back(p.shape(axis_));
@ -956,9 +956,9 @@ array conv_weight_backward_patches(
const std::vector<int>& padding, const std::vector<int>& padding,
StreamOrDevice s) { StreamOrDevice s) {
// Resolve Padded input shapes and strides // Resolve Padded input shapes and strides
std::vector<int> padding_starts(in.ndim(), 0); Shape padding_starts(in.ndim(), 0);
std::vector<int> padding_ends = in.shape(); auto padding_ends = in.shape();
std::vector<int> in_padded_shape = in.shape(); auto in_padded_shape = in.shape();
// padded shape // padded shape
for (int i = 1; i < in.ndim() - 1; i++) { for (int i = 1; i < in.ndim() - 1; i++) {
@ -976,8 +976,9 @@ array conv_weight_backward_patches(
// Pad input // Pad input
std::vector<int> padded_axes(in.ndim() - 2, 0); std::vector<int> padded_axes(in.ndim() - 2, 0);
std::iota(padded_axes.begin(), padded_axes.end(), 1); std::iota(padded_axes.begin(), padded_axes.end(), 1);
Shape padding_(padding.begin(), padding.end());
auto in_padded = pad( auto in_padded = pad(
in, padded_axes, padding, padding, array(0, in.dtype()), "constant", s); in, padded_axes, padding_, padding_, array(0, in.dtype()), "constant", s);
// Resolve strided patches // Resolve strided patches
@ -1797,7 +1798,7 @@ std::vector<array> FFT::vjp(
std::vector<int> axes(axes_.begin(), axes_.end()); std::vector<int> axes(axes_.begin(), axes_.end());
if (real_ && inverse_) { if (real_ && inverse_) {
auto out = fft::fftn(cotangents[0], axes, stream()); auto out = fft::fftn(cotangents[0], axes, stream());
auto start = std::vector<int>(out.ndim(), 0); auto start = Shape(out.ndim(), 0);
auto stop = in.shape(); auto stop = in.shape();
out = slice(out, start, stop, stream()); out = slice(out, start, stop, stream());
auto mask_shape = out.shape(); auto mask_shape = out.shape();
@ -1809,7 +1810,7 @@ std::vector<array> FFT::vjp(
mask = concatenate({pad, mask, pad}, axes_.back(), stream()); mask = concatenate({pad, mask, pad}, axes_.back(), stream());
return {multiply(mask, out, stream())}; return {multiply(mask, out, stream())};
} else if (real_) { } else if (real_) {
std::vector<int> n; Shape n;
for (auto ax : axes_) { for (auto ax : axes_) {
n.push_back(in.shape()[ax]); n.push_back(in.shape()[ax]);
} }
@ -1934,10 +1935,11 @@ std::pair<std::vector<array>, std::vector<int>> Gather::vmap(
} }
if (indices_vmapped) { if (indices_vmapped) {
// Make a new index array for the vmapped dimension // Make a new index array for the vmapped dimension
auto vmap_inds = arange(0, src.shape(axes[0]), stream()); auto vmap_inds =
arange(static_cast<ShapeElem>(0), src.shape(axes[0]), stream());
// Reshape it so it broadcasts with other index arrays // Reshape it so it broadcasts with other index arrays
{ {
auto shape = std::vector<int>(idx_dims, 1); auto shape = Shape(idx_dims, 1);
shape[out_ax] = vmap_inds.size(); shape[out_ax] = vmap_inds.size();
vmap_inds = reshape(vmap_inds, std::move(shape), stream()); vmap_inds = reshape(vmap_inds, std::move(shape), stream());
} }
@ -2628,8 +2630,8 @@ std::vector<array> Pad::vjp(
assert(argnums.size() == 1 && argnums[0] == 0); assert(argnums.size() == 1 && argnums[0] == 0);
auto& cotan = cotangents[0]; auto& cotan = cotangents[0];
std::vector<int> start(cotan.ndim(), 0); Shape start(cotan.ndim(), 0);
std::vector<int> stop = cotan.shape(); auto stop = cotan.shape();
for (auto i : axes_) { for (auto i : axes_) {
start[i] = low_pad_size_[i]; start[i] = low_pad_size_[i];
@ -3019,7 +3021,7 @@ std::vector<array> Reduce::vjp(
const std::vector<array>& outputs) { const std::vector<array>& outputs) {
auto in = primals[0]; auto in = primals[0];
std::vector<int> shape = in.shape(); auto shape = in.shape();
for (auto ax : axes_) { for (auto ax : axes_) {
shape[ax] = 1; shape[ax] = 1;
} }
@ -3044,7 +3046,7 @@ std::vector<array> Reduce::vjp(
if (axes_.size() > 1) { if (axes_.size() > 1) {
std::vector<int> transpose_to; std::vector<int> transpose_to;
std::vector<int> transpose_back; std::vector<int> transpose_back;
std::vector<int> shape_flat; Shape shape_flat;
{ {
// Find the transpose needed to move axes_ to the back and the shape // Find the transpose needed to move axes_ to the back and the shape
// except the reduced over axes. // except the reduced over axes.
@ -3422,7 +3424,7 @@ std::pair<std::vector<array>, std::vector<int>> Scatter::vmap(
} }
auto vmap_inds = arange(vmap_size, inputs[1].dtype(), stream()); auto vmap_inds = arange(vmap_size, inputs[1].dtype(), stream());
auto vmap_inds_shape = std::vector<int>(inputs[1].ndim(), 1); auto vmap_inds_shape = Shape(inputs[1].ndim(), 1);
vmap_inds_shape[0] = vmap_inds.size(); vmap_inds_shape[0] = vmap_inds.size();
vmap_inds = reshape(vmap_inds, std::move(vmap_inds_shape), stream()); vmap_inds = reshape(vmap_inds, std::move(vmap_inds_shape), stream());
inputs.insert( inputs.insert(
@ -3607,7 +3609,7 @@ std::vector<array> Slice::vjp(
// Transpose and reshape cotangents // Transpose and reshape cotangents
auto cotan = cotangents[0]; auto cotan = cotangents[0];
if (!ind_axes.empty()) { if (!ind_axes.empty()) {
std::vector<int> cotan_shape; Shape cotan_shape;
for (auto ax : ind_axes) { for (auto ax : ind_axes) {
cotan_shape.push_back(cotan.shape(ax)); cotan_shape.push_back(cotan.shape(ax));
} }
@ -3626,7 +3628,7 @@ std::vector<array> Slice::vjp(
} }
// Make indices broadcastable // Make indices broadcastable
std::vector<int> inds_shape(inds.size(), 1); Shape inds_shape(inds.size(), 1);
for (int i = 0; i < inds.size(); ++i) { for (int i = 0; i < inds.size(); ++i) {
inds_shape[i] = inds[i].size(); inds_shape[i] = inds[i].size();
inds[i] = reshape(inds[i], inds_shape, stream()); inds[i] = reshape(inds[i], inds_shape, stream());
@ -4184,7 +4186,7 @@ std::vector<array> BlockMaskedMM::vjp(
// Slice mask // Slice mask
mask_reshape[mask_ndim - 2] = Y; mask_reshape[mask_ndim - 2] = Y;
mask_reshape[mask_ndim - 1] = X; mask_reshape[mask_ndim - 1] = X;
mask = slice(mask, std::vector<int>(mask_ndim, 0), mask_reshape, stream()); mask = slice(mask, Shape(mask_ndim, 0), mask_reshape, stream());
return mask; return mask;
}; };
@ -4202,7 +4204,7 @@ std::vector<array> BlockMaskedMM::vjp(
} }
// Reshape // Reshape
std::vector<int> r_reshape(r.shape().begin(), r.shape().end() - 2); Shape r_reshape(r.shape().begin(), r.shape().end() - 2);
r_reshape.push_back(r.shape(-2) / block_size_); r_reshape.push_back(r.shape(-2) / block_size_);
r_reshape.push_back(block_size_); r_reshape.push_back(block_size_);
r_reshape.push_back(r.shape(-1) / block_size_); r_reshape.push_back(r.shape(-1) / block_size_);
@ -4492,7 +4494,7 @@ std::pair<std::vector<array>, std::vector<int>> NumberOfElements::vmap(
} }
array out = array( array out = array(
std::vector<int>{}, {},
dtype_, dtype_,
std::make_shared<NumberOfElements>(stream(), new_axes, inverted_, dtype_), std::make_shared<NumberOfElements>(stream(), new_axes, inverted_, dtype_),
inputs); inputs);

View File

@ -1088,10 +1088,7 @@ class Full : public UnaryPrimitive {
class Gather : public UnaryPrimitive { class Gather : public UnaryPrimitive {
public: public:
explicit Gather( explicit Gather(Stream stream, std::vector<int> axes, Shape slice_sizes)
Stream stream,
std::vector<int> axes,
std::vector<int> slice_sizes)
: UnaryPrimitive(stream), : UnaryPrimitive(stream),
axes_(std::move(axes)), axes_(std::move(axes)),
slice_sizes_(std::move(slice_sizes)) {} slice_sizes_(std::move(slice_sizes)) {}
@ -1108,7 +1105,7 @@ class Gather : public UnaryPrimitive {
private: private:
void eval(const std::vector<array>& inputs, array& out); void eval(const std::vector<array>& inputs, array& out);
std::vector<int> axes_; std::vector<int> axes_;
std::vector<int> slice_sizes_; Shape slice_sizes_;
}; };
class Greater : public UnaryPrimitive { class Greater : public UnaryPrimitive {
@ -1503,8 +1500,8 @@ class Pad : public UnaryPrimitive {
explicit Pad( explicit Pad(
Stream stream, Stream stream,
const std::vector<int>& axes, const std::vector<int>& axes,
const std::vector<int>& low_pad_size, const Shape& low_pad_size,
const std::vector<int>& high_pad_size) const Shape& high_pad_size)
: UnaryPrimitive(stream), : UnaryPrimitive(stream),
axes_(axes), axes_(axes),
low_pad_size_(low_pad_size), low_pad_size_(low_pad_size),
@ -1520,8 +1517,8 @@ class Pad : public UnaryPrimitive {
private: private:
std::vector<int> axes_; std::vector<int> axes_;
std::vector<int> low_pad_size_; Shape low_pad_size_;
std::vector<int> high_pad_size_; Shape high_pad_size_;
void eval(const std::vector<array>& inputs, array& out); void eval(const std::vector<array>& inputs, array& out);
}; };
@ -1903,9 +1900,9 @@ class Slice : public UnaryPrimitive {
public: public:
explicit Slice( explicit Slice(
Stream stream, Stream stream,
const std::vector<int>& start_indices, const Shape& start_indices,
const std::vector<int>& end_indices, const Shape& end_indices,
const std::vector<int>& strides) const Shape& strides)
: UnaryPrimitive(stream), : UnaryPrimitive(stream),
start_indices_(start_indices), start_indices_(start_indices),
end_indices_(end_indices), end_indices_(end_indices),
@ -1920,9 +1917,9 @@ class Slice : public UnaryPrimitive {
bool is_equivalent(const Primitive& other) const override; bool is_equivalent(const Primitive& other) const override;
private: private:
std::vector<int> start_indices_; Shape start_indices_;
std::vector<int> end_indices_; Shape end_indices_;
std::vector<int> strides_; Shape strides_;
void eval(const std::vector<array>& inputs, array& out); void eval(const std::vector<array>& inputs, array& out);
}; };
@ -1931,9 +1928,9 @@ class SliceUpdate : public UnaryPrimitive {
public: public:
explicit SliceUpdate( explicit SliceUpdate(
Stream stream, Stream stream,
const std::vector<int>& start_indices, const Shape& start_indices,
const std::vector<int>& end_indices, const Shape& end_indices,
const std::vector<int>& strides) const Shape& strides)
: UnaryPrimitive(stream), : UnaryPrimitive(stream),
start_indices_(start_indices), start_indices_(start_indices),
end_indices_(end_indices), end_indices_(end_indices),
@ -1948,9 +1945,9 @@ class SliceUpdate : public UnaryPrimitive {
bool is_equivalent(const Primitive& other) const override; bool is_equivalent(const Primitive& other) const override;
private: private:
std::vector<int> start_indices_; Shape start_indices_;
std::vector<int> end_indices_; Shape end_indices_;
std::vector<int> strides_; Shape strides_;
void eval(const std::vector<array>& inputs, array& out); void eval(const std::vector<array>& inputs, array& out);
}; };
@ -1997,7 +1994,7 @@ class Sort : public UnaryPrimitive {
class Split : public Primitive { class Split : public Primitive {
public: public:
explicit Split(Stream stream, const std::vector<int>& indices, int axis) explicit Split(Stream stream, const Shape& indices, int axis)
: Primitive(stream), indices_(indices), axis_(axis) {} : Primitive(stream), indices_(indices), axis_(axis) {}
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs) void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
@ -2013,7 +2010,7 @@ class Split : public Primitive {
private: private:
void eval(const std::vector<array>& inputs, std::vector<array>& outputs); void eval(const std::vector<array>& inputs, std::vector<array>& outputs);
std::vector<int> indices_; Shape indices_;
int axis_; int axis_;
}; };

View File

@ -296,7 +296,7 @@ std::ostream& operator<<(std::ostream& os, array a) {
return os; return os;
} }
std::ostream& operator<<(std::ostream& os, const Shape& v) { std::ostream& operator<<(std::ostream& os, const std::vector<int>& v) {
os << "("; os << "(";
for (int i = 0; i < v.size(); ++i) { for (int i = 0; i < v.size(); ++i) {
os << v[i] << ((i == v.size() - 1) ? "" : ","); os << v[i] << ((i == v.size() - 1) ? "" : ",");
@ -305,7 +305,7 @@ std::ostream& operator<<(std::ostream& os, const Shape& v) {
return os; return os;
} }
std::ostream& operator<<(std::ostream& os, const Strides& v) { std::ostream& operator<<(std::ostream& os, const std::vector<int64_t>& v) {
os << "("; os << "(";
for (int i = 0; i < v.size(); ++i) { for (int i = 0; i < v.size(); ++i) {
os << v[i] << ((i == v.size() - 1) ? "" : ","); os << v[i] << ((i == v.size() - 1) ? "" : ",");

View File

@ -77,8 +77,8 @@ std::ostream& operator<<(std::ostream& os, const Stream& s);
std::ostream& operator<<(std::ostream& os, const Dtype& d); std::ostream& operator<<(std::ostream& os, const Dtype& d);
std::ostream& operator<<(std::ostream& os, const Dtype::Kind& k); std::ostream& operator<<(std::ostream& os, const Dtype::Kind& k);
std::ostream& operator<<(std::ostream& os, array a); std::ostream& operator<<(std::ostream& os, array a);
std::ostream& operator<<(std::ostream& os, const Shape& v); std::ostream& operator<<(std::ostream& os, const std::vector<int>& v);
std::ostream& operator<<(std::ostream& os, const Strides& v); std::ostream& operator<<(std::ostream& os, const std::vector<int64_t>& v);
inline std::ostream& operator<<(std::ostream& os, const complex64_t& v) { inline std::ostream& operator<<(std::ostream& os, const complex64_t& v) {
return os << v.real() << (v.imag() >= 0 ? "+" : "") << v.imag() << "j"; return os << v.real() << (v.imag() >= 0 ? "+" : "") << v.imag() << "j";
} }

View File

@ -889,13 +889,13 @@ void init_array(nb::module_& m) {
.def( .def(
"reshape", "reshape",
[](const mx::array& a, nb::args shape_, mx::StreamOrDevice s) { [](const mx::array& a, nb::args shape_, mx::StreamOrDevice s) {
std::vector<int> shape; mx::Shape shape;
if (!nb::isinstance<int>(shape_[0])) { if (!nb::isinstance<int>(shape_[0])) {
shape = nb::cast<std::vector<int>>(shape_[0]); shape = nb::cast<mx::Shape>(shape_[0]);
} else { } else {
shape = nb::cast<std::vector<int>>(shape_); shape = nb::cast<mx::Shape>(shape_);
} }
return mx::reshape(a, shape, s); return mx::reshape(a, std::move(shape), s);
}, },
"shape"_a, "shape"_a,
"stream"_a = nb::none(), "stream"_a = nb::none(),
@ -1182,14 +1182,14 @@ void init_array(nb::module_& m) {
.def( .def(
"split", "split",
[](const mx::array& a, [](const mx::array& a,
const std::variant<int, std::vector<int>>& indices_or_sections, const std::variant<int, mx::Shape>& indices_or_sections,
int axis, int axis,
mx::StreamOrDevice s) { mx::StreamOrDevice s) {
if (auto pv = std::get_if<int>(&indices_or_sections); pv) { if (auto pv = std::get_if<int>(&indices_or_sections); pv) {
return mx::split(a, *pv, axis, s); return mx::split(a, *pv, axis, s);
} else { } else {
return mx::split( return mx::split(
a, std::get<std::vector<int>>(indices_or_sections), axis, s); a, std::get<mx::Shape>(indices_or_sections), axis, s);
} }
}, },
"indices_or_sections"_a, "indices_or_sections"_a,

View File

@ -181,7 +181,7 @@ void init_fast(nb::module_& parent_module) {
return nb::cpp_function( return nb::cpp_function(
[kernel = std::move(kernel)]( [kernel = std::move(kernel)](
const std::vector<ScalarOrArray>& inputs_, const std::vector<ScalarOrArray>& inputs_,
const std::vector<std::vector<int>>& output_shapes, const std::vector<mx::Shape>& output_shapes,
const std::vector<mx::Dtype>& output_dtypes, const std::vector<mx::Dtype>& output_dtypes,
std::tuple<int, int, int> grid, std::tuple<int, int, int> grid,
std::tuple<int, int, int> threadgroup, std::tuple<int, int, int> threadgroup,

View File

@ -79,7 +79,7 @@ void init_fft(nb::module_& parent_module) {
m.def( m.def(
"fft2", "fft2",
[](const mx::array& a, [](const mx::array& a,
const std::optional<std::vector<int>>& n, const std::optional<mx::Shape>& n,
const std::optional<std::vector<int>>& axes, const std::optional<std::vector<int>>& axes,
mx::StreamOrDevice s) { mx::StreamOrDevice s) {
if (axes.has_value() && n.has_value()) { if (axes.has_value() && n.has_value()) {
@ -115,7 +115,7 @@ void init_fft(nb::module_& parent_module) {
m.def( m.def(
"ifft2", "ifft2",
[](const mx::array& a, [](const mx::array& a,
const std::optional<std::vector<int>>& n, const std::optional<mx::Shape>& n,
const std::optional<std::vector<int>>& axes, const std::optional<std::vector<int>>& axes,
mx::StreamOrDevice s) { mx::StreamOrDevice s) {
if (axes.has_value() && n.has_value()) { if (axes.has_value() && n.has_value()) {
@ -151,7 +151,7 @@ void init_fft(nb::module_& parent_module) {
m.def( m.def(
"fftn", "fftn",
[](const mx::array& a, [](const mx::array& a,
const std::optional<std::vector<int>>& n, const std::optional<mx::Shape>& n,
const std::optional<std::vector<int>>& axes, const std::optional<std::vector<int>>& axes,
mx::StreamOrDevice s) { mx::StreamOrDevice s) {
if (axes.has_value() && n.has_value()) { if (axes.has_value() && n.has_value()) {
@ -188,7 +188,7 @@ void init_fft(nb::module_& parent_module) {
m.def( m.def(
"ifftn", "ifftn",
[](const mx::array& a, [](const mx::array& a,
const std::optional<std::vector<int>>& n, const std::optional<mx::Shape>& n,
const std::optional<std::vector<int>>& axes, const std::optional<std::vector<int>>& axes,
mx::StreamOrDevice s) { mx::StreamOrDevice s) {
if (axes.has_value() && n.has_value()) { if (axes.has_value() && n.has_value()) {
@ -294,7 +294,7 @@ void init_fft(nb::module_& parent_module) {
m.def( m.def(
"rfft2", "rfft2",
[](const mx::array& a, [](const mx::array& a,
const std::optional<std::vector<int>>& n, const std::optional<mx::Shape>& n,
const std::optional<std::vector<int>>& axes, const std::optional<std::vector<int>>& axes,
mx::StreamOrDevice s) { mx::StreamOrDevice s) {
if (axes.has_value() && n.has_value()) { if (axes.has_value() && n.has_value()) {
@ -336,7 +336,7 @@ void init_fft(nb::module_& parent_module) {
m.def( m.def(
"irfft2", "irfft2",
[](const mx::array& a, [](const mx::array& a,
const std::optional<std::vector<int>>& n, const std::optional<mx::Shape>& n,
const std::optional<std::vector<int>>& axes, const std::optional<std::vector<int>>& axes,
mx::StreamOrDevice s) { mx::StreamOrDevice s) {
if (axes.has_value() && n.has_value()) { if (axes.has_value() && n.has_value()) {
@ -378,7 +378,7 @@ void init_fft(nb::module_& parent_module) {
m.def( m.def(
"rfftn", "rfftn",
[](const mx::array& a, [](const mx::array& a,
const std::optional<std::vector<int>>& n, const std::optional<mx::Shape>& n,
const std::optional<std::vector<int>>& axes, const std::optional<std::vector<int>>& axes,
mx::StreamOrDevice s) { mx::StreamOrDevice s) {
if (axes.has_value() && n.has_value()) { if (axes.has_value() && n.has_value()) {
@ -420,7 +420,7 @@ void init_fft(nb::module_& parent_module) {
m.def( m.def(
"irfftn", "irfftn",
[](const mx::array& a, [](const mx::array& a,
const std::optional<std::vector<int>>& n, const std::optional<mx::Shape>& n,
const std::optional<std::vector<int>>& axes, const std::optional<std::vector<int>>& axes,
mx::StreamOrDevice s) { mx::StreamOrDevice s) {
if (axes.has_value() && n.has_value()) { if (axes.has_value() && n.has_value()) {

View File

@ -25,9 +25,9 @@ int get_slice_int(nb::object obj, int default_val) {
} }
void get_slice_params( void get_slice_params(
int& starts, mx::ShapeElem& starts,
int& ends, mx::ShapeElem& ends,
int& strides, mx::ShapeElem& strides,
const nb::slice& in_slice, const nb::slice& in_slice,
int axis_size) { int axis_size) {
// Following numpy's convention // Following numpy's convention
@ -68,9 +68,9 @@ mx::array mlx_get_item_slice(const mx::array& src, const nb::slice& in_slice) {
return src; return src;
} }
std::vector<int> starts(src.ndim(), 0); mx::Shape starts(src.ndim(), 0);
std::vector<int> ends = src.shape(); auto ends = src.shape();
std::vector<int> strides(src.ndim(), 1); mx::Shape strides(src.ndim(), 1);
// Check and update slice params // Check and update slice params
get_slice_params(starts[0], ends[0], strides[0], in_slice, ends[0]); get_slice_params(starts[0], ends[0], strides[0], in_slice, ends[0]);
@ -119,7 +119,7 @@ mx::array mlx_gather_nd(
auto& idx = indices[i]; auto& idx = indices[i];
if (nb::isinstance<nb::slice>(idx)) { if (nb::isinstance<nb::slice>(idx)) {
int start, end, stride; mx::ShapeElem start, end, stride;
get_slice_params( get_slice_params(
start, end, stride, nb::cast<nb::slice>(idx), src.shape(i)); start, end, stride, nb::cast<nb::slice>(idx), src.shape(i));
@ -168,7 +168,7 @@ mx::array mlx_gather_nd(
// Do the gather // Do the gather
std::vector<int> axes(indices.size()); std::vector<int> axes(indices.size());
std::iota(axes.begin(), axes.end(), 0); std::iota(axes.begin(), axes.end(), 0);
std::vector<int> slice_sizes = src.shape(); auto slice_sizes = src.shape();
std::fill(slice_sizes.begin(), slice_sizes.begin() + indices.size(), 1); std::fill(slice_sizes.begin(), slice_sizes.begin() + indices.size(), 1);
src = gather(src, gather_indices, axes, slice_sizes); src = gather(src, gather_indices, axes, slice_sizes);
@ -179,9 +179,7 @@ mx::array mlx_gather_nd(
return mx::squeeze(src, axes); return mx::squeeze(src, axes);
} }
auto mlx_expand_ellipsis( auto mlx_expand_ellipsis(const mx::Shape& shape, const nb::tuple& entries) {
const std::vector<int>& shape,
const nb::tuple& entries) {
std::vector<nb::object> indices; std::vector<nb::object> indices;
// Go over all entries and note the position of ellipsis // Go over all entries and note the position of ellipsis
@ -230,7 +228,8 @@ auto mlx_expand_ellipsis(
for (int axis = non_none_indices_before; for (int axis = non_none_indices_before;
axis < shape.size() - non_none_indices_after; axis < shape.size() - non_none_indices_after;
axis++) { axis++) {
indices.push_back(nb::slice(0, shape[axis], 1)); indices.push_back(
nb::slice(mx::ShapeElem{0}, shape[axis], mx::ShapeElem{1}));
non_none_indices++; non_none_indices++;
} }
} }
@ -371,9 +370,9 @@ mx::array mlx_get_item_nd(mx::array src, const nb::tuple& entries) {
// Slice handling // Slice handling
{ {
std::vector<int> starts(src.ndim(), 0); mx::Shape starts(src.ndim(), 0);
std::vector<int> ends = src.shape(); auto ends = src.shape();
std::vector<int> strides(src.ndim(), 1); mx::Shape strides(src.ndim(), 1);
int axis = 0; int axis = 0;
for (auto& idx : remaining_indices) { for (auto& idx : remaining_indices) {
if (!idx.is_none()) { if (!idx.is_none()) {
@ -461,8 +460,7 @@ mlx_scatter_args_int(
int s = 0; int s = 0;
for (; s < update.ndim() && update.shape(s) == 1; s++) for (; s < update.ndim() && update.shape(s) == 1; s++)
; ;
auto up_shape = auto up_shape = mx::Shape(update.shape().begin() + s, update.shape().end());
std::vector<int>(update.shape().begin() + s, update.shape().end());
auto shape = src.shape(); auto shape = src.shape();
shape[0] = 1; shape[0] = 1;
@ -521,9 +519,9 @@ mlx_scatter_args_slice(
{}, broadcast_to(squeeze_leading_singletons(update), src.shape()), {}}; {}, broadcast_to(squeeze_leading_singletons(update), src.shape()), {}};
} }
int start = 0; mx::ShapeElem start = 0;
int end = src.shape(0); auto end = src.shape(0);
int stride = 1; mx::ShapeElem stride = 1;
// Check and update slice params // Check and update slice params
get_slice_params(start, end, stride, in_slice, end); get_slice_params(start, end, stride, in_slice, end);
@ -645,7 +643,7 @@ mlx_scatter_args_nd(
for (int i = 0; i < indices.size(); ++i) { for (int i = 0; i < indices.size(); ++i) {
auto& pyidx = indices[i]; auto& pyidx = indices[i];
if (nb::isinstance<nb::slice>(pyidx)) { if (nb::isinstance<nb::slice>(pyidx)) {
int start, end, stride; mx::ShapeElem start, end, stride;
auto axis_size = src.shape(ax++); auto axis_size = src.shape(ax++);
get_slice_params( get_slice_params(
start, end, stride, nb::cast<nb::slice>(pyidx), axis_size); start, end, stride, nb::cast<nb::slice>(pyidx), axis_size);
@ -654,7 +652,7 @@ mlx_scatter_args_nd(
start = (start < 0) ? start + axis_size : start; start = (start < 0) ? start + axis_size : start;
end = (end < 0) ? end + axis_size : end; end = (end < 0) ? end + axis_size : end;
std::vector<int> idx_shape(idx_ndim, 1); mx::Shape idx_shape(idx_ndim, 1);
// If it's a simple slice, we only need to add the start index // If it's a simple slice, we only need to add the start index
if (array_num >= num_arrays && num_strided_slices <= 0 && stride == 1) { if (array_num >= num_arrays && num_strided_slices <= 0 && stride == 1) {

View File

@ -1571,15 +1571,14 @@ void init_ops(nb::module_& m) {
)pbdoc"); )pbdoc");
m.def( m.def(
"full", "full",
[](const std::variant<int, std::vector<int>>& shape, [](const std::variant<int, mx::Shape>& shape,
const ScalarOrArray& vals, const ScalarOrArray& vals,
std::optional<mx::Dtype> dtype, std::optional<mx::Dtype> dtype,
mx::StreamOrDevice s) { mx::StreamOrDevice s) {
if (auto pv = std::get_if<int>(&shape); pv) { if (auto pv = std::get_if<int>(&shape); pv) {
return mx::full({*pv}, to_array(vals, dtype), s); return mx::full({*pv}, to_array(vals, dtype), s);
} else { } else {
return mx::full( return mx::full(std::get<mx::Shape>(shape), to_array(vals, dtype), s);
std::get<std::vector<int>>(shape), to_array(vals, dtype), s);
} }
}, },
"shape"_a, "shape"_a,
@ -1606,14 +1605,14 @@ void init_ops(nb::module_& m) {
)pbdoc"); )pbdoc");
m.def( m.def(
"zeros", "zeros",
[](const std::variant<int, std::vector<int>>& shape, [](const std::variant<int, mx::Shape>& shape,
std::optional<mx::Dtype> dtype, std::optional<mx::Dtype> dtype,
mx::StreamOrDevice s) { mx::StreamOrDevice s) {
auto t = dtype.value_or(mx::float32); auto t = dtype.value_or(mx::float32);
if (auto pv = std::get_if<int>(&shape); pv) { if (auto pv = std::get_if<int>(&shape); pv) {
return mx::zeros({*pv}, t, s); return mx::zeros({*pv}, t, s);
} else { } else {
return mx::zeros(std::get<std::vector<int>>(shape), t, s); return mx::zeros(std::get<mx::Shape>(shape), t, s);
} }
}, },
"shape"_a, "shape"_a,
@ -1652,14 +1651,14 @@ void init_ops(nb::module_& m) {
)pbdoc"); )pbdoc");
m.def( m.def(
"ones", "ones",
[](const std::variant<int, std::vector<int>>& shape, [](const std::variant<int, mx::Shape>& shape,
std::optional<mx::Dtype> dtype, std::optional<mx::Dtype> dtype,
mx::StreamOrDevice s) { mx::StreamOrDevice s) {
auto t = dtype.value_or(mx::float32); auto t = dtype.value_or(mx::float32);
if (auto pv = std::get_if<int>(&shape); pv) { if (auto pv = std::get_if<int>(&shape); pv) {
return mx::ones({*pv}, t, s); return mx::ones({*pv}, t, s);
} else { } else {
return mx::ones(std::get<std::vector<int>>(shape), t, s); return mx::ones(std::get<mx::Shape>(shape), t, s);
} }
}, },
"shape"_a, "shape"_a,
@ -2481,14 +2480,14 @@ void init_ops(nb::module_& m) {
m.def( m.def(
"split", "split",
[](const mx::array& a, [](const mx::array& a,
const std::variant<int, std::vector<int>>& indices_or_sections, const std::variant<int, mx::Shape>& indices_or_sections,
int axis, int axis,
mx::StreamOrDevice s) { mx::StreamOrDevice s) {
if (auto pv = std::get_if<int>(&indices_or_sections); pv) { if (auto pv = std::get_if<int>(&indices_or_sections); pv) {
return mx::split(a, *pv, axis, s); return mx::split(a, *pv, axis, s);
} else { } else {
return mx::split( return mx::split(
a, std::get<std::vector<int>>(indices_or_sections), axis, s); a, std::get<mx::Shape>(indices_or_sections), axis, s);
} }
}, },
nb::arg(), nb::arg(),
@ -2744,9 +2743,7 @@ void init_ops(nb::module_& m) {
)pbdoc"); )pbdoc");
m.def( m.def(
"broadcast_to", "broadcast_to",
[](const ScalarOrArray& a, [](const ScalarOrArray& a, const mx::Shape& shape, mx::StreamOrDevice s) {
const std::vector<int>& shape,
mx::StreamOrDevice s) {
return mx::broadcast_to(to_array(a), shape, s); return mx::broadcast_to(to_array(a), shape, s);
}, },
nb::arg(), nb::arg(),
@ -4895,24 +4892,16 @@ void init_ops(nb::module_& m) {
m.def( m.def(
"roll", "roll",
[](const mx::array& a, [](const mx::array& a,
const IntOrVec& shift, const std::variant<int, mx::Shape>& shift,
const IntOrVec& axis, const IntOrVec& axis,
mx::StreamOrDevice s) { mx::StreamOrDevice s) {
return std::visit( return std::visit(
[&](auto sh, auto ax) -> mx::array { [&](auto sh, auto ax) -> mx::array {
using T = decltype(ax); if constexpr (std::is_same_v<decltype(ax), std::monostate>) {
using V = decltype(sh);
if constexpr (std::is_same_v<V, std::monostate>) {
throw std::invalid_argument(
"[roll] Expected two arguments but only one was given.");
} else {
if constexpr (std::is_same_v<T, std::monostate>) {
return mx::roll(a, sh, s); return mx::roll(a, sh, s);
} else { } else {
return mx::roll(a, sh, ax, s); return mx::roll(a, sh, ax, s);
} }
}
}, },
shift, shift,
axis); axis);

View File

@ -108,7 +108,7 @@ void init_random(nb::module_& parent_module) {
"uniform", "uniform",
[](const ScalarOrArray& low, [](const ScalarOrArray& low,
const ScalarOrArray& high, const ScalarOrArray& high,
const std::vector<int>& shape, const mx::Shape& shape,
std::optional<mx::Dtype> type, std::optional<mx::Dtype> type,
const std::optional<mx::array>& key_, const std::optional<mx::array>& key_,
mx::StreamOrDevice s) { mx::StreamOrDevice s) {
@ -123,7 +123,7 @@ void init_random(nb::module_& parent_module) {
}, },
"low"_a = 0, "low"_a = 0,
"high"_a = 1, "high"_a = 1,
"shape"_a = std::vector<int>{}, "shape"_a = mx::Shape{},
"dtype"_a.none() = mx::float32, "dtype"_a.none() = mx::float32,
"key"_a = nb::none(), "key"_a = nb::none(),
"stream"_a = nb::none(), "stream"_a = nb::none(),
@ -150,7 +150,7 @@ void init_random(nb::module_& parent_module) {
)pbdoc"); )pbdoc");
m.def( m.def(
"normal", "normal",
[](const std::vector<int>& shape, [](const mx::Shape& shape,
std::optional<mx::Dtype> type, std::optional<mx::Dtype> type,
float loc, float loc,
float scale, float scale,
@ -160,7 +160,7 @@ void init_random(nb::module_& parent_module) {
return mx::random::normal( return mx::random::normal(
shape, type.value_or(mx::float32), loc, scale, key, s); shape, type.value_or(mx::float32), loc, scale, key, s);
}, },
"shape"_a = std::vector<int>{}, "shape"_a = mx::Shape{},
"dtype"_a.none() = mx::float32, "dtype"_a.none() = mx::float32,
"loc"_a = 0.0, "loc"_a = 0.0,
"scale"_a = 1.0, "scale"_a = 1.0,
@ -185,7 +185,7 @@ void init_random(nb::module_& parent_module) {
"multivariate_normal", "multivariate_normal",
[](const mx::array& mean, [](const mx::array& mean,
const mx::array& cov, const mx::array& cov,
const std::vector<int>& shape, const mx::Shape& shape,
std::optional<mx::Dtype> type, std::optional<mx::Dtype> type,
const std::optional<mx::array>& key_, const std::optional<mx::array>& key_,
mx::StreamOrDevice s) { mx::StreamOrDevice s) {
@ -195,7 +195,7 @@ void init_random(nb::module_& parent_module) {
}, },
"mean"_a, "mean"_a,
"cov"_a, "cov"_a,
"shape"_a = std::vector<int>{}, "shape"_a = mx::Shape{},
"dtype"_a.none() = mx::float32, "dtype"_a.none() = mx::float32,
"key"_a = nb::none(), "key"_a = nb::none(),
"stream"_a = nb::none(), "stream"_a = nb::none(),
@ -227,7 +227,7 @@ void init_random(nb::module_& parent_module) {
"randint", "randint",
[](const ScalarOrArray& low, [](const ScalarOrArray& low,
const ScalarOrArray& high, const ScalarOrArray& high,
const std::vector<int>& shape, const mx::Shape& shape,
std::optional<mx::Dtype> type, std::optional<mx::Dtype> type,
const std::optional<mx::array>& key_, const std::optional<mx::array>& key_,
mx::StreamOrDevice s) { mx::StreamOrDevice s) {
@ -242,7 +242,7 @@ void init_random(nb::module_& parent_module) {
}, },
"low"_a, "low"_a,
"high"_a, "high"_a,
"shape"_a = std::vector<int>{}, "shape"_a = mx::Shape{},
"dtype"_a.none() = mx::int32, "dtype"_a.none() = mx::int32,
"key"_a = nb::none(), "key"_a = nb::none(),
"stream"_a = nb::none(), "stream"_a = nb::none(),
@ -268,7 +268,7 @@ void init_random(nb::module_& parent_module) {
m.def( m.def(
"bernoulli", "bernoulli",
[](const ScalarOrArray& p_, [](const ScalarOrArray& p_,
const std::optional<std::vector<int>> shape, const std::optional<mx::Shape> shape,
const std::optional<mx::array>& key_, const std::optional<mx::array>& key_,
mx::StreamOrDevice s) { mx::StreamOrDevice s) {
auto key = key_ ? key_.value() : default_key().next(); auto key = key_ ? key_.value() : default_key().next();
@ -306,7 +306,7 @@ void init_random(nb::module_& parent_module) {
"truncated_normal", "truncated_normal",
[](const ScalarOrArray& lower_, [](const ScalarOrArray& lower_,
const ScalarOrArray& upper_, const ScalarOrArray& upper_,
const std::optional<std::vector<int>> shape_, const std::optional<mx::Shape> shape_,
std::optional<mx::Dtype> type, std::optional<mx::Dtype> type,
const std::optional<mx::array>& key_, const std::optional<mx::array>& key_,
mx::StreamOrDevice s) { mx::StreamOrDevice s) {
@ -350,14 +350,14 @@ void init_random(nb::module_& parent_module) {
)pbdoc"); )pbdoc");
m.def( m.def(
"gumbel", "gumbel",
[](const std::vector<int>& shape, [](const mx::Shape& shape,
std::optional<mx::Dtype> type, std::optional<mx::Dtype> type,
const std::optional<mx::array>& key_, const std::optional<mx::array>& key_,
mx::StreamOrDevice s) { mx::StreamOrDevice s) {
auto key = key_ ? key_.value() : default_key().next(); auto key = key_ ? key_.value() : default_key().next();
return mx::random::gumbel(shape, type.value_or(mx::float32), key, s); return mx::random::gumbel(shape, type.value_or(mx::float32), key, s);
}, },
"shape"_a = std::vector<int>{}, "shape"_a = mx::Shape{},
"dtype"_a.none() = mx::float32, "dtype"_a.none() = mx::float32,
"key"_a = nb::none(), "key"_a = nb::none(),
"stream"_a = nb::none(), "stream"_a = nb::none(),
@ -384,7 +384,7 @@ void init_random(nb::module_& parent_module) {
"categorical", "categorical",
[](const mx::array& logits, [](const mx::array& logits,
int axis, int axis,
const std::optional<std::vector<int>> shape, const std::optional<mx::Shape> shape,
const std::optional<int> num_samples, const std::optional<int> num_samples,
const std::optional<mx::array>& key_, const std::optional<mx::array>& key_,
mx::StreamOrDevice s) { mx::StreamOrDevice s) {
@ -434,7 +434,7 @@ void init_random(nb::module_& parent_module) {
)pbdoc"); )pbdoc");
m.def( m.def(
"laplace", "laplace",
[](const std::vector<int>& shape, [](const mx::Shape& shape,
std::optional<mx::Dtype> type, std::optional<mx::Dtype> type,
float loc, float loc,
float scale, float scale,
@ -444,7 +444,7 @@ void init_random(nb::module_& parent_module) {
return mx::random::laplace( return mx::random::laplace(
shape, type.value_or(mx::float32), loc, scale, key, s); shape, type.value_or(mx::float32), loc, scale, key, s);
}, },
"shape"_a = std::vector<int>{}, "shape"_a = mx::Shape{},
"dtype"_a.none() = mx::float32, "dtype"_a.none() = mx::float32,
"loc"_a = 0.0, "loc"_a = 0.0,
"scale"_a = 1.0, "scale"_a = 1.0,
@ -479,7 +479,7 @@ void init_random(nb::module_& parent_module) {
return mx::random::permutation(std::get<mx::array>(x), axis, key, s); return mx::random::permutation(std::get<mx::array>(x), axis, key, s);
} }
}, },
"shape"_a = std::vector<int>{}, "x"_a,
"axis"_a = 0, "axis"_a = 0,
"key"_a = nb::none(), "key"_a = nb::none(),
"stream"_a = nb::none(), "stream"_a = nb::none(),

View File

@ -1,4 +1,5 @@
// Copyright © 2023-2024 Apple Inc. // Copyright © 2023-2024 Apple Inc.
#include <nanobind/nanobind.h> #include <nanobind/nanobind.h>
#include <nanobind/stl/optional.h> #include <nanobind/stl/optional.h>
#include <nanobind/stl/pair.h> #include <nanobind/stl/pair.h>

View File

@ -395,7 +395,7 @@ TEST_CASE("test split") {
CHECK_EQ(out[1].shape(), Shape{8, 4}); CHECK_EQ(out[1].shape(), Shape{8, 4});
CHECK_EQ(out[2].shape(), Shape{8, 4}); CHECK_EQ(out[2].shape(), Shape{8, 4});
out = split(x, std::vector<int>{}); out = split(x, Shape{});
CHECK_EQ(out.size(), 1); CHECK_EQ(out.size(), 1);
CHECK_EQ(out[0].shape(), x.shape()); CHECK_EQ(out[0].shape(), x.shape());
@ -405,25 +405,25 @@ TEST_CASE("test split") {
CHECK_EQ(out[1].shape(), Shape{4, 12}); CHECK_EQ(out[1].shape(), Shape{4, 12});
CHECK_EQ(out[2].shape(), Shape{1, 12}); CHECK_EQ(out[2].shape(), Shape{1, 12});
out = split(x, std::vector<int>{20}); out = split(x, Shape{20});
CHECK_EQ(out.size(), 2); CHECK_EQ(out.size(), 2);
CHECK_EQ(out[0].shape(), Shape{8, 12}); CHECK_EQ(out[0].shape(), Shape{8, 12});
CHECK_EQ(out[1].shape(), Shape{0, 12}); CHECK_EQ(out[1].shape(), Shape{0, 12});
// Negative indices // Negative indices
out = split(x, std::vector<int>{-5}); out = split(x, Shape{-5});
CHECK_EQ(out[0].shape(), Shape{3, 12}); CHECK_EQ(out[0].shape(), Shape{3, 12});
CHECK_EQ(out[1].shape(), Shape{5, 12}); CHECK_EQ(out[1].shape(), Shape{5, 12});
// Different axis // Different axis
out = split(x, std::vector<int>{2, 8}, 1); out = split(x, {2, 8}, 1);
CHECK_EQ(out[0].shape(), Shape{8, 2}); CHECK_EQ(out[0].shape(), Shape{8, 2});
CHECK_EQ(out[1].shape(), Shape{8, 6}); CHECK_EQ(out[1].shape(), Shape{8, 6});
CHECK_EQ(out[2].shape(), Shape{8, 4}); CHECK_EQ(out[2].shape(), Shape{8, 4});
// Out of order indices // Out of order indices
x = arange(5); x = arange(5);
out = split(x, std::vector<int>{2, 1, 2}); out = split(x, {2, 1, 2});
CHECK(array_equal(out[0], array({0, 1})).item<bool>()); CHECK(array_equal(out[0], array({0, 1})).item<bool>());
CHECK(array_equal(out[1], array({})).item<bool>()); CHECK(array_equal(out[1], array({})).item<bool>());
CHECK(array_equal(out[2], array({1})).item<bool>()); CHECK(array_equal(out[2], array({1})).item<bool>());

View File

@ -611,8 +611,8 @@ TEST_CASE("test categorical") {
CHECK_THROWS(categorical(logits, -3)); CHECK_THROWS(categorical(logits, -3));
// Invalid requested shapes // Invalid requested shapes
CHECK_THROWS(categorical(logits, 1, std::vector<int>{1})); CHECK_THROWS(categorical(logits, 1, Shape{1}));
CHECK_THROWS(categorical(logits, 1, std::vector<int>{11})); CHECK_THROWS(categorical(logits, 1, Shape{11}));
CHECK_THROWS(categorical(logits, 1, {10, 1})); CHECK_THROWS(categorical(logits, 1, {10, 1}));
CHECK_EQ(categorical(logits, -1).shape(), Shape{10}); CHECK_EQ(categorical(logits, -1).shape(), Shape{10});

View File

@ -335,8 +335,7 @@ TEST_CASE("test vmap gather") {
auto fun = [](std::vector<array> inputs) { auto fun = [](std::vector<array> inputs) {
auto src = inputs[0]; auto src = inputs[0];
auto indices = inputs[1]; auto indices = inputs[1];
std::vector<int> slice_sizes = {1, 2, 2}; auto out = squeeze(gather(src, indices, 0, {1, 2, 2}), 2);
auto out = squeeze(gather(src, indices, 0, slice_sizes), 2);
return std::vector<array>{out}; return std::vector<array>{out};
}; };
auto x = zeros({2, 2, 2, 2}); auto x = zeros({2, 2, 2, 2});
@ -351,8 +350,7 @@ TEST_CASE("test vmap gather") {
auto fun = [](std::vector<array> inputs) { auto fun = [](std::vector<array> inputs) {
auto src = inputs[0]; auto src = inputs[0];
auto indices = inputs[1]; auto indices = inputs[1];
std::vector<int> slice_sizes = {1, 2, 2}; auto out = squeeze(gather(src, indices, 0, {1, 2, 2}), 1);
auto out = squeeze(gather(src, indices, 0, slice_sizes), 1);
return std::vector<array>{out}; return std::vector<array>{out};
}; };
auto x = zeros({2, 2, 2, 2}); auto x = zeros({2, 2, 2, 2});
@ -365,8 +363,7 @@ TEST_CASE("test vmap gather") {
auto fun = [](std::vector<array> inputs) { auto fun = [](std::vector<array> inputs) {
auto src = inputs[0]; auto src = inputs[0];
auto indices = inputs[1]; auto indices = inputs[1];
std::vector<int> slice_sizes = {1, 2, 2, 2}; auto out = squeeze(gather(src, indices, 0, {1, 2, 2, 2}), 1);
auto out = squeeze(gather(src, indices, 0, slice_sizes), 1);
return std::vector<array>{out}; return std::vector<array>{out};
}; };
auto x = zeros({2, 2, 2, 2}); auto x = zeros({2, 2, 2, 2});
@ -380,8 +377,7 @@ TEST_CASE("test vmap gather") {
auto fun = [](std::vector<array> inputs) { auto fun = [](std::vector<array> inputs) {
auto src = inputs[0]; auto src = inputs[0];
auto indices = std::vector<array>(inputs.begin() + 1, inputs.end()); auto indices = std::vector<array>(inputs.begin() + 1, inputs.end());
std::vector<int> slice_sizes = {1, 1, 2, 2}; auto out = squeeze(gather(src, indices, {0, 1}, {1, 1, 2, 2}), {1, 2});
auto out = squeeze(gather(src, indices, {0, 1}, slice_sizes), {1, 2});
return std::vector<array>{out}; return std::vector<array>{out};
}; };
auto x = zeros({2, 2, 2, 2}); auto x = zeros({2, 2, 2, 2});