// Copyright © 2023-2024 Apple Inc. #pragma once namespace mlx::core::detail { std::pair, std::vector> vmap_trace( const std::function(const std::vector&)>& fun, const std::vector& inputs, const std::vector& in_axes); std::vector vmap_replace( const std::vector& inputs, const std::vector& s_inputs, const std::vector& s_outputs, const std::vector& in_axes, const std::vector& out_axes); // Create an InTracing object during tracing operations to signify to the rest // of the codebase that we are during tracing so evals should not throw away // the graph. struct InTracing { InTracing() { tracing_counter++; } ~InTracing() { tracing_counter--; } static bool in_tracing() { return tracing_counter > 0; } private: static int tracing_counter; }; struct RetainGraph { RetainGraph() { tracing_counter++; } ~RetainGraph() { tracing_counter--; } static bool retain_graph() { return tracing_counter > 0; } private: static int tracing_counter; }; } // namespace mlx::core::detail