diff --git a/mlx/3rdparty/pocketfft.h b/mlx/3rdparty/pocketfft.h index 03a45897a..dc15c61be 100644 --- a/mlx/3rdparty/pocketfft.h +++ b/mlx/3rdparty/pocketfft.h @@ -537,7 +537,11 @@ inline size_t &num_threads() static thread_local size_t num_threads_=1; return num_threads_; } -static const size_t max_threads = std::max(1u, std::thread::hardware_concurrency()); +inline const size_t &max_threads() + { + static thread_local const size_t max_threads_ = std::max(1u, std::thread::hardware_concurrency()); + return max_threads_; + } class latch { @@ -721,7 +725,7 @@ class thread_pool workers_(nthreads) { create_threads(); } - thread_pool(): thread_pool(max_threads) {} + thread_pool(): thread_pool(max_threads()) {} ~thread_pool() { shutdown(); } @@ -786,7 +790,7 @@ template void thread_map(size_t nthreads, Func f) { if (nthreads == 0) - nthreads = max_threads; + nthreads = max_threads(); if (nthreads == 1) { f(); return; } diff --git a/mlx/device.cpp b/mlx/device.cpp index e635782e2..20d8675d8 100644 --- a/mlx/device.cpp +++ b/mlx/device.cpp @@ -5,11 +5,14 @@ namespace mlx::core { -static Device default_device_{ - metal::is_available() ? Device::gpu : Device::cpu}; +Device& mutable_default_device() { + static Device default_device{ + metal::is_available() ? Device::gpu : Device::cpu}; + return default_device; +} const Device& default_device() { - return default_device_; + return mutable_default_device(); } void set_default_device(const Device& d) { @@ -17,7 +20,7 @@ void set_default_device(const Device& d) { throw std::invalid_argument( "[set_default_device] Cannot set gpu device without gpu backend."); } - default_device_ = d; + mutable_default_device() = d; } bool operator==(const Device& lhs, const Device& rhs) { diff --git a/mlx/io/load.cpp b/mlx/io/load.cpp index 59d91c007..2f9053f4d 100644 --- a/mlx/io/load.cpp +++ b/mlx/io/load.cpp @@ -335,7 +335,10 @@ ThreadPool& thread_pool() { return pool_; } -ThreadPool ParallelFileReader::thread_pool_{4}; +ThreadPool& ParallelFileReader::thread_pool() { + static ThreadPool thread_pool{4}; + return thread_pool; +} void ParallelFileReader::read(char* data, size_t n) { while (n != 0) { @@ -371,7 +374,8 @@ void ParallelFileReader::read(char* data, size_t n, size_t offset) { break; } else { size_t m = batch_size_; - futs.emplace_back(thread_pool_.enqueue(readfn, offset, m, data)); + futs.emplace_back( + ParallelFileReader::thread_pool().enqueue(readfn, offset, m, data)); data += m; n -= m; offset += m; diff --git a/mlx/io/load.h b/mlx/io/load.h index 138098e82..6bcebde6f 100644 --- a/mlx/io/load.h +++ b/mlx/io/load.h @@ -99,9 +99,11 @@ class ParallelFileReader : public Reader { return "file " + label_; } + private: + static ThreadPool& thread_pool(); + private: static constexpr size_t batch_size_ = 1 << 25; - static ThreadPool thread_pool_; int fd_; std::string label_; };