diff --git a/mlx/backend/cuda/allocator.cpp b/mlx/backend/cuda/allocator.cpp index 97a5ae4d4..6ef2da907 100644 --- a/mlx/backend/cuda/allocator.cpp +++ b/mlx/backend/cuda/allocator.cpp @@ -99,8 +99,10 @@ CudaAllocator::CudaAllocator() loc.id = 0; loc.type = cudaMemLocationTypeNone; cudaMemGetDefaultMemPool(&cuda_pool_, &loc, cudaMemAllocationTypeManaged); - // TODO set that. - // uint64_t threshold = UINT64_MAX; + // TODO need a strategy for that + uint64_t threshold = UINT64_MAX; + cudaMemPoolSetAttribute( + cuda_pool_, cudaMemPoolAttrReleaseThreshold, &threshold); #endif }