From b26d88591c4c261418d1ac99e973743f96bb898d Mon Sep 17 00:00:00 2001 From: Cheng Date: Fri, 1 Aug 2025 10:16:06 +0900 Subject: [PATCH] [CUDA] Save primitive inputs faster (#2449) * Add more nvtx loggings * [CUDA] Saving primitive inputs faster * Remove unneeded check --- mlx/backend/cuda/device.cpp | 1 + mlx/backend/cuda/eval.cpp | 13 +++++-------- mlx/backend/cuda/quantized/quantized.cpp | 3 +++ mlx/backend/cuda/ternary.cu | 2 +- mlx/backend/gpu/primitives.cpp | 1 + 5 files changed, 11 insertions(+), 9 deletions(-) diff --git a/mlx/backend/cuda/device.cpp b/mlx/backend/cuda/device.cpp index 5871ce3e22..59357febf0 100644 --- a/mlx/backend/cuda/device.cpp +++ b/mlx/backend/cuda/device.cpp @@ -269,6 +269,7 @@ void CommandEncoder::add_graph_node(cudaGraph_t child) { } void CommandEncoder::commit() { + nvtx3::scoped_range r("CommandEncoder::commit"); if (!temporaries_.empty()) { add_completed_handler([temporaries = std::move(temporaries_)]() {}); } diff --git a/mlx/backend/cuda/eval.cpp b/mlx/backend/cuda/eval.cpp index 0e1477e950..0c2690dcf4 100644 --- a/mlx/backend/cuda/eval.cpp +++ b/mlx/backend/cuda/eval.cpp @@ -36,18 +36,15 @@ void eval(array& arr) { auto& encoder = cu::get_command_encoder(arr.primitive().stream()); // Keep used buffers alive until kernel finishes running. - std::unordered_set> buffers; for (auto& in : arr.inputs()) { - buffers.insert(in.data_shared_ptr()); + // Except for the donated one. + if (in.data_shared_ptr() != arr.data_shared_ptr()) { + encoder.add_temporary(in); + } } for (auto& s : arr.siblings()) { - buffers.insert(s.data_shared_ptr()); + encoder.add_temporary(s); } - // Remove the output if it was donated to by an input. - if (auto it = buffers.find(arr.data_shared_ptr()); it != buffers.end()) { - buffers.erase(it); - } - encoder.add_completed_handler([buffers = std::move(buffers)]() {}); encoder.maybe_commit(); } diff --git a/mlx/backend/cuda/quantized/quantized.cpp b/mlx/backend/cuda/quantized/quantized.cpp index f495af53b3..256f2c7d50 100644 --- a/mlx/backend/cuda/quantized/quantized.cpp +++ b/mlx/backend/cuda/quantized/quantized.cpp @@ -5,6 +5,8 @@ #include "mlx/backend/gpu/copy.h" #include "mlx/fast_primitives.h" +#include + namespace mlx::core { namespace { @@ -42,6 +44,7 @@ inline array ensure_row_contiguous_matrix( void fast::AffineQuantize::eval_gpu( const std::vector& inputs, std::vector& outputs) { + nvtx3::scoped_range r("AffineQuantize::eval_gpu"); auto& s = stream(); auto& d = cu::device(s.device); auto& enc = d.get_command_encoder(s); diff --git a/mlx/backend/cuda/ternary.cu b/mlx/backend/cuda/ternary.cu index 58d3fa119a..bc4097d99c 100644 --- a/mlx/backend/cuda/ternary.cu +++ b/mlx/backend/cuda/ternary.cu @@ -192,7 +192,7 @@ void ternary_op_gpu( } void Select::eval_gpu(const std::vector& inputs, array& out) { - nvtx3::scoped_range r("select::eval_gpu"); + nvtx3::scoped_range r("Select::eval_gpu"); auto& s = out.primitive().stream(); ternary_op_gpu(inputs, out, s); } diff --git a/mlx/backend/gpu/primitives.cpp b/mlx/backend/gpu/primitives.cpp index 1adb859180..56d389b4f7 100644 --- a/mlx/backend/gpu/primitives.cpp +++ b/mlx/backend/gpu/primitives.cpp @@ -133,6 +133,7 @@ void NumberOfElements::eval_gpu(const std::vector& inputs, array& out) { } void Pad::eval_gpu(const std::vector& inputs, array& out) { + MLX_PROFILER_RANGE("Pad::eval_gpu"); // Inputs must be base input array and scalar val array assert(inputs.size() == 2); auto& in = inputs[0];