Use int64 stride everywhere (#1671)

* use int64 stride everywhere

* fix ext

* fix ext

* more shape + cleanup

* one more

* few more
This commit is contained in:
Awni Hannun
2024-12-09 11:09:02 -08:00
committed by GitHub
parent 35b412c099
commit 40c62c1321
102 changed files with 1262 additions and 1705 deletions

View File

@@ -13,8 +13,8 @@ template <typename InT, typename OpT>
void arg_reduce(const array& in, array& out, const OpT& op, int axis) {
auto axis_size = in.shape()[axis];
auto axis_stride = in.strides()[axis];
std::vector<size_t> strides = in.strides();
std::vector<int> shape = in.shape();
Strides strides = in.strides();
Shape shape = in.shape();
strides.erase(strides.begin() + axis);
shape.erase(shape.begin() + axis);
for (uint32_t i = 0; i < out.size(); ++i) {

View File

@@ -178,10 +178,10 @@ void binary_op_dims(
const T* b,
U* out,
Op op,
const std::vector<int>& shape,
const std::vector<size_t>& a_strides,
const std::vector<size_t>& b_strides,
const std::vector<size_t>& out_strides,
const Shape& shape,
const Strides& a_strides,
const Strides& b_strides,
const Strides& out_strides,
int axis) {
auto stride_a = a_strides[axis];
auto stride_b = b_strides[axis];
@@ -212,10 +212,10 @@ void binary_op_dispatch_dims(
array& out,
Op op,
int dim,
const std::vector<int>& shape,
const std::vector<size_t>& a_strides,
const std::vector<size_t>& b_strides,
const std::vector<size_t>& out_strides) {
const Shape& shape,
const Strides& a_strides,
const Strides& b_strides,
const Strides& out_strides) {
const T* a_ptr = a.data<T>();
const T* b_ptr = b.data<T>();
U* out_ptr = out.data<U>();
@@ -258,10 +258,10 @@ void binary_op_dispatch_dims(
return;
}
ContiguousIterator<size_t> a_it(shape, a_strides, dim - 3);
ContiguousIterator<size_t> b_it(shape, b_strides, dim - 3);
size_t stride = out_strides[dim - 4];
for (size_t elem = 0; elem < a.size(); elem += stride) {
ContiguousIterator a_it(shape, a_strides, dim - 3);
ContiguousIterator b_it(shape, b_strides, dim - 3);
auto stride = out_strides[dim - 4];
for (int64_t elem = 0; elem < a.size(); elem += stride) {
binary_op_dims<T, U, Op, 3, Strided>(
a_ptr + a_it.loc,
b_ptr + b_it.loc,
@@ -327,7 +327,7 @@ void binary_op(
const auto& strides = new_strides[2];
// Get the left-most dim such that the array is row contiguous after
auto leftmost_rc_dim = [&strides](const std::vector<size_t>& arr_strides) {
auto leftmost_rc_dim = [&strides](const auto& arr_strides) {
int d = arr_strides.size() - 1;
for (; d >= 0 && arr_strides[d] == strides[d]; d--) {
}
@@ -337,7 +337,7 @@ void binary_op(
auto b_rc_dim = leftmost_rc_dim(b_strides);
// Get the left-most dim such that the array is a broadcasted "scalar" after
auto leftmost_s_dim = [](const std::vector<size_t>& arr_strides) {
auto leftmost_s_dim = [](const auto& arr_strides) {
int d = arr_strides.size() - 1;
for (; d >= 0 && arr_strides[d] == 0; d--) {
}

View File

@@ -16,10 +16,10 @@ void binary_op_dims(
U* out_a,
U* out_b,
Op op,
const std::vector<int>& shape,
const std::vector<size_t>& a_strides,
const std::vector<size_t>& b_strides,
const std::vector<size_t>& out_strides,
const Shape& shape,
const Strides& a_strides,
const Strides& b_strides,
const Strides& out_strides,
int axis) {
auto stride_a = a_strides[axis];
auto stride_b = b_strides[axis];
@@ -96,9 +96,9 @@ void binary_op_dispatch_dims(
return;
}
ContiguousIterator<size_t> a_it(shape, a_strides, ndim - 2);
ContiguousIterator<size_t> b_it(shape, b_strides, ndim - 2);
size_t stride = out_strides[ndim - 3];
ContiguousIterator a_it(shape, a_strides, ndim - 2);
ContiguousIterator b_it(shape, b_strides, ndim - 2);
auto stride = out_strides[ndim - 3];
for (size_t elem = 0; elem < a.size(); elem += stride) {
binary_op_dims<T, U, Op, 2>(
a_ptr + a_it.loc,

View File

@@ -49,7 +49,7 @@ void Broadcast::eval(const std::vector<array>& inputs, array& out) {
out.set_data(nullptr);
return;
}
std::vector<size_t> strides(out.ndim(), 0);
Strides strides(out.ndim(), 0);
int diff = out.ndim() - in.ndim();
for (int i = in.ndim() - 1; i >= 0; --i) {
strides[i + diff] = (in.shape()[i] == 1) ? 0 : in.strides()[i];
@@ -141,7 +141,7 @@ void NumberOfElements::eval(const std::vector<array>& inputs, array& out) {
}
}
std::pair<bool, std::vector<size_t>> Reshape::prepare_reshape(
std::pair<bool, Strides> Reshape::prepare_reshape(
const array& in,
const array& out) {
// Special case for empty arrays or row contiguous arrays
@@ -151,8 +151,7 @@ std::pair<bool, std::vector<size_t>> Reshape::prepare_reshape(
// Special case for scalars
if (in.ndim() == 0) {
std::vector<size_t> out_strides(out.ndim(), 0);
return {false, out_strides};
return {false, Strides(out.ndim(), 0)};
}
// Firstly let's collapse all the contiguous dimensions of the input
@@ -160,7 +159,7 @@ std::pair<bool, std::vector<size_t>> Reshape::prepare_reshape(
// If shapes fit exactly in the contiguous dims then no copy is necessary so
// let's check.
std::vector<size_t> out_strides;
Strides out_strides;
bool copy_necessary = false;
int j = 0;
for (int i = 0; i < out.ndim(); i++) {
@@ -183,7 +182,7 @@ std::pair<bool, std::vector<size_t>> Reshape::prepare_reshape(
void Reshape::shared_buffer_reshape(
const array& in,
const std::vector<size_t>& out_strides,
const Strides& out_strides,
array& out) {
auto flags = in.flags();
if (flags.row_contiguous) {
@@ -249,18 +248,6 @@ void Split::eval(
}
}
std::tuple<int64_t, std::vector<int64_t>> SliceUpdate::prepare_slice(
const array& in) {
int64_t data_offset = 0;
std::vector<int64_t> inp_strides(in.ndim(), 0);
for (int i = 0; i < in.ndim(); ++i) {
data_offset += start_indices_[i] * in.strides()[i];
inp_strides[i] = in.strides()[i] * strides_[i];
}
return std::make_tuple(data_offset, inp_strides);
}
void StopGradient::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
move_or_copy(inputs[0], out);
@@ -268,7 +255,7 @@ void StopGradient::eval(const std::vector<array>& inputs, array& out) {
void Transpose::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
std::vector<size_t> out_strides(out.ndim());
Strides out_strides(out.ndim());
auto& in = inputs[0];
for (int ax = 0; ax < axes_.size(); ++ax) {
out_strides[ax] = in.strides()[axes_[ax]];
@@ -285,8 +272,8 @@ void Transpose::eval(const std::vector<array>& inputs, array& out) {
// true, they stay true)
auto flags = in.flags();
if (flags.contiguous && in.data_size() == in.size()) {
size_t f_stride = 1;
size_t b_stride = 1;
int64_t f_stride = 1;
int64_t b_stride = 1;
flags.col_contiguous = true;
flags.row_contiguous = true;
for (int i = 0, ri = out.ndim() - 1; i < out.ndim(); ++i, --ri) {

View File

@@ -165,7 +165,7 @@ void compiled_allocate_outputs(
bool move_buffers /* = false */) {
if (contiguous) {
int o = 0;
std::vector<size_t> strides;
Strides strides;
size_t data_size;
array::Flags flags;
for (int i = 0; i < inputs.size() && o < outputs.size(); ++i) {

View File

@@ -746,9 +746,9 @@ void explicit_gemm_conv_1D_cpu(
copy_inplace(in, in_padded_slice, CopyType::GeneralGeneral);
// Make strided view
std::vector<int> strided_shape = {N, oH, wH, C};
Shape strided_shape = {N, oH, wH, C};
std::vector<size_t> strided_strides = {
Strides strided_strides = {
in_padded.strides()[0],
in_padded.strides()[1] * wt_strides[0],
in_padded.strides()[1],
@@ -865,9 +865,9 @@ void explicit_gemm_conv_2D_cpu(
copy_inplace(in, in_padded_slice, CopyType::GeneralGeneral);
// Make strided view
std::vector<int> strided_shape = {N, oH, oW, wH, wW, C};
Shape strided_shape = {N, oH, oW, wH, wW, C};
std::vector<size_t> strided_strides = {
Strides strided_strides = {
in_padded.strides()[0],
in_padded.strides()[1] * wt_strides[0],
in_padded.strides()[2] * wt_strides[1],
@@ -974,7 +974,7 @@ void explicit_gemm_conv_ND_cpu(
copy_inplace(in, in_padded_slice, CopyType::GeneralGeneral);
// Make strided view
std::vector<int> strided_shape(oDim.size() + wDim.size() + 2);
Shape strided_shape(oDim.size() + wDim.size() + 2);
strided_shape.front() = N;
for (size_t i = 0; i < oDim.size(); i++) {
strided_shape[i + 1] = oDim[i];
@@ -984,7 +984,7 @@ void explicit_gemm_conv_ND_cpu(
}
strided_shape.back() = C;
std::vector<size_t> strided_strides(in.shape().size() * 2 - 2);
Strides strided_strides(in.shape().size() * 2 - 2);
strided_strides[0] = in_padded.strides()[0];
for (size_t i = 0; i < wt_strides.size(); i++) {
strided_strides[i + 1] = in_padded.strides()[i + 1] * wt_strides[i];
@@ -1000,7 +1000,7 @@ void explicit_gemm_conv_ND_cpu(
in_padded, strided_strides, flags, in_strided_view.size(), 0);
// Materialize strided view
std::vector<int> strided_reshape = {N, C};
Shape strided_reshape = {N, C};
for (const auto& o : oDim) {
strided_reshape[0] *= o;
}

View File

@@ -26,13 +26,13 @@ void copy_vector(const array& src, array& dst) {
std::copy(src_ptr, src_ptr + src.data_size(), dst_ptr);
}
template <typename SrcT, typename DstT, typename StrideT, int D>
template <typename SrcT, typename DstT, int D>
inline void copy_dims(
const SrcT* src,
DstT* dst,
const std::vector<int>& shape,
const std::vector<StrideT>& i_strides,
const std::vector<StrideT>& o_strides,
const Shape& shape,
const Strides& i_strides,
const Strides& o_strides,
int axis) {
auto stride_src = i_strides[axis];
auto stride_dst = o_strides[axis];
@@ -40,7 +40,7 @@ inline void copy_dims(
for (int i = 0; i < N; i++) {
if constexpr (D > 1) {
copy_dims<SrcT, DstT, StrideT, D - 1>(
copy_dims<SrcT, DstT, D - 1>(
src, dst, shape, i_strides, o_strides, axis + 1);
} else {
*dst = static_cast<DstT>(*src);
@@ -50,13 +50,13 @@ inline void copy_dims(
}
}
template <typename SrcT, typename DstT, typename StrideT>
template <typename SrcT, typename DstT>
void copy_general_general(
const array& src,
array& dst,
const std::vector<int>& data_shape,
const std::vector<StrideT>& i_strides,
const std::vector<StrideT>& o_strides,
const Shape& data_shape,
const Strides& i_strides,
const Strides& o_strides,
int64_t i_offset,
int64_t o_offset) {
if (data_shape.empty()) {
@@ -65,30 +65,30 @@ void copy_general_general(
*dst_ptr = val;
return;
}
auto [shape, strides] = collapse_contiguous_dims(
data_shape, std::vector<std::vector<StrideT>>{i_strides, o_strides});
auto [shape, strides] =
collapse_contiguous_dims(data_shape, {i_strides, o_strides});
auto src_ptr = src.data<SrcT>() + i_offset;
auto dst_ptr = dst.data<DstT>() + o_offset;
int ndim = shape.size();
if (ndim == 1) {
copy_dims<SrcT, DstT, StrideT, 1>(
copy_dims<SrcT, DstT, 1>(
src_ptr, dst_ptr, shape, strides[0], strides[1], 0);
return;
} else if (ndim == 2) {
copy_dims<SrcT, DstT, StrideT, 2>(
copy_dims<SrcT, DstT, 2>(
src_ptr, dst_ptr, shape, strides[0], strides[1], 0);
return;
} else if (ndim == 3) {
copy_dims<SrcT, DstT, StrideT, 3>(
copy_dims<SrcT, DstT, 3>(
src_ptr, dst_ptr, shape, strides[0], strides[1], 0);
return;
}
ContiguousIterator<StrideT> in(shape, strides[0], ndim - 3);
ContiguousIterator<StrideT> out(shape, strides[1], ndim - 3);
StrideT stride = std::accumulate(
shape.end() - 3, shape.end(), 1, std::multiplies<StrideT>());
for (StrideT elem = 0; elem < src.size(); elem += stride) {
copy_dims<SrcT, DstT, StrideT, 3>(
ContiguousIterator in(shape, strides[0], ndim - 3);
ContiguousIterator out(shape, strides[1], ndim - 3);
auto stride = std::accumulate(
shape.end() - 3, shape.end(), 1, std::multiplies<int64_t>());
for (int64_t elem = 0; elem < src.size(); elem += stride) {
copy_dims<SrcT, DstT, 3>(
src_ptr + in.loc,
dst_ptr + out.loc,
shape,
@@ -102,37 +102,37 @@ void copy_general_general(
template <typename SrcT, typename DstT>
inline void copy_general_general(const array& src, array& dst) {
copy_general_general<SrcT, DstT, size_t>(
copy_general_general<SrcT, DstT>(
src, dst, src.shape(), src.strides(), dst.strides(), 0, 0);
}
template <typename SrcT, typename DstT, typename StrideT>
template <typename SrcT, typename DstT>
void copy_general(
const array& src,
array& dst,
const std::vector<int>& data_shape,
const std::vector<StrideT>& i_strides,
const std::vector<StrideT>&,
const Shape& data_shape,
const Strides& i_strides,
const Strides&,
int64_t i_offset,
int64_t o_offset) {
copy_general_general<SrcT, DstT, StrideT>(
copy_general_general<SrcT, DstT>(
src,
dst,
data_shape,
i_strides,
make_contiguous_strides<StrideT>(data_shape),
make_contiguous_strides(data_shape),
i_offset,
o_offset);
}
template <typename SrcT, typename DstT>
inline void copy_general(const array& src, array& dst) {
copy_general_general<SrcT, DstT, size_t>(
copy_general_general<SrcT, DstT>(
src,
dst,
src.shape(),
src.strides(),
make_contiguous_strides<size_t>(src.shape()),
make_contiguous_strides(src.shape()),
0,
0);
}
@@ -282,13 +282,12 @@ void copy(const array& src, array& dst, CopyType ctype) {
copy_inplace(src, dst, ctype);
}
template <typename StrideT>
void copy_inplace(
const array& src,
array& dst,
const std::vector<int>& data_shape,
const std::vector<StrideT>& i_strides,
const std::vector<StrideT>& o_strides,
const Shape& data_shape,
const Strides& i_strides,
const Strides& o_strides,
int64_t i_offset,
int64_t o_offset,
CopyType ctype) {
@@ -311,24 +310,4 @@ void copy_inplace(
}
}
template void copy_inplace<size_t>(
const array& src,
array& dst,
const std::vector<int>& data_shape,
const std::vector<size_t>& i_strides,
const std::vector<size_t>& o_strides,
int64_t i_offset,
int64_t o_offset,
CopyType ctype);
template void copy_inplace<int64_t>(
const array& src,
array& dst,
const std::vector<int>& data_shape,
const std::vector<int64_t>& i_strides,
const std::vector<int64_t>& o_strides,
int64_t i_offset,
int64_t o_offset,
CopyType ctype);
} // namespace mlx::core

View File

@@ -26,13 +26,12 @@ enum class CopyType {
void copy(const array& src, array& dst, CopyType ctype);
void copy_inplace(const array& src, array& dst, CopyType ctype);
template <typename stride_t>
void copy_inplace(
const array& src,
array& dst,
const std::vector<int>& data_shape,
const std::vector<stride_t>& i_strides,
const std::vector<stride_t>& o_strides,
const Shape& data_shape,
const Strides& i_strides,
const Strides& o_strides,
int64_t i_offset,
int64_t o_offset,
CopyType ctype);

View File

@@ -130,7 +130,7 @@ inline void matmul_common_general(
} else {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy(arr, arr_copy, CopyType::General);
size_t stx = arr.shape(-1);
stx = arr.shape(-1);
return std::make_tuple(false, stx, arr_copy);
}
};

View File

@@ -32,7 +32,7 @@ void gather(
const std::vector<array>& inds,
array& out,
const std::vector<int>& axes,
const std::vector<int>& slice_sizes) {
const Shape& slice_sizes) {
// If the array is row contiguous then we can do a contiguous copy given
// two conditions on the slice size:
// - Any number of leading ones in the slice sizes are allowed
@@ -80,11 +80,10 @@ void gather(
T* dst_ptr = out.data<T>();
size_t out_idx = 0;
std::vector<ContiguousIterator<size_t>> its(inds.begin(), inds.end());
ContiguousIterator<size_t> src_it;
std::vector<ContiguousIterator> its(inds.begin(), inds.end());
ContiguousIterator src_it;
if (!can_copy && src.ndim() > 0) {
src_it = std::move(
ContiguousIterator<size_t>(slice_sizes, src.strides(), src.ndim()));
src_it = ContiguousIterator(slice_sizes, src.strides(), src.ndim());
}
for (int idx = 0; idx < ind_size; idx++) {
size_t src_idx = 0;
@@ -119,7 +118,7 @@ void dispatch_gather(
const std::vector<array>& inds,
array& out,
const std::vector<int>& axes,
const std::vector<int>& size) {
const Shape& size) {
switch (out.dtype()) {
case bool_:
gather<bool, IdxT>(src, inds, out, axes, size);
@@ -223,16 +222,16 @@ void scatter(
auto inds_ndim = updates.ndim() - out.ndim();
size_t n_updates = nind ? inds[0].size() : 1;
std::vector<int> update_shape(
Shape update_shape(
updates.shape().begin() + inds_ndim, updates.shape().end());
size_t update_size = 1;
for (auto us : update_shape) {
update_size *= us;
}
std::vector<ContiguousIterator<size_t>> its(inds.begin(), inds.end());
ContiguousIterator<size_t> update_it(updates);
ContiguousIterator<size_t> out_it(update_shape, out.strides(), out.ndim());
std::vector<ContiguousIterator> its(inds.begin(), inds.end());
ContiguousIterator update_it(updates);
ContiguousIterator out_it(update_shape, out.strides(), out.ndim());
for (int i = 0; i < n_updates; ++i) {
size_t out_offset = 0;

View File

@@ -19,10 +19,10 @@ inline void mask_matrix(
int block_size,
const int X,
const int Y,
const size_t X_data_str,
const size_t Y_data_str,
const size_t X_mask_str,
const size_t Y_mask_str,
const int64_t X_data_str,
const int64_t Y_data_str,
const int64_t X_mask_str,
const int64_t Y_mask_str,
const size_t mask_offset) {
int tX = (X + block_size - 1) / block_size;
int tY = (Y + block_size - 1) / block_size;
@@ -84,7 +84,7 @@ void BlockMaskedMM::eval(const std::vector<array>& inputs, array& out) {
} else {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy(arr, arr_copy, CopyType::General);
size_t stx = arr.shape(-1);
int64_t stx = arr.shape(-1);
return std::make_tuple(false, stx, arr_copy);
}
};
@@ -117,13 +117,13 @@ void BlockMaskedMM::eval(const std::vector<array>& inputs, array& out) {
int Y,
size_t X_data_str,
size_t Y_data_str) {
size_t mask_offset = elem_to_loc(
auto mask_offset = elem_to_loc(
mask.shape(-1) * mask.shape(-2) * batch_idx,
mask.shape(),
mask.strides());
size_t X_mask_str = mask.strides()[mask.ndim() - 2];
size_t Y_mask_str = mask.strides()[mask.ndim() - 1];
auto X_mask_str = mask.strides()[mask.ndim() - 2];
auto Y_mask_str = mask.strides()[mask.ndim() - 1];
if (mask.dtype() == bool_) {
return mask_matrix(
@@ -230,7 +230,7 @@ void GatherMM::eval(const std::vector<array>& inputs, array& out) {
} else {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy(arr, arr_copy, CopyType::General);
size_t stx = arr.shape(-1);
int64_t stx = arr.shape(-1);
return std::make_tuple(false, stx, arr_copy);
}
};
@@ -262,13 +262,13 @@ void GatherMM::eval(const std::vector<array>& inputs, array& out) {
auto& lhs_indices = inputs[2];
auto& rhs_indices = inputs[3];
std::vector<int> batch_shape = get_batch_dims(out.shape());
auto batch_shape = get_batch_dims(out.shape());
int batch_ndim = batch_shape.size();
std::vector<int> batch_shape_A = get_batch_dims(a.shape());
std::vector<size_t> batch_strides_A = get_batch_dims(a.strides());
std::vector<int> batch_shape_B = get_batch_dims(b.shape());
std::vector<size_t> batch_strides_B = get_batch_dims(b.strides());
auto batch_shape_A = get_batch_dims(a.shape());
auto batch_strides_A = get_batch_dims(a.strides());
auto batch_shape_B = get_batch_dims(b.shape());
auto batch_strides_B = get_batch_dims(b.strides());
const uint32_t* lhs_indices_ptr = lhs_indices.data<uint32_t>();
const uint32_t* rhs_indices_ptr = rhs_indices.data<uint32_t>();

View File

@@ -498,14 +498,15 @@ void Slice::eval(const std::vector<array>& inputs, array& out) {
auto& in = inputs[0];
// Calculate out strides, initial offset and if copy needs to be made
auto [copy_needed, data_offset, inp_strides] =
prepare_slice(in, start_indices_, strides_);
auto [data_offset, inp_strides] = prepare_slice(in, start_indices_, strides_);
auto copy_needed = std::any_of(
strides_.begin(), strides_.end(), [](auto i) { return i < 0; });
// Do copy if needed
if (copy_needed) {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
std::vector<int64_t> ostrides{out.strides().begin(), out.strides().end()};
copy_inplace<int64_t>(
Strides ostrides{out.strides().begin(), out.strides().end()};
copy_inplace(
/* const array& src = */ in,
/* array& dst = */ out,
/* const std::vector<int>& data_shape = */ out.shape(),
@@ -523,7 +524,7 @@ void Slice::eval(const std::vector<array>& inputs, array& out) {
}
}
size_t data_size = data_end - data_offset;
std::vector<size_t> ostrides{inp_strides.begin(), inp_strides.end()};
Strides ostrides{inp_strides.begin(), inp_strides.end()};
shared_buffer_slice(in, ostrides, data_offset, data_size, out);
}
}
@@ -550,11 +551,11 @@ void SliceUpdate::eval(const std::vector<array>& inputs, array& out) {
copy(in, out, in.data_size() == 1 ? CopyType::Scalar : ctype);
// Calculate out strides, initial offset and if copy needs to be made
auto [data_offset, out_strides] = prepare_slice(out);
auto [data_offset, out_strides] = prepare_slice(in, start_indices_, strides_);
// Do copy
std::vector<int64_t> upd_strides{upd.strides().begin(), upd.strides().end()};
copy_inplace<int64_t>(
Strides upd_strides{upd.strides().begin(), upd.strides().end()};
copy_inplace(
/* const array& src = */ upd,
/* array& dst = */ out,
/* const std::vector<int>& data_shape = */ upd.shape(),

View File

@@ -54,7 +54,7 @@ void qrf_impl(const array& a, array& q, array& r) {
// Copy the input to be column contiguous
flags.col_contiguous = num_matrices == 1;
flags.row_contiguous = false;
std::vector<size_t> strides = in.strides();
auto strides = in.strides();
strides[in.ndim() - 2] = 1;
strides[in.ndim() - 1] = M;
in.set_data(

View File

@@ -174,19 +174,19 @@ void reduce_dispatch_min_max(
void nd_loop(
std::function<void(int)> callback,
const std::vector<int>& shape,
const std::vector<size_t>& strides) {
const Shape& shape,
const Strides& strides) {
std::function<void(int, int)> loop_inner;
loop_inner = [&](int dim, int offset) {
if (dim < shape.size() - 1) {
int size = shape[dim];
size_t stride = strides[dim];
auto size = shape[dim];
auto stride = strides[dim];
for (int i = 0; i < size; i++) {
loop_inner(dim + 1, offset + i * stride);
}
} else {
int size = shape[dim];
size_t stride = strides[dim];
auto size = shape[dim];
auto stride = strides[dim];
for (int i = 0; i < size; i++) {
callback(offset + i * stride);
}

View File

@@ -38,13 +38,10 @@ enum ReductionOpType {
struct ReductionPlan {
ReductionOpType type;
std::vector<int> shape;
std::vector<size_t> strides;
Shape shape;
Strides strides;
ReductionPlan(
ReductionOpType type_,
std::vector<int> shape_,
std::vector<size_t> strides_)
ReductionPlan(ReductionOpType type_, Shape shape_, Strides strides_)
: type(type_), shape(std::move(shape_)), strides(std::move(strides_)) {}
ReductionPlan(ReductionOpType type_) : type(type_) {}
};
@@ -55,10 +52,10 @@ ReductionPlan get_reduction_plan(const array& x, const std::vector<int>& axes);
// Should this be in utils?
void nd_loop(
std::function<void(int)> callback,
const std::vector<int>& shape,
const std::vector<size_t>& strides);
const Shape& shape,
const Strides& strides);
std::pair<std::vector<int>, std::vector<size_t>> shapes_without_reduction_axes(
std::pair<Shape, Strides> shapes_without_reduction_axes(
const array& x,
const std::vector<int>& axes);
@@ -113,9 +110,6 @@ void reduction_op(
return;
}
std::vector<int> shape;
std::vector<size_t> strides;
if (plan.type == ContiguousReduce && plan.shape.size() == 1) {
int reduction_size = plan.shape[0];
const T* x_ptr = x.data<T>();
@@ -135,7 +129,7 @@ void reduction_op(
U* out_ptr = out.data<U>();
// Unrolling the following loop (and implementing it in order for
// ContiguousReduce) should hold extra performance boost.
std::tie(shape, strides) = shapes_without_reduction_axes(x, axes);
auto [shape, strides] = shapes_without_reduction_axes(x, axes);
if (plan.shape.size() == 0) {
for (int i = 0; i < out.size(); i++, out_ptr++) {
int offset = elem_to_loc(i, shape, strides);
@@ -181,7 +175,7 @@ void reduction_op(
plan.strides.pop_back();
const T* x_ptr = x.data<T>();
U* out_ptr = out.data<U>();
std::tie(shape, strides) = shapes_without_reduction_axes(x, axes);
auto [shape, strides] = shapes_without_reduction_axes(x, axes);
if (plan.shape.size() == 0) {
for (int i = 0; i < out.size(); i += reduction_stride) {
int offset = elem_to_loc(i, shape, strides);
@@ -211,7 +205,7 @@ void reduction_op(
if (plan.type == GeneralReduce) {
const T* x_ptr = x.data<T>();
U* out_ptr = out.data<U>();
std::tie(shape, strides) = shapes_without_reduction_axes(x, axes);
auto [shape, strides] = shapes_without_reduction_axes(x, axes);
for (int i = 0; i < out.size(); i++, out_ptr++) {
int offset = elem_to_loc(i, shape, strides);
U val = init;

View File

@@ -4,11 +4,11 @@
namespace mlx::core {
std::pair<std::vector<int>, std::vector<size_t>> shapes_without_reduction_axes(
std::pair<Shape, Strides> shapes_without_reduction_axes(
const array& x,
const std::vector<int>& axes) {
std::vector<int> shape = x.shape();
std::vector<size_t> strides = x.strides();
auto shape = x.shape();
auto strides = x.strides();
for (int i = axes.size() - 1; i >= 0; i--) {
int a = axes[i];
@@ -29,8 +29,8 @@ ReductionPlan get_reduction_plan(const array& x, const std::vector<int>& axes) {
// Row contiguous input so the output is row contiguous
if (x.flags().row_contiguous) {
// Merge consecutive axes
std::vector<int> shape = {x.shape(axes[0])};
std::vector<size_t> strides = {x.strides()[axes[0]]};
Shape shape = {x.shape(axes[0])};
Strides strides = {x.strides()[axes[0]]};
for (int i = 1; i < axes.size(); i++) {
if (axes[i] - 1 == axes[i - 1] && x.shape(axes[i]) > 1) {
shape.back() *= x.shape(axes[i]);
@@ -69,7 +69,7 @@ ReductionPlan get_reduction_plan(const array& x, const std::vector<int>& axes) {
// Sort reduction axes by stride in order to merge them and figure out if we
// have a contiguous reduction.
std::vector<std::pair<int, size_t>> reductions;
std::vector<std::pair<int, int64_t>> reductions;
for (auto a : axes) {
if (x.shape(a) > 1) {
reductions.push_back(std::make_pair(x.shape(a), x.strides()[a]));
@@ -93,8 +93,8 @@ ReductionPlan get_reduction_plan(const array& x, const std::vector<int>& axes) {
}
}
std::vector<int> shape;
std::vector<size_t> strides;
Shape shape;
Strides strides;
for (auto r : reductions) {
shape.push_back(r.first);
strides.push_back(r.second);
@@ -109,15 +109,15 @@ ReductionPlan get_reduction_plan(const array& x, const std::vector<int>& axes) {
// Delegate to the general strided reduction op if the axes after
// strides.back() are contiguous.
if (strides.back() > 1) {
int size = 1;
int64_t size = 1;
bool have_expand = false;
for (int i = x.ndim() - 1; i >= 0; i--) {
if (axes.back() == i) {
continue;
}
size_t stride_i = x.strides()[i];
int shape_i = x.shape(i);
auto stride_i = x.strides()[i];
auto shape_i = x.shape(i);
if (stride_i == 0) {
if (shape_i == 1) {
continue;

View File

@@ -4,24 +4,22 @@
namespace mlx::core {
std::tuple<bool, int64_t, std::vector<int64_t>> prepare_slice(
std::tuple<int64_t, Strides> prepare_slice(
const array& in,
const std::vector<int>& start_indices,
const std::vector<int>& strides) {
const Shape& start_indices,
const Shape& strides) {
int64_t data_offset = 0;
bool copy_needed = false;
std::vector<int64_t> inp_strides(in.ndim(), 0);
Strides inp_strides(in.ndim(), 0);
for (int i = 0; i < in.ndim(); ++i) {
data_offset += start_indices[i] * in.strides()[i];
inp_strides[i] = in.strides()[i] * strides[i];
copy_needed |= strides[i] < 0;
}
return std::make_tuple(copy_needed, data_offset, inp_strides);
return std::make_tuple(data_offset, inp_strides);
}
void shared_buffer_slice(
const array& in,
const std::vector<size_t>& out_strides,
const Strides& out_strides,
size_t data_offset,
size_t data_size,
array& out) {

View File

@@ -6,14 +6,14 @@
namespace mlx::core {
std::tuple<bool, int64_t, std::vector<int64_t>> prepare_slice(
std::tuple<int64_t, Strides> prepare_slice(
const array& in,
const std::vector<int>& start_indices,
const std::vector<int>& strides);
const Shape& start_indices,
const Shape& strides);
void shared_buffer_slice(
const array& in,
const std::vector<size_t>& out_strides,
const Strides& out_strides,
size_t data_offset,
size_t data_size,
array& out);

View File

@@ -25,7 +25,7 @@ struct StridedIterator {
// Constructors
StridedIterator() = default;
explicit StridedIterator(T* ptr, size_t stride, difference_type offset = 0)
explicit StridedIterator(T* ptr, int64_t stride, difference_type offset = 0)
: ptr_(ptr + offset * stride), stride_(stride) {}
explicit StridedIterator(array& arr, int axis, difference_type offset = 0)
@@ -99,7 +99,7 @@ struct StridedIterator {
}
private:
size_t stride_;
int64_t stride_;
T* ptr_;
};
@@ -120,11 +120,11 @@ void sort(const array& in, array& out, int axis) {
auto remaining_strides = out.strides();
remaining_strides.erase(remaining_strides.begin() + axis);
size_t axis_stride = out.strides()[axis];
int axis_size = out.shape(axis);
auto axis_stride = out.strides()[axis];
auto axis_size = out.shape(axis);
// Perform sorting in place
ContiguousIterator<size_t> src_it(
ContiguousIterator src_it(
remaining_shape, remaining_strides, remaining_shape.size());
for (int i = 0; i < n_rows; i++) {
T* data_ptr = out.data<T>() + src_it.loc;
@@ -158,14 +158,14 @@ void argsort(const array& in, array& out, int axis) {
auto out_remaining_strides = out.strides();
out_remaining_strides.erase(out_remaining_strides.begin() + axis);
size_t in_stride = in.strides()[axis];
size_t out_stride = out.strides()[axis];
int axis_size = in.shape(axis);
auto in_stride = in.strides()[axis];
auto out_stride = out.strides()[axis];
auto axis_size = in.shape(axis);
// Perform sorting
ContiguousIterator<size_t> in_it(
ContiguousIterator in_it(
in_remaining_shape, in_remaining_strides, in_remaining_shape.size());
ContiguousIterator<size_t> out_it(
ContiguousIterator out_it(
out_remaining_shape, out_remaining_strides, out_remaining_shape.size());
for (int i = 0; i < n_rows; i++) {
const T* data_ptr = in.data<T>() + in_it.loc;
@@ -208,13 +208,13 @@ void partition(const array& in, array& out, int axis, int kth) {
auto remaining_strides = in.strides();
remaining_strides.erase(remaining_strides.begin() + axis);
size_t axis_stride = in.strides()[axis];
auto axis_stride = in.strides()[axis];
int axis_size = in.shape(axis);
kth = kth < 0 ? kth + axis_size : kth;
// Perform partition in place
ContiguousIterator<size_t> src_it(
ContiguousIterator src_it(
remaining_shape, remaining_strides, remaining_shape.size());
for (int i = 0; i < n_rows; i++) {
T* data_ptr = out.data<T>() + src_it.loc;
@@ -249,16 +249,16 @@ void argpartition(const array& in, array& out, int axis, int kth) {
auto out_remaining_strides = out.strides();
out_remaining_strides.erase(out_remaining_strides.begin() + axis);
size_t in_stride = in.strides()[axis];
size_t out_stride = out.strides()[axis];
int axis_size = in.shape(axis);
auto in_stride = in.strides()[axis];
auto out_stride = out.strides()[axis];
auto axis_size = in.shape(axis);
kth = kth < 0 ? kth + axis_size : kth;
// Perform partition
ContiguousIterator<size_t> in_it(
ContiguousIterator in_it(
in_remaining_shape, in_remaining_strides, in_remaining_shape.size());
ContiguousIterator<size_t> out_it(
ContiguousIterator out_it(
out_remaining_shape, out_remaining_strides, out_remaining_shape.size());
for (int i = 0; i < n_rows; i++) {
const T* data_ptr = in.data<T>() + in_it.loc;

View File

@@ -78,11 +78,11 @@ void ternary_op_dims(
const T3* c,
U* out,
Op op,
const std::vector<int>& shape,
const std::vector<size_t>& a_strides,
const std::vector<size_t>& b_strides,
const std::vector<size_t>& c_strides,
const std::vector<size_t>& out_strides,
const Shape& shape,
const Strides& a_strides,
const Strides& b_strides,
const Strides& c_strides,
const Strides& out_strides,
int axis) {
auto stride_a = a_strides[axis];
auto stride_b = b_strides[axis];
@@ -164,10 +164,10 @@ void ternary_op_dispatch_dims(
return;
}
ContiguousIterator<size_t> a_it(shape, a_strides, ndim - 2);
ContiguousIterator<size_t> b_it(shape, b_strides, ndim - 2);
ContiguousIterator<size_t> c_it(shape, c_strides, ndim - 2);
size_t stride = out_strides[ndim - 3];
ContiguousIterator a_it(shape, a_strides, ndim - 2);
ContiguousIterator b_it(shape, b_strides, ndim - 2);
ContiguousIterator c_it(shape, c_strides, ndim - 2);
auto stride = out_strides[ndim - 3];
for (size_t elem = 0; elem < a.size(); elem += stride) {
ternary_op_dims<T1, T2, T3, U, Op, 2>(
a_ptr + a_it.loc,

View File

@@ -15,7 +15,7 @@ void move_or_copy(const array& in, array& out) {
void move_or_copy(
const array& in,
array& out,
const std::vector<size_t>& strides,
const Strides& strides,
array::Flags flags,
size_t data_size,
size_t offset /* = 0 */) {
@@ -26,15 +26,13 @@ void move_or_copy(
}
}
template <typename StrideT>
std::tuple<std::vector<int>, std::vector<std::vector<StrideT>>>
collapse_contiguous_dims_impl(
const std::vector<int>& shape,
const std::vector<std::vector<StrideT>>& strides,
StrideT size_cap) {
std::tuple<Shape, std::vector<Strides>> collapse_contiguous_dims(
const Shape& shape,
const std::vector<Strides>& strides,
int64_t size_cap) {
// Make a vector that has axes separated with -1. Collapse all axes between
// -1.
std::vector<int> to_collapse;
Shape to_collapse;
if (shape.size() > 0) {
if (shape[0] != 1) {
to_collapse.push_back(0);
@@ -43,7 +41,7 @@ collapse_contiguous_dims_impl(
for (int i = 1; i < shape.size(); i++) {
bool contiguous = true;
size *= shape[i];
for (const std::vector<StrideT>& st : strides) {
for (const auto& st : strides) {
if (st[i] * shape[i] != st[i - 1] || size > size_cap) {
contiguous = false;
size = shape[i];
@@ -60,8 +58,8 @@ collapse_contiguous_dims_impl(
to_collapse.push_back(-1);
}
std::vector<int> out_shape;
std::vector<std::vector<StrideT>> out_strides(strides.size());
Shape out_shape;
std::vector<Strides> out_strides(strides.size());
for (int i = 0;;) {
while (i < to_collapse.size() && to_collapse[i] == -1) {
++i;
@@ -76,7 +74,7 @@ collapse_contiguous_dims_impl(
}
out_shape.push_back(current_shape);
for (int j = 0; j < strides.size(); j++) {
const std::vector<StrideT>& st = strides[j];
const auto& st = strides[j];
out_strides[j].push_back(st[to_collapse[k - 1]]);
}
i = k + 1;
@@ -91,29 +89,12 @@ collapse_contiguous_dims_impl(
return std::make_tuple(out_shape, out_strides);
}
std::tuple<std::vector<int>, std::vector<std::vector<int64_t>>>
collapse_contiguous_dims(
const std::vector<int>& shape,
const std::vector<std::vector<int64_t>>& strides,
int64_t size_cap /* = std::numeric_limits<int32_t>::max() */) {
return collapse_contiguous_dims_impl(shape, strides, size_cap);
}
std::tuple<std::vector<int>, std::vector<std::vector<size_t>>>
collapse_contiguous_dims(
const std::vector<int>& shape,
const std::vector<std::vector<size_t>>& strides,
size_t size_cap /* = std::numeric_limits<int32>::max() */) {
return collapse_contiguous_dims_impl(shape, strides, size_cap);
}
template <typename StrideT>
std::pair<std::vector<int>, std::vector<StrideT>> collapse_contiguous_dims_impl(
const std::vector<int>& shape,
const std::vector<StrideT>& strides,
StrideT size_cap) {
std::vector<int> collapsed_shape;
std::vector<StrideT> collapsed_strides;
std::pair<Shape, Strides> collapse_contiguous_dims(
const Shape& shape,
const Strides& strides,
int64_t size_cap) {
Shape collapsed_shape;
Strides collapsed_strides;
if (shape.size() > 0) {
collapsed_shape.push_back(shape[0]);
@@ -123,7 +104,7 @@ std::pair<std::vector<int>, std::vector<StrideT>> collapse_contiguous_dims_impl(
continue;
} else if (
strides[i] * shape[i] != collapsed_strides.back() ||
collapsed_shape.back() * static_cast<StrideT>(shape[i]) > size_cap) {
collapsed_shape.back() * static_cast<int64_t>(shape[i]) > size_cap) {
collapsed_shape.push_back(shape[i]);
collapsed_strides.push_back(strides[i]);
} else {
@@ -136,25 +117,10 @@ std::pair<std::vector<int>, std::vector<StrideT>> collapse_contiguous_dims_impl(
return std::make_pair(collapsed_shape, collapsed_strides);
}
std::pair<std::vector<int>, std::vector<int64_t>> collapse_contiguous_dims(
const std::vector<int>& shape,
const std::vector<int64_t>& strides,
int64_t size_cap /* = std::numeric_limits<int32_t>::max() */) {
return collapse_contiguous_dims_impl<int64_t>(shape, strides, size_cap);
}
std::pair<std::vector<int>, std::vector<size_t>> collapse_contiguous_dims(
const std::vector<int>& shape,
const std::vector<size_t>& strides,
size_t size_cap /* = std::numeric_limits<int32_t>::max() */) {
return collapse_contiguous_dims_impl<size_t>(shape, strides, size_cap);
}
std::pair<std::vector<int>, std::vector<size_t>> collapse_contiguous_dims(
std::pair<Shape, Strides> collapse_contiguous_dims(
const array& a,
size_t size_cap /* = std::numeric_limits<int32_t>::max()*/) {
return collapse_contiguous_dims_impl<size_t>(
a.shape(), a.strides(), size_cap);
int64_t size_cap /* = std::numeric_limits<int32_t>::max()*/) {
return collapse_contiguous_dims(a.shape(), a.strides(), size_cap);
}
} // namespace mlx::core

View File

@@ -8,12 +8,9 @@
namespace mlx::core {
template <typename StrideT>
inline StrideT elem_to_loc(
int elem,
const std::vector<int>& shape,
const std::vector<StrideT>& strides) {
StrideT loc = 0;
inline int64_t
elem_to_loc(int elem, const Shape& shape, const Strides& strides) {
int64_t loc = 0;
for (int i = shape.size() - 1; i >= 0; --i) {
auto q_and_r = ldiv(elem, shape[i]);
loc += q_and_r.rem * strides[i];
@@ -22,16 +19,15 @@ inline StrideT elem_to_loc(
return loc;
}
inline size_t elem_to_loc(int elem, const array& a) {
inline int64_t elem_to_loc(int elem, const array& a) {
if (a.flags().row_contiguous) {
return elem;
}
return elem_to_loc(elem, a.shape(), a.strides());
}
template <typename StrideT>
std::vector<StrideT> make_contiguous_strides(const std::vector<int>& shape) {
std::vector<StrideT> strides(shape.size(), 1);
inline Strides make_contiguous_strides(const Shape& shape) {
Strides strides(shape.size(), 1);
for (int i = shape.size() - 1; i > 0; i--) {
strides[i - 1] = strides[i] * shape[i];
}
@@ -44,22 +40,15 @@ std::vector<StrideT> make_contiguous_strides(const std::vector<int>& shape) {
//
// When multiple arrays are passed they should all have the same shape. The
// collapsed axes are also the same so one shape is returned.
std::tuple<std::vector<int>, std::vector<std::vector<int64_t>>>
collapse_contiguous_dims(
const std::vector<int>& shape,
const std::vector<std::vector<int64_t>>& strides,
std::tuple<Shape, std::vector<Strides>> collapse_contiguous_dims(
const Shape& shape,
const std::vector<Strides>& strides,
int64_t size_cap = std::numeric_limits<int32_t>::max());
std::tuple<std::vector<int>, std::vector<std::vector<size_t>>>
collapse_contiguous_dims(
const std::vector<int>& shape,
const std::vector<std::vector<size_t>>& strides,
size_t size_cap = std::numeric_limits<int32_t>::max());
inline std::tuple<std::vector<int>, std::vector<std::vector<size_t>>>
collapse_contiguous_dims(
inline std::tuple<Shape, std::vector<Strides>> collapse_contiguous_dims(
const std::vector<array>& xs,
size_t size_cap = std::numeric_limits<int32_t>::max()) {
std::vector<std::vector<size_t>> strides;
std::vector<Strides> strides;
for (auto& x : xs) {
strides.emplace_back(x.strides());
}
@@ -73,19 +62,14 @@ inline auto collapse_contiguous_dims(Arrays&&... xs) {
}
// The single array version of the above.
std::pair<std::vector<int>, std::vector<int64_t>> collapse_contiguous_dims(
const std::vector<int>& shape,
const std::vector<int64_t>& strides,
std::pair<Shape, Strides> collapse_contiguous_dims(
const Shape& shape,
const Strides& strides,
int64_t size_cap = std::numeric_limits<int32_t>::max());
std::pair<std::vector<int>, std::vector<size_t>> collapse_contiguous_dims(
const std::vector<int>& shape,
const std::vector<size_t>& strides,
size_t size_cap = std::numeric_limits<int32_t>::max());
std::pair<std::vector<int>, std::vector<size_t>> collapse_contiguous_dims(
std::pair<Shape, Strides> collapse_contiguous_dims(
const array& a,
size_t size_cap = std::numeric_limits<int32_t>::max());
int64_t size_cap = std::numeric_limits<int32_t>::max());
template <typename StrideT>
struct ContiguousIterator {
inline void step() {
int dims = shape_.size();
@@ -102,7 +86,7 @@ struct ContiguousIterator {
loc += strides_[i];
}
void seek(StrideT n) {
void seek(int64_t n) {
loc = 0;
for (int i = shape_.size() - 1; i >= 0; --i) {
auto q_and_r = ldiv(n, shape_[i]);
@@ -128,32 +112,29 @@ struct ContiguousIterator {
}
explicit ContiguousIterator(
const std::vector<int>& shape,
const std::vector<StrideT>& strides,
const Shape& shape,
const Strides& strides,
int dims)
: shape_(shape.begin(), shape.begin() + dims),
strides_(strides.begin(), strides.begin() + dims) {
if (!shape_.empty()) {
std::tie(shape_, strides_) = collapse_contiguous_dims(shape_, strides_);
pos_ = std::vector<int>(shape_.size(), 0);
pos_ = Shape(shape_.size(), 0);
}
}
StrideT loc{0};
int64_t loc{0};
private:
std::vector<int> shape_;
std::vector<StrideT> strides_;
std::vector<int> pos_;
Shape shape_;
Strides strides_;
Shape pos_;
};
template <typename StrideT>
inline auto check_contiguity(
const std::vector<int>& shape,
const std::vector<StrideT>& strides) {
inline auto check_contiguity(const Shape& shape, const Strides& strides) {
size_t no_broadcast_data_size = 1;
size_t f_stride = 1;
size_t b_stride = 1;
int64_t f_stride = 1;
int64_t b_stride = 1;
bool is_row_contiguous = true;
bool is_col_contiguous = true;
@@ -182,7 +163,7 @@ void move_or_copy(const array& in, array& out);
void move_or_copy(
const array& in,
array& out,
const std::vector<size_t>& strides,
const Strides& strides,
array::Flags flags,
size_t data_size,
size_t offset = 0);