mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
[CUDA] Saving primitive inputs faster
This commit is contained in:
@@ -36,18 +36,17 @@ void eval(array& arr) {
|
|||||||
|
|
||||||
auto& encoder = cu::get_command_encoder(arr.primitive().stream());
|
auto& encoder = cu::get_command_encoder(arr.primitive().stream());
|
||||||
// Keep used buffers alive until kernel finishes running.
|
// Keep used buffers alive until kernel finishes running.
|
||||||
std::unordered_set<std::shared_ptr<array::Data>> buffers;
|
|
||||||
for (auto& in : arr.inputs()) {
|
for (auto& in : arr.inputs()) {
|
||||||
buffers.insert(in.data_shared_ptr());
|
// Except for the donated one.
|
||||||
|
if (in.data_shared_ptr() != arr.data_shared_ptr()) {
|
||||||
|
encoder.add_temporary(in);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
for (auto& s : arr.siblings()) {
|
for (auto& s : arr.siblings()) {
|
||||||
buffers.insert(s.data_shared_ptr());
|
if (s.data_shared_ptr() != arr.data_shared_ptr()) {
|
||||||
|
encoder.add_temporary(s);
|
||||||
}
|
}
|
||||||
// Remove the output if it was donated to by an input.
|
|
||||||
if (auto it = buffers.find(arr.data_shared_ptr()); it != buffers.end()) {
|
|
||||||
buffers.erase(it);
|
|
||||||
}
|
}
|
||||||
encoder.add_completed_handler([buffers = std::move(buffers)]() {});
|
|
||||||
encoder.maybe_commit();
|
encoder.maybe_commit();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user