mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-22 19:28:14 +08:00
Fix reshape copy bug (#1253)
This commit is contained in:

committed by
GitHub

parent
bdb36c9a63
commit
03cf033f82
@@ -273,7 +273,18 @@ void Reshape::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
auto [copy_necessary, out_strides] = prepare_reshape(in, out);
|
||||
|
||||
if (copy_necessary) {
|
||||
copy_gpu(in, out, CopyType::General);
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
auto out_strides = make_contiguous_strides<size_t>(in.shape());
|
||||
copy_gpu_inplace(
|
||||
in,
|
||||
out,
|
||||
in.shape(),
|
||||
in.strides(),
|
||||
out_strides,
|
||||
0,
|
||||
0,
|
||||
CopyType::General,
|
||||
stream());
|
||||
} else {
|
||||
shared_buffer_reshape(in, out_strides, out);
|
||||
}
|
||||
|
Reference in New Issue
Block a user