5#include <Metal/Metal.hpp>
11#include <unordered_map>
12#include <unordered_set>
17namespace fs = std::filesystem;
25 std::string mtllib_path;
26 std::string lib_ext = lib_name +
".metallib";
30 auto mtllib = fs::path(info.dli_fname).remove_filename() / lib_ext;
31 mtllib_path = mtllib.c_str();
38 std::vector<std::tuple<const void*, MTL::DataType, NS::UInteger>>;
47 enc.concurrent =
true;
50 enc.concurrent =
false;
52 enc.concurrent_outputs.begin(), enc.concurrent_outputs.end());
53 enc.concurrent_outputs.clear();
78 int num_dispatches{0};
79 MTL::CommandBuffer* cbuf;
80 MTL::ComputeCommandEncoder* enc;
81 bool concurrent{
false};
82 std::unordered_set<MTL::Resource*> outputs;
83 std::unordered_set<MTL::Resource*> concurrent_outputs;
106 const std::string& lib_name,
107 const std::string& lib_path);
112 if (
auto it = library_map_.find(lib_name); it == library_map_.end()) {
120 const std::string& name,
121 const std::string& source_string,
125 const std::string& name,
126 const MTL::StitchedLibraryDescriptor* desc,
130 const std::string& base_name,
131 MTL::Library* mtl_lib,
132 const std::string& specialized_name =
"",
136 const std::string& base_name,
137 const std::string& lib_name =
"mlx",
138 const std::string& specialized_name =
"",
142 const std::string& base_name,
143 MTL::Library* mtl_lib,
144 const std::string& hash_name =
"",
146 const std::vector<MTL::Function*>& linked_functions = {});
149 const std::string& base_name,
150 const std::string& lib_name =
"mlx",
151 const std::string& hash_name =
"",
153 const std::vector<MTL::Function*>& linked_functions = {});
156 const std::vector<MTL::ArgumentDescriptor*>& arg_descs)
const;
159 MTL::Library* get_library_cache_(
const std::string& name);
161 MTL::Library* get_library_(
const std::string& source_string);
162 MTL::Library* get_library_(
const MTL::StitchedLibraryDescriptor* desc);
164 MTL::Function* get_function_(
const std::string& name, MTL::Library* mtl_lib);
166 MTL::Function* get_function_(
167 const std::string& name,
168 const std::string& specialized_name,
170 MTL::Library* mtl_lib);
172 MTL::LinkedFunctions* get_linked_functions_(
173 const std::vector<MTL::Function*>& funcs);
175 MTL::ComputePipelineState* get_kernel_(
176 const std::string& name,
177 const MTL::Function* mtl_function);
179 MTL::ComputePipelineState* get_kernel_(
180 const std::string& name,
181 const MTL::Function* mtl_function,
182 const MTL::LinkedFunctions* linked_functions);
184 MTL::Device* device_;
185 std::unordered_map<int32_t, MTL::CommandQueue*> queue_map_;
186 std::unordered_map<int32_t, std::pair<int, MTL::CommandBuffer*>> buffer_map_;
187 std::unordered_map<int32_t, std::unique_ptr<CommandEncoder>> encoder_map_;
188 std::unordered_map<std::string, MTL::ComputePipelineState*> kernel_map_;
189 std::unordered_map<std::string, MTL::Library*> library_map_;
~ConcurrentContext()
Definition device.h:49
ConcurrentContext(CommandEncoder &enc)
Definition device.h:46