[CUDA] Save primitive inputs faster (#2449)

* Add more nvtx loggings

* [CUDA] Saving primitive inputs faster

* Remove unneeded check
This commit is contained in:
Cheng 2025-08-01 10:16:06 +09:00 committed by GitHub
parent 86c6a15571
commit b26d88591c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 11 additions and 9 deletions

View File

@ -269,6 +269,7 @@ void CommandEncoder::add_graph_node(cudaGraph_t child) {
}
void CommandEncoder::commit() {
nvtx3::scoped_range r("CommandEncoder::commit");
if (!temporaries_.empty()) {
add_completed_handler([temporaries = std::move(temporaries_)]() {});
}

View File

@ -36,18 +36,15 @@ void eval(array& arr) {
auto& encoder = cu::get_command_encoder(arr.primitive().stream());
// Keep used buffers alive until kernel finishes running.
std::unordered_set<std::shared_ptr<array::Data>> buffers;
for (auto& in : arr.inputs()) {
buffers.insert(in.data_shared_ptr());
// Except for the donated one.
if (in.data_shared_ptr() != arr.data_shared_ptr()) {
encoder.add_temporary(in);
}
}
for (auto& s : arr.siblings()) {
buffers.insert(s.data_shared_ptr());
encoder.add_temporary(s);
}
// Remove the output if it was donated to by an input.
if (auto it = buffers.find(arr.data_shared_ptr()); it != buffers.end()) {
buffers.erase(it);
}
encoder.add_completed_handler([buffers = std::move(buffers)]() {});
encoder.maybe_commit();
}

View File

@ -5,6 +5,8 @@
#include "mlx/backend/gpu/copy.h"
#include "mlx/fast_primitives.h"
#include <nvtx3/nvtx3.hpp>
namespace mlx::core {
namespace {
@ -42,6 +44,7 @@ inline array ensure_row_contiguous_matrix(
void fast::AffineQuantize::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
nvtx3::scoped_range r("AffineQuantize::eval_gpu");
auto& s = stream();
auto& d = cu::device(s.device);
auto& enc = d.get_command_encoder(s);

View File

@ -192,7 +192,7 @@ void ternary_op_gpu(
}
void Select::eval_gpu(const std::vector<array>& inputs, array& out) {
nvtx3::scoped_range r("select::eval_gpu");
nvtx3::scoped_range r("Select::eval_gpu");
auto& s = out.primitive().stream();
ternary_op_gpu<cu::Select>(inputs, out, s);
}

View File

@ -133,6 +133,7 @@ void NumberOfElements::eval_gpu(const std::vector<array>& inputs, array& out) {
}
void Pad::eval_gpu(const std::vector<array>& inputs, array& out) {
MLX_PROFILER_RANGE("Pad::eval_gpu");
// Inputs must be base input array and scalar val array
assert(inputs.size() == 2);
auto& in = inputs[0];