From bbf14239538c18a9f8a3d61f52ab0f29589410f0 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Tue, 30 Sep 2025 16:08:46 -0700 Subject: [PATCH] wait for tasks in cuda (#2636) --- mlx/backend/cuda/allocator.cpp | 2 +- mlx/backend/cuda/device.cpp | 22 ++++++++++------------ mlx/backend/cuda/device.h | 2 +- mlx/backend/cuda/eval.cpp | 16 ++++++++++++++-- 4 files changed, 26 insertions(+), 16 deletions(-) diff --git a/mlx/backend/cuda/allocator.cpp b/mlx/backend/cuda/allocator.cpp index 5eb10b8ac..329906a13 100644 --- a/mlx/backend/cuda/allocator.cpp +++ b/mlx/backend/cuda/allocator.cpp @@ -86,7 +86,7 @@ CudaAllocator::CudaAllocator() // TODO: Set memory limit for multi-device. size_t free, total; CHECK_CUDA_ERROR(cudaMemGetInfo(&free, &total)); - memory_limit_ = total * 0.8; + memory_limit_ = total * 0.95; max_pool_size_ = memory_limit_; } diff --git a/mlx/backend/cuda/device.cpp b/mlx/backend/cuda/device.cpp index 27e702743..7d0da0580 100644 --- a/mlx/backend/cuda/device.cpp +++ b/mlx/backend/cuda/device.cpp @@ -14,10 +14,6 @@ namespace mlx::core::cu { namespace { -// Can be tuned with MLX_MAX_OPS_PER_BUFFER -// This should be less than 255 -constexpr int default_max_nodes_per_graph = 20; - #define CHECK_CUDNN_ERROR(cmd) check_cudnn_error(#cmd, (cmd)) void check_cudnn_error(const char* name, cudnnStatus_t err) { @@ -95,6 +91,7 @@ CommandEncoder::CaptureContext::CaptureContext(CommandEncoder& enc) : enc(enc) { CommandEncoder::CaptureContext::~CaptureContext() { if (!use_cuda_graphs()) { + enc.node_count_++; return; } @@ -221,12 +218,6 @@ void CommandEncoder::set_output_array(const array& arr) { active_outputs_.push_back(id); } -void CommandEncoder::maybe_commit() { - if (node_count_ >= env::max_ops_per_buffer(default_max_nodes_per_graph)) { - commit(); - } -} - void CommandEncoder::add_kernel_node( void* func, dim3 grid_dim, @@ -234,6 +225,7 @@ void CommandEncoder::add_kernel_node( uint32_t smem_bytes, void** params) { if (!use_cuda_graphs()) { + node_count_++; CHECK_CUDA_ERROR(cudaLaunchKernel( func, grid_dim, block_dim, params, smem_bytes, stream())); return; @@ -254,6 +246,7 @@ void CommandEncoder::add_kernel_node( uint32_t smem_bytes, void** params) { if (!use_cuda_graphs()) { + node_count_++; CHECK_CUDA_ERROR(cuLaunchKernel( func, grid_dim.x, @@ -296,6 +289,7 @@ void CommandEncoder::add_kernel_node(const CUDA_KERNEL_NODE_PARAMS& params) { void CommandEncoder::add_graph_node(cudaGraph_t child) { if (!use_cuda_graphs()) { + node_count_++; CudaGraphExec graph_exec; graph_exec.instantiate(child); device_.make_current(); @@ -307,12 +301,16 @@ void CommandEncoder::add_graph_node(cudaGraph_t child) { insert_graph_dependencies(GraphNode{node, 'G'}); } +int CommandEncoder::get_num_ops() { + return node_count_; +} + void CommandEncoder::commit() { nvtx3::scoped_range r("CommandEncoder::commit"); if (!temporaries_.empty()) { add_completed_handler([temporaries = std::move(temporaries_)]() {}); } - if (node_count_ > 0) { + if (use_cuda_graphs() && node_count_ > 0) { if (!from_nodes_.empty()) { CHECK_CUDA_ERROR(cudaGraphAddDependencies( graph_, @@ -355,7 +353,6 @@ void CommandEncoder::commit() { CHECK_CUDA_ERROR(cudaGraphLaunch(graph_exec, stream_)); // Reset state - node_count_ = 0; graph_node_count_ = 0; empty_node_count_ = 0; from_nodes_.clear(); @@ -367,6 +364,7 @@ void CommandEncoder::commit() { // Put completion handlers in a batch. worker_.commit(stream_); + node_count_ = 0; } void CommandEncoder::synchronize() { diff --git a/mlx/backend/cuda/device.h b/mlx/backend/cuda/device.h index 93667e736..d18092328 100644 --- a/mlx/backend/cuda/device.h +++ b/mlx/backend/cuda/device.h @@ -83,7 +83,7 @@ class CommandEncoder { } void add_completed_handler(std::function task); - void maybe_commit(); + int get_num_ops(); void commit(); Device& device() { diff --git a/mlx/backend/cuda/eval.cpp b/mlx/backend/cuda/eval.cpp index 379d65423..07b3ad63e 100644 --- a/mlx/backend/cuda/eval.cpp +++ b/mlx/backend/cuda/eval.cpp @@ -5,11 +5,15 @@ #include "mlx/backend/cuda/device.h" #include "mlx/backend/gpu/available.h" #include "mlx/primitives.h" +#include "mlx/scheduler.h" #include namespace mlx::core::gpu { +// Can be tuned with MLX_MAX_OPS_PER_BUFFER +constexpr int default_max_nodes_per_graph = 20; + bool is_available() { return true; } @@ -36,7 +40,8 @@ void eval(array& arr) { arr.primitive().eval_gpu(arr.inputs(), outputs); } - auto& encoder = cu::get_command_encoder(arr.primitive().stream()); + auto& stream = arr.primitive().stream(); + auto& encoder = cu::get_command_encoder(stream); // Keep used buffers alive until kernel finishes running. for (auto& in : arr.inputs()) { // Except for the donated one. @@ -47,7 +52,14 @@ void eval(array& arr) { for (auto& s : arr.siblings()) { encoder.add_temporary(s); } - encoder.maybe_commit(); + + if (encoder.get_num_ops() >= + env::max_ops_per_buffer(default_max_nodes_per_graph)) { + scheduler::notify_new_task(stream); + encoder.add_completed_handler( + [stream]() { scheduler::notify_task_completion(stream); }); + encoder.commit(); + } } void finalize(Stream s) {