diff --git a/mlx/transforms.cpp b/mlx/transforms.cpp index b305257f0..4c462a74f 100644 --- a/mlx/transforms.cpp +++ b/mlx/transforms.cpp @@ -42,7 +42,10 @@ class Synchronizer : public Primitive { // are currently under a function transformation and the retain_graph() // function which returns true if we are forced to retain the graph during // evaluation. -std::vector> detail::InTracing::trace_stack{}; +std::vector& detail::InTracing::trace_stack() { + static std::vector trace_stack_; + return trace_stack_; +} int detail::InTracing::grad_counter{0}; int detail::RetainGraph::tracing_counter{0}; diff --git a/mlx/transforms_impl.h b/mlx/transforms_impl.h index 7f62c406b..f61b25c56 100644 --- a/mlx/transforms_impl.h +++ b/mlx/transforms_impl.h @@ -22,19 +22,19 @@ std::vector vmap_replace( struct InTracing { explicit InTracing(bool dynamic = false, bool grad = false) { grad_counter += grad; - trace_stack.push_back({dynamic, grad}); + trace_stack().push_back({dynamic, grad}); } ~InTracing() { - grad_counter -= trace_stack.back().second; + grad_counter -= trace_stack().back().second; trace_stack.pop_back(); } static bool in_tracing() { - return !trace_stack.empty(); + return !trace_stack().empty(); } static bool in_dynamic_tracing() { // compile is always and only the outer-most transform - return in_tracing() && trace_stack.front().first; + return in_tracing() && trace_stack().front().first; } static bool in_grad_tracing() { @@ -43,7 +43,7 @@ struct InTracing { private: static int grad_counter; - static std::vector> trace_stack; + static std::vector& trace_stack(); }; struct RetainGraph {