5#include <Metal/Metal.hpp> 
   10#include <shared_mutex> 
   12#include <unordered_map> 
   13#include <unordered_set> 
   18namespace fs = std::filesystem;
 
   26  std::string mtllib_path;
 
   27  std::string lib_ext = lib_name + 
".metallib";
 
   31    auto mtllib = fs::path(info.dli_fname).remove_filename() / lib_ext;
 
   32    mtllib_path = mtllib.c_str();
 
 
   39    std::vector<std::tuple<const void*, MTL::DataType, NS::UInteger>>;
 
   48      enc.concurrent_ = 
true;
 
 
   51      enc.concurrent_ = 
false;
 
   52      enc.prev_outputs_.insert(
 
   53          enc.concurrent_outputs_.begin(), enc.concurrent_outputs_.end());
 
   54      enc.concurrent_outputs_.clear();
 
 
 
   67  void set_buffer(
const MTL::Buffer* buf, 
int idx, int64_t offset = 0);
 
   70    enc_->setComputePipelineState(kernel);
 
 
   74    enc_->waitForFence(fence);
 
 
   78    enc_->updateFence(fence);
 
 
   83    enc_->setBytes(vec.data(), nelems * 
sizeof(T), idx);
 
 
   92    return enc_->setBytes(v, n * 
sizeof(T), idx);
 
 
   97    return enc_->setBytes(&v, 
sizeof(T), idx);
 
 
  106  std::unordered_set<const void*>& 
inputs() {
 
 
  118  MTL::ComputeCommandEncoder* enc_;
 
  119  bool needs_barrier_{
false};
 
  120  bool concurrent_{
false};
 
  121  std::unordered_set<MTL::Resource*> prev_outputs_;
 
  122  std::unordered_set<MTL::Resource*> next_outputs_;
 
  123  std::unordered_set<MTL::Resource*> concurrent_outputs_;
 
  124  std::unordered_set<const void*> all_inputs_;
 
  125  std::unordered_set<const void*> all_outputs_;
 
 
  146  std::unordered_map<const void*, std::shared_ptr<Fence>> 
outputs;
 
  157  std::unique_ptr<CommandEncoder> 
encoder{
nullptr};
 
 
  186      const std::string& lib_name,
 
  187      const std::string& lib_path);
 
  192    if (
auto it = library_map_.find(lib_name); it == library_map_.end()) {
 
 
  198      const std::string& name,
 
  199      const std::function<std::string(
void)>& builder);
 
  202      const std::string& base_name,
 
  203      MTL::Library* mtl_lib,
 
  204      const std::string& hash_name = 
"",
 
  206      const std::vector<MTL::Function*>& linked_functions = {});
 
  209      const std::string& base_name,
 
  210      const std::string& lib_name = 
"mlx",
 
  211      const std::string& hash_name = 
"",
 
  213      const std::vector<MTL::Function*>& linked_functions = {});
 
  216      const std::vector<MTL::ArgumentDescriptor*>& arg_descs) 
const;
 
  226    return stream_map_.find(index)->second;
 
  228  MTL::Library* get_library_cache_(
const std::string& name);
 
  230  MTL::Library* get_library_(
const std::string& name);
 
  231  MTL::Library* build_library_(
const std::string& source_string);
 
  233  MTL::Function* get_function_(
const std::string& name, MTL::Library* mtl_lib);
 
  235  MTL::Function* get_function_(
 
  236      const std::string& name,
 
  237      const std::string& specialized_name,
 
  239      MTL::Library* mtl_lib);
 
  241  MTL::LinkedFunctions* get_linked_functions_(
 
  242      const std::vector<MTL::Function*>& funcs);
 
  244  MTL::ComputePipelineState* get_kernel_(
 
  245      const std::string& name,
 
  246      const MTL::Function* mtl_function);
 
  248  MTL::ComputePipelineState* get_kernel_(
 
  249      const std::string& name,
 
  250      const MTL::Function* mtl_function,
 
  251      const MTL::LinkedFunctions* linked_functions);
 
  253  MTL::ComputePipelineState* get_kernel_(
 
  254      const std::string& base_name,
 
  255      MTL::Library* mtl_lib,
 
  256      const std::string& hash_name,
 
  258      const std::vector<MTL::Function*>& linked_functions = {});
 
  260  MTL::Device* device_;
 
  261  std::unordered_map<int32_t, DeviceStream> stream_map_;
 
  263  std::shared_mutex kernel_mtx_;
 
  264  std::unordered_map<std::string, MTL::ComputePipelineState*> kernel_map_;
 
  266  std::shared_mutex library_mtx_;
 
  267  std::unordered_map<std::string, MTL::Library*> library_map_;
 
  268  const MTL::ResidencySet* residency_set_{
nullptr};
 
 
ConcurrentContext(CommandEncoder &enc)
Definition device.h:47
 
CommandEncoder(MTL::CommandBuffer *cbuf)
 
~ConcurrentContext()
Definition device.h:50
 
ConcurrentContext(CommandEncoder &enc)
Definition device.h:47