MLX
 
Loading...
Searching...
No Matches
device.h
Go to the documentation of this file.
1// Copyright © 2023-2024 Apple Inc.
2
3#pragma once
4
5#include <Metal/Metal.hpp>
6#include <dlfcn.h>
7#include <filesystem>
8#include <functional>
9#include <mutex>
10#include <shared_mutex>
11#include <string>
12#include <unordered_map>
13#include <unordered_set>
14
15#include "mlx/array.h"
16#include "mlx/device.h"
17
18namespace fs = std::filesystem;
19
20namespace mlx::core::metal {
21
22// Note, this function must be left inline in a header so that it is not
23// dynamically linked.
24inline std::string get_colocated_mtllib_path(const std::string& lib_name) {
25 Dl_info info;
26 std::string mtllib_path;
27 std::string lib_ext = lib_name + ".metallib";
28
29 int success = dladdr((void*)get_colocated_mtllib_path, &info);
30 if (success) {
31 auto mtllib = fs::path(info.dli_fname).remove_filename() / lib_ext;
32 mtllib_path = mtllib.c_str();
33 }
34
35 return mtllib_path;
36}
37
38using MTLFCList =
39 std::vector<std::tuple<const void*, MTL::DataType, NS::UInteger>>;
40
41struct DeviceStream;
42
44 explicit CommandEncoder(DeviceStream& stream);
47
50 enc.concurrent_ = true;
51 }
53 enc.concurrent_ = false;
54 enc.prev_outputs_.insert(
55 enc.concurrent_outputs_.begin(), enc.concurrent_outputs_.end());
56 enc.concurrent_outputs_.clear();
57 }
58
59 private:
60 CommandEncoder& enc;
61 };
62
63 void set_input_array(const array& a, int idx, int64_t offset = 0);
64 void set_output_array(array& a, int idx, int64_t offset = 0);
66 void dispatch_threadgroups(MTL::Size grid_dims, MTL::Size group_dims);
67 void dispatch_threads(MTL::Size grid_dims, MTL::Size group_dims);
69 void set_buffer(const MTL::Buffer* buf, int idx, int64_t offset = 0);
70
71 void set_compute_pipeline_state(MTL::ComputePipelineState* kernel) {
72 enc_->setComputePipelineState(kernel);
73 }
74
75 void wait_for_fence(MTL::Fence* fence) {
76 enc_->waitForFence(fence);
77 }
78
79 void update_fence(MTL::Fence* fence) {
80 enc_->updateFence(fence);
81 }
82
83 template <typename T>
84 void set_vector_bytes(const std::vector<T>& vec, size_t nelems, int idx) {
85 enc_->setBytes(vec.data(), nelems * sizeof(T), idx);
86 }
87 template <typename T>
88 void set_vector_bytes(const std::vector<T>& vec, int idx) {
89 return set_vector_bytes(vec, vec.size(), idx);
90 }
91
92 template <typename T>
93 void set_bytes(const T* v, int n, int idx) {
94 return enc_->setBytes(v, n * sizeof(T), idx);
95 }
96
97 template <typename T>
98 void set_bytes(const T& v, int idx) {
99 return enc_->setBytes(&v, sizeof(T), idx);
100 }
101
106
107 // Inputs to all kernels in the encoder including temporaries
108 std::unordered_set<const void*>& inputs() {
109 return all_inputs_;
110 };
111
112 // Outputs of all kernels in the encoder including temporaries
113 std::unordered_set<const void*> outputs() {
114 return all_outputs_;
115 };
116
117 void barrier();
118
119 private:
120 DeviceStream& stream_;
121 MTL::ComputeCommandEncoder* enc_;
122 bool needs_barrier_{false};
123 bool concurrent_{false};
124 std::unordered_set<MTL::Resource*> prev_outputs_;
125 std::unordered_set<MTL::Resource*> next_outputs_;
126 std::unordered_set<MTL::Resource*> concurrent_outputs_;
127 std::unordered_set<const void*> all_inputs_;
128 std::unordered_set<const void*> all_outputs_;
129};
130
131struct Fence {
132 Fence(MTL::Fence* fence) : fence(fence) {}
134 fence->release();
135 }
136 MTL::Fence* fence;
137};
138
140 DeviceStream(MTL::CommandQueue* queue) : queue(queue) {};
142 queue->release();
143 if (buffer != nullptr) {
144 buffer->release();
145 }
146 };
147 MTL::CommandQueue* queue;
148 // A map of prior command encoder outputs to their corresponding fence
149 std::unordered_map<const void*, std::shared_ptr<Fence>> outputs;
150 // Used to allow thread-safe access to the outputs map
151 std::mutex fence_mtx;
152
153 // Data updated between command buffers
154 MTL::CommandBuffer* buffer{nullptr};
156 size_t buffer_sizes{0};
157
158 // The command encoder, fence, and temporaries are updated between command
159 // encoders
160 std::unique_ptr<CommandEncoder> encoder{nullptr};
161 std::shared_ptr<Fence> fence;
162 std::vector<array> temporaries;
163};
164
165class Device {
166 public:
168 Device(const Device&) = delete;
169 Device& operator=(const Device&) = delete;
171
172 MTL::Device* mtl_device() {
173 return device_;
174 };
175
176 const std::string& get_architecture() {
177 return arch_;
178 }
179
180 void new_queue(int index);
181 MTL::CommandBuffer* get_command_buffer(int index);
183 void commit_command_buffer(int index);
185 void end_encoding(int index);
186
188 const std::string& lib_name,
189 const std::string& lib_path);
190
191 // Note, this should remain in the header so that it is not dynamically
192 // linked
193 void register_library(const std::string& lib_name) {
194 if (auto it = library_map_.find(lib_name); it == library_map_.end()) {
195 register_library(lib_name, get_colocated_mtllib_path(lib_name));
196 }
197 }
198
199 MTL::Library* get_library(
200 const std::string& name,
201 const std::function<std::string(void)>& builder);
202
203 MTL::ComputePipelineState* get_kernel(
204 const std::string& base_name,
205 MTL::Library* mtl_lib,
206 const std::string& hash_name = "",
207 const MTLFCList& func_consts = {},
208 const std::vector<MTL::Function*>& linked_functions = {});
209
210 MTL::ComputePipelineState* get_kernel(
211 const std::string& base_name,
212 const std::string& lib_name = "mlx",
213 const std::string& hash_name = "",
214 const MTLFCList& func_consts = {},
215 const std::vector<MTL::Function*>& linked_functions = {});
216
217 MTL::ArgumentEncoder* argument_encoder(
218 const std::vector<MTL::ArgumentDescriptor*>& arg_descs) const;
219
220 // Record temporary arrays for the given stream index
221 void add_temporary(array arr, int index);
222 void add_temporaries(std::vector<array> arrays, int index);
223
224 void set_residency_set(const MTL::ResidencySet* residency_set);
225
226 private:
227 DeviceStream& get_stream_(int index) {
228 return stream_map_.find(index)->second;
229 }
230 MTL::Library* get_library_cache_(const std::string& name);
231
232 MTL::Library* get_library_(const std::string& name);
233 MTL::Library* build_library_(const std::string& source_string);
234
235 MTL::Function* get_function_(const std::string& name, MTL::Library* mtl_lib);
236
237 MTL::Function* get_function_(
238 const std::string& name,
239 const std::string& specialized_name,
240 const MTLFCList& func_consts,
241 MTL::Library* mtl_lib);
242
243 MTL::LinkedFunctions* get_linked_functions_(
244 const std::vector<MTL::Function*>& funcs);
245
246 MTL::ComputePipelineState* get_kernel_(
247 const std::string& name,
248 const MTL::Function* mtl_function);
249
250 MTL::ComputePipelineState* get_kernel_(
251 const std::string& name,
252 const MTL::Function* mtl_function,
253 const MTL::LinkedFunctions* linked_functions);
254
255 MTL::ComputePipelineState* get_kernel_(
256 const std::string& base_name,
257 MTL::Library* mtl_lib,
258 const std::string& hash_name,
259 const MTLFCList& func_consts = {},
260 const std::vector<MTL::Function*>& linked_functions = {});
261
262 MTL::Device* device_;
263 std::unordered_map<int32_t, DeviceStream> stream_map_;
264
265 std::shared_mutex kernel_mtx_;
266 std::unordered_map<std::string, MTL::ComputePipelineState*> kernel_map_;
267
268 std::shared_mutex library_mtx_;
269 std::unordered_map<std::string, MTL::Library*> library_map_;
270 const MTL::ResidencySet* residency_set_{nullptr};
271 std::string arch_;
272 int max_ops_per_buffer_;
273 int max_mb_per_buffer_;
274};
275
277
278} // namespace mlx::core::metal
Definition array.h:24
Definition device.h:165
void set_residency_set(const MTL::ResidencySet *residency_set)
bool command_buffer_needs_commit(int index)
MTL::Device * mtl_device()
Definition device.h:172
void register_library(const std::string &lib_name, const std::string &lib_path)
MTL::CommandBuffer * get_command_buffer(int index)
void end_encoding(int index)
const std::string & get_architecture()
Definition device.h:176
MTL::ComputePipelineState * get_kernel(const std::string &base_name, MTL::Library *mtl_lib, const std::string &hash_name="", const MTLFCList &func_consts={}, const std::vector< MTL::Function * > &linked_functions={})
MTL::ArgumentEncoder * argument_encoder(const std::vector< MTL::ArgumentDescriptor * > &arg_descs) const
void add_temporaries(std::vector< array > arrays, int index)
MTL::Library * get_library(const std::string &name, const std::function< std::string(void)> &builder)
void new_queue(int index)
void commit_command_buffer(int index)
void register_library(const std::string &lib_name)
Definition device.h:193
Device(const Device &)=delete
void add_temporary(array arr, int index)
Device & operator=(const Device &)=delete
MTL::ComputePipelineState * get_kernel(const std::string &base_name, const std::string &lib_name="mlx", const std::string &hash_name="", const MTLFCList &func_consts={}, const std::vector< MTL::Function * > &linked_functions={})
CommandEncoder & get_command_encoder(int index)
Definition allocator.h:13
std::string get_colocated_mtllib_path(const std::string &lib_name)
Definition device.h:24
std::vector< std::tuple< const void *, MTL::DataType, NS::UInteger > > MTLFCList
Definition device.h:38
Device & device(mlx::core::Device)
ConcurrentContext(CommandEncoder &enc)
Definition device.h:49
CommandEncoder(DeviceStream &stream)
Definition device.h:7
ConcurrentContext(CommandEncoder &enc)
Definition device.h:49
Definition device.h:43
void dispatch_threads(MTL::Size grid_dims, MTL::Size group_dims)
std::unordered_set< const void * > & inputs()
Definition device.h:108
CommandEncoder & operator=(const CommandEncoder &)=delete
ConcurrentContext start_concurrent()
Definition device.h:102
void set_vector_bytes(const std::vector< T > &vec, size_t nelems, int idx)
Definition device.h:84
void set_output_array(array &a, int idx, int64_t offset=0)
void set_compute_pipeline_state(MTL::ComputePipelineState *kernel)
Definition device.h:71
CommandEncoder(DeviceStream &stream)
void set_vector_bytes(const std::vector< T > &vec, int idx)
Definition device.h:88
void dispatch_threadgroups(MTL::Size grid_dims, MTL::Size group_dims)
void set_bytes(const T *v, int n, int idx)
Definition device.h:93
void set_input_array(const array &a, int idx, int64_t offset=0)
void set_bytes(const T &v, int idx)
Definition device.h:98
CommandEncoder(const CommandEncoder &)=delete
void set_buffer(const MTL::Buffer *buf, int idx, int64_t offset=0)
void update_fence(MTL::Fence *fence)
Definition device.h:79
std::unordered_set< const void * > outputs()
Definition device.h:113
void wait_for_fence(MTL::Fence *fence)
Definition device.h:75
Definition device.h:139
~DeviceStream()
Definition device.h:141
std::unordered_map< const void *, std::shared_ptr< Fence > > outputs
Definition device.h:149
DeviceStream(MTL::CommandQueue *queue)
Definition device.h:140
std::unique_ptr< CommandEncoder > encoder
Definition device.h:160
std::mutex fence_mtx
Definition device.h:151
MTL::CommandQueue * queue
Definition device.h:147
std::shared_ptr< Fence > fence
Definition device.h:161
MTL::CommandBuffer * buffer
Definition device.h:154
int buffer_ops
Definition device.h:155
size_t buffer_sizes
Definition device.h:156
std::vector< array > temporaries
Definition device.h:162
Fence(MTL::Fence *fence)
Definition device.h:132
~Fence()
Definition device.h:133
MTL::Fence * fence
Definition device.h:136