MLX
Loading...
Searching...
No Matches
transforms_impl.h
Go to the documentation of this file.
1// Copyright © 2023-2024 Apple Inc.
2
3#pragma once
4
5namespace mlx::core::detail {
6
7std::pair<std::vector<array>, std::vector<array>> vmap_trace(
8 const std::function<std::vector<array>(const std::vector<array>&)>& fun,
9 const std::vector<array>& inputs,
10 const std::vector<int>& in_axes);
11
12std::vector<array> vmap_replace(
13 const std::vector<array>& inputs,
14 const std::vector<array>& s_inputs,
15 const std::vector<array>& s_outputs,
16 const std::vector<int>& in_axes,
17 const std::vector<int>& out_axes);
18
19// This is not part of the general C++ API as calling with a bad id is a bad
20// idea.
21std::function<std::vector<array>(const std::vector<array>&)> compile(
22 const std::function<std::vector<array>(const std::vector<array>&)>& fun,
23 std::uintptr_t fun_id,
24 bool shapeless = false,
25 std::vector<uint64_t> constants = {});
26
27// Erase cached compile functions
28void compile_erase(std::uintptr_t fun_id);
29
30// Create an InTracing object during tracing operations to signify to the rest
31// of the codebase that we are during tracing so evals should not throw away
32// the graph.
33struct InTracing {
35 tracing_counter++;
36 }
38 tracing_counter--;
39 }
40
41 static bool in_tracing() {
42 return tracing_counter > 0;
43 }
44
45 private:
46 static int tracing_counter;
47};
48
49} // namespace mlx::core::detail
Definition ops.h:8
std::vector< array > vmap_replace(const std::vector< array > &inputs, const std::vector< array > &s_inputs, const std::vector< array > &s_outputs, const std::vector< int > &in_axes, const std::vector< int > &out_axes)
std::function< std::vector< array >(const std::vector< array > &) compile)(const std::function< std::vector< array >(const std::vector< array > &)> &fun, std::uintptr_t fun_id, bool shapeless=false, std::vector< uint64_t > constants={})
std::pair< std::vector< array >, std::vector< array > > vmap_trace(const std::function< std::vector< array >(const std::vector< array > &)> &fun, const std::vector< array > &inputs, const std::vector< int > &in_axes)
void compile_erase(std::uintptr_t fun_id)
Definition transforms_impl.h:33
InTracing()
Definition transforms_impl.h:34
~InTracing()
Definition transforms_impl.h:37
static bool in_tracing()
Definition transforms_impl.h:41