make general autorelease pool part of metal device

This commit is contained in:
Ronan Collobert 2023-12-18 20:45:29 -08:00
parent e8deca84e0
commit a813bdda0a
3 changed files with 42 additions and 39 deletions

View File

@ -288,15 +288,15 @@ Device& device(mlx::core::Device) {
return metal_device_;
}
NS::AutoreleasePool*& thread_autorelease_pool() {
NS::AutoreleasePool*& Device::g_thread_autorelease_pool() {
static thread_local NS::AutoreleasePool* p =
NS::AutoreleasePool::alloc()->init();
return p;
}
void new_stream(Stream stream) {
thread_autorelease_pool();
if (stream.device == mlx::core::Device::gpu) {
device(stream.device).g_thread_autorelease_pool();
device(stream.device).new_queue(stream.index);
}
}

View File

@ -66,6 +66,8 @@ class Device {
MTL::ArgumentEncoder* argument_encoder(
const std::vector<MTL::ArgumentDescriptor*>& arg_descs) const;
NS::AutoreleasePool*& g_thread_autorelease_pool();
private:
NS::AutoreleasePool* pool_;
MTL::Device* device_;
@ -78,6 +80,5 @@ class Device {
};
Device& device(mlx::core::Device);
NS::AutoreleasePool*& thread_autorelease_pool();
} // namespace mlx::core::metal

View File

@ -48,42 +48,44 @@ std::function<void()> make_task(
std::vector<std::shared_future<void>> deps,
std::shared_ptr<std::promise<void>> p,
bool retain_graph) {
auto task =
[retain_graph, arr, deps = std::move(deps), p = std::move(p)]() mutable {
for (auto& d : deps) {
d.wait();
}
auto s = arr.primitive().stream();
auto command_buffer = increment_command_buffer(s);
arr.primitive().eval_gpu(arr.inputs(), arr);
if (p) {
metal::device(s.device).end_encoding(s.index);
scheduler::notify_new_task(s);
command_buffer->addCompletedHandler(
[retain_graph, s, arr, p = std::move(p)](
MTL::CommandBuffer*) mutable {
if (!retain_graph) {
arr.detach();
}
p->set_value();
// Signal this thread to clear the pool on a synchroniztion.
scheduler::enqueue(s, []() {
thread_autorelease_pool()->release();
thread_autorelease_pool() =
NS::AutoreleasePool::alloc()->init();
});
scheduler::notify_task_completion(s);
});
metal::device(s.device).commit_command_buffer(s.index);
} else {
command_buffer->addCompletedHandler(
[retain_graph, s, arr](MTL::CommandBuffer*) mutable {
if (!retain_graph) {
arr.detach();
}
});
}
};
auto task = [retain_graph,
arr,
deps = std::move(deps),
p = std::move(p)]() mutable {
for (auto& d : deps) {
d.wait();
}
auto s = arr.primitive().stream();
auto command_buffer = increment_command_buffer(s);
arr.primitive().eval_gpu(arr.inputs(), arr);
if (p) {
metal::device(s.device).end_encoding(s.index);
scheduler::notify_new_task(s);
command_buffer->addCompletedHandler(
[retain_graph, s, arr, p = std::move(p)](
MTL::CommandBuffer*) mutable {
if (!retain_graph) {
arr.detach();
}
p->set_value();
// Signal this thread to clear the pool on a synchroniztion.
scheduler::enqueue(s, [s]() {
metal::device(s.device).g_thread_autorelease_pool()->release();
metal::device(s.device).g_thread_autorelease_pool() =
NS::AutoreleasePool::alloc()->init();
});
scheduler::notify_task_completion(s);
});
metal::device(s.device).commit_command_buffer(s.index);
} else {
command_buffer->addCompletedHandler(
[retain_graph, s, arr](MTL::CommandBuffer*) mutable {
if (!retain_graph) {
arr.detach();
}
});
}
};
return task;
}