5#include <Metal/Metal.hpp> 
   10#include <shared_mutex> 
   12#include <unordered_map> 
   13#include <unordered_set> 
   18namespace fs = std::filesystem;
 
   26  std::string mtllib_path;
 
   27  std::string lib_ext = lib_name + 
".metallib";
 
   31    auto mtllib = fs::path(info.dli_fname).remove_filename() / lib_ext;
 
   32    mtllib_path = mtllib.c_str();
 
 
   39    std::vector<std::tuple<const void*, MTL::DataType, NS::UInteger>>;
 
   48      enc.concurrent_ = 
true;
 
 
   51      enc.concurrent_ = 
false;
 
   52      enc.prev_outputs_.insert(
 
   53          enc.concurrent_outputs_.begin(), enc.concurrent_outputs_.end());
 
   54      enc.concurrent_outputs_.clear();
 
 
 
   77  std::unordered_set<const void*>& 
inputs() {
 
 
   82  std::unordered_set<const void*> 
outputs() {
 
 
   87  MTL::ComputeCommandEncoder* enc_;
 
   88  bool needs_barrier_{
false};
 
   89  bool concurrent_{
false};
 
   90  std::unordered_set<MTL::Resource*> prev_outputs_;
 
   91  std::unordered_set<MTL::Resource*> next_outputs_;
 
   92  std::unordered_set<MTL::Resource*> concurrent_outputs_;
 
   93  std::unordered_set<const void*> all_inputs_;
 
   94  std::unordered_set<const void*> all_outputs_;
 
 
  115  std::unordered_map<const void*, std::shared_ptr<Fence>> 
outputs;
 
  126  std::unique_ptr<CommandEncoder> 
encoder{
nullptr};
 
 
  155      const std::string& lib_name,
 
  156      const std::string& lib_path);
 
  161    if (
auto it = library_map_.find(lib_name); it == library_map_.end()) {
 
 
  167      const std::string& name,
 
  168      const std::function<std::string(
void)>& builder);
 
  171      const std::string& base_name,
 
  172      MTL::Library* mtl_lib,
 
  173      const std::string& hash_name = 
"",
 
  175      const std::vector<MTL::Function*>& linked_functions = {});
 
  178      const std::string& base_name,
 
  179      const std::string& lib_name = 
"mlx",
 
  180      const std::string& hash_name = 
"",
 
  182      const std::vector<MTL::Function*>& linked_functions = {});
 
  185      const std::vector<MTL::ArgumentDescriptor*>& arg_descs) 
const;
 
  195    return stream_map_.find(index)->second;
 
  197  MTL::Library* get_library_cache_(
const std::string& name);
 
  199  MTL::Library* get_library_(
const std::string& name);
 
  200  MTL::Library* build_library_(
const std::string& source_string);
 
  202  MTL::Function* get_function_(
const std::string& name, MTL::Library* mtl_lib);
 
  204  MTL::Function* get_function_(
 
  205      const std::string& name,
 
  206      const std::string& specialized_name,
 
  208      MTL::Library* mtl_lib);
 
  210  MTL::LinkedFunctions* get_linked_functions_(
 
  211      const std::vector<MTL::Function*>& funcs);
 
  213  MTL::ComputePipelineState* get_kernel_(
 
  214      const std::string& name,
 
  215      const MTL::Function* mtl_function);
 
  217  MTL::ComputePipelineState* get_kernel_(
 
  218      const std::string& name,
 
  219      const MTL::Function* mtl_function,
 
  220      const MTL::LinkedFunctions* linked_functions);
 
  222  MTL::ComputePipelineState* get_kernel_(
 
  223      const std::string& base_name,
 
  224      MTL::Library* mtl_lib,
 
  225      const std::string& hash_name,
 
  227      const std::vector<MTL::Function*>& linked_functions = {});
 
  229  MTL::Device* device_;
 
  230  std::unordered_map<int32_t, DeviceStream> stream_map_;
 
  232  std::shared_mutex kernel_mtx_;
 
  233  std::unordered_map<std::string, MTL::ComputePipelineState*> kernel_map_;
 
  235  std::shared_mutex library_mtx_;
 
  236  std::unordered_map<std::string, MTL::Library*> library_map_;
 
  237  const MTL::ResidencySet* residency_set_{
nullptr};
 
 
~ConcurrentContext()
Definition device.h:50
 
ConcurrentContext(CommandEncoder &enc)
Definition device.h:47