More shape type (#1705)

* more shape type

* fix
This commit is contained in:
Awni Hannun 2024-12-19 08:08:20 -08:00 committed by GitHub
parent f17536af9c
commit e03f0372b1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
38 changed files with 260 additions and 258 deletions

View File

@ -25,7 +25,7 @@ bool retain_graph() {
} // namespace
array::array(const std::complex<float>& val, Dtype dtype /* = complex64 */)
: array_desc_(std::make_shared<ArrayDesc>(std::vector<int>{}, dtype)) {
: array_desc_(std::make_shared<ArrayDesc>(Shape{}, dtype)) {
auto cval = static_cast<complex64_t>(val);
init(&cval);
}
@ -61,14 +61,14 @@ std::vector<array> array::make_arrays(
array::array(std::initializer_list<float> data)
: array_desc_(std::make_shared<ArrayDesc>(
std::vector<int>{static_cast<int>(data.size())},
Shape{static_cast<ShapeElem>(data.size())},
float32)) {
init(data.begin());
}
array::array(std::initializer_list<int> data, Dtype dtype)
: array_desc_(std::make_shared<ArrayDesc>(
std::vector<int>{static_cast<int>(data.size())},
Shape{static_cast<ShapeElem>(data.size())},
dtype)) {
init(data.begin());
}
@ -322,7 +322,7 @@ array::ArrayIterator::ArrayIterator(const array& arr, int idx)
}
array::ArrayIterator::reference array::ArrayIterator::operator*() const {
auto start = std::vector<int>(arr.ndim(), 0);
auto start = Shape(arr.ndim(), 0);
auto end = arr.shape();
auto shape = arr.shape();
shape.erase(shape.begin());

View File

@ -17,7 +17,8 @@ namespace mlx::core {
class Primitive;
using Deleter = std::function<void(allocator::Buffer)>;
using Shape = std::vector<int32_t>;
using ShapeElem = int32_t;
using Shape = std::vector<ShapeElem>;
using Strides = std::vector<int64_t>;
class array {
@ -498,7 +499,7 @@ class array {
template <typename T>
array::array(T val, Dtype dtype /* = TypeToDtype<T>() */)
: array_desc_(std::make_shared<ArrayDesc>(std::vector<int>{}, dtype)) {
: array_desc_(std::make_shared<ArrayDesc>(Shape{}, dtype)) {
init(&val);
}
@ -516,7 +517,7 @@ array::array(
std::initializer_list<T> data,
Dtype dtype /* = TypeToDtype<T>() */)
: array_desc_(std::make_shared<ArrayDesc>(
std::vector<int>{static_cast<int>(data.size())},
Shape{static_cast<ShapeElem>(data.size())},
dtype)) {
init(data.begin());
}

View File

@ -130,7 +130,7 @@ std::string build_lib_name(
bool compiled_check_contiguity(
const std::vector<array>& inputs,
const std::vector<int>& shape) {
const Shape& shape) {
bool contiguous = true;
bool all_contig = true;
bool all_row_contig = true;

View File

@ -56,7 +56,7 @@ inline bool is_scalar(const array& x) {
// Check if we can use a contiguous operation given inputs and the output shape
bool compiled_check_contiguity(
const std::vector<array>& inputs,
const std::vector<int>& shape);
const Shape& shape);
// Allocate space for the outputs possibly with input donation
void compiled_allocate_outputs(

View File

@ -726,7 +726,7 @@ void explicit_gemm_conv_1D_cpu(
auto conv_dtype = float32;
// Pad input
std::vector<int> padded_shape = {N, iH + 2 * padding[0], C};
Shape padded_shape = {N, iH + 2 * padding[0], C};
array in_padded(padded_shape, conv_dtype, nullptr, {});
// Fill with zeros
@ -765,7 +765,7 @@ void explicit_gemm_conv_1D_cpu(
in_padded, strided_strides, flags, in_strided_view.size(), 0);
// Materialize strided view
std::vector<int> strided_reshape = {N * oH, wH * C};
Shape strided_reshape = {N * oH, wH * C};
array in_strided(strided_reshape, in_strided_view.dtype(), nullptr, {});
copy(in_strided_view, in_strided, CopyType::General);
@ -843,8 +843,7 @@ void explicit_gemm_conv_2D_cpu(
auto conv_dtype = out.dtype();
// Pad input
std::vector<int> padded_shape = {
N, iH + 2 * padding[0], iW + 2 * padding[1], C};
Shape padded_shape = {N, iH + 2 * padding[0], iW + 2 * padding[1], C};
array in_padded(padded_shape, conv_dtype, nullptr, {});
// Fill with zeros
@ -881,7 +880,7 @@ void explicit_gemm_conv_2D_cpu(
in_padded, strided_strides, flags, in_strided_view.size(), 0);
// Materialize strided view
std::vector<int> strided_reshape = {N * oH * oW, wH * wW * C};
Shape strided_reshape = {N * oH * oW, wH * wW * C};
array in_strided(strided_reshape, in_strided_view.dtype(), nullptr, {});
copy(in_strided_view, in_strided, CopyType::General);
@ -934,19 +933,19 @@ void explicit_gemm_conv_ND_cpu(
const std::vector<int>& wt_dilation,
const bool flip) {
const int N = in.shape(0); // Batch size, should be the same as out.shape(0)
const auto iDim = std::vector<int>(
in.shape().begin() + 1, in.shape().end() - 1); // Input spatial dim
const auto oDim = std::vector<int>(
const auto iDim =
Shape(in.shape().begin() + 1, in.shape().end() - 1); // Input spatial dim
const auto oDim = Shape(
out.shape().begin() + 1, out.shape().end() - 1); // Output spatial dim
const int O = wt.shape(0); // Out channels
const int C = wt.shape(-1); // In channels
const auto wDim = std::vector<int>(
wt.shape().begin() + 1, wt.shape().end() - 1); // Weight spatial dim
const auto wDim =
Shape(wt.shape().begin() + 1, wt.shape().end() - 1); // Weight spatial dim
auto conv_dtype = float32;
// Pad input
std::vector<int> padded_shape(in.shape().size());
Shape padded_shape(in.shape().size());
padded_shape.front() = N;
for (size_t i = 0; i < iDim.size(); i++) {
padded_shape[i + 1] = iDim[i] + 2 * padding[i];

View File

@ -14,10 +14,10 @@ namespace mlx::core {
namespace {
template <typename T, typename IdxT = int32_t>
template <typename T>
struct StridedIterator {
using iterator_category = std::random_access_iterator_tag;
using difference_type = IdxT;
using difference_type = int32_t;
using value_type = T;
using reference = value_type&;
using pointer = value_type*;

View File

@ -107,7 +107,7 @@ struct ContiguousIterator {
: shape_(a.shape()), strides_(a.strides()) {
if (!shape_.empty()) {
std::tie(shape_, strides_) = collapse_contiguous_dims(shape_, strides_);
pos_ = std::vector<int>(shape_.size(), 0);
pos_ = Shape(shape_.size(), 0);
}
}

View File

@ -34,7 +34,7 @@ void explicit_gemm_conv_ND_gpu(
int implicit_K = wt.size() / conv_params.O;
int implicit_N = conv_params.O;
// Prepare unfolding array
std::vector<int> unfolded_shape{implicit_M, implicit_K};
Shape unfolded_shape{implicit_M, implicit_K};
array in_unfolded(unfolded_shape, in.dtype(), nullptr, {});
in_unfolded.set_data(allocator::malloc_or_wait(in_unfolded.nbytes()));
@ -113,7 +113,7 @@ void explicit_gemm_conv_group_ND_gpu(
}
// Prepare unfolding array
std::vector<int> unfolded_shape{implicit_M, implicit_K * groups};
Shape unfolded_shape{implicit_M, implicit_K * groups};
array in_unfolded(unfolded_shape, in.dtype(), nullptr, {});
in_unfolded.set_data(allocator::malloc_or_wait(in_unfolded.nbytes()));
@ -192,12 +192,12 @@ void conv_1D_gpu(
bool flip) {
// Make conv params
MLXConvParams<1> conv_params{
/* const int N = */ in.shape(0),
/* const int C = */ in.shape(2),
/* const int O = */ wt.shape(0),
/* const int iS[NDIM] = */ {in.shape(1)},
/* const int wS[NDIM] = */ {wt.shape(1)},
/* const int oS[NDIM] = */ {out.shape(1)},
/* const int N = */ static_cast<int>(in.shape(0)),
/* const int C = */ static_cast<int>(in.shape(2)),
/* const int O = */ static_cast<int>(wt.shape(0)),
/* const int iS[NDIM] = */ {static_cast<int>(in.shape(1))},
/* const int wS[NDIM] = */ {static_cast<int>(wt.shape(1))},
/* const int oS[NDIM] = */ {static_cast<int>(out.shape(1))},
/* const int str[NDIM] = */ {wt_strides[0]},
/* const int pad[NDIM] = */ {padding[0]},
/* const int kdil[NDIM] = */ {wt_dilation[0]},
@ -541,7 +541,7 @@ void winograd_conv_2D_gpu(
array out,
const MLXConvParams<2>& conv_params,
std::vector<array>& copies_w) {
std::vector<int> padded_shape = {
Shape padded_shape = {
conv_params.N,
conv_params.iS[0] + 2 * conv_params.pad[0],
conv_params.iS[1] + 2 * conv_params.pad[1],
@ -550,7 +550,7 @@ void winograd_conv_2D_gpu(
padded_shape[1] = 6 * ((padded_shape[1] - 2 + 5) / 6) + 2;
padded_shape[2] = 6 * ((padded_shape[2] - 2 + 5) / 6) + 2;
array in_padded(padded_shape, in.dtype(), nullptr, {});
array in_padded(std::move(padded_shape), in.dtype(), nullptr, {});
// Fill with zeros
array zero_arr = array(0, in.dtype());
@ -575,12 +575,16 @@ void winograd_conv_2D_gpu(
copies_w.push_back(in_padded);
MLXConvParams<2> conv_params_updated{
/* const int N = */ in_padded.shape(0),
/* const int C = */ in_padded.shape(3),
/* const int O = */ wt.shape(0),
/* const int iS[NDIM] = */ {in_padded.shape(1), in_padded.shape(2)},
/* const int wS[NDIM] = */ {wt.shape(1), wt.shape(2)},
/* const int oS[NDIM] = */ {out.shape(1), out.shape(2)},
/* const int N = */ static_cast<int>(in_padded.shape(0)),
/* const int C = */ static_cast<int>(in_padded.shape(3)),
/* const int O = */ static_cast<int>(wt.shape(0)),
/* const int iS[NDIM] = */
{static_cast<int>(in_padded.shape(1)),
static_cast<int>(in_padded.shape(2))},
/* const int wS[NDIM] = */
{static_cast<int>(wt.shape(1)), static_cast<int>(wt.shape(2))},
/* const int oS[NDIM] = */
{static_cast<int>(out.shape(1)), static_cast<int>(out.shape(2))},
/* const int str[NDIM] = */ {1, 1},
/* const int pad[NDIM] = */ {0, 0},
/* const int kdil[NDIM] = */ {1, 1},
@ -607,8 +611,8 @@ void winograd_conv_2D_gpu(
int N_tiles = N_tiles_n * N_tiles_h * N_tiles_w;
// Do filter transform
std::vector<int> filt_wg_shape = {8 * 8, conv_params.C, conv_params.O};
array filt_wg(filt_wg_shape, wt.dtype(), nullptr, {});
Shape filt_wg_shape = {8 * 8, conv_params.C, conv_params.O};
array filt_wg(std::move(filt_wg_shape), wt.dtype(), nullptr, {});
filt_wg.set_data(allocator::malloc_or_wait(filt_wg.nbytes()));
copies_w.push_back(filt_wg);
{
@ -634,8 +638,8 @@ void winograd_conv_2D_gpu(
}
// Do input transform
std::vector<int> inp_wg_shape = {8 * 8, N_tiles, conv_params.C};
array inp_wg(inp_wg_shape, in.dtype(), nullptr, {});
Shape inp_wg_shape = {8 * 8, N_tiles, conv_params.C};
array inp_wg(std::move(inp_wg_shape), in.dtype(), nullptr, {});
inp_wg.set_data(allocator::malloc_or_wait(inp_wg.nbytes()));
copies_w.push_back(inp_wg);
{
@ -661,8 +665,8 @@ void winograd_conv_2D_gpu(
}
// Do batched gemm
std::vector<int> out_wg_shape = {8 * 8, N_tiles, conv_params.O};
array out_wg(out_wg_shape, in.dtype(), nullptr, {});
Shape out_wg_shape = {8 * 8, N_tiles, conv_params.O};
array out_wg(std::move(out_wg_shape), in.dtype(), nullptr, {});
out_wg.set_data(allocator::malloc_or_wait(out_wg.nbytes()));
copies_w.push_back(out_wg);
{
@ -723,12 +727,15 @@ void conv_2D_gpu(
std::vector<array>& copies) {
// Make conv params
MLXConvParams<2> conv_params{
/* const int N = */ in.shape(0),
/* const int C = */ in.shape(3),
/* const int O = */ wt.shape(0),
/* const int iS[NDIM] = */ {in.shape(1), in.shape(2)},
/* const int wS[NDIM] = */ {wt.shape(1), wt.shape(2)},
/* const int oS[NDIM] = */ {out.shape(1), out.shape(2)},
/* const int N = */ static_cast<int>(in.shape(0)),
/* const int C = */ static_cast<int>(in.shape(3)),
/* const int O = */ static_cast<int>(wt.shape(0)),
/* const int iS[NDIM] = */
{static_cast<int>(in.shape(1)), static_cast<int>(in.shape(2))},
/* const int wS[NDIM] = */
{static_cast<int>(wt.shape(1)), static_cast<int>(wt.shape(2))},
/* const int oS[NDIM] = */
{static_cast<int>(out.shape(1)), static_cast<int>(out.shape(2))},
/* const int str[NDIM] = */ {wt_strides[0], wt_strides[1]},
/* const int pad[NDIM] = */ {padding[0], padding[1]},
/* const int kdil[NDIM] = */ {wt_dilation[0], wt_dilation[1]},
@ -800,12 +807,21 @@ void conv_3D_gpu(
std::vector<array>& copies) {
// Make conv params
MLXConvParams<3> conv_params{
/* const int N = */ in.shape(0),
/* const int C = */ in.shape(4),
/* const int O = */ wt.shape(0),
/* const int iS[NDIM] = */ {in.shape(1), in.shape(2), in.shape(3)},
/* const int wS[NDIM] = */ {wt.shape(1), wt.shape(2), wt.shape(3)},
/* const int oS[NDIM] = */ {out.shape(1), out.shape(2), out.shape(3)},
/* const int N = */ static_cast<int>(in.shape(0)),
/* const int C = */ static_cast<int>(in.shape(4)),
/* const int O = */ static_cast<int>(wt.shape(0)),
/* const int iS[NDIM] = */
{static_cast<int>(in.shape(1)),
static_cast<int>(in.shape(2)),
static_cast<int>(in.shape(3))},
/* const int wS[NDIM] = */
{static_cast<int>(wt.shape(1)),
static_cast<int>(wt.shape(2)),
static_cast<int>(wt.shape(3))},
/* const int oS[NDIM] = */
{static_cast<int>(out.shape(1)),
static_cast<int>(out.shape(2)),
static_cast<int>(out.shape(3))},
/* const int str[NDIM] = */ {wt_strides[0], wt_strides[1], wt_strides[2]},
/* const int pad[NDIM] = */ {padding[0], padding[1], padding[2]},
/* const int kdil[NDIM] = */

View File

@ -635,7 +635,7 @@ void strided_reduce_longcolumn(
}
// Prepare the temporary accumulator
std::vector<int> intermediate_shape;
Shape intermediate_shape;
intermediate_shape.reserve(out.ndim() + 1);
intermediate_shape.push_back(outer_blocks);
intermediate_shape.insert(
@ -806,7 +806,7 @@ void strided_reduce_2pass(
auto [in_type, out_type] = remap_reduce_types(in, op_name);
// Prepare the temporary accumulator
std::vector<int> intermediate_shape;
Shape intermediate_shape;
intermediate_shape.reserve(out.ndim() + 1);
intermediate_shape.push_back(32);
intermediate_shape.insert(

View File

@ -63,8 +63,8 @@ void pad_gpu(
const array& in,
const array& val,
array& out,
std::vector<int> axes,
std::vector<int> low_pad_size,
const std::vector<int>& axes,
const Shape& low_pad_size,
const Stream& s) {
// Fill output with val
fill_gpu(val, out, s);

View File

@ -23,8 +23,8 @@ void pad_gpu(
const array& in,
const array& val,
array& out,
std::vector<int> axes,
std::vector<int> low_pad_size,
const std::vector<int>& axes,
const Shape& low_pad_size,
const Stream& s);
} // namespace mlx::core

View File

@ -82,7 +82,7 @@ array send(
}
array recv(
std::vector<int> shape,
Shape shape,
Dtype dtype,
int src,
std::optional<Group> group_ /* = std::nullopt */,

View File

@ -26,7 +26,7 @@ array send(
StreamOrDevice s = {});
array recv(
std::vector<int> shape,
Shape shape,
Dtype dtype,
int src,
std::optional<Group> group = std::nullopt,

View File

@ -91,7 +91,7 @@ std::vector<array> AllGather::vjp(
const std::vector<int>& argnums,
const std::vector<array>& outputs) {
auto g = group();
std::vector<int> starts(primals[0].ndim(), 0);
Shape starts(primals[0].ndim(), 0);
auto stops = primals[0].shape();
starts[0] = g.rank() * stops[0];
stops[0] += starts[0];

View File

@ -108,7 +108,7 @@ bool disjoint(const CharSet& x, const CharSet& y) {
}
template <typename T>
size_t term_size(const T& term, std::unordered_map<char, int> dict) {
size_t term_size(const T& term, std::unordered_map<char, ShapeElem> dict) {
size_t size = 1;
for (auto c : term) {
size *= dict[c];
@ -120,7 +120,7 @@ size_t flop_count(
const CharSet& term,
bool inner,
int num_terms,
std::unordered_map<char, int> dict) {
std::unordered_map<char, ShapeElem> dict) {
size_t size = term_size(term, dict);
auto op_factor = 1;
if ((num_terms - 1) > op_factor) {
@ -135,7 +135,7 @@ size_t flop_count(
std::pair<size_t, int> compute_cost_and_scaling(
const std::vector<Subscript>& inputs,
const Subscript& output,
std::unordered_map<char, int> dim_map) {
std::unordered_map<char, ShapeElem> dim_map) {
CharSet contractions;
for (auto& in : inputs) {
contractions.insert(in.set.begin(), in.set.end());
@ -155,7 +155,7 @@ std::pair<size_t, int> compute_cost_and_scaling(
std::tuple<std::vector<PathNode>, size_t, int> greedy_path(
std::vector<Subscript> inputs,
const Subscript& output,
std::unordered_map<char, int> dim_map,
std::unordered_map<char, ShapeElem> dim_map,
size_t cost_limit,
size_t memory_limit) {
// Helper struct for building the greedy path
@ -457,7 +457,8 @@ array collapse_repeats(array in, Subscript& subscript, StreamOrDevice s) {
}
Shape idx_shape(n_expand--, 1);
idx_shape[0] = in.shape(axes.back());
auto idx = reshape(arange(in.shape(axes.back()), s), idx_shape, s);
auto idx = reshape(
arange(static_cast<ShapeElem>(in.shape(axes.back())), s), idx_shape, s);
for (int i = 0; i < v; ++i) {
indices.push_back(idx);
}
@ -663,7 +664,7 @@ std::pair<std::vector<PathNode>, PathInfo> einsum_path_helper(
}
Subscript output(out_subscript, std::move(out_set));
std::unordered_map<char, int> dim_map;
std::unordered_map<char, ShapeElem> dim_map;
std::vector<Subscript> inputs;
for (int i = 0; i < in_subscripts.size(); ++i) {
auto& in = in_subscripts[i];
@ -680,7 +681,7 @@ std::pair<std::vector<PathNode>, PathInfo> einsum_path_helper(
// Check repeat subscripts are valid
if (in_set.size() < in.size()) {
std::unordered_map<char, int> local_dims;
std::unordered_map<char, ShapeElem> local_dims;
for (int j = 0; j < in.size(); ++j) {
auto dim = operands[i].shape(j);
auto inserted = local_dims.insert({in[j], dim});

View File

@ -670,8 +670,7 @@ array scaled_dot_product_attention(
supports_sdpa_full || supports_sdpa_vector;
if (implementation_supports_use_case) {
auto out_shape =
std::vector<int>({q.shape(0), q.shape(1), q.shape(2), v.shape(-1)});
auto out_shape = Shape{q.shape(0), q.shape(1), q.shape(2), v.shape(-1)};
return array(
std::move(out_shape),
final_type,

View File

@ -59,7 +59,7 @@ typedef std::variant<int, bool, Dtype> TemplateArg;
typedef std::function<std::vector<array>(
const std::vector<array>&,
const std::vector<std::vector<int>>&,
const std::vector<Shape>&,
const std::vector<Dtype>&,
std::tuple<int, int, int>,
std::tuple<int, int, int>,

View File

@ -47,8 +47,8 @@ std::optional<Dtype> gguf_type_to_dtype(const uint32_t& gguf_type) {
}
}
std::vector<int> get_shape(const gguf_tensor& tensor) {
std::vector<int> shape;
Shape get_shape(const gguf_tensor& tensor) {
Shape shape;
// The dimension order in GGML is the reverse of the order used in MLX.
for (int i = tensor.ndim - 1; i >= 0; i--) {
shape.push_back(tensor.dim[i]);

View File

@ -12,7 +12,7 @@ extern "C" {
namespace mlx::core {
std::vector<int> get_shape(const gguf_tensor& tensor);
Shape get_shape(const gguf_tensor& tensor);
void gguf_load_quantized(
std::unordered_map<std::string, array>& a,
const gguf_tensor& tensor);

View File

@ -109,7 +109,7 @@ void gguf_load_quantized(
std::string name(tensor.name, tensor.namelen);
std::vector<int> shape = get_shape(tensor);
auto shape = get_shape(tensor);
const uint64_t weights_per_block = 32;
if (shape[shape.size() - 1] % weights_per_block != 0) {
std::ostringstream msg;
@ -118,7 +118,7 @@ void gguf_load_quantized(
throw std::runtime_error(msg.str());
}
std::vector<int> weights_shape = shape;
auto weights_shape = shape;
weights_shape.back() /= (weights_per_byte * 4);
auto w_nbytes = uint32.size() *
std::accumulate(weights_shape.begin(),

View File

@ -271,7 +271,7 @@ array load(std::shared_ptr<io::Reader> in_stream, StreamOrDevice s) {
bool col_contiguous = header[34] == 'T';
// Read array shape from header
std::vector<int> shape;
Shape shape;
size_t st = header.find_last_of('(') + 1;
size_t ed = header.find_last_of(')');

View File

@ -219,15 +219,15 @@ std::vector<array> svd(const array& a, StreamOrDevice s /* = {} */) {
const auto n = a.shape(-1);
const auto rank = a.ndim();
std::vector<int> u_shape = a.shape();
auto u_shape = a.shape();
u_shape[rank - 2] = m;
u_shape[rank - 1] = m;
std::vector<int> s_shape = a.shape();
auto s_shape = a.shape();
s_shape.pop_back();
s_shape[rank - 2] = std::min(m, n);
std::vector<int> vt_shape = a.shape();
auto vt_shape = a.shape();
vt_shape[rank - 2] = n;
vt_shape[rank - 1] = n;
@ -328,8 +328,8 @@ array pinv(const array& a, StreamOrDevice s /* = {} */) {
array S = outs[1];
array V = outs[2];
std::vector<int> starts(a.ndim(), 0);
std::vector<int> ends = a.shape();
Shape starts(a.ndim(), 0);
auto ends = a.shape();
int i = a.ndim() - 2;
int j = a.ndim() - 1;
@ -479,7 +479,7 @@ array eigvalsh(
std::string UPLO /* = "L" */,
StreamOrDevice s /* = {} */) {
validate_eigh(a, "[linalg::eigvalsh]");
std::vector<int> out_shape(a.shape().begin(), a.shape().end() - 1);
Shape out_shape(a.shape().begin(), a.shape().end() - 1);
return array(
std::move(out_shape),
a.dtype(),
@ -493,7 +493,7 @@ std::pair<array, array> eigh(
StreamOrDevice s /* = {} */) {
validate_eigh(a, "[linalg::eigh]");
auto out = array::make_arrays(
{std::vector<int>(a.shape().begin(), a.shape().end() - 1), a.shape()},
{Shape(a.shape().begin(), a.shape().end() - 1), a.shape()},
{a.dtype(), a.dtype()},
std::make_shared<Eigh>(to_stream(s), UPLO, true),
{a});

View File

@ -649,7 +649,7 @@ normalize_slice(const Shape& shape, Shape& start, Shape& stop, Shape& strides) {
// Clamp to bounds
auto st = std::min(s, n - 1);
auto ed = std::max(-1, e);
auto ed = e > -1 ? e : -1;
start[i] = st;
stop[i] = ed > st ? st : ed;
@ -659,8 +659,8 @@ normalize_slice(const Shape& shape, Shape& start, Shape& stop, Shape& strides) {
} else {
// Clamp to bounds
auto st = std::max(0, std::min(s, n));
auto ed = std::max(0, std::min(e, n));
auto st = std::max(static_cast<ShapeElem>(0), std::min(s, n));
auto ed = std::max(static_cast<ShapeElem>(0), std::min(e, n));
start[i] = st;
stop[i] = ed < st ? st : ed;
@ -765,7 +765,7 @@ array slice_update(
std::vector<array> split(
const array& a,
const std::vector<int>& indices,
const Shape& indices,
int axis,
StreamOrDevice s /* = {} */) {
auto ax = axis < 0 ? axis + a.ndim() : axis;
@ -809,10 +809,8 @@ std::vector<array> split(
return res;
}
std::vector<array> split(
const array& a,
const std::vector<int>& indices,
StreamOrDevice s /* = {} */) {
std::vector<array>
split(const array& a, const Shape& indices, StreamOrDevice s /* = {} */) {
return split(a, indices, 0, s);
}
@ -834,7 +832,7 @@ split(const array& a, int num_splits, int axis, StreamOrDevice s /* = {} */) {
throw std::invalid_argument(msg.str());
}
auto split_size = q_and_r.quot;
std::vector<int> indices(num_splits - 1);
Shape indices(num_splits - 1);
for (int i = 0; i < indices.size(); ++i) {
indices[i] = (i + 1) * split_size;
}
@ -1104,7 +1102,7 @@ array edge_pad(
/** Pad an array with a constant value */
array pad(
const array& a,
const Shape& axes,
const std::vector<int>& axes,
const Shape& low_pad_size,
const Shape& high_pad_size,
const array& pad_value /*= array(0)*/,
@ -1904,9 +1902,11 @@ array min(
array argmin(const array& a, bool keepdims, StreamOrDevice s /* = {} */) {
int size = a.size();
auto result = argmin(reshape(a, {size}, s), 0, true, s);
auto result = argmin(flatten(a, s), 0, true, s);
if (keepdims) {
result = reshape(result, std::vector<int>(a.shape().size(), 1), s);
std::vector<int> axes(a.ndim() - 1);
std::iota(axes.begin(), axes.end(), 0);
result = expand_dims(result, axes, s);
} else {
result = squeeze(result, s);
}
@ -1940,9 +1940,11 @@ array argmin(
array argmax(const array& a, bool keepdims, StreamOrDevice s /* = {} */) {
int size = a.size();
auto result = argmax(reshape(a, {size}, s), 0, true, s);
auto result = argmax(flatten(a, s), 0, true, s);
if (keepdims) {
result = reshape(result, Shape(a.shape().size(), 1), s);
std::vector<int> axes(a.ndim() - 1);
std::iota(axes.begin(), axes.end(), 0);
result = expand_dims(result, axes, s);
} else {
result = squeeze(result, s);
}
@ -3238,8 +3240,8 @@ inline int dilate_size(int dim, int dil) {
}
Shape conv_out_shape(
const std::vector<int>& in_shape,
const std::vector<int>& wt_shape,
const Shape& in_shape,
const Shape& wt_shape,
const std::vector<int>& strides,
const std::vector<int>& pads_lo,
const std::vector<int>& pads_hi,
@ -4329,16 +4331,16 @@ array diagonal(
"[diagonal] axis1 and axis2 cannot be the same axis");
}
auto off1 = std::max(-offset, 0);
auto off2 = std::max(offset, 0);
ShapeElem off1 = std::max(-offset, 0);
ShapeElem off2 = std::max(offset, 0);
auto diag_size = std::min(a.shape(ax1) - off1, a.shape(ax2) - off2);
diag_size = std::max(diag_size, 0);
diag_size = diag_size < 0 ? 0 : diag_size;
std::vector<array> indices = {
arange(off1, off1 + diag_size, s), arange(off2, off2 + diag_size, s)};
std::vector<int> slice_sizes = a.shape();
Shape slice_sizes = a.shape();
slice_sizes[ax1] = 1;
slice_sizes[ax2] = 1;

View File

@ -189,13 +189,10 @@ array slice_update(
std::vector<array>
split(const array& a, int num_splits, int axis, StreamOrDevice s = {});
std::vector<array> split(const array& a, int num_splits, StreamOrDevice s = {});
std::vector<array> split(
const array& a,
const std::vector<int>& indices,
int axis,
StreamOrDevice s = {});
std::vector<array>
split(const array& a, const std::vector<int>& indices, StreamOrDevice s = {});
split(const array& a, const Shape& indices, int axis, StreamOrDevice s = {});
std::vector<array>
split(const array& a, const Shape& indices, StreamOrDevice s = {});
/** A vector of coordinate arrays from coordinate vectors. */
std::vector<array> meshgrid(
@ -253,8 +250,8 @@ array moveaxis(
array pad(
const array& a,
const std::vector<int>& axes,
const std::vector<int>& low_pad_size,
const std::vector<int>& high_pad_size,
const Shape& low_pad_size,
const Shape& high_pad_size,
const array& pad_value = array(0),
const std::string mode = "constant",
StreamOrDevice s = {});
@ -1453,7 +1450,11 @@ array view(const array& a, const Dtype& dtype, StreamOrDevice s = {});
array roll(const array& a, int shift, StreamOrDevice s = {});
array roll(const array& a, const Shape& shift, StreamOrDevice s = {});
array roll(const array& a, int shift, int axis, StreamOrDevice s = {});
array roll(const array& a, int shift, const Shape& axes, StreamOrDevice s = {});
array roll(
const array& a,
int shift,
const std::vector<int>& axes,
StreamOrDevice s = {});
array roll(const array& a, const Shape& shift, int axis, StreamOrDevice s = {});
array roll(
const array& a,

View File

@ -817,10 +817,10 @@ std::vector<array> Concatenate::vjp(
const std::vector<int>& argnums,
const std::vector<array>&) {
auto& cotan = cotangents[0];
std::vector<int> start(cotan.ndim(), 0);
std::vector<int> stop = cotan.shape();
Shape start(cotan.ndim(), 0);
Shape stop = cotan.shape();
std::vector<int> sizes;
Shape sizes;
sizes.push_back(0);
for (auto& p : primals) {
sizes.push_back(p.shape(axis_));
@ -956,9 +956,9 @@ array conv_weight_backward_patches(
const std::vector<int>& padding,
StreamOrDevice s) {
// Resolve Padded input shapes and strides
std::vector<int> padding_starts(in.ndim(), 0);
std::vector<int> padding_ends = in.shape();
std::vector<int> in_padded_shape = in.shape();
Shape padding_starts(in.ndim(), 0);
auto padding_ends = in.shape();
auto in_padded_shape = in.shape();
// padded shape
for (int i = 1; i < in.ndim() - 1; i++) {
@ -976,8 +976,9 @@ array conv_weight_backward_patches(
// Pad input
std::vector<int> padded_axes(in.ndim() - 2, 0);
std::iota(padded_axes.begin(), padded_axes.end(), 1);
Shape padding_(padding.begin(), padding.end());
auto in_padded = pad(
in, padded_axes, padding, padding, array(0, in.dtype()), "constant", s);
in, padded_axes, padding_, padding_, array(0, in.dtype()), "constant", s);
// Resolve strided patches
@ -1797,7 +1798,7 @@ std::vector<array> FFT::vjp(
std::vector<int> axes(axes_.begin(), axes_.end());
if (real_ && inverse_) {
auto out = fft::fftn(cotangents[0], axes, stream());
auto start = std::vector<int>(out.ndim(), 0);
auto start = Shape(out.ndim(), 0);
auto stop = in.shape();
out = slice(out, start, stop, stream());
auto mask_shape = out.shape();
@ -1809,7 +1810,7 @@ std::vector<array> FFT::vjp(
mask = concatenate({pad, mask, pad}, axes_.back(), stream());
return {multiply(mask, out, stream())};
} else if (real_) {
std::vector<int> n;
Shape n;
for (auto ax : axes_) {
n.push_back(in.shape()[ax]);
}
@ -1934,10 +1935,11 @@ std::pair<std::vector<array>, std::vector<int>> Gather::vmap(
}
if (indices_vmapped) {
// Make a new index array for the vmapped dimension
auto vmap_inds = arange(0, src.shape(axes[0]), stream());
auto vmap_inds =
arange(static_cast<ShapeElem>(0), src.shape(axes[0]), stream());
// Reshape it so it broadcasts with other index arrays
{
auto shape = std::vector<int>(idx_dims, 1);
auto shape = Shape(idx_dims, 1);
shape[out_ax] = vmap_inds.size();
vmap_inds = reshape(vmap_inds, std::move(shape), stream());
}
@ -2628,8 +2630,8 @@ std::vector<array> Pad::vjp(
assert(argnums.size() == 1 && argnums[0] == 0);
auto& cotan = cotangents[0];
std::vector<int> start(cotan.ndim(), 0);
std::vector<int> stop = cotan.shape();
Shape start(cotan.ndim(), 0);
auto stop = cotan.shape();
for (auto i : axes_) {
start[i] = low_pad_size_[i];
@ -3019,7 +3021,7 @@ std::vector<array> Reduce::vjp(
const std::vector<array>& outputs) {
auto in = primals[0];
std::vector<int> shape = in.shape();
auto shape = in.shape();
for (auto ax : axes_) {
shape[ax] = 1;
}
@ -3044,7 +3046,7 @@ std::vector<array> Reduce::vjp(
if (axes_.size() > 1) {
std::vector<int> transpose_to;
std::vector<int> transpose_back;
std::vector<int> shape_flat;
Shape shape_flat;
{
// Find the transpose needed to move axes_ to the back and the shape
// except the reduced over axes.
@ -3422,7 +3424,7 @@ std::pair<std::vector<array>, std::vector<int>> Scatter::vmap(
}
auto vmap_inds = arange(vmap_size, inputs[1].dtype(), stream());
auto vmap_inds_shape = std::vector<int>(inputs[1].ndim(), 1);
auto vmap_inds_shape = Shape(inputs[1].ndim(), 1);
vmap_inds_shape[0] = vmap_inds.size();
vmap_inds = reshape(vmap_inds, std::move(vmap_inds_shape), stream());
inputs.insert(
@ -3607,7 +3609,7 @@ std::vector<array> Slice::vjp(
// Transpose and reshape cotangents
auto cotan = cotangents[0];
if (!ind_axes.empty()) {
std::vector<int> cotan_shape;
Shape cotan_shape;
for (auto ax : ind_axes) {
cotan_shape.push_back(cotan.shape(ax));
}
@ -3626,7 +3628,7 @@ std::vector<array> Slice::vjp(
}
// Make indices broadcastable
std::vector<int> inds_shape(inds.size(), 1);
Shape inds_shape(inds.size(), 1);
for (int i = 0; i < inds.size(); ++i) {
inds_shape[i] = inds[i].size();
inds[i] = reshape(inds[i], inds_shape, stream());
@ -4184,7 +4186,7 @@ std::vector<array> BlockMaskedMM::vjp(
// Slice mask
mask_reshape[mask_ndim - 2] = Y;
mask_reshape[mask_ndim - 1] = X;
mask = slice(mask, std::vector<int>(mask_ndim, 0), mask_reshape, stream());
mask = slice(mask, Shape(mask_ndim, 0), mask_reshape, stream());
return mask;
};
@ -4202,7 +4204,7 @@ std::vector<array> BlockMaskedMM::vjp(
}
// Reshape
std::vector<int> r_reshape(r.shape().begin(), r.shape().end() - 2);
Shape r_reshape(r.shape().begin(), r.shape().end() - 2);
r_reshape.push_back(r.shape(-2) / block_size_);
r_reshape.push_back(block_size_);
r_reshape.push_back(r.shape(-1) / block_size_);
@ -4492,7 +4494,7 @@ std::pair<std::vector<array>, std::vector<int>> NumberOfElements::vmap(
}
array out = array(
std::vector<int>{},
{},
dtype_,
std::make_shared<NumberOfElements>(stream(), new_axes, inverted_, dtype_),
inputs);

View File

@ -1088,10 +1088,7 @@ class Full : public UnaryPrimitive {
class Gather : public UnaryPrimitive {
public:
explicit Gather(
Stream stream,
std::vector<int> axes,
std::vector<int> slice_sizes)
explicit Gather(Stream stream, std::vector<int> axes, Shape slice_sizes)
: UnaryPrimitive(stream),
axes_(std::move(axes)),
slice_sizes_(std::move(slice_sizes)) {}
@ -1108,7 +1105,7 @@ class Gather : public UnaryPrimitive {
private:
void eval(const std::vector<array>& inputs, array& out);
std::vector<int> axes_;
std::vector<int> slice_sizes_;
Shape slice_sizes_;
};
class Greater : public UnaryPrimitive {
@ -1503,8 +1500,8 @@ class Pad : public UnaryPrimitive {
explicit Pad(
Stream stream,
const std::vector<int>& axes,
const std::vector<int>& low_pad_size,
const std::vector<int>& high_pad_size)
const Shape& low_pad_size,
const Shape& high_pad_size)
: UnaryPrimitive(stream),
axes_(axes),
low_pad_size_(low_pad_size),
@ -1520,8 +1517,8 @@ class Pad : public UnaryPrimitive {
private:
std::vector<int> axes_;
std::vector<int> low_pad_size_;
std::vector<int> high_pad_size_;
Shape low_pad_size_;
Shape high_pad_size_;
void eval(const std::vector<array>& inputs, array& out);
};
@ -1903,9 +1900,9 @@ class Slice : public UnaryPrimitive {
public:
explicit Slice(
Stream stream,
const std::vector<int>& start_indices,
const std::vector<int>& end_indices,
const std::vector<int>& strides)
const Shape& start_indices,
const Shape& end_indices,
const Shape& strides)
: UnaryPrimitive(stream),
start_indices_(start_indices),
end_indices_(end_indices),
@ -1920,9 +1917,9 @@ class Slice : public UnaryPrimitive {
bool is_equivalent(const Primitive& other) const override;
private:
std::vector<int> start_indices_;
std::vector<int> end_indices_;
std::vector<int> strides_;
Shape start_indices_;
Shape end_indices_;
Shape strides_;
void eval(const std::vector<array>& inputs, array& out);
};
@ -1931,9 +1928,9 @@ class SliceUpdate : public UnaryPrimitive {
public:
explicit SliceUpdate(
Stream stream,
const std::vector<int>& start_indices,
const std::vector<int>& end_indices,
const std::vector<int>& strides)
const Shape& start_indices,
const Shape& end_indices,
const Shape& strides)
: UnaryPrimitive(stream),
start_indices_(start_indices),
end_indices_(end_indices),
@ -1948,9 +1945,9 @@ class SliceUpdate : public UnaryPrimitive {
bool is_equivalent(const Primitive& other) const override;
private:
std::vector<int> start_indices_;
std::vector<int> end_indices_;
std::vector<int> strides_;
Shape start_indices_;
Shape end_indices_;
Shape strides_;
void eval(const std::vector<array>& inputs, array& out);
};
@ -1997,7 +1994,7 @@ class Sort : public UnaryPrimitive {
class Split : public Primitive {
public:
explicit Split(Stream stream, const std::vector<int>& indices, int axis)
explicit Split(Stream stream, const Shape& indices, int axis)
: Primitive(stream), indices_(indices), axis_(axis) {}
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
@ -2013,7 +2010,7 @@ class Split : public Primitive {
private:
void eval(const std::vector<array>& inputs, std::vector<array>& outputs);
std::vector<int> indices_;
Shape indices_;
int axis_;
};

View File

@ -296,7 +296,7 @@ std::ostream& operator<<(std::ostream& os, array a) {
return os;
}
std::ostream& operator<<(std::ostream& os, const Shape& v) {
std::ostream& operator<<(std::ostream& os, const std::vector<int>& v) {
os << "(";
for (int i = 0; i < v.size(); ++i) {
os << v[i] << ((i == v.size() - 1) ? "" : ",");
@ -305,7 +305,7 @@ std::ostream& operator<<(std::ostream& os, const Shape& v) {
return os;
}
std::ostream& operator<<(std::ostream& os, const Strides& v) {
std::ostream& operator<<(std::ostream& os, const std::vector<int64_t>& v) {
os << "(";
for (int i = 0; i < v.size(); ++i) {
os << v[i] << ((i == v.size() - 1) ? "" : ",");

View File

@ -77,8 +77,8 @@ std::ostream& operator<<(std::ostream& os, const Stream& s);
std::ostream& operator<<(std::ostream& os, const Dtype& d);
std::ostream& operator<<(std::ostream& os, const Dtype::Kind& k);
std::ostream& operator<<(std::ostream& os, array a);
std::ostream& operator<<(std::ostream& os, const Shape& v);
std::ostream& operator<<(std::ostream& os, const Strides& v);
std::ostream& operator<<(std::ostream& os, const std::vector<int>& v);
std::ostream& operator<<(std::ostream& os, const std::vector<int64_t>& v);
inline std::ostream& operator<<(std::ostream& os, const complex64_t& v) {
return os << v.real() << (v.imag() >= 0 ? "+" : "") << v.imag() << "j";
}

View File

@ -889,13 +889,13 @@ void init_array(nb::module_& m) {
.def(
"reshape",
[](const mx::array& a, nb::args shape_, mx::StreamOrDevice s) {
std::vector<int> shape;
mx::Shape shape;
if (!nb::isinstance<int>(shape_[0])) {
shape = nb::cast<std::vector<int>>(shape_[0]);
shape = nb::cast<mx::Shape>(shape_[0]);
} else {
shape = nb::cast<std::vector<int>>(shape_);
shape = nb::cast<mx::Shape>(shape_);
}
return mx::reshape(a, shape, s);
return mx::reshape(a, std::move(shape), s);
},
"shape"_a,
"stream"_a = nb::none(),
@ -1182,14 +1182,14 @@ void init_array(nb::module_& m) {
.def(
"split",
[](const mx::array& a,
const std::variant<int, std::vector<int>>& indices_or_sections,
const std::variant<int, mx::Shape>& indices_or_sections,
int axis,
mx::StreamOrDevice s) {
if (auto pv = std::get_if<int>(&indices_or_sections); pv) {
return mx::split(a, *pv, axis, s);
} else {
return mx::split(
a, std::get<std::vector<int>>(indices_or_sections), axis, s);
a, std::get<mx::Shape>(indices_or_sections), axis, s);
}
},
"indices_or_sections"_a,

View File

@ -181,7 +181,7 @@ void init_fast(nb::module_& parent_module) {
return nb::cpp_function(
[kernel = std::move(kernel)](
const std::vector<ScalarOrArray>& inputs_,
const std::vector<std::vector<int>>& output_shapes,
const std::vector<mx::Shape>& output_shapes,
const std::vector<mx::Dtype>& output_dtypes,
std::tuple<int, int, int> grid,
std::tuple<int, int, int> threadgroup,

View File

@ -79,7 +79,7 @@ void init_fft(nb::module_& parent_module) {
m.def(
"fft2",
[](const mx::array& a,
const std::optional<std::vector<int>>& n,
const std::optional<mx::Shape>& n,
const std::optional<std::vector<int>>& axes,
mx::StreamOrDevice s) {
if (axes.has_value() && n.has_value()) {
@ -115,7 +115,7 @@ void init_fft(nb::module_& parent_module) {
m.def(
"ifft2",
[](const mx::array& a,
const std::optional<std::vector<int>>& n,
const std::optional<mx::Shape>& n,
const std::optional<std::vector<int>>& axes,
mx::StreamOrDevice s) {
if (axes.has_value() && n.has_value()) {
@ -151,7 +151,7 @@ void init_fft(nb::module_& parent_module) {
m.def(
"fftn",
[](const mx::array& a,
const std::optional<std::vector<int>>& n,
const std::optional<mx::Shape>& n,
const std::optional<std::vector<int>>& axes,
mx::StreamOrDevice s) {
if (axes.has_value() && n.has_value()) {
@ -188,7 +188,7 @@ void init_fft(nb::module_& parent_module) {
m.def(
"ifftn",
[](const mx::array& a,
const std::optional<std::vector<int>>& n,
const std::optional<mx::Shape>& n,
const std::optional<std::vector<int>>& axes,
mx::StreamOrDevice s) {
if (axes.has_value() && n.has_value()) {
@ -294,7 +294,7 @@ void init_fft(nb::module_& parent_module) {
m.def(
"rfft2",
[](const mx::array& a,
const std::optional<std::vector<int>>& n,
const std::optional<mx::Shape>& n,
const std::optional<std::vector<int>>& axes,
mx::StreamOrDevice s) {
if (axes.has_value() && n.has_value()) {
@ -336,7 +336,7 @@ void init_fft(nb::module_& parent_module) {
m.def(
"irfft2",
[](const mx::array& a,
const std::optional<std::vector<int>>& n,
const std::optional<mx::Shape>& n,
const std::optional<std::vector<int>>& axes,
mx::StreamOrDevice s) {
if (axes.has_value() && n.has_value()) {
@ -378,7 +378,7 @@ void init_fft(nb::module_& parent_module) {
m.def(
"rfftn",
[](const mx::array& a,
const std::optional<std::vector<int>>& n,
const std::optional<mx::Shape>& n,
const std::optional<std::vector<int>>& axes,
mx::StreamOrDevice s) {
if (axes.has_value() && n.has_value()) {
@ -420,7 +420,7 @@ void init_fft(nb::module_& parent_module) {
m.def(
"irfftn",
[](const mx::array& a,
const std::optional<std::vector<int>>& n,
const std::optional<mx::Shape>& n,
const std::optional<std::vector<int>>& axes,
mx::StreamOrDevice s) {
if (axes.has_value() && n.has_value()) {

View File

@ -25,9 +25,9 @@ int get_slice_int(nb::object obj, int default_val) {
}
void get_slice_params(
int& starts,
int& ends,
int& strides,
mx::ShapeElem& starts,
mx::ShapeElem& ends,
mx::ShapeElem& strides,
const nb::slice& in_slice,
int axis_size) {
// Following numpy's convention
@ -68,9 +68,9 @@ mx::array mlx_get_item_slice(const mx::array& src, const nb::slice& in_slice) {
return src;
}
std::vector<int> starts(src.ndim(), 0);
std::vector<int> ends = src.shape();
std::vector<int> strides(src.ndim(), 1);
mx::Shape starts(src.ndim(), 0);
auto ends = src.shape();
mx::Shape strides(src.ndim(), 1);
// Check and update slice params
get_slice_params(starts[0], ends[0], strides[0], in_slice, ends[0]);
@ -119,7 +119,7 @@ mx::array mlx_gather_nd(
auto& idx = indices[i];
if (nb::isinstance<nb::slice>(idx)) {
int start, end, stride;
mx::ShapeElem start, end, stride;
get_slice_params(
start, end, stride, nb::cast<nb::slice>(idx), src.shape(i));
@ -168,7 +168,7 @@ mx::array mlx_gather_nd(
// Do the gather
std::vector<int> axes(indices.size());
std::iota(axes.begin(), axes.end(), 0);
std::vector<int> slice_sizes = src.shape();
auto slice_sizes = src.shape();
std::fill(slice_sizes.begin(), slice_sizes.begin() + indices.size(), 1);
src = gather(src, gather_indices, axes, slice_sizes);
@ -179,9 +179,7 @@ mx::array mlx_gather_nd(
return mx::squeeze(src, axes);
}
auto mlx_expand_ellipsis(
const std::vector<int>& shape,
const nb::tuple& entries) {
auto mlx_expand_ellipsis(const mx::Shape& shape, const nb::tuple& entries) {
std::vector<nb::object> indices;
// Go over all entries and note the position of ellipsis
@ -230,7 +228,8 @@ auto mlx_expand_ellipsis(
for (int axis = non_none_indices_before;
axis < shape.size() - non_none_indices_after;
axis++) {
indices.push_back(nb::slice(0, shape[axis], 1));
indices.push_back(
nb::slice(mx::ShapeElem{0}, shape[axis], mx::ShapeElem{1}));
non_none_indices++;
}
}
@ -371,9 +370,9 @@ mx::array mlx_get_item_nd(mx::array src, const nb::tuple& entries) {
// Slice handling
{
std::vector<int> starts(src.ndim(), 0);
std::vector<int> ends = src.shape();
std::vector<int> strides(src.ndim(), 1);
mx::Shape starts(src.ndim(), 0);
auto ends = src.shape();
mx::Shape strides(src.ndim(), 1);
int axis = 0;
for (auto& idx : remaining_indices) {
if (!idx.is_none()) {
@ -461,8 +460,7 @@ mlx_scatter_args_int(
int s = 0;
for (; s < update.ndim() && update.shape(s) == 1; s++)
;
auto up_shape =
std::vector<int>(update.shape().begin() + s, update.shape().end());
auto up_shape = mx::Shape(update.shape().begin() + s, update.shape().end());
auto shape = src.shape();
shape[0] = 1;
@ -521,9 +519,9 @@ mlx_scatter_args_slice(
{}, broadcast_to(squeeze_leading_singletons(update), src.shape()), {}};
}
int start = 0;
int end = src.shape(0);
int stride = 1;
mx::ShapeElem start = 0;
auto end = src.shape(0);
mx::ShapeElem stride = 1;
// Check and update slice params
get_slice_params(start, end, stride, in_slice, end);
@ -645,7 +643,7 @@ mlx_scatter_args_nd(
for (int i = 0; i < indices.size(); ++i) {
auto& pyidx = indices[i];
if (nb::isinstance<nb::slice>(pyidx)) {
int start, end, stride;
mx::ShapeElem start, end, stride;
auto axis_size = src.shape(ax++);
get_slice_params(
start, end, stride, nb::cast<nb::slice>(pyidx), axis_size);
@ -654,7 +652,7 @@ mlx_scatter_args_nd(
start = (start < 0) ? start + axis_size : start;
end = (end < 0) ? end + axis_size : end;
std::vector<int> idx_shape(idx_ndim, 1);
mx::Shape idx_shape(idx_ndim, 1);
// If it's a simple slice, we only need to add the start index
if (array_num >= num_arrays && num_strided_slices <= 0 && stride == 1) {

View File

@ -1571,15 +1571,14 @@ void init_ops(nb::module_& m) {
)pbdoc");
m.def(
"full",
[](const std::variant<int, std::vector<int>>& shape,
[](const std::variant<int, mx::Shape>& shape,
const ScalarOrArray& vals,
std::optional<mx::Dtype> dtype,
mx::StreamOrDevice s) {
if (auto pv = std::get_if<int>(&shape); pv) {
return mx::full({*pv}, to_array(vals, dtype), s);
} else {
return mx::full(
std::get<std::vector<int>>(shape), to_array(vals, dtype), s);
return mx::full(std::get<mx::Shape>(shape), to_array(vals, dtype), s);
}
},
"shape"_a,
@ -1606,14 +1605,14 @@ void init_ops(nb::module_& m) {
)pbdoc");
m.def(
"zeros",
[](const std::variant<int, std::vector<int>>& shape,
[](const std::variant<int, mx::Shape>& shape,
std::optional<mx::Dtype> dtype,
mx::StreamOrDevice s) {
auto t = dtype.value_or(mx::float32);
if (auto pv = std::get_if<int>(&shape); pv) {
return mx::zeros({*pv}, t, s);
} else {
return mx::zeros(std::get<std::vector<int>>(shape), t, s);
return mx::zeros(std::get<mx::Shape>(shape), t, s);
}
},
"shape"_a,
@ -1652,14 +1651,14 @@ void init_ops(nb::module_& m) {
)pbdoc");
m.def(
"ones",
[](const std::variant<int, std::vector<int>>& shape,
[](const std::variant<int, mx::Shape>& shape,
std::optional<mx::Dtype> dtype,
mx::StreamOrDevice s) {
auto t = dtype.value_or(mx::float32);
if (auto pv = std::get_if<int>(&shape); pv) {
return mx::ones({*pv}, t, s);
} else {
return mx::ones(std::get<std::vector<int>>(shape), t, s);
return mx::ones(std::get<mx::Shape>(shape), t, s);
}
},
"shape"_a,
@ -2481,14 +2480,14 @@ void init_ops(nb::module_& m) {
m.def(
"split",
[](const mx::array& a,
const std::variant<int, std::vector<int>>& indices_or_sections,
const std::variant<int, mx::Shape>& indices_or_sections,
int axis,
mx::StreamOrDevice s) {
if (auto pv = std::get_if<int>(&indices_or_sections); pv) {
return mx::split(a, *pv, axis, s);
} else {
return mx::split(
a, std::get<std::vector<int>>(indices_or_sections), axis, s);
a, std::get<mx::Shape>(indices_or_sections), axis, s);
}
},
nb::arg(),
@ -2744,9 +2743,7 @@ void init_ops(nb::module_& m) {
)pbdoc");
m.def(
"broadcast_to",
[](const ScalarOrArray& a,
const std::vector<int>& shape,
mx::StreamOrDevice s) {
[](const ScalarOrArray& a, const mx::Shape& shape, mx::StreamOrDevice s) {
return mx::broadcast_to(to_array(a), shape, s);
},
nb::arg(),
@ -4895,23 +4892,15 @@ void init_ops(nb::module_& m) {
m.def(
"roll",
[](const mx::array& a,
const IntOrVec& shift,
const std::variant<int, mx::Shape>& shift,
const IntOrVec& axis,
mx::StreamOrDevice s) {
return std::visit(
[&](auto sh, auto ax) -> mx::array {
using T = decltype(ax);
using V = decltype(sh);
if constexpr (std::is_same_v<V, std::monostate>) {
throw std::invalid_argument(
"[roll] Expected two arguments but only one was given.");
if constexpr (std::is_same_v<decltype(ax), std::monostate>) {
return mx::roll(a, sh, s);
} else {
if constexpr (std::is_same_v<T, std::monostate>) {
return mx::roll(a, sh, s);
} else {
return mx::roll(a, sh, ax, s);
}
return mx::roll(a, sh, ax, s);
}
},
shift,

View File

@ -108,7 +108,7 @@ void init_random(nb::module_& parent_module) {
"uniform",
[](const ScalarOrArray& low,
const ScalarOrArray& high,
const std::vector<int>& shape,
const mx::Shape& shape,
std::optional<mx::Dtype> type,
const std::optional<mx::array>& key_,
mx::StreamOrDevice s) {
@ -123,7 +123,7 @@ void init_random(nb::module_& parent_module) {
},
"low"_a = 0,
"high"_a = 1,
"shape"_a = std::vector<int>{},
"shape"_a = mx::Shape{},
"dtype"_a.none() = mx::float32,
"key"_a = nb::none(),
"stream"_a = nb::none(),
@ -150,7 +150,7 @@ void init_random(nb::module_& parent_module) {
)pbdoc");
m.def(
"normal",
[](const std::vector<int>& shape,
[](const mx::Shape& shape,
std::optional<mx::Dtype> type,
float loc,
float scale,
@ -160,7 +160,7 @@ void init_random(nb::module_& parent_module) {
return mx::random::normal(
shape, type.value_or(mx::float32), loc, scale, key, s);
},
"shape"_a = std::vector<int>{},
"shape"_a = mx::Shape{},
"dtype"_a.none() = mx::float32,
"loc"_a = 0.0,
"scale"_a = 1.0,
@ -185,7 +185,7 @@ void init_random(nb::module_& parent_module) {
"multivariate_normal",
[](const mx::array& mean,
const mx::array& cov,
const std::vector<int>& shape,
const mx::Shape& shape,
std::optional<mx::Dtype> type,
const std::optional<mx::array>& key_,
mx::StreamOrDevice s) {
@ -195,7 +195,7 @@ void init_random(nb::module_& parent_module) {
},
"mean"_a,
"cov"_a,
"shape"_a = std::vector<int>{},
"shape"_a = mx::Shape{},
"dtype"_a.none() = mx::float32,
"key"_a = nb::none(),
"stream"_a = nb::none(),
@ -227,7 +227,7 @@ void init_random(nb::module_& parent_module) {
"randint",
[](const ScalarOrArray& low,
const ScalarOrArray& high,
const std::vector<int>& shape,
const mx::Shape& shape,
std::optional<mx::Dtype> type,
const std::optional<mx::array>& key_,
mx::StreamOrDevice s) {
@ -242,7 +242,7 @@ void init_random(nb::module_& parent_module) {
},
"low"_a,
"high"_a,
"shape"_a = std::vector<int>{},
"shape"_a = mx::Shape{},
"dtype"_a.none() = mx::int32,
"key"_a = nb::none(),
"stream"_a = nb::none(),
@ -268,7 +268,7 @@ void init_random(nb::module_& parent_module) {
m.def(
"bernoulli",
[](const ScalarOrArray& p_,
const std::optional<std::vector<int>> shape,
const std::optional<mx::Shape> shape,
const std::optional<mx::array>& key_,
mx::StreamOrDevice s) {
auto key = key_ ? key_.value() : default_key().next();
@ -306,7 +306,7 @@ void init_random(nb::module_& parent_module) {
"truncated_normal",
[](const ScalarOrArray& lower_,
const ScalarOrArray& upper_,
const std::optional<std::vector<int>> shape_,
const std::optional<mx::Shape> shape_,
std::optional<mx::Dtype> type,
const std::optional<mx::array>& key_,
mx::StreamOrDevice s) {
@ -350,14 +350,14 @@ void init_random(nb::module_& parent_module) {
)pbdoc");
m.def(
"gumbel",
[](const std::vector<int>& shape,
[](const mx::Shape& shape,
std::optional<mx::Dtype> type,
const std::optional<mx::array>& key_,
mx::StreamOrDevice s) {
auto key = key_ ? key_.value() : default_key().next();
return mx::random::gumbel(shape, type.value_or(mx::float32), key, s);
},
"shape"_a = std::vector<int>{},
"shape"_a = mx::Shape{},
"dtype"_a.none() = mx::float32,
"key"_a = nb::none(),
"stream"_a = nb::none(),
@ -384,7 +384,7 @@ void init_random(nb::module_& parent_module) {
"categorical",
[](const mx::array& logits,
int axis,
const std::optional<std::vector<int>> shape,
const std::optional<mx::Shape> shape,
const std::optional<int> num_samples,
const std::optional<mx::array>& key_,
mx::StreamOrDevice s) {
@ -434,7 +434,7 @@ void init_random(nb::module_& parent_module) {
)pbdoc");
m.def(
"laplace",
[](const std::vector<int>& shape,
[](const mx::Shape& shape,
std::optional<mx::Dtype> type,
float loc,
float scale,
@ -444,7 +444,7 @@ void init_random(nb::module_& parent_module) {
return mx::random::laplace(
shape, type.value_or(mx::float32), loc, scale, key, s);
},
"shape"_a = std::vector<int>{},
"shape"_a = mx::Shape{},
"dtype"_a.none() = mx::float32,
"loc"_a = 0.0,
"scale"_a = 1.0,
@ -479,7 +479,7 @@ void init_random(nb::module_& parent_module) {
return mx::random::permutation(std::get<mx::array>(x), axis, key, s);
}
},
"shape"_a = std::vector<int>{},
"x"_a,
"axis"_a = 0,
"key"_a = nb::none(),
"stream"_a = nb::none(),

View File

@ -1,4 +1,5 @@
// Copyright © 2023-2024 Apple Inc.
#include <nanobind/nanobind.h>
#include <nanobind/stl/optional.h>
#include <nanobind/stl/pair.h>

View File

@ -395,7 +395,7 @@ TEST_CASE("test split") {
CHECK_EQ(out[1].shape(), Shape{8, 4});
CHECK_EQ(out[2].shape(), Shape{8, 4});
out = split(x, std::vector<int>{});
out = split(x, Shape{});
CHECK_EQ(out.size(), 1);
CHECK_EQ(out[0].shape(), x.shape());
@ -405,25 +405,25 @@ TEST_CASE("test split") {
CHECK_EQ(out[1].shape(), Shape{4, 12});
CHECK_EQ(out[2].shape(), Shape{1, 12});
out = split(x, std::vector<int>{20});
out = split(x, Shape{20});
CHECK_EQ(out.size(), 2);
CHECK_EQ(out[0].shape(), Shape{8, 12});
CHECK_EQ(out[1].shape(), Shape{0, 12});
// Negative indices
out = split(x, std::vector<int>{-5});
out = split(x, Shape{-5});
CHECK_EQ(out[0].shape(), Shape{3, 12});
CHECK_EQ(out[1].shape(), Shape{5, 12});
// Different axis
out = split(x, std::vector<int>{2, 8}, 1);
out = split(x, {2, 8}, 1);
CHECK_EQ(out[0].shape(), Shape{8, 2});
CHECK_EQ(out[1].shape(), Shape{8, 6});
CHECK_EQ(out[2].shape(), Shape{8, 4});
// Out of order indices
x = arange(5);
out = split(x, std::vector<int>{2, 1, 2});
out = split(x, {2, 1, 2});
CHECK(array_equal(out[0], array({0, 1})).item<bool>());
CHECK(array_equal(out[1], array({})).item<bool>());
CHECK(array_equal(out[2], array({1})).item<bool>());

View File

@ -611,8 +611,8 @@ TEST_CASE("test categorical") {
CHECK_THROWS(categorical(logits, -3));
// Invalid requested shapes
CHECK_THROWS(categorical(logits, 1, std::vector<int>{1}));
CHECK_THROWS(categorical(logits, 1, std::vector<int>{11}));
CHECK_THROWS(categorical(logits, 1, Shape{1}));
CHECK_THROWS(categorical(logits, 1, Shape{11}));
CHECK_THROWS(categorical(logits, 1, {10, 1}));
CHECK_EQ(categorical(logits, -1).shape(), Shape{10});

View File

@ -335,8 +335,7 @@ TEST_CASE("test vmap gather") {
auto fun = [](std::vector<array> inputs) {
auto src = inputs[0];
auto indices = inputs[1];
std::vector<int> slice_sizes = {1, 2, 2};
auto out = squeeze(gather(src, indices, 0, slice_sizes), 2);
auto out = squeeze(gather(src, indices, 0, {1, 2, 2}), 2);
return std::vector<array>{out};
};
auto x = zeros({2, 2, 2, 2});
@ -351,8 +350,7 @@ TEST_CASE("test vmap gather") {
auto fun = [](std::vector<array> inputs) {
auto src = inputs[0];
auto indices = inputs[1];
std::vector<int> slice_sizes = {1, 2, 2};
auto out = squeeze(gather(src, indices, 0, slice_sizes), 1);
auto out = squeeze(gather(src, indices, 0, {1, 2, 2}), 1);
return std::vector<array>{out};
};
auto x = zeros({2, 2, 2, 2});
@ -365,8 +363,7 @@ TEST_CASE("test vmap gather") {
auto fun = [](std::vector<array> inputs) {
auto src = inputs[0];
auto indices = inputs[1];
std::vector<int> slice_sizes = {1, 2, 2, 2};
auto out = squeeze(gather(src, indices, 0, slice_sizes), 1);
auto out = squeeze(gather(src, indices, 0, {1, 2, 2, 2}), 1);
return std::vector<array>{out};
};
auto x = zeros({2, 2, 2, 2});
@ -380,8 +377,7 @@ TEST_CASE("test vmap gather") {
auto fun = [](std::vector<array> inputs) {
auto src = inputs[0];
auto indices = std::vector<array>(inputs.begin() + 1, inputs.end());
std::vector<int> slice_sizes = {1, 1, 2, 2};
auto out = squeeze(gather(src, indices, {0, 1}, slice_sizes), {1, 2});
auto out = squeeze(gather(src, indices, {0, 1}, {1, 1, 2, 2}), {1, 2});
return std::vector<array>{out};
};
auto x = zeros({2, 2, 2, 2});