Add a SliceUpdate op and primitive (#850)

* Enable copy to work with int64 strides
* Fix uniform buffer indices or copy kernel arguments
* Update utils.h
* Remove manual unrolling of elem to loc loop
* GPU copy updated to handle negative strides
* Add slice update primitive
This commit is contained in:
Jagrit Digani 2024-03-20 10:39:25 -07:00 committed by GitHub
parent 73a8c090e0
commit cec8661113
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
21 changed files with 1147 additions and 506 deletions

View File

@ -69,6 +69,7 @@ DEFAULT(Select)
DEFAULT(Sigmoid) DEFAULT(Sigmoid)
DEFAULT(Sign) DEFAULT(Sign)
DEFAULT(Slice) DEFAULT(Slice)
DEFAULT(SliceUpdate)
DEFAULT_MULTI(Split) DEFAULT_MULTI(Split)
DEFAULT(Sort) DEFAULT(Sort)
DEFAULT(StopGradient) DEFAULT(StopGradient)

View File

@ -1,4 +1,4 @@
// Copyright © 2023 Apple Inc. // Copyright © 2023-2024 Apple Inc.
#include <numeric> #include <numeric>
@ -25,121 +25,196 @@ void copy_vector(const array& src, array& dst) {
std::copy(src_ptr, src_ptr + src.data_size(), dst_ptr); std::copy(src_ptr, src_ptr + src.data_size(), dst_ptr);
} }
template <typename SrcT, typename DstT> template <typename SrcT, typename DstT, typename stride_t>
void copy_general_dim1(const array& src, array& dst) { void copy_general_dim1(
const array& src,
array& dst,
const std::vector<int>& data_shape,
const std::vector<stride_t>& i_strides,
int64_t i_offset) {
const SrcT* src_ptr = src.data<SrcT>(); const SrcT* src_ptr = src.data<SrcT>();
DstT* dst_ptr = dst.data<DstT>(); DstT* dst_ptr = dst.data<DstT>();
size_t src_idx = 0; stride_t src_idx = i_offset;
size_t dst_idx = 0; stride_t dst_idx = 0;
for (size_t i = 0; i < src.shape()[0]; ++i) { for (int i = 0; i < data_shape[0]; ++i) {
dst_ptr[dst_idx++] = static_cast<DstT>(src_ptr[src_idx]); dst_ptr[dst_idx++] = static_cast<DstT>(src_ptr[src_idx]);
src_idx += src.strides()[0]; src_idx += i_strides[0];
} }
} }
template <typename SrcT, typename DstT> template <typename SrcT, typename DstT>
void copy_general_dim2(const array& src, array& dst) { inline void copy_general_dim1(const array& src, array& dst) {
return copy_general_dim1<SrcT, DstT, size_t>(
src, dst, src.shape(), src.strides(), 0);
}
template <typename SrcT, typename DstT, typename stride_t>
void copy_general_dim2(
const array& src,
array& dst,
const std::vector<int>& data_shape,
const std::vector<stride_t>& i_strides,
int64_t i_offset) {
const SrcT* src_ptr = src.data<SrcT>(); const SrcT* src_ptr = src.data<SrcT>();
DstT* dst_ptr = dst.data<DstT>(); DstT* dst_ptr = dst.data<DstT>();
size_t src_idx = 0; stride_t src_idx = i_offset;
size_t dst_idx = 0; stride_t dst_idx = 0;
for (size_t i = 0; i < src.shape()[0]; ++i) { for (int i = 0; i < data_shape[0]; ++i) {
for (size_t j = 0; j < src.shape()[1]; ++j) { for (int j = 0; j < data_shape[1]; ++j) {
dst_ptr[dst_idx++] = static_cast<DstT>(src_ptr[src_idx]); dst_ptr[dst_idx++] = static_cast<DstT>(src_ptr[src_idx]);
src_idx += src.strides()[1]; src_idx += i_strides[1];
} }
src_idx += src.strides()[0] - src.strides()[1] * src.shape()[1]; src_idx += i_strides[0] - i_strides[1] * data_shape[1];
} }
} }
template <typename SrcT, typename DstT> template <typename SrcT, typename DstT>
void copy_general_dim3(const array& src, array& dst) { inline void copy_general_dim2(const array& src, array& dst) {
return copy_general_dim2<SrcT, DstT, size_t>(
src, dst, src.shape(), src.strides(), 0);
}
template <typename SrcT, typename DstT, typename stride_t>
void copy_general_dim3(
const array& src,
array& dst,
const std::vector<int>& data_shape,
const std::vector<stride_t>& i_strides,
int64_t i_offset) {
const SrcT* src_ptr = src.data<SrcT>(); const SrcT* src_ptr = src.data<SrcT>();
DstT* dst_ptr = dst.data<DstT>(); DstT* dst_ptr = dst.data<DstT>();
size_t src_idx = 0; stride_t src_idx = i_offset;
size_t dst_idx = 0; stride_t dst_idx = 0;
for (size_t i = 0; i < src.shape()[0]; ++i) { for (int i = 0; i < data_shape[0]; ++i) {
for (size_t j = 0; j < src.shape()[1]; ++j) { for (int j = 0; j < data_shape[1]; ++j) {
for (size_t k = 0; k < src.shape()[2]; ++k) { for (int k = 0; k < data_shape[2]; ++k) {
dst_ptr[dst_idx++] = static_cast<DstT>(src_ptr[src_idx]); dst_ptr[dst_idx++] = static_cast<DstT>(src_ptr[src_idx]);
src_idx += src.strides()[2]; src_idx += i_strides[2];
} }
src_idx += src.strides()[1] - src.strides()[2] * src.shape()[2]; src_idx += i_strides[1] - i_strides[2] * data_shape[2];
} }
src_idx += src.strides()[0] - src.strides()[1] * src.shape()[1]; src_idx += i_strides[0] - i_strides[1] * data_shape[1];
} }
} }
template <typename SrcT, typename DstT> template <typename SrcT, typename DstT>
void copy_general_dim4(const array& src, array& dst) { inline void copy_general_dim3(const array& src, array& dst) {
return copy_general_dim3<SrcT, DstT, size_t>(
src, dst, src.shape(), src.strides(), 0);
}
template <typename SrcT, typename DstT, typename stride_t>
void copy_general_dim4(
const array& src,
array& dst,
const std::vector<int>& data_shape,
const std::vector<stride_t>& i_strides,
int64_t i_offset) {
const SrcT* src_ptr = src.data<SrcT>(); const SrcT* src_ptr = src.data<SrcT>();
DstT* dst_ptr = dst.data<DstT>(); DstT* dst_ptr = dst.data<DstT>();
size_t src_idx = 0; stride_t src_idx = i_offset;
size_t dst_idx = 0; stride_t dst_idx = 0;
for (size_t i = 0; i < src.shape()[0]; ++i) { for (int i = 0; i < data_shape[0]; ++i) {
for (size_t j = 0; j < src.shape()[1]; ++j) { for (int j = 0; j < data_shape[1]; ++j) {
for (size_t k = 0; k < src.shape()[2]; ++k) { for (int k = 0; k < data_shape[2]; ++k) {
for (size_t ii = 0; ii < src.shape()[3]; ++ii) { for (int ii = 0; ii < data_shape[3]; ++ii) {
dst_ptr[dst_idx++] = static_cast<DstT>(src_ptr[src_idx]); dst_ptr[dst_idx++] = static_cast<DstT>(src_ptr[src_idx]);
src_idx += src.strides()[3]; src_idx += i_strides[3];
} }
src_idx += src.strides()[2] - src.strides()[3] * src.shape()[3]; src_idx += i_strides[2] - i_strides[3] * data_shape[3];
} }
src_idx += src.strides()[1] - src.strides()[2] * src.shape()[2]; src_idx += i_strides[1] - i_strides[2] * data_shape[2];
} }
src_idx += src.strides()[0] - src.strides()[1] * src.shape()[1]; src_idx += i_strides[0] - i_strides[1] * data_shape[1];
} }
} }
template <typename SrcT, typename DstT> template <typename SrcT, typename DstT>
void copy_general(const array& src, array& dst) { inline void copy_general_dim4(const array& src, array& dst) {
return copy_general_dim4<SrcT, DstT, size_t>(
src, dst, src.shape(), src.strides(), 0);
}
template <typename SrcT, typename DstT, typename stride_t>
void copy_general(
const array& src,
array& dst,
const std::vector<int>& data_shape,
const std::vector<stride_t>& i_strides,
int64_t i_offset) {
switch (src.ndim()) { switch (src.ndim()) {
case 1: case 1:
copy_general_dim1<SrcT, DstT>(src, dst); copy_general_dim1<SrcT, DstT, stride_t>(
src, dst, data_shape, i_strides, i_offset);
return; return;
case 2: case 2:
copy_general_dim2<SrcT, DstT>(src, dst); copy_general_dim2<SrcT, DstT, stride_t>(
src, dst, data_shape, i_strides, i_offset);
return; return;
case 3: case 3:
copy_general_dim3<SrcT, DstT>(src, dst); copy_general_dim3<SrcT, DstT, stride_t>(
src, dst, data_shape, i_strides, i_offset);
return; return;
case 4: case 4:
copy_general_dim4<SrcT, DstT>(src, dst); copy_general_dim4<SrcT, DstT, stride_t>(
src, dst, data_shape, i_strides, i_offset);
return; return;
} }
auto src_ptr = src.data<SrcT>(); auto src_ptr = src.data<SrcT>() + i_offset;
auto dst_ptr = dst.data<DstT>(); auto dst_ptr = dst.data<DstT>();
for (size_t i = 0; i < dst.size(); ++i) { for (size_t i = 0; i < dst.size(); ++i) {
size_t src_elem = elem_to_loc(i, src.shape(), src.strides()); stride_t src_elem = elem_to_loc(i, data_shape, i_strides);
dst_ptr[i] = static_cast<DstT>(src_ptr[src_elem]); dst_ptr[i] = static_cast<DstT>(src_ptr[src_elem]);
} }
} }
template <typename SrcT, typename DstT, int D> template <typename SrcT, typename DstT>
inline void copy_general(const array& src, array& dst) {
return copy_general<SrcT, DstT, size_t>(
src, dst, src.shape(), src.strides(), 0);
}
template <typename SrcT, typename DstT, typename stride_t>
inline void copy_general(
const array& src,
array& dst,
const std::vector<int>& data_shape,
const std::vector<stride_t>& i_strides,
const std::vector<stride_t>& o_strides,
int64_t i_offset,
int64_t o_offset) {
return copy_general<SrcT, DstT, stride_t>(
src, dst, data_shape, i_strides, i_offset);
}
template <typename SrcT, typename DstT, typename stride_t, int D>
inline void copy_general_general_dims( inline void copy_general_general_dims(
const array& src, const array& src,
array& dst, array& dst,
size_t offset_src, const std::vector<int>& data_shape,
size_t offset_dst) { const std::vector<stride_t>& i_strides,
const std::vector<stride_t>& o_strides,
stride_t i_offset,
stride_t o_offset) {
if constexpr (D > 1) { if constexpr (D > 1) {
int axis = src.ndim() - D; int axis = src.ndim() - D;
auto stride_src = src.strides()[axis]; auto stride_src = i_strides[axis];
auto stride_dst = dst.strides()[axis]; auto stride_dst = o_strides[axis];
auto N = src.shape(axis); auto N = data_shape[axis];
for (int i = 0; i < N; i++) { for (int i = 0; i < N; i++) {
copy_general_general_dims<SrcT, DstT, D - 1>( copy_general_general_dims<SrcT, DstT, stride_t, D - 1>(
src, dst, offset_src, offset_dst); src, dst, data_shape, i_strides, o_strides, i_offset, o_offset);
offset_src += stride_src; i_offset += stride_src;
offset_dst += stride_dst; o_offset += stride_dst;
} }
} else { } else {
int axis = src.ndim() - 1; int axis = src.ndim() - 1;
auto stride_src = src.strides()[axis]; auto stride_src = i_strides[axis];
auto stride_dst = dst.strides()[axis]; auto stride_dst = o_strides[axis];
auto N = src.shape(axis); auto N = data_shape[axis];
const SrcT* src_ptr = src.data<SrcT>() + offset_src; const SrcT* src_ptr = src.data<SrcT>() + i_offset;
DstT* dst_ptr = dst.data<DstT>() + offset_dst; DstT* dst_ptr = dst.data<DstT>() + o_offset;
for (int i = 0; i < N; i++) { for (int i = 0; i < N; i++) {
*dst_ptr = static_cast<DstT>(*src_ptr); *dst_ptr = static_cast<DstT>(*src_ptr);
src_ptr += stride_src; src_ptr += stride_src;
@ -148,37 +223,56 @@ inline void copy_general_general_dims(
} }
} }
template <typename SrcT, typename DstT> template <typename SrcT, typename DstT, typename stride_t>
void copy_general_general(const array& src, array& dst) { void copy_general_general(
const array& src,
array& dst,
const std::vector<int>& data_shape,
const std::vector<stride_t>& i_strides,
const std::vector<stride_t>& o_strides,
stride_t i_offset,
stride_t o_offset) {
switch (src.ndim()) { switch (src.ndim()) {
case 1: case 1:
copy_general_general_dims<SrcT, DstT, 1>(src, dst, 0, 0); copy_general_general_dims<SrcT, DstT, stride_t, 1>(
src, dst, data_shape, i_strides, o_strides, i_offset, o_offset);
return; return;
case 2: case 2:
copy_general_general_dims<SrcT, DstT, 2>(src, dst, 0, 0); copy_general_general_dims<SrcT, DstT, stride_t, 2>(
src, dst, data_shape, i_strides, o_strides, i_offset, o_offset);
return; return;
case 3: case 3:
copy_general_general_dims<SrcT, DstT, 3>(src, dst, 0, 0); copy_general_general_dims<SrcT, DstT, stride_t, 3>(
src, dst, data_shape, i_strides, o_strides, i_offset, o_offset);
return; return;
case 4: case 4:
copy_general_general_dims<SrcT, DstT, 4>(src, dst, 0, 0); copy_general_general_dims<SrcT, DstT, stride_t, 4>(
src, dst, data_shape, i_strides, o_strides, i_offset, o_offset);
return; return;
case 5: case 5:
copy_general_general_dims<SrcT, DstT, 5>(src, dst, 0, 0); copy_general_general_dims<SrcT, DstT, stride_t, 5>(
src, dst, data_shape, i_strides, o_strides, i_offset, o_offset);
return; return;
} }
int size = std::accumulate( int size = std::accumulate(
src.shape().begin() - 5, src.shape().end(), 1, std::multiplies<int>()); data_shape.begin() - 5, data_shape.end(), 1, std::multiplies<int>());
for (int i = 0; i < src.size(); i += size) { for (int i = 0; i < src.size(); i += size) {
size_t offset_src = elem_to_loc(i, src.shape(), src.strides()); stride_t src_offset = i_offset + elem_to_loc(i, data_shape, i_strides);
size_t offset_dst = elem_to_loc(i, dst.shape(), dst.strides()); stride_t dst_offset = o_offset + elem_to_loc(i, dst.shape(), o_strides);
copy_general_general_dims<SrcT, DstT, 5>(src, dst, offset_src, offset_dst); copy_general_general_dims<SrcT, DstT, stride_t, 5>(
src, dst, data_shape, i_strides, o_strides, src_offset, dst_offset);
} }
} }
template <typename SrcT, typename DstT> template <typename SrcT, typename DstT>
void copy(const array& src, array& dst, CopyType ctype) { inline void copy_general_general(const array& src, array& dst) {
return copy_general_general<SrcT, DstT, size_t>(
src, dst, src.shape(), src.strides(), dst.strides(), 0, 0);
}
template <typename SrcT, typename DstT, typename... Args>
void copy(const array& src, array& dst, CopyType ctype, Args... args) {
switch (ctype) { switch (ctype) {
case CopyType::Scalar: case CopyType::Scalar:
copy_single<SrcT, DstT>(src, dst); copy_single<SrcT, DstT>(src, dst);
@ -187,54 +281,103 @@ void copy(const array& src, array& dst, CopyType ctype) {
copy_vector<SrcT, DstT>(src, dst); copy_vector<SrcT, DstT>(src, dst);
return; return;
case CopyType::General: case CopyType::General:
copy_general<SrcT, DstT>(src, dst); copy_general<SrcT, DstT>(src, dst, args...);
return; return;
case CopyType::GeneralGeneral: case CopyType::GeneralGeneral:
copy_general_general<SrcT, DstT>(src, dst); copy_general_general<SrcT, DstT>(src, dst, args...);
} }
} }
template <typename SrcT> template <typename SrcT, typename... Args>
void copy(const array& src, array& dst, CopyType ctype) { void copy(const array& src, array& dst, CopyType ctype, Args... args) {
switch (dst.dtype()) { switch (dst.dtype()) {
case bool_: case bool_:
copy<SrcT, bool>(src, dst, ctype); copy<SrcT, bool>(src, dst, ctype, args...);
break; break;
case uint8: case uint8:
copy<SrcT, uint8_t>(src, dst, ctype); copy<SrcT, uint8_t>(src, dst, ctype, args...);
break; break;
case uint16: case uint16:
copy<SrcT, uint16_t>(src, dst, ctype); copy<SrcT, uint16_t>(src, dst, ctype, args...);
break; break;
case uint32: case uint32:
copy<SrcT, uint32_t>(src, dst, ctype); copy<SrcT, uint32_t>(src, dst, ctype, args...);
break; break;
case uint64: case uint64:
copy<SrcT, uint64_t>(src, dst, ctype); copy<SrcT, uint64_t>(src, dst, ctype, args...);
break; break;
case int8: case int8:
copy<SrcT, int8_t>(src, dst, ctype); copy<SrcT, int8_t>(src, dst, ctype, args...);
break; break;
case int16: case int16:
copy<SrcT, int16_t>(src, dst, ctype); copy<SrcT, int16_t>(src, dst, ctype, args...);
break; break;
case int32: case int32:
copy<SrcT, int32_t>(src, dst, ctype); copy<SrcT, int32_t>(src, dst, ctype, args...);
break; break;
case int64: case int64:
copy<SrcT, int64_t>(src, dst, ctype); copy<SrcT, int64_t>(src, dst, ctype, args...);
break; break;
case float16: case float16:
copy<SrcT, float16_t>(src, dst, ctype); copy<SrcT, float16_t>(src, dst, ctype, args...);
break; break;
case float32: case float32:
copy<SrcT, float>(src, dst, ctype); copy<SrcT, float>(src, dst, ctype, args...);
break; break;
case bfloat16: case bfloat16:
copy<SrcT, bfloat16_t>(src, dst, ctype); copy<SrcT, bfloat16_t>(src, dst, ctype, args...);
break; break;
case complex64: case complex64:
copy<SrcT, complex64_t>(src, dst, ctype); copy<SrcT, complex64_t>(src, dst, ctype, args...);
break;
}
}
template <typename... Args>
inline void copy_inplace_dispatch(
const array& src,
array& dst,
CopyType ctype,
Args... args) {
switch (src.dtype()) {
case bool_:
copy<bool>(src, dst, ctype, args...);
break;
case uint8:
copy<uint8_t>(src, dst, ctype, args...);
break;
case uint16:
copy<uint16_t>(src, dst, ctype, args...);
break;
case uint32:
copy<uint32_t>(src, dst, ctype, args...);
break;
case uint64:
copy<uint64_t>(src, dst, ctype, args...);
break;
case int8:
copy<int8_t>(src, dst, ctype, args...);
break;
case int16:
copy<int16_t>(src, dst, ctype, args...);
break;
case int32:
copy<int32_t>(src, dst, ctype, args...);
break;
case int64:
copy<int64_t>(src, dst, ctype, args...);
break;
case float16:
copy<float16_t>(src, dst, ctype, args...);
break;
case float32:
copy<float>(src, dst, ctype, args...);
break;
case bfloat16:
copy<bfloat16_t>(src, dst, ctype, args...);
break;
case complex64:
copy<complex64_t>(src, dst, ctype, args...);
break; break;
} }
} }
@ -242,47 +385,7 @@ void copy(const array& src, array& dst, CopyType ctype) {
} // namespace } // namespace
void copy_inplace(const array& src, array& dst, CopyType ctype) { void copy_inplace(const array& src, array& dst, CopyType ctype) {
switch (src.dtype()) { return copy_inplace_dispatch(src, dst, ctype);
case bool_:
copy<bool>(src, dst, ctype);
break;
case uint8:
copy<uint8_t>(src, dst, ctype);
break;
case uint16:
copy<uint16_t>(src, dst, ctype);
break;
case uint32:
copy<uint32_t>(src, dst, ctype);
break;
case uint64:
copy<uint64_t>(src, dst, ctype);
break;
case int8:
copy<int8_t>(src, dst, ctype);
break;
case int16:
copy<int16_t>(src, dst, ctype);
break;
case int32:
copy<int32_t>(src, dst, ctype);
break;
case int64:
copy<int64_t>(src, dst, ctype);
break;
case float16:
copy<float16_t>(src, dst, ctype);
break;
case float32:
copy<float>(src, dst, ctype);
break;
case bfloat16:
copy<bfloat16_t>(src, dst, ctype);
break;
case complex64:
copy<complex64_t>(src, dst, ctype);
break;
}
} }
void copy(const array& src, array& dst, CopyType ctype) { void copy(const array& src, array& dst, CopyType ctype) {
@ -312,4 +415,62 @@ void copy(const array& src, array& dst, CopyType ctype) {
copy_inplace(src, dst, ctype); copy_inplace(src, dst, ctype);
} }
template <typename stride_t>
void copy_inplace(
const array& src,
array& dst,
const std::vector<int>& data_shape,
const std::vector<stride_t>& i_strides,
const std::vector<stride_t>& o_strides,
int64_t i_offset,
int64_t o_offset,
CopyType ctype) {
switch (ctype) {
case CopyType::General:
case CopyType::GeneralGeneral:
return copy_inplace_dispatch(
src,
dst,
ctype,
data_shape,
i_strides,
o_strides,
i_offset,
o_offset);
case CopyType::Scalar:
case CopyType::Vector:
return copy_inplace_dispatch(src, dst, ctype);
}
}
template <>
void copy_inplace<int64_t>(
const array& src,
array& dst,
const std::vector<int>& data_shape,
const std::vector<int64_t>& i_strides,
const std::vector<int64_t>& o_strides,
int64_t i_offset,
int64_t o_offset,
CopyType ctype) {
switch (ctype) {
case CopyType::General:
case CopyType::GeneralGeneral:
return copy_inplace_dispatch(
src,
dst,
ctype,
data_shape,
i_strides,
o_strides,
i_offset,
o_offset);
case CopyType::Scalar:
case CopyType::Vector:
return copy_inplace_dispatch(src, dst, ctype);
}
}
} // namespace mlx::core } // namespace mlx::core

View File

@ -1,4 +1,4 @@
// Copyright © 2023 Apple Inc. // Copyright © 2023-2024 Apple Inc.
#pragma once #pragma once
@ -26,4 +26,15 @@ enum class CopyType {
void copy(const array& src, array& dst, CopyType ctype); void copy(const array& src, array& dst, CopyType ctype);
void copy_inplace(const array& src, array& dst, CopyType ctype); void copy_inplace(const array& src, array& dst, CopyType ctype);
template <typename stride_t>
void copy_inplace(
const array& src,
array& dst,
const std::vector<int>& data_shape,
const std::vector<stride_t>& i_strides,
const std::vector<stride_t>& o_strides,
int64_t i_offset,
int64_t o_offset,
CopyType ctype);
} // namespace mlx::core } // namespace mlx::core

View File

@ -94,6 +94,7 @@ DEFAULT(Sign)
DEFAULT(Sin) DEFAULT(Sin)
DEFAULT(Sinh) DEFAULT(Sinh)
DEFAULT(Slice) DEFAULT(Slice)
DEFAULT(SliceUpdate)
DEFAULT(Softmax) DEFAULT(Softmax)
DEFAULT(Sort) DEFAULT(Sort)
DEFAULT_MULTI(Split) DEFAULT_MULTI(Split)

View File

@ -651,36 +651,33 @@ void Sinh::eval(const std::vector<array>& inputs, array& out) {
} }
} }
void Slice::eval(const std::vector<array>& inputs, array& out) { std::tuple<bool, int64_t, std::vector<int64_t>> Slice::prepare_slice(
assert(inputs.size() == 1); const array& in) {
if (out.size() == 0) { int64_t data_offset = 0;
out.set_data(nullptr); bool copy_needed = false;
return; std::vector<int64_t> inp_strides(in.ndim(), 0);
}
auto& in = inputs[0];
auto strides = in.strides();
auto flags = in.flags();
size_t data_offset = 0;
for (int i = 0; i < in.ndim(); ++i) { for (int i = 0; i < in.ndim(); ++i) {
data_offset += start_indices_[i] * in.strides()[i]; data_offset += start_indices_[i] * in.strides()[i];
strides[i] *= strides_[i]; inp_strides[i] = in.strides()[i] * strides_[i];
copy_needed |= strides_[i] < 0;
} }
return std::make_tuple(copy_needed, data_offset, inp_strides);
}
void Slice::shared_buffer_slice(
const array& in,
const std::vector<size_t>& out_strides,
size_t data_offset,
array& out) {
// Compute row/col contiguity // Compute row/col contiguity
size_t data_size = 1; auto [data_size, is_row_contiguous, is_col_contiguous] =
size_t f_stride = 1; check_contiguity(out.shape(), out_strides);
size_t b_stride = 1;
flags.row_contiguous = true; auto flags = in.flags();
flags.col_contiguous = true; flags.row_contiguous = is_row_contiguous;
for (int i = 0, ri = out.ndim() - 1; ri >= 0; i++, ri--) { flags.col_contiguous = is_col_contiguous;
flags.col_contiguous &= strides[i] == f_stride || out.shape(i) == 1;
flags.row_contiguous &= strides[ri] == b_stride || out.shape(ri) == 1;
f_stride *= out.shape(i);
b_stride *= out.shape(ri);
if (strides[i] > 0) {
data_size *= out.shape(i);
}
}
if (data_size == 1) { if (data_size == 1) {
// Broadcasted scalar array is contiguous. // Broadcasted scalar array is contiguous.
@ -694,7 +691,87 @@ void Slice::eval(const std::vector<array>& inputs, array& out) {
flags.contiguous &= flags.row_contiguous || flags.col_contiguous; flags.contiguous &= flags.row_contiguous || flags.col_contiguous;
} }
out.copy_shared_buffer(in, strides, flags, data_size, data_offset); out.copy_shared_buffer(in, out_strides, flags, data_size, data_offset);
}
void Slice::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
if (out.size() == 0) {
out.set_data(nullptr);
return;
}
auto& in = inputs[0];
// Calculate out strides, initial offset and if copy needs to be made
auto [copy_needed, data_offset, inp_strides] = prepare_slice(in);
// Do copy if needed
if (copy_needed) {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
std::vector<int64_t> ostrides{out.strides().begin(), out.strides().end()};
copy_inplace<int64_t>(
/* const array& src = */ in,
/* array& dst = */ out,
/* const std::vector<int>& data_shape = */ out.shape(),
/* const std::vector<stride_t>& i_strides = */ inp_strides,
/* const std::vector<stride_t>& o_strides = */ ostrides,
/* int64_t i_offset = */ data_offset,
/* int64_t o_offset = */ 0,
/* CopyType ctype = */ CopyType::General);
} else {
std::vector<size_t> ostrides{inp_strides.begin(), inp_strides.end()};
shared_buffer_slice(in, ostrides, data_offset, out);
}
}
std::tuple<int64_t, std::vector<int64_t>> SliceUpdate::prepare_slice(
const array& in) {
int64_t data_offset = 0;
std::vector<int64_t> inp_strides(in.ndim(), 0);
for (int i = 0; i < in.ndim(); ++i) {
data_offset += start_indices_[i] * in.strides()[i];
inp_strides[i] = in.strides()[i] * strides_[i];
}
return std::make_tuple(data_offset, inp_strides);
}
void SliceUpdate::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
if (out.size() == 0) {
out.set_data(nullptr);
return;
}
auto& in = inputs[0];
auto& upd = inputs[1];
if (upd.size() == 0) {
out.copy_shared_buffer(in);
return;
}
// Check if materialization is needed
auto ctype = in.flags().contiguous && in.size() == in.data_size()
? CopyType::Vector
: CopyType::General;
copy(in, out, in.data_size() == 1 ? CopyType::Scalar : ctype);
// Calculate out strides, initial offset and if copy needs to be made
auto [data_offset, out_strides] = prepare_slice(out);
// Do copy
std::vector<int64_t> upd_strides{upd.strides().begin(), upd.strides().end()};
copy_inplace<int64_t>(
/* const array& src = */ upd,
/* array& dst = */ out,
/* const std::vector<int>& data_shape = */ upd.shape(),
/* const std::vector<stride_t>& i_strides = */ upd_strides,
/* const std::vector<stride_t>& o_strides = */ out_strides,
/* int64_t i_offset = */ 0,
/* int64_t o_offset = */ data_offset,
/* CopyType ctype = */ CopyType::GeneralGeneral);
} }
void Split::eval( void Split::eval(

View File

@ -1,4 +1,4 @@
// Copyright © 2023 Apple Inc. // Copyright © 2023-2024 Apple Inc.
#pragma once #pragma once
@ -8,11 +8,12 @@
namespace mlx::core { namespace mlx::core {
inline size_t elem_to_loc( template <typename stride_t>
inline stride_t elem_to_loc(
int elem, int elem,
const std::vector<int>& shape, const std::vector<int>& shape,
const std::vector<size_t>& strides) { const std::vector<stride_t>& strides) {
size_t loc = 0; stride_t loc = 0;
for (int i = shape.size() - 1; i >= 0; --i) { for (int i = shape.size() - 1; i >= 0; --i) {
auto q_and_r = ldiv(elem, shape[i]); auto q_and_r = ldiv(elem, shape[i]);
loc += q_and_r.rem * strides[i]; loc += q_and_r.rem * strides[i];
@ -34,10 +35,11 @@ inline size_t elem_to_loc(int elem, const array& a) {
// //
// When multiple arrays are passed they should all have the same shape. The // When multiple arrays are passed they should all have the same shape. The
// collapsed axes are also the same so one shape is returned. // collapsed axes are also the same so one shape is returned.
inline std::tuple<std::vector<int>, std::vector<std::vector<size_t>>> template <typename stride_t>
inline std::tuple<std::vector<int>, std::vector<std::vector<stride_t>>>
collapse_contiguous_dims( collapse_contiguous_dims(
const std::vector<int>& shape, const std::vector<int>& shape,
const std::vector<std::vector<size_t>> strides) { const std::vector<std::vector<stride_t>> strides) {
// Make a vector that has axes separated with -1. Collapse all axes between // Make a vector that has axes separated with -1. Collapse all axes between
// -1. // -1.
std::vector<int> to_collapse; std::vector<int> to_collapse;
@ -45,7 +47,7 @@ collapse_contiguous_dims(
to_collapse.push_back(0); to_collapse.push_back(0);
for (int i = 1; i < shape.size(); i++) { for (int i = 1; i < shape.size(); i++) {
bool contiguous = true; bool contiguous = true;
for (const std::vector<size_t>& st : strides) { for (const std::vector<stride_t>& st : strides) {
if (st[i] * shape[i] != st[i - 1]) { if (st[i] * shape[i] != st[i - 1]) {
contiguous = false; contiguous = false;
} }
@ -62,7 +64,7 @@ collapse_contiguous_dims(
} }
std::vector<int> out_shape; std::vector<int> out_shape;
std::vector<std::vector<size_t>> out_strides(strides.size()); std::vector<std::vector<stride_t>> out_strides(strides.size());
for (int i = 0; i < to_collapse.size(); i++) { for (int i = 0; i < to_collapse.size(); i++) {
int current_shape = shape[to_collapse[i]]; int current_shape = shape[to_collapse[i]];
while (to_collapse[++i] != -1) { while (to_collapse[++i] != -1) {
@ -70,7 +72,7 @@ collapse_contiguous_dims(
} }
out_shape.push_back(current_shape); out_shape.push_back(current_shape);
for (int j = 0; j < strides.size(); j++) { for (int j = 0; j < strides.size(); j++) {
const std::vector<size_t>& st = strides[j]; const std::vector<stride_t>& st = strides[j];
out_strides[j].push_back(st[to_collapse[i - 1]]); out_strides[j].push_back(st[to_collapse[i - 1]]);
} }
} }
@ -94,4 +96,27 @@ collapse_contiguous_dims(Arrays... xs) {
std::vector<array>{std::forward<Arrays>(xs)...}); std::vector<array>{std::forward<Arrays>(xs)...});
} }
template <typename stride_t>
inline auto check_contiguity(
const std::vector<int>& shape,
const std::vector<stride_t>& strides) {
size_t data_size = 1;
size_t f_stride = 1;
size_t b_stride = 1;
bool is_row_contiguous = true;
bool is_col_contiguous = true;
for (int i = 0, ri = shape.size() - 1; ri >= 0; i++, ri--) {
is_row_contiguous &= strides[i] == f_stride || shape[i] == 1;
is_col_contiguous &= strides[ri] == b_stride || shape[ri] == 1;
f_stride *= shape[i];
b_stride *= shape[ri];
if (strides[i] > 0) {
data_size *= shape[i];
}
}
return std::make_tuple(data_size, is_row_contiguous, is_col_contiguous);
}
} // namespace mlx::core } // namespace mlx::core

View File

@ -1,4 +1,4 @@
// Copyright © 2023 Apple Inc. // Copyright © 2023-2024 Apple Inc.
#include <sstream> #include <sstream>
@ -37,15 +37,22 @@ void copy_gpu(const array& in, array& out, CopyType ctype) {
copy_gpu(in, out, ctype, out.primitive().stream()); copy_gpu(in, out, ctype, out.primitive().stream());
} }
template <typename stride_t>
void copy_gpu_inplace( void copy_gpu_inplace(
const array& in, const array& in,
array& out, array& out,
const std::vector<int>& data_shape,
const std::vector<stride_t>& strides_in_pre,
const std::vector<stride_t>& strides_out_pre,
int64_t inp_offset,
int64_t out_offset,
CopyType ctype, CopyType ctype,
const Stream& s) { const Stream& s) {
// Try to collapse contiguous dims // Try to collapse contiguous dims
auto [shape, strides] = collapse_contiguous_dims(in, out); auto [shape, strides] = collapse_contiguous_dims(
auto& strides_in = strides[0]; data_shape, std::vector{strides_in_pre, strides_out_pre});
auto& strides_out = strides[1]; auto& strides_in_ = strides[0];
auto& strides_out_ = strides[1];
auto& d = metal::device(s.device); auto& d = metal::device(s.device);
std::ostringstream kname; std::ostringstream kname;
@ -72,39 +79,44 @@ void copy_gpu_inplace(
auto compute_encoder = d.get_command_encoder(s.index); auto compute_encoder = d.get_command_encoder(s.index);
compute_encoder->setComputePipelineState(kernel); compute_encoder->setComputePipelineState(kernel);
bool donate_in = in.data_shared_ptr() == nullptr; bool donate_in = in.data_shared_ptr() == nullptr;
set_array_buffer(compute_encoder, donate_in ? out : in, 0);
set_array_buffer(compute_encoder, out, 1); inp_offset *= size_of(in.dtype());
out_offset *= size_of(out.dtype());
set_array_buffer(compute_encoder, donate_in ? out : in, inp_offset, 0);
set_array_buffer(compute_encoder, out, out_offset, 1);
if (ctype == CopyType::General || ctype == CopyType::GeneralGeneral) { if (ctype == CopyType::General || ctype == CopyType::GeneralGeneral) {
size_t ndim = shape.size(); int ndim = shape.size();
std::vector<int64_t> strides_in{strides_in_.begin(), strides_in_.end()};
std::vector<int64_t> strides_out{strides_out_.begin(), strides_out_.end()};
if (ndim > 3) { if (ndim > 3) {
compute_encoder->setBytes(shape.data(), ndim * sizeof(int), 2); set_vector_bytes(compute_encoder, shape, ndim, 2);
compute_encoder->setBytes(strides_in.data(), ndim * sizeof(size_t), 3);
if (ctype == CopyType::GeneralGeneral) {
compute_encoder->setBytes(strides_out.data(), ndim * sizeof(size_t), 4);
} }
} else { set_vector_bytes(compute_encoder, strides_in, ndim, 3);
// The shape is implicit in the grid for <= 3D
compute_encoder->setBytes(strides_in.data(), ndim * sizeof(size_t), 2);
if (ctype == CopyType::GeneralGeneral) { if (ctype == CopyType::GeneralGeneral) {
compute_encoder->setBytes(strides_out.data(), ndim * sizeof(size_t), 3); set_vector_bytes(compute_encoder, strides_out, ndim, 4);
}
} }
if (ndim > MAX_BINARY_SPECIALIZED_DIMS) { if (ndim > MAX_BINARY_SPECIALIZED_DIMS) {
compute_encoder->setBytes( compute_encoder->setBytes(&ndim, sizeof(int), 5);
&ndim, sizeof(int), (ctype == CopyType::GeneralGeneral) ? 5 : 4);
} }
int dim0 = ndim > 0 ? shape[ndim - 1] : 1; int dim0 = ndim > 0 ? shape[ndim - 1] : 1;
int dim1 = ndim > 1 ? shape[ndim - 2] : 1; int dim1 = ndim > 1 ? shape[ndim - 2] : 1;
int rest = in.size() / (dim0 * dim1);
size_t data_size = 1;
for (auto& s : shape)
data_size *= s;
int rest = data_size / (dim0 * dim1);
// NB assuming thread_group_size is a power of 2 larger than 32 x 32 // NB assuming thread_group_size is a power of 2 larger than 32 x 32
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
if (thread_group_size != 1024) { if (thread_group_size != 1024) {
throw std::runtime_error("[Metal::copy] Must use 1024 sized block"); throw std::runtime_error("[Metal::copy] Must use 1024 sized block");
} }
auto group_dims = get_block_dims(dim0, dim1, rest); auto group_dims = get_block_dims(dim0, dim1, rest);
MTL::Size grid_dims = MTL::Size(dim0, dim1, rest); MTL::Size grid_dims = MTL::Size(dim0, dim1, rest);
compute_encoder->dispatchThreads(grid_dims, group_dims); compute_encoder->dispatchThreads(grid_dims, group_dims);
@ -120,4 +132,25 @@ void copy_gpu_inplace(
} }
} }
void copy_gpu_inplace(
const array& in,
array& out,
CopyType ctype,
const Stream& s) {
return copy_gpu_inplace(
in, out, in.shape(), in.strides(), out.strides(), 0, 0, ctype, s);
}
void copy_gpu_inplace(
const array& in,
array& out,
const std::vector<int64_t>& istride,
int64_t ioffset,
CopyType ctype,
const Stream& s) {
std::vector<int64_t> ostrides{out.strides().begin(), out.strides().end()};
return copy_gpu_inplace(
in, out, in.shape(), istride, ostrides, ioffset, 0, ctype, s);
}
} // namespace mlx::core } // namespace mlx::core

View File

@ -1,4 +1,4 @@
// Copyright © 2023 Apple Inc. // Copyright © 2023-2024 Apple Inc.
#pragma once #pragma once
@ -7,12 +7,34 @@
namespace mlx::core { namespace mlx::core {
// Generic copy inplace
template <typename stride_t>
void copy_gpu_inplace(
const array& in,
array& out,
const std::vector<int>& data_shape,
const std::vector<stride_t>& i_strides,
const std::vector<stride_t>& o_strides,
int64_t i_offset,
int64_t o_offset,
CopyType ctype,
const Stream& s);
void copy_gpu(const array& src, array& out, CopyType ctype, const Stream& s); void copy_gpu(const array& src, array& out, CopyType ctype, const Stream& s);
void copy_gpu(const array& src, array& out, CopyType ctype); void copy_gpu(const array& src, array& out, CopyType ctype);
void copy_gpu_inplace( void copy_gpu_inplace(
const array& src, const array& src,
array& out, array& out,
CopyType ctype, CopyType ctype,
const Stream& s); const Stream& s);
void copy_gpu_inplace(
const array& in,
array& out,
const std::vector<int64_t>& istride,
int64_t ioffset,
CopyType ctype,
const Stream& s);
} // namespace mlx::core } // namespace mlx::core

View File

@ -1,29 +1,29 @@
// Copyright © 2023 Apple Inc. // Copyright © 2023-2024 Apple Inc.
#include "mlx/backend/metal/kernels/bf16.h" #include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/utils.h" #include "mlx/backend/metal/kernels/utils.h"
template <typename T, typename U> template <typename T, typename U>
[[kernel]] void copy_s( [[kernel]] void copy_s(
device const T* src, device const T* src [[buffer(0)]],
device U* dst, device U* dst [[buffer(1)]],
uint index [[thread_position_in_grid]]) { uint index [[thread_position_in_grid]]) {
dst[index] = static_cast<U>(src[0]); dst[index] = static_cast<U>(src[0]);
} }
template <typename T, typename U> template <typename T, typename U>
[[kernel]] void copy_v( [[kernel]] void copy_v(
device const T* src, device const T* src [[buffer(0)]],
device U* dst, device U* dst [[buffer(1)]],
uint index [[thread_position_in_grid]]) { uint index [[thread_position_in_grid]]) {
dst[index] = static_cast<U>(src[index]); dst[index] = static_cast<U>(src[index]);
} }
template <typename T, typename U> template <typename T, typename U>
[[kernel]] void copy_g_nd1( [[kernel]] void copy_g_nd1(
device const T* src, device const T* src [[buffer(0)]],
device U* dst, device U* dst [[buffer(1)]],
constant const size_t& src_stride, constant const int64_t& src_stride [[buffer(3)]],
uint index [[thread_position_in_grid]]) { uint index [[thread_position_in_grid]]) {
auto src_idx = elem_to_loc_1(index, src_stride); auto src_idx = elem_to_loc_1(index, src_stride);
dst[index] = static_cast<U>(src[src_idx]); dst[index] = static_cast<U>(src[src_idx]);
@ -31,61 +31,61 @@ template <typename T, typename U>
template <typename T, typename U> template <typename T, typename U>
[[kernel]] void copy_g_nd2( [[kernel]] void copy_g_nd2(
device const T* src, device const T* src [[buffer(0)]],
device U* dst, device U* dst [[buffer(1)]],
constant const size_t src_strides[2], constant const int64_t* src_strides [[buffer(3)]],
uint2 index [[thread_position_in_grid]], uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) { uint2 grid_dim [[threads_per_grid]]) {
auto src_idx = elem_to_loc_2(index, src_strides); auto src_idx = elem_to_loc_2(index, src_strides);
size_t dst_idx = index.x + (size_t)grid_dim.x * index.y; int64_t dst_idx = index.x + (int64_t)grid_dim.x * index.y;
dst[dst_idx] = static_cast<U>(src[src_idx]); dst[dst_idx] = static_cast<U>(src[src_idx]);
} }
template <typename T, typename U> template <typename T, typename U>
[[kernel]] void copy_g_nd3( [[kernel]] void copy_g_nd3(
device const T* src, device const T* src [[buffer(0)]],
device U* dst, device U* dst [[buffer(1)]],
constant const size_t src_strides[3], constant const int64_t* src_strides [[buffer(3)]],
uint3 index [[thread_position_in_grid]], uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) { uint3 grid_dim [[threads_per_grid]]) {
auto src_idx = elem_to_loc_3(index, src_strides); auto src_idx = elem_to_loc_3(index, src_strides);
size_t dst_idx = index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z); int64_t dst_idx = index.x + (int64_t)grid_dim.x * (index.y + (int64_t)grid_dim.y * index.z);
dst[dst_idx] = static_cast<U>(src[src_idx]); dst[dst_idx] = static_cast<U>(src[src_idx]);
} }
template <typename T, typename U, int DIM> template <typename T, typename U, int DIM>
[[kernel]] void copy_g_nd( [[kernel]] void copy_g_nd(
device const T* src, device const T* src [[buffer(0)]],
device U* dst, device U* dst [[buffer(1)]],
constant const int src_shape[DIM], constant const int* src_shape [[buffer(2)]],
constant const size_t src_strides[DIM], constant const int64_t* src_strides [[buffer(3)]],
uint3 index [[thread_position_in_grid]], uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) { uint3 grid_dim [[threads_per_grid]]) {
auto src_idx = elem_to_loc_nd<DIM>(index, src_shape, src_strides); auto src_idx = elem_to_loc_nd<DIM>(index, src_shape, src_strides);
size_t dst_idx = index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z); int64_t dst_idx = index.x + (int64_t)grid_dim.x * (index.y + (int64_t)grid_dim.y * index.z);
dst[dst_idx] = static_cast<U>(src[src_idx]); dst[dst_idx] = static_cast<U>(src[src_idx]);
} }
template <typename T, typename U> template <typename T, typename U>
[[kernel]] void copy_g( [[kernel]] void copy_g(
device const T* src, device const T* src [[buffer(0)]],
device U* dst, device U* dst [[buffer(1)]],
constant const int* src_shape, constant const int* src_shape [[buffer(2)]],
constant const size_t* src_strides, constant const int64_t* src_strides [[buffer(3)]],
constant const int& ndim, constant const int& ndim [[buffer(5)]],
uint3 index [[thread_position_in_grid]], uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) { uint3 grid_dim [[threads_per_grid]]) {
auto src_idx = elem_to_loc(index, src_shape, src_strides, ndim); auto src_idx = elem_to_loc(index, src_shape, src_strides, ndim);
size_t dst_idx = index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z); int64_t dst_idx = index.x + (int64_t)grid_dim.x * (index.y + (int64_t)grid_dim.y * index.z);
dst[dst_idx] = static_cast<U>(src[src_idx]); dst[dst_idx] = static_cast<U>(src[src_idx]);
} }
template <typename T, typename U> template <typename T, typename U>
[[kernel]] void copy_gg_nd1( [[kernel]] void copy_gg_nd1(
device const T* src, device const T* src [[buffer(0)]],
device U* dst, device U* dst [[buffer(1)]],
constant const size_t& src_stride, constant const int64_t& src_stride [[buffer(3)]],
constant const size_t& dst_stride, constant const int64_t& dst_stride [[buffer(4)]],
uint index [[thread_position_in_grid]]) { uint index [[thread_position_in_grid]]) {
auto src_idx = elem_to_loc_1(index, src_stride); auto src_idx = elem_to_loc_1(index, src_stride);
auto dst_idx = elem_to_loc_1(index, dst_stride); auto dst_idx = elem_to_loc_1(index, dst_stride);
@ -94,10 +94,10 @@ template <typename T, typename U>
template <typename T, typename U> template <typename T, typename U>
[[kernel]] void copy_gg_nd2( [[kernel]] void copy_gg_nd2(
device const T* src, device const T* src [[buffer(0)]],
device U* dst, device U* dst [[buffer(1)]],
constant const size_t src_strides[2], constant const int64_t* src_strides [[buffer(3)]],
constant const size_t dst_strides[2], constant const int64_t* dst_strides [[buffer(4)]],
uint2 index [[thread_position_in_grid]]) { uint2 index [[thread_position_in_grid]]) {
auto src_idx = elem_to_loc_2(index, src_strides); auto src_idx = elem_to_loc_2(index, src_strides);
auto dst_idx = elem_to_loc_2(index, dst_strides); auto dst_idx = elem_to_loc_2(index, dst_strides);
@ -106,10 +106,10 @@ template <typename T, typename U>
template <typename T, typename U> template <typename T, typename U>
[[kernel]] void copy_gg_nd3( [[kernel]] void copy_gg_nd3(
device const T* src, device const T* src [[buffer(0)]],
device U* dst, device U* dst [[buffer(1)]],
constant const size_t src_strides[3], constant const int64_t* src_strides [[buffer(3)]],
constant const size_t dst_strides[3], constant const int64_t* dst_strides [[buffer(4)]],
uint3 index [[thread_position_in_grid]]) { uint3 index [[thread_position_in_grid]]) {
auto src_idx = elem_to_loc_3(index, src_strides); auto src_idx = elem_to_loc_3(index, src_strides);
auto dst_idx = elem_to_loc_3(index, dst_strides); auto dst_idx = elem_to_loc_3(index, dst_strides);
@ -118,11 +118,11 @@ template <typename T, typename U>
template <typename T, typename U, int DIM> template <typename T, typename U, int DIM>
[[kernel]] void copy_gg_nd( [[kernel]] void copy_gg_nd(
device const T* src, device const T* src [[buffer(0)]],
device U* dst, device U* dst [[buffer(1)]],
constant const int src_shape[DIM], constant const int* src_shape [[buffer(2)]],
constant const size_t src_strides[DIM], constant const int64_t* src_strides [[buffer(3)]],
constant const size_t dst_strides[DIM], constant const int64_t* dst_strides [[buffer(4)]],
uint3 index [[thread_position_in_grid]]) { uint3 index [[thread_position_in_grid]]) {
auto src_idx = elem_to_loc_nd<DIM>(index, src_shape, src_strides); auto src_idx = elem_to_loc_nd<DIM>(index, src_shape, src_strides);
auto dst_idx = elem_to_loc_nd<DIM>(index, src_shape, dst_strides); auto dst_idx = elem_to_loc_nd<DIM>(index, src_shape, dst_strides);
@ -131,12 +131,12 @@ template <typename T, typename U, int DIM>
template <typename T, typename U> template <typename T, typename U>
[[kernel]] void copy_gg( [[kernel]] void copy_gg(
device const T* src, device const T* src [[buffer(0)]],
device U* dst, device U* dst [[buffer(1)]],
constant const int* src_shape, constant const int* src_shape [[buffer(2)]],
constant const size_t* src_strides, constant const int64_t* src_strides [[buffer(3)]],
constant const size_t* dst_strides, constant const int64_t* dst_strides [[buffer(4)]],
constant const int& ndim, constant const int& ndim [[buffer(5)]],
uint3 index [[thread_position_in_grid]]) { uint3 index [[thread_position_in_grid]]) {
auto src_idx = elem_to_loc(index, src_shape, src_strides, ndim); auto src_idx = elem_to_loc(index, src_shape, src_strides, ndim);
auto dst_idx = elem_to_loc(index, src_shape, dst_strides, ndim); auto dst_idx = elem_to_loc(index, src_shape, dst_strides, ndim);
@ -146,70 +146,70 @@ template <typename T, typename U>
#define instantiate_copy(name, itype, otype, ctype) \ #define instantiate_copy(name, itype, otype, ctype) \
template [[host_name(name)]] \ template [[host_name(name)]] \
[[kernel]] void copy_##ctype<itype, otype>( \ [[kernel]] void copy_##ctype<itype, otype>( \
device const itype* src, \ device const itype* src [[buffer(0)]], \
device otype* dst, \ device otype* dst [[buffer(1)]], \
uint index [[thread_position_in_grid]]); uint index [[thread_position_in_grid]]);
#define instantiate_copy_g_dim(name, itype, otype, dims) \ #define instantiate_copy_g_dim(name, itype, otype, dims) \
template [[host_name(name "_" #dims)]] \ template [[host_name(name "_" #dims)]] \
[[kernel]] void copy_g_nd<itype, otype, dims>( \ [[kernel]] void copy_g_nd<itype, otype, dims>( \
device const itype* src, \ device const itype* src [[buffer(0)]], \
device otype* dst, \ device otype* dst [[buffer(1)]], \
constant const int src_shape[dims], \ constant const int* src_shape [[buffer(2)]], \
constant const size_t src_strides[dims], \ constant const int64_t* src_strides [[buffer(3)]], \
uint3 index [[thread_position_in_grid]], \ uint3 index [[thread_position_in_grid]], \
uint3 grid_dim [[threads_per_grid]]); \ uint3 grid_dim [[threads_per_grid]]); \
template [[host_name("g" name "_" #dims)]] \ template [[host_name("g" name "_" #dims)]] \
[[kernel]] void copy_gg_nd<itype, otype, dims>( \ [[kernel]] void copy_gg_nd<itype, otype, dims>( \
device const itype* src, \ device const itype* src [[buffer(0)]], \
device otype* dst, \ device otype* dst [[buffer(1)]], \
constant const int src_shape[dims], \ constant const int* src_shape [[buffer(2)]], \
constant const size_t src_strides[dims], \ constant const int64_t* src_strides [[buffer(3)]], \
constant const size_t dst_strides[dims], \ constant const int64_t* dst_strides [[buffer(4)]], \
uint3 index [[thread_position_in_grid]]); uint3 index [[thread_position_in_grid]]);
#define instantiate_copy_g_nd(name, itype, otype) \ #define instantiate_copy_g_nd(name, itype, otype) \
template [[host_name(name "_1")]] \ template [[host_name(name "_1")]] \
[[kernel]] void copy_g_nd1<itype, otype>( \ [[kernel]] void copy_g_nd1<itype, otype>( \
device const itype* src, \ device const itype* src [[buffer(0)]], \
device otype* dst, \ device otype* dst [[buffer(1)]], \
constant const size_t& src_stride, \ constant const int64_t& src_stride [[buffer(3)]], \
uint index [[thread_position_in_grid]]); \ uint index [[thread_position_in_grid]]); \
template [[host_name(name "_2")]] \ template [[host_name(name "_2")]] \
[[kernel]] void copy_g_nd2<itype, otype>( \ [[kernel]] void copy_g_nd2<itype, otype>( \
device const itype* src, \ device const itype* src [[buffer(0)]], \
device otype* dst, \ device otype* dst [[buffer(1)]], \
constant const size_t src_strides[2], \ constant const int64_t* src_strides [[buffer(3)]], \
uint2 index [[thread_position_in_grid]], \ uint2 index [[thread_position_in_grid]], \
uint2 grid_dim [[threads_per_grid]]); \ uint2 grid_dim [[threads_per_grid]]); \
template [[host_name(name "_3")]] \ template [[host_name(name "_3")]] \
[[kernel]] void copy_g_nd3<itype, otype>( \ [[kernel]] void copy_g_nd3<itype, otype>( \
device const itype* src, \ device const itype* src [[buffer(0)]], \
device otype* dst, \ device otype* dst [[buffer(1)]], \
constant const size_t src_strides[3], \ constant const int64_t* src_strides [[buffer(3)]], \
uint3 index [[thread_position_in_grid]], \ uint3 index [[thread_position_in_grid]], \
uint3 grid_dim [[threads_per_grid]]); \ uint3 grid_dim [[threads_per_grid]]); \
template [[host_name("g" name "_1")]] \ template [[host_name("g" name "_1")]] \
[[kernel]] void copy_gg_nd1<itype, otype>( \ [[kernel]] void copy_gg_nd1<itype, otype>( \
device const itype* src, \ device const itype* src [[buffer(0)]], \
device otype* dst, \ device otype* dst [[buffer(1)]], \
constant const size_t& src_stride, \ constant const int64_t& src_stride [[buffer(3)]], \
constant const size_t& dst_stride, \ constant const int64_t& dst_stride [[buffer(4)]], \
uint index [[thread_position_in_grid]]); \ uint index [[thread_position_in_grid]]); \
template [[host_name("g" name "_2")]] \ template [[host_name("g" name "_2")]] \
[[kernel]] void copy_gg_nd2<itype, otype>( \ [[kernel]] void copy_gg_nd2<itype, otype>( \
device const itype* src, \ device const itype* src [[buffer(0)]], \
device otype* dst, \ device otype* dst [[buffer(1)]], \
constant const size_t src_strides[2], \ constant const int64_t* src_strides [[buffer(3)]], \
constant const size_t dst_strides[2], \ constant const int64_t* dst_strides [[buffer(4)]], \
uint2 index [[thread_position_in_grid]]); \ uint2 index [[thread_position_in_grid]]); \
template [[host_name("g" name "_3")]] \ template [[host_name("g" name "_3")]] \
[[kernel]] void copy_gg_nd3<itype, otype>( \ [[kernel]] void copy_gg_nd3<itype, otype>( \
device const itype* src, \ device const itype* src [[buffer(0)]], \
device otype* dst, \ device otype* dst [[buffer(1)]], \
constant const size_t src_strides[3], \ constant const int64_t* src_strides [[buffer(3)]], \
constant const size_t dst_strides[3], \ constant const int64_t* dst_strides [[buffer(4)]], \
uint3 index [[thread_position_in_grid]]); \ uint3 index [[thread_position_in_grid]]); \
instantiate_copy_g_dim(name, itype, otype, 4) \ instantiate_copy_g_dim(name, itype, otype, 4) \
instantiate_copy_g_dim(name, itype, otype, 5) instantiate_copy_g_dim(name, itype, otype, 5)
@ -218,21 +218,21 @@ template <typename T, typename U>
#define instantiate_copy_g(name, itype, otype) \ #define instantiate_copy_g(name, itype, otype) \
template [[host_name(name)]] \ template [[host_name(name)]] \
[[kernel]] void copy_g<itype, otype>( \ [[kernel]] void copy_g<itype, otype>( \
device const itype* src, \ device const itype* src [[buffer(0)]], \
device otype* dst, \ device otype* dst [[buffer(1)]], \
constant const int* src_shape, \ constant const int* src_shape [[buffer(2)]], \
constant const size_t* src_strides, \ constant const int64_t* src_strides [[buffer(3)]], \
constant const int& ndim, \ constant const int& ndim [[buffer(5)]], \
uint3 index [[thread_position_in_grid]], \ uint3 index [[thread_position_in_grid]], \
uint3 grid_dim [[threads_per_grid]]); \ uint3 grid_dim [[threads_per_grid]]); \
template [[host_name("g" name)]] \ template [[host_name("g" name)]] \
[[kernel]] void copy_gg<itype, otype>( \ [[kernel]] void copy_gg<itype, otype>( \
device const itype* src, \ device const itype* src [[buffer(0)]], \
device otype* dst, \ device otype* dst [[buffer(1)]], \
constant const int* src_shape, \ constant const int* src_shape [[buffer(2)]], \
constant const size_t* src_strides, \ constant const int64_t* src_strides [[buffer(3)]], \
constant const size_t* dst_strides, \ constant const int64_t* dst_strides [[buffer(4)]], \
constant const int& ndim, \ constant const int& ndim [[buffer(5)]], \
uint3 index [[thread_position_in_grid]]); uint3 index [[thread_position_in_grid]]);
#define instantiate_copy_all(tname, itype, otype) \ #define instantiate_copy_all(tname, itype, otype) \

View File

@ -1,4 +1,4 @@
// Copyright © 2023 Apple Inc. // Copyright © 2023-2024 Apple Inc.
#pragma once #pragma once
@ -65,12 +65,18 @@ struct Limits<bool> {
// Indexing utils // Indexing utils
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
inline size_t elem_to_loc( #define MLX_MTL_PRAGMA_UNROLL _Pragma("clang loop unroll(full)")
///////////////////////////////////////////////////////////////////////////////
// Single Array with generic dims
template <typename stride_t>
METAL_FUNC stride_t elem_to_loc(
uint elem, uint elem,
device const int* shape, device const int* shape,
device const size_t* strides, device const stride_t* strides,
int ndim) { int ndim) {
size_t loc = 0; stride_t loc = 0;
for (int i = ndim - 1; i >= 0 && elem > 0; --i) { for (int i = ndim - 1; i >= 0 && elem > 0; --i) {
loc += (elem % shape[i]) * strides[i]; loc += (elem % shape[i]) * strides[i];
elem /= shape[i]; elem /= shape[i];
@ -78,12 +84,13 @@ inline size_t elem_to_loc(
return loc; return loc;
} }
inline size_t elem_to_loc( template <typename stride_t>
METAL_FUNC stride_t elem_to_loc(
uint elem, uint elem,
constant const int* shape, constant const int* shape,
constant const size_t* strides, constant const stride_t* strides,
int ndim) { int ndim) {
size_t loc = 0; stride_t loc = 0;
for (int i = ndim - 1; i >= 0 && elem > 0; --i) { for (int i = ndim - 1; i >= 0 && elem > 0; --i) {
loc += (elem % shape[i]) * strides[i]; loc += (elem % shape[i]) * strides[i];
elem /= shape[i]; elem /= shape[i];
@ -91,52 +98,59 @@ inline size_t elem_to_loc(
return loc; return loc;
} }
template <int NDIM> // Non templated version to handle arbitrary dims
inline uint3 elem_to_loc_3_nd( template <typename stride_t>
METAL_FUNC stride_t elem_to_loc(
uint3 elem, uint3 elem,
constant const int shape[NDIM], constant const int* shape,
constant const size_t a_strides[NDIM], constant const stride_t* strides,
constant const size_t b_strides[NDIM], int ndim) {
constant const size_t c_strides[NDIM]) { stride_t loc = elem.x * strides[ndim - 1] + elem.y * strides[ndim - 2];
uint3 loc = { for (int d = ndim - 3; d >= 0; --d) {
static_cast<uint>( loc += (elem.z % shape[d]) * strides[d];
elem.x * a_strides[NDIM - 1] + elem.y * a_strides[NDIM - 2]),
static_cast<uint>(
elem.x * b_strides[NDIM - 1] + elem.y * b_strides[NDIM - 2]),
static_cast<uint>(
elem.x * c_strides[NDIM - 1] + elem.y * c_strides[NDIM - 2])};
for (int d = NDIM - 3; d >= 0; --d) {
uint l = elem.z % shape[d];
loc.x += l * a_strides[d];
loc.y += l * b_strides[d];
loc.z += l * c_strides[d];
elem.z /= shape[d]; elem.z /= shape[d];
} }
return loc; return loc;
} }
///////////////////////////////////////////////////////////////////////////////
// Single Array with fixed N dims
template <typename stride_t>
METAL_FUNC stride_t elem_to_loc_1(uint elem, constant const stride_t& stride) {
return elem * stride;
}
template <typename stride_t>
METAL_FUNC stride_t
elem_to_loc_2(uint2 elem, constant const stride_t strides[2]) {
return elem.x * strides[1] + elem.y * strides[0];
}
template <typename stride_t>
METAL_FUNC stride_t
elem_to_loc_3(uint3 elem, constant const stride_t strides[3]) {
return elem.x * strides[2] + elem.y * strides[1] + elem.z * strides[0];
}
template <int NDIM> template <int NDIM>
inline uint2 elem_to_loc_2_nd( METAL_FUNC size_t elem_to_loc_nd(
uint3 elem, uint elem,
constant const int shape[NDIM], device const int* shape,
constant const size_t a_strides[NDIM], device const size_t* strides) {
constant const size_t b_strides[NDIM]) { size_t loc = (elem % shape[NDIM - 1]) * strides[NDIM - 1];
uint2 loc = {
static_cast<uint>( MLX_MTL_PRAGMA_UNROLL
elem.x * a_strides[NDIM - 1] + elem.y * a_strides[NDIM - 2]), for (int d = NDIM - 2; d >= 0; --d) {
static_cast<uint>( elem /= shape[d + 1];
elem.x * b_strides[NDIM - 1] + elem.y * b_strides[NDIM - 2])}; loc += (elem % shape[d]) * strides[d];
for (int d = NDIM - 3; d >= 0; --d) {
uint l = elem.z % shape[d];
loc.x += l * a_strides[d];
loc.y += l * b_strides[d];
elem.z /= shape[d];
} }
return loc; return loc;
} }
template <int NDIM> template <int NDIM>
inline size_t elem_to_loc_nd( METAL_FUNC size_t elem_to_loc_nd(
uint3 elem, uint3 elem,
constant const int shape[NDIM], constant const int shape[NDIM],
constant const size_t strides[NDIM]) { constant const size_t strides[NDIM]) {
@ -148,33 +162,59 @@ inline size_t elem_to_loc_nd(
return loc; return loc;
} }
inline size_t elem_to_loc_1(uint elem, constant const size_t& stride) { template <int NDIM>
return elem * stride; METAL_FUNC int64_t elem_to_loc_nd(
uint elem,
constant const int shape[NDIM],
constant const int64_t strides[NDIM]) {
int64_t loc = (elem % shape[NDIM - 1]) * strides[NDIM - 1];
MLX_MTL_PRAGMA_UNROLL
for (int d = NDIM - 2; d >= 0; --d) {
elem /= shape[d + 1];
loc += (elem % shape[d]) * strides[d];
}
return loc;
} }
inline size_t elem_to_loc_2(uint2 elem, constant const size_t strides[2]) { template <int NDIM>
return elem.x * strides[1] + elem.y * strides[0]; METAL_FUNC int64_t elem_to_loc_nd(
}
inline size_t elem_to_loc_3(uint3 elem, constant const size_t strides[3]) {
return elem.x * strides[2] + elem.y * strides[1] + elem.z * strides[0];
}
// Non templated version to handle arbitrary dims
inline size_t elem_to_loc(
uint3 elem, uint3 elem,
constant const int* shape, constant const int shape[NDIM],
constant const size_t* strides, constant const int64_t strides[NDIM]) {
int ndim) { int64_t loc = elem.x * strides[NDIM - 1] + elem.y * strides[NDIM - 2];
size_t loc = elem.x * strides[ndim - 1] + elem.y * strides[ndim - 2]; for (int d = NDIM - 3; d >= 0; --d) {
for (int d = ndim - 3; d >= 0; --d) {
loc += (elem.z % shape[d]) * strides[d]; loc += (elem.z % shape[d]) * strides[d];
elem.z /= shape[d]; elem.z /= shape[d];
} }
return loc; return loc;
} }
inline uint3 elem_to_loc_3_nd( ///////////////////////////////////////////////////////////////////////////////
// Multiple Arrays with generic dims
METAL_FUNC uint2 elem_to_loc_2_nd(
uint3 elem,
constant const int* shape,
constant const size_t* a_strides,
constant const size_t* b_strides,
int ndim) {
uint2 loc = {
static_cast<uint>(
elem.x * a_strides[ndim - 1] + elem.y * a_strides[ndim - 2]),
static_cast<uint>(
elem.x * b_strides[ndim - 1] + elem.y * b_strides[ndim - 2])};
for (int d = ndim - 3; d >= 0; --d) {
uint l = elem.z % shape[d];
loc.x += l * a_strides[d];
loc.y += l * b_strides[d];
elem.z /= shape[d];
}
return loc;
}
METAL_FUNC uint3 elem_to_loc_3_nd(
uint3 elem, uint3 elem,
constant const int* shape, constant const int* shape,
constant const size_t* a_strides, constant const size_t* a_strides,
@ -198,18 +238,21 @@ inline uint3 elem_to_loc_3_nd(
return loc; return loc;
} }
inline uint2 elem_to_loc_2_nd( ///////////////////////////////////////////////////////////////////////////////
// Multiple Arrays with fixed N dims
template <int NDIM>
METAL_FUNC uint2 elem_to_loc_2_nd(
uint3 elem, uint3 elem,
constant const int* shape, constant const int shape[NDIM],
constant const size_t* a_strides, constant const size_t a_strides[NDIM],
constant const size_t* b_strides, constant const size_t b_strides[NDIM]) {
int ndim) {
uint2 loc = { uint2 loc = {
static_cast<uint>( static_cast<uint>(
elem.x * a_strides[ndim - 1] + elem.y * a_strides[ndim - 2]), elem.x * a_strides[NDIM - 1] + elem.y * a_strides[NDIM - 2]),
static_cast<uint>( static_cast<uint>(
elem.x * b_strides[ndim - 1] + elem.y * b_strides[ndim - 2])}; elem.x * b_strides[NDIM - 1] + elem.y * b_strides[NDIM - 2])};
for (int d = ndim - 3; d >= 0; --d) { for (int d = NDIM - 3; d >= 0; --d) {
uint l = elem.z % shape[d]; uint l = elem.z % shape[d];
loc.x += l * a_strides[d]; loc.x += l * a_strides[d];
loc.y += l * b_strides[d]; loc.y += l * b_strides[d];
@ -219,55 +262,26 @@ inline uint2 elem_to_loc_2_nd(
} }
template <int NDIM> template <int NDIM>
inline uint elem_to_loc_nd( METAL_FUNC uint3 elem_to_loc_3_nd(
uint elem, uint3 elem,
device const int* shape, constant const int shape[NDIM],
device const size_t* strides); constant const size_t a_strides[NDIM],
constant const size_t b_strides[NDIM],
template <> constant const size_t c_strides[NDIM]) {
inline uint elem_to_loc_nd<1>( uint3 loc = {
uint elem, static_cast<uint>(
device const int* shape, elem.x * a_strides[NDIM - 1] + elem.y * a_strides[NDIM - 2]),
device const size_t* strides) { static_cast<uint>(
return (elem % shape[0]) * strides[0]; elem.x * b_strides[NDIM - 1] + elem.y * b_strides[NDIM - 2]),
} static_cast<uint>(
elem.x * c_strides[NDIM - 1] + elem.y * c_strides[NDIM - 2])};
template <> for (int d = NDIM - 3; d >= 0; --d) {
inline uint elem_to_loc_nd<2>( uint l = elem.z % shape[d];
uint elem, loc.x += l * a_strides[d];
device const int* shape, loc.y += l * b_strides[d];
device const size_t* strides) { loc.z += l * c_strides[d];
uint loc = (elem % shape[1]) * strides[1]; elem.z /= shape[d];
elem /= shape[1]; }
loc += (elem % shape[0]) * strides[0];
return loc;
}
template <>
inline uint elem_to_loc_nd<3>(
uint elem,
device const int* shape,
device const size_t* strides) {
uint loc = (elem % shape[2]) * strides[2];
elem /= shape[2];
loc += (elem % shape[1]) * strides[1];
elem /= shape[1];
loc += (elem % shape[0]) * strides[0];
return loc;
}
template <>
inline uint elem_to_loc_nd<4>(
uint elem,
device const int* shape,
device const size_t* strides) {
uint loc = (elem % shape[3]) * strides[3];
elem /= shape[3];
loc += (elem % shape[2]) * strides[2];
elem /= shape[2];
loc += (elem % shape[1]) * strides[1];
elem /= shape[1];
loc += (elem % shape[0]) * strides[0];
return loc; return loc;
} }

View File

@ -206,7 +206,7 @@ inline auto collapse_batches(const array& a, const array& b) {
std::vector<size_t> B_bstride{b.strides().begin(), b.strides().end() - 2}; std::vector<size_t> B_bstride{b.strides().begin(), b.strides().end() - 2};
auto [batch_shape, batch_strides] = auto [batch_shape, batch_strides] =
collapse_contiguous_dims(A_bshape, {A_bstride, B_bstride}); collapse_contiguous_dims(A_bshape, std::vector{A_bstride, B_bstride});
auto A_batch_stride = batch_strides[0]; auto A_batch_stride = batch_strides[0];
auto B_batch_stride = batch_strides[1]; auto B_batch_stride = batch_strides[1];
@ -237,8 +237,8 @@ inline auto collapse_batches(const array& a, const array& b, const array& c) {
std::vector<size_t> B_bstride{b.strides().begin(), b.strides().end() - 2}; std::vector<size_t> B_bstride{b.strides().begin(), b.strides().end() - 2};
std::vector<size_t> C_bstride{c.strides().begin(), c.strides().end() - 2}; std::vector<size_t> C_bstride{c.strides().begin(), c.strides().end() - 2};
auto [batch_shape, batch_strides] = auto [batch_shape, batch_strides] = collapse_contiguous_dims(
collapse_contiguous_dims(A_bshape, {A_bstride, B_bstride, C_bstride}); A_bshape, std::vector{A_bstride, B_bstride, C_bstride});
auto A_batch_stride = batch_strides[0]; auto A_batch_stride = batch_strides[0];
auto B_batch_stride = batch_strides[1]; auto B_batch_stride = batch_strides[1];

View File

@ -865,7 +865,73 @@ void Sqrt::eval_gpu(const std::vector<array>& inputs, array& out) {
} }
void Slice::eval_gpu(const std::vector<array>& inputs, array& out) { void Slice::eval_gpu(const std::vector<array>& inputs, array& out) {
eval(inputs, out); assert(inputs.size() == 1);
if (out.size() == 0) {
out.set_data(nullptr);
return;
}
auto& in = inputs[0];
// Calculate out strides, initial offset and if copy needs to be made
auto [copy_needed, data_offset, inp_strides] = prepare_slice(in);
// Do copy if needed
if (copy_needed) {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
std::vector<int64_t> ostrides{out.strides().begin(), out.strides().end()};
copy_gpu_inplace(
/* const array& in = */ in,
/* array& out = */ out,
/* const std::vector<int>& data_shape = */ out.shape(),
/* const std::vector<stride_t>& i_strides = */ inp_strides,
/* const std::vector<stride_t>& o_strides = */ ostrides,
/* int64_t i_offset = */ data_offset,
/* int64_t o_offset = */ 0,
/* CopyType ctype = */ CopyType::General,
/* const Stream& s = */ stream());
} else {
std::vector<size_t> ostrides{inp_strides.begin(), inp_strides.end()};
shared_buffer_slice(in, ostrides, data_offset, out);
}
}
void SliceUpdate::eval_gpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
if (out.size() == 0) {
out.set_data(nullptr);
return;
}
auto& in = inputs[0];
auto& upd = inputs[1];
if (upd.size() == 0) {
out.copy_shared_buffer(in);
return;
}
// Check if materialization is needed
auto ctype = in.flags().contiguous && in.size() == in.data_size()
? CopyType::Vector
: CopyType::General;
copy_gpu(in, out, in.data_size() == 1 ? CopyType::Scalar : ctype, stream());
// Calculate out strides, initial offset and if copy needs to be made
auto [data_offset, out_strides] = prepare_slice(out);
// Do copy
std::vector<int64_t> upd_strides{upd.strides().begin(), upd.strides().end()};
copy_gpu_inplace<int64_t>(
/* const array& src = */ upd,
/* array& dst = */ out,
/* const std::vector<int>& data_shape = */ upd.shape(),
/* const std::vector<stride_t>& i_strides = */ upd_strides,
/* const std::vector<stride_t>& o_strides = */ out_strides,
/* int64_t i_offset = */ 0,
/* int64_t o_offset = */ data_offset,
/* CopyType ctype = */ CopyType::GeneralGeneral,
/* const Stream& s = */ stream());
} }
void StopGradient::eval_gpu(const std::vector<array>& inputs, array& out) { void StopGradient::eval_gpu(const std::vector<array>& inputs, array& out) {

View File

@ -9,16 +9,43 @@ namespace mlx::core {
namespace { namespace {
void set_array_buffer( inline void
MTL::ComputeCommandEncoder* enc, set_array_buffer(MTL::ComputeCommandEncoder* enc, const array& a, int idx) {
const array& a,
int idx) {
auto a_buf = static_cast<const MTL::Buffer*>(a.buffer().ptr()); auto a_buf = static_cast<const MTL::Buffer*>(a.buffer().ptr());
auto offset = a.data<char>() - auto offset = a.data<char>() -
static_cast<char*>(const_cast<MTL::Buffer*>(a_buf)->contents()); static_cast<char*>(const_cast<MTL::Buffer*>(a_buf)->contents());
enc->setBuffer(a_buf, offset, idx); enc->setBuffer(a_buf, offset, idx);
} }
inline void set_array_buffer(
MTL::ComputeCommandEncoder* enc,
const array& a,
int64_t offset,
int idx) {
auto a_buf = static_cast<const MTL::Buffer*>(a.buffer().ptr());
auto base_offset = a.data<char>() -
static_cast<char*>(const_cast<MTL::Buffer*>(a_buf)->contents());
base_offset += offset;
enc->setBuffer(a_buf, base_offset, idx);
}
template <typename T>
inline void set_vector_bytes(
MTL::ComputeCommandEncoder* enc,
const std::vector<T>& vec,
size_t nelems,
int idx) {
enc->setBytes(vec.data(), nelems * sizeof(T), idx);
}
template <typename T>
inline void set_vector_bytes(
MTL::ComputeCommandEncoder* enc,
const std::vector<T>& vec,
int idx) {
return set_vector_bytes(enc, vec, vec.size(), idx);
}
std::string type_to_name(const array& a) { std::string type_to_name(const array& a) {
std::string tname; std::string tname;
switch (a.dtype()) { switch (a.dtype()) {

View File

@ -87,6 +87,7 @@ NO_GPU(Sign)
NO_GPU(Sin) NO_GPU(Sin)
NO_GPU(Sinh) NO_GPU(Sinh)
NO_GPU(Slice) NO_GPU(Slice)
NO_GPU(SliceUpdate)
NO_GPU(Softmax) NO_GPU(Softmax)
NO_GPU(Sort) NO_GPU(Sort)
NO_GPU_MULTI(Split) NO_GPU_MULTI(Split)

View File

@ -445,6 +445,60 @@ array expand_dims(
return reshape(a, out_shape, s); return reshape(a, out_shape, s);
} }
// Slice helper
namespace {
inline auto normalize_slice(
const std::vector<int>& shape,
std::vector<int>& start,
std::vector<int>& stop,
std::vector<int>& strides) {
std::vector<int> out_shape(shape.size());
bool has_neg_strides = false;
for (int i = 0; i < shape.size(); ++i) {
// Following numpy docs
// Negative i and j are interpreted as n + i and n + j where n is
// the number of elements in the corresponding dimension. Negative
// k makes stepping go towards smaller indices
auto n = shape[i];
auto s = start[i];
s = s < 0 ? s + n : s;
auto e = stop[i];
e = e < 0 ? e + n : e;
// Note: -ve strides require start >= stop
if (strides[i] < 0) {
has_neg_strides = true;
// Clamp to bounds
auto st = std::min(s, n - 1);
auto ed = std::max(-1, e);
start[i] = st;
stop[i] = ed > st ? st : ed;
auto str = -strides[i];
out_shape[i] = (start[i] - stop[i] + str - 1) / str;
} else {
// Clamp to bounds
auto st = std::max(0, std::min(s, n));
auto ed = std::max(0, std::min(e, n));
start[i] = st;
stop[i] = ed < st ? st : ed;
out_shape[i] = (stop[i] - start[i] + strides[i] - 1) / strides[i];
}
}
return std::make_pair(has_neg_strides, out_shape);
}
} // namespace
array slice( array slice(
const array& a, const array& a,
std::vector<int> start, std::vector<int> start,
@ -459,113 +513,13 @@ array slice(
throw std::invalid_argument(msg.str()); throw std::invalid_argument(msg.str());
} }
std::vector<int> negatively_strided_axes; auto [has_neg_strides, out_shape] =
std::vector<std::vector<int>> negatively_strided_slices; normalize_slice(a.shape(), start, stop, strides);
std::vector<int> out_shape(a.ndim());
for (int i = 0; i < a.ndim(); ++i) {
// Following numpy docs
// Negative i and j are interpreted as n + i and n + j where n is
// the number of elements in the corresponding dimension. Negative
// k makes stepping go towards smaller indices
auto n = a.shape(i); if (!has_neg_strides && out_shape == a.shape()) {
auto s = start[i];
s = s < 0 ? s + n : s;
auto e = stop[i];
e = e < 0 ? e + n : e;
// Note: We pass positive strides to the primitive and then flip
// the axes later as needed
if (strides[i] < 0) {
negatively_strided_axes.push_back(i);
auto st = std::min(s, n - 1);
auto ed = std::max(e, -1);
negatively_strided_slices.push_back({st, ed, strides[i]});
start[i] = 0;
stop[i] = n;
strides[i] = 1;
} else {
start[i] = s;
stop[i] = e < s ? s : e;
}
// Clamp to bounds
start[i] = std::max(0, std::min(start[i], n));
stop[i] = std::max(0, std::min(stop[i], n));
out_shape[i] = (stop[i] - start[i] + strides[i] - 1) / strides[i];
}
// If strides are negative, slice and then make a copy with axes flipped
if (negatively_strided_axes.size() > 0) {
// First, take the slice of the positively strided axes
auto out = array(
out_shape,
a.dtype(),
std::make_unique<Slice>(
to_stream(s),
std::move(start),
std::move(stop),
std::move(strides)),
{a});
std::vector<array> indices;
std::vector<int> slice_sizes = out.shape();
std::vector<int> t_axes(out.ndim(), -1);
std::vector<int> out_reshape(out.ndim(), -1);
int n_axes = negatively_strided_axes.size();
for (int i = 0; i < n_axes; i++) {
// Get axis and corresponding slice
auto ax = negatively_strided_axes[i];
auto sl = negatively_strided_slices[i];
// Get indices for the slice
auto ax_idx = arange(sl[0], sl[1], sl[2], s);
// Reshape indices for broadcast as needed
std::vector<int> ax_idx_shape(n_axes, 1);
ax_idx_shape[i] = ax_idx.size();
ax_idx = reshape(ax_idx, ax_idx_shape, s);
// Add indices to list
indices.push_back(ax_idx);
// Set slice size for axis
slice_sizes[ax] = 1;
// Gather moves the axis up, remainder needs to be squeezed
out_reshape[i] = indices[i].size();
// Gather moves the axis up, needs to be transposed
t_axes[ax] = i;
}
// Prepare out_reshape to squeeze gathered dims
// Prepare to transpose dims as needed
int j = n_axes;
for (int i = 0; j < out.ndim() && i < out.ndim(); i++) {
if (t_axes[i] < 0) {
t_axes[i] = j;
out_reshape[j] = out_shape[i];
j++;
}
}
// Gather
out = gather(out, indices, negatively_strided_axes, slice_sizes, s);
// Squeeze dims
out = reshape(out, out_reshape, s);
// Transpose dims
out = transpose(out, t_axes, s);
return out;
}
if (out_shape == a.shape()) {
return a; return a;
} }
return array( return array(
out_shape, out_shape,
a.dtype(), a.dtype(),
@ -582,6 +536,56 @@ array slice(
return slice(a, start, stop, std::vector<int>(a.ndim(), 1), to_stream(s)); return slice(a, start, stop, std::vector<int>(a.ndim(), 1), to_stream(s));
} }
/** Update a slice from the source array */
array slice_update(
const array& src,
const array& update,
std::vector<int> start,
std::vector<int> stop,
std::vector<int> strides,
StreamOrDevice s /* = {} */) {
// Check dimensions
if (start.size() != src.ndim() || stop.size() != src.ndim() ||
strides.size() != src.ndim()) {
std::ostringstream msg;
msg << "[slice] Invalid number of indices or strides for "
<< "array with dimension " << src.ndim() << ".";
throw std::invalid_argument(msg.str());
}
// Process slice dimensions
auto [has_neg_strides, upd_shape] =
normalize_slice(src.shape(), start, stop, strides);
// Broadcast update shape to slice shape
auto upd_shape_broadcast = broadcast_shapes(upd_shape, update.shape());
auto update_broadcasted = broadcast_to(update, upd_shape_broadcast, s);
// If the entire src is the slice, just return the update
if (!has_neg_strides && upd_shape == src.shape()) {
return astype(update_broadcasted, src.dtype(), s);
}
return array(
src.shape(),
src.dtype(),
std::make_unique<SliceUpdate>(
to_stream(s), std::move(start), std::move(stop), std::move(strides)),
{src, update});
}
/** Update a slice from the source array with stride 1 in each dimension */
array slice_update(
const array& src,
const array& update,
std::vector<int> start,
std::vector<int> stop,
StreamOrDevice s /* = {} */) {
auto strides = std::vector<int>(src.ndim(), 1);
return slice_update(
src, update, std::move(start), std::move(stop), std::move(strides), s);
}
std::vector<array> split( std::vector<array> split(
const array& a, const array& a,
const std::vector<int>& indices, const std::vector<int>& indices,

View File

@ -177,6 +177,23 @@ array slice(
const std::vector<int>& stop, const std::vector<int>& stop,
StreamOrDevice s = {}); StreamOrDevice s = {});
/** Update a slice from the source array */
array slice_update(
const array& src,
const array& update,
std::vector<int> start,
std::vector<int> stop,
std::vector<int> strides,
StreamOrDevice s = {});
/** Update a slice from the source array with stride 1 in each dimension */
array slice_update(
const array& src,
const array& update,
std::vector<int> start,
std::vector<int> stop,
StreamOrDevice s = {});
/** Split an array into sub-arrays along a given axis. */ /** Split an array into sub-arrays along a given axis. */
std::vector<array> std::vector<array>
split(const array& a, int num_splits, int axis, StreamOrDevice s = {}); split(const array& a, int num_splits, int axis, StreamOrDevice s = {});

View File

@ -2849,6 +2849,114 @@ bool Slice::is_equivalent(const Primitive& other) const {
end_indices_ == s_other.end_indices_ && strides_ == s_other.strides_); end_indices_ == s_other.end_indices_ && strides_ == s_other.strides_);
} }
std::pair<std::vector<array>, std::vector<int>> SliceUpdate::vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) {
assert(inputs.size() == 2);
assert(axes.size() == 2);
auto start = start_indices_;
auto stop = end_indices_;
auto strides = strides_;
auto src = inputs[0];
auto upd = inputs[1];
auto src_ax = axes[0];
auto upd_ax = axes[1];
// No vmapping needed
if (src_ax == -1 && upd_ax == -1) {
return {{slice_update(src, upd, start, stop, strides, stream())}, {-1}};
}
// Broadcast src
if (src_ax == -1) {
src = expand_dims(src, upd_ax, stream());
auto shape = src.shape();
shape[upd_ax] = upd.shape(upd_ax);
src = broadcast_to(src, shape, stream());
src_ax = upd_ax;
}
// Broadcast upd
if (upd_ax == -1) {
upd = expand_dims(upd, src_ax, stream());
upd_ax = src_ax;
}
if (src_ax != upd_ax) {
upd = moveaxis(upd, upd_ax, src_ax, stream());
}
start.insert(start.begin() + src_ax, 0);
stop.insert(stop.begin() + src_ax, src.shape(src_ax));
strides.insert(strides.begin() + src_ax, 1);
return {{slice_update(src, upd, start, stop, strides, stream())}, {src_ax}};
}
std::vector<array> SliceUpdate::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums,
const std::vector<array>&) {
// Check inputs
assert(primals.size() == 2);
auto& cotan = cotangents[0];
auto& src = primals[0];
auto& upd = primals[1];
std::vector<array> vjps;
for (int num : argnums) {
// Vjp for source
if (num == 0) {
auto grad = slice_update(
cotan,
zeros_like(upd, stream()),
start_indices_,
end_indices_,
strides_,
stream());
vjps.push_back(grad);
}
// Vjp fpr updates
else {
auto grad =
slice(cotan, start_indices_, end_indices_, strides_, stream());
vjps.push_back(grad);
}
}
return vjps;
}
std::vector<array> SliceUpdate::jvp(
const std::vector<array>& primals,
const std::vector<array>& tangents,
const std::vector<int>& argnums) {
// Check inputs
assert(primals.size() == 2);
return {slice_update(
tangents[0],
tangents[1],
start_indices_,
end_indices_,
strides_,
stream())};
}
bool SliceUpdate::is_equivalent(const Primitive& other) const {
const SliceUpdate& s_other = static_cast<const SliceUpdate&>(other);
return (
start_indices_ == s_other.start_indices_ &&
end_indices_ == s_other.end_indices_ && strides_ == s_other.strides_);
}
std::pair<std::vector<array>, std::vector<int>> Softmax::vmap( std::pair<std::vector<array>, std::vector<int>> Softmax::vmap(
const std::vector<array>& inputs, const std::vector<array>& inputs,
const std::vector<int>& axes) { const std::vector<int>& axes) {

View File

@ -1660,6 +1660,44 @@ class Slice : public UnaryPrimitive {
std::vector<int> strides_; std::vector<int> strides_;
void eval(const std::vector<array>& inputs, array& out); void eval(const std::vector<array>& inputs, array& out);
std::tuple<bool, int64_t, std::vector<int64_t>> prepare_slice(
const array& in);
void shared_buffer_slice(
const array& in,
const std::vector<size_t>& out_strides,
size_t data_offset,
array& out);
};
class SliceUpdate : public UnaryPrimitive {
public:
explicit SliceUpdate(
Stream stream,
const std::vector<int>& start_indices,
const std::vector<int>& end_indices,
const std::vector<int>& strides)
: UnaryPrimitive(stream),
start_indices_(start_indices),
end_indices_(end_indices),
strides_(strides){};
void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override;
DEFINE_VMAP()
DEFINE_GRADS()
DEFINE_PRINT(SliceUpdate)
bool is_equivalent(const Primitive& other) const override;
private:
std::vector<int> start_indices_;
std::vector<int> end_indices_;
std::vector<int> strides_;
void eval(const std::vector<array>& inputs, array& out);
std::tuple<int64_t, std::vector<int64_t>> prepare_slice(const array& in);
}; };
class Softmax : public UnaryPrimitive { class Softmax : public UnaryPrimitive {

View File

@ -329,4 +329,13 @@ std::ostream& operator<<(std::ostream& os, const std::vector<size_t>& v) {
return os; return os;
} }
std::ostream& operator<<(std::ostream& os, const std::vector<int64_t>& v) {
os << "(";
for (int i = 0; i < v.size(); ++i) {
os << v[i] << ((i == v.size() - 1) ? "" : ",");
}
os << ")";
return os;
}
} // namespace mlx::core } // namespace mlx::core

View File

@ -1,4 +1,4 @@
// Copyright © 2023 Apple Inc. // Copyright © 2023-2024 Apple Inc.
#pragma once #pragma once
@ -76,6 +76,7 @@ std::ostream& operator<<(std::ostream& os, const Dtype::Kind& k);
std::ostream& operator<<(std::ostream& os, array a); std::ostream& operator<<(std::ostream& os, array a);
std::ostream& operator<<(std::ostream& os, const std::vector<int>& v); std::ostream& operator<<(std::ostream& os, const std::vector<int>& v);
std::ostream& operator<<(std::ostream& os, const std::vector<size_t>& v); std::ostream& operator<<(std::ostream& os, const std::vector<size_t>& v);
std::ostream& operator<<(std::ostream& os, const std::vector<int64_t>& v);
inline std::ostream& operator<<(std::ostream& os, const complex64_t& v) { inline std::ostream& operator<<(std::ostream& os, const complex64_t& v) {
return os << v.real() << (v.imag() >= 0 ? "+" : "") << v.imag() << "j"; return os << v.real() << (v.imag() >= 0 ? "+" : "") << v.imag() << "j";
} }

View File

@ -198,6 +198,31 @@ TEST_CASE("test slice") {
CHECK(array_equal(out, array({0, 2, 4, 6}, {2, 2})).item<bool>()); CHECK(array_equal(out, array({0, 2, 4, 6}, {2, 2})).item<bool>());
} }
TEST_CASE("test slice update") {
array x = array({0., 0., 0., 0., 0., 0., 0., 0.}, {8}, float32);
array y = array(
{
1.,
2.,
3.,
4.,
},
{4},
float32);
auto out = slice_update(x, y, {2}, {6}, {1});
CHECK(array_equal(slice(out, {2}, {6}, {1}), y).item<bool>());
out = slice_update(x, y, {5}, {1}, {-1});
CHECK(array_equal(slice(out, {5}, {1}, {-1}), y).item<bool>());
x = reshape(x, {2, 4});
out = slice_update(x, y, {0, 0}, {2, 4}, {1, 1});
out = reshape(out, {8});
CHECK(array_equal(slice(out, {0}, {4}, {1}), y).item<bool>());
CHECK(array_equal(slice(out, {4}, {8}, {1}), y).item<bool>());
}
TEST_CASE("test split") { TEST_CASE("test split") {
array x = array(1); array x = array(1);
CHECK_THROWS(split(x, 0)); CHECK_THROWS(split(x, 0));