5#include <Metal/Metal.hpp>
10#include <shared_mutex>
12#include <unordered_map>
13#include <unordered_set>
18namespace fs = std::filesystem;
26 std::string mtllib_path;
27 std::string lib_ext = lib_name +
".metallib";
31 auto mtllib = fs::path(info.dli_fname).remove_filename() / lib_ext;
32 mtllib_path = mtllib.c_str();
39 std::vector<std::tuple<const void*, MTL::DataType, NS::UInteger>>;
48 enc.concurrent_ =
true;
51 enc.concurrent_ =
false;
53 enc.concurrent_outputs_.begin(), enc.concurrent_outputs_.end());
54 enc.concurrent_outputs_.clear();
76 std::unordered_set<const void*>&
inputs() {
81 std::unordered_set<const void*>
outputs() {
86 MTL::ComputeCommandEncoder* enc_;
87 bool concurrent_{
false};
88 std::unordered_set<MTL::Resource*> outputs_;
89 std::unordered_set<MTL::Resource*> concurrent_outputs_;
90 std::unordered_set<const void*> all_inputs_;
91 std::unordered_set<const void*> all_outputs_;
112 std::unordered_map<const void*, std::shared_ptr<Fence>>
outputs;
123 std::unique_ptr<CommandEncoder>
encoder{
nullptr};
148 const std::string& lib_name,
149 const std::string& lib_path);
154 if (
auto it = library_map_.find(lib_name); it == library_map_.end()) {
160 const std::string& name,
161 const std::function<std::string(
void)>& builder);
164 const std::string& base_name,
165 MTL::Library* mtl_lib,
166 const std::string& hash_name =
"",
168 const std::vector<MTL::Function*>& linked_functions = {});
171 const std::string& base_name,
172 const std::string& lib_name =
"mlx",
173 const std::string& hash_name =
"",
175 const std::vector<MTL::Function*>& linked_functions = {});
178 const std::vector<MTL::ArgumentDescriptor*>& arg_descs)
const;
188 return stream_map_.find(index)->second;
190 MTL::Library* get_library_cache_(
const std::string& name);
192 MTL::Library* get_library_(
const std::string& name);
193 MTL::Library* build_library_(
const std::string& source_string);
195 MTL::Function* get_function_(
const std::string& name, MTL::Library* mtl_lib);
197 MTL::Function* get_function_(
198 const std::string& name,
199 const std::string& specialized_name,
201 MTL::Library* mtl_lib);
203 MTL::LinkedFunctions* get_linked_functions_(
204 const std::vector<MTL::Function*>& funcs);
206 MTL::ComputePipelineState* get_kernel_(
207 const std::string& name,
208 const MTL::Function* mtl_function);
210 MTL::ComputePipelineState* get_kernel_(
211 const std::string& name,
212 const MTL::Function* mtl_function,
213 const MTL::LinkedFunctions* linked_functions);
215 MTL::ComputePipelineState* get_kernel_(
216 const std::string& base_name,
217 MTL::Library* mtl_lib,
218 const std::string& hash_name,
220 const std::vector<MTL::Function*>& linked_functions = {});
222 MTL::Device* device_;
223 std::unordered_map<int32_t, DeviceStream> stream_map_;
225 std::shared_mutex kernel_mtx_;
226 std::unordered_map<std::string, MTL::ComputePipelineState*> kernel_map_;
228 std::shared_mutex library_mtx_;
229 std::unordered_map<std::string, MTL::Library*> library_map_;
230 const MTL::ResidencySet* residency_set_{
nullptr};
~ConcurrentContext()
Definition device.h:50
ConcurrentContext(CommandEncoder &enc)
Definition device.h:47