mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-21 20:46:46 +08:00
new_memory_pool -> new_scoped_memory_pool
This commit is contained in:
parent
5ccc1fb314
commit
d9478d0eb0
@ -20,7 +20,7 @@ namespace mlx::core::metal {
|
|||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
// Catch things related to the main-thread static variables
|
// Catch things related to the main-thread static variables
|
||||||
static std::shared_ptr<void> global_memory_pool = new_memory_pool();
|
static std::shared_ptr<void> global_memory_pool = new_scoped_memory_pool();
|
||||||
|
|
||||||
// TODO nicer way to set this or possibly expose as an environment variable
|
// TODO nicer way to set this or possibly expose as an environment variable
|
||||||
static constexpr int MAX_BUFFERS_PER_QUEUE = 12;
|
static constexpr int MAX_BUFFERS_PER_QUEUE = 12;
|
||||||
@ -114,7 +114,7 @@ MTL::Library* load_library(
|
|||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
Device::Device() {
|
Device::Device() {
|
||||||
auto pool = new_memory_pool();
|
auto pool = new_scoped_memory_pool();
|
||||||
device_ = load_device();
|
device_ = load_device();
|
||||||
library_map_ = {{"mlx", load_library(device_)}};
|
library_map_ = {{"mlx", load_library(device_)}};
|
||||||
}
|
}
|
||||||
@ -244,7 +244,7 @@ void Device::register_library(
|
|||||||
MTL::ComputePipelineState* Device::get_kernel(
|
MTL::ComputePipelineState* Device::get_kernel(
|
||||||
const std::string& name,
|
const std::string& name,
|
||||||
const std::string& lib_name /* = "mlx" */) {
|
const std::string& lib_name /* = "mlx" */) {
|
||||||
auto pool = new_memory_pool();
|
auto pool = new_scoped_memory_pool();
|
||||||
// Look for cached kernel
|
// Look for cached kernel
|
||||||
if (auto it = kernel_map_.find(name); it != kernel_map_.end()) {
|
if (auto it = kernel_map_.find(name); it != kernel_map_.end()) {
|
||||||
return it->second;
|
return it->second;
|
||||||
@ -291,7 +291,7 @@ Device& device(mlx::core::Device) {
|
|||||||
return metal_device;
|
return metal_device;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::shared_ptr<void> new_memory_pool() {
|
std::shared_ptr<void> new_scoped_memory_pool() {
|
||||||
auto dtor = [](void* ptr) {
|
auto dtor = [](void* ptr) {
|
||||||
static_cast<NS::AutoreleasePool*>(ptr)->release();
|
static_cast<NS::AutoreleasePool*>(ptr)->release();
|
||||||
};
|
};
|
||||||
|
@ -50,7 +50,7 @@ std::function<void()> make_task(
|
|||||||
bool retain_graph) {
|
bool retain_graph) {
|
||||||
auto task =
|
auto task =
|
||||||
[retain_graph, arr, deps = std::move(deps), p = std::move(p)]() mutable {
|
[retain_graph, arr, deps = std::move(deps), p = std::move(p)]() mutable {
|
||||||
auto pool = new_memory_pool();
|
auto pool = new_scoped_memory_pool();
|
||||||
for (auto& d : deps) {
|
for (auto& d : deps) {
|
||||||
d.wait();
|
d.wait();
|
||||||
}
|
}
|
||||||
|
@ -20,7 +20,7 @@ constexpr bool is_available() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void new_stream(Stream stream);
|
void new_stream(Stream stream);
|
||||||
std::shared_ptr<void> new_memory_pool();
|
std::shared_ptr<void> new_scoped_memory_pool();
|
||||||
|
|
||||||
std::function<void()> make_task(
|
std::function<void()> make_task(
|
||||||
array& arr,
|
array& arr,
|
||||||
|
@ -35,7 +35,7 @@ struct StreamThread {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void thread_fn() {
|
void thread_fn() {
|
||||||
auto thread_pool = metal::new_memory_pool();
|
auto thread_pool = metal::new_scoped_memory_pool();
|
||||||
metal::new_stream(stream);
|
metal::new_stream(stream);
|
||||||
while (true) {
|
while (true) {
|
||||||
std::function<void()> task;
|
std::function<void()> task;
|
||||||
|
Loading…
Reference in New Issue
Block a user