new_memory_pool -> new_scoped_memory_pool

This commit is contained in:
Ronan Collobert 2023-12-22 10:51:49 -08:00
parent 5ccc1fb314
commit d9478d0eb0
4 changed files with 7 additions and 7 deletions

View File

@ -20,7 +20,7 @@ namespace mlx::core::metal {
namespace { namespace {
// Catch things related to the main-thread static variables // Catch things related to the main-thread static variables
static std::shared_ptr<void> global_memory_pool = new_memory_pool(); static std::shared_ptr<void> global_memory_pool = new_scoped_memory_pool();
// TODO nicer way to set this or possibly expose as an environment variable // TODO nicer way to set this or possibly expose as an environment variable
static constexpr int MAX_BUFFERS_PER_QUEUE = 12; static constexpr int MAX_BUFFERS_PER_QUEUE = 12;
@ -114,7 +114,7 @@ MTL::Library* load_library(
} // namespace } // namespace
Device::Device() { Device::Device() {
auto pool = new_memory_pool(); auto pool = new_scoped_memory_pool();
device_ = load_device(); device_ = load_device();
library_map_ = {{"mlx", load_library(device_)}}; library_map_ = {{"mlx", load_library(device_)}};
} }
@ -244,7 +244,7 @@ void Device::register_library(
MTL::ComputePipelineState* Device::get_kernel( MTL::ComputePipelineState* Device::get_kernel(
const std::string& name, const std::string& name,
const std::string& lib_name /* = "mlx" */) { const std::string& lib_name /* = "mlx" */) {
auto pool = new_memory_pool(); auto pool = new_scoped_memory_pool();
// Look for cached kernel // Look for cached kernel
if (auto it = kernel_map_.find(name); it != kernel_map_.end()) { if (auto it = kernel_map_.find(name); it != kernel_map_.end()) {
return it->second; return it->second;
@ -291,7 +291,7 @@ Device& device(mlx::core::Device) {
return metal_device; return metal_device;
} }
std::shared_ptr<void> new_memory_pool() { std::shared_ptr<void> new_scoped_memory_pool() {
auto dtor = [](void* ptr) { auto dtor = [](void* ptr) {
static_cast<NS::AutoreleasePool*>(ptr)->release(); static_cast<NS::AutoreleasePool*>(ptr)->release();
}; };

View File

@ -50,7 +50,7 @@ std::function<void()> make_task(
bool retain_graph) { bool retain_graph) {
auto task = auto task =
[retain_graph, arr, deps = std::move(deps), p = std::move(p)]() mutable { [retain_graph, arr, deps = std::move(deps), p = std::move(p)]() mutable {
auto pool = new_memory_pool(); auto pool = new_scoped_memory_pool();
for (auto& d : deps) { for (auto& d : deps) {
d.wait(); d.wait();
} }

View File

@ -20,7 +20,7 @@ constexpr bool is_available() {
} }
void new_stream(Stream stream); void new_stream(Stream stream);
std::shared_ptr<void> new_memory_pool(); std::shared_ptr<void> new_scoped_memory_pool();
std::function<void()> make_task( std::function<void()> make_task(
array& arr, array& arr,

View File

@ -35,7 +35,7 @@ struct StreamThread {
} }
void thread_fn() { void thread_fn() {
auto thread_pool = metal::new_memory_pool(); auto thread_pool = metal::new_scoped_memory_pool();
metal::new_stream(stream); metal::new_stream(stream);
while (true) { while (true) {
std::function<void()> task; std::function<void()> task;