diff --git a/mlx/io/load.h b/mlx/io/load.h index 6bcebde6f..8b5dd95b6 100644 --- a/mlx/io/load.h +++ b/mlx/io/load.h @@ -99,11 +99,9 @@ class ParallelFileReader : public Reader { return "file " + label_; } - private: - static ThreadPool& thread_pool(); - private: static constexpr size_t batch_size_ = 1 << 25; + static ThreadPool& thread_pool(); int fd_; std::string label_; }; diff --git a/mlx/transforms.cpp b/mlx/transforms.cpp index 4c462a74f..f9a5de031 100644 --- a/mlx/transforms.cpp +++ b/mlx/transforms.cpp @@ -42,8 +42,8 @@ class Synchronizer : public Primitive { // are currently under a function transformation and the retain_graph() // function which returns true if we are forced to retain the graph during // evaluation. -std::vector& detail::InTracing::trace_stack() { - static std::vector trace_stack_; +std::vector>& detail::InTracing::trace_stack() { + static std::vector> trace_stack_; return trace_stack_; } int detail::InTracing::grad_counter{0}; diff --git a/mlx/transforms_impl.h b/mlx/transforms_impl.h index f61b25c56..46851fa3d 100644 --- a/mlx/transforms_impl.h +++ b/mlx/transforms_impl.h @@ -26,7 +26,7 @@ struct InTracing { } ~InTracing() { grad_counter -= trace_stack().back().second; - trace_stack.pop_back(); + trace_stack().pop_back(); } static bool in_tracing() { @@ -43,7 +43,7 @@ struct InTracing { private: static int grad_counter; - static std::vector& trace_stack(); + static std::vector>& trace_stack(); }; struct RetainGraph {