From 4ba4544549d7a082330d974ed019afa85c50d3c4 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Mon, 25 Aug 2025 12:56:56 -0700 Subject: [PATCH 1/2] enable cuda graph toggle --- mlx/backend/cuda/device.cpp | 59 +++++++++++++++++++++++++++++++++++++ mlx/backend/cuda/device.h | 6 ++-- 2 files changed, 62 insertions(+), 3 deletions(-) diff --git a/mlx/backend/cuda/device.cpp b/mlx/backend/cuda/device.cpp index 334655ffe..f657a8326 100644 --- a/mlx/backend/cuda/device.cpp +++ b/mlx/backend/cuda/device.cpp @@ -34,6 +34,14 @@ int cuda_graph_cache_size() { return cache_size; } +bool use_cuda_graphs() { + static bool use_graphs = []() { + return env::get_var("MLX_USE_CUDA_GRAPHS", true); + }(); + return use_graphs; +} + + } // namespace Device::Device(int device) : device_(device) { @@ -86,11 +94,18 @@ CommandEncoder& Device::get_command_encoder(Stream s) { CommandEncoder::CaptureContext::CaptureContext(CommandEncoder& enc) : enc(enc) { enc.device().make_current(); + if (!use_cuda_graphs()) { + return; + } CHECK_CUDA_ERROR( cudaStreamBeginCapture(enc.stream(), cudaStreamCaptureModeGlobal)); } CommandEncoder::CaptureContext::~CaptureContext() { + if (!use_cuda_graphs()) { + return; + } + graph.end_capture(enc.stream()); if (discard) { return; @@ -105,6 +120,9 @@ CommandEncoder::ConcurrentContext::ConcurrentContext(CommandEncoder& enc) CommandEncoder::ConcurrentContext::~ConcurrentContext() { enc.in_concurrent_ = false; + if (!use_cuda_graphs()) { + return; + } // Use an empty graph node for synchronization CommandEncoder::GraphNode empty{NULL, 'E', std::to_string(enc.node_count_++)}; @@ -193,11 +211,18 @@ void CommandEncoder::add_completed_handler(std::function task) { } void CommandEncoder::set_input_array(const array& arr) { + if (!use_cuda_graphs()) { + return; + } auto id = reinterpret_cast(arr.buffer().ptr()); active_deps_.push_back(id); } void CommandEncoder::set_output_array(const array& arr) { + if (!use_cuda_graphs()) { + return; + } + auto id = reinterpret_cast(arr.buffer().ptr()); active_deps_.push_back(id); active_outputs_.push_back(id); @@ -215,6 +240,17 @@ void CommandEncoder::add_kernel_node( dim3 block_dim, uint32_t smem_bytes, void** params) { + if (!use_cuda_graphs()) { + CHECK_CUDA_ERROR(cudaLaunchKernel( + func, + grid_dim, + block_dim, + params, + smem_bytes, + stream() + )); + return; + } cudaKernelNodeParams kernel_params = {0}; kernel_params.func = func; kernel_params.gridDim = grid_dim; @@ -230,6 +266,23 @@ void CommandEncoder::add_kernel_node( dim3 block_dim, uint32_t smem_bytes, void** params) { + if (!use_cuda_graphs()) { + CHECK_CUDA_ERROR(cuLaunchKernel( + func, + grid_dim.x, + grid_dim.y, + grid_dim.z, + block_dim.x, + block_dim.y, + block_dim.z, + smem_bytes, + stream(), + params, + nullptr + )); + return; + } + CUDA_KERNEL_NODE_PARAMS kernel_params = {0}; kernel_params.func = func; kernel_params.gridDimX = grid_dim.x; @@ -256,6 +309,12 @@ void CommandEncoder::add_kernel_node(const CUDA_KERNEL_NODE_PARAMS& params) { } void CommandEncoder::add_graph_node(cudaGraph_t child) { + if (!use_cuda_graphs()) { + CudaGraphExec graph_exec; + graph_exec.instantiate(child); + device_.make_current(); + CHECK_CUDA_ERROR(cudaGraphLaunch(graph_exec, stream())); + } cudaGraphNode_t node; CHECK_CUDA_ERROR(cudaGraphAddChildGraphNode(&node, graph_, NULL, 0, child)); insert_graph_dependencies(GraphNode{node, 'G'}); diff --git a/mlx/backend/cuda/device.h b/mlx/backend/cuda/device.h index 7b0ff5629..3526de947 100644 --- a/mlx/backend/cuda/device.h +++ b/mlx/backend/cuda/device.h @@ -76,9 +76,6 @@ class CommandEncoder { uint32_t smem_bytes, void** params); - // Low-level graph helpers. - void add_kernel_node(const cudaKernelNodeParams& params); - void add_kernel_node(const CUDA_KERNEL_NODE_PARAMS& params); void add_graph_node(cudaGraph_t child); void add_temporary(const array& arr) { @@ -101,6 +98,9 @@ class CommandEncoder { void synchronize(); private: + void add_kernel_node(const cudaKernelNodeParams& params); + void add_kernel_node(const CUDA_KERNEL_NODE_PARAMS& params); + struct GraphNode { cudaGraphNode_t node; // K = kernel From c093fa72c893b1f36be3bcd2de49d73433fa3649 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Tue, 26 Aug 2025 07:49:09 -0700 Subject: [PATCH 2/2] increase cache size --- mlx/backend/cuda/device.cpp | 34 +++++++++++++--------------------- 1 file changed, 13 insertions(+), 21 deletions(-) diff --git a/mlx/backend/cuda/device.cpp b/mlx/backend/cuda/device.cpp index f657a8326..d7b9a0328 100644 --- a/mlx/backend/cuda/device.cpp +++ b/mlx/backend/cuda/device.cpp @@ -29,7 +29,7 @@ void check_cudnn_error(const char* name, cudnnStatus_t err) { int cuda_graph_cache_size() { static int cache_size = []() { - return env::get_var("MLX_CUDA_GRAPH_CACHE_SIZE", 100); + return env::get_var("MLX_CUDA_GRAPH_CACHE_SIZE", 400); }(); return cache_size; } @@ -41,7 +41,6 @@ bool use_cuda_graphs() { return use_graphs; } - } // namespace Device::Device(int device) : device_(device) { @@ -242,13 +241,7 @@ void CommandEncoder::add_kernel_node( void** params) { if (!use_cuda_graphs()) { CHECK_CUDA_ERROR(cudaLaunchKernel( - func, - grid_dim, - block_dim, - params, - smem_bytes, - stream() - )); + func, grid_dim, block_dim, params, smem_bytes, stream())); return; } cudaKernelNodeParams kernel_params = {0}; @@ -268,18 +261,17 @@ void CommandEncoder::add_kernel_node( void** params) { if (!use_cuda_graphs()) { CHECK_CUDA_ERROR(cuLaunchKernel( - func, - grid_dim.x, - grid_dim.y, - grid_dim.z, - block_dim.x, - block_dim.y, - block_dim.z, - smem_bytes, - stream(), - params, - nullptr - )); + func, + grid_dim.x, + grid_dim.y, + grid_dim.z, + block_dim.x, + block_dim.y, + block_dim.z, + smem_bytes, + stream(), + params, + nullptr)); return; }