MLX
 
Loading...
Searching...
No Matches
compile_impl.h
Go to the documentation of this file.
1// Copyright © 2023-2024 Apple Inc.
2
3#pragma once
4
5#include <unordered_map>
6
7#include "mlx/array.h"
8
9namespace mlx::core::detail {
10
11// This is not part of the general C++ API as calling with a bad id is a bad
12// idea.
13std::function<std::vector<array>(const std::vector<array>&)> compile(
14 std::function<std::vector<array>(const std::vector<array>&)> fun,
15 std::uintptr_t fun_id,
16 bool shapeless = false,
17 std::vector<uint64_t> constants = {});
18
19// Erase cached compile functions
20void compile_erase(std::uintptr_t fun_id);
21
22// Clear the compiler cache causing a recompilation of all compiled functions
23// when called again.
25
27
28std::pair<std::vector<array>, std::vector<array>> compile_trace(
29 const std::function<std::vector<array>(const std::vector<array>&)>& fun,
30 const std::vector<array>& inputs,
31 bool shapeless);
32
34 std::unordered_map<std::uintptr_t, std::vector<std::pair<array, int>>>;
35
36// Traverses the graph to build a tape and a map of array ids to their parents
37std::pair<std::vector<array>, ParentsMap> compile_dfs(
38 const std::vector<array>& inputs,
39 const std::vector<array>& outputs,
40 const std::vector<array>& original_inputs);
41
42// Simplify the tape.
44 std::vector<array>& tape,
45 ParentsMap& parents_map,
46 std::vector<array>& outputs,
47 int passes);
48
49std::vector<array> compile_replace(
50 const std::vector<array>& tape,
51 const std::vector<array>& trace_inputs,
52 const std::vector<array>& trace_outputs,
53 const std::vector<array>& inputs,
54 bool shapeless);
55
56void compile_validate_shapeless(const std::vector<array>& tape);
57
58} // namespace mlx::core::detail
Definition binary_ops.h:7
void compile_validate_shapeless(const std::vector< array > &tape)
void compile_simplify(std::vector< array > &tape, ParentsMap &parents_map, std::vector< array > &outputs, int passes)
void compile_clear_cache()
std::pair< std::vector< array >, ParentsMap > compile_dfs(const std::vector< array > &inputs, const std::vector< array > &outputs, const std::vector< array > &original_inputs)
std::vector< array > compile_replace(const std::vector< array > &tape, const std::vector< array > &trace_inputs, const std::vector< array > &trace_outputs, const std::vector< array > &inputs, bool shapeless)
void compile_erase(std::uintptr_t fun_id)
std::unordered_map< std::uintptr_t, std::vector< std::pair< array, int > > > ParentsMap
Definition compile_impl.h:33
std::pair< std::vector< array >, std::vector< array > > compile_trace(const std::function< std::vector< array >(const std::vector< array > &)> &fun, const std::vector< array > &inputs, bool shapeless)
bool compile_available_for_device(const Device &device)
std::function< std::vector< array >(const std::vector< array > &)> compile(std::function< std::vector< array >(const std::vector< array > &)> fun, std::uintptr_t fun_id, bool shapeless=false, std::vector< uint64_t > constants={})
Definition device.h:7