mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 01:41:17 +08:00
Add a SliceUpdate op and primitive (#850)
* Enable copy to work with int64 strides * Fix uniform buffer indices or copy kernel arguments * Update utils.h * Remove manual unrolling of elem to loc loop * GPU copy updated to handle negative strides * Add slice update primitive
This commit is contained in:
parent
73a8c090e0
commit
cec8661113
@ -69,6 +69,7 @@ DEFAULT(Select)
|
|||||||
DEFAULT(Sigmoid)
|
DEFAULT(Sigmoid)
|
||||||
DEFAULT(Sign)
|
DEFAULT(Sign)
|
||||||
DEFAULT(Slice)
|
DEFAULT(Slice)
|
||||||
|
DEFAULT(SliceUpdate)
|
||||||
DEFAULT_MULTI(Split)
|
DEFAULT_MULTI(Split)
|
||||||
DEFAULT(Sort)
|
DEFAULT(Sort)
|
||||||
DEFAULT(StopGradient)
|
DEFAULT(StopGradient)
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
// Copyright © 2023 Apple Inc.
|
// Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
#include <numeric>
|
#include <numeric>
|
||||||
|
|
||||||
@ -25,121 +25,196 @@ void copy_vector(const array& src, array& dst) {
|
|||||||
std::copy(src_ptr, src_ptr + src.data_size(), dst_ptr);
|
std::copy(src_ptr, src_ptr + src.data_size(), dst_ptr);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename SrcT, typename DstT>
|
template <typename SrcT, typename DstT, typename stride_t>
|
||||||
void copy_general_dim1(const array& src, array& dst) {
|
void copy_general_dim1(
|
||||||
|
const array& src,
|
||||||
|
array& dst,
|
||||||
|
const std::vector<int>& data_shape,
|
||||||
|
const std::vector<stride_t>& i_strides,
|
||||||
|
int64_t i_offset) {
|
||||||
const SrcT* src_ptr = src.data<SrcT>();
|
const SrcT* src_ptr = src.data<SrcT>();
|
||||||
DstT* dst_ptr = dst.data<DstT>();
|
DstT* dst_ptr = dst.data<DstT>();
|
||||||
size_t src_idx = 0;
|
stride_t src_idx = i_offset;
|
||||||
size_t dst_idx = 0;
|
stride_t dst_idx = 0;
|
||||||
for (size_t i = 0; i < src.shape()[0]; ++i) {
|
for (int i = 0; i < data_shape[0]; ++i) {
|
||||||
dst_ptr[dst_idx++] = static_cast<DstT>(src_ptr[src_idx]);
|
dst_ptr[dst_idx++] = static_cast<DstT>(src_ptr[src_idx]);
|
||||||
src_idx += src.strides()[0];
|
src_idx += i_strides[0];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename SrcT, typename DstT>
|
template <typename SrcT, typename DstT>
|
||||||
void copy_general_dim2(const array& src, array& dst) {
|
inline void copy_general_dim1(const array& src, array& dst) {
|
||||||
|
return copy_general_dim1<SrcT, DstT, size_t>(
|
||||||
|
src, dst, src.shape(), src.strides(), 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename SrcT, typename DstT, typename stride_t>
|
||||||
|
void copy_general_dim2(
|
||||||
|
const array& src,
|
||||||
|
array& dst,
|
||||||
|
const std::vector<int>& data_shape,
|
||||||
|
const std::vector<stride_t>& i_strides,
|
||||||
|
int64_t i_offset) {
|
||||||
const SrcT* src_ptr = src.data<SrcT>();
|
const SrcT* src_ptr = src.data<SrcT>();
|
||||||
DstT* dst_ptr = dst.data<DstT>();
|
DstT* dst_ptr = dst.data<DstT>();
|
||||||
size_t src_idx = 0;
|
stride_t src_idx = i_offset;
|
||||||
size_t dst_idx = 0;
|
stride_t dst_idx = 0;
|
||||||
for (size_t i = 0; i < src.shape()[0]; ++i) {
|
for (int i = 0; i < data_shape[0]; ++i) {
|
||||||
for (size_t j = 0; j < src.shape()[1]; ++j) {
|
for (int j = 0; j < data_shape[1]; ++j) {
|
||||||
dst_ptr[dst_idx++] = static_cast<DstT>(src_ptr[src_idx]);
|
dst_ptr[dst_idx++] = static_cast<DstT>(src_ptr[src_idx]);
|
||||||
src_idx += src.strides()[1];
|
src_idx += i_strides[1];
|
||||||
}
|
}
|
||||||
src_idx += src.strides()[0] - src.strides()[1] * src.shape()[1];
|
src_idx += i_strides[0] - i_strides[1] * data_shape[1];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename SrcT, typename DstT>
|
template <typename SrcT, typename DstT>
|
||||||
void copy_general_dim3(const array& src, array& dst) {
|
inline void copy_general_dim2(const array& src, array& dst) {
|
||||||
|
return copy_general_dim2<SrcT, DstT, size_t>(
|
||||||
|
src, dst, src.shape(), src.strides(), 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename SrcT, typename DstT, typename stride_t>
|
||||||
|
void copy_general_dim3(
|
||||||
|
const array& src,
|
||||||
|
array& dst,
|
||||||
|
const std::vector<int>& data_shape,
|
||||||
|
const std::vector<stride_t>& i_strides,
|
||||||
|
int64_t i_offset) {
|
||||||
const SrcT* src_ptr = src.data<SrcT>();
|
const SrcT* src_ptr = src.data<SrcT>();
|
||||||
DstT* dst_ptr = dst.data<DstT>();
|
DstT* dst_ptr = dst.data<DstT>();
|
||||||
size_t src_idx = 0;
|
stride_t src_idx = i_offset;
|
||||||
size_t dst_idx = 0;
|
stride_t dst_idx = 0;
|
||||||
for (size_t i = 0; i < src.shape()[0]; ++i) {
|
for (int i = 0; i < data_shape[0]; ++i) {
|
||||||
for (size_t j = 0; j < src.shape()[1]; ++j) {
|
for (int j = 0; j < data_shape[1]; ++j) {
|
||||||
for (size_t k = 0; k < src.shape()[2]; ++k) {
|
for (int k = 0; k < data_shape[2]; ++k) {
|
||||||
dst_ptr[dst_idx++] = static_cast<DstT>(src_ptr[src_idx]);
|
dst_ptr[dst_idx++] = static_cast<DstT>(src_ptr[src_idx]);
|
||||||
src_idx += src.strides()[2];
|
src_idx += i_strides[2];
|
||||||
}
|
}
|
||||||
src_idx += src.strides()[1] - src.strides()[2] * src.shape()[2];
|
src_idx += i_strides[1] - i_strides[2] * data_shape[2];
|
||||||
}
|
}
|
||||||
src_idx += src.strides()[0] - src.strides()[1] * src.shape()[1];
|
src_idx += i_strides[0] - i_strides[1] * data_shape[1];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename SrcT, typename DstT>
|
template <typename SrcT, typename DstT>
|
||||||
void copy_general_dim4(const array& src, array& dst) {
|
inline void copy_general_dim3(const array& src, array& dst) {
|
||||||
|
return copy_general_dim3<SrcT, DstT, size_t>(
|
||||||
|
src, dst, src.shape(), src.strides(), 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename SrcT, typename DstT, typename stride_t>
|
||||||
|
void copy_general_dim4(
|
||||||
|
const array& src,
|
||||||
|
array& dst,
|
||||||
|
const std::vector<int>& data_shape,
|
||||||
|
const std::vector<stride_t>& i_strides,
|
||||||
|
int64_t i_offset) {
|
||||||
const SrcT* src_ptr = src.data<SrcT>();
|
const SrcT* src_ptr = src.data<SrcT>();
|
||||||
DstT* dst_ptr = dst.data<DstT>();
|
DstT* dst_ptr = dst.data<DstT>();
|
||||||
size_t src_idx = 0;
|
stride_t src_idx = i_offset;
|
||||||
size_t dst_idx = 0;
|
stride_t dst_idx = 0;
|
||||||
for (size_t i = 0; i < src.shape()[0]; ++i) {
|
for (int i = 0; i < data_shape[0]; ++i) {
|
||||||
for (size_t j = 0; j < src.shape()[1]; ++j) {
|
for (int j = 0; j < data_shape[1]; ++j) {
|
||||||
for (size_t k = 0; k < src.shape()[2]; ++k) {
|
for (int k = 0; k < data_shape[2]; ++k) {
|
||||||
for (size_t ii = 0; ii < src.shape()[3]; ++ii) {
|
for (int ii = 0; ii < data_shape[3]; ++ii) {
|
||||||
dst_ptr[dst_idx++] = static_cast<DstT>(src_ptr[src_idx]);
|
dst_ptr[dst_idx++] = static_cast<DstT>(src_ptr[src_idx]);
|
||||||
src_idx += src.strides()[3];
|
src_idx += i_strides[3];
|
||||||
}
|
}
|
||||||
src_idx += src.strides()[2] - src.strides()[3] * src.shape()[3];
|
src_idx += i_strides[2] - i_strides[3] * data_shape[3];
|
||||||
}
|
}
|
||||||
src_idx += src.strides()[1] - src.strides()[2] * src.shape()[2];
|
src_idx += i_strides[1] - i_strides[2] * data_shape[2];
|
||||||
}
|
}
|
||||||
src_idx += src.strides()[0] - src.strides()[1] * src.shape()[1];
|
src_idx += i_strides[0] - i_strides[1] * data_shape[1];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename SrcT, typename DstT>
|
template <typename SrcT, typename DstT>
|
||||||
void copy_general(const array& src, array& dst) {
|
inline void copy_general_dim4(const array& src, array& dst) {
|
||||||
|
return copy_general_dim4<SrcT, DstT, size_t>(
|
||||||
|
src, dst, src.shape(), src.strides(), 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename SrcT, typename DstT, typename stride_t>
|
||||||
|
void copy_general(
|
||||||
|
const array& src,
|
||||||
|
array& dst,
|
||||||
|
const std::vector<int>& data_shape,
|
||||||
|
const std::vector<stride_t>& i_strides,
|
||||||
|
int64_t i_offset) {
|
||||||
switch (src.ndim()) {
|
switch (src.ndim()) {
|
||||||
case 1:
|
case 1:
|
||||||
copy_general_dim1<SrcT, DstT>(src, dst);
|
copy_general_dim1<SrcT, DstT, stride_t>(
|
||||||
|
src, dst, data_shape, i_strides, i_offset);
|
||||||
return;
|
return;
|
||||||
case 2:
|
case 2:
|
||||||
copy_general_dim2<SrcT, DstT>(src, dst);
|
copy_general_dim2<SrcT, DstT, stride_t>(
|
||||||
|
src, dst, data_shape, i_strides, i_offset);
|
||||||
return;
|
return;
|
||||||
case 3:
|
case 3:
|
||||||
copy_general_dim3<SrcT, DstT>(src, dst);
|
copy_general_dim3<SrcT, DstT, stride_t>(
|
||||||
|
src, dst, data_shape, i_strides, i_offset);
|
||||||
return;
|
return;
|
||||||
case 4:
|
case 4:
|
||||||
copy_general_dim4<SrcT, DstT>(src, dst);
|
copy_general_dim4<SrcT, DstT, stride_t>(
|
||||||
|
src, dst, data_shape, i_strides, i_offset);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto src_ptr = src.data<SrcT>();
|
auto src_ptr = src.data<SrcT>() + i_offset;
|
||||||
auto dst_ptr = dst.data<DstT>();
|
auto dst_ptr = dst.data<DstT>();
|
||||||
for (size_t i = 0; i < dst.size(); ++i) {
|
for (size_t i = 0; i < dst.size(); ++i) {
|
||||||
size_t src_elem = elem_to_loc(i, src.shape(), src.strides());
|
stride_t src_elem = elem_to_loc(i, data_shape, i_strides);
|
||||||
dst_ptr[i] = static_cast<DstT>(src_ptr[src_elem]);
|
dst_ptr[i] = static_cast<DstT>(src_ptr[src_elem]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename SrcT, typename DstT, int D>
|
template <typename SrcT, typename DstT>
|
||||||
|
inline void copy_general(const array& src, array& dst) {
|
||||||
|
return copy_general<SrcT, DstT, size_t>(
|
||||||
|
src, dst, src.shape(), src.strides(), 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename SrcT, typename DstT, typename stride_t>
|
||||||
|
inline void copy_general(
|
||||||
|
const array& src,
|
||||||
|
array& dst,
|
||||||
|
const std::vector<int>& data_shape,
|
||||||
|
const std::vector<stride_t>& i_strides,
|
||||||
|
const std::vector<stride_t>& o_strides,
|
||||||
|
int64_t i_offset,
|
||||||
|
int64_t o_offset) {
|
||||||
|
return copy_general<SrcT, DstT, stride_t>(
|
||||||
|
src, dst, data_shape, i_strides, i_offset);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename SrcT, typename DstT, typename stride_t, int D>
|
||||||
inline void copy_general_general_dims(
|
inline void copy_general_general_dims(
|
||||||
const array& src,
|
const array& src,
|
||||||
array& dst,
|
array& dst,
|
||||||
size_t offset_src,
|
const std::vector<int>& data_shape,
|
||||||
size_t offset_dst) {
|
const std::vector<stride_t>& i_strides,
|
||||||
|
const std::vector<stride_t>& o_strides,
|
||||||
|
stride_t i_offset,
|
||||||
|
stride_t o_offset) {
|
||||||
if constexpr (D > 1) {
|
if constexpr (D > 1) {
|
||||||
int axis = src.ndim() - D;
|
int axis = src.ndim() - D;
|
||||||
auto stride_src = src.strides()[axis];
|
auto stride_src = i_strides[axis];
|
||||||
auto stride_dst = dst.strides()[axis];
|
auto stride_dst = o_strides[axis];
|
||||||
auto N = src.shape(axis);
|
auto N = data_shape[axis];
|
||||||
for (int i = 0; i < N; i++) {
|
for (int i = 0; i < N; i++) {
|
||||||
copy_general_general_dims<SrcT, DstT, D - 1>(
|
copy_general_general_dims<SrcT, DstT, stride_t, D - 1>(
|
||||||
src, dst, offset_src, offset_dst);
|
src, dst, data_shape, i_strides, o_strides, i_offset, o_offset);
|
||||||
offset_src += stride_src;
|
i_offset += stride_src;
|
||||||
offset_dst += stride_dst;
|
o_offset += stride_dst;
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
int axis = src.ndim() - 1;
|
int axis = src.ndim() - 1;
|
||||||
auto stride_src = src.strides()[axis];
|
auto stride_src = i_strides[axis];
|
||||||
auto stride_dst = dst.strides()[axis];
|
auto stride_dst = o_strides[axis];
|
||||||
auto N = src.shape(axis);
|
auto N = data_shape[axis];
|
||||||
const SrcT* src_ptr = src.data<SrcT>() + offset_src;
|
const SrcT* src_ptr = src.data<SrcT>() + i_offset;
|
||||||
DstT* dst_ptr = dst.data<DstT>() + offset_dst;
|
DstT* dst_ptr = dst.data<DstT>() + o_offset;
|
||||||
for (int i = 0; i < N; i++) {
|
for (int i = 0; i < N; i++) {
|
||||||
*dst_ptr = static_cast<DstT>(*src_ptr);
|
*dst_ptr = static_cast<DstT>(*src_ptr);
|
||||||
src_ptr += stride_src;
|
src_ptr += stride_src;
|
||||||
@ -148,37 +223,56 @@ inline void copy_general_general_dims(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename SrcT, typename DstT>
|
template <typename SrcT, typename DstT, typename stride_t>
|
||||||
void copy_general_general(const array& src, array& dst) {
|
void copy_general_general(
|
||||||
|
const array& src,
|
||||||
|
array& dst,
|
||||||
|
const std::vector<int>& data_shape,
|
||||||
|
const std::vector<stride_t>& i_strides,
|
||||||
|
const std::vector<stride_t>& o_strides,
|
||||||
|
stride_t i_offset,
|
||||||
|
stride_t o_offset) {
|
||||||
switch (src.ndim()) {
|
switch (src.ndim()) {
|
||||||
case 1:
|
case 1:
|
||||||
copy_general_general_dims<SrcT, DstT, 1>(src, dst, 0, 0);
|
copy_general_general_dims<SrcT, DstT, stride_t, 1>(
|
||||||
|
src, dst, data_shape, i_strides, o_strides, i_offset, o_offset);
|
||||||
return;
|
return;
|
||||||
case 2:
|
case 2:
|
||||||
copy_general_general_dims<SrcT, DstT, 2>(src, dst, 0, 0);
|
copy_general_general_dims<SrcT, DstT, stride_t, 2>(
|
||||||
|
src, dst, data_shape, i_strides, o_strides, i_offset, o_offset);
|
||||||
return;
|
return;
|
||||||
case 3:
|
case 3:
|
||||||
copy_general_general_dims<SrcT, DstT, 3>(src, dst, 0, 0);
|
copy_general_general_dims<SrcT, DstT, stride_t, 3>(
|
||||||
|
src, dst, data_shape, i_strides, o_strides, i_offset, o_offset);
|
||||||
return;
|
return;
|
||||||
case 4:
|
case 4:
|
||||||
copy_general_general_dims<SrcT, DstT, 4>(src, dst, 0, 0);
|
copy_general_general_dims<SrcT, DstT, stride_t, 4>(
|
||||||
|
src, dst, data_shape, i_strides, o_strides, i_offset, o_offset);
|
||||||
return;
|
return;
|
||||||
case 5:
|
case 5:
|
||||||
copy_general_general_dims<SrcT, DstT, 5>(src, dst, 0, 0);
|
copy_general_general_dims<SrcT, DstT, stride_t, 5>(
|
||||||
|
src, dst, data_shape, i_strides, o_strides, i_offset, o_offset);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
int size = std::accumulate(
|
int size = std::accumulate(
|
||||||
src.shape().begin() - 5, src.shape().end(), 1, std::multiplies<int>());
|
data_shape.begin() - 5, data_shape.end(), 1, std::multiplies<int>());
|
||||||
for (int i = 0; i < src.size(); i += size) {
|
for (int i = 0; i < src.size(); i += size) {
|
||||||
size_t offset_src = elem_to_loc(i, src.shape(), src.strides());
|
stride_t src_offset = i_offset + elem_to_loc(i, data_shape, i_strides);
|
||||||
size_t offset_dst = elem_to_loc(i, dst.shape(), dst.strides());
|
stride_t dst_offset = o_offset + elem_to_loc(i, dst.shape(), o_strides);
|
||||||
copy_general_general_dims<SrcT, DstT, 5>(src, dst, offset_src, offset_dst);
|
copy_general_general_dims<SrcT, DstT, stride_t, 5>(
|
||||||
|
src, dst, data_shape, i_strides, o_strides, src_offset, dst_offset);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename SrcT, typename DstT>
|
template <typename SrcT, typename DstT>
|
||||||
void copy(const array& src, array& dst, CopyType ctype) {
|
inline void copy_general_general(const array& src, array& dst) {
|
||||||
|
return copy_general_general<SrcT, DstT, size_t>(
|
||||||
|
src, dst, src.shape(), src.strides(), dst.strides(), 0, 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename SrcT, typename DstT, typename... Args>
|
||||||
|
void copy(const array& src, array& dst, CopyType ctype, Args... args) {
|
||||||
switch (ctype) {
|
switch (ctype) {
|
||||||
case CopyType::Scalar:
|
case CopyType::Scalar:
|
||||||
copy_single<SrcT, DstT>(src, dst);
|
copy_single<SrcT, DstT>(src, dst);
|
||||||
@ -187,54 +281,103 @@ void copy(const array& src, array& dst, CopyType ctype) {
|
|||||||
copy_vector<SrcT, DstT>(src, dst);
|
copy_vector<SrcT, DstT>(src, dst);
|
||||||
return;
|
return;
|
||||||
case CopyType::General:
|
case CopyType::General:
|
||||||
copy_general<SrcT, DstT>(src, dst);
|
copy_general<SrcT, DstT>(src, dst, args...);
|
||||||
return;
|
return;
|
||||||
case CopyType::GeneralGeneral:
|
case CopyType::GeneralGeneral:
|
||||||
copy_general_general<SrcT, DstT>(src, dst);
|
copy_general_general<SrcT, DstT>(src, dst, args...);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename SrcT>
|
template <typename SrcT, typename... Args>
|
||||||
void copy(const array& src, array& dst, CopyType ctype) {
|
void copy(const array& src, array& dst, CopyType ctype, Args... args) {
|
||||||
switch (dst.dtype()) {
|
switch (dst.dtype()) {
|
||||||
case bool_:
|
case bool_:
|
||||||
copy<SrcT, bool>(src, dst, ctype);
|
copy<SrcT, bool>(src, dst, ctype, args...);
|
||||||
break;
|
break;
|
||||||
case uint8:
|
case uint8:
|
||||||
copy<SrcT, uint8_t>(src, dst, ctype);
|
copy<SrcT, uint8_t>(src, dst, ctype, args...);
|
||||||
break;
|
break;
|
||||||
case uint16:
|
case uint16:
|
||||||
copy<SrcT, uint16_t>(src, dst, ctype);
|
copy<SrcT, uint16_t>(src, dst, ctype, args...);
|
||||||
break;
|
break;
|
||||||
case uint32:
|
case uint32:
|
||||||
copy<SrcT, uint32_t>(src, dst, ctype);
|
copy<SrcT, uint32_t>(src, dst, ctype, args...);
|
||||||
break;
|
break;
|
||||||
case uint64:
|
case uint64:
|
||||||
copy<SrcT, uint64_t>(src, dst, ctype);
|
copy<SrcT, uint64_t>(src, dst, ctype, args...);
|
||||||
break;
|
break;
|
||||||
case int8:
|
case int8:
|
||||||
copy<SrcT, int8_t>(src, dst, ctype);
|
copy<SrcT, int8_t>(src, dst, ctype, args...);
|
||||||
break;
|
break;
|
||||||
case int16:
|
case int16:
|
||||||
copy<SrcT, int16_t>(src, dst, ctype);
|
copy<SrcT, int16_t>(src, dst, ctype, args...);
|
||||||
break;
|
break;
|
||||||
case int32:
|
case int32:
|
||||||
copy<SrcT, int32_t>(src, dst, ctype);
|
copy<SrcT, int32_t>(src, dst, ctype, args...);
|
||||||
break;
|
break;
|
||||||
case int64:
|
case int64:
|
||||||
copy<SrcT, int64_t>(src, dst, ctype);
|
copy<SrcT, int64_t>(src, dst, ctype, args...);
|
||||||
break;
|
break;
|
||||||
case float16:
|
case float16:
|
||||||
copy<SrcT, float16_t>(src, dst, ctype);
|
copy<SrcT, float16_t>(src, dst, ctype, args...);
|
||||||
break;
|
break;
|
||||||
case float32:
|
case float32:
|
||||||
copy<SrcT, float>(src, dst, ctype);
|
copy<SrcT, float>(src, dst, ctype, args...);
|
||||||
break;
|
break;
|
||||||
case bfloat16:
|
case bfloat16:
|
||||||
copy<SrcT, bfloat16_t>(src, dst, ctype);
|
copy<SrcT, bfloat16_t>(src, dst, ctype, args...);
|
||||||
break;
|
break;
|
||||||
case complex64:
|
case complex64:
|
||||||
copy<SrcT, complex64_t>(src, dst, ctype);
|
copy<SrcT, complex64_t>(src, dst, ctype, args...);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename... Args>
|
||||||
|
inline void copy_inplace_dispatch(
|
||||||
|
const array& src,
|
||||||
|
array& dst,
|
||||||
|
CopyType ctype,
|
||||||
|
Args... args) {
|
||||||
|
switch (src.dtype()) {
|
||||||
|
case bool_:
|
||||||
|
copy<bool>(src, dst, ctype, args...);
|
||||||
|
break;
|
||||||
|
case uint8:
|
||||||
|
copy<uint8_t>(src, dst, ctype, args...);
|
||||||
|
break;
|
||||||
|
case uint16:
|
||||||
|
copy<uint16_t>(src, dst, ctype, args...);
|
||||||
|
break;
|
||||||
|
case uint32:
|
||||||
|
copy<uint32_t>(src, dst, ctype, args...);
|
||||||
|
break;
|
||||||
|
case uint64:
|
||||||
|
copy<uint64_t>(src, dst, ctype, args...);
|
||||||
|
break;
|
||||||
|
case int8:
|
||||||
|
copy<int8_t>(src, dst, ctype, args...);
|
||||||
|
break;
|
||||||
|
case int16:
|
||||||
|
copy<int16_t>(src, dst, ctype, args...);
|
||||||
|
break;
|
||||||
|
case int32:
|
||||||
|
copy<int32_t>(src, dst, ctype, args...);
|
||||||
|
break;
|
||||||
|
case int64:
|
||||||
|
copy<int64_t>(src, dst, ctype, args...);
|
||||||
|
break;
|
||||||
|
case float16:
|
||||||
|
copy<float16_t>(src, dst, ctype, args...);
|
||||||
|
break;
|
||||||
|
case float32:
|
||||||
|
copy<float>(src, dst, ctype, args...);
|
||||||
|
break;
|
||||||
|
case bfloat16:
|
||||||
|
copy<bfloat16_t>(src, dst, ctype, args...);
|
||||||
|
break;
|
||||||
|
case complex64:
|
||||||
|
copy<complex64_t>(src, dst, ctype, args...);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -242,47 +385,7 @@ void copy(const array& src, array& dst, CopyType ctype) {
|
|||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
void copy_inplace(const array& src, array& dst, CopyType ctype) {
|
void copy_inplace(const array& src, array& dst, CopyType ctype) {
|
||||||
switch (src.dtype()) {
|
return copy_inplace_dispatch(src, dst, ctype);
|
||||||
case bool_:
|
|
||||||
copy<bool>(src, dst, ctype);
|
|
||||||
break;
|
|
||||||
case uint8:
|
|
||||||
copy<uint8_t>(src, dst, ctype);
|
|
||||||
break;
|
|
||||||
case uint16:
|
|
||||||
copy<uint16_t>(src, dst, ctype);
|
|
||||||
break;
|
|
||||||
case uint32:
|
|
||||||
copy<uint32_t>(src, dst, ctype);
|
|
||||||
break;
|
|
||||||
case uint64:
|
|
||||||
copy<uint64_t>(src, dst, ctype);
|
|
||||||
break;
|
|
||||||
case int8:
|
|
||||||
copy<int8_t>(src, dst, ctype);
|
|
||||||
break;
|
|
||||||
case int16:
|
|
||||||
copy<int16_t>(src, dst, ctype);
|
|
||||||
break;
|
|
||||||
case int32:
|
|
||||||
copy<int32_t>(src, dst, ctype);
|
|
||||||
break;
|
|
||||||
case int64:
|
|
||||||
copy<int64_t>(src, dst, ctype);
|
|
||||||
break;
|
|
||||||
case float16:
|
|
||||||
copy<float16_t>(src, dst, ctype);
|
|
||||||
break;
|
|
||||||
case float32:
|
|
||||||
copy<float>(src, dst, ctype);
|
|
||||||
break;
|
|
||||||
case bfloat16:
|
|
||||||
copy<bfloat16_t>(src, dst, ctype);
|
|
||||||
break;
|
|
||||||
case complex64:
|
|
||||||
copy<complex64_t>(src, dst, ctype);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void copy(const array& src, array& dst, CopyType ctype) {
|
void copy(const array& src, array& dst, CopyType ctype) {
|
||||||
@ -312,4 +415,62 @@ void copy(const array& src, array& dst, CopyType ctype) {
|
|||||||
copy_inplace(src, dst, ctype);
|
copy_inplace(src, dst, ctype);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename stride_t>
|
||||||
|
void copy_inplace(
|
||||||
|
const array& src,
|
||||||
|
array& dst,
|
||||||
|
const std::vector<int>& data_shape,
|
||||||
|
const std::vector<stride_t>& i_strides,
|
||||||
|
const std::vector<stride_t>& o_strides,
|
||||||
|
int64_t i_offset,
|
||||||
|
int64_t o_offset,
|
||||||
|
CopyType ctype) {
|
||||||
|
switch (ctype) {
|
||||||
|
case CopyType::General:
|
||||||
|
case CopyType::GeneralGeneral:
|
||||||
|
return copy_inplace_dispatch(
|
||||||
|
src,
|
||||||
|
dst,
|
||||||
|
ctype,
|
||||||
|
data_shape,
|
||||||
|
i_strides,
|
||||||
|
o_strides,
|
||||||
|
i_offset,
|
||||||
|
o_offset);
|
||||||
|
|
||||||
|
case CopyType::Scalar:
|
||||||
|
case CopyType::Vector:
|
||||||
|
return copy_inplace_dispatch(src, dst, ctype);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
void copy_inplace<int64_t>(
|
||||||
|
const array& src,
|
||||||
|
array& dst,
|
||||||
|
const std::vector<int>& data_shape,
|
||||||
|
const std::vector<int64_t>& i_strides,
|
||||||
|
const std::vector<int64_t>& o_strides,
|
||||||
|
int64_t i_offset,
|
||||||
|
int64_t o_offset,
|
||||||
|
CopyType ctype) {
|
||||||
|
switch (ctype) {
|
||||||
|
case CopyType::General:
|
||||||
|
case CopyType::GeneralGeneral:
|
||||||
|
return copy_inplace_dispatch(
|
||||||
|
src,
|
||||||
|
dst,
|
||||||
|
ctype,
|
||||||
|
data_shape,
|
||||||
|
i_strides,
|
||||||
|
o_strides,
|
||||||
|
i_offset,
|
||||||
|
o_offset);
|
||||||
|
|
||||||
|
case CopyType::Scalar:
|
||||||
|
case CopyType::Vector:
|
||||||
|
return copy_inplace_dispatch(src, dst, ctype);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
// Copyright © 2023 Apple Inc.
|
// Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
@ -26,4 +26,15 @@ enum class CopyType {
|
|||||||
void copy(const array& src, array& dst, CopyType ctype);
|
void copy(const array& src, array& dst, CopyType ctype);
|
||||||
void copy_inplace(const array& src, array& dst, CopyType ctype);
|
void copy_inplace(const array& src, array& dst, CopyType ctype);
|
||||||
|
|
||||||
|
template <typename stride_t>
|
||||||
|
void copy_inplace(
|
||||||
|
const array& src,
|
||||||
|
array& dst,
|
||||||
|
const std::vector<int>& data_shape,
|
||||||
|
const std::vector<stride_t>& i_strides,
|
||||||
|
const std::vector<stride_t>& o_strides,
|
||||||
|
int64_t i_offset,
|
||||||
|
int64_t o_offset,
|
||||||
|
CopyType ctype);
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
@ -94,6 +94,7 @@ DEFAULT(Sign)
|
|||||||
DEFAULT(Sin)
|
DEFAULT(Sin)
|
||||||
DEFAULT(Sinh)
|
DEFAULT(Sinh)
|
||||||
DEFAULT(Slice)
|
DEFAULT(Slice)
|
||||||
|
DEFAULT(SliceUpdate)
|
||||||
DEFAULT(Softmax)
|
DEFAULT(Softmax)
|
||||||
DEFAULT(Sort)
|
DEFAULT(Sort)
|
||||||
DEFAULT_MULTI(Split)
|
DEFAULT_MULTI(Split)
|
||||||
|
@ -651,36 +651,33 @@ void Sinh::eval(const std::vector<array>& inputs, array& out) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void Slice::eval(const std::vector<array>& inputs, array& out) {
|
std::tuple<bool, int64_t, std::vector<int64_t>> Slice::prepare_slice(
|
||||||
assert(inputs.size() == 1);
|
const array& in) {
|
||||||
if (out.size() == 0) {
|
int64_t data_offset = 0;
|
||||||
out.set_data(nullptr);
|
bool copy_needed = false;
|
||||||
return;
|
std::vector<int64_t> inp_strides(in.ndim(), 0);
|
||||||
}
|
|
||||||
auto& in = inputs[0];
|
|
||||||
auto strides = in.strides();
|
|
||||||
auto flags = in.flags();
|
|
||||||
size_t data_offset = 0;
|
|
||||||
for (int i = 0; i < in.ndim(); ++i) {
|
for (int i = 0; i < in.ndim(); ++i) {
|
||||||
data_offset += start_indices_[i] * in.strides()[i];
|
data_offset += start_indices_[i] * in.strides()[i];
|
||||||
strides[i] *= strides_[i];
|
inp_strides[i] = in.strides()[i] * strides_[i];
|
||||||
|
|
||||||
|
copy_needed |= strides_[i] < 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return std::make_tuple(copy_needed, data_offset, inp_strides);
|
||||||
|
}
|
||||||
|
|
||||||
|
void Slice::shared_buffer_slice(
|
||||||
|
const array& in,
|
||||||
|
const std::vector<size_t>& out_strides,
|
||||||
|
size_t data_offset,
|
||||||
|
array& out) {
|
||||||
// Compute row/col contiguity
|
// Compute row/col contiguity
|
||||||
size_t data_size = 1;
|
auto [data_size, is_row_contiguous, is_col_contiguous] =
|
||||||
size_t f_stride = 1;
|
check_contiguity(out.shape(), out_strides);
|
||||||
size_t b_stride = 1;
|
|
||||||
flags.row_contiguous = true;
|
auto flags = in.flags();
|
||||||
flags.col_contiguous = true;
|
flags.row_contiguous = is_row_contiguous;
|
||||||
for (int i = 0, ri = out.ndim() - 1; ri >= 0; i++, ri--) {
|
flags.col_contiguous = is_col_contiguous;
|
||||||
flags.col_contiguous &= strides[i] == f_stride || out.shape(i) == 1;
|
|
||||||
flags.row_contiguous &= strides[ri] == b_stride || out.shape(ri) == 1;
|
|
||||||
f_stride *= out.shape(i);
|
|
||||||
b_stride *= out.shape(ri);
|
|
||||||
if (strides[i] > 0) {
|
|
||||||
data_size *= out.shape(i);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (data_size == 1) {
|
if (data_size == 1) {
|
||||||
// Broadcasted scalar array is contiguous.
|
// Broadcasted scalar array is contiguous.
|
||||||
@ -694,7 +691,87 @@ void Slice::eval(const std::vector<array>& inputs, array& out) {
|
|||||||
flags.contiguous &= flags.row_contiguous || flags.col_contiguous;
|
flags.contiguous &= flags.row_contiguous || flags.col_contiguous;
|
||||||
}
|
}
|
||||||
|
|
||||||
out.copy_shared_buffer(in, strides, flags, data_size, data_offset);
|
out.copy_shared_buffer(in, out_strides, flags, data_size, data_offset);
|
||||||
|
}
|
||||||
|
|
||||||
|
void Slice::eval(const std::vector<array>& inputs, array& out) {
|
||||||
|
assert(inputs.size() == 1);
|
||||||
|
if (out.size() == 0) {
|
||||||
|
out.set_data(nullptr);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto& in = inputs[0];
|
||||||
|
|
||||||
|
// Calculate out strides, initial offset and if copy needs to be made
|
||||||
|
auto [copy_needed, data_offset, inp_strides] = prepare_slice(in);
|
||||||
|
|
||||||
|
// Do copy if needed
|
||||||
|
if (copy_needed) {
|
||||||
|
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||||
|
std::vector<int64_t> ostrides{out.strides().begin(), out.strides().end()};
|
||||||
|
copy_inplace<int64_t>(
|
||||||
|
/* const array& src = */ in,
|
||||||
|
/* array& dst = */ out,
|
||||||
|
/* const std::vector<int>& data_shape = */ out.shape(),
|
||||||
|
/* const std::vector<stride_t>& i_strides = */ inp_strides,
|
||||||
|
/* const std::vector<stride_t>& o_strides = */ ostrides,
|
||||||
|
/* int64_t i_offset = */ data_offset,
|
||||||
|
/* int64_t o_offset = */ 0,
|
||||||
|
/* CopyType ctype = */ CopyType::General);
|
||||||
|
} else {
|
||||||
|
std::vector<size_t> ostrides{inp_strides.begin(), inp_strides.end()};
|
||||||
|
shared_buffer_slice(in, ostrides, data_offset, out);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::tuple<int64_t, std::vector<int64_t>> SliceUpdate::prepare_slice(
|
||||||
|
const array& in) {
|
||||||
|
int64_t data_offset = 0;
|
||||||
|
std::vector<int64_t> inp_strides(in.ndim(), 0);
|
||||||
|
for (int i = 0; i < in.ndim(); ++i) {
|
||||||
|
data_offset += start_indices_[i] * in.strides()[i];
|
||||||
|
inp_strides[i] = in.strides()[i] * strides_[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
return std::make_tuple(data_offset, inp_strides);
|
||||||
|
}
|
||||||
|
|
||||||
|
void SliceUpdate::eval(const std::vector<array>& inputs, array& out) {
|
||||||
|
assert(inputs.size() == 2);
|
||||||
|
if (out.size() == 0) {
|
||||||
|
out.set_data(nullptr);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto& in = inputs[0];
|
||||||
|
auto& upd = inputs[1];
|
||||||
|
|
||||||
|
if (upd.size() == 0) {
|
||||||
|
out.copy_shared_buffer(in);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if materialization is needed
|
||||||
|
auto ctype = in.flags().contiguous && in.size() == in.data_size()
|
||||||
|
? CopyType::Vector
|
||||||
|
: CopyType::General;
|
||||||
|
copy(in, out, in.data_size() == 1 ? CopyType::Scalar : ctype);
|
||||||
|
|
||||||
|
// Calculate out strides, initial offset and if copy needs to be made
|
||||||
|
auto [data_offset, out_strides] = prepare_slice(out);
|
||||||
|
|
||||||
|
// Do copy
|
||||||
|
std::vector<int64_t> upd_strides{upd.strides().begin(), upd.strides().end()};
|
||||||
|
copy_inplace<int64_t>(
|
||||||
|
/* const array& src = */ upd,
|
||||||
|
/* array& dst = */ out,
|
||||||
|
/* const std::vector<int>& data_shape = */ upd.shape(),
|
||||||
|
/* const std::vector<stride_t>& i_strides = */ upd_strides,
|
||||||
|
/* const std::vector<stride_t>& o_strides = */ out_strides,
|
||||||
|
/* int64_t i_offset = */ 0,
|
||||||
|
/* int64_t o_offset = */ data_offset,
|
||||||
|
/* CopyType ctype = */ CopyType::GeneralGeneral);
|
||||||
}
|
}
|
||||||
|
|
||||||
void Split::eval(
|
void Split::eval(
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
// Copyright © 2023 Apple Inc.
|
// Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
@ -8,11 +8,12 @@
|
|||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
inline size_t elem_to_loc(
|
template <typename stride_t>
|
||||||
|
inline stride_t elem_to_loc(
|
||||||
int elem,
|
int elem,
|
||||||
const std::vector<int>& shape,
|
const std::vector<int>& shape,
|
||||||
const std::vector<size_t>& strides) {
|
const std::vector<stride_t>& strides) {
|
||||||
size_t loc = 0;
|
stride_t loc = 0;
|
||||||
for (int i = shape.size() - 1; i >= 0; --i) {
|
for (int i = shape.size() - 1; i >= 0; --i) {
|
||||||
auto q_and_r = ldiv(elem, shape[i]);
|
auto q_and_r = ldiv(elem, shape[i]);
|
||||||
loc += q_and_r.rem * strides[i];
|
loc += q_and_r.rem * strides[i];
|
||||||
@ -34,10 +35,11 @@ inline size_t elem_to_loc(int elem, const array& a) {
|
|||||||
//
|
//
|
||||||
// When multiple arrays are passed they should all have the same shape. The
|
// When multiple arrays are passed they should all have the same shape. The
|
||||||
// collapsed axes are also the same so one shape is returned.
|
// collapsed axes are also the same so one shape is returned.
|
||||||
inline std::tuple<std::vector<int>, std::vector<std::vector<size_t>>>
|
template <typename stride_t>
|
||||||
|
inline std::tuple<std::vector<int>, std::vector<std::vector<stride_t>>>
|
||||||
collapse_contiguous_dims(
|
collapse_contiguous_dims(
|
||||||
const std::vector<int>& shape,
|
const std::vector<int>& shape,
|
||||||
const std::vector<std::vector<size_t>> strides) {
|
const std::vector<std::vector<stride_t>> strides) {
|
||||||
// Make a vector that has axes separated with -1. Collapse all axes between
|
// Make a vector that has axes separated with -1. Collapse all axes between
|
||||||
// -1.
|
// -1.
|
||||||
std::vector<int> to_collapse;
|
std::vector<int> to_collapse;
|
||||||
@ -45,7 +47,7 @@ collapse_contiguous_dims(
|
|||||||
to_collapse.push_back(0);
|
to_collapse.push_back(0);
|
||||||
for (int i = 1; i < shape.size(); i++) {
|
for (int i = 1; i < shape.size(); i++) {
|
||||||
bool contiguous = true;
|
bool contiguous = true;
|
||||||
for (const std::vector<size_t>& st : strides) {
|
for (const std::vector<stride_t>& st : strides) {
|
||||||
if (st[i] * shape[i] != st[i - 1]) {
|
if (st[i] * shape[i] != st[i - 1]) {
|
||||||
contiguous = false;
|
contiguous = false;
|
||||||
}
|
}
|
||||||
@ -62,7 +64,7 @@ collapse_contiguous_dims(
|
|||||||
}
|
}
|
||||||
|
|
||||||
std::vector<int> out_shape;
|
std::vector<int> out_shape;
|
||||||
std::vector<std::vector<size_t>> out_strides(strides.size());
|
std::vector<std::vector<stride_t>> out_strides(strides.size());
|
||||||
for (int i = 0; i < to_collapse.size(); i++) {
|
for (int i = 0; i < to_collapse.size(); i++) {
|
||||||
int current_shape = shape[to_collapse[i]];
|
int current_shape = shape[to_collapse[i]];
|
||||||
while (to_collapse[++i] != -1) {
|
while (to_collapse[++i] != -1) {
|
||||||
@ -70,7 +72,7 @@ collapse_contiguous_dims(
|
|||||||
}
|
}
|
||||||
out_shape.push_back(current_shape);
|
out_shape.push_back(current_shape);
|
||||||
for (int j = 0; j < strides.size(); j++) {
|
for (int j = 0; j < strides.size(); j++) {
|
||||||
const std::vector<size_t>& st = strides[j];
|
const std::vector<stride_t>& st = strides[j];
|
||||||
out_strides[j].push_back(st[to_collapse[i - 1]]);
|
out_strides[j].push_back(st[to_collapse[i - 1]]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -94,4 +96,27 @@ collapse_contiguous_dims(Arrays... xs) {
|
|||||||
std::vector<array>{std::forward<Arrays>(xs)...});
|
std::vector<array>{std::forward<Arrays>(xs)...});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename stride_t>
|
||||||
|
inline auto check_contiguity(
|
||||||
|
const std::vector<int>& shape,
|
||||||
|
const std::vector<stride_t>& strides) {
|
||||||
|
size_t data_size = 1;
|
||||||
|
size_t f_stride = 1;
|
||||||
|
size_t b_stride = 1;
|
||||||
|
bool is_row_contiguous = true;
|
||||||
|
bool is_col_contiguous = true;
|
||||||
|
|
||||||
|
for (int i = 0, ri = shape.size() - 1; ri >= 0; i++, ri--) {
|
||||||
|
is_row_contiguous &= strides[i] == f_stride || shape[i] == 1;
|
||||||
|
is_col_contiguous &= strides[ri] == b_stride || shape[ri] == 1;
|
||||||
|
f_stride *= shape[i];
|
||||||
|
b_stride *= shape[ri];
|
||||||
|
if (strides[i] > 0) {
|
||||||
|
data_size *= shape[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return std::make_tuple(data_size, is_row_contiguous, is_col_contiguous);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
// Copyright © 2023 Apple Inc.
|
// Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
|
|
||||||
@ -37,15 +37,22 @@ void copy_gpu(const array& in, array& out, CopyType ctype) {
|
|||||||
copy_gpu(in, out, ctype, out.primitive().stream());
|
copy_gpu(in, out, ctype, out.primitive().stream());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename stride_t>
|
||||||
void copy_gpu_inplace(
|
void copy_gpu_inplace(
|
||||||
const array& in,
|
const array& in,
|
||||||
array& out,
|
array& out,
|
||||||
|
const std::vector<int>& data_shape,
|
||||||
|
const std::vector<stride_t>& strides_in_pre,
|
||||||
|
const std::vector<stride_t>& strides_out_pre,
|
||||||
|
int64_t inp_offset,
|
||||||
|
int64_t out_offset,
|
||||||
CopyType ctype,
|
CopyType ctype,
|
||||||
const Stream& s) {
|
const Stream& s) {
|
||||||
// Try to collapse contiguous dims
|
// Try to collapse contiguous dims
|
||||||
auto [shape, strides] = collapse_contiguous_dims(in, out);
|
auto [shape, strides] = collapse_contiguous_dims(
|
||||||
auto& strides_in = strides[0];
|
data_shape, std::vector{strides_in_pre, strides_out_pre});
|
||||||
auto& strides_out = strides[1];
|
auto& strides_in_ = strides[0];
|
||||||
|
auto& strides_out_ = strides[1];
|
||||||
|
|
||||||
auto& d = metal::device(s.device);
|
auto& d = metal::device(s.device);
|
||||||
std::ostringstream kname;
|
std::ostringstream kname;
|
||||||
@ -72,39 +79,44 @@ void copy_gpu_inplace(
|
|||||||
auto compute_encoder = d.get_command_encoder(s.index);
|
auto compute_encoder = d.get_command_encoder(s.index);
|
||||||
compute_encoder->setComputePipelineState(kernel);
|
compute_encoder->setComputePipelineState(kernel);
|
||||||
bool donate_in = in.data_shared_ptr() == nullptr;
|
bool donate_in = in.data_shared_ptr() == nullptr;
|
||||||
set_array_buffer(compute_encoder, donate_in ? out : in, 0);
|
|
||||||
set_array_buffer(compute_encoder, out, 1);
|
inp_offset *= size_of(in.dtype());
|
||||||
|
out_offset *= size_of(out.dtype());
|
||||||
|
|
||||||
|
set_array_buffer(compute_encoder, donate_in ? out : in, inp_offset, 0);
|
||||||
|
set_array_buffer(compute_encoder, out, out_offset, 1);
|
||||||
|
|
||||||
if (ctype == CopyType::General || ctype == CopyType::GeneralGeneral) {
|
if (ctype == CopyType::General || ctype == CopyType::GeneralGeneral) {
|
||||||
size_t ndim = shape.size();
|
int ndim = shape.size();
|
||||||
|
std::vector<int64_t> strides_in{strides_in_.begin(), strides_in_.end()};
|
||||||
|
std::vector<int64_t> strides_out{strides_out_.begin(), strides_out_.end()};
|
||||||
|
|
||||||
if (ndim > 3) {
|
if (ndim > 3) {
|
||||||
compute_encoder->setBytes(shape.data(), ndim * sizeof(int), 2);
|
set_vector_bytes(compute_encoder, shape, ndim, 2);
|
||||||
compute_encoder->setBytes(strides_in.data(), ndim * sizeof(size_t), 3);
|
|
||||||
if (ctype == CopyType::GeneralGeneral) {
|
|
||||||
compute_encoder->setBytes(strides_out.data(), ndim * sizeof(size_t), 4);
|
|
||||||
}
|
}
|
||||||
} else {
|
set_vector_bytes(compute_encoder, strides_in, ndim, 3);
|
||||||
// The shape is implicit in the grid for <= 3D
|
|
||||||
compute_encoder->setBytes(strides_in.data(), ndim * sizeof(size_t), 2);
|
|
||||||
if (ctype == CopyType::GeneralGeneral) {
|
if (ctype == CopyType::GeneralGeneral) {
|
||||||
compute_encoder->setBytes(strides_out.data(), ndim * sizeof(size_t), 3);
|
set_vector_bytes(compute_encoder, strides_out, ndim, 4);
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if (ndim > MAX_BINARY_SPECIALIZED_DIMS) {
|
if (ndim > MAX_BINARY_SPECIALIZED_DIMS) {
|
||||||
compute_encoder->setBytes(
|
compute_encoder->setBytes(&ndim, sizeof(int), 5);
|
||||||
&ndim, sizeof(int), (ctype == CopyType::GeneralGeneral) ? 5 : 4);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
int dim0 = ndim > 0 ? shape[ndim - 1] : 1;
|
int dim0 = ndim > 0 ? shape[ndim - 1] : 1;
|
||||||
int dim1 = ndim > 1 ? shape[ndim - 2] : 1;
|
int dim1 = ndim > 1 ? shape[ndim - 2] : 1;
|
||||||
int rest = in.size() / (dim0 * dim1);
|
|
||||||
|
size_t data_size = 1;
|
||||||
|
for (auto& s : shape)
|
||||||
|
data_size *= s;
|
||||||
|
int rest = data_size / (dim0 * dim1);
|
||||||
|
|
||||||
// NB assuming thread_group_size is a power of 2 larger than 32 x 32
|
// NB assuming thread_group_size is a power of 2 larger than 32 x 32
|
||||||
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||||
if (thread_group_size != 1024) {
|
if (thread_group_size != 1024) {
|
||||||
throw std::runtime_error("[Metal::copy] Must use 1024 sized block");
|
throw std::runtime_error("[Metal::copy] Must use 1024 sized block");
|
||||||
}
|
}
|
||||||
|
|
||||||
auto group_dims = get_block_dims(dim0, dim1, rest);
|
auto group_dims = get_block_dims(dim0, dim1, rest);
|
||||||
MTL::Size grid_dims = MTL::Size(dim0, dim1, rest);
|
MTL::Size grid_dims = MTL::Size(dim0, dim1, rest);
|
||||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||||
@ -120,4 +132,25 @@ void copy_gpu_inplace(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void copy_gpu_inplace(
|
||||||
|
const array& in,
|
||||||
|
array& out,
|
||||||
|
CopyType ctype,
|
||||||
|
const Stream& s) {
|
||||||
|
return copy_gpu_inplace(
|
||||||
|
in, out, in.shape(), in.strides(), out.strides(), 0, 0, ctype, s);
|
||||||
|
}
|
||||||
|
|
||||||
|
void copy_gpu_inplace(
|
||||||
|
const array& in,
|
||||||
|
array& out,
|
||||||
|
const std::vector<int64_t>& istride,
|
||||||
|
int64_t ioffset,
|
||||||
|
CopyType ctype,
|
||||||
|
const Stream& s) {
|
||||||
|
std::vector<int64_t> ostrides{out.strides().begin(), out.strides().end()};
|
||||||
|
return copy_gpu_inplace(
|
||||||
|
in, out, in.shape(), istride, ostrides, ioffset, 0, ctype, s);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
// Copyright © 2023 Apple Inc.
|
// Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
@ -7,12 +7,34 @@
|
|||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
|
// Generic copy inplace
|
||||||
|
template <typename stride_t>
|
||||||
|
void copy_gpu_inplace(
|
||||||
|
const array& in,
|
||||||
|
array& out,
|
||||||
|
const std::vector<int>& data_shape,
|
||||||
|
const std::vector<stride_t>& i_strides,
|
||||||
|
const std::vector<stride_t>& o_strides,
|
||||||
|
int64_t i_offset,
|
||||||
|
int64_t o_offset,
|
||||||
|
CopyType ctype,
|
||||||
|
const Stream& s);
|
||||||
|
|
||||||
void copy_gpu(const array& src, array& out, CopyType ctype, const Stream& s);
|
void copy_gpu(const array& src, array& out, CopyType ctype, const Stream& s);
|
||||||
void copy_gpu(const array& src, array& out, CopyType ctype);
|
void copy_gpu(const array& src, array& out, CopyType ctype);
|
||||||
|
|
||||||
void copy_gpu_inplace(
|
void copy_gpu_inplace(
|
||||||
const array& src,
|
const array& src,
|
||||||
array& out,
|
array& out,
|
||||||
CopyType ctype,
|
CopyType ctype,
|
||||||
const Stream& s);
|
const Stream& s);
|
||||||
|
|
||||||
|
void copy_gpu_inplace(
|
||||||
|
const array& in,
|
||||||
|
array& out,
|
||||||
|
const std::vector<int64_t>& istride,
|
||||||
|
int64_t ioffset,
|
||||||
|
CopyType ctype,
|
||||||
|
const Stream& s);
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
@ -1,29 +1,29 @@
|
|||||||
// Copyright © 2023 Apple Inc.
|
// Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
#include "mlx/backend/metal/kernels/bf16.h"
|
#include "mlx/backend/metal/kernels/bf16.h"
|
||||||
#include "mlx/backend/metal/kernels/utils.h"
|
#include "mlx/backend/metal/kernels/utils.h"
|
||||||
|
|
||||||
template <typename T, typename U>
|
template <typename T, typename U>
|
||||||
[[kernel]] void copy_s(
|
[[kernel]] void copy_s(
|
||||||
device const T* src,
|
device const T* src [[buffer(0)]],
|
||||||
device U* dst,
|
device U* dst [[buffer(1)]],
|
||||||
uint index [[thread_position_in_grid]]) {
|
uint index [[thread_position_in_grid]]) {
|
||||||
dst[index] = static_cast<U>(src[0]);
|
dst[index] = static_cast<U>(src[0]);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, typename U>
|
template <typename T, typename U>
|
||||||
[[kernel]] void copy_v(
|
[[kernel]] void copy_v(
|
||||||
device const T* src,
|
device const T* src [[buffer(0)]],
|
||||||
device U* dst,
|
device U* dst [[buffer(1)]],
|
||||||
uint index [[thread_position_in_grid]]) {
|
uint index [[thread_position_in_grid]]) {
|
||||||
dst[index] = static_cast<U>(src[index]);
|
dst[index] = static_cast<U>(src[index]);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, typename U>
|
template <typename T, typename U>
|
||||||
[[kernel]] void copy_g_nd1(
|
[[kernel]] void copy_g_nd1(
|
||||||
device const T* src,
|
device const T* src [[buffer(0)]],
|
||||||
device U* dst,
|
device U* dst [[buffer(1)]],
|
||||||
constant const size_t& src_stride,
|
constant const int64_t& src_stride [[buffer(3)]],
|
||||||
uint index [[thread_position_in_grid]]) {
|
uint index [[thread_position_in_grid]]) {
|
||||||
auto src_idx = elem_to_loc_1(index, src_stride);
|
auto src_idx = elem_to_loc_1(index, src_stride);
|
||||||
dst[index] = static_cast<U>(src[src_idx]);
|
dst[index] = static_cast<U>(src[src_idx]);
|
||||||
@ -31,61 +31,61 @@ template <typename T, typename U>
|
|||||||
|
|
||||||
template <typename T, typename U>
|
template <typename T, typename U>
|
||||||
[[kernel]] void copy_g_nd2(
|
[[kernel]] void copy_g_nd2(
|
||||||
device const T* src,
|
device const T* src [[buffer(0)]],
|
||||||
device U* dst,
|
device U* dst [[buffer(1)]],
|
||||||
constant const size_t src_strides[2],
|
constant const int64_t* src_strides [[buffer(3)]],
|
||||||
uint2 index [[thread_position_in_grid]],
|
uint2 index [[thread_position_in_grid]],
|
||||||
uint2 grid_dim [[threads_per_grid]]) {
|
uint2 grid_dim [[threads_per_grid]]) {
|
||||||
auto src_idx = elem_to_loc_2(index, src_strides);
|
auto src_idx = elem_to_loc_2(index, src_strides);
|
||||||
size_t dst_idx = index.x + (size_t)grid_dim.x * index.y;
|
int64_t dst_idx = index.x + (int64_t)grid_dim.x * index.y;
|
||||||
dst[dst_idx] = static_cast<U>(src[src_idx]);
|
dst[dst_idx] = static_cast<U>(src[src_idx]);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, typename U>
|
template <typename T, typename U>
|
||||||
[[kernel]] void copy_g_nd3(
|
[[kernel]] void copy_g_nd3(
|
||||||
device const T* src,
|
device const T* src [[buffer(0)]],
|
||||||
device U* dst,
|
device U* dst [[buffer(1)]],
|
||||||
constant const size_t src_strides[3],
|
constant const int64_t* src_strides [[buffer(3)]],
|
||||||
uint3 index [[thread_position_in_grid]],
|
uint3 index [[thread_position_in_grid]],
|
||||||
uint3 grid_dim [[threads_per_grid]]) {
|
uint3 grid_dim [[threads_per_grid]]) {
|
||||||
auto src_idx = elem_to_loc_3(index, src_strides);
|
auto src_idx = elem_to_loc_3(index, src_strides);
|
||||||
size_t dst_idx = index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
|
int64_t dst_idx = index.x + (int64_t)grid_dim.x * (index.y + (int64_t)grid_dim.y * index.z);
|
||||||
dst[dst_idx] = static_cast<U>(src[src_idx]);
|
dst[dst_idx] = static_cast<U>(src[src_idx]);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, typename U, int DIM>
|
template <typename T, typename U, int DIM>
|
||||||
[[kernel]] void copy_g_nd(
|
[[kernel]] void copy_g_nd(
|
||||||
device const T* src,
|
device const T* src [[buffer(0)]],
|
||||||
device U* dst,
|
device U* dst [[buffer(1)]],
|
||||||
constant const int src_shape[DIM],
|
constant const int* src_shape [[buffer(2)]],
|
||||||
constant const size_t src_strides[DIM],
|
constant const int64_t* src_strides [[buffer(3)]],
|
||||||
uint3 index [[thread_position_in_grid]],
|
uint3 index [[thread_position_in_grid]],
|
||||||
uint3 grid_dim [[threads_per_grid]]) {
|
uint3 grid_dim [[threads_per_grid]]) {
|
||||||
auto src_idx = elem_to_loc_nd<DIM>(index, src_shape, src_strides);
|
auto src_idx = elem_to_loc_nd<DIM>(index, src_shape, src_strides);
|
||||||
size_t dst_idx = index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
|
int64_t dst_idx = index.x + (int64_t)grid_dim.x * (index.y + (int64_t)grid_dim.y * index.z);
|
||||||
dst[dst_idx] = static_cast<U>(src[src_idx]);
|
dst[dst_idx] = static_cast<U>(src[src_idx]);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, typename U>
|
template <typename T, typename U>
|
||||||
[[kernel]] void copy_g(
|
[[kernel]] void copy_g(
|
||||||
device const T* src,
|
device const T* src [[buffer(0)]],
|
||||||
device U* dst,
|
device U* dst [[buffer(1)]],
|
||||||
constant const int* src_shape,
|
constant const int* src_shape [[buffer(2)]],
|
||||||
constant const size_t* src_strides,
|
constant const int64_t* src_strides [[buffer(3)]],
|
||||||
constant const int& ndim,
|
constant const int& ndim [[buffer(5)]],
|
||||||
uint3 index [[thread_position_in_grid]],
|
uint3 index [[thread_position_in_grid]],
|
||||||
uint3 grid_dim [[threads_per_grid]]) {
|
uint3 grid_dim [[threads_per_grid]]) {
|
||||||
auto src_idx = elem_to_loc(index, src_shape, src_strides, ndim);
|
auto src_idx = elem_to_loc(index, src_shape, src_strides, ndim);
|
||||||
size_t dst_idx = index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
|
int64_t dst_idx = index.x + (int64_t)grid_dim.x * (index.y + (int64_t)grid_dim.y * index.z);
|
||||||
dst[dst_idx] = static_cast<U>(src[src_idx]);
|
dst[dst_idx] = static_cast<U>(src[src_idx]);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, typename U>
|
template <typename T, typename U>
|
||||||
[[kernel]] void copy_gg_nd1(
|
[[kernel]] void copy_gg_nd1(
|
||||||
device const T* src,
|
device const T* src [[buffer(0)]],
|
||||||
device U* dst,
|
device U* dst [[buffer(1)]],
|
||||||
constant const size_t& src_stride,
|
constant const int64_t& src_stride [[buffer(3)]],
|
||||||
constant const size_t& dst_stride,
|
constant const int64_t& dst_stride [[buffer(4)]],
|
||||||
uint index [[thread_position_in_grid]]) {
|
uint index [[thread_position_in_grid]]) {
|
||||||
auto src_idx = elem_to_loc_1(index, src_stride);
|
auto src_idx = elem_to_loc_1(index, src_stride);
|
||||||
auto dst_idx = elem_to_loc_1(index, dst_stride);
|
auto dst_idx = elem_to_loc_1(index, dst_stride);
|
||||||
@ -94,10 +94,10 @@ template <typename T, typename U>
|
|||||||
|
|
||||||
template <typename T, typename U>
|
template <typename T, typename U>
|
||||||
[[kernel]] void copy_gg_nd2(
|
[[kernel]] void copy_gg_nd2(
|
||||||
device const T* src,
|
device const T* src [[buffer(0)]],
|
||||||
device U* dst,
|
device U* dst [[buffer(1)]],
|
||||||
constant const size_t src_strides[2],
|
constant const int64_t* src_strides [[buffer(3)]],
|
||||||
constant const size_t dst_strides[2],
|
constant const int64_t* dst_strides [[buffer(4)]],
|
||||||
uint2 index [[thread_position_in_grid]]) {
|
uint2 index [[thread_position_in_grid]]) {
|
||||||
auto src_idx = elem_to_loc_2(index, src_strides);
|
auto src_idx = elem_to_loc_2(index, src_strides);
|
||||||
auto dst_idx = elem_to_loc_2(index, dst_strides);
|
auto dst_idx = elem_to_loc_2(index, dst_strides);
|
||||||
@ -106,10 +106,10 @@ template <typename T, typename U>
|
|||||||
|
|
||||||
template <typename T, typename U>
|
template <typename T, typename U>
|
||||||
[[kernel]] void copy_gg_nd3(
|
[[kernel]] void copy_gg_nd3(
|
||||||
device const T* src,
|
device const T* src [[buffer(0)]],
|
||||||
device U* dst,
|
device U* dst [[buffer(1)]],
|
||||||
constant const size_t src_strides[3],
|
constant const int64_t* src_strides [[buffer(3)]],
|
||||||
constant const size_t dst_strides[3],
|
constant const int64_t* dst_strides [[buffer(4)]],
|
||||||
uint3 index [[thread_position_in_grid]]) {
|
uint3 index [[thread_position_in_grid]]) {
|
||||||
auto src_idx = elem_to_loc_3(index, src_strides);
|
auto src_idx = elem_to_loc_3(index, src_strides);
|
||||||
auto dst_idx = elem_to_loc_3(index, dst_strides);
|
auto dst_idx = elem_to_loc_3(index, dst_strides);
|
||||||
@ -118,11 +118,11 @@ template <typename T, typename U>
|
|||||||
|
|
||||||
template <typename T, typename U, int DIM>
|
template <typename T, typename U, int DIM>
|
||||||
[[kernel]] void copy_gg_nd(
|
[[kernel]] void copy_gg_nd(
|
||||||
device const T* src,
|
device const T* src [[buffer(0)]],
|
||||||
device U* dst,
|
device U* dst [[buffer(1)]],
|
||||||
constant const int src_shape[DIM],
|
constant const int* src_shape [[buffer(2)]],
|
||||||
constant const size_t src_strides[DIM],
|
constant const int64_t* src_strides [[buffer(3)]],
|
||||||
constant const size_t dst_strides[DIM],
|
constant const int64_t* dst_strides [[buffer(4)]],
|
||||||
uint3 index [[thread_position_in_grid]]) {
|
uint3 index [[thread_position_in_grid]]) {
|
||||||
auto src_idx = elem_to_loc_nd<DIM>(index, src_shape, src_strides);
|
auto src_idx = elem_to_loc_nd<DIM>(index, src_shape, src_strides);
|
||||||
auto dst_idx = elem_to_loc_nd<DIM>(index, src_shape, dst_strides);
|
auto dst_idx = elem_to_loc_nd<DIM>(index, src_shape, dst_strides);
|
||||||
@ -131,12 +131,12 @@ template <typename T, typename U, int DIM>
|
|||||||
|
|
||||||
template <typename T, typename U>
|
template <typename T, typename U>
|
||||||
[[kernel]] void copy_gg(
|
[[kernel]] void copy_gg(
|
||||||
device const T* src,
|
device const T* src [[buffer(0)]],
|
||||||
device U* dst,
|
device U* dst [[buffer(1)]],
|
||||||
constant const int* src_shape,
|
constant const int* src_shape [[buffer(2)]],
|
||||||
constant const size_t* src_strides,
|
constant const int64_t* src_strides [[buffer(3)]],
|
||||||
constant const size_t* dst_strides,
|
constant const int64_t* dst_strides [[buffer(4)]],
|
||||||
constant const int& ndim,
|
constant const int& ndim [[buffer(5)]],
|
||||||
uint3 index [[thread_position_in_grid]]) {
|
uint3 index [[thread_position_in_grid]]) {
|
||||||
auto src_idx = elem_to_loc(index, src_shape, src_strides, ndim);
|
auto src_idx = elem_to_loc(index, src_shape, src_strides, ndim);
|
||||||
auto dst_idx = elem_to_loc(index, src_shape, dst_strides, ndim);
|
auto dst_idx = elem_to_loc(index, src_shape, dst_strides, ndim);
|
||||||
@ -146,70 +146,70 @@ template <typename T, typename U>
|
|||||||
#define instantiate_copy(name, itype, otype, ctype) \
|
#define instantiate_copy(name, itype, otype, ctype) \
|
||||||
template [[host_name(name)]] \
|
template [[host_name(name)]] \
|
||||||
[[kernel]] void copy_##ctype<itype, otype>( \
|
[[kernel]] void copy_##ctype<itype, otype>( \
|
||||||
device const itype* src, \
|
device const itype* src [[buffer(0)]], \
|
||||||
device otype* dst, \
|
device otype* dst [[buffer(1)]], \
|
||||||
uint index [[thread_position_in_grid]]);
|
uint index [[thread_position_in_grid]]);
|
||||||
|
|
||||||
#define instantiate_copy_g_dim(name, itype, otype, dims) \
|
#define instantiate_copy_g_dim(name, itype, otype, dims) \
|
||||||
template [[host_name(name "_" #dims)]] \
|
template [[host_name(name "_" #dims)]] \
|
||||||
[[kernel]] void copy_g_nd<itype, otype, dims>( \
|
[[kernel]] void copy_g_nd<itype, otype, dims>( \
|
||||||
device const itype* src, \
|
device const itype* src [[buffer(0)]], \
|
||||||
device otype* dst, \
|
device otype* dst [[buffer(1)]], \
|
||||||
constant const int src_shape[dims], \
|
constant const int* src_shape [[buffer(2)]], \
|
||||||
constant const size_t src_strides[dims], \
|
constant const int64_t* src_strides [[buffer(3)]], \
|
||||||
uint3 index [[thread_position_in_grid]], \
|
uint3 index [[thread_position_in_grid]], \
|
||||||
uint3 grid_dim [[threads_per_grid]]); \
|
uint3 grid_dim [[threads_per_grid]]); \
|
||||||
template [[host_name("g" name "_" #dims)]] \
|
template [[host_name("g" name "_" #dims)]] \
|
||||||
[[kernel]] void copy_gg_nd<itype, otype, dims>( \
|
[[kernel]] void copy_gg_nd<itype, otype, dims>( \
|
||||||
device const itype* src, \
|
device const itype* src [[buffer(0)]], \
|
||||||
device otype* dst, \
|
device otype* dst [[buffer(1)]], \
|
||||||
constant const int src_shape[dims], \
|
constant const int* src_shape [[buffer(2)]], \
|
||||||
constant const size_t src_strides[dims], \
|
constant const int64_t* src_strides [[buffer(3)]], \
|
||||||
constant const size_t dst_strides[dims], \
|
constant const int64_t* dst_strides [[buffer(4)]], \
|
||||||
uint3 index [[thread_position_in_grid]]);
|
uint3 index [[thread_position_in_grid]]);
|
||||||
|
|
||||||
|
|
||||||
#define instantiate_copy_g_nd(name, itype, otype) \
|
#define instantiate_copy_g_nd(name, itype, otype) \
|
||||||
template [[host_name(name "_1")]] \
|
template [[host_name(name "_1")]] \
|
||||||
[[kernel]] void copy_g_nd1<itype, otype>( \
|
[[kernel]] void copy_g_nd1<itype, otype>( \
|
||||||
device const itype* src, \
|
device const itype* src [[buffer(0)]], \
|
||||||
device otype* dst, \
|
device otype* dst [[buffer(1)]], \
|
||||||
constant const size_t& src_stride, \
|
constant const int64_t& src_stride [[buffer(3)]], \
|
||||||
uint index [[thread_position_in_grid]]); \
|
uint index [[thread_position_in_grid]]); \
|
||||||
template [[host_name(name "_2")]] \
|
template [[host_name(name "_2")]] \
|
||||||
[[kernel]] void copy_g_nd2<itype, otype>( \
|
[[kernel]] void copy_g_nd2<itype, otype>( \
|
||||||
device const itype* src, \
|
device const itype* src [[buffer(0)]], \
|
||||||
device otype* dst, \
|
device otype* dst [[buffer(1)]], \
|
||||||
constant const size_t src_strides[2], \
|
constant const int64_t* src_strides [[buffer(3)]], \
|
||||||
uint2 index [[thread_position_in_grid]], \
|
uint2 index [[thread_position_in_grid]], \
|
||||||
uint2 grid_dim [[threads_per_grid]]); \
|
uint2 grid_dim [[threads_per_grid]]); \
|
||||||
template [[host_name(name "_3")]] \
|
template [[host_name(name "_3")]] \
|
||||||
[[kernel]] void copy_g_nd3<itype, otype>( \
|
[[kernel]] void copy_g_nd3<itype, otype>( \
|
||||||
device const itype* src, \
|
device const itype* src [[buffer(0)]], \
|
||||||
device otype* dst, \
|
device otype* dst [[buffer(1)]], \
|
||||||
constant const size_t src_strides[3], \
|
constant const int64_t* src_strides [[buffer(3)]], \
|
||||||
uint3 index [[thread_position_in_grid]], \
|
uint3 index [[thread_position_in_grid]], \
|
||||||
uint3 grid_dim [[threads_per_grid]]); \
|
uint3 grid_dim [[threads_per_grid]]); \
|
||||||
template [[host_name("g" name "_1")]] \
|
template [[host_name("g" name "_1")]] \
|
||||||
[[kernel]] void copy_gg_nd1<itype, otype>( \
|
[[kernel]] void copy_gg_nd1<itype, otype>( \
|
||||||
device const itype* src, \
|
device const itype* src [[buffer(0)]], \
|
||||||
device otype* dst, \
|
device otype* dst [[buffer(1)]], \
|
||||||
constant const size_t& src_stride, \
|
constant const int64_t& src_stride [[buffer(3)]], \
|
||||||
constant const size_t& dst_stride, \
|
constant const int64_t& dst_stride [[buffer(4)]], \
|
||||||
uint index [[thread_position_in_grid]]); \
|
uint index [[thread_position_in_grid]]); \
|
||||||
template [[host_name("g" name "_2")]] \
|
template [[host_name("g" name "_2")]] \
|
||||||
[[kernel]] void copy_gg_nd2<itype, otype>( \
|
[[kernel]] void copy_gg_nd2<itype, otype>( \
|
||||||
device const itype* src, \
|
device const itype* src [[buffer(0)]], \
|
||||||
device otype* dst, \
|
device otype* dst [[buffer(1)]], \
|
||||||
constant const size_t src_strides[2], \
|
constant const int64_t* src_strides [[buffer(3)]], \
|
||||||
constant const size_t dst_strides[2], \
|
constant const int64_t* dst_strides [[buffer(4)]], \
|
||||||
uint2 index [[thread_position_in_grid]]); \
|
uint2 index [[thread_position_in_grid]]); \
|
||||||
template [[host_name("g" name "_3")]] \
|
template [[host_name("g" name "_3")]] \
|
||||||
[[kernel]] void copy_gg_nd3<itype, otype>( \
|
[[kernel]] void copy_gg_nd3<itype, otype>( \
|
||||||
device const itype* src, \
|
device const itype* src [[buffer(0)]], \
|
||||||
device otype* dst, \
|
device otype* dst [[buffer(1)]], \
|
||||||
constant const size_t src_strides[3], \
|
constant const int64_t* src_strides [[buffer(3)]], \
|
||||||
constant const size_t dst_strides[3], \
|
constant const int64_t* dst_strides [[buffer(4)]], \
|
||||||
uint3 index [[thread_position_in_grid]]); \
|
uint3 index [[thread_position_in_grid]]); \
|
||||||
instantiate_copy_g_dim(name, itype, otype, 4) \
|
instantiate_copy_g_dim(name, itype, otype, 4) \
|
||||||
instantiate_copy_g_dim(name, itype, otype, 5)
|
instantiate_copy_g_dim(name, itype, otype, 5)
|
||||||
@ -218,21 +218,21 @@ template <typename T, typename U>
|
|||||||
#define instantiate_copy_g(name, itype, otype) \
|
#define instantiate_copy_g(name, itype, otype) \
|
||||||
template [[host_name(name)]] \
|
template [[host_name(name)]] \
|
||||||
[[kernel]] void copy_g<itype, otype>( \
|
[[kernel]] void copy_g<itype, otype>( \
|
||||||
device const itype* src, \
|
device const itype* src [[buffer(0)]], \
|
||||||
device otype* dst, \
|
device otype* dst [[buffer(1)]], \
|
||||||
constant const int* src_shape, \
|
constant const int* src_shape [[buffer(2)]], \
|
||||||
constant const size_t* src_strides, \
|
constant const int64_t* src_strides [[buffer(3)]], \
|
||||||
constant const int& ndim, \
|
constant const int& ndim [[buffer(5)]], \
|
||||||
uint3 index [[thread_position_in_grid]], \
|
uint3 index [[thread_position_in_grid]], \
|
||||||
uint3 grid_dim [[threads_per_grid]]); \
|
uint3 grid_dim [[threads_per_grid]]); \
|
||||||
template [[host_name("g" name)]] \
|
template [[host_name("g" name)]] \
|
||||||
[[kernel]] void copy_gg<itype, otype>( \
|
[[kernel]] void copy_gg<itype, otype>( \
|
||||||
device const itype* src, \
|
device const itype* src [[buffer(0)]], \
|
||||||
device otype* dst, \
|
device otype* dst [[buffer(1)]], \
|
||||||
constant const int* src_shape, \
|
constant const int* src_shape [[buffer(2)]], \
|
||||||
constant const size_t* src_strides, \
|
constant const int64_t* src_strides [[buffer(3)]], \
|
||||||
constant const size_t* dst_strides, \
|
constant const int64_t* dst_strides [[buffer(4)]], \
|
||||||
constant const int& ndim, \
|
constant const int& ndim [[buffer(5)]], \
|
||||||
uint3 index [[thread_position_in_grid]]);
|
uint3 index [[thread_position_in_grid]]);
|
||||||
|
|
||||||
#define instantiate_copy_all(tname, itype, otype) \
|
#define instantiate_copy_all(tname, itype, otype) \
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
// Copyright © 2023 Apple Inc.
|
// Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
@ -65,12 +65,18 @@ struct Limits<bool> {
|
|||||||
// Indexing utils
|
// Indexing utils
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
inline size_t elem_to_loc(
|
#define MLX_MTL_PRAGMA_UNROLL _Pragma("clang loop unroll(full)")
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
// Single Array with generic dims
|
||||||
|
|
||||||
|
template <typename stride_t>
|
||||||
|
METAL_FUNC stride_t elem_to_loc(
|
||||||
uint elem,
|
uint elem,
|
||||||
device const int* shape,
|
device const int* shape,
|
||||||
device const size_t* strides,
|
device const stride_t* strides,
|
||||||
int ndim) {
|
int ndim) {
|
||||||
size_t loc = 0;
|
stride_t loc = 0;
|
||||||
for (int i = ndim - 1; i >= 0 && elem > 0; --i) {
|
for (int i = ndim - 1; i >= 0 && elem > 0; --i) {
|
||||||
loc += (elem % shape[i]) * strides[i];
|
loc += (elem % shape[i]) * strides[i];
|
||||||
elem /= shape[i];
|
elem /= shape[i];
|
||||||
@ -78,12 +84,13 @@ inline size_t elem_to_loc(
|
|||||||
return loc;
|
return loc;
|
||||||
}
|
}
|
||||||
|
|
||||||
inline size_t elem_to_loc(
|
template <typename stride_t>
|
||||||
|
METAL_FUNC stride_t elem_to_loc(
|
||||||
uint elem,
|
uint elem,
|
||||||
constant const int* shape,
|
constant const int* shape,
|
||||||
constant const size_t* strides,
|
constant const stride_t* strides,
|
||||||
int ndim) {
|
int ndim) {
|
||||||
size_t loc = 0;
|
stride_t loc = 0;
|
||||||
for (int i = ndim - 1; i >= 0 && elem > 0; --i) {
|
for (int i = ndim - 1; i >= 0 && elem > 0; --i) {
|
||||||
loc += (elem % shape[i]) * strides[i];
|
loc += (elem % shape[i]) * strides[i];
|
||||||
elem /= shape[i];
|
elem /= shape[i];
|
||||||
@ -91,52 +98,59 @@ inline size_t elem_to_loc(
|
|||||||
return loc;
|
return loc;
|
||||||
}
|
}
|
||||||
|
|
||||||
template <int NDIM>
|
// Non templated version to handle arbitrary dims
|
||||||
inline uint3 elem_to_loc_3_nd(
|
template <typename stride_t>
|
||||||
|
METAL_FUNC stride_t elem_to_loc(
|
||||||
uint3 elem,
|
uint3 elem,
|
||||||
constant const int shape[NDIM],
|
constant const int* shape,
|
||||||
constant const size_t a_strides[NDIM],
|
constant const stride_t* strides,
|
||||||
constant const size_t b_strides[NDIM],
|
int ndim) {
|
||||||
constant const size_t c_strides[NDIM]) {
|
stride_t loc = elem.x * strides[ndim - 1] + elem.y * strides[ndim - 2];
|
||||||
uint3 loc = {
|
for (int d = ndim - 3; d >= 0; --d) {
|
||||||
static_cast<uint>(
|
loc += (elem.z % shape[d]) * strides[d];
|
||||||
elem.x * a_strides[NDIM - 1] + elem.y * a_strides[NDIM - 2]),
|
|
||||||
static_cast<uint>(
|
|
||||||
elem.x * b_strides[NDIM - 1] + elem.y * b_strides[NDIM - 2]),
|
|
||||||
static_cast<uint>(
|
|
||||||
elem.x * c_strides[NDIM - 1] + elem.y * c_strides[NDIM - 2])};
|
|
||||||
for (int d = NDIM - 3; d >= 0; --d) {
|
|
||||||
uint l = elem.z % shape[d];
|
|
||||||
loc.x += l * a_strides[d];
|
|
||||||
loc.y += l * b_strides[d];
|
|
||||||
loc.z += l * c_strides[d];
|
|
||||||
elem.z /= shape[d];
|
elem.z /= shape[d];
|
||||||
}
|
}
|
||||||
return loc;
|
return loc;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
// Single Array with fixed N dims
|
||||||
|
|
||||||
|
template <typename stride_t>
|
||||||
|
METAL_FUNC stride_t elem_to_loc_1(uint elem, constant const stride_t& stride) {
|
||||||
|
return elem * stride;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename stride_t>
|
||||||
|
METAL_FUNC stride_t
|
||||||
|
elem_to_loc_2(uint2 elem, constant const stride_t strides[2]) {
|
||||||
|
return elem.x * strides[1] + elem.y * strides[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename stride_t>
|
||||||
|
METAL_FUNC stride_t
|
||||||
|
elem_to_loc_3(uint3 elem, constant const stride_t strides[3]) {
|
||||||
|
return elem.x * strides[2] + elem.y * strides[1] + elem.z * strides[0];
|
||||||
|
}
|
||||||
|
|
||||||
template <int NDIM>
|
template <int NDIM>
|
||||||
inline uint2 elem_to_loc_2_nd(
|
METAL_FUNC size_t elem_to_loc_nd(
|
||||||
uint3 elem,
|
uint elem,
|
||||||
constant const int shape[NDIM],
|
device const int* shape,
|
||||||
constant const size_t a_strides[NDIM],
|
device const size_t* strides) {
|
||||||
constant const size_t b_strides[NDIM]) {
|
size_t loc = (elem % shape[NDIM - 1]) * strides[NDIM - 1];
|
||||||
uint2 loc = {
|
|
||||||
static_cast<uint>(
|
MLX_MTL_PRAGMA_UNROLL
|
||||||
elem.x * a_strides[NDIM - 1] + elem.y * a_strides[NDIM - 2]),
|
for (int d = NDIM - 2; d >= 0; --d) {
|
||||||
static_cast<uint>(
|
elem /= shape[d + 1];
|
||||||
elem.x * b_strides[NDIM - 1] + elem.y * b_strides[NDIM - 2])};
|
loc += (elem % shape[d]) * strides[d];
|
||||||
for (int d = NDIM - 3; d >= 0; --d) {
|
|
||||||
uint l = elem.z % shape[d];
|
|
||||||
loc.x += l * a_strides[d];
|
|
||||||
loc.y += l * b_strides[d];
|
|
||||||
elem.z /= shape[d];
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return loc;
|
return loc;
|
||||||
}
|
}
|
||||||
|
|
||||||
template <int NDIM>
|
template <int NDIM>
|
||||||
inline size_t elem_to_loc_nd(
|
METAL_FUNC size_t elem_to_loc_nd(
|
||||||
uint3 elem,
|
uint3 elem,
|
||||||
constant const int shape[NDIM],
|
constant const int shape[NDIM],
|
||||||
constant const size_t strides[NDIM]) {
|
constant const size_t strides[NDIM]) {
|
||||||
@ -148,33 +162,59 @@ inline size_t elem_to_loc_nd(
|
|||||||
return loc;
|
return loc;
|
||||||
}
|
}
|
||||||
|
|
||||||
inline size_t elem_to_loc_1(uint elem, constant const size_t& stride) {
|
template <int NDIM>
|
||||||
return elem * stride;
|
METAL_FUNC int64_t elem_to_loc_nd(
|
||||||
|
uint elem,
|
||||||
|
constant const int shape[NDIM],
|
||||||
|
constant const int64_t strides[NDIM]) {
|
||||||
|
int64_t loc = (elem % shape[NDIM - 1]) * strides[NDIM - 1];
|
||||||
|
|
||||||
|
MLX_MTL_PRAGMA_UNROLL
|
||||||
|
for (int d = NDIM - 2; d >= 0; --d) {
|
||||||
|
elem /= shape[d + 1];
|
||||||
|
loc += (elem % shape[d]) * strides[d];
|
||||||
|
}
|
||||||
|
|
||||||
|
return loc;
|
||||||
}
|
}
|
||||||
|
|
||||||
inline size_t elem_to_loc_2(uint2 elem, constant const size_t strides[2]) {
|
template <int NDIM>
|
||||||
return elem.x * strides[1] + elem.y * strides[0];
|
METAL_FUNC int64_t elem_to_loc_nd(
|
||||||
}
|
|
||||||
|
|
||||||
inline size_t elem_to_loc_3(uint3 elem, constant const size_t strides[3]) {
|
|
||||||
return elem.x * strides[2] + elem.y * strides[1] + elem.z * strides[0];
|
|
||||||
}
|
|
||||||
|
|
||||||
// Non templated version to handle arbitrary dims
|
|
||||||
inline size_t elem_to_loc(
|
|
||||||
uint3 elem,
|
uint3 elem,
|
||||||
constant const int* shape,
|
constant const int shape[NDIM],
|
||||||
constant const size_t* strides,
|
constant const int64_t strides[NDIM]) {
|
||||||
int ndim) {
|
int64_t loc = elem.x * strides[NDIM - 1] + elem.y * strides[NDIM - 2];
|
||||||
size_t loc = elem.x * strides[ndim - 1] + elem.y * strides[ndim - 2];
|
for (int d = NDIM - 3; d >= 0; --d) {
|
||||||
for (int d = ndim - 3; d >= 0; --d) {
|
|
||||||
loc += (elem.z % shape[d]) * strides[d];
|
loc += (elem.z % shape[d]) * strides[d];
|
||||||
elem.z /= shape[d];
|
elem.z /= shape[d];
|
||||||
}
|
}
|
||||||
return loc;
|
return loc;
|
||||||
}
|
}
|
||||||
|
|
||||||
inline uint3 elem_to_loc_3_nd(
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
// Multiple Arrays with generic dims
|
||||||
|
|
||||||
|
METAL_FUNC uint2 elem_to_loc_2_nd(
|
||||||
|
uint3 elem,
|
||||||
|
constant const int* shape,
|
||||||
|
constant const size_t* a_strides,
|
||||||
|
constant const size_t* b_strides,
|
||||||
|
int ndim) {
|
||||||
|
uint2 loc = {
|
||||||
|
static_cast<uint>(
|
||||||
|
elem.x * a_strides[ndim - 1] + elem.y * a_strides[ndim - 2]),
|
||||||
|
static_cast<uint>(
|
||||||
|
elem.x * b_strides[ndim - 1] + elem.y * b_strides[ndim - 2])};
|
||||||
|
for (int d = ndim - 3; d >= 0; --d) {
|
||||||
|
uint l = elem.z % shape[d];
|
||||||
|
loc.x += l * a_strides[d];
|
||||||
|
loc.y += l * b_strides[d];
|
||||||
|
elem.z /= shape[d];
|
||||||
|
}
|
||||||
|
return loc;
|
||||||
|
}
|
||||||
|
|
||||||
|
METAL_FUNC uint3 elem_to_loc_3_nd(
|
||||||
uint3 elem,
|
uint3 elem,
|
||||||
constant const int* shape,
|
constant const int* shape,
|
||||||
constant const size_t* a_strides,
|
constant const size_t* a_strides,
|
||||||
@ -198,18 +238,21 @@ inline uint3 elem_to_loc_3_nd(
|
|||||||
return loc;
|
return loc;
|
||||||
}
|
}
|
||||||
|
|
||||||
inline uint2 elem_to_loc_2_nd(
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
// Multiple Arrays with fixed N dims
|
||||||
|
|
||||||
|
template <int NDIM>
|
||||||
|
METAL_FUNC uint2 elem_to_loc_2_nd(
|
||||||
uint3 elem,
|
uint3 elem,
|
||||||
constant const int* shape,
|
constant const int shape[NDIM],
|
||||||
constant const size_t* a_strides,
|
constant const size_t a_strides[NDIM],
|
||||||
constant const size_t* b_strides,
|
constant const size_t b_strides[NDIM]) {
|
||||||
int ndim) {
|
|
||||||
uint2 loc = {
|
uint2 loc = {
|
||||||
static_cast<uint>(
|
static_cast<uint>(
|
||||||
elem.x * a_strides[ndim - 1] + elem.y * a_strides[ndim - 2]),
|
elem.x * a_strides[NDIM - 1] + elem.y * a_strides[NDIM - 2]),
|
||||||
static_cast<uint>(
|
static_cast<uint>(
|
||||||
elem.x * b_strides[ndim - 1] + elem.y * b_strides[ndim - 2])};
|
elem.x * b_strides[NDIM - 1] + elem.y * b_strides[NDIM - 2])};
|
||||||
for (int d = ndim - 3; d >= 0; --d) {
|
for (int d = NDIM - 3; d >= 0; --d) {
|
||||||
uint l = elem.z % shape[d];
|
uint l = elem.z % shape[d];
|
||||||
loc.x += l * a_strides[d];
|
loc.x += l * a_strides[d];
|
||||||
loc.y += l * b_strides[d];
|
loc.y += l * b_strides[d];
|
||||||
@ -219,55 +262,26 @@ inline uint2 elem_to_loc_2_nd(
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <int NDIM>
|
template <int NDIM>
|
||||||
inline uint elem_to_loc_nd(
|
METAL_FUNC uint3 elem_to_loc_3_nd(
|
||||||
uint elem,
|
uint3 elem,
|
||||||
device const int* shape,
|
constant const int shape[NDIM],
|
||||||
device const size_t* strides);
|
constant const size_t a_strides[NDIM],
|
||||||
|
constant const size_t b_strides[NDIM],
|
||||||
template <>
|
constant const size_t c_strides[NDIM]) {
|
||||||
inline uint elem_to_loc_nd<1>(
|
uint3 loc = {
|
||||||
uint elem,
|
static_cast<uint>(
|
||||||
device const int* shape,
|
elem.x * a_strides[NDIM - 1] + elem.y * a_strides[NDIM - 2]),
|
||||||
device const size_t* strides) {
|
static_cast<uint>(
|
||||||
return (elem % shape[0]) * strides[0];
|
elem.x * b_strides[NDIM - 1] + elem.y * b_strides[NDIM - 2]),
|
||||||
}
|
static_cast<uint>(
|
||||||
|
elem.x * c_strides[NDIM - 1] + elem.y * c_strides[NDIM - 2])};
|
||||||
template <>
|
for (int d = NDIM - 3; d >= 0; --d) {
|
||||||
inline uint elem_to_loc_nd<2>(
|
uint l = elem.z % shape[d];
|
||||||
uint elem,
|
loc.x += l * a_strides[d];
|
||||||
device const int* shape,
|
loc.y += l * b_strides[d];
|
||||||
device const size_t* strides) {
|
loc.z += l * c_strides[d];
|
||||||
uint loc = (elem % shape[1]) * strides[1];
|
elem.z /= shape[d];
|
||||||
elem /= shape[1];
|
}
|
||||||
loc += (elem % shape[0]) * strides[0];
|
|
||||||
return loc;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <>
|
|
||||||
inline uint elem_to_loc_nd<3>(
|
|
||||||
uint elem,
|
|
||||||
device const int* shape,
|
|
||||||
device const size_t* strides) {
|
|
||||||
uint loc = (elem % shape[2]) * strides[2];
|
|
||||||
elem /= shape[2];
|
|
||||||
loc += (elem % shape[1]) * strides[1];
|
|
||||||
elem /= shape[1];
|
|
||||||
loc += (elem % shape[0]) * strides[0];
|
|
||||||
return loc;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <>
|
|
||||||
inline uint elem_to_loc_nd<4>(
|
|
||||||
uint elem,
|
|
||||||
device const int* shape,
|
|
||||||
device const size_t* strides) {
|
|
||||||
uint loc = (elem % shape[3]) * strides[3];
|
|
||||||
elem /= shape[3];
|
|
||||||
loc += (elem % shape[2]) * strides[2];
|
|
||||||
elem /= shape[2];
|
|
||||||
loc += (elem % shape[1]) * strides[1];
|
|
||||||
elem /= shape[1];
|
|
||||||
loc += (elem % shape[0]) * strides[0];
|
|
||||||
return loc;
|
return loc;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -206,7 +206,7 @@ inline auto collapse_batches(const array& a, const array& b) {
|
|||||||
std::vector<size_t> B_bstride{b.strides().begin(), b.strides().end() - 2};
|
std::vector<size_t> B_bstride{b.strides().begin(), b.strides().end() - 2};
|
||||||
|
|
||||||
auto [batch_shape, batch_strides] =
|
auto [batch_shape, batch_strides] =
|
||||||
collapse_contiguous_dims(A_bshape, {A_bstride, B_bstride});
|
collapse_contiguous_dims(A_bshape, std::vector{A_bstride, B_bstride});
|
||||||
|
|
||||||
auto A_batch_stride = batch_strides[0];
|
auto A_batch_stride = batch_strides[0];
|
||||||
auto B_batch_stride = batch_strides[1];
|
auto B_batch_stride = batch_strides[1];
|
||||||
@ -237,8 +237,8 @@ inline auto collapse_batches(const array& a, const array& b, const array& c) {
|
|||||||
std::vector<size_t> B_bstride{b.strides().begin(), b.strides().end() - 2};
|
std::vector<size_t> B_bstride{b.strides().begin(), b.strides().end() - 2};
|
||||||
std::vector<size_t> C_bstride{c.strides().begin(), c.strides().end() - 2};
|
std::vector<size_t> C_bstride{c.strides().begin(), c.strides().end() - 2};
|
||||||
|
|
||||||
auto [batch_shape, batch_strides] =
|
auto [batch_shape, batch_strides] = collapse_contiguous_dims(
|
||||||
collapse_contiguous_dims(A_bshape, {A_bstride, B_bstride, C_bstride});
|
A_bshape, std::vector{A_bstride, B_bstride, C_bstride});
|
||||||
|
|
||||||
auto A_batch_stride = batch_strides[0];
|
auto A_batch_stride = batch_strides[0];
|
||||||
auto B_batch_stride = batch_strides[1];
|
auto B_batch_stride = batch_strides[1];
|
||||||
|
@ -865,7 +865,73 @@ void Sqrt::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void Slice::eval_gpu(const std::vector<array>& inputs, array& out) {
|
void Slice::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
eval(inputs, out);
|
assert(inputs.size() == 1);
|
||||||
|
if (out.size() == 0) {
|
||||||
|
out.set_data(nullptr);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto& in = inputs[0];
|
||||||
|
|
||||||
|
// Calculate out strides, initial offset and if copy needs to be made
|
||||||
|
auto [copy_needed, data_offset, inp_strides] = prepare_slice(in);
|
||||||
|
|
||||||
|
// Do copy if needed
|
||||||
|
if (copy_needed) {
|
||||||
|
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||||
|
std::vector<int64_t> ostrides{out.strides().begin(), out.strides().end()};
|
||||||
|
copy_gpu_inplace(
|
||||||
|
/* const array& in = */ in,
|
||||||
|
/* array& out = */ out,
|
||||||
|
/* const std::vector<int>& data_shape = */ out.shape(),
|
||||||
|
/* const std::vector<stride_t>& i_strides = */ inp_strides,
|
||||||
|
/* const std::vector<stride_t>& o_strides = */ ostrides,
|
||||||
|
/* int64_t i_offset = */ data_offset,
|
||||||
|
/* int64_t o_offset = */ 0,
|
||||||
|
/* CopyType ctype = */ CopyType::General,
|
||||||
|
/* const Stream& s = */ stream());
|
||||||
|
} else {
|
||||||
|
std::vector<size_t> ostrides{inp_strides.begin(), inp_strides.end()};
|
||||||
|
shared_buffer_slice(in, ostrides, data_offset, out);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void SliceUpdate::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
assert(inputs.size() == 2);
|
||||||
|
if (out.size() == 0) {
|
||||||
|
out.set_data(nullptr);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto& in = inputs[0];
|
||||||
|
auto& upd = inputs[1];
|
||||||
|
|
||||||
|
if (upd.size() == 0) {
|
||||||
|
out.copy_shared_buffer(in);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if materialization is needed
|
||||||
|
auto ctype = in.flags().contiguous && in.size() == in.data_size()
|
||||||
|
? CopyType::Vector
|
||||||
|
: CopyType::General;
|
||||||
|
copy_gpu(in, out, in.data_size() == 1 ? CopyType::Scalar : ctype, stream());
|
||||||
|
|
||||||
|
// Calculate out strides, initial offset and if copy needs to be made
|
||||||
|
auto [data_offset, out_strides] = prepare_slice(out);
|
||||||
|
|
||||||
|
// Do copy
|
||||||
|
std::vector<int64_t> upd_strides{upd.strides().begin(), upd.strides().end()};
|
||||||
|
copy_gpu_inplace<int64_t>(
|
||||||
|
/* const array& src = */ upd,
|
||||||
|
/* array& dst = */ out,
|
||||||
|
/* const std::vector<int>& data_shape = */ upd.shape(),
|
||||||
|
/* const std::vector<stride_t>& i_strides = */ upd_strides,
|
||||||
|
/* const std::vector<stride_t>& o_strides = */ out_strides,
|
||||||
|
/* int64_t i_offset = */ 0,
|
||||||
|
/* int64_t o_offset = */ data_offset,
|
||||||
|
/* CopyType ctype = */ CopyType::GeneralGeneral,
|
||||||
|
/* const Stream& s = */ stream());
|
||||||
}
|
}
|
||||||
|
|
||||||
void StopGradient::eval_gpu(const std::vector<array>& inputs, array& out) {
|
void StopGradient::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
@ -9,16 +9,43 @@ namespace mlx::core {
|
|||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
void set_array_buffer(
|
inline void
|
||||||
MTL::ComputeCommandEncoder* enc,
|
set_array_buffer(MTL::ComputeCommandEncoder* enc, const array& a, int idx) {
|
||||||
const array& a,
|
|
||||||
int idx) {
|
|
||||||
auto a_buf = static_cast<const MTL::Buffer*>(a.buffer().ptr());
|
auto a_buf = static_cast<const MTL::Buffer*>(a.buffer().ptr());
|
||||||
auto offset = a.data<char>() -
|
auto offset = a.data<char>() -
|
||||||
static_cast<char*>(const_cast<MTL::Buffer*>(a_buf)->contents());
|
static_cast<char*>(const_cast<MTL::Buffer*>(a_buf)->contents());
|
||||||
enc->setBuffer(a_buf, offset, idx);
|
enc->setBuffer(a_buf, offset, idx);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
inline void set_array_buffer(
|
||||||
|
MTL::ComputeCommandEncoder* enc,
|
||||||
|
const array& a,
|
||||||
|
int64_t offset,
|
||||||
|
int idx) {
|
||||||
|
auto a_buf = static_cast<const MTL::Buffer*>(a.buffer().ptr());
|
||||||
|
auto base_offset = a.data<char>() -
|
||||||
|
static_cast<char*>(const_cast<MTL::Buffer*>(a_buf)->contents());
|
||||||
|
base_offset += offset;
|
||||||
|
enc->setBuffer(a_buf, base_offset, idx);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
inline void set_vector_bytes(
|
||||||
|
MTL::ComputeCommandEncoder* enc,
|
||||||
|
const std::vector<T>& vec,
|
||||||
|
size_t nelems,
|
||||||
|
int idx) {
|
||||||
|
enc->setBytes(vec.data(), nelems * sizeof(T), idx);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
inline void set_vector_bytes(
|
||||||
|
MTL::ComputeCommandEncoder* enc,
|
||||||
|
const std::vector<T>& vec,
|
||||||
|
int idx) {
|
||||||
|
return set_vector_bytes(enc, vec, vec.size(), idx);
|
||||||
|
}
|
||||||
|
|
||||||
std::string type_to_name(const array& a) {
|
std::string type_to_name(const array& a) {
|
||||||
std::string tname;
|
std::string tname;
|
||||||
switch (a.dtype()) {
|
switch (a.dtype()) {
|
||||||
|
@ -87,6 +87,7 @@ NO_GPU(Sign)
|
|||||||
NO_GPU(Sin)
|
NO_GPU(Sin)
|
||||||
NO_GPU(Sinh)
|
NO_GPU(Sinh)
|
||||||
NO_GPU(Slice)
|
NO_GPU(Slice)
|
||||||
|
NO_GPU(SliceUpdate)
|
||||||
NO_GPU(Softmax)
|
NO_GPU(Softmax)
|
||||||
NO_GPU(Sort)
|
NO_GPU(Sort)
|
||||||
NO_GPU_MULTI(Split)
|
NO_GPU_MULTI(Split)
|
||||||
|
212
mlx/ops.cpp
212
mlx/ops.cpp
@ -445,6 +445,60 @@ array expand_dims(
|
|||||||
return reshape(a, out_shape, s);
|
return reshape(a, out_shape, s);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Slice helper
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
inline auto normalize_slice(
|
||||||
|
const std::vector<int>& shape,
|
||||||
|
std::vector<int>& start,
|
||||||
|
std::vector<int>& stop,
|
||||||
|
std::vector<int>& strides) {
|
||||||
|
std::vector<int> out_shape(shape.size());
|
||||||
|
bool has_neg_strides = false;
|
||||||
|
|
||||||
|
for (int i = 0; i < shape.size(); ++i) {
|
||||||
|
// Following numpy docs
|
||||||
|
// Negative i and j are interpreted as n + i and n + j where n is
|
||||||
|
// the number of elements in the corresponding dimension. Negative
|
||||||
|
// k makes stepping go towards smaller indices
|
||||||
|
|
||||||
|
auto n = shape[i];
|
||||||
|
auto s = start[i];
|
||||||
|
s = s < 0 ? s + n : s;
|
||||||
|
auto e = stop[i];
|
||||||
|
e = e < 0 ? e + n : e;
|
||||||
|
|
||||||
|
// Note: -ve strides require start >= stop
|
||||||
|
if (strides[i] < 0) {
|
||||||
|
has_neg_strides = true;
|
||||||
|
|
||||||
|
// Clamp to bounds
|
||||||
|
auto st = std::min(s, n - 1);
|
||||||
|
auto ed = std::max(-1, e);
|
||||||
|
|
||||||
|
start[i] = st;
|
||||||
|
stop[i] = ed > st ? st : ed;
|
||||||
|
|
||||||
|
auto str = -strides[i];
|
||||||
|
out_shape[i] = (start[i] - stop[i] + str - 1) / str;
|
||||||
|
|
||||||
|
} else {
|
||||||
|
// Clamp to bounds
|
||||||
|
auto st = std::max(0, std::min(s, n));
|
||||||
|
auto ed = std::max(0, std::min(e, n));
|
||||||
|
|
||||||
|
start[i] = st;
|
||||||
|
stop[i] = ed < st ? st : ed;
|
||||||
|
|
||||||
|
out_shape[i] = (stop[i] - start[i] + strides[i] - 1) / strides[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return std::make_pair(has_neg_strides, out_shape);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
array slice(
|
array slice(
|
||||||
const array& a,
|
const array& a,
|
||||||
std::vector<int> start,
|
std::vector<int> start,
|
||||||
@ -459,113 +513,13 @@ array slice(
|
|||||||
throw std::invalid_argument(msg.str());
|
throw std::invalid_argument(msg.str());
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<int> negatively_strided_axes;
|
auto [has_neg_strides, out_shape] =
|
||||||
std::vector<std::vector<int>> negatively_strided_slices;
|
normalize_slice(a.shape(), start, stop, strides);
|
||||||
std::vector<int> out_shape(a.ndim());
|
|
||||||
for (int i = 0; i < a.ndim(); ++i) {
|
|
||||||
// Following numpy docs
|
|
||||||
// Negative i and j are interpreted as n + i and n + j where n is
|
|
||||||
// the number of elements in the corresponding dimension. Negative
|
|
||||||
// k makes stepping go towards smaller indices
|
|
||||||
|
|
||||||
auto n = a.shape(i);
|
if (!has_neg_strides && out_shape == a.shape()) {
|
||||||
auto s = start[i];
|
|
||||||
s = s < 0 ? s + n : s;
|
|
||||||
auto e = stop[i];
|
|
||||||
e = e < 0 ? e + n : e;
|
|
||||||
|
|
||||||
// Note: We pass positive strides to the primitive and then flip
|
|
||||||
// the axes later as needed
|
|
||||||
if (strides[i] < 0) {
|
|
||||||
negatively_strided_axes.push_back(i);
|
|
||||||
auto st = std::min(s, n - 1);
|
|
||||||
auto ed = std::max(e, -1);
|
|
||||||
negatively_strided_slices.push_back({st, ed, strides[i]});
|
|
||||||
start[i] = 0;
|
|
||||||
stop[i] = n;
|
|
||||||
strides[i] = 1;
|
|
||||||
} else {
|
|
||||||
start[i] = s;
|
|
||||||
stop[i] = e < s ? s : e;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Clamp to bounds
|
|
||||||
start[i] = std::max(0, std::min(start[i], n));
|
|
||||||
stop[i] = std::max(0, std::min(stop[i], n));
|
|
||||||
|
|
||||||
out_shape[i] = (stop[i] - start[i] + strides[i] - 1) / strides[i];
|
|
||||||
}
|
|
||||||
|
|
||||||
// If strides are negative, slice and then make a copy with axes flipped
|
|
||||||
if (negatively_strided_axes.size() > 0) {
|
|
||||||
// First, take the slice of the positively strided axes
|
|
||||||
auto out = array(
|
|
||||||
out_shape,
|
|
||||||
a.dtype(),
|
|
||||||
std::make_unique<Slice>(
|
|
||||||
to_stream(s),
|
|
||||||
std::move(start),
|
|
||||||
std::move(stop),
|
|
||||||
std::move(strides)),
|
|
||||||
{a});
|
|
||||||
|
|
||||||
std::vector<array> indices;
|
|
||||||
std::vector<int> slice_sizes = out.shape();
|
|
||||||
std::vector<int> t_axes(out.ndim(), -1);
|
|
||||||
std::vector<int> out_reshape(out.ndim(), -1);
|
|
||||||
|
|
||||||
int n_axes = negatively_strided_axes.size();
|
|
||||||
for (int i = 0; i < n_axes; i++) {
|
|
||||||
// Get axis and corresponding slice
|
|
||||||
auto ax = negatively_strided_axes[i];
|
|
||||||
auto sl = negatively_strided_slices[i];
|
|
||||||
|
|
||||||
// Get indices for the slice
|
|
||||||
auto ax_idx = arange(sl[0], sl[1], sl[2], s);
|
|
||||||
|
|
||||||
// Reshape indices for broadcast as needed
|
|
||||||
std::vector<int> ax_idx_shape(n_axes, 1);
|
|
||||||
ax_idx_shape[i] = ax_idx.size();
|
|
||||||
ax_idx = reshape(ax_idx, ax_idx_shape, s);
|
|
||||||
|
|
||||||
// Add indices to list
|
|
||||||
indices.push_back(ax_idx);
|
|
||||||
|
|
||||||
// Set slice size for axis
|
|
||||||
slice_sizes[ax] = 1;
|
|
||||||
|
|
||||||
// Gather moves the axis up, remainder needs to be squeezed
|
|
||||||
out_reshape[i] = indices[i].size();
|
|
||||||
|
|
||||||
// Gather moves the axis up, needs to be transposed
|
|
||||||
t_axes[ax] = i;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Prepare out_reshape to squeeze gathered dims
|
|
||||||
// Prepare to transpose dims as needed
|
|
||||||
int j = n_axes;
|
|
||||||
for (int i = 0; j < out.ndim() && i < out.ndim(); i++) {
|
|
||||||
if (t_axes[i] < 0) {
|
|
||||||
t_axes[i] = j;
|
|
||||||
out_reshape[j] = out_shape[i];
|
|
||||||
j++;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Gather
|
|
||||||
out = gather(out, indices, negatively_strided_axes, slice_sizes, s);
|
|
||||||
|
|
||||||
// Squeeze dims
|
|
||||||
out = reshape(out, out_reshape, s);
|
|
||||||
|
|
||||||
// Transpose dims
|
|
||||||
out = transpose(out, t_axes, s);
|
|
||||||
|
|
||||||
return out;
|
|
||||||
}
|
|
||||||
if (out_shape == a.shape()) {
|
|
||||||
return a;
|
return a;
|
||||||
}
|
}
|
||||||
|
|
||||||
return array(
|
return array(
|
||||||
out_shape,
|
out_shape,
|
||||||
a.dtype(),
|
a.dtype(),
|
||||||
@ -582,6 +536,56 @@ array slice(
|
|||||||
return slice(a, start, stop, std::vector<int>(a.ndim(), 1), to_stream(s));
|
return slice(a, start, stop, std::vector<int>(a.ndim(), 1), to_stream(s));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/** Update a slice from the source array */
|
||||||
|
array slice_update(
|
||||||
|
const array& src,
|
||||||
|
const array& update,
|
||||||
|
std::vector<int> start,
|
||||||
|
std::vector<int> stop,
|
||||||
|
std::vector<int> strides,
|
||||||
|
StreamOrDevice s /* = {} */) {
|
||||||
|
// Check dimensions
|
||||||
|
if (start.size() != src.ndim() || stop.size() != src.ndim() ||
|
||||||
|
strides.size() != src.ndim()) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[slice] Invalid number of indices or strides for "
|
||||||
|
<< "array with dimension " << src.ndim() << ".";
|
||||||
|
throw std::invalid_argument(msg.str());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Process slice dimensions
|
||||||
|
auto [has_neg_strides, upd_shape] =
|
||||||
|
normalize_slice(src.shape(), start, stop, strides);
|
||||||
|
|
||||||
|
// Broadcast update shape to slice shape
|
||||||
|
auto upd_shape_broadcast = broadcast_shapes(upd_shape, update.shape());
|
||||||
|
auto update_broadcasted = broadcast_to(update, upd_shape_broadcast, s);
|
||||||
|
|
||||||
|
// If the entire src is the slice, just return the update
|
||||||
|
if (!has_neg_strides && upd_shape == src.shape()) {
|
||||||
|
return astype(update_broadcasted, src.dtype(), s);
|
||||||
|
}
|
||||||
|
|
||||||
|
return array(
|
||||||
|
src.shape(),
|
||||||
|
src.dtype(),
|
||||||
|
std::make_unique<SliceUpdate>(
|
||||||
|
to_stream(s), std::move(start), std::move(stop), std::move(strides)),
|
||||||
|
{src, update});
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Update a slice from the source array with stride 1 in each dimension */
|
||||||
|
array slice_update(
|
||||||
|
const array& src,
|
||||||
|
const array& update,
|
||||||
|
std::vector<int> start,
|
||||||
|
std::vector<int> stop,
|
||||||
|
StreamOrDevice s /* = {} */) {
|
||||||
|
auto strides = std::vector<int>(src.ndim(), 1);
|
||||||
|
return slice_update(
|
||||||
|
src, update, std::move(start), std::move(stop), std::move(strides), s);
|
||||||
|
}
|
||||||
|
|
||||||
std::vector<array> split(
|
std::vector<array> split(
|
||||||
const array& a,
|
const array& a,
|
||||||
const std::vector<int>& indices,
|
const std::vector<int>& indices,
|
||||||
|
17
mlx/ops.h
17
mlx/ops.h
@ -177,6 +177,23 @@ array slice(
|
|||||||
const std::vector<int>& stop,
|
const std::vector<int>& stop,
|
||||||
StreamOrDevice s = {});
|
StreamOrDevice s = {});
|
||||||
|
|
||||||
|
/** Update a slice from the source array */
|
||||||
|
array slice_update(
|
||||||
|
const array& src,
|
||||||
|
const array& update,
|
||||||
|
std::vector<int> start,
|
||||||
|
std::vector<int> stop,
|
||||||
|
std::vector<int> strides,
|
||||||
|
StreamOrDevice s = {});
|
||||||
|
|
||||||
|
/** Update a slice from the source array with stride 1 in each dimension */
|
||||||
|
array slice_update(
|
||||||
|
const array& src,
|
||||||
|
const array& update,
|
||||||
|
std::vector<int> start,
|
||||||
|
std::vector<int> stop,
|
||||||
|
StreamOrDevice s = {});
|
||||||
|
|
||||||
/** Split an array into sub-arrays along a given axis. */
|
/** Split an array into sub-arrays along a given axis. */
|
||||||
std::vector<array>
|
std::vector<array>
|
||||||
split(const array& a, int num_splits, int axis, StreamOrDevice s = {});
|
split(const array& a, int num_splits, int axis, StreamOrDevice s = {});
|
||||||
|
@ -2849,6 +2849,114 @@ bool Slice::is_equivalent(const Primitive& other) const {
|
|||||||
end_indices_ == s_other.end_indices_ && strides_ == s_other.strides_);
|
end_indices_ == s_other.end_indices_ && strides_ == s_other.strides_);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::pair<std::vector<array>, std::vector<int>> SliceUpdate::vmap(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
const std::vector<int>& axes) {
|
||||||
|
assert(inputs.size() == 2);
|
||||||
|
assert(axes.size() == 2);
|
||||||
|
|
||||||
|
auto start = start_indices_;
|
||||||
|
auto stop = end_indices_;
|
||||||
|
auto strides = strides_;
|
||||||
|
|
||||||
|
auto src = inputs[0];
|
||||||
|
auto upd = inputs[1];
|
||||||
|
|
||||||
|
auto src_ax = axes[0];
|
||||||
|
auto upd_ax = axes[1];
|
||||||
|
|
||||||
|
// No vmapping needed
|
||||||
|
if (src_ax == -1 && upd_ax == -1) {
|
||||||
|
return {{slice_update(src, upd, start, stop, strides, stream())}, {-1}};
|
||||||
|
}
|
||||||
|
|
||||||
|
// Broadcast src
|
||||||
|
if (src_ax == -1) {
|
||||||
|
src = expand_dims(src, upd_ax, stream());
|
||||||
|
auto shape = src.shape();
|
||||||
|
shape[upd_ax] = upd.shape(upd_ax);
|
||||||
|
src = broadcast_to(src, shape, stream());
|
||||||
|
src_ax = upd_ax;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Broadcast upd
|
||||||
|
if (upd_ax == -1) {
|
||||||
|
upd = expand_dims(upd, src_ax, stream());
|
||||||
|
upd_ax = src_ax;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (src_ax != upd_ax) {
|
||||||
|
upd = moveaxis(upd, upd_ax, src_ax, stream());
|
||||||
|
}
|
||||||
|
|
||||||
|
start.insert(start.begin() + src_ax, 0);
|
||||||
|
stop.insert(stop.begin() + src_ax, src.shape(src_ax));
|
||||||
|
strides.insert(strides.begin() + src_ax, 1);
|
||||||
|
|
||||||
|
return {{slice_update(src, upd, start, stop, strides, stream())}, {src_ax}};
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<array> SliceUpdate::vjp(
|
||||||
|
const std::vector<array>& primals,
|
||||||
|
const std::vector<array>& cotangents,
|
||||||
|
const std::vector<int>& argnums,
|
||||||
|
const std::vector<array>&) {
|
||||||
|
// Check inputs
|
||||||
|
assert(primals.size() == 2);
|
||||||
|
|
||||||
|
auto& cotan = cotangents[0];
|
||||||
|
auto& src = primals[0];
|
||||||
|
auto& upd = primals[1];
|
||||||
|
|
||||||
|
std::vector<array> vjps;
|
||||||
|
|
||||||
|
for (int num : argnums) {
|
||||||
|
// Vjp for source
|
||||||
|
if (num == 0) {
|
||||||
|
auto grad = slice_update(
|
||||||
|
cotan,
|
||||||
|
zeros_like(upd, stream()),
|
||||||
|
start_indices_,
|
||||||
|
end_indices_,
|
||||||
|
strides_,
|
||||||
|
stream());
|
||||||
|
|
||||||
|
vjps.push_back(grad);
|
||||||
|
}
|
||||||
|
// Vjp fpr updates
|
||||||
|
else {
|
||||||
|
auto grad =
|
||||||
|
slice(cotan, start_indices_, end_indices_, strides_, stream());
|
||||||
|
|
||||||
|
vjps.push_back(grad);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return vjps;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<array> SliceUpdate::jvp(
|
||||||
|
const std::vector<array>& primals,
|
||||||
|
const std::vector<array>& tangents,
|
||||||
|
const std::vector<int>& argnums) {
|
||||||
|
// Check inputs
|
||||||
|
assert(primals.size() == 2);
|
||||||
|
return {slice_update(
|
||||||
|
tangents[0],
|
||||||
|
tangents[1],
|
||||||
|
start_indices_,
|
||||||
|
end_indices_,
|
||||||
|
strides_,
|
||||||
|
stream())};
|
||||||
|
}
|
||||||
|
|
||||||
|
bool SliceUpdate::is_equivalent(const Primitive& other) const {
|
||||||
|
const SliceUpdate& s_other = static_cast<const SliceUpdate&>(other);
|
||||||
|
return (
|
||||||
|
start_indices_ == s_other.start_indices_ &&
|
||||||
|
end_indices_ == s_other.end_indices_ && strides_ == s_other.strides_);
|
||||||
|
}
|
||||||
|
|
||||||
std::pair<std::vector<array>, std::vector<int>> Softmax::vmap(
|
std::pair<std::vector<array>, std::vector<int>> Softmax::vmap(
|
||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
const std::vector<int>& axes) {
|
const std::vector<int>& axes) {
|
||||||
|
@ -1660,6 +1660,44 @@ class Slice : public UnaryPrimitive {
|
|||||||
std::vector<int> strides_;
|
std::vector<int> strides_;
|
||||||
|
|
||||||
void eval(const std::vector<array>& inputs, array& out);
|
void eval(const std::vector<array>& inputs, array& out);
|
||||||
|
|
||||||
|
std::tuple<bool, int64_t, std::vector<int64_t>> prepare_slice(
|
||||||
|
const array& in);
|
||||||
|
void shared_buffer_slice(
|
||||||
|
const array& in,
|
||||||
|
const std::vector<size_t>& out_strides,
|
||||||
|
size_t data_offset,
|
||||||
|
array& out);
|
||||||
|
};
|
||||||
|
|
||||||
|
class SliceUpdate : public UnaryPrimitive {
|
||||||
|
public:
|
||||||
|
explicit SliceUpdate(
|
||||||
|
Stream stream,
|
||||||
|
const std::vector<int>& start_indices,
|
||||||
|
const std::vector<int>& end_indices,
|
||||||
|
const std::vector<int>& strides)
|
||||||
|
: UnaryPrimitive(stream),
|
||||||
|
start_indices_(start_indices),
|
||||||
|
end_indices_(end_indices),
|
||||||
|
strides_(strides){};
|
||||||
|
|
||||||
|
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||||
|
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||||
|
|
||||||
|
DEFINE_VMAP()
|
||||||
|
DEFINE_GRADS()
|
||||||
|
DEFINE_PRINT(SliceUpdate)
|
||||||
|
bool is_equivalent(const Primitive& other) const override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::vector<int> start_indices_;
|
||||||
|
std::vector<int> end_indices_;
|
||||||
|
std::vector<int> strides_;
|
||||||
|
|
||||||
|
void eval(const std::vector<array>& inputs, array& out);
|
||||||
|
|
||||||
|
std::tuple<int64_t, std::vector<int64_t>> prepare_slice(const array& in);
|
||||||
};
|
};
|
||||||
|
|
||||||
class Softmax : public UnaryPrimitive {
|
class Softmax : public UnaryPrimitive {
|
||||||
|
@ -329,4 +329,13 @@ std::ostream& operator<<(std::ostream& os, const std::vector<size_t>& v) {
|
|||||||
return os;
|
return os;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::ostream& operator<<(std::ostream& os, const std::vector<int64_t>& v) {
|
||||||
|
os << "(";
|
||||||
|
for (int i = 0; i < v.size(); ++i) {
|
||||||
|
os << v[i] << ((i == v.size() - 1) ? "" : ",");
|
||||||
|
}
|
||||||
|
os << ")";
|
||||||
|
return os;
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
// Copyright © 2023 Apple Inc.
|
// Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
@ -76,6 +76,7 @@ std::ostream& operator<<(std::ostream& os, const Dtype::Kind& k);
|
|||||||
std::ostream& operator<<(std::ostream& os, array a);
|
std::ostream& operator<<(std::ostream& os, array a);
|
||||||
std::ostream& operator<<(std::ostream& os, const std::vector<int>& v);
|
std::ostream& operator<<(std::ostream& os, const std::vector<int>& v);
|
||||||
std::ostream& operator<<(std::ostream& os, const std::vector<size_t>& v);
|
std::ostream& operator<<(std::ostream& os, const std::vector<size_t>& v);
|
||||||
|
std::ostream& operator<<(std::ostream& os, const std::vector<int64_t>& v);
|
||||||
inline std::ostream& operator<<(std::ostream& os, const complex64_t& v) {
|
inline std::ostream& operator<<(std::ostream& os, const complex64_t& v) {
|
||||||
return os << v.real() << (v.imag() >= 0 ? "+" : "") << v.imag() << "j";
|
return os << v.real() << (v.imag() >= 0 ? "+" : "") << v.imag() << "j";
|
||||||
}
|
}
|
||||||
|
@ -198,6 +198,31 @@ TEST_CASE("test slice") {
|
|||||||
CHECK(array_equal(out, array({0, 2, 4, 6}, {2, 2})).item<bool>());
|
CHECK(array_equal(out, array({0, 2, 4, 6}, {2, 2})).item<bool>());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_CASE("test slice update") {
|
||||||
|
array x = array({0., 0., 0., 0., 0., 0., 0., 0.}, {8}, float32);
|
||||||
|
array y = array(
|
||||||
|
{
|
||||||
|
1.,
|
||||||
|
2.,
|
||||||
|
3.,
|
||||||
|
4.,
|
||||||
|
},
|
||||||
|
{4},
|
||||||
|
float32);
|
||||||
|
|
||||||
|
auto out = slice_update(x, y, {2}, {6}, {1});
|
||||||
|
CHECK(array_equal(slice(out, {2}, {6}, {1}), y).item<bool>());
|
||||||
|
|
||||||
|
out = slice_update(x, y, {5}, {1}, {-1});
|
||||||
|
CHECK(array_equal(slice(out, {5}, {1}, {-1}), y).item<bool>());
|
||||||
|
|
||||||
|
x = reshape(x, {2, 4});
|
||||||
|
out = slice_update(x, y, {0, 0}, {2, 4}, {1, 1});
|
||||||
|
out = reshape(out, {8});
|
||||||
|
CHECK(array_equal(slice(out, {0}, {4}, {1}), y).item<bool>());
|
||||||
|
CHECK(array_equal(slice(out, {4}, {8}, {1}), y).item<bool>());
|
||||||
|
}
|
||||||
|
|
||||||
TEST_CASE("test split") {
|
TEST_CASE("test split") {
|
||||||
array x = array(1);
|
array x = array(1);
|
||||||
CHECK_THROWS(split(x, 0));
|
CHECK_THROWS(split(x, 0));
|
||||||
|
Loading…
Reference in New Issue
Block a user