From c093fa72c893b1f36be3bcd2de49d73433fa3649 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Tue, 26 Aug 2025 07:49:09 -0700 Subject: [PATCH] increase cache size --- mlx/backend/cuda/device.cpp | 34 +++++++++++++--------------------- 1 file changed, 13 insertions(+), 21 deletions(-) diff --git a/mlx/backend/cuda/device.cpp b/mlx/backend/cuda/device.cpp index f657a8326..d7b9a0328 100644 --- a/mlx/backend/cuda/device.cpp +++ b/mlx/backend/cuda/device.cpp @@ -29,7 +29,7 @@ void check_cudnn_error(const char* name, cudnnStatus_t err) { int cuda_graph_cache_size() { static int cache_size = []() { - return env::get_var("MLX_CUDA_GRAPH_CACHE_SIZE", 100); + return env::get_var("MLX_CUDA_GRAPH_CACHE_SIZE", 400); }(); return cache_size; } @@ -41,7 +41,6 @@ bool use_cuda_graphs() { return use_graphs; } - } // namespace Device::Device(int device) : device_(device) { @@ -242,13 +241,7 @@ void CommandEncoder::add_kernel_node( void** params) { if (!use_cuda_graphs()) { CHECK_CUDA_ERROR(cudaLaunchKernel( - func, - grid_dim, - block_dim, - params, - smem_bytes, - stream() - )); + func, grid_dim, block_dim, params, smem_bytes, stream())); return; } cudaKernelNodeParams kernel_params = {0}; @@ -268,18 +261,17 @@ void CommandEncoder::add_kernel_node( void** params) { if (!use_cuda_graphs()) { CHECK_CUDA_ERROR(cuLaunchKernel( - func, - grid_dim.x, - grid_dim.y, - grid_dim.z, - block_dim.x, - block_dim.y, - block_dim.z, - smem_bytes, - stream(), - params, - nullptr - )); + func, + grid_dim.x, + grid_dim.y, + grid_dim.z, + block_dim.x, + block_dim.y, + block_dim.z, + smem_bytes, + stream(), + params, + nullptr)); return; }