diff --git a/mlx/backend/cpu/compiled.cpp b/mlx/backend/cpu/compiled.cpp index 9da9c14e8..e389e0df5 100644 --- a/mlx/backend/cpu/compiled.cpp +++ b/mlx/backend/cpu/compiled.cpp @@ -40,7 +40,10 @@ struct CompilerCache { std::shared_mutex mtx; }; -static CompilerCache cache{}; +static CompilerCache& cache() { + static CompilerCache cache_; + return cache_; +}; // GPU compile is always available if the GPU is available and since we are in // this file CPU compile is also available. @@ -56,14 +59,16 @@ void* compile( const std::string& kernel_name, const std::function& source_builder) { { - std::shared_lock lock(cache.mtx); - if (auto it = cache.kernels.find(kernel_name); it != cache.kernels.end()) { + std::shared_lock lock(cache().mtx); + if (auto it = cache().kernels.find(kernel_name); + it != cache().kernels.end()) { return it->second; } } - std::unique_lock lock(cache.mtx); - if (auto it = cache.kernels.find(kernel_name); it != cache.kernels.end()) { + std::unique_lock lock(cache().mtx); + if (auto it = cache().kernels.find(kernel_name); + it != cache().kernels.end()) { return it->second; } std::string source_code = source_builder(); @@ -120,10 +125,10 @@ void* compile( } // load library - cache.libs.emplace_back(shared_lib_path); + cache().libs.emplace_back(shared_lib_path); // Load function - void* fun = dlsym(cache.libs.back().lib, kernel_name.c_str()); + void* fun = dlsym(cache().libs.back().lib, kernel_name.c_str()); if (!fun) { std::ostringstream msg; msg << "[Compile::eval_cpu] Failed to load compiled function " @@ -131,7 +136,7 @@ void* compile( << dlerror(); throw std::runtime_error(msg.str()); } - cache.kernels.insert({kernel_name, fun}); + cache().kernels.insert({kernel_name, fun}); return fun; } diff --git a/mlx/device.cpp b/mlx/device.cpp index e635782e2..20d8675d8 100644 --- a/mlx/device.cpp +++ b/mlx/device.cpp @@ -5,11 +5,14 @@ namespace mlx::core { -static Device default_device_{ - metal::is_available() ? Device::gpu : Device::cpu}; +Device& mutable_default_device() { + static Device default_device{ + metal::is_available() ? Device::gpu : Device::cpu}; + return default_device; +} const Device& default_device() { - return default_device_; + return mutable_default_device(); } void set_default_device(const Device& d) { @@ -17,7 +20,7 @@ void set_default_device(const Device& d) { throw std::invalid_argument( "[set_default_device] Cannot set gpu device without gpu backend."); } - default_device_ = d; + mutable_default_device() = d; } bool operator==(const Device& lhs, const Device& rhs) { diff --git a/mlx/io/load.cpp b/mlx/io/load.cpp index 59d91c007..2f9053f4d 100644 --- a/mlx/io/load.cpp +++ b/mlx/io/load.cpp @@ -335,7 +335,10 @@ ThreadPool& thread_pool() { return pool_; } -ThreadPool ParallelFileReader::thread_pool_{4}; +ThreadPool& ParallelFileReader::thread_pool() { + static ThreadPool thread_pool{4}; + return thread_pool; +} void ParallelFileReader::read(char* data, size_t n) { while (n != 0) { @@ -371,7 +374,8 @@ void ParallelFileReader::read(char* data, size_t n, size_t offset) { break; } else { size_t m = batch_size_; - futs.emplace_back(thread_pool_.enqueue(readfn, offset, m, data)); + futs.emplace_back( + ParallelFileReader::thread_pool().enqueue(readfn, offset, m, data)); data += m; n -= m; offset += m; diff --git a/mlx/io/load.h b/mlx/io/load.h index 138098e82..8b5dd95b6 100644 --- a/mlx/io/load.h +++ b/mlx/io/load.h @@ -101,7 +101,7 @@ class ParallelFileReader : public Reader { private: static constexpr size_t batch_size_ = 1 << 25; - static ThreadPool thread_pool_; + static ThreadPool& thread_pool(); int fd_; std::string label_; }; diff --git a/mlx/transforms.cpp b/mlx/transforms.cpp index b305257f0..f9a5de031 100644 --- a/mlx/transforms.cpp +++ b/mlx/transforms.cpp @@ -42,7 +42,10 @@ class Synchronizer : public Primitive { // are currently under a function transformation and the retain_graph() // function which returns true if we are forced to retain the graph during // evaluation. -std::vector> detail::InTracing::trace_stack{}; +std::vector>& detail::InTracing::trace_stack() { + static std::vector> trace_stack_; + return trace_stack_; +} int detail::InTracing::grad_counter{0}; int detail::RetainGraph::tracing_counter{0}; diff --git a/mlx/transforms_impl.h b/mlx/transforms_impl.h index 7f62c406b..46851fa3d 100644 --- a/mlx/transforms_impl.h +++ b/mlx/transforms_impl.h @@ -22,19 +22,19 @@ std::vector vmap_replace( struct InTracing { explicit InTracing(bool dynamic = false, bool grad = false) { grad_counter += grad; - trace_stack.push_back({dynamic, grad}); + trace_stack().push_back({dynamic, grad}); } ~InTracing() { - grad_counter -= trace_stack.back().second; - trace_stack.pop_back(); + grad_counter -= trace_stack().back().second; + trace_stack().pop_back(); } static bool in_tracing() { - return !trace_stack.empty(); + return !trace_stack().empty(); } static bool in_dynamic_tracing() { // compile is always and only the outer-most transform - return in_tracing() && trace_stack.front().first; + return in_tracing() && trace_stack().front().first; } static bool in_grad_tracing() { @@ -43,7 +43,7 @@ struct InTracing { private: static int grad_counter; - static std::vector> trace_stack; + static std::vector>& trace_stack(); }; struct RetainGraph {