5#include <Metal/Metal.hpp> 
   10#include <shared_mutex> 
   12#include <unordered_map> 
   13#include <unordered_set> 
   18namespace fs = std::filesystem;
 
   26  std::string mtllib_path;
 
   27  std::string lib_ext = lib_name + 
".metallib";
 
   31    auto mtllib = fs::path(info.dli_fname).remove_filename() / lib_ext;
 
   32    mtllib_path = mtllib.c_str();
 
 
   39    std::vector<std::tuple<const void*, MTL::DataType, NS::UInteger>>;
 
   48      enc.concurrent_ = 
true;
 
 
   51      enc.concurrent_ = 
false;
 
   52      enc.prev_outputs_.insert(
 
   53          enc.concurrent_outputs_.begin(), enc.concurrent_outputs_.end());
 
   54      enc.concurrent_outputs_.clear();
 
 
 
   68    enc_->setComputePipelineState(kernel);
 
 
   72    enc_->waitForFence(fence);
 
 
   76    enc_->updateFence(fence);
 
 
   81    enc_->setBytes(vec.data(), nelems * 
sizeof(T), idx);
 
 
   90    return enc_->setBytes(v, n * 
sizeof(T), idx);
 
 
   95    return enc_->setBytes(&v, 
sizeof(T), idx);
 
 
  104  std::unordered_set<const void*>& 
inputs() {
 
 
  114  MTL::ComputeCommandEncoder* enc_;
 
  115  bool needs_barrier_{
false};
 
  116  bool concurrent_{
false};
 
  117  std::unordered_set<MTL::Resource*> prev_outputs_;
 
  118  std::unordered_set<MTL::Resource*> next_outputs_;
 
  119  std::unordered_set<MTL::Resource*> concurrent_outputs_;
 
  120  std::unordered_set<const void*> all_inputs_;
 
  121  std::unordered_set<const void*> all_outputs_;
 
 
  142  std::unordered_map<const void*, std::shared_ptr<Fence>> 
outputs;
 
  153  std::unique_ptr<CommandEncoder> 
encoder{
nullptr};
 
 
  182      const std::string& lib_name,
 
  183      const std::string& lib_path);
 
  188    if (
auto it = library_map_.find(lib_name); it == library_map_.end()) {
 
 
  194      const std::string& name,
 
  195      const std::function<std::string(
void)>& builder);
 
  198      const std::string& base_name,
 
  199      MTL::Library* mtl_lib,
 
  200      const std::string& hash_name = 
"",
 
  202      const std::vector<MTL::Function*>& linked_functions = {});
 
  205      const std::string& base_name,
 
  206      const std::string& lib_name = 
"mlx",
 
  207      const std::string& hash_name = 
"",
 
  209      const std::vector<MTL::Function*>& linked_functions = {});
 
  212      const std::vector<MTL::ArgumentDescriptor*>& arg_descs) 
const;
 
  222    return stream_map_.find(index)->second;
 
  224  MTL::Library* get_library_cache_(
const std::string& name);
 
  226  MTL::Library* get_library_(
const std::string& name);
 
  227  MTL::Library* build_library_(
const std::string& source_string);
 
  229  MTL::Function* get_function_(
const std::string& name, MTL::Library* mtl_lib);
 
  231  MTL::Function* get_function_(
 
  232      const std::string& name,
 
  233      const std::string& specialized_name,
 
  235      MTL::Library* mtl_lib);
 
  237  MTL::LinkedFunctions* get_linked_functions_(
 
  238      const std::vector<MTL::Function*>& funcs);
 
  240  MTL::ComputePipelineState* get_kernel_(
 
  241      const std::string& name,
 
  242      const MTL::Function* mtl_function);
 
  244  MTL::ComputePipelineState* get_kernel_(
 
  245      const std::string& name,
 
  246      const MTL::Function* mtl_function,
 
  247      const MTL::LinkedFunctions* linked_functions);
 
  249  MTL::ComputePipelineState* get_kernel_(
 
  250      const std::string& base_name,
 
  251      MTL::Library* mtl_lib,
 
  252      const std::string& hash_name,
 
  254      const std::vector<MTL::Function*>& linked_functions = {});
 
  256  MTL::Device* device_;
 
  257  std::unordered_map<int32_t, DeviceStream> stream_map_;
 
  259  std::shared_mutex kernel_mtx_;
 
  260  std::unordered_map<std::string, MTL::ComputePipelineState*> kernel_map_;
 
  262  std::shared_mutex library_mtx_;
 
  263  std::unordered_map<std::string, MTL::Library*> library_map_;
 
  264  const MTL::ResidencySet* residency_set_{
nullptr};
 
 
~ConcurrentContext()
Definition device.h:50
 
ConcurrentContext(CommandEncoder &enc)
Definition device.h:47