5#include <Metal/Metal.hpp>
10#include <shared_mutex>
12#include <unordered_map>
13#include <unordered_set>
18namespace fs = std::filesystem;
26 std::string mtllib_path;
27 std::string lib_ext = lib_name +
".metallib";
31 auto mtllib = fs::path(info.dli_fname).remove_filename() / lib_ext;
32 mtllib_path = mtllib.c_str();
39 std::vector<std::tuple<const void*, MTL::DataType, NS::UInteger>>;
50 enc.concurrent_ =
true;
53 enc.concurrent_ =
false;
54 enc.prev_outputs_.insert(
55 enc.concurrent_outputs_.begin(), enc.concurrent_outputs_.end());
56 enc.concurrent_outputs_.clear();
69 void set_buffer(
const MTL::Buffer* buf,
int idx, int64_t offset = 0);
72 enc_->setComputePipelineState(kernel);
76 enc_->waitForFence(fence);
80 enc_->updateFence(fence);
85 enc_->setBytes(vec.data(), nelems *
sizeof(T), idx);
94 return enc_->setBytes(v, n *
sizeof(T), idx);
99 return enc_->setBytes(&v,
sizeof(T), idx);
108 std::unordered_set<const void*>&
inputs() {
121 MTL::ComputeCommandEncoder* enc_;
122 bool needs_barrier_{
false};
123 bool concurrent_{
false};
124 std::unordered_set<MTL::Resource*> prev_outputs_;
125 std::unordered_set<MTL::Resource*> next_outputs_;
126 std::unordered_set<MTL::Resource*> concurrent_outputs_;
127 std::unordered_set<const void*> all_inputs_;
128 std::unordered_set<const void*> all_outputs_;
149 std::unordered_map<const void*, std::shared_ptr<Fence>>
outputs;
160 std::unique_ptr<CommandEncoder>
encoder{
nullptr};
188 const std::string& lib_name,
189 const std::string& lib_path);
194 if (
auto it = library_map_.find(lib_name); it == library_map_.end()) {
200 const std::string& name,
201 const std::function<std::string(
void)>& builder);
204 const std::string& base_name,
205 MTL::Library* mtl_lib,
206 const std::string& hash_name =
"",
208 const std::vector<MTL::Function*>& linked_functions = {});
211 const std::string& base_name,
212 const std::string& lib_name =
"mlx",
213 const std::string& hash_name =
"",
215 const std::vector<MTL::Function*>& linked_functions = {});
218 const std::vector<MTL::ArgumentDescriptor*>& arg_descs)
const;
228 return stream_map_.find(index)->second;
230 MTL::Library* get_library_cache_(
const std::string& name);
232 MTL::Library* get_library_(
const std::string& name);
233 MTL::Library* build_library_(
const std::string& source_string);
235 MTL::Function* get_function_(
const std::string& name, MTL::Library* mtl_lib);
237 MTL::Function* get_function_(
238 const std::string& name,
239 const std::string& specialized_name,
241 MTL::Library* mtl_lib);
243 MTL::LinkedFunctions* get_linked_functions_(
244 const std::vector<MTL::Function*>& funcs);
246 MTL::ComputePipelineState* get_kernel_(
247 const std::string& name,
248 const MTL::Function* mtl_function);
250 MTL::ComputePipelineState* get_kernel_(
251 const std::string& name,
252 const MTL::Function* mtl_function,
253 const MTL::LinkedFunctions* linked_functions);
255 MTL::ComputePipelineState* get_kernel_(
256 const std::string& base_name,
257 MTL::Library* mtl_lib,
258 const std::string& hash_name,
260 const std::vector<MTL::Function*>& linked_functions = {});
262 MTL::Device* device_;
263 std::unordered_map<int32_t, DeviceStream> stream_map_;
265 std::shared_mutex kernel_mtx_;
266 std::unordered_map<std::string, MTL::ComputePipelineState*> kernel_map_;
268 std::shared_mutex library_mtx_;
269 std::unordered_map<std::string, MTL::Library*> library_map_;
270 const MTL::ResidencySet* residency_set_{
nullptr};
272 int max_ops_per_buffer_;
273 int max_mb_per_buffer_;
ConcurrentContext(CommandEncoder &enc)
Definition device.h:49
CommandEncoder(DeviceStream &stream)
~ConcurrentContext()
Definition device.h:52
ConcurrentContext(CommandEncoder &enc)
Definition device.h:49