MLX
Loading...
Searching...
No Matches
transforms_impl.h
Go to the documentation of this file.
1// Copyright © 2023-2024 Apple Inc.
2
3#pragma once
4
5namespace mlx::core::detail {
6
7std::pair<std::vector<array>, std::vector<array>> vmap_trace(
8 const std::function<std::vector<array>(const std::vector<array>&)>& fun,
9 const std::vector<array>& inputs,
10 const std::vector<int>& in_axes);
11
12std::vector<array> vmap_replace(
13 const std::vector<array>& inputs,
14 const std::vector<array>& s_inputs,
15 const std::vector<array>& s_outputs,
16 const std::vector<int>& in_axes,
17 const std::vector<int>& out_axes);
18
19// This is not part of the general C++ API as calling with a bad id is a bad
20// idea.
21std::function<std::vector<array>(const std::vector<array>&)> compile(
22 const std::function<std::vector<array>(const std::vector<array>&)>& fun,
23 std::uintptr_t fun_id,
24 bool shapeless = false,
25 std::vector<uint64_t> constants = {});
26
27// Erase cached compile functions
28void compile_erase(std::uintptr_t fun_id);
29
30// Clear the compiler cache causing a recompilation of all compiled functions
31// when called again.
33
34// Create an InTracing object during tracing operations to signify to the rest
35// of the codebase that we are during tracing so evals should not throw away
36// the graph.
37struct InTracing {
39 tracing_counter++;
40 }
42 tracing_counter--;
43 }
44
45 static bool in_tracing() {
46 return tracing_counter > 0;
47 }
48
49 private:
50 static int tracing_counter;
51};
52
55 tracing_counter++;
56 }
58 tracing_counter--;
59 }
60
61 static bool retain_graph() {
62 return tracing_counter > 0;
63 }
64
65 private:
66 static int tracing_counter;
67};
68
69} // namespace mlx::core::detail
Definition ops.h:8
std::vector< array > vmap_replace(const std::vector< array > &inputs, const std::vector< array > &s_inputs, const std::vector< array > &s_outputs, const std::vector< int > &in_axes, const std::vector< int > &out_axes)
void compile_clear_cache()
std::function< std::vector< array >(const std::vector< array > &) compile)(const std::function< std::vector< array >(const std::vector< array > &)> &fun, std::uintptr_t fun_id, bool shapeless=false, std::vector< uint64_t > constants={})
std::pair< std::vector< array >, std::vector< array > > vmap_trace(const std::function< std::vector< array >(const std::vector< array > &)> &fun, const std::vector< array > &inputs, const std::vector< int > &in_axes)
void compile_erase(std::uintptr_t fun_id)
Definition transforms_impl.h:37
InTracing()
Definition transforms_impl.h:38
~InTracing()
Definition transforms_impl.h:41
static bool in_tracing()
Definition transforms_impl.h:45
Definition transforms_impl.h:53
static bool retain_graph()
Definition transforms_impl.h:61
~RetainGraph()
Definition transforms_impl.h:57
RetainGraph()
Definition transforms_impl.h:54