[CUDA] Reduce use of managed memory (#2725)
Some checks failed
Nightly Build / build_linux_release (3.10) (push) Has been cancelled
Nightly Build / build_linux_release (3.14) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.10) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14) (push) Has been cancelled
Nightly Build / build_mac_release (3.10) (push) Has been cancelled
Nightly Build / build_mac_release (3.13) (push) Has been cancelled
Nightly Build / build_cuda_with_tests (push) Has been cancelled
Nightly Build / build_cuda_release (push) Has been cancelled
Nightly Build / Linux Fedora CPP Build (aarch64) (push) Has been cancelled
Nightly Build / Linux Fedora CPP Build (x86_64) (push) Has been cancelled

* Use async cuda malloc managed with cuda 13

* add pool threshold

* refactor for regular cuda malloc

* load eval gpu for cuda

* remove use of cuda pool, use cuda free async

* fix

* fix

* fix

* fix

* fix + comment
This commit is contained in:
Awni Hannun
2025-11-05 16:05:23 -08:00
committed by GitHub
parent 27778156dc
commit df58b4133a
79 changed files with 795 additions and 515 deletions

View File

@@ -3,6 +3,7 @@
#pragma once
#include "mlx/array.h"
#include "mlx/backend/cuda/allocator.h"
#include "mlx/backend/cuda/device/config.h"
#include "mlx/backend/cuda/utils.h"
#include "mlx/dtype_utils.h"
@@ -23,7 +24,7 @@ class CommandEncoder;
// Return pointer alignment of |x|'s data.
inline uint8_t get_alignment(const array& x) {
uint8_t alignment = 1;
uintptr_t address = reinterpret_cast<uintptr_t>(x.data<void>());
uintptr_t address = reinterpret_cast<uintptr_t>(gpu_ptr<void>(x));
for (; alignment < 32; alignment *= 2) {
if (address % (alignment * 2)) {
return alignment;
@@ -56,7 +57,7 @@ inline std::array<T, MAX_NDIM> vector_key(const Vec<T>& vec) {
// Helpers used by get_data_ptrs to get pointers.
inline void* get_data_ptr(const array& arr) {
return const_cast<void*>(arr.data<void>());
return const_cast<void*>(gpu_ptr<void>(arr));
}
template <typename T, typename = std::enable_if_t<std::is_scalar_v<T>>>