mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
[CUDA] Save primitive inputs faster (#2449)
* Add more nvtx loggings * [CUDA] Saving primitive inputs faster * Remove unneeded check
This commit is contained in:
@@ -36,18 +36,15 @@ void eval(array& arr) {
|
||||
|
||||
auto& encoder = cu::get_command_encoder(arr.primitive().stream());
|
||||
// Keep used buffers alive until kernel finishes running.
|
||||
std::unordered_set<std::shared_ptr<array::Data>> buffers;
|
||||
for (auto& in : arr.inputs()) {
|
||||
buffers.insert(in.data_shared_ptr());
|
||||
// Except for the donated one.
|
||||
if (in.data_shared_ptr() != arr.data_shared_ptr()) {
|
||||
encoder.add_temporary(in);
|
||||
}
|
||||
}
|
||||
for (auto& s : arr.siblings()) {
|
||||
buffers.insert(s.data_shared_ptr());
|
||||
encoder.add_temporary(s);
|
||||
}
|
||||
// Remove the output if it was donated to by an input.
|
||||
if (auto it = buffers.find(arr.data_shared_ptr()); it != buffers.end()) {
|
||||
buffers.erase(it);
|
||||
}
|
||||
encoder.add_completed_handler([buffers = std::move(buffers)]() {});
|
||||
encoder.maybe_commit();
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user