5#include <Metal/Metal.hpp>
9#include <unordered_map>
10#include <unordered_set>
18namespace fs = std::filesystem;
24 std::string mtllib_path;
25 std::string lib_ext = lib_name +
".metallib";
29 auto mtllib = fs::path(info.dli_fname).remove_filename() / lib_ext;
30 mtllib_path = mtllib.c_str();
37 std::vector<std::tuple<const void*, MTL::DataType, NS::UInteger>>;
41 enc = cbuf->computeCommandEncoder(MTL::DispatchTypeConcurrent);
49 enc.concurrent =
true;
52 enc.concurrent =
false;
54 enc.concurrent_outputs.begin(), enc.concurrent_outputs.end());
55 enc.concurrent_outputs.clear();
68 static_cast<MTL::Resource*
>(
const_cast<void*
>(a.
buffer().
ptr()));
69 if (
auto it = outputs.find(r_buf); it != outputs.end()) {
71 enc->memoryBarrier(&r_buf, 1);
76 auto a_buf =
static_cast<const MTL::Buffer*
>(a.
buffer().
ptr());
77 auto base_offset = a.
data<
char>() -
78 static_cast<char*
>(
const_cast<MTL::Buffer*
>(a_buf)->contents());
79 base_offset += offset;
80 enc->setBuffer(a_buf, base_offset, idx);
86 auto buf =
static_cast<MTL::Resource*
>(a.
buffer().
ptr());
88 concurrent_outputs.insert(
buf);
109 int num_dispatches{0};
110 MTL::CommandBuffer* cbuf;
111 MTL::ComputeCommandEncoder* enc;
112 bool concurrent{
false};
113 std::unordered_set<MTL::Resource*> outputs;
114 std::unordered_set<MTL::Resource*> concurrent_outputs;
137 const std::string& lib_name,
138 const std::string& lib_path);
140 const std::string& lib_name,
141 const std::function<std::string(
const std::string&)>& lib_path_func =
147 const std::string& name,
148 const std::string& source_string,
152 const std::string& name,
153 const MTL::StitchedLibraryDescriptor* desc,
157 const std::string& base_name,
158 MTL::Library* mtl_lib,
159 const std::string& specialized_name =
"",
163 const std::string& base_name,
164 const std::string& lib_name =
"mlx",
165 const std::string& specialized_name =
"",
169 const std::string& base_name,
170 MTL::Library* mtl_lib,
171 const std::string& hash_name =
"",
173 const std::vector<MTL::Function*>& linked_functions = {});
176 const std::string& base_name,
177 const std::string& lib_name =
"mlx",
178 const std::string& hash_name =
"",
180 const std::vector<MTL::Function*>& linked_functions = {});
183 const std::vector<MTL::ArgumentDescriptor*>& arg_descs)
const;
186 MTL::Library* get_library_cache_(
const std::string& name);
188 MTL::Library* get_library_(
const std::string& source_string);
189 MTL::Library* get_library_(
const MTL::StitchedLibraryDescriptor* desc);
191 MTL::Function* get_function_(
const std::string& name, MTL::Library* mtl_lib);
193 MTL::Function* get_function_(
194 const std::string& name,
195 const std::string& specialized_name,
197 MTL::Library* mtl_lib);
199 MTL::LinkedFunctions* get_linked_functions_(
200 const std::vector<MTL::Function*>& funcs);
202 MTL::ComputePipelineState* get_kernel_(
203 const std::string& name,
204 const MTL::Function* mtl_function);
206 MTL::ComputePipelineState* get_kernel_(
207 const std::string& name,
208 const MTL::Function* mtl_function,
209 const MTL::LinkedFunctions* linked_functions);
211 MTL::Device* device_;
212 std::unordered_map<int32_t, MTL::CommandQueue*> queue_map_;
213 std::unordered_map<int32_t, std::pair<int, MTL::CommandBuffer*>> buffer_map_;
214 std::unordered_map<int32_t, std::unique_ptr<CommandEncoder>> encoder_map_;
215 std::unordered_map<std::string, MTL::ComputePipelineState*> kernel_map_;
216 std::unordered_map<std::string, MTL::Library*> library_map_;
const void * ptr() const
Definition allocator.h:23
T * data()
Definition array.h:313
allocator::Buffer & buffer()
Definition array.h:299
~ConcurrentContext()
Definition device.h:51
ConcurrentContext(CommandEncoder &enc)
Definition device.h:48