increase cache size

This commit is contained in:
Awni Hannun 2025-08-26 07:49:09 -07:00
parent 4ba4544549
commit c093fa72c8

View File

@ -29,7 +29,7 @@ void check_cudnn_error(const char* name, cudnnStatus_t err) {
int cuda_graph_cache_size() { int cuda_graph_cache_size() {
static int cache_size = []() { static int cache_size = []() {
return env::get_var("MLX_CUDA_GRAPH_CACHE_SIZE", 100); return env::get_var("MLX_CUDA_GRAPH_CACHE_SIZE", 400);
}(); }();
return cache_size; return cache_size;
} }
@ -41,7 +41,6 @@ bool use_cuda_graphs() {
return use_graphs; return use_graphs;
} }
} // namespace } // namespace
Device::Device(int device) : device_(device) { Device::Device(int device) : device_(device) {
@ -242,13 +241,7 @@ void CommandEncoder::add_kernel_node(
void** params) { void** params) {
if (!use_cuda_graphs()) { if (!use_cuda_graphs()) {
CHECK_CUDA_ERROR(cudaLaunchKernel( CHECK_CUDA_ERROR(cudaLaunchKernel(
func, func, grid_dim, block_dim, params, smem_bytes, stream()));
grid_dim,
block_dim,
params,
smem_bytes,
stream()
));
return; return;
} }
cudaKernelNodeParams kernel_params = {0}; cudaKernelNodeParams kernel_params = {0};
@ -268,18 +261,17 @@ void CommandEncoder::add_kernel_node(
void** params) { void** params) {
if (!use_cuda_graphs()) { if (!use_cuda_graphs()) {
CHECK_CUDA_ERROR(cuLaunchKernel( CHECK_CUDA_ERROR(cuLaunchKernel(
func, func,
grid_dim.x, grid_dim.x,
grid_dim.y, grid_dim.y,
grid_dim.z, grid_dim.z,
block_dim.x, block_dim.x,
block_dim.y, block_dim.y,
block_dim.z, block_dim.z,
smem_bytes, smem_bytes,
stream(), stream(),
params, params,
nullptr nullptr));
));
return; return;
} }