mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 16:21:27 +08:00 
			
		
		
		
	[CUDA] Simplify allocator (#2392)
* simplify allocator and fixe race with small pool * Don't use shared event in worker * use cuda buffer in small pool * comment * comment
This commit is contained in:
		| @@ -2,7 +2,6 @@ | ||||
|  | ||||
| #include "mlx/backend/cuda/allocator.h" | ||||
| #include "mlx/backend/cuda/utils.h" | ||||
| #include "mlx/backend/cuda/worker.h" | ||||
| #include "mlx/utils.h" | ||||
|  | ||||
| #include <cuda_runtime.h> | ||||
| @@ -25,52 +24,58 @@ constexpr int small_block_size = 8; | ||||
| constexpr int small_pool_size = 4 * page_size; | ||||
|  | ||||
| SmallSizePool::SmallSizePool() { | ||||
|   CHECK_CUDA_ERROR(cudaMallocManaged(&buffer_, small_pool_size)); | ||||
|   end_ = reinterpret_cast<void*>( | ||||
|       reinterpret_cast<char*>(buffer_) + small_pool_size); | ||||
|   next_free_ = reinterpret_cast<Block*>(buffer_); | ||||
|  | ||||
|   auto num_blocks = small_pool_size / small_block_size; | ||||
|   buffer_ = new Block[num_blocks]; | ||||
|  | ||||
|   next_free_ = buffer_; | ||||
|  | ||||
|   CHECK_CUDA_ERROR(cudaMallocManaged(&data_, small_pool_size)); | ||||
|   CHECK_CUDA_ERROR( | ||||
|       cudaMemAdvise(data_, small_pool_size, cudaMemAdviseSetReadMostly, 0)); | ||||
|  | ||||
|   auto curr = next_free_; | ||||
|   for (size_t i = 0; i < num_blocks - 1; ++i) { | ||||
|     curr->next = reinterpret_cast<Block*>( | ||||
|         reinterpret_cast<char*>(buffer_) + (i + 1) * small_block_size); | ||||
|   for (size_t i = 1; i < num_blocks; ++i) { | ||||
|     curr->next = buffer_ + i; | ||||
|     curr = curr->next; | ||||
|   } | ||||
|   curr->next = nullptr; | ||||
| } | ||||
|  | ||||
| SmallSizePool::~SmallSizePool() { | ||||
|   CHECK_CUDA_ERROR(cudaFree(buffer_)); | ||||
|   CHECK_CUDA_ERROR(cudaFree(data_)); | ||||
|   delete[] buffer_; | ||||
| } | ||||
|  | ||||
| void* SmallSizePool::malloc() { | ||||
| CudaBuffer* SmallSizePool::malloc() { | ||||
|   if (next_free_ == nullptr) { | ||||
|     return nullptr; | ||||
|   } | ||||
|   Block* b = next_free_; | ||||
|   uint64_t i = next_free_ - buffer_; | ||||
|   next_free_ = next_free_->next; | ||||
|   return static_cast<void*>(b); | ||||
|   b->buf.data = static_cast<char*>(data_) + i * small_block_size; | ||||
|   b->buf.size = small_block_size; | ||||
|   return &b->buf; | ||||
| } | ||||
|  | ||||
| void SmallSizePool::free(void* p) { | ||||
|   auto b = static_cast<Block*>(p); | ||||
| void SmallSizePool::free(CudaBuffer* buf) { | ||||
|   auto b = reinterpret_cast<Block*>(buf); | ||||
|   b->next = next_free_; | ||||
|   next_free_ = b; | ||||
| } | ||||
|  | ||||
| bool SmallSizePool::in_pool(void* p) { | ||||
|   return (p >= buffer_) && (p < end_); | ||||
| bool SmallSizePool::in_pool(CudaBuffer* buf) { | ||||
|   constexpr int num_blocks = (small_pool_size / small_block_size); | ||||
|   auto b = reinterpret_cast<Block*>(buf); | ||||
|   int64_t block_num = b - buffer_; | ||||
|   return block_num >= 0 && block_num < num_blocks; | ||||
| } | ||||
|  | ||||
| CudaAllocator::CudaAllocator() | ||||
|     : buffer_cache_( | ||||
|           page_size, | ||||
|           [](CudaBuffer* buf) { return buf->size; }, | ||||
|           [this](CudaBuffer* buf) { | ||||
|             cuda_free(buf->data); | ||||
|             delete buf; | ||||
|           }) { | ||||
|           [this](CudaBuffer* buf) { cuda_free(buf); }) { | ||||
|   // TODO: Set memory limit for multi-device. | ||||
|   size_t free, total; | ||||
|   CHECK_CUDA_ERROR(cudaMemGetInfo(&free, &total)); | ||||
| @@ -92,28 +97,26 @@ Buffer CudaAllocator::malloc(size_t size) { | ||||
|  | ||||
|   CudaBuffer* buf = buffer_cache_.reuse_from_cache(size); | ||||
|   if (!buf) { | ||||
|     // If we have a lot of memory pressure or are over the maximum cache size, | ||||
|     // try to reclaim memory from the cache. | ||||
|     size_t mem_required = get_active_memory() + get_cache_memory() + size; | ||||
|     if (mem_required >= memory_limit_) { | ||||
|       buffer_cache_.release_cached_buffers(mem_required - memory_limit_); | ||||
|     // If we have a lot of memory pressure try to reclaim memory from the cache. | ||||
|     int64_t mem_to_free = | ||||
|         get_active_memory() + get_cache_memory() + size - memory_limit_; | ||||
|     if (mem_to_free > 0) { | ||||
|       buffer_cache_.release_cached_buffers(mem_to_free); | ||||
|     } | ||||
|  | ||||
|     lock.unlock(); | ||||
|     buf = new CudaBuffer{nullptr, size}; | ||||
|  | ||||
|     // Try the scalar pool first | ||||
|     if (size <= small_block_size) { | ||||
|       buf->data = scalar_pool_.malloc(); | ||||
|       buf = scalar_pool_.malloc(); | ||||
|     } | ||||
|     if (!buf->data) { | ||||
|     lock.unlock(); | ||||
|     if (!buf) { | ||||
|       buf = new CudaBuffer{nullptr, size}; | ||||
|       cudaError_t err = cudaMallocManaged(&buf->data, size); | ||||
|       if (err != cudaSuccess && err != cudaErrorMemoryAllocation) { | ||||
|         throw std::runtime_error(fmt::format( | ||||
|             "cudaMallocManaged failed: {}.", cudaGetErrorString(err))); | ||||
|       } | ||||
|     } | ||||
|  | ||||
|     lock.lock(); | ||||
|   } | ||||
|   active_memory_ += size; | ||||
| @@ -123,7 +126,6 @@ Buffer CudaAllocator::malloc(size_t size) { | ||||
|   if (get_cache_memory() > max_pool_size_) { | ||||
|     buffer_cache_.release_cached_buffers(get_cache_memory() - max_pool_size_); | ||||
|   } | ||||
|  | ||||
|   return Buffer{buf}; | ||||
| } | ||||
|  | ||||
| @@ -138,9 +140,7 @@ void CudaAllocator::free(Buffer buffer) { | ||||
|   if (get_cache_memory() < max_pool_size_) { | ||||
|     buffer_cache_.recycle_to_cache(buf); | ||||
|   } else { | ||||
|     lock.unlock(); | ||||
|     cuda_free(buf->data); | ||||
|     delete buf; | ||||
|     cuda_free(buf); | ||||
|   } | ||||
| } | ||||
|  | ||||
| @@ -152,30 +152,13 @@ size_t CudaAllocator::size(Buffer buffer) const { | ||||
|   return buf->size; | ||||
| } | ||||
|  | ||||
| void CudaAllocator::register_this_thread() { | ||||
|   std::lock_guard lock(worker_mutex_); | ||||
|   allowed_threads_.insert(std::this_thread::get_id()); | ||||
| } | ||||
|  | ||||
| void CudaAllocator::cuda_free(void* buf) { | ||||
|   // If cuda_free() is called from a unregistered thread, reschedule the call to | ||||
|   // worker. | ||||
|   { | ||||
|     std::lock_guard lock(worker_mutex_); | ||||
|     if (allowed_threads_.count(std::this_thread::get_id()) == 0) { | ||||
|       if (!worker_) { | ||||
|         worker_.reset(new Worker); | ||||
|       } | ||||
|       worker_->add_task([this, buf]() { this->cuda_free(buf); }); | ||||
|       worker_->end_batch(); | ||||
|       worker_->commit(); | ||||
|       return; | ||||
|     } | ||||
|   } | ||||
| // This must be called with mutex_ aquired | ||||
| void CudaAllocator::cuda_free(CudaBuffer* buf) { | ||||
|   if (scalar_pool_.in_pool(buf)) { | ||||
|     scalar_pool_.free(buf); | ||||
|   } else { | ||||
|     cudaFree(buf); | ||||
|     cudaFree(buf->data); | ||||
|     delete buf; | ||||
|   } | ||||
| } | ||||
|  | ||||
|   | ||||
| @@ -7,13 +7,10 @@ | ||||
|  | ||||
| #include <mutex> | ||||
| #include <set> | ||||
| #include <thread> | ||||
| #include <utility> | ||||
|  | ||||
| namespace mlx::core::cu { | ||||
|  | ||||
| class Worker; | ||||
|  | ||||
| using allocator::Buffer; | ||||
|  | ||||
| // Stores cuda-managed unified memory. | ||||
| @@ -24,13 +21,14 @@ struct CudaBuffer { | ||||
|  | ||||
| class SmallSizePool { | ||||
|  private: | ||||
|   struct Block { | ||||
|   union Block { | ||||
|     Block* next; | ||||
|     CudaBuffer buf; | ||||
|   }; | ||||
|  | ||||
|   void* buffer_{nullptr}; | ||||
|   Block* buffer_{nullptr}; | ||||
|   void* data_{nullptr}; | ||||
|   Block* next_free_{nullptr}; | ||||
|   void* end_{nullptr}; | ||||
|  | ||||
|  public: | ||||
|   SmallSizePool(); | ||||
| @@ -39,9 +37,9 @@ class SmallSizePool { | ||||
|   SmallSizePool(const SmallSizePool&) = delete; | ||||
|   SmallSizePool& operator=(const SmallSizePool&) = delete; | ||||
|  | ||||
|   void* malloc(); | ||||
|   void free(void* p); | ||||
|   bool in_pool(void* p); | ||||
|   CudaBuffer* malloc(); | ||||
|   void free(CudaBuffer* buf); | ||||
|   bool in_pool(CudaBuffer* buf); | ||||
| }; | ||||
|  | ||||
| class CudaAllocator : public allocator::Allocator { | ||||
| @@ -50,15 +48,6 @@ class CudaAllocator : public allocator::Allocator { | ||||
|   void free(Buffer buffer) override; | ||||
|   size_t size(Buffer buffer) const override; | ||||
|  | ||||
|   // Register current thread as safe to free buffers. | ||||
|   // In cuda freeing a buffer implicitly synchronizes stream, and for threads | ||||
|   // that may be waited by gpu stream (for example cpu stream threads), freeing | ||||
|   // buffers there would result in dead lock. | ||||
|   void register_this_thread(); | ||||
|  | ||||
|   // Call cudaFree in the safe thread. | ||||
|   void cuda_free(void* buf); | ||||
|  | ||||
|   size_t get_active_memory() const; | ||||
|   size_t get_peak_memory() const; | ||||
|   void reset_peak_memory(); | ||||
| @@ -69,13 +58,11 @@ class CudaAllocator : public allocator::Allocator { | ||||
|   void clear_cache(); | ||||
|  | ||||
|  private: | ||||
|   void cuda_free(CudaBuffer* buf); | ||||
|  | ||||
|   CudaAllocator(); | ||||
|   friend CudaAllocator& allocator(); | ||||
|  | ||||
|   std::mutex worker_mutex_; | ||||
|   std::unique_ptr<Worker> worker_; | ||||
|   std::set<std::thread::id> allowed_threads_; | ||||
|  | ||||
|   std::mutex mutex_; | ||||
|   size_t memory_limit_; | ||||
|   size_t max_pool_size_; | ||||
|   | ||||
| @@ -306,7 +306,6 @@ void CommandEncoder::commit() { | ||||
|   } | ||||
|  | ||||
|   // Put completion handlers in a batch. | ||||
|   worker_.end_batch(); | ||||
|   worker_.commit(stream_); | ||||
| } | ||||
|  | ||||
| @@ -315,7 +314,6 @@ void CommandEncoder::synchronize() { | ||||
|   auto p = std::make_shared<std::promise<void>>(); | ||||
|   std::future<void> f = p->get_future(); | ||||
|   add_completed_handler([p = std::move(p)]() { p->set_value(); }); | ||||
|   worker_.end_batch(); | ||||
|   commit(); | ||||
|   f.wait(); | ||||
| } | ||||
|   | ||||
| @@ -19,8 +19,6 @@ void new_stream(Stream s) { | ||||
|   cudaFree(nullptr); | ||||
|   // Ensure the static stream objects get created. | ||||
|   cu::get_command_encoder(s); | ||||
|   // The main thread is safe to free buffers. | ||||
|   cu::allocator().register_this_thread(); | ||||
| } | ||||
|  | ||||
| void eval(array& arr) { | ||||
|   | ||||
| @@ -110,24 +110,26 @@ __global__ void event_signal_kernel(SharedEvent::Atomic* ac, uint64_t value) { | ||||
|   event_signal(ac, value); | ||||
| } | ||||
|  | ||||
| SharedEvent::Atomic* to_atomic(std::shared_ptr<Buffer> buf) { | ||||
|   return static_cast<SharedEvent::Atomic*>(buf->raw_ptr()); | ||||
| } | ||||
|  | ||||
| SharedEvent::SharedEvent() { | ||||
|   // Allocate cuda::atomic on managed memory. | ||||
|   Atomic* ac; | ||||
|   CHECK_CUDA_ERROR(cudaMallocManaged(&ac, sizeof(Atomic))); | ||||
|   new (ac) Atomic(0); | ||||
|   ac_ = std::shared_ptr<Atomic>(ac, [](Atomic* ptr) { | ||||
|     ptr->~Atomic(); | ||||
|     allocator().cuda_free(ptr); | ||||
|   }); | ||||
|   buf_ = std::shared_ptr<Buffer>( | ||||
|       new Buffer{allocator().malloc(sizeof(Atomic))}, [](Buffer* ptr) { | ||||
|         allocator().free(*ptr); | ||||
|         delete ptr; | ||||
|       }); | ||||
|   *static_cast<uint64_t*>(buf_->raw_ptr()) = 0; | ||||
| } | ||||
|  | ||||
| void SharedEvent::wait(uint64_t value) { | ||||
|   nvtx3::scoped_range r("cu::SharedEvent::wait"); | ||||
|   event_wait(ac_.get(), value); | ||||
|   event_wait(to_atomic(buf_), value); | ||||
| } | ||||
|  | ||||
| void SharedEvent::wait(cudaStream_t stream, uint64_t value) { | ||||
|   event_wait_kernel<<<1, 1, 0, stream>>>(ac_.get(), value); | ||||
|   event_wait_kernel<<<1, 1, 0, stream>>>(to_atomic(buf_), value); | ||||
| } | ||||
|  | ||||
| void SharedEvent::wait(Stream s, uint64_t value) { | ||||
| @@ -138,17 +140,17 @@ void SharedEvent::wait(Stream s, uint64_t value) { | ||||
|     auto& encoder = get_command_encoder(s); | ||||
|     encoder.commit(); | ||||
|     wait(encoder.stream(), value); | ||||
|     encoder.add_completed_handler([ac = ac_]() {}); | ||||
|     encoder.add_completed_handler([buf = buf_]() {}); | ||||
|   } | ||||
| } | ||||
|  | ||||
| void SharedEvent::signal(uint64_t value) { | ||||
|   nvtx3::scoped_range r("cu::SharedEvent::signal"); | ||||
|   event_signal(ac_.get(), value); | ||||
|   event_signal(to_atomic(buf_), value); | ||||
| } | ||||
|  | ||||
| void SharedEvent::signal(cudaStream_t stream, uint64_t value) { | ||||
|   event_signal_kernel<<<1, 1, 0, stream>>>(ac_.get(), value); | ||||
|   event_signal_kernel<<<1, 1, 0, stream>>>(to_atomic(buf_), value); | ||||
| } | ||||
|  | ||||
| void SharedEvent::signal(Stream s, uint64_t value) { | ||||
| @@ -162,18 +164,18 @@ void SharedEvent::signal(Stream s, uint64_t value) { | ||||
|     auto& encoder = get_command_encoder(s); | ||||
|     encoder.commit(); | ||||
|     signal(encoder.stream(), value); | ||||
|     encoder.add_completed_handler([ac = ac_]() {}); | ||||
|     encoder.add_completed_handler([buf = buf_]() {}); | ||||
|   } | ||||
| } | ||||
|  | ||||
| bool SharedEvent::is_signaled(uint64_t value) const { | ||||
|   nvtx3::scoped_range r("cu::SharedEvent::is_signaled"); | ||||
|   return ac_->load() >= value; | ||||
|   return to_atomic(buf_)->load() >= value; | ||||
| } | ||||
|  | ||||
| uint64_t SharedEvent::value() const { | ||||
|   nvtx3::scoped_range r("cu::SharedEvent::value"); | ||||
|   return ac_->load(); | ||||
|   return to_atomic(buf_)->load(); | ||||
| } | ||||
|  | ||||
| } // namespace cu | ||||
|   | ||||
| @@ -2,6 +2,7 @@ | ||||
|  | ||||
| #pragma once | ||||
|  | ||||
| #include "mlx/allocator.h" | ||||
| #include "mlx/stream.h" | ||||
|  | ||||
| #include <cuda_runtime.h> | ||||
| @@ -55,12 +56,8 @@ class SharedEvent { | ||||
|   bool is_signaled(uint64_t value) const; | ||||
|   uint64_t value() const; | ||||
|  | ||||
|   const std::shared_ptr<Atomic>& atomic() const { | ||||
|     return ac_; | ||||
|   } | ||||
|  | ||||
|  private: | ||||
|   std::shared_ptr<Atomic> ac_; | ||||
|   std::shared_ptr<mlx::core::allocator::Buffer> buf_; | ||||
| }; | ||||
|  | ||||
| } // namespace mlx::core::cu | ||||
|   | ||||
| @@ -1,7 +1,6 @@ | ||||
| // Copyright © 2025 Apple Inc. | ||||
|  | ||||
| #include "mlx/backend/cuda/worker.h" | ||||
| #include "mlx/backend/cuda/allocator.h" | ||||
| #include "mlx/backend/cuda/device.h" | ||||
|  | ||||
| namespace mlx::core::cu { | ||||
| @@ -12,10 +11,10 @@ Worker::Worker() | ||||
|  | ||||
| Worker::~Worker() { | ||||
|   { | ||||
|     std::lock_guard lock(worker_mutex_); | ||||
|     std::lock_guard lock(mtx_); | ||||
|     stop_ = true; | ||||
|   } | ||||
|   worker_event_.signal(batch_ + 1); | ||||
|   cond_.notify_one(); | ||||
|   worker_.join(); | ||||
| } | ||||
|  | ||||
| @@ -23,53 +22,41 @@ void Worker::add_task(std::function<void()> task) { | ||||
|   pending_tasks_.push_back(std::move(task)); | ||||
| } | ||||
|  | ||||
| void Worker::consume_in_this_thread() { | ||||
|   for (auto& task : pending_tasks_) { | ||||
|     task(); | ||||
|   } | ||||
|   pending_tasks_.clear(); | ||||
| } | ||||
|  | ||||
| void Worker::end_batch() { | ||||
|   batch_++; | ||||
| void Worker::signal(void* data) { | ||||
|   auto w = static_cast<Worker*>(data); | ||||
|   { | ||||
|     std::lock_guard lock(worker_mutex_); | ||||
|     worker_tasks_[batch_] = std::move(pending_tasks_); | ||||
|     std::lock_guard lock(w->mtx_); | ||||
|     w->signaled_batch_++; | ||||
|   } | ||||
|   uncommited_batches_++; | ||||
| } | ||||
|  | ||||
| void Worker::commit() { | ||||
|   if (uncommited_batches_ == 0) { | ||||
|     return; | ||||
|   } | ||||
|   uncommited_batches_ = 0; | ||||
|   worker_event_.signal(batch_); | ||||
|   w->cond_.notify_one(); | ||||
| } | ||||
|  | ||||
| void Worker::commit(cudaStream_t stream) { | ||||
|   if (uncommited_batches_ == 0) { | ||||
|   // Move pending tasks into tasks | ||||
|   if (pending_tasks_.empty()) { | ||||
|     return; | ||||
|   } | ||||
|   uncommited_batches_ = 0; | ||||
|   // Signal the |worker_event_| in |signal_stream_| after the kernels in | ||||
|   // |stream_| finish running. | ||||
|   { | ||||
|     std::lock_guard lock(mtx_); | ||||
|     // Move pending tasks into ready tasks | ||||
|     worker_tasks_[++committed_batch_] = std::move(pending_tasks_); | ||||
|   } | ||||
|   signal_event_.record(stream); | ||||
|   signal_event_.wait(signal_stream_); | ||||
|   worker_event_.signal(signal_stream_, batch_); | ||||
|   cudaLaunchHostFunc(signal_stream_, signal, this); | ||||
| } | ||||
|  | ||||
| void Worker::thread_fn() { | ||||
|   // The worker thread is safe to free buffers. | ||||
|   allocator().register_this_thread(); | ||||
|  | ||||
|   while (!stop_) { | ||||
|     uint64_t batch = worker_event_.value(); | ||||
|     uint64_t current_batch = 0; | ||||
|     Tasks tasks; | ||||
|     { | ||||
|       std::lock_guard lock(worker_mutex_); | ||||
|       // Move tasks in signaled batches. | ||||
|       auto end = worker_tasks_.upper_bound(batch); | ||||
|       std::unique_lock<std::mutex> lk(mtx_); | ||||
|       cond_.wait(lk, [this, ¤t_batch] { | ||||
|         return this->signaled_batch_ > current_batch || this->stop_; | ||||
|       }); | ||||
|       current_batch = signaled_batch_; | ||||
|       auto end = worker_tasks_.upper_bound(current_batch); | ||||
|       for (auto it = worker_tasks_.begin(); it != end; ++it) { | ||||
|         if (tasks.empty()) { | ||||
|           tasks = std::move(it->second); | ||||
| @@ -85,7 +72,6 @@ void Worker::thread_fn() { | ||||
|       auto task = std::move(tasks[i]); | ||||
|       task(); | ||||
|     } | ||||
|     worker_event_.wait(batch + 1); | ||||
|   } | ||||
| } | ||||
|  | ||||
|   | ||||
| @@ -5,6 +5,7 @@ | ||||
| #include "mlx/backend/cuda/event.h" | ||||
| #include "mlx/backend/cuda/utils.h" | ||||
|  | ||||
| #include <condition_variable> | ||||
| #include <functional> | ||||
| #include <map> | ||||
| #include <mutex> | ||||
| @@ -24,38 +25,24 @@ class Worker { | ||||
|   // Add a pending |task| that will run when consumed or commited. | ||||
|   void add_task(std::function<void()> task); | ||||
|  | ||||
|   // Run pending tasks immediately in current thread. | ||||
|   void consume_in_this_thread(); | ||||
|  | ||||
|   // Put pending tasks in a batch. | ||||
|   void end_batch(); | ||||
|  | ||||
|   // Inform worker thread to run current batches now. | ||||
|   void commit(); | ||||
|  | ||||
|   // Inform worker thread to run current batches after kernels in |stream| | ||||
|   // finish running. | ||||
|   void commit(cudaStream_t stream); | ||||
|  | ||||
|   // Return how many batches have been added but not committed yet. | ||||
|   size_t uncommited_batches() const { | ||||
|     return uncommited_batches_; | ||||
|   } | ||||
|  | ||||
|  private: | ||||
|   void thread_fn(); | ||||
|   static void signal(void*); | ||||
|  | ||||
|   uint64_t batch_{0}; | ||||
|   size_t uncommited_batches_{0}; | ||||
|   void thread_fn(); | ||||
|   std::mutex mtx_; | ||||
|   std::condition_variable cond_; | ||||
|  | ||||
|   uint64_t committed_batch_{0}; | ||||
|   uint64_t signaled_batch_{0}; | ||||
|  | ||||
|   // Cuda stream and event for signaling kernel completion. | ||||
|   CudaStream signal_stream_; | ||||
|   CudaEvent signal_event_; | ||||
|  | ||||
|   // Worker thread. | ||||
|   SharedEvent worker_event_; | ||||
|   std::thread worker_; | ||||
|   std::mutex worker_mutex_; | ||||
|   bool stop_{false}; | ||||
|  | ||||
|   // Tasks are put in |pending_tasks_| first, and then moved to | ||||
| @@ -63,6 +50,7 @@ class Worker { | ||||
|   using Tasks = std::vector<std::function<void()>>; | ||||
|   Tasks pending_tasks_; | ||||
|   std::map<uint64_t, Tasks> worker_tasks_; | ||||
|   std::thread worker_; | ||||
| }; | ||||
|  | ||||
| } // namespace mlx::core::cu | ||||
|   | ||||
| @@ -128,8 +128,7 @@ Buffer MetalAllocator::malloc(size_t size) { | ||||
|  | ||||
|     auto pool = metal::new_scoped_memory_pool(); | ||||
|  | ||||
|     // If we have a lot of memory pressure or are over the maximum cache size, | ||||
|     // try to reclaim memory from the cache | ||||
|     // If we have a lot of memory pressure try to reclaim memory from the cache | ||||
|     if (mem_required >= gc_limit_ || num_resources_ >= resource_limit_) { | ||||
|       num_resources_ -= | ||||
|           buffer_cache_.release_cached_buffers(mem_required - gc_limit_); | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Awni Hannun
					Awni Hannun