diff --git a/mlx/backend/cuda/cudnn_utils.cpp b/mlx/backend/cuda/cudnn_utils.cpp index e93778056..2728414be 100644 --- a/mlx/backend/cuda/cudnn_utils.cpp +++ b/mlx/backend/cuda/cudnn_utils.cpp @@ -132,14 +132,18 @@ bool prepare_cudnn_plan( void** data_ptrs, F&& execute) { int workspace_size = plan.getWorkspaceSize(); - array workspace( - workspace_size > 0 ? cu::malloc_async(workspace_size, encoder.stream()) - : allocator::Buffer(nullptr), - {workspace_size}, - uint8); + void* workspace_ptr = nullptr; + if (workspace_size > 0) { + array workspace( + cu::malloc_async(workspace_size, encoder.stream()), + {workspace_size}, + uint8); + encoder.add_temporary(workspace); + workspace_ptr = gpu_ptr(workspace); + } auto args = cudnn_frontend::VariantPackBuilder() - .setWorkspacePointer(gpu_ptr(workspace)) + .setWorkspacePointer(workspace_ptr) .setDataPointers(num_args, data_ptrs) .setUids(num_args, uids) .build(); @@ -151,7 +155,6 @@ bool prepare_cudnn_plan( return false; } - encoder.add_temporary(workspace); return true; }