mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-21 12:06:42 +08:00
make general autorelease pool part of metal device
This commit is contained in:
parent
e8deca84e0
commit
a813bdda0a
@ -288,15 +288,15 @@ Device& device(mlx::core::Device) {
|
||||
return metal_device_;
|
||||
}
|
||||
|
||||
NS::AutoreleasePool*& thread_autorelease_pool() {
|
||||
NS::AutoreleasePool*& Device::g_thread_autorelease_pool() {
|
||||
static thread_local NS::AutoreleasePool* p =
|
||||
NS::AutoreleasePool::alloc()->init();
|
||||
return p;
|
||||
}
|
||||
|
||||
void new_stream(Stream stream) {
|
||||
thread_autorelease_pool();
|
||||
if (stream.device == mlx::core::Device::gpu) {
|
||||
device(stream.device).g_thread_autorelease_pool();
|
||||
device(stream.device).new_queue(stream.index);
|
||||
}
|
||||
}
|
||||
|
@ -66,6 +66,8 @@ class Device {
|
||||
MTL::ArgumentEncoder* argument_encoder(
|
||||
const std::vector<MTL::ArgumentDescriptor*>& arg_descs) const;
|
||||
|
||||
NS::AutoreleasePool*& g_thread_autorelease_pool();
|
||||
|
||||
private:
|
||||
NS::AutoreleasePool* pool_;
|
||||
MTL::Device* device_;
|
||||
@ -78,6 +80,5 @@ class Device {
|
||||
};
|
||||
|
||||
Device& device(mlx::core::Device);
|
||||
NS::AutoreleasePool*& thread_autorelease_pool();
|
||||
|
||||
} // namespace mlx::core::metal
|
||||
|
@ -48,42 +48,44 @@ std::function<void()> make_task(
|
||||
std::vector<std::shared_future<void>> deps,
|
||||
std::shared_ptr<std::promise<void>> p,
|
||||
bool retain_graph) {
|
||||
auto task =
|
||||
[retain_graph, arr, deps = std::move(deps), p = std::move(p)]() mutable {
|
||||
for (auto& d : deps) {
|
||||
d.wait();
|
||||
}
|
||||
auto s = arr.primitive().stream();
|
||||
auto command_buffer = increment_command_buffer(s);
|
||||
arr.primitive().eval_gpu(arr.inputs(), arr);
|
||||
if (p) {
|
||||
metal::device(s.device).end_encoding(s.index);
|
||||
scheduler::notify_new_task(s);
|
||||
command_buffer->addCompletedHandler(
|
||||
[retain_graph, s, arr, p = std::move(p)](
|
||||
MTL::CommandBuffer*) mutable {
|
||||
if (!retain_graph) {
|
||||
arr.detach();
|
||||
}
|
||||
p->set_value();
|
||||
// Signal this thread to clear the pool on a synchroniztion.
|
||||
scheduler::enqueue(s, []() {
|
||||
thread_autorelease_pool()->release();
|
||||
thread_autorelease_pool() =
|
||||
NS::AutoreleasePool::alloc()->init();
|
||||
});
|
||||
scheduler::notify_task_completion(s);
|
||||
});
|
||||
metal::device(s.device).commit_command_buffer(s.index);
|
||||
} else {
|
||||
command_buffer->addCompletedHandler(
|
||||
[retain_graph, s, arr](MTL::CommandBuffer*) mutable {
|
||||
if (!retain_graph) {
|
||||
arr.detach();
|
||||
}
|
||||
});
|
||||
}
|
||||
};
|
||||
auto task = [retain_graph,
|
||||
arr,
|
||||
deps = std::move(deps),
|
||||
p = std::move(p)]() mutable {
|
||||
for (auto& d : deps) {
|
||||
d.wait();
|
||||
}
|
||||
auto s = arr.primitive().stream();
|
||||
auto command_buffer = increment_command_buffer(s);
|
||||
arr.primitive().eval_gpu(arr.inputs(), arr);
|
||||
if (p) {
|
||||
metal::device(s.device).end_encoding(s.index);
|
||||
scheduler::notify_new_task(s);
|
||||
command_buffer->addCompletedHandler(
|
||||
[retain_graph, s, arr, p = std::move(p)](
|
||||
MTL::CommandBuffer*) mutable {
|
||||
if (!retain_graph) {
|
||||
arr.detach();
|
||||
}
|
||||
p->set_value();
|
||||
// Signal this thread to clear the pool on a synchroniztion.
|
||||
scheduler::enqueue(s, [s]() {
|
||||
metal::device(s.device).g_thread_autorelease_pool()->release();
|
||||
metal::device(s.device).g_thread_autorelease_pool() =
|
||||
NS::AutoreleasePool::alloc()->init();
|
||||
});
|
||||
scheduler::notify_task_completion(s);
|
||||
});
|
||||
metal::device(s.device).commit_command_buffer(s.index);
|
||||
} else {
|
||||
command_buffer->addCompletedHandler(
|
||||
[retain_graph, s, arr](MTL::CommandBuffer*) mutable {
|
||||
if (!retain_graph) {
|
||||
arr.detach();
|
||||
}
|
||||
});
|
||||
}
|
||||
};
|
||||
return task;
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user