mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-19 15:41:13 +08:00
Revisit autorelease memory pools (#260)
* make general autorelease pool part of metal device * make things simpler * no metal backend support * new_memory_pool -> new_scoped_memory_pool
This commit is contained in:
parent
d35fa1db41
commit
cd3616a463
@ -17,10 +17,11 @@ namespace fs = std::filesystem;
|
||||
|
||||
namespace mlx::core::metal {
|
||||
|
||||
static Device metal_device_;
|
||||
|
||||
namespace {
|
||||
|
||||
// Catch things related to the main-thread static variables
|
||||
static std::shared_ptr<void> global_memory_pool = new_scoped_memory_pool();
|
||||
|
||||
// TODO nicer way to set this or possibly expose as an environment variable
|
||||
static constexpr int MAX_BUFFERS_PER_QUEUE = 12;
|
||||
|
||||
@ -112,29 +113,29 @@ MTL::Library* load_library(
|
||||
|
||||
} // namespace
|
||||
|
||||
Device::Device()
|
||||
: pool_(NS::AutoreleasePool::alloc()->init()),
|
||||
device_(load_device()),
|
||||
library_map_({{"mlx", load_library(device_)}}) {}
|
||||
Device::Device() {
|
||||
auto pool = new_scoped_memory_pool();
|
||||
device_ = load_device();
|
||||
library_map_ = {{"mlx", load_library(device_)}};
|
||||
}
|
||||
|
||||
Device::~Device() {
|
||||
for (auto& q : queue_map_) {
|
||||
q.second->release();
|
||||
}
|
||||
for (auto& k : kernel_map_) {
|
||||
k.second->release();
|
||||
}
|
||||
for (auto& l : library_map_) {
|
||||
l.second->release();
|
||||
}
|
||||
for (auto& b : buffer_map_) {
|
||||
b.second.second->release();
|
||||
}
|
||||
for (auto& e : encoder_map_) {
|
||||
e.second->release();
|
||||
}
|
||||
for (auto& k : kernel_map_) {
|
||||
k.second->release();
|
||||
}
|
||||
for (auto& l : library_map_) {
|
||||
l.second->release();
|
||||
}
|
||||
device_->release();
|
||||
pool_->release();
|
||||
}
|
||||
|
||||
void Device::new_queue(int index) {
|
||||
@ -243,6 +244,7 @@ void Device::register_library(
|
||||
MTL::ComputePipelineState* Device::get_kernel(
|
||||
const std::string& name,
|
||||
const std::string& lib_name /* = "mlx" */) {
|
||||
auto pool = new_scoped_memory_pool();
|
||||
// Look for cached kernel
|
||||
if (auto it = kernel_map_.find(name); it != kernel_map_.end()) {
|
||||
return it->second;
|
||||
@ -285,17 +287,18 @@ MTL::ComputePipelineState* Device::get_kernel(
|
||||
}
|
||||
|
||||
Device& device(mlx::core::Device) {
|
||||
return metal_device_;
|
||||
static Device metal_device;
|
||||
return metal_device;
|
||||
}
|
||||
|
||||
NS::AutoreleasePool*& thread_autorelease_pool() {
|
||||
static thread_local NS::AutoreleasePool* p =
|
||||
NS::AutoreleasePool::alloc()->init();
|
||||
return p;
|
||||
std::shared_ptr<void> new_scoped_memory_pool() {
|
||||
auto dtor = [](void* ptr) {
|
||||
static_cast<NS::AutoreleasePool*>(ptr)->release();
|
||||
};
|
||||
return std::shared_ptr<void>(NS::AutoreleasePool::alloc()->init(), dtor);
|
||||
}
|
||||
|
||||
void new_stream(Stream stream) {
|
||||
thread_autorelease_pool();
|
||||
if (stream.device == mlx::core::Device::gpu) {
|
||||
device(stream.device).new_queue(stream.index);
|
||||
}
|
||||
|
@ -67,7 +67,6 @@ class Device {
|
||||
const std::vector<MTL::ArgumentDescriptor*>& arg_descs) const;
|
||||
|
||||
private:
|
||||
NS::AutoreleasePool* pool_;
|
||||
MTL::Device* device_;
|
||||
std::unordered_map<int32_t, MTL::CommandQueue*> queue_map_;
|
||||
std::unordered_map<int32_t, std::pair<int, MTL::CommandBuffer*>> buffer_map_;
|
||||
@ -78,6 +77,5 @@ class Device {
|
||||
};
|
||||
|
||||
Device& device(mlx::core::Device);
|
||||
NS::AutoreleasePool*& thread_autorelease_pool();
|
||||
|
||||
} // namespace mlx::core::metal
|
||||
|
@ -50,6 +50,7 @@ std::function<void()> make_task(
|
||||
bool retain_graph) {
|
||||
auto task =
|
||||
[retain_graph, arr, deps = std::move(deps), p = std::move(p)]() mutable {
|
||||
auto pool = new_scoped_memory_pool();
|
||||
for (auto& d : deps) {
|
||||
d.wait();
|
||||
}
|
||||
@ -66,12 +67,6 @@ std::function<void()> make_task(
|
||||
arr.detach();
|
||||
}
|
||||
p->set_value();
|
||||
// Signal this thread to clear the pool on a synchroniztion.
|
||||
scheduler::enqueue(s, []() {
|
||||
thread_autorelease_pool()->release();
|
||||
thread_autorelease_pool() =
|
||||
NS::AutoreleasePool::alloc()->init();
|
||||
});
|
||||
scheduler::notify_task_completion(s);
|
||||
});
|
||||
metal::device(s.device).commit_command_buffer(s.index);
|
||||
|
@ -20,6 +20,7 @@ constexpr bool is_available() {
|
||||
}
|
||||
|
||||
void new_stream(Stream stream);
|
||||
std::shared_ptr<void> new_scoped_memory_pool();
|
||||
|
||||
std::function<void()> make_task(
|
||||
array& arr,
|
||||
|
@ -7,6 +7,9 @@
|
||||
namespace mlx::core::metal {
|
||||
|
||||
void new_stream(Stream) {}
|
||||
std::shared_ptr<void> new_memory_pool() {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
std::function<void()> make_task(
|
||||
array& arr,
|
||||
|
@ -35,6 +35,7 @@ struct StreamThread {
|
||||
}
|
||||
|
||||
void thread_fn() {
|
||||
auto thread_pool = metal::new_scoped_memory_pool();
|
||||
metal::new_stream(stream);
|
||||
while (true) {
|
||||
std::function<void()> task;
|
||||
|
Loading…
Reference in New Issue
Block a user