[CUDA] Simplify allocator (#2392)

* simplify allocator and fixe race with small pool

* Don't use shared event in worker

* use cuda buffer in small pool

* comment

* comment
This commit is contained in:
Awni Hannun
2025-07-22 08:24:01 -07:00
committed by GitHub
parent 74eccbf3fa
commit 1e496ddb82
9 changed files with 100 additions and 162 deletions

View File

@@ -7,13 +7,10 @@
#include <mutex>
#include <set>
#include <thread>
#include <utility>
namespace mlx::core::cu {
class Worker;
using allocator::Buffer;
// Stores cuda-managed unified memory.
@@ -24,13 +21,14 @@ struct CudaBuffer {
class SmallSizePool {
private:
struct Block {
union Block {
Block* next;
CudaBuffer buf;
};
void* buffer_{nullptr};
Block* buffer_{nullptr};
void* data_{nullptr};
Block* next_free_{nullptr};
void* end_{nullptr};
public:
SmallSizePool();
@@ -39,9 +37,9 @@ class SmallSizePool {
SmallSizePool(const SmallSizePool&) = delete;
SmallSizePool& operator=(const SmallSizePool&) = delete;
void* malloc();
void free(void* p);
bool in_pool(void* p);
CudaBuffer* malloc();
void free(CudaBuffer* buf);
bool in_pool(CudaBuffer* buf);
};
class CudaAllocator : public allocator::Allocator {
@@ -50,15 +48,6 @@ class CudaAllocator : public allocator::Allocator {
void free(Buffer buffer) override;
size_t size(Buffer buffer) const override;
// Register current thread as safe to free buffers.
// In cuda freeing a buffer implicitly synchronizes stream, and for threads
// that may be waited by gpu stream (for example cpu stream threads), freeing
// buffers there would result in dead lock.
void register_this_thread();
// Call cudaFree in the safe thread.
void cuda_free(void* buf);
size_t get_active_memory() const;
size_t get_peak_memory() const;
void reset_peak_memory();
@@ -69,13 +58,11 @@ class CudaAllocator : public allocator::Allocator {
void clear_cache();
private:
void cuda_free(CudaBuffer* buf);
CudaAllocator();
friend CudaAllocator& allocator();
std::mutex worker_mutex_;
std::unique_ptr<Worker> worker_;
std::set<std::thread::id> allowed_threads_;
std::mutex mutex_;
size_t memory_limit_;
size_t max_pool_size_;