5#include <Metal/Metal.hpp> 
    9#include <unordered_map> 
   10#include <unordered_set> 
   18namespace fs = std::filesystem;
 
   24  std::string mtllib_path;
 
   25  std::string lib_ext = lib_name + 
".metallib";
 
   29    auto mtllib = fs::path(info.dli_fname).remove_filename() / lib_ext;
 
   30    mtllib_path = mtllib.c_str();
 
 
   37    std::vector<std::tuple<const void*, MTL::DataType, NS::UInteger>>;
 
   41    enc = cbuf->computeCommandEncoder(MTL::DispatchTypeConcurrent);
 
 
   49      enc.concurrent = 
true;
 
 
   52      enc.concurrent = 
false;
 
   54          enc.concurrent_outputs.begin(), enc.concurrent_outputs.end());
 
   55      enc.concurrent_outputs.clear();
 
 
 
   68        static_cast<MTL::Resource*
>(
const_cast<void*
>(a.
buffer().
ptr()));
 
   69    if (
auto it = outputs.find(r_buf); it != outputs.end()) {
 
   71      enc->memoryBarrier(&r_buf, 1);
 
   76    auto a_buf = 
static_cast<const MTL::Buffer*
>(a.
buffer().
ptr());
 
   77    auto base_offset = a.
data<
char>() -
 
   78        static_cast<char*
>(
const_cast<MTL::Buffer*
>(a_buf)->contents());
 
   79    base_offset += offset;
 
   80    enc->setBuffer(a_buf, base_offset, idx);
 
 
   86    auto buf = 
static_cast<MTL::Resource*
>(a.
buffer().
ptr());
 
   88      concurrent_outputs.insert(
buf);
 
 
  109  int num_dispatches{0};
 
  110  MTL::CommandBuffer* cbuf;
 
  111  MTL::ComputeCommandEncoder* enc;
 
  112  bool concurrent{
false};
 
  113  std::unordered_set<MTL::Resource*> outputs;
 
  114  std::unordered_set<MTL::Resource*> concurrent_outputs;
 
 
  137      const std::string& lib_name,
 
  138      const std::string& lib_path);
 
  140      const std::string& lib_name,
 
  141      const std::function<std::string(
const std::string&)>& lib_path_func =
 
  147      const std::string& name,
 
  148      const std::string& source_string,
 
  152      const std::string& name,
 
  153      const MTL::StitchedLibraryDescriptor* desc,
 
  157      const std::string& base_name,
 
  158      MTL::Library* mtl_lib,
 
  159      const std::string& specialized_name = 
"",
 
  163      const std::string& base_name,
 
  164      const std::string& lib_name = 
"mlx",
 
  165      const std::string& specialized_name = 
"",
 
  169      const std::string& base_name,
 
  170      MTL::Library* mtl_lib,
 
  171      const std::string& hash_name = 
"",
 
  173      const std::vector<MTL::Function*>& linked_functions = {});
 
  176      const std::string& base_name,
 
  177      const std::string& lib_name = 
"mlx",
 
  178      const std::string& hash_name = 
"",
 
  180      const std::vector<MTL::Function*>& linked_functions = {});
 
  183      const std::vector<MTL::ArgumentDescriptor*>& arg_descs) 
const;
 
  186  MTL::Library* get_library_cache_(
const std::string& name);
 
  188  MTL::Library* get_library_(
const std::string& source_string);
 
  189  MTL::Library* get_library_(
const MTL::StitchedLibraryDescriptor* desc);
 
  191  MTL::Function* get_function_(
const std::string& name, MTL::Library* mtl_lib);
 
  193  MTL::Function* get_function_(
 
  194      const std::string& name,
 
  195      const std::string& specialized_name,
 
  197      MTL::Library* mtl_lib);
 
  199  MTL::LinkedFunctions* get_linked_functions_(
 
  200      const std::vector<MTL::Function*>& funcs);
 
  202  MTL::ComputePipelineState* get_kernel_(
 
  203      const std::string& name,
 
  204      const MTL::Function* mtl_function);
 
  206  MTL::ComputePipelineState* get_kernel_(
 
  207      const std::string& name,
 
  208      const MTL::Function* mtl_function,
 
  209      const MTL::LinkedFunctions* linked_functions);
 
  211  MTL::Device* device_;
 
  212  std::unordered_map<int32_t, MTL::CommandQueue*> queue_map_;
 
  213  std::unordered_map<int32_t, std::pair<int, MTL::CommandBuffer*>> buffer_map_;
 
  214  std::unordered_map<int32_t, std::unique_ptr<CommandEncoder>> encoder_map_;
 
  215  std::unordered_map<std::string, MTL::ComputePipelineState*> kernel_map_;
 
  216  std::unordered_map<std::string, MTL::Library*> library_map_;
 
 
const void * ptr() const
Definition allocator.h:23
 
T * data()
Definition array.h:313
 
allocator::Buffer & buffer()
Definition array.h:299
 
~ConcurrentContext()
Definition device.h:51
 
ConcurrentContext(CommandEncoder &enc)
Definition device.h:48