mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-17 09:18:12 +08:00
More buffer donation in some cases (#1858)
* more donation * fix * add test
This commit is contained in:
@@ -1,6 +1,5 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
#include <memory>
|
||||
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/event.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
@@ -54,16 +53,34 @@ std::function<void()> make_task(array arr, bool signal) {
|
||||
abort_with_exception(error);
|
||||
}
|
||||
}
|
||||
std::vector<std::shared_ptr<array::Data>> buffers;
|
||||
|
||||
// Use a set to de-dup input buffers
|
||||
std::unordered_set<std::shared_ptr<array::Data>> buffers;
|
||||
for (auto& in : arr.inputs()) {
|
||||
buffers.push_back(in.data_shared_ptr());
|
||||
if (in.data_shared_ptr() != nullptr) {
|
||||
buffers.insert(in.data_shared_ptr());
|
||||
}
|
||||
}
|
||||
for (auto& s : arr.siblings()) {
|
||||
buffers.push_back(s.data_shared_ptr());
|
||||
buffers.insert(s.data_shared_ptr());
|
||||
}
|
||||
if (!arr.is_tracer()) {
|
||||
arr.detach();
|
||||
}
|
||||
// Erase any input buffers that have a use count > 1
|
||||
// to keep them elligible for donation
|
||||
for (auto it = buffers.begin(); it != buffers.end();) {
|
||||
// Always hold buffers which could belong to outputs
|
||||
// TODO: this shouldn't be necessary, but metal validation
|
||||
// complains if buffers get released before the command buffer is
|
||||
// finished, even if the ready signal has fired
|
||||
if (!signal && it->use_count() > 1) {
|
||||
it = buffers.erase(it);
|
||||
} else {
|
||||
++it;
|
||||
}
|
||||
}
|
||||
|
||||
for (auto& out : outputs) {
|
||||
out.set_status(array::Status::evaluated);
|
||||
}
|
||||
|
Reference in New Issue
Block a user