5#include <Metal/Metal.hpp>
10#include <shared_mutex>
12#include <unordered_map>
13#include <unordered_set>
18namespace fs = std::filesystem;
26 std::string mtllib_path;
27 std::string lib_ext = lib_name +
".metallib";
31 auto mtllib = fs::path(info.dli_fname).remove_filename() / lib_ext;
32 mtllib_path = mtllib.c_str();
39 std::vector<std::tuple<const void*, MTL::DataType, NS::UInteger>>;
48 enc.concurrent =
true;
51 enc.concurrent =
false;
53 enc.concurrent_outputs.begin(), enc.concurrent_outputs.end());
54 enc.concurrent_outputs.clear();
79 int num_dispatches{0};
80 MTL::CommandBuffer* cbuf;
81 MTL::ComputeCommandEncoder* enc;
82 bool concurrent{
false};
83 std::unordered_set<MTL::Resource*> outputs;
84 std::unordered_set<MTL::Resource*> concurrent_outputs;
107 const std::string& lib_name,
108 const std::string& lib_path);
113 if (
auto it = library_map_.find(lib_name); it == library_map_.end()) {
119 const std::string& name,
120 const std::function<std::string(
void)>& builder);
123 const std::string& base_name,
124 MTL::Library* mtl_lib,
125 const std::string& hash_name =
"",
127 const std::vector<MTL::Function*>& linked_functions = {});
130 const std::string& base_name,
131 const std::string& lib_name =
"mlx",
132 const std::string& hash_name =
"",
134 const std::vector<MTL::Function*>& linked_functions = {});
137 const std::vector<MTL::ArgumentDescriptor*>& arg_descs)
const;
140 MTL::Library* get_library_cache_(
const std::string& name);
142 MTL::Library* get_library_(
const std::string& name);
143 MTL::Library* build_library_(
const std::string& source_string);
145 MTL::Function* get_function_(
const std::string& name, MTL::Library* mtl_lib);
147 MTL::Function* get_function_(
148 const std::string& name,
149 const std::string& specialized_name,
151 MTL::Library* mtl_lib);
153 MTL::LinkedFunctions* get_linked_functions_(
154 const std::vector<MTL::Function*>& funcs);
156 MTL::ComputePipelineState* get_kernel_(
157 const std::string& name,
158 const MTL::Function* mtl_function);
160 MTL::ComputePipelineState* get_kernel_(
161 const std::string& name,
162 const MTL::Function* mtl_function,
163 const MTL::LinkedFunctions* linked_functions);
165 MTL::ComputePipelineState* get_kernel_(
166 const std::string& base_name,
167 MTL::Library* mtl_lib,
168 const std::string& hash_name,
170 const std::vector<MTL::Function*>& linked_functions = {});
172 MTL::Device* device_;
173 std::unordered_map<int32_t, MTL::CommandQueue*> queue_map_;
174 std::unordered_map<int32_t, std::pair<int, MTL::CommandBuffer*>> buffer_map_;
175 std::unordered_map<int32_t, std::unique_ptr<CommandEncoder>> encoder_map_;
177 std::shared_mutex kernel_mtx_;
178 std::unordered_map<std::string, MTL::ComputePipelineState*> kernel_map_;
180 std::shared_mutex library_mtx_;
181 std::unordered_map<std::string, MTL::Library*> library_map_;
~ConcurrentContext()
Definition device.h:50
ConcurrentContext(CommandEncoder &enc)
Definition device.h:47