mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 01:41:17 +08:00
Fix reshape copy bug (#1253)
This commit is contained in:
parent
bdb36c9a63
commit
03cf033f82
@ -4,6 +4,7 @@
|
|||||||
|
|
||||||
#include "mlx/allocator.h"
|
#include "mlx/allocator.h"
|
||||||
#include "mlx/backend/common/copy.h"
|
#include "mlx/backend/common/copy.h"
|
||||||
|
#include "mlx/backend/common/utils.h"
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
@ -142,29 +143,31 @@ void copy_general(
|
|||||||
const std::vector<int>& data_shape,
|
const std::vector<int>& data_shape,
|
||||||
const std::vector<stride_t>& i_strides,
|
const std::vector<stride_t>& i_strides,
|
||||||
int64_t i_offset) {
|
int64_t i_offset) {
|
||||||
switch (src.ndim()) {
|
auto [new_shape, new_strides] = collapse_contiguous_dims(
|
||||||
|
data_shape, std::vector<std::vector<stride_t>>{i_strides});
|
||||||
|
switch (new_shape.size()) {
|
||||||
case 1:
|
case 1:
|
||||||
copy_general_dim1<SrcT, DstT, stride_t>(
|
copy_general_dim1<SrcT, DstT, stride_t>(
|
||||||
src, dst, data_shape, i_strides, i_offset);
|
src, dst, new_shape, new_strides[0], i_offset);
|
||||||
return;
|
return;
|
||||||
case 2:
|
case 2:
|
||||||
copy_general_dim2<SrcT, DstT, stride_t>(
|
copy_general_dim2<SrcT, DstT, stride_t>(
|
||||||
src, dst, data_shape, i_strides, i_offset);
|
src, dst, new_shape, new_strides[0], i_offset);
|
||||||
return;
|
return;
|
||||||
case 3:
|
case 3:
|
||||||
copy_general_dim3<SrcT, DstT, stride_t>(
|
copy_general_dim3<SrcT, DstT, stride_t>(
|
||||||
src, dst, data_shape, i_strides, i_offset);
|
src, dst, new_shape, new_strides[0], i_offset);
|
||||||
return;
|
return;
|
||||||
case 4:
|
case 4:
|
||||||
copy_general_dim4<SrcT, DstT, stride_t>(
|
copy_general_dim4<SrcT, DstT, stride_t>(
|
||||||
src, dst, data_shape, i_strides, i_offset);
|
src, dst, new_shape, new_strides[0], i_offset);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto src_ptr = src.data<SrcT>() + i_offset;
|
auto src_ptr = src.data<SrcT>() + i_offset;
|
||||||
auto dst_ptr = dst.data<DstT>();
|
auto dst_ptr = dst.data<DstT>();
|
||||||
for (size_t i = 0; i < dst.size(); ++i) {
|
for (size_t i = 0; i < dst.size(); ++i) {
|
||||||
stride_t src_elem = elem_to_loc(i, data_shape, i_strides);
|
stride_t src_elem = elem_to_loc(i, new_shape, new_strides[0]);
|
||||||
dst_ptr[i] = static_cast<DstT>(src_ptr[src_elem]);
|
dst_ptr[i] = static_cast<DstT>(src_ptr[src_elem]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -195,10 +198,10 @@ inline void copy_general_general_dims(
|
|||||||
const std::vector<int>& data_shape,
|
const std::vector<int>& data_shape,
|
||||||
const std::vector<stride_t>& i_strides,
|
const std::vector<stride_t>& i_strides,
|
||||||
const std::vector<stride_t>& o_strides,
|
const std::vector<stride_t>& o_strides,
|
||||||
stride_t i_offset,
|
int64_t i_offset,
|
||||||
stride_t o_offset) {
|
int64_t o_offset) {
|
||||||
if constexpr (D > 1) {
|
if constexpr (D > 1) {
|
||||||
int axis = src.ndim() - D;
|
int axis = data_shape.size() - D;
|
||||||
auto stride_src = i_strides[axis];
|
auto stride_src = i_strides[axis];
|
||||||
auto stride_dst = o_strides[axis];
|
auto stride_dst = o_strides[axis];
|
||||||
auto N = data_shape[axis];
|
auto N = data_shape[axis];
|
||||||
@ -209,7 +212,7 @@ inline void copy_general_general_dims(
|
|||||||
o_offset += stride_dst;
|
o_offset += stride_dst;
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
int axis = src.ndim() - 1;
|
int axis = data_shape.size() - 1;
|
||||||
auto stride_src = i_strides[axis];
|
auto stride_src = i_strides[axis];
|
||||||
auto stride_dst = o_strides[axis];
|
auto stride_dst = o_strides[axis];
|
||||||
auto N = data_shape[axis];
|
auto N = data_shape[axis];
|
||||||
@ -230,38 +233,76 @@ void copy_general_general(
|
|||||||
const std::vector<int>& data_shape,
|
const std::vector<int>& data_shape,
|
||||||
const std::vector<stride_t>& i_strides,
|
const std::vector<stride_t>& i_strides,
|
||||||
const std::vector<stride_t>& o_strides,
|
const std::vector<stride_t>& o_strides,
|
||||||
stride_t i_offset,
|
int64_t i_offset,
|
||||||
stride_t o_offset) {
|
int64_t o_offset) {
|
||||||
switch (src.ndim()) {
|
auto [new_shape, new_strides] = collapse_contiguous_dims(
|
||||||
|
data_shape, std::vector<std::vector<stride_t>>{i_strides, o_strides});
|
||||||
|
switch (new_shape.size()) {
|
||||||
case 1:
|
case 1:
|
||||||
copy_general_general_dims<SrcT, DstT, stride_t, 1>(
|
copy_general_general_dims<SrcT, DstT, stride_t, 1>(
|
||||||
src, dst, data_shape, i_strides, o_strides, i_offset, o_offset);
|
src,
|
||||||
|
dst,
|
||||||
|
new_shape,
|
||||||
|
new_strides[0],
|
||||||
|
new_strides[1],
|
||||||
|
i_offset,
|
||||||
|
o_offset);
|
||||||
return;
|
return;
|
||||||
case 2:
|
case 2:
|
||||||
copy_general_general_dims<SrcT, DstT, stride_t, 2>(
|
copy_general_general_dims<SrcT, DstT, stride_t, 2>(
|
||||||
src, dst, data_shape, i_strides, o_strides, i_offset, o_offset);
|
src,
|
||||||
|
dst,
|
||||||
|
new_shape,
|
||||||
|
new_strides[0],
|
||||||
|
new_strides[1],
|
||||||
|
i_offset,
|
||||||
|
o_offset);
|
||||||
return;
|
return;
|
||||||
case 3:
|
case 3:
|
||||||
copy_general_general_dims<SrcT, DstT, stride_t, 3>(
|
copy_general_general_dims<SrcT, DstT, stride_t, 3>(
|
||||||
src, dst, data_shape, i_strides, o_strides, i_offset, o_offset);
|
src,
|
||||||
|
dst,
|
||||||
|
new_shape,
|
||||||
|
new_strides[0],
|
||||||
|
new_strides[1],
|
||||||
|
i_offset,
|
||||||
|
o_offset);
|
||||||
return;
|
return;
|
||||||
case 4:
|
case 4:
|
||||||
copy_general_general_dims<SrcT, DstT, stride_t, 4>(
|
copy_general_general_dims<SrcT, DstT, stride_t, 4>(
|
||||||
src, dst, data_shape, i_strides, o_strides, i_offset, o_offset);
|
src,
|
||||||
|
dst,
|
||||||
|
new_shape,
|
||||||
|
new_strides[0],
|
||||||
|
new_strides[1],
|
||||||
|
i_offset,
|
||||||
|
o_offset);
|
||||||
return;
|
return;
|
||||||
case 5:
|
case 5:
|
||||||
copy_general_general_dims<SrcT, DstT, stride_t, 5>(
|
copy_general_general_dims<SrcT, DstT, stride_t, 5>(
|
||||||
src, dst, data_shape, i_strides, o_strides, i_offset, o_offset);
|
src,
|
||||||
|
dst,
|
||||||
|
new_shape,
|
||||||
|
new_strides[0],
|
||||||
|
new_strides[1],
|
||||||
|
i_offset,
|
||||||
|
o_offset);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
int size = std::accumulate(
|
int size = std::accumulate(
|
||||||
data_shape.end() - 5, data_shape.end(), 1, std::multiplies<int>());
|
new_shape.end() - 5, new_shape.end(), 1, std::multiplies<int>());
|
||||||
for (int i = 0; i < src.size(); i += size) {
|
for (int i = 0; i < src.size(); i += size) {
|
||||||
stride_t src_offset = i_offset + elem_to_loc(i, data_shape, i_strides);
|
stride_t src_offset = i_offset + elem_to_loc(i, new_shape, new_strides[0]);
|
||||||
stride_t dst_offset = o_offset + elem_to_loc(i, dst.shape(), o_strides);
|
stride_t dst_offset = o_offset + elem_to_loc(i, new_shape, new_strides[1]);
|
||||||
copy_general_general_dims<SrcT, DstT, stride_t, 5>(
|
copy_general_general_dims<SrcT, DstT, stride_t, 5>(
|
||||||
src, dst, data_shape, i_strides, o_strides, src_offset, dst_offset);
|
src,
|
||||||
|
dst,
|
||||||
|
new_shape,
|
||||||
|
new_strides[0],
|
||||||
|
new_strides[1],
|
||||||
|
src_offset,
|
||||||
|
dst_offset);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -444,8 +485,17 @@ void copy_inplace(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <>
|
template void copy_inplace<size_t>(
|
||||||
void copy_inplace<int64_t>(
|
const array& src,
|
||||||
|
array& dst,
|
||||||
|
const std::vector<int>& data_shape,
|
||||||
|
const std::vector<size_t>& i_strides,
|
||||||
|
const std::vector<size_t>& o_strides,
|
||||||
|
int64_t i_offset,
|
||||||
|
int64_t o_offset,
|
||||||
|
CopyType ctype);
|
||||||
|
|
||||||
|
template void copy_inplace<int64_t>(
|
||||||
const array& src,
|
const array& src,
|
||||||
array& dst,
|
array& dst,
|
||||||
const std::vector<int>& data_shape,
|
const std::vector<int>& data_shape,
|
||||||
@ -453,24 +503,6 @@ void copy_inplace<int64_t>(
|
|||||||
const std::vector<int64_t>& o_strides,
|
const std::vector<int64_t>& o_strides,
|
||||||
int64_t i_offset,
|
int64_t i_offset,
|
||||||
int64_t o_offset,
|
int64_t o_offset,
|
||||||
CopyType ctype) {
|
CopyType ctype);
|
||||||
switch (ctype) {
|
|
||||||
case CopyType::General:
|
|
||||||
case CopyType::GeneralGeneral:
|
|
||||||
return copy_inplace_dispatch(
|
|
||||||
src,
|
|
||||||
dst,
|
|
||||||
ctype,
|
|
||||||
data_shape,
|
|
||||||
i_strides,
|
|
||||||
o_strides,
|
|
||||||
i_offset,
|
|
||||||
o_offset);
|
|
||||||
|
|
||||||
case CopyType::Scalar:
|
|
||||||
case CopyType::Vector:
|
|
||||||
return copy_inplace_dispatch(src, dst, ctype);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
@ -405,7 +405,17 @@ void Reshape::eval(const std::vector<array>& inputs, array& out) {
|
|||||||
auto [copy_necessary, out_strides] = prepare_reshape(in, out);
|
auto [copy_necessary, out_strides] = prepare_reshape(in, out);
|
||||||
|
|
||||||
if (copy_necessary) {
|
if (copy_necessary) {
|
||||||
copy(in, out, in.data_size() == 1 ? CopyType::Scalar : CopyType::General);
|
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||||
|
auto out_strides = make_contiguous_strides<size_t>(in.shape());
|
||||||
|
copy_inplace<size_t>(
|
||||||
|
in,
|
||||||
|
out,
|
||||||
|
in.shape(),
|
||||||
|
in.strides(),
|
||||||
|
out_strides,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
CopyType::General);
|
||||||
} else {
|
} else {
|
||||||
shared_buffer_reshape(in, out_strides, out);
|
shared_buffer_reshape(in, out_strides, out);
|
||||||
}
|
}
|
||||||
|
@ -29,6 +29,15 @@ inline size_t elem_to_loc(int elem, const array& a) {
|
|||||||
return elem_to_loc(elem, a.shape(), a.strides());
|
return elem_to_loc(elem, a.shape(), a.strides());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename stride_t>
|
||||||
|
std::vector<stride_t> make_contiguous_strides(const std::vector<int>& shape) {
|
||||||
|
std::vector<stride_t> strides(shape.size(), 1);
|
||||||
|
for (int i = shape.size() - 1; i > 0; i--) {
|
||||||
|
strides[i - 1] = strides[i] * shape[i];
|
||||||
|
}
|
||||||
|
return strides;
|
||||||
|
}
|
||||||
|
|
||||||
// Collapse dims that are contiguous to possibly route to a better kernel
|
// Collapse dims that are contiguous to possibly route to a better kernel
|
||||||
// e.g. for x = transpose(array({0, 1, 2, 3, 4, 5, 6, 7}, {2, 2, 2}), {2, 0, 1})
|
// e.g. for x = transpose(array({0, 1, 2, 3, 4, 5, 6, 7}, {2, 2, 2}), {2, 0, 1})
|
||||||
// should return {{2, 4}, {{1, 2}}}.
|
// should return {{2, 4}, {{1, 2}}}.
|
||||||
|
@ -273,7 +273,18 @@ void Reshape::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
auto [copy_necessary, out_strides] = prepare_reshape(in, out);
|
auto [copy_necessary, out_strides] = prepare_reshape(in, out);
|
||||||
|
|
||||||
if (copy_necessary) {
|
if (copy_necessary) {
|
||||||
copy_gpu(in, out, CopyType::General);
|
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||||
|
auto out_strides = make_contiguous_strides<size_t>(in.shape());
|
||||||
|
copy_gpu_inplace(
|
||||||
|
in,
|
||||||
|
out,
|
||||||
|
in.shape(),
|
||||||
|
in.strides(),
|
||||||
|
out_strides,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
CopyType::General,
|
||||||
|
stream());
|
||||||
} else {
|
} else {
|
||||||
shared_buffer_reshape(in, out_strides, out);
|
shared_buffer_reshape(in, out_strides, out);
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user