MLX
Loading...
Searching...
No Matches
device.h
Go to the documentation of this file.
1// Copyright © 2023-2024 Apple Inc.
2
3#pragma once
4
5#include <Metal/Metal.hpp>
6#include <functional>
7#include <mutex>
8#include <string>
9#include <unordered_map>
10#include <unordered_set>
11
12#include <dlfcn.h>
13#include <filesystem>
14
15#include "mlx/array.h"
16#include "mlx/device.h"
17
18namespace fs = std::filesystem;
19
20namespace mlx::core::metal {
21
22inline std::string get_colocated_mtllib_path(const std::string& lib_name) {
23 Dl_info info;
24 std::string mtllib_path;
25 std::string lib_ext = lib_name + ".metallib";
26
27 int success = dladdr((void*)get_colocated_mtllib_path, &info);
28 if (success) {
29 auto mtllib = fs::path(info.dli_fname).remove_filename() / lib_ext;
30 mtllib_path = mtllib.c_str();
31 }
32
33 return mtllib_path;
34}
35
36using MTLFCList =
37 std::vector<std::tuple<const void*, MTL::DataType, NS::UInteger>>;
38
40 CommandEncoder(MTL::CommandBuffer* cbuf) : cbuf(cbuf) {
41 enc = cbuf->computeCommandEncoder(MTL::DispatchTypeConcurrent);
42 enc->retain();
43 };
46
49 enc.concurrent = true;
50 }
52 enc.concurrent = false;
53 enc.outputs.insert(
54 enc.concurrent_outputs.begin(), enc.concurrent_outputs.end());
55 enc.concurrent_outputs.clear();
56 }
57
58 private:
59 CommandEncoder& enc;
60 };
61
62 MTL::ComputeCommandEncoder* operator->() {
63 return enc;
64 }
65
66 void set_input_array(const array& a, int idx, int64_t offset = 0) {
67 auto r_buf =
68 static_cast<MTL::Resource*>(const_cast<void*>(a.buffer().ptr()));
69 if (auto it = outputs.find(r_buf); it != outputs.end()) {
70 // Insert a barrier
71 enc->memoryBarrier(&r_buf, 1);
72
73 // Remove the output
74 outputs.erase(it);
75 }
76 auto a_buf = static_cast<const MTL::Buffer*>(a.buffer().ptr());
77 auto base_offset = a.data<char>() -
78 static_cast<char*>(const_cast<MTL::Buffer*>(a_buf)->contents());
79 base_offset += offset;
80 enc->setBuffer(a_buf, base_offset, idx);
81 }
82
83 void set_output_array(array& a, int idx, int64_t offset = 0) {
84 // Add barriers before adding the output to the output set
85 set_input_array(a, idx, offset);
86 auto buf = static_cast<MTL::Resource*>(a.buffer().ptr());
87 if (concurrent) {
88 concurrent_outputs.insert(buf);
89 } else {
90 outputs.insert(buf);
91 }
92 }
93
94 void dispatchThreadgroups(MTL::Size grid_dims, MTL::Size group_dims);
95 void dispatchThreads(MTL::Size grid_dims, MTL::Size group_dims);
96
100
102 enc->endEncoding();
103 enc->release();
104 }
105
106 private:
107 void maybe_split();
108
109 int num_dispatches{0};
110 MTL::CommandBuffer* cbuf;
111 MTL::ComputeCommandEncoder* enc;
112 bool concurrent{false};
113 std::unordered_set<MTL::Resource*> outputs;
114 std::unordered_set<MTL::Resource*> concurrent_outputs;
115};
116
117class Device {
118 public:
120 Device(const Device&) = delete;
121 Device& operator=(const Device&) = delete;
123
124 MTL::Device* mtl_device() {
125 return device_;
126 };
127
128 void new_queue(int index);
129 MTL::CommandBuffer* get_command_buffer(int index);
132 void commit_command_buffer(int index);
134 void end_encoding(int index);
135
137 const std::string& lib_name,
138 const std::string& lib_path);
140 const std::string& lib_name,
141 const std::function<std::string(const std::string&)>& lib_path_func =
143
144 MTL::Library* get_library(const std::string& name);
145
146 MTL::Library* get_library(
147 const std::string& name,
148 const std::string& source_string,
149 bool cache = true);
150
151 MTL::Library* get_library(
152 const std::string& name,
153 const MTL::StitchedLibraryDescriptor* desc,
154 bool cache = true);
155
156 MTL::Function* get_function(
157 const std::string& base_name,
158 MTL::Library* mtl_lib,
159 const std::string& specialized_name = "",
160 const MTLFCList& func_consts = {});
161
162 MTL::Function* get_function(
163 const std::string& base_name,
164 const std::string& lib_name = "mlx",
165 const std::string& specialized_name = "",
166 const MTLFCList& func_consts = {});
167
168 MTL::ComputePipelineState* get_kernel(
169 const std::string& base_name,
170 MTL::Library* mtl_lib,
171 const std::string& hash_name = "",
172 const MTLFCList& func_consts = {},
173 const std::vector<MTL::Function*>& linked_functions = {});
174
175 MTL::ComputePipelineState* get_kernel(
176 const std::string& base_name,
177 const std::string& lib_name = "mlx",
178 const std::string& hash_name = "",
179 const MTLFCList& func_consts = {},
180 const std::vector<MTL::Function*>& linked_functions = {});
181
182 MTL::ArgumentEncoder* argument_encoder(
183 const std::vector<MTL::ArgumentDescriptor*>& arg_descs) const;
184
185 private:
186 MTL::Library* get_library_cache_(const std::string& name);
187
188 MTL::Library* get_library_(const std::string& source_string);
189 MTL::Library* get_library_(const MTL::StitchedLibraryDescriptor* desc);
190
191 MTL::Function* get_function_(const std::string& name, MTL::Library* mtl_lib);
192
193 MTL::Function* get_function_(
194 const std::string& name,
195 const std::string& specialized_name,
196 const MTLFCList& func_consts,
197 MTL::Library* mtl_lib);
198
199 MTL::LinkedFunctions* get_linked_functions_(
200 const std::vector<MTL::Function*>& funcs);
201
202 MTL::ComputePipelineState* get_kernel_(
203 const std::string& name,
204 const MTL::Function* mtl_function);
205
206 MTL::ComputePipelineState* get_kernel_(
207 const std::string& name,
208 const MTL::Function* mtl_function,
209 const MTL::LinkedFunctions* linked_functions);
210
211 MTL::Device* device_;
212 std::unordered_map<int32_t, MTL::CommandQueue*> queue_map_;
213 std::unordered_map<int32_t, std::pair<int, MTL::CommandBuffer*>> buffer_map_;
214 std::unordered_map<int32_t, std::unique_ptr<CommandEncoder>> encoder_map_;
215 std::unordered_map<std::string, MTL::ComputePipelineState*> kernel_map_;
216 std::unordered_map<std::string, MTL::Library*> library_map_;
217 std::mutex mtx_;
218};
219
221
222} // namespace mlx::core::metal
MTL::Buffer * buf
Definition allocator.h:38
const void * ptr() const
Definition allocator.h:23
Definition array.h:20
T * data()
Definition array.h:313
allocator::Buffer & buffer()
Definition array.h:299
Definition device.h:117
int get_command_buffer_ops(int index)
MTL::Device * mtl_device()
Definition device.h:124
void register_library(const std::string &lib_name, const std::string &lib_path)
MTL::CommandBuffer * get_command_buffer(int index)
void end_encoding(int index)
MTL::ComputePipelineState * get_kernel(const std::string &base_name, MTL::Library *mtl_lib, const std::string &hash_name="", const MTLFCList &func_consts={}, const std::vector< MTL::Function * > &linked_functions={})
void register_library(const std::string &lib_name, const std::function< std::string(const std::string &)> &lib_path_func=get_colocated_mtllib_path)
MTL::ArgumentEncoder * argument_encoder(const std::vector< MTL::ArgumentDescriptor * > &arg_descs) const
void increment_command_buffer_ops(int index)
void new_queue(int index)
MTL::Library * get_library(const std::string &name)
MTL::Library * get_library(const std::string &name, const MTL::StitchedLibraryDescriptor *desc, bool cache=true)
void commit_command_buffer(int index)
MTL::Library * get_library(const std::string &name, const std::string &source_string, bool cache=true)
MTL::Function * get_function(const std::string &base_name, MTL::Library *mtl_lib, const std::string &specialized_name="", const MTLFCList &func_consts={})
Device(const Device &)=delete
MTL::Function * get_function(const std::string &base_name, const std::string &lib_name="mlx", const std::string &specialized_name="", const MTLFCList &func_consts={})
Device & operator=(const Device &)=delete
MTL::ComputePipelineState * get_kernel(const std::string &base_name, const std::string &lib_name="mlx", const std::string &hash_name="", const MTLFCList &func_consts={}, const std::vector< MTL::Function * > &linked_functions={})
CommandEncoder & get_command_encoder(int index)
Definition allocator.h:12
std::string get_colocated_mtllib_path(const std::string &lib_name)
Definition device.h:22
std::vector< std::tuple< const void *, MTL::DataType, NS::UInteger > > MTLFCList
Definition device.h:36
Device & device(mlx::core::Device)
Definition device.h:7
ConcurrentContext(CommandEncoder &enc)
Definition device.h:48
Definition device.h:39
void dispatchThreads(MTL::Size grid_dims, MTL::Size group_dims)
CommandEncoder(MTL::CommandBuffer *cbuf)
Definition device.h:40
CommandEncoder & operator=(const CommandEncoder &)=delete
ConcurrentContext start_concurrent()
Definition device.h:97
void set_output_array(array &a, int idx, int64_t offset=0)
Definition device.h:83
void dispatchThreadgroups(MTL::Size grid_dims, MTL::Size group_dims)
~CommandEncoder()
Definition device.h:101
MTL::ComputeCommandEncoder * operator->()
Definition device.h:62
void set_input_array(const array &a, int idx, int64_t offset=0)
Definition device.h:66
CommandEncoder(const CommandEncoder &)=delete