From 529842fed9ff3a789dafb5ea8eb8d67ff3becbc4 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Mon, 3 Nov 2025 16:43:19 -0800 Subject: [PATCH] fix --- mlx/backend/cuda/cudnn_utils.cpp | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/mlx/backend/cuda/cudnn_utils.cpp b/mlx/backend/cuda/cudnn_utils.cpp index e93778056..2728414be 100644 --- a/mlx/backend/cuda/cudnn_utils.cpp +++ b/mlx/backend/cuda/cudnn_utils.cpp @@ -132,14 +132,18 @@ bool prepare_cudnn_plan( void** data_ptrs, F&& execute) { int workspace_size = plan.getWorkspaceSize(); - array workspace( - workspace_size > 0 ? cu::malloc_async(workspace_size, encoder.stream()) - : allocator::Buffer(nullptr), - {workspace_size}, - uint8); + void* workspace_ptr = nullptr; + if (workspace_size > 0) { + array workspace( + cu::malloc_async(workspace_size, encoder.stream()), + {workspace_size}, + uint8); + encoder.add_temporary(workspace); + workspace_ptr = gpu_ptr(workspace); + } auto args = cudnn_frontend::VariantPackBuilder() - .setWorkspacePointer(gpu_ptr(workspace)) + .setWorkspacePointer(workspace_ptr) .setDataPointers(num_args, data_ptrs) .setUids(num_args, uids) .build(); @@ -151,7 +155,6 @@ bool prepare_cudnn_plan( return false; } - encoder.add_temporary(workspace); return true; }