diff --git a/mlx/array.cpp b/mlx/array.cpp index eeda019a7..9cf36416f 100644 --- a/mlx/array.cpp +++ b/mlx/array.cpp @@ -25,7 +25,7 @@ bool retain_graph() { } // namespace array::array(const std::complex& val, Dtype dtype /* = complex64 */) - : array_desc_(std::make_shared(std::vector{}, dtype)) { + : array_desc_(std::make_shared(Shape{}, dtype)) { auto cval = static_cast(val); init(&cval); } @@ -61,14 +61,14 @@ std::vector array::make_arrays( array::array(std::initializer_list data) : array_desc_(std::make_shared( - std::vector{static_cast(data.size())}, + Shape{static_cast(data.size())}, float32)) { init(data.begin()); } array::array(std::initializer_list data, Dtype dtype) : array_desc_(std::make_shared( - std::vector{static_cast(data.size())}, + Shape{static_cast(data.size())}, dtype)) { init(data.begin()); } @@ -322,7 +322,7 @@ array::ArrayIterator::ArrayIterator(const array& arr, int idx) } array::ArrayIterator::reference array::ArrayIterator::operator*() const { - auto start = std::vector(arr.ndim(), 0); + auto start = Shape(arr.ndim(), 0); auto end = arr.shape(); auto shape = arr.shape(); shape.erase(shape.begin()); diff --git a/mlx/array.h b/mlx/array.h index d76a1c0e0..d4ed48b6c 100644 --- a/mlx/array.h +++ b/mlx/array.h @@ -17,7 +17,8 @@ namespace mlx::core { class Primitive; using Deleter = std::function; -using Shape = std::vector; +using ShapeElem = int32_t; +using Shape = std::vector; using Strides = std::vector; class array { @@ -498,7 +499,7 @@ class array { template array::array(T val, Dtype dtype /* = TypeToDtype() */) - : array_desc_(std::make_shared(std::vector{}, dtype)) { + : array_desc_(std::make_shared(Shape{}, dtype)) { init(&val); } @@ -516,7 +517,7 @@ array::array( std::initializer_list data, Dtype dtype /* = TypeToDtype() */) : array_desc_(std::make_shared( - std::vector{static_cast(data.size())}, + Shape{static_cast(data.size())}, dtype)) { init(data.begin()); } diff --git a/mlx/backend/common/compiled.cpp b/mlx/backend/common/compiled.cpp index 4c782089a..75d679d86 100644 --- a/mlx/backend/common/compiled.cpp +++ b/mlx/backend/common/compiled.cpp @@ -130,7 +130,7 @@ std::string build_lib_name( bool compiled_check_contiguity( const std::vector& inputs, - const std::vector& shape) { + const Shape& shape) { bool contiguous = true; bool all_contig = true; bool all_row_contig = true; diff --git a/mlx/backend/common/compiled.h b/mlx/backend/common/compiled.h index a08a53e68..72959fdc9 100644 --- a/mlx/backend/common/compiled.h +++ b/mlx/backend/common/compiled.h @@ -56,7 +56,7 @@ inline bool is_scalar(const array& x) { // Check if we can use a contiguous operation given inputs and the output shape bool compiled_check_contiguity( const std::vector& inputs, - const std::vector& shape); + const Shape& shape); // Allocate space for the outputs possibly with input donation void compiled_allocate_outputs( diff --git a/mlx/backend/common/conv.cpp b/mlx/backend/common/conv.cpp index 879c0312e..3f6d324fb 100644 --- a/mlx/backend/common/conv.cpp +++ b/mlx/backend/common/conv.cpp @@ -726,7 +726,7 @@ void explicit_gemm_conv_1D_cpu( auto conv_dtype = float32; // Pad input - std::vector padded_shape = {N, iH + 2 * padding[0], C}; + Shape padded_shape = {N, iH + 2 * padding[0], C}; array in_padded(padded_shape, conv_dtype, nullptr, {}); // Fill with zeros @@ -765,7 +765,7 @@ void explicit_gemm_conv_1D_cpu( in_padded, strided_strides, flags, in_strided_view.size(), 0); // Materialize strided view - std::vector strided_reshape = {N * oH, wH * C}; + Shape strided_reshape = {N * oH, wH * C}; array in_strided(strided_reshape, in_strided_view.dtype(), nullptr, {}); copy(in_strided_view, in_strided, CopyType::General); @@ -843,8 +843,7 @@ void explicit_gemm_conv_2D_cpu( auto conv_dtype = out.dtype(); // Pad input - std::vector padded_shape = { - N, iH + 2 * padding[0], iW + 2 * padding[1], C}; + Shape padded_shape = {N, iH + 2 * padding[0], iW + 2 * padding[1], C}; array in_padded(padded_shape, conv_dtype, nullptr, {}); // Fill with zeros @@ -881,7 +880,7 @@ void explicit_gemm_conv_2D_cpu( in_padded, strided_strides, flags, in_strided_view.size(), 0); // Materialize strided view - std::vector strided_reshape = {N * oH * oW, wH * wW * C}; + Shape strided_reshape = {N * oH * oW, wH * wW * C}; array in_strided(strided_reshape, in_strided_view.dtype(), nullptr, {}); copy(in_strided_view, in_strided, CopyType::General); @@ -934,19 +933,19 @@ void explicit_gemm_conv_ND_cpu( const std::vector& wt_dilation, const bool flip) { const int N = in.shape(0); // Batch size, should be the same as out.shape(0) - const auto iDim = std::vector( - in.shape().begin() + 1, in.shape().end() - 1); // Input spatial dim - const auto oDim = std::vector( + const auto iDim = + Shape(in.shape().begin() + 1, in.shape().end() - 1); // Input spatial dim + const auto oDim = Shape( out.shape().begin() + 1, out.shape().end() - 1); // Output spatial dim const int O = wt.shape(0); // Out channels const int C = wt.shape(-1); // In channels - const auto wDim = std::vector( - wt.shape().begin() + 1, wt.shape().end() - 1); // Weight spatial dim + const auto wDim = + Shape(wt.shape().begin() + 1, wt.shape().end() - 1); // Weight spatial dim auto conv_dtype = float32; // Pad input - std::vector padded_shape(in.shape().size()); + Shape padded_shape(in.shape().size()); padded_shape.front() = N; for (size_t i = 0; i < iDim.size(); i++) { padded_shape[i + 1] = iDim[i] + 2 * padding[i]; diff --git a/mlx/backend/common/sort.cpp b/mlx/backend/common/sort.cpp index 29e4d9d5e..e2f6d48bd 100644 --- a/mlx/backend/common/sort.cpp +++ b/mlx/backend/common/sort.cpp @@ -14,10 +14,10 @@ namespace mlx::core { namespace { -template +template struct StridedIterator { using iterator_category = std::random_access_iterator_tag; - using difference_type = IdxT; + using difference_type = int32_t; using value_type = T; using reference = value_type&; using pointer = value_type*; diff --git a/mlx/backend/common/utils.h b/mlx/backend/common/utils.h index 99d49f150..5f3376fdf 100644 --- a/mlx/backend/common/utils.h +++ b/mlx/backend/common/utils.h @@ -107,7 +107,7 @@ struct ContiguousIterator { : shape_(a.shape()), strides_(a.strides()) { if (!shape_.empty()) { std::tie(shape_, strides_) = collapse_contiguous_dims(shape_, strides_); - pos_ = std::vector(shape_.size(), 0); + pos_ = Shape(shape_.size(), 0); } } diff --git a/mlx/backend/metal/conv.cpp b/mlx/backend/metal/conv.cpp index 554d280a9..3e42f7d2f 100644 --- a/mlx/backend/metal/conv.cpp +++ b/mlx/backend/metal/conv.cpp @@ -34,7 +34,7 @@ void explicit_gemm_conv_ND_gpu( int implicit_K = wt.size() / conv_params.O; int implicit_N = conv_params.O; // Prepare unfolding array - std::vector unfolded_shape{implicit_M, implicit_K}; + Shape unfolded_shape{implicit_M, implicit_K}; array in_unfolded(unfolded_shape, in.dtype(), nullptr, {}); in_unfolded.set_data(allocator::malloc_or_wait(in_unfolded.nbytes())); @@ -113,7 +113,7 @@ void explicit_gemm_conv_group_ND_gpu( } // Prepare unfolding array - std::vector unfolded_shape{implicit_M, implicit_K * groups}; + Shape unfolded_shape{implicit_M, implicit_K * groups}; array in_unfolded(unfolded_shape, in.dtype(), nullptr, {}); in_unfolded.set_data(allocator::malloc_or_wait(in_unfolded.nbytes())); @@ -192,12 +192,12 @@ void conv_1D_gpu( bool flip) { // Make conv params MLXConvParams<1> conv_params{ - /* const int N = */ in.shape(0), - /* const int C = */ in.shape(2), - /* const int O = */ wt.shape(0), - /* const int iS[NDIM] = */ {in.shape(1)}, - /* const int wS[NDIM] = */ {wt.shape(1)}, - /* const int oS[NDIM] = */ {out.shape(1)}, + /* const int N = */ static_cast(in.shape(0)), + /* const int C = */ static_cast(in.shape(2)), + /* const int O = */ static_cast(wt.shape(0)), + /* const int iS[NDIM] = */ {static_cast(in.shape(1))}, + /* const int wS[NDIM] = */ {static_cast(wt.shape(1))}, + /* const int oS[NDIM] = */ {static_cast(out.shape(1))}, /* const int str[NDIM] = */ {wt_strides[0]}, /* const int pad[NDIM] = */ {padding[0]}, /* const int kdil[NDIM] = */ {wt_dilation[0]}, @@ -541,7 +541,7 @@ void winograd_conv_2D_gpu( array out, const MLXConvParams<2>& conv_params, std::vector& copies_w) { - std::vector padded_shape = { + Shape padded_shape = { conv_params.N, conv_params.iS[0] + 2 * conv_params.pad[0], conv_params.iS[1] + 2 * conv_params.pad[1], @@ -550,7 +550,7 @@ void winograd_conv_2D_gpu( padded_shape[1] = 6 * ((padded_shape[1] - 2 + 5) / 6) + 2; padded_shape[2] = 6 * ((padded_shape[2] - 2 + 5) / 6) + 2; - array in_padded(padded_shape, in.dtype(), nullptr, {}); + array in_padded(std::move(padded_shape), in.dtype(), nullptr, {}); // Fill with zeros array zero_arr = array(0, in.dtype()); @@ -575,12 +575,16 @@ void winograd_conv_2D_gpu( copies_w.push_back(in_padded); MLXConvParams<2> conv_params_updated{ - /* const int N = */ in_padded.shape(0), - /* const int C = */ in_padded.shape(3), - /* const int O = */ wt.shape(0), - /* const int iS[NDIM] = */ {in_padded.shape(1), in_padded.shape(2)}, - /* const int wS[NDIM] = */ {wt.shape(1), wt.shape(2)}, - /* const int oS[NDIM] = */ {out.shape(1), out.shape(2)}, + /* const int N = */ static_cast(in_padded.shape(0)), + /* const int C = */ static_cast(in_padded.shape(3)), + /* const int O = */ static_cast(wt.shape(0)), + /* const int iS[NDIM] = */ + {static_cast(in_padded.shape(1)), + static_cast(in_padded.shape(2))}, + /* const int wS[NDIM] = */ + {static_cast(wt.shape(1)), static_cast(wt.shape(2))}, + /* const int oS[NDIM] = */ + {static_cast(out.shape(1)), static_cast(out.shape(2))}, /* const int str[NDIM] = */ {1, 1}, /* const int pad[NDIM] = */ {0, 0}, /* const int kdil[NDIM] = */ {1, 1}, @@ -607,8 +611,8 @@ void winograd_conv_2D_gpu( int N_tiles = N_tiles_n * N_tiles_h * N_tiles_w; // Do filter transform - std::vector filt_wg_shape = {8 * 8, conv_params.C, conv_params.O}; - array filt_wg(filt_wg_shape, wt.dtype(), nullptr, {}); + Shape filt_wg_shape = {8 * 8, conv_params.C, conv_params.O}; + array filt_wg(std::move(filt_wg_shape), wt.dtype(), nullptr, {}); filt_wg.set_data(allocator::malloc_or_wait(filt_wg.nbytes())); copies_w.push_back(filt_wg); { @@ -634,8 +638,8 @@ void winograd_conv_2D_gpu( } // Do input transform - std::vector inp_wg_shape = {8 * 8, N_tiles, conv_params.C}; - array inp_wg(inp_wg_shape, in.dtype(), nullptr, {}); + Shape inp_wg_shape = {8 * 8, N_tiles, conv_params.C}; + array inp_wg(std::move(inp_wg_shape), in.dtype(), nullptr, {}); inp_wg.set_data(allocator::malloc_or_wait(inp_wg.nbytes())); copies_w.push_back(inp_wg); { @@ -661,8 +665,8 @@ void winograd_conv_2D_gpu( } // Do batched gemm - std::vector out_wg_shape = {8 * 8, N_tiles, conv_params.O}; - array out_wg(out_wg_shape, in.dtype(), nullptr, {}); + Shape out_wg_shape = {8 * 8, N_tiles, conv_params.O}; + array out_wg(std::move(out_wg_shape), in.dtype(), nullptr, {}); out_wg.set_data(allocator::malloc_or_wait(out_wg.nbytes())); copies_w.push_back(out_wg); { @@ -723,12 +727,15 @@ void conv_2D_gpu( std::vector& copies) { // Make conv params MLXConvParams<2> conv_params{ - /* const int N = */ in.shape(0), - /* const int C = */ in.shape(3), - /* const int O = */ wt.shape(0), - /* const int iS[NDIM] = */ {in.shape(1), in.shape(2)}, - /* const int wS[NDIM] = */ {wt.shape(1), wt.shape(2)}, - /* const int oS[NDIM] = */ {out.shape(1), out.shape(2)}, + /* const int N = */ static_cast(in.shape(0)), + /* const int C = */ static_cast(in.shape(3)), + /* const int O = */ static_cast(wt.shape(0)), + /* const int iS[NDIM] = */ + {static_cast(in.shape(1)), static_cast(in.shape(2))}, + /* const int wS[NDIM] = */ + {static_cast(wt.shape(1)), static_cast(wt.shape(2))}, + /* const int oS[NDIM] = */ + {static_cast(out.shape(1)), static_cast(out.shape(2))}, /* const int str[NDIM] = */ {wt_strides[0], wt_strides[1]}, /* const int pad[NDIM] = */ {padding[0], padding[1]}, /* const int kdil[NDIM] = */ {wt_dilation[0], wt_dilation[1]}, @@ -800,12 +807,21 @@ void conv_3D_gpu( std::vector& copies) { // Make conv params MLXConvParams<3> conv_params{ - /* const int N = */ in.shape(0), - /* const int C = */ in.shape(4), - /* const int O = */ wt.shape(0), - /* const int iS[NDIM] = */ {in.shape(1), in.shape(2), in.shape(3)}, - /* const int wS[NDIM] = */ {wt.shape(1), wt.shape(2), wt.shape(3)}, - /* const int oS[NDIM] = */ {out.shape(1), out.shape(2), out.shape(3)}, + /* const int N = */ static_cast(in.shape(0)), + /* const int C = */ static_cast(in.shape(4)), + /* const int O = */ static_cast(wt.shape(0)), + /* const int iS[NDIM] = */ + {static_cast(in.shape(1)), + static_cast(in.shape(2)), + static_cast(in.shape(3))}, + /* const int wS[NDIM] = */ + {static_cast(wt.shape(1)), + static_cast(wt.shape(2)), + static_cast(wt.shape(3))}, + /* const int oS[NDIM] = */ + {static_cast(out.shape(1)), + static_cast(out.shape(2)), + static_cast(out.shape(3))}, /* const int str[NDIM] = */ {wt_strides[0], wt_strides[1], wt_strides[2]}, /* const int pad[NDIM] = */ {padding[0], padding[1], padding[2]}, /* const int kdil[NDIM] = */ diff --git a/mlx/backend/metal/reduce.cpp b/mlx/backend/metal/reduce.cpp index 05deff057..36e1266e6 100644 --- a/mlx/backend/metal/reduce.cpp +++ b/mlx/backend/metal/reduce.cpp @@ -635,7 +635,7 @@ void strided_reduce_longcolumn( } // Prepare the temporary accumulator - std::vector intermediate_shape; + Shape intermediate_shape; intermediate_shape.reserve(out.ndim() + 1); intermediate_shape.push_back(outer_blocks); intermediate_shape.insert( @@ -806,7 +806,7 @@ void strided_reduce_2pass( auto [in_type, out_type] = remap_reduce_types(in, op_name); // Prepare the temporary accumulator - std::vector intermediate_shape; + Shape intermediate_shape; intermediate_shape.reserve(out.ndim() + 1); intermediate_shape.push_back(32); intermediate_shape.insert( diff --git a/mlx/backend/metal/slicing.cpp b/mlx/backend/metal/slicing.cpp index 26664b78e..b37d9c316 100644 --- a/mlx/backend/metal/slicing.cpp +++ b/mlx/backend/metal/slicing.cpp @@ -63,8 +63,8 @@ void pad_gpu( const array& in, const array& val, array& out, - std::vector axes, - std::vector low_pad_size, + const std::vector& axes, + const Shape& low_pad_size, const Stream& s) { // Fill output with val fill_gpu(val, out, s); diff --git a/mlx/backend/metal/slicing.h b/mlx/backend/metal/slicing.h index 5c62b7b73..7c48214a4 100644 --- a/mlx/backend/metal/slicing.h +++ b/mlx/backend/metal/slicing.h @@ -23,8 +23,8 @@ void pad_gpu( const array& in, const array& val, array& out, - std::vector axes, - std::vector low_pad_size, + const std::vector& axes, + const Shape& low_pad_size, const Stream& s); } // namespace mlx::core diff --git a/mlx/distributed/ops.cpp b/mlx/distributed/ops.cpp index 64b1cbae1..8f3148778 100644 --- a/mlx/distributed/ops.cpp +++ b/mlx/distributed/ops.cpp @@ -82,7 +82,7 @@ array send( } array recv( - std::vector shape, + Shape shape, Dtype dtype, int src, std::optional group_ /* = std::nullopt */, diff --git a/mlx/distributed/ops.h b/mlx/distributed/ops.h index 5e9a06515..9430106b1 100644 --- a/mlx/distributed/ops.h +++ b/mlx/distributed/ops.h @@ -26,7 +26,7 @@ array send( StreamOrDevice s = {}); array recv( - std::vector shape, + Shape shape, Dtype dtype, int src, std::optional group = std::nullopt, diff --git a/mlx/distributed/primitives.cpp b/mlx/distributed/primitives.cpp index 7d8499b99..84ce86ffa 100644 --- a/mlx/distributed/primitives.cpp +++ b/mlx/distributed/primitives.cpp @@ -91,7 +91,7 @@ std::vector AllGather::vjp( const std::vector& argnums, const std::vector& outputs) { auto g = group(); - std::vector starts(primals[0].ndim(), 0); + Shape starts(primals[0].ndim(), 0); auto stops = primals[0].shape(); starts[0] = g.rank() * stops[0]; stops[0] += starts[0]; diff --git a/mlx/einsum.cpp b/mlx/einsum.cpp index ce14b2315..2858f9110 100644 --- a/mlx/einsum.cpp +++ b/mlx/einsum.cpp @@ -108,7 +108,7 @@ bool disjoint(const CharSet& x, const CharSet& y) { } template -size_t term_size(const T& term, std::unordered_map dict) { +size_t term_size(const T& term, std::unordered_map dict) { size_t size = 1; for (auto c : term) { size *= dict[c]; @@ -120,7 +120,7 @@ size_t flop_count( const CharSet& term, bool inner, int num_terms, - std::unordered_map dict) { + std::unordered_map dict) { size_t size = term_size(term, dict); auto op_factor = 1; if ((num_terms - 1) > op_factor) { @@ -135,7 +135,7 @@ size_t flop_count( std::pair compute_cost_and_scaling( const std::vector& inputs, const Subscript& output, - std::unordered_map dim_map) { + std::unordered_map dim_map) { CharSet contractions; for (auto& in : inputs) { contractions.insert(in.set.begin(), in.set.end()); @@ -155,7 +155,7 @@ std::pair compute_cost_and_scaling( std::tuple, size_t, int> greedy_path( std::vector inputs, const Subscript& output, - std::unordered_map dim_map, + std::unordered_map dim_map, size_t cost_limit, size_t memory_limit) { // Helper struct for building the greedy path @@ -457,7 +457,8 @@ array collapse_repeats(array in, Subscript& subscript, StreamOrDevice s) { } Shape idx_shape(n_expand--, 1); idx_shape[0] = in.shape(axes.back()); - auto idx = reshape(arange(in.shape(axes.back()), s), idx_shape, s); + auto idx = reshape( + arange(static_cast(in.shape(axes.back())), s), idx_shape, s); for (int i = 0; i < v; ++i) { indices.push_back(idx); } @@ -663,7 +664,7 @@ std::pair, PathInfo> einsum_path_helper( } Subscript output(out_subscript, std::move(out_set)); - std::unordered_map dim_map; + std::unordered_map dim_map; std::vector inputs; for (int i = 0; i < in_subscripts.size(); ++i) { auto& in = in_subscripts[i]; @@ -680,7 +681,7 @@ std::pair, PathInfo> einsum_path_helper( // Check repeat subscripts are valid if (in_set.size() < in.size()) { - std::unordered_map local_dims; + std::unordered_map local_dims; for (int j = 0; j < in.size(); ++j) { auto dim = operands[i].shape(j); auto inserted = local_dims.insert({in[j], dim}); diff --git a/mlx/fast.cpp b/mlx/fast.cpp index c7dfb36e6..01089d82d 100644 --- a/mlx/fast.cpp +++ b/mlx/fast.cpp @@ -670,8 +670,7 @@ array scaled_dot_product_attention( supports_sdpa_full || supports_sdpa_vector; if (implementation_supports_use_case) { - auto out_shape = - std::vector({q.shape(0), q.shape(1), q.shape(2), v.shape(-1)}); + auto out_shape = Shape{q.shape(0), q.shape(1), q.shape(2), v.shape(-1)}; return array( std::move(out_shape), final_type, diff --git a/mlx/fast.h b/mlx/fast.h index ddc3512b5..0b9608eec 100644 --- a/mlx/fast.h +++ b/mlx/fast.h @@ -59,7 +59,7 @@ typedef std::variant TemplateArg; typedef std::function( const std::vector&, - const std::vector>&, + const std::vector&, const std::vector&, std::tuple, std::tuple, diff --git a/mlx/io/gguf.cpp b/mlx/io/gguf.cpp index c1a1d03bf..ed6ea11dd 100644 --- a/mlx/io/gguf.cpp +++ b/mlx/io/gguf.cpp @@ -47,8 +47,8 @@ std::optional gguf_type_to_dtype(const uint32_t& gguf_type) { } } -std::vector get_shape(const gguf_tensor& tensor) { - std::vector shape; +Shape get_shape(const gguf_tensor& tensor) { + Shape shape; // The dimension order in GGML is the reverse of the order used in MLX. for (int i = tensor.ndim - 1; i >= 0; i--) { shape.push_back(tensor.dim[i]); diff --git a/mlx/io/gguf.h b/mlx/io/gguf.h index 170fd6b0a..fa5bc458d 100644 --- a/mlx/io/gguf.h +++ b/mlx/io/gguf.h @@ -12,7 +12,7 @@ extern "C" { namespace mlx::core { -std::vector get_shape(const gguf_tensor& tensor); +Shape get_shape(const gguf_tensor& tensor); void gguf_load_quantized( std::unordered_map& a, const gguf_tensor& tensor); diff --git a/mlx/io/gguf_quants.cpp b/mlx/io/gguf_quants.cpp index 8e6a5b2f9..148ed6c47 100644 --- a/mlx/io/gguf_quants.cpp +++ b/mlx/io/gguf_quants.cpp @@ -109,7 +109,7 @@ void gguf_load_quantized( std::string name(tensor.name, tensor.namelen); - std::vector shape = get_shape(tensor); + auto shape = get_shape(tensor); const uint64_t weights_per_block = 32; if (shape[shape.size() - 1] % weights_per_block != 0) { std::ostringstream msg; @@ -118,7 +118,7 @@ void gguf_load_quantized( throw std::runtime_error(msg.str()); } - std::vector weights_shape = shape; + auto weights_shape = shape; weights_shape.back() /= (weights_per_byte * 4); auto w_nbytes = uint32.size() * std::accumulate(weights_shape.begin(), diff --git a/mlx/io/load.cpp b/mlx/io/load.cpp index bde12fd19..51c5805d9 100644 --- a/mlx/io/load.cpp +++ b/mlx/io/load.cpp @@ -271,7 +271,7 @@ array load(std::shared_ptr in_stream, StreamOrDevice s) { bool col_contiguous = header[34] == 'T'; // Read array shape from header - std::vector shape; + Shape shape; size_t st = header.find_last_of('(') + 1; size_t ed = header.find_last_of(')'); diff --git a/mlx/linalg.cpp b/mlx/linalg.cpp index daf5573fc..5a69f1eae 100644 --- a/mlx/linalg.cpp +++ b/mlx/linalg.cpp @@ -219,15 +219,15 @@ std::vector svd(const array& a, StreamOrDevice s /* = {} */) { const auto n = a.shape(-1); const auto rank = a.ndim(); - std::vector u_shape = a.shape(); + auto u_shape = a.shape(); u_shape[rank - 2] = m; u_shape[rank - 1] = m; - std::vector s_shape = a.shape(); + auto s_shape = a.shape(); s_shape.pop_back(); s_shape[rank - 2] = std::min(m, n); - std::vector vt_shape = a.shape(); + auto vt_shape = a.shape(); vt_shape[rank - 2] = n; vt_shape[rank - 1] = n; @@ -328,8 +328,8 @@ array pinv(const array& a, StreamOrDevice s /* = {} */) { array S = outs[1]; array V = outs[2]; - std::vector starts(a.ndim(), 0); - std::vector ends = a.shape(); + Shape starts(a.ndim(), 0); + auto ends = a.shape(); int i = a.ndim() - 2; int j = a.ndim() - 1; @@ -479,7 +479,7 @@ array eigvalsh( std::string UPLO /* = "L" */, StreamOrDevice s /* = {} */) { validate_eigh(a, "[linalg::eigvalsh]"); - std::vector out_shape(a.shape().begin(), a.shape().end() - 1); + Shape out_shape(a.shape().begin(), a.shape().end() - 1); return array( std::move(out_shape), a.dtype(), @@ -493,7 +493,7 @@ std::pair eigh( StreamOrDevice s /* = {} */) { validate_eigh(a, "[linalg::eigh]"); auto out = array::make_arrays( - {std::vector(a.shape().begin(), a.shape().end() - 1), a.shape()}, + {Shape(a.shape().begin(), a.shape().end() - 1), a.shape()}, {a.dtype(), a.dtype()}, std::make_shared(to_stream(s), UPLO, true), {a}); diff --git a/mlx/ops.cpp b/mlx/ops.cpp index a0a259580..3b54b43af 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -649,7 +649,7 @@ normalize_slice(const Shape& shape, Shape& start, Shape& stop, Shape& strides) { // Clamp to bounds auto st = std::min(s, n - 1); - auto ed = std::max(-1, e); + auto ed = e > -1 ? e : -1; start[i] = st; stop[i] = ed > st ? st : ed; @@ -659,8 +659,8 @@ normalize_slice(const Shape& shape, Shape& start, Shape& stop, Shape& strides) { } else { // Clamp to bounds - auto st = std::max(0, std::min(s, n)); - auto ed = std::max(0, std::min(e, n)); + auto st = std::max(static_cast(0), std::min(s, n)); + auto ed = std::max(static_cast(0), std::min(e, n)); start[i] = st; stop[i] = ed < st ? st : ed; @@ -765,7 +765,7 @@ array slice_update( std::vector split( const array& a, - const std::vector& indices, + const Shape& indices, int axis, StreamOrDevice s /* = {} */) { auto ax = axis < 0 ? axis + a.ndim() : axis; @@ -809,10 +809,8 @@ std::vector split( return res; } -std::vector split( - const array& a, - const std::vector& indices, - StreamOrDevice s /* = {} */) { +std::vector +split(const array& a, const Shape& indices, StreamOrDevice s /* = {} */) { return split(a, indices, 0, s); } @@ -834,7 +832,7 @@ split(const array& a, int num_splits, int axis, StreamOrDevice s /* = {} */) { throw std::invalid_argument(msg.str()); } auto split_size = q_and_r.quot; - std::vector indices(num_splits - 1); + Shape indices(num_splits - 1); for (int i = 0; i < indices.size(); ++i) { indices[i] = (i + 1) * split_size; } @@ -1104,7 +1102,7 @@ array edge_pad( /** Pad an array with a constant value */ array pad( const array& a, - const Shape& axes, + const std::vector& axes, const Shape& low_pad_size, const Shape& high_pad_size, const array& pad_value /*= array(0)*/, @@ -1904,9 +1902,11 @@ array min( array argmin(const array& a, bool keepdims, StreamOrDevice s /* = {} */) { int size = a.size(); - auto result = argmin(reshape(a, {size}, s), 0, true, s); + auto result = argmin(flatten(a, s), 0, true, s); if (keepdims) { - result = reshape(result, std::vector(a.shape().size(), 1), s); + std::vector axes(a.ndim() - 1); + std::iota(axes.begin(), axes.end(), 0); + result = expand_dims(result, axes, s); } else { result = squeeze(result, s); } @@ -1940,9 +1940,11 @@ array argmin( array argmax(const array& a, bool keepdims, StreamOrDevice s /* = {} */) { int size = a.size(); - auto result = argmax(reshape(a, {size}, s), 0, true, s); + auto result = argmax(flatten(a, s), 0, true, s); if (keepdims) { - result = reshape(result, Shape(a.shape().size(), 1), s); + std::vector axes(a.ndim() - 1); + std::iota(axes.begin(), axes.end(), 0); + result = expand_dims(result, axes, s); } else { result = squeeze(result, s); } @@ -3238,8 +3240,8 @@ inline int dilate_size(int dim, int dil) { } Shape conv_out_shape( - const std::vector& in_shape, - const std::vector& wt_shape, + const Shape& in_shape, + const Shape& wt_shape, const std::vector& strides, const std::vector& pads_lo, const std::vector& pads_hi, @@ -4329,16 +4331,16 @@ array diagonal( "[diagonal] axis1 and axis2 cannot be the same axis"); } - auto off1 = std::max(-offset, 0); - auto off2 = std::max(offset, 0); + ShapeElem off1 = std::max(-offset, 0); + ShapeElem off2 = std::max(offset, 0); auto diag_size = std::min(a.shape(ax1) - off1, a.shape(ax2) - off2); - diag_size = std::max(diag_size, 0); + diag_size = diag_size < 0 ? 0 : diag_size; std::vector indices = { arange(off1, off1 + diag_size, s), arange(off2, off2 + diag_size, s)}; - std::vector slice_sizes = a.shape(); + Shape slice_sizes = a.shape(); slice_sizes[ax1] = 1; slice_sizes[ax2] = 1; diff --git a/mlx/ops.h b/mlx/ops.h index 7576774b5..d6e456c88 100644 --- a/mlx/ops.h +++ b/mlx/ops.h @@ -189,13 +189,10 @@ array slice_update( std::vector split(const array& a, int num_splits, int axis, StreamOrDevice s = {}); std::vector split(const array& a, int num_splits, StreamOrDevice s = {}); -std::vector split( - const array& a, - const std::vector& indices, - int axis, - StreamOrDevice s = {}); std::vector -split(const array& a, const std::vector& indices, StreamOrDevice s = {}); +split(const array& a, const Shape& indices, int axis, StreamOrDevice s = {}); +std::vector +split(const array& a, const Shape& indices, StreamOrDevice s = {}); /** A vector of coordinate arrays from coordinate vectors. */ std::vector meshgrid( @@ -253,8 +250,8 @@ array moveaxis( array pad( const array& a, const std::vector& axes, - const std::vector& low_pad_size, - const std::vector& high_pad_size, + const Shape& low_pad_size, + const Shape& high_pad_size, const array& pad_value = array(0), const std::string mode = "constant", StreamOrDevice s = {}); @@ -1453,7 +1450,11 @@ array view(const array& a, const Dtype& dtype, StreamOrDevice s = {}); array roll(const array& a, int shift, StreamOrDevice s = {}); array roll(const array& a, const Shape& shift, StreamOrDevice s = {}); array roll(const array& a, int shift, int axis, StreamOrDevice s = {}); -array roll(const array& a, int shift, const Shape& axes, StreamOrDevice s = {}); +array roll( + const array& a, + int shift, + const std::vector& axes, + StreamOrDevice s = {}); array roll(const array& a, const Shape& shift, int axis, StreamOrDevice s = {}); array roll( const array& a, diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index 9603e3cf1..aa8f16c9f 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -817,10 +817,10 @@ std::vector Concatenate::vjp( const std::vector& argnums, const std::vector&) { auto& cotan = cotangents[0]; - std::vector start(cotan.ndim(), 0); - std::vector stop = cotan.shape(); + Shape start(cotan.ndim(), 0); + Shape stop = cotan.shape(); - std::vector sizes; + Shape sizes; sizes.push_back(0); for (auto& p : primals) { sizes.push_back(p.shape(axis_)); @@ -956,9 +956,9 @@ array conv_weight_backward_patches( const std::vector& padding, StreamOrDevice s) { // Resolve Padded input shapes and strides - std::vector padding_starts(in.ndim(), 0); - std::vector padding_ends = in.shape(); - std::vector in_padded_shape = in.shape(); + Shape padding_starts(in.ndim(), 0); + auto padding_ends = in.shape(); + auto in_padded_shape = in.shape(); // padded shape for (int i = 1; i < in.ndim() - 1; i++) { @@ -976,8 +976,9 @@ array conv_weight_backward_patches( // Pad input std::vector padded_axes(in.ndim() - 2, 0); std::iota(padded_axes.begin(), padded_axes.end(), 1); + Shape padding_(padding.begin(), padding.end()); auto in_padded = pad( - in, padded_axes, padding, padding, array(0, in.dtype()), "constant", s); + in, padded_axes, padding_, padding_, array(0, in.dtype()), "constant", s); // Resolve strided patches @@ -1797,7 +1798,7 @@ std::vector FFT::vjp( std::vector axes(axes_.begin(), axes_.end()); if (real_ && inverse_) { auto out = fft::fftn(cotangents[0], axes, stream()); - auto start = std::vector(out.ndim(), 0); + auto start = Shape(out.ndim(), 0); auto stop = in.shape(); out = slice(out, start, stop, stream()); auto mask_shape = out.shape(); @@ -1809,7 +1810,7 @@ std::vector FFT::vjp( mask = concatenate({pad, mask, pad}, axes_.back(), stream()); return {multiply(mask, out, stream())}; } else if (real_) { - std::vector n; + Shape n; for (auto ax : axes_) { n.push_back(in.shape()[ax]); } @@ -1934,10 +1935,11 @@ std::pair, std::vector> Gather::vmap( } if (indices_vmapped) { // Make a new index array for the vmapped dimension - auto vmap_inds = arange(0, src.shape(axes[0]), stream()); + auto vmap_inds = + arange(static_cast(0), src.shape(axes[0]), stream()); // Reshape it so it broadcasts with other index arrays { - auto shape = std::vector(idx_dims, 1); + auto shape = Shape(idx_dims, 1); shape[out_ax] = vmap_inds.size(); vmap_inds = reshape(vmap_inds, std::move(shape), stream()); } @@ -2628,8 +2630,8 @@ std::vector Pad::vjp( assert(argnums.size() == 1 && argnums[0] == 0); auto& cotan = cotangents[0]; - std::vector start(cotan.ndim(), 0); - std::vector stop = cotan.shape(); + Shape start(cotan.ndim(), 0); + auto stop = cotan.shape(); for (auto i : axes_) { start[i] = low_pad_size_[i]; @@ -3019,7 +3021,7 @@ std::vector Reduce::vjp( const std::vector& outputs) { auto in = primals[0]; - std::vector shape = in.shape(); + auto shape = in.shape(); for (auto ax : axes_) { shape[ax] = 1; } @@ -3044,7 +3046,7 @@ std::vector Reduce::vjp( if (axes_.size() > 1) { std::vector transpose_to; std::vector transpose_back; - std::vector shape_flat; + Shape shape_flat; { // Find the transpose needed to move axes_ to the back and the shape // except the reduced over axes. @@ -3422,7 +3424,7 @@ std::pair, std::vector> Scatter::vmap( } auto vmap_inds = arange(vmap_size, inputs[1].dtype(), stream()); - auto vmap_inds_shape = std::vector(inputs[1].ndim(), 1); + auto vmap_inds_shape = Shape(inputs[1].ndim(), 1); vmap_inds_shape[0] = vmap_inds.size(); vmap_inds = reshape(vmap_inds, std::move(vmap_inds_shape), stream()); inputs.insert( @@ -3607,7 +3609,7 @@ std::vector Slice::vjp( // Transpose and reshape cotangents auto cotan = cotangents[0]; if (!ind_axes.empty()) { - std::vector cotan_shape; + Shape cotan_shape; for (auto ax : ind_axes) { cotan_shape.push_back(cotan.shape(ax)); } @@ -3626,7 +3628,7 @@ std::vector Slice::vjp( } // Make indices broadcastable - std::vector inds_shape(inds.size(), 1); + Shape inds_shape(inds.size(), 1); for (int i = 0; i < inds.size(); ++i) { inds_shape[i] = inds[i].size(); inds[i] = reshape(inds[i], inds_shape, stream()); @@ -4184,7 +4186,7 @@ std::vector BlockMaskedMM::vjp( // Slice mask mask_reshape[mask_ndim - 2] = Y; mask_reshape[mask_ndim - 1] = X; - mask = slice(mask, std::vector(mask_ndim, 0), mask_reshape, stream()); + mask = slice(mask, Shape(mask_ndim, 0), mask_reshape, stream()); return mask; }; @@ -4202,7 +4204,7 @@ std::vector BlockMaskedMM::vjp( } // Reshape - std::vector r_reshape(r.shape().begin(), r.shape().end() - 2); + Shape r_reshape(r.shape().begin(), r.shape().end() - 2); r_reshape.push_back(r.shape(-2) / block_size_); r_reshape.push_back(block_size_); r_reshape.push_back(r.shape(-1) / block_size_); @@ -4492,7 +4494,7 @@ std::pair, std::vector> NumberOfElements::vmap( } array out = array( - std::vector{}, + {}, dtype_, std::make_shared(stream(), new_axes, inverted_, dtype_), inputs); diff --git a/mlx/primitives.h b/mlx/primitives.h index 55a87cf18..88b7a63ed 100644 --- a/mlx/primitives.h +++ b/mlx/primitives.h @@ -1088,10 +1088,7 @@ class Full : public UnaryPrimitive { class Gather : public UnaryPrimitive { public: - explicit Gather( - Stream stream, - std::vector axes, - std::vector slice_sizes) + explicit Gather(Stream stream, std::vector axes, Shape slice_sizes) : UnaryPrimitive(stream), axes_(std::move(axes)), slice_sizes_(std::move(slice_sizes)) {} @@ -1108,7 +1105,7 @@ class Gather : public UnaryPrimitive { private: void eval(const std::vector& inputs, array& out); std::vector axes_; - std::vector slice_sizes_; + Shape slice_sizes_; }; class Greater : public UnaryPrimitive { @@ -1503,8 +1500,8 @@ class Pad : public UnaryPrimitive { explicit Pad( Stream stream, const std::vector& axes, - const std::vector& low_pad_size, - const std::vector& high_pad_size) + const Shape& low_pad_size, + const Shape& high_pad_size) : UnaryPrimitive(stream), axes_(axes), low_pad_size_(low_pad_size), @@ -1520,8 +1517,8 @@ class Pad : public UnaryPrimitive { private: std::vector axes_; - std::vector low_pad_size_; - std::vector high_pad_size_; + Shape low_pad_size_; + Shape high_pad_size_; void eval(const std::vector& inputs, array& out); }; @@ -1903,9 +1900,9 @@ class Slice : public UnaryPrimitive { public: explicit Slice( Stream stream, - const std::vector& start_indices, - const std::vector& end_indices, - const std::vector& strides) + const Shape& start_indices, + const Shape& end_indices, + const Shape& strides) : UnaryPrimitive(stream), start_indices_(start_indices), end_indices_(end_indices), @@ -1920,9 +1917,9 @@ class Slice : public UnaryPrimitive { bool is_equivalent(const Primitive& other) const override; private: - std::vector start_indices_; - std::vector end_indices_; - std::vector strides_; + Shape start_indices_; + Shape end_indices_; + Shape strides_; void eval(const std::vector& inputs, array& out); }; @@ -1931,9 +1928,9 @@ class SliceUpdate : public UnaryPrimitive { public: explicit SliceUpdate( Stream stream, - const std::vector& start_indices, - const std::vector& end_indices, - const std::vector& strides) + const Shape& start_indices, + const Shape& end_indices, + const Shape& strides) : UnaryPrimitive(stream), start_indices_(start_indices), end_indices_(end_indices), @@ -1948,9 +1945,9 @@ class SliceUpdate : public UnaryPrimitive { bool is_equivalent(const Primitive& other) const override; private: - std::vector start_indices_; - std::vector end_indices_; - std::vector strides_; + Shape start_indices_; + Shape end_indices_; + Shape strides_; void eval(const std::vector& inputs, array& out); }; @@ -1997,7 +1994,7 @@ class Sort : public UnaryPrimitive { class Split : public Primitive { public: - explicit Split(Stream stream, const std::vector& indices, int axis) + explicit Split(Stream stream, const Shape& indices, int axis) : Primitive(stream), indices_(indices), axis_(axis) {} void eval_cpu(const std::vector& inputs, std::vector& outputs) @@ -2013,7 +2010,7 @@ class Split : public Primitive { private: void eval(const std::vector& inputs, std::vector& outputs); - std::vector indices_; + Shape indices_; int axis_; }; diff --git a/mlx/utils.cpp b/mlx/utils.cpp index 6d05ad5f8..3471ef566 100644 --- a/mlx/utils.cpp +++ b/mlx/utils.cpp @@ -296,7 +296,7 @@ std::ostream& operator<<(std::ostream& os, array a) { return os; } -std::ostream& operator<<(std::ostream& os, const Shape& v) { +std::ostream& operator<<(std::ostream& os, const std::vector& v) { os << "("; for (int i = 0; i < v.size(); ++i) { os << v[i] << ((i == v.size() - 1) ? "" : ","); @@ -305,7 +305,7 @@ std::ostream& operator<<(std::ostream& os, const Shape& v) { return os; } -std::ostream& operator<<(std::ostream& os, const Strides& v) { +std::ostream& operator<<(std::ostream& os, const std::vector& v) { os << "("; for (int i = 0; i < v.size(); ++i) { os << v[i] << ((i == v.size() - 1) ? "" : ","); diff --git a/mlx/utils.h b/mlx/utils.h index 04f59feaa..730bf0315 100644 --- a/mlx/utils.h +++ b/mlx/utils.h @@ -77,8 +77,8 @@ std::ostream& operator<<(std::ostream& os, const Stream& s); std::ostream& operator<<(std::ostream& os, const Dtype& d); std::ostream& operator<<(std::ostream& os, const Dtype::Kind& k); std::ostream& operator<<(std::ostream& os, array a); -std::ostream& operator<<(std::ostream& os, const Shape& v); -std::ostream& operator<<(std::ostream& os, const Strides& v); +std::ostream& operator<<(std::ostream& os, const std::vector& v); +std::ostream& operator<<(std::ostream& os, const std::vector& v); inline std::ostream& operator<<(std::ostream& os, const complex64_t& v) { return os << v.real() << (v.imag() >= 0 ? "+" : "") << v.imag() << "j"; } diff --git a/python/src/array.cpp b/python/src/array.cpp index d1c56ae55..f35236ede 100644 --- a/python/src/array.cpp +++ b/python/src/array.cpp @@ -889,13 +889,13 @@ void init_array(nb::module_& m) { .def( "reshape", [](const mx::array& a, nb::args shape_, mx::StreamOrDevice s) { - std::vector shape; + mx::Shape shape; if (!nb::isinstance(shape_[0])) { - shape = nb::cast>(shape_[0]); + shape = nb::cast(shape_[0]); } else { - shape = nb::cast>(shape_); + shape = nb::cast(shape_); } - return mx::reshape(a, shape, s); + return mx::reshape(a, std::move(shape), s); }, "shape"_a, "stream"_a = nb::none(), @@ -1182,14 +1182,14 @@ void init_array(nb::module_& m) { .def( "split", [](const mx::array& a, - const std::variant>& indices_or_sections, + const std::variant& indices_or_sections, int axis, mx::StreamOrDevice s) { if (auto pv = std::get_if(&indices_or_sections); pv) { return mx::split(a, *pv, axis, s); } else { return mx::split( - a, std::get>(indices_or_sections), axis, s); + a, std::get(indices_or_sections), axis, s); } }, "indices_or_sections"_a, diff --git a/python/src/fast.cpp b/python/src/fast.cpp index 103c5e76d..5f04f3e69 100644 --- a/python/src/fast.cpp +++ b/python/src/fast.cpp @@ -181,7 +181,7 @@ void init_fast(nb::module_& parent_module) { return nb::cpp_function( [kernel = std::move(kernel)]( const std::vector& inputs_, - const std::vector>& output_shapes, + const std::vector& output_shapes, const std::vector& output_dtypes, std::tuple grid, std::tuple threadgroup, diff --git a/python/src/fft.cpp b/python/src/fft.cpp index 986cd8f67..5ad4702e2 100644 --- a/python/src/fft.cpp +++ b/python/src/fft.cpp @@ -79,7 +79,7 @@ void init_fft(nb::module_& parent_module) { m.def( "fft2", [](const mx::array& a, - const std::optional>& n, + const std::optional& n, const std::optional>& axes, mx::StreamOrDevice s) { if (axes.has_value() && n.has_value()) { @@ -115,7 +115,7 @@ void init_fft(nb::module_& parent_module) { m.def( "ifft2", [](const mx::array& a, - const std::optional>& n, + const std::optional& n, const std::optional>& axes, mx::StreamOrDevice s) { if (axes.has_value() && n.has_value()) { @@ -151,7 +151,7 @@ void init_fft(nb::module_& parent_module) { m.def( "fftn", [](const mx::array& a, - const std::optional>& n, + const std::optional& n, const std::optional>& axes, mx::StreamOrDevice s) { if (axes.has_value() && n.has_value()) { @@ -188,7 +188,7 @@ void init_fft(nb::module_& parent_module) { m.def( "ifftn", [](const mx::array& a, - const std::optional>& n, + const std::optional& n, const std::optional>& axes, mx::StreamOrDevice s) { if (axes.has_value() && n.has_value()) { @@ -294,7 +294,7 @@ void init_fft(nb::module_& parent_module) { m.def( "rfft2", [](const mx::array& a, - const std::optional>& n, + const std::optional& n, const std::optional>& axes, mx::StreamOrDevice s) { if (axes.has_value() && n.has_value()) { @@ -336,7 +336,7 @@ void init_fft(nb::module_& parent_module) { m.def( "irfft2", [](const mx::array& a, - const std::optional>& n, + const std::optional& n, const std::optional>& axes, mx::StreamOrDevice s) { if (axes.has_value() && n.has_value()) { @@ -378,7 +378,7 @@ void init_fft(nb::module_& parent_module) { m.def( "rfftn", [](const mx::array& a, - const std::optional>& n, + const std::optional& n, const std::optional>& axes, mx::StreamOrDevice s) { if (axes.has_value() && n.has_value()) { @@ -420,7 +420,7 @@ void init_fft(nb::module_& parent_module) { m.def( "irfftn", [](const mx::array& a, - const std::optional>& n, + const std::optional& n, const std::optional>& axes, mx::StreamOrDevice s) { if (axes.has_value() && n.has_value()) { diff --git a/python/src/indexing.cpp b/python/src/indexing.cpp index 3c042ce33..40aa2eabb 100644 --- a/python/src/indexing.cpp +++ b/python/src/indexing.cpp @@ -25,9 +25,9 @@ int get_slice_int(nb::object obj, int default_val) { } void get_slice_params( - int& starts, - int& ends, - int& strides, + mx::ShapeElem& starts, + mx::ShapeElem& ends, + mx::ShapeElem& strides, const nb::slice& in_slice, int axis_size) { // Following numpy's convention @@ -68,9 +68,9 @@ mx::array mlx_get_item_slice(const mx::array& src, const nb::slice& in_slice) { return src; } - std::vector starts(src.ndim(), 0); - std::vector ends = src.shape(); - std::vector strides(src.ndim(), 1); + mx::Shape starts(src.ndim(), 0); + auto ends = src.shape(); + mx::Shape strides(src.ndim(), 1); // Check and update slice params get_slice_params(starts[0], ends[0], strides[0], in_slice, ends[0]); @@ -119,7 +119,7 @@ mx::array mlx_gather_nd( auto& idx = indices[i]; if (nb::isinstance(idx)) { - int start, end, stride; + mx::ShapeElem start, end, stride; get_slice_params( start, end, stride, nb::cast(idx), src.shape(i)); @@ -168,7 +168,7 @@ mx::array mlx_gather_nd( // Do the gather std::vector axes(indices.size()); std::iota(axes.begin(), axes.end(), 0); - std::vector slice_sizes = src.shape(); + auto slice_sizes = src.shape(); std::fill(slice_sizes.begin(), slice_sizes.begin() + indices.size(), 1); src = gather(src, gather_indices, axes, slice_sizes); @@ -179,9 +179,7 @@ mx::array mlx_gather_nd( return mx::squeeze(src, axes); } -auto mlx_expand_ellipsis( - const std::vector& shape, - const nb::tuple& entries) { +auto mlx_expand_ellipsis(const mx::Shape& shape, const nb::tuple& entries) { std::vector indices; // Go over all entries and note the position of ellipsis @@ -230,7 +228,8 @@ auto mlx_expand_ellipsis( for (int axis = non_none_indices_before; axis < shape.size() - non_none_indices_after; axis++) { - indices.push_back(nb::slice(0, shape[axis], 1)); + indices.push_back( + nb::slice(mx::ShapeElem{0}, shape[axis], mx::ShapeElem{1})); non_none_indices++; } } @@ -371,9 +370,9 @@ mx::array mlx_get_item_nd(mx::array src, const nb::tuple& entries) { // Slice handling { - std::vector starts(src.ndim(), 0); - std::vector ends = src.shape(); - std::vector strides(src.ndim(), 1); + mx::Shape starts(src.ndim(), 0); + auto ends = src.shape(); + mx::Shape strides(src.ndim(), 1); int axis = 0; for (auto& idx : remaining_indices) { if (!idx.is_none()) { @@ -461,8 +460,7 @@ mlx_scatter_args_int( int s = 0; for (; s < update.ndim() && update.shape(s) == 1; s++) ; - auto up_shape = - std::vector(update.shape().begin() + s, update.shape().end()); + auto up_shape = mx::Shape(update.shape().begin() + s, update.shape().end()); auto shape = src.shape(); shape[0] = 1; @@ -521,9 +519,9 @@ mlx_scatter_args_slice( {}, broadcast_to(squeeze_leading_singletons(update), src.shape()), {}}; } - int start = 0; - int end = src.shape(0); - int stride = 1; + mx::ShapeElem start = 0; + auto end = src.shape(0); + mx::ShapeElem stride = 1; // Check and update slice params get_slice_params(start, end, stride, in_slice, end); @@ -645,7 +643,7 @@ mlx_scatter_args_nd( for (int i = 0; i < indices.size(); ++i) { auto& pyidx = indices[i]; if (nb::isinstance(pyidx)) { - int start, end, stride; + mx::ShapeElem start, end, stride; auto axis_size = src.shape(ax++); get_slice_params( start, end, stride, nb::cast(pyidx), axis_size); @@ -654,7 +652,7 @@ mlx_scatter_args_nd( start = (start < 0) ? start + axis_size : start; end = (end < 0) ? end + axis_size : end; - std::vector idx_shape(idx_ndim, 1); + mx::Shape idx_shape(idx_ndim, 1); // If it's a simple slice, we only need to add the start index if (array_num >= num_arrays && num_strided_slices <= 0 && stride == 1) { diff --git a/python/src/ops.cpp b/python/src/ops.cpp index abfbbbc7c..1f5ecb5cf 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -1571,15 +1571,14 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "full", - [](const std::variant>& shape, + [](const std::variant& shape, const ScalarOrArray& vals, std::optional dtype, mx::StreamOrDevice s) { if (auto pv = std::get_if(&shape); pv) { return mx::full({*pv}, to_array(vals, dtype), s); } else { - return mx::full( - std::get>(shape), to_array(vals, dtype), s); + return mx::full(std::get(shape), to_array(vals, dtype), s); } }, "shape"_a, @@ -1606,14 +1605,14 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "zeros", - [](const std::variant>& shape, + [](const std::variant& shape, std::optional dtype, mx::StreamOrDevice s) { auto t = dtype.value_or(mx::float32); if (auto pv = std::get_if(&shape); pv) { return mx::zeros({*pv}, t, s); } else { - return mx::zeros(std::get>(shape), t, s); + return mx::zeros(std::get(shape), t, s); } }, "shape"_a, @@ -1652,14 +1651,14 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "ones", - [](const std::variant>& shape, + [](const std::variant& shape, std::optional dtype, mx::StreamOrDevice s) { auto t = dtype.value_or(mx::float32); if (auto pv = std::get_if(&shape); pv) { return mx::ones({*pv}, t, s); } else { - return mx::ones(std::get>(shape), t, s); + return mx::ones(std::get(shape), t, s); } }, "shape"_a, @@ -2481,14 +2480,14 @@ void init_ops(nb::module_& m) { m.def( "split", [](const mx::array& a, - const std::variant>& indices_or_sections, + const std::variant& indices_or_sections, int axis, mx::StreamOrDevice s) { if (auto pv = std::get_if(&indices_or_sections); pv) { return mx::split(a, *pv, axis, s); } else { return mx::split( - a, std::get>(indices_or_sections), axis, s); + a, std::get(indices_or_sections), axis, s); } }, nb::arg(), @@ -2744,9 +2743,7 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "broadcast_to", - [](const ScalarOrArray& a, - const std::vector& shape, - mx::StreamOrDevice s) { + [](const ScalarOrArray& a, const mx::Shape& shape, mx::StreamOrDevice s) { return mx::broadcast_to(to_array(a), shape, s); }, nb::arg(), @@ -4895,23 +4892,15 @@ void init_ops(nb::module_& m) { m.def( "roll", [](const mx::array& a, - const IntOrVec& shift, + const std::variant& shift, const IntOrVec& axis, mx::StreamOrDevice s) { return std::visit( [&](auto sh, auto ax) -> mx::array { - using T = decltype(ax); - using V = decltype(sh); - - if constexpr (std::is_same_v) { - throw std::invalid_argument( - "[roll] Expected two arguments but only one was given."); + if constexpr (std::is_same_v) { + return mx::roll(a, sh, s); } else { - if constexpr (std::is_same_v) { - return mx::roll(a, sh, s); - } else { - return mx::roll(a, sh, ax, s); - } + return mx::roll(a, sh, ax, s); } }, shift, diff --git a/python/src/random.cpp b/python/src/random.cpp index b67dfc219..49f3220cb 100644 --- a/python/src/random.cpp +++ b/python/src/random.cpp @@ -108,7 +108,7 @@ void init_random(nb::module_& parent_module) { "uniform", [](const ScalarOrArray& low, const ScalarOrArray& high, - const std::vector& shape, + const mx::Shape& shape, std::optional type, const std::optional& key_, mx::StreamOrDevice s) { @@ -123,7 +123,7 @@ void init_random(nb::module_& parent_module) { }, "low"_a = 0, "high"_a = 1, - "shape"_a = std::vector{}, + "shape"_a = mx::Shape{}, "dtype"_a.none() = mx::float32, "key"_a = nb::none(), "stream"_a = nb::none(), @@ -150,7 +150,7 @@ void init_random(nb::module_& parent_module) { )pbdoc"); m.def( "normal", - [](const std::vector& shape, + [](const mx::Shape& shape, std::optional type, float loc, float scale, @@ -160,7 +160,7 @@ void init_random(nb::module_& parent_module) { return mx::random::normal( shape, type.value_or(mx::float32), loc, scale, key, s); }, - "shape"_a = std::vector{}, + "shape"_a = mx::Shape{}, "dtype"_a.none() = mx::float32, "loc"_a = 0.0, "scale"_a = 1.0, @@ -185,7 +185,7 @@ void init_random(nb::module_& parent_module) { "multivariate_normal", [](const mx::array& mean, const mx::array& cov, - const std::vector& shape, + const mx::Shape& shape, std::optional type, const std::optional& key_, mx::StreamOrDevice s) { @@ -195,7 +195,7 @@ void init_random(nb::module_& parent_module) { }, "mean"_a, "cov"_a, - "shape"_a = std::vector{}, + "shape"_a = mx::Shape{}, "dtype"_a.none() = mx::float32, "key"_a = nb::none(), "stream"_a = nb::none(), @@ -227,7 +227,7 @@ void init_random(nb::module_& parent_module) { "randint", [](const ScalarOrArray& low, const ScalarOrArray& high, - const std::vector& shape, + const mx::Shape& shape, std::optional type, const std::optional& key_, mx::StreamOrDevice s) { @@ -242,7 +242,7 @@ void init_random(nb::module_& parent_module) { }, "low"_a, "high"_a, - "shape"_a = std::vector{}, + "shape"_a = mx::Shape{}, "dtype"_a.none() = mx::int32, "key"_a = nb::none(), "stream"_a = nb::none(), @@ -268,7 +268,7 @@ void init_random(nb::module_& parent_module) { m.def( "bernoulli", [](const ScalarOrArray& p_, - const std::optional> shape, + const std::optional shape, const std::optional& key_, mx::StreamOrDevice s) { auto key = key_ ? key_.value() : default_key().next(); @@ -306,7 +306,7 @@ void init_random(nb::module_& parent_module) { "truncated_normal", [](const ScalarOrArray& lower_, const ScalarOrArray& upper_, - const std::optional> shape_, + const std::optional shape_, std::optional type, const std::optional& key_, mx::StreamOrDevice s) { @@ -350,14 +350,14 @@ void init_random(nb::module_& parent_module) { )pbdoc"); m.def( "gumbel", - [](const std::vector& shape, + [](const mx::Shape& shape, std::optional type, const std::optional& key_, mx::StreamOrDevice s) { auto key = key_ ? key_.value() : default_key().next(); return mx::random::gumbel(shape, type.value_or(mx::float32), key, s); }, - "shape"_a = std::vector{}, + "shape"_a = mx::Shape{}, "dtype"_a.none() = mx::float32, "key"_a = nb::none(), "stream"_a = nb::none(), @@ -384,7 +384,7 @@ void init_random(nb::module_& parent_module) { "categorical", [](const mx::array& logits, int axis, - const std::optional> shape, + const std::optional shape, const std::optional num_samples, const std::optional& key_, mx::StreamOrDevice s) { @@ -434,7 +434,7 @@ void init_random(nb::module_& parent_module) { )pbdoc"); m.def( "laplace", - [](const std::vector& shape, + [](const mx::Shape& shape, std::optional type, float loc, float scale, @@ -444,7 +444,7 @@ void init_random(nb::module_& parent_module) { return mx::random::laplace( shape, type.value_or(mx::float32), loc, scale, key, s); }, - "shape"_a = std::vector{}, + "shape"_a = mx::Shape{}, "dtype"_a.none() = mx::float32, "loc"_a = 0.0, "scale"_a = 1.0, @@ -479,7 +479,7 @@ void init_random(nb::module_& parent_module) { return mx::random::permutation(std::get(x), axis, key, s); } }, - "shape"_a = std::vector{}, + "x"_a, "axis"_a = 0, "key"_a = nb::none(), "stream"_a = nb::none(), diff --git a/python/src/transforms.cpp b/python/src/transforms.cpp index c7fd5b4c7..569c7aabc 100644 --- a/python/src/transforms.cpp +++ b/python/src/transforms.cpp @@ -1,4 +1,5 @@ // Copyright © 2023-2024 Apple Inc. + #include #include #include diff --git a/tests/ops_tests.cpp b/tests/ops_tests.cpp index 6842717cc..3d038cd30 100644 --- a/tests/ops_tests.cpp +++ b/tests/ops_tests.cpp @@ -395,7 +395,7 @@ TEST_CASE("test split") { CHECK_EQ(out[1].shape(), Shape{8, 4}); CHECK_EQ(out[2].shape(), Shape{8, 4}); - out = split(x, std::vector{}); + out = split(x, Shape{}); CHECK_EQ(out.size(), 1); CHECK_EQ(out[0].shape(), x.shape()); @@ -405,25 +405,25 @@ TEST_CASE("test split") { CHECK_EQ(out[1].shape(), Shape{4, 12}); CHECK_EQ(out[2].shape(), Shape{1, 12}); - out = split(x, std::vector{20}); + out = split(x, Shape{20}); CHECK_EQ(out.size(), 2); CHECK_EQ(out[0].shape(), Shape{8, 12}); CHECK_EQ(out[1].shape(), Shape{0, 12}); // Negative indices - out = split(x, std::vector{-5}); + out = split(x, Shape{-5}); CHECK_EQ(out[0].shape(), Shape{3, 12}); CHECK_EQ(out[1].shape(), Shape{5, 12}); // Different axis - out = split(x, std::vector{2, 8}, 1); + out = split(x, {2, 8}, 1); CHECK_EQ(out[0].shape(), Shape{8, 2}); CHECK_EQ(out[1].shape(), Shape{8, 6}); CHECK_EQ(out[2].shape(), Shape{8, 4}); // Out of order indices x = arange(5); - out = split(x, std::vector{2, 1, 2}); + out = split(x, {2, 1, 2}); CHECK(array_equal(out[0], array({0, 1})).item()); CHECK(array_equal(out[1], array({})).item()); CHECK(array_equal(out[2], array({1})).item()); diff --git a/tests/random_tests.cpp b/tests/random_tests.cpp index 7ab72c075..1a9e1aa78 100644 --- a/tests/random_tests.cpp +++ b/tests/random_tests.cpp @@ -611,8 +611,8 @@ TEST_CASE("test categorical") { CHECK_THROWS(categorical(logits, -3)); // Invalid requested shapes - CHECK_THROWS(categorical(logits, 1, std::vector{1})); - CHECK_THROWS(categorical(logits, 1, std::vector{11})); + CHECK_THROWS(categorical(logits, 1, Shape{1})); + CHECK_THROWS(categorical(logits, 1, Shape{11})); CHECK_THROWS(categorical(logits, 1, {10, 1})); CHECK_EQ(categorical(logits, -1).shape(), Shape{10}); diff --git a/tests/vmap_tests.cpp b/tests/vmap_tests.cpp index ba5b528ae..88aac6991 100644 --- a/tests/vmap_tests.cpp +++ b/tests/vmap_tests.cpp @@ -335,8 +335,7 @@ TEST_CASE("test vmap gather") { auto fun = [](std::vector inputs) { auto src = inputs[0]; auto indices = inputs[1]; - std::vector slice_sizes = {1, 2, 2}; - auto out = squeeze(gather(src, indices, 0, slice_sizes), 2); + auto out = squeeze(gather(src, indices, 0, {1, 2, 2}), 2); return std::vector{out}; }; auto x = zeros({2, 2, 2, 2}); @@ -351,8 +350,7 @@ TEST_CASE("test vmap gather") { auto fun = [](std::vector inputs) { auto src = inputs[0]; auto indices = inputs[1]; - std::vector slice_sizes = {1, 2, 2}; - auto out = squeeze(gather(src, indices, 0, slice_sizes), 1); + auto out = squeeze(gather(src, indices, 0, {1, 2, 2}), 1); return std::vector{out}; }; auto x = zeros({2, 2, 2, 2}); @@ -365,8 +363,7 @@ TEST_CASE("test vmap gather") { auto fun = [](std::vector inputs) { auto src = inputs[0]; auto indices = inputs[1]; - std::vector slice_sizes = {1, 2, 2, 2}; - auto out = squeeze(gather(src, indices, 0, slice_sizes), 1); + auto out = squeeze(gather(src, indices, 0, {1, 2, 2, 2}), 1); return std::vector{out}; }; auto x = zeros({2, 2, 2, 2}); @@ -380,8 +377,7 @@ TEST_CASE("test vmap gather") { auto fun = [](std::vector inputs) { auto src = inputs[0]; auto indices = std::vector(inputs.begin() + 1, inputs.end()); - std::vector slice_sizes = {1, 1, 2, 2}; - auto out = squeeze(gather(src, indices, {0, 1}, slice_sizes), {1, 2}); + auto out = squeeze(gather(src, indices, {0, 1}, {1, 1, 2, 2}), {1, 2}); return std::vector{out}; }; auto x = zeros({2, 2, 2, 2});