Remove static initializers (#2059)

* Remove static initializers in device.cpp, load.cpp, pocketfft.h

* Remove static initializer InTracing::trace_stack

* Remove static initializer of CompilerCache cache

* Revert changes in pocketfft.h

* Remove duplicate private section of thread_pool()
This commit is contained in:
hdeng-apple 2025-04-24 21:14:49 +08:00 committed by GitHub
parent fbc89e3ced
commit 86984cad68
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 37 additions and 22 deletions

View File

@ -40,7 +40,10 @@ struct CompilerCache {
std::shared_mutex mtx; std::shared_mutex mtx;
}; };
static CompilerCache cache{}; static CompilerCache& cache() {
static CompilerCache cache_;
return cache_;
};
// GPU compile is always available if the GPU is available and since we are in // GPU compile is always available if the GPU is available and since we are in
// this file CPU compile is also available. // this file CPU compile is also available.
@ -56,14 +59,16 @@ void* compile(
const std::string& kernel_name, const std::string& kernel_name,
const std::function<std::string(void)>& source_builder) { const std::function<std::string(void)>& source_builder) {
{ {
std::shared_lock lock(cache.mtx); std::shared_lock lock(cache().mtx);
if (auto it = cache.kernels.find(kernel_name); it != cache.kernels.end()) { if (auto it = cache().kernels.find(kernel_name);
it != cache().kernels.end()) {
return it->second; return it->second;
} }
} }
std::unique_lock lock(cache.mtx); std::unique_lock lock(cache().mtx);
if (auto it = cache.kernels.find(kernel_name); it != cache.kernels.end()) { if (auto it = cache().kernels.find(kernel_name);
it != cache().kernels.end()) {
return it->second; return it->second;
} }
std::string source_code = source_builder(); std::string source_code = source_builder();
@ -120,10 +125,10 @@ void* compile(
} }
// load library // load library
cache.libs.emplace_back(shared_lib_path); cache().libs.emplace_back(shared_lib_path);
// Load function // Load function
void* fun = dlsym(cache.libs.back().lib, kernel_name.c_str()); void* fun = dlsym(cache().libs.back().lib, kernel_name.c_str());
if (!fun) { if (!fun) {
std::ostringstream msg; std::ostringstream msg;
msg << "[Compile::eval_cpu] Failed to load compiled function " msg << "[Compile::eval_cpu] Failed to load compiled function "
@ -131,7 +136,7 @@ void* compile(
<< dlerror(); << dlerror();
throw std::runtime_error(msg.str()); throw std::runtime_error(msg.str());
} }
cache.kernels.insert({kernel_name, fun}); cache().kernels.insert({kernel_name, fun});
return fun; return fun;
} }

View File

@ -5,11 +5,14 @@
namespace mlx::core { namespace mlx::core {
static Device default_device_{ Device& mutable_default_device() {
metal::is_available() ? Device::gpu : Device::cpu}; static Device default_device{
metal::is_available() ? Device::gpu : Device::cpu};
return default_device;
}
const Device& default_device() { const Device& default_device() {
return default_device_; return mutable_default_device();
} }
void set_default_device(const Device& d) { void set_default_device(const Device& d) {
@ -17,7 +20,7 @@ void set_default_device(const Device& d) {
throw std::invalid_argument( throw std::invalid_argument(
"[set_default_device] Cannot set gpu device without gpu backend."); "[set_default_device] Cannot set gpu device without gpu backend.");
} }
default_device_ = d; mutable_default_device() = d;
} }
bool operator==(const Device& lhs, const Device& rhs) { bool operator==(const Device& lhs, const Device& rhs) {

View File

@ -335,7 +335,10 @@ ThreadPool& thread_pool() {
return pool_; return pool_;
} }
ThreadPool ParallelFileReader::thread_pool_{4}; ThreadPool& ParallelFileReader::thread_pool() {
static ThreadPool thread_pool{4};
return thread_pool;
}
void ParallelFileReader::read(char* data, size_t n) { void ParallelFileReader::read(char* data, size_t n) {
while (n != 0) { while (n != 0) {
@ -371,7 +374,8 @@ void ParallelFileReader::read(char* data, size_t n, size_t offset) {
break; break;
} else { } else {
size_t m = batch_size_; size_t m = batch_size_;
futs.emplace_back(thread_pool_.enqueue(readfn, offset, m, data)); futs.emplace_back(
ParallelFileReader::thread_pool().enqueue(readfn, offset, m, data));
data += m; data += m;
n -= m; n -= m;
offset += m; offset += m;

View File

@ -101,7 +101,7 @@ class ParallelFileReader : public Reader {
private: private:
static constexpr size_t batch_size_ = 1 << 25; static constexpr size_t batch_size_ = 1 << 25;
static ThreadPool thread_pool_; static ThreadPool& thread_pool();
int fd_; int fd_;
std::string label_; std::string label_;
}; };

View File

@ -42,7 +42,10 @@ class Synchronizer : public Primitive {
// are currently under a function transformation and the retain_graph() // are currently under a function transformation and the retain_graph()
// function which returns true if we are forced to retain the graph during // function which returns true if we are forced to retain the graph during
// evaluation. // evaluation.
std::vector<std::pair<char, char>> detail::InTracing::trace_stack{}; std::vector<std::pair<char, char>>& detail::InTracing::trace_stack() {
static std::vector<std::pair<char, char>> trace_stack_;
return trace_stack_;
}
int detail::InTracing::grad_counter{0}; int detail::InTracing::grad_counter{0};
int detail::RetainGraph::tracing_counter{0}; int detail::RetainGraph::tracing_counter{0};

View File

@ -22,19 +22,19 @@ std::vector<array> vmap_replace(
struct InTracing { struct InTracing {
explicit InTracing(bool dynamic = false, bool grad = false) { explicit InTracing(bool dynamic = false, bool grad = false) {
grad_counter += grad; grad_counter += grad;
trace_stack.push_back({dynamic, grad}); trace_stack().push_back({dynamic, grad});
} }
~InTracing() { ~InTracing() {
grad_counter -= trace_stack.back().second; grad_counter -= trace_stack().back().second;
trace_stack.pop_back(); trace_stack().pop_back();
} }
static bool in_tracing() { static bool in_tracing() {
return !trace_stack.empty(); return !trace_stack().empty();
} }
static bool in_dynamic_tracing() { static bool in_dynamic_tracing() {
// compile is always and only the outer-most transform // compile is always and only the outer-most transform
return in_tracing() && trace_stack.front().first; return in_tracing() && trace_stack().front().first;
} }
static bool in_grad_tracing() { static bool in_grad_tracing() {
@ -43,7 +43,7 @@ struct InTracing {
private: private:
static int grad_counter; static int grad_counter;
static std::vector<std::pair<char, char>> trace_stack; static std::vector<std::pair<char, char>>& trace_stack();
}; };
struct RetainGraph { struct RetainGraph {