From 83d320c2ead82632cb2188448242880f558278f5 Mon Sep 17 00:00:00 2001 From: Hanyu Deng Date: Thu, 10 Apr 2025 12:00:52 +0800 Subject: [PATCH] Remove static initializer InTracing::trace_stack --- mlx/transforms.cpp | 5 ++++- mlx/transforms_impl.h | 10 +++++----- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/mlx/transforms.cpp b/mlx/transforms.cpp index b305257f0..4c462a74f 100644 --- a/mlx/transforms.cpp +++ b/mlx/transforms.cpp @@ -42,7 +42,10 @@ class Synchronizer : public Primitive { // are currently under a function transformation and the retain_graph() // function which returns true if we are forced to retain the graph during // evaluation. -std::vector> detail::InTracing::trace_stack{}; +std::vector& detail::InTracing::trace_stack() { + static std::vector trace_stack_; + return trace_stack_; +} int detail::InTracing::grad_counter{0}; int detail::RetainGraph::tracing_counter{0}; diff --git a/mlx/transforms_impl.h b/mlx/transforms_impl.h index 7f62c406b..f61b25c56 100644 --- a/mlx/transforms_impl.h +++ b/mlx/transforms_impl.h @@ -22,19 +22,19 @@ std::vector vmap_replace( struct InTracing { explicit InTracing(bool dynamic = false, bool grad = false) { grad_counter += grad; - trace_stack.push_back({dynamic, grad}); + trace_stack().push_back({dynamic, grad}); } ~InTracing() { - grad_counter -= trace_stack.back().second; + grad_counter -= trace_stack().back().second; trace_stack.pop_back(); } static bool in_tracing() { - return !trace_stack.empty(); + return !trace_stack().empty(); } static bool in_dynamic_tracing() { // compile is always and only the outer-most transform - return in_tracing() && trace_stack.front().first; + return in_tracing() && trace_stack().front().first; } static bool in_grad_tracing() { @@ -43,7 +43,7 @@ struct InTracing { private: static int grad_counter; - static std::vector> trace_stack; + static std::vector& trace_stack(); }; struct RetainGraph {