MLX
 
Loading...
Searching...
No Matches
transforms_impl.h
Go to the documentation of this file.
1// Copyright © 2023-2024 Apple Inc.
2
3#pragma once
4
5namespace mlx::core::detail {
6
7std::pair<std::vector<array>, std::vector<array>> vmap_trace(
8 const std::function<std::vector<array>(const std::vector<array>&)>& fun,
9 const std::vector<array>& inputs,
10 const std::vector<int>& in_axes);
11
12std::vector<array> vmap_replace(
13 const std::vector<array>& inputs,
14 const std::vector<array>& s_inputs,
15 const std::vector<array>& s_outputs,
16 const std::vector<int>& in_axes,
17 const std::vector<int>& out_axes);
18
19// Create an InTracing object during tracing operations to signify to the rest
20// of the codebase that we are during tracing so evals should not throw away
21// the graph.
22struct InTracing {
23 explicit InTracing(bool dynamic = false) {
24 trace_stack.push_back(dynamic);
25 }
27 trace_stack.pop_back();
28 }
29
30 static bool in_tracing() {
31 return !trace_stack.empty();
32 }
33 static bool in_dynamic_tracing() {
34 // compile is always and only the outer-most transform
35 return in_tracing() && trace_stack.front();
36 }
37
38 private:
39 static std::vector<char> trace_stack;
40};
41
44 tracing_counter++;
45 }
47 tracing_counter--;
48 }
49
50 static bool retain_graph() {
51 return tracing_counter > 0;
52 }
53
54 private:
55 static int tracing_counter;
56};
57
60inline bool in_tracing() {
62}
63
69
70inline bool retain_graph() {
72}
73
74} // namespace mlx::core::detail
Definition binary_ops.h:7
std::vector< array > vmap_replace(const std::vector< array > &inputs, const std::vector< array > &s_inputs, const std::vector< array > &s_outputs, const std::vector< int > &in_axes, const std::vector< int > &out_axes)
bool retain_graph()
Definition transforms_impl.h:70
bool in_dynamic_tracing()
Return true if we are in a dynamic (shapeless) trace used for compiling or exporting graphs with dyna...
Definition transforms_impl.h:66
std::pair< std::vector< array >, std::vector< array > > vmap_trace(const std::function< std::vector< array >(const std::vector< array > &)> &fun, const std::vector< array > &inputs, const std::vector< int > &in_axes)
bool in_tracing()
Return true if we are currently performing a function transformation in order to keep the graph when ...
Definition transforms_impl.h:60
InTracing(bool dynamic=false)
Definition transforms_impl.h:23
~InTracing()
Definition transforms_impl.h:26
static bool in_tracing()
Definition transforms_impl.h:30
static bool in_dynamic_tracing()
Definition transforms_impl.h:33
static bool retain_graph()
Definition transforms_impl.h:50
~RetainGraph()
Definition transforms_impl.h:46
RetainGraph()
Definition transforms_impl.h:43