Fix reshape copy bug (#1253)

This commit is contained in:
Angelos Katharopoulos
2024-07-07 21:37:00 -07:00
committed by GitHub
parent bdb36c9a63
commit 03cf033f82
4 changed files with 107 additions and 45 deletions

View File

@@ -273,7 +273,18 @@ void Reshape::eval_gpu(const std::vector<array>& inputs, array& out) {
auto [copy_necessary, out_strides] = prepare_reshape(in, out);
if (copy_necessary) {
copy_gpu(in, out, CopyType::General);
out.set_data(allocator::malloc_or_wait(out.nbytes()));
auto out_strides = make_contiguous_strides<size_t>(in.shape());
copy_gpu_inplace(
in,
out,
in.shape(),
in.strides(),
out_strides,
0,
0,
CopyType::General,
stream());
} else {
shared_buffer_reshape(in, out_strides, out);
}