MLX
Loading...
Searching...
No Matches
transforms_impl.h
Go to the documentation of this file.
1// Copyright © 2023-2024 Apple Inc.
2
3#pragma once
4
5namespace mlx::core::detail {
6
7std::pair<std::vector<array>, std::vector<array>> vmap_trace(
8 const std::function<std::vector<array>(const std::vector<array>&)>& fun,
9 const std::vector<array>& inputs,
10 const std::vector<int>& in_axes);
11
12std::vector<array> vmap_replace(
13 const std::vector<array>& inputs,
14 const std::vector<array>& s_inputs,
15 const std::vector<array>& s_outputs,
16 const std::vector<int>& in_axes,
17 const std::vector<int>& out_axes);
18
19// Create an InTracing object during tracing operations to signify to the rest
20// of the codebase that we are during tracing so evals should not throw away
21// the graph.
22struct InTracing {
24 tracing_counter++;
25 }
27 tracing_counter--;
28 }
29
30 static bool in_tracing() {
31 return tracing_counter > 0;
32 }
33
34 private:
35 static int tracing_counter;
36};
37
40 tracing_counter++;
41 }
43 tracing_counter--;
44 }
45
46 static bool retain_graph() {
47 return tracing_counter > 0;
48 }
49
50 private:
51 static int tracing_counter;
52};
53
54} // namespace mlx::core::detail
Definition ops.h:8
std::vector< array > vmap_replace(const std::vector< array > &inputs, const std::vector< array > &s_inputs, const std::vector< array > &s_outputs, const std::vector< int > &in_axes, const std::vector< int > &out_axes)
std::pair< std::vector< array >, std::vector< array > > vmap_trace(const std::function< std::vector< array >(const std::vector< array > &)> &fun, const std::vector< array > &inputs, const std::vector< int > &in_axes)
Definition transforms_impl.h:22
InTracing()
Definition transforms_impl.h:23
~InTracing()
Definition transforms_impl.h:26
static bool in_tracing()
Definition transforms_impl.h:30
Definition transforms_impl.h:38
static bool retain_graph()
Definition transforms_impl.h:46
~RetainGraph()
Definition transforms_impl.h:42
RetainGraph()
Definition transforms_impl.h:39