mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-17 17:28:10 +08:00
fix donation in scan (#1917)
This commit is contained in:
@@ -17,10 +17,10 @@ void Scan::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
auto& s = stream();
|
||||
auto& d = metal::device(s.device);
|
||||
|
||||
std::vector<array> copies;
|
||||
bool donate = inputs[0].is_donatable();
|
||||
auto in = inputs[0];
|
||||
if (in.flags().contiguous && in.strides()[axis_] != 0) {
|
||||
if (in.is_donatable() && in.itemsize() == out.itemsize()) {
|
||||
if (donate && in.itemsize() == out.itemsize()) {
|
||||
out.move_shared_buffer(in);
|
||||
} else {
|
||||
out.set_data(
|
||||
@@ -32,8 +32,7 @@ void Scan::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
} else {
|
||||
array arr_copy(in.shape(), in.dtype(), nullptr, {});
|
||||
copy_gpu(in, arr_copy, CopyType::General, s);
|
||||
copies.push_back(arr_copy);
|
||||
in = arr_copy;
|
||||
in = std::move(arr_copy);
|
||||
out.move_shared_buffer(in);
|
||||
}
|
||||
|
||||
@@ -127,8 +126,6 @@ void Scan::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
MTL::Size group_dims(thread_group_size, 1, 1);
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
d.add_temporaries(std::move(copies), s.index);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
Reference in New Issue
Block a user