MLX
Loading...
Searching...
No Matches
device.h
Go to the documentation of this file.
1// Copyright © 2023-2024 Apple Inc.
2
3#pragma once
4
5#include <Metal/Metal.hpp>
6#include <dlfcn.h>
7#include <filesystem>
8#include <functional>
9#include <mutex>
10#include <shared_mutex>
11#include <string>
12#include <unordered_map>
13#include <unordered_set>
14
15#include "mlx/array.h"
16#include "mlx/device.h"
17
18namespace fs = std::filesystem;
19
20namespace mlx::core::metal {
21
22// Note, this function must be left inline in a header so that it is not
23// dynamically linked.
24inline std::string get_colocated_mtllib_path(const std::string& lib_name) {
25 Dl_info info;
26 std::string mtllib_path;
27 std::string lib_ext = lib_name + ".metallib";
28
29 int success = dladdr((void*)get_colocated_mtllib_path, &info);
30 if (success) {
31 auto mtllib = fs::path(info.dli_fname).remove_filename() / lib_ext;
32 mtllib_path = mtllib.c_str();
33 }
34
35 return mtllib_path;
36}
37
38using MTLFCList =
39 std::vector<std::tuple<const void*, MTL::DataType, NS::UInteger>>;
40
42 CommandEncoder(MTL::CommandBuffer* cbuf);
45
48 enc.concurrent_ = true;
49 }
51 enc.concurrent_ = false;
52 enc.prev_outputs_.insert(
53 enc.concurrent_outputs_.begin(), enc.concurrent_outputs_.end());
54 enc.concurrent_outputs_.clear();
55 }
56
57 private:
58 CommandEncoder& enc;
59 };
60
61 MTL::ComputeCommandEncoder* operator->() {
62 return enc_;
63 }
64
65 void set_input_array(const array& a, int idx, int64_t offset = 0);
66 void set_output_array(array& a, int idx, int64_t offset = 0);
67 void dispatchThreadgroups(MTL::Size grid_dims, MTL::Size group_dims);
68 void dispatchThreads(MTL::Size grid_dims, MTL::Size group_dims);
70
75
76 // Inputs to all kernels in the encoder including temporaries
77 std::unordered_set<const void*>& inputs() {
78 return all_inputs_;
79 };
80
81 // Outputs of all kernels in the encoder including temporaries
82 std::unordered_set<const void*> outputs() {
83 return all_outputs_;
84 };
85
86 private:
87 MTL::ComputeCommandEncoder* enc_;
88 bool needs_barrier_{false};
89 bool concurrent_{false};
90 std::unordered_set<MTL::Resource*> prev_outputs_;
91 std::unordered_set<MTL::Resource*> next_outputs_;
92 std::unordered_set<MTL::Resource*> concurrent_outputs_;
93 std::unordered_set<const void*> all_inputs_;
94 std::unordered_set<const void*> all_outputs_;
95};
96
97struct Fence {
98 Fence(MTL::Fence* fence) : fence(fence) {}
100 fence->release();
101 }
102 MTL::Fence* fence;
103};
104
106 DeviceStream(MTL::CommandQueue* queue) : queue(queue) {};
108 queue->release();
109 if (buffer != nullptr) {
110 buffer->release();
111 }
112 };
113 MTL::CommandQueue* queue;
114 // A map of prior command encoder outputs to their corresponding fence
115 std::unordered_map<const void*, std::shared_ptr<Fence>> outputs;
116 // Used to allow thread-safe access to the outputs map
117 std::mutex fence_mtx;
118
119 // The buffer and buffer op count are updated
120 // between command buffers
121 MTL::CommandBuffer* buffer{nullptr};
123
124 // The command encoder, fence, and temporaries are updated between command
125 // encoders
126 std::unique_ptr<CommandEncoder> encoder{nullptr};
127 std::shared_ptr<Fence> fence;
128 std::vector<array> temporaries;
129};
130
131class Device {
132 public:
134 Device(const Device&) = delete;
135 Device& operator=(const Device&) = delete;
137
138 MTL::Device* mtl_device() {
139 return device_;
140 };
141
142 const std::string& get_architecture() {
143 return arch_;
144 }
145
146 void new_queue(int index);
147 MTL::CommandBuffer* get_command_buffer(int index);
150 void commit_command_buffer(int index);
152 void end_encoding(int index);
153
155 const std::string& lib_name,
156 const std::string& lib_path);
157
158 // Note, this should remain in the header so that it is not dynamically
159 // linked
160 void register_library(const std::string& lib_name) {
161 if (auto it = library_map_.find(lib_name); it == library_map_.end()) {
162 register_library(lib_name, get_colocated_mtllib_path(lib_name));
163 }
164 }
165
166 MTL::Library* get_library(
167 const std::string& name,
168 const std::function<std::string(void)>& builder);
169
170 MTL::ComputePipelineState* get_kernel(
171 const std::string& base_name,
172 MTL::Library* mtl_lib,
173 const std::string& hash_name = "",
174 const MTLFCList& func_consts = {},
175 const std::vector<MTL::Function*>& linked_functions = {});
176
177 MTL::ComputePipelineState* get_kernel(
178 const std::string& base_name,
179 const std::string& lib_name = "mlx",
180 const std::string& hash_name = "",
181 const MTLFCList& func_consts = {},
182 const std::vector<MTL::Function*>& linked_functions = {});
183
184 MTL::ArgumentEncoder* argument_encoder(
185 const std::vector<MTL::ArgumentDescriptor*>& arg_descs) const;
186
187 // Record temporary arrays for the given stream index
188 void add_temporary(array arr, int index);
189 void add_temporaries(std::vector<array> arrays, int index);
190
191 void set_residency_set(const MTL::ResidencySet* residency_set);
192
193 private:
194 DeviceStream& get_stream_(int index) {
195 return stream_map_.find(index)->second;
196 }
197 MTL::Library* get_library_cache_(const std::string& name);
198
199 MTL::Library* get_library_(const std::string& name);
200 MTL::Library* build_library_(const std::string& source_string);
201
202 MTL::Function* get_function_(const std::string& name, MTL::Library* mtl_lib);
203
204 MTL::Function* get_function_(
205 const std::string& name,
206 const std::string& specialized_name,
207 const MTLFCList& func_consts,
208 MTL::Library* mtl_lib);
209
210 MTL::LinkedFunctions* get_linked_functions_(
211 const std::vector<MTL::Function*>& funcs);
212
213 MTL::ComputePipelineState* get_kernel_(
214 const std::string& name,
215 const MTL::Function* mtl_function);
216
217 MTL::ComputePipelineState* get_kernel_(
218 const std::string& name,
219 const MTL::Function* mtl_function,
220 const MTL::LinkedFunctions* linked_functions);
221
222 MTL::ComputePipelineState* get_kernel_(
223 const std::string& base_name,
224 MTL::Library* mtl_lib,
225 const std::string& hash_name,
226 const MTLFCList& func_consts = {},
227 const std::vector<MTL::Function*>& linked_functions = {});
228
229 MTL::Device* device_;
230 std::unordered_map<int32_t, DeviceStream> stream_map_;
231
232 std::shared_mutex kernel_mtx_;
233 std::unordered_map<std::string, MTL::ComputePipelineState*> kernel_map_;
234
235 std::shared_mutex library_mtx_;
236 std::unordered_map<std::string, MTL::Library*> library_map_;
237 const MTL::ResidencySet* residency_set_{nullptr};
238 std::string arch_;
239};
240
242
243} // namespace mlx::core::metal
Definition array.h:20
Definition device.h:131
void set_residency_set(const MTL::ResidencySet *residency_set)
int get_command_buffer_ops(int index)
MTL::Device * mtl_device()
Definition device.h:138
void register_library(const std::string &lib_name, const std::string &lib_path)
MTL::CommandBuffer * get_command_buffer(int index)
void end_encoding(int index)
const std::string & get_architecture()
Definition device.h:142
MTL::ComputePipelineState * get_kernel(const std::string &base_name, MTL::Library *mtl_lib, const std::string &hash_name="", const MTLFCList &func_consts={}, const std::vector< MTL::Function * > &linked_functions={})
MTL::ArgumentEncoder * argument_encoder(const std::vector< MTL::ArgumentDescriptor * > &arg_descs) const
void add_temporaries(std::vector< array > arrays, int index)
MTL::Library * get_library(const std::string &name, const std::function< std::string(void)> &builder)
void increment_command_buffer_ops(int index)
void new_queue(int index)
void commit_command_buffer(int index)
void register_library(const std::string &lib_name)
Definition device.h:160
Device(const Device &)=delete
void add_temporary(array arr, int index)
Device & operator=(const Device &)=delete
MTL::ComputePipelineState * get_kernel(const std::string &base_name, const std::string &lib_name="mlx", const std::string &hash_name="", const MTLFCList &func_consts={}, const std::vector< MTL::Function * > &linked_functions={})
CommandEncoder & get_command_encoder(int index)
Definition allocator.h:13
std::string get_colocated_mtllib_path(const std::string &lib_name)
Definition device.h:24
std::vector< std::tuple< const void *, MTL::DataType, NS::UInteger > > MTLFCList
Definition device.h:38
Device & device(mlx::core::Device)
Definition device.h:7
ConcurrentContext(CommandEncoder &enc)
Definition device.h:47
Definition device.h:41
void dispatchThreads(MTL::Size grid_dims, MTL::Size group_dims)
CommandEncoder(MTL::CommandBuffer *cbuf)
std::unordered_set< const void * > & inputs()
Definition device.h:77
CommandEncoder & operator=(const CommandEncoder &)=delete
ConcurrentContext start_concurrent()
Definition device.h:71
void set_output_array(array &a, int idx, int64_t offset=0)
void dispatchThreadgroups(MTL::Size grid_dims, MTL::Size group_dims)
MTL::ComputeCommandEncoder * operator->()
Definition device.h:61
void set_input_array(const array &a, int idx, int64_t offset=0)
CommandEncoder(const CommandEncoder &)=delete
std::unordered_set< const void * > outputs()
Definition device.h:82
Definition device.h:105
~DeviceStream()
Definition device.h:107
std::unordered_map< const void *, std::shared_ptr< Fence > > outputs
Definition device.h:115
DeviceStream(MTL::CommandQueue *queue)
Definition device.h:106
std::unique_ptr< CommandEncoder > encoder
Definition device.h:126
std::mutex fence_mtx
Definition device.h:117
MTL::CommandQueue * queue
Definition device.h:113
std::shared_ptr< Fence > fence
Definition device.h:127
MTL::CommandBuffer * buffer
Definition device.h:121
int buffer_ops
Definition device.h:122
std::vector< array > temporaries
Definition device.h:128
Definition device.h:97
Fence(MTL::Fence *fence)
Definition device.h:98
~Fence()
Definition device.h:99
MTL::Fence * fence
Definition device.h:102