mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
Remove static initializers (#2059)
* Remove static initializers in device.cpp, load.cpp, pocketfft.h * Remove static initializer InTracing::trace_stack * Remove static initializer of CompilerCache cache * Revert changes in pocketfft.h * Remove duplicate private section of thread_pool()
This commit is contained in:
parent
fbc89e3ced
commit
86984cad68
@ -40,7 +40,10 @@ struct CompilerCache {
|
|||||||
std::shared_mutex mtx;
|
std::shared_mutex mtx;
|
||||||
};
|
};
|
||||||
|
|
||||||
static CompilerCache cache{};
|
static CompilerCache& cache() {
|
||||||
|
static CompilerCache cache_;
|
||||||
|
return cache_;
|
||||||
|
};
|
||||||
|
|
||||||
// GPU compile is always available if the GPU is available and since we are in
|
// GPU compile is always available if the GPU is available and since we are in
|
||||||
// this file CPU compile is also available.
|
// this file CPU compile is also available.
|
||||||
@ -56,14 +59,16 @@ void* compile(
|
|||||||
const std::string& kernel_name,
|
const std::string& kernel_name,
|
||||||
const std::function<std::string(void)>& source_builder) {
|
const std::function<std::string(void)>& source_builder) {
|
||||||
{
|
{
|
||||||
std::shared_lock lock(cache.mtx);
|
std::shared_lock lock(cache().mtx);
|
||||||
if (auto it = cache.kernels.find(kernel_name); it != cache.kernels.end()) {
|
if (auto it = cache().kernels.find(kernel_name);
|
||||||
|
it != cache().kernels.end()) {
|
||||||
return it->second;
|
return it->second;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
std::unique_lock lock(cache.mtx);
|
std::unique_lock lock(cache().mtx);
|
||||||
if (auto it = cache.kernels.find(kernel_name); it != cache.kernels.end()) {
|
if (auto it = cache().kernels.find(kernel_name);
|
||||||
|
it != cache().kernels.end()) {
|
||||||
return it->second;
|
return it->second;
|
||||||
}
|
}
|
||||||
std::string source_code = source_builder();
|
std::string source_code = source_builder();
|
||||||
@ -120,10 +125,10 @@ void* compile(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// load library
|
// load library
|
||||||
cache.libs.emplace_back(shared_lib_path);
|
cache().libs.emplace_back(shared_lib_path);
|
||||||
|
|
||||||
// Load function
|
// Load function
|
||||||
void* fun = dlsym(cache.libs.back().lib, kernel_name.c_str());
|
void* fun = dlsym(cache().libs.back().lib, kernel_name.c_str());
|
||||||
if (!fun) {
|
if (!fun) {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
msg << "[Compile::eval_cpu] Failed to load compiled function "
|
msg << "[Compile::eval_cpu] Failed to load compiled function "
|
||||||
@ -131,7 +136,7 @@ void* compile(
|
|||||||
<< dlerror();
|
<< dlerror();
|
||||||
throw std::runtime_error(msg.str());
|
throw std::runtime_error(msg.str());
|
||||||
}
|
}
|
||||||
cache.kernels.insert({kernel_name, fun});
|
cache().kernels.insert({kernel_name, fun});
|
||||||
return fun;
|
return fun;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -5,11 +5,14 @@
|
|||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
static Device default_device_{
|
Device& mutable_default_device() {
|
||||||
|
static Device default_device{
|
||||||
metal::is_available() ? Device::gpu : Device::cpu};
|
metal::is_available() ? Device::gpu : Device::cpu};
|
||||||
|
return default_device;
|
||||||
|
}
|
||||||
|
|
||||||
const Device& default_device() {
|
const Device& default_device() {
|
||||||
return default_device_;
|
return mutable_default_device();
|
||||||
}
|
}
|
||||||
|
|
||||||
void set_default_device(const Device& d) {
|
void set_default_device(const Device& d) {
|
||||||
@ -17,7 +20,7 @@ void set_default_device(const Device& d) {
|
|||||||
throw std::invalid_argument(
|
throw std::invalid_argument(
|
||||||
"[set_default_device] Cannot set gpu device without gpu backend.");
|
"[set_default_device] Cannot set gpu device without gpu backend.");
|
||||||
}
|
}
|
||||||
default_device_ = d;
|
mutable_default_device() = d;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool operator==(const Device& lhs, const Device& rhs) {
|
bool operator==(const Device& lhs, const Device& rhs) {
|
||||||
|
@ -335,7 +335,10 @@ ThreadPool& thread_pool() {
|
|||||||
return pool_;
|
return pool_;
|
||||||
}
|
}
|
||||||
|
|
||||||
ThreadPool ParallelFileReader::thread_pool_{4};
|
ThreadPool& ParallelFileReader::thread_pool() {
|
||||||
|
static ThreadPool thread_pool{4};
|
||||||
|
return thread_pool;
|
||||||
|
}
|
||||||
|
|
||||||
void ParallelFileReader::read(char* data, size_t n) {
|
void ParallelFileReader::read(char* data, size_t n) {
|
||||||
while (n != 0) {
|
while (n != 0) {
|
||||||
@ -371,7 +374,8 @@ void ParallelFileReader::read(char* data, size_t n, size_t offset) {
|
|||||||
break;
|
break;
|
||||||
} else {
|
} else {
|
||||||
size_t m = batch_size_;
|
size_t m = batch_size_;
|
||||||
futs.emplace_back(thread_pool_.enqueue(readfn, offset, m, data));
|
futs.emplace_back(
|
||||||
|
ParallelFileReader::thread_pool().enqueue(readfn, offset, m, data));
|
||||||
data += m;
|
data += m;
|
||||||
n -= m;
|
n -= m;
|
||||||
offset += m;
|
offset += m;
|
||||||
|
@ -101,7 +101,7 @@ class ParallelFileReader : public Reader {
|
|||||||
|
|
||||||
private:
|
private:
|
||||||
static constexpr size_t batch_size_ = 1 << 25;
|
static constexpr size_t batch_size_ = 1 << 25;
|
||||||
static ThreadPool thread_pool_;
|
static ThreadPool& thread_pool();
|
||||||
int fd_;
|
int fd_;
|
||||||
std::string label_;
|
std::string label_;
|
||||||
};
|
};
|
||||||
|
@ -42,7 +42,10 @@ class Synchronizer : public Primitive {
|
|||||||
// are currently under a function transformation and the retain_graph()
|
// are currently under a function transformation and the retain_graph()
|
||||||
// function which returns true if we are forced to retain the graph during
|
// function which returns true if we are forced to retain the graph during
|
||||||
// evaluation.
|
// evaluation.
|
||||||
std::vector<std::pair<char, char>> detail::InTracing::trace_stack{};
|
std::vector<std::pair<char, char>>& detail::InTracing::trace_stack() {
|
||||||
|
static std::vector<std::pair<char, char>> trace_stack_;
|
||||||
|
return trace_stack_;
|
||||||
|
}
|
||||||
int detail::InTracing::grad_counter{0};
|
int detail::InTracing::grad_counter{0};
|
||||||
int detail::RetainGraph::tracing_counter{0};
|
int detail::RetainGraph::tracing_counter{0};
|
||||||
|
|
||||||
|
@ -22,19 +22,19 @@ std::vector<array> vmap_replace(
|
|||||||
struct InTracing {
|
struct InTracing {
|
||||||
explicit InTracing(bool dynamic = false, bool grad = false) {
|
explicit InTracing(bool dynamic = false, bool grad = false) {
|
||||||
grad_counter += grad;
|
grad_counter += grad;
|
||||||
trace_stack.push_back({dynamic, grad});
|
trace_stack().push_back({dynamic, grad});
|
||||||
}
|
}
|
||||||
~InTracing() {
|
~InTracing() {
|
||||||
grad_counter -= trace_stack.back().second;
|
grad_counter -= trace_stack().back().second;
|
||||||
trace_stack.pop_back();
|
trace_stack().pop_back();
|
||||||
}
|
}
|
||||||
|
|
||||||
static bool in_tracing() {
|
static bool in_tracing() {
|
||||||
return !trace_stack.empty();
|
return !trace_stack().empty();
|
||||||
}
|
}
|
||||||
static bool in_dynamic_tracing() {
|
static bool in_dynamic_tracing() {
|
||||||
// compile is always and only the outer-most transform
|
// compile is always and only the outer-most transform
|
||||||
return in_tracing() && trace_stack.front().first;
|
return in_tracing() && trace_stack().front().first;
|
||||||
}
|
}
|
||||||
|
|
||||||
static bool in_grad_tracing() {
|
static bool in_grad_tracing() {
|
||||||
@ -43,7 +43,7 @@ struct InTracing {
|
|||||||
|
|
||||||
private:
|
private:
|
||||||
static int grad_counter;
|
static int grad_counter;
|
||||||
static std::vector<std::pair<char, char>> trace_stack;
|
static std::vector<std::pair<char, char>>& trace_stack();
|
||||||
};
|
};
|
||||||
|
|
||||||
struct RetainGraph {
|
struct RetainGraph {
|
||||||
|
Loading…
Reference in New Issue
Block a user