mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-31 07:12:20 +08:00
[CUDA] Simplify allocator (#2392)
* simplify allocator and fixe race with small pool * Don't use shared event in worker * use cuda buffer in small pool * comment * comment
This commit is contained in:
parent
74eccbf3fa
commit
1e496ddb82
@ -2,7 +2,6 @@
|
|||||||
|
|
||||||
#include "mlx/backend/cuda/allocator.h"
|
#include "mlx/backend/cuda/allocator.h"
|
||||||
#include "mlx/backend/cuda/utils.h"
|
#include "mlx/backend/cuda/utils.h"
|
||||||
#include "mlx/backend/cuda/worker.h"
|
|
||||||
#include "mlx/utils.h"
|
#include "mlx/utils.h"
|
||||||
|
|
||||||
#include <cuda_runtime.h>
|
#include <cuda_runtime.h>
|
||||||
@ -25,52 +24,58 @@ constexpr int small_block_size = 8;
|
|||||||
constexpr int small_pool_size = 4 * page_size;
|
constexpr int small_pool_size = 4 * page_size;
|
||||||
|
|
||||||
SmallSizePool::SmallSizePool() {
|
SmallSizePool::SmallSizePool() {
|
||||||
CHECK_CUDA_ERROR(cudaMallocManaged(&buffer_, small_pool_size));
|
|
||||||
end_ = reinterpret_cast<void*>(
|
|
||||||
reinterpret_cast<char*>(buffer_) + small_pool_size);
|
|
||||||
next_free_ = reinterpret_cast<Block*>(buffer_);
|
|
||||||
|
|
||||||
auto num_blocks = small_pool_size / small_block_size;
|
auto num_blocks = small_pool_size / small_block_size;
|
||||||
|
buffer_ = new Block[num_blocks];
|
||||||
|
|
||||||
|
next_free_ = buffer_;
|
||||||
|
|
||||||
|
CHECK_CUDA_ERROR(cudaMallocManaged(&data_, small_pool_size));
|
||||||
|
CHECK_CUDA_ERROR(
|
||||||
|
cudaMemAdvise(data_, small_pool_size, cudaMemAdviseSetReadMostly, 0));
|
||||||
|
|
||||||
auto curr = next_free_;
|
auto curr = next_free_;
|
||||||
for (size_t i = 0; i < num_blocks - 1; ++i) {
|
for (size_t i = 1; i < num_blocks; ++i) {
|
||||||
curr->next = reinterpret_cast<Block*>(
|
curr->next = buffer_ + i;
|
||||||
reinterpret_cast<char*>(buffer_) + (i + 1) * small_block_size);
|
|
||||||
curr = curr->next;
|
curr = curr->next;
|
||||||
}
|
}
|
||||||
curr->next = nullptr;
|
curr->next = nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
SmallSizePool::~SmallSizePool() {
|
SmallSizePool::~SmallSizePool() {
|
||||||
CHECK_CUDA_ERROR(cudaFree(buffer_));
|
CHECK_CUDA_ERROR(cudaFree(data_));
|
||||||
|
delete[] buffer_;
|
||||||
}
|
}
|
||||||
|
|
||||||
void* SmallSizePool::malloc() {
|
CudaBuffer* SmallSizePool::malloc() {
|
||||||
if (next_free_ == nullptr) {
|
if (next_free_ == nullptr) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
Block* b = next_free_;
|
Block* b = next_free_;
|
||||||
|
uint64_t i = next_free_ - buffer_;
|
||||||
next_free_ = next_free_->next;
|
next_free_ = next_free_->next;
|
||||||
return static_cast<void*>(b);
|
b->buf.data = static_cast<char*>(data_) + i * small_block_size;
|
||||||
|
b->buf.size = small_block_size;
|
||||||
|
return &b->buf;
|
||||||
}
|
}
|
||||||
|
|
||||||
void SmallSizePool::free(void* p) {
|
void SmallSizePool::free(CudaBuffer* buf) {
|
||||||
auto b = static_cast<Block*>(p);
|
auto b = reinterpret_cast<Block*>(buf);
|
||||||
b->next = next_free_;
|
b->next = next_free_;
|
||||||
next_free_ = b;
|
next_free_ = b;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool SmallSizePool::in_pool(void* p) {
|
bool SmallSizePool::in_pool(CudaBuffer* buf) {
|
||||||
return (p >= buffer_) && (p < end_);
|
constexpr int num_blocks = (small_pool_size / small_block_size);
|
||||||
|
auto b = reinterpret_cast<Block*>(buf);
|
||||||
|
int64_t block_num = b - buffer_;
|
||||||
|
return block_num >= 0 && block_num < num_blocks;
|
||||||
}
|
}
|
||||||
|
|
||||||
CudaAllocator::CudaAllocator()
|
CudaAllocator::CudaAllocator()
|
||||||
: buffer_cache_(
|
: buffer_cache_(
|
||||||
page_size,
|
page_size,
|
||||||
[](CudaBuffer* buf) { return buf->size; },
|
[](CudaBuffer* buf) { return buf->size; },
|
||||||
[this](CudaBuffer* buf) {
|
[this](CudaBuffer* buf) { cuda_free(buf); }) {
|
||||||
cuda_free(buf->data);
|
|
||||||
delete buf;
|
|
||||||
}) {
|
|
||||||
// TODO: Set memory limit for multi-device.
|
// TODO: Set memory limit for multi-device.
|
||||||
size_t free, total;
|
size_t free, total;
|
||||||
CHECK_CUDA_ERROR(cudaMemGetInfo(&free, &total));
|
CHECK_CUDA_ERROR(cudaMemGetInfo(&free, &total));
|
||||||
@ -92,28 +97,26 @@ Buffer CudaAllocator::malloc(size_t size) {
|
|||||||
|
|
||||||
CudaBuffer* buf = buffer_cache_.reuse_from_cache(size);
|
CudaBuffer* buf = buffer_cache_.reuse_from_cache(size);
|
||||||
if (!buf) {
|
if (!buf) {
|
||||||
// If we have a lot of memory pressure or are over the maximum cache size,
|
// If we have a lot of memory pressure try to reclaim memory from the cache.
|
||||||
// try to reclaim memory from the cache.
|
int64_t mem_to_free =
|
||||||
size_t mem_required = get_active_memory() + get_cache_memory() + size;
|
get_active_memory() + get_cache_memory() + size - memory_limit_;
|
||||||
if (mem_required >= memory_limit_) {
|
if (mem_to_free > 0) {
|
||||||
buffer_cache_.release_cached_buffers(mem_required - memory_limit_);
|
buffer_cache_.release_cached_buffers(mem_to_free);
|
||||||
}
|
}
|
||||||
|
|
||||||
lock.unlock();
|
|
||||||
buf = new CudaBuffer{nullptr, size};
|
|
||||||
|
|
||||||
// Try the scalar pool first
|
// Try the scalar pool first
|
||||||
if (size <= small_block_size) {
|
if (size <= small_block_size) {
|
||||||
buf->data = scalar_pool_.malloc();
|
buf = scalar_pool_.malloc();
|
||||||
}
|
}
|
||||||
if (!buf->data) {
|
lock.unlock();
|
||||||
|
if (!buf) {
|
||||||
|
buf = new CudaBuffer{nullptr, size};
|
||||||
cudaError_t err = cudaMallocManaged(&buf->data, size);
|
cudaError_t err = cudaMallocManaged(&buf->data, size);
|
||||||
if (err != cudaSuccess && err != cudaErrorMemoryAllocation) {
|
if (err != cudaSuccess && err != cudaErrorMemoryAllocation) {
|
||||||
throw std::runtime_error(fmt::format(
|
throw std::runtime_error(fmt::format(
|
||||||
"cudaMallocManaged failed: {}.", cudaGetErrorString(err)));
|
"cudaMallocManaged failed: {}.", cudaGetErrorString(err)));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
lock.lock();
|
lock.lock();
|
||||||
}
|
}
|
||||||
active_memory_ += size;
|
active_memory_ += size;
|
||||||
@ -123,7 +126,6 @@ Buffer CudaAllocator::malloc(size_t size) {
|
|||||||
if (get_cache_memory() > max_pool_size_) {
|
if (get_cache_memory() > max_pool_size_) {
|
||||||
buffer_cache_.release_cached_buffers(get_cache_memory() - max_pool_size_);
|
buffer_cache_.release_cached_buffers(get_cache_memory() - max_pool_size_);
|
||||||
}
|
}
|
||||||
|
|
||||||
return Buffer{buf};
|
return Buffer{buf};
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -138,9 +140,7 @@ void CudaAllocator::free(Buffer buffer) {
|
|||||||
if (get_cache_memory() < max_pool_size_) {
|
if (get_cache_memory() < max_pool_size_) {
|
||||||
buffer_cache_.recycle_to_cache(buf);
|
buffer_cache_.recycle_to_cache(buf);
|
||||||
} else {
|
} else {
|
||||||
lock.unlock();
|
cuda_free(buf);
|
||||||
cuda_free(buf->data);
|
|
||||||
delete buf;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -152,30 +152,13 @@ size_t CudaAllocator::size(Buffer buffer) const {
|
|||||||
return buf->size;
|
return buf->size;
|
||||||
}
|
}
|
||||||
|
|
||||||
void CudaAllocator::register_this_thread() {
|
// This must be called with mutex_ aquired
|
||||||
std::lock_guard lock(worker_mutex_);
|
void CudaAllocator::cuda_free(CudaBuffer* buf) {
|
||||||
allowed_threads_.insert(std::this_thread::get_id());
|
|
||||||
}
|
|
||||||
|
|
||||||
void CudaAllocator::cuda_free(void* buf) {
|
|
||||||
// If cuda_free() is called from a unregistered thread, reschedule the call to
|
|
||||||
// worker.
|
|
||||||
{
|
|
||||||
std::lock_guard lock(worker_mutex_);
|
|
||||||
if (allowed_threads_.count(std::this_thread::get_id()) == 0) {
|
|
||||||
if (!worker_) {
|
|
||||||
worker_.reset(new Worker);
|
|
||||||
}
|
|
||||||
worker_->add_task([this, buf]() { this->cuda_free(buf); });
|
|
||||||
worker_->end_batch();
|
|
||||||
worker_->commit();
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (scalar_pool_.in_pool(buf)) {
|
if (scalar_pool_.in_pool(buf)) {
|
||||||
scalar_pool_.free(buf);
|
scalar_pool_.free(buf);
|
||||||
} else {
|
} else {
|
||||||
cudaFree(buf);
|
cudaFree(buf->data);
|
||||||
|
delete buf;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -7,13 +7,10 @@
|
|||||||
|
|
||||||
#include <mutex>
|
#include <mutex>
|
||||||
#include <set>
|
#include <set>
|
||||||
#include <thread>
|
|
||||||
#include <utility>
|
#include <utility>
|
||||||
|
|
||||||
namespace mlx::core::cu {
|
namespace mlx::core::cu {
|
||||||
|
|
||||||
class Worker;
|
|
||||||
|
|
||||||
using allocator::Buffer;
|
using allocator::Buffer;
|
||||||
|
|
||||||
// Stores cuda-managed unified memory.
|
// Stores cuda-managed unified memory.
|
||||||
@ -24,13 +21,14 @@ struct CudaBuffer {
|
|||||||
|
|
||||||
class SmallSizePool {
|
class SmallSizePool {
|
||||||
private:
|
private:
|
||||||
struct Block {
|
union Block {
|
||||||
Block* next;
|
Block* next;
|
||||||
|
CudaBuffer buf;
|
||||||
};
|
};
|
||||||
|
|
||||||
void* buffer_{nullptr};
|
Block* buffer_{nullptr};
|
||||||
|
void* data_{nullptr};
|
||||||
Block* next_free_{nullptr};
|
Block* next_free_{nullptr};
|
||||||
void* end_{nullptr};
|
|
||||||
|
|
||||||
public:
|
public:
|
||||||
SmallSizePool();
|
SmallSizePool();
|
||||||
@ -39,9 +37,9 @@ class SmallSizePool {
|
|||||||
SmallSizePool(const SmallSizePool&) = delete;
|
SmallSizePool(const SmallSizePool&) = delete;
|
||||||
SmallSizePool& operator=(const SmallSizePool&) = delete;
|
SmallSizePool& operator=(const SmallSizePool&) = delete;
|
||||||
|
|
||||||
void* malloc();
|
CudaBuffer* malloc();
|
||||||
void free(void* p);
|
void free(CudaBuffer* buf);
|
||||||
bool in_pool(void* p);
|
bool in_pool(CudaBuffer* buf);
|
||||||
};
|
};
|
||||||
|
|
||||||
class CudaAllocator : public allocator::Allocator {
|
class CudaAllocator : public allocator::Allocator {
|
||||||
@ -50,15 +48,6 @@ class CudaAllocator : public allocator::Allocator {
|
|||||||
void free(Buffer buffer) override;
|
void free(Buffer buffer) override;
|
||||||
size_t size(Buffer buffer) const override;
|
size_t size(Buffer buffer) const override;
|
||||||
|
|
||||||
// Register current thread as safe to free buffers.
|
|
||||||
// In cuda freeing a buffer implicitly synchronizes stream, and for threads
|
|
||||||
// that may be waited by gpu stream (for example cpu stream threads), freeing
|
|
||||||
// buffers there would result in dead lock.
|
|
||||||
void register_this_thread();
|
|
||||||
|
|
||||||
// Call cudaFree in the safe thread.
|
|
||||||
void cuda_free(void* buf);
|
|
||||||
|
|
||||||
size_t get_active_memory() const;
|
size_t get_active_memory() const;
|
||||||
size_t get_peak_memory() const;
|
size_t get_peak_memory() const;
|
||||||
void reset_peak_memory();
|
void reset_peak_memory();
|
||||||
@ -69,13 +58,11 @@ class CudaAllocator : public allocator::Allocator {
|
|||||||
void clear_cache();
|
void clear_cache();
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
void cuda_free(CudaBuffer* buf);
|
||||||
|
|
||||||
CudaAllocator();
|
CudaAllocator();
|
||||||
friend CudaAllocator& allocator();
|
friend CudaAllocator& allocator();
|
||||||
|
|
||||||
std::mutex worker_mutex_;
|
|
||||||
std::unique_ptr<Worker> worker_;
|
|
||||||
std::set<std::thread::id> allowed_threads_;
|
|
||||||
|
|
||||||
std::mutex mutex_;
|
std::mutex mutex_;
|
||||||
size_t memory_limit_;
|
size_t memory_limit_;
|
||||||
size_t max_pool_size_;
|
size_t max_pool_size_;
|
||||||
|
@ -306,7 +306,6 @@ void CommandEncoder::commit() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Put completion handlers in a batch.
|
// Put completion handlers in a batch.
|
||||||
worker_.end_batch();
|
|
||||||
worker_.commit(stream_);
|
worker_.commit(stream_);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -315,7 +314,6 @@ void CommandEncoder::synchronize() {
|
|||||||
auto p = std::make_shared<std::promise<void>>();
|
auto p = std::make_shared<std::promise<void>>();
|
||||||
std::future<void> f = p->get_future();
|
std::future<void> f = p->get_future();
|
||||||
add_completed_handler([p = std::move(p)]() { p->set_value(); });
|
add_completed_handler([p = std::move(p)]() { p->set_value(); });
|
||||||
worker_.end_batch();
|
|
||||||
commit();
|
commit();
|
||||||
f.wait();
|
f.wait();
|
||||||
}
|
}
|
||||||
|
@ -19,8 +19,6 @@ void new_stream(Stream s) {
|
|||||||
cudaFree(nullptr);
|
cudaFree(nullptr);
|
||||||
// Ensure the static stream objects get created.
|
// Ensure the static stream objects get created.
|
||||||
cu::get_command_encoder(s);
|
cu::get_command_encoder(s);
|
||||||
// The main thread is safe to free buffers.
|
|
||||||
cu::allocator().register_this_thread();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void eval(array& arr) {
|
void eval(array& arr) {
|
||||||
|
@ -110,24 +110,26 @@ __global__ void event_signal_kernel(SharedEvent::Atomic* ac, uint64_t value) {
|
|||||||
event_signal(ac, value);
|
event_signal(ac, value);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
SharedEvent::Atomic* to_atomic(std::shared_ptr<Buffer> buf) {
|
||||||
|
return static_cast<SharedEvent::Atomic*>(buf->raw_ptr());
|
||||||
|
}
|
||||||
|
|
||||||
SharedEvent::SharedEvent() {
|
SharedEvent::SharedEvent() {
|
||||||
// Allocate cuda::atomic on managed memory.
|
buf_ = std::shared_ptr<Buffer>(
|
||||||
Atomic* ac;
|
new Buffer{allocator().malloc(sizeof(Atomic))}, [](Buffer* ptr) {
|
||||||
CHECK_CUDA_ERROR(cudaMallocManaged(&ac, sizeof(Atomic)));
|
allocator().free(*ptr);
|
||||||
new (ac) Atomic(0);
|
delete ptr;
|
||||||
ac_ = std::shared_ptr<Atomic>(ac, [](Atomic* ptr) {
|
});
|
||||||
ptr->~Atomic();
|
*static_cast<uint64_t*>(buf_->raw_ptr()) = 0;
|
||||||
allocator().cuda_free(ptr);
|
|
||||||
});
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void SharedEvent::wait(uint64_t value) {
|
void SharedEvent::wait(uint64_t value) {
|
||||||
nvtx3::scoped_range r("cu::SharedEvent::wait");
|
nvtx3::scoped_range r("cu::SharedEvent::wait");
|
||||||
event_wait(ac_.get(), value);
|
event_wait(to_atomic(buf_), value);
|
||||||
}
|
}
|
||||||
|
|
||||||
void SharedEvent::wait(cudaStream_t stream, uint64_t value) {
|
void SharedEvent::wait(cudaStream_t stream, uint64_t value) {
|
||||||
event_wait_kernel<<<1, 1, 0, stream>>>(ac_.get(), value);
|
event_wait_kernel<<<1, 1, 0, stream>>>(to_atomic(buf_), value);
|
||||||
}
|
}
|
||||||
|
|
||||||
void SharedEvent::wait(Stream s, uint64_t value) {
|
void SharedEvent::wait(Stream s, uint64_t value) {
|
||||||
@ -138,17 +140,17 @@ void SharedEvent::wait(Stream s, uint64_t value) {
|
|||||||
auto& encoder = get_command_encoder(s);
|
auto& encoder = get_command_encoder(s);
|
||||||
encoder.commit();
|
encoder.commit();
|
||||||
wait(encoder.stream(), value);
|
wait(encoder.stream(), value);
|
||||||
encoder.add_completed_handler([ac = ac_]() {});
|
encoder.add_completed_handler([buf = buf_]() {});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void SharedEvent::signal(uint64_t value) {
|
void SharedEvent::signal(uint64_t value) {
|
||||||
nvtx3::scoped_range r("cu::SharedEvent::signal");
|
nvtx3::scoped_range r("cu::SharedEvent::signal");
|
||||||
event_signal(ac_.get(), value);
|
event_signal(to_atomic(buf_), value);
|
||||||
}
|
}
|
||||||
|
|
||||||
void SharedEvent::signal(cudaStream_t stream, uint64_t value) {
|
void SharedEvent::signal(cudaStream_t stream, uint64_t value) {
|
||||||
event_signal_kernel<<<1, 1, 0, stream>>>(ac_.get(), value);
|
event_signal_kernel<<<1, 1, 0, stream>>>(to_atomic(buf_), value);
|
||||||
}
|
}
|
||||||
|
|
||||||
void SharedEvent::signal(Stream s, uint64_t value) {
|
void SharedEvent::signal(Stream s, uint64_t value) {
|
||||||
@ -162,18 +164,18 @@ void SharedEvent::signal(Stream s, uint64_t value) {
|
|||||||
auto& encoder = get_command_encoder(s);
|
auto& encoder = get_command_encoder(s);
|
||||||
encoder.commit();
|
encoder.commit();
|
||||||
signal(encoder.stream(), value);
|
signal(encoder.stream(), value);
|
||||||
encoder.add_completed_handler([ac = ac_]() {});
|
encoder.add_completed_handler([buf = buf_]() {});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
bool SharedEvent::is_signaled(uint64_t value) const {
|
bool SharedEvent::is_signaled(uint64_t value) const {
|
||||||
nvtx3::scoped_range r("cu::SharedEvent::is_signaled");
|
nvtx3::scoped_range r("cu::SharedEvent::is_signaled");
|
||||||
return ac_->load() >= value;
|
return to_atomic(buf_)->load() >= value;
|
||||||
}
|
}
|
||||||
|
|
||||||
uint64_t SharedEvent::value() const {
|
uint64_t SharedEvent::value() const {
|
||||||
nvtx3::scoped_range r("cu::SharedEvent::value");
|
nvtx3::scoped_range r("cu::SharedEvent::value");
|
||||||
return ac_->load();
|
return to_atomic(buf_)->load();
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace cu
|
} // namespace cu
|
||||||
|
@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
|
#include "mlx/allocator.h"
|
||||||
#include "mlx/stream.h"
|
#include "mlx/stream.h"
|
||||||
|
|
||||||
#include <cuda_runtime.h>
|
#include <cuda_runtime.h>
|
||||||
@ -55,12 +56,8 @@ class SharedEvent {
|
|||||||
bool is_signaled(uint64_t value) const;
|
bool is_signaled(uint64_t value) const;
|
||||||
uint64_t value() const;
|
uint64_t value() const;
|
||||||
|
|
||||||
const std::shared_ptr<Atomic>& atomic() const {
|
|
||||||
return ac_;
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::shared_ptr<Atomic> ac_;
|
std::shared_ptr<mlx::core::allocator::Buffer> buf_;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace mlx::core::cu
|
} // namespace mlx::core::cu
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
// Copyright © 2025 Apple Inc.
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
#include "mlx/backend/cuda/worker.h"
|
#include "mlx/backend/cuda/worker.h"
|
||||||
#include "mlx/backend/cuda/allocator.h"
|
|
||||||
#include "mlx/backend/cuda/device.h"
|
#include "mlx/backend/cuda/device.h"
|
||||||
|
|
||||||
namespace mlx::core::cu {
|
namespace mlx::core::cu {
|
||||||
@ -12,10 +11,10 @@ Worker::Worker()
|
|||||||
|
|
||||||
Worker::~Worker() {
|
Worker::~Worker() {
|
||||||
{
|
{
|
||||||
std::lock_guard lock(worker_mutex_);
|
std::lock_guard lock(mtx_);
|
||||||
stop_ = true;
|
stop_ = true;
|
||||||
}
|
}
|
||||||
worker_event_.signal(batch_ + 1);
|
cond_.notify_one();
|
||||||
worker_.join();
|
worker_.join();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -23,53 +22,41 @@ void Worker::add_task(std::function<void()> task) {
|
|||||||
pending_tasks_.push_back(std::move(task));
|
pending_tasks_.push_back(std::move(task));
|
||||||
}
|
}
|
||||||
|
|
||||||
void Worker::consume_in_this_thread() {
|
void Worker::signal(void* data) {
|
||||||
for (auto& task : pending_tasks_) {
|
auto w = static_cast<Worker*>(data);
|
||||||
task();
|
|
||||||
}
|
|
||||||
pending_tasks_.clear();
|
|
||||||
}
|
|
||||||
|
|
||||||
void Worker::end_batch() {
|
|
||||||
batch_++;
|
|
||||||
{
|
{
|
||||||
std::lock_guard lock(worker_mutex_);
|
std::lock_guard lock(w->mtx_);
|
||||||
worker_tasks_[batch_] = std::move(pending_tasks_);
|
w->signaled_batch_++;
|
||||||
}
|
}
|
||||||
uncommited_batches_++;
|
w->cond_.notify_one();
|
||||||
}
|
|
||||||
|
|
||||||
void Worker::commit() {
|
|
||||||
if (uncommited_batches_ == 0) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
uncommited_batches_ = 0;
|
|
||||||
worker_event_.signal(batch_);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void Worker::commit(cudaStream_t stream) {
|
void Worker::commit(cudaStream_t stream) {
|
||||||
if (uncommited_batches_ == 0) {
|
// Move pending tasks into tasks
|
||||||
|
if (pending_tasks_.empty()) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
uncommited_batches_ = 0;
|
{
|
||||||
// Signal the |worker_event_| in |signal_stream_| after the kernels in
|
std::lock_guard lock(mtx_);
|
||||||
// |stream_| finish running.
|
// Move pending tasks into ready tasks
|
||||||
|
worker_tasks_[++committed_batch_] = std::move(pending_tasks_);
|
||||||
|
}
|
||||||
signal_event_.record(stream);
|
signal_event_.record(stream);
|
||||||
signal_event_.wait(signal_stream_);
|
signal_event_.wait(signal_stream_);
|
||||||
worker_event_.signal(signal_stream_, batch_);
|
cudaLaunchHostFunc(signal_stream_, signal, this);
|
||||||
}
|
}
|
||||||
|
|
||||||
void Worker::thread_fn() {
|
void Worker::thread_fn() {
|
||||||
// The worker thread is safe to free buffers.
|
|
||||||
allocator().register_this_thread();
|
|
||||||
|
|
||||||
while (!stop_) {
|
while (!stop_) {
|
||||||
uint64_t batch = worker_event_.value();
|
uint64_t current_batch = 0;
|
||||||
Tasks tasks;
|
Tasks tasks;
|
||||||
{
|
{
|
||||||
std::lock_guard lock(worker_mutex_);
|
std::unique_lock<std::mutex> lk(mtx_);
|
||||||
// Move tasks in signaled batches.
|
cond_.wait(lk, [this, ¤t_batch] {
|
||||||
auto end = worker_tasks_.upper_bound(batch);
|
return this->signaled_batch_ > current_batch || this->stop_;
|
||||||
|
});
|
||||||
|
current_batch = signaled_batch_;
|
||||||
|
auto end = worker_tasks_.upper_bound(current_batch);
|
||||||
for (auto it = worker_tasks_.begin(); it != end; ++it) {
|
for (auto it = worker_tasks_.begin(); it != end; ++it) {
|
||||||
if (tasks.empty()) {
|
if (tasks.empty()) {
|
||||||
tasks = std::move(it->second);
|
tasks = std::move(it->second);
|
||||||
@ -85,7 +72,6 @@ void Worker::thread_fn() {
|
|||||||
auto task = std::move(tasks[i]);
|
auto task = std::move(tasks[i]);
|
||||||
task();
|
task();
|
||||||
}
|
}
|
||||||
worker_event_.wait(batch + 1);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -5,6 +5,7 @@
|
|||||||
#include "mlx/backend/cuda/event.h"
|
#include "mlx/backend/cuda/event.h"
|
||||||
#include "mlx/backend/cuda/utils.h"
|
#include "mlx/backend/cuda/utils.h"
|
||||||
|
|
||||||
|
#include <condition_variable>
|
||||||
#include <functional>
|
#include <functional>
|
||||||
#include <map>
|
#include <map>
|
||||||
#include <mutex>
|
#include <mutex>
|
||||||
@ -24,38 +25,24 @@ class Worker {
|
|||||||
// Add a pending |task| that will run when consumed or commited.
|
// Add a pending |task| that will run when consumed or commited.
|
||||||
void add_task(std::function<void()> task);
|
void add_task(std::function<void()> task);
|
||||||
|
|
||||||
// Run pending tasks immediately in current thread.
|
|
||||||
void consume_in_this_thread();
|
|
||||||
|
|
||||||
// Put pending tasks in a batch.
|
|
||||||
void end_batch();
|
|
||||||
|
|
||||||
// Inform worker thread to run current batches now.
|
|
||||||
void commit();
|
|
||||||
|
|
||||||
// Inform worker thread to run current batches after kernels in |stream|
|
// Inform worker thread to run current batches after kernels in |stream|
|
||||||
// finish running.
|
// finish running.
|
||||||
void commit(cudaStream_t stream);
|
void commit(cudaStream_t stream);
|
||||||
|
|
||||||
// Return how many batches have been added but not committed yet.
|
|
||||||
size_t uncommited_batches() const {
|
|
||||||
return uncommited_batches_;
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void thread_fn();
|
static void signal(void*);
|
||||||
|
|
||||||
uint64_t batch_{0};
|
void thread_fn();
|
||||||
size_t uncommited_batches_{0};
|
std::mutex mtx_;
|
||||||
|
std::condition_variable cond_;
|
||||||
|
|
||||||
|
uint64_t committed_batch_{0};
|
||||||
|
uint64_t signaled_batch_{0};
|
||||||
|
|
||||||
// Cuda stream and event for signaling kernel completion.
|
// Cuda stream and event for signaling kernel completion.
|
||||||
CudaStream signal_stream_;
|
CudaStream signal_stream_;
|
||||||
CudaEvent signal_event_;
|
CudaEvent signal_event_;
|
||||||
|
|
||||||
// Worker thread.
|
|
||||||
SharedEvent worker_event_;
|
|
||||||
std::thread worker_;
|
|
||||||
std::mutex worker_mutex_;
|
|
||||||
bool stop_{false};
|
bool stop_{false};
|
||||||
|
|
||||||
// Tasks are put in |pending_tasks_| first, and then moved to
|
// Tasks are put in |pending_tasks_| first, and then moved to
|
||||||
@ -63,6 +50,7 @@ class Worker {
|
|||||||
using Tasks = std::vector<std::function<void()>>;
|
using Tasks = std::vector<std::function<void()>>;
|
||||||
Tasks pending_tasks_;
|
Tasks pending_tasks_;
|
||||||
std::map<uint64_t, Tasks> worker_tasks_;
|
std::map<uint64_t, Tasks> worker_tasks_;
|
||||||
|
std::thread worker_;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace mlx::core::cu
|
} // namespace mlx::core::cu
|
||||||
|
@ -128,8 +128,7 @@ Buffer MetalAllocator::malloc(size_t size) {
|
|||||||
|
|
||||||
auto pool = metal::new_scoped_memory_pool();
|
auto pool = metal::new_scoped_memory_pool();
|
||||||
|
|
||||||
// If we have a lot of memory pressure or are over the maximum cache size,
|
// If we have a lot of memory pressure try to reclaim memory from the cache
|
||||||
// try to reclaim memory from the cache
|
|
||||||
if (mem_required >= gc_limit_ || num_resources_ >= resource_limit_) {
|
if (mem_required >= gc_limit_ || num_resources_ >= resource_limit_) {
|
||||||
num_resources_ -=
|
num_resources_ -=
|
||||||
buffer_cache_.release_cached_buffers(mem_required - gc_limit_);
|
buffer_cache_.release_cached_buffers(mem_required - gc_limit_);
|
||||||
|
Loading…
Reference in New Issue
Block a user