Files
mlx/mlx/allocator.h
Awni Hannun df58b4133a
Some checks failed
Nightly Build / build_linux_release (3.10) (push) Has been cancelled
Nightly Build / build_linux_release (3.14) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.10) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14) (push) Has been cancelled
Nightly Build / build_mac_release (3.10) (push) Has been cancelled
Nightly Build / build_mac_release (3.13) (push) Has been cancelled
Nightly Build / build_cuda_with_tests (push) Has been cancelled
Nightly Build / build_cuda_release (push) Has been cancelled
Nightly Build / Linux Fedora CPP Build (aarch64) (push) Has been cancelled
Nightly Build / Linux Fedora CPP Build (x86_64) (push) Has been cancelled
[CUDA] Reduce use of managed memory (#2725)
* Use async cuda malloc managed with cuda 13

* add pool threshold

* refactor for regular cuda malloc

* load eval gpu for cuda

* remove use of cuda pool, use cuda free async

* fix

* fix

* fix

* fix

* fix + comment
2025-11-05 16:05:23 -08:00

53 lines
1.1 KiB
C++

// Copyright © 2023 Apple Inc.
#pragma once
#include <cstdlib>
namespace mlx::core::allocator {
// Simple wrapper around buffer pointers
// WARNING: Only Buffer objects constructed from and those that wrap
// raw pointers from mlx::allocator are supported.
class Buffer {
private:
void* ptr_;
public:
explicit Buffer(void* ptr) : ptr_(ptr) {};
// Get the raw data pointer from the buffer
void* raw_ptr();
// Get the buffer pointer from the buffer
const void* ptr() const {
return ptr_;
};
void* ptr() {
return ptr_;
};
};
Buffer malloc(size_t size);
void free(Buffer buffer);
class Allocator {
/** Abstract base class for a memory allocator. */
public:
virtual Buffer malloc(size_t size) = 0;
virtual void free(Buffer buffer) = 0;
virtual size_t size(Buffer buffer) const = 0;
Allocator() = default;
Allocator(const Allocator& other) = delete;
Allocator(Allocator&& other) = delete;
Allocator& operator=(const Allocator& other) = delete;
Allocator& operator=(Allocator&& other) = delete;
virtual ~Allocator() = default;
};
Allocator& allocator();
} // namespace mlx::core::allocator