[CUDA] Save primitive inputs faster (#2449)

* Add more nvtx loggings

* [CUDA] Saving primitive inputs faster

* Remove unneeded check
This commit is contained in:
Cheng
2025-08-01 10:16:06 +09:00
committed by GitHub
parent 86c6a15571
commit b26d88591c
5 changed files with 11 additions and 9 deletions

View File

@@ -36,18 +36,15 @@ void eval(array& arr) {
auto& encoder = cu::get_command_encoder(arr.primitive().stream());
// Keep used buffers alive until kernel finishes running.
std::unordered_set<std::shared_ptr<array::Data>> buffers;
for (auto& in : arr.inputs()) {
buffers.insert(in.data_shared_ptr());
// Except for the donated one.
if (in.data_shared_ptr() != arr.data_shared_ptr()) {
encoder.add_temporary(in);
}
}
for (auto& s : arr.siblings()) {
buffers.insert(s.data_shared_ptr());
encoder.add_temporary(s);
}
// Remove the output if it was donated to by an input.
if (auto it = buffers.find(arr.data_shared_ptr()); it != buffers.end()) {
buffers.erase(it);
}
encoder.add_completed_handler([buffers = std::move(buffers)]() {});
encoder.maybe_commit();
}