mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-22 04:56:41 +08:00
make general autorelease pool part of metal device
This commit is contained in:
parent
e8deca84e0
commit
a813bdda0a
@ -288,15 +288,15 @@ Device& device(mlx::core::Device) {
|
|||||||
return metal_device_;
|
return metal_device_;
|
||||||
}
|
}
|
||||||
|
|
||||||
NS::AutoreleasePool*& thread_autorelease_pool() {
|
NS::AutoreleasePool*& Device::g_thread_autorelease_pool() {
|
||||||
static thread_local NS::AutoreleasePool* p =
|
static thread_local NS::AutoreleasePool* p =
|
||||||
NS::AutoreleasePool::alloc()->init();
|
NS::AutoreleasePool::alloc()->init();
|
||||||
return p;
|
return p;
|
||||||
}
|
}
|
||||||
|
|
||||||
void new_stream(Stream stream) {
|
void new_stream(Stream stream) {
|
||||||
thread_autorelease_pool();
|
|
||||||
if (stream.device == mlx::core::Device::gpu) {
|
if (stream.device == mlx::core::Device::gpu) {
|
||||||
|
device(stream.device).g_thread_autorelease_pool();
|
||||||
device(stream.device).new_queue(stream.index);
|
device(stream.device).new_queue(stream.index);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -66,6 +66,8 @@ class Device {
|
|||||||
MTL::ArgumentEncoder* argument_encoder(
|
MTL::ArgumentEncoder* argument_encoder(
|
||||||
const std::vector<MTL::ArgumentDescriptor*>& arg_descs) const;
|
const std::vector<MTL::ArgumentDescriptor*>& arg_descs) const;
|
||||||
|
|
||||||
|
NS::AutoreleasePool*& g_thread_autorelease_pool();
|
||||||
|
|
||||||
private:
|
private:
|
||||||
NS::AutoreleasePool* pool_;
|
NS::AutoreleasePool* pool_;
|
||||||
MTL::Device* device_;
|
MTL::Device* device_;
|
||||||
@ -78,6 +80,5 @@ class Device {
|
|||||||
};
|
};
|
||||||
|
|
||||||
Device& device(mlx::core::Device);
|
Device& device(mlx::core::Device);
|
||||||
NS::AutoreleasePool*& thread_autorelease_pool();
|
|
||||||
|
|
||||||
} // namespace mlx::core::metal
|
} // namespace mlx::core::metal
|
||||||
|
@ -48,42 +48,44 @@ std::function<void()> make_task(
|
|||||||
std::vector<std::shared_future<void>> deps,
|
std::vector<std::shared_future<void>> deps,
|
||||||
std::shared_ptr<std::promise<void>> p,
|
std::shared_ptr<std::promise<void>> p,
|
||||||
bool retain_graph) {
|
bool retain_graph) {
|
||||||
auto task =
|
auto task = [retain_graph,
|
||||||
[retain_graph, arr, deps = std::move(deps), p = std::move(p)]() mutable {
|
arr,
|
||||||
for (auto& d : deps) {
|
deps = std::move(deps),
|
||||||
d.wait();
|
p = std::move(p)]() mutable {
|
||||||
}
|
for (auto& d : deps) {
|
||||||
auto s = arr.primitive().stream();
|
d.wait();
|
||||||
auto command_buffer = increment_command_buffer(s);
|
}
|
||||||
arr.primitive().eval_gpu(arr.inputs(), arr);
|
auto s = arr.primitive().stream();
|
||||||
if (p) {
|
auto command_buffer = increment_command_buffer(s);
|
||||||
metal::device(s.device).end_encoding(s.index);
|
arr.primitive().eval_gpu(arr.inputs(), arr);
|
||||||
scheduler::notify_new_task(s);
|
if (p) {
|
||||||
command_buffer->addCompletedHandler(
|
metal::device(s.device).end_encoding(s.index);
|
||||||
[retain_graph, s, arr, p = std::move(p)](
|
scheduler::notify_new_task(s);
|
||||||
MTL::CommandBuffer*) mutable {
|
command_buffer->addCompletedHandler(
|
||||||
if (!retain_graph) {
|
[retain_graph, s, arr, p = std::move(p)](
|
||||||
arr.detach();
|
MTL::CommandBuffer*) mutable {
|
||||||
}
|
if (!retain_graph) {
|
||||||
p->set_value();
|
arr.detach();
|
||||||
// Signal this thread to clear the pool on a synchroniztion.
|
}
|
||||||
scheduler::enqueue(s, []() {
|
p->set_value();
|
||||||
thread_autorelease_pool()->release();
|
// Signal this thread to clear the pool on a synchroniztion.
|
||||||
thread_autorelease_pool() =
|
scheduler::enqueue(s, [s]() {
|
||||||
NS::AutoreleasePool::alloc()->init();
|
metal::device(s.device).g_thread_autorelease_pool()->release();
|
||||||
});
|
metal::device(s.device).g_thread_autorelease_pool() =
|
||||||
scheduler::notify_task_completion(s);
|
NS::AutoreleasePool::alloc()->init();
|
||||||
});
|
});
|
||||||
metal::device(s.device).commit_command_buffer(s.index);
|
scheduler::notify_task_completion(s);
|
||||||
} else {
|
});
|
||||||
command_buffer->addCompletedHandler(
|
metal::device(s.device).commit_command_buffer(s.index);
|
||||||
[retain_graph, s, arr](MTL::CommandBuffer*) mutable {
|
} else {
|
||||||
if (!retain_graph) {
|
command_buffer->addCompletedHandler(
|
||||||
arr.detach();
|
[retain_graph, s, arr](MTL::CommandBuffer*) mutable {
|
||||||
}
|
if (!retain_graph) {
|
||||||
});
|
arr.detach();
|
||||||
}
|
}
|
||||||
};
|
});
|
||||||
|
}
|
||||||
|
};
|
||||||
return task;
|
return task;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user