diff --git a/mlx/backend/cuda/device.cpp b/mlx/backend/cuda/device.cpp index 334655ffe..d7b9a0328 100644 --- a/mlx/backend/cuda/device.cpp +++ b/mlx/backend/cuda/device.cpp @@ -29,11 +29,18 @@ void check_cudnn_error(const char* name, cudnnStatus_t err) { int cuda_graph_cache_size() { static int cache_size = []() { - return env::get_var("MLX_CUDA_GRAPH_CACHE_SIZE", 100); + return env::get_var("MLX_CUDA_GRAPH_CACHE_SIZE", 400); }(); return cache_size; } +bool use_cuda_graphs() { + static bool use_graphs = []() { + return env::get_var("MLX_USE_CUDA_GRAPHS", true); + }(); + return use_graphs; +} + } // namespace Device::Device(int device) : device_(device) { @@ -86,11 +93,18 @@ CommandEncoder& Device::get_command_encoder(Stream s) { CommandEncoder::CaptureContext::CaptureContext(CommandEncoder& enc) : enc(enc) { enc.device().make_current(); + if (!use_cuda_graphs()) { + return; + } CHECK_CUDA_ERROR( cudaStreamBeginCapture(enc.stream(), cudaStreamCaptureModeGlobal)); } CommandEncoder::CaptureContext::~CaptureContext() { + if (!use_cuda_graphs()) { + return; + } + graph.end_capture(enc.stream()); if (discard) { return; @@ -105,6 +119,9 @@ CommandEncoder::ConcurrentContext::ConcurrentContext(CommandEncoder& enc) CommandEncoder::ConcurrentContext::~ConcurrentContext() { enc.in_concurrent_ = false; + if (!use_cuda_graphs()) { + return; + } // Use an empty graph node for synchronization CommandEncoder::GraphNode empty{NULL, 'E', std::to_string(enc.node_count_++)}; @@ -193,11 +210,18 @@ void CommandEncoder::add_completed_handler(std::function task) { } void CommandEncoder::set_input_array(const array& arr) { + if (!use_cuda_graphs()) { + return; + } auto id = reinterpret_cast(arr.buffer().ptr()); active_deps_.push_back(id); } void CommandEncoder::set_output_array(const array& arr) { + if (!use_cuda_graphs()) { + return; + } + auto id = reinterpret_cast(arr.buffer().ptr()); active_deps_.push_back(id); active_outputs_.push_back(id); @@ -215,6 +239,11 @@ void CommandEncoder::add_kernel_node( dim3 block_dim, uint32_t smem_bytes, void** params) { + if (!use_cuda_graphs()) { + CHECK_CUDA_ERROR(cudaLaunchKernel( + func, grid_dim, block_dim, params, smem_bytes, stream())); + return; + } cudaKernelNodeParams kernel_params = {0}; kernel_params.func = func; kernel_params.gridDim = grid_dim; @@ -230,6 +259,22 @@ void CommandEncoder::add_kernel_node( dim3 block_dim, uint32_t smem_bytes, void** params) { + if (!use_cuda_graphs()) { + CHECK_CUDA_ERROR(cuLaunchKernel( + func, + grid_dim.x, + grid_dim.y, + grid_dim.z, + block_dim.x, + block_dim.y, + block_dim.z, + smem_bytes, + stream(), + params, + nullptr)); + return; + } + CUDA_KERNEL_NODE_PARAMS kernel_params = {0}; kernel_params.func = func; kernel_params.gridDimX = grid_dim.x; @@ -256,6 +301,12 @@ void CommandEncoder::add_kernel_node(const CUDA_KERNEL_NODE_PARAMS& params) { } void CommandEncoder::add_graph_node(cudaGraph_t child) { + if (!use_cuda_graphs()) { + CudaGraphExec graph_exec; + graph_exec.instantiate(child); + device_.make_current(); + CHECK_CUDA_ERROR(cudaGraphLaunch(graph_exec, stream())); + } cudaGraphNode_t node; CHECK_CUDA_ERROR(cudaGraphAddChildGraphNode(&node, graph_, NULL, 0, child)); insert_graph_dependencies(GraphNode{node, 'G'}); diff --git a/mlx/backend/cuda/device.h b/mlx/backend/cuda/device.h index 7b0ff5629..3526de947 100644 --- a/mlx/backend/cuda/device.h +++ b/mlx/backend/cuda/device.h @@ -76,9 +76,6 @@ class CommandEncoder { uint32_t smem_bytes, void** params); - // Low-level graph helpers. - void add_kernel_node(const cudaKernelNodeParams& params); - void add_kernel_node(const CUDA_KERNEL_NODE_PARAMS& params); void add_graph_node(cudaGraph_t child); void add_temporary(const array& arr) { @@ -101,6 +98,9 @@ class CommandEncoder { void synchronize(); private: + void add_kernel_node(const cudaKernelNodeParams& params); + void add_kernel_node(const CUDA_KERNEL_NODE_PARAMS& params); + struct GraphNode { cudaGraphNode_t node; // K = kernel