From 5722c147de361c4154aae117758445d8511e7220 Mon Sep 17 00:00:00 2001 From: Andrey Portnoy Date: Thu, 21 Aug 2025 22:57:20 -0400 Subject: [PATCH] [CUDA] Update calls to `cudaMemAdvise` and `cudaGraphAddDependencies` for CUDA 13 (#2525) * [CUDA] Update cudaMemAdvise and cudaGraphAddDependencies for CUDA 13 These functions' signatures changed in CUDA 13, so we differentiate between CUDA 13 and preceding releases at compile time. * Mention NVIDIA in ACKNOWLEDGMENTS.md --- ACKNOWLEDGMENTS.md | 5 +++++ mlx/backend/cuda/allocator.cpp | 9 ++++++++- mlx/backend/cuda/device.cpp | 8 +++++++- 3 files changed, 20 insertions(+), 2 deletions(-) diff --git a/ACKNOWLEDGMENTS.md b/ACKNOWLEDGMENTS.md index 786c9042c..6e479e5a6 100644 --- a/ACKNOWLEDGMENTS.md +++ b/ACKNOWLEDGMENTS.md @@ -25,6 +25,11 @@ MLX was developed with contributions from the following individuals: +# Organizations + +MLX has received contributions from the following companies: +- NVIDIA Corporation & Affiliates + # Third-Party Software MLX leverages several third-party software, listed here together with diff --git a/mlx/backend/cuda/allocator.cpp b/mlx/backend/cuda/allocator.cpp index 93bf48542..5eb10b8ac 100644 --- a/mlx/backend/cuda/allocator.cpp +++ b/mlx/backend/cuda/allocator.cpp @@ -30,8 +30,15 @@ SmallSizePool::SmallSizePool() { next_free_ = buffer_; CHECK_CUDA_ERROR(cudaMallocManaged(&data_, small_pool_size)); +#if CUDART_VERSION >= 13000 + cudaMemLocation loc; + loc.type = cudaMemLocationTypeDevice; + loc.id = 0; +#else + int loc = 0; +#endif // CUDART_VERSION >= 13000 CHECK_CUDA_ERROR( - cudaMemAdvise(data_, small_pool_size, cudaMemAdviseSetReadMostly, 0)); + cudaMemAdvise(data_, small_pool_size, cudaMemAdviseSetReadMostly, loc)); auto curr = next_free_; for (size_t i = 1; i < num_blocks; ++i) { diff --git a/mlx/backend/cuda/device.cpp b/mlx/backend/cuda/device.cpp index 371ae020c..334655ffe 100644 --- a/mlx/backend/cuda/device.cpp +++ b/mlx/backend/cuda/device.cpp @@ -269,7 +269,13 @@ void CommandEncoder::commit() { if (node_count_ > 0) { if (!from_nodes_.empty()) { CHECK_CUDA_ERROR(cudaGraphAddDependencies( - graph_, from_nodes_.data(), to_nodes_.data(), from_nodes_.size())); + graph_, + from_nodes_.data(), + to_nodes_.data(), +#if CUDART_VERSION >= 13000 + nullptr, // edgeData +#endif // CUDART_VERSION >= 13000 + from_nodes_.size())); } graph_key_ += ".";