Revisit autorelease memory pools (#260)

* make general autorelease pool part of metal device

* make things simpler

* no metal backend support

* new_memory_pool -> new_scoped_memory_pool
This commit is contained in:
Ronan Collobert 2023-12-22 11:01:26 -08:00 committed by GitHub
parent d35fa1db41
commit cd3616a463
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 28 additions and 27 deletions

View File

@ -17,10 +17,11 @@ namespace fs = std::filesystem;
namespace mlx::core::metal {
static Device metal_device_;
namespace {
// Catch things related to the main-thread static variables
static std::shared_ptr<void> global_memory_pool = new_scoped_memory_pool();
// TODO nicer way to set this or possibly expose as an environment variable
static constexpr int MAX_BUFFERS_PER_QUEUE = 12;
@ -112,29 +113,29 @@ MTL::Library* load_library(
} // namespace
Device::Device()
: pool_(NS::AutoreleasePool::alloc()->init()),
device_(load_device()),
library_map_({{"mlx", load_library(device_)}}) {}
Device::Device() {
auto pool = new_scoped_memory_pool();
device_ = load_device();
library_map_ = {{"mlx", load_library(device_)}};
}
Device::~Device() {
for (auto& q : queue_map_) {
q.second->release();
}
for (auto& k : kernel_map_) {
k.second->release();
}
for (auto& l : library_map_) {
l.second->release();
}
for (auto& b : buffer_map_) {
b.second.second->release();
}
for (auto& e : encoder_map_) {
e.second->release();
}
for (auto& k : kernel_map_) {
k.second->release();
}
for (auto& l : library_map_) {
l.second->release();
}
device_->release();
pool_->release();
}
void Device::new_queue(int index) {
@ -243,6 +244,7 @@ void Device::register_library(
MTL::ComputePipelineState* Device::get_kernel(
const std::string& name,
const std::string& lib_name /* = "mlx" */) {
auto pool = new_scoped_memory_pool();
// Look for cached kernel
if (auto it = kernel_map_.find(name); it != kernel_map_.end()) {
return it->second;
@ -285,17 +287,18 @@ MTL::ComputePipelineState* Device::get_kernel(
}
Device& device(mlx::core::Device) {
return metal_device_;
static Device metal_device;
return metal_device;
}
NS::AutoreleasePool*& thread_autorelease_pool() {
static thread_local NS::AutoreleasePool* p =
NS::AutoreleasePool::alloc()->init();
return p;
std::shared_ptr<void> new_scoped_memory_pool() {
auto dtor = [](void* ptr) {
static_cast<NS::AutoreleasePool*>(ptr)->release();
};
return std::shared_ptr<void>(NS::AutoreleasePool::alloc()->init(), dtor);
}
void new_stream(Stream stream) {
thread_autorelease_pool();
if (stream.device == mlx::core::Device::gpu) {
device(stream.device).new_queue(stream.index);
}

View File

@ -67,7 +67,6 @@ class Device {
const std::vector<MTL::ArgumentDescriptor*>& arg_descs) const;
private:
NS::AutoreleasePool* pool_;
MTL::Device* device_;
std::unordered_map<int32_t, MTL::CommandQueue*> queue_map_;
std::unordered_map<int32_t, std::pair<int, MTL::CommandBuffer*>> buffer_map_;
@ -78,6 +77,5 @@ class Device {
};
Device& device(mlx::core::Device);
NS::AutoreleasePool*& thread_autorelease_pool();
} // namespace mlx::core::metal

View File

@ -50,6 +50,7 @@ std::function<void()> make_task(
bool retain_graph) {
auto task =
[retain_graph, arr, deps = std::move(deps), p = std::move(p)]() mutable {
auto pool = new_scoped_memory_pool();
for (auto& d : deps) {
d.wait();
}
@ -66,12 +67,6 @@ std::function<void()> make_task(
arr.detach();
}
p->set_value();
// Signal this thread to clear the pool on a synchroniztion.
scheduler::enqueue(s, []() {
thread_autorelease_pool()->release();
thread_autorelease_pool() =
NS::AutoreleasePool::alloc()->init();
});
scheduler::notify_task_completion(s);
});
metal::device(s.device).commit_command_buffer(s.index);

View File

@ -20,6 +20,7 @@ constexpr bool is_available() {
}
void new_stream(Stream stream);
std::shared_ptr<void> new_scoped_memory_pool();
std::function<void()> make_task(
array& arr,

View File

@ -7,6 +7,9 @@
namespace mlx::core::metal {
void new_stream(Stream) {}
std::shared_ptr<void> new_memory_pool() {
return nullptr;
}
std::function<void()> make_task(
array& arr,

View File

@ -35,6 +35,7 @@ struct StreamThread {
}
void thread_fn() {
auto thread_pool = metal::new_scoped_memory_pool();
metal::new_stream(stream);
while (true) {
std::function<void()> task;