make general autorelease pool part of metal device

This commit is contained in:
Ronan Collobert 2023-12-18 20:45:29 -08:00
parent e8deca84e0
commit a813bdda0a
3 changed files with 42 additions and 39 deletions

View File

@ -288,15 +288,15 @@ Device& device(mlx::core::Device) {
return metal_device_; return metal_device_;
} }
NS::AutoreleasePool*& thread_autorelease_pool() { NS::AutoreleasePool*& Device::g_thread_autorelease_pool() {
static thread_local NS::AutoreleasePool* p = static thread_local NS::AutoreleasePool* p =
NS::AutoreleasePool::alloc()->init(); NS::AutoreleasePool::alloc()->init();
return p; return p;
} }
void new_stream(Stream stream) { void new_stream(Stream stream) {
thread_autorelease_pool();
if (stream.device == mlx::core::Device::gpu) { if (stream.device == mlx::core::Device::gpu) {
device(stream.device).g_thread_autorelease_pool();
device(stream.device).new_queue(stream.index); device(stream.device).new_queue(stream.index);
} }
} }

View File

@ -66,6 +66,8 @@ class Device {
MTL::ArgumentEncoder* argument_encoder( MTL::ArgumentEncoder* argument_encoder(
const std::vector<MTL::ArgumentDescriptor*>& arg_descs) const; const std::vector<MTL::ArgumentDescriptor*>& arg_descs) const;
NS::AutoreleasePool*& g_thread_autorelease_pool();
private: private:
NS::AutoreleasePool* pool_; NS::AutoreleasePool* pool_;
MTL::Device* device_; MTL::Device* device_;
@ -78,6 +80,5 @@ class Device {
}; };
Device& device(mlx::core::Device); Device& device(mlx::core::Device);
NS::AutoreleasePool*& thread_autorelease_pool();
} // namespace mlx::core::metal } // namespace mlx::core::metal

View File

@ -48,42 +48,44 @@ std::function<void()> make_task(
std::vector<std::shared_future<void>> deps, std::vector<std::shared_future<void>> deps,
std::shared_ptr<std::promise<void>> p, std::shared_ptr<std::promise<void>> p,
bool retain_graph) { bool retain_graph) {
auto task = auto task = [retain_graph,
[retain_graph, arr, deps = std::move(deps), p = std::move(p)]() mutable { arr,
for (auto& d : deps) { deps = std::move(deps),
d.wait(); p = std::move(p)]() mutable {
} for (auto& d : deps) {
auto s = arr.primitive().stream(); d.wait();
auto command_buffer = increment_command_buffer(s); }
arr.primitive().eval_gpu(arr.inputs(), arr); auto s = arr.primitive().stream();
if (p) { auto command_buffer = increment_command_buffer(s);
metal::device(s.device).end_encoding(s.index); arr.primitive().eval_gpu(arr.inputs(), arr);
scheduler::notify_new_task(s); if (p) {
command_buffer->addCompletedHandler( metal::device(s.device).end_encoding(s.index);
[retain_graph, s, arr, p = std::move(p)]( scheduler::notify_new_task(s);
MTL::CommandBuffer*) mutable { command_buffer->addCompletedHandler(
if (!retain_graph) { [retain_graph, s, arr, p = std::move(p)](
arr.detach(); MTL::CommandBuffer*) mutable {
} if (!retain_graph) {
p->set_value(); arr.detach();
// Signal this thread to clear the pool on a synchroniztion. }
scheduler::enqueue(s, []() { p->set_value();
thread_autorelease_pool()->release(); // Signal this thread to clear the pool on a synchroniztion.
thread_autorelease_pool() = scheduler::enqueue(s, [s]() {
NS::AutoreleasePool::alloc()->init(); metal::device(s.device).g_thread_autorelease_pool()->release();
}); metal::device(s.device).g_thread_autorelease_pool() =
scheduler::notify_task_completion(s); NS::AutoreleasePool::alloc()->init();
}); });
metal::device(s.device).commit_command_buffer(s.index); scheduler::notify_task_completion(s);
} else { });
command_buffer->addCompletedHandler( metal::device(s.device).commit_command_buffer(s.index);
[retain_graph, s, arr](MTL::CommandBuffer*) mutable { } else {
if (!retain_graph) { command_buffer->addCompletedHandler(
arr.detach(); [retain_graph, s, arr](MTL::CommandBuffer*) mutable {
} if (!retain_graph) {
}); arr.detach();
} }
}; });
}
};
return task; return task;
} }